diff --git a/.gitignore b/.gitignore index 740f3993c..a06bad8c9 100644 --- a/.gitignore +++ b/.gitignore @@ -268,6 +268,9 @@ renv.lock # Planning documents (local only) docs/plans/ +# Screenshot capture script (local only) +pkg-py/docs/_screenshots/ + # Playwright MCP .playwright-mcp/ diff --git a/CLAUDE.md b/CLAUDE.md index 4170fa304..98873f00e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -69,13 +69,15 @@ make py-build make py-docs ``` -Before finishing your implementation or committing any code, you should run: +Before committing any Python code, you must run all three checks and confirm they pass: ```bash uv run ruff check --fix pkg-py --config pyproject.toml +make py-check-types +make py-check-tests ``` -To get help with making sure code adheres to project standards. +Do not commit or push until all three pass. ### R Package diff --git a/docs/plans/2026-04-17-datasource-reader-bridge-design.md b/docs/plans/2026-04-17-datasource-reader-bridge-design.md new file mode 100644 index 000000000..b597aba87 --- /dev/null +++ b/docs/plans/2026-04-17-datasource-reader-bridge-design.md @@ -0,0 +1,129 @@ +# DataSourceReader Bridge Design + +## Problem + +querychat executes ggsql queries in two phases: run the SQL on the real database, then replay the VISUALISE portion locally against the result in an in-memory DuckDB. This has two drawbacks: + +1. **Scaling** — the full SQL result must be pulled into Python memory, even when ggsql's stat transforms (histogram, density, boxplot) would reduce it to a small summary. A histogram of 10M rows pulls all 10M rows into memory only to bin them into ~30 buckets. + +2. **Multi-source layers** — ggsql supports per-layer data sources (e.g., a CTE fed to a different DRAW clause). The two-phase approach loses intermediate tables at the DataSource boundary, so querychat rejects these queries. + +Both problems stem from the same root cause: querychat splits the query at the SQL/VISUALISE boundary and runs each half independently, rather than letting ggsql run the full pipeline against the real database. + +## Solution + +For `SQLAlchemySource` data sources, implement a `DataSourceReader` — a Python object that satisfies ggsql's reader protocol (`execute_sql()`, `register()`, `unregister()`) by routing SQL to the real database. Pass this reader to `ggsql.execute(query, reader)`, letting ggsql run the entire pipeline (parsing, CTEs, stat transforms, everything) against the real DB. + +Use [sqlglot](https://github.com/tobymao/sqlglot) to transpile ggsql's ANSI-generated SQL to the target database dialect. This gives broad database coverage (31 dialects) without waiting for ggsql to add each one. + +Fall back to the current two-phase approach when the bridge fails (e.g., temp table permission denied, unsupported dialect, transpilation error) or for non-SQLAlchemy data sources. + +## Data flow + +### Bridge path (SQLAlchemySource) + +``` +ggsql.execute(query, DataSourceReader) + │ + ├─ CTE materialization + │ execute_sql("SELECT … FROM orders GROUP BY …") + │ → sqlglot transpiles generic → target dialect + │ → runs on real DB + │ → result registered as temp table on real DB + │ + ├─ Global SQL + │ execute_sql("SELECT * FROM orders WHERE …") + │ → runs on real DB + │ → result registered as temp table on real DB + │ + ├─ Schema queries + │ execute_sql("SELECT … LIMIT 0") + │ → runs on real DB against temp tables + │ + ├─ Stat transforms (histograms, density, boxplot, etc.) + │ execute_sql("WITH … SELECT … binning SQL …") + │ → sqlglot transpiles generated ANSI SQL → target dialect + │ → runs on real DB against temp tables + │ + └─ Final layer queries + execute_sql("SELECT …") + → runs on real DB, small result set returned +``` + +### Fallback path (current approach, all DataSource types) + +``` +validated.sql() + → DataSource.execute_query() on real DB + → full result pulled into Python memory + → registered in local DuckDB + → ggsql replays VISUALISE portion locally +``` + +## Components + +### `DataSourceReader` + +Python class implementing ggsql's reader protocol. Lives in `_viz_ggsql.py`. + +- **Constructor**: takes a `sqlalchemy.Engine` and a sqlglot dialect string. Opens a single connection from the engine, held for the pipeline's duration. +- **`execute_sql(sql)`**: transpiles from generic SQL to target dialect via `sqlglot.transpile(sql, read="", write=dialect)`, executes on the real DB via SQLAlchemy, returns a polars DataFrame. +- **`register(name, df, replace)`**: creates a `TEMPORARY TABLE` on the real DB with column types derived from polars dtypes (generic SQL types, transpiled by sqlglot). Inserts rows in batches via SQLAlchemy. Tracks registered names for cleanup. +- **`unregister(name)`**: drops the temp table on the real DB. +- **Context manager**: `__exit__` drops all registered temp tables and closes the connection, ensuring cleanup even on error. + +### Dialect mapping + +```python +SQLGLOT_DIALECTS = { + "postgresql": "postgres", + "snowflake": "snowflake", + "duckdb": "duckdb", + "sqlite": "sqlite", + "mysql": "mysql", + "mssql": "tsql", + "bigquery": "bigquery", + "redshift": "redshift", +} +``` + +Maps `engine.dialect.name` to sqlglot dialect names. Unknown dialects skip the bridge and use the fallback. + +### Entry point + +```python +def execute_ggsql(data_source, query, validated): + if isinstance(data_source, SQLAlchemySource): + dialect = SQLGLOT_DIALECTS.get(data_source._engine.dialect.name) + if dialect is not None: + try: + with DataSourceReader(data_source._engine, dialect) as reader: + return ggsql.execute(query, reader) + except Exception: + pass # fall through + + # Fallback: current two-phase approach + return _execute_two_phase(data_source, validated) +``` + +### `_execute_two_phase` + +The current `execute_ggsql` body, renamed. Includes the existing regex-based `extract_visualise_table` and `has_layer_level_source` logic. Used for `DataFrameSource`, `PolarsLazySource`, `IbisSource`, and as the fallback for SQLAlchemy sources. + +## Dependencies + +- `sqlglot` added to the `viz` optional extra in `pyproject.toml` +- No changes to ggsql required for the initial implementation + +## Scope boundaries + +- **SQLAlchemySource only** — IbisSource could follow later +- **No ggsql changes required** — the `dialect` parameter contribution to `execute()` can come later as an optimization (skipping sqlglot when ggsql natively supports the dialect) +- **No prompt changes** — the LLM already writes SQL for the correct `db_type` + +## Testing + +- Unit tests for `DataSourceReader`: mock SQLAlchemy connection, verify transpile + execute, register/unregister lifecycle, cleanup on error +- Unit tests for sqlglot transpilation of ggsql's generated SQL patterns (recursive CTEs, NTILE percentiles, CREATE TEMPORARY TABLE) across key dialects +- Integration test for fallback: verify bridge failure triggers two-phase approach +- End-to-end with a test database connection if available diff --git a/docs/plans/2026-04-17-datasource-reader-bridge.md b/docs/plans/2026-04-17-datasource-reader-bridge.md new file mode 100644 index 000000000..4d42d8d61 --- /dev/null +++ b/docs/plans/2026-04-17-datasource-reader-bridge.md @@ -0,0 +1,825 @@ +# DataSourceReader Bridge Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Implement a `DataSourceReader` that lets ggsql run its full pipeline against the real database (via SQLAlchemy), using sqlglot for dialect transpilation, with fallback to the current two-phase approach on failure. + +**Architecture:** `DataSourceReader` implements ggsql's reader protocol (`execute_sql`, `register`, `unregister`) by routing SQL through a held SQLAlchemy connection. sqlglot transpiles ggsql's ANSI-generated SQL to the target dialect. Temp tables are created on the real DB for ggsql's intermediate results. Falls back to the existing two-phase approach (renamed `execute_two_phase`) for non-SQLAlchemy sources or on bridge failure. + +**Tech Stack:** Python, sqlglot, SQLAlchemy, ggsql (PyReaderBridge), polars, pytest + +**Design doc:** `docs/plans/2026-04-17-datasource-reader-bridge-design.md` + +--- + +### Task 1: Add sqlglot dependency + +**Files:** +- Modify: `pyproject.toml:52` + +**Step 1: Add sqlglot to the viz extra** + +In `pyproject.toml`, change line 52 from: + +```toml +viz = ["ggsql>=0.2.4", "altair>=6.0", "shinywidgets>=0.8.0", "vl-convert-python>=1.9.0"] +``` + +to: + +```toml +viz = ["ggsql>=0.2.4", "altair>=6.0", "shinywidgets>=0.8.0", "vl-convert-python>=1.9.0", "sqlglot>=26.0"] +``` + +**Step 2: Install the updated dependencies** + +Run: `cd /Users/cpsievert/github/querychat && uv sync --extra viz` +Expected: sqlglot installs successfully + +**Step 3: Verify import** + +Run: `cd /Users/cpsievert/github/querychat && uv run python -c "import sqlglot; print(sqlglot.__version__)"` +Expected: Version prints without error + +**Step 4: Commit** + +```bash +git add pyproject.toml uv.lock +git commit -m "feat: add sqlglot dependency for DataSourceReader bridge" +``` + +--- + +### Task 2: Write dialect mapping and transpile helper (tests first) + +**Files:** +- Create: `pkg-py/tests/test_datasource_reader.py` +- Create (later): `pkg-py/src/querychat/_datasource_reader.py` + +**Step 1: Write tests for dialect mapping and transpilation** + +Create `pkg-py/tests/test_datasource_reader.py`: + +```python +"""Tests for DataSourceReader bridge.""" + +import pytest + + +class TestDialectMapping: + """Tests for SQLGLOT_DIALECTS mapping.""" + + def test_known_dialects_present(self): + from querychat._datasource_reader import SQLGLOT_DIALECTS + + assert SQLGLOT_DIALECTS["postgresql"] == "postgres" + assert SQLGLOT_DIALECTS["snowflake"] == "snowflake" + assert SQLGLOT_DIALECTS["duckdb"] == "duckdb" + assert SQLGLOT_DIALECTS["sqlite"] == "sqlite" + assert SQLGLOT_DIALECTS["mysql"] == "mysql" + assert SQLGLOT_DIALECTS["mssql"] == "tsql" + + def test_unknown_dialect_not_present(self): + from querychat._datasource_reader import SQLGLOT_DIALECTS + + assert "oracle" not in SQLGLOT_DIALECTS + + +class TestTranspileSql: + """Tests for transpile_sql() helper.""" + + def test_identity_for_duckdb(self): + from querychat._datasource_reader import transpile_sql + + sql = "SELECT x, y FROM t WHERE x > 1" + result = transpile_sql(sql, "duckdb") + assert "SELECT" in result + assert "FROM" in result + + def test_transpiles_create_temp_table_to_snowflake(self): + from querychat._datasource_reader import transpile_sql + + sql = "CREATE TEMPORARY TABLE __ggsql_cte_0 AS SELECT x FROM t" + result = transpile_sql(sql, "snowflake") + assert "TEMPORARY" in result.upper() or "TEMP" in result.upper() + assert "__ggsql_cte_0" in result + + def test_transpiles_recursive_cte_to_postgres(self): + from querychat._datasource_reader import transpile_sql + + sql = ( + "WITH RECURSIVE series AS (" + "SELECT 0 AS n UNION ALL SELECT n + 1 FROM series WHERE n < 10" + ") SELECT n FROM series" + ) + result = transpile_sql(sql, "postgres") + assert "RECURSIVE" in result.upper() + + def test_transpiles_ntile_to_snowflake(self): + from querychat._datasource_reader import transpile_sql + + sql = "SELECT NTILE(4) OVER (ORDER BY x) AS quartile FROM t" + result = transpile_sql(sql, "snowflake") + assert "NTILE" in result.upper() + + def test_passthrough_on_empty_dialect(self): + """Empty string dialect means generic/ANSI — should pass through.""" + from querychat._datasource_reader import transpile_sql + + sql = "SELECT 1" + result = transpile_sql(sql, "") + assert result == "SELECT 1" +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py -v` +Expected: ImportError — `querychat._datasource_reader` does not exist + +**Step 3: Implement dialect mapping and transpile helper** + +Create `pkg-py/src/querychat/_datasource_reader.py`: + +```python +"""DataSourceReader bridge: routes ggsql's reader protocol through a real database.""" + +from __future__ import annotations + +import sqlglot + +SQLGLOT_DIALECTS: dict[str, str] = { + "postgresql": "postgres", + "snowflake": "snowflake", + "duckdb": "duckdb", + "sqlite": "sqlite", + "mysql": "mysql", + "mssql": "tsql", + "bigquery": "bigquery", + "redshift": "redshift", +} + + +def transpile_sql(sql: str, dialect: str) -> str: + """Transpile generic SQL to a target dialect using sqlglot.""" + results = sqlglot.transpile(sql, read="", write=dialect) + return results[0] +``` + +**Step 4: Run tests to verify they pass** + +Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py -v` +Expected: All pass + +**Step 5: Commit** + +```bash +git add pkg-py/tests/test_datasource_reader.py pkg-py/src/querychat/_datasource_reader.py +git commit -m "feat: add dialect mapping and transpile_sql helper" +``` + +--- + +### Task 3: Implement DataSourceReader class (tests first) + +This is the core class. It implements ggsql's reader protocol by executing SQL on the real database via SQLAlchemy. Tests use a real SQLite database (in-memory) to verify end-to-end behavior. + +**Files:** +- Modify: `pkg-py/tests/test_datasource_reader.py` +- Modify: `pkg-py/src/querychat/_datasource_reader.py` + +**Step 1: Write tests for DataSourceReader lifecycle** + +Append to `pkg-py/tests/test_datasource_reader.py`: + +```python +import polars as pl +from sqlalchemy import create_engine, text + + +@pytest.fixture +def sqlite_engine(): + """Create an in-memory SQLite database with test data.""" + engine = create_engine("sqlite://") + with engine.connect() as conn: + conn.execute(text("CREATE TABLE test_data (x INTEGER, y INTEGER, label TEXT)")) + conn.execute( + text("INSERT INTO test_data VALUES (1, 10, 'a'), (2, 20, 'b'), (3, 30, 'a')") + ) + conn.commit() + return engine + + +class TestDataSourceReader: + """Tests for DataSourceReader against a real SQLite database.""" + + def test_execute_sql_returns_polars(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + df = reader.execute_sql("SELECT * FROM test_data") + assert isinstance(df, pl.DataFrame) + assert len(df) == 3 + assert set(df.columns) == {"x", "y", "label"} + + def test_execute_sql_with_filter(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + df = reader.execute_sql("SELECT * FROM test_data WHERE x > 1") + assert len(df) == 2 + + def test_register_creates_temp_table(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + df = pl.DataFrame({"a": [1, 2], "b": ["x", "y"]}) + with DataSourceReader(sqlite_engine, "sqlite") as reader: + reader.register("my_temp", df, True) + result = reader.execute_sql("SELECT * FROM my_temp") + assert len(result) == 2 + assert set(result.columns) == {"a", "b"} + + def test_unregister_drops_temp_table(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + df = pl.DataFrame({"a": [1]}) + with DataSourceReader(sqlite_engine, "sqlite") as reader: + reader.register("drop_me", df, True) + reader.unregister("drop_me") + with pytest.raises(Exception, match="drop_me"): + reader.execute_sql("SELECT * FROM drop_me") + + def test_context_manager_cleans_up_temp_tables(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + df = pl.DataFrame({"a": [1]}) + with DataSourceReader(sqlite_engine, "sqlite") as reader: + reader.register("cleanup_test", df, True) + + # After exiting context, temp table should be gone. + # SQLite temp tables are connection-scoped, so they vanish + # when the connection closes. Verify by opening a new connection. + with sqlite_engine.connect() as conn: + result = conn.execute( + text("SELECT name FROM sqlite_temp_master WHERE name = 'cleanup_test'") + ) + assert result.fetchone() is None + + def test_register_replace_overwrites(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [10, 20, 30]}) + with DataSourceReader(sqlite_engine, "sqlite") as reader: + reader.register("replace_me", df1, True) + reader.register("replace_me", df2, True) + result = reader.execute_sql("SELECT * FROM replace_me") + assert len(result) == 3 + + def test_execute_sql_transpiles(self, sqlite_engine): + """Verify that generated SQL gets transpiled to the target dialect.""" + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + # This is valid generic SQL; sqlglot should pass it through for SQLite + df = reader.execute_sql("SELECT x, y FROM test_data ORDER BY x LIMIT 2") + assert len(df) == 2 +``` + +**Step 2: Run tests to verify they fail** + +Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py::TestDataSourceReader -v` +Expected: ImportError — `DataSourceReader` not found + +**Step 3: Implement DataSourceReader** + +Add to `pkg-py/src/querychat/_datasource_reader.py` (below the existing code): + +```python +from typing import TYPE_CHECKING + +import polars as pl +from sqlalchemy import text + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection, Engine + + +class DataSourceReader: + """ + ggsql reader protocol implementation that routes SQL through a real database. + + Implements execute_sql(), register(), and unregister() as expected by + ggsql's PyReaderBridge. Uses sqlglot to transpile ggsql's ANSI-generated + SQL to the target database dialect. + """ + + def __init__(self, engine: Engine, dialect: str): + self._engine = engine + self._dialect = dialect + self._conn: Connection | None = None + self._registered: list[str] = [] + + def __enter__(self): + self._conn = self._engine.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._conn is not None: + try: + for name in self._registered: + try: + self._conn.execute(text(f"DROP TABLE IF EXISTS {name}")) + except Exception: + pass + self._conn.commit() + finally: + self._conn.close() + self._conn = None + self._registered.clear() + return False + + def execute_sql(self, sql: str) -> pl.DataFrame: + assert self._conn is not None, "DataSourceReader must be used as a context manager" + transpiled = transpile_sql(sql, self._dialect) + result = self._conn.execute(text(transpiled)) + rows = result.fetchall() + columns = list(result.keys()) + if not rows: + return pl.DataFrame(schema={col: pl.Utf8 for col in columns}) + data = {col: [row[i] for row in rows] for i, col in enumerate(columns)} + return pl.DataFrame(data) + + def register(self, name: str, df: pl.DataFrame, replace: bool = True) -> None: + assert self._conn is not None, "DataSourceReader must be used as a context manager" + if replace: + self._conn.execute(text(f"DROP TABLE IF EXISTS {name}")) + if name in self._registered: + self._registered.remove(name) + + col_defs = ", ".join( + f"{col} {_polars_to_sql_type(dtype)}" for col, dtype in zip(df.columns, df.dtypes) + ) + create_sql = f"CREATE TEMPORARY TABLE {name} ({col_defs})" + transpiled_create = transpile_sql(create_sql, self._dialect) + self._conn.execute(text(transpiled_create)) + self._registered.append(name) + + if len(df) > 0: + placeholders = ", ".join(f":{col}" for col in df.columns) + insert_sql = f"INSERT INTO {name} VALUES ({placeholders})" + rows = df.to_dicts() + self._conn.execute(text(insert_sql), rows) + + self._conn.commit() + + def unregister(self, name: str) -> None: + assert self._conn is not None, "DataSourceReader must be used as a context manager" + self._conn.execute(text(f"DROP TABLE IF EXISTS {name}")) + self._conn.commit() + if name in self._registered: + self._registered.remove(name) + + +def _polars_to_sql_type(dtype: pl.DataType) -> str: + """Map polars dtypes to generic SQL types for CREATE TABLE.""" + if dtype.is_integer(): + return "INTEGER" + if dtype.is_float(): + return "REAL" + if dtype == pl.Boolean: + return "BOOLEAN" + if dtype == pl.Date: + return "DATE" + if dtype == pl.Datetime or dtype == pl.Duration: + return "TIMESTAMP" + return "TEXT" +``` + +**Step 4: Run tests to verify they pass** + +Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py -v` +Expected: All pass + +**Step 5: Commit** + +```bash +git add pkg-py/tests/test_datasource_reader.py pkg-py/src/querychat/_datasource_reader.py +git commit -m "feat: implement DataSourceReader with temp table lifecycle" +``` + +--- + +### Task 4: Integrate with ggsql — end-to-end test with SQLite + +Verify that `DataSourceReader` works with `ggsql.execute(query, reader)` against a real SQLite database. + +**Files:** +- Modify: `pkg-py/tests/test_datasource_reader.py` + +**Step 1: Write end-to-end test** + +Append to `pkg-py/tests/test_datasource_reader.py`: + +```python +@pytest.mark.ggsql +class TestDataSourceReaderWithGgsql: + """End-to-end tests: DataSourceReader + ggsql.execute().""" + + def test_simple_scatter(self, sqlite_engine): + import ggsql + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + spec = ggsql.execute( + "SELECT x, y FROM test_data VISUALISE x, y DRAW point", + reader, + ) + assert spec.metadata()["rows"] == 3 + assert "VISUALISE" in spec.visual() + + def test_with_filter(self, sqlite_engine): + import ggsql + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + spec = ggsql.execute( + "SELECT x, y FROM test_data WHERE x > 1 VISUALISE x, y DRAW point", + reader, + ) + assert spec.metadata()["rows"] == 2 + + def test_form_b_visualise_from(self, sqlite_engine): + import ggsql + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + spec = ggsql.execute( + "VISUALISE x, y FROM test_data DRAW point", + reader, + ) + assert spec.metadata()["rows"] == 3 + + def test_with_aggregation(self, sqlite_engine): + import ggsql + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + spec = ggsql.execute( + "SELECT label, SUM(y) AS total FROM test_data GROUP BY label " + "VISUALISE label AS x, total AS y DRAW bar", + reader, + ) + assert spec.metadata()["rows"] == 2 +``` + +**Step 2: Run tests** + +Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py::TestDataSourceReaderWithGgsql -v` +Expected: All pass (if the reader protocol is correctly implemented). If any fail, debug and fix. + +**Step 3: Commit** + +```bash +git add pkg-py/tests/test_datasource_reader.py +git commit -m "test: end-to-end DataSourceReader with ggsql.execute()" +``` + +--- + +### Task 5: Refactor execute_ggsql — rename current body, add bridge path + +**Files:** +- Modify: `pkg-py/src/querychat/_viz_ggsql.py` +- Modify: `pkg-py/src/querychat/_viz_tools.py:182` +- Modify: `pkg-py/tests/test_ggsql.py` + +**Step 1: Write test for bridge+fallback behavior** + +Add to `pkg-py/tests/test_datasource_reader.py`: + +```python +import narwhals.stable.v1 as nw + + +class TestExecuteGgsqlBridge: + """Tests for the updated execute_ggsql entry point with bridge+fallback.""" + + @pytest.mark.ggsql + def test_sqlalchemy_source_uses_bridge(self, sqlite_engine): + """SQLAlchemySource with known dialect should use the bridge path.""" + import ggsql + from querychat._datasource import SQLAlchemySource + from querychat._viz_ggsql import execute_ggsql + + # Create a table that SQLAlchemySource can find + with sqlite_engine.connect() as conn: + conn.execute(text("CREATE TABLE IF NOT EXISTS bridge_data (x INTEGER, y INTEGER)")) + conn.execute(text("INSERT INTO bridge_data VALUES (1, 10), (2, 20), (3, 30)")) + conn.commit() + + ds = SQLAlchemySource(sqlite_engine, "bridge_data") + query = "SELECT x, y FROM bridge_data VISUALISE x, y DRAW point" + validated = ggsql.validate(query) + spec = execute_ggsql(ds, query, validated) + assert spec.metadata()["rows"] == 3 + + @pytest.mark.ggsql + def test_dataframe_source_uses_fallback(self): + """DataFrameSource should always use the fallback path.""" + import ggsql + from querychat._datasource import DataFrameSource + from querychat._viz_ggsql import execute_ggsql + + nw_df = nw.from_native(pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + ds = DataFrameSource(nw_df, "test_data") + query = "SELECT * FROM test_data VISUALISE x, y DRAW point" + validated = ggsql.validate(query) + spec = execute_ggsql(ds, query, validated) + assert spec.metadata()["rows"] == 3 +``` + +**Step 2: Run new tests to verify they fail** + +Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py::TestExecuteGgsqlBridge -v` +Expected: TypeError — `execute_ggsql()` doesn't accept 3 arguments yet + +**Step 3: Update execute_ggsql signature and add bridge logic** + +Modify `pkg-py/src/querychat/_viz_ggsql.py`. The full updated file: + +```python +""" +Helpers for executing ggsql queries in querychat. + +Architecture overview +--------------------- +Querychat executes ggsql queries through two possible paths: + +1. **Bridge path** (SQLAlchemySource with known dialect) — A + ``DataSourceReader`` implements ggsql's reader protocol, routing all SQL + through the real database. ggsql runs its full pipeline (CTEs, stat + transforms, layer queries) against the real DB. sqlglot transpiles + ggsql's ANSI-generated SQL to the target dialect. This path supports + multi-source layers and avoids pulling large result sets into memory. + +2. **Fallback path** (all other DataSource types, or bridge failure) — The + SQL portion (before VISUALISE) runs on the real database via + ``DataSource.execute_query()``, then the VISUALISE portion replays + locally against the SQL result using ``ggsql.DuckDBReader``. + +The fallback path requires reconstructing a valid ggsql query from the +split ``sql()`` and ``visual()`` parts. See ``execute_two_phase()`` for +details on the two VISUALISE forms (Form A and Form B). + +Limitation of fallback path: layer-specific sources +---------------------------------------------------- +ggsql supports per-layer data sources (``DRAW line MAPPING … FROM cte``), +but the fallback path can't support them because the SQL result is a single +DataFrame — CTEs don't survive the DataSource boundary. The bridge path +handles this correctly. +""" + +from __future__ import annotations + +import logging +import re +from typing import TYPE_CHECKING + +from ._utils import to_polars + +if TYPE_CHECKING: + import ggsql + + from ._datasource import DataSource + +logger = logging.getLogger(__name__) + + +def execute_ggsql( + data_source: DataSource, + query: str, + validated: ggsql.Validated, +) -> ggsql.Spec: + """ + Execute a ggsql query, choosing the bridge or fallback path. + + Parameters + ---------- + data_source + The querychat DataSource to execute against. + query + The original ggsql query string (needed for the bridge path). + validated + A pre-validated ggsql query (from ``ggsql.validate()``). + + Returns + ------- + ggsql.Spec + The writer-independent plot specification. + + """ + from ._datasource import SQLAlchemySource + from ._datasource_reader import SQLGLOT_DIALECTS, DataSourceReader + + if isinstance(data_source, SQLAlchemySource): + dialect = SQLGLOT_DIALECTS.get(data_source._engine.dialect.name) + if dialect is not None: + try: + with DataSourceReader(data_source._engine, dialect) as reader: + import ggsql as _ggsql + + return _ggsql.execute(query, reader) + except Exception: + logger.debug( + "DataSourceReader bridge failed, falling back to two-phase", + exc_info=True, + ) + + return execute_two_phase(data_source, validated) + + +def execute_two_phase( + data_source: DataSource, + validated: ggsql.Validated, +) -> ggsql.Spec: + """ + Execute a ggsql query using the two-phase approach. + + Phase 1: execute SQL on the real database. + Phase 2: replay the VISUALISE portion locally in DuckDB. + + This is the fallback for non-SQLAlchemy sources or when the bridge fails. + """ + from ggsql import DuckDBReader + + visual = validated.visual() + if has_layer_level_source(visual): + raise ValueError( + "Layer-specific sources are not currently supported in querychat visual " + "queries. Rewrite the query so that all layers come from the final SQL " + "result." + ) + + pl_df = to_polars(data_source.execute_query(validated.sql())) + pl_df.columns = [c.lower() for c in pl_df.columns] + + reader = DuckDBReader("duckdb://memory") + table = extract_visualise_table(visual) + + if table is not None: + name = table[1:-1] if table.startswith('"') and table.endswith('"') else table + reader.register(name, pl_df) + return reader.execute(visual) + else: + reader.register("_data", pl_df) + return reader.execute(f"SELECT * FROM _data {visual}") + + +def extract_visualise_table(visual: str) -> str | None: + """ + Extract the table name from ``VISUALISE … FROM `` if present. + + Only looks at the portion before the first DRAW clause, since FROM after + DRAW belongs to layer-level MAPPING (a different concern). + """ + draw_pos = re.search(r"\bDRAW\b", visual, re.IGNORECASE) + vis_clause = visual[: draw_pos.start()] if draw_pos else visual + m = re.search(r'\bFROM\s+("[^"]+?"|\S+)', vis_clause, re.IGNORECASE) + return m.group(1) if m else None + + +def has_layer_level_source(visual: str) -> bool: + """ + Return ``True`` when a DRAW clause defines its own ``FROM ``. + """ + clauses = re.split( + r"(?=\b(?:DRAW|SCALE|PROJECT|FACET|PLACE|LABEL|THEME)\b)", + visual, + flags=re.IGNORECASE, + ) + for clause in clauses: + if not re.match(r"^\s*DRAW\b", clause, re.IGNORECASE): + continue + if re.search( + r'\bMAPPING\b[\s\S]*?\bFROM\s+("[^"]+?"|\S+)', + clause, + re.IGNORECASE, + ): + return True + return False +``` + +**Step 4: Update the caller in `_viz_tools.py`** + +In `pkg-py/src/querychat/_viz_tools.py`, change line 182 from: + +```python + spec = execute_ggsql(data_source, validated) +``` + +to: + +```python + spec = execute_ggsql(data_source, ggsql, validated) +``` + +Note: `ggsql` here is the local parameter name (the query string) from line 151, not the module. The module import at the top of the function (`from ggsql import VegaLiteWriter, validate`) is a different scope. + +**Wait** — there's a name collision. The parameter is `ggsql: str` (line 151) and the module import is `from ggsql import ...` (line 148). The parameter shadows the module name within `visualize_query`. But `execute_ggsql` is imported at the module level, not from the `ggsql` package, so the call `execute_ggsql(data_source, ggsql, validated)` correctly passes the string parameter. This works. + +**Step 5: Update existing tests in `test_ggsql.py`** + +The existing `TestExecuteGgsql` tests call `execute_ggsql(ds, ggsql.validate(query))` with 2 args. Update them to pass 3 args. In `pkg-py/tests/test_ggsql.py`, update each call in `TestExecuteGgsql`: + +Change every occurrence of: +```python + spec = execute_ggsql(ds, ggsql.validate(query)) +``` +to: +```python + spec = execute_ggsql(ds, query, ggsql.validate(query)) +``` + +Also update the layer-level source test (line 188): +```python + execute_ggsql(ds, ggsql.validate(query)) +``` +to: +```python + execute_ggsql(ds, query, ggsql.validate(query)) +``` + +**Step 6: Run all tests** + +Run: `cd /Users/cpsievert/github/querychat && uv run pytest pkg-py/tests/test_datasource_reader.py pkg-py/tests/test_ggsql.py -v` +Expected: All pass + +**Step 7: Run type checker and linter** + +Run: `cd /Users/cpsievert/github/querychat && uv run ruff check --fix pkg-py --config pyproject.toml && make py-check-types` +Expected: No errors (fix any that arise) + +**Step 8: Commit** + +```bash +git add pkg-py/src/querychat/_viz_ggsql.py pkg-py/src/querychat/_viz_tools.py pkg-py/tests/test_ggsql.py pkg-py/tests/test_datasource_reader.py +git commit -m "feat: integrate DataSourceReader bridge into execute_ggsql" +``` + +--- + +### Task 6: Run full test suite and fix any issues + +**Files:** +- Potentially any file touched above + +**Step 1: Run full Python checks** + +Run: `cd /Users/cpsievert/github/querychat && uv run ruff check --fix pkg-py --config pyproject.toml` +Expected: Clean + +**Step 2: Run type checker** + +Run: `cd /Users/cpsievert/github/querychat && make py-check-types` +Expected: Clean + +**Step 3: Run full test suite** + +Run: `cd /Users/cpsievert/github/querychat && make py-check-tests` +Expected: All pass + +**Step 4: Fix any failures** + +If any tests fail, debug and fix. Common issues: +- sqlglot transpilation producing unexpected SQL for a specific dialect +- Empty DataFrame handling in `execute_sql` (no rows returned) +- Polars dtype mapping edge cases in `_polars_to_sql_type` + +**Step 5: Commit any fixes** + +```bash +git add -u +git commit -m "fix: address issues from full test suite run" +``` + +--- + +### Task 7: Update module docstring and clean up + +**Files:** +- Modify: `pkg-py/src/querychat/_datasource_reader.py` + +**Step 1: Ensure the module docstring is accurate** + +The docstring added in Task 3 should already be correct. Verify `_datasource_reader.py` has a clear module-level docstring explaining the bridge's purpose. + +**Step 2: Verify no dead code remains** + +Check that the old `execute_ggsql` docstring in `_viz_ggsql.py` has been updated to reflect the new 3-arg signature and bridge behavior (done in Task 5). + +**Step 3: Final commit if any cleanup was needed** + +```bash +git add -u +git commit -m "docs: clean up DataSourceReader module docstrings" +``` diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md index a02fe5f68..20785d727 100644 --- a/pkg-py/CHANGELOG.md +++ b/pkg-py/CHANGELOG.md @@ -9,16 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features +* Added a `"visualize_query"` tool that lets the LLM create inline Altair charts from natural language requests using [ggsql](https://github.com/posit-dev/ggsql) — a SQL extension for declarative data visualization. Include it via `tools=("query", "visualize_query")` (or alongside `"update"`). Charts render inline in the chat with fullscreen support, a "Show Query" toggle, and Save as PNG/SVG. Install the optional dependencies with `pip install querychat[viz]`. (#219) + +* The `querychat_query` tool now accepts an optional `collapsed` parameter. When `collapsed=True`, the result card starts collapsed so preparatory or exploratory queries don't clutter the conversation. The LLM is guided to use this automatically when running queries before a visualization. + +* Added support for Snowflake Semantic Views. When connected to Snowflake (via SQLAlchemy or Ibis), querychat automatically discovers available Semantic Views and includes their definitions in the system prompt. This helps the LLM generate correct queries using the `SEMANTIC_VIEW()` table function with certified business metrics and dimensions. (#200) + * `QueryChat()` now supports deferred chat client initialization. Pass `client=` to `server()` to provide a session-scoped chat client, enabling use cases where API credentials are only available at session time (e.g., Posit Connect managed OAuth tokens). When no `client` is specified anywhere, querychat resolves a sensible default from the `QUERYCHAT_CLIENT` environment variable (or `"openai"`). (#205) ### Improvements * When a custom `prompt_template` is provided that doesn't contain Mustache references to `{{schema}}`, the expensive `get_schema()` call is now skipped entirely. This allows users with large databases to avoid slow startup by providing their own prompt that includes schema information inline (or omits it). (#208) -### New features - -* Added support for Snowflake Semantic Views. When connected to Snowflake (via SQLAlchemy or Ibis), querychat automatically discovers available Semantic Views and includes their definitions in the system prompt. This helps the LLM generate correct queries using the `SEMANTIC_VIEW()` table function with certified business metrics and dimensions. (#200) - ## [0.5.1] - 2026-01-23 ### New features diff --git a/pkg-py/docs/_quarto.yml b/pkg-py/docs/_quarto.yml index df2576e49..7574fb787 100644 --- a/pkg-py/docs/_quarto.yml +++ b/pkg-py/docs/_quarto.yml @@ -50,6 +50,7 @@ website: - models.qmd - data-sources.qmd - context.qmd + - visualize.qmd - section: "Build custom apps" contents: - build-intro.qmd @@ -114,6 +115,8 @@ quartodoc: signature_name: short - name: tools.tool_reset_dashboard signature_name: short + - name: tools.tool_visualize_query + signature_name: short filters: - "interlinks" diff --git a/pkg-py/docs/build.qmd b/pkg-py/docs/build.qmd index 009f6cfd0..f4ff68abd 100644 --- a/pkg-py/docs/build.qmd +++ b/pkg-py/docs/build.qmd @@ -31,6 +31,14 @@ from querychat.data import titanic qc = QueryChat(titanic(), "titanic") ``` +::: {.callout-tip} +### Visualization support + +querychat supports an optional visualization tool that lets the LLM create inline charts. +Enable it by including `"visualize_query"` in the `tools` parameter. +See [Visualizations](visualize.qmd) for details. +::: + ::: {.callout-note collapse="true"} ## Quick start with `.app()` diff --git a/pkg-py/docs/images/viz-bar-chart.png b/pkg-py/docs/images/viz-bar-chart.png new file mode 100644 index 000000000..0a7033651 Binary files /dev/null and b/pkg-py/docs/images/viz-bar-chart.png differ diff --git a/pkg-py/docs/images/viz-fullscreen.png b/pkg-py/docs/images/viz-fullscreen.png new file mode 100644 index 000000000..fbefca3fb Binary files /dev/null and b/pkg-py/docs/images/viz-fullscreen.png differ diff --git a/pkg-py/docs/images/viz-scatter.png b/pkg-py/docs/images/viz-scatter.png new file mode 100644 index 000000000..db25bfe2a Binary files /dev/null and b/pkg-py/docs/images/viz-scatter.png differ diff --git a/pkg-py/docs/images/viz-show-query.png b/pkg-py/docs/images/viz-show-query.png new file mode 100644 index 000000000..fc9ae6384 Binary files /dev/null and b/pkg-py/docs/images/viz-show-query.png differ diff --git a/pkg-py/docs/index.qmd b/pkg-py/docs/index.qmd index 88b0da5c5..8d119ae3b 100644 --- a/pkg-py/docs/index.qmd +++ b/pkg-py/docs/index.qmd @@ -75,6 +75,11 @@ querychat can also handle more general questions about the data that require cal ![](/images/quickstart-summary.png){fig-alt="Screenshot of the querychat's app with a summary statistic inlined in the chat." class="lightbox shadow rounded mb-3"} +querychat can also create visualizations, powered by [ggsql](https://ggsql.org/) and [Altair](https://altair-viz.github.io/). +With the [visualization tool](visualize.qmd) enabled, ask for a chart and it appears inline in the conversation: + +![](/images/viz-bar-chart.png){fig-alt="Screenshot of querychat with an inline bar chart showing survival rate by passenger class." class="lightbox shadow rounded mb-3"} + ## Web frameworks While the examples above use [Shiny](https://shiny.posit.co/py/), querychat also supports [Streamlit](https://streamlit.io/), [Gradio](https://gradio.app/), and [Dash](https://dash.plotly.com/). Each framework has its own `QueryChat` class under the relevant sub-module, but the methods and properties are mostly consistent across all of them. diff --git a/pkg-py/docs/tools.qmd b/pkg-py/docs/tools.qmd index e438e1bde..09b46be9d 100644 --- a/pkg-py/docs/tools.qmd +++ b/pkg-py/docs/tools.qmd @@ -6,7 +6,7 @@ querychat combines [tool calling](https://posit-dev.github.io/chatlas/get-starte One important thing to understand generally about querychat's tools is they are Python functions, and that execution happens on _your machine_, not on the LLM provider's side. In other words, the SQL queries generated by the LLM are executed locally in the Python process running the app. -querychat provides the LLM access two tool groups: +querychat provides the LLM access to three tool groups: 1. **Data updating** - Filter and sort data (without sending results to the LLM). 2. **Data analysis** - Calculate summaries and return results for interpretation by the LLM. @@ -52,6 +52,40 @@ app = qc.app() ![](/images/quickstart-summary.png){fig-alt="Screenshot of the querychat's app with a summary statistic inlined in the chat." class="lightbox shadow rounded mb-3"} +## Data visualization + +When a user asks for a chart or visualization, the LLM generates a [ggsql](https://ggsql.org/) query — standard SQL extended with a `VISUALISE` clause — and requests a call to the `visualize_query` tool. +This tool: + +1. Executes the SQL portion of the query +2. Renders the `VISUALISE` clause as an Altair chart +3. Displays the chart inline in the chat + +Unlike the data updating tools, visualization queries don't affect the dashboard filter. +They query the full dataset independently, and each call produces a new inline chart message in the chat. + +The inline chart includes controls for fullscreen viewing, saving as PNG/SVG, and a "Show Query" toggle that reveals the underlying ggsql code. + +To use the visualization tool, first install the `viz` extras: + +```bash +pip install "querychat[viz]" +``` + +Then include `"visualize_query"` in the `tools` parameter (it is not enabled by default): + +```{.python filename="viz-app.py"} +from querychat import QueryChat +from querychat.data import titanic + +qc = QueryChat(titanic(), "titanic", tools=("query", "update", "visualize_query")) +app = qc.app() +``` + +![](/images/viz-scatter.png){fig-alt="Screenshot of querychat with an inline scatter plot." class="lightbox shadow rounded mb-3"} + +See [Visualizations](visualize.qmd) for more details. + ## View the source If you'd like to better understand how the tools work and how the LLM is prompted to use them, check out the following resources: @@ -65,3 +99,4 @@ If you'd like to better understand how the tools work and how the LLM is prompte - [`prompts/tool-update-dashboard.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-update-dashboard.md) - [`prompts/tool-reset-dashboard.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-reset-dashboard.md) - [`prompts/tool-query.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-query.md) +- [`prompts/tool-visualize-query.md`](https://github.com/posit-dev/querychat/blob/main/pkg-py/src/querychat/prompts/tool-visualize-query.md) diff --git a/pkg-py/docs/visualize.qmd b/pkg-py/docs/visualize.qmd new file mode 100644 index 000000000..8074869b2 --- /dev/null +++ b/pkg-py/docs/visualize.qmd @@ -0,0 +1,104 @@ +--- +title: Visualizations +lightbox: true +--- + +querychat can create charts inline in the chat. +When you ask a question that benefits from a visualization, the LLM writes a query using [ggsql](https://ggsql.org/) — a SQL-like visualization grammar — and renders an [Altair](https://altair-viz.github.io/) chart directly in the conversation. + +## Getting started + +Visualization requires two steps: + +1. **Install the `viz` extras:** + + ```bash + pip install "querychat[viz]" + ``` + +2. **Include `"visualize_query"` in the `tools` parameter:** + + ```{.python filename="app.py"} + from querychat import QueryChat + from querychat.data import titanic + + qc = QueryChat(titanic(), "titanic", tools=("query", "update", "visualize_query")) + app = qc.app() + ``` + +Ask something like "Show me survival rate by passenger class as a bar chart" and querychat will generate and display the chart inline: + +![](/images/viz-bar-chart.png){fig-alt="Bar chart showing survival rate by passenger class." class="lightbox shadow rounded mb-3"} + +## Choosing tools + +The `tools` parameter controls which capabilities the LLM has access to. +By default, only `"query"` and `"update"` are enabled — visualization must be opted into explicitly. + +To enable only query and visualization (no dashboard filtering): + +```{.python} +qc = QueryChat(titanic(), "titanic", tools=("query", "visualize_query")) +``` + +See [Tools](tools.qmd) for a full reference on available tools and what each one does. + +## Custom apps + +The example below shows a minimal custom Shiny app using only the `"query"` and `"visualize_query"` tools. +It omits `"update"` to focus entirely on data analysis and visualization rather than dashboard filtering: + +```{.python filename="app.py"} +{{< include /../examples/10-viz-app.py >}} +``` + +## What you can ask for + +querychat can generate a wide range of chart types. +Some example prompts: + +- "Show me a bar chart of survival rate by passenger class" +- "Scatter plot of age vs fare, colored by survival" +- "Line chart of average fare over time" +- "Histogram of passenger ages" +- "Facet survival rate by class and sex" + +The LLM chooses an appropriate chart type based on your question, but you can always be specific. +If you ask for a bar chart, you'll get a bar chart. + +![](/images/viz-scatter.png){fig-alt="Scatter plot of age vs fare colored by survival status." class="lightbox shadow rounded mb-3"} + +::: {.callout-tip} +If you don't like the chart, ask the LLM to adjust it — for example, "make the dots bigger" or "use a log scale on the y-axis". +::: + +## Chart controls + +Each chart has controls in its footer: + +**Fullscreen** — Click the expand icon to view the chart in fullscreen mode. + +![](/images/viz-fullscreen.png){fig-alt="A chart displayed in fullscreen mode." class="lightbox shadow rounded mb-3"} + +**Save** — Download the chart as a PNG or SVG file. + +**Show Query** — Expand the footer to see the ggsql query used to generate the chart. + +![](/images/viz-show-query.png){fig-alt="A chart with the Show Query footer expanded, showing the ggsql query." class="lightbox shadow rounded mb-3"} + +## How it works + +1. **The LLM generates a ggsql query** — a SQL-like grammar that describes both data transformation and visual encoding in a single statement. +2. **The SQL is executed** — querychat runs the data portion of the query against your data source locally. +3. **The VISUALISE clause is rendered** — the result is passed to Altair, which produces a Vega-Lite chart specification. +4. **The chart appears inline** — the chart is streamed back into the conversation as an interactive widget. + +Note that visualization queries are independent of any active dashboard filter set by the `update` tool. +They always run against the full dataset. + +Learn more about the ggsql grammar at [ggsql.org](https://ggsql.org/). + +## See also + +- [Tools](tools.qmd) — Understand what querychat can do under the hood +- [Provide context](context.qmd) — Help the LLM understand your data better diff --git a/pkg-py/examples/10-viz-app.py b/pkg-py/examples/10-viz-app.py new file mode 100644 index 000000000..5a3af3356 --- /dev/null +++ b/pkg-py/examples/10-viz-app.py @@ -0,0 +1,26 @@ +from querychat import QueryChat +from querychat.data import titanic + +from shiny import App, ui + +# Omits "update" tool — this demo focuses on query + visualization only +qc = QueryChat( + titanic(), + "titanic", + tools=("query", "visualize_query"), +) + +#def app_ui(request): +# return ui.page_fillable( +# qc.ui(), +# ) +# +# +#def server(input, output, session): +# qc.server(enable_bookmarking=True) +# +# +#app = App(app_ui, server, bookmark_store="url") + + +app = qc.app() \ No newline at end of file diff --git a/pkg-py/src/querychat/_datasource_reader.py b/pkg-py/src/querychat/_datasource_reader.py new file mode 100644 index 000000000..4ff050fa8 --- /dev/null +++ b/pkg-py/src/querychat/_datasource_reader.py @@ -0,0 +1,160 @@ +"""DataSourceReader bridge: routes ggsql's reader protocol through a real database.""" + +from __future__ import annotations + +import contextlib +import logging +from typing import TYPE_CHECKING + +import polars as pl +import sqlglot +from sqlalchemy import text + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection, Engine + +logger = logging.getLogger(__name__) + +SQLGLOT_DIALECTS: dict[str, str] = { + # Built-in SQLAlchemy dialects + "postgresql": "postgres", + "mysql": "mysql", + "sqlite": "sqlite", + "mssql": "tsql", + "oracle": "oracle", + # Third-party SQLAlchemy dialects (dialect name verified via engine.dialect.name) + "duckdb": "duckdb", + "snowflake": "snowflake", + "bigquery": "bigquery", + "redshift": "redshift", + "trino": "trino", + "databricks": "databricks", + "clickhousedb": "clickhouse", # clickhouse-connect + "clickhouse": "clickhouse", # clickhouse-sqlalchemy + "awsathena": "athena", # PyAthena + "teradatasql": "teradata", # teradatasqlalchemy + "exasol": "exasol", + "doris": "doris", + "singlestoredb": "singlestore", + "risingwave": "risingwave", + "druid": "druid", + "hive": "hive", # PyHive; also covers Spark via hive:// + "presto": "presto", +} + + +def register_sqlglot_dialect(sqlalchemy_name: str, sqlglot_name: str) -> None: + """ + Register a custom SQLAlchemy dialect name to sqlglot dialect mapping. + + Use this if your database's SQLAlchemy driver reports a ``dialect.name`` + that isn't in the built-in mapping. + + Parameters + ---------- + sqlalchemy_name + The value of ``engine.dialect.name`` for your database. + sqlglot_name + The corresponding sqlglot dialect identifier (see + ``sqlglot.dialects.dialect.Dialect.classes`` for valid names). + + """ + SQLGLOT_DIALECTS[sqlalchemy_name] = sqlglot_name + + +def transpile_sql(sql: str, dialect: str) -> str: + """Transpile generic SQL to a target dialect using sqlglot.""" + results = sqlglot.transpile(sql, read="", write=dialect) + return results[0] + + +class DataSourceReader: + """ + ggsql reader protocol implementation that routes SQL through a real database. + + Implements execute_sql(), register(), and unregister() as expected by + ggsql's PyReaderBridge. + """ + + def __init__(self, engine: Engine, dialect: str): + self._engine = engine + self._dialect = dialect + self._conn: Connection | None = None + self._registered: list[str] = [] + + def __enter__(self): + self._conn = self._engine.connect() + return self + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> bool: + if self._conn is not None: + try: + for name in self._registered: + with contextlib.suppress(Exception): + self._conn.execute(text(f"DROP TABLE IF EXISTS {name}")) + self._conn.commit() + finally: + self._conn.close() + self._conn = None + self._registered.clear() + return False + + def execute_sql(self, sql: str) -> pl.DataFrame: + if self._conn is None: + raise RuntimeError("DataSourceReader must be used as a context manager") + transpiled = transpile_sql(sql, self._dialect) + result = self._conn.execute(text(transpiled)) + rows = result.fetchall() + columns = list(result.keys()) + if not rows: + return pl.DataFrame(schema=dict.fromkeys(columns, pl.Utf8)) + data = {col: [row[i] for row in rows] for i, col in enumerate(columns)} + return pl.DataFrame(data) + + def register(self, name: str, df: pl.DataFrame, replace: bool = True) -> None: # noqa: FBT001, FBT002 + if self._conn is None: + raise RuntimeError("DataSourceReader must be used as a context manager") + if replace: + self._conn.execute(text(f"DROP TABLE IF EXISTS {name}")) + if name in self._registered: + self._registered.remove(name) + + col_defs = ", ".join( + f"{col} {polars_to_sql_type(dtype)}" + for col, dtype in zip(df.columns, df.dtypes, strict=True) + ) + create_sql = f"CREATE TEMPORARY TABLE {name} ({col_defs})" + transpiled_create = transpile_sql(create_sql, self._dialect) + self._conn.execute(text(transpiled_create)) + self._registered.append(name) + + if len(df) > 0: + placeholders = ", ".join(f":{col}" for col in df.columns) + insert_sql = f"INSERT INTO {name} VALUES ({placeholders})" + rows = df.to_dicts() + self._conn.execute(text(insert_sql), rows) + + self._conn.commit() + + def unregister(self, name: str) -> None: + if self._conn is None: + raise RuntimeError("DataSourceReader must be used as a context manager") + self._conn.execute(text(f"DROP TABLE IF EXISTS {name}")) + self._conn.commit() + if name in self._registered: + self._registered.remove(name) + + +def polars_to_sql_type(dtype: pl.DataType) -> str: + """Map polars dtypes to generic SQL types for CREATE TABLE.""" + if dtype.is_integer(): + return "INTEGER" + if dtype.is_float(): + return "REAL" + if dtype == pl.Boolean: + return "BOOLEAN" + if dtype == pl.Date: + return "DATE" + if dtype in (pl.Datetime, pl.Duration): + return "TIMESTAMP" + return "TEXT" diff --git a/pkg-py/src/querychat/_icons.py b/pkg-py/src/querychat/_icons.py index 2b7683da0..fc484c9c0 100644 --- a/pkg-py/src/querychat/_icons.py +++ b/pkg-py/src/querychat/_icons.py @@ -2,19 +2,35 @@ from shiny import ui -ICON_NAMES = Literal["arrow-counterclockwise", "funnel-fill", "terminal-fill", "table"] +ICON_NAMES = Literal[ + "arrow-counterclockwise", + "bar-chart-fill", + "chevron-down", + "download", + "funnel-fill", + "graph-up", + "terminal-fill", + "table", +] -def bs_icon(name: ICON_NAMES) -> ui.HTML: +def bs_icon(name: ICON_NAMES, cls: str = "") -> ui.HTML: """Get Bootstrap icon SVG by name.""" if name not in BS_ICONS: raise ValueError(f"Unknown Bootstrap icon: {name}") - return ui.HTML(BS_ICONS[name]) + svg = BS_ICONS[name] + if cls: + svg = svg.replace('class="', f'class="{cls} ', 1) + return ui.HTML(svg) BS_ICONS = { "arrow-counterclockwise": '', + "bar-chart-fill": '', + "chevron-down": '', + "download": '', "funnel-fill": '', + "graph-up": '', "terminal-fill": '', "table": '', } diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py index d3bf29e26..25e841971 100644 --- a/pkg-py/src/querychat/_querychat_base.py +++ b/pkg-py/src/querychat/_querychat_base.py @@ -23,11 +23,13 @@ from ._shiny_module import GREETING_PROMPT from ._system_prompt import QueryChatSystemPrompt from ._utils import MISSING, MISSING_TYPE, is_ibis_table +from ._viz_utils import has_viz_deps, has_viz_tool from .tools import ( UpdateDashboardData, tool_query, tool_reset_dashboard, tool_update_dashboard, + tool_visualize_query, ) if TYPE_CHECKING: @@ -35,8 +37,10 @@ from narwhals.stable.v1.typing import IntoFrame -TOOL_GROUPS = Literal["update", "query"] + from ._viz_tools import VisualizeQueryData +TOOL_GROUPS = Literal["update", "query", "visualize_query"] +DEFAULT_TOOLS: tuple[TOOL_GROUPS, ...] = ("update", "query") class QueryChatBase(Generic[IntoFrameT]): """ @@ -58,7 +62,7 @@ def __init__( *, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -72,7 +76,7 @@ def __init__( "Table name must begin with a letter and contain only letters, numbers, and underscores", ) - self.tools = normalize_tools(tools, default=("update", "query")) + self.tools = normalize_tools(tools, default=DEFAULT_TOOLS) self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting # Store init parameters for deferred system prompt building @@ -128,6 +132,7 @@ def _create_session_client( tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING, update_dashboard: Callable[[UpdateDashboardData], None] | None = None, reset_dashboard: Callable[[], None] | None = None, + visualize_query: Callable[[VisualizeQueryData], None] | None = None, ) -> chatlas.Chat: """Create a fresh, fully-configured Chat.""" spec = self._client_spec if isinstance(client_spec, MISSING_TYPE) else client_spec @@ -152,6 +157,10 @@ def _create_session_client( if "query" in resolved_tools: chat.register_tool(tool_query(data_source)) + if "visualize_query" in resolved_tools: + query_viz_fn = visualize_query or (lambda _: None) + chat.register_tool(tool_visualize_query(data_source, query_viz_fn)) + return chat def client( @@ -160,6 +169,7 @@ def client( tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE = MISSING, update_dashboard: Callable[[UpdateDashboardData], None] | None = None, reset_dashboard: Callable[[], None] | None = None, + visualize_query: Callable[[VisualizeQueryData], None] | None = None, ) -> chatlas.Chat: """ Create a chat client with registered tools. @@ -167,11 +177,14 @@ def client( Parameters ---------- tools - Which tools to include: `"update"`, `"query"`, or both. + Which tools to include: `"update"`, `"query"`, `"visualize_query"`, + or a combination. update_dashboard Callback when update_dashboard tool succeeds. reset_dashboard Callback when reset_dashboard tool is invoked. + visualize_query + Callback when visualize_query tool succeeds. Returns ------- @@ -184,6 +197,7 @@ def client( tools=tools, update_dashboard=update_dashboard, reset_dashboard=reset_dashboard, + visualize_query=visualize_query, ) def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str: @@ -293,14 +307,24 @@ def create_client(client: str | chatlas.Chat | None) -> chatlas.Chat: def normalize_tools( tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None | MISSING_TYPE, default: tuple[TOOL_GROUPS, ...] | None, + *, + check_deps: bool = True, ) -> tuple[TOOL_GROUPS, ...] | None: if tools is None or tools == (): - return None + result = None elif isinstance(tools, MISSING_TYPE): - return default + result = default elif isinstance(tools, str): - return (tools,) + result = (tools,) elif isinstance(tools, tuple): - return tools + result = tools else: - return tuple(tools) + result = tuple(tools) + if not check_deps: + return result + if has_viz_tool(result) and not has_viz_deps(): + raise ImportError( + "Visualization tools require ggsql, altair, shinywidgets, and " + "vl-convert-python. Install them with: pip install querychat[viz]" + ) + return result diff --git a/pkg-py/src/querychat/_querychat_core.py b/pkg-py/src/querychat/_querychat_core.py index af0685e01..1dd132631 100644 --- a/pkg-py/src/querychat/_querychat_core.py +++ b/pkg-py/src/querychat/_querychat_core.py @@ -165,6 +165,8 @@ def format_tool_result(result: ContentToolResult) -> str: return str(result) + + def format_query_error(e: Exception) -> str: """Format a query error with helpful guidance.""" error_msg = str(e).lower() diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py index c25b923fc..d49aa11a2 100644 --- a/pkg-py/src/querychat/_shiny.py +++ b/pkg-py/src/querychat/_shiny.py @@ -10,9 +10,10 @@ from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui from ._icons import bs_icon -from ._querychat_base import TOOL_GROUPS, QueryChatBase +from ._querychat_base import DEFAULT_TOOLS, TOOL_GROUPS, QueryChatBase from ._shiny_module import ServerValues, mod_server, mod_ui from ._utils import MISSING, MISSING_TYPE, as_narwhals +from ._viz_utils import has_viz_tool if TYPE_CHECKING: from pathlib import Path @@ -97,10 +98,11 @@ class QueryChat(QueryChatBase[IntoFrameT]): tools Which querychat tools to include in the chat client by default. Can be: - A single tool string: `"update"` or `"query"` - - A tuple of tools: `("update", "query")` + - A tuple of tools: `("update", "query", "visualize_query")` - `None` or `()` to disable all tools - Default is `("update", "query")` (both tools enabled). + Default is `("update", "query")`. The visualization tool (`"visualize_query"`) + can be opted into by including it in the tuple. Set to `"update"` to prevent the LLM from accessing data values, only allowing dashboard filtering without answering questions. @@ -156,7 +158,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -172,7 +174,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -188,7 +190,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -204,7 +206,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -219,7 +221,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, - tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -245,7 +247,7 @@ def app( """ Quickly chat with a dataset. - Creates a Shiny app with a chat sidebar and data table view -- providing a + Creates a Shiny app with a chat sidebar and data view -- providing a quick-and-easy way to start chatting with your data. Parameters @@ -301,6 +303,7 @@ def app_server(input: Inputs, output: Outputs, session: Session): greeting=self.greeting, client=self._create_session_client, enable_bookmarking=enable_bookmarking, + tools=self.tools, ) @render.text @@ -399,7 +402,7 @@ def ui(self, *, id: Optional[str] = None, **kwargs): A UI component. """ - return mod_ui(id or self.id, **kwargs) + return mod_ui(id or self.id, preload_viz=has_viz_tool(self.tools), **kwargs) def server( self, @@ -506,6 +509,7 @@ def create_session_client(**kwargs) -> chatlas.Chat: greeting=self.greeting, client=create_session_client, enable_bookmarking=enable_bookmarking, + tools=self.tools, ) @@ -616,6 +620,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = ("update", "query"), data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -632,6 +637,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -648,6 +654,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -664,6 +671,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -680,6 +688,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -695,6 +704,7 @@ def __init__( id: Optional[str] = None, greeting: Optional[str | Path] = None, client: Optional[str | chatlas.Chat] = None, + tools: TOOL_GROUPS | tuple[TOOL_GROUPS, ...] | None = DEFAULT_TOOLS, data_description: Optional[str | Path] = None, categorical_threshold: int = 20, extra_instructions: Optional[str | Path] = None, @@ -714,6 +724,7 @@ def __init__( table_name, greeting=greeting, client=client, + tools=tools, data_description=data_description, categorical_threshold=categorical_threshold, extra_instructions=extra_instructions, @@ -743,6 +754,7 @@ def __init__( greeting=self.greeting, client=self._create_session_client, enable_bookmarking=enable, + tools=self.tools, ) def sidebar( @@ -804,7 +816,7 @@ def ui(self, *, id: Optional[str] = None, **kwargs): A UI component. """ - return mod_ui(id or self.id, **kwargs) + return mod_ui(id or self.id, preload_viz=has_viz_tool(self.tools), **kwargs) def df(self) -> IntoFrameT: """ diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py index 4264285bd..49ec7fa6f 100644 --- a/pkg-py/src/querychat/_shiny_module.py +++ b/pkg-py/src/querychat/_shiny_module.py @@ -1,10 +1,9 @@ from __future__ import annotations -import copy import warnings from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Generic, Union +from typing import TYPE_CHECKING, Generic, TypedDict, Union import chatlas import shinychat @@ -13,7 +12,9 @@ from shiny import module, reactive, ui from ._querychat_core import GREETING_PROMPT -from .tools import tool_query, tool_reset_dashboard, tool_update_dashboard +from ._viz_altair_widget import AltairWidget +from ._viz_ggsql import execute_ggsql +from ._viz_utils import has_viz_tool, preload_viz_deps_server, preload_viz_deps_ui if TYPE_CHECKING: from collections.abc import Callable @@ -23,6 +24,8 @@ from shiny import Inputs, Outputs, Session from ._datasource import DataSource + from ._querychat_base import TOOL_GROUPS + from ._viz_tools import VisualizeQueryData from .types import UpdateDashboardData ReactiveString = reactive.Value[str] @@ -30,11 +33,31 @@ ReactiveStringOrNone = reactive.Value[Union[str, None]] """A reactive string (or None) value.""" + +class VizWidgetEntry(TypedDict): + """A bookmarked visualization widget: enough state to re-register on restore.""" + + widget_id: str + ggsql: str + + CHAT_ID = "chat" +class _DeferredStubChatClient: + """Placeholder chat client for deferred stub sessions.""" + + def __getattr__(self, _name: str): + raise RuntimeError( + "Chat client is unavailable during stub session before data_source is set." + ) + + +ServerClient = chatlas.Chat | _DeferredStubChatClient + + @module.ui -def mod_ui(**kwargs): +def mod_ui(*, preload_viz: bool = False, **kwargs): css_path = Path(__file__).parent / "static" / "css" / "styles.css" js_path = Path(__file__).parent / "static" / "js" / "querychat.js" @@ -47,6 +70,7 @@ def mod_ui(**kwargs): ui.include_js(js_path), ), tag, + preload_viz_deps_ui() if preload_viz else None, ) @@ -76,18 +100,17 @@ class ServerValues(Generic[IntoFrameT]): `.title()`, or set it with `.title.set("...")`. Returns `None` if no title has been set. client - The session-specific chat client instance. This is a deep copy of the - base client configured for this specific session, containing the chat - history and tool registrations for this session only. This may be - `None` during stub sessions when the client depends on deferred, - session-scoped state. + Session chat client value. + For real sessions this is a `chatlas.Chat` created by the client + factory. For deferred stub sessions (where `data_source` is not set + yet), this is a placeholder client that raises when accessed. """ df: Callable[[], IntoFrameT] sql: ReactiveStringOrNone title: ReactiveStringOrNone - client: chatlas.Chat | None + client: ServerClient @module.server @@ -98,14 +121,39 @@ def mod_server( *, data_source: DataSource[IntoFrameT] | None, greeting: str | None, - client: chatlas.Chat | Callable, + client: Callable[..., chatlas.Chat], enable_bookmarking: bool, + tools: tuple[TOOL_GROUPS, ...] | None = None, ) -> ServerValues[IntoFrameT]: # Reactive values to store state sql = ReactiveStringOrNone(None) title = ReactiveStringOrNone(None) has_greeted = reactive.value[bool](False) # noqa: FBT003 + if not callable(client): + raise TypeError("mod_server() requires a callable client factory.") + + def update_dashboard(data: UpdateDashboardData): + sql.set(data["query"]) + title.set(data["title"]) + + def reset_dashboard(): + sql.set(None) + title.set(None) + + viz_widgets: list[VizWidgetEntry] = [] + + def on_visualize(data: VisualizeQueryData): + viz_widgets.append({"widget_id": data["widget_id"], "ggsql": data["ggsql"]}) + + def build_chat_client() -> chatlas.Chat: + return client( + update_dashboard=update_dashboard, + reset_dashboard=reset_dashboard, + visualize_query=on_visualize, + tools=tools, + ) + # Short-circuit for stub sessions (e.g. 1st run of an Express app) # data_source may be None during stub session for deferred pattern if session.is_stub_session(): @@ -113,11 +161,15 @@ def mod_server( def _stub_df(): raise RuntimeError("RuntimeError: No current reactive context") + stub_client = ( + _DeferredStubChatClient() if data_source is None else build_chat_client() + ) + return ServerValues( df=_stub_df, sql=sql, title=title, - client=client if isinstance(client, chatlas.Chat) else None, + client=stub_client, ) # Real session requires data_source @@ -127,27 +179,11 @@ def _stub_df(): "Set it via the data_source property before users connect." ) - def update_dashboard(data: UpdateDashboardData): - sql.set(data["query"]) - title.set(data["title"]) - - def reset_dashboard(): - sql.set(None) - title.set(None) - - # Set up the chat object for this session - # Support both a callable that creates a client and legacy instance pattern - if callable(client) and not isinstance(client, chatlas.Chat): - chat = client( - update_dashboard=update_dashboard, reset_dashboard=reset_dashboard - ) - else: - # Legacy pattern: client is Chat instance - chat = copy.deepcopy(client) + # Build the session-specific chat client through QueryChat.client(...). + chat = build_chat_client() - chat.register_tool(tool_update_dashboard(data_source, update_dashboard)) - chat.register_tool(tool_query(data_source)) - chat.register_tool(tool_reset_dashboard(reset_dashboard)) + if has_viz_tool(tools): + preload_viz_deps_server() # Execute query when SQL changes @reactive.calc @@ -211,6 +247,8 @@ def _on_bookmark(x: BookmarkState) -> None: vals["querychat_sql"] = sql.get() vals["querychat_title"] = title.get() vals["querychat_has_greeted"] = has_greeted.get() + if viz_widgets: + vals["querychat_viz_widgets"] = viz_widgets @session.bookmark.on_restore def _on_restore(x: RestoreState) -> None: @@ -221,9 +259,44 @@ def _on_restore(x: RestoreState) -> None: title.set(vals["querychat_title"]) if "querychat_has_greeted" in vals: has_greeted.set(vals["querychat_has_greeted"]) + if "querychat_viz_widgets" in vals: + restored = restore_viz_widgets( + data_source, vals["querychat_viz_widgets"] + ) + viz_widgets[:] = restored return ServerValues(df=filtered_df, sql=sql, title=title, client=chat) class GreetWarning(Warning): """Warning raised when no greeting is provided to QueryChat.""" + + +def restore_viz_widgets( + data_source: DataSource[IntoFrameT], + saved_widgets: list[VizWidgetEntry], +) -> list[VizWidgetEntry]: + """Re-execute ggsql queries, register widgets, and return restored entries.""" + from ggsql import validate + from shinywidgets import register_widget + + restored: list[VizWidgetEntry] = [] + + for entry in saved_widgets: + widget_id = entry["widget_id"] + ggsql_str = entry["ggsql"] + try: + validated = validate(ggsql_str) + spec = execute_ggsql(data_source, ggsql_str, validated) + altair_widget = AltairWidget.from_ggsql(spec, widget_id=widget_id) + register_widget(widget_id, altair_widget.widget) + restored.append(entry) + except Exception: + # If a query fails on restore (e.g. data changed), skip it. + # The placeholder will remain empty but the rest of the chat restores. + warnings.warn( + f"Failed to restore visualization widget '{widget_id}' on bookmark restore.", + stacklevel=2, + ) + + return restored diff --git a/pkg-py/src/querychat/_system_prompt.py b/pkg-py/src/querychat/_system_prompt.py index 5a8445e93..a18630153 100644 --- a/pkg-py/src/querychat/_system_prompt.py +++ b/pkg-py/src/querychat/_system_prompt.py @@ -6,6 +6,8 @@ import chevron +from ._viz_utils import has_viz_tool + _SCHEMA_TAG_RE = re.compile(r"\{\{[{#^/]?\s*schema\b") if TYPE_CHECKING: @@ -83,7 +85,14 @@ def render(self, tools: tuple[TOOL_GROUPS, ...] | None) -> str: "extra_instructions": self.extra_instructions, "has_tool_update": "update" in tools if tools else False, "has_tool_query": "query" in tools if tools else False, + "has_tool_visualize_query": has_viz_tool(tools), "include_query_guidelines": len(tools or ()) > 0, } - return chevron.render(self.template, context) + prompts_dir = str(Path(__file__).parent / "prompts") + return chevron.render( + self.template, + context, + partials_path=prompts_dir, + partials_ext="md", + ) diff --git a/pkg-py/src/querychat/_utils.py b/pkg-py/src/querychat/_utils.py index 555e8e376..6d4c803d8 100644 --- a/pkg-py/src/querychat/_utils.py +++ b/pkg-py/src/querychat/_utils.py @@ -4,8 +4,10 @@ import re import warnings from contextlib import contextmanager +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, Optional, overload +import chevron import narwhals.stable.v1 as nw from great_tables import GT @@ -14,6 +16,50 @@ import ibis import pandas as pd + import polars as pl + from narwhals.stable.v1.typing import IntoFrame + + +_SCHEMA_DUMP_PATTERN = re.compile( + r"^\s*[\{\[]|'additionalProperties'|\"additionalProperties\"", +) + + +def truncate_error(error_msg: str, max_chars: int = 500) -> str: + if len(error_msg) <= max_chars: + return error_msg + + lines = error_msg.split("\n") + meaningful: list[str] = [] + truncated_by_schema = False + for line in lines: + if not line.strip(): + truncated_by_schema = True + break + if _SCHEMA_DUMP_PATTERN.search(line): + truncated_by_schema = True + break + meaningful.append(line) + + if truncated_by_schema and meaningful: + prefix = "\n".join(meaningful) + if len(prefix) > max_chars: + cut = prefix[:max_chars] + last_space = cut.rfind(" ") + if last_space > max_chars // 2: + cut = cut[:last_space] + prefix = cut + return prefix.rstrip() + "\n\n(error truncated)" + + # No schema markers found (or nothing before them) — apply hard cap if needed + if len(error_msg) <= max_chars: + return error_msg + + truncated = error_msg[:max_chars] + last_space = truncated.rfind(" ") + if last_space > max_chars // 2: + truncated = truncated[:last_space] + return truncated.rstrip() + "\n\n(error truncated)" class MISSING_TYPE: # noqa: N801 @@ -171,14 +217,18 @@ def get_tool_details_setting() -> Optional[Literal["expanded", "collapsed", "def return setting_lower -def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> bool: +def querychat_tool_starts_open( + action: Literal[ + "update", "query", "reset", "visualize_query" + ], +) -> bool: """ Determine whether a tool card should be open based on action and setting. Parameters ---------- action : str - The action type ('update', 'query', or 'reset') + The action type ('update', 'query', 'reset', or 'visualize_query') Returns ------- @@ -290,3 +340,15 @@ def df_to_html(df, maxrows: int = 5) -> str: table_html += f"\n\n*(Showing {maxrows} of {nrow_full} rows)*\n" return table_html + + +def to_polars(data: IntoFrame) -> pl.DataFrame: + """Convert any narwhals-compatible frame to a polars DataFrame.""" + return as_narwhals(data).to_polars() + + +def read_prompt_template(filename: str, **kwargs: object) -> str: + """Read and interpolate a prompt template file.""" + template_path = Path(__file__).parent / "prompts" / filename + template = template_path.read_text() + return chevron.render(template, kwargs) diff --git a/pkg-py/src/querychat/_viz_altair_widget.py b/pkg-py/src/querychat/_viz_altair_widget.py new file mode 100644 index 000000000..00d40347d --- /dev/null +++ b/pkg-py/src/querychat/_viz_altair_widget.py @@ -0,0 +1,187 @@ +"""Altair chart wrapper for responsive display in Shiny.""" + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING, Any, cast +from uuid import uuid4 + +from shiny.session import get_current_session + +from shiny import reactive + +if TYPE_CHECKING: + import altair as alt + import ggsql + +class AltairWidget: + """ + An Altair chart wrapped in ``alt.JupyterChart`` for display in Shiny. + + Always produces a ``JupyterChart`` so that ``shinywidgets`` receives + a consistent widget type and doesn't call ``chart.properties(width=...)`` + (which fails on compound specs). + + Simple charts use native ``width/height: "container"`` sizing. + Compound charts (facet, concat) get calculated cell dimensions + that are reactively updated when the output container resizes. + """ + + widget: alt.JupyterChart + widget_id: str + + def __init__( + self, + chart: alt.TopLevelMixin, + *, + widget_id: str | None = None, + ) -> None: + import altair as alt + + is_compound = isinstance( + chart, + (alt.FacetChart, alt.ConcatChart, alt.HConcatChart, alt.VConcatChart), + ) + + # Workaround: Vega-Lite's width/height: "container" doesn't work for + # compound specs (facet, concat, etc.), so we inject pixel dimensions + # and reconstruct the chart. Remove this branch when ggsql handles it + # natively: https://github.com/posit-dev/ggsql/issues/238 + if is_compound: + chart = fit_chart_to_container( + chart, DEFAULT_COMPOUND_WIDTH, DEFAULT_COMPOUND_HEIGHT + ) + else: + chart = chart.properties(width="container", height="container") + + self.widget = alt.JupyterChart(chart) + self.widget_id = widget_id or f"querychat_viz_{uuid4().hex[:8]}" + + # Reactively update compound cell sizes when the container resizes. + # Also part of the compound sizing workaround (issue #238). + if is_compound: + self._setup_reactive_sizing(self.widget, self.widget_id) + + @classmethod + def from_ggsql( + cls, spec: ggsql.Spec, *, widget_id: str | None = None + ) -> AltairWidget: + from ggsql import VegaLiteWriter + + writer = VegaLiteWriter() + return cls(writer.render_chart(spec), widget_id=widget_id) + + @staticmethod + def _setup_reactive_sizing(widget: alt.JupyterChart, widget_id: str) -> None: + session = get_current_session() + if session is None: + return + + @reactive.effect + def _sizing_effect(): + width = session.clientdata.output_width(widget_id) + height = session.clientdata.output_height(widget_id) + if width is None or height is None: + return + chart = widget.chart + if chart is None: + return + chart = cast("alt.Chart", chart) + chart2 = fit_chart_to_container(chart, int(width), int(height)) + # Must set widget.spec (a new dict) rather than widget.chart, + # because traitlets won't fire change events when the same + # chart object is assigned back after in-place mutation. + widget.spec = chart2.to_dict() + + # Clean up the effect when the session ends to avoid memory leaks + session.on_ended(_sizing_effect.destroy) + + +# --------------------------------------------------------------------------- +# Compound chart sizing helpers +# +# Vega-Lite's `width/height: "container"` doesn't work for compound specs +# (facet, concat, etc.), so we manually inject cell dimensions. Ideally ggsql +# will handle this natively: https://github.com/posit-dev/ggsql/issues/238 +# --------------------------------------------------------------------------- + +DEFAULT_COMPOUND_WIDTH = 900 +DEFAULT_COMPOUND_HEIGHT = 450 + +LEGEND_CHANNELS = frozenset( + {"color", "fill", "stroke", "shape", "size", "opacity"} +) +LEGEND_WIDTH = 120 # approximate space for a right-side legend + + +def fit_chart_to_container( + chart: alt.TopLevelMixin, + container_width: int, + container_height: int, +) -> alt.TopLevelMixin: + """ + Return a copy of ``chart`` with cell ``width``/``height`` set. + + The original chart is never mutated. + + For faceted charts, divides the container width by the number of columns. + For hconcat/concat, divides by the number of sub-specs. + For vconcat, each sub-spec gets the full width. + + Subtracts padding estimates so the rendered cells fill the container, + including space for legends when present. + """ + import altair as alt + + chart = copy.deepcopy(chart) + + # Approximate padding; will be replaced when ggsql handles compound sizing + # natively (https://github.com/posit-dev/ggsql/issues/238). + padding_x = 80 # y-axis labels + title padding + padding_y = 120 # facet headers, x-axis labels + title, bottom padding + if has_legend(chart.to_dict()): + padding_x += LEGEND_WIDTH + usable_w = max(container_width - padding_x, 100) + usable_h = max(container_height - padding_y, 100) + + if isinstance(chart, alt.FacetChart): + ncol = chart.columns if isinstance(chart.columns, int) else 1 + cell_w = usable_w // max(ncol, 1) + chart.spec.width = cell_w + chart.spec.height = usable_h + elif isinstance(chart, alt.HConcatChart): + cell_w = usable_w // max(len(chart.hconcat), 1) + for sub in chart.hconcat: + sub.width = cell_w + sub.height = usable_h + elif isinstance(chart, alt.ConcatChart): + ncol = chart.columns if isinstance(chart.columns, int) else len(chart.concat) + cell_w = usable_w // max(ncol, 1) + for sub in chart.concat: + sub.width = cell_w + sub.height = usable_h + elif isinstance(chart, alt.VConcatChart): + cell_h = usable_h // max(len(chart.vconcat), 1) + for sub in chart.vconcat: + sub.width = usable_w + sub.height = cell_h + + return chart + + +def has_legend(vl: dict[str, object]) -> bool: + """Check if any encoding in the VL spec uses a legend-producing channel with a field.""" + specs: list[dict[str, Any]] = [] + if "spec" in vl: + specs.append(vl["spec"]) # type: ignore[arg-type] + for key in ("hconcat", "vconcat", "concat"): + if key in vl: + specs.extend(vl[key]) # type: ignore[arg-type] + + for spec in specs: + for layer in spec.get("layer", [spec]): # type: ignore[union-attr] + enc = layer.get("encoding", {}) # type: ignore[union-attr] + for ch in LEGEND_CHANNELS: + if ch in enc and "field" in enc[ch]: # type: ignore[operator] + return True + return False diff --git a/pkg-py/src/querychat/_viz_ggsql.py b/pkg-py/src/querychat/_viz_ggsql.py new file mode 100644 index 000000000..710e6f46b --- /dev/null +++ b/pkg-py/src/querychat/_viz_ggsql.py @@ -0,0 +1,194 @@ +""" +Helpers for executing ggsql queries in querychat. + +Architecture overview +--------------------- +Querychat executes ggsql queries through two possible paths: + +1. **Bridge path** (SQLAlchemySource with known dialect) — A + ``DataSourceReader`` implements ggsql's reader protocol, routing all SQL + through the real database. ggsql runs its full pipeline (CTEs, stat + transforms, layer queries) against the real DB. sqlglot transpiles + ggsql's ANSI-generated SQL to the target dialect. This path supports + multi-source layers and avoids pulling large result sets into memory. + +2. **Fallback path** (all other DataSource types, or bridge failure) — The + SQL portion (before VISUALISE) runs on the real database via + ``DataSource.execute_query()``, then the VISUALISE portion replays + locally against the SQL result using ``ggsql.DuckDBReader``. + +The fallback path requires reconstructing a valid ggsql query from the +split ``sql()`` and ``visual()`` parts. See ``execute_two_phase()`` for +details on the two VISUALISE forms (Form A and Form B). + +Limitation of fallback path: layer-specific sources +---------------------------------------------------- +ggsql supports per-layer data sources (``DRAW line MAPPING … FROM cte``), +but the fallback path can't support them because the SQL result is a single +DataFrame — CTEs don't survive the DataSource boundary. The bridge path +handles this correctly. +""" + +from __future__ import annotations + +import logging +import re +from typing import TYPE_CHECKING + +from ._utils import to_polars + +if TYPE_CHECKING: + import ggsql + + from ._datasource import DataSource + +logger = logging.getLogger(__name__) + + +def execute_ggsql( + data_source: DataSource, + query: str, + validated: ggsql.Validated, +) -> ggsql.Spec: + """ + Execute a ggsql query, choosing the bridge or fallback path. + + Parameters + ---------- + data_source + The querychat DataSource to execute against. + query + The original ggsql query string (needed for the bridge path). + validated + A pre-validated ggsql query (from ``ggsql.validate()``). + + Returns + ------- + ggsql.Spec + The writer-independent plot specification. + + """ + from ._datasource import SQLAlchemySource + from ._datasource_reader import SQLGLOT_DIALECTS, DataSourceReader + + if isinstance(data_source, SQLAlchemySource): + sa_dialect_name = data_source._engine.dialect.name + dialect = SQLGLOT_DIALECTS.get(sa_dialect_name) + if dialect is None: + logger.warning( + "Unknown SQLAlchemy dialect %r — falling back to two-phase execution. " + "You can register it via: " + "from querychat._datasource_reader import register_sqlglot_dialect", + sa_dialect_name, + ) + if dialect is not None: + try: + with DataSourceReader(data_source._engine, dialect) as reader: + import ggsql as _ggsql + + return _ggsql.execute(query, reader) + except Exception: + logger.debug( + "DataSourceReader bridge failed, falling back to two-phase", + exc_info=True, + ) + + return execute_two_phase(data_source, validated) + + +def execute_two_phase( + data_source: DataSource, + validated: ggsql.Validated, +) -> ggsql.Spec: + """ + Execute a ggsql query using the two-phase approach (fallback path). + + Phase 1: execute SQL on the real database. + Phase 2: replay the VISUALISE portion locally in DuckDB. + """ + from ggsql import DuckDBReader + + visual = validated.visual() + if has_layer_level_source(visual): + raise ValueError( + "Layer-specific sources are not currently supported in querychat visual " + "queries. Rewrite the query so that all layers come from the final SQL " + "result." + ) + + pl_df = to_polars(data_source.execute_query(validated.sql())) + # Snowflake (and some other backends) uppercase unquoted identifiers, + # but the LLM writes lowercase aliases in the VISUALISE clause. + # DuckDB is case-insensitive, so lowercasing here lets both match. + pl_df.columns = [c.lower() for c in pl_df.columns] + + reader = DuckDBReader("duckdb://memory") + table = extract_visualise_table(visual) + + if table is not None: + name = table[1:-1] if table.startswith('"') and table.endswith('"') else table + reader.register(name, pl_df) + return reader.execute(visual) + else: + reader.register("_data", pl_df) + return reader.execute(f"SELECT * FROM _data {visual}") + + +def extract_visualise_table(visual: str) -> str | None: + """ + Extract the table name from ``VISUALISE … FROM
`` if present. + + This handles Form B queries where the visual string contains an explicit + source (e.g., ``VISUALISE FROM sales DRAW …``). We need the table name + to register the DataFrame under the correct name in local DuckDB. + + Only looks at the portion before the first DRAW clause, since FROM after + DRAW belongs to layer-level MAPPING (a different concern). + + The ggsql Python bindings don't expose the parsed VISUALISE source, so + we use a regex. This is fragile in theory (could match FROM inside a + string literal or comment), but safe in practice because LLM-generated + VISUALISE clauses are simple and well-structured. + """ + draw_pos = re.search(r"\bDRAW\b", visual, re.IGNORECASE) + vis_clause = visual[: draw_pos.start()] if draw_pos else visual + m = re.search(r'\bFROM\s+("[^"]+?"|\S+)', vis_clause, re.IGNORECASE) + return m.group(1) if m else None + + +def has_layer_level_source(visual: str) -> bool: + """ + Return ``True`` when a DRAW clause defines its own ``FROM ``. + + ggsql supports per-layer data sources:: + + WITH summary AS (…) + SELECT * FROM raw_data + VISUALISE … + DRAW point -- from global SQL result + DRAW line MAPPING region AS x, … FROM summary -- from CTE + + Querychat can't support this because we only have the single DataFrame + from executing validated.sql() on the real database. The CTE was + evaluated server-side and its result isn't available locally. We detect + this pattern upfront and raise a clear error rather than letting ggsql + fail with a confusing "table not found". + + The regex splits the visual string on clause boundaries, then checks + each DRAW clause for ``MAPPING … FROM ``. + """ + clauses = re.split( + r"(?=\b(?:DRAW|SCALE|PROJECT|FACET|PLACE|LABEL|THEME)\b)", + visual, + flags=re.IGNORECASE, + ) + for clause in clauses: + if not re.match(r"^\s*DRAW\b", clause, re.IGNORECASE): + continue + if re.search( + r'\bMAPPING\b[\s\S]*?\bFROM\s+("[^"]+?"|\S+)', + clause, + re.IGNORECASE, + ): + return True + return False diff --git a/pkg-py/src/querychat/_viz_tools.py b/pkg-py/src/querychat/_viz_tools.py new file mode 100644 index 000000000..d08ac08cc --- /dev/null +++ b/pkg-py/src/querychat/_viz_tools.py @@ -0,0 +1,328 @@ +"""Visualization tool definitions for querychat.""" + +from __future__ import annotations + +import base64 +import copy +import io +from typing import TYPE_CHECKING, Any, TypedDict +from uuid import uuid4 + +from chatlas import ContentToolResult, Tool, content_image_url +from htmltools import HTMLDependency, TagList, tags +from shinychat.types import ToolResultDisplay + +from shiny import ui + +from .__version import __version__ +from ._icons import bs_icon +from ._utils import querychat_tool_starts_open, read_prompt_template, truncate_error +from ._viz_altair_widget import AltairWidget, fit_chart_to_container +from ._viz_ggsql import execute_ggsql + +if TYPE_CHECKING: + from collections.abc import Callable + + import altair as alt + from ipywidgets.widgets.widget import Widget + + from ._datasource import DataSource + + +class VisualizeQueryData(TypedDict): + """ + Data passed to visualize_query callback. + + This TypedDict defines the structure of data passed to the + `tool_visualize_query` callback function when the LLM creates an + exploratory visualization from a ggsql query. + + Attributes + ---------- + ggsql + The full ggsql query string (SQL + VISUALISE). + title + A descriptive title for the visualization. + widget_id + The unique widget ID used to register the visualization with shinywidgets. + + """ + + ggsql: str + title: str + widget_id: str + + +def tool_visualize_query( + data_source: DataSource, + update_fn: Callable[[VisualizeQueryData], None], +) -> Tool: + """ + Create a tool that executes a ggsql query and renders the visualization. + + Parameters + ---------- + data_source + The data source to query against + update_fn + Callback function to call with VisualizeQueryData when visualization succeeds + + Returns + ------- + Tool + A tool that can be registered with chatlas + + """ + impl = visualize_query_impl(data_source, update_fn) + impl.__doc__ = read_prompt_template( + "tool-visualize-query.md", + db_type=data_source.get_db_type(), + ) + + return Tool.from_func( + impl, + name="querychat_visualize_query", + annotations={"title": "Query Visualization"}, + ) + + +class VisualizeQueryResult(ContentToolResult): + """Tool result that registers an ipywidget and embeds it inline via shinywidgets.""" + + def __init__( + self, + widget_id: str, + widget: Widget, + ggsql_str: str, + title: str, + png_bytes: bytes | None = None, + **kwargs: Any, + ): + from shinywidgets import output_widget, register_widget + + register_widget(widget_id, widget) + + title_display = f" with title '{title}'" if title else "" + text = f"Chart displayed{title_display}." + + if png_bytes is not None: + png_b64 = base64.b64encode(png_bytes).decode("ascii") + value = [ + text, + content_image_url(f"data:image/png;base64,{png_b64}"), + ] + else: + value = text + + footer = build_viz_footer(ggsql_str, title, widget_id) + + widget_html = output_widget(widget_id, fill=True, fillable=True) + widget_html.add_class("querychat-viz-container") + widget_html.append(viz_dep()) + + extra = { + "display": ToolResultDisplay( + html=widget_html, + title=title or "Query Visualization", + show_request=False, + open=querychat_tool_starts_open("visualize_query"), + full_screen=True, + icon=bs_icon("graph-up"), + footer=footer, + ), + } + + super().__init__(value=value, model_format="as_is", extra=extra, **kwargs) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def visualize_query_impl( + data_source: DataSource, + update_fn: Callable[[VisualizeQueryData], None], +) -> Callable[[str, str], ContentToolResult]: + """Create the visualize_query implementation function.""" + from ggsql import VegaLiteWriter, validate + + def visualize_query( + ggsql: str, + title: str, + ) -> ContentToolResult: + """Execute a ggsql query and render the visualization.""" + markdown = f"```sql\n{ggsql}\n```" + + try: + validated = validate(ggsql) + if not validated.has_visual(): + # When VISUALISE contains SQL expressions (e.g., CAST()), + # ggsql silently treats the entire query as plain SQL: + # valid()=True, has_visual()=False, no errors. This + # heuristic catches that case so we can guide the LLM. + # Remove when ggsql reports this as a parse error: + # https://github.com/posit-dev/ggsql/issues/256 + has_keyword = ( + "VISUALISE" in ggsql.upper() or "VISUALIZE" in ggsql.upper() + ) + if has_keyword: + raise ValueError( + "VISUALISE clause was not recognized. " + "VISUALISE and MAPPING accept column names only — " + "no SQL expressions, CAST(), or functions. " + "Move all data transformations to the SELECT clause, " + "then reference the resulting column by name in VISUALISE." + ) + raise ValueError( + "Query must include a VISUALISE clause. " + "Use querychat_query for queries without visualization." + ) + + spec = execute_ggsql(data_source, ggsql, validated) + + raw_chart = VegaLiteWriter().render_chart(spec) + altair_widget = AltairWidget(copy.deepcopy(raw_chart)) + + try: + png_bytes = render_chart_to_png(raw_chart) + except Exception: + png_bytes = None + + update_fn( + {"ggsql": ggsql, "title": title, "widget_id": altair_widget.widget_id} + ) + + return VisualizeQueryResult( + widget_id=altair_widget.widget_id, + widget=altair_widget.widget, + ggsql_str=ggsql, + title=title, + png_bytes=png_bytes, + ) + + except Exception as e: + error_msg = truncate_error(str(e)) + markdown += f"\n\n> Error: {error_msg}" + return ContentToolResult(value=markdown, error=Exception(error_msg)) + + return visualize_query + + +PNG_WIDTH = 500 +PNG_HEIGHT = 300 + + +def render_chart_to_png(chart: alt.TopLevelMixin) -> bytes: + """Render an Altair chart to PNG bytes at a fixed size for LLM feedback.""" + import altair as alt + + chart = copy.deepcopy(chart) + is_compound = isinstance( + chart, + (alt.FacetChart, alt.ConcatChart, alt.HConcatChart, alt.VConcatChart), + ) + if is_compound: + chart = fit_chart_to_container(chart, PNG_WIDTH, PNG_HEIGHT) + else: + chart = chart.properties(width=PNG_WIDTH, height=PNG_HEIGHT) + + buf = io.BytesIO() + chart.save(buf, format="png", scale_factor=1) + return buf.getvalue() + + +def viz_dep() -> HTMLDependency: + """HTMLDependency for viz-specific CSS and JS assets.""" + return HTMLDependency( + "querychat-viz", + __version__, + source={ + "package": "querychat", + "subdir": "static", + }, + stylesheet=[{"href": "css/viz.css"}], + script=[{"src": "js/viz.js"}], + ) + + +def build_viz_footer( + ggsql_str: str, + title: str, + widget_id: str, +) -> TagList: + """Build footer HTML for visualization tool results.""" + footer_id = f"querychat_footer_{uuid4().hex[:8]}" + query_section_id = f"{footer_id}_query" + code_editor_id = f"{footer_id}_code" + + # Read-only code editor for query display + code_editor = ui.input_code_editor( + id=code_editor_id, + value=ggsql_str, + language="ggsql", + read_only=True, + line_numbers=False, + height="auto", + theme_dark="github-dark", + ) + + # Query section (hidden by default) + query_section = tags.div( + {"class": "querychat-query-section", "id": query_section_id}, + code_editor, + ) + + # Footer buttons row + buttons_row = tags.div( + {"class": "querychat-footer-buttons"}, + # Left: Show Query toggle + tags.div( + {"class": "querychat-footer-left"}, + tags.button( + { + "class": "querychat-show-query-btn", + "data-target": query_section_id, + }, + tags.span({"class": "querychat-query-chevron"}, "\u25b6"), + tags.span({"class": "querychat-query-label"}, "Show Query"), + ), + ), + # Right: Save dropdown + tags.div( + {"class": "querychat-footer-right"}, + tags.div( + {"class": "querychat-save-dropdown"}, + tags.button( + { + "class": "querychat-save-btn", + "data-widget-id": widget_id, + }, + bs_icon("download", cls="querychat-icon"), + "Save", + bs_icon("chevron-down", cls="querychat-dropdown-chevron"), + ), + tags.div( + {"class": "querychat-save-menu"}, + tags.button( + { + "class": "querychat-save-png-btn", + "data-widget-id": widget_id, + "data-title": title, + }, + "Save as PNG", + ), + tags.button( + { + "class": "querychat-save-svg-btn", + "data-widget-id": widget_id, + "data-title": title, + }, + "Save as SVG", + ), + ), + ), + ), + ) + + return TagList(buttons_row, query_section) diff --git a/pkg-py/src/querychat/_viz_utils.py b/pkg-py/src/querychat/_viz_utils.py new file mode 100644 index 000000000..eb9e0897d --- /dev/null +++ b/pkg-py/src/querychat/_viz_utils.py @@ -0,0 +1,67 @@ +"""Shared visualization utilities.""" + +from __future__ import annotations + +import importlib.util + +from htmltools import HTMLDependency, tags + +from .__version import __version__ + + +def has_viz_tool(tools: tuple[str, ...] | None) -> bool: + """Check if visualize_query is among the configured tools.""" + return tools is not None and "visualize_query" in tools + + +def has_viz_deps() -> bool: + """Check whether visualization dependencies (ggsql, altair, shinywidgets, vl-convert-python) are installed.""" + return all( + importlib.util.find_spec(pkg) is not None + for pkg in ("ggsql", "altair", "shinywidgets", "vl_convert") + ) + + +PRELOAD_WIDGET_ID = "__querychat_preload_viz__" + + +def preload_viz_deps_ui(): + """Return a hidden widget output that triggers eager JS dependency loading.""" + from shinywidgets import output_widget + + return tags.div( + output_widget(PRELOAD_WIDGET_ID), + viz_preload_dep(), + class_="querychat-viz-preload", + hidden="", + aria_hidden="true", + style="position:absolute; left:-9999px; width:1px; height:1px;", + ) + + +def viz_preload_dep() -> HTMLDependency: + """HTMLDependency for viz preload-specific JS.""" + return HTMLDependency( + "querychat-viz-preload", + __version__, + source={ + "package": "querychat", + "subdir": "static", + }, + script=[{"src": "js/viz-preload.js"}], + ) + + +def preload_viz_deps_server() -> None: + """Register a minimal Altair widget to trigger full JS dependency loading.""" + from shinywidgets import register_widget + + register_widget(PRELOAD_WIDGET_ID, mock_altair_widget()) + + +def mock_altair_widget(): + """Create a minimal Altair JupyterChart suitable for preloading JS dependencies.""" + import altair as alt + + chart = alt.Chart({"values": [{"x": 0}]}).mark_point().encode(x="x:Q") + return alt.JupyterChart(chart) diff --git a/pkg-py/src/querychat/prompts/ggsql-syntax.md b/pkg-py/src/querychat/prompts/ggsql-syntax.md new file mode 100644 index 000000000..9a6f0c016 --- /dev/null +++ b/pkg-py/src/querychat/prompts/ggsql-syntax.md @@ -0,0 +1,523 @@ +## ggsql Syntax Reference + +### Quick Reference + +```sql +[WITH cte AS (...), ...] +[SELECT columns FROM table WHERE conditions] +VISUALISE [mappings] [FROM source] +DRAW geom_type + [MAPPING col AS aesthetic, ... FROM source] + [REMAPPING stat AS aesthetic, ...] + [SETTING param => value, ...] + [FILTER sql_condition] + [PARTITION BY col, ...] + [ORDER BY col [ASC|DESC], ...] +[SCALE [TYPE] aesthetic [FROM ...] [TO ...] [VIA ...] [SETTING ...] [RENAMING ...]] +[PROJECT [aesthetics] TO coord_system [SETTING ...]] +[FACET var | row_var BY col_var [SETTING free => 'x'|'y'|('x','y'), ncol => N, nrow => N]] +[PLACE geom_type SETTING param => value, ...] +[LABEL x => '...', y => '...', ...] +``` + +### VISUALISE Clause + +Entry point for visualization. Marks where SQL ends and visualization begins. Mappings in VISUALISE and MAPPING accept **column names only** — no SQL expressions, functions, or casts. All data transformations must happen in the SELECT clause. + +```sql +-- After SELECT (most common) +SELECT date, revenue, region FROM sales +VISUALISE date AS x, revenue AS y, region AS color +DRAW line + +-- Shorthand with FROM (auto-generates SELECT * FROM) +VISUALISE FROM sales +DRAW bar MAPPING region AS x, total AS y +``` + +### Mapping Styles + +| Style | Syntax | Use When | +|-------|--------|----------| +| Explicit | `date AS x` | Column name differs from aesthetic | +| Implicit | `x` | Column name equals aesthetic name | +| Wildcard | `*` | Map all matching columns automatically | +| Literal | `'string' AS color` | Use a literal value (for legend labels in multi-layer plots) | + +### DRAW Clause (Layers) + +Multiple DRAW clauses create layered visualizations. + +```sql +DRAW geom_type + [MAPPING col AS aesthetic, ... FROM source] + [REMAPPING stat AS aesthetic, ...] + [SETTING param => value, ...] + [FILTER sql_condition] + [PARTITION BY col, ...] + [ORDER BY col [ASC|DESC], ...] +``` + +**Geom types:** + +| Category | Types | +|----------|-------| +| Basic | `point`, `line`, `path`, `bar`, `area`, `tile`, `polygon`, `ribbon` | +| Statistical | `histogram`, `density`, `smooth`, `boxplot`, `violin` | +| Annotation | `text`, `label`, `segment`, `arrow`, `rule`, `errorbar` | + +- `path` is like `line` but preserves data order instead of sorting by x. +- `tile` draws rectangles for heatmaps or range indicators. Map `x`/`y` for center (defaults to width/height of 1), or use `xmin`/`xmax`/`ymin`/`ymax` for explicit bounds. +- `smooth` fits a trendline to data. Settings: `method` (`'nw'` default for kernel regression, `'ols'` for linear, `'tls'` for total least squares), `bandwidth`, `adjust`, `kernel`. +- `text` (or `label`) renders text labels. Map `label` for the text content. Settings: `format` (template string for label formatting), `offset` (pixel offset as `(x, y)`). Labels containing `\n` are automatically split into multiple lines. +- `arrow` draws arrows between two points. Requires `x`, `y`, `xend`, `yend` aesthetics. +- `rule` draws full-span reference lines. Map a value to `y` for a horizontal line or `x` for a vertical line. Optionally map `slope` to create diagonal reference lines: `y = a + slope * x` (when `y` is mapped) or `x = a + slope * y` (when `x` is mapped). +- `line` and `path` support continuously varying `linewidth`, `stroke`, and `opacity` aesthetics within groups. + +**Aesthetics (MAPPING):** + +| Category | Aesthetics | +|----------|------------| +| Position | `x`, `y`, `xmin`, `xmax`, `ymin`, `ymax`, `xend`, `yend` | +| Color | `color`/`colour`, `fill`, `stroke`, `opacity` | +| Size/Shape | `size`, `shape`, `linewidth`, `linetype`, `width`, `height` | +| Text | `label`, `typeface`, `fontweight`, `italic`, `fontsize`, `hjust`, `vjust`, `rotation` | +| Aggregation | `weight` (for histogram/bar/density/violin) | +| Rule | `slope` (for diagonal `rule` lines) | + +**PARTITION BY** groups data without visual encoding (useful for separate lines per group without color): + +```sql +DRAW line PARTITION BY category +``` + +**ORDER BY** controls row ordering within a layer: + +```sql +DRAW line ORDER BY date ASC +``` + +### PLACE Clause (Annotations) + +`PLACE` creates annotation layers with literal values only — no data mappings. Use it for reference lines, text labels, and other fixed annotations. All aesthetics are set via `SETTING` and bypass scaling. + +```sql +PLACE geom_type SETTING param => value, ... +``` + +**Examples:** +```sql +-- Horizontal reference line +PLACE rule SETTING y => 100 + +-- Vertical reference line +PLACE rule SETTING x => '2024-06-01' + +-- Multiple reference lines (array values) +PLACE rule SETTING y => (50, 75, 100) + +-- Text annotation +PLACE text SETTING x => 10, y => 50, label => 'Threshold' + +-- Diagonal reference line (y = -1 + 0.4 * x) +PLACE rule SETTING slope => 0.4, y => -1 +``` + +`PLACE` supports any geom type but is most useful with `rule`, `text`, `segment`, and `tile`. Use `PLACE` for fixed annotation values known at query time; use `DRAW` with `MAPPING` when values come from data columns. Unlike `DRAW`, `PLACE` has no `MAPPING`, `FILTER`, `PARTITION BY`, or `ORDER BY` sub-clauses. Array values in PLACE SETTING are recycled into multiple rows only for supported aesthetics; geom parameters (like `offset` on `text`) are passed through as-is. + +### Statistical Layers and REMAPPING + +Some layers compute statistics. Use REMAPPING to access computed values: + +| Layer | Computed Stats | Default Remapping | +|-------|---------------|-------------------| +| `bar` (y unmapped) | `count`, `proportion` | `count AS y` | +| `histogram` | `count`, `density` | `count AS y` | +| `density` | `density`, `intensity` | `density AS y` | +| `violin` | `density`, `intensity` | `density AS offset` | +| `smooth` | `intensity` | `intensity AS y` | +| `boxplot` | `value`, `type` | `value AS y` | + +`boxplot` displays box-and-whisker plots. Settings: `outliers` (`true` default — show outlier points), `coef` (`1.5` default — whisker fence coefficient), `width` (`0.9` default — box width, 0–1). + +`smooth` fits a trendline to data. Settings: `method` (`'nw'` or `'nadaraya-watson'` default kernel regression, `'ols'` for OLS linear, `'tls'` for total least squares). NW-only settings: `bandwidth` (numeric), `adjust` (multiplier, default 1), `kernel` (`'gaussian'` default, `'epanechnikov'`, `'triangular'`, `'rectangular'`, `'uniform'`, `'biweight'`, `'quartic'`, `'cosine'`). + +`density` computes a KDE from a continuous `x`. Settings: `bandwidth` (numeric), `adjust` (multiplier, default 1), `kernel` (`'gaussian'` default, `'epanechnikov'`, `'triangular'`, `'rectangular'`, `'uniform'`, `'biweight'`, `'quartic'`, `'cosine'`). Use `REMAPPING intensity AS y` to show unnormalized density that reflects group size differences. Use `SETTING position => 'stack'` for stacked densities. + +`violin` displays mirrored KDE curves for groups. Requires both `x` (categorical) and `y` (continuous). Accepts the same bandwidth/adjust/kernel settings as density. Use `REMAPPING intensity AS offset` to reflect group size differences. Additional settings: `side` (`'both'` default, `'left'`/`'bottom'`, `'right'`/`'top'` — for half-violin/ridgeline plots), `width` (any value > 0; values > 1 enable ridgeline-style overlapping). + +**Examples:** + +```sql +-- Density histogram (instead of count) +VISUALISE FROM products +DRAW histogram MAPPING price AS x REMAPPING density AS y + +-- Bar showing proportion +VISUALISE FROM sales +DRAW bar MAPPING region AS x REMAPPING proportion AS y + +-- Overlay histogram and density on the same scale +VISUALISE FROM measurements +DRAW histogram MAPPING value AS x SETTING opacity => 0.5 +DRAW density MAPPING value AS x REMAPPING intensity AS y SETTING opacity => 0.5 + +-- Violin plot +SELECT department, salary FROM employees +VISUALISE department AS x, salary AS y +DRAW violin +``` + +### SCALE Clause + +Configures how data maps to visual properties. All sub-clauses are optional; type and transform are auto-detected from data when omitted. + +```sql +SCALE [TYPE] aesthetic [FROM range] [TO output] [VIA transform] [SETTING prop => value, ...] [RENAMING ...] +``` + +**Type identifiers** (optional — auto-detected if omitted): + +| Type | Description | +|------|-------------| +| `CONTINUOUS` | Numeric data on a continuous axis | +| `DISCRETE` | Categorical/nominal data | +| `BINNED` | Pre-bucketed data | +| `ORDINAL` | Ordered categories with interpolated output | +| `IDENTITY` | Data values are already visual values (e.g., literal hex colors) | + +**Important — integer columns used as categories:** When an integer column represents categories (e.g., a 0/1 `survived` column), ggsql will treat it as continuous by default. This causes errors when mapping to `fill`, `color`, `shape`, or using it in `FACET`. Two fixes: +- **Preferred:** Cast to string in the SELECT clause: `SELECT CAST(survived AS VARCHAR) AS survived ...`, then map the column by name in VISUALISE: `survived AS fill` +- **Alternative:** Declare the scale: `SCALE DISCRETE fill` or `SCALE fill VIA bool` + +**FROM** — input domain: +```sql +SCALE x FROM (0, 100) -- explicit min and max +SCALE x FROM (0, null) -- explicit min, auto max +SCALE DISCRETE x FROM ('A', 'B', 'C') -- explicit category order +``` + +**TO** — output range or palette: +```sql +SCALE color TO sequential -- default continuous palette (derived from navia) +SCALE color TO viridis -- other continuous: viridis, plasma, inferno, magma, cividis, navia, batlow +SCALE color TO vik -- diverging: vik, rdbu, rdylbu, spectral, brbg, berlin, roma +SCALE DISCRETE color TO ggsql10 -- discrete (default: ggsql10): tableau10, category10, set1, set2, set3, dark2, paired, kelly +SCALE color TO ('red', 'blue') -- explicit color array +SCALE size TO (1, 10) -- numeric output range +``` + +**VIA** — transformation: +```sql +SCALE x VIA date -- date axis (auto-detected from Date columns) +SCALE x VIA datetime -- datetime axis +SCALE y VIA log10 -- base-10 logarithm +SCALE y VIA sqrt -- square root +``` + +| Category | Transforms | +|----------|------------| +| Logarithmic | `log10`, `log2`, `log` (natural) | +| Power | `sqrt`, `square` | +| Exponential | `exp`, `exp2`, `exp10` | +| Other | `asinh`, `pseudo_log` | +| Temporal | `date`, `datetime`, `time` | +| Type coercion | `integer`, `string`, `bool` | + +**SETTING** — additional properties: +```sql +SCALE x SETTING breaks => 5 -- number of tick marks +SCALE x SETTING breaks => '2 months' -- interval-based breaks +SCALE x SETTING expand => 0.05 -- expand scale range by 5% +SCALE x SETTING reverse => true -- reverse direction +``` + +**RENAMING** — custom axis/legend labels: +```sql +SCALE DISCRETE x RENAMING 'A' => 'Alpha', 'B' => 'Beta' +SCALE CONTINUOUS x RENAMING * => '{} units' -- template for all labels +SCALE x VIA date RENAMING * => '{:time %b %Y}' -- date label formatting +``` + +### Date/Time Axes + +Temporal transforms are auto-detected from column data types, including after `DATE_TRUNC`. + +**Break intervals:** +```sql +SCALE x SETTING breaks => 'month' -- one break per month +SCALE x SETTING breaks => '2 weeks' -- every 2 weeks +SCALE x SETTING breaks => '3 months' -- quarterly +SCALE x SETTING breaks => 'year' -- yearly +``` + +Valid units: `day`, `week`, `month`, `year` (for date); also `hour`, `minute`, `second` (for datetime/time). + +**Date label formatting** (strftime syntax): +```sql +SCALE x VIA date RENAMING * => '{:time %b %Y}' -- "Jan 2024" +SCALE x VIA date RENAMING * => '{:time %B %d, %Y}' -- "January 15, 2024" +SCALE x VIA date RENAMING * => '{:time %b %d}' -- "Jan 15" +``` + +### PROJECT Clause + +Sets coordinate system. Use `PROJECT ... TO` to specify coordinates. + +**Coordinate systems:** `cartesian` (default), `polar`. + +**Polar aesthetics:** In polar coordinates, positional aesthetics use `angle` and `radius` (instead of `x` and `y`). Variants `anglemin`, `anglemax`, `angleend`, `radiusmin`, `radiusmax`, `radiusend` are also available. Typically you map to `x`/`y` and let `PROJECT TO polar` handle the conversion, but you can use `angle`/`radius` explicitly when needed. + +```sql +PROJECT TO cartesian -- explicit default (usually omitted) +PROJECT y, x TO cartesian -- flip axes (maps y to horizontal, x to vertical) +PROJECT TO polar -- pie/radial charts +PROJECT TO polar SETTING start => 90 -- start at 3 o'clock +PROJECT TO polar SETTING inner => 0.5 -- donut chart (50% hole) +PROJECT TO polar SETTING start => -90, end => 90 -- half-circle gauge +``` + +**Cartesian settings:** +- `clip` — clip out-of-bounds data (default `true`) +- `ratio` — enforce aspect ratio between axes + +**Polar settings:** +- `start` — starting angle in degrees (0 = 12 o'clock, 90 = 3 o'clock) +- `end` — ending angle in degrees (default: start + 360; use for partial arcs/gauges) +- `inner` — inner radius as proportion 0–1 (0 = full pie, 0.5 = donut with 50% hole) +- `clip` — clip out-of-bounds data (default `true`) + +**Axis flipping:** To create horizontal bar charts or flip axes, use `PROJECT y, x TO cartesian`. This maps anything on `y` to the horizontal axis and `x` to the vertical axis. + +### FACET Clause + +Creates small multiples (subplots by category). + +```sql +FACET category -- Single variable, wrapped layout +FACET row_var BY col_var -- Grid layout (rows x columns) +FACET category SETTING free => 'y' -- Independent y-axes +FACET category SETTING free => ('x', 'y') -- Independent both axes +FACET category SETTING ncol => 4 -- Control number of columns +FACET category SETTING nrow => 2 -- Control number of rows (mutually exclusive with ncol) +``` + +Custom strip labels via SCALE: +```sql +FACET region +SCALE panel RENAMING 'N' => 'North', 'S' => 'South' +``` + +### LABEL Clause + +Use LABEL for axis labels only. Do NOT use `LABEL title => ...` — the tool's `title` parameter handles chart titles. Set a label to `null` to suppress it. + +```sql +LABEL x => 'X Axis Label', y => 'Y Axis Label' +LABEL x => null -- suppress x-axis label +``` + +## Complete Examples + +**Line chart with multiple series:** +```sql +SELECT date, revenue, region FROM sales WHERE year = 2024 +VISUALISE date AS x, revenue AS y, region AS color +DRAW line +SCALE x VIA date +LABEL x => 'Date', y => 'Revenue ($)' +``` + +**Bar chart (auto-count):** +```sql +VISUALISE FROM products +DRAW bar MAPPING category AS x +``` + +**Horizontal bar chart:** +```sql +SELECT region, COUNT(*) as n FROM sales GROUP BY region +VISUALISE region AS y, n AS x +DRAW bar +PROJECT y, x TO cartesian +``` + +**Scatter plot with trend line:** +```sql +SELECT mpg, hp, cylinders FROM cars +VISUALISE mpg AS x, hp AS y +DRAW point MAPPING cylinders AS color +DRAW smooth +``` + +**Histogram with density overlay:** +```sql +VISUALISE FROM measurements +DRAW histogram MAPPING value AS x SETTING bins => 20, opacity => 0.5 +DRAW density MAPPING value AS x REMAPPING intensity AS y SETTING opacity => 0.5 +``` + +**Density plot with groups:** +```sql +VISUALISE FROM measurements +DRAW density MAPPING value AS x, category AS color SETTING opacity => 0.7 +``` + +**Heatmap with tile:** +```sql +SELECT day, month, temperature FROM weather +VISUALISE day AS x, month AS y, temperature AS color +DRAW tile +``` + +**Threshold reference lines (using PLACE):** +```sql +SELECT date, temperature FROM sensor_data +VISUALISE date AS x, temperature AS y +DRAW line +PLACE rule SETTING y => 100, stroke => 'red', linetype => 'dashed' +LABEL y => 'Temperature (F)' +``` + +**Faceted chart:** +```sql +SELECT month, sales, region FROM data +VISUALISE month AS x, sales AS y +DRAW line +DRAW point +FACET region +SCALE x VIA date +``` + +**CTE with aggregation and date formatting:** +```sql +WITH monthly AS ( + SELECT DATE_TRUNC('month', order_date) as month, SUM(amount) as total + FROM orders GROUP BY 1 +) +VISUALISE month AS x, total AS y FROM monthly +DRAW line +DRAW point +SCALE x VIA date SETTING breaks => 'month' RENAMING * => '{:time %b %Y}' +LABEL y => 'Revenue ($)' +``` + +**Ribbon / confidence band:** +```sql +WITH daily AS ( + SELECT DATE_TRUNC('day', timestamp) as day, + AVG(temperature) as avg_temp, + MIN(temperature) as min_temp, + MAX(temperature) as max_temp + FROM sensor_data + GROUP BY DATE_TRUNC('day', timestamp) +) +VISUALISE day AS x FROM daily +DRAW ribbon MAPPING min_temp AS ymin, max_temp AS ymax SETTING opacity => 0.3 +DRAW line MAPPING avg_temp AS y +SCALE x VIA date +LABEL y => 'Temperature' +``` + +**Text labels on bars:** +```sql +SELECT region, COUNT(*) AS n FROM sales GROUP BY region +VISUALISE region AS x, n AS y +DRAW bar +DRAW text MAPPING n AS label SETTING offset => (0, -11), fill => 'white' +``` + +**Donut chart:** +```sql +VISUALISE FROM products +DRAW bar MAPPING category AS fill +PROJECT TO polar SETTING inner => 0.5 +``` + +## Important Notes + +1. **Numeric columns as categories**: Integer columns representing categories (e.g., 0/1 `survived`) are treated as continuous by default, causing errors with `fill`, `color`, `shape`, and `FACET`. Fix by casting in SQL or declaring the scale: + ```sql + -- WRONG: integer fill without discrete scale — causes validation error + SELECT sex, survived FROM titanic + VISUALISE sex AS x, survived AS fill + DRAW bar + + -- CORRECT: cast to string in SQL (preferred) + SELECT sex, CAST(survived AS VARCHAR) AS survived FROM titanic + VISUALISE sex AS x, survived AS fill + DRAW bar + + -- ALSO CORRECT: declare the scale as discrete + SELECT sex, survived FROM titanic + VISUALISE sex AS x, survived AS fill + DRAW bar + SCALE DISCRETE fill + ``` +2. **Do not mix `VISUALISE FROM` with a preceding `SELECT`**: `VISUALISE FROM table` is shorthand that auto-generates `SELECT * FROM table`. If you already have a `SELECT`, use `SELECT ... VISUALISE` instead: + ```sql + -- WRONG: VISUALISE FROM after SELECT + SELECT * FROM titanic + VISUALISE FROM titanic + DRAW bar MAPPING class AS x + + -- CORRECT: use VISUALISE (without FROM) after SELECT + SELECT * FROM titanic + VISUALISE class AS x + DRAW bar + + -- ALSO CORRECT: use VISUALISE FROM without any SELECT + VISUALISE FROM titanic + DRAW bar MAPPING class AS x + ``` +3. **In querychat, all layers must come from the final SQL result**: Do not use layer-specific `FROM source` inside `DRAW ... MAPPING ...` clauses. If you need raw data and a summary in one chart, put both into one final relation and distinguish layers with a column such as `layer_type`: + ```sql + WITH raw AS ( + SELECT + date, + amount, + region, + 'raw' AS layer_type + FROM sales + ), + summary AS ( + SELECT + date, + AVG(amount) AS amount, + region, + 'summary' AS layer_type + FROM sales + GROUP BY date, region + ), + combined AS ( + SELECT * FROM raw + UNION ALL + SELECT * FROM summary + ) + SELECT * FROM combined + VISUALISE date AS x, amount AS y + DRAW point MAPPING region AS color FILTER layer_type = 'raw' + DRAW line MAPPING region AS color FILTER layer_type = 'summary' + ``` +4. **String values use single quotes**: In SETTING, LABEL, and RENAMING clauses, always use single quotes for string values. Double quotes cause parse errors. +5. **Column casing in VISUALISE**: DuckDB lowercases unquoted column names in query results, and VISUALISE validates column references **case-sensitively**. If your source table has uppercase column names (e.g., from Snowflake), you **must** alias them to lowercase in the SELECT clause: + ```sql + -- WRONG: VISUALISE references uppercase name, but DuckDB lowercases it in results + SELECT ROOM_TYPE, COUNT(*) AS listings FROM airbnb + VISUALISE ROOM_TYPE AS x, listings AS y + DRAW bar + + -- CORRECT: Alias to lowercase, then reference the alias + SELECT ROOM_TYPE AS room_type, COUNT(*) AS listings FROM airbnb + VISUALISE room_type AS x, listings AS y + DRAW bar + ``` + As a general rule, always use lowercase column names and aliases in both SELECT and VISUALISE clauses. +6. **Charts vs Tables**: For visualizations use VISUALISE with DRAW. For tabular data use plain SQL without VISUALISE. +7. **Statistical layers**: When using `histogram`, `bar` (without y), `density`, `smooth`, `violin`, or `boxplot`, the layer computes statistics. Use REMAPPING to access `density`, `intensity`, `proportion`, etc. +8. **Bar position adjustments**: Bars stack automatically when `fill` is mapped. Use `SETTING position => 'dodge'` for side-by-side bars, or `position => 'stack', total => 1` for proportional (100%) stacking: + ```sql + DRAW bar MAPPING category AS x, subcategory AS fill -- stacked (default) + DRAW bar MAPPING category AS x, subcategory AS fill SETTING position => 'dodge' -- side-by-side + DRAW bar MAPPING category AS x, subcategory AS fill SETTING position => 'stack', total => 1 -- proportional + ``` diff --git a/pkg-py/src/querychat/prompts/prompt.md b/pkg-py/src/querychat/prompts/prompt.md index 8c6ff97bc..9123bbe40 100644 --- a/pkg-py/src/querychat/prompts/prompt.md +++ b/pkg-py/src/querychat/prompts/prompt.md @@ -1,4 +1,4 @@ -You are a data dashboard chatbot that operates in a sidebar interface. Your role is to help users interact with their data through filtering, sorting, and answering questions. +You are a data dashboard chatbot that operates in a sidebar interface. Your role is to help users interact with their data through filtering, sorting, answering questions, and exploring data visually. You have access to a {{db_type}} SQL database with the following schema: @@ -118,11 +118,105 @@ Response: "The average revenue is $X." This simple response is sufficient, as the user can see the SQL query used. {{/has_tool_query}} +{{#has_tool_visualize_query}} +### Visualizing Data + +You can create visualizations using the `querychat_visualize_query` tool, which uses ggsql — a SQL extension for declarative data visualization. Write a ggsql query (SQL with a VISUALISE clause), and the tool executes the SQL, renders the VISUALISE clause as an Altair chart, and displays it inline in the chat. + +#### Visualization best practices + +The database schema in this prompt includes column names, types, and summary statistics. {{#has_tool_query}}If that context isn't sufficient for a confident visualization — e.g., you're unsure about value distributions, need to check for NULLs, or want to gauge row counts before choosing a chart type — use the `querychat_query` tool to inspect the data before visualizing. Always pass `collapsed=True` for these preparatory queries so the chart remains the focal point of the response.{{/has_tool_query}} + +Follow the principles below to produce clear, interpretable charts. + +#### Axis labels must be readable + +When the x-axis contains categorical labels (names, categories, long strings), prefer flipping axes with `PROJECT y, x TO cartesian` so labels read naturally left-to-right. Short numeric or date labels on the x-axis are fine horizontal — this applies specifically to text categories. + +#### Always include axis labels with units + +Charts should be interpretable without reading the surrounding prose. Always include axis labels that describe what is shown, including units when applicable (e.g., `LABEL y => 'Revenue ($M)'`, not just `LABEL y => 'Revenue'`). + +#### Maximize data-ink ratio + +Every visual element should serve a purpose: + +- Don't map columns to aesthetics (color, size, shape) unless the distinction is meaningful to the user's question. A single-series bar chart doesn't need color. +- When using color for categories, keep to 7 or fewer distinct values. Beyond that, consider filtering to the most important categories or using facets instead. +- Avoid dual-encoding the same variable (e.g., mapping the same column to both x-position and color) unless it genuinely aids interpretation. + +#### Avoid overplotting + +When a dataset has many rows, plotting one mark per row creates clutter that obscures patterns. Before generating a query, consider the row count and data characteristics visible in the schema. + +**For large datasets (hundreds+ rows):** + +- **Aggregate first**: Use `GROUP BY` with `COUNT`, `AVG`, `SUM`, or other aggregates to reduce to meaningful summaries before visualizing. +- **Choose chart types that summarize naturally**: histograms for distributions, boxplots for group comparisons, line charts for trends over time. + +**For two numeric variables with many rows:** + +Bin in SQL and use `DRAW tile` to create a heatmap: + +```sql +WITH binned AS ( + SELECT ROUND(x_col / 5) * 5 AS x_bin, + ROUND(y_col / 5) * 5 AS y_bin, + COUNT(*) AS n + FROM large_table + GROUP BY x_bin, y_bin +) +SELECT * FROM binned +VISUALISE x_bin AS x, y_bin AS y, n AS fill +DRAW tile +SCALE fill TO viridis +``` + +**If individual points matter** (e.g., outlier detection): use `SETTING opacity` to reveal density through overlap. + +#### Choose chart types based on the data relationship + +Match the chart type to what the user is trying to understand: + +- **Comparison across categories**: bar chart (`DRAW bar`, with `PROJECT y, x TO cartesian` for long labels). Order bars by value, not alphabetically. +- **Trend over time**: line chart (`DRAW line`). Use `SCALE x VIA date` for date columns. +- **Distribution of a single variable**: histogram (`DRAW histogram`) or density (`DRAW density`). +- **Relationship between two numeric variables**: scatter plot (`DRAW point`), but prefer aggregation or heatmap if the dataset is large. +- **Part-of-whole**: stacked bar chart (map subcategory to `fill`). Avoid pie charts — position along a common scale is easier to decode than angle. + +#### Graceful recovery + +If a visualization fails, read the error message carefully and retry with a corrected query. Common fixes: correcting column names, adding `SCALE DISCRETE` for integer categories, using single quotes for strings, moving SQL expressions out of VISUALISE into the SELECT clause.{{#has_tool_query}} If the error persists, fall back to `querychat_query` for a tabular answer.{{/has_tool_query}} + +#### ggsql syntax reference + +The syntax reference below covers all available clauses, geom types, scales, and examples. + +{{> ggsql-syntax}} +{{/has_tool_visualize_query}} +{{#has_tool_query}} +{{#has_tool_visualize_query}} +### Choosing Between Query and Visualization + +Use `querychat_query` for single-value answers (averages, counts, totals, specific lookups) or when the user needs to see exact values. Use `querychat_visualize_query` when comparisons, distributions, or trends are involved — even for small result sets, a chart is often clearer than a short table. + +**Avoid redundant expanded results.** If you run a preparatory query before visualizing, or if both a table and chart would show the same data, always pass `collapsed=True` on the query so the user sees the chart prominently, not a duplicate table above it. The user can still expand the table if they want the exact values. + +{{/has_tool_visualize_query}} +{{/has_tool_query}} +{{^has_tool_visualize_query}} +### Visualization Requests + +You cannot create charts or visualizations. If users ask for a plot, chart, or visual representation of the data, explain that visualization is not currently enabled.{{#has_tool_query}} Offer to answer their question with a tabular query instead.{{/has_tool_query}} Suggest that the developer can enable visualization by installing `querychat[viz]` and adding `"visualize_query"` to the `tools` parameter. + +{{/has_tool_visualize_query}} {{^has_tool_query}} +{{^has_tool_visualize_query}} ### Questions About Data You cannot query or analyze the data. If users ask questions about data values, statistics, or calculations (e.g., "What is the average ____?" or "How many ____ are there?"), explain that you're not able to run queries on this data. Do not attempt to answer based on your own knowledge or assumptions about the data, even if the dataset seems familiar. +{{/has_tool_visualize_query}} {{/has_tool_query}} ### Providing Suggestions for Next Steps @@ -146,9 +240,16 @@ You might want to explore the advanced features **Nested lists:** ```md +{{#has_tool_query}} * Analyze the data * What's the average …? * How many …? +{{/has_tool_query}} +{{#has_tool_visualize_query}} +* Visualize the data + * Show a bar chart of … + * Plot the trend of … over time +{{/has_tool_visualize_query}} * Filter and sort * Show records from the year … * Sort the ____ by ____ … @@ -185,6 +286,7 @@ You might want to explore the advanced features - **Ask for clarification** if any request is unclear or ambiguous - **Be concise** due to the constrained interface - **Only answer data questions using your tools** - never use prior knowledge or assumptions about the data, even if the dataset seems familiar +- **Be skeptical of your own interpretations** - when describing chart results or data patterns, encourage the user to verify findings rather than presenting analytical conclusions as fact - **Use Markdown tables** for any tabular or structured data in your responses {{#extra_instructions}} diff --git a/pkg-py/src/querychat/prompts/tool-query.md b/pkg-py/src/querychat/prompts/tool-query.md index 0fcdec4b3..312b85fcc 100644 --- a/pkg-py/src/querychat/prompts/tool-query.md +++ b/pkg-py/src/querychat/prompts/tool-query.md @@ -1,17 +1,4 @@ -Execute a SQL query and return the results - -This tool executes a {{db_type}} SQL SELECT query against the database and returns the raw result data for analysis. - -**When to use:** Call this tool whenever the user asks a question that requires data analysis, aggregation, or calculations. Use this for questions like: -- "What is the average...?" -- "How many records...?" -- "Which item has the highest/lowest...?" -- "What's the total sum of...?" -- "What percentage of ...?" - -Always use SQL for counting, averaging, summing, and other calculations—NEVER attempt manual calculations on your own. Use this tool repeatedly if needed to avoid any kind of manual calculation. - -**When not to use:** Do NOT use this tool for filtering or sorting the dashboard display. If the user wants to "Show me..." or "Filter to..." certain records in the dashboard, use the `querychat_update_dashboard` tool instead. +Execute a {{db_type}} SQL SELECT query and return the results for analysis. **Important guidelines:** @@ -25,6 +12,8 @@ Parameters ---------- query : A valid {{db_type}} SQL SELECT statement. Must follow the database schema provided in the system prompt. Use clear column aliases (e.g., 'AVG(price) AS avg_price') and include SQL comments for complex logic. Subqueries and CTEs are encouraged for readability. +collapsed : + Optional (default: false). Set to true for exploratory or preparatory queries (e.g., inspecting data before visualization, checking row counts, previewing column values) whose results aren't the primary answer. When true, the result card starts collapsed so it doesn't clutter the conversation. _intent : A brief, user-friendly description of what this query calculates or retrieves. diff --git a/pkg-py/src/querychat/prompts/tool-update-dashboard.md b/pkg-py/src/querychat/prompts/tool-update-dashboard.md index dae9861c0..0b98d219b 100644 --- a/pkg-py/src/querychat/prompts/tool-update-dashboard.md +++ b/pkg-py/src/querychat/prompts/tool-update-dashboard.md @@ -1,10 +1,4 @@ -Filter and sort the dashboard data - -This tool executes a {{db_type}} SQL SELECT query to filter or sort the data used in the dashboard. - -**When to use:** Call this tool whenever the user requests filtering, sorting, or data manipulation on the dashboard with questions like "Show me..." or "Which records have...". This tool is appropriate for any request that involves showing a subset of the data or reordering it. - -**When not to use:** Do NOT use this tool for general questions about the data that can be answered with a single value or summary statistic. For those questions, use the `querychat_query` tool instead. +Filter and sort the dashboard data by executing a {{db_type}} SQL SELECT query. **Important constraints:** diff --git a/pkg-py/src/querychat/prompts/tool-visualize-query.md b/pkg-py/src/querychat/prompts/tool-visualize-query.md new file mode 100644 index 000000000..ee671a9dd --- /dev/null +++ b/pkg-py/src/querychat/prompts/tool-visualize-query.md @@ -0,0 +1,13 @@ +Render a ggsql query inline in the chat. All data transformations must happen in the SELECT clause — VISUALISE and MAPPING accept column names only, not SQL expressions or functions. + +Parameters +---------- +ggsql : + A full ggsql query. Must include a VISUALISE clause and at least one DRAW clause. The SELECT portion uses {{db_type}} SQL; VISUALISE and MAPPING accept column names only, not expressions. Do NOT include `LABEL title => ...` in the query — use the `title` parameter instead. +title : + A brief, user-friendly title for this visualization. This is displayed as the card header above the chart. + +Returns +------- +: + If successful, a static image of the rendered plot. If not, an error message. diff --git a/pkg-py/src/querychat/static/css/viz.css b/pkg-py/src/querychat/static/css/viz.css new file mode 100644 index 000000000..1b5812bc1 --- /dev/null +++ b/pkg-py/src/querychat/static/css/viz.css @@ -0,0 +1,141 @@ +/* Hide Vega's built-in action dropdown (we have our own save button) */ +.querychat-viz-container details:has(> .vega-actions) { + display: none !important; +} + +/* ---- Visualization container ---- */ + +.querychat-viz-container { + aspect-ratio: 4 / 2; + width: 100%; +} + +/* In full-screen mode, let the chart fill the available space */ +.bslib-full-screen-container .querychat-viz-container { + aspect-ratio: unset; +} + +/* ---- Visualization footer ---- */ + +.querychat-footer-buttons { + display: flex; + justify-content: space-between; + align-items: center; +} + +.querychat-footer-left, +.querychat-footer-right { + display: flex; + align-items: center; + gap: 4px; +} + +.querychat-show-query-btn, +.querychat-save-btn { + display: inline-flex; + align-items: center; + gap: 4px; + padding: 2px 8px; + height: 28px; + border: none; + border-radius: var(--bs-border-radius, 4px); + background: transparent; + color: var(--bs-secondary-color, #6c757d); + font-size: 0.75rem; + cursor: pointer; + white-space: nowrap; +} + +.querychat-show-query-btn:hover, +.querychat-save-btn:hover { + color: var(--bs-body-color, #212529); + background-color: rgba(var(--bs-emphasis-color-rgb, 0, 0, 0), 0.05); +} + +.querychat-query-chevron { + font-size: 0.625rem; + transition: transform 150ms; + display: inline-block; +} + +.querychat-query-chevron--expanded { + transform: rotate(90deg); +} + +.querychat-icon { + width: 14px; + height: 14px; +} + +.querychat-dropdown-chevron { + width: 12px; + height: 12px; + margin-left: 2px; +} + +.querychat-save-dropdown { + position: relative; +} + +.querychat-save-menu { + display: none; + position: absolute; + right: 0; + bottom: 100%; + margin-bottom: 4px; + z-index: 20; + background: var(--bs-body-bg, #fff); + border: 1px solid var(--bs-border-color, #dee2e6); + border-radius: var(--bs-border-radius, 4px); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15); + padding: 4px 0; + min-width: 120px; +} + +.querychat-save-menu--visible { + display: block; +} + +.querychat-save-menu button { + display: block; + width: 100%; + padding: 6px 12px; + border: none; + background: transparent; + color: var(--bs-body-color, #212529); + font-size: 0.75rem; + text-align: left; + cursor: pointer; +} + +.querychat-save-menu button:hover { + background-color: rgba(var(--bs-emphasis-color-rgb, 0, 0, 0), 0.05); +} + +.querychat-query-section { + display: none; + position: relative; + border-top: 1px solid var(--bs-border-color, #dee2e6); + margin: 8px -16px -8px; +} + +.querychat-query-section--visible { + display: block; +} + + +/* shinychat sets max-height:500px on all cards, which is too small for viz+editor */ +.shiny-tool-card:has(.querychat-viz-container) { + max-height: 700px; + overflow: hidden; +} + +.querychat-query-section bslib-code-editor .code-editor { + margin: 1em; +} + +.querychat-query-section bslib-code-editor .prism-code-editor { + background-color: var(--bs-light, #f8f8f8); + max-height: 200px; + overflow-y: auto; +} \ No newline at end of file diff --git a/pkg-py/src/querychat/static/js/viz-preload.js b/pkg-py/src/querychat/static/js/viz-preload.js new file mode 100644 index 000000000..088433af2 --- /dev/null +++ b/pkg-py/src/querychat/static/js/viz-preload.js @@ -0,0 +1,50 @@ +(function () { + if (!window.Shiny) return; + + var preloadObserver = null; + + function stopVizPreloadObserver() { + if (!preloadObserver) return; + preloadObserver.disconnect(); + preloadObserver = null; + } + + function handleVizPreload(root) { + if (!root || !root.isConnected) return; + + if (window.__querychatVizPreloaded) { + root.remove(); + stopVizPreloadObserver(); + return; + } + + window.__querychatVizPreloaded = true; + root.removeAttribute("hidden"); + stopVizPreloadObserver(); + } + + function processVizPreloads(node) { + if (!(node instanceof Element)) return; + + if (node.matches(".querychat-viz-preload")) { + handleVizPreload(node); + } + + node.querySelectorAll(".querychat-viz-preload").forEach(handleVizPreload); + } + + processVizPreloads(document.documentElement); + + if (!window.__querychatVizPreloaded) { + preloadObserver = new MutationObserver(function (mutations) { + mutations.forEach(function (mutation) { + mutation.addedNodes.forEach(processVizPreloads); + }); + }); + + preloadObserver.observe(document.documentElement, { + childList: true, + subtree: true, + }); + } +})(); diff --git a/pkg-py/src/querychat/static/js/viz.js b/pkg-py/src/querychat/static/js/viz.js new file mode 100644 index 000000000..a04475173 --- /dev/null +++ b/pkg-py/src/querychat/static/js/viz.js @@ -0,0 +1,129 @@ +// Helper: find a native vega-embed action link inside a widget container. +// vega-embed renders a hidden
with tags for "Save as SVG", +// "Save as PNG", etc. We find them by matching the download attribute suffix. +// +// Why not use the Vega View API (view.toSVG(), view.toImageURL()) directly? +// Altair renders charts via its anywidget ESM, which calls vegaEmbed() and +// stores the resulting View in a closure — it's never exposed on the DOM or +// any accessible object. vega-embed v7 also doesn't set __vega_embed__ on +// the element. The only code with access to the View is vega-embed's own +// action handlers, so we delegate to them. +function findVegaAction(container, extension) { + return container.querySelector( + '.vega-actions a[download$=".' + extension + '"]' + ); +} + +// Helper: find a widget container by its base ID. +// Shiny module namespacing may prefix the ID (e.g. "mod-querychat_viz_abc"), +// so we match elements whose ID ends with the base widget ID. +function findWidgetContainer(widgetId) { + return document.getElementById(widgetId) + || document.querySelector('[id$="' + CSS.escape(widgetId) + '"]'); +} + +// Helper: trigger a vega-embed export action link. +// vega-embed attaches an async mousedown handler that calls +// view.toImageURL() and sets the link's href to a data URL. +// We dispatch mousedown, then use a MutationObserver to detect +// when href changes from "#" to a data URL, and click the link. +function triggerVegaAction(link, filename) { + link.download = filename; + + // If href is already a data URL (unlikely but possible), click immediately. + if (link.href && link.href !== "#" && !link.href.endsWith("#")) { + link.click(); + return; + } + + var observer = new MutationObserver(function () { + if (link.href && link.href !== "#" && !link.href.endsWith("#")) { + observer.disconnect(); + clearTimeout(timeout); + link.click(); + } + }); + + observer.observe(link, { attributes: true, attributeFilter: ["href"] }); + + var timeout = setTimeout(function () { + observer.disconnect(); + console.error("Timed out waiting for vega-embed to generate image"); + }, 5000); + + link.dispatchEvent(new MouseEvent("mousedown", { bubbles: true })); +} + +function closeAllSaveMenus() { + document.querySelectorAll(".querychat-save-menu--visible").forEach(function (menu) { + menu.classList.remove("querychat-save-menu--visible"); + }); +} + +function handleShowQuery(event, btn) { + event.stopPropagation(); + var targetId = btn.dataset.target; + var section = document.getElementById(targetId); + if (!section) return; + var isVisible = section.classList.toggle("querychat-query-section--visible"); + var label = btn.querySelector(".querychat-query-label"); + var chevron = btn.querySelector(".querychat-query-chevron"); + if (label) label.textContent = isVisible ? "Hide Query" : "Show Query"; + if (chevron) chevron.classList.toggle("querychat-query-chevron--expanded", isVisible); +} + +function handleSaveToggle(event, btn) { + event.stopPropagation(); + var menu = btn.parentElement.querySelector(".querychat-save-menu"); + if (menu) menu.classList.toggle("querychat-save-menu--visible"); +} + +function handleSaveExport(event, btn, extension) { + event.stopPropagation(); + var widgetId = btn.dataset.widgetId; + var title = btn.dataset.title || "chart"; + var menu = btn.closest(".querychat-save-menu"); + if (menu) menu.classList.remove("querychat-save-menu--visible"); + + var container = findWidgetContainer(widgetId); + if (!container) return; + var link = findVegaAction(container, extension); + if (!link) return; + triggerVegaAction(link, title + "." + extension); +} + +function handleCopy(event, btn) { + event.stopPropagation(); + var query = btn.dataset.query; + if (!query) return; + navigator.clipboard.writeText(query).then(function () { + var original = btn.textContent; + btn.textContent = "Copied!"; + setTimeout(function () { btn.textContent = original; }, 2000); + }).catch(function (err) { + console.error("Failed to copy:", err); + }); +} + +// Single delegated click handler for all querychat viz footer buttons. +window.addEventListener("click", function (event) { + var target = event.target; + + var btn = target.closest(".querychat-show-query-btn"); + if (btn) { handleShowQuery(event, btn); return; } + + btn = target.closest(".querychat-save-png-btn"); + if (btn) { handleSaveExport(event, btn, "png"); return; } + + btn = target.closest(".querychat-save-svg-btn"); + if (btn) { handleSaveExport(event, btn, "svg"); return; } + + btn = target.closest(".querychat-copy-btn"); + if (btn) { handleCopy(event, btn); return; } + + btn = target.closest(".querychat-save-btn"); + if (btn) { handleSaveToggle(event, btn); return; } + + // Click outside any button — close open save menus + closeAllSaveMenus(); +}); diff --git a/pkg-py/src/querychat/tools.py b/pkg-py/src/querychat/tools.py index 67ea453f5..27e42909f 100644 --- a/pkg-py/src/querychat/tools.py +++ b/pkg-py/src/querychat/tools.py @@ -1,14 +1,26 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, Any, Protocol, TypedDict, runtime_checkable -import chevron from chatlas import ContentToolResult, Tool from shinychat.types import ToolResultDisplay from ._icons import bs_icon -from ._utils import as_narwhals, df_to_html, querychat_tool_starts_open +from ._utils import ( + as_narwhals, + df_to_html, + querychat_tool_starts_open, + read_prompt_template, + truncate_error, +) +from ._viz_tools import tool_visualize_query + +__all__ = [ + "tool_query", + "tool_reset_dashboard", + "tool_update_dashboard", + "tool_visualize_query", +] if TYPE_CHECKING: from collections.abc import Callable @@ -69,13 +81,6 @@ def log_update(data: UpdateDashboardData): title: str -def _read_prompt_template(filename: str, **kwargs) -> str: - """Read and interpolate a prompt template file.""" - template_path = Path(__file__).parent / "prompts" / filename - template = template_path.read_text() - return chevron.render(template, kwargs) - - def _update_dashboard_impl( data_source: DataSource, update_fn: Callable[[UpdateDashboardData], None], @@ -103,9 +108,9 @@ def update_dashboard(query: str, title: str) -> ContentToolResult: update_fn({"query": query, "title": title}) except Exception as e: - error = str(e) + error = truncate_error(str(e)) markdown += f"\n\n> Error: {error}" - return ContentToolResult(value=markdown, error=e) + return ContentToolResult(value=markdown, error=Exception(error)) # Return ContentToolResult with display metadata return ContentToolResult( @@ -146,7 +151,7 @@ def tool_update_dashboard( """ impl = _update_dashboard_impl(data_source, update_fn) - description = _read_prompt_template( + description = read_prompt_template( "tool-update-dashboard.md", db_type=data_source.get_db_type(), ) @@ -212,7 +217,7 @@ def tool_reset_dashboard( """ impl = _reset_dashboard_impl(reset_fn) - description = _read_prompt_template("tool-reset-dashboard.md") + description = read_prompt_template("tool-reset-dashboard.md") impl.__doc__ = description return Tool.from_func( @@ -222,10 +227,14 @@ def tool_reset_dashboard( ) -def _query_impl(data_source: DataSource) -> Callable[[str, str], ContentToolResult]: +def _query_impl(data_source: DataSource) -> Callable[..., ContentToolResult]: """Create the implementation function for querying data.""" - def query(query: str, _intent: str = "") -> ContentToolResult: + def query( + query: str, + collapsed: bool | None = None, # noqa: FBT001 (LLM tool parameter) + _intent: str = "", + ) -> ContentToolResult: error = None markdown = f"```sql\n{query}\n```" value = None @@ -239,9 +248,9 @@ def query(query: str, _intent: str = "") -> ContentToolResult: markdown += "\n\n" + str(tbl_html) except Exception as e: - error = str(e) + error = truncate_error(str(e)) markdown += f"\n\n> Error: {error}" - return ContentToolResult(value=markdown, error=e) + return ContentToolResult(value=markdown, error=Exception(error)) # Return ContentToolResult with display metadata return ContentToolResult( @@ -250,7 +259,9 @@ def query(query: str, _intent: str = "") -> ContentToolResult: "display": ToolResultDisplay( markdown=markdown, show_request=False, - open=querychat_tool_starts_open("query"), + open=(not collapsed) + if collapsed is not None + else querychat_tool_starts_open("query"), icon=bs_icon("table"), ), }, @@ -276,7 +287,7 @@ def tool_query(data_source: DataSource) -> Tool: """ impl = _query_impl(data_source) - description = _read_prompt_template( + description = read_prompt_template( "tool-query.md", db_type=data_source.get_db_type() ) impl.__doc__ = description diff --git a/pkg-py/src/querychat/types/__init__.py b/pkg-py/src/querychat/types/__init__.py index f9a8163df..87b284325 100644 --- a/pkg-py/src/querychat/types/__init__.py +++ b/pkg-py/src/querychat/types/__init__.py @@ -9,6 +9,7 @@ from .._querychat_core import AppStateDict from .._shiny_module import ServerValues from .._utils import UnsafeQueryError +from .._viz_tools import VisualizeQueryData, VisualizeQueryResult from ..tools import UpdateDashboardData __all__ = ( @@ -22,4 +23,6 @@ "ServerValues", "UnsafeQueryError", "UpdateDashboardData", + "VisualizeQueryData", + "VisualizeQueryResult", ) diff --git a/pkg-py/tests/conftest.py b/pkg-py/tests/conftest.py new file mode 100644 index 000000000..95d586937 --- /dev/null +++ b/pkg-py/tests/conftest.py @@ -0,0 +1,32 @@ +"""Shared pytest fixtures for querychat unit tests.""" + +import polars as pl +import pytest + + +def _ggsql_render_works() -> bool: + """Check if ggsql.render_altair() is functional (build can be broken in some envs).""" + try: + import ggsql + + df = pl.DataFrame({"x": [1, 2], "y": [3, 4]}) + result = ggsql.render_altair(df, "VISUALISE x, y DRAW point") + spec = result.to_dict() + return "$schema" in spec + except (ValueError, ImportError): + return False + + +_ggsql_available = _ggsql_render_works() + + +def pytest_collection_modifyitems(config, items): + """Auto-skip tests marked with @pytest.mark.ggsql when ggsql is broken.""" + if _ggsql_available: + return + skip = pytest.mark.skip( + reason="ggsql.render_altair() not functional (build environment issue)" + ) + for item in items: + if "ggsql" in item.keywords: + item.add_marker(skip) diff --git a/pkg-py/tests/playwright/apps/viz_bookmark_app.py b/pkg-py/tests/playwright/apps/viz_bookmark_app.py new file mode 100644 index 000000000..6552bfa8a --- /dev/null +++ b/pkg-py/tests/playwright/apps/viz_bookmark_app.py @@ -0,0 +1,25 @@ +"""Test app for viz bookmark restore: uses server-side bookmarking to avoid URL length limits.""" + +from querychat import QueryChat +from querychat.data import titanic + +from shiny import App, ui + +qc = QueryChat( + titanic(), + "titanic", + tools=("query", "visualize_query"), +) + + +def app_ui(request): + return ui.page_fillable( + qc.ui(), + ) + + +def server(input, output, session): + qc.server(enable_bookmarking=True) + + +app = App(app_ui, server, bookmark_store="server") diff --git a/pkg-py/tests/playwright/conftest.py b/pkg-py/tests/playwright/conftest.py index 6febfd4e8..961af01f3 100644 --- a/pkg-py/tests/playwright/conftest.py +++ b/pkg-py/tests/playwright/conftest.py @@ -592,3 +592,31 @@ def dash_cleanup(_thread, server): yield url finally: _stop_dash_server(server) + + +@pytest.fixture(scope="module") +def app_10_viz() -> Generator[str, None, None]: + """Start the 10-viz-app.py Shiny server for testing.""" + app_path = str(EXAMPLES_DIR / "10-viz-app.py") + + def start_factory(): + port = _find_free_port() + url = f"http://localhost:{port}" + return url, lambda: _start_shiny_app_threaded(app_path, port) + + def shiny_cleanup(_thread, server): + _stop_shiny_server(server) + + url, _thread, server = _start_server_with_retry( + start_factory, shiny_cleanup, timeout=30.0 + ) + try: + yield url + finally: + _stop_shiny_server(server) + + +@pytest.fixture +def chat_10_viz(page: Page) -> ChatControllerType: + """Create a ChatController for the 10-viz-app chat component.""" + return _create_chat_controller(page, "titanic") diff --git a/pkg-py/tests/playwright/test_10_viz_inline.py b/pkg-py/tests/playwright/test_10_viz_inline.py new file mode 100644 index 000000000..857e860cb --- /dev/null +++ b/pkg-py/tests/playwright/test_10_viz_inline.py @@ -0,0 +1,119 @@ +""" +Playwright tests for inline visualization and fullscreen behavior. + +These tests verify that: +1. The visualize_query tool renders Altair charts inline in tool result cards +2. The fullscreen toggle button appears on visualization tool results +3. Fullscreen mode works (expand and collapse via button and Escape key) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from playwright.sync_api import expect + +if TYPE_CHECKING: + from playwright.sync_api import Page + from shinychat.playwright import ChatController + + +class TestInlineVisualization: + """Tests for inline chart rendering in tool result cards.""" + + @pytest.fixture(autouse=True) + def setup( + self, page: Page, app_10_viz: str, chat_10_viz: ChatController + ) -> None: + """Navigate to the viz app before each test.""" + page.goto(app_10_viz) + page.wait_for_selector("shiny-chat-container", timeout=30000) + self.page = page + self.chat = chat_10_viz + + def test_viz_tool_renders_inline_chart(self) -> None: + """VIZ-INLINE: Visualization tool result contains an inline chart widget.""" + self.chat.set_user_input( + "Create a scatter plot of age vs fare for the titanic passengers" + ) + self.chat.send_user_input(method="click") + + # Wait for a tool result card with full-screen attribute (viz results have it) + tool_card = self.page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)") + expect(tool_card).to_be_visible(timeout=90000) + + # The card should contain the viz container (Altair chart via shinywidgets) + viz_container = tool_card.locator(".querychat-viz-container") + expect(viz_container).to_be_visible(timeout=10000) + + def test_fullscreen_button_visible_on_viz_card(self) -> None: + """VIZ-FS-BTN: Fullscreen toggle button appears on visualization cards.""" + self.chat.set_user_input( + "Make a bar chart showing count of passengers by class" + ) + self.chat.send_user_input(method="click") + + # Wait for viz tool result + tool_card = self.page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)") + expect(tool_card).to_be_visible(timeout=90000) + + # Fullscreen toggle should be visible + fs_button = tool_card.locator(".tool-fullscreen-toggle") + expect(fs_button).to_be_visible() + + def test_fullscreen_toggle_expands_card(self) -> None: + """VIZ-FS-EXPAND: Clicking fullscreen button expands the card.""" + self.chat.set_user_input( + "Plot a histogram of passenger ages from the titanic data" + ) + self.chat.send_user_input(method="click") + + # Wait for viz tool result + tool_result = self.page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)") + expect(tool_result).to_be_visible(timeout=90000) + + # Click fullscreen toggle + fs_button = tool_result.locator(".tool-fullscreen-toggle") + fs_button.click() + + # The .shiny-tool-card inside should now have fullscreen attribute + card = tool_result.locator(".shiny-tool-card[fullscreen]") + expect(card).to_be_visible() + + def test_escape_closes_fullscreen(self) -> None: + """VIZ-FS-ESC: Pressing Escape closes fullscreen mode.""" + self.chat.set_user_input( + "Create a visualization of survival rate by passenger class" + ) + self.chat.send_user_input(method="click") + + # Wait for viz tool result + tool_result = self.page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)") + expect(tool_result).to_be_visible(timeout=90000) + + # Enter fullscreen + fs_button = tool_result.locator(".tool-fullscreen-toggle") + fs_button.click() + + card = tool_result.locator(".shiny-tool-card[fullscreen]") + expect(card).to_be_visible() + + # Press Escape + self.page.keyboard.press("Escape") + + # Fullscreen should be removed + expect(card).not_to_be_visible() + + def test_non_viz_tool_results_have_no_fullscreen(self) -> None: + """VIZ-NO-FS: Non-visualization tool results don't have fullscreen.""" + self.chat.set_user_input("Show me passengers who survived") + self.chat.send_user_input(method="click") + + # Wait for a tool result (any) + tool_result = self.page.locator(".shiny-tool-result").first + expect(tool_result).to_be_visible(timeout=90000) + + # Non-viz tool results should NOT have fullscreen toggle + fs_results = self.page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)") + expect(fs_results).to_have_count(0) diff --git a/pkg-py/tests/playwright/test_11_viz_footer.py b/pkg-py/tests/playwright/test_11_viz_footer.py new file mode 100644 index 000000000..2cd586952 --- /dev/null +++ b/pkg-py/tests/playwright/test_11_viz_footer.py @@ -0,0 +1,154 @@ +""" +Playwright tests for visualization footer interactions (Show Query, Save dropdown). + +These tests verify the client-side JS behavior in viz.js: +1. Show Query toggle reveals/hides the query section +2. Save dropdown opens/closes on click +3. Clicking outside the Save dropdown closes it +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from playwright.sync_api import expect + +if TYPE_CHECKING: + from playwright.sync_api import Page + from shinychat.playwright import ChatController + + +VIZ_PROMPT = "Use the visualize tool to create a scatter plot of age vs fare" +TOOL_RESULT_TIMEOUT = 90_000 + + +@pytest.fixture(autouse=True) +def _send_viz_prompt( + page: Page, app_10_viz: str, chat_10_viz: ChatController +) -> None: + """Navigate to the viz app and trigger a visualization before each test.""" + page.goto(app_10_viz) + page.wait_for_selector("shiny-chat-container", timeout=30_000) + + chat_10_viz.set_user_input(VIZ_PROMPT) + chat_10_viz.send_user_input(method="click") + + # Wait for the viz tool result card with fullscreen support + page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)").wait_for( + state="visible", timeout=TOOL_RESULT_TIMEOUT + ) + # Wait for the footer buttons to appear inside the card + page.locator(".querychat-footer-buttons").wait_for( + state="visible", timeout=10_000 + ) + + +class TestShowQueryToggle: + """Tests for the Show Query / Hide Query toggle button.""" + + def test_query_section_hidden_by_default(self, page: Page) -> None: + """The query section should be hidden initially.""" + section = page.locator(".querychat-query-section") + expect(section).to_be_attached() + expect(section).not_to_be_visible() + + def test_click_show_query_reveals_section(self, page: Page) -> None: + """Clicking 'Show Query' should reveal the query section.""" + btn = page.locator(".querychat-show-query-btn") + btn.click() + + section = page.locator(".querychat-query-section--visible") + expect(section).to_be_visible() + + def test_label_changes_to_hide_query(self, page: Page) -> None: + """After clicking, the label should change to 'Hide Query'.""" + btn = page.locator(".querychat-show-query-btn") + label = btn.locator(".querychat-query-label") + + expect(label).to_have_text("Show Query") + btn.click() + expect(label).to_have_text("Hide Query") + + def test_chevron_rotates_on_expand(self, page: Page) -> None: + """The chevron should get the --expanded class when query is shown.""" + btn = page.locator(".querychat-show-query-btn") + chevron = btn.locator(".querychat-query-chevron") + + expect(chevron).not_to_have_class("querychat-query-chevron--expanded") + btn.click() + expect(chevron).to_have_class("querychat-query-chevron querychat-query-chevron--expanded") + + def test_toggle_hides_section_again(self, page: Page) -> None: + """Clicking the button a second time should hide the query section.""" + btn = page.locator(".querychat-show-query-btn") + btn.click() # show + btn.click() # hide + + section = page.locator(".querychat-query-section") + expect(section).not_to_have_class("querychat-query-section--visible") + + label = btn.locator(".querychat-query-label") + expect(label).to_have_text("Show Query") + + def test_query_section_contains_code(self, page: Page) -> None: + """The revealed query section should contain the ggsql code.""" + btn = page.locator(".querychat-show-query-btn") + btn.click() + + section = page.locator(".querychat-query-section--visible") + expect(section).to_be_visible() + + # The code editor should contain VISUALISE (ggsql keyword) + code = section.locator(".code-editor") + expect(code).to_be_visible() + + +class TestSaveDropdown: + """Tests for the Save button dropdown menu.""" + + def test_save_menu_hidden_by_default(self, page: Page) -> None: + """The save dropdown menu should be hidden initially.""" + menu = page.locator(".querychat-save-menu") + expect(menu).to_be_attached() + expect(menu).not_to_be_visible() + + def test_click_save_opens_menu(self, page: Page) -> None: + """Clicking the Save button should reveal the dropdown menu.""" + btn = page.locator(".querychat-save-btn") + btn.click() + + menu = page.locator(".querychat-save-menu--visible") + expect(menu).to_be_visible() + + def test_menu_has_png_and_svg_options(self, page: Page) -> None: + """The save menu should contain 'Save as PNG' and 'Save as SVG' options.""" + btn = page.locator(".querychat-save-btn") + btn.click() + + menu = page.locator(".querychat-save-menu--visible") + expect(menu.locator(".querychat-save-png-btn")).to_be_visible() + expect(menu.locator(".querychat-save-svg-btn")).to_be_visible() + + def test_click_outside_closes_menu(self, page: Page) -> None: + """Clicking outside the dropdown should close it.""" + btn = page.locator(".querychat-save-btn") + btn.click() + + menu = page.locator(".querychat-save-menu") + expect(menu).to_have_class("querychat-save-menu querychat-save-menu--visible") + + # Click somewhere else on the page body + page.locator("body").click(position={"x": 10, "y": 10}) + + expect(menu).not_to_have_class("querychat-save-menu--visible") + + def test_toggle_save_menu(self, page: Page) -> None: + """Clicking Save twice should open then close the menu.""" + btn = page.locator(".querychat-save-btn") + btn.click() + menu = page.locator(".querychat-save-menu") + expect(menu).to_have_class("querychat-save-menu querychat-save-menu--visible") + + btn.click() + expect(menu).not_to_have_class("querychat-save-menu--visible") diff --git a/pkg-py/tests/playwright/test_12_viz_bookmark.py b/pkg-py/tests/playwright/test_12_viz_bookmark.py new file mode 100644 index 000000000..7683b4355 --- /dev/null +++ b/pkg-py/tests/playwright/test_12_viz_bookmark.py @@ -0,0 +1,136 @@ +""" +Playwright tests for visualization bookmark restore behavior. + +These tests verify that when a user creates a visualization and then +restores from a bookmark URL, the chart widget is properly re-rendered +(not just the empty HTML shell). +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest +from playwright.sync_api import expect + +if TYPE_CHECKING: + from collections.abc import Generator + + from playwright.sync_api import BrowserContext, Page + from shinychat.playwright import ChatController as ChatControllerType + +import sys + +# conftest.py is not importable directly; add the test directory to sys.path +sys.path.insert(0, str(Path(__file__).parent)) +from conftest import ( + _create_chat_controller, + _find_free_port, + _start_server_with_retry, + _start_shiny_app_threaded, + _stop_shiny_server, +) + +VIZ_PROMPT = "Use the visualize tool to create a scatter plot of age vs fare" +TOOL_RESULT_TIMEOUT = 90_000 +APPS_DIR = Path(__file__).parent / "apps" + + +@pytest.fixture(scope="module") +def app_viz_bookmark() -> Generator[str, None, None]: + """Start the viz bookmark test app with server-side bookmarking.""" + app_path = str(APPS_DIR / "viz_bookmark_app.py") + + def start_factory(): + port = _find_free_port() + url = f"http://localhost:{port}" + return url, lambda: _start_shiny_app_threaded(app_path, port) + + def shiny_cleanup(_thread, server): + _stop_shiny_server(server) + + url, _thread, server = _start_server_with_retry( + start_factory, shiny_cleanup, timeout=30.0 + ) + try: + yield url + finally: + _stop_shiny_server(server) + + +@pytest.fixture +def chat_viz_bookmark(page: Page) -> ChatControllerType: + return _create_chat_controller(page, "titanic") + + +class TestVizBookmarkRestore: + """Tests for visualization restoration from bookmark URLs.""" + + @pytest.fixture(autouse=True) + def setup( + self, page: Page, app_viz_bookmark: str, chat_viz_bookmark: ChatControllerType + ) -> None: + """Navigate to the viz app and create a viz before each test.""" + self.app_url = app_viz_bookmark + self.page = page + self.chat = chat_viz_bookmark + + page.goto(app_viz_bookmark) + page.wait_for_selector("shiny-chat-container", timeout=30_000) + + # Wait for the greeting bookmark URL to be set first + # (bookmark_on="response" auto-bookmarks after greeting) + page.wait_for_function( + "() => window.location.search.includes('_state_id_=')", + timeout=30_000, + ) + self.greeting_url = page.url + + # Create a visualization + chat_viz_bookmark.set_user_input(VIZ_PROMPT) + chat_viz_bookmark.send_user_input(method="click") + + # Wait for the viz tool result to fully render + page.locator(".shiny-tool-result:has(.tool-fullscreen-toggle)").wait_for( + state="visible", timeout=TOOL_RESULT_TIMEOUT + ) + page.locator(".querychat-footer-buttons").wait_for( + state="visible", timeout=10_000 + ) + + def _wait_for_viz_bookmark_url(self) -> str: + """Wait for the URL to update from the greeting bookmark to the viz bookmark.""" + greeting_search = self.greeting_url.split("?", 1)[1] if "?" in self.greeting_url else "" + self.page.wait_for_function( + "(greetingSearch) => window.location.search.includes('_state_id_=') && window.location.search !== '?' + greetingSearch", + arg=greeting_search, + timeout=30_000, + ) + return self.page.url + + def test_bookmark_url_updates_after_viz(self) -> None: + """BOOKMARK-VIZ-URL: URL should update with new state ID after viz is created.""" + url = self._wait_for_viz_bookmark_url() + assert url != self.greeting_url, "URL should have changed after viz bookmarking" + + def test_viz_widget_renders_on_bookmark_restore(self, context: BrowserContext) -> None: + """BOOKMARK-VIZ-RESTORE: Restored bookmark should re-render the chart widget, not just the HTML shell.""" + bookmark_url = self._wait_for_viz_bookmark_url() + + # Open the bookmark URL in a new page (new session) + new_page = context.new_page() + new_page.goto(bookmark_url) + new_page.wait_for_selector("shiny-chat-container", timeout=30_000) + + # The viz container HTML should be restored (shinychat restores message HTML) + viz_container = new_page.locator(".querychat-viz-container") + expect(viz_container).to_be_visible(timeout=30_000) + + # The critical check: the widget should actually render a chart, + # not just be an empty output_widget div. A rendered Vega-Lite chart + # will have a canvas or SVG inside a .vega-embed container. + chart_element = viz_container.locator("canvas, svg, .vega-embed") + expect(chart_element.first).to_be_visible(timeout=30_000) + + new_page.close() diff --git a/pkg-py/tests/test_datasource_reader.py b/pkg-py/tests/test_datasource_reader.py new file mode 100644 index 000000000..13a83cd96 --- /dev/null +++ b/pkg-py/tests/test_datasource_reader.py @@ -0,0 +1,242 @@ +"""Tests for DataSourceReader bridge.""" + +import polars as pl +import pytest +from sqlalchemy import create_engine, text + + +class TestDialectMapping: + """Tests for SQLGLOT_DIALECTS mapping.""" + + def test_known_dialects_present(self): + from querychat._datasource_reader import SQLGLOT_DIALECTS + + expected = { + "postgresql": "postgres", + "mysql": "mysql", + "sqlite": "sqlite", + "mssql": "tsql", + "oracle": "oracle", + "duckdb": "duckdb", + "snowflake": "snowflake", + "bigquery": "bigquery", + "redshift": "redshift", + "trino": "trino", + "databricks": "databricks", + "clickhousedb": "clickhouse", + "clickhouse": "clickhouse", + "awsathena": "athena", + "teradatasql": "teradata", + "exasol": "exasol", + "doris": "doris", + "singlestoredb": "singlestore", + "risingwave": "risingwave", + "druid": "druid", + "hive": "hive", + "presto": "presto", + } + for sa_name, sqlglot_name in expected.items(): + assert SQLGLOT_DIALECTS[sa_name] == sqlglot_name, f"mismatch for {sa_name}" + + def test_unknown_dialect_not_present(self): + from querychat._datasource_reader import SQLGLOT_DIALECTS + + assert "nonexistent_db" not in SQLGLOT_DIALECTS + + def test_register_custom_dialect(self): + from querychat._datasource_reader import ( + SQLGLOT_DIALECTS, + register_sqlglot_dialect, + ) + + register_sqlglot_dialect("my_custom_db", "mysql") + assert SQLGLOT_DIALECTS["my_custom_db"] == "mysql" + # Clean up to avoid polluting other tests + del SQLGLOT_DIALECTS["my_custom_db"] + + +class TestTranspileSql: + """Tests for transpile_sql() helper.""" + + def test_identity_for_duckdb(self): + from querychat._datasource_reader import transpile_sql + + sql = "SELECT x, y FROM t WHERE x > 1" + result = transpile_sql(sql, "duckdb") + assert "SELECT" in result + assert "FROM" in result + + def test_transpiles_create_temp_table_to_snowflake(self): + from querychat._datasource_reader import transpile_sql + + sql = "CREATE TEMPORARY TABLE __ggsql_cte_0 AS SELECT x FROM t" + result = transpile_sql(sql, "snowflake") + assert "TEMPORARY" in result.upper() or "TEMP" in result.upper() + assert "__ggsql_cte_0" in result + + def test_transpiles_recursive_cte_to_postgres(self): + from querychat._datasource_reader import transpile_sql + + sql = ( + "WITH RECURSIVE series AS (" + "SELECT 0 AS n UNION ALL SELECT n + 1 FROM series WHERE n < 10" + ") SELECT n FROM series" + ) + result = transpile_sql(sql, "postgres") + assert "RECURSIVE" in result.upper() + + def test_transpiles_ntile_to_snowflake(self): + from querychat._datasource_reader import transpile_sql + + sql = "SELECT NTILE(4) OVER (ORDER BY x) AS quartile FROM t" + result = transpile_sql(sql, "snowflake") + assert "NTILE" in result.upper() + + def test_passthrough_on_empty_dialect(self): + """Empty string dialect means generic/ANSI — should pass through.""" + from querychat._datasource_reader import transpile_sql + + sql = "SELECT 1" + result = transpile_sql(sql, "") + assert result == "SELECT 1" + + +@pytest.fixture +def sqlite_engine(): + """Create an in-memory SQLite database with test data.""" + engine = create_engine("sqlite://") + with engine.connect() as conn: + conn.execute(text("CREATE TABLE test_data (x INTEGER, y INTEGER, label TEXT)")) + conn.execute( + text("INSERT INTO test_data VALUES (1, 10, 'a'), (2, 20, 'b'), (3, 30, 'a')") + ) + conn.commit() + return engine + + +class TestDataSourceReader: + """Tests for DataSourceReader against a real SQLite database.""" + + def test_execute_sql_returns_polars(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + df = reader.execute_sql("SELECT * FROM test_data") + assert isinstance(df, pl.DataFrame) + assert len(df) == 3 + assert set(df.columns) == {"x", "y", "label"} + + def test_execute_sql_with_filter(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + df = reader.execute_sql("SELECT * FROM test_data WHERE x > 1") + assert len(df) == 2 + + def test_register_creates_temp_table(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + df = pl.DataFrame({"a": [1, 2], "b": ["x", "y"]}) + with DataSourceReader(sqlite_engine, "sqlite") as reader: + reader.register("my_temp", df, replace=True) + result = reader.execute_sql("SELECT * FROM my_temp") + assert len(result) == 2 + assert set(result.columns) == {"a", "b"} + + def test_unregister_drops_temp_table(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + df = pl.DataFrame({"a": [1]}) + with DataSourceReader(sqlite_engine, "sqlite") as reader: + reader.register("drop_me", df, replace=True) + reader.unregister("drop_me") + with pytest.raises(Exception, match="drop_me"): + reader.execute_sql("SELECT * FROM drop_me") + + def test_context_manager_cleans_up_temp_tables(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + df = pl.DataFrame({"a": [1]}) + with DataSourceReader(sqlite_engine, "sqlite") as reader: + reader.register("cleanup_test", df, replace=True) + + # After exiting context, temp table should be gone. + # SQLite temp tables are connection-scoped, so they vanish + # when the connection closes. + with sqlite_engine.connect() as conn: + result = conn.execute( + text("SELECT name FROM sqlite_temp_master WHERE name = 'cleanup_test'") + ) + assert result.fetchone() is None + + def test_register_replace_overwrites(self, sqlite_engine): + from querychat._datasource_reader import DataSourceReader + + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [10, 20, 30]}) + with DataSourceReader(sqlite_engine, "sqlite") as reader: + reader.register("replace_me", df1, replace=True) + reader.register("replace_me", df2, replace=True) + result = reader.execute_sql("SELECT * FROM replace_me") + assert len(result) == 3 + + def test_execute_sql_transpiles(self, sqlite_engine): + """Verify that generated SQL gets transpiled to the target dialect.""" + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + df = reader.execute_sql("SELECT x, y FROM test_data ORDER BY x LIMIT 2") + assert len(df) == 2 + + +@pytest.mark.ggsql +class TestDataSourceReaderWithGgsql: + """End-to-end tests: DataSourceReader + ggsql.execute().""" + + def test_simple_scatter(self, sqlite_engine): + import ggsql + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + spec = ggsql.execute( + "SELECT x, y FROM test_data VISUALISE x, y DRAW point", + reader, + ) + assert spec.metadata()["rows"] == 3 + assert "VISUALISE" in spec.visual() + + def test_with_filter(self, sqlite_engine): + import ggsql + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + spec = ggsql.execute( + "SELECT x, y FROM test_data WHERE x > 1 VISUALISE x, y DRAW point", + reader, + ) + assert spec.metadata()["rows"] == 2 + + def test_form_b_visualise_from(self, sqlite_engine): + import ggsql + from querychat._datasource_reader import DataSourceReader + + with DataSourceReader(sqlite_engine, "sqlite") as reader: + spec = ggsql.execute( + "VISUALISE x, y FROM test_data DRAW point", + reader, + ) + assert spec.metadata()["rows"] == 3 + + def test_with_aggregation(self, sqlite_engine): + import ggsql + from querychat._datasource_reader import DataSourceReader + + # ggsql bar layer requires columns named x/y; "label" is a reserved keyword + # so we alias: label -> x, SUM(y) -> y. Two distinct label values -> 2 rows. + with DataSourceReader(sqlite_engine, "sqlite") as reader: + spec = ggsql.execute( + "SELECT label AS x, SUM(y) AS y FROM test_data GROUP BY label " + "VISUALISE x, y DRAW bar", + reader, + ) + assert spec.metadata()["rows"] == 2 diff --git a/pkg-py/tests/test_deferred_shiny.py b/pkg-py/tests/test_deferred_shiny.py index 39899a772..96ba29656 100644 --- a/pkg-py/tests/test_deferred_shiny.py +++ b/pkg-py/tests/test_deferred_shiny.py @@ -2,6 +2,7 @@ import os +import chatlas import pandas as pd import pytest from chatlas import ChatOpenAI @@ -95,8 +96,9 @@ def spy_create_client(client_spec): with session_context(ExpressStubSession()): vals = qc.server(data_source=sample_df, client=override_client) - assert vals.client is None - assert recorded_specs == [] + assert isinstance(vals.client, chatlas.Chat) + assert len(recorded_specs) == 1 + assert recorded_specs[0] is override_client assert qc._client_spec is init_client def test_multiple_server_overrides_do_not_leak_into_shared_state(self, sample_df): diff --git a/pkg-py/tests/test_ggsql.py b/pkg-py/tests/test_ggsql.py new file mode 100644 index 000000000..6f1d43b39 --- /dev/null +++ b/pkg-py/tests/test_ggsql.py @@ -0,0 +1,234 @@ +"""Tests for ggsql integration helpers.""" + +import ggsql +import narwhals.stable.v1 as nw +import polars as pl +import pytest +from querychat._datasource import DataFrameSource +from querychat._viz_altair_widget import AltairWidget +from querychat._viz_ggsql import ( + execute_ggsql, + extract_visualise_table, + has_layer_level_source, +) + + +class TestExtractVisualiseTable: + """Tests for extract_visualise_table() parsing.""" + + def test_bare_identifier(self): + assert extract_visualise_table("VISUALISE x, y FROM mytable DRAW point") == "mytable" + + def test_quoted_identifier(self): + assert ( + extract_visualise_table('VISUALISE x FROM "my table" DRAW point') + == '"my table"' + ) + + def test_no_from_returns_none(self): + assert extract_visualise_table("VISUALISE x, y DRAW point") is None + + def test_ignores_draw_level_from(self): + visual = "VISUALISE x, y DRAW bar MAPPING z AS fill FROM summary" + assert extract_visualise_table(visual) is None + + +class TestHasLayerLevelSource: + def test_detects_draw_level_from(self): + visual = "VISUALISE x, y DRAW bar MAPPING z AS fill FROM summary" + assert has_layer_level_source(visual) + + def test_ignores_visualise_from(self): + visual = "VISUALISE x, y FROM sales DRAW point MAPPING z AS color" + assert not has_layer_level_source(visual) + + def test_ignores_scale_from(self): + visual = "VISUALISE x, y DRAW point MAPPING z AS color SCALE x FROM [0, 10]" + assert not has_layer_level_source(visual) + + +class TestGgsqlValidate: + """Tests for ggsql.validate() usage (split SQL and VISUALISE).""" + + def test_splits_query_with_visualise(self): + query = "SELECT x, y FROM data VISUALISE x, y DRAW point" + validated = ggsql.validate(query) + assert validated.sql() == "SELECT x, y FROM data" + assert validated.visual() == "VISUALISE x, y DRAW point" + assert validated.has_visual() + + def test_returns_empty_viz_without_visualise(self): + query = "SELECT x, y FROM data" + validated = ggsql.validate(query) + assert validated.sql() == "SELECT x, y FROM data" + assert validated.visual() == "" + assert not validated.has_visual() + + def test_handles_complex_query(self): + query = """ + SELECT date, SUM(revenue) as total + FROM sales + GROUP BY date + VISUALISE date AS x, total AS y + DRAW line + LABEL title => 'Revenue Over Time' + """ + validated = ggsql.validate(query) + assert "SELECT date, SUM(revenue)" in validated.sql() + assert "GROUP BY date" in validated.sql() + assert "VISUALISE date AS x" in validated.visual() + assert "LABEL title" in validated.visual() + + + +@pytest.fixture(autouse=True) +def _allow_widget_outside_session(monkeypatch): + """Allow JupyterChart (an ipywidget) to be constructed without a Shiny session.""" + from ipywidgets.widgets.widget import Widget + + monkeypatch.setattr(Widget, "_widget_construction_callback", lambda _w: None) + + +class TestAltairWidget: + @pytest.mark.ggsql + def test_produces_jupyter_chart(self): + import altair as alt + import ggsql + + reader = ggsql.DuckDBReader("duckdb://memory") + df = pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) + reader.register("data", df) + spec = reader.execute("SELECT * FROM data VISUALISE x, y DRAW point") + altair_widget = AltairWidget.from_ggsql(spec) + assert isinstance(altair_widget.widget, alt.JupyterChart) + result = altair_widget.widget.chart.to_dict() + assert "$schema" in result + assert "vega-lite" in result["$schema"] + + +class TestExecuteGgsql: + @pytest.mark.ggsql + def test_full_pipeline(self): + nw_df = nw.from_native(pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + ds = DataFrameSource(nw_df, "test_data") + query = "SELECT * FROM test_data VISUALISE x, y DRAW point" + spec = execute_ggsql(ds, query, ggsql.validate(query)) + altair_widget = AltairWidget.from_ggsql(spec) + result = altair_widget.widget.chart.to_dict() + assert "$schema" in result + + @pytest.mark.ggsql + def test_with_filtered_query(self): + nw_df = nw.from_native( + pl.DataFrame({"x": [1, 2, 3, 4, 5], "y": [10, 20, 30, 40, 50]}) + ) + ds = DataFrameSource(nw_df, "test_data") + query = "SELECT * FROM test_data WHERE x > 2 VISUALISE x, y DRAW point" + spec = execute_ggsql(ds, query, ggsql.validate(query)) + assert spec.metadata()["rows"] == 3 + + @pytest.mark.ggsql + def test_spec_has_visual(self): + nw_df = nw.from_native(pl.DataFrame({"x": [1, 2], "y": [3, 4]})) + ds = DataFrameSource(nw_df, "test_data") + query = "SELECT * FROM test_data VISUALISE x, y DRAW point" + spec = execute_ggsql(ds, query, ggsql.validate(query)) + assert "VISUALISE" in spec.visual() + + @pytest.mark.ggsql + def test_visualise_from_path(self): + nw_df = nw.from_native(pl.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + ds = DataFrameSource(nw_df, "test_data") + query = "VISUALISE x, y FROM test_data DRAW point" + spec = execute_ggsql(ds, query, ggsql.validate(query)) + assert spec.metadata()["rows"] == 3 + assert "VISUALISE" in spec.visual() + + @pytest.mark.ggsql + def test_with_pandas_dataframe(self): + import pandas as pd + + nw_df = nw.from_native(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + ds = DataFrameSource(nw_df, "test_data") + query = "SELECT * FROM test_data VISUALISE x, y DRAW point" + spec = execute_ggsql(ds, query, ggsql.validate(query)) + altair_widget = AltairWidget.from_ggsql(spec) + result = altair_widget.widget.chart.to_dict() + assert "$schema" in result + + @pytest.mark.ggsql + def test_rejects_layer_level_from_sources_with_clear_error(self): + nw_df = nw.from_native( + pl.DataFrame( + { + "date": ["2024-01", "2024-01", "2024-02", "2024-02"], + "region": ["north", "south", "north", "south"], + "amount": [10, 20, 30, 40], + } + ) + ) + ds = DataFrameSource(nw_df, "sales") + query = """ + WITH summary AS ( + SELECT region, SUM(amount) AS total + FROM sales + GROUP BY region + ) + SELECT * + FROM sales + VISUALISE date AS x, amount AS y + DRAW line + DRAW bar MAPPING region AS x, total AS y FROM summary + """ + + with pytest.raises( + ValueError, + match="Layer-specific sources are not currently supported", + ): + execute_ggsql(ds, query, ggsql.validate(query)) + + @pytest.mark.ggsql + def test_supports_single_relation_raw_plus_summary_overlay(self): + nw_df = nw.from_native( + pl.DataFrame( + { + "x": [1, 1, 2, 2], + "y": [10, 20, 30, 40], + "category": ["a", "b", "a", "b"], + } + ) + ) + ds = DataFrameSource(nw_df, "sales") + query = """ + WITH raw AS ( + SELECT + x, + y, + category, + 'raw' AS layer_type + FROM sales + ), + summary AS ( + SELECT + x, + AVG(y) AS y, + category, + 'summary' AS layer_type + FROM sales + GROUP BY x, category + ), + combined AS ( + SELECT * FROM raw + UNION ALL + SELECT * FROM summary + ) + SELECT * + FROM combined + VISUALISE x AS x, y AS y + DRAW point MAPPING category AS color FILTER layer_type = 'raw' + DRAW line MAPPING category AS color FILTER layer_type = 'summary' + """ + + spec = execute_ggsql(ds, query, ggsql.validate(query)) + assert spec.metadata()["rows"] == 4 + assert "VISUALISE" in spec.visual() diff --git a/pkg-py/tests/test_shiny_viz_regressions.py b/pkg-py/tests/test_shiny_viz_regressions.py new file mode 100644 index 000000000..b9d51772c --- /dev/null +++ b/pkg-py/tests/test_shiny_viz_regressions.py @@ -0,0 +1,387 @@ +"""Regression tests for Shiny ggsql tool wiring and bookmark restore.""" + +import inspect +import os +from types import SimpleNamespace +from unittest.mock import patch + +import chatlas +import pytest +from querychat import QueryChat +from querychat._shiny import QueryChatExpress +from querychat._shiny_module import mod_server +from querychat.data import tips + +from shiny import reactive + + +@pytest.fixture(autouse=True) +def set_dummy_api_key(): + old_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "sk-dummy-api-key-for-testing" + yield + if old_api_key is not None: + os.environ["OPENAI_API_KEY"] = old_api_key + else: + del os.environ["OPENAI_API_KEY"] + + +@pytest.fixture +def sample_df(): + return tips() + + +def _identity(fn): + return fn + + +def _event(*_args, **_kwargs): + def wrapper(fn): + return fn + + return wrapper + + +def _raw_mod_server(): + return inspect.getclosurevars(mod_server).nonlocals["fn"] + + +class DummyBookmark: + def on_bookmark(self, fn): + self.bookmark_fn = fn + return fn + + def on_restore(self, fn): + self.restore_fn = fn + return fn + + +class DummySession: + def __init__(self): + self.bookmark = DummyBookmark() + + def is_stub_session(self): + return False + + +class DummyStubSession(DummySession): + def is_stub_session(self): + return True + + +class DummyChatUi: + def __init__(self, *_args, **_kwargs): + pass + + def on_user_submit(self, fn): + return fn + + async def append_message_stream(self, _stream): + return None + + async def append_message(self, _message): + return None + + def enable_bookmarking(self, _chat): + return None + + +class DummyProvider(chatlas.Provider): + def __init__(self, *, name, model): + super().__init__(name=name, model=model) + + def list_models(self): + return [] + + def chat_perform(self, *, stream, turns, tools, data_model, kwargs): + return () if stream else SimpleNamespace() + + async def chat_perform_async( + self, *, stream, turns, tools, data_model, kwargs + ): + return () if stream else SimpleNamespace() + + def stream_content(self, chunk): + return None + + def stream_text(self, chunk): + return None + + def stream_merge_chunks(self, completion, chunk): + return completion or {} + + def stream_turn(self, completion, has_data_model): + return SimpleNamespace() + + def value_turn(self, completion, has_data_model): + return SimpleNamespace() + + def value_tokens(self, completion): + return (0, 0, 0) + + def token_count(self, *args, tools, data_model): + return 0 + + async def token_count_async(self, *args, tools, data_model): + return 0 + + def translate_model_params(self, params): + return params + + def supported_model_params(self): + return set() + + +def test_app_passes_callable_client_to_mod_server(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize_query")) + app = qc.app() + captured = {} + + def fake_mod_server(*args, **kwargs): + captured.update(kwargs) + vals = SimpleNamespace() + vals.title = lambda: None + vals.sql = lambda: None + vals.df = list + vals.title.set = lambda _value: None + vals.sql.set = lambda _value: None + return vals + + with ( + patch("querychat._shiny.mod_server", fake_mod_server), + patch("querychat._shiny.render.text", _identity), + patch("querychat._shiny.render.ui", _identity), + patch("querychat._shiny.render.data_frame", _identity), + patch("querychat._shiny.reactive.effect", _identity), + patch("querychat._shiny.reactive.event", _event), + patch("querychat._shiny.req", lambda value: value), + patch("querychat._shiny.output_markdown_stream", lambda *a, **k: None), + ): + app.server( + SimpleNamespace(reset_query=lambda: None), + SimpleNamespace(), + SimpleNamespace(), + ) + + assert callable(captured["client"]) + assert not isinstance(captured["client"], chatlas.Chat) + + +def test_express_passes_callable_client_to_mod_server(sample_df, monkeypatch): + captured = {} + + class CurrentSession: + pass + + monkeypatch.setattr("querychat._shiny.get_current_session", lambda: CurrentSession()) + monkeypatch.setattr( + "querychat._shiny.mod_server", + lambda *args, **kwargs: captured.update(kwargs) or SimpleNamespace(), + ) + + QueryChatExpress( + sample_df, + "tips", + tools=("query", "visualize_query"), + enable_bookmarking=False, + ) + + assert callable(captured["client"]) + assert not isinstance(captured["client"], chatlas.Chat) + + +def test_server_passes_callable_client_to_mod_server(sample_df, monkeypatch): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize_query")) + captured = {} + + class CurrentSession: + pass + + monkeypatch.setattr("querychat._shiny.get_current_session", lambda: CurrentSession()) + monkeypatch.setattr( + "querychat._shiny.mod_server", + lambda *args, **kwargs: captured.update(kwargs) or SimpleNamespace(), + ) + + qc.server(enable_bookmarking=False) + + assert callable(captured["client"]) + assert not isinstance(captured["client"], chatlas.Chat) + + +def test_mod_server_rejects_raw_chat_instance(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize_query")) + raw_chat = chatlas.Chat(provider=DummyProvider(name="dummy", model="dummy")) + + with ( + patch("querychat._shiny_module.preload_viz_deps_server", lambda: None), + patch("querychat._shiny_module.shinychat.Chat", DummyChatUi), + pytest.raises(TypeError, match="callable"), + ): + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + DummySession(), + data_source=qc.data_source, + greeting=qc.greeting, + client=raw_chat, + enable_bookmarking=False, + tools=qc.tools, + ) + + +def test_mod_server_stub_session_deferred_client_factory_does_not_raise(): + qc = QueryChat(None, "users") + + vals = _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + DummyStubSession(), + data_source=None, + greeting=qc.greeting, + client=qc.client, + enable_bookmarking=False, + tools=qc.tools, + ) + + with pytest.raises(RuntimeError, match="unavailable during stub session"): + _ = vals.client.stream_async + + +def test_callable_mod_server_passes_visualize_callback_and_tools(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize_query")) + captured = {} + + def client_factory(**kwargs): + captured.update(kwargs) + return qc.client(**kwargs) + + with ( + patch("querychat._shiny_module.preload_viz_deps_server", lambda: None), + patch("querychat._shiny_module.shinychat.Chat", DummyChatUi), + ): + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + DummySession(), + data_source=qc.data_source, + greeting=qc.greeting, + client=client_factory, + enable_bookmarking=False, + tools=qc.tools, + ) + + assert captured["tools"] == ("query", "visualize_query") + assert callable(captured["visualize_query"]) + assert callable(captured["update_dashboard"]) + assert callable(captured["reset_dashboard"]) + + +def test_mod_server_preloads_viz_for_each_real_session_instance(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize_query")) + session = DummySession() + preload_calls = [] + + with ( + patch( + "querychat._shiny_module.preload_viz_deps_server", + lambda: preload_calls.append("called"), + ), + patch("querychat._shiny_module.shinychat.Chat", DummyChatUi), + ): + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + session, + data_source=qc.data_source, + greeting=qc.greeting, + client=qc.client, + enable_bookmarking=False, + tools=qc.tools, + ) + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + session, + data_source=qc.data_source, + greeting=qc.greeting, + client=qc.client, + enable_bookmarking=False, + tools=qc.tools, + ) + + assert preload_calls == ["called", "called"] + + +def test_mod_server_stub_session_does_not_preload_viz(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize_query")) + preload_calls = [] + + with ( + patch( + "querychat._shiny_module.preload_viz_deps_server", + lambda: preload_calls.append("called"), + ), + patch("querychat._shiny_module.shinychat.Chat", DummyChatUi), + ): + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + DummyStubSession(), + data_source=qc.data_source, + greeting=qc.greeting, + client=qc.client, + enable_bookmarking=False, + tools=qc.tools, + ) + + assert preload_calls == [] + + +def test_restored_viz_widgets_survive_second_bookmark_cycle(sample_df): + qc = QueryChat(sample_df, "tips", tools=("query", "visualize_query")) + callbacks = {} + session = DummySession() + + def client_factory(**kwargs): + callbacks.update(kwargs) + return qc.client(**kwargs) + + with ( + patch("querychat._shiny_module.preload_viz_deps_server", lambda: None), + patch("querychat._shiny_module.shinychat.Chat", DummyChatUi), + patch( + "querychat._shiny_module.restore_viz_widgets", + lambda _data_source, saved_widgets: list(saved_widgets), + ), + ): + _raw_mod_server()( + SimpleNamespace(chat_update=lambda: None), + SimpleNamespace(), + session, + data_source=qc.data_source, + greeting=qc.greeting, + client=client_factory, + enable_bookmarking=True, + tools=qc.tools, + ) + saved = [ + { + "widget_id": "querychat_viz_1", + "ggsql": "SELECT 1 VISUALISE 1 AS x DRAW point", + } + ] + callbacks["visualize_query"](saved[0]) + + first_bookmark = SimpleNamespace(values={}) + with reactive.isolate(): + session.bookmark.bookmark_fn(first_bookmark) + assert first_bookmark.values["querychat_viz_widgets"] == saved + + with reactive.isolate(): + session.bookmark.restore_fn(SimpleNamespace(values=first_bookmark.values)) + + second_bookmark = SimpleNamespace(values={}) + with reactive.isolate(): + session.bookmark.bookmark_fn(second_bookmark) + assert second_bookmark.values["querychat_viz_widgets"] == saved diff --git a/pkg-py/tests/test_system_prompt.py b/pkg-py/tests/test_system_prompt.py index 64b64c9b7..68bf6c1c6 100644 --- a/pkg-py/tests/test_system_prompt.py +++ b/pkg-py/tests/test_system_prompt.py @@ -298,3 +298,109 @@ def test_schema_computed_for_conditional_section(self, sample_data_source): ) assert prompt.schema != "" + + +class TestVizPromptConditionals: + """Tests for visualization-related conditional rendering in the real prompt.""" + + def test_graceful_recovery_fallback_excluded_without_query_tool( + self, sample_data_source + ): + """ + When only visualize_query is enabled (no query tool), the fallback + to querychat_query should not appear in the rendered prompt. + """ + from pathlib import Path + + template_path = ( + Path(__file__).parent.parent + / "src" + / "querychat" + / "prompts" + / "prompt.md" + ) + prompt = QueryChatSystemPrompt( + prompt_template=template_path, + data_source=sample_data_source, + ) + + rendered = prompt.render(tools=("update", "visualize_query")) + + assert "fall back to" not in rendered + + def test_graceful_recovery_fallback_included_with_query_tool( + self, sample_data_source + ): + """ + When both query and visualize_query are enabled, the fallback + to querychat_query should appear. + """ + from pathlib import Path + + template_path = ( + Path(__file__).parent.parent + / "src" + / "querychat" + / "prompts" + / "prompt.md" + ) + prompt = QueryChatSystemPrompt( + prompt_template=template_path, + data_source=sample_data_source, + ) + + rendered = prompt.render(tools=("update", "query", "visualize_query")) + + assert "fall back to" in rendered + + def test_viz_only_has_no_cannot_query_message(self, sample_data_source): + """ + When only visualize_query is enabled (no query tool), the rendered prompt + should NOT contain "cannot query or analyze" and SHOULD contain + "Visualizing Data". + """ + from pathlib import Path + + template_path = ( + Path(__file__).parent.parent + / "src" + / "querychat" + / "prompts" + / "prompt.md" + ) + prompt = QueryChatSystemPrompt( + prompt_template=template_path, + data_source=sample_data_source, + ) + + rendered = prompt.render(tools=("visualize_query",)) + + assert "cannot query or analyze" not in rendered + assert "Visualizing Data" in rendered + + def test_choosing_section_only_with_both_tools(self, sample_data_source): + """ + The "Choosing Between Query and Visualization" section should only appear + when both query and visualize_query are enabled. + """ + from pathlib import Path + + template_path = ( + Path(__file__).parent.parent + / "src" + / "querychat" + / "prompts" + / "prompt.md" + ) + prompt = QueryChatSystemPrompt( + prompt_template=template_path, + data_source=sample_data_source, + ) + + rendered_both = prompt.render(tools=("query", "visualize_query")) + rendered_query_only = prompt.render(tools=("query",)) + rendered_viz_only = prompt.render(tools=("visualize_query",)) + + assert "Choosing Between Query and Visualization" in rendered_both + assert "Choosing Between Query and Visualization" not in rendered_query_only + assert "Choosing Between Query and Visualization" not in rendered_viz_only diff --git a/pkg-py/tests/test_tools.py b/pkg-py/tests/test_tools.py index 682f259cf..3267b6570 100644 --- a/pkg-py/tests/test_tools.py +++ b/pkg-py/tests/test_tools.py @@ -2,7 +2,52 @@ import warnings +import narwhals.stable.v1 as nw +import pandas as pd +import pytest +from querychat._datasource import DataFrameSource from querychat._utils import querychat_tool_starts_open +from querychat.tools import _query_impl + + +@pytest.fixture +def data_source(): + df = nw.from_native(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) + return DataFrameSource(df, "test_table") + + +class TestQueryCollapsedParameter: + """Tests for the query tool's collapsed parameter.""" + + def test_collapsed_true_sets_open_false(self, data_source, monkeypatch): + monkeypatch.delenv("QUERYCHAT_TOOL_DETAILS", raising=False) + query_fn = _query_impl(data_source) + result = query_fn("SELECT * FROM test_table", collapsed=True) + assert result.extra["display"].open is False + + def test_collapsed_false_sets_open_true(self, data_source, monkeypatch): + monkeypatch.delenv("QUERYCHAT_TOOL_DETAILS", raising=False) + query_fn = _query_impl(data_source) + result = query_fn("SELECT * FROM test_table", collapsed=False) + assert result.extra["display"].open is True + + def test_collapsed_none_falls_back_to_default(self, data_source, monkeypatch): + monkeypatch.delenv("QUERYCHAT_TOOL_DETAILS", raising=False) + query_fn = _query_impl(data_source) + result = query_fn("SELECT * FROM test_table") + assert result.extra["display"].open is True # default for query + + def test_collapsed_overrides_env_expanded(self, data_source, monkeypatch): + monkeypatch.setenv("QUERYCHAT_TOOL_DETAILS", "expanded") + query_fn = _query_impl(data_source) + result = query_fn("SELECT * FROM test_table", collapsed=True) + assert result.extra["display"].open is False + + def test_collapsed_overrides_env_collapsed(self, data_source, monkeypatch): + monkeypatch.setenv("QUERYCHAT_TOOL_DETAILS", "collapsed") + query_fn = _query_impl(data_source) + result = query_fn("SELECT * FROM test_table", collapsed=False) + assert result.extra["display"].open is True def test_querychat_tool_starts_open_default_behavior(monkeypatch): @@ -12,6 +57,7 @@ def test_querychat_tool_starts_open_default_behavior(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize_query") is True def test_querychat_tool_starts_open_expanded(monkeypatch): @@ -21,6 +67,7 @@ def test_querychat_tool_starts_open_expanded(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is True + assert querychat_tool_starts_open("visualize_query") is True def test_querychat_tool_starts_open_collapsed(monkeypatch): @@ -30,6 +77,7 @@ def test_querychat_tool_starts_open_collapsed(monkeypatch): assert querychat_tool_starts_open("query") is False assert querychat_tool_starts_open("update") is False assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize_query") is False def test_querychat_tool_starts_open_default_setting(monkeypatch): @@ -39,6 +87,7 @@ def test_querychat_tool_starts_open_default_setting(monkeypatch): assert querychat_tool_starts_open("query") is True assert querychat_tool_starts_open("update") is True assert querychat_tool_starts_open("reset") is False + assert querychat_tool_starts_open("visualize_query") is True def test_querychat_tool_starts_open_case_insensitive(monkeypatch): diff --git a/pkg-py/tests/test_truncate_error.py b/pkg-py/tests/test_truncate_error.py new file mode 100644 index 000000000..57b2db169 --- /dev/null +++ b/pkg-py/tests/test_truncate_error.py @@ -0,0 +1,52 @@ +"""Tests for truncate_error.""" + +from querychat._utils import truncate_error + + +class TestTruncateError: + def test_short_message_unchanged(self): + msg = "Column 'foo' not found" + assert truncate_error(msg) == msg + + def test_empty_string(self): + assert truncate_error("") == "" + + def test_short_message_with_blank_line_unchanged(self): + msg = "line1\n\nline2" + assert truncate_error(msg) == msg + + def test_truncates_at_blank_line(self): + msg = "Something went wrong\n\n" + "x" * 500 + result = truncate_error(msg) + assert result == "Something went wrong\n\n(error truncated)" + + def test_truncates_at_schema_dump_line(self): + msg = "Bad property\nFailed validating 'additionalProperties' in schema[0]:\n" + "x" * 500 + result = truncate_error(msg) + assert "Bad property" in result + assert "(error truncated)" in result + assert "{'additionalProperties'" not in result + + def test_hard_cap_on_long_single_line(self): + msg = "x " * 300 # 600 chars, single line, no schema markers + result = truncate_error(msg, max_chars=500) + assert len(result) <= 500 + len("\n\n(error truncated)") + assert result.endswith("\n\n(error truncated)") + + def test_hard_cap_cuts_on_word_boundary(self): + msg = "word " * 200 + result = truncate_error(msg, max_chars=100) + assert not result.split("\n\n(error truncated)")[0].endswith(" w") + + def test_preserves_first_line_of_altair_error(self): + first_line = "Additional properties are not allowed ('offset' was unexpected)" + schema_dump = "\n\nFailed validating 'additionalProperties' in schema[0]['properties']['encoding']:\n {'additionalProperties': False,\n 'properties': {'angle': " + "x" * 10000 + msg = first_line + schema_dump + result = truncate_error(msg) + assert result.startswith(first_line) + assert len(result) < 600 + + def test_custom_max_chars(self): + msg = "a" * 200 + result = truncate_error(msg, max_chars=100) + assert len(result) <= 100 + len("\n\n(error truncated)") diff --git a/pkg-py/tests/test_viz_footer.py b/pkg-py/tests/test_viz_footer.py new file mode 100644 index 000000000..7051fec43 --- /dev/null +++ b/pkg-py/tests/test_viz_footer.py @@ -0,0 +1,109 @@ +""" +Tests for visualization footer (Save dropdown, Show Query). + +The footer HTML (containing Save dropdown and Show Query toggle) is built by +_build_viz_footer() and passed as the `footer` parameter to ToolResultDisplay. +shinychat renders this in the card footer area. +""" + +from unittest.mock import MagicMock + +import narwhals.stable.v1 as nw +import polars as pl +import pytest +from htmltools import TagList, tags +from querychat._datasource import DataFrameSource + + +@pytest.fixture +def sample_df(): + return pl.DataFrame( + {"x": [1, 2, 3, 4, 5], "y": [10, 20, 15, 25, 30]} + ) + + +@pytest.fixture +def data_source(sample_df): + nw_df = nw.from_native(sample_df) + return DataFrameSource(nw_df, "test_data") + + +def _mock_output_widget(widget_id, **kwargs): + return tags.div(id=widget_id) + + +@pytest.fixture(autouse=True) +def _patch_deps(monkeypatch): + monkeypatch.setattr( + "shinywidgets.register_widget", lambda _widget_id, _chart: None + ) + monkeypatch.setattr("shinywidgets.output_widget", _mock_output_widget) + + mock_spec = MagicMock() + mock_spec.metadata.return_value = {"rows": 5, "columns": ["x", "y"]} + mock_chart = MagicMock() + mock_chart.properties.return_value = mock_chart + + mock_altair_widget = MagicMock() + mock_altair_widget.widget = mock_chart + mock_altair_widget.widget_id = "querychat_viz_test1234" + mock_altair_widget.is_compound = False + + monkeypatch.setattr( + "querychat._viz_ggsql.execute_ggsql", lambda _ds, _q: mock_spec + ) + monkeypatch.setattr( + "querychat._viz_altair_widget.AltairWidget.from_ggsql", + staticmethod(lambda _spec: mock_altair_widget), + ) + + import ggsql + from querychat import _viz_tools + + mock_raw_chart = MagicMock() + mock_vl_writer = MagicMock() + mock_vl_writer.render_chart.return_value = mock_raw_chart + monkeypatch.setattr(ggsql, "VegaLiteWriter", lambda: mock_vl_writer) + monkeypatch.setattr( + _viz_tools, "render_chart_to_png", lambda _chart: b"\x89PNG\r\n\x1a\n" + ) + + +class TestVizFooterIcons: + """Verify Bootstrap icons used in viz footer are defined in _icons.py.""" + + def test_download_icon_exists(self): + from querychat._icons import bs_icon + + html = str(bs_icon("download")) + assert "svg" in html + assert "bi-download" in html + + def test_chevron_down_icon_exists(self): + from querychat._icons import bs_icon + + html = str(bs_icon("chevron-down")) + assert "svg" in html + assert "bi-chevron-down" in html + + def test_cls_parameter_injects_class(self): + from querychat._icons import bs_icon + + html = str(bs_icon("download", cls="querychat-icon")) + assert "querychat-icon" in html + + +class TestVizPreloadMarkup: + def test_preload_markup_has_no_inline_script(self): + from querychat._viz_utils import PRELOAD_WIDGET_ID, preload_viz_deps_ui + + rendered = TagList(preload_viz_deps_ui()).render() + preload_dep = next( + dep for dep in rendered["dependencies"] if dep.name == "querychat-viz-preload" + ) + + assert PRELOAD_WIDGET_ID in rendered["html"] + assert "querychat-viz-preload" in rendered["html"] + assert "hidden" in rendered["html"] + assert "=1.5.1", - "shinychat>=0.2.8", + "shiny @ git+https://github.com/posit-dev/py-shiny.git", + "shinychat @ git+https://github.com/posit-dev/shinychat.git", "htmltools", "chatlas>=0.13.2", "narwhals>=2.2.0", @@ -48,6 +48,8 @@ ibis = ["ibis-framework>=9.0.0", "pandas"] # pandas required for ibis .execute( streamlit = ["streamlit>=1.30"] gradio = ["gradio>=6.0"] dash = ["dash-ag-grid>=31.0", "dash[async]>=3.1", "dash-bootstrap-components>=2.0", "pandas"] +# Visualization with ggsql +viz = ["ggsql>=0.2.4", "altair>=6.0", "shinywidgets>=0.8.0", "vl-convert-python>=1.9.0", "sqlglot>=26.0"] [project.urls] Homepage = "https://github.com/posit-dev/querychat" # TODO update when we have docs @@ -55,6 +57,15 @@ Repository = "https://github.com/posit-dev/querychat" Issues = "https://github.com/posit-dev/querychat/issues" Source = "https://github.com/posit-dev/querychat/tree/main/pkg-py" +[tool.uv] +# Restrict lock-file resolution to platforms we actually target in CI. +# Without this, uv may resolve dependency versions whose wheels aren't +# available on all platforms (e.g. non-x86_64 Linux), causing CI failures. +required-environments = [ + "sys_platform == 'linux' and platform_machine == 'x86_64'", + "sys_platform == 'darwin'", +] + [tool.hatch.metadata] allow-direct-references = true @@ -76,7 +87,7 @@ git_describe_command = "git describe --dirty --tags --long --match 'py/v*'" version-file = "pkg-py/src/querychat/__version.py" [dependency-groups] -dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0", "polars>=1.0.0", "pyarrow>=14.0.0", "ibis-framework[duckdb]>=9.0.0"] +dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0", "polars>=1.0.0", "pyarrow>=14.0.0", "ibis-framework[duckdb]>=9.0.0", "ggsql>=0.2.4", "altair>=6.0", "shinywidgets>=0.8.0", "vl-convert-python>=1.9.0"] docs = ["quartodoc>=0.11.1", "griffe<2", "nbformat", "nbclient", "ipykernel"] examples = [ "openai", @@ -214,13 +225,14 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" # disable S101 (flagging asserts) for tests [tool.ruff.lint.per-file-ignores] -"pkg-py/tests/*.py" = ["S101", "PLR2004"] # Allow assert and magic numbers in tests +"pkg-py/tests/*.py" = ["S101", "PLR2004", "ARG", "PLW0108"] # Allow assert, magic numbers, unused args, and unnecessary lambdas in tests "pkg-py/tests/playwright/*.py" = ["S101", "PLR2004", "S310", "S603", "S607", "PERF203"] # Test fixtures launch subprocesses "pkg-py/examples/tests/*.py" = ["S101", "PLR2004"] # Allow assert and magic numbers in tests "pkg-py/src/querychat/_dash.py" = ["E402"] # Backwards-compat aliases at end of file "pkg-py/src/querychat/_gradio.py" = ["E402"] # Backwards-compat aliases at end of file "pkg-py/src/querychat/_streamlit.py" = ["E402"] # Backwards-compat aliases at end of file "pkg-py/src/querychat/types/__init__.py" = ["A005"] # Deliberately shadows stdlib types module +"pkg-py/docs/_screenshots/*.py" = ["S310", "PLR2004", "PERF203"] # Dev utility scripts [tool.ruff.format] quote-style = "double" @@ -230,6 +242,9 @@ line-ending = "auto" docstring-code-format = true docstring-code-line-length = "dynamic" +[tool.pytest.ini_options] +markers = ["ggsql: requires working ggsql.render_altair()"] + [tool.pyright] include = ["pkg-py/src/querychat"]