# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Plotting functions for the CLV module."""
from collections.abc import Sequence
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D
from pymc_marketing.clv import BetaGeoModel, ParetoNBDModel
__all__ = [
"plot_customer_exposure",
"plot_frequency_recency_matrix",
"plot_probability_alive_matrix",
]
[docs]
def plot_customer_exposure(
df: pd.DataFrame,
linewidth: float | None = None,
size: float | None = None,
labels: Sequence[str] | None = None,
colors: Sequence[str] | None = None,
padding: float = 0.25,
ax: plt.Axes | None = None,
) -> plt.Axes:
"""Plot the recency and T of DataFrame of customers.
Plots customers as horizontal lines with markers representing their recency and T starting.
Order is the same as the DataFrame and plotted from the bottom up.
The lines are colored by recency and T.
Parameters
----------
df : pd.DataFrame
A DataFrame with columns "recency" and "T" representing the recency and age of customers.
linewidth : float, optional
The width of the horizontal lines in the plot.
size : float, optional
The size of the markers in the plot.
labels : Sequence[str], optional
A sequence of labels for the legend. Default is ["Recency", "T"].
colors : Sequence[str], optional
A sequence of colors for the legend. Default is ["C0", "C1"].
padding : float, optional
The padding around the plot. Default is 0.25.
ax : plt.Axes, optional
A matplotlib axes instance to plot on. If None, a new figure and axes is created.
Returns
-------
plt.Axes
The matplotlib axes instance.
Examples
--------
Plot customer exposure
.. code-block:: python
df = pd.DataFrame({
"recency": [0, 1, 2, 3, 4],
"T": [5, 5, 5, 5, 5]
})
plot_customer_exposure(df)
Plot customer exposure ordered by recency and T
.. code-block:: python
(
df
.sort_values(["recency", "T"])
.pipe(plot_customer_exposure)
)
Plot exposure for only those with time until last purchase is less than 3
.. code-block:: python
(
df
.query("T - recency < 3")
.pipe(plot_customer_exposure)
)
"""
if padding < 0:
raise ValueError("padding must be non-negative")
if size is not None and size < 0:
raise ValueError("size must be non-negative")
if linewidth is not None and linewidth < 0:
raise ValueError("linewidth must be non-negative")
if ax is None:
ax = plt.gca()
n = len(df)
customer_idx = np.arange(1, n + 1)
recency = df["recency"].to_numpy()
T = df["T"].to_numpy()
if colors is None:
colors = ["C0", "C1"]
if len(colors) != 2:
raise ValueError("colors must be a sequence of length 2")
recency_color, T_color = colors
ax.hlines(
y=customer_idx, xmin=0, xmax=recency, linewidth=linewidth, color=recency_color
)
ax.hlines(y=customer_idx, xmin=recency, xmax=T, linewidth=linewidth, color=T_color)
ax.scatter(x=recency, y=customer_idx, linewidth=linewidth, s=size, c=recency_color)
ax.scatter(x=T, y=customer_idx, linewidth=linewidth, s=size, c=T_color)
ax.set(
xlabel="Time since first purchase",
ylabel="Customer",
xlim=(0 - padding, T.max() + padding),
ylim=(1 - padding, n + padding),
title="Customer Exposure",
)
if labels is None:
labels = ["Recency", "T"]
if len(labels) != 2:
raise ValueError("labels must be a sequence of length 2")
recency_label, T_label = labels
legend_elements = [
Line2D([0], [0], color=recency_color, label=recency_label),
Line2D([0], [0], color=T_color, label=T_label),
]
ax.legend(handles=legend_elements, loc="best")
return ax
def _create_frequency_recency_meshes(
max_frequency: int,
max_recency: int,
) -> tuple[np.ndarray, np.ndarray]:
frequency = np.arange(max_frequency + 1)
recency = np.arange(max_recency + 1)
mesh_frequency, mesh_recency = np.meshgrid(frequency, recency)
return mesh_frequency, mesh_recency
[docs]
def plot_frequency_recency_matrix(
model: BetaGeoModel | ParetoNBDModel,
future_t: int = 1,
max_frequency: int | None = None,
max_recency: int | None = None,
title: str | None = None,
xlabel: str = "Customer's Historical Frequency",
ylabel: str = "Customer's Recency",
ax: plt.Axes | None = None,
**kwargs,
) -> plt.Axes:
"""Plot expected transactions in *future_t* time periods as a heatmap based on customer population *frequency* and *recency*.
Parameters
----------
model: CLV model
A fitted CLV model.
future_t: float, optional
Future time periods over which to run predictions.
max_frequency: int, optional
The maximum *frequency* to plot. Defaults to max observed *frequency*.
max_recency: int, optional
The maximum *recency* to plot. This also determines the age of the customer. Defaults to max observed *recency*.
title: str, optional
Figure title
xlabel: str, optional
Figure xlabel
ylabel: str, optional
Figure ylabel
ax: plt.Axes, optional
A matplotlib axes instance. Creates new axes instance by default.
kwargs
Passed into the matplotlib.imshow command.
Returns
-------
axes: matplotlib.AxesSubplot
""" # noqa: E501
if max_frequency is None:
max_frequency = int(model.data["frequency"].max())
if max_recency is None:
max_recency = int(model.data["recency"].max())
mesh_frequency, mesh_recency = _create_frequency_recency_meshes(
max_frequency=max_frequency,
max_recency=max_recency,
)
# create dataframe for model input
transaction_data = pd.DataFrame(
{
"customer_id": np.arange(mesh_recency.size), # placeholder
"frequency": mesh_frequency.ravel(),
"recency": mesh_recency.ravel(),
"T": max_recency,
}
)
# run model predictions to create heatmap values
Z = (
model.expected_purchases(
data=transaction_data,
future_t=future_t,
)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
if ax is None:
ax = plt.subplot(111)
pcm = ax.imshow(Z, **kwargs)
if title is None:
title = (
"Expected Number of Future Purchases for {} Unit{} of Time,".format(
future_t, "s"[future_t == 1 :]
)
+ "\nby Frequency and Recency of a Customer"
)
ax.set(
xlabel=xlabel,
ylabel=ylabel,
title=title,
)
force_aspect(ax)
# plot colorbar beside matrix
plt.colorbar(pcm, ax=ax)
return ax
[docs]
def plot_probability_alive_matrix(
model: BetaGeoModel | ParetoNBDModel,
max_frequency: int | None = None,
max_recency: int | None = None,
title: str = "Probability Customer is Alive,\nby Frequency and Recency of a Customer",
xlabel: str = "Customer's Historical Frequency",
ylabel: str = "Customer's Recency",
ax: plt.Axes | None = None,
**kwargs,
) -> plt.Axes:
"""Plot probability alive matrix as a heatmap based on customer population *frequency* and *recency*.
Parameters
----------
model: CLV model
A fitted CLV model.
max_frequency: int, optional
The maximum *frequency* to plot. Defaults to max observed *frequency*.
max_recency: int, optional
The maximum *recency* to plot. This also determines the age of the customer. Defaults to max observed *recency*.
title: str, optional
Figure title
xlabel: str, optional
Figure xlabel
ylabel: str, optional
Figure ylabel
ax: plt.Axes, optional
A matplotlib axes instance. Creates new axes instance by default.
kwargs
Passed into the matplotlib.imshow command.
Returns
-------
axes: matplotlib.AxesSubplot
"""
if max_frequency is None:
max_frequency = int(model.data["frequency"].max())
if max_recency is None:
max_recency = int(model.data["recency"].max())
mesh_frequency, mesh_recency = _create_frequency_recency_meshes(
max_frequency=max_frequency,
max_recency=max_recency,
)
# create dataframe for model input
transaction_data = pd.DataFrame(
{
"customer_id": np.arange(mesh_recency.size), # placeholder
"frequency": mesh_frequency.ravel(),
"recency": mesh_recency.ravel(),
"T": max_recency,
}
)
# run model predictions to create heatmap values
Z = (
model.expected_probability_alive(data=transaction_data)
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
interpolation = kwargs.pop("interpolation", "none")
if ax is None:
ax = plt.subplot(111)
pcm = ax.imshow(Z, interpolation=interpolation, **kwargs)
ax.set(
xlabel=xlabel,
ylabel=ylabel,
title=title,
)
force_aspect(ax)
# plot colorbar beside matrix
plt.colorbar(pcm, ax=ax)
return ax
[docs]
def force_aspect(ax: plt.Axes, aspect=1):
im = ax.get_images()
extent = im[0].get_extent()
ax.set_aspect(abs((extent[1] - extent[0]) / (extent[3] - extent[2])) / aspect)