Source code for glyphx.nlp

"""
GlyphX Natural Language Chart Generation.

``glyphx.from_prompt()`` lets you describe a chart in plain English
and get back a fully rendered Figure — no axis wrangling, no theme
fiddling, no column juggling.

Requires the ``anthropic`` Python package::

    pip install anthropic

And an API key, either passed directly or via the ``ANTHROPIC_API_KEY``
environment variable.

Example::

    import pandas as pd
    from glyphx import from_prompt

    df = pd.read_csv("sales.csv")
    fig = from_prompt("bar chart of total revenue by region", df=df)
    fig.share("revenue_by_region.html")
"""

from __future__ import annotations

import json
import os
import re
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import pandas as pd

# JSON schema the LLM must return
_SCHEMA_DOC = """\
Return ONLY a JSON object (no markdown fences, no explanation) with this schema:

{
  "kind":      "line" | "bar" | "scatter" | "pie" | "donut" | "hist" | "box",
  "x":         "<column name or null>",
  "y":         "<column name or null>",
  "groupby":   "<column name or null>",
  "agg":       "sum" | "mean" | "count" | "max" | "min",
  "bins":      <integer, for hist only>,
  "title":     "<chart title or null>",
  "theme":     "default"|"dark"|"colorblind"|"pastel"|"warm"|"ocean"|"monochrome",
  "color":     "<hex color or null>",
  "label":     "<series label or null>",
  "xlabel":    "<x-axis label or null>",
  "ylabel":    "<y-axis label or null>",
  "sort_by":   "x" | "y" | null,
  "sort_desc": true | false,
  "top_n":     <integer or null — keep only top N rows after aggregation>,
  "reasoning": "<one-sentence explanation of your choices>"
}

Rules:
- kind must be one of the listed values
- x/y/groupby must be exact column names from the schema, or null
- agg defaults to "sum" when groupby is set
- bins defaults to 10
- theme defaults to "default"
- top_n is useful for "top 10 X by Y" queries
- sort_desc defaults to true when top_n is set, false otherwise
"""

_SYSTEM_PROMPT = (
    "You are a data visualisation expert who configures GlyphX charts.\n"
    "Given a user's description and optional DataFrame schema, choose the best "
    "chart type, map columns to axes, select a fitting theme, and return the "
    "configuration as JSON.\n\n"
    + _SCHEMA_DOC
)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

[docs] def from_prompt( prompt: str, df=None, api_key: str = None, model: str = "claude-sonnet-4-20250514", auto_display: bool = True, ) -> "Figure": # noqa: F821 """ Generate a GlyphX Figure from a plain-English description. Args: prompt (str): Natural language description of the desired chart, e.g. ``"bar chart of monthly revenue grouped by region"``. df (pd.DataFrame | None): DataFrame to plot. Column names, dtypes, and a sample are sent to the model so it can choose sensible x/y mappings. api_key (str | None): Anthropic API key. Falls back to the ``ANTHROPIC_API_KEY`` environment variable. model (str): Anthropic model name. auto_display (bool): Auto-render and show the figure when True. Returns: Figure: A fully configured and rendered GlyphX Figure. Raises: ImportError: If the ``anthropic`` package is not installed. ValueError: If no API key is found. json.JSONDecodeError: If the model returns unparseable JSON (rarely happens; usually recoverable). Examples:: # Simple — no data, just a chart type hint fig = from_prompt("show me a sample line chart of sin(x)", auto_display=False) # With a DataFrame import pandas as pd df = pd.DataFrame({"month": range(1,13), "sales": [120,135,98,...]}) fig = from_prompt("line chart of sales over time", df=df) # Grouped bar fig = from_prompt( "top 5 products by total revenue, grouped by region", df=sales_df, ) """ try: import anthropic except ImportError: raise ImportError( "Natural language chart generation requires the anthropic package.\n" "Install it with: pip install anthropic" ) key = api_key or os.environ.get("ANTHROPIC_API_KEY") if not key: raise ValueError( "No Anthropic API key found. Either pass api_key= or set the " "ANTHROPIC_API_KEY environment variable." ) # Build the user message user_parts = [prompt] if df is not None: user_parts.append(_df_context(df)) user_msg = "\n\n".join(user_parts) client = anthropic.Anthropic(api_key=key) response = client.messages.create( model=model, max_tokens=600, system=_SYSTEM_PROMPT, messages=[{"role": "user", "content": user_msg}], ) raw = response.content[0].text.strip() config = _parse_json(raw) return _build_figure(config, df, auto_display=auto_display)
# --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _df_context(df) -> str: """Format a DataFrame's schema and sample for the LLM.""" try: sample = df.head(5).to_string(index=False) except Exception: sample = "(sample unavailable)" numeric_cols = df.select_dtypes(include="number").columns.tolist() category_cols = df.select_dtypes(exclude="number").columns.tolist() return ( f"DataFrame schema:\n" f" Shape : {df.shape[0]:,} rows × {df.shape[1]} columns\n" f" Numeric : {numeric_cols}\n" f" Categorical: {category_cols}\n" f" Dtypes : {df.dtypes.to_dict()}\n\n" f"First 5 rows:\n{sample}" ) def _parse_json(raw: str) -> dict: """Strip markdown fences if present, then parse JSON.""" # Remove ```json ... ``` fences raw = re.sub(r"^```(?:json)?\s*", "", raw.strip(), flags=re.IGNORECASE) raw = re.sub(r"\s*```$", "", raw.strip()) return json.loads(raw.strip()) def _coerce(df, col: str): """Return df[col].tolist(), or None if col is None / not in df.""" if col is None or df is None: return None if col not in df.columns: return None return df[col].tolist() def _build_figure(config: dict, df, auto_display: bool = True): """ Translate a config dict returned by the LLM into a GlyphX Figure. Handles: - Aggregation (groupby + agg) - top_n filtering - Sorting - Multi-series (groupby without aggregation) - All supported chart kinds """ import numpy as np from .figure import Figure from .series import ( LineSeries, BarSeries, ScatterSeries, PieSeries, DonutSeries, HistogramSeries, BoxPlotSeries, ) kind = config.get("kind", "line").lower() title = config.get("title") theme = config.get("theme", "default") color = config.get("color") label = config.get("label") xlabel = config.get("xlabel") ylabel = config.get("ylabel") x_col = config.get("x") y_col = config.get("y") groupby = config.get("groupby") agg = config.get("agg", "sum") bins = int(config.get("bins") or 10) sort_by = config.get("sort_by") sort_desc = bool(config.get("sort_desc", False)) top_n = config.get("top_n") fig = Figure(title=title, theme=theme, auto_display=False) fig.axes.xlabel = xlabel fig.axes.ylabel = ylabel # ── No DataFrame: generate illustrative sample data ────────────────── if df is None: fig = _build_sample_figure(kind, title, theme, color, label, fig) if auto_display: fig.show() return fig # ── Axis-free kinds (hist, box, pie, donut) ─────────────────────────── if kind == "hist": col = y_col or x_col or df.select_dtypes(include="number").columns[0] data = df[col].dropna().tolist() fig.add(HistogramSeries(data, bins=bins, color=color, label=label or col)) if auto_display: fig.show() return fig if kind == "box": if groupby and groupby in df.columns: groups = df[groupby].unique().tolist() col = y_col or df.select_dtypes(include="number").columns[0] arrays = [df[df[groupby] == g][col].dropna().tolist() for g in groups] fig.add(BoxPlotSeries(arrays, categories=[str(g) for g in groups], color=color or "#1f77b4")) else: col = y_col or x_col or df.select_dtypes(include="number").columns[0] data = df[col].dropna().tolist() fig.add(BoxPlotSeries(data, color=color or "#1f77b4", label=label or col)) if auto_display: fig.show() return fig if kind in {"pie", "donut"}: values, labels = _pie_data(df, x_col, y_col, agg) if kind == "pie": fig.add(PieSeries(values, labels=labels)) else: fig.add(DonutSeries(values, labels=[str(l) for l in labels])) if auto_display: fig.show() return fig # ── Aggregation (groupby) → single or multi-series ─────────────────── if groupby and groupby in df.columns: theme_colors = fig.theme.get("colors", ["#1f77b4", "#ff7f0e", "#2ca02c"]) if x_col and y_col and x_col in df.columns and y_col in df.columns: # Pivot: x = x_col, one series per groupby value for i, (grp_val, grp_df) in enumerate(df.groupby(groupby)): grp_color = theme_colors[i % len(theme_colors)] x_data = grp_df[x_col].tolist() y_data = grp_df[y_col].tolist() s = _make_series(kind, x_data, y_data, grp_color, str(grp_val)) if s: fig.add(s) else: # Aggregate y_col by groupby num_col = y_col or df.select_dtypes(include="number").columns[0] agg_func = {"sum": "sum", "mean": "mean", "count": "count", "max": "max", "min": "min"}.get(agg, "sum") agg_df = df.groupby(groupby)[num_col].agg(agg_func).reset_index() agg_df.columns = [groupby, num_col] agg_df = _apply_sort_top(agg_df, groupby, num_col, sort_by, sort_desc, top_n) x_data = agg_df[groupby].tolist() y_data = agg_df[num_col].tolist() s = _make_series(kind, x_data, y_data, color, label or f"{agg}({num_col})") if s: fig.add(s) # ── Simple x / y mapping ───────────────────────────────────────────── else: work_df = df.copy() if sort_by == "y" and y_col in work_df.columns: work_df = work_df.sort_values(y_col, ascending=not sort_desc) elif sort_by == "x" and x_col in work_df.columns: work_df = work_df.sort_values(x_col, ascending=not sort_desc) if top_n: work_df = work_df.head(int(top_n)) x_data = _coerce(work_df, x_col) or list(range(len(work_df))) y_data = _coerce(work_df, y_col) or work_df.select_dtypes(include="number").iloc[:, 0].tolist() s = _make_series(kind, x_data, y_data, color, label or y_col) if s: fig.add(s) if auto_display: fig.show() return fig def _make_series(kind, x, y, color, label): """Instantiate the right series class.""" from .series import LineSeries, BarSeries, ScatterSeries if kind == "bar": return BarSeries(x, y, color=color, label=label) if kind == "scatter": return ScatterSeries(x, y, color=color, label=label) return LineSeries(x, y, color=color, label=label) # default / "line" def _pie_data(df, x_col, y_col, agg): """Extract (values, labels) for pie/donut from a DataFrame.""" if x_col and y_col and x_col in df.columns and y_col in df.columns: grp = df.groupby(x_col)[y_col] fn = {"sum": grp.sum, "mean": grp.mean, "count": grp.count, "max": grp.max, "min": grp.min}.get(agg, grp.sum) agg_df = fn().reset_index() return agg_df[y_col].tolist(), agg_df[x_col].tolist() if y_col and y_col in df.columns: return df[y_col].tolist(), list(range(len(df))) col = df.select_dtypes(include="number").columns[0] return df[col].tolist(), list(range(len(df))) def _apply_sort_top(df, x_col, y_col, sort_by, sort_desc, top_n): if sort_by == "y" or top_n: df = df.sort_values(y_col, ascending=not sort_desc) elif sort_by == "x": df = df.sort_values(x_col, ascending=not sort_desc) if top_n: df = df.head(int(top_n)) return df def _build_sample_figure(kind, title, theme, color, label, fig): """Return a figure with illustrative data when no DataFrame is given.""" import math from .series import LineSeries, BarSeries, ScatterSeries, PieSeries, DonutSeries, HistogramSeries, BoxPlotSeries color = color or "#1f77b4" if kind == "bar": fig.add(BarSeries(["A", "B", "C", "D", "E"], [23, 47, 31, 56, 38], color=color, label=label or "Sample")) elif kind == "scatter": import random, math random.seed(42) x = [random.gauss(0, 1) for _ in range(60)] y = [v + random.gauss(0, 0.5) for v in x] fig.add(ScatterSeries(x, y, color=color, label=label or "Sample")) elif kind == "pie": fig.add(PieSeries([30, 25, 20, 15, 10], labels=["Alpha", "Beta", "Gamma", "Delta", "Epsilon"])) elif kind == "donut": fig.add(DonutSeries([30, 25, 20, 15, 10], labels=["Alpha", "Beta", "Gamma", "Delta", "Epsilon"])) elif kind == "hist": import random random.seed(0) data = [random.gauss(50, 15) for _ in range(200)] from .series import HistogramSeries fig.add(HistogramSeries(data, bins=15, color=color)) elif kind == "box": import random random.seed(1) fig.add(BoxPlotSeries([random.gauss(50, 10) for _ in range(100)], color=color, label=label or "Sample")) else: # line (default) x = list(range(20)) y = [math.sin(i * 0.4) * 10 + 20 for i in x] fig.add(LineSeries(x, y, color=color, label=label or "Sample")) return fig