Source code for glyphx.parallel_coords

"""
GlyphX ParallelCoordinatesSeries — high-dimensional data visualization.

Each row of data becomes a polyline drawn across a set of parallel
vertical axes, one axis per variable.  This beats Plotly's verbose
``go.Parcoords`` and Seaborn, which has no parallel coordinates at all.

    from glyphx import Figure
    from glyphx.parallel_coords import ParallelCoordinatesSeries
    import pandas as pd

    df = pd.read_csv("iris.csv")
    numeric_cols = ["sepal_length", "sepal_width", "petal_length", "petal_width"]

    fig = Figure(width=800, height=500, auto_display=False)
    fig.add(ParallelCoordinatesSeries(
        data=df[numeric_cols].values.tolist(),
        axes=numeric_cols,
        hue=df["species"].tolist(),       # color-code by a categorical column
        alpha=0.35,
    ))
    fig.show()
"""
from __future__ import annotations

import math
import numpy as np
from typing import Any

from .colormaps import colormap_colors, apply_colormap
from .utils import svg_escape, _format_tick, LEGEND_GUTTER


[docs] class ParallelCoordinatesSeries: """ Parallel coordinates plot for high-dimensional data. Args: data: 2-D array-like (rows × variables). axes: Column / variable names (length must match data columns). hue: Per-row group labels for color coding. When provided, each unique value gets a distinct color from the theme or ``cmap``. Pass a numeric array to use a continuous colormap instead. cmap: Colormap name used when ``hue`` holds numeric values, or when ``colors`` is not supplied (default ``"viridis"``). colors: Explicit hex color per unique ``hue`` group (dict or list). alpha: Line opacity 0–1 (default 0.45). Lower values help with overplotting. line_width: Stroke width (default 1.2). show_axes: Draw vertical axis lines and tick labels (default True). label: Legend label (unused but kept for API consistency). """ def __init__( self, data, axes: list[str], hue: list | None = None, cmap: str = "viridis", colors: dict | list | None = None, alpha: float = 0.45, line_width: float = 1.2, show_axes: bool = True, label: str | None = None, ) -> None: self.matrix = np.asarray(data, dtype=float) self.axes_names = list(axes) self.alpha = float(alpha) self.line_width = float(line_width) self.show_axes = show_axes self.label = label self.css_class = f"series-{id(self) % 100000}" self.cmap = cmap self.hue = hue n_rows, n_cols = self.matrix.shape if n_cols != len(axes): raise ValueError( f"data has {n_cols} columns but {len(axes)} axis names were given." ) # Compute per-column min/max for normalisation self._col_min = self.matrix.min(axis=0) self._col_max = self.matrix.max(axis=0) self._col_range = np.where( self._col_max != self._col_min, self._col_max - self._col_min, 1.0, ) # Build per-row colour assignment if hue is None: self._row_colors = [apply_colormap(i / max(n_rows - 1, 1), cmap) for i in range(n_rows)] else: hue_arr = np.asarray(hue) if np.issubdtype(hue_arr.dtype, np.number): # Continuous h_min, h_max = hue_arr.min(), hue_arr.max() span = h_max - h_min or 1.0 self._row_colors = [ apply_colormap(float((v - h_min) / span), cmap) for v in hue_arr ] else: # Categorical unique_groups = list(dict.fromkeys(str(v) for v in hue)) if isinstance(colors, dict): group_color = {k: v for k, v in colors.items()} elif isinstance(colors, list): group_color = dict(zip(unique_groups, colors)) else: palette = colormap_colors(cmap, len(unique_groups)) group_color = dict(zip(unique_groups, palette)) self._row_colors = [group_color.get(str(v), "#888") for v in hue] self._legend_items = group_color # used for legend rendering # Unique groups for legend if hue is not None and not np.issubdtype(np.asarray(hue).dtype, np.number): self._groups = list(dict.fromkeys(str(v) for v in hue)) else: self._groups = [] # x/y stubs so Figure knows this is axis-free self.x = None self.y = None
[docs] def to_svg(self, ax: object = None) -> str: # type: ignore[override] if ax is None: pad_x, pad_y = 60, 50 w, h = 740, 400 font, tc = "sans-serif", "#000" grid_color = "#ddd" else: pad_x = ax.padding * 2 # type: ignore pad_y = ax.padding # type: ignore w = ax.width # type: ignore h = ax.height # type: ignore font = ax.theme.get("font", "sans-serif") # type: ignore tc = ax.theme.get("text_color", "#000") # type: ignore grid_color = ax.theme.get("grid_color", "#ddd") # type: ignore n_axes = len(self.axes_names) n_rows = self.matrix.shape[0] # Reserve right gutter for the legend so it never overlaps the axes _gutter = LEGEND_GUTTER if self._groups else 0 plot_w = w - 2 * pad_x - _gutter plot_h = h - 2 * pad_y ax_step = plot_w / (n_axes - 1) if n_axes > 1 else plot_w # Axis X positions ax_x = [pad_x + i * ax_step for i in range(n_axes)] elements: list[str] = [] # Vertical axis lines and tick labels if self.show_axes: TICKS = 5 for j, (x_pos, name) in enumerate(zip(ax_x, self.axes_names)): # Axis line elements.append( f'<line x1="{x_pos:.1f}" x2="{x_pos:.1f}" ' f'y1="{pad_y}" y2="{pad_y + plot_h}" ' f'stroke="{tc}" stroke-width="1.2" opacity="0.6"/>' ) # Axis label elements.append( f'<text x="{x_pos:.1f}" y="{pad_y - 8}" ' f'text-anchor="middle" font-size="12" ' f'font-family="{font}" fill="{tc}" font-weight="600">' f'{svg_escape(name)}</text>' ) # Tick labels col_min = self._col_min[j] col_rng = self._col_range[j] for k in range(TICKS + 1): t = k / TICKS val = col_min + t * col_rng y = pad_y + plot_h - t * plot_h elements.append( f'<text x="{x_pos - 6:.1f}" y="{y + 3:.1f}" ' f'text-anchor="end" font-size="9" ' f'font-family="{font}" fill="{tc}" opacity="0.7">' f'{_format_tick(val)}</text>' ) elements.append( f'<line x1="{x_pos - 3:.1f}" x2="{x_pos + 3:.1f}" ' f'y1="{y:.1f}" y2="{y:.1f}" ' f'stroke="{tc}" stroke-width="0.8" opacity="0.4"/>' ) # Data polylines (draw in reverse z-order so first rows are on top) for i in range(n_rows - 1, -1, -1): row = self.matrix[i] color = self._row_colors[i] norm = (row - self._col_min) / self._col_range # [0, 1] pts = " ".join( f"{ax_x[j]:.1f},{pad_y + plot_h - float(norm[j]) * plot_h:.1f}" for j in range(n_axes) ) elements.append( f'<polyline class="glyphx-point {self.css_class}" ' f'points="{pts}" fill="none" stroke="{color}" ' f'stroke-width="{self.line_width}" opacity="{self.alpha}"/>' ) # Categorical legend — always rendered in the right gutter if self._groups: legend_x = w - _gutter + 8 if _gutter else w - 110 n_groups = len(self._groups) total_h = n_groups * 20 legend_y = (h - total_h) // 2 # vertically centred for k, grp in enumerate(self._groups): col = getattr(self, "_legend_items", {}).get(grp, "#888") gy = legend_y + k * 20 elements.append( f'<rect x="{legend_x}" y="{gy}" width="12" height="12" ' f'fill="{col}" rx="2"/>' ) elements.append( f'<text x="{legend_x + 16}" y="{gy + 10}" ' f'font-size="11" font-family="{font}" fill="{tc}">' f'{svg_escape(grp)}</text>' ) return "\n".join(elements)