"""AMR-specific evaluation metrics for clinical microbiology.
Provides Very Major Error (VME), Major Error (ME), sensitivity, specificity,
and categorical agreement metrics following EUCAST conventions.
In AMR prediction:
- VME (Very Major Error): resistant isolates classified as susceptible (dangerous)
- ME (Major Error): susceptible isolates classified as resistant (wasteful)
"""
from __future__ import annotations
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, make_scorer
def _get_confusion_values(
y_true: np.ndarray,
y_pred: np.ndarray,
resistant_label: int = 1,
) -> tuple[int, int, int, int]:
"""Extract TP, TN, FP, FN from predictions.
Parameters
----------
y_true : array-like
True labels.
y_pred : array-like
Predicted labels.
resistant_label : int, default=1
Label value representing the resistant class.
Returns
-------
tp, tn, fp, fn : int
Confusion matrix values where positive = resistant.
"""
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
labels = sorted(set(y_true) | set(y_pred))
if len(labels) < 2:
# Single-class edge case
if labels[0] == resistant_label:
tp = np.sum(y_pred == resistant_label)
fn = np.sum(y_pred != resistant_label)
return tp, 0, 0, fn
else:
tn = np.sum(y_pred != resistant_label)
fp = np.sum(y_pred == resistant_label)
return 0, tn, fp, 0
if len(labels) > 2:
raise ValueError(
f"Expected binary labels, got {len(labels)} unique values: {labels}. "
f"Encode labels to binary before using AMR metrics."
)
if resistant_label not in labels:
# resistant_label never appears - treat as all-susceptible
tn = len(y_true)
return 0, tn, 0, 0
cm = confusion_matrix(y_true, y_pred, labels=labels)
# Find index of resistant label
r_idx = labels.index(resistant_label)
s_idx = 1 - r_idx
tp = cm[r_idx, r_idx]
fn = cm[r_idx, s_idx]
fp = cm[s_idx, r_idx]
tn = cm[s_idx, s_idx]
return tp, tn, fp, fn
[docs]
def very_major_error_rate(
y_true: np.ndarray,
y_pred: np.ndarray,
resistant_label: int = 1,
) -> float:
"""Very Major Error rate: resistant isolates classified as susceptible.
VME = FN / (FN + TP), i.e., the miss rate for resistant samples.
This is the most dangerous error type in clinical microbiology.
Parameters
----------
y_true : array-like
True labels.
y_pred : array-like
Predicted labels.
resistant_label : int, default=1
Label value representing the resistant class.
Returns
-------
float
VME rate in [0, 1]. Returns 0.0 if no resistant samples exist.
Examples
--------
>>> very_major_error_rate([1, 1, 0, 0], [0, 1, 0, 0])
0.5
"""
tp, _, _, fn = _get_confusion_values(y_true, y_pred, resistant_label)
denom = fn + tp
return fn / denom if denom > 0 else 0.0
[docs]
def major_error_rate(
y_true: np.ndarray,
y_pred: np.ndarray,
resistant_label: int = 1,
) -> float:
"""Major Error rate: susceptible isolates classified as resistant.
ME = FP / (FP + TN), i.e., the false alarm rate for susceptible samples.
Parameters
----------
y_true : array-like
True labels.
y_pred : array-like
Predicted labels.
resistant_label : int, default=1
Label value representing the resistant class.
Returns
-------
float
ME rate in [0, 1]. Returns 0.0 if no susceptible samples exist.
Examples
--------
>>> major_error_rate([1, 1, 0, 0], [1, 1, 1, 0])
0.5
"""
_, tn, fp, _ = _get_confusion_values(y_true, y_pred, resistant_label)
denom = fp + tn
return fp / denom if denom > 0 else 0.0
[docs]
def sensitivity_score(
y_true: np.ndarray,
y_pred: np.ndarray,
resistant_label: int = 1,
) -> float:
"""Sensitivity (recall) for the resistant class.
Sensitivity = TP / (TP + FN) = 1 - VME.
Parameters
----------
y_true : array-like
True labels.
y_pred : array-like
Predicted labels.
resistant_label : int, default=1
Label value representing the resistant class.
Returns
-------
float
Sensitivity in [0, 1]. Returns 0.0 if no resistant samples exist.
"""
tp, _, _, fn = _get_confusion_values(y_true, y_pred, resistant_label)
denom = tp + fn
return tp / denom if denom > 0 else 0.0
[docs]
def specificity_score(
y_true: np.ndarray,
y_pred: np.ndarray,
resistant_label: int = 1,
) -> float:
"""Specificity (true negative rate) for the susceptible class.
Specificity = TN / (TN + FP) = 1 - ME.
Parameters
----------
y_true : array-like
True labels.
y_pred : array-like
Predicted labels.
resistant_label : int, default=1
Label value representing the resistant class.
Returns
-------
float
Specificity in [0, 1]. Returns 0.0 if no susceptible samples exist.
"""
_, tn, fp, _ = _get_confusion_values(y_true, y_pred, resistant_label)
denom = tn + fp
return tn / denom if denom > 0 else 0.0
[docs]
def categorical_agreement(
y_true: np.ndarray,
y_pred: np.ndarray,
) -> float:
"""Categorical agreement (accuracy) as reported in AST studies.
CA = (TP + TN) / N.
Parameters
----------
y_true : array-like
True labels.
y_pred : array-like
Predicted labels.
Returns
-------
float
Agreement rate in [0, 1].
"""
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
if len(y_true) == 0:
return 0.0
return np.mean(y_true == y_pred)
[docs]
def vme_me_curve(
y_true: np.ndarray,
y_score: np.ndarray,
resistant_label: int = 1,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""VME and ME rates at varying decision thresholds.
Useful for selecting an optimal threshold balancing VME against ME.
Parameters
----------
y_true : array-like
True binary labels.
y_score : array-like
Predicted scores (e.g., probabilities for the resistant class).
resistant_label : int, default=1
Label value representing the resistant class.
Returns
-------
vme_rates : np.ndarray
VME rates at each threshold.
me_rates : np.ndarray
ME rates at each threshold.
thresholds : np.ndarray
Decision thresholds (sorted ascending).
"""
y_true = np.asarray(y_true)
y_score = np.asarray(y_score, dtype=float)
y_binary = (y_true == resistant_label).astype(np.int64)
total_pos = int(y_binary.sum())
total_neg = int(y_binary.size - total_pos)
thresholds = np.sort(np.unique(y_score))
if thresholds.size == 0:
empty = np.empty(0, dtype=float)
return empty, empty, thresholds
order = np.argsort(y_score, kind="mergesort")
sorted_score = y_score[order]
sorted_pos = y_binary[order]
sorted_neg = 1 - sorted_pos
cum_pos = np.concatenate(([0], np.cumsum(sorted_pos)))
cum_neg = np.concatenate(([0], np.cumsum(sorted_neg)))
# Number of samples strictly below each threshold (predicted negative).
left = np.searchsorted(sorted_score, thresholds, side="left")
fn = cum_pos[left]
tn = cum_neg[left]
tp = total_pos - fn
fp = total_neg - tn
vme_rates = np.where((fn + tp) > 0, fn / np.maximum(fn + tp, 1), 0.0)
me_rates = np.where((fp + tn) > 0, fp / np.maximum(fp + tn, 1), 0.0)
return vme_rates.astype(float), me_rates.astype(float), thresholds
[docs]
def amr_classification_report(
y_true: np.ndarray,
y_pred: np.ndarray,
resistant_label: int = 1,
) -> dict:
"""Full AMR classification report.
Returns all clinical metrics in a single dictionary.
Parameters
----------
y_true : array-like
True labels.
y_pred : array-like
Predicted labels.
resistant_label : int, default=1
Label value representing the resistant class.
Returns
-------
dict
Dictionary with keys: vme, me, sensitivity, specificity, categorical_agreement,
n_resistant, n_susceptible, n_total.
Examples
--------
>>> report = amr_classification_report([1, 1, 0, 0], [1, 0, 0, 1])
>>> report["vme"]
0.5
"""
y_true = np.asarray(y_true)
tp, tn, fp, fn = _get_confusion_values(y_true, y_pred, resistant_label)
return {
"vme": very_major_error_rate(y_true, y_pred, resistant_label),
"me": major_error_rate(y_true, y_pred, resistant_label),
"sensitivity": sensitivity_score(y_true, y_pred, resistant_label),
"specificity": specificity_score(y_true, y_pred, resistant_label),
"categorical_agreement": categorical_agreement(y_true, y_pred),
"n_resistant": int(tp + fn),
"n_susceptible": int(tn + fp),
"n_total": len(y_true),
}
[docs]
def amr_multilabel_report(
y_true: pd.DataFrame,
y_pred: pd.DataFrame,
*,
resistant_label: int = 1,
as_dataframe: bool = False,
) -> dict | pd.DataFrame:
"""AMR classification report for multiple antibiotics.
Computes per-drug VME, ME, sensitivity, specificity, and categorical
agreement, plus a macro-average across all drugs.
Parameters
----------
y_true : pd.DataFrame
True binary labels with one column per antibiotic.
y_pred : pd.DataFrame
Predicted binary labels with matching columns.
resistant_label : int, default=1
Label value representing the resistant class.
as_dataframe : bool, default=False
If ``True``, return a :class:`~pandas.DataFrame` instead of a
nested dict.
Returns
-------
dict or pd.DataFrame
Per-drug metrics plus a ``"macro_avg"`` entry. When
*as_dataframe* is ``True``, rows are drugs + ``"macro_avg"`` and
columns are metric names.
Examples
--------
>>> report = amr_multilabel_report(y_true, y_pred, as_dataframe=True)
>>> report.loc["macro_avg", "vme"]
0.15
"""
drugs = [c for c in y_true.columns if c in y_pred.columns]
if not drugs:
raise ValueError("No common columns between y_true and y_pred.")
reports: dict[str, dict] = {}
for drug in drugs:
yt = y_true[drug]
yp = y_pred[drug]
# Drop rows where either side is NaN
valid = yt.notna() & yp.notna()
yt = yt[valid]
yp = yp[valid]
if len(yt) == 0:
continue
reports[drug] = amr_classification_report(
yt.to_numpy(), yp.to_numpy(), resistant_label=resistant_label
)
reports["macro_avg"] = _compute_macro_average(reports)
if as_dataframe:
return pd.DataFrame(reports).T
return reports
def _compute_macro_average(reports: dict[str, dict]) -> dict:
"""Compute macro-averaged metrics across per-drug reports."""
metric_keys = ["vme", "me", "sensitivity", "specificity", "categorical_agreement"]
macro: dict[str, float | int] = {}
for key in metric_keys:
values = [r[key] for r in reports.values()]
macro[key] = float(np.mean(values)) if values else 0.0
for key in ["n_resistant", "n_susceptible", "n_total"]:
macro[key] = sum(r[key] for r in reports.values())
return macro
# Pre-built sklearn scorers for cross_val_score / GridSearchCV
vme_scorer = make_scorer(very_major_error_rate, greater_is_better=False)
"""Scorer that minimizes VME. Use with ``cross_val_score`` or ``GridSearchCV``."""
me_scorer = make_scorer(major_error_rate, greater_is_better=False)
"""Scorer that minimizes ME. Use with ``cross_val_score`` or ``GridSearchCV``."""