"""Spectral alignment and warping transformers for binned spectra."""
from __future__ import annotations
import warnings
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from sklearn.base import BaseEstimator, TransformerMixin
from ..detection.peak_detector import MaldiPeakDetector
from .strategies import ALIGNMENT_REGISTRY, AlignmentMethod, DTWStrategy
[docs]
class Warping(BaseEstimator, TransformerMixin):
"""
Align MALDI-TOF spectra to a reference using different strategies.
Supports multiple alignment methods for correcting mass calibration drift
in binned spectra.
Parameters
----------
peak_detector : MaldiPeakDetector, optional
Peak detector used to find peaks in spectra. If None, a default
detector is created with binary=True and prominence=1e-5.
reference : str or int, default="median"
How to choose the reference spectrum:
- "median" : median spectrum across all samples
- int : use that row index as reference
method : str, default="shift"
Alignment method:
- "shift" : global median shift
- "linear" : least-squares linear transform
- "piecewise" : local median shifts across segments
- "dtw" : dynamic time warping
- "quadratic" : quadratic polynomial fit on matched peak pairs
- "cubic" : cubic polynomial fit on matched peak pairs
- "lowess" : LOWESS (Cleveland 1979) non-linear warping
n_segments : int, default=5
Number of segments for piecewise warping.
max_shift : int, default=50
Max allowed shift in bins (used as fallback for shift / linear /
polynomial / LOWESS methods when too few peaks match).
dtw_radius : int, default=10
Radius constraint for DTW to limit warping path search space.
smooth_sigma : float, default=2.0
Gaussian smoothing parameter for piecewise segment shifts.
lowess_frac : float, default=0.3
LOWESS smoothing bandwidth (fraction of matched peaks used for
each local fit). Applies when ``method="lowess"``.
lowess_it : int, default=3
Number of LOWESS robustness iterations. Applies when
``method="lowess"``.
min_reference_peaks : int, default=5
Minimum number of peaks expected in reference for quality check.
n_jobs : int, default=1
Number of parallel jobs for transform. Use -1 for all available
cores, 1 for sequential processing. Parallelization is particularly
beneficial for the "dtw" method which is CPU-intensive.
Attributes
----------
ref_spec_ : np.ndarray
The fitted reference spectrum (stored after fit()).
Examples
--------
>>> from maldiamrkit.alignment import Warping
>>> warper = Warping(method="shift")
>>> warper.fit(X_train)
>>> X_aligned = warper.transform(X_test)
"""
[docs]
def __init__(
self,
peak_detector: MaldiPeakDetector | None = None,
reference: str | int = "median",
method: str | AlignmentMethod = AlignmentMethod.shift,
n_segments: int = 5,
max_shift: int = 50,
dtw_radius: int = 10,
smooth_sigma: float = 2.0,
lowess_frac: float = 0.3,
lowess_it: int = 3,
min_reference_peaks: int = 5,
n_jobs: int = 1,
) -> None:
self.peak_detector = peak_detector or MaldiPeakDetector(
binary=True, prominence=1e-5
)
self.reference = reference
self.method = AlignmentMethod(method)
self.n_segments = n_segments
self.max_shift = max_shift
self.dtw_radius = dtw_radius
self.smooth_sigma = smooth_sigma
self.lowess_frac = lowess_frac
self.lowess_it = lowess_it
self.min_reference_peaks = min_reference_peaks
self.n_jobs = n_jobs
[docs]
def fit(self, X: pd.DataFrame, y=None):
"""
Fit the transformer by selecting or computing the reference spectrum.
Parameters
----------
X : pd.DataFrame
Input spectra with shape (n_samples, n_bins).
y : array-like, optional
Target values (ignored).
Returns
-------
self : Warping
Fitted transformer.
Raises
------
ValueError
If the input DataFrame is empty, the reference index is out of
bounds, the reference specifier is unsupported, the warping
method is unknown, or parameters are invalid.
"""
if X.empty:
raise ValueError("Input DataFrame X is empty")
if self.reference == "median":
self.ref_spec_ = X.median(axis=0).to_numpy()
elif isinstance(self.reference, int):
if self.reference < 0 or self.reference >= len(X):
raise ValueError(
f"Reference index {self.reference} is out of bounds "
f"for X with {len(X)} samples"
)
self.ref_spec_ = X.iloc[self.reference].to_numpy()
else:
raise ValueError(
f"Unsupported reference specifier: {self.reference}. "
f"Must be 'median' or int."
)
# Validate parameters
if self.n_segments < 1:
raise ValueError(f"n_segments must be >= 1, got {self.n_segments}")
if self.max_shift < 0:
raise ValueError(f"max_shift must be >= 0, got {self.max_shift}")
if not (0.0 < self.lowess_frac <= 1.0):
raise ValueError(f"lowess_frac must be in (0, 1], got {self.lowess_frac}")
if self.lowess_it < 0:
raise ValueError(f"lowess_it must be >= 0, got {self.lowess_it}")
# Validate reference quality
self._validate_reference_quality(X)
return self
def _validate_reference_quality(self, X: pd.DataFrame):
"""Validate that the reference spectrum has sufficient quality."""
ref_peaks_df = self.peak_detector.transform(
pd.DataFrame(self.ref_spec_[np.newaxis, :], columns=X.columns)
)
n_peaks = ref_peaks_df.iloc[0].to_numpy().nonzero()[0].size
if n_peaks < self.min_reference_peaks:
warnings.warn(
f"Reference spectrum has only {n_peaks} peaks detected. "
f"Expected at least {self.min_reference_peaks}. "
f"This may result in poor alignment quality. "
f"Consider adjusting peak detection parameters or "
f"choosing a different reference.",
UserWarning,
stacklevel=2,
)
def _get_strategy(self):
"""Build strategy instance from current parameters."""
cls = ALIGNMENT_REGISTRY[self.method]
if self.method == "shift":
return cls(max_shift=self.max_shift)
elif self.method == "linear":
return cls(max_shift=self.max_shift)
elif self.method == "piecewise":
return cls(
n_segments=self.n_segments,
smooth_sigma=self.smooth_sigma,
max_shift=self.max_shift,
)
elif self.method == "dtw":
return cls(dtw_radius=self.dtw_radius)
elif self.method == "quadratic":
return cls(max_shift=self.max_shift, degree=2)
elif self.method == "cubic":
return cls(max_shift=self.max_shift, degree=3)
elif self.method == "lowess":
return cls(
max_shift=self.max_shift,
frac=self.lowess_frac,
it=self.lowess_it,
)
return cls()
def _align_single_row(
self,
row: np.ndarray,
peaks: np.ndarray | None,
ref_peaks: np.ndarray,
mz_axis: np.ndarray,
) -> tuple[np.ndarray, bool]:
"""Align a single row (helper for parallelization).
Returns ``(aligned_row, was_non_monotonic)``. The boolean flag
captures whether ``monotonic_interp`` had to fall back to the
sort-and-deduplicate branch, so :meth:`transform` can emit a
single aggregated warning instead of one per sample (which would
otherwise spam the log on large batches).
"""
strategy = self._get_strategy()
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always", UserWarning)
if isinstance(strategy, DTWStrategy):
aligned = strategy._dtw_align(row, self.ref_spec_)
else:
aligned = strategy.align_binned(row, peaks, ref_peaks, mz_axis)
non_monotonic = any(
"non-monotonic m/z mapping" in str(w.message) for w in caught
)
return aligned, non_monotonic
[docs]
def get_alignment_quality(
self, X_original: pd.DataFrame, X_aligned: pd.DataFrame | None = None
) -> pd.DataFrame:
"""
Compute alignment quality metrics.
Parameters
----------
X_original : pd.DataFrame
Original (unaligned) spectra.
X_aligned : pd.DataFrame, optional
Aligned spectra. If None, will compute by calling transform().
Returns
-------
pd.DataFrame
Quality metrics with columns:
- correlation_before: Pearson correlation with reference (before)
- correlation_after: Pearson correlation with reference (after)
- improvement: correlation_after - correlation_before
- rmse_before: RMSE with reference (before)
- rmse_after: RMSE with reference (after)
Raises
------
RuntimeError
If the transformer has not been fitted.
"""
if not hasattr(self, "ref_spec_"):
raise RuntimeError("Warping must be fitted before computing quality")
if X_aligned is None:
X_aligned = self.transform(X_original)
metrics = []
for i in range(len(X_original)):
original = X_original.iloc[i].to_numpy()
aligned = X_aligned.iloc[i].to_numpy()
# Correlation with reference (NaN when a signal has zero variance)
corr_before = np.corrcoef(original, self.ref_spec_)[0, 1]
corr_after = np.corrcoef(aligned, self.ref_spec_)[0, 1]
if np.isnan(corr_before) or np.isnan(corr_after):
warnings.warn(
f"Correlation undefined for sample {X_original.index[i]} "
f"(constant signal); defaulting to 0.0",
UserWarning,
stacklevel=2,
)
corr_before = 0.0 if np.isnan(corr_before) else corr_before
corr_after = 0.0 if np.isnan(corr_after) else corr_after
# RMSE with reference
rmse_before = np.sqrt(np.mean((original - self.ref_spec_) ** 2))
rmse_after = np.sqrt(np.mean((aligned - self.ref_spec_) ** 2))
metrics.append(
{
"correlation_before": corr_before,
"correlation_after": corr_after,
"improvement": corr_after - corr_before,
"rmse_before": rmse_before,
"rmse_after": rmse_after,
}
)
return pd.DataFrame(metrics, index=X_original.index)