Source code for maldiamrkit.differential.plots

"""Visualizations for differential analysis."""

from __future__ import annotations

import warnings
from enum import Enum
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd

from ..visualization._common import show_with_warning

if TYPE_CHECKING:
    import matplotlib.pyplot as plt


_VOLCANO_EPS = 1e-300


[docs] class DrugComparisonKind(str, Enum): """Rendering kind for :func:`~maldiamrkit.differential.plot_drug_comparison`. Attributes ---------- heatmap : str Boolean ``rows x drugs`` heatmap (compact, precise positions). upset : str UpSet-style intersection plot: bar chart of intersection sizes plus a dot matrix of drug membership. """ heatmap = "heatmap" upset = "upset"
def _annotate_top_k(ax, xs, ys, labels_like, k: int) -> None: """Place text labels at the top-k points (largest ys). ``labels_like`` is aligned with ``xs``/``ys`` and used as the text content (typically the ``mz_bin`` value). """ if k <= 0 or len(ys) == 0: return order = np.argsort(ys)[::-1][:k] for i in order: ax.annotate( f"{labels_like[i]}", xy=(xs[i], ys[i]), xytext=(3, 3), textcoords="offset points", fontsize=8, color="black", )
[docs] def plot_volcano( results: pd.DataFrame, fc_threshold: float = 1.0, p_threshold: float = 0.05, *, ax: plt.Axes | None = None, title: str | None = None, drug: str | None = None, figsize: tuple[float, float] = (8, 6), annotate_top_k: int | None = None, grid: bool = True, show: bool = True, ) -> tuple[plt.Figure, plt.Axes]: r"""Volcano plot of log2 fold change vs. -log10 adjusted p-value. Points are coloured by direction and significance: grey for non-significant, red for up in resistant (``fold_change > fc_threshold`` and ``adjusted_p_value <= p_threshold``), blue for up in susceptible (``fold_change < -fc_threshold`` and ``adjusted_p_value <= p_threshold``). Horizontal and vertical dashed lines mark the thresholds and are referenced in the legend with their counts. Parameters ---------- results : pd.DataFrame Output of :attr:`DifferentialAnalysis.results`. Must contain ``fold_change`` and ``adjusted_p_value`` columns. fc_threshold : float, default=1.0 Absolute log2 fold-change threshold (drawn as vertical dashed lines at :math:`\pm` ``fc_threshold``). p_threshold : float, default=0.05 Adjusted p-value threshold (drawn as a horizontal dashed line at ``-log10(p_threshold)``). ax : Axes or None, default=None Pre-existing axes. If ``None``, a new figure and axes are created. title : str or None, default=None Plot title. Defaults to ``"Volcano plot"``; if ``drug`` is given, the default becomes ``f"Volcano plot - {drug}"``. drug : str or None, default=None Drug name appended to the default title. Ignored when ``title`` is explicitly provided. figsize : tuple of float, default=(8, 6) Figure size in inches (only used when ``ax`` is ``None``). annotate_top_k : int, optional If given, label the ``k`` most significant peaks with their ``mz_bin`` value. Requires an ``mz_bin`` column in ``results``. grid : bool, default=True Draw a faint background grid. show : bool, default=True Call ``plt.show()`` at the end. Returns ------- fig : matplotlib.figure.Figure ax : matplotlib.axes.Axes """ import matplotlib.pyplot as plt required = {"fold_change", "adjusted_p_value"} missing = required - set(results.columns) if missing: raise ValueError(f"'results' is missing required columns: {sorted(missing)}") if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() fc = results["fold_change"].to_numpy(dtype=float) adj_p = results["adjusted_p_value"].to_numpy(dtype=float) neg_log10_p = -np.log10(np.clip(adj_p, _VOLCANO_EPS, 1.0)) sig = adj_p <= p_threshold up_r = sig & (fc > fc_threshold) up_s = sig & (fc < -fc_threshold) ns = ~(up_r | up_s) ax.scatter( fc[ns], neg_log10_p[ns], s=10, color="lightgrey", label=f"NS (n={int(ns.sum())})", alpha=0.6, ) ax.scatter( fc[up_s], neg_log10_p[up_s], s=14, color="#3b82f6", label=f"Up in S (n={int(up_s.sum())})", alpha=0.85, ) ax.scatter( fc[up_r], neg_log10_p[up_r], s=14, color="#ef4444", label=f"Up in R (n={int(up_r.sum())})", alpha=0.85, ) ax.axhline(-np.log10(p_threshold), color="black", linestyle="--", linewidth=0.8) ax.axvline(fc_threshold, color="black", linestyle="--", linewidth=0.8) ax.axvline(-fc_threshold, color="black", linestyle="--", linewidth=0.8) if annotate_top_k and "mz_bin" in results.columns: labels_like = results["mz_bin"].to_numpy() _annotate_top_k(ax, fc, neg_log10_p, labels_like, annotate_top_k) ax.set_xlabel("log2 fold change (R / S)") ax.set_ylabel(r"$-\log_{10}$(adjusted p-value)") if title is None: title = "Volcano plot" + (f" - {drug}" if drug else "") ax.set_title(title) ax.legend(loc="best", frameon=False) if grid: ax.grid(True, alpha=0.3) show_with_warning(show) return fig, ax
[docs] def plot_manhattan( results: pd.DataFrame, p_threshold: float = 0.05, *, ax: plt.Axes | None = None, title: str | None = None, drug: str | None = None, figsize: tuple[float, float] = (12, 4), annotate_top_k: int | None = None, grid: bool = True, show: bool = True, ) -> tuple[plt.Figure, plt.Axes]: """Manhattan plot along the m/z axis. x-axis is the numeric m/z bin value; y-axis is ``-log10(adjusted_p_value)``. Points with ``adjusted_p_value <= p_threshold`` are highlighted in red, and the legend reports per-class counts. Parameters ---------- results : pd.DataFrame Output of :attr:`DifferentialAnalysis.results`. Must contain ``mz_bin`` and ``adjusted_p_value`` columns. ``mz_bin`` values that cannot be coerced to float are excluded. p_threshold : float, default=0.05 Adjusted p-value threshold. ax : Axes or None, default=None Pre-existing axes. title : str or None, default=None Plot title. Defaults to ``"Manhattan plot"``; if ``drug`` is given, the default becomes ``f"Manhattan plot - {drug}"``. drug : str or None, default=None Drug name appended to the default title. Ignored when ``title`` is explicitly provided. figsize : tuple of float, default=(12, 4) Figure size in inches. annotate_top_k : int, optional If given, label the ``k`` most significant peaks with their ``mz_bin`` value. grid : bool, default=True Draw a faint background grid. show : bool, default=True Call ``plt.show()`` at the end. Returns ------- fig : matplotlib.figure.Figure ax : matplotlib.axes.Axes """ import matplotlib.pyplot as plt required = {"mz_bin", "adjusted_p_value"} missing = required - set(results.columns) if missing: raise ValueError(f"'results' is missing required columns: {sorted(missing)}") if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() mz_values = pd.to_numeric(results["mz_bin"], errors="coerce").to_numpy() adj_p = results["adjusted_p_value"].to_numpy(dtype=float) valid = ~np.isnan(mz_values) mz_values = mz_values[valid] adj_p = adj_p[valid] raw_bins = ( results.loc[valid, "mz_bin"].to_numpy() if "mz_bin" in results.columns else mz_values ) neg_log10_p = -np.log10(np.clip(adj_p, _VOLCANO_EPS, 1.0)) sig = adj_p <= p_threshold ax.scatter( mz_values[~sig], neg_log10_p[~sig], s=8, color="#6b7280", alpha=0.6, label=f"NS (n={int((~sig).sum())})", ) ax.scatter( mz_values[sig], neg_log10_p[sig], s=14, color="#ef4444", alpha=0.9, label=f"Significant (n={int(sig.sum())})", ) ax.axhline(-np.log10(p_threshold), color="black", linestyle="--", linewidth=0.8) if annotate_top_k: _annotate_top_k(ax, mz_values, neg_log10_p, raw_bins, annotate_top_k) ax.set_xlabel("m/z") ax.set_ylabel(r"$-\log_{10}$(adjusted p-value)") if title is None: title = "Manhattan plot" + (f" - {drug}" if drug else "") ax.set_title(title) ax.legend(loc="best", frameon=False) if grid: ax.grid(True, alpha=0.3) show_with_warning(show) return fig, ax
def _plot_drug_comparison_heatmap( comparison_df: pd.DataFrame, ax: plt.Axes | None, title: str | None, figsize: tuple[float, float], ) -> tuple[plt.Figure, plt.Axes]: """Render the boolean comparison as a seaborn binary heatmap.""" import matplotlib.pyplot as plt import seaborn as sns if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() data = comparison_df.astype(int) # Append per-drug significant-peak counts to the column labels so # the reader sees both the presence pattern AND how many peaks each # drug contributes. counts = data.sum(axis=0) xticklabels = [f"{col} (n={int(counts[col])})" for col in data.columns] sns.heatmap( data, ax=ax, cmap=["#f3f4f6", "#ef4444"], cbar=False, linewidths=0.25, linecolor="white", xticklabels=xticklabels, yticklabels=data.shape[0] <= 60, ) ax.set_xlabel("Drug") ax.set_ylabel("m/z bin") ax.set_title(title or "Drug comparison") return fig, ax def _plot_drug_comparison_upset( comparison_df: pd.DataFrame, title: str | None, figsize: tuple[float, float], ) -> tuple[plt.Figure, plt.Axes]: """Render an UpSet-style intersection plot using matplotlib only. Layout follows the ``UpSetPlot`` convention (Nothman, 2018): - top-right: intersection-size bar chart - bottom-left: per-set totals bar chart (horizontal, right-to-left) - bottom-right: dot matrix of set membership, with alternating row shading for readability The returned Axes is the intersection-size bar chart at the top. """ import matplotlib.pyplot as plt bool_df = comparison_df.astype(bool) drugs = list(bool_df.columns) if bool_df.empty or not drugs: fig, ax = plt.subplots(figsize=figsize) ax.text( 0.5, 0.5, "No significant peaks", ha="center", va="center", transform=ax.transAxes, ) ax.set_axis_off() ax.set_title(title or "Drug comparison") return fig, ax signatures = bool_df.apply(lambda row: tuple(row.values), axis=1) sig_counts = signatures.value_counts() all_false = tuple(False for _ in drugs) if all_false in sig_counts.index: sig_counts = sig_counts.drop(all_false) sig_counts = sig_counts.sort_values(ascending=False) combos = list(sig_counts.index) counts = sig_counts.to_numpy() set_totals = bool_df.sum(axis=0).to_numpy() # Palette close to UpSetPlot defaults dot_on = "#1f1f1f" dot_off = "#d9d9d9" bar_color = "#1f1f1f" row_shade = "#f2f2f2" fig = plt.figure(figsize=figsize) matrix_height = max(1.5, 0.45 * len(drugs)) gs = fig.add_gridspec( nrows=2, ncols=2, height_ratios=[3.0, matrix_height], width_ratios=[1.0, max(3.0, 0.55 * len(combos))], hspace=0.06, wspace=0.04, ) ax_bars = fig.add_subplot(gs[0, 1]) ax_totals = fig.add_subplot(gs[1, 0]) ax_matrix = fig.add_subplot(gs[1, 1], sharex=ax_bars, sharey=ax_totals) ax_corner = fig.add_subplot(gs[0, 0]) ax_corner.set_axis_off() x = np.arange(len(combos)) ax_bars.bar(x, counts, color=bar_color, width=0.55, zorder=2) for xi, ci in zip(x, counts, strict=True): ax_bars.text( float(xi), float(ci), str(int(ci)), ha="center", va="bottom", fontsize=9, color=bar_color, ) ax_bars.set_ylabel("Intersection size") ax_bars.grid(axis="y", linestyle=":", linewidth=0.6, color="#bdbdbd", zorder=0) ax_bars.set_axisbelow(True) for spine in ("top", "right"): ax_bars.spines[spine].set_visible(False) ax_bars.tick_params(axis="x", bottom=False, labelbottom=False) ax_bars.set_title(title or "Drug comparison") y_pos = np.arange(len(drugs)) ax_totals.barh(y_pos, set_totals, color=bar_color, height=0.55, zorder=2) ax_totals.set_xlabel("Set size") ax_totals.invert_xaxis() ax_totals.set_yticks(y_pos) # Append n= counts to drug labels for parity with the heatmap view. ax_totals.set_yticklabels( [f"{d} (n={int(t)})" for d, t in zip(drugs, set_totals, strict=True)] ) ax_totals.grid(axis="x", linestyle=":", linewidth=0.6, color="#bdbdbd", zorder=0) ax_totals.set_axisbelow(True) for spine in ("top", "right", "left"): ax_totals.spines[spine].set_visible(False) ax_totals.tick_params(axis="y", left=False) for yi in range(len(drugs)): if yi % 2 == 0: ax_matrix.axhspan(yi - 0.5, yi + 0.5, color=row_shade, zorder=0) for col_idx, combo in enumerate(combos): col_x: float = float(col_idx) members: list[float] = [float(yi) for yi, flag in enumerate(combo) if flag] non_members: list[float] = [ float(yi) for yi in range(len(drugs)) if float(yi) not in members ] if non_members: ax_matrix.scatter( [col_x] * len(non_members), non_members, s=80, color=dot_off, edgecolors="none", zorder=1, ) if members: ax_matrix.scatter( [col_x] * len(members), members, s=80, color=dot_on, edgecolors="none", zorder=3, ) if len(members) > 1: ax_matrix.plot( [col_x, col_x], [min(members), max(members)], color=dot_on, linewidth=1.8, zorder=2, ) ax_matrix.set_xticks(x) ax_matrix.set_xticklabels([]) ax_matrix.set_xlim(-0.5, len(combos) - 0.5) ax_matrix.set_ylim(-0.5, len(drugs) - 0.5) ax_matrix.invert_yaxis() for spine in ("top", "right", "bottom", "left"): ax_matrix.spines[spine].set_visible(False) ax_matrix.tick_params(axis="x", bottom=False) ax_matrix.tick_params(axis="y", left=False, labelleft=False) return fig, ax_bars
[docs] def plot_drug_comparison( comparison_df: pd.DataFrame, *, kind: str | DrugComparisonKind = DrugComparisonKind.heatmap, ax: plt.Axes | None = None, title: str | None = None, figsize: tuple[float, float] = (10, 8), show: bool = True, ) -> tuple[plt.Figure, plt.Axes]: """Visualise a multi-drug differential-peak comparison matrix. Parameters ---------- comparison_df : pd.DataFrame Boolean significance matrix from :meth:`DifferentialAnalysis.compare_drugs`. Index = m/z bins, columns = drug names, values coerced to ``bool``. kind : {"heatmap", "upset"} or DrugComparisonKind, default="heatmap" Rendering style. - ``"heatmap"``: compact binary heatmap of peaks x drugs. Drug labels show per-drug significant-peak counts. - ``"upset"``: UpSet-style plot showing intersection counts across drug combinations. ax : Axes or None, default=None Pre-existing axes (used only by ``kind="heatmap"``; ignored for ``"upset"`` which needs its own composite figure). title : str or None, default=None Plot title. Defaults to ``"Drug comparison"``. figsize : tuple of float, default=(10, 8) Figure size in inches (only used when ``ax`` is ``None``). show : bool, default=True Call ``plt.show()`` at the end. Returns ------- fig : matplotlib.figure.Figure ax : matplotlib.axes.Axes For ``kind="upset"``, the returned Axes is the intersection-size bar chart; the drug-membership matrix is drawn on a second Axes inside the same Figure. """ kind = DrugComparisonKind(kind) if kind == DrugComparisonKind.heatmap: fig, ax = _plot_drug_comparison_heatmap( comparison_df, ax=ax, title=title, figsize=figsize ) else: if ax is not None: warnings.warn( "plot_drug_comparison(kind='upset') creates its own composite " "figure; the provided 'ax' is ignored.", UserWarning, stacklevel=2, ) fig, ax = _plot_drug_comparison_upset( comparison_df, title=title, figsize=figsize ) show_with_warning(show) return fig, ax