Source code for drippy.regression

"""Plotting functions for regression models (y = f(x) + e, x continuous)."""

from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
from drippy.onefactor import scatter_plot  # shared — not re-implemented
from drippy.utilities import get_figure_and_axes

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure
    from drippy.data import EDAData

__all__ = [
    "linear_correlation_plot",
    "linear_intercept_plot",
    "linear_residual_sd_plot",
    "linear_slope_plot",
    "scatter_plot",
    "six_plot",
]


def _rolling_linregress(
    x: np.ndarray,
    y: np.ndarray,
    window: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Compute rolling OLS for each window of data.

    Returns:
        (midpoints, slopes, intercepts, residual_sds) one value per window.
    """
    n = len(y)
    if window < 2 or window > n:  # noqa: PLR2004
        msg = (
            "window must satisfy 2 <= window <= len(y); "
            f"got window={window}, len(y)={n}"
        )
        raise ValueError(msg)
    n_windows = n - window + 1
    midpoints = np.empty(n_windows)
    slopes = np.empty(n_windows)
    intercepts = np.empty(n_windows)
    residual_sds = np.empty(n_windows)
    for i in range(n_windows):
        x_win = x[i : i + window].astype(float)
        y_win = y[i : i + window]
        result = scipy.stats.linregress(x_win, y_win)
        midpoints[i] = i + window // 2
        slopes[i] = result.slope
        intercepts[i] = result.intercept
        res = y_win - (result.slope * x_win + result.intercept)
        residual_sds[i] = res.std(ddof=2)
    return midpoints, slopes, intercepts, residual_sds


[docs] def six_plot( # noqa: PLR0915 data: EDAData, fig: Figure | None = None, axes: np.ndarray | None = None, ) -> tuple[Figure, np.ndarray]: """Creates a 2x3 composite regression diagnostic plot. The six panels are: scatter with regression line, residuals vs x, lag plot of residuals, histogram of residuals, normal probability plot of residuals, and run sequence of residuals. Args: data: EDAData container. Requires x. fig: Matplotlib figure. If None, creates new figure. axes: 2x3 array of Axes. If None, creates new axes. Returns: (fig, axes) where axes has shape (2, 3). """ if data.x is None: msg = "six_plot requires x" raise ValueError(msg) if fig is None and axes is None: fig, axes = plt.subplots(2, 3) elif axes is None: axes = fig.subplots(2, 3) elif fig is None: fig = axes.flat[0].get_figure() elif axes.flat[0].get_figure() is not fig: msg = ( "axes must belong to provided fig; " f"axes belong to fig_id={id(axes.flat[0].get_figure())}, " f"provided fig_id={id(fig)}" ) raise ValueError(msg) if axes.shape != (2, 3): msg = "axes must have shape (2, 3)" raise ValueError(msg) x = data.x.astype(float) y = data.y result = scipy.stats.linregress(x, y) slope, intercept = result.slope, result.intercept res = y - (slope * x + intercept) # [0,0] Scatter plot with regression line ax = axes[0, 0] ax.scatter(x, y, s=10) x_line = np.array([x.min(), x.max()]) ax.plot(x_line, slope * x_line + intercept, color="r") ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_title("Scatter Plot") # [0,1] Residuals vs x ax = axes[0, 1] ax.scatter(x, res, s=10) ax.axhline(0, color="r", linestyle="--") ax.set_xlabel("X") ax.set_ylabel("Residuals") ax.set_title("Residual Plot") # [0,2] Lag plot of residuals (lag 1) ax = axes[0, 2] ax.scatter(res[:-1], res[1:], s=10) ax.set_xlabel("Residual[i]") ax.set_ylabel("Residual[i+1]") ax.set_title("Lag Plot") # [1,0] Histogram of residuals ax = axes[1, 0] ax.hist(res, bins="auto") ax.set_xlabel("Residuals") ax.set_ylabel("Frequency") ax.set_title("Histogram") # [1,1] Normal probability plot of residuals ax = axes[1, 1] (osm, osr), (fit_slope, fit_intercept, _) = scipy.stats.probplot( res, dist="norm" ) ax.scatter(osm, osr, s=10) ax.plot( [osm[0], osm[-1]], [ fit_slope * osm[0] + fit_intercept, fit_slope * osm[-1] + fit_intercept, ], color="r", linestyle="--", ) ax.set_xlabel("Theoretical Quantiles") ax.set_ylabel("Ordered Residuals") ax.set_title("Normal Probability Plot") # [1,2] Run sequence of residuals ax = axes[1, 2] ax.plot(res) ax.axhline(res.mean(), color="r", linestyle="--") ax.set_xlabel("Index") ax.set_ylabel("Residuals") ax.set_title("Run Sequence") fig.tight_layout() return fig, axes
[docs] def linear_correlation_plot( data: EDAData, fig: Figure | None = None, ax: Axes | None = None, window: int = 10, ) -> tuple[Figure, Axes]: """Plots Pearson correlation coefficient for rolling windows of data. Args: data: EDAData container. Requires x. fig: Matplotlib figure. If None, creates new figure. ax: Matplotlib axes. If None, creates new axes. window: Number of observations per rolling window. Returns: The figure and axes containing the plot. """ if data.x is None: msg = "linear_correlation_plot requires x" raise ValueError(msg) x = data.x.astype(float) y = data.y n = len(y) if window < 2 or window > n: # noqa: PLR2004 msg = ( "window must satisfy 2 <= window <= len(y); " f"got window={window}, len(y)={n}" ) raise ValueError(msg) fig, ax = get_figure_and_axes(fig, ax) n_windows = n - window + 1 midpoints = np.empty(n_windows) correlations = np.empty(n_windows) for i in range(n_windows): r, _ = scipy.stats.pearsonr(x[i : i + window], y[i : i + window]) midpoints[i] = i + window // 2 correlations[i] = r ax.plot(midpoints, correlations, "o-", markersize=4) ax.axhline(0, color="r", linestyle="--") ax.set_xlabel("Subset") ax.set_ylabel("Correlation r") ax.set_title("Linear Correlation Plot") fig.tight_layout() return fig, ax
[docs] def linear_intercept_plot( data: EDAData, fig: Figure | None = None, ax: Axes | None = None, window: int = 10, ) -> tuple[Figure, Axes]: """Plots OLS regression intercept for rolling windows of data. Args: data: EDAData container. Requires x. fig: Matplotlib figure. If None, creates new figure. ax: Matplotlib axes. If None, creates new axes. window: Number of observations per rolling window. Returns: The figure and axes containing the plot. """ if data.x is None: msg = "linear_intercept_plot requires x" raise ValueError(msg) n = len(data.y) if window < 2 or window > n: # noqa: PLR2004 msg = ( "window must satisfy 2 <= window <= len(y); " f"got window={window}, len(y)={n}" ) raise ValueError(msg) fig, ax = get_figure_and_axes(fig, ax) midpoints, _, intercepts, _ = _rolling_linregress( data.x.astype(float), data.y, window ) ax.plot(midpoints, intercepts, "o-", markersize=4) ax.axhline(intercepts.mean(), color="r", linestyle="--") ax.set_xlabel("Subset") ax.set_ylabel("Intercept") ax.set_title("Linear Intercept Plot") fig.tight_layout() return fig, ax
[docs] def linear_slope_plot( data: EDAData, fig: Figure | None = None, ax: Axes | None = None, window: int = 10, ) -> tuple[Figure, Axes]: """Plots OLS regression slope for rolling windows of data. Args: data: EDAData container. Requires x. fig: Matplotlib figure. If None, creates new figure. ax: Matplotlib axes. If None, creates new axes. window: Number of observations per rolling window. Returns: The figure and axes containing the plot. """ if data.x is None: msg = "linear_slope_plot requires x" raise ValueError(msg) n = len(data.y) if window < 2 or window > n: # noqa: PLR2004 msg = ( "window must satisfy 2 <= window <= len(y); " f"got window={window}, len(y)={n}" ) raise ValueError(msg) fig, ax = get_figure_and_axes(fig, ax) midpoints, slopes, _, _ = _rolling_linregress( data.x.astype(float), data.y, window ) ax.plot(midpoints, slopes, "o-", markersize=4) ax.axhline(slopes.mean(), color="r", linestyle="--") ax.set_xlabel("Subset") ax.set_ylabel("Slope") ax.set_title("Linear Slope Plot") fig.tight_layout() return fig, ax
[docs] def linear_residual_sd_plot( data: EDAData, fig: Figure | None = None, ax: Axes | None = None, window: int = 10, ) -> tuple[Figure, Axes]: """Plots residual standard deviation for rolling windows of data. Args: data: EDAData container. Requires x. fig: Matplotlib figure. If None, creates new figure. ax: Matplotlib axes. If None, creates new axes. window: Number of observations per rolling window. Returns: The figure and axes containing the plot. """ if data.x is None: msg = "linear_residual_sd_plot requires x" raise ValueError(msg) n = len(data.y) if window < 2 or window > n: # noqa: PLR2004 msg = ( "window must satisfy 2 <= window <= len(y); " f"got window={window}, len(y)={n}" ) raise ValueError(msg) fig, ax = get_figure_and_axes(fig, ax) midpoints, _, _, residual_sds = _rolling_linregress( data.x.astype(float), data.y, window ) ax.plot(midpoints, residual_sds, "o-", markersize=4) ax.axhline(residual_sds.mean(), color="r", linestyle="--") ax.set_xlabel("Subset") ax.set_ylabel("Residual SD") ax.set_title("Linear Residual SD Plot") fig.tight_layout() return fig, ax