Source code for pymc_marketing.mmm.fourier

#   Copyright 2024 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Fourier seasonality transformations.

This modules provides Fourier seasonality transformations for use in
Marketing Mix Models. The Fourier seasonality is a set of sine and cosine
functions that can be used to model periodic patterns in the data.

There are two types of Fourier seasonality transformations available:

- Yearly Fourier: A yearly seasonality with a period of 365.25 days
- Monthly Fourier: A monthly seasonality with a period of 365.25 / 12 days

.. plot::
    :context: close-figs

    import matplotlib.pyplot as plt
    import numpy as np
    import arviz as az
    from pymc_marketing.mmm import YearlyFourier
    from pymc_marketing.prior import Prior

    plt.style.use('arviz-darkgrid')

    prior = Prior(
        "Normal",
        mu=[0, 0, -1, 0],
        sigma=Prior("Gamma", mu=0.10, sigma=0.1, dims="fourier"),
        dims=("hierarchy", "fourier"),
    )
    yearly = YearlyFourier(n_order=2, prior=prior)
    coords = {"hierarchy": ["A", "B"]}
    prior = yearly.sample_prior(coords=coords)
    curve = yearly.sample_curve(prior)
    fig, _ = yearly.plot_curve(curve, subplot_kwargs={"ncols": 1})
    fig.suptitle("Yearly Fourier Seasonality")
    plt.show()

Examples
--------
Use yearly fourier seasonality for custom Marketing Mix Model.

.. code-block:: python

    import pandas as pd
    import pymc as pm

    from pymc_marketing.mmm import YearlyFourier

    yearly = YearlyFourier(n_order=3)

    dates = pd.date_range("2023-01-01", periods=52, freq="W-MON")

    dayofyear = dates.dayofyear.to_numpy()

    with pm.Model() as model:
        fourier_trend = yearly.apply(dayofyear)

Plot the prior fourier seasonality trend.

.. code-block:: python

    import matplotlib.pyplot as plt

    prior = yearly.sample_prior()
    curve = yearly.sample_curve(prior)
    yearly.plot_curve(curve)
    plt.show()

Change the prior distribution of the fourier seasonality.

.. code-block:: python

    from pymc_marketing.mmm import YearlyFourier
    from pymc_marketing.prior import Prior

    prior = Prior("Normal", mu=0, sigma=0.10)
    yearly = YearlyFourier(n_order=6, prior=prior)

Even make it hierarchical...

.. code-block:: python

    from pymc_marketing.mmm import YearlyFourier
    from pymc_marketing.prior import Prior

    # "fourier" is the default prefix!
    prior = Prior(
        "Laplace",
        mu=Prior("Normal", dims="fourier"),
        b=Prior("HalfNormal", sigma=0.1, dims="fourier"),
        dims=("fourier", "hierarchy"),
    )
    yearly = YearlyFourier(n_order=3, prior=prior)

All the plotting will still work! Just pass any coords.

.. code-block:: python

    import matplotlib.pyplot as plt

    coords = {"hierarchy": ["A", "B", "C"]}
    prior = yearly.sample_prior(coords=coords)
    curve = yearly.sample_curve(prior)
    yearly.plot_curve(curve)
    plt.show()

Out of sample predictions with fourier seasonality by changing the day of year
used in the model.

.. code-block:: python

    import pandas as pd
    import pymc as pm

    from pymc_marketing.mmm import YearlyFourier

    periods = 52 * 3
    dates = pd.date_range("2022-01-01", periods=periods, freq="W-MON")

    training_dates = dates[:52 * 2]
    testing_dates = dates[52 * 2:]

    yearly = YearlyFourier(n_order=3)

    coords = {
        "date": training_dates,
    }
    with pm.Model(coords=coords) as model:
        dayofyear = pm.Data(
            "dayofyear",
            training_dates.dayofyear.to_numpy(),
            dims="date",
        )

        trend = pm.Deterministic(
            "trend",
            yearly.apply(dayofyear),
            dims="date",
        )

        idata = pm.sample_prior_predictive().prior

    with model:
        pm.set_data(
            {"dayofyear": testing_dates.dayofyear.to_numpy()},
            coords={"date": testing_dates},
        )

        out_of_sample = pm.sample_posterior_predictive(
            idata,
            var_names=["trend"],
        ).posterior_predictive["trend"]


Use yearly and monthly fourier seasonality together.

By default, the prefix of the fourier seasonality is set to "fourier". However,
the prefix can be changed upon initialization in order to avoid variable name
conflicts.

.. code-block:: python

    import pandas as pd
    import pymc as pm

    from pymc_marketing.mmm import (
        MonthlyFourier,
        YearlyFourier,
    )

    yearly = YearlyFourier(n_order=6, prefix="yearly")
    monthly = MonthlyFourier(n_order=3, prefix="monthly")

    dates = pd.date_range("2023-01-01", periods=52, freq="W-MON")
    dayofyear = dates.dayofyear.to_numpy()

    coords = {
        "date": dates,
    }

    with pm.Model(coords=coords) as model:
        yearly_trend = yearly.apply(dayofyear)
        monthly_trend = monthly.apply(dayofyear)

        trend = pm.Deterministic(
            "trend",
            yearly_trend + monthly_trend,
            dims="date",
        )

    with model:
        prior_samples = pm.sample_prior_predictive().prior

"""

from collections.abc import Callable
from typing import Any

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pydantic import BaseModel, Field, InstanceOf, field_serializer, model_validator
from typing_extensions import Self

from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_YEAR
from pymc_marketing.mmm.plot import plot_curve, plot_hdi, plot_samples
from pymc_marketing.prior import Prior, create_dim_handler

X_NAME: str = "day"
NON_GRID_NAMES: frozenset[str] = frozenset({X_NAME})


[docs] def generate_fourier_modes( periods: pt.TensorLike, n_order: int, ) -> pt.TensorVariable: """Create fourier modes for a given period. Parameters ---------- periods : pt.TensorLike Periods to generate fourier modes for. n_order : int Number of fourier modes to generate. Returns ------- pt.TensorVariable Fourier modes. """ multiples = pt.arange(1, n_order + 1) x = 2 * pt.pi * periods values = x[:, None] * multiples return pt.concatenate( [ pt.sin(values), pt.cos(values), ], axis=1, )
[docs] class FourierBase(BaseModel): """Base class for Fourier seasonality transformations. Parameters ---------- n_order : int Number of fourier modes to use. days_in_period : float Number of days in a period. prefix : str, optional Alternative prefix for the fourier seasonality, by default None or "fourier" prior : Prior, optional Prior distribution for the fourier seasonality beta parameters, by default `Prior("Laplace", mu=0, b=1)` variable_name : str, optional Name of the variable that multiplies the fourier modes. By default None, in which case it is set to the `{prefix}_beta`. """ n_order: int = Field(..., gt=0) days_in_period: float = Field(..., gt=0) prefix: str = Field("fourier") prior: InstanceOf[Prior] = Field(Prior("Laplace", mu=0, b=1)) variable_name: str | None = Field(None)
[docs] def model_post_init(self, __context: Any) -> None: """Model post initialization for a Pydantic model.""" if self.variable_name is None: self.variable_name = f"{self.prefix}_beta" if not self.prior.dims: self.prior = self.prior.deepcopy() self.prior.dims = self.prefix
@model_validator(mode="after") def _check_variable_name(self) -> Self: if self.variable_name == self.prefix: raise ValueError("Variable name cannot be the same as the prefix") return self @model_validator(mode="after") def _check_prior_has_right_dimensions(self) -> Self: if self.prefix not in self.prior.dims: raise ValueError(f"Prior distribution must have dimension {self.prefix}") return self
[docs] @field_serializer("prior", when_used="json") def serialize_prior(prior: Prior) -> dict[str, Any]: """Serialize the prior distribution. Parameters ---------- prior : Prior The prior distribution to serialize. Returns ------- dict[str, Any] The serialized prior distribution. """ return prior.to_json()
@property def nodes(self) -> list[str]: """Fourier node names for model coordinates.""" return [ f"{func}_{i}" for func in ["sin", "cos"] for i in range(1, self.n_order + 1) ]
[docs] def apply( self, dayofyear: pt.TensorLike, result_callback: Callable[[pt.TensorVariable], None] | None = None, ) -> pt.TensorVariable: """Apply fourier seasonality to day of year. Must be used within a PyMC model context. Parameters ---------- dayofyear : pt.TensorLike Day of year. result_callback : Callable[[pt.TensorVariable], None], optional Callback function to apply to the result, by default None Returns ------- pt.TensorVariable Fourier seasonality Examples -------- Save off the result before summing through the prefix dimension. .. code-block:: python import pandas as pd import pymc as pm from pymc_marketing.mmm import YearlyFourier fourier = YearlyFourier(n_order=3) def callback(result): pm.Deterministic("fourier_trend", result, dims=("date", "fourier")) dates = pd.date_range("2023-01-01", periods=52, freq="W-MON") coords = { "date": dates, } with pm.Model(coords=coords) as model: dayofyear = dates.dayofyear.to_numpy() fourier.apply(dayofyear, result_callback=callback) """ periods = dayofyear / self.days_in_period model = pm.modelcontext(None) model.add_coord(self.prefix, self.nodes) beta = self.prior.create_variable(self.variable_name) fourier_modes = generate_fourier_modes(periods=periods, n_order=self.n_order) DUMMY_DIM = "DATE" prefix_idx = self.prior.dims.index(self.prefix) result_dims = (DUMMY_DIM, *self.prior.dims) dim_handler = create_dim_handler(result_dims) result = dim_handler(fourier_modes, (DUMMY_DIM, self.prefix)) * dim_handler( beta, self.prior.dims ) if result_callback is not None: result_callback(result) return result.sum(axis=prefix_idx + 1)
[docs] def sample_prior(self, coords: dict | None = None, **kwargs) -> xr.Dataset: """Sample the prior distributions. Parameters ---------- coords : dict, optional Coordinates for the prior distribution, by default None kwargs Additional keywords for sample_prior_predictive Returns ------- xr.Dataset Prior distribution. """ coords = coords or {} coords[self.prefix] = self.nodes return self.prior.sample_prior(coords=coords, name=self.variable_name, **kwargs)
[docs] def sample_curve(self, parameters: az.InferenceData | xr.Dataset) -> xr.DataArray: """Create full period of the fourier seasonality. Parameters ---------- parameters : az.InferenceData | xr.Dataset Inference data or dataset containing the fourier parameters. Can be posterior or prior. Returns ------- xr.DataArray Full period of the fourier seasonality. """ full_period = np.arange(self.days_in_period + 1) coords = { "day": full_period, } for key, values in parameters[self.variable_name].coords.items(): if key in {"chain", "draw", self.prefix}: continue coords[key] = values.to_numpy() with pm.Model(coords=coords): name = f"{self.prefix}_trend" pm.Deterministic( name, self.apply(dayofyear=full_period), dims=tuple(coords.keys()), ) return pm.sample_posterior_predictive( parameters, var_names=[name], ).posterior_predictive[name]
[docs] def plot_curve( self, curve: xr.DataArray, subplot_kwargs: dict | None = None, sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot the seasonality for one full period. Parameters ---------- curve : xr.DataArray Sampled full period of the fourier seasonality. subplot_kwargs : dict, optional Keyword arguments for the subplot, by default None sample_kwargs : dict, optional Keyword arguments for the plot_full_period_samples method, by default None hdi_kwargs : dict, optional Keyword arguments for the plot_full_period_hdi method, by default None Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Matplotlib figure and axes. """ return plot_curve( curve, non_grid_names=set(NON_GRID_NAMES), subplot_kwargs=subplot_kwargs, sample_kwargs=sample_kwargs, hdi_kwargs=hdi_kwargs, )
[docs] def plot_curve_hdi( self, curve: xr.DataArray, hdi_kwargs: dict | None = None, subplot_kwargs: dict[str, Any] | None = None, plot_kwargs: dict[str, Any] | None = None, axes: npt.NDArray[plt.Axes] | None = None, ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot full period of the fourier seasonality. Parameters ---------- curve : xr.DataArray The curve to plot. hdi_kwargs : dict, optional Keyword arguments for the az.hdi function. Defaults to None. plot_kwargs : dict, optional Keyword arguments for the fill_between function. Defaults to None. subplot_kwargs : dict, optional Keyword arguments for plt.subplots axes : npt.NDArray[plt.Axes], optional The exact axes to plot on. Overrides any subplot_kwargs Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] """ return plot_hdi( curve, non_grid_names=set(NON_GRID_NAMES), hdi_kwargs=hdi_kwargs, subplot_kwargs=subplot_kwargs, plot_kwargs=plot_kwargs, axes=axes, )
[docs] def plot_curve_samples( self, curve: xr.DataArray, n: int = 10, rng: np.random.Generator | None = None, plot_kwargs: dict[str, Any] | None = None, subplot_kwargs: dict[str, Any] | None = None, axes: npt.NDArray[plt.Axes] | None = None, ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot samples from the curve. Parameters ---------- curve : xr.DataArray Samples from the curve. n : int, optional Number of samples to plot, by default 10 rng : np.random.Generator, optional Random number generator, by default None plot_kwargs : dict, optional Keyword arguments for the plot function, by default None subplot_kwargs : dict, optional Keyword arguments for the subplot, by default None axes : npt.NDArray[plt.Axes], optional Matplotlib axes, by default None Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Matplotlib figure and axes. """ return plot_samples( curve, non_grid_names=set(NON_GRID_NAMES), n=n, rng=rng, axes=axes, subplot_kwargs=subplot_kwargs, plot_kwargs=plot_kwargs, )
[docs] class YearlyFourier(FourierBase): """Yearly fourier seasonality. .. plot:: :context: close-figs import arviz as az import matplotlib.pyplot as plt import numpy as np from pymc_marketing.mmm import YearlyFourier from pymc_marketing.prior import Prior az.style.use("arviz-white") seed = sum(map(ord, "Yearly")) rng = np.random.default_rng(seed) mu = np.array([0, 0, -1, 0]) b = 0.15 dist = Prior("Laplace", mu=mu, b=b, dims="fourier") yearly = YearlyFourier(n_order=2, prior=dist) prior = yearly.sample_prior(random_seed=rng) curve = yearly.sample_full_period(prior) _, axes = yearly.plot_full_period(curve) axes[0].set(title="Yearly Fourier Seasonality") plt.show() n_order : int Number of fourier modes to use. prefix : str, optional Alternative prefix for the fourier seasonality, by default None or "fourier" prior : Prior, optional Prior distribution for the fourier seasonality beta parameters, by default `Prior("Laplace", mu=0, b=1)` name : str, optional Name of the variable that multiplies the fourier modes, by default None variable_name : str, optional Name of the variable that multiplies the fourier modes, by default None """ days_in_period: float = DAYS_IN_YEAR
[docs] class MonthlyFourier(FourierBase): """Monthly fourier seasonality. .. plot:: :context: close-figs import arviz as az import matplotlib.pyplot as plt import numpy as np from pymc_marketing.mmm import MonthlyFourier from pymc_marketing.prior import Prior az.style.use("arviz-white") seed = sum(map(ord, "Monthly")) rng = np.random.default_rng(seed) mu = np.array([0, 0, 0.5, 0]) b = 0.075 dist = Prior("Laplace", mu=mu, b=b, dims="fourier") yearly = MonthlyFourier(n_order=2, prior=dist) prior = yearly.sample_prior(samples=100) curve = yearly.sample_full_period(prior) _, axes = yearly.plot_full_period(curve) axes[0].set(title="Monthly Fourier Seasonality") plt.show() n_order : int Number of fourier modes to use. prefix : str, optional Alternative prefix for the fourier seasonality, by default None or "fourier" prior : Prior, optional Prior distribution for the fourier seasonality beta parameters, by default `Prior("Laplace", mu=0, b=1)` name : str, optional Name of the variable that multiplies the fourier modes, by default None variable_name : str, optional Name of the variable that multiplies the fourier modes, by default None """ days_in_period: float = DAYS_IN_MONTH