Source code for pymc_marketing.mmm.time_slice_cross_validation

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Time-slice cross-validation utilities for PyMC-Marketing MMM.

This module provides the TimeSliceCrossValidator which can run rolling
time-slice cross-validation for media-mix models built with the library.
The validator does not retain a fitted MMM instance; models may be
constructed per-fold from a YAML configuration or supplied to ``run()``.
"""

from collections.abc import Generator
from dataclasses import dataclass
from typing import Any

import arviz as az
import numpy as np
import pandas as pd
import xarray as xr
from tqdm.auto import tqdm

from pymc_marketing.mmm.builders.yaml import build_mmm_from_yaml
from pymc_marketing.mmm.plot import MMMPlotSuite
from pymc_marketing.mmm.types import MMMBuilder


[docs] @dataclass class TimeSliceCrossValidationResult: """Container for the results of one time-slice CV step. Attributes ---------- X_train : pd.DataFrame Feature matrix used for training in this fold. y_train : pd.Series Target variable used for training in this fold. X_test : pd.DataFrame Feature matrix used for testing in this fold. y_test : pd.Series Target variable used for testing in this fold. idata : az.InferenceData ArviZ InferenceData object containing posterior samples and predictions from the fitted model for this fold. """ X_train: pd.DataFrame y_train: pd.Series X_test: pd.DataFrame y_test: pd.Series idata: az.InferenceData
[docs] class TimeSliceCrossValidator: """Time-Slice Cross Validator for Media Mix Models (MMM). Provides a scikit-learn-style API for performing rolling time-slice cross-validation on media mix models. This is useful for evaluating model stability and out-of-sample prediction performance. Parameters ---------- n_init : int Number of initial time periods to use for the first training fold. Must be a positive integer. forecast_horizon : int Number of time periods to forecast in each fold. Must be a positive integer. date_column : str Name of the column in X containing date values. step_size : int, optional Number of time periods to step forward between consecutive folds. Default is 1. Must be a positive integer. sampler_config : dict, optional Configuration dictionary for the PyMC sampler. Can include keys like 'tune', 'draws', 'chains', 'random_seed', 'target_accept', etc. Can be overridden per-run via the ``run()`` method. Attributes ---------- n_init : int Number of initial training periods. forecast_horizon : int Number of forecast periods per fold. date_column : str Name of the date column. step_size : int Step size between folds. sampler_config : dict or None Sampler configuration dictionary. See Also -------- pymc_marketing.mmm.MMM : The Media Mix Model class. pymc_marketing.mmm.plot.MMMPlotSuite : Plotting utilities for CV results. Notes ----- This validator does not retain a fitted MMM instance; models are constructed per-fold from a YAML configuration or supplied to ``run()``. Each fold stores its full InferenceData, which can consume significant memory for large models with many folds. Examples -------- Basic usage with a YAML configuration: >>> cv = TimeSliceCrossValidator( ... n_init=100, ... forecast_horizon=10, ... date_column="date", ... step_size=5, ... ) >>> combined_idata = cv.run(X, y, yaml_path="model_config.yml") With custom sampler configuration: >>> cv = TimeSliceCrossValidator( ... n_init=158, ... forecast_horizon=10, ... date_column="date", ... step_size=50, ... sampler_config={ ... "tune": 500, ... "draws": 200, ... "chains": 4, ... "random_seed": 123, ... }, ... ) >>> combined_idata = cv.run(X, y, mmm=mmm_builder) """
[docs] def __init__( self, n_init: int, forecast_horizon: int, date_column: str, step_size: int = 1, sampler_config: dict[str, Any] | None = None, ) -> None: if not isinstance(step_size, int) or step_size <= 0: raise ValueError("step_size must be a positive integer") if not isinstance(n_init, int) or n_init <= 0: raise ValueError("n_init must be a positive integer") if not isinstance(forecast_horizon, int) or forecast_horizon <= 0: raise ValueError("forecast_horizon must be a positive integer") self.n_init = n_init self.forecast_horizon = forecast_horizon self.date_column = date_column self.step_size = step_size # Optional sampler configuration that will be applied to the MMM prior to fitting # Can be provided here at construction or passed to run() to override per-run. self.sampler_config = sampler_config
@property def plot(self) -> MMMPlotSuite: """Use the MMMPlotSuite to plot the results.""" self._validate_model_was_built() self._validate_idata_exists() return MMMPlotSuite(idata=self.idata) def _validate_model_was_built(self) -> None: """Validate that at least one CV run has produced results. Ensures `self._cv_results` exists and is non-empty. If an InferenceData is present on the last result, expose it as `self.idata` for compatibility with the MMMPlotSuite API. """ if not hasattr(self, "_cv_results") or not self._cv_results: raise ValueError( "No CV results available. Run `TimeSliceCrossValidator.run(...)` first." ) last_result = self._cv_results[-1] if hasattr(last_result, "idata") and last_result.idata is not None: # make idata accessible for plotting helpers self.idata = last_result.idata def _validate_idata_exists(self) -> None: """Validate that `self.idata` is present and not None.""" if not hasattr(self, "idata") or self.idata is None: raise ValueError( "No InferenceData available on the validator. Run `TimeSliceCrossValidator.run(...)` first." ) def _create_metadata(self, cv_coord: pd.Index) -> xr.Dataset: """Build a cv_metadata Dataset that stores per-fold metadata. The dataset stores per-fold metadata as Python objects (DataFrames/Series) under a single DataArray named 'metadata' indexed by the same 'cv' labels. Consumers can access fold metadata via ``cv_idata.cv_metadata.metadata.sel(cv=...)``. Parameters ---------- cv_coord : pd.Index The coordinate index for the 'cv' dimension. Returns ------- xr.Dataset Dataset containing per-fold metadata. """ metadata_list = [] for r in self._cv_results: meta = { "X_train": getattr(r, "X_train", None), "y_train": getattr(r, "y_train", None), "X_test": getattr(r, "X_test", None), "y_test": getattr(r, "y_test", None), } metadata_list.append(meta) # Create an object-dtype array so xarray can hold arbitrary Python objects meta_arr = np.empty((len(metadata_list),), dtype=object) for i, m in enumerate(metadata_list): meta_arr[i] = m ds_meta = xr.Dataset( {"metadata": ("cv", meta_arr)}, coords={"cv": cv_coord}, ) # persist on instance for convenience self.cv_metadata = metadata_list return ds_meta def _combine_idata( self, results: list[TimeSliceCrossValidationResult], model_names: list[str], ) -> az.InferenceData: """Combine InferenceData objects from multiple CV results. Parameters ---------- results : list of TimeSliceCrossValidationResult List of CV results from each fold. model_names : list of str Names for each CV fold. Returns ------- az.InferenceData Combined InferenceData with folds concatenated along 'cv' coordinate. Raises ------ ValueError If no InferenceData objects were produced during CV. """ cv_idata: az.InferenceData | None = None if results: # try to discover available groups from the first idata first_idata = results[0].idata try: groups = list(first_idata._groups) except Exception: # fallback to common groups groups = [ "posterior", "posterior_predictive", "observed_data", "sample_stats", "prior", ] combined_kwargs: dict = {} # Ensure we pass a concrete list[str] into pd.Index to satisfy type checkers cv_coord = pd.Index([str(n) for n in model_names], name="cv") for group in groups: # collect available datasets for this group ds_list = [] for r in results: if r.idata is None: continue try: ds = r.idata[group] except Exception: ds = None if ds is not None: ds_list.append(ds) if not ds_list: continue # concatenate along new cv coordinate, making sure each dataset # gets the cv coordinate labels try: combined_ds = xr.concat(ds_list, dim=cv_coord) except Exception: # if concat fails, try to align then concat without coords combined_ds = xr.concat( [ d.assign_coords({"cv": [n]}) for d, n in zip(ds_list, model_names, strict=False) ], dim="cv", ) combined_kwargs[group] = combined_ds # Build a cv_metadata Dataset that stores per-fold metadata ds_meta = self._create_metadata(cv_coord) combined_kwargs["cv_metadata"] = ds_meta if combined_kwargs: cv_idata = az.InferenceData(**combined_kwargs) # persist for plot helpers self.cv_idata = cv_idata # Also expose the last fold's idata (if any) for compatibility with MMMPlotSuite if results: last = results[-1] if hasattr(last, "idata") and last.idata is not None: self.idata = last.idata # Always return the combined arviz.InferenceData. If none could be # constructed (e.g. folds did not produce idata), raise an error so the # caller knows something went wrong. if cv_idata is None: raise ValueError( "No InferenceData objects were produced during CV; ensure models produce idata." ) return cv_idata def _fit_mmm( self, mmm: Any, X: pd.DataFrame, y: pd.Series, sampler_config: dict[str, Any] | None = None, ) -> Any: """Fit the MMM model. Parameters ---------- mmm : object MMM instance to fit. X : pd.DataFrame Feature matrix. y : pd.Series Target variable. sampler_config : dict, optional Sampler configuration to apply before fitting. Returns ------- object The fitted MMM instance. """ # Determine which sampler config to apply (explicit override takes precedence) effective_sampler_config = ( sampler_config if sampler_config is not None else self.sampler_config ) if effective_sampler_config is not None: # Set the sampler config on the model prior to fitting mmm.sampler_config = effective_sampler_config _ = mmm.fit( X, y, progressbar=True, ) return mmm def _time_slice_step( self, mmm: Any, X_train: pd.DataFrame, y_train: pd.Series, X_test: pd.DataFrame, y_test: pd.Series, sampler_config: dict[str, Any] | None = None, ) -> TimeSliceCrossValidationResult: """Run one CV step and return results. Parameters ---------- mmm : object MMM instance to fit. X_train : pd.DataFrame Training feature matrix. y_train : pd.Series Training target variable. X_test : pd.DataFrame Test feature matrix. y_test : pd.Series Test target variable. sampler_config : dict, optional Sampler configuration to apply before fitting. Returns ------- TimeSliceCrossValidationResult Results container with fitted model data and predictions. """ # Fit the model for this fold. sampler_config can override the validator-level config. mmm = self._fit_mmm(mmm, X_train, y_train, sampler_config=sampler_config) # Combine train and test data for posterior predictions X_combined = pd.concat([X_train, X_test], ignore_index=True) # Remove existing posterior_predictive groups if they exist to avoid conflicts # when extending idata with new predictions if mmm.idata is not None: if "posterior_predictive" in mmm.idata.groups(): del mmm.idata.posterior_predictive if "posterior_predictive_constant_data" in mmm.idata.groups(): del mmm.idata.posterior_predictive_constant_data # Run posterior predictions on combined data with extend_idata=True _ = mmm.sample_posterior_predictive( X=X_combined, include_last_observations=False, extend_idata=True, progressbar=False, ) return TimeSliceCrossValidationResult( X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, idata=mmm.idata, )
[docs] def get_n_splits(self, X: pd.DataFrame, y: pd.Series | None = None) -> int: """Return the number of possible rolling splits. Parameters ---------- X : pd.DataFrame Feature matrix containing the date column. y : pd.Series, optional Target variable. Not used but included for scikit-learn API compatibility. Returns ------- int Number of cross-validation splits that can be generated. """ total_dates = len(X[self.date_column].unique()) # Calculate how many splits we can make with the given step_size # We need at least n_init + forecast_horizon dates for one split # With step_size, we can make splits at positions: 0, step_size, 2*step_size, ... # The last possible split position is: total_dates - n_init - forecast_horizon # So the number of splits is: floor((total_dates - n_init - forecast_horizon) / step_size) + 1 max_splits = ( total_dates - self.n_init - self.forecast_horizon ) // self.step_size + 1 return max(0, max_splits)
[docs] def split( self, X: pd.DataFrame, y: pd.Series | None = None ) -> Generator[tuple[np.ndarray, np.ndarray], None, None]: """Generate train/test indices for each time-slice split. This implementation selects rows by date masks so that all coordinate levels (e.g., multiple geos) for the selected date ranges are included in each fold. It returns integer positions suitable for use with ``DataFrame.iloc``. Parameters ---------- X : pd.DataFrame Feature matrix containing the date column. y : pd.Series, optional Target variable. Not used but included for scikit-learn API compatibility. Yields ------ train_idx : np.ndarray Integer indices for training rows in this fold. test_idx : np.ndarray Integer indices for test rows in this fold. Raises ------ ValueError If no splits are possible with the given parameters. Examples -------- >>> cv = TimeSliceCrossValidator( ... n_init=10, forecast_horizon=5, date_column="date" ... ) >>> for train_idx, test_idx in cv.split(X, y): ... X_train, X_test = X.iloc[train_idx], X.iloc[test_idx] ... y_train, y_test = y.iloc[train_idx], y.iloc[test_idx] """ n_splits = self.get_n_splits(X, y) if n_splits <= 0: raise ValueError( "No splits possible with the given n_init, forecast_horizon and step_size" ) # unique sorted dates udates = np.unique(pd.to_datetime(X[self.date_column].to_numpy())) for i in range(n_splits): start_date = udates[i * self.step_size + self.n_init] end_date = udates[ i * self.step_size + self.n_init + self.forecast_horizon - 1 ] # boolean masks selecting rows for train and test ranges (preserve all geos) train_mask = pd.to_datetime(X[self.date_column]) < start_date test_mask = (pd.to_datetime(X[self.date_column]) >= start_date) & ( pd.to_datetime(X[self.date_column]) <= end_date ) train_idx = np.flatnonzero(train_mask.to_numpy()) test_idx = np.flatnonzero(test_mask.to_numpy()) yield train_idx, test_idx
[docs] def run( self, X: pd.DataFrame, y: pd.Series, sampler_config: dict[str, Any] | None = None, yaml_path: str | None = None, mmm: MMMBuilder | None = None, model_names: list[str] | None = None, ) -> az.InferenceData: """Run the complete time-slice cross-validation loop. Executes cross-validation by iterating through all folds, fitting a model for each training set, and generating predictions on the combined train+test data. Parameters ---------- X : pd.DataFrame Feature matrix containing the date column and predictor variables. y : pd.Series Target variable. sampler_config : dict, optional Sampler configuration to override the validator-level configuration for all folds in this run. If provided, takes precedence over the configuration passed at construction time. yaml_path : str, optional Path to a YAML configuration file for building the MMM model per fold. Mutually exclusive with ``mmm``. mmm : object, optional An object with a ``build_model(X, y)`` method that returns a fitted MMM instance. Mutually exclusive with ``yaml_path``. model_names : list of str, optional Names to assign to each CV fold in the combined InferenceData. If provided, length must match the number of splits. If not provided, names are generated from each model's ``_model_name`` attribute or as ``'Iteration {i}'``. Returns ------- arviz.InferenceData Combined InferenceData where each fold is concatenated along a new coordinate named 'cv'. Includes a 'cv_metadata' group with per-fold train/test data. Raises ------ ValueError If neither ``yaml_path`` nor ``mmm`` is provided. If ``model_names`` length doesn't match the number of splits. If no InferenceData objects are produced during CV. See Also -------- split : Generate train/test indices for cross-validation. get_n_splits : Return the number of splits. Notes ----- Per-fold results are also stored in ``self._cv_results`` after calling this method. Examples -------- Using a YAML configuration: >>> cv = TimeSliceCrossValidator( ... n_init=100, forecast_horizon=10, date_column="date" ... ) >>> combined_idata = cv.run(X, y, yaml_path="model_config.yml") Using a model builder object: >>> cv = TimeSliceCrossValidator( ... n_init=100, forecast_horizon=10, date_column="date" ... ) >>> combined_idata = cv.run(X, y, mmm=mmm_builder) """ # Upfront validation of model_names length n_splits = self.get_n_splits(X, y) if model_names is not None and len(model_names) != n_splits: raise ValueError( f"`model_names` length ({len(model_names)}) must match the number " f"of CV splits ({n_splits})." ) results: list[TimeSliceCrossValidationResult] = [] # Preserve the user-provided `model_names` parameter separately so we # don't shadow it with the accumulator used to collect generated names. user_model_names = model_names model_name_labels: list[str] = [] for _i, (train_idx, test_idx) in enumerate(tqdm(self.split(X, y))): X_train, y_train = X.iloc[train_idx], y.iloc[train_idx] X_test, y_test = X.iloc[test_idx], y.iloc[test_idx] # Optionally (re)build the model from yaml using the training fold if yaml_path is not None: fold_mmm = build_mmm_from_yaml( config_path=yaml_path, X=X_train, y=y_train ) elif mmm is not None: # use provided mmm instance (do not store it on self) fold_mmm = mmm.build_model(X_train, y_train) else: raise ValueError( "Either provide an `mmm` instance to run(...) or a `yaml_path` to build the model per-fold." ) # determine name for this fold if user_model_names is not None: # Length was validated upfront, so direct indexing is safe fold_name = user_model_names[_i] else: base_name = ( getattr(fold_mmm, "_model_name", None) or getattr(mmm, "_model_name", None) or "Iteration" ) # produce human-friendly default if base_name == "Iteration": fold_name = f"Iteration {_i}" else: fold_name = f"{base_name}_{_i}" model_name_labels.append(fold_name) result = self._time_slice_step( fold_mmm, X_train, y_train, X_test, y_test, sampler_config=sampler_config, ) results.append(result) # Persist results on the instance so plotting helpers can access them self._cv_results = results # Build a combined InferenceData. We combine each fold's # datasets along a new coordinate named 'cv' where each label is the # fold name determined above. cv_idata = self._combine_idata(results, model_name_labels) return cv_idata