Source code for glyphx.clustermap

"""
GlyphX clustermap -- hierarchically clustered heatmap with dendrograms.

The clustermap is Seaborn's most distinctive chart in bioinformatics and
machine learning.  Seaborn's ``sns.clustermap()`` requires scipy; GlyphX
implements the full pipeline (hierarchical clustering, dendrogram layout,
heatmap rendering) in pure NumPy -- no scipy, no matplotlib required.

    from glyphx.clustermap import clustermap

    fig = clustermap(
        df,                      # DataFrame or 2-D array
        cmap="viridis",
        row_cluster=True,
        col_cluster=True,
        title="Gene Expression",
    )
    fig.show()
"""
from __future__ import annotations

import math
import numpy as np
from typing import Any

from .figure   import Figure
from .series   import HeatmapSeries
from .colormaps import colormap_colors, apply_colormap
from .utils    import svg_escape, _format_tick


# ---------------------------------------------------------------------------
# Pure-NumPy hierarchical clustering (average linkage, Euclidean distance)
# ---------------------------------------------------------------------------

def _pdist(X: np.ndarray) -> np.ndarray:
    """Pairwise Euclidean distance matrix (nxn)."""
    n = len(X)
    D = np.zeros((n, n))
    for i in range(n):
        diff = X[i] - X   # broadcast
        D[i] = np.sqrt((diff ** 2).sum(axis=1))
    return D


def _average_linkage(D: np.ndarray) -> list[tuple]:
    """
    UPGMA (average linkage) hierarchical clustering.

    Returns a linkage list compatible with the Scipy/Matplotlib convention:
    ``[(left_id, right_id, distance, cluster_size), ...]``.
    """
    n    = len(D)
    dist = D.copy()
    np.fill_diagonal(dist, np.inf)

    # Cluster membership: initially each point is its own cluster
    members: list[list[int]] = [[i] for i in range(n)]
    active  = list(range(n))
    linkage = []
    next_id = n

    while len(active) > 1:
        # Find closest pair among active clusters
        min_d = np.inf
        ci, cj = -1, -1
        for ii in range(len(active)):
            for jj in range(ii + 1, len(active)):
                a, b = active[ii], active[jj]
                d    = dist[a, b]
                if d < min_d:
                    min_d = d
                    ci, cj = ii, jj

        ai, aj = active[ci], active[cj]
        merged  = members[ai] + members[aj]
        linkage.append((ai, aj, min_d, len(merged)))

        # Merge: update distances to the new cluster (average linkage)
        new_row = np.full(dist.shape[0] + 1, np.inf)
        for ak in active:
            if ak == ai or ak == aj:
                continue
            d_new = (dist[ai, ak] * len(members[ai]) +
                     dist[aj, ak] * len(members[aj])) / len(merged)
            new_row[ak] = d_new

        # Grow distance matrix
        old_size  = dist.shape[0]
        new_dist  = np.full((old_size + 1, old_size + 1), np.inf)
        new_dist[:old_size, :old_size] = dist
        new_dist[next_id, :old_size]   = new_row[:old_size]
        new_dist[:old_size, next_id]   = new_row[:old_size]
        dist = new_dist

        members.append(merged)
        active.remove(ai)
        active.remove(aj)
        active.append(next_id)
        next_id += 1

    return linkage


def _leaf_order(linkage: list[tuple], n_leaves: int) -> list[int]:
    """
    Traverse the linkage tree and return the leaf order (left-to-right DFS).
    """
    n       = n_leaves
    children: dict[int, tuple[int, int]] = {}
    for i, (a, b, *_) in enumerate(linkage):
        children[n + i] = (int(a), int(b))

    root = n + len(linkage) - 1

    def _dfs(node: int) -> list[int]:
        if node < n:
            return [node]
        l, r = children[node]
        return _dfs(l) + _dfs(r)

    return _dfs(root)


def _dendrogram_svg(
    linkage:   list[tuple],
    n_leaves:  int,
    leaf_order: list[int],
    orient:    str,    # "left" or "top"
    x0: float, y0: float,
    plot_w: float, plot_h: float,
    color: str = "#555",
    line_width: float = 1.2,
) -> str:
    """
    Render a dendrogram as SVG polylines.

    ``orient="top"`` draws the tree growing downward (column dendrogram).
    ``orient="left"`` draws the tree growing rightward (row dendrogram).
    """
    n         = n_leaves
    pos_map   = {leaf: i for i, leaf in enumerate(leaf_order)}
    max_height = max(d for _, _, d, _ in linkage) if linkage else 1.0
    cluster_pos: dict[int, float] = {i: i + 0.5 for i in range(n)}

    elements: list[str] = []

    def _leaf_px(leaf: int) -> float:
        """Pixel position of a leaf along the axis."""
        rank = pos_map.get(leaf, leaf)
        if orient == "top":
            return x0 + (rank / n) * plot_w
        else:
            return y0 + (rank / n) * plot_h

    def _height_px(height: float) -> float:
        """Pixel position for a given linkage height."""
        norm = height / max_height if max_height > 0 else 0
        if orient == "top":
            return y0 + norm * plot_h
        else:
            return x0 + norm * plot_w

    current_id = n
    for left, right, height, _ in linkage:
        h_px = _height_px(height)

        lp = cluster_pos.get(left,  left  if left  < n else left)
        rp = cluster_pos.get(right, right if right < n else right)

        lx = _leaf_px(lp) if left  < n else cluster_pos.get(left,  0)
        rx = _leaf_px(rp) if right < n else cluster_pos.get(right, 0)

        # Recalculate using leaf positions correctly
        def _node_px(node_id: int) -> float:
            if node_id < n:
                return _leaf_px(node_id)
            return cluster_pos.get(node_id, 0)

        lp_px = _node_px(left)
        rp_px = _node_px(right)
        mid   = (lp_px + rp_px) / 2

        if orient == "top":
            # Horizontal segments at height h_px, vertical connectors
            lh_px = cluster_pos.get(left + 10000, y0)   # previous height
            rh_px = cluster_pos.get(right + 10000, y0)
            elements.append(
                f'<polyline points="{lp_px:.1f},{lh_px:.1f} '
                f'{lp_px:.1f},{h_px:.1f} '
                f'{rp_px:.1f},{h_px:.1f} '
                f'{rp_px:.1f},{rh_px:.1f}" '
                f'fill="none" stroke="{color}" stroke-width="{line_width}"/>'
            )
            cluster_pos[current_id]           = mid
            cluster_pos[current_id + 10000]   = h_px
        else:
            # orient = "left"
            lh_px = cluster_pos.get(left  + 10000, x0)
            rh_px = cluster_pos.get(right + 10000, x0)
            elements.append(
                f'<polyline points="{lh_px:.1f},{lp_px:.1f} '
                f'{h_px:.1f},{lp_px:.1f} '
                f'{h_px:.1f},{rp_px:.1f} '
                f'{rh_px:.1f},{rp_px:.1f}" '
                f'fill="none" stroke="{color}" stroke-width="{line_width}"/>'
            )
            cluster_pos[current_id]           = mid
            cluster_pos[current_id + 10000]   = h_px

        current_id += 1

    return "\n".join(elements)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

[docs] def clustermap( data, row_labels: list[str] | None = None, col_labels: list[str] | None = None, cmap: str = "viridis", row_cluster: bool = True, col_cluster: bool = True, standard_scale: str | None = None, # "row", "col", or None z_score: str | None = None, # "row", "col", or None show_values: bool = False, figsize: tuple[int,int] = (720, 640), title: str = "", dendrogram_ratio: float = 0.15, line_color: str = "#555", theme: str = "default", ) -> Figure: """ Hierarchically clustered heatmap with row and column dendrograms. Equivalent to ``seaborn.clustermap()`` but implemented in pure NumPy -- no scipy, no matplotlib required. Args: data: 2-D numeric array or pandas DataFrame. row_labels: Labels for rows (inferred from DataFrame index if None). col_labels: Labels for columns (inferred from DataFrame columns if None). cmap: Colormap name (default ``"viridis"``). row_cluster: Cluster and reorder rows (default True). col_cluster: Cluster and reorder columns (default True). standard_scale: ``"row"`` or ``"col"`` -- scale each row/column to [0,1]. z_score: ``"row"`` or ``"col"`` -- z-score normalise each row/column. show_values: Overlay the numeric value in each cell. figsize: ``(width, height)`` in pixels. title: Chart title. dendrogram_ratio: Fraction of figure width/height used by dendrograms. line_color: Dendrogram line colour. theme: GlyphX theme name. Returns: :class:`~glyphx.Figure` containing the clustered heatmap and dendrograms rendered as SVG. Example:: import pandas as pd from glyphx.clustermap import clustermap df = pd.read_csv("gene_expression.csv", index_col=0) fig = clustermap(df, cmap="coolwarm", z_score="row", title="Gene Expression Heatmap") fig.show() fig.save("clustermap.html") """ import pandas as pd # Coerce to numpy if isinstance(data, pd.DataFrame): if row_labels is None: row_labels = [str(i) for i in data.index] if col_labels is None: col_labels = [str(c) for c in data.columns] mat = data.values.astype(float) else: mat = np.asarray(data, dtype=float) n_rows, n_cols = mat.shape if row_labels is None: row_labels = [str(i) for i in range(n_rows)] if col_labels is None: col_labels = [str(j) for j in range(n_cols)] # Preprocessing if z_score == "row": mu = mat.mean(axis=1, keepdims=True) sig = mat.std(axis=1, keepdims=True) + 1e-10 mat = (mat - mu) / sig elif z_score == "col": mu = mat.mean(axis=0, keepdims=True) sig = mat.std(axis=0, keepdims=True) + 1e-10 mat = (mat - mu) / sig if standard_scale == "row": lo = mat.min(axis=1, keepdims=True) hi = mat.max(axis=1, keepdims=True) + 1e-10 mat = (mat - lo) / (hi - lo) elif standard_scale == "col": lo = mat.min(axis=0, keepdims=True) hi = mat.max(axis=0, keepdims=True) + 1e-10 mat = (mat - lo) / (hi - lo) # Clustering row_order = list(range(n_rows)) col_order = list(range(n_cols)) row_linkage: list = [] col_linkage: list = [] if row_cluster and n_rows > 1: D = _pdist(mat) row_linkage = _average_linkage(D) row_order = _leaf_order(row_linkage, n_rows) if col_cluster and n_cols > 1: D = _pdist(mat.T) col_linkage = _average_linkage(D) col_order = _leaf_order(col_linkage, n_cols) # Reorder matrix and labels mat_r = mat[np.ix_(row_order, col_order)] row_lbl_r = [row_labels[i] for i in row_order] col_lbl_r = [col_labels[j] for j in col_order] # Build figure as raw SVG (Figure wraps it as axis-free) W, H = figsize dend_w = int(W * dendrogram_ratio) # row dendrogram width (left) dend_h = int(H * dendrogram_ratio) # col dendrogram height (top) title_h = 32 if title else 0 label_w = max(max(len(l) for l in row_lbl_r) * 6, 60) label_h = max(max(len(l) for l in col_lbl_r) * 6, 40) colorbar_w = 18 heat_x = dend_w + label_w heat_y = title_h + dend_h heat_w = W - heat_x - colorbar_w - 8 heat_h = H - heat_y - label_h cell_w = heat_w / n_cols cell_h = heat_h / n_rows vmin, vmax = float(mat_r.min()), float(mat_r.max()) span = vmax - vmin or 1.0 from .themes import themes as _themes theme_dict = _themes.get(theme, _themes["default"]) bg = theme_dict.get("background", "#fff") tc = theme_dict.get("text_color", "#000") font = theme_dict.get("font", "sans-serif") parts: list[str] = [ f'<svg width="{W}" height="{H}" xmlns="http://www.w3.org/2000/svg" ' f'viewBox="0 0 {W} {H}">', f'<rect width="{W}" height="{H}" fill="{bg}"/>', ] # Title if title: parts.append( f'<text x="{W//2}" y="22" text-anchor="middle" ' f'font-size="15" font-weight="bold" ' f'font-family="{font}" fill="{tc}">{svg_escape(title)}</text>' ) # -- Heatmap cells ------------------------------------------------- for ri in range(n_rows): for ci in range(n_cols): v = float(mat_r[ri, ci]) norm = (v - vmin) / span col = apply_colormap(norm, cmap) cx = heat_x + ci * cell_w cy = heat_y + ri * cell_h parts.append( f'<rect x="{cx:.1f}" y="{cy:.1f}" ' f'width="{cell_w:.1f}" height="{cell_h:.1f}" ' f'fill="{col}" ' f'data-row="{svg_escape(row_lbl_r[ri])}" ' f'data-col="{svg_escape(col_lbl_r[ci])}" ' f'data-value="{v:.3g}"/>' ) if show_values and cell_w > 24 and cell_h > 14: txt_col = "#fff" if norm < 0.6 else "#000" parts.append( f'<text x="{cx + cell_w/2:.1f}" y="{cy + cell_h/2 + 4:.1f}" ' f'text-anchor="middle" font-size="9" ' f'font-family="{font}" fill="{txt_col}">' f'{_format_tick(v)}</text>' ) # -- Row labels (right side of row dendrogram, left of heatmap) --- for ri, lbl in enumerate(row_lbl_r): cy = heat_y + ri * cell_h + cell_h / 2 parts.append( f'<text x="{heat_x - 4:.1f}" y="{cy + 4:.1f}" ' f'text-anchor="end" font-size="10" ' f'font-family="{font}" fill="{tc}">{svg_escape(lbl)}</text>' ) # -- Column labels (below heatmap) -------------------------------- for ci, lbl in enumerate(col_lbl_r): cx = heat_x + ci * cell_w + cell_w / 2 cy = heat_y + heat_h + 4 parts.append( f'<text x="{cx:.1f}" y="{cy:.1f}" ' f'text-anchor="start" font-size="10" ' f'font-family="{font}" fill="{tc}" ' f'transform="rotate(45,{cx:.1f},{cy:.1f})">' f'{svg_escape(lbl)}</text>' ) # -- Row dendrogram (left panel, growing rightward) --------------- if row_cluster and row_linkage: parts.append(_dendrogram_svg( row_linkage, n_rows, list(range(n_rows)), orient="left", x0=dend_w * 0.05, y0=heat_y, plot_w=dend_w * 0.90, plot_h=heat_h, color=line_color, )) # -- Column dendrogram (top panel, growing downward) -------------- if col_cluster and col_linkage: parts.append(_dendrogram_svg( col_linkage, n_cols, list(range(n_cols)), orient="top", x0=heat_x, y0=title_h + dend_h * 0.05, plot_w=heat_w, plot_h=dend_h * 0.90, color=line_color, )) # -- Colorbar ----------------------------------------------------- cb_x = heat_x + heat_w + 6 cb_y = heat_y n_steps = 50 step_h = heat_h / n_steps for k in range(n_steps): norm = 1 - k / n_steps col = apply_colormap(norm, cmap) parts.append( f'<rect x="{cb_x}" y="{cb_y + k * step_h:.1f}" ' f'width="{colorbar_w - 2}" height="{step_h + 0.5:.1f}" ' f'fill="{col}"/>' ) parts.append( f'<text x="{cb_x + colorbar_w}" y="{cb_y + 10}" ' f'font-size="9" font-family="{font}" fill="{tc}">' f'{_format_tick(vmax)}</text>' ) parts.append( f'<text x="{cb_x + colorbar_w}" y="{cb_y + heat_h}" ' f'font-size="9" font-family="{font}" fill="{tc}">' f'{_format_tick(vmin)}</text>' ) parts.append("</svg>") # Wrap in an axis-free Figure fig = Figure(width=W, height=H, auto_display=False, theme=theme) fig.title = "" # already in SVG # Inject raw SVG via a custom series stub class _RawSVG: x = None; y = None; label = None; color = "#000" css_class = "clustermap" def to_svg(self, ax=None, use_y2=False): return "\n".join(parts[2:-1]) fig._raw_svg = "\n".join(parts) fig._clustermap = True # Override render_svg to return our pre-built SVG _orig_render = fig.render_svg def _render_patched(): return fig._raw_svg import types fig.render_svg = types.MethodType(lambda self: self._raw_svg, fig) return fig