diff --git a/.gitignore b/.gitignore
index 740f3993c..a06bad8c9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -268,6 +268,9 @@ renv.lock
# Planning documents (local only)
docs/plans/
+# Screenshot capture script (local only)
+pkg-py/docs/_screenshots/
+
# Playwright MCP
.playwright-mcp/
diff --git a/CLAUDE.md b/CLAUDE.md
index 4170fa304..98873f00e 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -69,13 +69,15 @@ make py-build
make py-docs
```
-Before finishing your implementation or committing any code, you should run:
+Before committing any Python code, you must run all three checks and confirm they pass:
```bash
uv run ruff check --fix pkg-py --config pyproject.toml
+make py-check-types
+make py-check-tests
```
-To get help with making sure code adheres to project standards.
+Do not commit or push until all three pass.
### R Package
diff --git a/docs/plans/2026-04-17-datasource-reader-bridge-design.md b/docs/plans/2026-04-17-datasource-reader-bridge-design.md
new file mode 100644
index 000000000..b597aba87
--- /dev/null
+++ b/docs/plans/2026-04-17-datasource-reader-bridge-design.md
@@ -0,0 +1,129 @@
+# DataSourceReader Bridge Design
+
+## Problem
+
+querychat executes ggsql queries in two phases: run the SQL on the real database, then replay the VISUALISE portion locally against the result in an in-memory DuckDB. This has two drawbacks:
+
+1. **Scaling** — the full SQL result must be pulled into Python memory, even when ggsql's stat transforms (histogram, density, boxplot) would reduce it to a small summary. A histogram of 10M rows pulls all 10M rows into memory only to bin them into ~30 buckets.
+
+2. **Multi-source layers** — ggsql supports per-layer data sources (e.g., a CTE fed to a different DRAW clause). The two-phase approach loses intermediate tables at the DataSource boundary, so querychat rejects these queries.
+
+Both problems stem from the same root cause: querychat splits the query at the SQL/VISUALISE boundary and runs each half independently, rather than letting ggsql run the full pipeline against the real database.
+
+## Solution
+
+For `SQLAlchemySource` data sources, implement a `DataSourceReader` — a Python object that satisfies ggsql's reader protocol (`execute_sql()`, `register()`, `unregister()`) by routing SQL to the real database. Pass this reader to `ggsql.execute(query, reader)`, letting ggsql run the entire pipeline (parsing, CTEs, stat transforms, everything) against the real DB.
+
+Use [sqlglot](https://github.com/tobymao/sqlglot) to transpile ggsql's ANSI-generated SQL to the target database dialect. This gives broad database coverage (31 dialects) without waiting for ggsql to add each one.
+
+Fall back to the current two-phase approach when the bridge fails (e.g., temp table permission denied, unsupported dialect, transpilation error) or for non-SQLAlchemy data sources.
+
+## Data flow
+
+### Bridge path (SQLAlchemySource)
+
+```
+ggsql.execute(query, DataSourceReader)
+ │
+ ├─ CTE materialization
+ │ execute_sql("SELECT … FROM orders GROUP BY …")
+ │ → sqlglot transpiles generic → target dialect
+ │ → runs on real DB
+ │ → result registered as temp table on real DB
+ │
+ ├─ Global SQL
+ │ execute_sql("SELECT * FROM orders WHERE …")
+ │ → runs on real DB
+ │ → result registered as temp table on real DB
+ │
+ ├─ Schema queries
+ │ execute_sql("SELECT … LIMIT 0")
+ │ → runs on real DB against temp tables
+ │
+ ├─ Stat transforms (histograms, density, boxplot, etc.)
+ │ execute_sql("WITH … SELECT … binning SQL …")
+ │ → sqlglot transpiles generated ANSI SQL → target dialect
+ │ → runs on real DB against temp tables
+ │
+ └─ Final layer queries
+ execute_sql("SELECT …")
+ → runs on real DB, small result set returned
+```
+
+### Fallback path (current approach, all DataSource types)
+
+```
+validated.sql()
+ → DataSource.execute_query() on real DB
+ → full result pulled into Python memory
+ → registered in local DuckDB
+ → ggsql replays VISUALISE portion locally
+```
+
+## Components
+
+### `DataSourceReader`
+
+Python class implementing ggsql's reader protocol. Lives in `_viz_ggsql.py`.
+
+- **Constructor**: takes a `sqlalchemy.Engine` and a sqlglot dialect string. Opens a single connection from the engine, held for the pipeline's duration.
+- **`execute_sql(sql)`**: transpiles from generic SQL to target dialect via `sqlglot.transpile(sql, read="", write=dialect)`, executes on the real DB via SQLAlchemy, returns a polars DataFrame.
+- **`register(name, df, replace)`**: creates a `TEMPORARY TABLE` on the real DB with column types derived from polars dtypes (generic SQL types, transpiled by sqlglot). Inserts rows in batches via SQLAlchemy. Tracks registered names for cleanup.
+- **`unregister(name)`**: drops the temp table on the real DB.
+- **Context manager**: `__exit__` drops all registered temp tables and closes the connection, ensuring cleanup even on error.
+
+### Dialect mapping
+
+```python
+SQLGLOT_DIALECTS = {
+ "postgresql": "postgres",
+ "snowflake": "snowflake",
+ "duckdb": "duckdb",
+ "sqlite": "sqlite",
+ "mysql": "mysql",
+ "mssql": "tsql",
+ "bigquery": "bigquery",
+ "redshift": "redshift",
+}
+```
+
+Maps `engine.dialect.name` to sqlglot dialect names. Unknown dialects skip the bridge and use the fallback.
+
+### Entry point
+
+```python
+def execute_ggsql(data_source, query, validated):
+ if isinstance(data_source, SQLAlchemySource):
+ dialect = SQLGLOT_DIALECTS.get(data_source._engine.dialect.name)
+ if dialect is not None:
+ try:
+ with DataSourceReader(data_source._engine, dialect) as reader:
+ return ggsql.execute(query, reader)
+ except Exception:
+ pass # fall through
+
+ # Fallback: current two-phase approach
+ return _execute_two_phase(data_source, validated)
+```
+
+### `_execute_two_phase`
+
+The current `execute_ggsql` body, renamed. Includes the existing regex-based `extract_visualise_table` and `has_layer_level_source` logic. Used for `DataFrameSource`, `PolarsLazySource`, `IbisSource`, and as the fallback for SQLAlchemy sources.
+
+## Dependencies
+
+- `sqlglot` added to the `viz` optional extra in `pyproject.toml`
+- No changes to ggsql required for the initial implementation
+
+## Scope boundaries
+
+- **SQLAlchemySource only** — IbisSource could follow later
+- **No ggsql changes required** — the `dialect` parameter contribution to `execute()` can come later as an optimization (skipping sqlglot when ggsql natively supports the dialect)
+- **No prompt changes** — the LLM already writes SQL for the correct `db_type`
+
+## Testing
+
+- Unit tests for `DataSourceReader`: mock SQLAlchemy connection, verify transpile + execute, register/unregister lifecycle, cleanup on error
+- Unit tests for sqlglot transpilation of ggsql's generated SQL patterns (recursive CTEs, NTILE percentiles, CREATE TEMPORARY TABLE) across key dialects
+- Integration test for fallback: verify bridge failure triggers two-phase approach
+- End-to-end with a test database connection if available
diff --git a/docs/plans/2026-04-17-datasource-reader-bridge.md b/docs/plans/2026-04-17-datasource-reader-bridge.md
new file mode 100644
index 000000000..4d42d8d61
--- /dev/null
+++ b/docs/plans/2026-04-17-datasource-reader-bridge.md
@@ -0,0 +1,825 @@
+# DataSourceReader Bridge Implementation Plan
+
+> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
+
+**Goal:** Implement a `DataSourceReader` that lets ggsql run its full pipeline against the real database (via SQLAlchemy), using sqlglot for dialect transpilation, with fallback to the current two-phase approach on failure.
+
+**Architecture:** `DataSourceReader` implements ggsql's reader protocol (`execute_sql`, `register`, `unregister`) by routing SQL through a held SQLAlchemy connection. sqlglot transpiles ggsql's ANSI-generated SQL to the target dialect. Temp tables are created on the real DB for ggsql's intermediate results. Falls back to the existing two-phase approach (renamed `execute_two_phase`) for non-SQLAlchemy sources or on bridge failure.
+
+**Tech Stack:** Python, sqlglot, SQLAlchemy, ggsql (PyReaderBridge), polars, pytest
+
+**Design doc:** `docs/plans/2026-04-17-datasource-reader-bridge-design.md`
+
+---
+
+### Task 1: Add sqlglot dependency
+
+**Files:**
+- Modify: `pyproject.toml:52`
+
+**Step 1: Add sqlglot to the viz extra**
+
+In `pyproject.toml`, change line 52 from:
+
+```toml
+viz = ["ggsql>=0.2.4", "altair>=6.0", "shinywidgets>=0.8.0", "vl-convert-python>=1.9.0"]
+```
+
+to:
+
+```toml
+viz = ["ggsql>=0.2.4", "altair>=6.0", "shinywidgets>=0.8.0", "vl-convert-python>=1.9.0", "sqlglot>=26.0"]
+```
+
+**Step 2: Install the updated dependencies**
+
+Run: `cd /Users/cpsievert/github/querychat && uv sync --extra viz`
+Expected: sqlglot installs successfully
+
+**Step 3: Verify import**
+
+Run: `cd /Users/cpsievert/github/querychat && uv run python -c "import sqlglot; print(sqlglot.__version__)"`
+Expected: Version prints without error
+
+**Step 4: Commit**
+
+```bash
+git add pyproject.toml uv.lock
+git commit -m "feat: add sqlglot dependency for DataSourceReader bridge"
+```
+
+---
+
+### Task 2: Write dialect mapping and transpile helper (tests first)
+
+**Files:**
+- Create: `pkg-py/tests/test_datasource_reader.py`
+- Create (later): `pkg-py/src/querychat/_datasource_reader.py`
+
+**Step 1: Write tests for dialect mapping and transpilation**
+
+Create `pkg-py/tests/test_datasource_reader.py`:
+
+```python
+"""Tests for DataSourceReader bridge."""
+
+import pytest
+
+
+class TestDialectMapping:
+ """Tests for SQLGLOT_DIALECTS mapping."""
+
+ def test_known_dialects_present(self):
+ from querychat._datasource_reader import SQLGLOT_DIALECTS
+
+ assert SQLGLOT_DIALECTS["postgresql"] == "postgres"
+ assert SQLGLOT_DIALECTS["snowflake"] == "snowflake"
+ assert SQLGLOT_DIALECTS["duckdb"] == "duckdb"
+ assert SQLGLOT_DIALECTS["sqlite"] == "sqlite"
+ assert SQLGLOT_DIALECTS["mysql"] == "mysql"
+ assert SQLGLOT_DIALECTS["mssql"] == "tsql"
+
+ def test_unknown_dialect_not_present(self):
+ from querychat._datasource_reader import SQLGLOT_DIALECTS
+
+ assert "oracle" not in SQLGLOT_DIALECTS
+
+
+class TestTranspileSql:
+ """Tests for transpile_sql() helper."""
+
+ def test_identity_for_duckdb(self):
+ from querychat._datasource_reader import transpile_sql
+
+ sql = "SELECT x, y FROM t WHERE x > 1"
+ result = transpile_sql(sql, "duckdb")
+ assert "SELECT" in result
+ assert "FROM" in result
+
+ def test_transpiles_create_temp_table_to_snowflake(self):
+ from querychat._datasource_reader import transpile_sql
+
+ sql = "CREATE TEMPORARY TABLE __ggsql_cte_0 AS SELECT x FROM t"
+ result = transpile_sql(sql, "snowflake")
+ assert "TEMPORARY" in result.upper() or "TEMP" in result.upper()
+ assert "__ggsql_cte_0" in result
+
+ def test_transpiles_recursive_cte_to_postgres(self):
+ from querychat._datasource_reader import transpile_sql
+
+ sql = (
+ "WITH RECURSIVE series AS ("
+ "SELECT 0 AS n UNION ALL SELECT n + 1 FROM series WHERE n < 10"
+ ") SELECT n FROM series"
+ )
+ result = transpile_sql(sql, "postgres")
+ assert "RECURSIVE" in result.upper()
+
+ def test_transpiles_ntile_to_snowflake(self):
+ from querychat._datasource_reader import transpile_sql
+
+ sql = "SELECT NTILE(4) OVER (ORDER BY x) AS quartile FROM t"
+ result = transpile_sql(sql, "snowflake")
+ assert "NTILE" in result.upper()
+
+ def test_passthrough_on_empty_dialect(self):
+ """Empty string dialect means generic/ANSI — should pass through."""
+ from querychat._datasource_reader import transpile_sql
+
+ sql = "SELECT 1"
+ result = transpile_sql(sql, "")
+ assert result == "SELECT 1"
+```
+
+**Step 2: Run tests to verify they fail**
+
+Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py -v`
+Expected: ImportError — `querychat._datasource_reader` does not exist
+
+**Step 3: Implement dialect mapping and transpile helper**
+
+Create `pkg-py/src/querychat/_datasource_reader.py`:
+
+```python
+"""DataSourceReader bridge: routes ggsql's reader protocol through a real database."""
+
+from __future__ import annotations
+
+import sqlglot
+
+SQLGLOT_DIALECTS: dict[str, str] = {
+ "postgresql": "postgres",
+ "snowflake": "snowflake",
+ "duckdb": "duckdb",
+ "sqlite": "sqlite",
+ "mysql": "mysql",
+ "mssql": "tsql",
+ "bigquery": "bigquery",
+ "redshift": "redshift",
+}
+
+
+def transpile_sql(sql: str, dialect: str) -> str:
+ """Transpile generic SQL to a target dialect using sqlglot."""
+ results = sqlglot.transpile(sql, read="", write=dialect)
+ return results[0]
+```
+
+**Step 4: Run tests to verify they pass**
+
+Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py -v`
+Expected: All pass
+
+**Step 5: Commit**
+
+```bash
+git add pkg-py/tests/test_datasource_reader.py pkg-py/src/querychat/_datasource_reader.py
+git commit -m "feat: add dialect mapping and transpile_sql helper"
+```
+
+---
+
+### Task 3: Implement DataSourceReader class (tests first)
+
+This is the core class. It implements ggsql's reader protocol by executing SQL on the real database via SQLAlchemy. Tests use a real SQLite database (in-memory) to verify end-to-end behavior.
+
+**Files:**
+- Modify: `pkg-py/tests/test_datasource_reader.py`
+- Modify: `pkg-py/src/querychat/_datasource_reader.py`
+
+**Step 1: Write tests for DataSourceReader lifecycle**
+
+Append to `pkg-py/tests/test_datasource_reader.py`:
+
+```python
+import polars as pl
+from sqlalchemy import create_engine, text
+
+
+@pytest.fixture
+def sqlite_engine():
+ """Create an in-memory SQLite database with test data."""
+ engine = create_engine("sqlite://")
+ with engine.connect() as conn:
+ conn.execute(text("CREATE TABLE test_data (x INTEGER, y INTEGER, label TEXT)"))
+ conn.execute(
+ text("INSERT INTO test_data VALUES (1, 10, 'a'), (2, 20, 'b'), (3, 30, 'a')")
+ )
+ conn.commit()
+ return engine
+
+
+class TestDataSourceReader:
+ """Tests for DataSourceReader against a real SQLite database."""
+
+ def test_execute_sql_returns_polars(self, sqlite_engine):
+ from querychat._datasource_reader import DataSourceReader
+
+ with DataSourceReader(sqlite_engine, "sqlite") as reader:
+ df = reader.execute_sql("SELECT * FROM test_data")
+ assert isinstance(df, pl.DataFrame)
+ assert len(df) == 3
+ assert set(df.columns) == {"x", "y", "label"}
+
+ def test_execute_sql_with_filter(self, sqlite_engine):
+ from querychat._datasource_reader import DataSourceReader
+
+ with DataSourceReader(sqlite_engine, "sqlite") as reader:
+ df = reader.execute_sql("SELECT * FROM test_data WHERE x > 1")
+ assert len(df) == 2
+
+ def test_register_creates_temp_table(self, sqlite_engine):
+ from querychat._datasource_reader import DataSourceReader
+
+ df = pl.DataFrame({"a": [1, 2], "b": ["x", "y"]})
+ with DataSourceReader(sqlite_engine, "sqlite") as reader:
+ reader.register("my_temp", df, True)
+ result = reader.execute_sql("SELECT * FROM my_temp")
+ assert len(result) == 2
+ assert set(result.columns) == {"a", "b"}
+
+ def test_unregister_drops_temp_table(self, sqlite_engine):
+ from querychat._datasource_reader import DataSourceReader
+
+ df = pl.DataFrame({"a": [1]})
+ with DataSourceReader(sqlite_engine, "sqlite") as reader:
+ reader.register("drop_me", df, True)
+ reader.unregister("drop_me")
+ with pytest.raises(Exception, match="drop_me"):
+ reader.execute_sql("SELECT * FROM drop_me")
+
+ def test_context_manager_cleans_up_temp_tables(self, sqlite_engine):
+ from querychat._datasource_reader import DataSourceReader
+
+ df = pl.DataFrame({"a": [1]})
+ with DataSourceReader(sqlite_engine, "sqlite") as reader:
+ reader.register("cleanup_test", df, True)
+
+ # After exiting context, temp table should be gone.
+ # SQLite temp tables are connection-scoped, so they vanish
+ # when the connection closes. Verify by opening a new connection.
+ with sqlite_engine.connect() as conn:
+ result = conn.execute(
+ text("SELECT name FROM sqlite_temp_master WHERE name = 'cleanup_test'")
+ )
+ assert result.fetchone() is None
+
+ def test_register_replace_overwrites(self, sqlite_engine):
+ from querychat._datasource_reader import DataSourceReader
+
+ df1 = pl.DataFrame({"a": [1, 2]})
+ df2 = pl.DataFrame({"a": [10, 20, 30]})
+ with DataSourceReader(sqlite_engine, "sqlite") as reader:
+ reader.register("replace_me", df1, True)
+ reader.register("replace_me", df2, True)
+ result = reader.execute_sql("SELECT * FROM replace_me")
+ assert len(result) == 3
+
+ def test_execute_sql_transpiles(self, sqlite_engine):
+ """Verify that generated SQL gets transpiled to the target dialect."""
+ from querychat._datasource_reader import DataSourceReader
+
+ with DataSourceReader(sqlite_engine, "sqlite") as reader:
+ # This is valid generic SQL; sqlglot should pass it through for SQLite
+ df = reader.execute_sql("SELECT x, y FROM test_data ORDER BY x LIMIT 2")
+ assert len(df) == 2
+```
+
+**Step 2: Run tests to verify they fail**
+
+Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py::TestDataSourceReader -v`
+Expected: ImportError — `DataSourceReader` not found
+
+**Step 3: Implement DataSourceReader**
+
+Add to `pkg-py/src/querychat/_datasource_reader.py` (below the existing code):
+
+```python
+from typing import TYPE_CHECKING
+
+import polars as pl
+from sqlalchemy import text
+
+if TYPE_CHECKING:
+ from sqlalchemy.engine import Connection, Engine
+
+
+class DataSourceReader:
+ """
+ ggsql reader protocol implementation that routes SQL through a real database.
+
+ Implements execute_sql(), register(), and unregister() as expected by
+ ggsql's PyReaderBridge. Uses sqlglot to transpile ggsql's ANSI-generated
+ SQL to the target database dialect.
+ """
+
+ def __init__(self, engine: Engine, dialect: str):
+ self._engine = engine
+ self._dialect = dialect
+ self._conn: Connection | None = None
+ self._registered: list[str] = []
+
+ def __enter__(self):
+ self._conn = self._engine.connect()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self._conn is not None:
+ try:
+ for name in self._registered:
+ try:
+ self._conn.execute(text(f"DROP TABLE IF EXISTS {name}"))
+ except Exception:
+ pass
+ self._conn.commit()
+ finally:
+ self._conn.close()
+ self._conn = None
+ self._registered.clear()
+ return False
+
+ def execute_sql(self, sql: str) -> pl.DataFrame:
+ assert self._conn is not None, "DataSourceReader must be used as a context manager"
+ transpiled = transpile_sql(sql, self._dialect)
+ result = self._conn.execute(text(transpiled))
+ rows = result.fetchall()
+ columns = list(result.keys())
+ if not rows:
+ return pl.DataFrame(schema={col: pl.Utf8 for col in columns})
+ data = {col: [row[i] for row in rows] for i, col in enumerate(columns)}
+ return pl.DataFrame(data)
+
+ def register(self, name: str, df: pl.DataFrame, replace: bool = True) -> None:
+ assert self._conn is not None, "DataSourceReader must be used as a context manager"
+ if replace:
+ self._conn.execute(text(f"DROP TABLE IF EXISTS {name}"))
+ if name in self._registered:
+ self._registered.remove(name)
+
+ col_defs = ", ".join(
+ f"{col} {_polars_to_sql_type(dtype)}" for col, dtype in zip(df.columns, df.dtypes)
+ )
+ create_sql = f"CREATE TEMPORARY TABLE {name} ({col_defs})"
+ transpiled_create = transpile_sql(create_sql, self._dialect)
+ self._conn.execute(text(transpiled_create))
+ self._registered.append(name)
+
+ if len(df) > 0:
+ placeholders = ", ".join(f":{col}" for col in df.columns)
+ insert_sql = f"INSERT INTO {name} VALUES ({placeholders})"
+ rows = df.to_dicts()
+ self._conn.execute(text(insert_sql), rows)
+
+ self._conn.commit()
+
+ def unregister(self, name: str) -> None:
+ assert self._conn is not None, "DataSourceReader must be used as a context manager"
+ self._conn.execute(text(f"DROP TABLE IF EXISTS {name}"))
+ self._conn.commit()
+ if name in self._registered:
+ self._registered.remove(name)
+
+
+def _polars_to_sql_type(dtype: pl.DataType) -> str:
+ """Map polars dtypes to generic SQL types for CREATE TABLE."""
+ if dtype.is_integer():
+ return "INTEGER"
+ if dtype.is_float():
+ return "REAL"
+ if dtype == pl.Boolean:
+ return "BOOLEAN"
+ if dtype == pl.Date:
+ return "DATE"
+ if dtype == pl.Datetime or dtype == pl.Duration:
+ return "TIMESTAMP"
+ return "TEXT"
+```
+
+**Step 4: Run tests to verify they pass**
+
+Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py -v`
+Expected: All pass
+
+**Step 5: Commit**
+
+```bash
+git add pkg-py/tests/test_datasource_reader.py pkg-py/src/querychat/_datasource_reader.py
+git commit -m "feat: implement DataSourceReader with temp table lifecycle"
+```
+
+---
+
+### Task 4: Integrate with ggsql — end-to-end test with SQLite
+
+Verify that `DataSourceReader` works with `ggsql.execute(query, reader)` against a real SQLite database.
+
+**Files:**
+- Modify: `pkg-py/tests/test_datasource_reader.py`
+
+**Step 1: Write end-to-end test**
+
+Append to `pkg-py/tests/test_datasource_reader.py`:
+
+```python
+@pytest.mark.ggsql
+class TestDataSourceReaderWithGgsql:
+ """End-to-end tests: DataSourceReader + ggsql.execute()."""
+
+ def test_simple_scatter(self, sqlite_engine):
+ import ggsql
+ from querychat._datasource_reader import DataSourceReader
+
+ with DataSourceReader(sqlite_engine, "sqlite") as reader:
+ spec = ggsql.execute(
+ "SELECT x, y FROM test_data VISUALISE x, y DRAW point",
+ reader,
+ )
+ assert spec.metadata()["rows"] == 3
+ assert "VISUALISE" in spec.visual()
+
+ def test_with_filter(self, sqlite_engine):
+ import ggsql
+ from querychat._datasource_reader import DataSourceReader
+
+ with DataSourceReader(sqlite_engine, "sqlite") as reader:
+ spec = ggsql.execute(
+ "SELECT x, y FROM test_data WHERE x > 1 VISUALISE x, y DRAW point",
+ reader,
+ )
+ assert spec.metadata()["rows"] == 2
+
+ def test_form_b_visualise_from(self, sqlite_engine):
+ import ggsql
+ from querychat._datasource_reader import DataSourceReader
+
+ with DataSourceReader(sqlite_engine, "sqlite") as reader:
+ spec = ggsql.execute(
+ "VISUALISE x, y FROM test_data DRAW point",
+ reader,
+ )
+ assert spec.metadata()["rows"] == 3
+
+ def test_with_aggregation(self, sqlite_engine):
+ import ggsql
+ from querychat._datasource_reader import DataSourceReader
+
+ with DataSourceReader(sqlite_engine, "sqlite") as reader:
+ spec = ggsql.execute(
+ "SELECT label, SUM(y) AS total FROM test_data GROUP BY label "
+ "VISUALISE label AS x, total AS y DRAW bar",
+ reader,
+ )
+ assert spec.metadata()["rows"] == 2
+```
+
+**Step 2: Run tests**
+
+Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py::TestDataSourceReaderWithGgsql -v`
+Expected: All pass (if the reader protocol is correctly implemented). If any fail, debug and fix.
+
+**Step 3: Commit**
+
+```bash
+git add pkg-py/tests/test_datasource_reader.py
+git commit -m "test: end-to-end DataSourceReader with ggsql.execute()"
+```
+
+---
+
+### Task 5: Refactor execute_ggsql — rename current body, add bridge path
+
+**Files:**
+- Modify: `pkg-py/src/querychat/_viz_ggsql.py`
+- Modify: `pkg-py/src/querychat/_viz_tools.py:182`
+- Modify: `pkg-py/tests/test_ggsql.py`
+
+**Step 1: Write test for bridge+fallback behavior**
+
+Add to `pkg-py/tests/test_datasource_reader.py`:
+
+```python
+import narwhals.stable.v1 as nw
+
+
+class TestExecuteGgsqlBridge:
+ """Tests for the updated execute_ggsql entry point with bridge+fallback."""
+
+ @pytest.mark.ggsql
+ def test_sqlalchemy_source_uses_bridge(self, sqlite_engine):
+ """SQLAlchemySource with known dialect should use the bridge path."""
+ import ggsql
+ from querychat._datasource import SQLAlchemySource
+ from querychat._viz_ggsql import execute_ggsql
+
+ # Create a table that SQLAlchemySource can find
+ with sqlite_engine.connect() as conn:
+ conn.execute(text("CREATE TABLE IF NOT EXISTS bridge_data (x INTEGER, y INTEGER)"))
+ conn.execute(text("INSERT INTO bridge_data VALUES (1, 10), (2, 20), (3, 30)"))
+ conn.commit()
+
+ ds = SQLAlchemySource(sqlite_engine, "bridge_data")
+ query = "SELECT x, y FROM bridge_data VISUALISE x, y DRAW point"
+ validated = ggsql.validate(query)
+ spec = execute_ggsql(ds, query, validated)
+ assert spec.metadata()["rows"] == 3
+
+ @pytest.mark.ggsql
+ def test_dataframe_source_uses_fallback(self):
+ """DataFrameSource should always use the fallback path."""
+ import ggsql
+ from querychat._datasource import DataFrameSource
+ from querychat._viz_ggsql import execute_ggsql
+
+ nw_df = nw.from_native(pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}))
+ ds = DataFrameSource(nw_df, "test_data")
+ query = "SELECT * FROM test_data VISUALISE x, y DRAW point"
+ validated = ggsql.validate(query)
+ spec = execute_ggsql(ds, query, validated)
+ assert spec.metadata()["rows"] == 3
+```
+
+**Step 2: Run new tests to verify they fail**
+
+Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py::TestExecuteGgsqlBridge -v`
+Expected: TypeError — `execute_ggsql()` doesn't accept 3 arguments yet
+
+**Step 3: Update execute_ggsql signature and add bridge logic**
+
+Modify `pkg-py/src/querychat/_viz_ggsql.py`. The full updated file:
+
+```python
+"""
+Helpers for executing ggsql queries in querychat.
+
+Architecture overview
+---------------------
+Querychat executes ggsql queries through two possible paths:
+
+1. **Bridge path** (SQLAlchemySource with known dialect) — A
+ ``DataSourceReader`` implements ggsql's reader protocol, routing all SQL
+ through the real database. ggsql runs its full pipeline (CTEs, stat
+ transforms, layer queries) against the real DB. sqlglot transpiles
+ ggsql's ANSI-generated SQL to the target dialect. This path supports
+ multi-source layers and avoids pulling large result sets into memory.
+
+2. **Fallback path** (all other DataSource types, or bridge failure) — The
+ SQL portion (before VISUALISE) runs on the real database via
+ ``DataSource.execute_query()``, then the VISUALISE portion replays
+ locally against the SQL result using ``ggsql.DuckDBReader``.
+
+The fallback path requires reconstructing a valid ggsql query from the
+split ``sql()`` and ``visual()`` parts. See ``execute_two_phase()`` for
+details on the two VISUALISE forms (Form A and Form B).
+
+Limitation of fallback path: layer-specific sources
+----------------------------------------------------
+ggsql supports per-layer data sources (``DRAW line MAPPING … FROM cte``),
+but the fallback path can't support them because the SQL result is a single
+DataFrame — CTEs don't survive the DataSource boundary. The bridge path
+handles this correctly.
+"""
+
+from __future__ import annotations
+
+import logging
+import re
+from typing import TYPE_CHECKING
+
+from ._utils import to_polars
+
+if TYPE_CHECKING:
+ import ggsql
+
+ from ._datasource import DataSource
+
+logger = logging.getLogger(__name__)
+
+
+def execute_ggsql(
+ data_source: DataSource,
+ query: str,
+ validated: ggsql.Validated,
+) -> ggsql.Spec:
+ """
+ Execute a ggsql query, choosing the bridge or fallback path.
+
+ Parameters
+ ----------
+ data_source
+ The querychat DataSource to execute against.
+ query
+ The original ggsql query string (needed for the bridge path).
+ validated
+ A pre-validated ggsql query (from ``ggsql.validate()``).
+
+ Returns
+ -------
+ ggsql.Spec
+ The writer-independent plot specification.
+
+ """
+ from ._datasource import SQLAlchemySource
+ from ._datasource_reader import SQLGLOT_DIALECTS, DataSourceReader
+
+ if isinstance(data_source, SQLAlchemySource):
+ dialect = SQLGLOT_DIALECTS.get(data_source._engine.dialect.name)
+ if dialect is not None:
+ try:
+ with DataSourceReader(data_source._engine, dialect) as reader:
+ import ggsql as _ggsql
+
+ return _ggsql.execute(query, reader)
+ except Exception:
+ logger.debug(
+ "DataSourceReader bridge failed, falling back to two-phase",
+ exc_info=True,
+ )
+
+ return execute_two_phase(data_source, validated)
+
+
+def execute_two_phase(
+ data_source: DataSource,
+ validated: ggsql.Validated,
+) -> ggsql.Spec:
+ """
+ Execute a ggsql query using the two-phase approach.
+
+ Phase 1: execute SQL on the real database.
+ Phase 2: replay the VISUALISE portion locally in DuckDB.
+
+ This is the fallback for non-SQLAlchemy sources or when the bridge fails.
+ """
+ from ggsql import DuckDBReader
+
+ visual = validated.visual()
+ if has_layer_level_source(visual):
+ raise ValueError(
+ "Layer-specific sources are not currently supported in querychat visual "
+ "queries. Rewrite the query so that all layers come from the final SQL "
+ "result."
+ )
+
+ pl_df = to_polars(data_source.execute_query(validated.sql()))
+ pl_df.columns = [c.lower() for c in pl_df.columns]
+
+ reader = DuckDBReader("duckdb://memory")
+ table = extract_visualise_table(visual)
+
+ if table is not None:
+ name = table[1:-1] if table.startswith('"') and table.endswith('"') else table
+ reader.register(name, pl_df)
+ return reader.execute(visual)
+ else:
+ reader.register("_data", pl_df)
+ return reader.execute(f"SELECT * FROM _data {visual}")
+
+
+def extract_visualise_table(visual: str) -> str | None:
+ """
+ Extract the table name from ``VISUALISE … FROM
`` if present.
+
+ Only looks at the portion before the first DRAW clause, since FROM after
+ DRAW belongs to layer-level MAPPING (a different concern).
+ """
+ draw_pos = re.search(r"\bDRAW\b", visual, re.IGNORECASE)
+ vis_clause = visual[: draw_pos.start()] if draw_pos else visual
+ m = re.search(r'\bFROM\s+("[^"]+?"|\S+)', vis_clause, re.IGNORECASE)
+ return m.group(1) if m else None
+
+
+def has_layer_level_source(visual: str) -> bool:
+ """
+ Return ``True`` when a DRAW clause defines its own ``FROM ``.
+ """
+ clauses = re.split(
+ r"(?=\b(?:DRAW|SCALE|PROJECT|FACET|PLACE|LABEL|THEME)\b)",
+ visual,
+ flags=re.IGNORECASE,
+ )
+ for clause in clauses:
+ if not re.match(r"^\s*DRAW\b", clause, re.IGNORECASE):
+ continue
+ if re.search(
+ r'\bMAPPING\b[\s\S]*?\bFROM\s+("[^"]+?"|\S+)',
+ clause,
+ re.IGNORECASE,
+ ):
+ return True
+ return False
+```
+
+**Step 4: Update the caller in `_viz_tools.py`**
+
+In `pkg-py/src/querychat/_viz_tools.py`, change line 182 from:
+
+```python
+ spec = execute_ggsql(data_source, validated)
+```
+
+to:
+
+```python
+ spec = execute_ggsql(data_source, ggsql, validated)
+```
+
+Note: `ggsql` here is the local parameter name (the query string) from line 151, not the module. The module import at the top of the function (`from ggsql import VegaLiteWriter, validate`) is a different scope.
+
+**Wait** — there's a name collision. The parameter is `ggsql: str` (line 151) and the module import is `from ggsql import ...` (line 148). The parameter shadows the module name within `visualize_query`. But `execute_ggsql` is imported at the module level, not from the `ggsql` package, so the call `execute_ggsql(data_source, ggsql, validated)` correctly passes the string parameter. This works.
+
+**Step 5: Update existing tests in `test_ggsql.py`**
+
+The existing `TestExecuteGgsql` tests call `execute_ggsql(ds, ggsql.validate(query))` with 2 args. Update them to pass 3 args. In `pkg-py/tests/test_ggsql.py`, update each call in `TestExecuteGgsql`:
+
+Change every occurrence of:
+```python
+ spec = execute_ggsql(ds, ggsql.validate(query))
+```
+to:
+```python
+ spec = execute_ggsql(ds, query, ggsql.validate(query))
+```
+
+Also update the layer-level source test (line 188):
+```python
+ execute_ggsql(ds, ggsql.validate(query))
+```
+to:
+```python
+ execute_ggsql(ds, query, ggsql.validate(query))
+```
+
+**Step 6: Run all tests**
+
+Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py pkg-py/tests/test_ggsql.py -v`
+Expected: All pass
+
+**Step 7: Run type checker and linter**
+
+Run: `cd /Users/cpsievert/github/querychat && uv run ruff check --fix pkg-py --config pyproject.toml && make py-check-types`
+Expected: No errors (fix any that arise)
+
+**Step 8: Commit**
+
+```bash
+git add pkg-py/src/querychat/_viz_ggsql.py pkg-py/src/querychat/_viz_tools.py pkg-py/tests/test_ggsql.py pkg-py/tests/test_datasource_reader.py
+git commit -m "feat: integrate DataSourceReader bridge into execute_ggsql"
+```
+
+---
+
+### Task 6: Run full test suite and fix any issues
+
+**Files:**
+- Potentially any file touched above
+
+**Step 1: Run full Python checks**
+
+Run: `cd /Users/cpsievert/github/querychat && uv run ruff check --fix pkg-py --config pyproject.toml`
+Expected: Clean
+
+**Step 2: Run type checker**
+
+Run: `cd /Users/cpsievert/github/querychat && make py-check-types`
+Expected: Clean
+
+**Step 3: Run full test suite**
+
+Run: `cd /Users/cpsievert/github/querychat && make py-check-tests`
+Expected: All pass
+
+**Step 4: Fix any failures**
+
+If any tests fail, debug and fix. Common issues:
+- sqlglot transpilation producing unexpected SQL for a specific dialect
+- Empty DataFrame handling in `execute_sql` (no rows returned)
+- Polars dtype mapping edge cases in `_polars_to_sql_type`
+
+**Step 5: Commit any fixes**
+
+```bash
+git add -u
+git commit -m "fix: address issues from full test suite run"
+```
+
+---
+
+### Task 7: Update module docstring and clean up
+
+**Files:**
+- Modify: `pkg-py/src/querychat/_datasource_reader.py`
+
+**Step 1: Ensure the module docstring is accurate**
+
+The docstring added in Task 3 should already be correct. Verify `_datasource_reader.py` has a clear module-level docstring explaining the bridge's purpose.
+
+**Step 2: Verify no dead code remains**
+
+Check that the old `execute_ggsql` docstring in `_viz_ggsql.py` has been updated to reflect the new 3-arg signature and bridge behavior (done in Task 5).
+
+**Step 3: Final commit if any cleanup was needed**
+
+```bash
+git add -u
+git commit -m "docs: clean up DataSourceReader module docstrings"
+```
diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md
index a02fe5f68..20785d727 100644
--- a/pkg-py/CHANGELOG.md
+++ b/pkg-py/CHANGELOG.md
@@ -9,16 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### New features
+* Added a `"visualize_query"` tool that lets the LLM create inline Altair charts from natural language requests using [ggsql](https://github.com/posit-dev/ggsql) — a SQL extension for declarative data visualization. Include it via `tools=("query", "visualize_query")` (or alongside `"update"`). Charts render inline in the chat with fullscreen support, a "Show Query" toggle, and Save as PNG/SVG. Install the optional dependencies with `pip install querychat[viz]`. (#219)
+
+* The `querychat_query` tool now accepts an optional `collapsed` parameter. When `collapsed=True`, the result card starts collapsed so preparatory or exploratory queries don't clutter the conversation. The LLM is guided to use this automatically when running queries before a visualization.
+
+* Added support for Snowflake Semantic Views. When connected to Snowflake (via SQLAlchemy or Ibis), querychat automatically discovers available Semantic Views and includes their definitions in the system prompt. This helps the LLM generate correct queries using the `SEMANTIC_VIEW()` table function with certified business metrics and dimensions. (#200)
+
* `QueryChat()` now supports deferred chat client initialization. Pass `client=` to `server()` to provide a session-scoped chat client, enabling use cases where API credentials are only available at session time (e.g., Posit Connect managed OAuth tokens). When no `client` is specified anywhere, querychat resolves a sensible default from the `QUERYCHAT_CLIENT` environment variable (or `"openai"`). (#205)
### Improvements
* When a custom `prompt_template` is provided that doesn't contain Mustache references to `{{schema}}`, the expensive `get_schema()` call is now skipped entirely. This allows users with large databases to avoid slow startup by providing their own prompt that includes schema information inline (or omits it). (#208)
-### New features
-
-* Added support for Snowflake Semantic Views. When connected to Snowflake (via SQLAlchemy or Ibis), querychat automatically discovers available Semantic Views and includes their definitions in the system prompt. This helps the LLM generate correct queries using the `SEMANTIC_VIEW()` table function with certified business metrics and dimensions. (#200)
-
## [0.5.1] - 2026-01-23
### New features
diff --git a/pkg-py/docs/_quarto.yml b/pkg-py/docs/_quarto.yml
index df2576e49..7574fb787 100644
--- a/pkg-py/docs/_quarto.yml
+++ b/pkg-py/docs/_quarto.yml
@@ -50,6 +50,7 @@ website:
- models.qmd
- data-sources.qmd
- context.qmd
+ - visualize.qmd
- section: "Build custom apps"
contents:
- build-intro.qmd
@@ -114,6 +115,8 @@ quartodoc:
signature_name: short
- name: tools.tool_reset_dashboard
signature_name: short
+ - name: tools.tool_visualize_query
+ signature_name: short
filters:
- "interlinks"
diff --git a/pkg-py/docs/build.qmd b/pkg-py/docs/build.qmd
index 009f6cfd0..f4ff68abd 100644
--- a/pkg-py/docs/build.qmd
+++ b/pkg-py/docs/build.qmd
@@ -31,6 +31,14 @@ from querychat.data import titanic
qc = QueryChat(titanic(), "titanic")
```
+::: {.callout-tip}
+### Visualization support
+
+querychat supports an optional visualization tool that lets the LLM create inline charts.
+Enable it by including `"visualize_query"` in the `tools` parameter.
+See [Visualizations](visualize.qmd) for details.
+:::
+
::: {.callout-note collapse="true"}
## Quick start with `.app()`
diff --git a/pkg-py/docs/images/viz-bar-chart.png b/pkg-py/docs/images/viz-bar-chart.png
new file mode 100644
index 000000000..0a7033651
Binary files /dev/null and b/pkg-py/docs/images/viz-bar-chart.png differ
diff --git a/pkg-py/docs/images/viz-fullscreen.png b/pkg-py/docs/images/viz-fullscreen.png
new file mode 100644
index 000000000..fbefca3fb
Binary files /dev/null and b/pkg-py/docs/images/viz-fullscreen.png differ
diff --git a/pkg-py/docs/images/viz-scatter.png b/pkg-py/docs/images/viz-scatter.png
new file mode 100644
index 000000000..db25bfe2a
Binary files /dev/null and b/pkg-py/docs/images/viz-scatter.png differ
diff --git a/pkg-py/docs/images/viz-show-query.png b/pkg-py/docs/images/viz-show-query.png
new file mode 100644
index 000000000..fc9ae6384
Binary files /dev/null and b/pkg-py/docs/images/viz-show-query.png differ
diff --git a/pkg-py/docs/index.qmd b/pkg-py/docs/index.qmd
index 88b0da5c5..8d119ae3b 100644
--- a/pkg-py/docs/index.qmd
+++ b/pkg-py/docs/index.qmd
@@ -75,6 +75,11 @@ querychat can also handle more general questions about the data that require cal
{fig-alt="Screenshot of the querychat's app with a summary statistic inlined in the chat." class="lightbox shadow rounded mb-3"}
+querychat can also create visualizations, powered by [ggsql](https://ggsql.org/) and [Altair](https://altair-viz.github.io/).
+With the [visualization tool](visualize.qmd) enabled, ask for a chart and it appears inline in the conversation:
+
+{fig-alt="Screenshot of querychat with an inline bar chart showing survival rate by passenger class." class="lightbox shadow rounded mb-3"}
+
## Web frameworks
While the examples above use [Shiny](https://shiny.posit.co/py/), querychat also supports [Streamlit](https://streamlit.io/), [Gradio](https://gradio.app/), and [Dash](https://dash.plotly.com/). Each framework has its own `QueryChat` class under the relevant sub-module, but the methods and properties are mostly consistent across all of them.
diff --git a/pkg-py/docs/tools.qmd b/pkg-py/docs/tools.qmd
index e438e1bde..09b46be9d 100644
--- a/pkg-py/docs/tools.qmd
+++ b/pkg-py/docs/tools.qmd
@@ -6,7 +6,7 @@ querychat combines [tool calling](https://posit-dev.github.io/chatlas/get-starte
One important thing to understand generally about querychat's tools is they are Python functions, and that execution happens on _your machine_, not on the LLM provider's side. In other words, the SQL queries generated by the LLM are executed locally in the Python process running the app.
-querychat provides the LLM access two tool groups:
+querychat provides the LLM access to three tool groups:
1. **Data updating** - Filter and sort data (without sending results to the LLM).
2. **Data analysis** - Calculate summaries and return results for interpretation by the LLM.
@@ -52,6 +52,40 @@ app = qc.app()
{fig-alt="Screenshot of the querychat's app with a summary statistic inlined in the chat." class="lightbox shadow rounded mb-3"}
+## Data visualization
+
+When a user asks for a chart or visualization, the LLM generates a [ggsql](https://ggsql.org/) query — standard SQL extended with a `VISUALISE` clause — and requests a call to the `visualize_query` tool.
+This tool:
+
+1. Executes the SQL portion of the query
+2. Renders the `VISUALISE` clause as an Altair chart
+3. Displays the chart inline in the chat
+
+Unlike the data updating tools, visualization queries don't affect the dashboard filter.
+They query the full dataset independently, and each call produces a new inline chart message in the chat.
+
+The inline chart includes controls for fullscreen viewing, saving as PNG/SVG, and a "Show Query" toggle that reveals the underlying ggsql code.
+
+To use the visualization tool, first install the `viz` extras:
+
+```bash
+pip install "querychat[viz]"
+```
+
+Then include `"visualize_query"` in the `tools` parameter (it is not enabled by default):
+
+```{.python filename="viz-app.py"}
+from querychat import QueryChat
+from querychat.data import titanic
+
+qc = QueryChat(titanic(), "titanic", tools=("query", "update", "visualize_query"))
+app = qc.app()
+```
+
+{fig-alt="Screenshot of querychat with an inline scatter plot." class="lightbox shadow rounded mb-3"}
+
+See [Visualizations](visualize.qmd) for more details.
+
## View the source
If you'd like to better understand how the tools work and how the LLM is prompted to use them, check out the following resources:
@@ -65,3 +99,4 @@ If you'd like to better understand how the tools work and how the LLM is prompte
- [`prompts/tool-update-dashboard.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-update-dashboard.md)
- [`prompts/tool-reset-dashboard.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-reset-dashboard.md)
- [`prompts/tool-query.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-query.md)
+- [`prompts/tool-visualize-query.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-visualize-query.md)
diff --git a/pkg-py/docs/visualize.qmd b/pkg-py/docs/visualize.qmd
new file mode 100644
index 000000000..8074869b2
--- /dev/null
+++ b/pkg-py/docs/visualize.qmd
@@ -0,0 +1,104 @@
+---
+title: Visualizations
+lightbox: true
+---
+
+querychat can create charts inline in the chat.
+When you ask a question that benefits from a visualization, the LLM writes a query using [ggsql](https://ggsql.org/) — a SQL-like visualization grammar — and renders an [Altair](https://altair-viz.github.io/) chart directly in the conversation.
+
+## Getting started
+
+Visualization requires two steps:
+
+1. **Install the `viz` extras:**
+
+ ```bash
+ pip install "querychat[viz]"
+ ```
+
+2. **Include `"visualize_query"` in the `tools` parameter:**
+
+ ```{.python filename="app.py"}
+ from querychat import QueryChat
+ from querychat.data import titanic
+
+ qc = QueryChat(titanic(), "titanic", tools=("query", "update", "visualize_query"))
+ app = qc.app()
+ ```
+
+Ask something like "Show me survival rate by passenger class as a bar chart" and querychat will generate and display the chart inline:
+
+{fig-alt="Bar chart showing survival rate by passenger class." class="lightbox shadow rounded mb-3"}
+
+## Choosing tools
+
+The `tools` parameter controls which capabilities the LLM has access to.
+By default, only `"query"` and `"update"` are enabled — visualization must be opted into explicitly.
+
+To enable only query and visualization (no dashboard filtering):
+
+```{.python}
+qc = QueryChat(titanic(), "titanic", tools=("query", "visualize_query"))
+```
+
+See [Tools](tools.qmd) for a full reference on available tools and what each one does.
+
+## Custom apps
+
+The example below shows a minimal custom Shiny app using only the `"query"` and `"visualize_query"` tools.
+It omits `"update"` to focus entirely on data analysis and visualization rather than dashboard filtering:
+
+```{.python filename="app.py"}
+{{< include /../examples/10-viz-app.py >}}
+```
+
+## What you can ask for
+
+querychat can generate a wide range of chart types.
+Some example prompts:
+
+- "Show me a bar chart of survival rate by passenger class"
+- "Scatter plot of age vs fare, colored by survival"
+- "Line chart of average fare over time"
+- "Histogram of passenger ages"
+- "Facet survival rate by class and sex"
+
+The LLM chooses an appropriate chart type based on your question, but you can always be specific.
+If you ask for a bar chart, you'll get a bar chart.
+
+{fig-alt="Scatter plot of age vs fare colored by survival status." class="lightbox shadow rounded mb-3"}
+
+::: {.callout-tip}
+If you don't like the chart, ask the LLM to adjust it — for example, "make the dots bigger" or "use a log scale on the y-axis".
+:::
+
+## Chart controls
+
+Each chart has controls in its footer:
+
+**Fullscreen** — Click the expand icon to view the chart in fullscreen mode.
+
+{fig-alt="A chart displayed in fullscreen mode." class="lightbox shadow rounded mb-3"}
+
+**Save** — Download the chart as a PNG or SVG file.
+
+**Show Query** — Expand the footer to see the ggsql query used to generate the chart.
+
+{fig-alt="A chart with the Show Query footer expanded, showing the ggsql query." class="lightbox shadow rounded mb-3"}
+
+## How it works
+
+1. **The LLM generates a ggsql query** — a SQL-like grammar that describes both data transformation and visual encoding in a single statement.
+2. **The SQL is executed** — querychat runs the data portion of the query against your data source locally.
+3. **The VISUALISE clause is rendered** — the result is passed to Altair, which produces a Vega-Lite chart specification.
+4. **The chart appears inline** — the chart is streamed back into the conversation as an interactive widget.
+
+Note that visualization queries are independent of any active dashboard filter set by the `update` tool.
+They always run against the full dataset.
+
+Learn more about the ggsql grammar at [ggsql.org](https://ggsql.org/).
+
+## See also
+
+- [Tools](tools.qmd) — Understand what querychat can do under the hood
+- [Provide context](context.qmd) — Help the LLM understand your data better
diff --git a/pkg-py/examples/10-viz-app.py b/pkg-py/examples/10-viz-app.py
new file mode 100644
index 000000000..5a3af3356
--- /dev/null
+++ b/pkg-py/examples/10-viz-app.py
@@ -0,0 +1,26 @@
+from querychat import QueryChat
+from querychat.data import titanic
+
+from shiny import App, ui
+
+# Omits "update" tool — this demo focuses on query + visualization only
+qc = QueryChat(
+ titanic(),
+ "titanic",
+ tools=("query", "visualize_query"),
+)
+
+#def app_ui(request):
+# return ui.page_fillable(
+# qc.ui(),
+# )
+#
+#
+#def server(input, output, session):
+# qc.server(enable_bookmarking=True)
+#
+#
+#app = App(app_ui, server, bookmark_store="url")
+
+
+app = qc.app()
\ No newline at end of file
diff --git a/pkg-py/src/querychat/_datasource_reader.py b/pkg-py/src/querychat/_datasource_reader.py
new file mode 100644
index 000000000..4ff050fa8
--- /dev/null
+++ b/pkg-py/src/querychat/_datasource_reader.py
@@ -0,0 +1,160 @@
+"""DataSourceReader bridge: routes ggsql's reader protocol through a real database."""
+
+from __future__ import annotations
+
+import contextlib
+import logging
+from typing import TYPE_CHECKING
+
+import polars as pl
+import sqlglot
+from sqlalchemy import text
+
+if TYPE_CHECKING:
+ from sqlalchemy.engine import Connection, Engine
+
+logger = logging.getLogger(__name__)
+
+SQLGLOT_DIALECTS: dict[str, str] = {
+ # Built-in SQLAlchemy dialects
+ "postgresql": "postgres",
+ "mysql": "mysql",
+ "sqlite": "sqlite",
+ "mssql": "tsql",
+ "oracle": "oracle",
+ # Third-party SQLAlchemy dialects (dialect name verified via engine.dialect.name)
+ "duckdb": "duckdb",
+ "snowflake": "snowflake",
+ "bigquery": "bigquery",
+ "redshift": "redshift",
+ "trino": "trino",
+ "databricks": "databricks",
+ "clickhousedb": "clickhouse", # clickhouse-connect
+ "clickhouse": "clickhouse", # clickhouse-sqlalchemy
+ "awsathena": "athena", # PyAthena
+ "teradatasql": "teradata", # teradatasqlalchemy
+ "exasol": "exasol",
+ "doris": "doris",
+ "singlestoredb": "singlestore",
+ "risingwave": "risingwave",
+ "druid": "druid",
+ "hive": "hive", # PyHive; also covers Spark via hive://
+ "presto": "presto",
+}
+
+
+def register_sqlglot_dialect(sqlalchemy_name: str, sqlglot_name: str) -> None:
+ """
+ Register a custom SQLAlchemy dialect name to sqlglot dialect mapping.
+
+ Use this if your database's SQLAlchemy driver reports a ``dialect.name``
+ that isn't in the built-in mapping.
+
+ Parameters
+ ----------
+ sqlalchemy_name
+ The value of ``engine.dialect.name`` for your database.
+ sqlglot_name
+ The corresponding sqlglot dialect identifier (see
+ ``sqlglot.dialects.dialect.Dialect.classes`` for valid names).
+
+ """
+ SQLGLOT_DIALECTS[sqlalchemy_name] = sqlglot_name
+
+
+def transpile_sql(sql: str, dialect: str) -> str:
+ """Transpile generic SQL to a target dialect using sqlglot."""
+ results = sqlglot.transpile(sql, read="", write=dialect)
+ return results[0]
+
+
+class DataSourceReader:
+ """
+ ggsql reader protocol implementation that routes SQL through a real database.
+
+ Implements execute_sql(), register(), and unregister() as expected by
+ ggsql's PyReaderBridge.
+ """
+
+ def __init__(self, engine: Engine, dialect: str):
+ self._engine = engine
+ self._dialect = dialect
+ self._conn: Connection | None = None
+ self._registered: list[str] = []
+
+ def __enter__(self):
+ self._conn = self._engine.connect()
+ return self
+
+ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> bool:
+ if self._conn is not None:
+ try:
+ for name in self._registered:
+ with contextlib.suppress(Exception):
+ self._conn.execute(text(f"DROP TABLE IF EXISTS {name}"))
+ self._conn.commit()
+ finally:
+ self._conn.close()
+ self._conn = None
+ self._registered.clear()
+ return False
+
+ def execute_sql(self, sql: str) -> pl.DataFrame:
+ if self._conn is None:
+ raise RuntimeError("DataSourceReader must be used as a context manager")
+ transpiled = transpile_sql(sql, self._dialect)
+ result = self._conn.execute(text(transpiled))
+ rows = result.fetchall()
+ columns = list(result.keys())
+ if not rows:
+ return pl.DataFrame(schema=dict.fromkeys(columns, pl.Utf8))
+ data = {col: [row[i] for row in rows] for i, col in enumerate(columns)}
+ return pl.DataFrame(data)
+
+ def register(self, name: str, df: pl.DataFrame, replace: bool = True) -> None: # noqa: FBT001, FBT002
+ if self._conn is None:
+ raise RuntimeError("DataSourceReader must be used as a context manager")
+ if replace:
+ self._conn.execute(text(f"DROP TABLE IF EXISTS {name}"))
+ if name in self._registered:
+ self._registered.remove(name)
+
+ col_defs = ", ".join(
+ f"{col} {polars_to_sql_type(dtype)}"
+ for col, dtype in zip(df.columns, df.dtypes, strict=True)
+ )
+ create_sql = f"CREATE TEMPORARY TABLE {name} ({col_defs})"
+ transpiled_create = transpile_sql(create_sql, self._dialect)
+ self._conn.execute(text(transpiled_create))
+ self._registered.append(name)
+
+ if len(df) > 0:
+ placeholders = ", ".join(f":{col}" for col in df.columns)
+ insert_sql = f"INSERT INTO {name} VALUES ({placeholders})"
+ rows = df.to_dicts()
+ self._conn.execute(text(insert_sql), rows)
+
+ self._conn.commit()
+
+ def unregister(self, name: str) -> None:
+ if self._conn is None:
+ raise RuntimeError("DataSourceReader must be used as a context manager")
+ self._conn.execute(text(f"DROP TABLE IF EXISTS {name}"))
+ self._conn.commit()
+ if name in self._registered:
+ self._registered.remove(name)
+
+
+def polars_to_sql_type(dtype: pl.DataType) -> str:
+ """Map polars dtypes to generic SQL types for CREATE TABLE."""
+ if dtype.is_integer():
+ return "INTEGER"
+ if dtype.is_float():
+ return "REAL"
+ if dtype == pl.Boolean:
+ return "BOOLEAN"
+ if dtype == pl.Date:
+ return "DATE"
+ if dtype in (pl.Datetime, pl.Duration):
+ return "TIMESTAMP"
+ return "TEXT"
diff --git a/pkg-py/src/querychat/_icons.py b/pkg-py/src/querychat/_icons.py
index 2b7683da0..fc484c9c0 100644
--- a/pkg-py/src/querychat/_icons.py
+++ b/pkg-py/src/querychat/_icons.py
@@ -2,19 +2,35 @@
from shiny import ui
-ICON_NAMES = Literal["arrow-counterclockwise", "funnel-fill", "terminal-fill", "table"]
+ICON_NAMES = Literal[
+ "arrow-counterclockwise",
+ "bar-chart-fill",
+ "chevron-down",
+ "download",
+ "funnel-fill",
+ "graph-up",
+ "terminal-fill",
+ "table",
+]
-def bs_icon(name: ICON_NAMES) -> ui.HTML:
+def bs_icon(name: ICON_NAMES, cls: str = "") -> ui.HTML:
"""Get Bootstrap icon SVG by name."""
if name not in BS_ICONS:
raise ValueError(f"Unknown Bootstrap icon: {name}")
- return ui.HTML(BS_ICONS[name])
+ svg = BS_ICONS[name]
+ if cls:
+ svg = svg.replace('class="', f'class="{cls} ', 1)
+ return ui.HTML(svg)
BS_ICONS = {
"arrow-counterclockwise": '',
+ "bar-chart-fill": '',
+ "chevron-down": '',
+ "download": '',
"funnel-fill": '',
+ "graph-up": '',
"terminal-fill": '',
"table": '',
}
diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py
index d3bf29e26..25e841971 100644
--- a/pkg-py/src/querychat/_querychat_base.py
+++ b/pkg-py/src/querychat/_querychat_base.py
@@ -23,11 +23,13 @@
from ._shiny_module import GREETING_PROMPT
from ._system_prompt import QueryChatSystemPrompt
from ._utils import MISSING, MISSING_TYPE, is_ibis_table
+from ._viz_utils import has_viz_deps, has_viz_tool
from .tools import (
UpdateDashboardData,
tool_query,
tool_reset_dashboard,
tool_update_dashboard,
+ tool_visualize_query,
)
if TYPE_CHECKING:
@@ -35,8 +37,10 @@
from narwhals.stable.v1.typing import IntoFrame
-TOOL_GROUPS = Literal["update", "query"]
+ from ._viz_tools import VisualizeQueryData
+TOOL_GROUPS = Literal["update", "query", "visualize_query"]
+DEFAULT_TOOLS: tuple[TOOL_GROUPS, ...] = ("update", "query")
class QueryChatBase(Generic[IntoFrameT]):
"""
@@ -58,7 +62,7 @@ def __init__(
*,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
- tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -72,7 +76,7 @@ def __init__(
"Table name must begin with a letter and contain only letters, numbers, and underscores",
)
- self.tools = normalize_tools(tools, default=("update", "query"))
+ self.tools = normalize_tools(tools, default=DEFAULT_TOOLS)
self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting
# Store init parameters for deferred system prompt building
@@ -128,6 +132,7 @@ def _create_session_client(
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING,
update_dashboard: Callable[[UpdateDashboardData], None] | None = None,
reset_dashboard: Callable[[], None] | None = None,
+ visualize_query: Callable[[VisualizeQueryData], None] | None = None,
) -> chatlas.Chat:
"""Create a fresh, fully-configured Chat."""
spec = self._client_spec if isinstance(client_spec, MISSING_TYPE) else client_spec
@@ -152,6 +157,10 @@ def _create_session_client(
if "query" in resolved_tools:
chat.register_tool(tool_query(data_source))
+ if "visualize_query" in resolved_tools:
+ query_viz_fn = visualize_query or (lambda _: None)
+ chat.register_tool(tool_visualize_query(data_source, query_viz_fn))
+
return chat
def client(
@@ -160,6 +169,7 @@ def client(
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING,
update_dashboard: Callable[[UpdateDashboardData], None] | None = None,
reset_dashboard: Callable[[], None] | None = None,
+ visualize_query: Callable[[VisualizeQueryData], None] | None = None,
) -> chatlas.Chat:
"""
Create a chat client with registered tools.
@@ -167,11 +177,14 @@ def client(
Parameters
----------
tools
- Which tools to include: `"update"`, `"query"`, or both.
+ Which tools to include: `"update"`, `"query"`, `"visualize_query"`,
+ or a combination.
update_dashboard
Callback when update_dashboard tool succeeds.
reset_dashboard
Callback when reset_dashboard tool is invoked.
+ visualize_query
+ Callback when visualize_query tool succeeds.
Returns
-------
@@ -184,6 +197,7 @@ def client(
tools=tools,
update_dashboard=update_dashboard,
reset_dashboard=reset_dashboard,
+ visualize_query=visualize_query,
)
def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str:
@@ -293,14 +307,24 @@ def create_client(client: str | chatlas.Chat | None) -> chatlas.Chat:
def normalize_tools(
tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE,
default: tuple[TOOL_GROUPS, ...] | None,
+ *,
+ check_deps: bool = True,
) -> tuple[TOOL_GROUPS, ...] | None:
if tools is None or tools == ():
- return None
+ result = None
elif isinstance(tools, MISSING_TYPE):
- return default
+ result = default
elif isinstance(tools, str):
- return (tools,)
+ result = (tools,)
elif isinstance(tools, tuple):
- return tools
+ result = tools
else:
- return tuple(tools)
+ result = tuple(tools)
+ if not check_deps:
+ return result
+ if has_viz_tool(result) and not has_viz_deps():
+ raise ImportError(
+ "Visualization tools require ggsql, altair, shinywidgets, and "
+ "vl-convert-python. Install them with: pip install querychat[viz]"
+ )
+ return result
diff --git a/pkg-py/src/querychat/_querychat_core.py b/pkg-py/src/querychat/_querychat_core.py
index af0685e01..1dd132631 100644
--- a/pkg-py/src/querychat/_querychat_core.py
+++ b/pkg-py/src/querychat/_querychat_core.py
@@ -165,6 +165,8 @@ def format_tool_result(result: ContentToolResult) -> str:
return str(result)
+
+
def format_query_error(e: Exception) -> str:
"""Format a query error with helpful guidance."""
error_msg = str(e).lower()
diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py
index c25b923fc..d49aa11a2 100644
--- a/pkg-py/src/querychat/_shiny.py
+++ b/pkg-py/src/querychat/_shiny.py
@@ -10,9 +10,10 @@
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
from ._icons import bs_icon
-from ._querychat_base import TOOL_GROUPS, QueryChatBase
+from ._querychat_base import DEFAULT_TOOLS, TOOL_GROUPS, QueryChatBase
from ._shiny_module import ServerValues, mod_server, mod_ui
from ._utils import MISSING, MISSING_TYPE, as_narwhals
+from ._viz_utils import has_viz_tool
if TYPE_CHECKING:
from pathlib import Path
@@ -97,10 +98,11 @@ class QueryChat(QueryChatBase[IntoFrameT]):
tools
Which querychat tools to include in the chat client by default. Can be:
- A single tool string: `"update"` or `"query"`
- - A tuple of tools: `("update", "query")`
+ - A tuple of tools: `("update", "query", "visualize_query")`
- `None` or `()` to disable all tools
- Default is `("update", "query")` (both tools enabled).
+ Default is `("update", "query")`. The visualization tool (`"visualize_query"`)
+ can be opted into by including it in the tuple.
Set to `"update"` to prevent the LLM from accessing data values, only
allowing dashboard filtering without answering questions.
@@ -156,7 +158,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
- tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -172,7 +174,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
- tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -188,7 +190,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
- tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -204,7 +206,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
- tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -219,7 +221,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
- tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -245,7 +247,7 @@ def app(
"""
Quickly chat with a dataset.
- Creates a Shiny app with a chat sidebar and data table view -- providing a
+ Creates a Shiny app with a chat sidebar and data view -- providing a
quick-and-easy way to start chatting with your data.
Parameters
@@ -301,6 +303,7 @@ def app_server(input: Inputs, output: Outputs, session: Session):
greeting=self.greeting,
client=self._create_session_client,
enable_bookmarking=enable_bookmarking,
+ tools=self.tools,
)
@render.text
@@ -399,7 +402,7 @@ def ui(self, *, id: Optional[str] = None, **kwargs):
A UI component.
"""
- return mod_ui(id or self.id, **kwargs)
+ return mod_ui(id or self.id, preload_viz=has_viz_tool(self.tools), **kwargs)
def server(
self,
@@ -506,6 +509,7 @@ def create_session_client(**kwargs) -> chatlas.Chat:
greeting=self.greeting,
client=create_session_client,
enable_bookmarking=enable_bookmarking,
+ tools=self.tools,
)
@@ -616,6 +620,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"),
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -632,6 +637,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -648,6 +654,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -664,6 +671,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -680,6 +688,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -695,6 +704,7 @@ def __init__(
id: Optional[str] = None,
greeting: Optional[str | Path] = None,
client: Optional[str | chatlas.Chat] = None,
+ tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS,
data_description: Optional[str | Path] = None,
categorical_threshold: int = 20,
extra_instructions: Optional[str | Path] = None,
@@ -714,6 +724,7 @@ def __init__(
table_name,
greeting=greeting,
client=client,
+ tools=tools,
data_description=data_description,
categorical_threshold=categorical_threshold,
extra_instructions=extra_instructions,
@@ -743,6 +754,7 @@ def __init__(
greeting=self.greeting,
client=self._create_session_client,
enable_bookmarking=enable,
+ tools=self.tools,
)
def sidebar(
@@ -804,7 +816,7 @@ def ui(self, *, id: Optional[str] = None, **kwargs):
A UI component.
"""
- return mod_ui(id or self.id, **kwargs)
+ return mod_ui(id or self.id, preload_viz=has_viz_tool(self.tools), **kwargs)
def df(self) -> IntoFrameT:
"""
diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py
index 4264285bd..49ec7fa6f 100644
--- a/pkg-py/src/querychat/_shiny_module.py
+++ b/pkg-py/src/querychat/_shiny_module.py
@@ -1,10 +1,9 @@
from __future__ import annotations
-import copy
import warnings
from dataclasses import dataclass
from pathlib import Path
-from typing import TYPE_CHECKING, Generic, Union
+from typing import TYPE_CHECKING, Generic, TypedDict, Union
import chatlas
import shinychat
@@ -13,7 +12,9 @@
from shiny import module, reactive, ui
from ._querychat_core import GREETING_PROMPT
-from .tools import tool_query, tool_reset_dashboard, tool_update_dashboard
+from ._viz_altair_widget import AltairWidget
+from ._viz_ggsql import execute_ggsql
+from ._viz_utils import has_viz_tool, preload_viz_deps_server, preload_viz_deps_ui
if TYPE_CHECKING:
from collections.abc import Callable
@@ -23,6 +24,8 @@
from shiny import Inputs, Outputs, Session
from ._datasource import DataSource
+ from ._querychat_base import TOOL_GROUPS
+ from ._viz_tools import VisualizeQueryData
from .types import UpdateDashboardData
ReactiveString = reactive.Value[str]
@@ -30,11 +33,31 @@
ReactiveStringOrNone = reactive.Value[Union[str, None]]
"""A reactive string (or None) value."""
+
+class VizWidgetEntry(TypedDict):
+ """A bookmarked visualization widget: enough state to re-register on restore."""
+
+ widget_id: str
+ ggsql: str
+
+
CHAT_ID = "chat"
+class _DeferredStubChatClient:
+ """Placeholder chat client for deferred stub sessions."""
+
+ def __getattr__(self, _name: str):
+ raise RuntimeError(
+ "Chat client is unavailable during stub session before data_source is set."
+ )
+
+
+ServerClient = chatlas.Chat | _DeferredStubChatClient
+
+
@module.ui
-def mod_ui(**kwargs):
+def mod_ui(*, preload_viz: bool = False, **kwargs):
css_path = Path(__file__).parent / "static" / "css" / "styles.css"
js_path = Path(__file__).parent / "static" / "js" / "querychat.js"
@@ -47,6 +70,7 @@ def mod_ui(**kwargs):
ui.include_js(js_path),
),
tag,
+ preload_viz_deps_ui() if preload_viz else None,
)
@@ -76,18 +100,17 @@ class ServerValues(Generic[IntoFrameT]):
`.title()`, or set it with `.title.set("...")`. Returns
`None` if no title has been set.
client
- The session-specific chat client instance. This is a deep copy of the
- base client configured for this specific session, containing the chat
- history and tool registrations for this session only. This may be
- `None` during stub sessions when the client depends on deferred,
- session-scoped state.
+ Session chat client value.
+ For real sessions this is a `chatlas.Chat` created by the client
+ factory. For deferred stub sessions (where `data_source` is not set
+ yet), this is a placeholder client that raises when accessed.
"""
df: Callable[[], IntoFrameT]
sql: ReactiveStringOrNone
title: ReactiveStringOrNone
- client: chatlas.Chat | None
+ client: ServerClient
@module.server
@@ -98,14 +121,39 @@ def mod_server(
*,
data_source: DataSource[IntoFrameT] | None,
greeting: str | None,
- client: chatlas.Chat | Callable,
+ client: Callable[..., chatlas.Chat],
enable_bookmarking: bool,
+ tools: tuple[TOOL_GROUPS, ...] | None = None,
) -> ServerValues[IntoFrameT]:
# Reactive values to store state
sql = ReactiveStringOrNone(None)
title = ReactiveStringOrNone(None)
has_greeted = reactive.value[bool](False) # noqa: FBT003
+ if not callable(client):
+ raise TypeError("mod_server() requires a callable client factory.")
+
+ def update_dashboard(data: UpdateDashboardData):
+ sql.set(data["query"])
+ title.set(data["title"])
+
+ def reset_dashboard():
+ sql.set(None)
+ title.set(None)
+
+ viz_widgets: list[VizWidgetEntry] = []
+
+ def on_visualize(data: VisualizeQueryData):
+ viz_widgets.append({"widget_id": data["widget_id"], "ggsql": data["ggsql"]})
+
+ def build_chat_client() -> chatlas.Chat:
+ return client(
+ update_dashboard=update_dashboard,
+ reset_dashboard=reset_dashboard,
+ visualize_query=on_visualize,
+ tools=tools,
+ )
+
# Short-circuit for stub sessions (e.g. 1st run of an Express app)
# data_source may be None during stub session for deferred pattern
if session.is_stub_session():
@@ -113,11 +161,15 @@ def mod_server(
def _stub_df():
raise RuntimeError("RuntimeError: No current reactive context")
+ stub_client = (
+ _DeferredStubChatClient() if data_source is None else build_chat_client()
+ )
+
return ServerValues(
df=_stub_df,
sql=sql,
title=title,
- client=client if isinstance(client, chatlas.Chat) else None,
+ client=stub_client,
)
# Real session requires data_source
@@ -127,27 +179,11 @@ def _stub_df():
"Set it via the data_source property before users connect."
)
- def update_dashboard(data: UpdateDashboardData):
- sql.set(data["query"])
- title.set(data["title"])
-
- def reset_dashboard():
- sql.set(None)
- title.set(None)
-
- # Set up the chat object for this session
- # Support both a callable that creates a client and legacy instance pattern
- if callable(client) and not isinstance(client, chatlas.Chat):
- chat = client(
- update_dashboard=update_dashboard, reset_dashboard=reset_dashboard
- )
- else:
- # Legacy pattern: client is Chat instance
- chat = copy.deepcopy(client)
+ # Build the session-specific chat client through QueryChat.client(...).
+ chat = build_chat_client()
- chat.register_tool(tool_update_dashboard(data_source, update_dashboard))
- chat.register_tool(tool_query(data_source))
- chat.register_tool(tool_reset_dashboard(reset_dashboard))
+ if has_viz_tool(tools):
+ preload_viz_deps_server()
# Execute query when SQL changes
@reactive.calc
@@ -211,6 +247,8 @@ def _on_bookmark(x: BookmarkState) -> None:
vals["querychat_sql"] = sql.get()
vals["querychat_title"] = title.get()
vals["querychat_has_greeted"] = has_greeted.get()
+ if viz_widgets:
+ vals["querychat_viz_widgets"] = viz_widgets
@session.bookmark.on_restore
def _on_restore(x: RestoreState) -> None:
@@ -221,9 +259,44 @@ def _on_restore(x: RestoreState) -> None:
title.set(vals["querychat_title"])
if "querychat_has_greeted" in vals:
has_greeted.set(vals["querychat_has_greeted"])
+ if "querychat_viz_widgets" in vals:
+ restored = restore_viz_widgets(
+ data_source, vals["querychat_viz_widgets"]
+ )
+ viz_widgets[:] = restored
return ServerValues(df=filtered_df, sql=sql, title=title, client=chat)
class GreetWarning(Warning):
"""Warning raised when no greeting is provided to QueryChat."""
+
+
+def restore_viz_widgets(
+ data_source: DataSource[IntoFrameT],
+ saved_widgets: list[VizWidgetEntry],
+) -> list[VizWidgetEntry]:
+ """Re-execute ggsql queries, register widgets, and return restored entries."""
+ from ggsql import validate
+ from shinywidgets import register_widget
+
+ restored: list[VizWidgetEntry] = []
+
+ for entry in saved_widgets:
+ widget_id = entry["widget_id"]
+ ggsql_str = entry["ggsql"]
+ try:
+ validated = validate(ggsql_str)
+ spec = execute_ggsql(data_source, ggsql_str, validated)
+ altair_widget = AltairWidget.from_ggsql(spec, widget_id=widget_id)
+ register_widget(widget_id, altair_widget.widget)
+ restored.append(entry)
+ except Exception:
+ # If a query fails on restore (e.g. data changed), skip it.
+ # The placeholder will remain empty but the rest of the chat restores.
+ warnings.warn(
+ f"Failed to restore visualization widget '{widget_id}' on bookmark restore.",
+ stacklevel=2,
+ )
+
+ return restored
diff --git a/pkg-py/src/querychat/_system_prompt.py b/pkg-py/src/querychat/_system_prompt.py
index 5a8445e93..a18630153 100644
--- a/pkg-py/src/querychat/_system_prompt.py
+++ b/pkg-py/src/querychat/_system_prompt.py
@@ -6,6 +6,8 @@
import chevron
+from ._viz_utils import has_viz_tool
+
_SCHEMA_TAG_RE = re.compile(r"\{\{[{#^/]?\s*schema\b")
if TYPE_CHECKING:
@@ -83,7 +85,14 @@ def render(self, tools: tuple[TOOL_GROUPS, ...] | None) -> str:
"extra_instructions": self.extra_instructions,
"has_tool_update": "update" in tools if tools else False,
"has_tool_query": "query" in tools if tools else False,
+ "has_tool_visualize_query": has_viz_tool(tools),
"include_query_guidelines": len(tools or ()) > 0,
}
- return chevron.render(self.template, context)
+ prompts_dir = str(Path(__file__).parent / "prompts")
+ return chevron.render(
+ self.template,
+ context,
+ partials_path=prompts_dir,
+ partials_ext="md",
+ )
diff --git a/pkg-py/src/querychat/_utils.py b/pkg-py/src/querychat/_utils.py
index 555e8e376..6d4c803d8 100644
--- a/pkg-py/src/querychat/_utils.py
+++ b/pkg-py/src/querychat/_utils.py
@@ -4,8 +4,10 @@
import re
import warnings
from contextlib import contextmanager
+from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional, overload
+import chevron
import narwhals.stable.v1 as nw
from great_tables import GT
@@ -14,6 +16,50 @@
import ibis
import pandas as pd
+ import polars as pl
+ from narwhals.stable.v1.typing import IntoFrame
+
+
+_SCHEMA_DUMP_PATTERN = re.compile(
+ r"^\s*[\{\[]|'additionalProperties'|\"additionalProperties\"",
+)
+
+
+def truncate_error(error_msg: str, max_chars: int = 500) -> str:
+ if len(error_msg) <= max_chars:
+ return error_msg
+
+ lines = error_msg.split("\n")
+ meaningful: list[str] = []
+ truncated_by_schema = False
+ for line in lines:
+ if not line.strip():
+ truncated_by_schema = True
+ break
+ if _SCHEMA_DUMP_PATTERN.search(line):
+ truncated_by_schema = True
+ break
+ meaningful.append(line)
+
+ if truncated_by_schema and meaningful:
+ prefix = "\n".join(meaningful)
+ if len(prefix) > max_chars:
+ cut = prefix[:max_chars]
+ last_space = cut.rfind(" ")
+ if last_space > max_chars // 2:
+ cut = cut[:last_space]
+ prefix = cut
+ return prefix.rstrip() + "\n\n(error truncated)"
+
+ # No schema markers found (or nothing before them) — apply hard cap if needed
+ if len(error_msg) <= max_chars:
+ return error_msg
+
+ truncated = error_msg[:max_chars]
+ last_space = truncated.rfind(" ")
+ if last_space > max_chars // 2:
+ truncated = truncated[:last_space]
+ return truncated.rstrip() + "\n\n(error truncated)"
class MISSING_TYPE: # noqa: N801
@@ -171,14 +217,18 @@ def get_tool_details_setting() -> Optional[Literal["expanded", "collapsed", "def
return setting_lower
-def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> bool:
+def querychat_tool_starts_open(
+ action: Literal[
+ "update", "query", "reset", "visualize_query"
+ ],
+) -> bool:
"""
Determine whether a tool card should be open based on action and setting.
Parameters
----------
action : str
- The action type ('update', 'query', or 'reset')
+ The action type ('update', 'query', 'reset', or 'visualize_query')
Returns
-------
@@ -290,3 +340,15 @@ def df_to_html(df, maxrows: int = 5) -> str:
table_html += f"\n\n*(Showing {maxrows} of {nrow_full} rows)*\n"
return table_html
+
+
+def to_polars(data: IntoFrame) -> pl.DataFrame:
+ """Convert any narwhals-compatible frame to a polars DataFrame."""
+ return as_narwhals(data).to_polars()
+
+
+def read_prompt_template(filename: str, **kwargs: object) -> str:
+ """Read and interpolate a prompt template file."""
+ template_path = Path(__file__).parent / "prompts" / filename
+ template = template_path.read_text()
+ return chevron.render(template, kwargs)
diff --git a/pkg-py/src/querychat/_viz_altair_widget.py b/pkg-py/src/querychat/_viz_altair_widget.py
new file mode 100644
index 000000000..00d40347d
--- /dev/null
+++ b/pkg-py/src/querychat/_viz_altair_widget.py
@@ -0,0 +1,187 @@
+"""Altair chart wrapper for responsive display in Shiny."""
+
+from __future__ import annotations
+
+import copy
+from typing import TYPE_CHECKING, Any, cast
+from uuid import uuid4
+
+from shiny.session import get_current_session
+
+from shiny import reactive
+
+if TYPE_CHECKING:
+ import altair as alt
+ import ggsql
+
+class AltairWidget:
+ """
+ An Altair chart wrapped in ``alt.JupyterChart`` for display in Shiny.
+
+ Always produces a ``JupyterChart`` so that ``shinywidgets`` receives
+ a consistent widget type and doesn't call ``chart.properties(width=...)``
+ (which fails on compound specs).
+
+ Simple charts use native ``width/height: "container"`` sizing.
+ Compound charts (facet, concat) get calculated cell dimensions
+ that are reactively updated when the output container resizes.
+ """
+
+ widget: alt.JupyterChart
+ widget_id: str
+
+ def __init__(
+ self,
+ chart: alt.TopLevelMixin,
+ *,
+ widget_id: str | None = None,
+ ) -> None:
+ import altair as alt
+
+ is_compound = isinstance(
+ chart,
+ (alt.FacetChart, alt.ConcatChart, alt.HConcatChart, alt.VConcatChart),
+ )
+
+ # Workaround: Vega-Lite's width/height: "container" doesn't work for
+ # compound specs (facet, concat, etc.), so we inject pixel dimensions
+ # and reconstruct the chart. Remove this branch when ggsql handles it
+ # natively: https://github.com/posit-dev/ggsql/issues/238
+ if is_compound:
+ chart = fit_chart_to_container(
+ chart, DEFAULT_COMPOUND_WIDTH, DEFAULT_COMPOUND_HEIGHT
+ )
+ else:
+ chart = chart.properties(width="container", height="container")
+
+ self.widget = alt.JupyterChart(chart)
+ self.widget_id = widget_id or f"querychat_viz_{uuid4().hex[:8]}"
+
+ # Reactively update compound cell sizes when the container resizes.
+ # Also part of the compound sizing workaround (issue #238).
+ if is_compound:
+ self._setup_reactive_sizing(self.widget, self.widget_id)
+
+ @classmethod
+ def from_ggsql(
+ cls, spec: ggsql.Spec, *, widget_id: str | None = None
+ ) -> AltairWidget:
+ from ggsql import VegaLiteWriter
+
+ writer = VegaLiteWriter()
+ return cls(writer.render_chart(spec), widget_id=widget_id)
+
+ @staticmethod
+ def _setup_reactive_sizing(widget: alt.JupyterChart, widget_id: str) -> None:
+ session = get_current_session()
+ if session is None:
+ return
+
+ @reactive.effect
+ def _sizing_effect():
+ width = session.clientdata.output_width(widget_id)
+ height = session.clientdata.output_height(widget_id)
+ if width is None or height is None:
+ return
+ chart = widget.chart
+ if chart is None:
+ return
+ chart = cast("alt.Chart", chart)
+ chart2 = fit_chart_to_container(chart, int(width), int(height))
+ # Must set widget.spec (a new dict) rather than widget.chart,
+ # because traitlets won't fire change events when the same
+ # chart object is assigned back after in-place mutation.
+ widget.spec = chart2.to_dict()
+
+ # Clean up the effect when the session ends to avoid memory leaks
+ session.on_ended(_sizing_effect.destroy)
+
+
+# ---------------------------------------------------------------------------
+# Compound chart sizing helpers
+#
+# Vega-Lite's `width/height: "container"` doesn't work for compound specs
+# (facet, concat, etc.), so we manually inject cell dimensions. Ideally ggsql
+# will handle this natively: https://github.com/posit-dev/ggsql/issues/238
+# ---------------------------------------------------------------------------
+
+DEFAULT_COMPOUND_WIDTH = 900
+DEFAULT_COMPOUND_HEIGHT = 450
+
+LEGEND_CHANNELS = frozenset(
+ {"color", "fill", "stroke", "shape", "size", "opacity"}
+)
+LEGEND_WIDTH = 120 # approximate space for a right-side legend
+
+
+def fit_chart_to_container(
+ chart: alt.TopLevelMixin,
+ container_width: int,
+ container_height: int,
+) -> alt.TopLevelMixin:
+ """
+ Return a copy of ``chart`` with cell ``width``/``height`` set.
+
+ The original chart is never mutated.
+
+ For faceted charts, divides the container width by the number of columns.
+ For hconcat/concat, divides by the number of sub-specs.
+ For vconcat, each sub-spec gets the full width.
+
+ Subtracts padding estimates so the rendered cells fill the container,
+ including space for legends when present.
+ """
+ import altair as alt
+
+ chart = copy.deepcopy(chart)
+
+ # Approximate padding; will be replaced when ggsql handles compound sizing
+ # natively (https://github.com/posit-dev/ggsql/issues/238).
+ padding_x = 80 # y-axis labels + title padding
+ padding_y = 120 # facet headers, x-axis labels + title, bottom padding
+ if has_legend(chart.to_dict()):
+ padding_x += LEGEND_WIDTH
+ usable_w = max(container_width - padding_x, 100)
+ usable_h = max(container_height - padding_y, 100)
+
+ if isinstance(chart, alt.FacetChart):
+ ncol = chart.columns if isinstance(chart.columns, int) else 1
+ cell_w = usable_w // max(ncol, 1)
+ chart.spec.width = cell_w
+ chart.spec.height = usable_h
+ elif isinstance(chart, alt.HConcatChart):
+ cell_w = usable_w // max(len(chart.hconcat), 1)
+ for sub in chart.hconcat:
+ sub.width = cell_w
+ sub.height = usable_h
+ elif isinstance(chart, alt.ConcatChart):
+ ncol = chart.columns if isinstance(chart.columns, int) else len(chart.concat)
+ cell_w = usable_w // max(ncol, 1)
+ for sub in chart.concat:
+ sub.width = cell_w
+ sub.height = usable_h
+ elif isinstance(chart, alt.VConcatChart):
+ cell_h = usable_h // max(len(chart.vconcat), 1)
+ for sub in chart.vconcat:
+ sub.width = usable_w
+ sub.height = cell_h
+
+ return chart
+
+
+def has_legend(vl: dict[str, object]) -> bool:
+ """Check if any encoding in the VL spec uses a legend-producing channel with a field."""
+ specs: list[dict[str, Any]] = []
+ if "spec" in vl:
+ specs.append(vl["spec"]) # type: ignore[arg-type]
+ for key in ("hconcat", "vconcat", "concat"):
+ if key in vl:
+ specs.extend(vl[key]) # type: ignore[arg-type]
+
+ for spec in specs:
+ for layer in spec.get("layer", [spec]): # type: ignore[union-attr]
+ enc = layer.get("encoding", {}) # type: ignore[union-attr]
+ for ch in LEGEND_CHANNELS:
+ if ch in enc and "field" in enc[ch]: # type: ignore[operator]
+ return True
+ return False
diff --git a/pkg-py/src/querychat/_viz_ggsql.py b/pkg-py/src/querychat/_viz_ggsql.py
new file mode 100644
index 000000000..710e6f46b
--- /dev/null
+++ b/pkg-py/src/querychat/_viz_ggsql.py
@@ -0,0 +1,194 @@
+"""
+Helpers for executing ggsql queries in querychat.
+
+Architecture overview
+---------------------
+Querychat executes ggsql queries through two possible paths:
+
+1. **Bridge path** (SQLAlchemySource with known dialect) — A
+ ``DataSourceReader`` implements ggsql's reader protocol, routing all SQL
+ through the real database. ggsql runs its full pipeline (CTEs, stat
+ transforms, layer queries) against the real DB. sqlglot transpiles
+ ggsql's ANSI-generated SQL to the target dialect. This path supports
+ multi-source layers and avoids pulling large result sets into memory.
+
+2. **Fallback path** (all other DataSource types, or bridge failure) — The
+ SQL portion (before VISUALISE) runs on the real database via
+ ``DataSource.execute_query()``, then the VISUALISE portion replays
+ locally against the SQL result using ``ggsql.DuckDBReader``.
+
+The fallback path requires reconstructing a valid ggsql query from the
+split ``sql()`` and ``visual()`` parts. See ``execute_two_phase()`` for
+details on the two VISUALISE forms (Form A and Form B).
+
+Limitation of fallback path: layer-specific sources
+----------------------------------------------------
+ggsql supports per-layer data sources (``DRAW line MAPPING … FROM cte``),
+but the fallback path can't support them because the SQL result is a single
+DataFrame — CTEs don't survive the DataSource boundary. The bridge path
+handles this correctly.
+"""
+
+from __future__ import annotations
+
+import logging
+import re
+from typing import TYPE_CHECKING
+
+from ._utils import to_polars
+
+if TYPE_CHECKING:
+ import ggsql
+
+ from ._datasource import DataSource
+
+logger = logging.getLogger(__name__)
+
+
+def execute_ggsql(
+ data_source: DataSource,
+ query: str,
+ validated: ggsql.Validated,
+) -> ggsql.Spec:
+ """
+ Execute a ggsql query, choosing the bridge or fallback path.
+
+ Parameters
+ ----------
+ data_source
+ The querychat DataSource to execute against.
+ query
+ The original ggsql query string (needed for the bridge path).
+ validated
+ A pre-validated ggsql query (from ``ggsql.validate()``).
+
+ Returns
+ -------
+ ggsql.Spec
+ The writer-independent plot specification.
+
+ """
+ from ._datasource import SQLAlchemySource
+ from ._datasource_reader import SQLGLOT_DIALECTS, DataSourceReader
+
+ if isinstance(data_source, SQLAlchemySource):
+ sa_dialect_name = data_source._engine.dialect.name
+ dialect = SQLGLOT_DIALECTS.get(sa_dialect_name)
+ if dialect is None:
+ logger.warning(
+ "Unknown SQLAlchemy dialect %r — falling back to two-phase execution. "
+ "You can register it via: "
+ "from querychat._datasource_reader import register_sqlglot_dialect",
+ sa_dialect_name,
+ )
+ if dialect is not None:
+ try:
+ with DataSourceReader(data_source._engine, dialect) as reader:
+ import ggsql as _ggsql
+
+ return _ggsql.execute(query, reader)
+ except Exception:
+ logger.debug(
+ "DataSourceReader bridge failed, falling back to two-phase",
+ exc_info=True,
+ )
+
+ return execute_two_phase(data_source, validated)
+
+
+def execute_two_phase(
+ data_source: DataSource,
+ validated: ggsql.Validated,
+) -> ggsql.Spec:
+ """
+ Execute a ggsql query using the two-phase approach (fallback path).
+
+ Phase 1: execute SQL on the real database.
+ Phase 2: replay the VISUALISE portion locally in DuckDB.
+ """
+ from ggsql import DuckDBReader
+
+ visual = validated.visual()
+ if has_layer_level_source(visual):
+ raise ValueError(
+ "Layer-specific sources are not currently supported in querychat visual "
+ "queries. Rewrite the query so that all layers come from the final SQL "
+ "result."
+ )
+
+ pl_df = to_polars(data_source.execute_query(validated.sql()))
+ # Snowflake (and some other backends) uppercase unquoted identifiers,
+ # but the LLM writes lowercase aliases in the VISUALISE clause.
+ # DuckDB is case-insensitive, so lowercasing here lets both match.
+ pl_df.columns = [c.lower() for c in pl_df.columns]
+
+ reader = DuckDBReader("duckdb://memory")
+ table = extract_visualise_table(visual)
+
+ if table is not None:
+ name = table[1:-1] if table.startswith('"') and table.endswith('"') else table
+ reader.register(name, pl_df)
+ return reader.execute(visual)
+ else:
+ reader.register("_data", pl_df)
+ return reader.execute(f"SELECT * FROM _data {visual}")
+
+
+def extract_visualise_table(visual: str) -> str | None:
+ """
+ Extract the table name from ``VISUALISE … FROM