Source code for maldiamrkit.data.loader

"""Load datasets into MaldiSet objects.

The :class:`DatasetLoader` uses a :class:`DatasetLayout` to navigate
different dataset structures and load spectra into a
:class:`~maldiamrkit.dataset.MaldiSet`.
"""

from __future__ import annotations

import logging
from pathlib import Path

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from tqdm.auto import tqdm

from ..dataset import MaldiSet
from ..spectrum import MaldiSpectrum, _infer_id
from .dataset_layouts import DatasetLayout

logger = logging.getLogger(__name__)


def _load_single(path: Path) -> MaldiSpectrum:
    """Load a single spectrum without preprocessing or binning."""
    return MaldiSpectrum(path)


[docs] class DatasetLoader: """Load a dataset into a :class:`~maldiamrkit.dataset.MaldiSet`. Parameters ---------- layout : DatasetLayout Dataset navigation adapter (e.g. :class:`DRIAMSLayout` or :class:`MARISMaLayout`). stage : str or None Processing stage to load. ``None`` triggers auto-detection via the layout. n_jobs : int, default=-1 Number of parallel workers for spectrum loading. verbose : bool, default=False If True, show tqdm progress bars during spectrum loading and pass ``verbose`` through to :class:`~maldiamrkit.dataset.MaldiSet`. Examples -------- >>> from maldiamrkit.data import DatasetLoader, DRIAMSLayout >>> layout = DRIAMSLayout("output/my_dataset") >>> loader = DatasetLoader(layout) >>> ds = loader.load(aggregate_by=dict(antibiotics="Ceftriaxone")) """
[docs] def __init__( self, layout: DatasetLayout, *, stage: str | None = None, n_jobs: int = -1, verbose: bool = False, ) -> None: self.layout = layout self.stage = stage self.n_jobs = n_jobs self.verbose = verbose
[docs] def load( self, aggregate_by: dict[str, str | list[str]] | None = None, ) -> MaldiSet: """Load the dataset. Parameters ---------- aggregate_by : dict, optional Passed through to :class:`~maldiamrkit.dataset.MaldiSet`. Returns ------- MaldiSet Dataset with loaded spectra and metadata. """ # 1. Resolve stage stage_name = ( self.stage if self.stage is not None else self.layout.detect_stage() ) # 2. Load metadata meta = self.layout.discover_metadata() # 3. Pre-filter metadata by aggregate_by criteria n_meta_before = len(meta) meta = self._prefilter_metadata(meta, aggregate_by) if len(meta) < n_meta_before: logger.info( "Pre-filtered metadata: %d -> %d rows", n_meta_before, len(meta), ) # 4. Collect spectrum files year = getattr(self.layout, "year", None) spectrum_files = self.layout.collect_spectrum_files(stage_name, year) if not spectrum_files: raise FileNotFoundError(f"No spectrum files found for stage '{stage_name}'") # 5. Match spectrum files to metadata IDs meta_ids = set(meta["ID"].astype(str)) matched_files: list[Path] = [] for f in spectrum_files: fid = _infer_id(f) if fid in meta_ids: matched_files.append(f) if not matched_files: raise ValueError( f"No spectrum files matched metadata IDs. " f"Found {len(spectrum_files)} files and " f"{len(meta_ids)} metadata IDs." ) n_total = len(spectrum_files) n_matched = len(matched_files) if n_matched < n_total: logger.info( "Loading %d/%d spectra (others not in metadata)", n_matched, n_total, ) # 6. Load spectra if self.verbose and self.n_jobs == 1: spectra = [ _load_single(p) for p in tqdm(matched_files, desc="Loading spectra", unit="file") ] else: spectra = Parallel(n_jobs=self.n_jobs, prefer="threads")( delayed(_load_single)(p) for p in matched_files ) # 6a. Dataset-specific post-processing (e.g. DRIAMS binned_N/ files # store bin_index, not m/z, and need rewriting into real m/z). spectra = [ self.layout.postprocess_spectrum(s, stage=stage_name) for s in spectra ] # 6b. Average replicates when the layout used strategy="average" if "_original_id" in meta.columns: spectra, meta = self._average_replicates(spectra, meta) # 7. Build MaldiSet return MaldiSet( spectra, meta, aggregate_by=aggregate_by, verbose=self.verbose, )
@staticmethod def _prefilter_metadata( meta: pd.DataFrame, aggregate_by: dict[str, str | list[str]] | None, ) -> pd.DataFrame: """Filter metadata rows before spectrum loading to reduce I/O.""" if not aggregate_by: return meta species_val = aggregate_by.get("species") if species_val and "Species" in meta.columns: meta = meta[meta["Species"] == species_val] antibiotics_val = aggregate_by.get("antibiotics") if antibiotics_val is not None: if isinstance(antibiotics_val, str): antibiotics_val = [antibiotics_val] available = [ab for ab in antibiotics_val if ab in meta.columns] if available: meta = meta[meta[available].notna().any(axis=1)] return meta @staticmethod def _average_replicates( spectra: list[MaldiSpectrum], meta: pd.DataFrame, ) -> tuple[list[MaldiSpectrum], pd.DataFrame]: """Average replicate spectra that share an ``_original_id``. Groups spectra by the ``_original_id`` metadata column, interpolates each group onto a common m/z grid, and averages the intensities. Returns the deduplicated spectra list and metadata DataFrame. """ matched_ids = meta["ID"].tolist() id_to_spec: dict[str, MaldiSpectrum] = {} for sid, spec in zip(matched_ids, spectra, strict=True): id_to_spec[sid] = spec groups: dict[str, list[str]] = {} for _, row in meta.iterrows(): orig = row["_original_id"] groups.setdefault(orig, []).append(row["ID"]) averaged_spectra: list[MaldiSpectrum] = [] keep_rows: list[int] = [] for orig_id, member_ids in groups.items(): members = [id_to_spec[m] for m in member_ids if m in id_to_spec] if not members: continue if len(members) == 1: averaged_spectra.append(members[0]) else: mz_min = max(s.raw["mass"].min() for s in members) mz_max = min(s.raw["mass"].max() for s in members) step = min(np.median(np.diff(s.raw["mass"].values)) for s in members) common_mz = np.arange(mz_min, mz_max, step) intensities = np.zeros((len(members), len(common_mz))) for i, s in enumerate(members): intensities[i] = np.interp( common_mz, s.raw["mass"].values, s.raw["intensity"].values ) avg_intensity = intensities.mean(axis=0) avg_df = pd.DataFrame({"mass": common_mz, "intensity": avg_intensity}) avg_spec = MaldiSpectrum(avg_df) avg_spec.id = orig_id averaged_spectra.append(avg_spec) first_idx = meta.index[meta["_original_id"] == orig_id][0] keep_rows.append(first_idx) meta = meta.loc[keep_rows].copy() meta["ID"] = meta["_original_id"] meta = meta.drop(columns=["_original_id"]).reset_index(drop=True) logger.info( "Averaged replicates: %d spectra -> %d unique IDs.", len(spectra), len(averaged_spectra), ) return averaged_spectra, meta