Source code for maldiamrkit.similarity.plots

"""Visualizations for spectral similarity analysis."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from ..visualization._common import show_with_warning

if TYPE_CHECKING:
    import matplotlib.pyplot as plt


# Known theoretical upper bounds for spectral-distance metrics.  When
# the caller passes `metric=...`, the colorbar is clamped to these
# bounds (unless explicit ``vmin``/``vmax`` are given) so the colour
# scale becomes comparable across heatmaps computed with the same
# metric.  Metrics without a finite upper bound (wasserstein, dtw) are
# omitted.
_METRIC_BOUNDS: dict[str, tuple[float, float]] = {
    "cosine": (0.0, 1.0),
    "pearson": (0.0, 2.0),
    "spectral_contrast_angle": (0.0, float(np.pi / 2)),
}


def _dynamic_heatmap_figsize(n: int) -> tuple[float, float]:
    r"""Scale the figure side with matrix size, capped at 16\"."""
    side = min(16.0, 4.0 + 0.1 * n)
    return (side, side)


[docs] def plot_distance_heatmap( distance_matrix: np.ndarray, labels: list[str] | np.ndarray | None = None, *, metric: str | None = None, cmap: str = "viridis", ax: plt.Axes | None = None, title: str | None = None, figsize: tuple[float, float] | None = None, annot: bool | None = None, cluster: bool = False, vmin: float | None = None, vmax: float | None = None, cbar_label: str = "distance", show: bool = True, ) -> tuple[plt.Figure, plt.Axes]: """Plot a pairwise distance matrix as a heatmap. Parameters ---------- distance_matrix : ndarray of shape (n, n) Symmetric distance matrix. labels : list of str, ndarray, or None, default=None Tick labels for rows and columns. metric : str, optional Name of the distance metric used (e.g. ``"cosine"``, ``"pearson"``, ``"spectral_contrast_angle"``). When given and recognised, the colourbar limits are clamped to the metric's theoretical bounds so heatmaps computed with the same metric share a comparable colour scale. Explicit ``vmin``/``vmax`` always win. cmap : str, default="viridis" Matplotlib / seaborn colormap name. ax : Axes or None, default=None Pre-existing axes. If ``None``, a new figure and axes are created. title : str or None, default=None Plot title. Defaults to ``"Pairwise distance"`` (including the metric name, if provided). figsize : tuple of float, optional Figure size in inches. When ``None``, scales with ``n`` (``side = min(16, 4 + 0.1 * n)``). Only used when ``ax`` is ``None``. annot : bool, optional When ``True``, annotate each cell with its distance value. When ``None`` (default), annotate iff the matrix is small (``n ≤ 15``). cluster : bool, default=False When ``True``, reorder rows and columns via hierarchical clustering so similar samples group visually. Labels reorder accordingly. vmin, vmax : float, optional Explicit colourbar limits. Override any metric-derived bounds. cbar_label : str, default="distance" Label drawn on the colourbar. show : bool, default=True Call ``plt.show()`` at the end. Returns ------- fig : matplotlib.figure.Figure ax : matplotlib.axes.Axes """ import matplotlib.pyplot as plt import seaborn as sns D = np.asarray(distance_matrix, dtype=float) n = D.shape[0] # Clustering reorder (before figsize / annot decisions). if cluster: from scipy.cluster.hierarchy import leaves_list from .clustering import hierarchical_clustering linkage = hierarchical_clustering(D, method="average") order = leaves_list(linkage) D = D[np.ix_(order, order)] if labels is not None: labels = np.asarray(labels)[order] # Metric-derived colourbar bounds; user-supplied vmin/vmax win. if metric is not None and metric in _METRIC_BOUNDS: lo, hi = _METRIC_BOUNDS[metric] if vmin is None: vmin = lo if vmax is None: vmax = hi if annot is None: annot = n <= 15 if ax is None: if figsize is None: figsize = _dynamic_heatmap_figsize(n) fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() tick_labels = labels if labels is not None else False sns.heatmap( D, ax=ax, cmap=cmap, xticklabels=tick_labels, yticklabels=tick_labels, square=True, vmin=vmin, vmax=vmax, annot=annot, fmt=".2f" if annot else "", cbar_kws={"label": cbar_label, "shrink": 0.8}, ) if title is None: title = "Pairwise distance" if metric: title = f"{title} ({metric})" ax.set_title(title) show_with_warning(show) return fig, ax
[docs] def plot_dendrogram( linkage_matrix: np.ndarray, labels: list[str] | None = None, *, ax: plt.Axes | None = None, title: str | None = None, figsize: tuple[float, float] = (10, 6), leaf_rotation: float = 90.0, color_threshold: float | None = None, truncate_mode: str | None = None, p: int = 30, show: bool = True, ) -> tuple[plt.Figure, plt.Axes]: """Plot a dendrogram from a hierarchical clustering linkage matrix. Parameters ---------- linkage_matrix : ndarray of shape (n - 1, 4) Linkage matrix from :func:`~maldiamrkit.similarity.hierarchical_clustering`. labels : list of str or None, default=None Leaf labels. ax : Axes or None, default=None Pre-existing axes. title : str or None, default=None Plot title. Defaults to ``"Hierarchical clustering dendrogram"``. figsize : tuple of float, default=(10, 6) Figure size in inches (used only when ``ax`` is ``None``). leaf_rotation : float, default=90.0 Rotation (in degrees) of leaf labels along the bottom axis. color_threshold : float, optional Colour threshold forwarded to scipy's ``dendrogram``. Clusters below this threshold share a colour. When ``None`` (default), scipy chooses ``0.7 * max(linkage[:, 2])``. truncate_mode : {"lastp", "level", None}, optional Forwarded to scipy's ``dendrogram`` to collapse deep branches. Essential for large trees; pair with ``p``. p : int, default=30 Number of leaves / merges to keep when ``truncate_mode`` is set. show : bool, default=True Call ``plt.show()`` at the end. Returns ------- fig : matplotlib.figure.Figure ax : matplotlib.axes.Axes """ import matplotlib.pyplot as plt from scipy.cluster.hierarchy import dendrogram if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() kwargs: dict = { "labels": labels, "ax": ax, "leaf_rotation": leaf_rotation, } if color_threshold is not None: kwargs["color_threshold"] = color_threshold if truncate_mode is not None: kwargs["truncate_mode"] = truncate_mode kwargs["p"] = p dendrogram(linkage_matrix, **kwargs) ax.set_title(title or "Hierarchical clustering dendrogram") show_with_warning(show) return fig, ax