"""Peak detection algorithms for MALDI-TOF spectra."""
from __future__ import annotations
from enum import Enum
import gudhi
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from scipy.signal import find_peaks
from sklearn.base import BaseEstimator, TransformerMixin
[docs]
class PeakMethod(str, Enum):
"""Supported peak detection methods.
Attributes
----------
local : str
Local maxima detection via ``scipy.signal.find_peaks``.
ph : str
Persistent homology based peak detection.
"""
local = "local"
ph = "ph"
[docs]
class MaldiPeakDetector(BaseEstimator, TransformerMixin):
"""
Peak detector for MALDI-TOF spectra with local maxima and topological methods.
The transformer maintains the original feature dimension; all non-peak
positions are set to 0. Peaks can be returned as binary flags or with
their original intensities.
Parameters
----------
method : {"local", "ph"}, default="local"
Detection method to use:
- "local" : Local maxima detection using scipy.signal.find_peaks
- "ph" : Persistent homology based detection using gudhi
binary : bool, default=True
If True, peaks are marked with 1; otherwise, original intensity is kept.
persistence_threshold : float, default=1e-6
Minimum persistence (death - birth) required for a peak when using
method="ph". For normalized spectra (sum=1), typical values are 1e-6
to 1e-4. Higher values detect fewer, more prominent peaks.
n_jobs : int, default=1
Number of parallel jobs for peak detection. Use -1 for all available
cores, 1 for sequential processing. Parallelization is particularly
beneficial for the "ph" method which is CPU-intensive.
prominence : float or None, default=None
Minimum prominence of peaks (recommended: 1e-5 to 1e-2).
Passed to :func:`scipy.signal.find_peaks` when ``method="local"``.
height : float or None, default=None
Minimum height of peaks.
Passed to :func:`scipy.signal.find_peaks` when ``method="local"``.
distance : int or None, default=None
Minimum distance between peaks in bins.
Passed to :func:`scipy.signal.find_peaks` when ``method="local"``.
width : float or None, default=None
Minimum width of peaks.
Passed to :func:`scipy.signal.find_peaks` when ``method="local"``.
**kwargs :
Additional keyword arguments passed to
:func:`scipy.signal.find_peaks` when ``method="local"``.
Notes
-----
For MALDI-TOF spectra normalized to sum=1:
- prominence=1e-5 to 1e-3 typically works well for local maxima
- persistence_threshold=1e-6 to 1e-4 for persistent homology
Raises
------
ValueError
If ``method`` is not one of 'local' or 'ph'.
Examples
--------
>>> # Local maxima detection with prominence filter
>>> detector = MaldiPeakDetector(method="local", prominence=0.01)
>>> peaks = detector.fit_transform(spectra_df)
>>> # Persistent homology based detection
>>> detector = MaldiPeakDetector(method="ph", persistence_threshold=1e-6)
>>> peaks = detector.fit_transform(spectra_df)
"""
[docs]
def __init__(
self,
method: str | PeakMethod = PeakMethod.local,
binary: bool = True,
persistence_threshold: float = 1e-6,
n_jobs: int = 1,
prominence: float | None = None,
height: float | None = None,
distance: int | None = None,
width: float | None = None,
**kwargs,
) -> None:
self.method = PeakMethod(method)
self.binary = binary
self.persistence_threshold = persistence_threshold
self.n_jobs = n_jobs
self.prominence = prominence
self.height = height
self.distance = distance
self.width = width
# Build kwargs from explicit params + extra kwargs
self.kwargs = dict(kwargs)
for param in ("prominence", "height", "distance", "width"):
val = getattr(self, param)
if val is not None:
self.kwargs.setdefault(param, val)
[docs]
def fit(self, X: pd.DataFrame, y=None):
"""
Fit the peak detector (no learning performed).
Parameters
----------
X : pd.DataFrame
Input spectra with shape (n_samples, n_bins).
y : array-like, optional
Target values (ignored).
Returns
-------
self : MaldiPeakDetector
Fitted transformer.
Raises
------
ValueError
If the input DataFrame is empty.
"""
if X.empty:
raise ValueError("Input DataFrame X is empty")
return self
def _detect_peaks_local(self, row: np.ndarray) -> np.ndarray:
"""
Detect peaks using local maxima detection.
Uses scipy.signal.find_peaks with configurable parameters.
Parameters
----------
row : np.ndarray
1D spectrum intensity array.
Returns
-------
peaks : np.ndarray
Array of peak indices.
"""
peaks, _ = find_peaks(row, **self.kwargs)
return peaks
def _detect_peaks_ph(self, row: np.ndarray) -> np.ndarray:
"""
Detect peaks using persistent homology (0D persistence).
Computes the 0D persistence diagram of the signal treated as a
1D cubical complex. Peaks correspond to local maxima with sufficient
persistence (death - birth) above the threshold.
Parameters
----------
row : np.ndarray
1D spectrum intensity array.
Returns
-------
peaks : np.ndarray
Array of peak indices corresponding to persistent maxima.
Notes
-----
The algorithm:
1. Negates the signal so that ``row``'s maxima become sub-level
minima (0D component births).
2. Builds a 1D cubical complex and computes 0D persistence.
3. Recovers the exact birth-cell index for each pair.
4. Filters pairs by ``persistence >= persistence_threshold``.
"""
if np.allclose(row, row[0]):
return np.array([], dtype=int)
# Negate signal for sub-level-set filtration on the negated signal
# (so that peaks of ``row`` become births of 0D features).
signal = -row
signal = signal - signal.min()
cc = gudhi.CubicalComplex(top_dimensional_cells=signal[np.newaxis, :])
cc.persistence()
regular_pairs, essential_pairs = cc.cofaces_of_persistence_pairs()
regular = regular_pairs[0] if len(regular_pairs) else np.empty((0, 2), int)
essential = essential_pairs[0] if len(essential_pairs) else np.empty(0, int)
signal_max = float(np.max(signal))
peak_indices: list[int] = []
if regular.size:
births = signal[regular[:, 0]]
deaths = signal[regular[:, 1]]
persistences = deaths - births
keep = persistences >= self.persistence_threshold
peak_indices.extend(int(i) for i in regular[keep, 0].tolist())
if essential.size:
essential_births = signal[essential]
persistences = signal_max - essential_births
keep = persistences >= self.persistence_threshold
peak_indices.extend(int(i) for i in essential[keep].tolist())
return np.array(sorted(set(peak_indices)), dtype=int)
def _process_single_row(self, row: np.ndarray) -> np.ndarray:
"""Process a single row and return masked array (helper for parallelization)."""
if self.method == "local":
peaks = self._detect_peaks_local(row)
elif self.method == "ph":
peaks = self._detect_peaks_ph(row)
else:
raise ValueError(f"Unknown method: {self.method}")
masked = np.zeros_like(row, dtype=row.dtype)
if self.binary:
masked[peaks] = 1
else:
masked[peaks] = row[peaks]
return masked
[docs]
def get_peak_statistics(self, X: pd.DataFrame) -> pd.DataFrame:
"""
Get statistics about detected peaks for each spectrum.
Parameters
----------
X : pd.DataFrame or pd.Series
Input spectra with shape (n_samples, n_bins).
Returns
-------
stats : pd.DataFrame
DataFrame with columns:
- n_peaks: number of peaks detected
- mean_intensity: mean intensity of detected peaks
- max_intensity: maximum intensity of detected peaks
"""
input_is_series = isinstance(X, pd.Series)
if input_is_series:
X = X.to_frame().T
stats = []
for i in range(X.shape[0]):
row = X.iloc[i].values
if self.method == "local":
peaks = self._detect_peaks_local(row)
elif self.method == "ph":
peaks = self._detect_peaks_ph(row)
else:
raise ValueError(f"Unknown method: {self.method}")
n_peaks = len(peaks)
if n_peaks > 0:
peak_intensities = row[peaks]
mean_intensity = np.mean(peak_intensities)
max_intensity = np.max(peak_intensities)
else:
mean_intensity = 0.0
max_intensity = 0.0
stats.append(
{
"n_peaks": n_peaks,
"mean_intensity": mean_intensity,
"max_intensity": max_intensity,
}
)
return pd.DataFrame(stats, index=X.index)