Source code for maldiamrkit.evaluation.metrics

"""AMR-specific evaluation metrics for clinical microbiology.

Provides Very Major Error (VME), Major Error (ME), sensitivity, specificity,
and categorical agreement metrics following EUCAST conventions.

In AMR prediction:
- VME (Very Major Error): resistant isolates classified as susceptible (dangerous)
- ME (Major Error): susceptible isolates classified as resistant (wasteful)
"""

from __future__ import annotations

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, make_scorer


def _get_confusion_values(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    resistant_label: int = 1,
) -> tuple[int, int, int, int]:
    """Extract TP, TN, FP, FN from predictions.

    Parameters
    ----------
    y_true : array-like
        True labels.
    y_pred : array-like
        Predicted labels.
    resistant_label : int, default=1
        Label value representing the resistant class.

    Returns
    -------
    tp, tn, fp, fn : int
        Confusion matrix values where positive = resistant.
    """
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    labels = sorted(set(y_true) | set(y_pred))

    if len(labels) < 2:
        # Single-class edge case
        if labels[0] == resistant_label:
            tp = np.sum(y_pred == resistant_label)
            fn = np.sum(y_pred != resistant_label)
            return tp, 0, 0, fn
        else:
            tn = np.sum(y_pred != resistant_label)
            fp = np.sum(y_pred == resistant_label)
            return 0, tn, fp, 0

    if len(labels) > 2:
        raise ValueError(
            f"Expected binary labels, got {len(labels)} unique values: {labels}. "
            f"Encode labels to binary before using AMR metrics."
        )

    if resistant_label not in labels:
        # resistant_label never appears - treat as all-susceptible
        tn = len(y_true)
        return 0, tn, 0, 0

    cm = confusion_matrix(y_true, y_pred, labels=labels)

    # Find index of resistant label
    r_idx = labels.index(resistant_label)
    s_idx = 1 - r_idx

    tp = cm[r_idx, r_idx]
    fn = cm[r_idx, s_idx]
    fp = cm[s_idx, r_idx]
    tn = cm[s_idx, s_idx]

    return tp, tn, fp, fn


[docs] def very_major_error_rate( y_true: np.ndarray, y_pred: np.ndarray, resistant_label: int = 1, ) -> float: """Very Major Error rate: resistant isolates classified as susceptible. VME = FN / (FN + TP), i.e., the miss rate for resistant samples. This is the most dangerous error type in clinical microbiology. Parameters ---------- y_true : array-like True labels. y_pred : array-like Predicted labels. resistant_label : int, default=1 Label value representing the resistant class. Returns ------- float VME rate in [0, 1]. Returns 0.0 if no resistant samples exist. Examples -------- >>> very_major_error_rate([1, 1, 0, 0], [0, 1, 0, 0]) 0.5 """ tp, _, _, fn = _get_confusion_values(y_true, y_pred, resistant_label) denom = fn + tp return fn / denom if denom > 0 else 0.0
[docs] def major_error_rate( y_true: np.ndarray, y_pred: np.ndarray, resistant_label: int = 1, ) -> float: """Major Error rate: susceptible isolates classified as resistant. ME = FP / (FP + TN), i.e., the false alarm rate for susceptible samples. Parameters ---------- y_true : array-like True labels. y_pred : array-like Predicted labels. resistant_label : int, default=1 Label value representing the resistant class. Returns ------- float ME rate in [0, 1]. Returns 0.0 if no susceptible samples exist. Examples -------- >>> major_error_rate([1, 1, 0, 0], [1, 1, 1, 0]) 0.5 """ _, tn, fp, _ = _get_confusion_values(y_true, y_pred, resistant_label) denom = fp + tn return fp / denom if denom > 0 else 0.0
[docs] def sensitivity_score( y_true: np.ndarray, y_pred: np.ndarray, resistant_label: int = 1, ) -> float: """Sensitivity (recall) for the resistant class. Sensitivity = TP / (TP + FN) = 1 - VME. Parameters ---------- y_true : array-like True labels. y_pred : array-like Predicted labels. resistant_label : int, default=1 Label value representing the resistant class. Returns ------- float Sensitivity in [0, 1]. Returns 0.0 if no resistant samples exist. """ tp, _, _, fn = _get_confusion_values(y_true, y_pred, resistant_label) denom = tp + fn return tp / denom if denom > 0 else 0.0
[docs] def specificity_score( y_true: np.ndarray, y_pred: np.ndarray, resistant_label: int = 1, ) -> float: """Specificity (true negative rate) for the susceptible class. Specificity = TN / (TN + FP) = 1 - ME. Parameters ---------- y_true : array-like True labels. y_pred : array-like Predicted labels. resistant_label : int, default=1 Label value representing the resistant class. Returns ------- float Specificity in [0, 1]. Returns 0.0 if no susceptible samples exist. """ _, tn, fp, _ = _get_confusion_values(y_true, y_pred, resistant_label) denom = tn + fp return tn / denom if denom > 0 else 0.0
[docs] def categorical_agreement( y_true: np.ndarray, y_pred: np.ndarray, ) -> float: """Categorical agreement (accuracy) as reported in AST studies. CA = (TP + TN) / N. Parameters ---------- y_true : array-like True labels. y_pred : array-like Predicted labels. Returns ------- float Agreement rate in [0, 1]. """ y_true = np.asarray(y_true) y_pred = np.asarray(y_pred) if len(y_true) == 0: return 0.0 return np.mean(y_true == y_pred)
[docs] def vme_me_curve( y_true: np.ndarray, y_score: np.ndarray, resistant_label: int = 1, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """VME and ME rates at varying decision thresholds. Useful for selecting an optimal threshold balancing VME against ME. Parameters ---------- y_true : array-like True binary labels. y_score : array-like Predicted scores (e.g., probabilities for the resistant class). resistant_label : int, default=1 Label value representing the resistant class. Returns ------- vme_rates : np.ndarray VME rates at each threshold. me_rates : np.ndarray ME rates at each threshold. thresholds : np.ndarray Decision thresholds (sorted ascending). """ y_true = np.asarray(y_true) y_score = np.asarray(y_score, dtype=float) y_binary = (y_true == resistant_label).astype(np.int64) total_pos = int(y_binary.sum()) total_neg = int(y_binary.size - total_pos) thresholds = np.sort(np.unique(y_score)) if thresholds.size == 0: empty = np.empty(0, dtype=float) return empty, empty, thresholds order = np.argsort(y_score, kind="mergesort") sorted_score = y_score[order] sorted_pos = y_binary[order] sorted_neg = 1 - sorted_pos cum_pos = np.concatenate(([0], np.cumsum(sorted_pos))) cum_neg = np.concatenate(([0], np.cumsum(sorted_neg))) # Number of samples strictly below each threshold (predicted negative). left = np.searchsorted(sorted_score, thresholds, side="left") fn = cum_pos[left] tn = cum_neg[left] tp = total_pos - fn fp = total_neg - tn vme_rates = np.where((fn + tp) > 0, fn / np.maximum(fn + tp, 1), 0.0) me_rates = np.where((fp + tn) > 0, fp / np.maximum(fp + tn, 1), 0.0) return vme_rates.astype(float), me_rates.astype(float), thresholds
[docs] def amr_classification_report( y_true: np.ndarray, y_pred: np.ndarray, resistant_label: int = 1, ) -> dict: """Full AMR classification report. Returns all clinical metrics in a single dictionary. Parameters ---------- y_true : array-like True labels. y_pred : array-like Predicted labels. resistant_label : int, default=1 Label value representing the resistant class. Returns ------- dict Dictionary with keys: vme, me, sensitivity, specificity, categorical_agreement, n_resistant, n_susceptible, n_total. Examples -------- >>> report = amr_classification_report([1, 1, 0, 0], [1, 0, 0, 1]) >>> report["vme"] 0.5 """ y_true = np.asarray(y_true) tp, tn, fp, fn = _get_confusion_values(y_true, y_pred, resistant_label) return { "vme": very_major_error_rate(y_true, y_pred, resistant_label), "me": major_error_rate(y_true, y_pred, resistant_label), "sensitivity": sensitivity_score(y_true, y_pred, resistant_label), "specificity": specificity_score(y_true, y_pred, resistant_label), "categorical_agreement": categorical_agreement(y_true, y_pred), "n_resistant": int(tp + fn), "n_susceptible": int(tn + fp), "n_total": len(y_true), }
[docs] def amr_multilabel_report( y_true: pd.DataFrame, y_pred: pd.DataFrame, *, resistant_label: int = 1, as_dataframe: bool = False, ) -> dict | pd.DataFrame: """AMR classification report for multiple antibiotics. Computes per-drug VME, ME, sensitivity, specificity, and categorical agreement, plus a macro-average across all drugs. Parameters ---------- y_true : pd.DataFrame True binary labels with one column per antibiotic. y_pred : pd.DataFrame Predicted binary labels with matching columns. resistant_label : int, default=1 Label value representing the resistant class. as_dataframe : bool, default=False If ``True``, return a :class:`~pandas.DataFrame` instead of a nested dict. Returns ------- dict or pd.DataFrame Per-drug metrics plus a ``"macro_avg"`` entry. When *as_dataframe* is ``True``, rows are drugs + ``"macro_avg"`` and columns are metric names. Examples -------- >>> report = amr_multilabel_report(y_true, y_pred, as_dataframe=True) >>> report.loc["macro_avg", "vme"] 0.15 """ drugs = [c for c in y_true.columns if c in y_pred.columns] if not drugs: raise ValueError("No common columns between y_true and y_pred.") reports: dict[str, dict] = {} for drug in drugs: yt = y_true[drug] yp = y_pred[drug] # Drop rows where either side is NaN valid = yt.notna() & yp.notna() yt = yt[valid] yp = yp[valid] if len(yt) == 0: continue reports[drug] = amr_classification_report( yt.to_numpy(), yp.to_numpy(), resistant_label=resistant_label ) reports["macro_avg"] = _compute_macro_average(reports) if as_dataframe: return pd.DataFrame(reports).T return reports
def _compute_macro_average(reports: dict[str, dict]) -> dict: """Compute macro-averaged metrics across per-drug reports.""" metric_keys = ["vme", "me", "sensitivity", "specificity", "categorical_agreement"] macro: dict[str, float | int] = {} for key in metric_keys: values = [r[key] for r in reports.values()] macro[key] = float(np.mean(values)) if values else 0.0 for key in ["n_resistant", "n_susceptible", "n_total"]: macro[key] = sum(r[key] for r in reports.values()) return macro # Pre-built sklearn scorers for cross_val_score / GridSearchCV vme_scorer = make_scorer(very_major_error_rate, greater_is_better=False) """Scorer that minimizes VME. Use with ``cross_val_score`` or ``GridSearchCV``.""" me_scorer = make_scorer(major_error_rate, greater_is_better=False) """Scorer that minimizes ME. Use with ``cross_val_score`` or ``GridSearchCV``."""