"""Shared alignment strategy classes for spectral warping.
Each strategy implements one alignment algorithm that can operate on either
binned (index-based) or raw (m/z-based) coordinate systems.
"""
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from enum import Enum
import numpy as np
from scipy.ndimage import gaussian_filter1d
from tslearn.metrics import dtw_path
[docs]
class AlignmentMethod(str, Enum):
"""Supported alignment/warping methods.
Attributes
----------
shift : str
Rigid global shift alignment.
linear : str
Linear (affine) recalibration.
piecewise : str
Piecewise-linear recalibration.
dtw : str
Dynamic time warping alignment.
quadratic : str
Quadratic polynomial recalibration.
cubic : str
Cubic polynomial recalibration.
lowess : str
Non-linear LOWESS (Cleveland 1979) recalibration.
"""
shift = "shift"
linear = "linear"
piecewise = "piecewise"
dtw = "dtw"
quadratic = "quadratic"
cubic = "cubic"
lowess = "lowess"
[docs]
class AlignmentStrategy(ABC):
"""Base class for alignment strategies."""
[docs]
@abstractmethod
def align_binned(
self,
row: np.ndarray,
peaks: np.ndarray,
ref_peaks: np.ndarray,
mz_axis: np.ndarray,
) -> np.ndarray:
"""Align a binned spectrum row to the reference.
Parameters
----------
row : np.ndarray
Intensity values of the spectrum to align.
peaks : np.ndarray
Detected peak indices in ``row``.
ref_peaks : np.ndarray
Detected peak indices in the reference spectrum.
mz_axis : np.ndarray
Array of bin positions (e.g. ``np.arange(len(row))``).
Returns
-------
np.ndarray
Aligned intensity array with the same length as ``row``.
"""
[docs]
@abstractmethod
def align_raw(
self,
mz: np.ndarray,
intensity: np.ndarray,
peaks_mz: np.ndarray,
ref_peaks_mz: np.ndarray,
ref_mz: np.ndarray,
ref_intensity: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Align a raw spectrum to the reference.
Parameters
----------
mz : np.ndarray
m/z values of the spectrum to align.
intensity : np.ndarray
Intensity values of the spectrum to align.
peaks_mz : np.ndarray
Detected peak m/z positions in the sample spectrum.
ref_peaks_mz : np.ndarray
Detected peak m/z positions in the reference spectrum.
ref_mz : np.ndarray
m/z values of the reference spectrum.
ref_intensity : np.ndarray
Intensity values of the reference spectrum.
Returns
-------
aligned_mz : np.ndarray
Aligned m/z values.
aligned_intensity : np.ndarray
Aligned intensity values.
"""
def monotonic_interp(
mz_axis: np.ndarray, new_positions: np.ndarray, row: np.ndarray
) -> np.ndarray:
"""Interpolate a spectrum onto *mz_axis* after a warping transform.
When *new_positions* is monotonically increasing a simple
``np.interp`` call is used. Otherwise the positions are sorted and
duplicate positions are averaged before interpolation, and a
warning is emitted.
Parameters
----------
mz_axis : np.ndarray
Target m/z (or index) grid.
new_positions : np.ndarray
Warped positions corresponding to each element of *row*.
row : np.ndarray
Intensity values to be re-mapped.
Returns
-------
np.ndarray
Interpolated intensity array on *mz_axis*.
"""
if np.all(np.diff(new_positions) > 0):
return np.interp(mz_axis, new_positions, row, left=0.0, right=0.0)
warnings.warn(
"Warping produced non-monotonic m/z mapping for a sample. "
"This may indicate poor alignment quality. "
"Consider adjusting alignment parameters (e.g., reduce max_shift_da "
"or increase n_segments).",
UserWarning,
stacklevel=4,
)
sort_idx = np.argsort(new_positions)
new_positions_sorted = new_positions[sort_idx]
row_sorted = row[sort_idx]
unique_pos, inverse = np.unique(new_positions_sorted, return_inverse=True)
counts = np.bincount(inverse)
unique_intensities = (
np.bincount(inverse, weights=row_sorted, minlength=len(unique_pos)) / counts
)
return np.interp(mz_axis, unique_pos, unique_intensities, left=0.0, right=0.0)
def _nearest_ref_indices(peaks: np.ndarray, ref_peaks: np.ndarray) -> np.ndarray:
"""For each peak, find the index of the nearest reference peak (O(P log R))."""
idx = np.searchsorted(ref_peaks, peaks)
idx = np.clip(idx, 0, len(ref_peaks) - 1)
left = np.clip(idx - 1, 0, len(ref_peaks) - 1)
use_left = np.abs(ref_peaks[left] - peaks) < np.abs(ref_peaks[idx] - peaks)
idx[use_left] = left[use_left]
return idx
def _match_peaks_to_ref(peaks: np.ndarray, ref_peaks: np.ndarray) -> np.ndarray:
"""For each peak, compute shift to nearest reference peak."""
matched = ref_peaks[_nearest_ref_indices(peaks, ref_peaks)]
return matched - peaks
def _match_peak_pairs(
peaks: np.ndarray, ref_peaks: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""Match peaks to nearest reference peaks. Returns (sample, ref) arrays."""
matched = ref_peaks[_nearest_ref_indices(peaks, ref_peaks)]
return peaks.copy(), matched
class ShiftStrategy(AlignmentStrategy):
"""Global median shift alignment."""
def __init__(self, max_shift: float) -> None:
self.max_shift = max_shift
def align_binned(
self,
row: np.ndarray,
peaks: np.ndarray,
ref_peaks: np.ndarray,
mz_axis: np.ndarray,
) -> np.ndarray:
"""Apply global median shift to a binned spectrum."""
if len(peaks) == 0 or len(ref_peaks) == 0:
return row
shifts = _match_peaks_to_ref(peaks, ref_peaks)
shift = int(np.round(np.median(shifts))) if len(shifts) else 0
shift = np.clip(shift, -self.max_shift, self.max_shift)
if shift > 0:
aligned = np.zeros_like(row)
aligned[shift:] = row[:-shift]
elif shift < 0:
aligned = np.zeros_like(row)
aligned[:shift] = row[-shift:]
else:
aligned = row.copy()
return aligned
def align_raw(
self,
mz: np.ndarray,
intensity: np.ndarray,
peaks_mz: np.ndarray,
ref_peaks_mz: np.ndarray,
ref_mz: np.ndarray,
ref_intensity: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Apply global m/z shift to a raw spectrum."""
if len(peaks_mz) == 0 or len(ref_peaks_mz) == 0:
return mz, intensity
shifts = _match_peaks_to_ref(peaks_mz, ref_peaks_mz)
shift_da = np.median(shifts) if len(shifts) else 0.0
shift_da = np.clip(shift_da, -self.max_shift, self.max_shift)
return mz + shift_da, intensity
def _robust_linear_fit(
sample: np.ndarray,
ref: np.ndarray,
residual_threshold: float = 3.0,
) -> tuple[float, float]:
"""Fit mz' = a*mz + b with MAD-based outlier rejection.
Performs a least-squares fit, then rejects peak pairs whose
residual exceeds ``residual_threshold * 1.4826 * MAD`` and refits.
Parameters
----------
sample, ref : np.ndarray
Matched peak positions (sample and reference).
residual_threshold : float, default=3.0
Number of MAD-based sigma units for outlier rejection.
Returns
-------
a, b : float
Coefficients of the linear fit ``ref = a * sample + b``.
"""
A = np.vstack([sample, np.ones_like(sample)]).T
a, b = np.linalg.lstsq(A, ref, rcond=1e-10)[0]
if len(sample) > 2:
residuals = ref - (a * sample + b)
mad = np.median(np.abs(residuals - np.median(residuals)))
if mad > 0:
cutoff = residual_threshold * 1.4826 * mad
inlier_mask = np.abs(residuals - np.median(residuals)) <= cutoff
if inlier_mask.sum() >= 2:
A_inlier = A[inlier_mask]
a, b = np.linalg.lstsq(A_inlier, ref[inlier_mask], rcond=1e-10)[0]
return a, b
class LinearStrategy(AlignmentStrategy):
"""Least-squares linear transformation alignment with outlier rejection."""
def __init__(self, max_shift: float) -> None:
self.max_shift = max_shift
self._fallback = ShiftStrategy(max_shift)
def align_binned(
self,
row: np.ndarray,
peaks: np.ndarray,
ref_peaks: np.ndarray,
mz_axis: np.ndarray,
) -> np.ndarray:
"""Apply linear transformation alignment to a binned spectrum."""
if len(peaks) < 2 or len(ref_peaks) < 2:
return self._fallback.align_binned(row, peaks, ref_peaks, mz_axis)
sample, ref = _match_peak_pairs(peaks, ref_peaks)
a, b = _robust_linear_fit(sample, ref)
new_positions = a * mz_axis + b
return monotonic_interp(mz_axis, new_positions, row)
def align_raw(
self,
mz: np.ndarray,
intensity: np.ndarray,
peaks_mz: np.ndarray,
ref_peaks_mz: np.ndarray,
ref_mz: np.ndarray,
ref_intensity: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Apply linear m/z transformation to a raw spectrum."""
if len(peaks_mz) < 2 or len(ref_peaks_mz) < 2:
return self._fallback.align_raw(
mz, intensity, peaks_mz, ref_peaks_mz, ref_mz, ref_intensity
)
sample, ref = _match_peak_pairs(peaks_mz, ref_peaks_mz)
a, b = _robust_linear_fit(sample, ref)
return a * mz + b, intensity
class PiecewiseStrategy(AlignmentStrategy):
"""Piecewise linear alignment with smoothed local shifts."""
def __init__(self, n_segments: int, smooth_sigma: float, max_shift: float) -> None:
self.n_segments = n_segments
self.smooth_sigma = smooth_sigma
self.max_shift = max_shift
def _compute_segment_shifts(
self, sample: np.ndarray, ref: np.ndarray
) -> tuple[list, list]:
"""Compute per-segment median positions and shifts."""
quantiles = np.linspace(0, 1, self.n_segments + 1)
boundaries = np.quantile(sample, quantiles)
seg_x, seg_shift = [], []
for q in range(self.n_segments):
if q == self.n_segments - 1:
mask = (sample >= boundaries[q]) & (sample <= boundaries[q + 1])
else:
mask = (sample >= boundaries[q]) & (sample < boundaries[q + 1])
if mask.sum() > 0:
seg_x.append(np.median(sample[mask]))
seg_shift.append(np.median(ref[mask] - sample[mask]))
return seg_x, seg_shift
def align_binned(
self,
row: np.ndarray,
peaks: np.ndarray,
ref_peaks: np.ndarray,
mz_axis: np.ndarray,
) -> np.ndarray:
"""Apply piecewise alignment to a binned spectrum."""
if len(peaks) == 0 or len(ref_peaks) == 0:
return row
sample, ref = _match_peak_pairs(peaks, ref_peaks)
seg_x, seg_shift = self._compute_segment_shifts(sample, ref)
if len(seg_x) == 0:
return row
shift_interp = np.interp(
mz_axis, seg_x, seg_shift, left=seg_shift[0], right=seg_shift[-1]
)
if self.smooth_sigma > 0:
shift_interp = gaussian_filter1d(
shift_interp, sigma=self.smooth_sigma, mode="nearest"
)
new_positions = mz_axis + shift_interp
return monotonic_interp(mz_axis, new_positions, row)
def align_raw(
self,
mz: np.ndarray,
intensity: np.ndarray,
peaks_mz: np.ndarray,
ref_peaks_mz: np.ndarray,
ref_mz: np.ndarray,
ref_intensity: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Apply piecewise m/z transformation to a raw spectrum."""
if len(peaks_mz) == 0 or len(ref_peaks_mz) == 0:
return mz, intensity
sample, ref = _match_peak_pairs(peaks_mz, ref_peaks_mz)
seg_x, seg_shift = self._compute_segment_shifts(sample, ref)
if len(seg_x) == 0:
return mz, intensity
shift_interp = np.interp(
mz, seg_x, seg_shift, left=seg_shift[0], right=seg_shift[-1]
)
if self.smooth_sigma > 0:
mz_spacing = np.median(np.diff(mz))
sigma_points = min(int(self.smooth_sigma / mz_spacing), len(mz) // 4)
if sigma_points > 1:
shift_interp = gaussian_filter1d(
shift_interp, sigma=sigma_points, mode="nearest"
)
return mz + shift_interp, intensity
class DTWStrategy(AlignmentStrategy):
"""Dynamic time warping alignment."""
def __init__(self, dtw_radius: int) -> None:
self.dtw_radius = dtw_radius
def _dtw_align(self, query: np.ndarray, reference: np.ndarray) -> np.ndarray:
"""Core DTW alignment returning aligned intensity."""
path, _ = dtw_path(
query,
reference,
global_constraint="sakoe_chiba",
sakoe_chiba_radius=self.dtw_radius,
)
aligned_sum = np.zeros_like(reference)
aligned_count = np.zeros_like(reference)
for i, j in path:
if 0 <= j < len(aligned_sum):
aligned_sum[j] += query[i]
aligned_count[j] += 1
aligned = np.zeros_like(reference)
mask = aligned_count > 0
aligned[mask] = aligned_sum[mask] / aligned_count[mask]
return aligned
def align_binned(
self,
row: np.ndarray,
peaks: np.ndarray,
ref_peaks: np.ndarray,
mz_axis: np.ndarray,
) -> np.ndarray:
"""Raise because DTW binned alignment requires the full reference spectrum.
DTW binned alignment is handled directly by the Warping class
which has access to the stored reference spectrum.
"""
raise NotImplementedError(
"DTW binned alignment is handled directly by the Warping class"
)
def align_raw(
self,
mz: np.ndarray,
intensity: np.ndarray,
peaks_mz: np.ndarray,
ref_peaks_mz: np.ndarray,
ref_mz: np.ndarray,
ref_intensity: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Apply DTW alignment to a raw spectrum."""
query_intensity = np.interp(ref_mz, mz, intensity)
aligned_intensity = self._dtw_align(query_intensity, ref_intensity)
return ref_mz, aligned_intensity
class PolynomialStrategy(AlignmentStrategy):
"""Polynomial alignment via ``numpy.polyfit`` on matched peak pairs.
Used for both the ``"quadratic"`` (``degree=2``) and ``"cubic"``
(``degree=3``) warping methods. When fewer than ``degree + 1``
matched peak pairs are available, falls back to
:class:`ShiftStrategy` to avoid an under-determined fit.
Parameters
----------
max_shift : float
Maximum allowed shift passed through to the shift fallback.
degree : int
Polynomial degree. ``2`` for quadratic, ``3`` for cubic.
"""
def __init__(self, max_shift: float, degree: int) -> None:
if degree < 1:
raise ValueError(f"degree must be >= 1, got {degree}")
self.max_shift = max_shift
self.degree = degree
self._fallback = ShiftStrategy(max_shift)
def _fit_polynomial(self, sample: np.ndarray, ref: np.ndarray) -> np.ndarray:
"""Fit a polynomial of ``self.degree`` mapping sample to reference."""
return np.polyfit(sample, ref, self.degree)
def align_binned(
self,
row: np.ndarray,
peaks: np.ndarray,
ref_peaks: np.ndarray,
mz_axis: np.ndarray,
) -> np.ndarray:
"""Apply polynomial warping to a binned spectrum."""
if len(peaks) < self.degree + 1 or len(ref_peaks) < self.degree + 1:
return self._fallback.align_binned(row, peaks, ref_peaks, mz_axis)
sample, ref = _match_peak_pairs(peaks, ref_peaks)
coeffs = self._fit_polynomial(sample, ref)
new_positions = np.polyval(coeffs, mz_axis)
return monotonic_interp(mz_axis, new_positions, row)
def align_raw(
self,
mz: np.ndarray,
intensity: np.ndarray,
peaks_mz: np.ndarray,
ref_peaks_mz: np.ndarray,
ref_mz: np.ndarray,
ref_intensity: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Apply polynomial m/z warping to a raw spectrum."""
if len(peaks_mz) < self.degree + 1 or len(ref_peaks_mz) < self.degree + 1:
return self._fallback.align_raw(
mz, intensity, peaks_mz, ref_peaks_mz, ref_mz, ref_intensity
)
sample, ref = _match_peak_pairs(peaks_mz, ref_peaks_mz)
coeffs = self._fit_polynomial(sample, ref)
return np.polyval(coeffs, mz), intensity
def _lowess_fit(
x: np.ndarray,
y: np.ndarray,
frac: float,
it: int,
) -> np.ndarray:
"""Locally-weighted scatterplot smoothing (Cleveland 1979).
Lightweight pure-numpy implementation that fits a local linear
regression at each data point using a tricube kernel on the
``frac * n`` nearest neighbours, with ``it`` robustness reweighting
iterations using bisquare residual weights. Input ``x`` must be
sorted in ascending order.
Parameters
----------
x, y : np.ndarray
One-dimensional arrays of the same length; ``x`` sorted.
frac : float
Smoothing bandwidth, fraction of data used for each local fit,
in the half-open interval ``(0, 1]``.
it : int
Number of robustness reweighting iterations. ``0`` disables it.
Returns
-------
np.ndarray
Smoothed ``y`` values at each ``x``.
"""
n = len(x)
if n == 0:
return np.empty(0, dtype=float)
k = max(int(np.ceil(frac * n)), 2)
k = min(k, n)
fitted = np.empty(n, dtype=float)
residual_weights = np.ones(n, dtype=float)
for _ in range(it + 1):
for i in range(n):
distances = np.abs(x - x[i])
nearest = np.argsort(distances)[:k]
h = distances[nearest].max()
if h == 0:
fitted[i] = np.average(y[nearest], weights=residual_weights[nearest])
continue
u = np.clip(distances[nearest] / h, 0.0, 1.0)
kernel = (1.0 - u**3) ** 3
w = kernel * residual_weights[nearest]
total = w.sum()
if total <= 0:
fitted[i] = y[i]
continue
xn = x[nearest]
yn = y[nearest]
mean_x = np.sum(w * xn) / total
mean_y = np.sum(w * yn) / total
dx = xn - mean_x
denom = np.sum(w * dx * dx)
if denom <= 0:
fitted[i] = mean_y
else:
beta = np.sum(w * dx * (yn - mean_y)) / denom
fitted[i] = mean_y + beta * (x[i] - mean_x)
residuals = y - fitted
s = np.median(np.abs(residuals))
if s == 0:
break
u = np.clip(residuals / (6.0 * s), -1.0, 1.0)
residual_weights = (1.0 - u**2) ** 2
return fitted
class LOWESSStrategy(AlignmentStrategy):
"""Non-linear LOWESS alignment on matched peak pairs.
Fits a LOWESS (Cleveland 1979) regression of the form
``ref = f(sample)`` on the matched peak pairs, then applies the
fitted function to the full m/z axis by linear interpolation
between the fitted peak positions.
Parameters
----------
max_shift : float
Maximum allowed shift passed through to the shift fallback.
frac : float
LOWESS smoothing bandwidth (fraction of matched peaks used for
each local fit). Must be in ``(0, 1]``.
it : int
Number of LOWESS robustness iterations. Must be ``>= 0``.
"""
_MIN_PEAKS = 3
def __init__(self, max_shift: float, frac: float, it: int) -> None:
if not (0.0 < frac <= 1.0):
raise ValueError(f"frac must be in (0, 1], got {frac}")
if it < 0:
raise ValueError(f"it must be >= 0, got {it}")
self.max_shift = max_shift
self.frac = frac
self.it = it
self._fallback = ShiftStrategy(max_shift)
def _fit_positions(
self, sample: np.ndarray, ref: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""Return (sorted_sample, fitted_ref) on the matched peak pairs."""
order = np.argsort(sample)
sample_sorted = sample[order].astype(float)
ref_sorted = ref[order].astype(float)
fitted = _lowess_fit(sample_sorted, ref_sorted, self.frac, self.it)
return sample_sorted, fitted
def align_binned(
self,
row: np.ndarray,
peaks: np.ndarray,
ref_peaks: np.ndarray,
mz_axis: np.ndarray,
) -> np.ndarray:
"""Apply LOWESS warping to a binned spectrum."""
if len(peaks) < self._MIN_PEAKS or len(ref_peaks) < self._MIN_PEAKS:
return self._fallback.align_binned(row, peaks, ref_peaks, mz_axis)
sample, ref = _match_peak_pairs(peaks, ref_peaks)
sample_sorted, fitted = self._fit_positions(sample, ref)
new_positions = np.interp(
mz_axis,
sample_sorted,
fitted,
left=fitted[0] + (mz_axis[0] - sample_sorted[0]),
right=fitted[-1] + (mz_axis[-1] - sample_sorted[-1]),
)
return monotonic_interp(mz_axis, new_positions, row)
def align_raw(
self,
mz: np.ndarray,
intensity: np.ndarray,
peaks_mz: np.ndarray,
ref_peaks_mz: np.ndarray,
ref_mz: np.ndarray,
ref_intensity: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Apply LOWESS m/z warping to a raw spectrum."""
if len(peaks_mz) < self._MIN_PEAKS or len(ref_peaks_mz) < self._MIN_PEAKS:
return self._fallback.align_raw(
mz, intensity, peaks_mz, ref_peaks_mz, ref_mz, ref_intensity
)
sample, ref = _match_peak_pairs(peaks_mz, ref_peaks_mz)
sample_sorted, fitted = self._fit_positions(sample, ref)
new_mz = np.interp(
mz,
sample_sorted,
fitted,
left=fitted[0] + (mz[0] - sample_sorted[0]),
right=fitted[-1] + (mz[-1] - sample_sorted[-1]),
)
return new_mz, intensity
ALIGNMENT_REGISTRY: dict[str, type[AlignmentStrategy]] = {
"shift": ShiftStrategy,
"linear": LinearStrategy,
"piecewise": PiecewiseStrategy,
"dtw": DTWStrategy,
"quadratic": PolynomialStrategy,
"cubic": PolynomialStrategy,
"lowess": LOWESSStrategy,
}