Source code for maldiamrkit.evaluation.splitting

"""Stratified splitting utilities for AMR datasets.

Provides species-drug stratified and case-based (patient-grouped) splitting
to prevent data leakage and ensure balanced evaluation of AMR classifiers.
"""

from __future__ import annotations

from collections.abc import Iterator

import numpy as np
import pandas as pd
from sklearn.model_selection import (
    GroupShuffleSplit,
    StratifiedGroupKFold,
    StratifiedKFold,
    StratifiedShuffleSplit,
)


def _build_strata(
    y: np.ndarray,
    species: np.ndarray,
    min_count: int = 2,
) -> np.ndarray:
    """Build stratification labels from species + drug resistance.

    Combines species and resistance label into a single stratum key.
    Rare strata (< min_count samples) are merged into an "other" group.

    Parameters
    ----------
    y : array-like
        Resistance labels.
    species : array-like
        Species labels.
    min_count : int, default=2
        Minimum samples per stratum. Smaller groups are merged.

    Returns
    -------
    np.ndarray
        Array of stratum keys (strings).
    """
    y = np.asarray(y, dtype=str)
    species = np.asarray(species, dtype=str)
    strata = np.array([f"{s}__{lab}" for s, lab in zip(species, y, strict=True)])

    # Merge rare strata, preserving resistance label to maintain
    # class balance across the merged group.
    unique, counts = np.unique(strata, return_counts=True)
    rare = set(unique[counts < min_count])
    if rare:
        strata = np.array(
            [
                s if s not in rare else f"__rare_{lab}__"
                for s, lab in zip(strata, y, strict=True)
            ]
        )

    return strata


[docs] def stratified_species_drug_split( X: pd.DataFrame | np.ndarray, y: np.ndarray, species: np.ndarray, test_size: float = 0.2, random_state: int | None = None, min_count: int = 2, ) -> tuple: """Stratified train/test split preserving species-drug label distributions. Parameters ---------- X : pd.DataFrame or np.ndarray Feature matrix. y : array-like Resistance labels. species : array-like Species labels aligned with X. test_size : float, default=0.2 Fraction of samples for the test set. random_state : int or None, default=None Random seed for reproducibility. min_count : int, default=2 Minimum samples per species-drug stratum. Smaller groups are merged. Returns ------- X_train, X_test, y_train, y_test : arrays Split data. """ y = np.asarray(y) species = np.asarray(species) strata = _build_strata(y, species, min_count) splitter = StratifiedShuffleSplit( n_splits=1, test_size=test_size, random_state=random_state ) train_idx, test_idx = next(splitter.split(X, strata)) if isinstance(X, pd.DataFrame): X_train = X.iloc[train_idx] X_test = X.iloc[test_idx] else: X_train = X[train_idx] X_test = X[test_idx] return X_train, X_test, y[train_idx], y[test_idx]
[docs] def case_based_split( X: pd.DataFrame | np.ndarray, y: np.ndarray, case_ids: np.ndarray, test_size: float = 0.2, random_state: int | None = None, ) -> tuple: """Train/test split keeping all samples from the same patient together. Prevents data leakage from having the same patient in both train and test. Parameters ---------- X : pd.DataFrame or np.ndarray Feature matrix. y : array-like Resistance labels. case_ids : array-like Patient/case identifiers aligned with X. test_size : float, default=0.2 Fraction of groups for the test set. random_state : int or None, default=None Random seed for reproducibility. Returns ------- X_train, X_test, y_train, y_test : arrays Split data. """ y = np.asarray(y) case_ids = np.asarray(case_ids) splitter = GroupShuffleSplit( n_splits=1, test_size=test_size, random_state=random_state ) train_idx, test_idx = next(splitter.split(X, y, groups=case_ids)) if isinstance(X, pd.DataFrame): X_train = X.iloc[train_idx] X_test = X.iloc[test_idx] else: X_train = X[train_idx] X_test = X[test_idx] return X_train, X_test, y[train_idx], y[test_idx]
[docs] class SpeciesDrugStratifiedKFold: """K-fold cross-validation with species-drug stratification. Ensures each fold preserves the distribution of species-drug combinations. Implements the sklearn splitter interface. Parameters ---------- n_splits : int, default=5 Number of folds. shuffle : bool, default=True Whether to shuffle before splitting. random_state : int or None, default=None Random seed for reproducibility. min_count : int, default=2 Minimum samples per stratum before merging. Examples -------- >>> cv = SpeciesDrugStratifiedKFold(n_splits=5) >>> for train_idx, test_idx in cv.split(X, y, species=species): ... X_train, X_test = X[train_idx], X[test_idx] """
[docs] def __init__( self, n_splits: int = 5, shuffle: bool = True, random_state: int | None = None, min_count: int = 2, ): self.n_splits = n_splits self.shuffle = shuffle self.random_state = random_state self.min_count = min_count
[docs] def get_n_splits( self, X: pd.DataFrame | np.ndarray | None = None, y: np.ndarray | None = None, groups: np.ndarray | None = None, ) -> int: """Return the number of splits.""" return self.n_splits
[docs] def split( self, X: pd.DataFrame | np.ndarray, y: np.ndarray, species: np.ndarray | None = None, groups: np.ndarray | None = None, ) -> Iterator[tuple[np.ndarray, np.ndarray]]: """Generate train/test indices for each fold. Parameters ---------- X : array-like Feature matrix. y : array-like Resistance labels. species : array-like Species labels. If None, falls back to plain stratified KFold. groups : ignored Not used, present for API compatibility. Yields ------ train_idx, test_idx : np.ndarray Indices for train and test sets. """ y = np.asarray(y) if species is not None: species = np.asarray(species) strata = _build_strata(y, species, self.min_count) else: strata = y skf = StratifiedKFold( n_splits=self.n_splits, shuffle=self.shuffle, random_state=self.random_state, ) yield from skf.split(X, strata)
[docs] class CaseGroupedKFold: """K-fold cross-validation keeping patient cases together and stratified by ``y``. All samples from the same case/patient are always in the same fold, and folds are stratified on the resistance label to preserve class balance. Wraps :class:`sklearn.model_selection.StratifiedGroupKFold`. Parameters ---------- n_splits : int, default=5 Number of folds. shuffle : bool, default=True Whether to shuffle group order before splitting. random_state : int or None, default=None Random seed (used only when ``shuffle=True``). Examples -------- >>> cv = CaseGroupedKFold(n_splits=5) >>> for train_idx, test_idx in cv.split(X, y, groups=case_ids): ... X_train, X_test = X[train_idx], X[test_idx] """
[docs] def __init__( self, n_splits: int = 5, shuffle: bool = True, random_state: int | None = None, ): self.n_splits = n_splits self.shuffle = shuffle self.random_state = random_state
[docs] def get_n_splits( self, X: pd.DataFrame | np.ndarray | None = None, y: np.ndarray | None = None, groups: np.ndarray | None = None, ) -> int: """Return the number of splits.""" return self.n_splits
[docs] def split( self, X: pd.DataFrame | np.ndarray, y: np.ndarray | None = None, groups: np.ndarray | None = None, ) -> Iterator[tuple[np.ndarray, np.ndarray]]: """Generate stratified, group-preserving train/test indices for each fold. Parameters ---------- X : array-like Feature matrix. y : array-like Resistance labels. Required for stratification. groups : array-like Case/patient identifiers. Required. Yields ------ train_idx, test_idx : np.ndarray Indices for train and test sets. Raises ------ ValueError If ``groups`` or ``y`` is None. """ if groups is None: raise ValueError("groups (case_ids) must be provided for CaseGroupedKFold") if y is None: raise ValueError("y must be provided for CaseGroupedKFold stratification") sgkf = StratifiedGroupKFold( n_splits=self.n_splits, shuffle=self.shuffle, random_state=self.random_state, ) yield from sgkf.split(X, y, groups=groups)