"""
GlyphX Raincloud Plot.
The raincloud combines three views of a distribution in one:
• Raw jittered data points (the "rain")
• A half-violin (KDE density curve)
• A box-and-whisker summary
It is the modern replacement for the plain box plot — you see
every data point AND the full density shape AND quantile summary.
from glyphx import Figure
from glyphx.raincloud import RaincloudSeries
fig = Figure(width=700, height=500, auto_display=False)
fig.add(RaincloudSeries(
data=[group_a, group_b, group_c],
categories=["Control", "Drug A", "Drug B"],
))
fig.show()
"""
from __future__ import annotations
import numpy as np
from .violin_plot import _numpy_kde
from .colormaps import colormap_colors
from .utils import svg_escape
[docs]
class RaincloudSeries:
"""
Raincloud plot: jitter + half-violin + box for each category.
Args:
data: List of 1-D arrays, one per category.
categories: Category labels (same length as data).
colors: Per-category colors; cycles if fewer than categories.
jitter_width: Max horizontal pixel displacement of raw points.
point_radius: Radius of each jittered data point.
violin_width: Max pixel width of the half-violin.
box_width: Pixel width of the IQR box.
seed: Random seed for reproducible jitter.
label: Legend label for the series.
"""
def __init__(
self,
data: list,
categories: list[str] | None = None,
colors: list[str] | None = None,
jitter_width: float = 18.0,
point_radius: float = 3.0,
violin_width: float = 40.0,
box_width: float = 10.0,
seed: int = 42,
label: str | None = None,
) -> None:
self.datasets = [np.asarray(d, dtype=float) for d in data]
self.categories = categories or [str(i) for i in range(len(data))]
self.jitter_width = jitter_width
self.point_radius = point_radius
self.violin_width = violin_width
self.box_width = box_width
self.seed = seed
self.label = label
self.css_class = f"series-{id(self) % 100000}"
n_cats = len(self.datasets)
self.colors = (colors or colormap_colors("viridis", n_cats))[:n_cats]
if len(self.colors) < n_cats:
self.colors = (self.colors * ((n_cats // len(self.colors)) + 1))[:n_cats]
# Expose x/y for domain computation — 0.5-indexed to align with grid label mapping
self.x = [i + 0.5 for i in range(n_cats)]
all_vals = np.concatenate(self.datasets)
self.y = [float(all_vals.min()), float(all_vals.max())]
[docs]
def to_svg(self, ax: object, use_y2: bool = False) -> str:
scale_y = ax.scale_y2 if use_y2 else ax.scale_y # type: ignore[union-attr]
rng = np.random.default_rng(self.seed)
elements: list[str] = []
for i, (arr, cat, color) in enumerate(
zip(self.datasets, self.categories, self.colors)
):
if len(arr) < 2:
continue
cx = ax.scale_x(i + 0.5) # 0-indexed, matches domain x positions # type: ignore[union-attr]
# ── 1. Jittered raw points (left side) ───────────────────────
jitter = rng.uniform(-self.jitter_width, 0, size=len(arr))
for val, jit in zip(arr, jitter):
py = scale_y(float(val))
px = cx + jit - self.jitter_width * 0.5
elements.append(
f'<circle class="glyphx-point {self.css_class}" '
f'cx="{px:.1f}" cy="{py:.1f}" r="{self.point_radius}" '
f'fill="{color}" fill-opacity="0.55" '
f'data-x="{svg_escape(cat)}" '
f'data-y="{val:.3g}" '
f'data-label="{svg_escape(self.label or cat)}"/>'
)
# ── 2. Half-violin (right side) ───────────────────────────────
kde = _numpy_kde(arr)
y_vals = np.linspace(arr.min(), arr.max(), 100)
dens = kde(y_vals)
max_d = dens.max() or 1
dens = dens / max_d * self.violin_width
right_pts = [(cx + d, scale_y(float(y))) for y, d in zip(y_vals, dens)]
left_pts = [(cx, scale_y(float(y))) for y in reversed(y_vals)]
all_pts = right_pts + left_pts
path = "M " + " L ".join(f"{px:.1f},{py:.1f}" for px, py in all_pts) + " Z"
elements.append(
f'<path d="{path}" fill="{color}" fill-opacity="0.35" '
f'stroke="{color}" stroke-width="1.5"/>'
)
# ── 3. Box plot (centre) ──────────────────────────────────────
q1 = float(np.percentile(arr, 25))
q2 = float(np.median(arr))
q3 = float(np.percentile(arr, 75))
iqr = q3 - q1
w_lo = float(max(arr.min(), q1 - 1.5 * iqr))
w_hi = float(min(arr.max(), q3 + 1.5 * iqr))
hw = self.box_width / 2
box_top = min(scale_y(q1), scale_y(q3))
box_h = abs(scale_y(q3) - scale_y(q1))
# Whiskers
elements.append(
f'<line x1="{cx}" x2="{cx}" '
f'y1="{scale_y(w_lo)}" y2="{scale_y(q1)}" '
f'stroke="{color}" stroke-width="1.5"/>'
)
elements.append(
f'<line x1="{cx}" x2="{cx}" '
f'y1="{scale_y(q3)}" y2="{scale_y(w_hi)}" '
f'stroke="{color}" stroke-width="1.5"/>'
)
# IQR box
elements.append(
f'<rect x="{cx - hw}" y="{box_top}" '
f'width="{self.box_width}" height="{box_h}" '
f'fill="{color}" fill-opacity="0.6" '
f'stroke="{color}" stroke-width="1.5"/>'
)
# Median line
elements.append(
f'<line x1="{cx - hw}" x2="{cx + hw}" '
f'y1="{scale_y(q2)}" y2="{scale_y(q2)}" '
f'stroke="#fff" stroke-width="2.5"/>'
)
# Category label — skip if _x_categories is set, grid handles it
if not getattr(self, "_x_categories", None):
elements.append(
f'<text x="{cx}" y="{ax.height - ax.padding + 16}" ' # type: ignore[union-attr]
f'text-anchor="middle" font-size="11" fill="#444">'
f'{svg_escape(cat)}</text>'
)
return "\n".join(elements)