Source code for glyphx.facet_grid

"""
GlyphX FacetGrid -- small-multiples grid of the same chart type, faceted
by a categorical column.

Matches the Seaborn ``FacetGrid.map()`` API:

    from glyphx.facet_grid import FacetGrid

    g = FacetGrid(df, col="species", hue="sex", height=300, aspect=1.2)
    g.map("scatter", x="bill_length", y="bill_depth")
    g.show()

Each cell is a full GlyphX Figure rendered at ``height x aspect`` pixels.
The grid wraps into multiple rows when ``col_wrap`` is set.
"""
from __future__ import annotations

import math
from typing import Any

import numpy as np
import pandas as pd

from .figure  import Figure
from .themes  import themes as _themes
from .colormaps import colormap_colors
from .utils   import svg_escape


[docs] class FacetGrid: """ Small-multiples grid. Args: data: Source DataFrame. col: Column name to facet along the X axis. row: Column name to facet along the Y axis (optional). hue: Column name for color-coding within each cell (optional). height: Pixel height of each cell. aspect: Width/height ratio of each cell. col_wrap: Wrap into a new row after this many columns. theme: GlyphX theme name. sharex: Share X-axis scale across all cells. sharey: Share Y-axis scale across all cells. col_order: Explicit ordering of column facets. row_order: Explicit ordering of row facets. hue_order: Explicit ordering of hue groups. palette: Color palette for hue groups (list of hex strings). """ def __init__( self, data, col: str | None = None, row: str | None = None, hue: str | None = None, height: int = 300, aspect: float = 1.4, col_wrap: int | None = None, theme: str = "default", sharex: bool = True, sharey: bool = True, col_order: list | None = None, row_order: list | None = None, hue_order: list | None = None, palette: list | None = None, ) -> None: self._df = data self._col = col self._row = row self._hue = hue self._height = height self._width = int(height * aspect) self._col_wrap = col_wrap self._theme = theme self._sharex = sharex self._sharey = sharey self._figs: list[tuple[Figure, str, str]] = [] self._map_kind: str | None = None # Facet values self._col_vals = (col_order if col_order else list(data[col].unique()) if col else [None]) self._row_vals = (row_order if row_order else list(data[row].unique()) if row else [None]) self._hue_vals = (hue_order if hue_order else list(data[hue].unique()) if hue else [None]) # Palette n_hue = len(self._hue_vals) if hue else 1 self._palette = palette or colormap_colors("viridis", max(n_hue, 2)) # ------------------------------------------------------------------
[docs] def map( self, kind: str, x: str | None = None, y: str | None = None, **kwargs, ) -> "FacetGrid": """ Apply a chart type to each facet cell. Args: kind: Chart kind (``"scatter"``, ``"line"``, ``"bar"``, ``"hist"``, ``"kde"``, ``"box"``, ``"violin"``). x: X-axis column name. y: Y-axis column name (not needed for ``"hist"``/``"kde"``). **kwargs: Passed to the series constructor. Returns: ``self`` for chaining. Example:: g = FacetGrid(df, col="species", hue="island") g.map("scatter", x="bill_length", y="flipper_length") g.map("hist", x="body_mass") """ from .series import (LineSeries, BarSeries, ScatterSeries, HistogramSeries, BoxPlotSeries) from .kde import KDESeries from .violin_plot import ViolinPlotSeries self._map_kind = kind self._figs = [] theme_dict = _themes.get(self._theme, _themes["default"]) for r_val in self._row_vals: for c_val in self._col_vals: # Build cell title parts = [] if c_val is not None: parts.append(f"{self._col}={c_val}") if r_val is not None: parts.append(f"{self._row}={r_val}") cell_title = " | ".join(parts) fig = Figure(width=self._width, height=self._height, auto_display=False, theme=self._theme) fig.set_title(cell_title) for hi, h_val in enumerate(self._hue_vals): # Subset the data sub = self._df.copy() if c_val is not None: sub = sub[sub[self._col] == c_val] if r_val is not None: sub = sub[sub[self._row] == r_val] if h_val is not None and self._hue: sub = sub[sub[self._hue] == h_val] if sub.empty: continue color = self._palette[hi % len(self._palette)] label = str(h_val) if h_val is not None else None series = None if kind == "scatter" and x and y: series = ScatterSeries( sub[x].tolist(), sub[y].tolist(), color=color, label=label, **kwargs ) elif kind == "line" and x and y: series = LineSeries( sub[x].tolist(), sub[y].tolist(), color=color, label=label, **kwargs ) elif kind == "bar" and x and y: series = BarSeries( sub[x].tolist(), sub[y].tolist(), color=color, label=label, **kwargs ) elif kind == "hist" and x: series = HistogramSeries( sub[x].dropna().tolist(), color=color, label=label, **kwargs ) elif kind == "kde" and x: series = KDESeries( sub[x].dropna().tolist(), color=color, label=label, **kwargs ) elif kind in ("box", "violin") and y: grp_data = [sub[y].dropna().tolist()] if kind == "box": series = BoxPlotSeries( grp_data, color=color, label=label, **kwargs ) else: series = ViolinPlotSeries( grp_data, color=color, label=label, **kwargs ) if series is not None: fig.add(series) if x: fig.set_xlabel(x) if y: fig.set_ylabel(y) self._figs.append((fig, str(r_val), str(c_val))) return self
# ------------------------------------------------------------------
[docs] def render_svg(self) -> str: """Composite all cell figures into a single SVG grid.""" if not self._figs: return '<svg xmlns="http://www.w3.org/2000/svg"></svg>' n_total = len(self._figs) n_cols = (self._col_wrap or len(self._col_vals)) or 1 n_rows = math.ceil(n_total / n_cols) gap = 16 title_h = 40 cell_w = self._width cell_h = self._height total_w = n_cols * cell_w + (n_cols + 1) * gap total_h = n_rows * cell_h + (n_rows + 1) * gap + title_h theme_dict = _themes.get(self._theme, _themes["default"]) bg = theme_dict.get("background", "#fff") tc = theme_dict.get("text_color", "#000") font = theme_dict.get("font", "sans-serif") parts = [ f'<svg width="{total_w}" height="{total_h}" ' f'xmlns="http://www.w3.org/2000/svg" ' f'viewBox="0 0 {total_w} {total_h}">', f'<rect width="{total_w}" height="{total_h}" fill="{bg}"/>', ] # Hue legend at top right if self._hue and self._hue_vals[0] is not None: lx = total_w - 120 ly = 8 for hi, hv in enumerate(self._hue_vals): col = self._palette[hi % len(self._palette)] parts.append( f'<rect x="{lx}" y="{ly + hi*16}" width="10" height="10" fill="{col}"/>' ) parts.append( f'<text x="{lx+14}" y="{ly + hi*16 + 9}" ' f'font-size="10" font-family="{font}" fill="{tc}">' f'{svg_escape(str(hv))}</text>' ) for idx, (fig, r_val, c_val) in enumerate(self._figs): row_i = idx // n_cols col_i = idx % n_cols tx = gap + col_i * (cell_w + gap) ty = title_h + gap + row_i * (cell_h + gap) inner_svg = fig.render_svg() # Extract inner content from the <svg> root import re body_match = re.search(r'<svg[^>]*>(.*)</svg>', inner_svg, re.DOTALL) if body_match: inner = body_match.group(1) else: inner = inner_svg parts.append( f'<g transform="translate({tx},{ty})">' f'<rect width="{cell_w}" height="{cell_h}" ' f'fill="{bg}" rx="4" stroke="#eee" stroke-width="1"/>' + inner + '</g>' ) parts.append("</svg>") return "\n".join(parts)
[docs] def show(self) -> "FacetGrid": """Display in Jupyter or open in browser.""" svg = self.render_svg() try: from IPython.display import SVG, display as jd jd(SVG(svg)); return self except Exception: pass import tempfile, webbrowser tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".svg", mode="w") tmp.write(svg); tmp.close() webbrowser.open(f"file://{tmp.name}") return self
[docs] def save(self, path: str) -> "FacetGrid": """Save the composite SVG to a file.""" from pathlib import Path Path(path).write_text(self.render_svg(), encoding="utf-8") return self
def __repr__(self) -> str: return (f"<FacetGrid col={self._col!r} row={self._row!r} " f"hue={self._hue!r} cells={len(self._figs)}>")