diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index 9b2f8e8db..422c207b9 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -129,6 +129,15 @@ def __init__( # Guarded by ``_async_handles_lock`` so concurrent cursors on the # same connection don't race on submit / close / close-session. self._async_handles: Dict[str, Any] = {} + # Parent ``Statement`` objects kept alive alongside async handles. + # On the kernel, ``Statement.close()`` flips the validity flag on + # the produced executed handle (see kernel + # ``statement::mutable::close``), so we cannot close the + # Statement immediately after ``submit()`` as we do for sync + # ``execute()``. Instead retain it here and close it in + # ``close_command`` / ``close_session`` after the async handle + # has finished its work. + self._async_statements: Dict[str, Any] = {} # CommandId.guids of async commands that have already been # closed (via ``close_command`` or ``close_session``). Lets # ``get_query_state`` report ``CLOSED`` for them rather than @@ -167,6 +176,16 @@ def open_session( schema=schema or self._schema, session_conf=session_conf, complex_types_as_json=not self._use_arrow_native_complex_types, + # Pyarrow's Python bindings cannot decode Arrow's + # ``month_interval`` type at all (id 21 — raises + # ``KeyError`` from ``.as_py``, ``to_pylist``, + # ``cast(string)``, and ``to_pandas``). Ask the kernel + # to stringify INTERVAL / DURATION columns server-side + # so result sets containing interval columns are + # decodable on the Python side. Matches the Thrift + # backend's surface (interval columns arrive as + # strings). + intervals_as_string=True, **auth_kwargs, ) except Exception as exc: @@ -197,7 +216,9 @@ def close_session(self, session_id: SessionId) -> None: # server-side CloseStatement before the session goes away. with self._async_handles_lock: tracked = list(self._async_handles.items()) + tracked_stmts = list(self._async_statements.items()) self._async_handles.clear() + self._async_statements.clear() for guid, _ in tracked: self._closed_commands.add(guid) for _, handle in tracked: @@ -211,6 +232,16 @@ def close_session(self, session_id: SessionId) -> None: logger.warning( "Error closing async handle during session close: %s", exc ) + # Now drop the parent Statements that were keeping those handles + # alive. Same non-fatal close semantics — close errors are not + # actionable at session-close time. + for _, stmt in tracked_stmts: + try: + stmt.close() + except Exception as exc: + logger.warning( + "Error closing async statement during session close: %s", exc + ) try: self._kernel_session.close() except Exception as exc: @@ -249,6 +280,11 @@ def execute_command( stmt = self._kernel_session.statement() except Exception as exc: raise _wrap_kernel_exception("execute_command", exc) from exc + # ``async_op`` keeps ``stmt`` alive (tracked in + # ``_async_statements`` and closed by ``close_command``); the sync + # path drops it in finally. ``close_stmt`` is the post-success + # decision flag — it stays True on sync, flips to False on async. + close_stmt = True try: try: stmt.set_sql(operation) @@ -262,21 +298,26 @@ def execute_command( cursor.active_command_id = command_id with self._async_handles_lock: self._async_handles[command_id.guid] = async_exec + # Closing the kernel ``Statement`` invalidates the + # async handle (see kernel validity flag). Retain + # the Statement here and close it on + # ``close_command`` / ``close_session``. + self._async_statements[command_id.guid] = stmt + close_stmt = False return None executed = stmt.execute() except Exception as exc: raise _wrap_kernel_exception("execute_command", exc) from exc finally: - # ``Statement`` is a lifecycle owner separate from the - # executed handle it produces. Drop it here so the - # parent doesn't keep the handle alive longer than the - # caller expects. Swallow all close errors (including - # PyO3 native exceptions) — a failed stmt.close() is - # not actionable for the caller. - try: - stmt.close() - except Exception: - pass + if close_stmt: + # Sync path: ``Statement`` is a lifecycle owner separate + # from the executed handle. Drop it here so the parent + # doesn't outlive its caller. Swallow close errors — + # they're not actionable. + try: + stmt.close() + except Exception: + pass command_id = CommandId.from_sea_statement_id(executed.statement_id) cursor.active_command_id = command_id @@ -307,17 +348,34 @@ def cancel_command(self, command_id: CommandId) -> None: def close_command(self, command_id: CommandId) -> None: with self._async_handles_lock: handle = self._async_handles.pop(command_id.guid, None) + stmt = self._async_statements.pop(command_id.guid, None) if handle is not None: # Record the close so ``get_query_state`` can report # ``CLOSED`` (not ``SUCCEEDED``) for this command. self._closed_commands.add(command_id.guid) if handle is None: logger.debug("close_command: no tracked handle for %s", command_id) + # Still drop the parent Statement if somehow tracked without + # the handle — keeps the invariant clean even on bookkeeping + # races. + if stmt is not None: + try: + stmt.close() + except Exception: + pass return try: handle.close() except Exception as exc: raise _wrap_kernel_exception("close_command", exc) from exc + finally: + # Now safe to close the parent Statement — the executed + # handle has finished its lifecycle. + if stmt is not None: + try: + stmt.close() + except Exception: + pass def get_query_state(self, command_id: CommandId) -> CommandState: with self._async_handles_lock: @@ -378,6 +436,7 @@ def get_execution_result( # it wraps. Drop tracking and fire-and-forget the close. with self._async_handles_lock: self._async_handles.pop(command_id.guid, None) + stmt = self._async_statements.pop(command_id.guid, None) self._closed_commands.add(command_id.guid) try: async_exec.close() @@ -387,6 +446,18 @@ def get_execution_result( command_id, exc, ) + # The parent Statement is no longer needed once the async handle + # has produced its ResultStream. Close to release server-side + # tracking; matches the sync path's eager Statement close. + if stmt is not None: + try: + stmt.close() + except Exception as exc: + logger.warning( + "Error closing async statement after await_result for %s: %s", + command_id, + exc, + ) # ``KernelResultSet.__init__`` calls ``arrow_schema()`` which # can raise — map that to PEP 249 too. try: diff --git a/src/databricks/sql/backend/kernel/type_mapping.py b/src/databricks/sql/backend/kernel/type_mapping.py index fc1a338cd..29dd875ff 100644 --- a/src/databricks/sql/backend/kernel/type_mapping.py +++ b/src/databricks/sql/backend/kernel/type_mapping.py @@ -21,7 +21,7 @@ from __future__ import annotations -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple import pyarrow @@ -102,6 +102,14 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]: backend's behaviour; other precise types (``INTERVAL_*``, ``GEOMETRY``, ``GEOGRAPHY``) collapse to their Arrow shape on both backends and don't need a remap. + + ``precision`` / ``scale`` are extracted from ``Decimal128Type`` / + ``Decimal256Type`` so DECIMAL columns expose the same + ``(precision, scale)`` pair the Thrift backend reports. The Arrow + schema carries these on the type itself; without this extraction + the kernel-backend description would silently drop them, breaking + parity for any consumer (SQLAlchemy, pandas-read-sql, etc.) that + reads slots 4/5 to know how to display or round decimal values. """ return [ ( @@ -109,14 +117,32 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]: _databricks_type_for_field(field), None, None, - None, - None, + *_precision_scale_for_arrow_type(field.type), None, ) for field in schema ] +def _precision_scale_for_arrow_type( + arrow_type: pyarrow.DataType, +) -> Tuple[Optional[int], Optional[int]]: + """Extract PEP 249 ``(precision, scale)`` from an Arrow type. + + Only Arrow's decimal types carry both; every other type collapses + to ``(None, None)`` to match the Thrift backend's behaviour. Future + extensions (e.g. fractional-second precision from + ``Time64Type`` / ``Timestamp``) can land here without touching the + description builder above. + """ + if pyarrow.types.is_decimal(arrow_type): + # Decimal128Type / Decimal256Type both expose `.precision` and + # `.scale`. The cast is for the type checker — pyarrow's + # `DataType` base type doesn't declare them. + return arrow_type.precision, arrow_type.scale # type: ignore[attr-defined] + return None, None + + def _databricks_type_for_field(field: pyarrow.Field) -> str: """Pick the PEP 249 type code for a single field.