diff --git a/tests/e2e/test_transactions.py b/tests/e2e/test_transactions.py index 4fb7918b9..e91afc0dd 100644 --- a/tests/e2e/test_transactions.py +++ b/tests/e2e/test_transactions.py @@ -29,6 +29,7 @@ import pytest import databricks.sql as sql +from databricks.sql.exc import DatabaseError logger = logging.getLogger(__name__) @@ -472,150 +473,78 @@ def test_unsupported_isolation_level_rejected(self, mst_conn_params): class TestMstMetadata: - """Metadata RPCs inside active transactions. - - Python uses Thrift RPCs for cursor.columns, cursor.tables, etc. These - RPCs bypass MST context and return non-transactional data — they see - concurrent DDL changes that the transaction shouldn't see. + """Thrift metadata RPCs inside active transactions. + + Python's cursor.columns/tables/schemas/catalogs map to Thrift + Get{Columns,Tables,Schemas,Catalogs} RPCs. The server's MST guard + rejects these RPCs with a "not supported within a multi-statement + transaction" error. The rejection happens before reaching the txn, + so the active transaction itself remains usable (unlike the SQL + forms in TestMstBlockedSql, which abort the txn). """ - def test_cursor_columns_in_mst( - self, mst_conn_params, mst_table, mst_catalog, mst_schema - ): - fq_table, table_name = mst_table - with sql.connect(**mst_conn_params) as conn: - conn.autocommit = False - with conn.cursor() as cursor: - cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") - cursor.columns( - catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name - ) - columns = cursor.fetchall() - assert len(columns) > 0 - conn.rollback() + def _assert_metadata_rpc_blocked(self, mst_conn_params, fq_table, rpc): + """Assert the metadata RPC raises inside an active MST. - def test_cursor_tables_in_mst( - self, mst_conn_params, mst_table, mst_catalog, mst_schema - ): - fq_table, table_name = mst_table - with sql.connect(**mst_conn_params) as conn: - conn.autocommit = False - with conn.cursor() as cursor: - cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") - cursor.tables( - catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name - ) - tables = cursor.fetchall() - assert len(tables) > 0 - conn.rollback() + The Thrift Get* RPCs are rejected by the MST gateway before reaching + the transaction, so the txn itself remains usable — only the RPC + call fails. - def test_cursor_schemas_in_mst(self, mst_conn_params, mst_table, mst_catalog): - fq_table, _ = mst_table + `rpc` is a callable that takes a cursor and invokes the metadata + RPC under test. + """ with sql.connect(**mst_conn_params) as conn: conn.autocommit = False with conn.cursor() as cursor: - cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") - cursor.schemas(catalog_name=mst_catalog) - schemas = cursor.fetchall() - assert len(schemas) > 0 - conn.rollback() + cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'before_blocked')") - def test_cursor_catalogs_in_mst(self, mst_conn_params, mst_table): - fq_table, _ = mst_table - with sql.connect(**mst_conn_params) as conn: - conn.autocommit = False - with conn.cursor() as cursor: - cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") - cursor.catalogs() - catalogs = cursor.fetchall() - assert len(catalogs) > 0 + with pytest.raises(DatabaseError, match="multi-statement transaction"): + rpc(cursor) conn.rollback() - @pytest.mark.xdist_group(name="mst_freshness_columns") - def test_cursor_columns_non_transactional_after_concurrent_ddl( + def test_cursor_columns_blocked( self, mst_conn_params, mst_table, mst_catalog, mst_schema ): - """Thrift cursor.columns() bypasses MST — sees concurrent ALTER TABLE.""" fq_table, table_name = mst_table - with sql.connect(**mst_conn_params) as conn: - conn.autocommit = False - with conn.cursor() as cursor: - cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") - cursor.columns( - catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name - ) - before_cols = {row[3].lower() for row in cursor.fetchall()} - - # External connection alters schema - with sql.connect(**mst_conn_params) as ext_conn: - with ext_conn.cursor() as ext_cursor: - ext_cursor.execute( - f"ALTER TABLE {fq_table} ADD COLUMN new_col STRING" - ) - - # Re-read columns in same txn — Thrift RPC bypasses txn isolation, - # so new_col IS visible (proves non-transactional behavior) - with conn.cursor() as cursor: - cursor.columns( - catalog_name=mst_catalog, schema_name=mst_schema, table_name=table_name - ) - after_cols = {row[3].lower() for row in cursor.fetchall()} - - assert "new_col" in after_cols, ( - "Thrift cursor.columns() should see concurrent DDL " - "(non-transactional behavior)" - ) - assert before_cols != after_cols - conn.rollback() + self._assert_metadata_rpc_blocked( + mst_conn_params, + fq_table, + lambda cursor: cursor.columns( + catalog_name=mst_catalog, + schema_name=mst_schema, + table_name=table_name, + ), + ) - @pytest.mark.xdist_group(name="mst_freshness_tables") - def test_cursor_tables_non_transactional_after_concurrent_create( + def test_cursor_tables_blocked( self, mst_conn_params, mst_table, mst_catalog, mst_schema ): - """Thrift cursor.tables() bypasses MST — sees concurrent CREATE TABLE.""" - fq_table, _ = mst_table - new_table_name = _unique_table_name_raw("freshness_new_tbl") - fq_new_table = f"{mst_catalog}.{mst_schema}.{new_table_name}" - - try: - with sql.connect(**mst_conn_params) as conn: - conn.autocommit = False - with conn.cursor() as cursor: - cursor.execute(f"INSERT INTO {fq_table} VALUES (1, 'test')") - cursor.tables( - catalog_name=mst_catalog, - schema_name=mst_schema, - table_name=new_table_name, - ) - assert len(cursor.fetchall()) == 0 + fq_table, table_name = mst_table + self._assert_metadata_rpc_blocked( + mst_conn_params, + fq_table, + lambda cursor: cursor.tables( + catalog_name=mst_catalog, + schema_name=mst_schema, + table_name=table_name, + ), + ) - # External connection creates the table - with sql.connect(**mst_conn_params) as ext_conn: - with ext_conn.cursor() as ext_cursor: - ext_cursor.execute( - f"CREATE TABLE {fq_new_table} (id INT) USING DELTA " - f"TBLPROPERTIES ('delta.feature.catalogManaged' = 'supported')" - ) + def test_cursor_schemas_blocked(self, mst_conn_params, mst_table, mst_catalog): + fq_table, _ = mst_table + self._assert_metadata_rpc_blocked( + mst_conn_params, + fq_table, + lambda cursor: cursor.schemas(catalog_name=mst_catalog), + ) - # Re-read in same txn — should see the new table - with conn.cursor() as cursor: - cursor.tables( - catalog_name=mst_catalog, - schema_name=mst_schema, - table_name=new_table_name, - ) - assert len(cursor.fetchall()) > 0, ( - "Thrift cursor.tables() should see concurrent CREATE TABLE " - "(non-transactional behavior)" - ) - conn.rollback() - finally: - try: - with sql.connect(**mst_conn_params) as conn: - with conn.cursor() as cursor: - cursor.execute(f"DROP TABLE IF EXISTS {fq_new_table}") - except Exception as e: - logger.warning(f"Failed to drop {fq_new_table}: {e}") + def test_cursor_catalogs_blocked(self, mst_conn_params, mst_table): + fq_table, _ = mst_table + self._assert_metadata_rpc_blocked( + mst_conn_params, + fq_table, + lambda cursor: cursor.catalogs(), + ) # ==================== D. BLOCKED SQL (MSTCheckRule) ==================== @@ -635,6 +564,7 @@ class TestMstBlockedSql: - SHOW TABLES, SHOW SCHEMAS, SHOW CATALOGS, SHOW FUNCTIONS - DESCRIBE QUERY, DESCRIBE TABLE EXTENDED - SELECT FROM information_schema + - Thrift Get{Catalogs,Schemas,Tables,Columns} RPCs (see TestMstMetadata) Allowed: - DESCRIBE TABLE (basic form) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 136c99e53..5d37cd9a5 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -8,6 +8,7 @@ THandleIdentifier, ) from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.common.agent import KNOWN_AGENTS from databricks.sql.session import Session import databricks.sql @@ -97,7 +98,9 @@ def test_tls_arg_passthrough(self, mock_client_class, mock_http_client): assert kwargs["_tls_client_cert_key_password"] == "key password" @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_useragent_header(self, mock_client_class): + def test_useragent_header(self, mock_client_class, monkeypatch): + for env_var, _ in KNOWN_AGENTS: + monkeypatch.delenv(env_var, raising=False) databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) call_kwargs = mock_client_class.call_args[1]