TimeSliceCrossValidator#

class pymc_marketing.mmm.time_slice_cross_validation.TimeSliceCrossValidator(n_init, forecast_horizon, date_column, step_size=1, sampler_config=None)[source]#

Time-Slice Cross Validator for Media Mix Models (MMM).

Provides a scikit-learn-style API for performing rolling time-slice cross-validation on media mix models. This is useful for evaluating model stability and out-of-sample prediction performance.

Parameters:
n_initint

Number of initial time periods to use for the first training fold. Must be a positive integer.

forecast_horizonint

Number of time periods to forecast in each fold. Must be a positive integer.

date_columnstr

Name of the column in X containing date values.

step_sizeint, optional

Number of time periods to step forward between consecutive folds. Default is 1. Must be a positive integer.

sampler_configdict, optional

Configuration dictionary for the PyMC sampler. Can include keys like ‘tune’, ‘draws’, ‘chains’, ‘random_seed’, ‘target_accept’, etc. Can be overridden per-run via the run() method.

Attributes:
n_initint

Number of initial training periods.

forecast_horizonint

Number of forecast periods per fold.

date_columnstr

Name of the date column.

step_sizeint

Step size between folds.

sampler_configdict or None

Sampler configuration dictionary.

See also

pymc_marketing.mmm.MMM

The Media Mix Model class.

pymc_marketing.mmm.plot.MMMPlotSuite

Plotting utilities for CV results.

Notes

This validator does not retain a fitted MMM instance; models are constructed per-fold from a YAML configuration or supplied to run().

Each fold stores its full InferenceData, which can consume significant memory for large models with many folds.

Examples

Basic usage with a YAML configuration:

>>> cv = TimeSliceCrossValidator(
...     n_init=100,
...     forecast_horizon=10,
...     date_column="date",
...     step_size=5,
... )
>>> combined_idata = cv.run(X, y, yaml_path="model_config.yml")

With custom sampler configuration:

>>> cv = TimeSliceCrossValidator(
...     n_init=158,
...     forecast_horizon=10,
...     date_column="date",
...     step_size=50,
...     sampler_config={
...         "tune": 500,
...         "draws": 200,
...         "chains": 4,
...         "random_seed": 123,
...     },
... )
>>> combined_idata = cv.run(X, y, mmm=mmm_builder)

Methods

TimeSliceCrossValidator.__init__(n_init, ...)

TimeSliceCrossValidator.get_n_splits(X[, y])

Return the number of possible rolling splits.

TimeSliceCrossValidator.run(X, y[, ...])

Run the complete time-slice cross-validation loop.

TimeSliceCrossValidator.split(X[, y])

Generate train/test indices for each time-slice split.

Attributes

plot

Use the MMMPlotSuite to plot the results.