# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Linear trend using change points.
Examples
--------
Define a linear trend with 8 changepoints:
.. code-block:: python
from pymc_marketing.mmm import LinearTrend
trend = LinearTrend(n_changepoints=8)
Sample the prior for the trend parameters and curve:
.. code-block:: python
import numpy as np
seed = sum(map(ord, "Linear Trend"))
rng = np.random.default_rng(seed)
prior = trend.sample_prior(random_seed=rng)
curve = trend.sample_curve(prior)
Plot the curve samples:
.. code-block:: python
_, axes = trend.plot_curve(curve, sample_kwargs={"rng": rng})
ax = axes[0]
ax.set(
xlabel="Time",
ylabel="Trend",
title=f"Linear Trend with {trend.n_changepoints} Change Points",
)
.. image:: /_static/linear-trend-prior.png
:alt: LinearTrend prior
"""
from typing import Any, cast
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pydantic import BaseModel, Field, InstanceOf, model_validator
from pymc.distributions.shape_utils import Dims
from typing_extensions import Self
from pymc_marketing.mmm.plot import plot_curve
from pymc_marketing.prior import Prior, create_dim_handler
[docs]
class LinearTrend(BaseModel):
r"""LinearTrend class.
Linear trend component using change points. The trend is defined as:
.. math::
f(t) = k + \sum_{m=1}^{M} \delta_m I(t > s_m)
where:
- :math:`k` is the base trend,
- :math:`\delta_m` is the change in the trend at change point :math:`m`,
- :math:`I` is the indicator function,
- :math:`s_m` is the change point.
The change points are defined as:
.. math::
s_m = \frac{m}{M+1} \max(t)
where :math:`M` is the number of change points.
The priors for the trend parameters are:
- :math:`k \sim \text{Normal}(0, 0.05)`
- :math:`\delta_m \sim \text{Laplace}(0, 0.25)`
Parameters
----------
priors : dict[str, Prior], optional
Dictionary with the priors for the trend parameters. The
dictionary must have 'delta' key. If `include_intercept` is
True, the 'k' key is also required. By default None, or
the default priors.
dims : Dims, optional
Dimensions of the parameters, by default None or empty.
n_changepoints : int, optional
Number of changepoints, by default 10.
include_intercept : bool, optional
Include an intercept in the trend, by default False
Examples
--------
Linear trend with 10 changepoints:
.. code-block:: python
from pymc_marketing.mmm import LinearTrend
trend = LinearTrend(n_changepoints=10)
Use the trend in a model:
.. code-block:: python
import pymc as pm
import numpy as np
import pandas as pd
n_years = 3
n_dates = 52 * n_years
first_date = "2020-01-01"
dates = pd.date_range(first_date, periods=n_dates, freq="W-MON")
dayofyear = dates.dayofyear.to_numpy()
t = (dates - dates[0]).days.to_numpy()
t = t / 365.25
coords = {"date": dates}
with pm.Model(coords=coords) as model:
intercept = pm.Normal("intercept", mu=0, sigma=1)
mu = intercept + trend.apply(t)
sigma = pm.Gamma("sigma", mu=0.1, sigma=0.025)
pm.Normal("obs", mu=mu, sigma=sigma, dims="date")
Hierarchical LinearTrend via hierarchical prior:
.. code-block:: python
from pymc_marketing.prior import Prior
hierarchical_delta = Prior(
"Laplace",
mu=Prior("Normal", dims="changepoint"),
b=Prior("HalfNormal", dims="changepoint"),
dims=("changepoint", "geo"),
)
priors = dict(delta=hierarchical_delta)
hierarchical_trend = LinearTrend(
priors=priors,
n_changepoints=10,
dims="geo",
)
Sample the hierarchical trend:
.. code-block:: python
seed = sum(map(ord, "Hierarchical LinearTrend"))
rng = np.random.default_rng(seed)
coords = {"geo": ["A", "B"]}
prior = hierarchical_trend.sample_prior(
coords=coords,
random_seed=rng,
)
curve = hierarchical_trend.sample_curve(prior)
Plot the curve HDI and samples:
sample_kwargs = {"n": 3, "rng": rng}
fig, axes = hierarchical_trend.plot_curve(
curve,
sample_kwargs=sample_kwargs,
)
fig.suptitle("Hierarchical Linear Trend")
axes[0].set(ylabel="Trend", xlabel="Time")
axes[1].set(xlabel="Time")
.. image:: /_static/hierarchical-linear-trend-prior.png
:alt: Hierarchical LinearTrend prior
References
----------
Adapted from MBrouns/timeseers package:
https://github.com/MBrouns/timeseers/blob/master/src/timeseers/linear_trend.py
"""
priors: InstanceOf[dict[str, Prior]] = Field(
None,
description="Priors for the trend parameters.",
)
dims: tuple[str] | InstanceOf[Dims] | str | None = Field(
None,
description="The additional dimensions for the trend.",
)
n_changepoints: int = Field(
10,
description="Number of changepoints.",
ge=1,
)
include_intercept: bool = Field(
False,
description="Include an intercept in the trend.",
)
@model_validator(mode="after")
def _dims_is_tuple(self) -> Self:
dims = self.dims
if isinstance(dims, str):
self.dims = (dims,)
self.dims: tuple[str] = self.dims or ()
return self
@model_validator(mode="after")
def _priors_are_set(self) -> Self:
self.priors = self.priors or self.default_priors.copy()
return self
@model_validator(mode="after")
def _check_parameters(self) -> Self:
required_parameters = set(self.default_priors.keys())
if set(self.priors.keys()) > required_parameters:
msg = f"Invalid priors. The required parameters are {required_parameters}."
raise ValueError(msg)
return self
@model_validator(mode="after")
def _check_dims_are_subsets(self) -> Self:
allowed_dims = {"changepoint"}.union(cast(Dims, self.dims))
if not all(set(prior.dims) <= allowed_dims for prior in self.priors.values()):
msg = "Invalid dimensions in the priors."
raise ValueError(msg)
return self
@property
def default_priors(self) -> dict[str, Prior]:
"""Default priors for the trend parameters.
Returns
-------
dict[str, Prior]
Dictionary with the default priors.
"""
priors = {
"delta": Prior(
"Laplace",
mu=0,
b=0.25,
dims="changepoint",
),
}
if self.include_intercept:
priors["k"] = Prior("Normal", mu=0, sigma=0.05)
return priors
[docs]
def apply(self, t: pt.TensorLike) -> pt.TensorVariable:
"""Create the linear trend for the given x values.
Parameters
----------
t : pt.TensorLike
Input values for the trend.
Returns
-------
pt.TensorVariable
TensorVariable with the trend values.
"""
dims = cast(Dims, self.dims)
model = pm.modelcontext(None)
model.add_coord("changepoint", range(self.n_changepoints))
DUMMY_DIM = "DATE"
out_dims = (DUMMY_DIM, "changepoint", *dims)
dim_handler = create_dim_handler(desired_dims=out_dims)
# (changepoints, )
s = pt.linspace(0, pt.max(t), self.n_changepoints)
s.type.shape = (self.n_changepoints,)
s = dim_handler(
s,
("changepoint",),
)
# (dates, changepoints)
A = (dim_handler(t, (DUMMY_DIM,)) > s) * 1.0
delta_dist = self.priors["delta"]
delta = dim_handler(
delta_dist.create_variable("delta"),
delta_dist.dims,
)
k_dim_handler = create_dim_handler((DUMMY_DIM, *dims))
first = (A * delta).sum(axis=1) * k_dim_handler(t, (DUMMY_DIM,))
if self.include_intercept:
# (additional_groups)
k_dist = self.priors["k"]
k = k_dim_handler(
k_dist.create_variable("k"),
k_dist.dims,
)
first += k
gamma = -s * delta
second = (A * gamma).sum(axis=1)
return first + second
[docs]
def sample_prior(
self,
coords=None,
**sample_prior_predictive_kwargs,
) -> xr.Dataset:
"""Sample the prior for the parameters used in the trend.
Parameters
----------
coords : dict, optional
Coordinates in the priors, by default includes the changepoints.
sample_prior_predictive_kwargs : dict, optional
Keyword arguments for the `pm.sample_prior_predictive` function.
Returns
-------
xr.Dataset
Dataset with the prior samples.
"""
coords = coords or {}
coords["changepoint"] = range(self.n_changepoints)
with pm.Model(coords=coords):
for key, param in self.priors.items():
param.create_variable(key)
return pm.sample_prior_predictive(**sample_prior_predictive_kwargs).prior
[docs]
def sample_curve(
self,
parameters: xr.Dataset,
max_value: float = 1.0,
) -> xr.DataArray:
"""Sample the curve given parameters.
Parameters
----------
parameters : xr.Dataset
Dataset with the parameters to condition on. Would be
either the prior or the posterior.
Returns
-------
xr.DataArray
DataArray with the curve samples.
"""
t = np.linspace(0, max_value, 100)
coords: dict[str, Any] = {"t": t}
for name in self.priors.keys():
for key, values in parameters[name].coords.items():
if key in {"chain", "draw"}:
continue
coords[key] = values.to_numpy()
with pm.Model(coords=coords):
name = "trend"
pm.Deterministic(
name,
self.apply(t),
dims=("t", *cast(Dims, self.dims)),
)
return pm.sample_posterior_predictive(
parameters,
var_names=[name],
).posterior_predictive[name]
[docs]
def plot_curve(
self,
curve: xr.DataArray,
subplot_kwargs: dict | None = None,
sample_kwargs: dict | None = None,
hdi_kwargs: dict | None = None,
include_changepoints: bool = True,
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
"""Plot the curve samples from the trend.
Parameters
----------
curve : xr.DataArray
DataArray with the curve samples.
subplot_kwargs : dict, optional
Keyword arguments for the subplots, by default None.
sample_kwargs : dict, optional
Keyword arguments for the samples, by default None.
hdi_kwargs : dict, optional
Keyword arguments for the HDI, by default None.
include_changepoints : bool, optional
Include the change points in the plot, by default True.
Returns
-------
tuple[plt.Figure, npt.NDArray[plt.Axes]]
Tuple with the figure and the axes.
"""
fig, axes = plot_curve(
curve,
{"t"},
subplot_kwargs=subplot_kwargs,
sample_kwargs=sample_kwargs,
hdi_kwargs=hdi_kwargs,
)
if not include_changepoints:
return fig, axes
max_value = curve.coords["t"].max().item()
for ax in np.ravel(axes):
for i in range(1, self.n_changepoints + 1):
# Need to add 1 to the number of changepoints
ax.axvline(
max_value * i / (self.n_changepoints + 1),
color="gray",
linestyle="--",
)
return fig, axes