Source code for maldiamrkit.visualization.alignment_plots

"""Alignment visualization functions."""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd

if TYPE_CHECKING:
    from matplotlib.figure import Figure

    from ..alignment.warping import Warping


def _show_with_warning(show: bool) -> None:
    """Match the ``show=True`` pattern used in sibling plot modules."""
    import matplotlib
    import matplotlib.pyplot as plt

    if show:
        if not matplotlib.is_interactive():
            warnings.warn(
                "matplotlib is using a non-interactive backend; "
                "plt.show() may not display the figure",
                UserWarning,
                stacklevel=3,
            )
        plt.show()


[docs] def plot_alignment( warper: Warping, X_original: pd.DataFrame, X_aligned: pd.DataFrame | None = None, indices: int | list[int] | None = None, *, show_peaks: bool = True, show_sample_peaks: bool = False, xlim: tuple[float, float] | None = None, figsize: tuple[float, float] | None = None, alpha: float = 0.7, color_reference: str = "black", color_original: str = "red", color_aligned: str = "blue", title: str | None = None, show: bool = True, ) -> tuple[Figure, np.ndarray]: """Plot comparison of original vs aligned spectra against reference. Parameters ---------- warper : Warping Fitted warping transformer. X_original : pd.DataFrame Original (unaligned) spectra. X_aligned : pd.DataFrame, optional Aligned spectra. If None, will compute by calling transform(). indices : int or list of int, optional Indices of spectra to plot. If None, plots the first spectrum. show_peaks : bool, default=True Whether to draw reference peak positions (vertical dashed lines). These are the calibration markers used to judge alignment quality and are on by default. show_sample_peaks : bool, default=False If True, additionally draw per-sample (and per-aligned) peak positions as dashed vertical lines. Off by default because dense peak sets clutter the panel. xlim : tuple of (float, float), optional X-axis limits for zooming into specific m/z range. figsize : tuple of (float, float), optional Figure size in inches. When ``None``, defaults to ``(14, 3 * n_spectra)``. alpha : float, default=0.7 Transparency for spectrum lines. color_reference : str, default="black" Line colour for the reference spectrum. color_original : str, default="red" Line colour for the original (before-alignment) spectrum. color_aligned : str, default="blue" Line colour for the aligned (after-alignment) spectrum. title : str, optional Figure-level title (suptitle). Defaults to ``f"Warping ({warper.method})"``. show : bool, default=True Call ``plt.show()`` at the end. Returns ------- fig : matplotlib.figure.Figure The generated figure. axes : ndarray of matplotlib.axes.Axes 2-D array of shape ``(n_spectra, 2)``: column 0 = before, 1 = after. Raises ------ RuntimeError If the transformer has not been fitted. ValueError If any index is out of bounds for the data. """ import matplotlib.pyplot as plt indices, mz_axis, X_aligned = _validate_alignment_inputs( warper, X_original, X_aligned, indices ) peaks_ctx = _compute_peak_positions( warper, X_original, X_aligned, indices, mz_axis, show_peaks, show_sample_peaks, ) n_spectra = len(indices) if figsize is None: figsize = (14, max(3.0, 3.0 * n_spectra)) fig, axes = plt.subplots( n_spectra, 2, figsize=figsize, squeeze=False, sharex=True, sharey=True, ) for plot_idx, spectrum_idx in enumerate(indices): original = X_original.iloc[spectrum_idx].to_numpy() aligned = X_aligned.iloc[spectrum_idx].to_numpy() is_bottom = plot_idx == n_spectra - 1 _plot_alignment_panel( axes[plot_idx, 0], mz_axis, warper.ref_spec_, original, spectrum_idx, peaks_ctx["ref"], peaks_ctx["sample"].get(spectrum_idx), xlim, alpha, label="Original", sample_color=color_original, ref_color=color_reference, column_title="Before" if plot_idx == 0 else None, ylabel=f"idx={spectrum_idx}", is_bottom=is_bottom, show_sample_peaks=show_sample_peaks, ) _plot_alignment_panel( axes[plot_idx, 1], mz_axis, warper.ref_spec_, aligned, spectrum_idx, peaks_ctx["ref"], peaks_ctx["aligned"].get(spectrum_idx), xlim, alpha, label="Aligned", sample_color=color_aligned, ref_color=color_reference, column_title="After" if plot_idx == 0 else None, ylabel=None, is_bottom=is_bottom, show_sample_peaks=show_sample_peaks, ) fig.suptitle(title or f"Warping ({warper.method})") fig.tight_layout() _show_with_warning(show) return fig, axes
def _validate_alignment_inputs(warper, X_original, X_aligned, indices): """Validate inputs and normalize indices for alignment plotting.""" if not hasattr(warper, "ref_spec_"): raise RuntimeError("Warping must be fitted before plotting") if X_aligned is None: X_aligned = warper.transform(X_original) if indices is None: indices = [0] elif isinstance(indices, int): indices = [indices] for idx in indices: if idx < 0 or idx >= len(X_original): raise ValueError( f"Index {idx} out of bounds for data with {len(X_original)} samples" ) mz_axis = X_original.columns.to_numpy() if not np.issubdtype(mz_axis.dtype, np.number): mz_axis = np.arange(len(mz_axis)) return indices, mz_axis, X_aligned def _compute_peak_positions( warper, X_original, X_aligned, indices, mz_axis, show_peaks, show_sample_peaks, ): """Compute peak positions for reference and (optionally) selected spectra. Returns a dict ``{"ref": np.ndarray | None, "sample": dict[int, np.ndarray], "aligned": dict[int, np.ndarray]}``. DTW outputs are detected the same way as any other aligned spectra so users see peak markers regardless of warping method. """ ctx: dict = {"ref": None, "sample": {}, "aligned": {}} if not show_peaks: return ctx ref_peaks_df = warper.peak_detector.transform( pd.DataFrame(warper.ref_spec_[np.newaxis, :], columns=X_original.columns) ) ctx["ref"] = mz_axis[ref_peaks_df.iloc[0].to_numpy().nonzero()[0]] if show_sample_peaks: sample_peaks_df = warper.peak_detector.transform(X_original.iloc[indices]) aligned_peaks_df = warper.peak_detector.transform(X_aligned.iloc[indices]) for i, idx in enumerate(indices): ctx["sample"][idx] = mz_axis[ sample_peaks_df.iloc[i].to_numpy().nonzero()[0] ] ctx["aligned"][idx] = mz_axis[ aligned_peaks_df.iloc[i].to_numpy().nonzero()[0] ] return ctx def _plot_alignment_panel( ax, mz_axis, ref_spec, sample_spec, spectrum_idx, ref_peaks, sample_peaks, xlim, alpha, *, label, sample_color, ref_color, column_title, ylabel, is_bottom, show_sample_peaks, ): """Draw one (before or after) panel of the alignment plot.""" ax.plot( mz_axis, ref_spec, label="Reference", color=ref_color, linewidth=1.5, alpha=alpha, ) ax.plot( mz_axis, sample_spec, label=f"{label} (idx={spectrum_idx})", color=sample_color, linewidth=1, alpha=alpha, ) if ref_peaks is not None: for peak in ref_peaks: ax.axvline(peak, color=ref_color, linestyle="--", alpha=0.3, linewidth=0.8) if show_sample_peaks and sample_peaks is not None: for peak in sample_peaks: ax.axvline( peak, color=sample_color, linestyle="--", alpha=0.3, linewidth=0.8, ) if column_title is not None: ax.set_title(column_title) if ylabel is not None: ax.set_ylabel(ylabel) if is_bottom: ax.set_xlabel("m/z (Da)") ax.legend(loc="upper right") ax.grid(True, alpha=0.3) if xlim: ax.set_xlim(xlim)