Source code for causalpy.experiments.piecewise_its

#   Copyright 2022 - 2026 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""
Piecewise Interrupted Time Series Analysis (Segmented Regression)
"""

import re
from typing import Any

import arviz as az
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib import pyplot as plt
from patsy import dmatrices
from sklearn.base import RegressorMixin

from causalpy.custom_exceptions import FormulaException
from causalpy.plot_utils import plot_xY
from causalpy.pymc_models import PyMCModel
from causalpy.transforms import ramp, step  # noqa: F401
from causalpy.utils import round_num

from .base import BaseExperiment

LEGEND_FONT_SIZE = 12


[docs] class PiecewiseITS(BaseExperiment): """ Piecewise Interrupted Time Series (Segmented Regression) experiment. This class implements segmented-regression / piecewise linear models for Interrupted Time Series analysis with **known** interruption dates. Unlike the standard :class:`InterruptedTimeSeries` which fits a model to pre-intervention data and forecasts a counterfactual, `PiecewiseITS` fits **one model to the full time series** and estimates explicit level and/or slope changes at each interruption. The model uses patsy formulas with custom `step()` and `ramp()` transforms: - ``step(time, threshold)``: Creates a binary indicator (1 if time >= threshold) for level changes - ``ramp(time, threshold)``: Creates a ramp function (max(0, time - threshold)) for slope changes Parameters ---------- data : pd.DataFrame A pandas DataFrame containing the time series data. formula : str A patsy formula specifying the model. Must include at least one ``step()`` or ``ramp()`` term. Example: ``"y ~ 1 + t + step(t, 50) + ramp(t, 50)"`` model : PyMCModel or RegressorMixin, optional A PyMC (Bayesian) or sklearn (OLS) model. If None, defaults to a PyMC LinearRegression model. **kwargs Additional keyword arguments passed to the model. Attributes ---------- formula : str The patsy formula used for the model. interruption_times : list List of interruption times extracted from the formula. labels : list[str] Names of all coefficients in the design matrix. effect : xr.DataArray or np.ndarray Pointwise causal effect (observed - counterfactual). cumulative_effect : xr.DataArray or np.ndarray Cumulative causal effect over time. Examples -------- >>> import causalpy as cp >>> import pandas as pd >>> import numpy as np >>> # Generate simple piecewise data >>> np.random.seed(42) >>> t = np.arange(100) >>> y = ( ... 10 ... + 0.1 * t ... + 5 * (t >= 50) ... + 0.2 * np.maximum(0, t - 50) ... + np.random.normal(0, 1, 100) ... ) >>> df = pd.DataFrame({"t": t, "y": y}) >>> result = cp.PiecewiseITS( ... df, ... formula="y ~ 1 + t + step(t, 50) + ramp(t, 50)", ... model=cp.pymc_models.LinearRegression( ... sample_kwargs={"random_seed": 42, "progressbar": False} ... ), ... ) **Different effects per intervention:** >>> # Level change only at t=50, level + slope change at t=100 >>> result = cp.PiecewiseITS( ... df, ... formula="y ~ 1 + t + step(t, 50) + step(t, 100) + ramp(t, 100)", ... model=..., ... ) # doctest: +SKIP **With datetime thresholds:** >>> df["date"] = pd.date_range("2020-01-01", periods=100, freq="D") >>> result = cp.PiecewiseITS( ... df, ... formula="y ~ 1 + date + step(date, '2020-02-20') + ramp(date, '2020-02-20')", ... model=..., ... ) # doctest: +SKIP Notes ----- The counterfactual is computed by setting all step/ramp terms to zero, representing what would have happened without the interventions. The `step` and `ramp` transforms are patsy stateful transforms that handle both numeric and datetime time columns. For datetime, thresholds can be specified as strings (e.g., '2020-06-01') or pd.Timestamp objects. References ---------- - Wagner AK, et al. (2002). Segmented regression analysis of interrupted time series studies in medication use research. Journal of Clinical Pharmacy and Therapeutics. - Lopez Bernal J, et al. (2017). Interrupted time series regression for the evaluation of public health interventions: a tutorial. Int J Epidemiol. """ expt_type = "Piecewise Interrupted Time Series" supports_ols = True supports_bayes = True
[docs] def __init__( self, data: pd.DataFrame, formula: str, model: PyMCModel | RegressorMixin | None = None, **kwargs: dict[str, Any], ) -> None: super().__init__(model=model) # Store configuration self.formula = formula self.data = data.copy() # Rename the index to "obs_ind" for consistency self.data.index.name = "obs_ind" # Input validation self._validate_inputs() # Extract interruption times from formula for plotting self.interruption_times = self._extract_interruption_times() # Detect time column from step/ramp usage self.time_col = self._extract_time_column() # Parse formula with patsy (step and ramp are available in namespace) y, X = dmatrices(formula, self.data) self.outcome_variable_name = y.design_info.column_names[0] self._y_design_info = y.design_info self._x_design_info = X.design_info self.labels = list(X.design_info.column_names) # Convert to numpy arrays y_array = np.asarray(y) X_array = np.asarray(X) n_obs = X_array.shape[0] # Convert to xarray DataArrays self.X = xr.DataArray( X_array, dims=["obs_ind", "coeffs"], coords={ "obs_ind": np.arange(n_obs), "coeffs": self.labels, }, ) self.y = xr.DataArray( y_array, dims=["obs_ind", "treated_units"], coords={ "obs_ind": np.arange(n_obs), "treated_units": ["unit_0"], }, ) # Track which columns are interruption-related (for counterfactual) self._interruption_cols = self._get_interruption_column_indices() # Fit the model to the full time series if isinstance(self.model, PyMCModel): COORDS: dict[str, Any] = { "coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0]), "treated_units": ["unit_0"], } self.model.fit(X=self.X, y=self.y, coords=COORDS) elif isinstance(self.model, RegressorMixin): if hasattr(self.model, "fit_intercept"): self.model.fit_intercept = False self.model.fit(X=self.X, y=self.y.isel(treated_units=0)) else: raise ValueError("Model type not recognized") # Compute predictions (fitted values) self.y_pred = self.model.predict(X=self.X) # Score the model fit if isinstance(self.model, PyMCModel): self.score = self.model.score(X=self.X, y=self.y) elif isinstance(self.model, RegressorMixin): self.score = self.model.score(X=self.X, y=self.y.isel(treated_units=0)) # Compute counterfactual and effects self._compute_counterfactual_and_effects()
def _validate_inputs(self) -> None: """Validate input data and formula.""" # Check formula contains at least one step() or ramp() term if "step(" not in self.formula and "ramp(" not in self.formula: raise FormulaException( "Formula must contain at least one step() or ramp() term. " "Example: 'y ~ 1 + t + step(t, 50) + ramp(t, 50)'" ) def _extract_interruption_times(self) -> list[int | float | str]: """Extract interruption times from step() and ramp() calls in formula. Returns a list of unique threshold values found in the formula. """ # Match step(var, threshold) and ramp(var, threshold) # Threshold can be: numeric (50, 50.0), string ('2020-01-01'), or expression pattern = r"(?:step|ramp)\s*\(\s*\w+\s*,\s*([^)]+)\s*\)" matches = re.findall(pattern, self.formula) thresholds: list[int | float | str] = [] for match in matches: match = match.strip() # Try to parse as numeric try: if "." in match: val: int | float | str = float(match) else: val = int(match) except ValueError: # Keep as string (e.g., "'2020-01-01'" or "pd.Timestamp(...)") # Strip quotes if present val = match.strip("'\"") if val not in thresholds: thresholds.append(val) return thresholds def _extract_time_column(self) -> str: """Extract the time column name from step/ramp calls in formula.""" # Match step(var, ...) or ramp(var, ...) pattern = r"(?:step|ramp)\s*\(\s*(\w+)\s*," matches = re.findall(pattern, self.formula) if matches: return matches[0] # Fallback: try to find a time-like column return "t" def _get_interruption_column_indices(self) -> list[int]: """Get indices of columns related to interruptions (step/ramp terms).""" indices = [] for i, label in enumerate(self.labels): # Patsy labels step/ramp terms like "step(t, 50)" or "ramp(t, 50)" if "step(" in label or "ramp(" in label: indices.append(i) return indices def _compute_counterfactual_and_effects(self) -> None: """ Compute the counterfactual (no intervention) and causal effects. The counterfactual is computed by setting step/ramp terms to zero. Also creates post_impact, datapost, and post_pred attributes for compatibility with effect_summary() from BaseExperiment. """ # Create design matrix for counterfactual (zero out interruption columns) X_cf = self.X.copy() for idx in self._interruption_cols: X_cf[:, idx] = 0 # Compute counterfactual predictions if isinstance(self.model, PyMCModel): self.y_counterfactual = self.model.predict(X=X_cf) # Extract mu for fitted and counterfactual y_pred_mu = self.y_pred["posterior_predictive"]["mu"] y_cf_mu = self.y_counterfactual["posterior_predictive"]["mu"] # Handle treated_units dimension if present if "treated_units" in y_pred_mu.dims: y_pred_mu = y_pred_mu.isel(treated_units=0) if "treated_units" in y_cf_mu.dims: y_cf_mu = y_cf_mu.isel(treated_units=0) # Compute effect as fitted - counterfactual self.effect = y_pred_mu - y_cf_mu # Cumulative effect self.cumulative_effect = self.effect.cumsum(dim="obs_ind") elif isinstance(self.model, RegressorMixin): self.y_counterfactual = self.model.predict(X=X_cf) # Compute effect self.effect = np.squeeze(self.y_pred) - np.squeeze(self.y_counterfactual) # Cumulative effect self.cumulative_effect = np.cumsum(self.effect) # Create compatibility attributes for effect_summary() from BaseExperiment # These represent the post-intervention portion (after the first interruption) self._create_post_intervention_attributes() def _create_post_intervention_attributes(self) -> None: """ Create post_impact, datapost, and post_pred attributes for effect_summary(). These attributes make PiecewiseITS compatible with the effect_summary() method inherited from BaseExperiment, which expects ITS-like attributes. The "post-intervention" portion is defined as all observations at or after the first interruption time. """ if not self.interruption_times: # No interruptions - all data is "pre-intervention" # Create empty post-intervention attributes self.datapost = self.data.iloc[0:0] # Empty DataFrame return # Get the first interruption time first_interruption = self.interruption_times[0] time_col = self.time_col # Create boolean mask for post-intervention period # Post-intervention = time >= first_interruption (inclusive) if isinstance(first_interruption, str): # String threshold (e.g., '2020-06-01') try: threshold = pd.Timestamp(first_interruption) post_mask = self.data[time_col] >= threshold except Exception: # Fallback: try direct comparison post_mask = self.data[time_col] >= first_interruption else: # Numeric threshold post_mask = self.data[time_col] >= first_interruption # Create datapost - the post-intervention data self.datapost = self.data[post_mask].copy() self.datapost.index.name = "obs_ind" # Get indices for post-intervention period post_indices = np.where(np.asarray(post_mask))[0] # Create post_impact - the effects after the first interruption if isinstance(self.model, PyMCModel): # For PyMC models, effect is an xarray.DataArray # Select using obs_ind coordinate self.post_impact = self.effect.isel(obs_ind=post_indices) # Create post_pred - counterfactual predictions for post-intervention # This needs to be an InferenceData-like object for extract_counterfactual y_cf_mu = self.y_counterfactual["posterior_predictive"]["mu"] if "treated_units" in y_cf_mu.dims: y_cf_mu = y_cf_mu.isel(treated_units=0) post_cf_mu = y_cf_mu.isel(obs_ind=post_indices) # Update the coordinates to match datapost.index post_cf_mu = post_cf_mu.assign_coords(obs_ind=self.datapost.index) # Create an InferenceData-like dict structure self.post_pred = { "posterior_predictive": {"mu": post_cf_mu}, } # Update post_impact coordinates to match datapost.index self.post_impact = self.post_impact.assign_coords( obs_ind=self.datapost.index ) elif isinstance(self.model, RegressorMixin): # For OLS models, effect and counterfactual are numpy arrays self.post_impact = self.effect[post_indices] self.post_pred = np.squeeze(self.y_counterfactual)[post_indices]
[docs] def summary(self, round_to: int | None = None) -> None: """Print summary of main results and model coefficients. Parameters ---------- round_to : int, optional Number of decimals used to round results. Defaults to 2. """ print(f"{self.expt_type:=^80}") print(f"Formula: {self.formula}") print(f"Interruption times: {self.interruption_times}") self.print_coefficients(round_to)
def _bayesian_plot( self, round_to: int | None = 2, **kwargs: dict[str, Any] ) -> tuple[plt.Figure, list[plt.Axes]]: """ Plot the results for Bayesian models. Parameters ---------- round_to : int, optional Number of decimals for rounding. Defaults to 2. Returns ------- fig : plt.Figure The matplotlib figure. ax : list[plt.Axes] List of axes objects. """ time_values = self.data[self.time_col].values fig, ax = plt.subplots(3, 1, sharex=True, figsize=(10, 10)) # TOP PLOT: Observed, Fitted, and Counterfactual # Observed data (h_obs,) = ax[0].plot( time_values, self.y.isel(treated_units=0), "k.", label="Observations", ) # Fitted values (mu) y_pred_mu = self.y_pred["posterior_predictive"]["mu"] if "treated_units" in y_pred_mu.dims: y_pred_mu = y_pred_mu.isel(treated_units=0) h_line_fit, h_patch_fit = plot_xY( time_values, y_pred_mu, ax=ax[0], plot_hdi_kwargs={"color": "C0"}, ) # Counterfactual y_cf_mu = self.y_counterfactual["posterior_predictive"]["mu"] if "treated_units" in y_cf_mu.dims: y_cf_mu = y_cf_mu.isel(treated_units=0) h_line_cf, h_patch_cf = plot_xY( time_values, y_cf_mu, ax=ax[0], plot_hdi_kwargs={"color": "C1"}, ) # Title with R^2 r2_val = None try: if isinstance(self.score, pd.Series): if "unit_0_r2" in self.score.index: r2_val = self.score["unit_0_r2"] elif "r2" in self.score.index: r2_val = self.score["r2"] except Exception: pass title_str = "Piecewise ITS: Bayesian $R^2$" if r2_val is not None: title_str += f" = {round_num(r2_val, round_to)}" ax[0].set(title=title_str, ylabel=self.outcome_variable_name) handles = [h_obs, (h_line_fit, h_patch_fit), (h_line_cf, h_patch_cf)] labels_legend = ["Observations", "Fitted", "Counterfactual"] # MIDDLE PLOT: Causal Effect plot_xY( time_values, self.effect, ax=ax[1], plot_hdi_kwargs={"color": "C2"}, ) ax[1].axhline(y=0, c="k", linestyle="--", alpha=0.5) ax[1].fill_between( time_values, y1=self.effect.mean(dim=["chain", "draw"]).values, alpha=0.25, color="C2", ) ax[1].set(title="Causal Effect", ylabel="Effect") # BOTTOM PLOT: Cumulative Effect plot_xY( time_values, self.cumulative_effect, ax=ax[1 + 1], plot_hdi_kwargs={"color": "C3"}, ) ax[2].axhline(y=0, c="k", linestyle="--", alpha=0.5) ax[2].set(title="Cumulative Causal Effect", ylabel="Cumulative Effect") # Add vertical lines for interruptions for i, t_k in enumerate(self.interruption_times): # Convert string thresholds to appropriate type for plotting plot_threshold = self._convert_threshold_for_plotting(t_k) for a in ax: a.axvline( x=plot_threshold, ls="-", lw=2, color="red", alpha=0.7, label=f"Interruption {i}" if a == ax[0] else None, ) handles.append(plt.Line2D([0], [0], color="red", lw=2)) labels_legend.append(f"Interruption {i}") ax[0].legend(handles=handles, labels=labels_legend, fontsize=LEGEND_FONT_SIZE) plt.tight_layout() return fig, ax def _ols_plot( self, round_to: int | None = 2, **kwargs: dict[str, Any] ) -> tuple[plt.Figure, list[plt.Axes]]: """ Plot the results for OLS models. Parameters ---------- round_to : int, optional Number of decimals for rounding. Defaults to 2. Returns ------- fig : plt.Figure The matplotlib figure. ax : list[plt.Axes] List of axes objects. """ time_values = self.data[self.time_col].values fig, ax = plt.subplots(3, 1, sharex=True, figsize=(10, 10)) # TOP PLOT: Observed, Fitted, and Counterfactual ax[0].plot(time_values, self.y.values, "k.", label="Observations") ax[0].plot(time_values, self.y_pred, "C0-", label="Fitted", linewidth=2) ax[0].plot( time_values, self.y_counterfactual, "C1--", label="Counterfactual", linewidth=2, ) title_str = f"Piecewise ITS: $R^2$ = {round_num(float(self.score), round_to)}" ax[0].set(title=title_str, ylabel=self.outcome_variable_name) # MIDDLE PLOT: Causal Effect ax[1].plot(time_values, self.effect, "C2-", linewidth=2) ax[1].fill_between(time_values, y1=self.effect, alpha=0.25, color="C2") ax[1].axhline(y=0, c="k", linestyle="--", alpha=0.5) ax[1].set(title="Causal Effect", ylabel="Effect") # BOTTOM PLOT: Cumulative Effect ax[2].plot(time_values, self.cumulative_effect, "C3-", linewidth=2) ax[2].axhline(y=0, c="k", linestyle="--", alpha=0.5) ax[2].set(title="Cumulative Causal Effect", ylabel="Cumulative Effect") # Add vertical lines for interruptions for i, t_k in enumerate(self.interruption_times): plot_threshold = self._convert_threshold_for_plotting(t_k) for a in ax: a.axvline( x=plot_threshold, ls="-", lw=2, color="red", alpha=0.7, label=f"Interruption {i}" if a == ax[0] else None, ) ax[0].legend(fontsize=LEGEND_FONT_SIZE) plt.tight_layout() return fig, ax def _convert_threshold_for_plotting( self, threshold: int | float | str ) -> int | float | pd.Timestamp: """Convert threshold to appropriate type for matplotlib plotting.""" if isinstance(threshold, str): # Try to parse as datetime try: return pd.Timestamp(threshold) except Exception: return threshold # type: ignore[return-value] return threshold
[docs] def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame: """ Recover the data of the experiment along with prediction and effect information. Parameters ---------- hdi_prob : float, default=0.94 Probability for the highest density interval. Returns ------- pd.DataFrame DataFrame containing observed data, predictions, and effects. """ hdi_pct = int(round(hdi_prob * 100)) # Get time values time_values = self.data[self.time_col].values # Extract predictions y_pred_mu = self.y_pred["posterior_predictive"]["mu"] if "treated_units" in y_pred_mu.dims: y_pred_mu = y_pred_mu.isel(treated_units=0) y_cf_mu = self.y_counterfactual["posterior_predictive"]["mu"] if "treated_units" in y_cf_mu.dims: y_cf_mu = y_cf_mu.isel(treated_units=0) # Helper to extract HDI bounds from az.hdi() result (which returns a Dataset) def _get_hdi_bounds( hdi_result: xr.Dataset, ) -> tuple[np.ndarray, np.ndarray]: """Extract lower and upper bounds from az.hdi result.""" data_var = list(hdi_result.data_vars)[0] hdi_data = hdi_result[data_var] lower = hdi_data.sel(hdi="lower").values.flatten() upper = hdi_data.sel(hdi="higher").values.flatten() return lower, upper # Compute means and HDIs fitted_mean = y_pred_mu.mean(dim=["chain", "draw"]).values fitted_hdi = az.hdi(y_pred_mu, hdi_prob=hdi_prob) fitted_lower, fitted_upper = _get_hdi_bounds(fitted_hdi) cf_mean = y_cf_mu.mean(dim=["chain", "draw"]).values cf_hdi = az.hdi(y_cf_mu, hdi_prob=hdi_prob) cf_lower, cf_upper = _get_hdi_bounds(cf_hdi) effect_mean = self.effect.mean(dim=["chain", "draw"]).values effect_hdi = az.hdi(self.effect, hdi_prob=hdi_prob) effect_lower, effect_upper = _get_hdi_bounds(effect_hdi) cum_effect_mean = self.cumulative_effect.mean(dim=["chain", "draw"]).values cum_effect_hdi = az.hdi(self.cumulative_effect, hdi_prob=hdi_prob) cum_effect_lower, cum_effect_upper = _get_hdi_bounds(cum_effect_hdi) # Build DataFrame result = pd.DataFrame( { self.time_col: time_values, self.outcome_variable_name: self.y.isel(treated_units=0).values, "fitted": fitted_mean, f"fitted_hdi_lower_{hdi_pct}": fitted_lower, f"fitted_hdi_upper_{hdi_pct}": fitted_upper, "counterfactual": cf_mean, f"counterfactual_hdi_lower_{hdi_pct}": cf_lower, f"counterfactual_hdi_upper_{hdi_pct}": cf_upper, "effect": effect_mean, f"effect_hdi_lower_{hdi_pct}": effect_lower, f"effect_hdi_upper_{hdi_pct}": effect_upper, "cumulative_effect": cum_effect_mean, f"cumulative_effect_hdi_lower_{hdi_pct}": cum_effect_lower, f"cumulative_effect_hdi_upper_{hdi_pct}": cum_effect_upper, } ) self.plot_data = result return result
[docs] def get_plot_data_ols(self) -> pd.DataFrame: """ Recover the data of the experiment along with prediction and effect information. Returns ------- pd.DataFrame DataFrame containing observed data, predictions, and effects. """ time_values = self.data[self.time_col].values result = pd.DataFrame( { self.time_col: time_values, self.outcome_variable_name: self.y.values.flatten(), "fitted": np.squeeze(self.y_pred), "counterfactual": np.squeeze(self.y_counterfactual), "effect": self.effect, "cumulative_effect": self.cumulative_effect, } ) self.plot_data = result return result