Source code for drippy.onefactor

"""Plotting functions for 1-factor models (y = f(x) + e, x categorical)."""

from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import numpy as np
from drippy.utilities import get_figure_and_axes

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure
    from drippy.data import EDAData

[docs] _FACTOR_LEVEL = "Factor Level"
[docs] def scatter_plot( data: EDAData, fig: Figure | None = None, ax: Axes | None = None, ) -> tuple[Figure, Axes]: """Creates a scatter plot of y vs x. Also used in regression context (see drippy.regression). Args: data: EDAData container. Requires x. fig: Matplotlib figure. If None, creates new figure. ax: Matplotlib axes. If None, creates new axes. Returns: The figure and axes containing the plot. """ if data.x is None: msg = "scatter_plot requires x" raise ValueError(msg) fig, ax = get_figure_and_axes(fig, ax) ax.scatter(data.x, data.y) ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_title("Scatter Plot") fig.tight_layout() return fig, ax
[docs] def box_plot( data: EDAData, fig: Figure | None = None, ax: Axes | None = None, ) -> tuple[Figure, Axes]: """Creates a box plot of y grouped by factor levels in x. Args: data: EDAData container. Requires x (categorical). fig: Matplotlib figure. If None, creates new figure. ax: Matplotlib axes. If None, creates new axes. Returns: The figure and axes containing the plot. """ if data.x is None: msg = "box_plot requires x" raise ValueError(msg) fig, ax = get_figure_and_axes(fig, ax) levels = np.unique(data.x) groups = [data.y[data.x == level] for level in levels] ax.boxplot(groups, tick_labels=levels) ax.set_xlabel(_FACTOR_LEVEL) ax.set_ylabel("Y") ax.set_title("Box Plot") fig.tight_layout() return fig, ax
[docs] def bihistogram( data: EDAData, fig: Figure | None = None, axes: np.ndarray | None = None, bins: int | str = "auto", ) -> tuple[Figure, np.ndarray]: """Creates side-by-side histograms for exactly 2 factor levels. Args: data: EDAData container. Requires x with exactly 2 unique levels. fig: Matplotlib figure. If None, creates new figure. axes: Array of 2 Axes. If None, creates new axes. bins: Number of bins or bin strategy. Returns: (fig, axes) where axes is a 1-D array of 2 Axes. """ if data.x is None: msg = "bihistogram requires x" raise ValueError(msg) levels = np.unique(data.x) if len(levels) != 2: # noqa: PLR2004 n_levels = len(levels) msg = f"bihistogram requires exactly 2 factor levels, got {n_levels}" raise ValueError(msg) if fig is None and axes is None: fig, axes = plt.subplots(1, 2) elif axes is None: axes = fig.subplots(1, 2) elif fig is None: fig = axes.flat[0].get_figure() if axes.shape != (2,): msg = "axes must have shape (2,)" raise ValueError(msg) group_a = data.y[data.x == levels[0]] group_b = data.y[data.x == levels[1]] axes[0].hist(group_a, bins=bins) axes[0].set_title(f"Histogram: {levels[0]}") axes[0].set_xlabel("Value") axes[0].set_ylabel("Frequency") axes[1].hist(group_b, bins=bins) axes[1].set_title(f"Histogram: {levels[1]}") axes[1].set_xlabel("Value") axes[1].set_ylabel("Frequency") fig.tight_layout() return fig, axes
[docs] def qq_plot( data: EDAData, fig: Figure | None = None, ax: Axes | None = None, ) -> tuple[Figure, Axes]: """Creates a quantile-quantile plot comparing 2 factor level distributions. Args: data: EDAData container. Requires x with exactly 2 unique levels. fig: Matplotlib figure. If None, creates new figure. ax: Matplotlib axes. If None, creates new axes. Returns: The figure and axes containing the plot. """ if data.x is None: msg = "qq_plot requires x" raise ValueError(msg) levels = np.unique(data.x) if len(levels) != 2: # noqa: PLR2004 n_levels = len(levels) msg = f"qq_plot requires exactly 2 factor levels, got {n_levels}" raise ValueError(msg) fig, ax = get_figure_and_axes(fig, ax) group_a = data.y[data.x == levels[0]] group_b = data.y[data.x == levels[1]] n = min(len(group_a), len(group_b)) quantiles = np.linspace(0, 1, n) qa = np.quantile(group_a, quantiles) qb = np.quantile(group_b, quantiles) ax.scatter(qa, qb, label="Quantiles") min_val = min(qa.min(), qb.min()) max_val = max(qa.max(), qb.max()) ax.plot([min_val, max_val], [min_val, max_val], "r--", label="y=x") ax.set_xlabel(f"Quantiles: {levels[0]}") ax.set_ylabel(f"Quantiles: {levels[1]}") ax.set_title("Quantile-Quantile Plot") ax.legend() fig.tight_layout() return fig, ax
[docs] def mean_plot( data: EDAData, fig: Figure | None = None, ax: Axes | None = None, ) -> tuple[Figure, Axes]: """Creates a mean plot of y grouped by factor levels in x. Shows group means connected by a line, with a horizontal grand mean. Args: data: EDAData container. Requires x. fig: Matplotlib figure. If None, creates new figure. ax: Matplotlib axes. If None, creates new axes. Returns: The figure and axes containing the plot. """ if data.x is None: msg = "mean_plot requires x" raise ValueError(msg) fig, ax = get_figure_and_axes(fig, ax) levels = np.unique(data.x) means = [data.y[data.x == level].mean() for level in levels] ax.plot(levels, means, "o-") ax.axhline(data.y.mean(), color="r", linestyle="--", label="Grand mean") ax.set_xlabel(_FACTOR_LEVEL) ax.set_ylabel("Mean of Y") ax.set_title("Mean Plot") ax.legend() fig.tight_layout() return fig, ax
[docs] def sd_plot( data: EDAData, fig: Figure | None = None, ax: Axes | None = None, ) -> tuple[Figure, Axes]: """Creates a standard deviation plot of y grouped by factor levels in x. Shows group standard deviations connected by a line, with a horizontal overall standard deviation reference line. Args: data: EDAData container. Requires x. fig: Matplotlib figure. If None, creates new figure. ax: Matplotlib axes. If None, creates new axes. Returns: The figure and axes containing the plot. """ if data.x is None: msg = "sd_plot requires x" raise ValueError(msg) fig, ax = get_figure_and_axes(fig, ax) levels = np.unique(data.x) sds = [data.y[data.x == level].std() for level in levels] ax.plot(levels, sds, "o-") ax.axhline(data.y.std(), color="r", linestyle="--", label="Overall SD") ax.set_xlabel(_FACTOR_LEVEL) ax.set_ylabel("Standard Deviation of Y") ax.set_title("Standard Deviation Plot") ax.legend() fig.tight_layout() return fig, ax