TimeSliceCrossValidator#
- class pymc_marketing.mmm.time_slice_cross_validation.TimeSliceCrossValidator(n_init, forecast_horizon, date_column, step_size=1, sampler_config=None)[source]#
Time-Slice Cross Validator for Media Mix Models (MMM).
Provides a scikit-learn-style API for performing rolling time-slice cross-validation on media mix models. This is useful for evaluating model stability and out-of-sample prediction performance.
- Parameters:
- n_init
int Number of initial time periods to use for the first training fold. Must be a positive integer.
- forecast_horizon
int Number of time periods to forecast in each fold. Must be a positive integer.
- date_column
str Name of the column in X containing date values.
- step_size
int, optional Number of time periods to step forward between consecutive folds. Default is 1. Must be a positive integer.
- sampler_config
dict, optional Configuration dictionary for the PyMC sampler. Can include keys like ‘tune’, ‘draws’, ‘chains’, ‘random_seed’, ‘target_accept’, etc. Can be overridden per-run via the
run()method.
- n_init
- Attributes:
See also
pymc_marketing.mmm.MMMThe Media Mix Model class.
pymc_marketing.mmm.plot.MMMPlotSuitePlotting utilities for CV results.
Notes
This validator does not retain a fitted MMM instance; models are constructed per-fold from a YAML configuration or supplied to
run().Each fold stores its full InferenceData, which can consume significant memory for large models with many folds.
Examples
Basic usage with a YAML configuration:
>>> cv = TimeSliceCrossValidator( ... n_init=100, ... forecast_horizon=10, ... date_column="date", ... step_size=5, ... ) >>> combined_idata = cv.run(X, y, yaml_path="model_config.yml")
With custom sampler configuration:
>>> cv = TimeSliceCrossValidator( ... n_init=158, ... forecast_horizon=10, ... date_column="date", ... step_size=50, ... sampler_config={ ... "tune": 500, ... "draws": 200, ... "chains": 4, ... "random_seed": 123, ... }, ... ) >>> combined_idata = cv.run(X, y, mmm=mmm_builder)
Methods
TimeSliceCrossValidator.__init__(n_init, ...)Return the number of possible rolling splits.
TimeSliceCrossValidator.run(X, y[, ...])Run the complete time-slice cross-validation loop.
TimeSliceCrossValidator.split(X[, y])Generate train/test indices for each time-slice split.
Attributes
plotUse the MMMPlotSuite to plot the results.