diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 080cff6dfa..6455e5fa90 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -179,6 +179,10 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): """Default precision of the timestamp type""" max_timestamp_precision: int = 6 """Maximum supported timestamp precision""" + supports_timestamp_precision_configuration: bool = True + """Whether destination supports configuring precision of its timestamp type.""" + supports_binary_precision_configuration: bool = True + """Whether destination supports configuring precision of its binary type.""" max_rows_per_insert: Optional[int] = None insert_values_writer_type: str = "default" diff --git a/dlt/common/destination/utils.py b/dlt/common/destination/utils.py index f44a3f66db..d5fab66205 100644 --- a/dlt/common/destination/utils.py +++ b/dlt/common/destination/utils.py @@ -16,6 +16,7 @@ TableNotFound, ) from dlt.common.schema.typing import ( + TColumnSchema, TColumnType, TLoaderMergeStrategy, TLoaderReplaceStrategy, @@ -230,6 +231,8 @@ def prepare_load_table( # remove incomplete columns for column, _ in find_incomplete_columns(table): prep_table["columns"].pop(column["name"]) + for column in prep_table["columns"].values(): + _drop_unsupported_precision_hints(column, destination_capabilities) return prep_table # type: ignore[return-value] except KeyError: raise TableNotFound("<>", table_name) @@ -298,3 +301,21 @@ def resolve_merge_strategy( ) return merge_strategy return None + + +def _drop_unsupported_precision_hints( + column: TColumnSchema, destination_capabilities: DestinationCapabilitiesContext +) -> None: + if "precision" not in column: + return + + if ( + column["data_type"] == "timestamp" + and not destination_capabilities.supports_timestamp_precision_configuration + ): + column.pop("precision", None) + elif ( + column["data_type"] == "binary" + and not destination_capabilities.supports_binary_precision_configuration + ): + column.pop("precision", None) diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index adaebbb5e5..dc2467464b 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -207,6 +207,17 @@ def __str__(self) -> str: return self.job_id() +def group_jobs_by_table_name( + jobs: Iterable[ParsedLoadJobFileName], +) -> dict[str, list[ParsedLoadJobFileName]]: + """Returns dictionary with table names as keys and list of jobs for those tables as values.""" + + jobs_by_table_name: dict[str, list[ParsedLoadJobFileName]] = {} + for job in jobs: + jobs_by_table_name.setdefault(job.table_name, []).append(job) + return jobs_by_table_name + + class LoadJobInfo(NamedTuple): state: TPackageJobState file_path: str diff --git a/dlt/common/time.py b/dlt/common/time.py index 63b18b598b..7a6a8d4bad 100644 --- a/dlt/common/time.py +++ b/dlt/common/time.py @@ -19,6 +19,7 @@ PAST_TIMESTAMP: float = 0.0 FUTURE_TIMESTAMP: float = 9999999999.0 DAY_DURATION_SEC: float = 24 * 60 * 60.0 +UNIX_EPOCH_DATE = datetime.date(1970, 1, 1) precise_time: Callable[[], float] = None """A precise timer using win_precise_time library on windows and time.time on other systems""" @@ -284,6 +285,11 @@ def datetime_obj_to_str( return datatime.strftime(datetime_format) +def date_to_epoch_days(value: datetime.date) -> int: + """Converts date value to number of days since Unix epoch.""" + return value.toordinal() - UNIX_EPOCH_DATE.toordinal() + + def ensure_pendulum_time(value: Union[str, int, float, datetime.time, timedelta]) -> pendulum.Time: """Coerce a time-like value to a `pendulum.Time` object using timezone=False semantics. @@ -439,6 +445,10 @@ def datetime_to_timestamp_ms(moment: Union[datetime.datetime, pendulum.DateTime] return int(moment.timestamp() * 1000) +def datetime_to_timestamp_us(moment: Union[datetime.datetime, pendulum.DateTime]) -> int: + return datetime_to_timestamp(moment) * 1_000_000 + moment.microsecond + + def _datetime_from_ts_or_iso( value: Union[int, float, str] ) -> Union[pendulum.DateTime, pendulum.Date, pendulum.Time]: diff --git a/dlt/common/typing.py b/dlt/common/typing.py index d5d6932409..29e5233104 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -125,6 +125,10 @@ class SecretSentinel: """A single data item as extracted from data source""" TDataItems: TypeAlias = Union[TDataItem, List[TDataItem]] "A single data item or a list as extracted from the data source" +TDataRecord = dict[str, Any] +"""Table row dictionary. Not guaranteed to be JSON serializable without custom encoding.""" +TDataRecordBatch = list[TDataRecord] +"""List of table row dictionaries. Not guaranteed to be JSON serializable without custom encoding.""" TAnyDateTime = Union[pendulum.DateTime, pendulum.Date, datetime, date, str, float, int] """DateTime represented as pendulum/python object, ISO string or unix timestamp""" TTimeInterval = Tuple[datetime, datetime] diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 13988f8644..7241aa40b1 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -613,7 +613,9 @@ def with_load_id_col(self) -> dlt.Relation: self._dataset.schema.tables, self._table_name )["name"] if root_table_name == self._table_name: - raise ValueError(f"{root_table_name} is a root table, but load id column is not present.") + raise ValueError( + f"{root_table_name} is a root table, but load id column is not present." + ) join_alias = "_dlt_root" joined = self.join(root_table_name, alias=join_alias) diff --git a/dlt/destinations/exceptions.py b/dlt/destinations/exceptions.py index aae0440996..c08b386ce2 100644 --- a/dlt/destinations/exceptions.py +++ b/dlt/destinations/exceptions.py @@ -74,6 +74,11 @@ def __init__(self, file_path: str, message: str) -> None: super().__init__(f"Job with `{file_path=:}` encountered unrecoverable problem: {message}") +class LoadJobTransientException(DestinationTransientException): + def __init__(self, file_path: str, message: str) -> None: + super().__init__(f"Job with `{file_path=:}` encountered recoverable problem: {message}") + + class LoadJobInvalidStateTransitionException(DestinationTerminalException): def __init__(self, from_state: TLoadJobState, to_state: TLoadJobState) -> None: self.from_state = from_state diff --git a/dlt/destinations/file_batching.py b/dlt/destinations/file_batching.py new file mode 100644 index 0000000000..a2e5493804 --- /dev/null +++ b/dlt/destinations/file_batching.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Generic, Iterator, Sequence, Sized, TypeVar + +if TYPE_CHECKING: + from dlt.common.libs.pyarrow import pyarrow + +from dlt.common.json import json +from dlt.common.storages import FileStorage +from dlt.common.typing import TDataRecordBatch + +TRecordBatch = TypeVar("TRecordBatch", bound=Sized) + + +class FileBatchIterator(ABC, Generic[TRecordBatch]): + def __init__( + self, + file_path: str, + batch_size: int, + record_offset: int, + columns: Sequence[str] = (), + ) -> None: + self._file_path = file_path + self._batch_size = batch_size + self._record_offset = record_offset + self._columns = list(columns) + + @abstractmethod + def __iter__(self) -> Iterator[TRecordBatch]: + pass + + +class ParquetFileBatchIterator(FileBatchIterator["pyarrow.RecordBatch"]): + def __init__( + self, + file_path: str, + batch_size: int, + record_offset: int, + columns: Sequence[str] = (), + ) -> None: + super().__init__(file_path, batch_size, record_offset, columns) + self._batch_offset, remainder = divmod(self._record_offset, self._batch_size) + assert remainder == 0, "`_record_offset` must be a multiple of `_batch_size`" + + def __iter__(self) -> Iterator[pyarrow.RecordBatch]: + from dlt.common.libs.pyarrow import pyarrow + + batches_to_skip = self._batch_offset + with pyarrow.parquet.ParquetFile(self._file_path) as reader: + for record_batch in reader.iter_batches( + batch_size=self._batch_size, + columns=self._columns or None, + ): + if batches_to_skip > 0: + batches_to_skip -= 1 + continue + yield record_batch + + +class JsonlFileBatchIterator(FileBatchIterator[TDataRecordBatch]): + def __iter__(self) -> Iterator[TDataRecordBatch]: + current_batch: TDataRecordBatch = [] + records_to_skip = self._record_offset + projected_columns = set(self._columns) + + with FileStorage.open_zipsafe_ro(self._file_path) as f: + for line in f: + records = json.typed_loads(line) + if isinstance(records, dict): + records = [records] + + for record in records: + if records_to_skip > 0: + records_to_skip -= 1 + continue + if projected_columns: + record = { + key: value for key, value in record.items() if key in projected_columns + } + current_batch.append(record) + if len(current_batch) == self._batch_size: + yield current_batch + current_batch = [] + + # yield any remaining records in last partial batch + if current_batch: + yield current_batch diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 7f644c703e..a0d4110855 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -249,7 +249,6 @@ def create_load_job( self.config, # type: ignore destination_state(), _streaming_load, # type: ignore - [], callable_requires_job_client_args=True, ) else: diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index d87f4efb7e..4e795fac3c 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -1,15 +1,31 @@ +from __future__ import annotations + import dataclasses -from typing import ClassVar, Final, Optional, Any, Dict, List, List, Dict, cast, Callable +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List, Optional, Union, cast from urllib.parse import urlparse from dlt.common import logger +from dlt.common.configuration.specs.base_configuration import ( + BaseConfiguration, + CredentialsConfiguration, + configspec, +) from dlt.common.typing import TSecretStrValue -from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec from dlt.common.destination.client import DestinationClientDwhWithStagingConfiguration from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common.utils import digest128 +from dlt.destinations.impl.databricks.typing import TDatabricksInsertApi + +if TYPE_CHECKING: + from zerobus import ArrowStreamConfigurationOptions, IPCCompression + DATABRICKS_APPLICATION_ID = "dltHub_dlt" +DEFAULT_DATABRICKS_INSERT_API: TDatabricksInsertApi = "copy_into" +# ZSTD was fastest in my benchmarks out of the three `ipc_compression` options +# currently available (NONE, LZ4_FRAME, ZSTD) — NONE (no compression) was slowest +DEFAULT_DATABRICKS_ZEROBUS_IPC_COMPRESSION = "ZSTD" @configspec @@ -163,10 +179,62 @@ def to_connector_params(self) -> Dict[str, Any]: conn_params["access_token"] = self.access_token return conn_params + def to_workspace_url(self) -> str: + if not self.server_hostname: + raise ConfigurationValueError( + "Cannot construct workspace URL: `server_hostname` is not set." + ) + return f"https://{self.server_hostname}" + def __str__(self) -> str: return f"databricks://{self.server_hostname}{self.http_path}/{self.catalog}" +@configspec +class DatabricksZerobusCredentials(CredentialsConfiguration): + client_id: str = None + client_secret: TSecretStrValue = None + + +@configspec +class DatabricksZerobusConfiguration(BaseConfiguration): + endpoint_url: str = None + """URL of the Zerobus server endpoint.""" + credentials: DatabricksZerobusCredentials = None + """Credentials to authenticate to the Zerobus server.""" + batch_size: int = 25_000 + """Number of records per batch to ingest into Zerobus.""" + stream_options: Optional[dict[str, Any]] = None + """Stream configuration options forwarded to `options` argument of `ZerobusSdk.create_arrow_stream()`.""" + + def on_partial(self) -> None: + if not self.endpoint_url: + return + + if self.credentials is None: + # we'll attempt to resolve credentials later in `DatabricksClientConfiguration.on_resolved()` + self.credentials = DatabricksZerobusCredentials() + + self.resolve() + + def to_arrow_stream_configuration_options(self) -> ArrowStreamConfigurationOptions: + from zerobus import ArrowStreamConfigurationOptions + + options = deepcopy(self.stream_options) if self.stream_options else dict() + if "ipc_compression" not in options: + options["ipc_compression"] = DEFAULT_DATABRICKS_ZEROBUS_IPC_COMPRESSION + options["ipc_compression"] = self._coerce_ipc_compression(options["ipc_compression"]) + return ArrowStreamConfigurationOptions(**options) + + @staticmethod + def _coerce_ipc_compression(ipc_compression: Union[str, IPCCompression]) -> IPCCompression: + from zerobus import IPCCompression + + if isinstance(ipc_compression, str): + return getattr(IPCCompression, ipc_compression) + return ipc_compression + + @configspec class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration): destination_type: Final[str] = dataclasses.field(default="databricks", init=False, repr=False, compare=False) # type: ignore[misc] @@ -179,9 +247,12 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration """Name of the Databricks managed volume for temporary storage, e.g., ... Defaults to '_dlt_temp_load_volume' if not set.""" keep_staged_files: Optional[bool] = True """Tells if to keep the files in internal (volume) stage""" - - """Whether PRIMARY KEY or FOREIGN KEY constrains should be created""" create_indexes: bool = False + """Whether PRIMARY KEY or FOREIGN KEY constrains should be created""" + insert_api: TDatabricksInsertApi = DEFAULT_DATABRICKS_INSERT_API + """Ingestion backend for `append` write disposition. Can be overridden per resource via `databricks_adapter`.""" + zerobus: Optional[DatabricksZerobusConfiguration] = None + """Databricks Zerobus Configuration including endpoint and credentials. Required when using the `zerobus` insert API.""" def __str__(self) -> str: """Return displayable destination location""" @@ -190,6 +261,31 @@ def __str__(self) -> str: else: return "" + def on_resolved(self) -> None: + if self.zerobus is None: + return + + if self.zerobus.credentials.client_id and self.zerobus.credentials.client_secret: + return + + # fall back to main credentials if Zerobus credentials are not fully provided + if self.credentials.client_id and self.credentials.client_secret: + self.zerobus.credentials = DatabricksZerobusCredentials( + client_id=self.credentials.client_id, + client_secret=self.credentials.client_secret, + ) + self.zerobus.credentials.resolve() + + if not self.zerobus.credentials.is_resolved(): + raise ConfigurationValueError( + "`client_id` and `client_secret` are required when" + " `destination.databricks.zerobus` is configured. Set either" + " `destination.databricks.zerobus.credentials.client_id` and" + " `destination.databricks.zerobus.credentials.client_secret`, or" + " `destination.databricks.credentials.client_id` and" + " `destination.databricks.credentials.client_secret`." + ) + def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" if self.credentials and self.credentials.server_hostname: diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index e5ac49b8d5..8008248549 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -1,56 +1,104 @@ +from __future__ import annotations + import os -from typing import Any, Dict, Optional, Sequence, List, cast, Union from urllib.parse import urlparse from pathlib import Path +import os +from abc import ABC, abstractmethod +from copy import deepcopy +from functools import cached_property +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + Generic, + Iterable, + List, + Optional, + Sequence, + Union, + cast, +) +from urllib.parse import urlparse + +from dlt.common import logger +from dlt.common.configuration.exceptions import ConfigurationValueError +from dlt.common.configuration.specs import ( + AwsCredentialsWithoutDefaults, + AzureCredentialsWithoutDefaults, +) from dlt.common.configuration.specs.azure_credentials import ( AzureServicePrincipalCredentialsWithoutDefaults, ) +from dlt.common.data_types import TDataType +from dlt.common.data_writers.escape import escape_databricks_literal from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.client import ( HasFollowupJobs, FollowupJobRequest, + LoadJob, PreparedTableSchema, RunnableLoadJob, SupportsStagingDestination, - LoadJob, ) -from dlt.common.configuration.specs import ( - AwsCredentialsWithoutDefaults, - AzureCredentialsWithoutDefaults, +from dlt.common.destination.exceptions import ( + DestinationException, + DestinationInvalidFileFormat, + WriteDispositionNotSupported, ) +from dlt.common.exceptions import TerminalValueError +from dlt.common.typing import TDataRecordBatch +from dlt.common.schema import Schema, TColumnSchema, TTableSchema +from dlt.common.schema.typing import TColumnType from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.storages import FilesystemConfiguration, fsspec_from_config from dlt.common.storages.configuration import ensure_canonical_az_url from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.fsspec_filesystem import ( AZURE_BLOB_STORAGE_PROTOCOLS, - S3_PROTOCOLS, GCS_PROTOCOLS, + S3_PROTOCOLS, +) +from dlt.common.storages.load_package import ( + ParsedLoadJobFileName, + destination_state, + group_jobs_by_table_name, +) +from dlt.common.utils import uniq_id +from dlt.destinations.exceptions import LoadJobTerminalException, LoadJobTransientException +from dlt.destinations.file_batching import ( + JsonlFileBatchIterator, + ParquetFileBatchIterator, + TRecordBatch, ) +from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.impl.databricks.databricks_adapter import ( CLUSTER_HINT, - TABLE_PROPERTIES_HINT, - TABLE_COMMENT_HINT, - TABLE_TAGS_HINT, COLUMN_COMMENT_HINT, COLUMN_TAGS_HINT, + INSERT_API_HINT, + TABLE_COMMENT_HINT, + TABLE_PROPERTIES_HINT, + TABLE_TAGS_HINT, ) -from dlt.common.schema import TColumnSchema, Schema, TTableSchema, TColumnHint -from dlt.common.schema.typing import TColumnType -from dlt.common.storages import FilesystemConfiguration, fsspec_from_config -from dlt.common.utils import uniq_id -from dlt.common import logger -from dlt.common.data_writers.escape import escape_databricks_literal -from dlt.common.exceptions import TerminalValueError -from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset -from dlt.destinations.exceptions import LoadJobTerminalException -from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient +from dlt.destinations.impl.databricks.typing import TDatabricksInsertApi +from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset +from dlt.destinations.job_impl import BatchedFileLoadJob, ReferenceFollowupJobRequest from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.job_impl import ReferenceFollowupJobRequest -from dlt.destinations.impl.databricks.typing import TDatabricksColumnHint from dlt.destinations.path_utils import get_file_format_and_compression +from dlt.destinations.sql_jobs import SqlMergeFollowupJob + +if TYPE_CHECKING: + from dlt.common.libs.pyarrow import pyarrow + from zerobus.sdk.shared import ZerobusException + from zerobus.sdk.sync import ZerobusArrowStream, ZerobusSdk + SUPPORTED_BLOB_STORAGE_PROTOCOLS = AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS + GCS_PROTOCOLS @@ -309,6 +357,134 @@ def gen_delete_from_sql( """ +class DatabricksZerobusLoadJob(BatchedFileLoadJob[TRecordBatch], ABC, Generic[TRecordBatch]): + def __init__( + self, + file_path: str, + config: DatabricksClientConfiguration, + destination_state: Dict[str, int], + ) -> None: + assert config.zerobus is not None + super().__init__(file_path, config.zerobus.batch_size, destination_state) + self._config = config + self._job_client: DatabricksClient = None + self.zerobus_config = config.zerobus + + def run(self) -> None: + from zerobus.sdk.shared import ZerobusException + + stream = None + try: + self._job_client.grant_zerobus_permissions(self.load_table_name) + stream = self._create_stream() + self._ingest_batches(stream) + except ZerobusException as exc: + raise self._wrap_zerobus_exception(exc) from exc + finally: + if stream is not None: + stream.close() + + @cached_property + def zerobus_sdk(self) -> ZerobusSdk: + from zerobus.sdk.sync import ZerobusSdk + + return ZerobusSdk( + host=self.zerobus_config.endpoint_url, + unity_catalog_url=self._config.credentials.to_workspace_url(), + ) + + @cached_property + def _arrow_schema(self) -> pyarrow.Schema: + """Returns Arrow schema for the table we're streaming into.""" + + from dlt.common.libs.pyarrow import columns_to_arrow + + columns = deepcopy(self._load_table["columns"]) + + for column_name, column in columns.items(): + # DatabricksTypeMapper maps `time` to `STRING` + if column.get("data_type") == "time": + columns[column_name]["data_type"] = "text" + + return columns_to_arrow(columns, self._job_client.capabilities) + + def _create_stream(self) -> ZerobusArrowStream: + table_name = self._job_client.sql_client.make_qualified_table_name( + self.load_table_name, quote=False + ) + client_id = self.zerobus_config.credentials.client_id + client_secret = self.zerobus_config.credentials.client_secret + return self.zerobus_sdk.create_arrow_stream( + table_name, + self._arrow_schema, + client_id, + client_secret, + options=self.zerobus_config.to_arrow_stream_configuration_options(), + ) + + def _ingest_batch(self, stream: ZerobusArrowStream, batch: pyarrow.RecordBatch) -> None: + offset = stream.ingest_batch(batch) + stream.wait_for_offset(offset) + self._advance_record_offset(int(batch.num_rows)) + + def _ingest_batches(self, stream: ZerobusArrowStream) -> None: + for batch in self.iter_batches(): + self._ingest_batch(stream, self._ensure_arrow_record_batch(batch)) + + def _wrap_zerobus_exception(self, exc: ZerobusException) -> DestinationException: + from zerobus.sdk.shared import NonRetriableException + + if isinstance(exc, NonRetriableException): + return LoadJobTerminalException( + self._file_path, + f"Databricks Zerobus load failed with non-retriable error: {exc}", + ) + return LoadJobTransientException( + self._file_path, f"Databricks Zerobus load failed with retriable error: {exc}" + ) + + @abstractmethod + def _ensure_arrow_record_batch(self, batch: TRecordBatch) -> pyarrow.RecordBatch: + """Returns a `pyarrow.RecordBatch` with a schema compliant for ingestion.""" + pass + + +class DatabricksZerobusJsonlLoadJob(DatabricksZerobusLoadJob[TDataRecordBatch]): + file_batch_iterator_class = JsonlFileBatchIterator + _ARRAY_CAST_TYPES: ClassVar[frozenset[TDataType]] = frozenset({"date", "timestamp"}) + """Data types that require array-level casting because record batch-level casting fails.""" + + def _ensure_arrow_record_batch(self, batch: TDataRecordBatch) -> pyarrow.RecordBatch: + from dlt.common.libs.pyarrow import pyarrow + + if any( + column.get("data_type") in self._ARRAY_CAST_TYPES + for column in self._load_table["columns"].values() + ): + inferred_batch = pyarrow.RecordBatch.from_pylist(batch) + + return pyarrow.RecordBatch.from_arrays( + [ + ( + inferred_batch.column(field.name).cast(field.type) + if field.name in inferred_batch.column_names + else pyarrow.nulls(inferred_batch.num_rows, type=field.type) + ) + for field in self._arrow_schema + ], + schema=self._arrow_schema, + ) + + return pyarrow.RecordBatch.from_pylist(batch, schema=self._arrow_schema) + + +class DatabricksZerobusParquetLoadJob(DatabricksZerobusLoadJob["pyarrow.RecordBatch"]): + file_batch_iterator_class = ParquetFileBatchIterator + + def _ensure_arrow_record_batch(self, batch: pyarrow.RecordBatch) -> pyarrow.RecordBatch: + return batch.cast(self._arrow_schema) + + class DatabricksClient(SqlJobClientWithStagingDataset, SupportsStagingDestination): def __init__( self, @@ -340,16 +516,113 @@ def _get_column_def_sql(self, column: TColumnSchema, table: PreparedTableSchema column_def_sql = f"{column_def_sql} COMMENT {escaped_comment}" return column_def_sql + def _verify_zerobus_write_disposition(self, table: PreparedTableSchema) -> None: + if table.get(INSERT_API_HINT) != "zerobus": + return + if table["write_disposition"] != "append": + raise WriteDispositionNotSupported( + table["write_disposition"], + f"The `zerobus` insert API does not support the `{table['write_disposition']}`" + " write disposition.", + ) + + def _verify_zerobus_file_format( + self, + table: PreparedTableSchema, + table_jobs: list[ParsedLoadJobFileName], + ) -> None: + if table.get(INSERT_API_HINT) != "zerobus": + return + if model_job := next((job for job in table_jobs if job.file_format == "model"), None): + raise DestinationInvalidFileFormat( + self.config.destination_type, + model_job.file_format, + model_job.file_name(), + "The `zerobus` insert API does not support the `model` file format.", + ) + + def _verify_zerobus_tables( + self, + loaded_tables: list[PreparedTableSchema], + new_jobs: Optional[list[ParsedLoadJobFileName]], + ) -> list[PreparedTableSchema]: + zerobus_tables = [ + table for table in loaded_tables if table.get(INSERT_API_HINT) == "zerobus" + ] + jobs_by_table_name = group_jobs_by_table_name(new_jobs or []) + for table in zerobus_tables: + self._verify_zerobus_write_disposition(table) + self._verify_zerobus_file_format( + table, table_jobs=jobs_by_table_name.get(table["name"], []) + ) + return zerobus_tables + + def _verify_zerobus_configuration(self) -> None: + if self.config.zerobus is None: + raise ConfigurationValueError( + "Databricks Zerobus configuration is required when using the `zerobus` insert API." + ) + + def verify_schema( + self, only_tables: Iterable[str] = None, new_jobs: Iterable[ParsedLoadJobFileName] = None + ) -> list[PreparedTableSchema]: + new_jobs = list(new_jobs) if new_jobs is not None else None + loaded_tables = super().verify_schema(only_tables, new_jobs) + zerobus_tables = self._verify_zerobus_tables(loaded_tables, new_jobs) + if zerobus_tables: + self._verify_zerobus_configuration() + return loaded_tables + + def prepare_load_table(self, table_name: str) -> PreparedTableSchema: + table = super().prepare_load_table(table_name) + if table_name in self.schema.dlt_table_names(): + table[INSERT_API_HINT] = "copy_into" # type: ignore[typeddict-unknown-key] + elif INSERT_API_HINT not in table: + table[INSERT_API_HINT] = self.config.insert_api # type: ignore[typeddict-unknown-key] + return table + + def get_load_job_class( + self, table: PreparedTableSchema, file_path: str + ) -> Union[type[DatabricksLoadJob], type[DatabricksZerobusLoadJob[Any]]]: + insert_api: TDatabricksInsertApi = table[INSERT_API_HINT] # type: ignore[typeddict-item] + if insert_api == "copy_into": + return DatabricksLoadJob + elif insert_api == "zerobus": + if ReferenceFollowupJobRequest.is_reference_job(file_path): + raise LoadJobTerminalException( + file_path, + "The `zerobus` insert API does not support using a staging destination.", + ) + file_format = ParsedLoadJobFileName.parse(file_path).file_format + if file_format == "jsonl": + return DatabricksZerobusJsonlLoadJob + if file_format == "parquet": + return DatabricksZerobusParquetLoadJob + raise ValueError( + f"The `zerobus` insert API does not support the `{file_path}` file format." + ) + def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: job = super().create_load_job(table, file_path, load_id, restore) if not job: - job = DatabricksLoadJob( - file_path, - staging_config=cast(FilesystemConfiguration, self.config.staging_config), - ) + job_cls = self.get_load_job_class(table, file_path) + if job_cls is DatabricksLoadJob: + job_cls = cast(type[DatabricksLoadJob], job_cls) + job = job_cls( + file_path, + staging_config=cast(FilesystemConfiguration, self.config.staging_config), + ) + else: + job_cls = cast(type[DatabricksZerobusLoadJob[Any]], job_cls) + job = job_cls( + file_path, + self.config, + destination_state(), + ) + return job def _create_merge_followup_jobs( @@ -608,3 +881,17 @@ def _get_storage_table_query_columns(self) -> List[str]: def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load + + def _make_zerobus_grant_sql(self, table_name: str) -> list[str]: + zerobus = self.config.zerobus + assert zerobus is not None + principal = self.sql_client.capabilities.escape_identifier(zerobus.credentials.client_id) + schema = self.sql_client.fully_qualified_dataset_name() + table = self.sql_client.make_qualified_table_name(table_name) + return [ + f"GRANT USE SCHEMA ON SCHEMA {schema} TO {principal}", + f"GRANT MODIFY, SELECT ON TABLE {table} TO {principal}", + ] + + def grant_zerobus_permissions(self, table_name: str) -> None: + self.sql_client.execute_many(self._make_zerobus_grant_sql(table_name)) diff --git a/dlt/destinations/impl/databricks/databricks_adapter.py b/dlt/destinations/impl/databricks/databricks_adapter.py index c44aad881d..bd76e039ec 100644 --- a/dlt/destinations/impl/databricks/databricks_adapter.py +++ b/dlt/destinations/impl/databricks/databricks_adapter.py @@ -2,10 +2,13 @@ from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.typing import TColumnNames +from dlt.destinations.impl.databricks.typing import ( + TDatabricksInsertApi, + TDatabricksTableSchemaColumns, +) from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource from dlt.extract.items import TTableHintTemplate -from dlt.destinations.impl.databricks.typing import TDatabricksTableSchemaColumns CLUSTER_HINT: Literal["x-databricks-cluster"] = "x-databricks-cluster" @@ -14,6 +17,7 @@ TABLE_PROPERTIES_HINT: Literal["x-databricks-table-properties"] = "x-databricks-table-properties" COLUMN_COMMENT_HINT: Literal["x-databricks-column-comment"] = "x-databricks-column-comment" COLUMN_TAGS_HINT: Literal["x-databricks-column-tags"] = "x-databricks-column-tags" +INSERT_API_HINT: Literal["x-insert-api"] = "x-insert-api" def databricks_adapter( @@ -25,6 +29,7 @@ def databricks_adapter( table_tags: Optional[List[Union[str, Dict[str, str]]]] = None, table_properties: Optional[Dict[str, Union[str, int, bool, float]]] = None, column_hints: Optional[TDatabricksTableSchemaColumns] = None, + insert_api: Optional[TDatabricksInsertApi] = None, ) -> DltResource: """ Prepares data for loading into Databricks. @@ -54,6 +59,14 @@ def databricks_adapter( The supported hints are: - `column_comment` - adds a comment to the column. Supports basic markdown format [basic-syntax](https://www.markdownguide.org/cheat-sheet/#basic-syntax). - `column_tags` - adds tags to the column. Supports a list of strings and/or key-value pairs. + insert_api (Optional[TDatabricksInsertApi], optional): Backend for `append` write disposition. + Supported values are: + - `copy_into`: insert records using Databricks `COPY INTO` command + - `zerobus`: insert records using Databricks Zerobus + + Destination falls back to `DatabricksClientConfiguration.insert_api` if `insert_api` + is not specified in `databricks_adapter`. `dlt` system tables ignore any `insert_api` + configuration, and always use `copy_into`. Returns: A `DltResource` object that is ready to be loaded into Databricks. @@ -197,8 +210,18 @@ def databricks_adapter( additional_table_hints[TABLE_PROPERTIES_HINT] = table_properties + if insert_api: + if insert_api == "zerobus" and resource.write_disposition != "append": + raise ValueError( + f"Cannot use `zerobus` insert API with `{resource.write_disposition}` write" + " disposition. `zerobus` insert API only supports `append` write disposition." + ) + additional_table_hints[INSERT_API_HINT] = insert_api + resource.apply_hints( - columns=cast(TTableSchemaColumns, additional_column_hints), + columns=( + cast(TTableSchemaColumns, additional_column_hints) if additional_column_hints else None + ), additional_table_hints=additional_table_hints, ) diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py index 875860ba2d..f91d2ee0ae 100644 --- a/dlt/destinations/impl/databricks/factory.py +++ b/dlt/destinations/impl/databricks/factory.py @@ -1,8 +1,8 @@ -from typing import Any, Optional, Type, Union, Dict, TYPE_CHECKING, Sequence, Tuple +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Type, Union -from dlt.common import logger from dlt.common.data_types.typing import TDataType from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.destination.configuration import ParquetFormatConfiguration from dlt.common.data_writers.escape import escape_databricks_identifier, escape_databricks_literal from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.destination.typing import PreparedTableSchema @@ -12,15 +12,30 @@ from dlt.destinations.type_mapping import TypeMapperImpl from dlt.destinations.impl.databricks.configuration import ( - DatabricksCredentials, + DEFAULT_DATABRICKS_INSERT_API, DatabricksClientConfiguration, + DatabricksCredentials, + DatabricksZerobusConfiguration, ) +from dlt.destinations.impl.databricks.databricks_adapter import INSERT_API_HINT +from dlt.destinations.impl.databricks.typing import TDatabricksInsertApi if TYPE_CHECKING: from dlt.destinations.impl.databricks.databricks import DatabricksClient class DatabricksTypeMapper(TypeMapperImpl): + UNSUPPORTED_TYPES: ClassVar[ + dict[tuple[TDatabricksInsertApi, TLoaderFileFormat], frozenset[TDataType]] + ] = { + (DEFAULT_DATABRICKS_INSERT_API, "parquet"): frozenset({"time"}), + (DEFAULT_DATABRICKS_INSERT_API, "jsonl"): frozenset({"decimal", "wei", "binary", "date"}), + (DEFAULT_DATABRICKS_INSERT_API, "model"): frozenset(), + ("zerobus", "parquet"): frozenset({"decimal", "wei"}), + ("zerobus", "jsonl"): frozenset({"decimal", "wei", "binary", "json"}), + ("zerobus", "model"): frozenset(), + } + sct_to_unbound_dbt = { "json": "STRING", # Json type stored as string "text": "STRING", @@ -60,26 +75,29 @@ def ensure_supported_type( table: PreparedTableSchema, loader_file_format: TLoaderFileFormat, ) -> None: - if loader_file_format == "jsonl": - if column["data_type"] in { - "decimal", - "wei", - "binary", - "json", - "date", - }: - raise TerminalValueError("", column["data_type"]) - if column["data_type"] == "timestamp" and column.get("timezone") is False: - raise TerminalValueError( - "Cannot load naive timestamps from json, use parquet", column["data_type"] - ) - if loader_file_format == "parquet": - if column["data_type"] in {"time"}: - raise TerminalValueError( - "Spark can't read Time from parquet. Convert your time column to string or" - " change file format.", - column["data_type"], - ) + insert_api = table[INSERT_API_HINT] # type: ignore[typeddict-item] + unsupported_types = self.UNSUPPORTED_TYPES[(insert_api, loader_file_format)] + + if insert_api == "copy_into": + if loader_file_format == "jsonl": + if column["data_type"] == "timestamp" and column.get("timezone") is False: + raise TerminalValueError( + "Cannot load naive timestamps from json, use parquet", column["data_type"] + ) + elif loader_file_format == "parquet": + if column["data_type"] == "time": + raise TerminalValueError( + "Spark can't read Time from parquet. Convert your time column to string" + " or change file format.", + column["data_type"], + ) + + if column["data_type"] in unsupported_types: + raise TerminalValueError( + f"The `{insert_api}` insert API does not support data type" + f" `{column['data_type']}` with loader file format `{loader_file_format}`.", + column["data_type"], + ) def to_db_integer_type(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: precision = column.get("precision") @@ -102,18 +120,7 @@ def to_db_datetime_type( column: TColumnSchema, table: PreparedTableSchema = None, ) -> str: - column_name = column["name"] - table_name = table["name"] - timezone = column.get("timezone", True) - precision = column.get("precision") - - if precision and precision != self.capabilities.timestamp_precision: - logger.warn( - f"Databricks does not support precision {precision} for column '{column_name}' in" - f" table '{table_name}'. Will default to 6." - ) - - return "TIMESTAMP" if timezone else "TIMESTAMP_NTZ" + return "TIMESTAMP" if column.get("timezone", True) else "TIMESTAMP_NTZ" def from_destination_type( self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None @@ -153,6 +160,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.has_case_sensitive_identifiers = False caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.supports_timestamp_precision_configuration = False + caps.supports_binary_precision_configuration = False caps.max_identifier_length = 255 caps.max_column_identifier_length = 255 caps.max_query_length = 2 * 1024 * 1024 @@ -172,6 +181,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: "staging-optimized", ] caps.sqlglot_dialect = "databricks" + # Databricks timestamps always have microsecond precision + caps.parquet_format = ParquetFormatConfiguration(coerce_timestamps="us") return caps @@ -190,6 +201,8 @@ def __init__( environment: str = None, staging_volume_name: str = None, create_indexes: bool = False, + insert_api: TDatabricksInsertApi = DEFAULT_DATABRICKS_INSERT_API, + zerobus: DatabricksZerobusConfiguration = None, **kwargs: Any, ) -> None: """Configure the Databricks destination to use in a pipeline. @@ -205,6 +218,9 @@ def __init__( environment (str, optional): Environment of the destination staging_volume_name (str, optional): Name of the staging volume to use create_indexes (bool, optional): Whether PRIMARY KEY or FOREIGN KEY constrains should be created + insert_api (TDatabricksInsertApi, optional): Ingestion backend for `append` write + disposition. Can be overridden per resource via `databricks_adapter`. + zerobus (DatabricksZerobusConfiguration, optional): Zerobus configuration including Zerobus endpoint, credentials, batch size, and optional Arrow stream settings. **kwargs (Any): Additional arguments passed to the destination config """ super().__init__( @@ -215,6 +231,8 @@ def __init__( environment=environment, staging_volume_name=staging_volume_name, create_indexes=create_indexes, + insert_api=insert_api, + zerobus=zerobus, **kwargs, ) diff --git a/dlt/destinations/impl/databricks/typing.py b/dlt/destinations/impl/databricks/typing.py index ee846d8871..9750969f92 100644 --- a/dlt/destinations/impl/databricks/typing.py +++ b/dlt/destinations/impl/databricks/typing.py @@ -1,6 +1,8 @@ -from typing import Optional, List, Dict, Union, Literal -from dlt.common.schema.typing import TColumnSchema, TColumnHint +from typing import Dict, List, Literal, Optional, Union +from dlt.common.schema.typing import TColumnHint, TColumnSchema + +TDatabricksInsertApi = Literal["copy_into", "zerobus"] TDatabricksColumnHint = Union[TColumnHint, Literal["foreign_key"]] diff --git a/dlt/destinations/impl/destination/destination.py b/dlt/destinations/impl/destination/destination.py index ef105db271..e655e2bc0f 100644 --- a/dlt/destinations/impl/destination/destination.py +++ b/dlt/destinations/impl/destination/destination.py @@ -60,12 +60,6 @@ def create_load_job( if is_dlt_table_or_column(table["name"], self.schema._dlt_tables_prefix): return FinalizedLoadJob(file_path) - skipped_columns: List[str] = [] - if self.config.skip_dlt_columns_and_tables: - for column in list(self.schema.get_table(table["name"])["columns"].keys()): - if is_dlt_table_or_column(column, self.schema._dlt_tables_prefix): - skipped_columns.append(column) - # save our state in destination name scope load_state = destination_state() @@ -76,7 +70,6 @@ def create_load_job( self.config, load_state, self.destination_callable, - skipped_columns, ) if parsed_file.file_format in ["jsonl", "typed-jsonl"]: return DestinationJsonlLoadJob( @@ -84,7 +77,6 @@ def create_load_job( self.config, load_state, self.destination_callable, - skipped_columns, ) return None diff --git a/dlt/destinations/job_impl.py b/dlt/destinations/job_impl.py index 1791f5a5d4..e56c26f008 100644 --- a/dlt/destinations/job_impl.py +++ b/dlt/destinations/job_impl.py @@ -1,10 +1,11 @@ -from abc import ABC, abstractmethod +from __future__ import annotations + import os import tempfile # noqa: 251 -from typing import Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, Generic, Iterable, List, Optional + from dlt.common import pendulum -from dlt.common.json import json from dlt.common.destination.client import ( HasFollowupJobs, TLoadJobState, @@ -14,15 +15,23 @@ ) from dlt.common.destination.exceptions import DestinationTerminalException from dlt.common.storages.load_package import commit_load_package_state -from dlt.common.storages import FileStorage -from dlt.common.typing import TDataItems +from dlt.common.typing import TDataItems, TDataRecordBatch from dlt.common.storages.load_storage import ParsedLoadJobFileName +from dlt.destinations.file_batching import ( + FileBatchIterator, + JsonlFileBatchIterator, + ParquetFileBatchIterator, + TRecordBatch, +) from dlt.destinations.impl.destination.configuration import ( CustomDestinationClientConfiguration, TDestinationCallable, ) +if TYPE_CHECKING: + from dlt.common.libs.pyarrow import pyarrow + class FinalizedLoadJob(LoadJob): """ @@ -142,39 +151,65 @@ def resolve_reference(file_path: str) -> str: return refs[0] -class DestinationLoadJob(RunnableLoadJob, ABC): +class BatchedFileLoadJob(RunnableLoadJob, Generic[TRecordBatch]): + file_batch_iterator_class: type[FileBatchIterator[TRecordBatch]] + + def __init__(self, file_path: str, batch_size: int, destination_state: Dict[str, int]) -> None: + super().__init__(file_path) + self._batch_size = batch_size + self._destination_state = destination_state + self._state_key = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" + + @property + def _record_offset(self) -> int: + return self._destination_state.get(self._state_key, 0) + + def iter_batches(self) -> Iterable[TRecordBatch]: + return self.file_batch_iterator_class( + self._file_path, + self._batch_size, + self._record_offset, + list(self._load_table["columns"].keys()), + ) + + def _advance_record_offset(self, processed_count: int) -> None: + self._destination_state[self._state_key] = self._record_offset + processed_count + commit_load_package_state() + + def _process_batches(self, process_batch: Callable[[TRecordBatch], int]) -> None: + for batch in self.iter_batches(): + processed_count = process_batch(batch) + self._advance_record_offset(processed_count) + + +class DestinationLoadJob(BatchedFileLoadJob[TRecordBatch]): def __init__( self, file_path: str, config: CustomDestinationClientConfiguration, destination_state: Dict[str, int], destination_callable: TDestinationCallable, - skipped_columns: List[str], callable_requires_job_client_args: bool = False, ) -> None: - super().__init__(file_path) + super().__init__(file_path, config.batch_size, destination_state) self._config = config self._callable = destination_callable - self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" - self._skipped_columns = skipped_columns - self._destination_state = destination_state self._callable_requires_job_client_args = callable_requires_job_client_args def run(self) -> None: # update filepath, it will be in running jobs now - if self._config.batch_size == 0: + if self._batch_size == 0: # on batch size zero we only call the callable with the filename self.call_callable_with_items(self._file_path) # save progress commit_load_package_state() else: - current_index = self._destination_state.get(self._storage_id, 0) - for batch in self.get_batches(current_index): + + def process_batch(batch: TRecordBatch) -> int: self.call_callable_with_items(batch) - current_index += len(batch) - self._destination_state[self._storage_id] = current_index - # save progress - commit_load_package_state() + return len(batch) + + self._process_batches(process_batch) def call_callable_with_items(self, items: TDataItems) -> None: if not items: @@ -185,56 +220,10 @@ def call_callable_with_items(self, items: TDataItems) -> None: else: self._callable(items, self._load_table) - @abstractmethod - def get_batches(self, start_index: int) -> Iterable[TDataItems]: - pass - - -class DestinationParquetLoadJob(DestinationLoadJob): - def get_batches(self, start_index: int) -> Iterable[TDataItems]: - # stream items - from dlt.common.libs.pyarrow import pyarrow - - # guard against changed batch size after restart of loadjob - assert ( - start_index % self._config.batch_size - ) == 0, "Batch size was changed during processing of one load package" - - # on record batches we cannot drop columns, we need to - # select the ones we want to keep - keep_columns = list(self._load_table["columns"].keys()) - start_batch = start_index / self._config.batch_size - with pyarrow.parquet.ParquetFile(self._file_path) as reader: - for record_batch in reader.iter_batches( - batch_size=self._config.batch_size, columns=keep_columns - ): - if start_batch > 0: - start_batch -= 1 - continue - yield record_batch - - -class DestinationJsonlLoadJob(DestinationLoadJob): - def get_batches(self, start_index: int) -> Iterable[TDataItems]: - current_batch: TDataItems = [] - - # stream items - with FileStorage.open_zipsafe_ro(self._file_path) as f: - for line in f: - encoded_json = json.typed_loads(line) - if isinstance(encoded_json, dict): - encoded_json = [encoded_json] - - for item in encoded_json: - # find correct start position - if start_index > 0: - start_index -= 1 - continue - # skip internal columns - for column in self._skipped_columns: - item.pop(column, None) - current_batch.append(item) - if len(current_batch) == self._config.batch_size: - yield current_batch - current_batch = [] - yield current_batch + +class DestinationParquetLoadJob(DestinationLoadJob["pyarrow.RecordBatch"]): + file_batch_iterator_class = ParquetFileBatchIterator + + +class DestinationJsonlLoadJob(DestinationLoadJob[TDataRecordBatch]): + file_batch_iterator_class = JsonlFileBatchIterator diff --git a/docs/uv.lock b/docs/uv.lock index 713e15246a..121d8f827b 100644 --- a/docs/uv.lock +++ b/docs/uv.lock @@ -618,6 +618,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/15/a674e16ebbe36190cf79cbf2c25e5dc3b5f0203b471d9e9e98b7c8e70381/databricks_sql_connector-4.1.4-py3-none-any.whl", hash = "sha256:cabe1640412c240b328291d7155c280570892961ce56d0529593f354e9958727", size = 202303, upload-time = "2025-10-15T17:45:33.078Z" }, ] +[[package]] +name = "databricks-zerobus-ingest-sdk" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3b/60/e6ae4785e933eec4e9cec823798f65a8946c4c2bba0c43b97eb1e1489dff/databricks_zerobus_ingest_sdk-1.2.0.tar.gz", hash = "sha256:5330f1bf7544fcc016de34e18b7bbfc66ec589dbaf59bbd530185d19b444d750", size = 65236, upload-time = "2026-04-27T13:15:55.728Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/95/0c8073eed3cb85a5bb6dc80b129c6b6afaf432b4b67e299f1b20d8c56a56/databricks_zerobus_ingest_sdk-1.2.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e48e82744e06cff8ef4704131e0aa295ec053b71098ac83155491f5dc72e77f5", size = 6167350, upload-time = "2026-04-27T13:15:50.592Z" }, + { url = "https://files.pythonhosted.org/packages/5b/13/b63f615e997f7bb5eaa3f20fe10217d064820a84e18a794784f1b3ad6009/databricks_zerobus_ingest_sdk-1.2.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:c8cbd504efc749e939fe4f7ebf5b7f5fd5bf520452aef932f1213e3bab9a69f0", size = 6169884, upload-time = "2026-04-27T13:15:53.131Z" }, + { url = "https://files.pythonhosted.org/packages/78/1d/9aedb2f20edffb5113a71d79343f9cb7ba7a1fc5a6e1f6249169444c8749/databricks_zerobus_ingest_sdk-1.2.0-cp39-abi3-win_amd64.whl", hash = "sha256:72d426d997cbc209d352953b2af5dfc78099cb0ed3d6d7fc81efd71ed521237e", size = 5024590, upload-time = "2026-04-27T13:15:54.507Z" }, +] + [[package]] name = "db-dtypes" version = "1.4.3" @@ -876,6 +891,7 @@ bigquery = [ databricks = [ { name = "databricks-sdk" }, { name = "databricks-sql-connector" }, + { name = "databricks-zerobus-ingest-sdk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32') or (platform_machine == 'x86_64' and sys_platform == 'win32')" }, ] duckdb = [ { name = "duckdb" }, @@ -925,6 +941,7 @@ requires-dist = [ { name = "databricks-sdk", marker = "extra == 'databricks'", specifier = ">=0.38.0" }, { name = "databricks-sql-connector", marker = "python_full_version >= '3.13' and extra == 'databricks'", specifier = ">=3.6.0" }, { name = "databricks-sql-connector", marker = "python_full_version < '3.13' and extra == 'databricks'", specifier = ">=2.9.3" }, + { name = "databricks-zerobus-ingest-sdk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'databricks') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'databricks') or (platform_machine == 'AMD64' and sys_platform == 'win32' and extra == 'databricks') or (platform_machine == 'x86_64' and sys_platform == 'win32' and extra == 'databricks')", specifier = ">=1.2.0" }, { name = "db-dtypes", marker = "extra == 'bigquery'", specifier = ">=1.2.0" }, { name = "db-dtypes", marker = "extra == 'gcp'", specifier = ">=1.2.0" }, { name = "deltalake", marker = "extra == 'deltalake'", specifier = ">=0.25.1" }, @@ -971,6 +988,7 @@ requires-dist = [ { name = "pip", marker = "extra == 'cli'", specifier = ">=23.0.0" }, { name = "pipdeptree", marker = "extra == 'cli'", specifier = ">=2.9.3,<2.10" }, { name = "pluggy", specifier = ">=1.3.0" }, + { name = "polars", marker = "extra == 'polars'", specifier = ">=1.0.0" }, { name = "psycopg2-binary", marker = "extra == 'postgis'", specifier = ">=2.9.1" }, { name = "psycopg2-binary", marker = "extra == 'postgres'", specifier = ">=2.9.1" }, { name = "psycopg2-binary", marker = "extra == 'redshift'", specifier = ">=2.9.1" }, @@ -1024,7 +1042,7 @@ requires-dist = [ { name = "weaviate-client", marker = "extra == 'weaviate'", specifier = ">=4.0.0,<5.0.0" }, { name = "win-precise-time", marker = "python_full_version < '3.13' and os_name == 'nt'", specifier = ">=1.4.2" }, ] -provides-extras = ["hub", "gcp", "bigquery", "postgres", "redshift", "parquet", "duckdb", "ducklake", "filesystem", "s3", "gs", "az", "hf", "sftp", "http", "snowflake", "motherduck", "cli", "athena", "weaviate", "mssql", "synapse", "fabric", "oracle", "qdrant", "databricks", "clickhouse", "dremio", "lancedb", "lance", "deltalake", "sql-database", "sqlalchemy", "pyiceberg", "postgis", "workspace", "dbml"] +provides-extras = ["hub", "gcp", "bigquery", "postgres", "redshift", "parquet", "polars", "duckdb", "ducklake", "filesystem", "s3", "gs", "az", "hf", "sftp", "http", "snowflake", "motherduck", "cli", "athena", "weaviate", "mssql", "synapse", "fabric", "oracle", "qdrant", "databricks", "clickhouse", "dremio", "lancedb", "lance", "deltalake", "sql-database", "sqlalchemy", "pyiceberg", "postgis", "workspace", "dbml"] [package.metadata.requires-dev] adbc = [ @@ -1727,7 +1745,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7d/ed/6bfa4109fcb23a58819600392564fea69cdc6551ffd5e69ccf1d52a40cbc/greenlet-3.2.4-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:8c68325b0d0acf8d91dde4e6f930967dd52a5302cd4062932a6b2e7c2969f47c", size = 271061, upload-time = "2025-08-07T13:17:15.373Z" }, { url = "https://files.pythonhosted.org/packages/2a/fc/102ec1a2fc015b3a7652abab7acf3541d58c04d3d17a8d3d6a44adae1eb1/greenlet-3.2.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:94385f101946790ae13da500603491f04a76b6e4c059dab271b3ce2e283b2590", size = 629475, upload-time = "2025-08-07T13:42:54.009Z" }, { url = "https://files.pythonhosted.org/packages/c5/26/80383131d55a4ac0fb08d71660fd77e7660b9db6bdb4e8884f46d9f2cc04/greenlet-3.2.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f10fd42b5ee276335863712fa3da6608e93f70629c631bf77145021600abc23c", size = 640802, upload-time = "2025-08-07T13:45:25.52Z" }, - { url = "https://files.pythonhosted.org/packages/9f/7c/e7833dbcd8f376f3326bd728c845d31dcde4c84268d3921afcae77d90d08/greenlet-3.2.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c8c9e331e58180d0d83c5b7999255721b725913ff6bc6cf39fa2a45841a4fd4b", size = 636703, upload-time = "2025-08-07T13:53:12.622Z" }, { url = "https://files.pythonhosted.org/packages/e9/49/547b93b7c0428ede7b3f309bc965986874759f7d89e4e04aeddbc9699acb/greenlet-3.2.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:58b97143c9cc7b86fc458f215bd0932f1757ce649e05b640fea2e79b54cedb31", size = 635417, upload-time = "2025-08-07T13:18:25.189Z" }, { url = "https://files.pythonhosted.org/packages/7f/91/ae2eb6b7979e2f9b035a9f612cf70f1bf54aad4e1d125129bef1eae96f19/greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d", size = 584358, upload-time = "2025-08-07T13:18:23.708Z" }, { url = "https://files.pythonhosted.org/packages/f7/85/433de0c9c0252b22b16d413c9407e6cb3b41df7389afc366ca204dbc1393/greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5", size = 1113550, upload-time = "2025-08-07T13:42:37.467Z" }, @@ -1738,7 +1755,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/de/f28ced0a67749cac23fecb02b694f6473f47686dff6afaa211d186e2ef9c/greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2", size = 272305, upload-time = "2025-08-07T13:15:41.288Z" }, { url = "https://files.pythonhosted.org/packages/09/16/2c3792cba130000bf2a31c5272999113f4764fd9d874fb257ff588ac779a/greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246", size = 632472, upload-time = "2025-08-07T13:42:55.044Z" }, { url = "https://files.pythonhosted.org/packages/ae/8f/95d48d7e3d433e6dae5b1682e4292242a53f22df82e6d3dda81b1701a960/greenlet-3.2.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:94abf90142c2a18151632371140b3dba4dee031633fe614cb592dbb6c9e17bc3", size = 644646, upload-time = "2025-08-07T13:45:26.523Z" }, - { url = "https://files.pythonhosted.org/packages/d5/5e/405965351aef8c76b8ef7ad370e5da58d57ef6068df197548b015464001a/greenlet-3.2.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:4d1378601b85e2e5171b99be8d2dc85f594c79967599328f95c1dc1a40f1c633", size = 640519, upload-time = "2025-08-07T13:53:13.928Z" }, { url = "https://files.pythonhosted.org/packages/25/5d/382753b52006ce0218297ec1b628e048c4e64b155379331f25a7316eb749/greenlet-3.2.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0db5594dce18db94f7d1650d7489909b57afde4c580806b8d9203b6e79cdc079", size = 639707, upload-time = "2025-08-07T13:18:27.146Z" }, { url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684, upload-time = "2025-08-07T13:18:25.164Z" }, { url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647, upload-time = "2025-08-07T13:42:38.655Z" }, @@ -1749,7 +1765,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" }, { url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" }, { url = "https://files.pythonhosted.org/packages/3b/16/035dcfcc48715ccd345f3a93183267167cdd162ad123cd93067d86f27ce4/greenlet-3.2.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f28588772bb5fb869a8eb331374ec06f24a83a9c25bfa1f38b6993afe9c1e968", size = 655185, upload-time = "2025-08-07T13:45:27.624Z" }, - { url = "https://files.pythonhosted.org/packages/31/da/0386695eef69ffae1ad726881571dfe28b41970173947e7c558d9998de0f/greenlet-3.2.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:5c9320971821a7cb77cfab8d956fa8e39cd07ca44b6070db358ceb7f8797c8c9", size = 649926, upload-time = "2025-08-07T13:53:15.251Z" }, { url = "https://files.pythonhosted.org/packages/68/88/69bf19fd4dc19981928ceacbc5fd4bb6bc2215d53199e367832e98d1d8fe/greenlet-3.2.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c60a6d84229b271d44b70fb6e5fa23781abb5d742af7b808ae3f6efd7c9c60f6", size = 651839, upload-time = "2025-08-07T13:18:30.281Z" }, { url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" }, { url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" }, diff --git a/docs/website/docs/dlt-ecosystem/destinations/databricks.md b/docs/website/docs/dlt-ecosystem/destinations/databricks.md index f668ff6af1..7f303fceab 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/databricks.md +++ b/docs/website/docs/dlt-ecosystem/destinations/databricks.md @@ -516,6 +516,89 @@ This destination fully supports [dlt state sync](../../general-usage/state#synci We enable Databricks to identify that the connection is created by `dlt`. Databricks will use this user agent identifier to better understand the usage patterns associated with dlt integration. The connection identifier is `dltHub_dlt`. +## Databricks Zerobus + +By default, `dlt` uses [COPY INTO](https://docs.databricks.com/aws/en/sql/language-manual/delta-copy-into) statements to load data. It is also possible to ingest data using [Databricks Zerobus](https://docs.databricks.com/aws/en/ingestion/zerobus-overview). This is only supported for the `append` write disposition. + +### Enable Zerobus + +:::note +Databricks Zerobus is currently supported on Linux and Windows only. macOS is not supported because the `databricks-zerobus-ingest-sdk` package does not publish macOS wheels. +::: + +```toml +[destination.databricks] +insert_api = "zerobus" + +[destination.databricks.credentials] +catalog = "your-catalog" +server_hostname = "adb-1234567890123456.7.azuredatabricks.net" +http_path = "/sql/1.0/warehouses/1234567890abcdef" +client_id = "your-client-id" +client_secret = "your-client-secret" + +[destination.databricks.zerobus] +endpoint_url = "https://" +``` + +See the Databricks guide on [how to construct your Zerobus endpoint URL](https://docs.databricks.com/aws/en/ingestion/zerobus-ingest#get-your-workspace-url-and-zerobus-ingest-endpoint). + +### Dedicated Zerobus Service Principal + +By default, `dlt` uses the `client_id` and `client_secret` from `destination.databricks.credentials` for both the regular Databricks SQL connection and the Zerobus stream. + +If you want the Zerobus connection to use a different service principal, configure `destination.databricks.zerobus.credentials` explicitly: + +```toml +[destination.databricks.credentials] +client_id = "your-sql-client-id" +client_secret = "your-sql-client-secret" + +[destination.databricks.zerobus.credentials] +client_id = "your-zerobus-client-id" +client_secret = "your-zerobus-client-secret" +``` + +Concerns are separated as follows: + +1. **SQL principal** — defined in `destination.databricks.credentials`: authenticates the regular Databricks SQL connection, creates or updates the table, and grants the [necessary privileges](https://docs.databricks.com/aws/en/ingestion/zerobus-ingest#create-a-service-principal-and-grant-permissions) to the Zerobus principal +2. **Zerobus principal** — defined in `destination.databricks.zerobus.credentials`: authenticates the Zerobus stream that ingests data into the table + +### Tune batch size and stream options + +`batch_size` controls how many records `dlt` sends in each Zerobus batch. `stream_options` are mapped to `ArrowStreamConfigurationOptions` and passed to the Zerobus SDK when creating the stream. + +```toml +[destination.databricks.zerobus] +batch_size = 100_000 # default is 25_000 +stream_options = {ipc_compression = "NONE"} # default `ipc_compression` is `ZSTD` +``` + +:::note +In our internal benchmarking, a batch size of 25_000 with `ZSTD` compression performed best. Since this depends on the workload, you may want to experiment with different settings. +::: + +### File formats and type support + +Both `parquet` and `jsonl` are supported. We strongly recommend `parquet` for best performance and the broadest data type support. + +| File format | Unsupported with `zerobus` | +| --- | --- | +| `parquet` | `decimal`, `wei` | +| `jsonl` | `decimal`, `wei`, `binary`, `json` | + +### Concurrent Zerobus streams + +`dlt` opens one Zerobus stream per load job. When a load is split into multiple jobs, multiple Zerobus streams can run at the same time. This can increase throughput and reduce load times. + +See the [Load](../../reference/performance.md#load) section of the performance guide to learn how to split a load into multiple concurrent jobs. As with batch size, the right setting depends on your workload. + +### Other notes + +- Zerobus provides *at-least-once* guarantees, so a destination table may contain duplicates +- `dlt` uses Arrow-backed Zerobus messages (as opposed to JSON- or protobuf-backed messages) +- `dlt` system tables always use `copy_into`, even when you set `insert_api` to `zerobus` + ## Databricks adapter You can use the `databricks_adapter` function to add Databricks-specific hints to a resource. These hints influence how data is loaded into Databricks tables, such as adding comments and tags. Hints can be defined at both the column level and table level. @@ -531,6 +614,7 @@ The adapter updates the DltResource with metadata about the destination column a - `table_comment`: Adds a comment to the table. Supports basic markdown format [basic-syntax](https://www.markdownguide.org/cheat-sheet/#basic-syntax) - `table_tags`: Adds tags to the table. Supports a list of strings and/or key-value pairs - `table_properties`: Dictionary of table properties for Delta Lake optimization (TBLPROPERTIES) +- `insert_api`: Ingestion backend for `append` write disposition. Can be `"copy_into"` or `"zerobus"`. Overrides the destination-wide `insert_api` setting for the resource. **Column-level hints:** - `column_hints`: Dictionary of column-specific hints @@ -576,6 +660,23 @@ databricks_adapter( ) ``` +### Override the insert API for one resource + +Use `databricks_adapter(..., insert_api=...)` when you want a resource to use a different insert API than the rest of the destination. + +```py +import dlt +from dlt.destinations.adapters import databricks_adapter + +@dlt.resource(write_disposition="append") +def events(): + yield from [{"id": 1}, {"id": 2}] + +databricks_adapter(events, insert_api="zerobus") +``` + +This hint only applies to `append` loads. Use `insert_api="copy_into"` to opt a resource out of a destination-wide Zerobus default. + ### Advanced examples #### Clustering and partitioning @@ -826,4 +927,3 @@ If this workaround is necessary, validate your setup after each platform upgrade ::: - diff --git a/pyproject.toml b/pyproject.toml index b6920623ab..9118fd886d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,6 +165,7 @@ databricks = [ "databricks-sql-connector>=2.9.3 ; python_version <= '3.12'", "databricks-sql-connector>=3.6.0 ; python_version >= '3.13'", "databricks-sdk>=0.38.0", + "databricks-zerobus-ingest-sdk>=1.2.0 ; (platform_system == 'Linux' and (platform_machine == 'x86_64' or platform_machine == 'aarch64')) or (platform_system == 'Windows' and (platform_machine == 'AMD64' or platform_machine == 'x86_64'))", ] clickhouse = [ "clickhouse-driver>=0.2.7", @@ -506,6 +507,7 @@ module = [ "adbc_driver_manager.*", "playwright.*", "lancedb.*", + "zerobus.*", ] ignore_missing_imports = true diff --git a/tests/common/destination/test_destination_capabilities.py b/tests/common/destination/test_destination_capabilities.py index 2e81547687..ef60a94d10 100644 --- a/tests/common/destination/test_destination_capabilities.py +++ b/tests/common/destination/test_destination_capabilities.py @@ -1,7 +1,11 @@ +from typing import cast + import pytest +from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.destination.exceptions import DestinationCapabilitiesException, UnsupportedDataType from dlt.common.destination.utils import ( + prepare_load_table, resolve_merge_strategy, verify_schema_capabilities, verify_supported_data_types, @@ -12,6 +16,7 @@ from dlt.common.schema.utils import new_table from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.destinations.impl.bigquery.bigquery_adapter import AUTODETECT_SCHEMA_HINT +from dlt.destinations.impl.databricks.databricks_adapter import INSERT_API_HINT def test_resolve_merge_strategy() -> None: @@ -137,6 +142,7 @@ def test_verify_capabilities_data_types() -> None: assert exceptions[0].available_in_formats == ["model"] # time not supported on databricks + schema.tables["table"][INSERT_API_HINT] = "copy_into" # type: ignore[typeddict-unknown-key] exceptions = verify_supported_data_types( schema.tables.values(), new_jobs_parquet, databricks().capabilities(), "databricks" # type: ignore[arg-type] ) @@ -150,6 +156,36 @@ def test_verify_capabilities_data_types() -> None: assert exceptions[0].column == "col2" assert set(exceptions[0].available_in_formats) == {"parquet", "model"} + # decimal and wei not supported on databricks zerobus + schema_zerobus = Schema("schema_zerobus") + table = new_table( + "table", + write_disposition="append", + columns=[ + # supported types + {"name": "date_col", "data_type": "date"}, + {"name": "time_col", "data_type": "time"}, + # unsupported types + {"name": "decimal_col", "data_type": "decimal"}, + {"name": "wei_col", "data_type": "wei"}, + ], + ) + table[INSERT_API_HINT] = "zerobus" # type: ignore[typeddict-unknown-key] + schema_zerobus.update_table(table, normalize_identifiers=False) + exceptions = verify_supported_data_types( + schema_zerobus.tables.values(), # type: ignore[arg-type] + new_jobs_parquet, + databricks().capabilities(), + "databricks", + ) + assert len(exceptions) == 2 + assert all(isinstance(exception, UnsupportedDataType) for exception in exceptions) + unsupported_exceptions = cast(list[UnsupportedDataType], exceptions) + assert {exception.column for exception in unsupported_exceptions} == { + "decimal_col", + "wei_col", + } + # exclude binary type if precision is set on column schema_bin = Schema("schema_bin") table = new_table( @@ -238,3 +274,43 @@ def test_verify_capabilities_data_types() -> None: ) assert len(exceptions) == 1 assert isinstance(exceptions[0], TerminalValueError) + + +def test_prepare_load_table_drops_unsupported_precision_hints() -> None: + schema = Schema("foo") + table_name = "bar" + table = new_table( + table_name, + columns=[ + {"name": "ts", "data_type": "timestamp", "precision": 3}, + {"name": "bin", "data_type": "binary", "precision": 16}, + ], + ) + schema.update_table(table) + + caps = DestinationCapabilitiesContext() + caps.supports_timestamp_precision_configuration = True + caps.supports_binary_precision_configuration = True + + prepared_table = prepare_load_table( + schema.tables, + schema.tables[table_name], + destination_capabilities=caps, + ) + + assert "precision" in prepared_table["columns"]["ts"] + assert "precision" in prepared_table["columns"]["bin"] + + caps.supports_timestamp_precision_configuration = False + caps.supports_binary_precision_configuration = False + + prepared_table = prepare_load_table( + schema.tables, + schema.tables[table_name], + destination_capabilities=caps, + ) + + assert "precision" not in prepared_table["columns"]["ts"] + assert "precision" not in prepared_table["columns"]["bin"] + assert "precision" in schema.tables[table_name]["columns"]["ts"] + assert "precision" in schema.tables[table_name]["columns"]["bin"] diff --git a/tests/common/test_time.py b/tests/common/test_time.py index d7c7afe3b2..b5feecd6bf 100644 --- a/tests/common/test_time.py +++ b/tests/common/test_time.py @@ -15,6 +15,7 @@ from dlt.common.time import ( MonotonicPreciseTime, LockedMonotonicPreciseTime, + date_to_epoch_days, increasing_precise_time, precise_time, parse_iso_like_datetime, @@ -984,6 +985,11 @@ def test_create_load_id_strictly_increasing() -> None: assert restored >= baseline +def test_date_to_epoch_days() -> None: + assert date_to_epoch_days(date(1970, 1, 1)) == 0 + assert date_to_epoch_days(pendulum.date(1970, 1, 2)) == 1 + + @pytest.mark.parametrize( "fn, args", [ diff --git a/tests/load/databricks/test_databricks_adapter.py b/tests/load/databricks/test_databricks_adapter.py index 32c3db2089..41e1499f34 100644 --- a/tests/load/databricks/test_databricks_adapter.py +++ b/tests/load/databricks/test_databricks_adapter.py @@ -2,13 +2,17 @@ import pytest import dlt +from dlt.common.schema.typing import TWriteDisposition from dlt.common.utils import uniq_id from dlt.destinations.adapters import databricks_adapter from dlt.destinations import databricks from dlt.destinations.impl.databricks.databricks_adapter import ( - CLUSTER_HINT, - TABLE_PROPERTIES_HINT, + COLUMN_COMMENT_HINT, + INSERT_API_HINT, ) +from dlt.destinations.impl.databricks.configuration import DEFAULT_DATABRICKS_INSERT_API +from dlt.destinations.impl.databricks.factory import DatabricksTypeMapper +from tests.cases import table_update_and_row from tests.load.utils import ( destinations_configs, DestinationTestConfiguration, @@ -540,19 +544,22 @@ def demo_source(): def test_databricks_adapter_iceberg_all_data_types( destination_config: DestinationTestConfiguration, ) -> None: - """Test ICEBERG table format with all supported dlt data types""" - from tests.cases import TABLE_UPDATE, TABLE_ROW_ALL_DATA_TYPES + """Test ICEBERG table format with all supported data types.""" + + columns, data_row = table_update_and_row( + exclude_types=tuple( + DatabricksTypeMapper.UNSUPPORTED_TYPES[(DEFAULT_DATABRICKS_INSERT_API, "parquet")] + ), + exclude_columns=("col1_precision",), + ) pipeline = destination_config.setup_pipeline( f"databricks_{uniq_id()}", dev_mode=True, destination=databricks() ) - # Create columns dict from TABLE_UPDATE - columns = {col["name"]: col for col in TABLE_UPDATE} - @dlt.resource(columns=columns, primary_key="col1") def iceberg_all_types() -> Iterator[Dict[str, Any]]: - yield TABLE_ROW_ALL_DATA_TYPES + yield data_row # Apply ICEBERG format databricks_adapter(iceberg_all_types, table_format="ICEBERG") @@ -592,10 +599,10 @@ def demo_source(): f" {pipeline.dataset_name}.iceberg_all_types" ) as cur: row = cur.fetchone() - assert row[0] == TABLE_ROW_ALL_DATA_TYPES["col1"] # bigint - assert abs(row[1] - TABLE_ROW_ALL_DATA_TYPES["col2"]) < 0.001 # double - assert row[2] == TABLE_ROW_ALL_DATA_TYPES["col3"] # bool - assert row[3] == TABLE_ROW_ALL_DATA_TYPES["col5"] # text + assert row[0] == data_row["col1"] # bigint + assert abs(row[1] - data_row["col2"]) < 0.001 # double + assert row[2] == data_row["col3"] # bool + assert row[3] == data_row["col5"] # text def test_databricks_adapter_invalid_table_format(): @@ -683,3 +690,49 @@ def dummy_resource(): dummy_resource, table_properties={"test_prop": {"nested": "value"}}, # type: ignore[dict-item] ) + + +def test_databricks_adapter_preserves_existing_columns() -> None: + # define resource with column hint + res = dlt.resource([{"id": 1}], name="foo", columns={"id": {"data_type": "bigint"}}) + + # apply adapter without `column_hints` + column_hints = None + databricks_adapter(res, column_hints=column_hints) + column = res.compute_table_schema()["columns"]["id"] + assert column["data_type"] == "bigint" # existing column hint preserved + + # apply adapter with `column_hints` + column_hints = {"id": {"column_comment": "foo"}} + databricks_adapter(res, column_hints=column_hints) # type: ignore[arg-type] + column = res.compute_table_schema()["columns"]["id"] + assert column["data_type"] == "bigint" # existing column hint preserved + assert ( # new column hint added + column[COLUMN_COMMENT_HINT] == "foo" # type: ignore[typeddict-item] + ) + + +def test_databricks_adapter_insert_api() -> None: + default_res = dlt.resource([{"id": 1}], name="default", write_disposition="append") + zerobus_res = dlt.resource([{"id": 1}], name="zerobus", write_disposition="append") + + databricks_adapter(default_res, insert_api=None) + databricks_adapter(zerobus_res, insert_api="zerobus") + + default_table_schema = default_res.compute_table_schema() + zerobus_table_schema = zerobus_res.compute_table_schema() + + assert INSERT_API_HINT not in default_table_schema + assert zerobus_table_schema[INSERT_API_HINT] == "zerobus" # type: ignore[typeddict-item] + + +@pytest.mark.parametrize("write_disposition", ("replace", "merge")) +def test_databricks_adapter_zerobus_insert_api_requires_append( + write_disposition: TWriteDisposition, +) -> None: + res = dlt.resource([{"id": 1}], name="foo", write_disposition=write_disposition) + + with pytest.raises(ValueError): + databricks_adapter(res, insert_api="zerobus") + + databricks_adapter(res, insert_api="copy_into") # `copy_into` should not raise diff --git a/tests/load/databricks/test_databricks_client.py b/tests/load/databricks/test_databricks_client.py new file mode 100644 index 0000000000..604c697a81 --- /dev/null +++ b/tests/load/databricks/test_databricks_client.py @@ -0,0 +1,250 @@ +from dataclasses import dataclass +from typing import Any, Iterator, Optional, cast + +import pytest +from pytest_mock import MockerFixture +from zerobus import IPCCompression +from zerobus.sdk.shared import NonRetriableException, ZerobusException + +from dlt.common.configuration.exceptions import ConfigurationValueError +from dlt.common.destination.client import LoadJob, TLoadJobState +from dlt.common.destination.exceptions import ( + DestinationInvalidFileFormat, + WriteDispositionNotSupported, +) +from dlt.common.schema.typing import TWriteDisposition +from dlt.common.schema.utils import new_table +from dlt.common.storages.load_package import ParsedLoadJobFileName +from dlt.common.utils import uniq_id +from dlt.destinations.exceptions import LoadJobTerminalException, LoadJobTransientException +from dlt.destinations.impl.databricks.configuration import ( + DatabricksClientConfiguration, + DatabricksZerobusConfiguration, + DatabricksZerobusCredentials, +) +from dlt.destinations.impl.databricks.databricks import ( + DatabricksClient, + DatabricksLoadJob, + DatabricksZerobusJsonlLoadJob, + DatabricksZerobusParquetLoadJob, +) +from dlt.destinations.impl.databricks.databricks_adapter import INSERT_API_HINT +from dlt.destinations.impl.databricks.typing import TDatabricksInsertApi +from tests.load.utils import yield_client + + +pytestmark = pytest.mark.essential + + +@pytest.fixture(scope="function") +def client() -> Iterator[DatabricksClient]: + dataset_name = "test_" + uniq_id() + yield from cast( + Iterator[DatabricksClient], + # skip entering the client to avoid starting the Databricks cluster, which takes multiple + # minutes and is not necessary for these tests + yield_client("databricks", dataset_name=dataset_name, enter_client=False), + ) + + +@pytest.mark.parametrize( + ("config_insert_api", "table_insert_api", "expected_insert_api"), + [ + ("copy_into", None, "copy_into"), + ("zerobus", None, "zerobus"), + ("zerobus", "copy_into", "copy_into"), + ("copy_into", "zerobus", "zerobus"), + ], +) +def test_databricks_client_prepare_load_table_resolves_insert_api( + client: DatabricksClient, + config_insert_api: TDatabricksInsertApi, + table_insert_api: Optional[TDatabricksInsertApi], + expected_insert_api: TDatabricksInsertApi, +) -> None: + client.config.insert_api = config_insert_api + table = new_table("items", write_disposition="append") + if table_insert_api is not None: + table[INSERT_API_HINT] = table_insert_api # type: ignore[typeddict-unknown-key] + client.schema.update_table(table) + + prepared_table = client.prepare_load_table("items") + prepared_dlt_table = client.prepare_load_table(client.schema.version_table_name) + + assert prepared_table[INSERT_API_HINT] == expected_insert_api # type: ignore[typeddict-item] + # dlt tables should disregard `insert_api` configuration and always use `copy_into` + assert prepared_dlt_table[INSERT_API_HINT] == "copy_into" # type: ignore[typeddict-item] + + +def test_databricks_client_verify_schema_zerobus_file_format(client: DatabricksClient) -> None: + """Asserts exception is raised if `zerobus` insert API is used with `model` file format.""" + + table = new_table("items", write_disposition="append") + table[INSERT_API_HINT] = "zerobus" # type: ignore[typeddict-unknown-key] + client.schema.update_table(table) + + with pytest.raises(DestinationInvalidFileFormat) as exc_info: + client.verify_schema( + ["items"], + [ParsedLoadJobFileName.parse("items.1.1.model")], + ) + + assert exc_info.value.file_format == "model" + + +@pytest.mark.parametrize("write_disposition", ("replace", "merge")) +def test_databricks_client_verify_schema_zerobus_write_disposition( + client: DatabricksClient, + write_disposition: TWriteDisposition, +) -> None: + """Asserts exception is raised if `zerobus` insert API is used with non-`append` write disposition.""" + + table = new_table("items", write_disposition=write_disposition) + table[INSERT_API_HINT] = "zerobus" # type: ignore[typeddict-unknown-key] + client.schema.update_table(table) + + with pytest.raises(WriteDispositionNotSupported) as exc_info: + client.verify_schema(["items"]) + + assert exc_info.value.write_disposition == write_disposition + + +def test_databricks_client_verify_schema_zerobus_requires_config( + client: DatabricksClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Asserts Zerobus configuration must be present when `zerobus` insert API is used.""" + + table = new_table("items", write_disposition="append") + table[INSERT_API_HINT] = "zerobus" # type: ignore[typeddict-unknown-key] + client.schema.update_table(table) + monkeypatch.setattr(client.config, "zerobus", None) # temporarily remove Zerobus configuration + + with pytest.raises(ConfigurationValueError, match="Zerobus configuration is required"): + client.verify_schema(["items"]) + + +@pytest.mark.parametrize( + ("insert_api", "file_extension", "expected_class", "expected_exception_match"), + [ + (None, "jsonl", DatabricksLoadJob, None), + ("copy_into", "parquet", DatabricksLoadJob, None), + ("copy_into", "reference", DatabricksLoadJob, None), + ("zerobus", "jsonl", DatabricksZerobusJsonlLoadJob, None), + ("zerobus", "parquet", DatabricksZerobusParquetLoadJob, None), + ("zerobus", "reference", None, "does not support using a staging destination"), + ], +) +def test_databricks_client_get_load_job_class( + client: DatabricksClient, + insert_api: Optional[TDatabricksInsertApi], + file_extension: str, + expected_class: Optional[type[LoadJob]], + expected_exception_match: Optional[str], +) -> None: + table_name = "foo" + table = new_table(table_name, write_disposition="append") + if insert_api is not None: + table[INSERT_API_HINT] = insert_api # type: ignore[typeddict-unknown-key] + client.schema.update_table(table) + + prepared_table = client.prepare_load_table(table_name) + file_path = f"{table_name}.1.1.{file_extension}" + + if expected_exception_match is None: + assert client.get_load_job_class(prepared_table, file_path) is expected_class + else: + with pytest.raises(LoadJobTerminalException, match=expected_exception_match): + client.get_load_job_class(prepared_table, file_path) + + +def test_databricks_zerobus_load_job_calls_create_arrow_stream_with_expected_args( + mocker: MockerFixture, +) -> None: + @dataclass + class FakeSqlClient: + def make_qualified_table_name(self, table_name: str, quote: bool = False) -> str: + return "catalog.schema.items" + + @dataclass + class FakeJobClient: + sql_client: FakeSqlClient + + @dataclass + class FakeZerobusSdk: + create_arrow_stream: Any + + create_arrow_stream = mocker.Mock() + zerobus_config = DatabricksZerobusConfiguration( + credentials=DatabricksZerobusCredentials( + client_id="client-id", client_secret="client-secret" + ), + stream_options={"ipc_compression": "LZ4_FRAME", "max_inflight_batches": 32}, + ) + job = DatabricksZerobusParquetLoadJob( + "/tmp/items.1.1.parquet", + DatabricksClientConfiguration(zerobus=zerobus_config), + {}, + ) + job._job_client = cast(Any, FakeJobClient(sql_client=FakeSqlClient())) + job._load_table = {"name": "items"} + job.zerobus_sdk = FakeZerobusSdk(create_arrow_stream=create_arrow_stream) + job._arrow_schema = object() + + job._create_stream() + + args, kwargs = create_arrow_stream.call_args + create_arrow_stream.assert_called_once() + assert args == ( + "catalog.schema.items", + job._arrow_schema, + "client-id", + "client-secret", + ) + assert kwargs["options"].ipc_compression == IPCCompression.LZ4_FRAME + assert kwargs["options"].max_inflight_batches == 32 + + +@pytest.mark.parametrize( + ("zerobus_exception", "expected_exception", "expected_state"), + [ + pytest.param( + ZerobusException("retriable failure"), + LoadJobTransientException, + "retry", + id="retriable", + ), + pytest.param( + NonRetriableException("terminal failure"), + LoadJobTerminalException, + "failed", + id="non-retriable", + ), + ], +) +def test_databricks_zerobus_load_job_error_handling( + mocker: MockerFixture, + zerobus_exception: ZerobusException, + expected_exception: type[Exception], + expected_state: TLoadJobState, +) -> None: + class FakeJobClient: + def prepare_load_job_execution(self, job: LoadJob) -> None: + pass + + def grant_zerobus_permissions(self, table_name: str) -> None: + pass + + job = DatabricksZerobusParquetLoadJob( + "/tmp/items.1.1.parquet", + DatabricksClientConfiguration(zerobus=DatabricksZerobusConfiguration()), + {}, + ) + job._load_table = {"name": "items"} + mocker.patch.object(job, "_create_stream", side_effect=zerobus_exception) + + job.run_managed(FakeJobClient(), None) # type: ignore[arg-type] + + assert isinstance(job.exception(), expected_exception) + assert job.exception().__cause__ is zerobus_exception + assert job.state() == expected_state diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index 1ec84c14ef..21ca0cd603 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -1,3 +1,5 @@ +from typing import Optional + import pytest import os @@ -11,12 +13,15 @@ from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob from dlt.common.configuration import resolve_configuration +from zerobus import IPCCompression from dlt.destinations import databricks from dlt.destinations.impl.databricks.configuration import ( DatabricksClientConfiguration, DATABRICKS_APPLICATION_ID, DatabricksCredentials, + DatabricksZerobusConfiguration, + DatabricksZerobusCredentials, ) # mark all tests as essential, do not remove @@ -429,3 +434,115 @@ def test_default_warehouse() -> None: )._bind_dataset_name(dataset_name="my-dataset-1234") ) assert config.credentials.http_path == "/sql/1.0/warehouses/588dbd71bd802f4d" + + +@pytest.mark.parametrize( + "zerobus_credentials", + [ + pytest.param(None, id="without-zerobus-credentials"), + pytest.param( + DatabricksZerobusCredentials(client_id="zerobus-client-id"), + id="without-zerobus-client-secret", + ), + pytest.param( + DatabricksZerobusCredentials(client_secret="zerobus-client-secret"), + id="without-zerobus-client-id", + ), + ], +) +def test_databricks_zerobus_credentials_fall_back_to_databricks_credentials( + zerobus_credentials: Optional[DatabricksZerobusCredentials], +) -> None: + config = resolve_configuration( + DatabricksClientConfiguration( + credentials=DatabricksCredentials( + catalog="foo", + server_hostname="foo", + http_path="foo", + client_id="sql-client-id", + client_secret="sql-client-secret", + ), + zerobus=DatabricksZerobusConfiguration( + endpoint_url="foo", credentials=zerobus_credentials + ), + )._bind_dataset_name(dataset_name="foo") + ) + + assert config.zerobus is not None + assert config.zerobus.credentials is not None + assert config.zerobus.credentials.client_id == "sql-client-id" + assert config.zerobus.credentials.client_secret == "sql-client-secret" + + +def test_databricks_zerobus_credentials_fallback_requires_oauth_credentials() -> None: + with pytest.raises( + ConfigurationValueError, + match=( + "`client_id` and `client_secret` are required when" + " `destination.databricks.zerobus` is configured" + ), + ): + resolve_configuration( + DatabricksClientConfiguration( + credentials=DatabricksCredentials( + catalog="foo", + server_hostname="foo", + http_path="foo", + access_token="foo", + ), + zerobus=DatabricksZerobusConfiguration(endpoint_url="foo"), + )._bind_dataset_name(dataset_name="foo") + ) + + +def test_databricks_zerobus_credentials_take_precedence() -> None: + config = resolve_configuration( + DatabricksClientConfiguration( + credentials=DatabricksCredentials( + catalog="foo", + server_hostname="foo", + http_path="foo", + client_id="sql-client-id", + client_secret="sql-client-secret", + ), + zerobus=DatabricksZerobusConfiguration( + endpoint_url="foo", + credentials=DatabricksZerobusCredentials( + client_id="zerobus-client-id", + client_secret="zerobus-client-secret", + ), + ), + )._bind_dataset_name(dataset_name="foo") + ) + + assert config.zerobus is not None + assert config.zerobus.credentials is not None + assert config.zerobus.credentials.client_id == "zerobus-client-id" + assert config.zerobus.credentials.client_secret == "zerobus-client-secret" + + +def test_databricks_zerobus_stream_options_setting() -> None: + options = DatabricksZerobusConfiguration( + stream_options={ + "ipc_compression": "LZ4_FRAME", + "max_inflight_batches": 16, + "recovery": False, + }, + ).to_arrow_stream_configuration_options() + + assert options.ipc_compression == IPCCompression.LZ4_FRAME + assert options.max_inflight_batches == 16 + assert options.recovery is False + + +def test_databricks_zerobus_stream_options_defaults() -> None: + options = DatabricksZerobusConfiguration().to_arrow_stream_configuration_options() + + assert options.ipc_compression == IPCCompression.ZSTD + + +def test_databricks_zerobus_stream_options_reject_invalid_values() -> None: + with pytest.raises(AttributeError, match="foo"): + DatabricksZerobusConfiguration( + stream_options={"ipc_compression": "foo"}, + ).to_arrow_stream_configuration_options() diff --git a/tests/load/pipeline/test_databricks_zerobus.py b/tests/load/pipeline/test_databricks_zerobus.py new file mode 100644 index 0000000000..175676e528 --- /dev/null +++ b/tests/load/pipeline/test_databricks_zerobus.py @@ -0,0 +1,223 @@ +import os +from typing import Any, Callable, Sequence + +import dlt +import pytest + +from dlt.common import sleep +from dlt.common.data_types import TDataType +from dlt.common.typing import TLoaderFileFormat +from dlt.destinations.impl.databricks.databricks import DatabricksZerobusJsonlLoadJob +from dlt.destinations.impl.databricks.factory import DatabricksTypeMapper +from tests.cases import assert_all_data_types_row, table_update_and_row +from tests.load.utils import DestinationTestConfiguration, destinations_configs +from tests.pipeline.utils import assert_load_info + + +pytestmark = pytest.mark.essential + + +def query_rows_eventually( + dataset: dlt.Dataset, + query: str, + rows_ready: Callable[[Sequence[Sequence[Any]]], bool], + max_attempts: int = 24, + poll_interval_seconds: int = 5, +) -> list[tuple[Any, ...]]: + """Run query repeatedly until rows satisfy readiness check or we exhaust attempts. + + Useful for Databricks Zerobus, which provides eventual consistency. + """ + + rows: list[tuple[Any, ...]] = [] + for _ in range(max_attempts): + rows = dataset(query).fetchall() + if rows_ready(rows): + return rows + sleep(poll_interval_seconds) + assert rows_ready(rows), f"Timed out waiting for rows to satisfy readiness check: {query}" + return rows + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(include_cids="databricks_zerobus"), + ids=lambda x: x.name, +) +@pytest.mark.parametrize( + ("file_format", "extra_exclude_types"), + [ + pytest.param("parquet", set(), id="parquet"), + pytest.param("jsonl", set(), id="jsonl"), + pytest.param( + "jsonl", + set(DatabricksZerobusJsonlLoadJob._ARRAY_CAST_TYPES), + id="jsonl-no-array-cast-types", + ), + ], +) +def test_databricks_zerobus_data_types( + destination_config: DestinationTestConfiguration, + file_format: TLoaderFileFormat, + extra_exclude_types: set[TDataType], +) -> None: + """Tests all data types `dlt` supports for `zerobus` insert API.""" + + unsupported_types = DatabricksTypeMapper.UNSUPPORTED_TYPES[("zerobus", file_format)] + exclude_types = set(unsupported_types) | extra_exclude_types + columns, data_row = table_update_and_row(exclude_types=tuple(exclude_types)) + + @dlt.resource( + write_disposition="append", + columns=columns, + file_format=file_format, + ) + def data_types(): + yield data_row + + # insert row with all supported data types + pipe = destination_config.setup_pipeline("test_databricks_zerobus_data_types", dev_mode=True) + info = pipe.run(data_types, **destination_config.run_kwargs) + assert_load_info(info) + + # assert inserted row has expected values + selected_columns = ", ".join(columns.keys()) + rows = query_rows_eventually( + dataset=pipe.dataset(), + query=f"SELECT {selected_columns} FROM data_types ORDER BY col1 LIMIT 1", + # "greater than" because Zerobus can deliver duplicates (at-least-once guarantee) + rows_ready=lambda rows: len(rows) > 0, + ) + assert_all_data_types_row( + pipe.destination_client().capabilities, + rows[0], + expected_row=data_row, + schema=columns, + ) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(include_cids="databricks_zerobus"), + ids=lambda x: x.name, +) +def test_databricks_zerobus_concurrent_streams( + destination_config: DestinationTestConfiguration, +) -> None: + """Asserts multiple Zerobus jobs can stream concurrently into the same table.""" + + # NOTE: run this test with `-s` to see Zerobus SDK logs, which show the lifecycle of the + # multiple streams nicely + + n_rows = 3 + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "1" # force 1 row per file to create multiple jobs + os.environ["LOAD__WORKERS"] = str(n_rows) # allow concurrent loads for all rows + + rows = [{"id": i} for i in range(n_rows)] + + @dlt.resource(write_disposition="append") + def my_resource(): + yield rows + + pipe = destination_config.setup_pipeline( + "test_databricks_zerobus_concurrent_streams", dev_mode=True + ) + info = pipe.run(my_resource) + assert_load_info(info) + + # pipeline used one job (stream) per row, and all jobs completed successfully + completed_jobs = [ + job + for job in info.load_packages[0].jobs["completed_jobs"] + if job.job_file_info.table_name == my_resource.table_name + ] + assert len(completed_jobs) == len(rows) + + # all expected rows are eventually available in the table — we use set comparison because + # Zerobus can deliver duplicates (at-least-once guarantee) + expected_rows = [(row["id"],) for row in rows] + observed_rows = query_rows_eventually( + dataset=pipe.dataset(), + query=f"SELECT id FROM {my_resource.table_name} ORDER BY id", + rows_ready=lambda rows: set(rows) == set(expected_rows), + ) + assert set(observed_rows) == set(expected_rows) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(include_cids="databricks_zerobus"), + ids=lambda x: x.name, +) +def test_databricks_zerobus_schema_evolution_add_column( + destination_config: DestinationTestConfiguration, +) -> None: + """Asserts the `zerobus` insert API handles column additions.""" + + initial_data = [{"id": 1, "foo": "a"}] + evolved_data = [{"id": 2, "foo": "b", "new": "x"}] + + @dlt.resource(name="items", write_disposition="append", file_format="parquet") + def items(data): + yield data + + pipe = destination_config.setup_pipeline( + "test_databricks_zerobus_schema_evolution_add_column", dev_mode=True + ) + + info = pipe.run(items(initial_data)) + assert_load_info(info) + + info = pipe.run(items(evolved_data)) + assert_load_info(info) + + table_schema = pipe.default_schema.tables[items.table_name] # type: ignore[index] + assert "new" in table_schema["columns"] + + expected_rows = [(1, "a", None), (2, "b", "x")] + observed_rows = query_rows_eventually( + dataset=pipe.dataset(), + query=f"SELECT DISTINCT id, foo, new FROM {items.table_name} ORDER BY id", + rows_ready=lambda rows: set(rows) == set(expected_rows), + ) + assert set(observed_rows) == set(expected_rows) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(include_cids="databricks_zerobus"), + ids=lambda x: x.name, +) +def test_databricks_zerobus_schema_evolution_alter_column( + destination_config: DestinationTestConfiguration, +) -> None: + """Asserts the `zerobus` insert API handles column type changes.""" + + initial_data = [{"id": 1, "foo": 1}] # `foo` is `bigint` + evolved_data = [{"id": 2, "foo": "b"}] # `foo` is `text` + + @dlt.resource(name="items", write_disposition="append", file_format="parquet") + def items(data): + yield data + + pipe = destination_config.setup_pipeline( + "test_databricks_zerobus_schema_evolution_alter_column", dev_mode=True + ) + + info = pipe.run(items(initial_data)) + assert_load_info(info) + + info = pipe.run(items(evolved_data)) + assert_load_info(info) + + table_schema = pipe.default_schema.tables[items.table_name] # type: ignore[index] + assert "foo" in table_schema["columns"] + assert "foo__v_text" in table_schema["columns"] # variant column created + + expected_rows = [(1, 1, None), (2, None, "b")] + observed_rows = query_rows_eventually( + dataset=pipe.dataset(), + query=f"SELECT DISTINCT id, foo, foo__v_text FROM {items.table_name} ORDER BY id", + rows_ready=lambda rows: set(rows) == set(expected_rows), + ) + assert set(observed_rows) == set(expected_rows) diff --git a/tests/load/utils.py b/tests/load/utils.py index 6e5af61ce1..6c4d650243 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -339,6 +339,7 @@ def destinations_configs( table_format_filesystem_configs: bool = False, table_format_local_configs: bool = False, read_only_sqlclient_configs: bool = False, + include_cids: Union[str, Sequence[str]] = (), subset: Sequence[str] = (), bucket_subset: Sequence[str] = (), exclude: Sequence[str] = (), @@ -371,6 +372,7 @@ def destinations_configs( table_format_local_configs: Include delta and iceberg configs for local file bucket only. read_only_sqlclient_configs: Include all configs that support read-only SQL client (filesystem with all buckets, table formats, and lancedb). + include_cids: Include configs by configuration id. Active Destination Filtering: The candidate list is first filtered to include only destinations in ACTIVE_DESTINATIONS @@ -428,19 +430,17 @@ def destinations_configs( # build destination configs destination_configs: List[DestinationTestConfiguration] = [] - # default sql configs that are also default staging configs - default_sql_configs_with_staging = [ - # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. + cid_configs = [ DestinationTestConfiguration( - destination_type="athena", cid="athena", + destination_type="athena", file_format="parquet", supports_merge=False, bucket_url=AWS_BUCKET, ), DestinationTestConfiguration( - destination_type="athena", cid="athena-iceberg", + destination_type="athena", file_format="parquet", bucket_url=AWS_BUCKET, supports_merge=True, @@ -448,8 +448,8 @@ def destinations_configs( table_format="iceberg", ), DestinationTestConfiguration( - destination_type="athena", cid="athena-s3-tables", + destination_type="athena", file_format="parquet", bucket_url=AWS_BUCKET, supports_merge=True, @@ -458,6 +458,21 @@ def destinations_configs( table_format="iceberg", naming_convention="s3_tables", ), + DestinationTestConfiguration( + cid="databricks_zerobus", + destination_type="databricks", + destination_name="databricks_zerobus", + env_vars={"DESTINATION__INSERT_API": "zerobus"}, + ), + ] + cid_configs_by_cid = {config.cid: config for config in cid_configs if config.cid is not None} + + # default sql configs that are also default staging configs + default_sql_configs_with_staging = [ + # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. + cid_configs_by_cid["athena"], + cid_configs_by_cid["athena-iceberg"], + cid_configs_by_cid["athena-s3-tables"], ] lance_configs = [ @@ -904,6 +919,16 @@ def destinations_configs( ] destination_configs += lance_configs + if include_cids: + if isinstance(include_cids, str): + include_cids = (include_cids,) + existing_cids = {config.cid for config in destination_configs if config.cid is not None} + destination_configs += [ + config + for config in cid_configs + if config.cid in include_cids and config.cid not in existing_cids + ] + try: # register additional destinations from _addons.py which must be placed in the same folder # as tests @@ -1223,6 +1248,7 @@ def yield_client( dataset_name: str = None, default_config_values: StrAny = None, schema_name: str = "event", + enter_client: bool = True, ) -> Iterator[SqlJobClientBase]: os.environ.pop("DATASET_NAME", None) # import destination reference by name @@ -1268,15 +1294,19 @@ def yield_client( ) ) ): - with destination.client(schema, dest_config) as client: # type: ignore[assignment] - try: - from dlt.destinations.impl.duckdb.sql_client import WithTableScanners - - # open table scanners automatically, context manager above does not do that - if issubclass(client.sql_client_class, WithTableScanners): - client.sql_client.open_connection() - except (ImportError, MissingDependencyException): - pass + client = destination.client(schema, dest_config) # type: ignore[assignment] + if enter_client: + with client: + try: + from dlt.destinations.impl.duckdb.sql_client import WithTableScanners + + # open table scanners automatically, context manager above does not do that + if issubclass(client.sql_client_class, WithTableScanners): + client.sql_client.open_connection() + except (ImportError, MissingDependencyException): + pass + yield client + else: yield client @@ -1286,8 +1316,9 @@ def cm_yield_client( dataset_name: str, default_config_values: StrAny = None, schema_name: str = "event", + enter_client: bool = True, ) -> Iterator[SqlJobClientBase]: - return yield_client(destination, dataset_name, default_config_values, schema_name) + return yield_client(destination, dataset_name, default_config_values, schema_name, enter_client) def yield_client_with_storage( diff --git a/uv.lock b/uv.lock index 7aba02c767..18beb1b947 100644 --- a/uv.lock +++ b/uv.lock @@ -1978,6 +1978,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/15/a674e16ebbe36190cf79cbf2c25e5dc3b5f0203b471d9e9e98b7c8e70381/databricks_sql_connector-4.1.4-py3-none-any.whl", hash = "sha256:cabe1640412c240b328291d7155c280570892961ce56d0529593f354e9958727", size = 202303, upload-time = "2025-10-15T17:45:33.078Z" }, ] +[[package]] +name = "databricks-zerobus-ingest-sdk" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf", marker = "sys_platform != 'emscripten'" }, + { name = "requests", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3b/60/e6ae4785e933eec4e9cec823798f65a8946c4c2bba0c43b97eb1e1489dff/databricks_zerobus_ingest_sdk-1.2.0.tar.gz", hash = "sha256:5330f1bf7544fcc016de34e18b7bbfc66ec589dbaf59bbd530185d19b444d750", size = 65236, upload-time = "2026-04-27T13:15:55.728Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/95/0c8073eed3cb85a5bb6dc80b129c6b6afaf432b4b67e299f1b20d8c56a56/databricks_zerobus_ingest_sdk-1.2.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e48e82744e06cff8ef4704131e0aa295ec053b71098ac83155491f5dc72e77f5", size = 6167350, upload-time = "2026-04-27T13:15:50.592Z" }, + { url = "https://files.pythonhosted.org/packages/5b/13/b63f615e997f7bb5eaa3f20fe10217d064820a84e18a794784f1b3ad6009/databricks_zerobus_ingest_sdk-1.2.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:c8cbd504efc749e939fe4f7ebf5b7f5fd5bf520452aef932f1213e3bab9a69f0", size = 6169884, upload-time = "2026-04-27T13:15:53.131Z" }, + { url = "https://files.pythonhosted.org/packages/78/1d/9aedb2f20edffb5113a71d79343f9cb7ba7a1fc5a6e1f6249169444c8749/databricks_zerobus_ingest_sdk-1.2.0-cp39-abi3-win_amd64.whl", hash = "sha256:72d426d997cbc209d352953b2af5dfc78099cb0ed3d6d7fc81efd71ed521237e", size = 5024590, upload-time = "2026-04-27T13:15:54.507Z" }, +] + [[package]] name = "datasets" version = "4.5.0" @@ -2520,6 +2535,7 @@ clickhouse = [ databricks = [ { name = "databricks-sdk" }, { name = "databricks-sql-connector" }, + { name = "databricks-zerobus-ingest-sdk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine == 'AMD64' and sys_platform == 'win32') or (platform_machine == 'x86_64' and sys_platform == 'win32')" }, ] dbml = [ { name = "pydbml" }, @@ -2806,6 +2822,7 @@ requires-dist = [ { name = "databricks-sdk", marker = "extra == 'databricks'", specifier = ">=0.38.0" }, { name = "databricks-sql-connector", marker = "python_full_version >= '3.13' and extra == 'databricks'", specifier = ">=3.6.0" }, { name = "databricks-sql-connector", marker = "python_full_version < '3.13' and extra == 'databricks'", specifier = ">=2.9.3" }, + { name = "databricks-zerobus-ingest-sdk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'databricks') or (platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'databricks') or (platform_machine == 'AMD64' and sys_platform == 'win32' and extra == 'databricks') or (platform_machine == 'x86_64' and sys_platform == 'win32' and extra == 'databricks')", specifier = ">=1.2.0" }, { name = "db-dtypes", marker = "extra == 'bigquery'", specifier = ">=1.2.0" }, { name = "db-dtypes", marker = "extra == 'gcp'", specifier = ">=1.2.0" }, { name = "deltalake", marker = "extra == 'deltalake'", specifier = ">=0.25.1" },