{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MaldiAMRKit - Spectral Alignment\n", "\n", "This notebook covers spectral alignment (warping) methods to correct for mass calibration drift." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import StratifiedKFold, cross_val_score\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from maldiamrkit import MaldiSet\n", "from maldiamrkit.alignment import RawWarping, Warping, create_raw_input" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Features shape: (29, 6000)\n" ] } ], "source": [ "data = MaldiSet.from_directory(\n", " \"../data/\",\n", " \"../data/metadata/metadata.csv\",\n", " aggregate_by=dict(antibiotics=\"Drug\"),\n", ")\n", "X = data.X\n", "y = data.y[\"Drug\"].map({\"S\": 0, \"I\": 1, \"R\": 1})\n", "\n", "print(f\"Features shape: {X.shape}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Warping Methods\n", "\n", "MaldiAMRKit supports multiple alignment methods:\n", "\n", "- **shift**: Global median shift (fast, simple)\n", "- **linear**: Least-squares linear transformation\n", "- **piecewise**: Local shifts across spectrum segments (most flexible)\n", "- **dtw**: Dynamic Time Warping (best for non-linear drift, slowest)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shift - CV ROC AUC: 0.400 +/- 0.255\n" ] } ], "source": [ "cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)\n", "\n", "# Shift method (fastest)\n", "pipe_shift = Pipeline(\n", " [\n", " (\"warp\", Warping(method=\"shift\")),\n", " (\"scaler\", StandardScaler()),\n", " (\"clf\", LogisticRegression()),\n", " ]\n", ")\n", "\n", "scores = cross_val_score(pipe_shift, X, y, cv=cv, scoring=\"roc_auc\")\n", "print(f\"Shift - CV ROC AUC: {scores.mean():.3f} +/- {scores.std():.3f}\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Linear - CV ROC AUC: 0.400 +/- 0.289\n" ] } ], "source": [ "# Linear method\n", "pipe_linear = Pipeline(\n", " [\n", " (\"warp\", Warping(method=\"linear\")),\n", " (\"scaler\", StandardScaler()),\n", " (\"clf\", LogisticRegression()),\n", " ]\n", ")\n", "\n", "scores = cross_val_score(pipe_linear, X, y, cv=cv, scoring=\"roc_auc\")\n", "print(f\"Linear - CV ROC AUC: {scores.mean():.3f} +/- {scores.std():.3f}\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Piecewise - CV ROC AUC: 0.400 +/- 0.289\n" ] } ], "source": [ "# Piecewise method (often best trade-off)\n", "pipe_piecewise = Pipeline(\n", " [\n", " (\"warp\", Warping(method=\"piecewise\", n_segments=10)),\n", " (\"scaler\", StandardScaler()),\n", " (\"clf\", LogisticRegression()),\n", " ]\n", ")\n", "\n", "scores = cross_val_score(pipe_piecewise, X, y, cv=cv, scoring=\"roc_auc\")\n", "print(f\"Piecewise - CV ROC AUC: {scores.mean():.3f} +/- {scores.std():.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Alignment Quality Assessment\n", "\n", "Use `get_alignment_quality()` to measure how well spectra were aligned to the reference." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean correlation improvement: 0.0056\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
correlation_beforecorrelation_afterimprovementrmse_beforermse_after
10s0.8507810.8507810.0000000.0001370.000137
11s0.8543970.8543970.0000000.0001850.000185
12s0.8983600.8983600.0000000.0001920.000192
13s0.8174040.8174040.0000000.0002400.000240
14s0.8251120.825087-0.0000240.0001770.000177
\n", "
" ], "text/plain": [ " correlation_before correlation_after improvement rmse_before \\\n", "10s 0.850781 0.850781 0.000000 0.000137 \n", "11s 0.854397 0.854397 0.000000 0.000185 \n", "12s 0.898360 0.898360 0.000000 0.000192 \n", "13s 0.817404 0.817404 0.000000 0.000240 \n", "14s 0.825112 0.825087 -0.000024 0.000177 \n", "\n", " rmse_after \n", "10s 0.000137 \n", "11s 0.000185 \n", "12s 0.000192 \n", "13s 0.000240 \n", "14s 0.000177 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fit warping and check alignment quality\n", "warper = Warping(method=\"piecewise\", n_segments=10)\n", "warper.fit(X)\n", "X_aligned = warper.transform(X)\n", "\n", "# Get alignment quality metrics\n", "quality = warper.get_alignment_quality(X, X_aligned)\n", "print(f\"Mean correlation improvement: {quality['improvement'].mean():.4f}\")\n", "quality.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Raw Spectra Warping\n", "\n", "`RawWarping` performs alignment at full m/z resolution (before binning) for higher precision. It loads raw spectra files during fit/transform and outputs properly binned data.\n", "\n", "**Key workflow:**\n", "1. Use `create_raw_input()` to create input DataFrame with file paths\n", "2. Pass this DataFrame to `RawWarping` in your pipeline\n", "3. Get properly binned, aligned spectra as output\n", "\n", "This design makes `RawWarping` fully compatible with sklearn pipelines." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input DataFrame shape: (29, 1)\n", "Columns: ['path']\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
path
10s../data/10s.txt
11s../data/11s.txt
12s../data/12s.txt
13s../data/13s.txt
14s../data/14s.txt
\n", "
" ], "text/plain": [ " path\n", "10s ../data/10s.txt\n", "11s ../data/11s.txt\n", "12s ../data/12s.txt\n", "13s ../data/13s.txt\n", "14s ../data/14s.txt" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create input DataFrame from raw spectra directory\n", "X_raw = create_raw_input(\"../data/\")\n", "print(f\"Input DataFrame shape: {X_raw.shape}\")\n", "print(f\"Columns: {X_raw.columns.tolist()}\")\n", "X_raw.head()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input shape: (29, 1) (single 'path' column)\n", "Output shape: (29, 6000) (binned spectra)\n", "Output columns are m/z bin starting points: ['2000', '2003', '2006', '2009', '2012']...\n" ] } ], "source": [ "# RawWarping in a pipeline - outputs binned spectra\n", "raw_warper = RawWarping(\n", " method=\"piecewise\",\n", " bin_width=3,\n", " max_shift_da=10.0,\n", " n_segments=5,\n", ")\n", "\n", "# Fit and transform - loads raw files, warps at full resolution, bins output\n", "raw_warper.fit(X_raw)\n", "X_raw_aligned = raw_warper.transform(X_raw)\n", "print(f\"Input shape: {X_raw.shape} (single 'path' column)\")\n", "print(f\"Output shape: {X_raw_aligned.shape} (binned spectra)\")\n", "print(\n", " f\"Output columns are m/z bin starting points: {X_raw_aligned.columns[:5].tolist()}...\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parallelization\n", "\n", "Use `n_jobs` parameter to enable parallel processing for faster computation." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Aligned 29 spectra\n" ] } ], "source": [ "# Parallel warping (use all cores)\n", "warper_parallel = Warping(method=\"piecewise\", n_segments=10, n_jobs=-1)\n", "warper_parallel.fit(X)\n", "X_aligned_parallel = warper_parallel.transform(X)\n", "print(f\"Aligned {len(X)} spectra\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## RawWarping in sklearn Pipeline\n", "\n", "Since `RawWarping` accepts a path-based DataFrame and outputs binned spectra, it integrates seamlessly into sklearn pipelines." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RawWarping Pipeline - CV ROC AUC: 0.375 +/- 0.250\n" ] } ], "source": [ "# Full pipeline: raw spectra -> alignment -> scaling -> classification\n", "pipe_raw = Pipeline(\n", " [\n", " (\"warp\", RawWarping(method=\"piecewise\", bin_width=3, n_segments=5)),\n", " (\"scaler\", StandardScaler()),\n", " (\"clf\", LogisticRegression()),\n", " ]\n", ")\n", "\n", "# Cross-validation with RawWarping pipeline\n", "# Note: X_raw contains file paths, y contains labels\n", "scores = cross_val_score(pipe_raw, X_raw, y, cv=cv, scoring=\"roc_auc\")\n", "print(f\"RawWarping Pipeline - CV ROC AUC: {scores.mean():.3f} +/- {scores.std():.3f}\")" ] } ], "metadata": { "kernelspec": { "display_name": "maldiamrkit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }