Source code for maldiamrkit.alignment.strategies

"""Shared alignment strategy classes for spectral warping.

Each strategy implements one alignment algorithm that can operate on either
binned (index-based) or raw (m/z-based) coordinate systems.
"""

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from enum import Enum

import numpy as np
from scipy.ndimage import gaussian_filter1d
from tslearn.metrics import dtw_path


[docs] class AlignmentMethod(str, Enum): """Supported alignment/warping methods. Attributes ---------- shift : str Rigid global shift alignment. linear : str Linear (affine) recalibration. piecewise : str Piecewise-linear recalibration. dtw : str Dynamic time warping alignment. quadratic : str Quadratic polynomial recalibration. cubic : str Cubic polynomial recalibration. lowess : str Non-linear LOWESS (Cleveland 1979) recalibration. """ shift = "shift" linear = "linear" piecewise = "piecewise" dtw = "dtw" quadratic = "quadratic" cubic = "cubic" lowess = "lowess"
[docs] class AlignmentStrategy(ABC): """Base class for alignment strategies."""
[docs] @abstractmethod def align_binned( self, row: np.ndarray, peaks: np.ndarray, ref_peaks: np.ndarray, mz_axis: np.ndarray, ) -> np.ndarray: """Align a binned spectrum row to the reference. Parameters ---------- row : np.ndarray Intensity values of the spectrum to align. peaks : np.ndarray Detected peak indices in ``row``. ref_peaks : np.ndarray Detected peak indices in the reference spectrum. mz_axis : np.ndarray Array of bin positions (e.g. ``np.arange(len(row))``). Returns ------- np.ndarray Aligned intensity array with the same length as ``row``. """
[docs] @abstractmethod def align_raw( self, mz: np.ndarray, intensity: np.ndarray, peaks_mz: np.ndarray, ref_peaks_mz: np.ndarray, ref_mz: np.ndarray, ref_intensity: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Align a raw spectrum to the reference. Parameters ---------- mz : np.ndarray m/z values of the spectrum to align. intensity : np.ndarray Intensity values of the spectrum to align. peaks_mz : np.ndarray Detected peak m/z positions in the sample spectrum. ref_peaks_mz : np.ndarray Detected peak m/z positions in the reference spectrum. ref_mz : np.ndarray m/z values of the reference spectrum. ref_intensity : np.ndarray Intensity values of the reference spectrum. Returns ------- aligned_mz : np.ndarray Aligned m/z values. aligned_intensity : np.ndarray Aligned intensity values. """
def monotonic_interp( mz_axis: np.ndarray, new_positions: np.ndarray, row: np.ndarray ) -> np.ndarray: """Interpolate a spectrum onto *mz_axis* after a warping transform. When *new_positions* is monotonically increasing a simple ``np.interp`` call is used. Otherwise the positions are sorted and duplicate positions are averaged before interpolation, and a warning is emitted. Parameters ---------- mz_axis : np.ndarray Target m/z (or index) grid. new_positions : np.ndarray Warped positions corresponding to each element of *row*. row : np.ndarray Intensity values to be re-mapped. Returns ------- np.ndarray Interpolated intensity array on *mz_axis*. """ if np.all(np.diff(new_positions) > 0): return np.interp(mz_axis, new_positions, row, left=0.0, right=0.0) warnings.warn( "Warping produced non-monotonic m/z mapping for a sample. " "This may indicate poor alignment quality. " "Consider adjusting alignment parameters (e.g., reduce max_shift_da " "or increase n_segments).", UserWarning, stacklevel=4, ) sort_idx = np.argsort(new_positions) new_positions_sorted = new_positions[sort_idx] row_sorted = row[sort_idx] unique_pos, inverse = np.unique(new_positions_sorted, return_inverse=True) counts = np.bincount(inverse) unique_intensities = ( np.bincount(inverse, weights=row_sorted, minlength=len(unique_pos)) / counts ) return np.interp(mz_axis, unique_pos, unique_intensities, left=0.0, right=0.0) def _nearest_ref_indices(peaks: np.ndarray, ref_peaks: np.ndarray) -> np.ndarray: """For each peak, find the index of the nearest reference peak (O(P log R)).""" idx = np.searchsorted(ref_peaks, peaks) idx = np.clip(idx, 0, len(ref_peaks) - 1) left = np.clip(idx - 1, 0, len(ref_peaks) - 1) use_left = np.abs(ref_peaks[left] - peaks) < np.abs(ref_peaks[idx] - peaks) idx[use_left] = left[use_left] return idx def _match_peaks_to_ref(peaks: np.ndarray, ref_peaks: np.ndarray) -> np.ndarray: """For each peak, compute shift to nearest reference peak.""" matched = ref_peaks[_nearest_ref_indices(peaks, ref_peaks)] return matched - peaks def _match_peak_pairs( peaks: np.ndarray, ref_peaks: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: """Match peaks to nearest reference peaks. Returns (sample, ref) arrays.""" matched = ref_peaks[_nearest_ref_indices(peaks, ref_peaks)] return peaks.copy(), matched class ShiftStrategy(AlignmentStrategy): """Global median shift alignment.""" def __init__(self, max_shift: float) -> None: self.max_shift = max_shift def align_binned( self, row: np.ndarray, peaks: np.ndarray, ref_peaks: np.ndarray, mz_axis: np.ndarray, ) -> np.ndarray: """Apply global median shift to a binned spectrum.""" if len(peaks) == 0 or len(ref_peaks) == 0: return row shifts = _match_peaks_to_ref(peaks, ref_peaks) shift = int(np.round(np.median(shifts))) if len(shifts) else 0 shift = np.clip(shift, -self.max_shift, self.max_shift) if shift > 0: aligned = np.zeros_like(row) aligned[shift:] = row[:-shift] elif shift < 0: aligned = np.zeros_like(row) aligned[:shift] = row[-shift:] else: aligned = row.copy() return aligned def align_raw( self, mz: np.ndarray, intensity: np.ndarray, peaks_mz: np.ndarray, ref_peaks_mz: np.ndarray, ref_mz: np.ndarray, ref_intensity: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Apply global m/z shift to a raw spectrum.""" if len(peaks_mz) == 0 or len(ref_peaks_mz) == 0: return mz, intensity shifts = _match_peaks_to_ref(peaks_mz, ref_peaks_mz) shift_da = np.median(shifts) if len(shifts) else 0.0 shift_da = np.clip(shift_da, -self.max_shift, self.max_shift) return mz + shift_da, intensity def _robust_linear_fit( sample: np.ndarray, ref: np.ndarray, residual_threshold: float = 3.0, ) -> tuple[float, float]: """Fit mz' = a*mz + b with MAD-based outlier rejection. Performs a least-squares fit, then rejects peak pairs whose residual exceeds ``residual_threshold * 1.4826 * MAD`` and refits. Parameters ---------- sample, ref : np.ndarray Matched peak positions (sample and reference). residual_threshold : float, default=3.0 Number of MAD-based sigma units for outlier rejection. Returns ------- a, b : float Coefficients of the linear fit ``ref = a * sample + b``. """ A = np.vstack([sample, np.ones_like(sample)]).T a, b = np.linalg.lstsq(A, ref, rcond=1e-10)[0] if len(sample) > 2: residuals = ref - (a * sample + b) mad = np.median(np.abs(residuals - np.median(residuals))) if mad > 0: cutoff = residual_threshold * 1.4826 * mad inlier_mask = np.abs(residuals - np.median(residuals)) <= cutoff if inlier_mask.sum() >= 2: A_inlier = A[inlier_mask] a, b = np.linalg.lstsq(A_inlier, ref[inlier_mask], rcond=1e-10)[0] return a, b class LinearStrategy(AlignmentStrategy): """Least-squares linear transformation alignment with outlier rejection.""" def __init__(self, max_shift: float) -> None: self.max_shift = max_shift self._fallback = ShiftStrategy(max_shift) def align_binned( self, row: np.ndarray, peaks: np.ndarray, ref_peaks: np.ndarray, mz_axis: np.ndarray, ) -> np.ndarray: """Apply linear transformation alignment to a binned spectrum.""" if len(peaks) < 2 or len(ref_peaks) < 2: return self._fallback.align_binned(row, peaks, ref_peaks, mz_axis) sample, ref = _match_peak_pairs(peaks, ref_peaks) a, b = _robust_linear_fit(sample, ref) new_positions = a * mz_axis + b return monotonic_interp(mz_axis, new_positions, row) def align_raw( self, mz: np.ndarray, intensity: np.ndarray, peaks_mz: np.ndarray, ref_peaks_mz: np.ndarray, ref_mz: np.ndarray, ref_intensity: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Apply linear m/z transformation to a raw spectrum.""" if len(peaks_mz) < 2 or len(ref_peaks_mz) < 2: return self._fallback.align_raw( mz, intensity, peaks_mz, ref_peaks_mz, ref_mz, ref_intensity ) sample, ref = _match_peak_pairs(peaks_mz, ref_peaks_mz) a, b = _robust_linear_fit(sample, ref) return a * mz + b, intensity class PiecewiseStrategy(AlignmentStrategy): """Piecewise linear alignment with smoothed local shifts.""" def __init__(self, n_segments: int, smooth_sigma: float, max_shift: float) -> None: self.n_segments = n_segments self.smooth_sigma = smooth_sigma self.max_shift = max_shift def _compute_segment_shifts( self, sample: np.ndarray, ref: np.ndarray ) -> tuple[list, list]: """Compute per-segment median positions and shifts.""" quantiles = np.linspace(0, 1, self.n_segments + 1) boundaries = np.quantile(sample, quantiles) seg_x, seg_shift = [], [] for q in range(self.n_segments): if q == self.n_segments - 1: mask = (sample >= boundaries[q]) & (sample <= boundaries[q + 1]) else: mask = (sample >= boundaries[q]) & (sample < boundaries[q + 1]) if mask.sum() > 0: seg_x.append(np.median(sample[mask])) seg_shift.append(np.median(ref[mask] - sample[mask])) return seg_x, seg_shift def align_binned( self, row: np.ndarray, peaks: np.ndarray, ref_peaks: np.ndarray, mz_axis: np.ndarray, ) -> np.ndarray: """Apply piecewise alignment to a binned spectrum.""" if len(peaks) == 0 or len(ref_peaks) == 0: return row sample, ref = _match_peak_pairs(peaks, ref_peaks) seg_x, seg_shift = self._compute_segment_shifts(sample, ref) if len(seg_x) == 0: return row shift_interp = np.interp( mz_axis, seg_x, seg_shift, left=seg_shift[0], right=seg_shift[-1] ) if self.smooth_sigma > 0: shift_interp = gaussian_filter1d( shift_interp, sigma=self.smooth_sigma, mode="nearest" ) new_positions = mz_axis + shift_interp return monotonic_interp(mz_axis, new_positions, row) def align_raw( self, mz: np.ndarray, intensity: np.ndarray, peaks_mz: np.ndarray, ref_peaks_mz: np.ndarray, ref_mz: np.ndarray, ref_intensity: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Apply piecewise m/z transformation to a raw spectrum.""" if len(peaks_mz) == 0 or len(ref_peaks_mz) == 0: return mz, intensity sample, ref = _match_peak_pairs(peaks_mz, ref_peaks_mz) seg_x, seg_shift = self._compute_segment_shifts(sample, ref) if len(seg_x) == 0: return mz, intensity shift_interp = np.interp( mz, seg_x, seg_shift, left=seg_shift[0], right=seg_shift[-1] ) if self.smooth_sigma > 0: mz_spacing = np.median(np.diff(mz)) sigma_points = min(int(self.smooth_sigma / mz_spacing), len(mz) // 4) if sigma_points > 1: shift_interp = gaussian_filter1d( shift_interp, sigma=sigma_points, mode="nearest" ) return mz + shift_interp, intensity class DTWStrategy(AlignmentStrategy): """Dynamic time warping alignment.""" def __init__(self, dtw_radius: int) -> None: self.dtw_radius = dtw_radius def _dtw_align(self, query: np.ndarray, reference: np.ndarray) -> np.ndarray: """Core DTW alignment returning aligned intensity.""" path, _ = dtw_path( query, reference, global_constraint="sakoe_chiba", sakoe_chiba_radius=self.dtw_radius, ) aligned_sum = np.zeros_like(reference) aligned_count = np.zeros_like(reference) for i, j in path: if 0 <= j < len(aligned_sum): aligned_sum[j] += query[i] aligned_count[j] += 1 aligned = np.zeros_like(reference) mask = aligned_count > 0 aligned[mask] = aligned_sum[mask] / aligned_count[mask] return aligned def align_binned( self, row: np.ndarray, peaks: np.ndarray, ref_peaks: np.ndarray, mz_axis: np.ndarray, ) -> np.ndarray: """Raise because DTW binned alignment requires the full reference spectrum. DTW binned alignment is handled directly by the Warping class which has access to the stored reference spectrum. """ raise NotImplementedError( "DTW binned alignment is handled directly by the Warping class" ) def align_raw( self, mz: np.ndarray, intensity: np.ndarray, peaks_mz: np.ndarray, ref_peaks_mz: np.ndarray, ref_mz: np.ndarray, ref_intensity: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Apply DTW alignment to a raw spectrum.""" query_intensity = np.interp(ref_mz, mz, intensity) aligned_intensity = self._dtw_align(query_intensity, ref_intensity) return ref_mz, aligned_intensity class PolynomialStrategy(AlignmentStrategy): """Polynomial alignment via ``numpy.polyfit`` on matched peak pairs. Used for both the ``"quadratic"`` (``degree=2``) and ``"cubic"`` (``degree=3``) warping methods. When fewer than ``degree + 1`` matched peak pairs are available, falls back to :class:`ShiftStrategy` to avoid an under-determined fit. Parameters ---------- max_shift : float Maximum allowed shift passed through to the shift fallback. degree : int Polynomial degree. ``2`` for quadratic, ``3`` for cubic. """ def __init__(self, max_shift: float, degree: int) -> None: if degree < 1: raise ValueError(f"degree must be >= 1, got {degree}") self.max_shift = max_shift self.degree = degree self._fallback = ShiftStrategy(max_shift) def _fit_polynomial(self, sample: np.ndarray, ref: np.ndarray) -> np.ndarray: """Fit a polynomial of ``self.degree`` mapping sample to reference.""" return np.polyfit(sample, ref, self.degree) def align_binned( self, row: np.ndarray, peaks: np.ndarray, ref_peaks: np.ndarray, mz_axis: np.ndarray, ) -> np.ndarray: """Apply polynomial warping to a binned spectrum.""" if len(peaks) < self.degree + 1 or len(ref_peaks) < self.degree + 1: return self._fallback.align_binned(row, peaks, ref_peaks, mz_axis) sample, ref = _match_peak_pairs(peaks, ref_peaks) coeffs = self._fit_polynomial(sample, ref) new_positions = np.polyval(coeffs, mz_axis) return monotonic_interp(mz_axis, new_positions, row) def align_raw( self, mz: np.ndarray, intensity: np.ndarray, peaks_mz: np.ndarray, ref_peaks_mz: np.ndarray, ref_mz: np.ndarray, ref_intensity: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Apply polynomial m/z warping to a raw spectrum.""" if len(peaks_mz) < self.degree + 1 or len(ref_peaks_mz) < self.degree + 1: return self._fallback.align_raw( mz, intensity, peaks_mz, ref_peaks_mz, ref_mz, ref_intensity ) sample, ref = _match_peak_pairs(peaks_mz, ref_peaks_mz) coeffs = self._fit_polynomial(sample, ref) return np.polyval(coeffs, mz), intensity def _lowess_fit( x: np.ndarray, y: np.ndarray, frac: float, it: int, ) -> np.ndarray: """Locally-weighted scatterplot smoothing (Cleveland 1979). Lightweight pure-numpy implementation that fits a local linear regression at each data point using a tricube kernel on the ``frac * n`` nearest neighbours, with ``it`` robustness reweighting iterations using bisquare residual weights. Input ``x`` must be sorted in ascending order. Parameters ---------- x, y : np.ndarray One-dimensional arrays of the same length; ``x`` sorted. frac : float Smoothing bandwidth, fraction of data used for each local fit, in the half-open interval ``(0, 1]``. it : int Number of robustness reweighting iterations. ``0`` disables it. Returns ------- np.ndarray Smoothed ``y`` values at each ``x``. """ n = len(x) if n == 0: return np.empty(0, dtype=float) k = max(int(np.ceil(frac * n)), 2) k = min(k, n) fitted = np.empty(n, dtype=float) residual_weights = np.ones(n, dtype=float) for _ in range(it + 1): for i in range(n): distances = np.abs(x - x[i]) nearest = np.argsort(distances)[:k] h = distances[nearest].max() if h == 0: fitted[i] = np.average(y[nearest], weights=residual_weights[nearest]) continue u = np.clip(distances[nearest] / h, 0.0, 1.0) kernel = (1.0 - u**3) ** 3 w = kernel * residual_weights[nearest] total = w.sum() if total <= 0: fitted[i] = y[i] continue xn = x[nearest] yn = y[nearest] mean_x = np.sum(w * xn) / total mean_y = np.sum(w * yn) / total dx = xn - mean_x denom = np.sum(w * dx * dx) if denom <= 0: fitted[i] = mean_y else: beta = np.sum(w * dx * (yn - mean_y)) / denom fitted[i] = mean_y + beta * (x[i] - mean_x) residuals = y - fitted s = np.median(np.abs(residuals)) if s == 0: break u = np.clip(residuals / (6.0 * s), -1.0, 1.0) residual_weights = (1.0 - u**2) ** 2 return fitted class LOWESSStrategy(AlignmentStrategy): """Non-linear LOWESS alignment on matched peak pairs. Fits a LOWESS (Cleveland 1979) regression of the form ``ref = f(sample)`` on the matched peak pairs, then applies the fitted function to the full m/z axis by linear interpolation between the fitted peak positions. Parameters ---------- max_shift : float Maximum allowed shift passed through to the shift fallback. frac : float LOWESS smoothing bandwidth (fraction of matched peaks used for each local fit). Must be in ``(0, 1]``. it : int Number of LOWESS robustness iterations. Must be ``>= 0``. """ _MIN_PEAKS = 3 def __init__(self, max_shift: float, frac: float, it: int) -> None: if not (0.0 < frac <= 1.0): raise ValueError(f"frac must be in (0, 1], got {frac}") if it < 0: raise ValueError(f"it must be >= 0, got {it}") self.max_shift = max_shift self.frac = frac self.it = it self._fallback = ShiftStrategy(max_shift) def _fit_positions( self, sample: np.ndarray, ref: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: """Return (sorted_sample, fitted_ref) on the matched peak pairs.""" order = np.argsort(sample) sample_sorted = sample[order].astype(float) ref_sorted = ref[order].astype(float) fitted = _lowess_fit(sample_sorted, ref_sorted, self.frac, self.it) return sample_sorted, fitted def align_binned( self, row: np.ndarray, peaks: np.ndarray, ref_peaks: np.ndarray, mz_axis: np.ndarray, ) -> np.ndarray: """Apply LOWESS warping to a binned spectrum.""" if len(peaks) < self._MIN_PEAKS or len(ref_peaks) < self._MIN_PEAKS: return self._fallback.align_binned(row, peaks, ref_peaks, mz_axis) sample, ref = _match_peak_pairs(peaks, ref_peaks) sample_sorted, fitted = self._fit_positions(sample, ref) new_positions = np.interp( mz_axis, sample_sorted, fitted, left=fitted[0] + (mz_axis[0] - sample_sorted[0]), right=fitted[-1] + (mz_axis[-1] - sample_sorted[-1]), ) return monotonic_interp(mz_axis, new_positions, row) def align_raw( self, mz: np.ndarray, intensity: np.ndarray, peaks_mz: np.ndarray, ref_peaks_mz: np.ndarray, ref_mz: np.ndarray, ref_intensity: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Apply LOWESS m/z warping to a raw spectrum.""" if len(peaks_mz) < self._MIN_PEAKS or len(ref_peaks_mz) < self._MIN_PEAKS: return self._fallback.align_raw( mz, intensity, peaks_mz, ref_peaks_mz, ref_mz, ref_intensity ) sample, ref = _match_peak_pairs(peaks_mz, ref_peaks_mz) sample_sorted, fitted = self._fit_positions(sample, ref) new_mz = np.interp( mz, sample_sorted, fitted, left=fitted[0] + (mz[0] - sample_sorted[0]), right=fitted[-1] + (mz[-1] - sample_sorted[-1]), ) return new_mz, intensity ALIGNMENT_REGISTRY: dict[str, type[AlignmentStrategy]] = { "shift": ShiftStrategy, "linear": LinearStrategy, "piecewise": PiecewiseStrategy, "dtw": DTWStrategy, "quadratic": PolynomialStrategy, "cubic": PolynomialStrategy, "lowess": LOWESSStrategy, }