"""Spectrum and dataset plotting functions."""
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
import pandas as pd
from ._common import DEFAULT_LABEL_MAP, order_labels, show_with_warning
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from ..dataset import MaldiSet
from ..spectrum import MaldiSpectrum
SpectrumStage = Literal["binned", "preprocessed", "raw"]
def _resolve_stage(
spectrum: MaldiSpectrum,
stage: SpectrumStage,
binned: bool | None,
) -> tuple[SpectrumStage, pd.DataFrame]:
"""Pick the stage to plot and fetch its DataFrame.
Handles the deprecated ``binned`` boolean: ``binned=True`` maps to
``"binned"`` and ``binned=False`` falls back to the best available
non-binned stage (``"preprocessed"`` over ``"raw"``).
"""
if binned is not None:
warnings.warn(
"plot_spectrum(binned=...) is deprecated; use "
"stage='binned'|'preprocessed'|'raw' instead.",
DeprecationWarning,
stacklevel=3,
)
if binned:
stage = "binned"
else:
stage = "preprocessed" if spectrum.is_preprocessed else "raw"
if stage == "binned":
return stage, spectrum.binned
if stage == "preprocessed":
return stage, spectrum.preprocessed
if stage == "raw":
return stage, spectrum.raw
raise ValueError(
f"Unknown stage {stage!r}; expected 'binned', 'preprocessed', or 'raw'."
)
[docs]
def plot_spectrum(
spectrum: MaldiSpectrum,
*,
stage: SpectrumStage = "binned",
peaks: list[float] | np.ndarray | None = None,
highlight_regions: list[tuple[float, float]] | None = None,
ax: Axes | None = None,
color: str | None = None,
figsize: tuple[float, float] = (10, 4),
title: str | None = None,
log_y: bool = False,
ylim: tuple[float, float] | None = None,
show: bool = True,
binned: bool | None = None,
**kwargs: Any,
) -> tuple[Figure, Axes]:
"""Plot a single MALDI-TOF spectrum with real m/z axis.
Parameters
----------
spectrum : MaldiSpectrum
Spectrum to plot.
stage : {"binned", "preprocessed", "raw"}, default="binned"
Processing stage to render. ``"binned"`` uses a bar plot with
bar width inferred from the bin spacing; ``"preprocessed"`` and
``"raw"`` use a line plot.
peaks : list of float or ndarray, optional
If given, draw a scatter marker above the spectrum at each
peak m/z.
highlight_regions : list of (mz_min, mz_max) tuples, optional
Shaded m/z bands drawn behind the spectrum (e.g. regions of
interest from differential analysis).
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, creates a new figure.
color : str, optional
Colour for the spectrum (bars / line). Matplotlib default used
when None.
figsize : tuple of float, default=(10, 4)
Figure size in inches (only used when ``ax`` is None).
title : str, optional
Overrides the auto-generated title
(``"{spectrum.id} ({stage})"``).
log_y : bool, default=False
Use a logarithmic y-axis.
ylim : tuple of float, optional
Override y-axis limits. Defaults to matplotlib autoscaling
(no clipping of negatives).
show : bool, default=True
Call ``plt.show()`` at the end.
binned : bool, optional
*Deprecated.* Use ``stage=`` instead. ``binned=True`` maps to
``stage="binned"``; ``binned=False`` maps to ``"preprocessed"``
if available, else ``"raw"``.
**kwargs : dict
Additional keyword arguments forwarded to ``ax.bar`` (binned
stage) or ``ax.plot`` (raw / preprocessed).
Returns
-------
fig : matplotlib.figure.Figure
ax : matplotlib.axes.Axes
"""
import matplotlib.pyplot as plt
resolved_stage, data = _resolve_stage(spectrum, stage, binned)
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.get_figure()
mass = data["mass"].to_numpy(dtype=float)
intensity = data["intensity"].to_numpy(dtype=float)
if highlight_regions:
for low, high in highlight_regions:
ax.axvspan(low, high, color="goldenrod", alpha=0.15, zorder=0)
if resolved_stage == "binned":
# Bar width derived from bin spacing: median diff is robust to
# non-uniform binning (adaptive / custom), uniform falls through.
if len(mass) >= 2:
bar_width = float(np.median(np.diff(mass)))
else:
bar_width = 1.0
ax.bar(
mass,
intensity,
width=bar_width,
align="center",
color=color,
linewidth=0,
**kwargs,
)
else:
ax.plot(mass, intensity, color=color, **kwargs)
if peaks is not None:
peaks = np.asarray(peaks, dtype=float)
if peaks.size:
y_at_peaks = np.interp(peaks, mass, intensity)
ax.scatter(
peaks,
y_at_peaks,
marker="v",
color="crimson",
s=30,
zorder=5,
label="peak",
)
ax.set_xlabel("m/z (Da)")
ax.set_ylabel("intensity")
ax.set_title(title or f"{spectrum.id} ({resolved_stage})")
if log_y:
ax.set_yscale("log")
if ylim is not None:
ax.set_ylim(ylim)
show_with_warning(show)
return fig, ax
[docs]
def plot_pseudogel(
dataset: MaldiSet,
*,
antibiotic: str | None = None,
species: str | None = None,
regions: tuple[float, float] | list[tuple[float, float]] | None = None,
cmap: str = "inferno",
vmin: float | None = None,
vmax: float | None = None,
figsize: tuple[float, float] | None = None,
log_scale: bool = True,
sort_by: str | None = "intensity",
label_map: dict | None = None,
title: str | None = None,
show: bool = True,
sort_by_intensity: bool | None = None,
) -> tuple[Figure, np.ndarray]:
"""Display a pseudogel heatmap of the spectra.
Creates one subplot per unique value of the antibiotic column, in
susceptibility order (S, I, R) with unknown labels appended
alphabetically.
Parameters
----------
dataset : MaldiSet
Dataset to visualize.
antibiotic : str, optional
Target column to group by. Defaults to the first configured
antibiotic in the MaldiSet.
species : str, optional
When given, restrict the pseudogel to that species via
:class:`~maldiamrkit.filters.SpeciesFilter`. Default ``None``
keeps all samples.
regions : tuple or list of tuples, optional
m/z region(s) to display. None shows all.
cmap : str, default="inferno"
Matplotlib colormap name.
vmin, vmax : float, optional
Colour-scale limits in the *raw intensity* units the caller
is familiar with. When ``log_scale=True`` both values are
automatically mapped through ``np.log1p`` before being passed
to ``imshow``, so the plotted range matches what the user
specified.
figsize : tuple, optional
Figure size. Defaults to ``(14.0, 2.5 * n_groups)`` so the
m/z axis is wide enough for typical binned data (thousands of
columns).
log_scale : bool, default=True
Apply ``np.log1p`` to intensities.
sort_by : {"intensity", "id", None}, default="intensity"
How to order samples within each group:
- ``"intensity"``: sort by mean intensity (descending).
- ``"id"``: sort by the sample's index value (deterministic).
- ``None``: keep the order encountered in the MaldiSet.
label_map : dict, optional
Mapping from raw group label to display name. Default maps
0/1 and R/I/S to ``"Susceptible (S)"`` / ``"Intermediate (I)"``
/ ``"Resistant (R)"``; any other value is stringified as-is.
Pass a dict to override.
title : str, optional
Figure title. Defaults to ``f"Pseudogel: {antibiotic}"`` when
omitted.
show : bool, default=True
Call ``plt.show()`` at the end.
sort_by_intensity : bool, optional
*Deprecated.* Use ``sort_by=`` instead. Retained for
backwards-compatibility; ``True`` maps to ``sort_by="intensity"``
and ``False`` maps to ``sort_by=None``.
Returns
-------
fig : matplotlib.figure.Figure
axes : ndarray of Axes
Raises
------
ValueError
If the antibiotic column is not defined, if a region has
min_mz > max_mz, if no m/z values lie within a specified
region, or if ``sort_by`` is not one of the recognised values.
"""
import matplotlib.pyplot as plt
from ..filters import SpeciesFilter
if sort_by_intensity is not None:
warnings.warn(
"plot_pseudogel(sort_by_intensity=...) is deprecated; use "
"sort_by='intensity'|'id'|None instead.",
DeprecationWarning,
stacklevel=2,
)
sort_by = "intensity" if sort_by_intensity else None
if sort_by not in (None, "intensity", "id"):
raise ValueError(
f"sort_by must be 'intensity', 'id', or None; got {sort_by!r}."
)
if antibiotic is None:
antibiotic = dataset.antibiotics[0] if dataset.antibiotics else None
if antibiotic is None:
raise ValueError("Antibiotic column not defined.")
if species is not None:
dataset = dataset.filter(SpeciesFilter(species))
X = dataset.X.copy()
y = dataset.get_y_single(antibiotic)
X = _apply_region_filter(X, regions)
groups = y.groupby(y).groups
n_groups = len(groups)
if figsize is None:
figsize = (14.0, 2.5 * max(1, n_groups))
fig, axes = plt.subplots(
n_groups, 1, figsize=figsize, sharex=True, constrained_layout=True
)
if n_groups == 1:
axes = np.asarray([axes])
cmap_obj = plt.get_cmap(cmap).copy()
cmap_obj.set_bad(color="white", alpha=1.0)
# Map user-supplied vmin/vmax from raw-intensity to display units
# so they behave consistently with log_scale.
display_vmin = np.log1p(vmin) if (log_scale and vmin is not None) else vmin
display_vmax = np.log1p(vmax) if (log_scale and vmax is not None) else vmax
merged_label_map: dict = dict(DEFAULT_LABEL_MAP)
if label_map:
merged_label_map.update(label_map)
ordered_items = [(lab, groups[lab]) for lab in order_labels(list(groups))]
im = None
for ax, (label, idx) in zip(axes, ordered_items, strict=True):
display_label = merged_label_map.get(label, str(label))
im = _render_pseudogel_group(
ax,
X.loc[idx].to_numpy(),
display_label,
log_scale,
sort_by,
cmap_obj,
display_vmin,
display_vmax,
sample_ids=list(idx),
)
_set_pseudogel_xaxis(axes, X)
if im is not None:
cbar = fig.colorbar(im, ax=axes, orientation="vertical", pad=0.01)
cbar.set_label("Log(intensity + 1)" if log_scale else "intensity")
fig.suptitle(title or f"Pseudogel: {antibiotic}")
if show:
plt.show()
return fig, axes
def _render_pseudogel_group(
ax,
M,
label,
log_scale,
sort_by,
cmap_obj,
vmin,
vmax,
*,
sample_ids=None,
):
"""Render a single group panel in a pseudogel heatmap."""
order: np.ndarray | None = None
if sort_by == "intensity":
order = np.argsort(np.nanmean(M, axis=1))[::-1]
elif sort_by == "id" and sample_ids is not None:
order = np.argsort([str(sid) for sid in sample_ids])
if order is not None:
M = M[order]
if log_scale:
M = np.log1p(M)
im = ax.imshow(
M,
aspect="auto",
interpolation="nearest",
cmap=cmap_obj,
vmin=vmin,
vmax=vmax,
)
ax.set_ylabel(f"{label}\n(n={M.shape[0]})", rotation=0, ha="right", va="center")
ax.set_yticks([])
return im
def _apply_region_filter(
X: pd.DataFrame,
regions: tuple[float, float] | list[tuple[float, float]] | None,
) -> pd.DataFrame:
"""Filter feature matrix to specified m/z regions."""
if regions is None:
return X
if (
isinstance(regions, tuple)
and len(regions) == 2
and not isinstance(regions[0], (tuple, list))
):
regions = [regions]
mz_values = X.columns.astype(float)
region_dfs = []
for min_mz, max_mz in regions:
if min_mz > max_mz:
raise ValueError(f"Invalid region: min_mz ({min_mz}) > max_mz ({max_mz})")
mask = (mz_values >= min_mz) & (mz_values <= max_mz)
if not mask.any():
raise ValueError(f"No m/z values found in region ({min_mz}, {max_mz})")
region_dfs.append(X.iloc[:, mask])
if len(region_dfs) < len(regions):
blank_col = pd.DataFrame(
np.nan, index=X.index, columns=[f"_blank_{len(region_dfs)}"]
)
region_dfs.append(blank_col)
return pd.concat(region_dfs, axis=1)
def _set_pseudogel_xaxis(axes: np.ndarray, X: pd.DataFrame) -> None:
"""Set x-axis ticks and labels for pseudogel plot."""
n_ticks = min(10, X.shape[1])
xticks = np.linspace(0, X.shape[1] - 1, n_ticks, dtype=int)
xticklabels = []
for i in xticks:
col_name = str(X.columns[i])
if col_name.startswith("_blank_"):
xticklabels.append("")
else:
xticklabels.append(col_name)
axes[-1].set_xticks(xticks)
axes[-1].set_xticklabels(xticklabels, rotation=90)
axes[-1].set_xlabel("m/z (binned)")