Source code for pymc_marketing.mmm.plot

#   Copyright 2024 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Plotting functions for the MMM."""

import warnings
from collections.abc import Generator, MutableMapping, Sequence
from itertools import product
from typing import Any

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import xarray as xr

from pymc_marketing.mmm.utils import drop_scalar_coords

Values = Sequence[Any] | npt.NDArray[Any]
Coords = dict[str, Values]


[docs] def get_plot_coords(coords: Coords, non_grid_names: set[str]) -> Coords: """Get the plot coordinates. Parameters ---------- coords : Coords The coordinates to get the plot coordinates from. non_grid_names : set[str] The names to exclude from the grid. Returns ------- Coords The plot coordinates. """ plot_coord_names = list(key for key in coords.keys() if key not in non_grid_names) return {name: np.array(coords[name]) for name in plot_coord_names}
[docs] def get_total_coord_size(coords: Coords) -> int: """Get the total size of the coordinates. Parameters ---------- coords : Coords The coordinates to get the total size of. Returns ------- int The total size of the coordinates. """ total_size: int = ( 1 if coords == {} else np.prod([len(values) for values in coords.values()]) # type: ignore ) if total_size >= 12: warnings.warn("Large number of coordinates!", stacklevel=2) return total_size
[docs] def set_subplot_kwargs_defaults( subplot_kwargs: MutableMapping[str, Any], total_size: int, ) -> None: """Set the defaults for the subplot kwargs. Parameters ---------- subplot_kwargs : MutableMapping[str, Any] The subplot kwargs to set the defaults for. total_size : int The total size of the coordinates. Raises ------ ValueError If both `ncols` and `nrows` are specified. """ if "ncols" in subplot_kwargs and "nrows" in subplot_kwargs: raise ValueError("Only specify one") if "ncols" not in subplot_kwargs and "nrows" not in subplot_kwargs: subplot_kwargs["ncols"] = total_size if "ncols" in subplot_kwargs: subplot_kwargs["nrows"] = total_size // subplot_kwargs["ncols"] elif "nrows" in subplot_kwargs: subplot_kwargs["ncols"] = total_size // subplot_kwargs["nrows"]
[docs] def selections( coords: Coords, ) -> Generator[dict[str, Any], None, None]: """Create generator of selections. Parameters ---------- coords : Coords The coordinates to create the selections from. Yields ------ dict[str, Any] The selections. """ coord_names = coords.keys() for values in product(*coords.values()): yield {name: value for name, value in zip(coord_names, values, strict=True)}
[docs] def plot_hdi( curve: xr.DataArray, non_grid_names: set[str], hdi_kwargs: dict | None = None, subplot_kwargs: dict[str, Any] | None = None, plot_kwargs: dict[str, Any] | None = None, axes: npt.NDArray[plt.Axes] | None = None, ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot hdi of the curve across coords. Parameters ---------- curve : xr.DataArray Curve to plot non_grid_names : set[str] The names to exclude from the grid. chain and draw are excluded automatically n : int, optional Number of samples to plot rng : np.random.Generator, optional Random number generator axes : npt.NDArray[plt.Axes], optional Axes to plot on subplot_kwargs : dict, optional Additional kwargs to while creating the fig and axes plot_kwargs : dict, optional Kwargs for the plot function Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Figure and the axes """ curve = drop_scalar_coords(curve) hdi_kwargs = hdi_kwargs or {} conf = az.hdi(curve, **hdi_kwargs)[curve.name] plot_coords = get_plot_coords( conf.coords, non_grid_names=non_grid_names.union({"hdi"}), ) total_size = get_total_coord_size(plot_coords) if axes is None: subplot_kwargs = subplot_kwargs or {} subplot_kwargs = {**{"sharey": True, "sharex": True}, **subplot_kwargs} set_subplot_kwargs_defaults(subplot_kwargs, total_size) fig, axes = plt.subplots(**subplot_kwargs) else: fig = plt.gcf() plot_kwargs = plot_kwargs or {} plot_kwargs = {**{"alpha": 0.25}, **plot_kwargs} for i, (ax, sel) in enumerate( zip(np.ravel(axes), selections(plot_coords), strict=False) ): color = f"C{i}" df_conf = conf.sel(sel).to_series().unstack() ax.fill_between( x=df_conf.index, y1=df_conf["lower"], y2=df_conf["higher"], color=color, **plot_kwargs, ) title = ", ".join(f"{name}={value}" for name, value in sel.items()) ax.set_title(title) if not isinstance(axes, np.ndarray): axes = np.array([axes]) return fig, axes
[docs] def random_samples( rng: np.random.Generator, n: int, n_chains: int, n_draws: int, ) -> list[tuple[int, int]]: """Generate random samples from the chains and draws. Parameters ---------- rng : np.random.Generator Random number generator n : int Number of samples to generate n_chains : int Number of chains n_draws : int Number of draws Returns ------- list[tuple[int, int]] The random samples """ combinations = list(product(range(n_chains), range(n_draws))) return [ tuple(pair) for pair in rng.choice(combinations, size=n, replace=False).tolist() ]
[docs] def plot_samples( curve: xr.DataArray, non_grid_names: set[str], n: int = 10, rng: np.random.Generator | None = None, axes: npt.NDArray[plt.Axes] | None = None, subplot_kwargs: dict[str, Any] | None = None, plot_kwargs: dict[str, Any] | None = None, ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot n samples of the curve across coords. Parameters ---------- curve : xr.DataArray Curve to plot non_grid_names : set[str] The names to exclude from the grid. chain and draw are excluded automatically n : int, optional Number of samples to plot rng : np.random.Generator, optional Random number generator axes : npt.NDArray[plt.Axes], optional Axes to plot on subplot_kwargs : dict, optional Additional kwargs to while creating the fig and axes plot_kwargs : dict, optional Kwargs for the plot function Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Figure and the axes """ curve = drop_scalar_coords(curve) plot_coords = get_plot_coords( curve.coords, non_grid_names=non_grid_names.union({"chain", "draw"}), ) total_size = get_total_coord_size(plot_coords) if axes is None: subplot_kwargs = subplot_kwargs or {} subplot_kwargs = {**{"sharey": True, "sharex": True}, **subplot_kwargs} set_subplot_kwargs_defaults(subplot_kwargs, total_size) fig, axes = plt.subplots(**subplot_kwargs) else: fig = plt.gcf() plot_kwargs = plot_kwargs or {} plot_kwargs = { **{"alpha": 0.3, "legend": False}, **plot_kwargs, } rng = rng or np.random.default_rng() idx = random_samples( rng, n=n, n_chains=curve.sizes["chain"], n_draws=curve.sizes["draw"] ) for i, (ax, sel) in enumerate( zip(np.ravel(axes), selections(plot_coords), strict=False) ): color = f"C{i}" df_curve = curve.sel(sel).to_series().unstack() df_sample = df_curve.loc[idx, :] df_sample.T.plot(ax=ax, color=color, **plot_kwargs) title = ", ".join(f"{name}={value}" for name, value in sel.items()) ax.set_title(title) if not isinstance(axes, np.ndarray): axes = np.array([axes]) return fig, axes
[docs] def plot_curve( curve: xr.DataArray, non_grid_names: set[str], subplot_kwargs: dict | None = None, sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot HDI with samples of the curve across coords. Parameters ---------- curve : xr.DataArray Curve to plot non_grid_names : set[str] The names to exclude from the grid. HDI and samples both have defaults of hdi and chain, draw, respectively subplot_kwargs : dict, optional Addtional kwargs to while creating the fig and axes sample_kwargs : dict, optional Kwargs for the :func:`plot_curve` function hdi_kwargs : dict, optional Kwargs for the :func:`plot_hdi` function Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Figure and the axes """ curve = drop_scalar_coords(curve) hdi_kwargs = hdi_kwargs or {} sample_kwargs = sample_kwargs or {} if "subplot_kwargs" not in sample_kwargs: sample_kwargs["subplot_kwargs"] = subplot_kwargs fig, axes = plot_samples( curve, non_grid_names=non_grid_names, **sample_kwargs, ) fig, axes = plot_hdi( curve, non_grid_names=non_grid_names, axes=axes, **hdi_kwargs, ) return fig, axes