Source code for maldiamrkit.preprocessing.preprocessing_pipeline

"""Composable preprocessing pipeline for MALDI-TOF spectra.

Similar to :class:`sklearn.pipeline.Pipeline` but designed for spectrum
DataFrames with ``mass`` and ``intensity`` columns.

Examples
--------
>>> from maldiamrkit.preprocessing import PreprocessingPipeline
>>> from maldiamrkit.preprocessing.transformers import *
>>>
>>> # Default pipeline (standard preprocessing)
>>> pipe = PreprocessingPipeline.default()
>>> preprocessed = pipe(raw_df)
>>>
>>> # Custom pipeline
>>> pipe = PreprocessingPipeline([
...     ("clip", ClipNegatives()),
...     ("log", LogTransform()),
...     ("smooth", SavitzkyGolaySmooth(window_length=15)),
...     ("baseline", SNIPBaseline(half_window=30)),
...     ("trim", MzTrimmer(mz_min=2000, mz_max=20000)),
...     ("norm", TICNormalizer()),
... ])
>>> preprocessed = pipe(raw_df)
"""

from __future__ import annotations

import json
from pathlib import Path
from typing import Protocol, runtime_checkable

import pandas as pd

from .transformers import (
    TRANSFORMER_REGISTRY,
    ClipNegatives,
    MzTrimmer,
    SavitzkyGolaySmooth,
    SNIPBaseline,
    SqrtTransform,
    TICNormalizer,
)


@runtime_checkable
class PreprocessingStep(Protocol):
    """Protocol for preprocessing step transformers."""

    def __call__(self, df: pd.DataFrame) -> pd.DataFrame:
        """Apply the preprocessing step to a spectrum DataFrame."""
        ...

    def to_dict(self) -> dict:
        """Serialize the step to a dictionary."""
        ...


[docs] class PreprocessingPipeline: """Composable pipeline of preprocessing steps for MALDI-TOF spectra. Parameters ---------- steps : list of (str, transformer) tuples Named preprocessing steps. Each transformer must be callable, accepting and returning a ``pd.DataFrame`` with ``mass`` and ``intensity`` columns. Examples -------- >>> pipe = PreprocessingPipeline.default() >>> preprocessed = pipe(raw_spectrum_df) """
[docs] def __init__(self, steps: list[tuple[str, PreprocessingStep]]): self.steps = list(steps)
[docs] def __call__(self, df: pd.DataFrame) -> pd.DataFrame: """Apply all preprocessing steps sequentially. Parameters ---------- df : pd.DataFrame Raw spectrum with ``mass`` and ``intensity`` columns. Returns ------- pd.DataFrame Preprocessed spectrum. """ for _name, step in self.steps: df = step(df) return df
[docs] @classmethod def default(cls) -> PreprocessingPipeline: """Return the standard preprocessing pipeline. Steps: clip negatives -> sqrt transform -> Savitzky-Golay smoothing -> SNIP baseline -> m/z trim (2000-20000 Da) -> TIC normalization. Returns ------- PreprocessingPipeline Default pipeline instance. """ return cls( [ ("clip", ClipNegatives()), ("sqrt", SqrtTransform()), ("smooth", SavitzkyGolaySmooth(window_length=21, polyorder=2)), ("baseline", SNIPBaseline(half_window=40)), ("trim", MzTrimmer(mz_min=2000, mz_max=20000)), ("normalize", TICNormalizer()), ] )
[docs] def get_step(self, name: str) -> object: """Get a step by name. Parameters ---------- name : str Step name. Returns ------- object The transformer associated with that name. Raises ------ KeyError If no step with that name exists. """ for step_name, step in self.steps: if step_name == name: return step raise KeyError(f"Step '{name}' not found. Available: {self.step_names}")
@property def step_names(self) -> list[str]: """Return the names of all steps.""" return [name for name, _ in self.steps] @property def mz_range(self) -> tuple[int, int]: """Extract (mz_min, mz_max) from the MzTrimmer step. Returns ------- tuple[int, int] The m/z range from the MzTrimmer step, or the default (2000, 20000) if no MzTrimmer is present. """ for _, step in self.steps: if isinstance(step, MzTrimmer): return step.mz_min, step.mz_max return 2000, 20000
[docs] def to_dict(self) -> dict: """Serialize the pipeline to a dictionary. Returns ------- dict Dictionary representation suitable for JSON/YAML serialization. """ return { "steps": [ {"step_name": name, **step.to_dict()} for name, step in self.steps ] }
[docs] @classmethod def from_dict(cls, d: dict) -> PreprocessingPipeline: """Reconstruct a pipeline from a dictionary. Parameters ---------- d : dict Dictionary as produced by :meth:`to_dict`. Returns ------- PreprocessingPipeline Reconstructed pipeline. """ steps = [] for step_dict in d["steps"]: step_name = step_dict["step_name"] transformer_name = step_dict["name"] transformer_cls = TRANSFORMER_REGISTRY[transformer_name] # Extract constructor kwargs (everything except step_name and name) kwargs = { k: v for k, v in step_dict.items() if k not in ("step_name", "name") } steps.append((step_name, transformer_cls(**kwargs))) return cls(steps)
[docs] def to_json(self, path: str | Path) -> None: """Save the pipeline configuration to a JSON file. Parameters ---------- path : str or Path Output file path. """ with open(path, "w") as f: json.dump(self.to_dict(), f, indent=2)
[docs] @classmethod def from_json(cls, path: str | Path) -> PreprocessingPipeline: """Load a pipeline from a JSON file. Parameters ---------- path : str or Path Input file path. Returns ------- PreprocessingPipeline Reconstructed pipeline. """ with open(path) as f: return cls.from_dict(json.load(f))
[docs] def to_yaml(self, path: str | Path) -> None: """Save the pipeline configuration to a YAML file. Requires ``pyyaml`` to be installed. Parameters ---------- path : str or Path Output file path. """ import yaml with open(path, "w") as f: yaml.dump(self.to_dict(), f, default_flow_style=False)
[docs] @classmethod def from_yaml(cls, path: str | Path) -> PreprocessingPipeline: """Load a pipeline from a YAML file. Requires ``pyyaml`` to be installed. Parameters ---------- path : str or Path Input file path. Returns ------- PreprocessingPipeline Reconstructed pipeline. """ import yaml with open(path) as f: return cls.from_dict(yaml.safe_load(f))
def __repr__(self) -> str: steps_repr = ",\n ".join(f"('{name}', {step!r})" for name, step in self.steps) return f"PreprocessingPipeline([\n {steps_repr}\n])" def __len__(self) -> int: return len(self.steps)