Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions pkg-py/src/querychat/_dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
stream_response_async,
)
from ._ui_assets import DASH_CSS, DASH_JS, SUGGESTION_CSS
from ._utils import as_narwhals
from ._utils import as_narwhals, maybe_truncate

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -207,10 +207,17 @@ def store_id(self) -> str:
"""
return self._ids.store

def app(self) -> dash.Dash:
def app(self, *, max_rows: Optional[int] = 1000) -> dash.Dash:
"""
Create a complete Dash app.

Parameters
----------
max_rows
Maximum number of rows to display in the data table. This does not
affect the number of rows that the LLM can query against. Default
is 1000. Set to ``None`` to disable row limit.

Returns
-------
dash.Dash
Expand All @@ -237,6 +244,7 @@ def app(self) -> dash.Dash:
self._ids,
data_source.table_name,
self._deserialize_state,
max_rows=max_rows,
)

return app
Expand Down Expand Up @@ -425,6 +433,8 @@ def register_app_callbacks(
ids: IDs,
table_name: str,
deserialize_state: Callable[[AppStateDict], AppState],
*,
max_rows: int | None = None,
) -> None:
"""Register callbacks for SQL display, data table, and export."""
from dash.dcc.express import send_data_frame
Expand Down Expand Up @@ -459,17 +469,16 @@ def update_display(state_data: AppStateDict, reset_clicks):
sql_title = state.title or "SQL Query"
sql_code = f"```sql\n{state.get_display_sql()}\n```"

nw_df = as_narwhals(state.get_current_data())
nrow, ncol = nw_df.shape
result = maybe_truncate(state.get_current_data(), max_rows)

display_df = nw_df.to_pandas()
display_df = result.df.to_pandas()
table_data = display_df.to_dict("records")
table_columns = [{"field": col} for col in display_df.columns]

data_info_parts = []
if state.error:
data_info_parts.append(f"Warning: {state.error}")
data_info_parts.append(f"Data has {nrow} rows and {ncol} columns.")
data_info_parts.append(result.info_message)
data_info = " ".join(data_info_parts)

return (
Expand Down
18 changes: 12 additions & 6 deletions pkg-py/src/querychat/_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
stream_response,
)
from ._ui_assets import GRADIO_CSS, GRADIO_JS, SUGGESTION_CSS
from ._utils import as_narwhals
from ._utils import maybe_truncate

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -317,10 +317,17 @@ def submit_message(message: str, state_dict: AppStateDict):

return state_holder

def app(self) -> GradioBlocksWrapper:
def app(self, *, max_rows: Optional[int] = 1000) -> GradioBlocksWrapper:
"""
Create a complete Gradio app.

Parameters
----------
max_rows
Maximum number of rows to display in the data table. This does not
affect the number of rows that the LLM can query against. Default
is 1000. Set to ``None`` to disable row limit.

Returns
-------
GradioBlocksWrapper
Expand Down Expand Up @@ -379,14 +386,13 @@ def update_displays(state_dict: AppStateDict):
)

df = self.df(state_dict)
nw_df = as_narwhals(df)
nrow, ncol = nw_df.shape
native_df = nw_df.to_native()
result = maybe_truncate(df, max_rows)
native_df = result.df.to_native()

data_info_parts = []
if error:
data_info_parts.append(f"⚠️ {error}")
data_info_parts.append(f"*Data has {nrow} rows and {ncol} columns.*")
data_info_parts.append(f"*{result.info_message}*")
data_info_text = " ".join(data_info_parts)

return sql_title_text, sql_code, native_df, data_info_text
Expand Down
26 changes: 22 additions & 4 deletions pkg-py/src/querychat/_shiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ._icons import bs_icon
from ._querychat_base import DEFAULT_TOOLS, TOOL_GROUPS, QueryChatBase
from ._shiny_module import ServerValues, mod_server, mod_ui
from ._utils import MISSING, MISSING_TYPE, as_narwhals
from ._utils import MISSING, MISSING_TYPE, maybe_truncate
from ._viz_utils import has_viz_tool

if TYPE_CHECKING:
Expand Down Expand Up @@ -242,7 +242,10 @@ def __init__(
self.id = id or f"querychat_{table_name}"

def app(
self, *, bookmark_store: Literal["url", "server", "disable"] = "url"
self,
*,
max_rows: Optional[int] = 1000,
bookmark_store: Literal["url", "server", "disable"] = "url",
) -> App:
"""
Quickly chat with a dataset.
Expand All @@ -252,6 +255,10 @@ def app(

Parameters
----------
max_rows
Maximum number of rows to display in the data table. This does not
affect the number of rows that the LLM can query against. Default
is 1000. Set to ``None`` to disable row limit.
bookmark_store
The bookmarking store to use for the Shiny app. Options are:
- `"url"`: Store bookmarks in the URL (default).
Expand Down Expand Up @@ -290,6 +297,10 @@ def app_ui(request):
ui.card(
ui.card_header(bs_icon("table"), " Data"),
ui.output_data_frame("dt"),
ui.card_footer(
ui.output_text("data_info"),
class_="text-muted small",
),
),
title=ui.span("querychat with ", ui.code(table_name)),
class_="bslib-page-dashboard",
Expand Down Expand Up @@ -325,10 +336,17 @@ def _():
vals.sql.set(None)
vals.title.set(None)

@reactive.calc
def truncated():
return maybe_truncate(vals.df(), max_rows)

@render.data_frame
def dt():
# Collect lazy sources (LazyFrame, Ibis Table) to eager DataFrame
return as_narwhals(vals.df())
return truncated().df

@render.text
def data_info():
return truncated().info_message

@render.ui
def sql_output():
Expand Down
25 changes: 18 additions & 7 deletions pkg-py/src/querychat/_streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
stream_response,
)
from ._ui_assets import STREAMLIT_JS, SUGGESTION_CSS
from ._utils import as_narwhals
from ._utils import maybe_truncate

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -175,12 +175,20 @@ def _get_state(self) -> AppState:
)
return st.session_state[self._state_key]

def app(self) -> None:
def app(self, *, max_rows: Optional[int] = 1000) -> None:
"""
Render a complete Streamlit app.

Configures the page, renders chat in sidebar, and displays
SQL query and data table in the main area.

Parameters
----------
max_rows
Maximum number of rows to display in the data table. This does not
affect the number of rows that the LLM can query against. Default
is 1000. Set to ``None`` to disable row limit.

"""
data_source = self._require_data_source("app")
import streamlit as st
Expand All @@ -192,7 +200,7 @@ def app(self) -> None:
)

self.sidebar()
self._render_main_content()
self._render_main_content(max_rows=max_rows)

def sidebar(self) -> None:
"""Render the chat interface in the Streamlit sidebar."""
Expand Down Expand Up @@ -303,7 +311,7 @@ def reset(self) -> None:
state.reset_dashboard()
st.rerun()

def _render_main_content(self) -> None:
def _render_main_content(self, *, max_rows: Optional[int] = None) -> None:
"""Render the main content area (SQL + data table)."""
data_source = self._require_data_source("_render_main_content")
import streamlit as st
Expand All @@ -324,10 +332,13 @@ def _render_main_content(self) -> None:
st.rerun()

st.subheader("Data view")
df = as_narwhals(state.get_current_data())
if state.error:
st.error(state.error)
result = maybe_truncate(state.get_current_data(), max_rows)
st.dataframe(
df.to_native(), use_container_width=True, height=400, hide_index=True
result.df.to_native(),
use_container_width=True,
height=400,
hide_index=True,
)
st.caption(f"Data has {df.shape[0]} rows and {df.shape[1]} columns.")
st.caption(result.info_message)
66 changes: 66 additions & 0 deletions pkg-py/src/querychat/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional, overload

Expand Down Expand Up @@ -354,3 +355,68 @@ def read_prompt_template(filename: str, **kwargs: object) -> str:
template_path = Path(__file__).parent / "prompts" / filename
template = template_path.read_text()
return chevron.render(template, kwargs)


@dataclass
class TruncationResult:
"""Result of maybe_truncate(), holding the (possibly truncated) DataFrame and metadata."""

df: nw.DataFrame
total_rows: int
total_cols: int
truncated: bool

@property
def info_message(self) -> str:
"""User-facing message describing the data dimensions and any truncation."""
if self.truncated:
return f"Showing first {len(self.df)} of {self.total_rows} rows ({self.total_cols} columns)."
return f"Data has {self.total_rows} rows and {self.total_cols} columns."


def maybe_truncate(
df: Any,
max_rows: int | None,
*,
warn: bool = True,
) -> TruncationResult:
"""
Collect and optionally truncate data for display.

Accepts any type that :func:`as_narwhals` understands (native DataFrame,
narwhals frame, Polars LazyFrame, Ibis Table). For lazy sources, truncation
is applied before collection so the backend only transfers *max_rows* rows.

Parameters
----------
df
Raw data from a data source.
max_rows
Maximum rows to display. ``None`` disables truncation.
warn
If True and truncation occurs, emit a warning for the developer.

"""
if max_rows is None:
nw_df = as_narwhals(df)
total_rows, total_cols = nw_df.shape
else:
nw_lazy = as_narwhals(df, lazy=True)
total_rows = int(nw_lazy.select(nw.len()).collect().item())
total_cols = len(nw_lazy.collect_schema())
nw_df = nw_lazy.head(max_rows).collect() if total_rows > max_rows else nw_lazy.collect()

truncated = max_rows is not None and total_rows > max_rows
if truncated and warn:
warnings.warn(
f"querychat: Displaying {max_rows} of {total_rows} rows. "
"Set `max_rows` to increase or `None` to disable.",
stacklevel=2,
)

return TruncationResult(
df=nw_df,
total_rows=int(total_rows),
total_cols=int(total_cols),
truncated=truncated,
)
Loading