Source code for maldiamrkit.susceptibility.breakpoint

"""Clinical breakpoint tables for MIC interpretation.

A :class:`BreakpointTable` maps each ``(species, drug)`` pair to ``S ≤ s_le``
and ``R > r_gt`` thresholds, optionally with an Area of Technical Uncertainty
(ATU) range. Categorisation:

- ``mic ≤ s_le``  →  ``"S"`` (Susceptible, standard dosing)
- ``mic > r_gt``  →  ``"R"`` (Resistant)
- otherwise       →  ``"I"`` (Susceptible, increased exposure -- modern EUCAST)

The ATU flag is *orthogonal* to S/I/R: it marks MICs that fall in a zone where
assay variability can flip the call. Treat it as an "investigate further"
warning, not a third clinical category.
"""

from __future__ import annotations

import re
from dataclasses import dataclass
from importlib import resources
from pathlib import Path
from typing import Iterable, Sequence

import numpy as np
import pandas as pd
import yaml

_VERSION_FILENAME_RE = re.compile(r"^eucast_v(?P<version>[\d.]+)\.yaml$")
_REQUIRED_ROW_FIELDS = ("species", "drug", "s_le", "r_gt")
_OPTIONAL_ROW_FIELDS = ("atu_low", "atu_high")


[docs] @dataclass(frozen=True) class BreakpointResult: """Result of applying a clinical breakpoint to a single MIC value. Attributes ---------- category : {"S", "I", "R"} or None Clinical category. ``"S"`` (Susceptible, standard dosing), ``"I"`` (Susceptible, increased exposure -- modern EUCAST), or ``"R"`` (Resistant). ``None`` when the lookup failed (no row for this ``(species, drug)``, or MIC is NaN). atu : bool True when the MIC value falls in the species/drug ATU range. Orthogonal to ``category`` -- not a third clinical category. source : str or None Provenance string, e.g. ``"EUCAST v16.0"``. ``None`` when the lookup failed. """ category: str | None atu: bool source: str | None
[docs] class BreakpointTable: """Clinical breakpoint table for MIC interpretation. Holds a set of ``(species, drug) → (s_le, r_gt, [atu_low, atu_high])`` rows from a single guideline release (e.g. EUCAST v16.0). Use :meth:`apply` for single MICs and :meth:`apply_batch` for arrays; :class:`~maldiamrkit.susceptibility.MICEncoder` consumes the batch API. Parameters ---------- rows : pd.DataFrame DataFrame with at least the columns ``species``, ``drug``, ``s_le``, ``r_gt``. Optional columns: ``atu_low``, ``atu_high``. guideline : str, default="EUCAST" e.g. ``"EUCAST"``. version : str, default="" Guideline version, e.g. ``"16.0"``. year : int or None, default=None Calendar year the guideline was published. source : str or None, default=None Free-text provenance, e.g. ``"EUCAST Clinical Breakpoints v16.0 (2026-01-01)"``. Raises ------ ValueError If required columns are missing, threshold types are not numeric, or any row violates ``s_le ≤ r_gt``. Notes ----- EUCAST's literal table format is preserved: ``s_le`` is the largest MIC classified as ``S`` and ``r_gt`` is the largest MIC *not* classified as ``R``. When ``s_le == r_gt`` there is no ``I`` zone. """
[docs] def __init__( self, rows: pd.DataFrame, *, guideline: str = "EUCAST", version: str = "", year: int | None = None, source: str | None = None, ) -> None: self._rows = self._validate_rows(rows) self.guideline = guideline self.version = version self.year = year self.source = source or self._default_source() self._lookup: dict[tuple[str, str], int] = { (str(r.species).strip().lower(), str(r.drug).strip().lower()): idx for idx, r in self._rows.iterrows() }
def __repr__(self) -> str: n = len(self._rows) return ( f"BreakpointTable({self.guideline} v{self.version}, " f"{n} row{'s' if n != 1 else ''})" ) def __len__(self) -> int: return len(self._rows) @property def rows(self) -> pd.DataFrame: """Return a copy of the underlying breakpoint rows.""" return self._rows.copy()
[docs] def species(self) -> list[str]: """List unique species present in the table.""" return sorted(self._rows["species"].unique().tolist())
[docs] def drugs(self) -> list[str]: """List unique drugs present in the table.""" return sorted(self._rows["drug"].unique().tolist())
[docs] def apply(self, species: str, drug: str, mic: float | None) -> BreakpointResult: """Categorise a single MIC value against the table. Parameters ---------- species : str Bacterial species, e.g. ``"Klebsiella pneumoniae"``. Matched case-insensitively against the table. drug : str Antibiotic name. Matched case-insensitively. mic : float or None MIC value in mg/L (linear scale, not ``log2``). ``None`` / ``NaN`` returns a result with ``category=None``. Returns ------- BreakpointResult See :class:`BreakpointResult`. """ key = (str(species).strip().lower(), str(drug).strip().lower()) idx = self._lookup.get(key) if idx is None: return BreakpointResult(category=None, atu=False, source=None) if mic is None or (isinstance(mic, float) and np.isnan(mic)): return BreakpointResult(category=None, atu=False, source=self.source) row = self._rows.loc[idx] return BreakpointResult( category=self._categorise(float(mic), float(row.s_le), float(row.r_gt)), atu=self._in_atu(float(mic), row.atu_low, row.atu_high), source=self.source, )
[docs] def apply_batch( self, species: str | Sequence[str] | np.ndarray | pd.Series, drug: str | Sequence[str] | np.ndarray | pd.Series, mic: Sequence[float] | np.ndarray | pd.Series, ) -> pd.DataFrame: """Categorise an array of MIC values. ``species`` and ``drug`` may be scalars (broadcast to all rows) or arrays of the same length as ``mic``. Parameters ---------- species : str or array-like Species per sample, or a single species applied to all. drug : str or array-like Drug per sample, or a single drug applied to all. mic : array-like MIC values in mg/L (linear scale). Returns ------- pd.DataFrame Columns: ``category`` (object, ``"S"``/``"I"``/``"R"``/NA), ``atu`` (bool), ``source`` (object, possibly NA for unmatched rows). """ mic_arr = pd.Series(mic).astype(float).to_numpy() n = len(mic_arr) species_arr = _broadcast(species, n, "species") drug_arr = _broadcast(drug, n, "drug") categories = np.full(n, None, dtype=object) atu_flags = np.zeros(n, dtype=bool) sources = np.full(n, None, dtype=object) for i in range(n): res = self.apply(species_arr[i], drug_arr[i], mic_arr[i]) categories[i] = res.category atu_flags[i] = res.atu sources[i] = res.source return pd.DataFrame( {"category": categories, "atu": atu_flags, "source": sources} )
[docs] @classmethod def from_yaml(cls, path: str | Path) -> BreakpointTable: """Load a breakpoint table from a YAML file. The YAML must have keys ``guideline``, ``version``, optional ``year`` and ``source``, and a ``rows`` list whose entries carry ``species, drug, s_le, r_gt`` and optionally ``atu_low, atu_high``. """ path = Path(path) with path.open("r", encoding="utf-8") as f: payload = yaml.safe_load(f) or {} return cls._from_payload(payload)
[docs] @classmethod def from_version(cls, version: str) -> BreakpointTable: """Load a bundled EUCAST table by version string, e.g. ``"16.0"``.""" version = str(version).strip().lstrip("vV") available = cls.list_available() if version not in available: raise FileNotFoundError( f"No bundled EUCAST v{version} table found. " f"Available versions: {available or '[]'}. " f"Generate vendored YAMLs by running the gitignored " f"eucast_converter/ tooling on the official EUCAST workbook." ) return cls._load_bundled(f"eucast_v{version}.yaml")
[docs] @classmethod def from_year(cls, year: int) -> BreakpointTable: """Load a bundled EUCAST table by calendar year of publication. EUCAST publishes annually but the version-to-year mapping isn't a clean function (mid-year dot releases exist). When several bundled versions match the same year, the highest version is returned. """ candidates: list[tuple[str, BreakpointTable]] = [] for v in cls.list_available(): table = cls._load_bundled(f"eucast_v{v}.yaml") if table.year == year: candidates.append((v, table)) if not candidates: raise FileNotFoundError( f"No bundled EUCAST table found for year {year}. " f"Available years: {sorted({t.year for _, t in cls._iter_bundled() if t.year})}." ) candidates.sort(key=lambda item: _version_tuple(item[0]), reverse=True) return candidates[0][1]
[docs] @classmethod def from_latest(cls) -> BreakpointTable: """Load the highest-numbered bundled EUCAST table.""" available = cls.list_available() if not available: raise FileNotFoundError( "No bundled EUCAST tables shipped with this install. " "Generate vendored YAMLs by running the gitignored " "eucast_converter/ tooling on the official EUCAST workbook." ) latest = max(available, key=_version_tuple) return cls._load_bundled(f"eucast_v{latest}.yaml")
[docs] @classmethod def list_available(cls) -> list[str]: """List bundled EUCAST version strings, sorted numerically.""" versions: list[str] = [] try: with resources.as_file(_BUNDLED_EUCAST_DIR) as eucast_dir: for entry in eucast_dir.iterdir(): m = _VERSION_FILENAME_RE.match(entry.name) if m: versions.append(m.group("version")) except (FileNotFoundError, ModuleNotFoundError): pass return sorted(versions, key=_version_tuple)
@classmethod def _iter_bundled(cls) -> Iterable[tuple[str, BreakpointTable]]: for v in cls.list_available(): yield v, cls._load_bundled(f"eucast_v{v}.yaml") @classmethod def _load_bundled(cls, filename: str) -> BreakpointTable: with resources.as_file(_BUNDLED_EUCAST_DIR / filename) as path: return cls.from_yaml(path) @classmethod def _from_payload(cls, payload: dict) -> BreakpointTable: rows_payload = payload.get("rows") or [] if not rows_payload: raise ValueError("YAML payload contains no 'rows' entries.") rows = pd.DataFrame(rows_payload) for field in _OPTIONAL_ROW_FIELDS: if field not in rows.columns: rows[field] = np.nan return cls( rows=rows, guideline=payload.get("guideline", "EUCAST"), version=str(payload.get("version", "")), year=payload.get("year"), source=payload.get("source"), ) @staticmethod def _validate_rows(rows: pd.DataFrame) -> pd.DataFrame: missing = [c for c in _REQUIRED_ROW_FIELDS if c not in rows.columns] if missing: raise ValueError( f"Breakpoint rows missing required columns: {missing}. " f"Expected: {list(_REQUIRED_ROW_FIELDS)}." ) out = rows.copy() for field in _OPTIONAL_ROW_FIELDS: if field not in out.columns: out[field] = np.nan out["s_le"] = pd.to_numeric(out["s_le"], errors="raise") out["r_gt"] = pd.to_numeric(out["r_gt"], errors="raise") out["atu_low"] = pd.to_numeric(out["atu_low"], errors="coerce") out["atu_high"] = pd.to_numeric(out["atu_high"], errors="coerce") out["species"] = out["species"].astype(str).str.strip() out["drug"] = out["drug"].astype(str).str.strip() bad = out[out["s_le"] > out["r_gt"]] if not bad.empty: sample = bad.head(3).to_dict(orient="records") raise ValueError( f"Found {len(bad)} row(s) with s_le > r_gt (invalid). " f"First offenders: {sample}" ) out = out.reset_index(drop=True) return out[["species", "drug", "s_le", "r_gt", "atu_low", "atu_high"]] def _default_source(self) -> str: return f"{self.guideline} v{self.version}" if self.version else self.guideline @staticmethod def _categorise(mic: float, s_le: float, r_gt: float) -> str: if mic <= s_le: return "S" if mic > r_gt: return "R" return "I" @staticmethod def _in_atu(mic: float, atu_low: float, atu_high: float) -> bool: if pd.isna(atu_low): return False if pd.isna(atu_high): return bool(mic == atu_low) return bool((mic >= atu_low) and (mic <= atu_high))
def _broadcast(value, n: int, name: str) -> np.ndarray: if isinstance(value, (str, bytes)) or np.isscalar(value): return np.full(n, value, dtype=object) arr = np.asarray(value, dtype=object) if arr.shape[0] != n: raise ValueError( f"{name!r} length {arr.shape[0]} does not match MIC array length {n}." ) return arr def _version_tuple(version: str) -> tuple[int, ...]: return tuple(int(part) for part in version.split(".") if part.isdigit()) _BUNDLED_EUCAST_DIR = resources.files("maldiamrkit") / "data" / "breakpoints" / "eucast"