Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 81 additions & 10 deletions src/databricks/sql/backend/kernel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ def __init__(
# Guarded by ``_async_handles_lock`` so concurrent cursors on the
# same connection don't race on submit / close / close-session.
self._async_handles: Dict[str, Any] = {}
# Parent ``Statement`` objects kept alive alongside async handles.
# On the kernel, ``Statement.close()`` flips the validity flag on
# the produced executed handle (see kernel
# ``statement::mutable::close``), so we cannot close the
# Statement immediately after ``submit()`` as we do for sync
# ``execute()``. Instead retain it here and close it in
# ``close_command`` / ``close_session`` after the async handle
# has finished its work.
self._async_statements: Dict[str, Any] = {}
# CommandId.guids of async commands that have already been
# closed (via ``close_command`` or ``close_session``). Lets
# ``get_query_state`` report ``CLOSED`` for them rather than
Expand Down Expand Up @@ -167,6 +176,16 @@ def open_session(
schema=schema or self._schema,
session_conf=session_conf,
complex_types_as_json=not self._use_arrow_native_complex_types,
# Pyarrow's Python bindings cannot decode Arrow's
# ``month_interval`` type at all (id 21 — raises
# ``KeyError`` from ``.as_py``, ``to_pylist``,
# ``cast(string)``, and ``to_pandas``). Ask the kernel
# to stringify INTERVAL / DURATION columns server-side
# so result sets containing interval columns are
# decodable on the Python side. Matches the Thrift
# backend's surface (interval columns arrive as
# strings).
intervals_as_string=True,
**auth_kwargs,
)
except Exception as exc:
Expand Down Expand Up @@ -197,7 +216,9 @@ def close_session(self, session_id: SessionId) -> None:
# server-side CloseStatement before the session goes away.
with self._async_handles_lock:
tracked = list(self._async_handles.items())
tracked_stmts = list(self._async_statements.items())
self._async_handles.clear()
self._async_statements.clear()
for guid, _ in tracked:
self._closed_commands.add(guid)
for _, handle in tracked:
Expand All @@ -211,6 +232,16 @@ def close_session(self, session_id: SessionId) -> None:
logger.warning(
"Error closing async handle during session close: %s", exc
)
# Now drop the parent Statements that were keeping those handles
# alive. Same non-fatal close semantics — close errors are not
# actionable at session-close time.
for _, stmt in tracked_stmts:
try:
stmt.close()
except Exception as exc:
logger.warning(
"Error closing async statement during session close: %s", exc
)
try:
self._kernel_session.close()
except Exception as exc:
Expand Down Expand Up @@ -249,6 +280,11 @@ def execute_command(
stmt = self._kernel_session.statement()
except Exception as exc:
raise _wrap_kernel_exception("execute_command", exc) from exc
# ``async_op`` keeps ``stmt`` alive (tracked in
# ``_async_statements`` and closed by ``close_command``); the sync
# path drops it in finally. ``close_stmt`` is the post-success
# decision flag — it stays True on sync, flips to False on async.
close_stmt = True
try:
try:
stmt.set_sql(operation)
Expand All @@ -262,21 +298,26 @@ def execute_command(
cursor.active_command_id = command_id
with self._async_handles_lock:
self._async_handles[command_id.guid] = async_exec
# Closing the kernel ``Statement`` invalidates the
# async handle (see kernel validity flag). Retain
# the Statement here and close it on
# ``close_command`` / ``close_session``.
self._async_statements[command_id.guid] = stmt
close_stmt = False
return None
executed = stmt.execute()
except Exception as exc:
raise _wrap_kernel_exception("execute_command", exc) from exc
finally:
# ``Statement`` is a lifecycle owner separate from the
# executed handle it produces. Drop it here so the
# parent doesn't keep the handle alive longer than the
# caller expects. Swallow all close errors (including
# PyO3 native exceptions) — a failed stmt.close() is
# not actionable for the caller.
try:
stmt.close()
except Exception:
pass
if close_stmt:
# Sync path: ``Statement`` is a lifecycle owner separate
# from the executed handle. Drop it here so the parent
# doesn't outlive its caller. Swallow close errors —
# they're not actionable.
try:
stmt.close()
except Exception:
pass

command_id = CommandId.from_sea_statement_id(executed.statement_id)
cursor.active_command_id = command_id
Expand Down Expand Up @@ -307,17 +348,34 @@ def cancel_command(self, command_id: CommandId) -> None:
def close_command(self, command_id: CommandId) -> None:
with self._async_handles_lock:
handle = self._async_handles.pop(command_id.guid, None)
stmt = self._async_statements.pop(command_id.guid, None)
if handle is not None:
# Record the close so ``get_query_state`` can report
# ``CLOSED`` (not ``SUCCEEDED``) for this command.
self._closed_commands.add(command_id.guid)
if handle is None:
logger.debug("close_command: no tracked handle for %s", command_id)
# Still drop the parent Statement if somehow tracked without
# the handle — keeps the invariant clean even on bookkeeping
# races.
if stmt is not None:
try:
stmt.close()
except Exception:
pass
return
try:
handle.close()
except Exception as exc:
raise _wrap_kernel_exception("close_command", exc) from exc
finally:
# Now safe to close the parent Statement — the executed
# handle has finished its lifecycle.
if stmt is not None:
try:
stmt.close()
except Exception:
pass

def get_query_state(self, command_id: CommandId) -> CommandState:
with self._async_handles_lock:
Expand Down Expand Up @@ -378,6 +436,7 @@ def get_execution_result(
# it wraps. Drop tracking and fire-and-forget the close.
with self._async_handles_lock:
self._async_handles.pop(command_id.guid, None)
stmt = self._async_statements.pop(command_id.guid, None)
self._closed_commands.add(command_id.guid)
try:
async_exec.close()
Expand All @@ -387,6 +446,18 @@ def get_execution_result(
command_id,
exc,
)
# The parent Statement is no longer needed once the async handle
# has produced its ResultStream. Close to release server-side
# tracking; matches the sync path's eager Statement close.
if stmt is not None:
try:
stmt.close()
except Exception as exc:
logger.warning(
"Error closing async statement after await_result for %s: %s",
command_id,
exc,
)
# ``KernelResultSet.__init__`` calls ``arrow_schema()`` which
# can raise — map that to PEP 249 too.
try:
Expand Down
32 changes: 29 additions & 3 deletions src/databricks/sql/backend/kernel/type_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from __future__ import annotations

from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple

import pyarrow

Expand Down Expand Up @@ -102,21 +102,47 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]:
backend's behaviour; other precise types (``INTERVAL_*``,
``GEOMETRY``, ``GEOGRAPHY``) collapse to their Arrow shape on
both backends and don't need a remap.

``precision`` / ``scale`` are extracted from ``Decimal128Type`` /
``Decimal256Type`` so DECIMAL columns expose the same
``(precision, scale)`` pair the Thrift backend reports. The Arrow
schema carries these on the type itself; without this extraction
the kernel-backend description would silently drop them, breaking
parity for any consumer (SQLAlchemy, pandas-read-sql, etc.) that
reads slots 4/5 to know how to display or round decimal values.
"""
return [
(
field.name,
_databricks_type_for_field(field),
None,
None,
None,
None,
*_precision_scale_for_arrow_type(field.type),
None,
)
for field in schema
]


def _precision_scale_for_arrow_type(
arrow_type: pyarrow.DataType,
) -> Tuple[Optional[int], Optional[int]]:
"""Extract PEP 249 ``(precision, scale)`` from an Arrow type.

Only Arrow's decimal types carry both; every other type collapses
to ``(None, None)`` to match the Thrift backend's behaviour. Future
extensions (e.g. fractional-second precision from
``Time64Type`` / ``Timestamp``) can land here without touching the
description builder above.
"""
if pyarrow.types.is_decimal(arrow_type):
# Decimal128Type / Decimal256Type both expose `.precision` and
# `.scale`. The cast is for the type checker — pyarrow's
# `DataType` base type doesn't declare them.
return arrow_type.precision, arrow_type.scale # type: ignore[attr-defined]
return None, None


def _databricks_type_for_field(field: pyarrow.Field) -> str:
"""Pick the PEP 249 type code for a single field.

Expand Down
Loading