{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MaldiAMRKit - Spectral Alignment\n",
"\n",
"This notebook covers spectral alignment (warping) methods to correct for mass calibration drift."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import Libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import StratifiedKFold, cross_val_score\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"from maldiamrkit import MaldiSet\n",
"from maldiamrkit.alignment import RawWarping, Warping, create_raw_input"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Features shape: (29, 6000)\n"
]
}
],
"source": [
"data = MaldiSet.from_directory(\n",
" \"../data/\",\n",
" \"../data/metadata/metadata.csv\",\n",
" aggregate_by=dict(antibiotics=\"Drug\"),\n",
")\n",
"X = data.X\n",
"y = data.y[\"Drug\"].map({\"S\": 0, \"I\": 1, \"R\": 1})\n",
"\n",
"print(f\"Features shape: {X.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Warping Methods\n",
"\n",
"MaldiAMRKit supports multiple alignment methods:\n",
"\n",
"- **shift**: Global median shift (fast, simple)\n",
"- **linear**: Least-squares linear transformation\n",
"- **piecewise**: Local shifts across spectrum segments (most flexible)\n",
"- **dtw**: Dynamic Time Warping (best for non-linear drift, slowest)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shift - CV ROC AUC: 0.400 +/- 0.255\n"
]
}
],
"source": [
"cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)\n",
"\n",
"# Shift method (fastest)\n",
"pipe_shift = Pipeline(\n",
" [\n",
" (\"warp\", Warping(method=\"shift\")),\n",
" (\"scaler\", StandardScaler()),\n",
" (\"clf\", LogisticRegression()),\n",
" ]\n",
")\n",
"\n",
"scores = cross_val_score(pipe_shift, X, y, cv=cv, scoring=\"roc_auc\")\n",
"print(f\"Shift - CV ROC AUC: {scores.mean():.3f} +/- {scores.std():.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Linear - CV ROC AUC: 0.400 +/- 0.289\n"
]
}
],
"source": [
"# Linear method\n",
"pipe_linear = Pipeline(\n",
" [\n",
" (\"warp\", Warping(method=\"linear\")),\n",
" (\"scaler\", StandardScaler()),\n",
" (\"clf\", LogisticRegression()),\n",
" ]\n",
")\n",
"\n",
"scores = cross_val_score(pipe_linear, X, y, cv=cv, scoring=\"roc_auc\")\n",
"print(f\"Linear - CV ROC AUC: {scores.mean():.3f} +/- {scores.std():.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Piecewise - CV ROC AUC: 0.400 +/- 0.289\n"
]
}
],
"source": [
"# Piecewise method (often best trade-off)\n",
"pipe_piecewise = Pipeline(\n",
" [\n",
" (\"warp\", Warping(method=\"piecewise\", n_segments=10)),\n",
" (\"scaler\", StandardScaler()),\n",
" (\"clf\", LogisticRegression()),\n",
" ]\n",
")\n",
"\n",
"scores = cross_val_score(pipe_piecewise, X, y, cv=cv, scoring=\"roc_auc\")\n",
"print(f\"Piecewise - CV ROC AUC: {scores.mean():.3f} +/- {scores.std():.3f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Alignment Quality Assessment\n",
"\n",
"Use `get_alignment_quality()` to measure how well spectra were aligned to the reference."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean correlation improvement: 0.0056\n"
]
},
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" correlation_before | \n",
" correlation_after | \n",
" improvement | \n",
" rmse_before | \n",
" rmse_after | \n",
"
\n",
" \n",
" \n",
" \n",
" | 10s | \n",
" 0.850781 | \n",
" 0.850781 | \n",
" 0.000000 | \n",
" 0.000137 | \n",
" 0.000137 | \n",
"
\n",
" \n",
" | 11s | \n",
" 0.854397 | \n",
" 0.854397 | \n",
" 0.000000 | \n",
" 0.000185 | \n",
" 0.000185 | \n",
"
\n",
" \n",
" | 12s | \n",
" 0.898360 | \n",
" 0.898360 | \n",
" 0.000000 | \n",
" 0.000192 | \n",
" 0.000192 | \n",
"
\n",
" \n",
" | 13s | \n",
" 0.817404 | \n",
" 0.817404 | \n",
" 0.000000 | \n",
" 0.000240 | \n",
" 0.000240 | \n",
"
\n",
" \n",
" | 14s | \n",
" 0.825112 | \n",
" 0.825087 | \n",
" -0.000024 | \n",
" 0.000177 | \n",
" 0.000177 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" correlation_before correlation_after improvement rmse_before \\\n",
"10s 0.850781 0.850781 0.000000 0.000137 \n",
"11s 0.854397 0.854397 0.000000 0.000185 \n",
"12s 0.898360 0.898360 0.000000 0.000192 \n",
"13s 0.817404 0.817404 0.000000 0.000240 \n",
"14s 0.825112 0.825087 -0.000024 0.000177 \n",
"\n",
" rmse_after \n",
"10s 0.000137 \n",
"11s 0.000185 \n",
"12s 0.000192 \n",
"13s 0.000240 \n",
"14s 0.000177 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Fit warping and check alignment quality\n",
"warper = Warping(method=\"piecewise\", n_segments=10)\n",
"warper.fit(X)\n",
"X_aligned = warper.transform(X)\n",
"\n",
"# Get alignment quality metrics\n",
"quality = warper.get_alignment_quality(X, X_aligned)\n",
"print(f\"Mean correlation improvement: {quality['improvement'].mean():.4f}\")\n",
"quality.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Raw Spectra Warping\n",
"\n",
"`RawWarping` performs alignment at full m/z resolution (before binning) for higher precision. It loads raw spectra files during fit/transform and outputs properly binned data.\n",
"\n",
"**Key workflow:**\n",
"1. Use `create_raw_input()` to create input DataFrame with file paths\n",
"2. Pass this DataFrame to `RawWarping` in your pipeline\n",
"3. Get properly binned, aligned spectra as output\n",
"\n",
"This design makes `RawWarping` fully compatible with sklearn pipelines."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input DataFrame shape: (29, 1)\n",
"Columns: ['path']\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" path | \n",
"
\n",
" \n",
" \n",
" \n",
" | 10s | \n",
" ../data/10s.txt | \n",
"
\n",
" \n",
" | 11s | \n",
" ../data/11s.txt | \n",
"
\n",
" \n",
" | 12s | \n",
" ../data/12s.txt | \n",
"
\n",
" \n",
" | 13s | \n",
" ../data/13s.txt | \n",
"
\n",
" \n",
" | 14s | \n",
" ../data/14s.txt | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" path\n",
"10s ../data/10s.txt\n",
"11s ../data/11s.txt\n",
"12s ../data/12s.txt\n",
"13s ../data/13s.txt\n",
"14s ../data/14s.txt"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create input DataFrame from raw spectra directory\n",
"X_raw = create_raw_input(\"../data/\")\n",
"print(f\"Input DataFrame shape: {X_raw.shape}\")\n",
"print(f\"Columns: {X_raw.columns.tolist()}\")\n",
"X_raw.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape: (29, 1) (single 'path' column)\n",
"Output shape: (29, 6000) (binned spectra)\n",
"Output columns are m/z bin starting points: ['2000', '2003', '2006', '2009', '2012']...\n"
]
}
],
"source": [
"# RawWarping in a pipeline - outputs binned spectra\n",
"raw_warper = RawWarping(\n",
" method=\"piecewise\",\n",
" bin_width=3,\n",
" max_shift_da=10.0,\n",
" n_segments=5,\n",
")\n",
"\n",
"# Fit and transform - loads raw files, warps at full resolution, bins output\n",
"raw_warper.fit(X_raw)\n",
"X_raw_aligned = raw_warper.transform(X_raw)\n",
"print(f\"Input shape: {X_raw.shape} (single 'path' column)\")\n",
"print(f\"Output shape: {X_raw_aligned.shape} (binned spectra)\")\n",
"print(\n",
" f\"Output columns are m/z bin starting points: {X_raw_aligned.columns[:5].tolist()}...\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Parallelization\n",
"\n",
"Use `n_jobs` parameter to enable parallel processing for faster computation."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Aligned 29 spectra\n"
]
}
],
"source": [
"# Parallel warping (use all cores)\n",
"warper_parallel = Warping(method=\"piecewise\", n_segments=10, n_jobs=-1)\n",
"warper_parallel.fit(X)\n",
"X_aligned_parallel = warper_parallel.transform(X)\n",
"print(f\"Aligned {len(X)} spectra\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RawWarping in sklearn Pipeline\n",
"\n",
"Since `RawWarping` accepts a path-based DataFrame and outputs binned spectra, it integrates seamlessly into sklearn pipelines."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RawWarping Pipeline - CV ROC AUC: 0.375 +/- 0.250\n"
]
}
],
"source": [
"# Full pipeline: raw spectra -> alignment -> scaling -> classification\n",
"pipe_raw = Pipeline(\n",
" [\n",
" (\"warp\", RawWarping(method=\"piecewise\", bin_width=3, n_segments=5)),\n",
" (\"scaler\", StandardScaler()),\n",
" (\"clf\", LogisticRegression()),\n",
" ]\n",
")\n",
"\n",
"# Cross-validation with RawWarping pipeline\n",
"# Note: X_raw contains file paths, y contains labels\n",
"scores = cross_val_score(pipe_raw, X_raw, y, cv=cv, scoring=\"roc_auc\")\n",
"print(f\"RawWarping Pipeline - CV ROC AUC: {scores.mean():.3f} +/- {scores.std():.3f}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "maldiamrkit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}