Source code for maldiamrkit.evaluation.mic_regression

"""Regression-style metrics for MIC prediction.

Reports the metrics clinicians and ML practitioners actually look at when
predicting MIC values on dilution series: RMSE in log2 dilutions, essential
agreement (within ±1 dilution), and -- if breakpoints are provided --
categorical agreement after re-binning to S/I/R.
"""

from __future__ import annotations

from typing import Sequence

import numpy as np
import pandas as pd

from ..susceptibility.breakpoint import BreakpointTable


[docs] def mic_regression_report( y_true: Sequence[float] | np.ndarray | pd.Series, y_pred: Sequence[float] | np.ndarray | pd.Series, *, breakpoints: BreakpointTable | None = None, species: str | Sequence[str] | None = None, drug: str | Sequence[str] | None = None, sample_weight: Sequence[float] | None = None, ) -> dict[str, float | int]: """Compute MIC regression metrics on log2-MIC predictions. Parameters ---------- y_true : array-like True ``log2(MIC)`` values. y_pred : array-like Predicted ``log2(MIC)`` values. breakpoints : BreakpointTable or None, default=None When provided, the report also includes categorical agreement after re-binning both ``y_true`` and ``y_pred`` to S/I/R. Requires ``species`` and ``drug``. species : str or array-like, optional Species per sample (or a single species applied to all). Required when ``breakpoints`` is provided. drug : str or array-like, optional Drug per sample (or a single drug applied to all). Required when ``breakpoints`` is provided. sample_weight : array-like, optional Per-sample weights for the regression metrics. Ignored for categorical agreement. Returns ------- dict Keys: ``n``, ``rmse_log2``, ``mae_log2``, ``bias_log2``, ``essential_agreement`` (fraction within ±1 dilution), and when breakpoints are provided also ``categorical_agreement``, ``very_major_error_rate`` (R predicted as S), ``major_error_rate`` (S predicted as R), and per-category sample counts. Notes ----- "Essential agreement" is the standard clinical benchmark for MIC prediction accuracy: a prediction is essential-agreement-correct if it is within one log2 dilution of the true value. """ y_true = np.asarray(y_true, dtype=float) y_pred = np.asarray(y_pred, dtype=float) if y_true.shape != y_pred.shape: raise ValueError( f"Shape mismatch: y_true {y_true.shape} vs y_pred {y_pred.shape}." ) mask = ~(np.isnan(y_true) | np.isnan(y_pred)) y_true_ok = y_true[mask] y_pred_ok = y_pred[mask] if sample_weight is not None: weights = np.asarray(sample_weight, dtype=float)[mask] else: weights = np.ones_like(y_true_ok) n = int(mask.sum()) if n == 0: return { "n": 0, "rmse_log2": float("nan"), "mae_log2": float("nan"), "bias_log2": float("nan"), "essential_agreement": float("nan"), } diff = y_pred_ok - y_true_ok rmse = float(np.sqrt(np.average(diff**2, weights=weights))) mae = float(np.average(np.abs(diff), weights=weights)) bias = float(np.average(diff, weights=weights)) essential = float(np.average((np.abs(diff) <= 1.0).astype(float), weights=weights)) report: dict[str, float | int] = { "n": n, "rmse_log2": rmse, "mae_log2": mae, "bias_log2": bias, "essential_agreement": essential, } if breakpoints is not None: if species is None or drug is None: raise ValueError( "species and drug are required when breakpoints is provided." ) cat_metrics = _categorical_block( breakpoints=breakpoints, species=species, drug=drug, log2_true=y_true, log2_pred=y_pred, mask=mask, ) report.update(cat_metrics) return report
def _categorical_block( *, breakpoints: BreakpointTable, species, drug, log2_true: np.ndarray, log2_pred: np.ndarray, mask: np.ndarray, ) -> dict[str, float | int]: mic_true = np.power(2.0, log2_true) mic_pred = np.power(2.0, log2_pred) cat_true = breakpoints.apply_batch(species, drug, mic_true)["category"].to_numpy() cat_pred = breakpoints.apply_batch(species, drug, mic_pred)["category"].to_numpy() both = ( mask & np.array([t is not None for t in cat_true]) & np.array([p is not None for p in cat_pred]) ) n_cat = int(both.sum()) if n_cat == 0: return { "categorical_agreement": float("nan"), "very_major_error_rate": float("nan"), "major_error_rate": float("nan"), "n_categorical": 0, } t = cat_true[both] p = cat_pred[both] agreement = float(np.mean(t == p)) resistant_mask = t == "R" susceptible_mask = t == "S" vme = ( float(np.mean(p[resistant_mask] == "S")) if resistant_mask.any() else float("nan") ) me = ( float(np.mean(p[susceptible_mask] == "R")) if susceptible_mask.any() else float("nan") ) return { "categorical_agreement": agreement, "very_major_error_rate": vme, "major_error_rate": me, "n_categorical": n_cat, "n_resistant_true": int(resistant_mask.sum()), "n_susceptible_true": int(susceptible_mask.sum()), }