diff --git a/dlt/common/libs/sqlglot.py b/dlt/common/libs/sqlglot.py index 58599344b7..af2d9435c5 100644 --- a/dlt/common/libs/sqlglot.py +++ b/dlt/common/libs/sqlglot.py @@ -948,7 +948,7 @@ def bind_query( qualified_query: sge.Query, sqlglot_schema: Any, # SQLGlotSchema *, - expand_table_name: Callable[[str], List[str]], + expand_table_name: Callable[[str, Optional[str]], List[str]], casefold_identifier: Callable[[str], str], ) -> sge.Query: """Binds a logical query (compliant with dlt schema) to physical tables in the destination dataset. @@ -971,7 +971,9 @@ def bind_query( Args: qualified_query: SQLGlot query expression with qualified table/column references sqlglot_schema: Schema mapping for name validation and column resolution - expand_table_name: Function that expands table name to fully qualified path [catalog, schema, table] + expand_table_name: Function ``(table_name, dataset_name | None) -> [catalog, schema, table]`` + that expands a table name to a fully qualified path. The second argument is the + dataset qualifier from the query (``node.db``), or `None` for the default dataset. casefold_identifier: Case transformation function (`str`, `str.upper`, or `str.lower`) Returns: @@ -993,7 +995,7 @@ def bind_query( # expand named of known tables. this is currently clickhouse things where # we use dataset.table in queries but render those as dataset___table if sqlglot_schema.column_names(node): - expanded_path = expand_table_name(node.name) + expanded_path = expand_table_name(node.name, node.db or None) # set the table name if node.name != expanded_path[-1]: node.this.set("this", expanded_path[-1]) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index 240116fec9..39be0b12b0 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -1,15 +1,18 @@ from __future__ import annotations from functools import reduce -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Optional, Sequence, Union +import sqlglot import sqlglot.expressions as sge + from dlt.common.typing import TypedDict from dlt.common.schema import Schema, utils as schema_utils -from dlt.common.schema.typing import TTableReference +from dlt.common.schema.typing import TTableReference, TTableSchemaColumns +from dlt.common.libs.sqlglot import TSqlGlotDialect if TYPE_CHECKING: - from dlt.dataset.relation import TJoinType + from dlt.dataset.relation import Relation, TJoinType _INTERMEDIATE_JOIN_ALIAS_PREFIX = "_dlt_int_t" @@ -268,17 +271,16 @@ def _discover_join_params( def _apply_join_projection( query: sge.Select, *, - schema: Schema, left_table: str, - target_table: str, + target_columns: TTableSchemaColumns, target_qualifier: str, projection_prefix: str, allow_existing_target_projection: bool, ) -> None: """Apply join projection contract onto ``query``. - Preserves the left-side projection and appends only columns from the explicitly - joined ``target_table`` as ``{projection_prefix}__{column}`` aliases. + Preserves the left-side projection and appends only columns from the + joined target as ``{projection_prefix}__{column}`` aliases. ``allow_existing_target_projection`` is used for idempotent re-joins: when a join call contributes no new join edges, all target-prefixed columns may already @@ -304,7 +306,6 @@ def _apply_join_projection( if expr.output_name not in {"", "*"} } - target_columns = schema.tables[target_table]["columns"] target_output_names = { f"{projection_prefix}__{column_name}" for column_name in target_columns.keys() } @@ -376,11 +377,108 @@ def _apply_join( _apply_join_projection( query, - schema=schema, left_table=left_table, - target_table=right_table, + target_columns=schema.tables[right_table]["columns"], target_qualifier=target_qualifier, projection_prefix=projection_prefix, allow_existing_target_projection=not join_params, ) return query + + +def _rewrite_on_qualifiers( + on_expr: sge.Expression, + target_table: str, + internal_alias: str, +) -> sge.Expression: + """Rewrite column qualifiers in the ON expression that reference the target table. + + The user writes ``on="users.id = orders.user_id"`` using logical table names. + Once the target is aliased internally, those references must point to the alias + so the SQL engine can resolve them. + """ + on_expr = on_expr.copy() + for col in on_expr.find_all(sge.Column): + table_node = col.args.get("table") + if isinstance(table_node, sge.Identifier) and table_node.name == target_table: + table_node.set("this", internal_alias) + return on_expr + + +def _apply_explicit_join( + expression: sge.Query, + *, + target: Optional["Relation"] = None, + target_table: str, + target_dataset_name: Optional[str], + target_columns: TTableSchemaColumns, + on: Union[str, sge.Expression], + projection_prefix: str, + kind: "TJoinType", + destination_dialect: TSqlGlotDialect, +) -> sge.Select: + """Apply an explicit-ON join to ``expression`` and return the new query. + + Args: + expression: Left-side query to join onto. + target: Right-hand Relation object (if transformed/subquery), or None for + string / base-table targets. + target_table: Bare table name for schema lookups and projection. + target_dataset_name: Foreign dataset qualifier, or None for local. + target_columns: Columns from the right-hand side for projection. + on: Join condition as a SQL string or sqlglot expression. + projection_prefix: Prefix for appended column aliases. + kind: SQL join type. + destination_dialect: Dialect for parsing string ON expressions. + """ + query = expression.copy() + if not isinstance(query, sge.Select): + raise ValueError(f"Join query `{query}` must be an SQL SELECT statement.") + + internal_alias = f"_dlt_jt_{projection_prefix}" + + # build target expression + target_expr: sge.Expression + if target is not None: + # transformed Relation -> subquery (preserves WHERE, SELECT, etc.) + target_expr = sge.Subquery( + this=target.sqlglot_expression, + alias=sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)), + ) + else: + # base-table target (Relation with _table_name, or str) + table_node_args: dict[str, sge.Expression] = { + "this": sge.to_identifier(target_table, quoted=True), + "alias": sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)), + } + if target_dataset_name: + table_node_args["db"] = sge.to_identifier(target_dataset_name, quoted=False) + target_expr = sge.Table(**table_node_args) + + if isinstance(on, str): + on_expr = sqlglot.parse_one(on, dialect=destination_dialect) + else: + on_expr = on + + on_expr = _rewrite_on_qualifiers(on_expr, target_table, internal_alias) + + join_expr = sge.Join(this=target_expr, kind=kind.upper()).on(on_expr) + query = query.join(join_expr) + + from_expr = query.args.get("from_") or query.args.get("from") + if not isinstance(from_expr, sge.From) or not isinstance(from_expr.this, sge.Table): + raise ValueError( + "Cannot apply explicit join: left-side query must have a base table " + "in its FROM clause (not a subquery or derived table)." + ) + left_table = from_expr.this.this.name + + _apply_join_projection( + query, + left_table=left_table, + target_columns=target_columns, + target_qualifier=internal_alias, + projection_prefix=projection_prefix, + allow_existing_target_projection=False, + ) + return query diff --git a/dlt/dataset/dataset.py b/dlt/dataset/dataset.py index f336a4e128..71494efeb0 100644 --- a/dlt/dataset/dataset.py +++ b/dlt/dataset/dataset.py @@ -69,6 +69,8 @@ def __init__( self._default_schema_name: Optional[str] = None self._resolved: bool = False + self._foreign_schemas: Dict[str, List[dlt.Schema]] = {} + self._sql_client: SqlClientBase[Any] = None self._opened_sql_client: SqlClientBase[Any] = None self._table_client: SupportsOpenTables = None @@ -158,14 +160,28 @@ def _ipython_key_completions_(self) -> list[str]: """Provide table names as completion suggestion in interactive environments.""" return self.tables + def _is_same_dataset(self, other: dlt.Dataset) -> bool: + """Whether `other` represents the same logical dataset.""" + # TODO currently only compares dataset name, + # once harderned, conside implementing __eq__ based on this method + return self.dataset_name == other.dataset_name + + def _add_foreign_schemas(self, dataset_name: str, schemas: Sequence[dlt.Schema]) -> None: + """Register schemas from a foreign dataset for cross-dataset joins.""" + if dataset_name == self.dataset_name: + return + self._foreign_schemas[dataset_name] = list(schemas) + @property def sqlglot_schema(self) -> SQLGlotSchema: """SQLGlot schema of the dataset derived from all dlt schemas.""" # NOTE: no cache for now, it is probably more expensive to compute the current schema hash # to see wether this is stale than to compute a new sqlglot schema - return lineage.create_sqlglot_schema( - {self.dataset_name: list(self.schemas)}, dialect=self.destination_dialect - ) + schema_map: Dict[str, Sequence[dlt.Schema]] = { + self.dataset_name: list(self.schemas), + **self._foreign_schemas, + } + return lineage.create_sqlglot_schema(schema_map, dialect=self.destination_dialect) @property def destination_dialect(self) -> TSqlGlotDialect: diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 102bf42307..9d6b4e1037 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -1,6 +1,5 @@ from __future__ import annotations from collections.abc import Collection, Sequence -from functools import partial from typing import ( overload, Union, @@ -35,10 +34,10 @@ from dlt.common.typing import Self, TSortOrder, TypedDict from dlt.common.exceptions import ValueErrorWithKnownValues from dlt.dataset import lineage -from dlt.destinations.sql_client import SqlClientBase, WithSqlClient +from dlt.destinations.sql_client import SqlClientBase, WithSchemas, WithSqlClient from dlt.destinations.queries import bind_query, build_select_expr from dlt.common.destination.dataset import SupportsDataAccess -from dlt.dataset._join import _apply_join +from dlt.dataset._join import _apply_join, _apply_explicit_join if TYPE_CHECKING: @@ -254,14 +253,16 @@ def to_sql(self, pretty: bool = False, *, _raw_query: bool = False) -> str: query = self.sqlglot_expression else: _, _qualified_query = _get_relation_output_columns_schema(self) + + def _expand(table_name: str, db: Optional[str] = None) -> list[str]: + return self.sql_client.make_qualified_table_name_path( + table_name, quote=False, casefold=False, dataset_name=db + ) + query = bind_query( qualified_query=_qualified_query, sqlglot_schema=self._dataset.sqlglot_schema, - expand_table_name=partial( - self.sql_client.make_qualified_table_name_path, - quote=False, - casefold=False, - ), + expand_table_name=_expand, casefold_identifier=self.sql_client.capabilities.casefold_identifier, ) @@ -358,33 +359,46 @@ def order_by(self, column_name: str, direction: TSortOrder = "asc") -> Self: rel._sqlglot_expression = rel.sqlglot_expression.order_by(order_expr) return rel + @overload def join( self, other: str | Self, *, kind: TJoinType = "inner", alias: Optional[str] = None, - ) -> Self: - """Join this relation to another table using dlt schema references. + ) -> Self: ... - Join conditions are discovered automatically from the schema's reference - chain (parent/child/root relationships created by dlt during loading). - Both the current relation and ``other`` must be base-table relations - (i.e., created via ``dataset[table_name]``, not transformed with - ``.select()``/``.where()`` etc.). + @overload + def join( + self, + other: str | Self, + on: str | sge.Expression, + *, + kind: TJoinType = "inner", + alias: Optional[str] = None, + ) -> Self: ... - This method is designed for the common case of navigating dlt's - built-in table hierarchy. For more complex join scenarios — such as - custom join predicates, joining on non-reference columns, self-joins, - or multi-way joins with mixed conditions — use ``Relation.to_ibis()`` - to obtain an ibis table expression and construct the join manually:: + def join( + self, + other: str | Self, + on: str | sge.Expression | None = None, + *, + kind: TJoinType = "inner", + alias: Optional[str] = None, + ) -> Self: + """Join this relation to another table. - t1 = dataset["orders"].to_ibis() - t2 = dataset["products"].to_ibis() - joined = t1.join(t2, t1.product_id == t2.id, how="left") + Without ``on``, join conditions are discovered automatically from the + schema's reference chain (parent/child/root relationships created by + dlt during loading). With ``on``, an explicit join predicate is used + instead — this also enables cross-dataset joins. Args: - other: Table name or base-table relation to join. + other: Table name or Relation to join. For cross-dataset joins, + pass a Relation from a different ``dlt.Dataset``. + on: Explicit join condition as an SQL string or sqlglot expression. + Required for cross-dataset joins and joins between tables + without dlt schema references. kind: Type of SQL join: ``"inner"``, ``"left"``, ``"right"``, or ``"full"``. alias: Projection prefix for the joined table's columns. Columns @@ -392,53 +406,152 @@ def join( the target table name. Returns: - A new relation with the join(s) applied and the target table's + A new relation with the join applied and the target table's columns appended to the projection. Raises: - ValueError: If schema references between the two tables cannot be - resolved, or if either relation is not join-eligible. + ValueError: If the join cannot be resolved. + + Example:: + + # auto join (schema references) + dataset["orders"].join("users") + + # explicit ON + dataset["orders"].join("users", on="orders._dlt_parent_id = users._dlt_id") + + # cross-dataset join + local["orders"].join( + foreign["products"], + on="orders.product_id = products.id", + ) """ if alias == "": raise ValueError("`alias` must be a non-empty string when provided.") - if not self._table_name: - raise ValueError("This relation has no base table to resolve references.") + target_dataset, target_table, target_columns = self._resolve_join_target(other, on=on) - if isinstance(other, dlt.Relation): - # TODO: remove once we allow cross-dataset joins - if not ( - self._dataset.is_same_physical_destination(other._dataset) - and self._dataset.dataset_name == other._dataset.dataset_name - ): - raise ValueError( - "Cannot join relations from different datasets: " - f"'{other._dataset.dataset_name}' vs '{self._dataset.dataset_name}'" - ) - target_table = other._table_name - if not target_table: - raise ValueError(f"Relation `{other}` has no base table to resolve references.") - else: - target_table = other + is_same_dataset = self._dataset._is_same_dataset(target_dataset) - if not target_table or not isinstance(target_table, str): - raise ValueError("`other` must be a table name or a base table relation.") - if target_table not in self._dataset.schema.tables: - raise ValueError(f"Table `{target_table}` not found in dataset schema") + # self-join detection + if target_table == self._table_name and is_same_dataset: + raise ValueError("Self-joins are not supported.") projection_prefix = alias or target_table - query = _apply_join( - self.sqlglot_expression, - schema=self._dataset.schema, - left_table=self._table_name, - right_table=target_table, - projection_prefix=projection_prefix, - kind=kind, - ) + + if on is None: + if not self._table_name: + raise ValueError("This relation has no base table to resolve references.") + if not is_same_dataset: + raise ValueError("`on` is required when joining relations from different datasets.") + if target_table not in self._dataset.schema.tables: + raise ValueError(f"Table `{target_table}` not found in dataset schema") + query = _apply_join( + self.sqlglot_expression, + schema=self._dataset.schema, + left_table=self._table_name, + right_table=target_table, + projection_prefix=projection_prefix, + kind=kind, + ) + else: + if not is_same_dataset: + self._dataset._add_foreign_schemas( + target_dataset.dataset_name, + list(target_dataset.schemas), + ) + # pass Relation as target when it's been transformed so it + # becomes a subquery (preserving WHERE, SELECT, LIMIT, etc.) + subquery_rhs: Optional[Relation] = ( + other if isinstance(other, dlt.Relation) and other._query is not None else None + ) + query = _apply_explicit_join( + self.sqlglot_expression, + target=subquery_rhs, + target_table=target_table, + target_dataset_name=(None if is_same_dataset else target_dataset.dataset_name), + target_columns=target_columns, + on=on, + projection_prefix=projection_prefix, + kind=kind, + destination_dialect=self.destination_dialect, + ) + rel = self.__copy__() rel._sqlglot_expression = query return rel + def _resolve_join_target( + self, + other: Union[str, Self], + *, + on: Union[str, sge.Expression, None] = None, + ) -> tuple[dlt.Dataset, str, TTableSchemaColumns]: + """Resolve the target dataset, table name, and columns for a join. + + Returns: + Tuple of (target_dataset, target_table_name, target_columns). + """ + if isinstance(other, dlt.Relation): + target_dataset = other._dataset + + # physical destination check + if not self._dataset._is_same_dataset(target_dataset): + if not self._dataset.is_same_physical_destination(target_dataset): + raise ValueError( + "Cannot join relations from different physical destinations: " + f"'{target_dataset.dataset_name}' vs '{self._dataset.dataset_name}'" + ) + # cross-dataset filesystem not supported + if isinstance(self.sql_client, WithSchemas): + raise ValueError( + "Cross-dataset joins are not supported on filesystem destinations." + ) + + target_table = other._table_name + is_transformed = other._query is not None + if target_table and not is_transformed: + # pristine base-table Relation: look up columns from schema + target_columns = _find_table_columns(target_dataset.schemas, target_table) + elif target_table and is_transformed: + # transformed Relation that still tracks its origin table + # (e.g., .where(), .select()); use its actual output columns + target_columns = other.columns_schema + else: + # no base table at all (e.g., from .query()) + if on is None: + raise ValueError(f"Relation `{other}` has no base table to resolve references.") + target_table = _extract_subquery_alias(other) + target_columns = other.columns_schema + elif isinstance(other, str): + if "." in other: + ds_name, tbl_name = other.split(".", 1) + if ds_name == self._dataset.dataset_name: + target_dataset = self._dataset + elif ds_name in self._dataset._foreign_schemas: + target_dataset = self._dataset + # columns come from the foreign schemas already registered + target_table = tbl_name + target_columns = _find_table_columns( + self._dataset._foreign_schemas[ds_name], tbl_name + ) + return target_dataset, target_table, target_columns + else: + raise ValueError( + f"Dataset `{ds_name}` is not registered. Pass a Relation from the " + "foreign dataset to automatically register its schema." + ) + target_table = tbl_name + target_columns = _find_table_columns(target_dataset.schemas, target_table) + else: + target_dataset = self._dataset + target_table = other + target_columns = _find_table_columns(target_dataset.schemas, target_table) + else: + raise ValueError("`other` must be a table name or a base table relation.") + + return target_dataset, target_table, target_columns + # NOTE we currently force to have one column selected; we could be more flexible # and rewrite the query to compute the AGG of all selected columns # `SELECT AGG(col1), AGG(col2), ... FROM table`` @@ -877,3 +990,22 @@ def _add_load_id_via_parent_key(relation: dlt.Relation) -> dlt.Relation: rel = relation.__copy__() rel._sqlglot_expression = query return rel + + +def _find_table_columns(schemas: Sequence[dlt.Schema], table_name: str) -> TTableSchemaColumns: + """Find the columns schema for a table across a sequence of schemas.""" + for schema in schemas: + if table_name in schema.tables: + return schema.tables[table_name]["columns"] + raise ValueError(f"Table `{table_name}` not found in dataset schema") + + +def _extract_subquery_alias(relation: dlt.Relation) -> str: + """Extract a stable alias for a transformed Relation without a base table.""" + expr = relation.sqlglot_expression + from_expr = expr.args.get("from_") or expr.args.get("from") + if isinstance(from_expr, sge.From) and isinstance(from_expr.this, sge.Table): + table_id = from_expr.this.this + if isinstance(table_id, sge.Identifier): + return table_id.name + return "subquery" diff --git a/dlt/destinations/impl/clickhouse/sql_client.py b/dlt/destinations/impl/clickhouse/sql_client.py index d94611aa8f..9e264e11e8 100644 --- a/dlt/destinations/impl/clickhouse/sql_client.py +++ b/dlt/destinations/impl/clickhouse/sql_client.py @@ -287,14 +287,21 @@ def catalog_name(self, quote: bool = True, casefold: bool = True) -> Optional[st return database_name def make_qualified_table_name_path( - self, table_name: Optional[str], quote: bool = True, casefold: bool = True + self, + table_name: Optional[str], + quote: bool = True, + casefold: bool = True, + dataset_name: Optional[str] = None, ) -> List[str]: # get catalog and dataset - path = super().make_qualified_table_name_path(None, quote=quote, casefold=casefold) + path = super().make_qualified_table_name_path( + None, quote=quote, casefold=casefold, dataset_name=dataset_name + ) + effective_dataset = dataset_name or self.dataset_name if table_name: # table name combines dataset name and table name - if self.dataset_name: - table_name = f"{self.dataset_name}{self.config.dataset_table_separator}{table_name}" + if effective_dataset: + table_name = f"{effective_dataset}{self.config.dataset_table_separator}{table_name}" else: # without dataset just use the table name pass diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 1e924d4e65..5119bc51ee 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -337,19 +337,23 @@ def create_table(self, table_obj: sa.Table) -> None: table_obj.create(self._current_connection) def make_qualified_table_name_path( - self, table_name: Optional[str], quote: bool = True, casefold: bool = True + self, + table_name: Optional[str], + quote: bool = True, + casefold: bool = True, + dataset_name: Optional[str] = None, ) -> List[str]: path: List[str] = [] # no catalog for sqlalchemy if catalog_name := self.catalog_name(quote=quote, casefold=casefold): path.append(catalog_name) - dataset_name = self.dataset_name + effective_dataset = dataset_name or self.dataset_name if self.dialect.requires_name_normalize and casefold: # type: ignore[attr-defined] - dataset_name = str(self.dialect.normalize_name(dataset_name)) # type: ignore[func-returns-value] + effective_dataset = str(self.dialect.normalize_name(effective_dataset)) # type: ignore[func-returns-value] if quote: - dataset_name = self.dialect.identifier_preparer.quote_identifier(dataset_name) # type: ignore[attr-defined] - path.append(dataset_name) + effective_dataset = self.dialect.identifier_preparer.quote_identifier(effective_dataset) # type: ignore[attr-defined] + path.append(effective_dataset) if table_name: if self.dialect.requires_name_normalize and casefold: # type: ignore[attr-defined] table_name = str(self.dialect.normalize_name(table_name)) # type: ignore[func-returns-value] diff --git a/dlt/destinations/queries.py b/dlt/destinations/queries.py index e171e2a5e8..c3ef72afb5 100644 --- a/dlt/destinations/queries.py +++ b/dlt/destinations/queries.py @@ -1,5 +1,4 @@ -from functools import partial -from typing import Any, List +from typing import Any, List, Optional import sqlglot.expressions as sge from sqlglot.schema import Schema as SQLGlotSchema @@ -20,12 +19,16 @@ def _normalize_query( TODO: remove after next dlthub release """ + + def _expand(table_name: str, db: Optional[str] = None) -> List[str]: + return sql_client.make_qualified_table_name_path( + table_name, quote=False, casefold=False, dataset_name=db + ) + return bind_query( qualified_query, sqlglot_schema, - expand_table_name=partial( - sql_client.make_qualified_table_name_path, quote=False, casefold=False - ), + expand_table_name=_expand, casefold_identifier=casefold_identifier, ) diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index c5198724b6..c77bb5f324 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -225,20 +225,27 @@ def make_qualified_table_name( # TODO make it a staticmethod to avoid passing SQLClient instances all around def make_qualified_table_name_path( - self, table_name: Optional[str], quote: bool = True, casefold: bool = True + self, + table_name: Optional[str], + quote: bool = True, + casefold: bool = True, + dataset_name: Optional[str] = None, ) -> List[str]: """Returns a list with path components leading from catalog to table_name. Used to construct fully qualified names. `table_name` is optional. + + Args: + dataset_name: Override the default dataset name for cross-dataset references. """ path: List[str] = [] if catalog_name := self.catalog_name(quote=quote, casefold=casefold): path.append(catalog_name) - dataset_name = self.dataset_name + effective_dataset = dataset_name or self.dataset_name if casefold: - dataset_name = self.capabilities.casefold_identifier(self.dataset_name) + effective_dataset = self.capabilities.casefold_identifier(effective_dataset) if quote: - dataset_name = self.capabilities.escape_identifier(dataset_name) - path.append(dataset_name) + effective_dataset = self.capabilities.escape_identifier(effective_dataset) + path.append(effective_dataset) if table_name: if casefold: table_name = self.capabilities.casefold_identifier(table_name) diff --git a/tests/dataset/conftest.py b/tests/dataset/conftest.py index e2c34178f0..ec9ad4cf66 100644 --- a/tests/dataset/conftest.py +++ b/tests/dataset/conftest.py @@ -8,9 +8,12 @@ from tests.dataset.utils import ( LOAD_0_STATS, LOAD_1_STATS, + TCrossDsFixture, TLoadsFixture, annotated_references, crm, + inventory, + relational_tables, ) from tests.utils import ( auto_test_run_context, @@ -86,6 +89,47 @@ def dataset_with_loads( raise ValueError(f"Unknown dataset fixture: {request.param}") +@pytest.fixture(scope="module") +def dataset_with_relational_tables(module_tmp_path: pathlib.Path) -> dlt.Dataset: + pipeline = dlt.pipeline( + pipeline_name="relational_tables", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(str(module_tmp_path / "relational.db")), + dev_mode=True, + ) + pipeline.run(relational_tables()) + return pipeline.dataset() + + +@pytest.fixture(scope="module") +def cross_dataset_duckdb(module_tmp_path: pathlib.Path) -> TCrossDsFixture: + db_path = str(module_tmp_path / "cross_dataset.db") + + # dataset A: CRM data (users + orders) + pipeline_a = dlt.pipeline( + pipeline_name="cross_ds_a", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(db_path), + dataset_name="crm_data", + dev_mode=True, + ) + source_a = crm(0) + source_a.root_key = True + pipeline_a.run(source_a) + + # dataset B: inventory data (products + warehouses) + pipeline_b = dlt.pipeline( + pipeline_name="cross_ds_b", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(db_path), + dataset_name="inv_data", + dev_mode=True, + ) + pipeline_b.run(inventory()) + + return pipeline_a.dataset(), pipeline_b.dataset() + + @pytest.fixture(scope="module") def dataset_with_annotated_references(module_tmp_path: pathlib.Path) -> dlt.Dataset: pipeline = dlt.pipeline( diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index d83c0ef39c..4747895e0f 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -13,7 +13,7 @@ _to_join_ref, ) from dlt.dataset.relation import TJoinType -from tests.dataset.utils import TLoadsFixture +from tests.dataset.utils import TCrossDsFixture, TLoadsFixture class _ColumnRef(TypedDict): @@ -208,8 +208,7 @@ def test_resolve_reference_chain_rejects_self_join(dataset_with_loads: TLoadsFix @pytest.mark.parametrize("dataset_with_loads", ["with_root_key"], indirect=True) -def test_join_rejects_cross_dataset(dataset_with_loads: TLoadsFixture) -> None: - """Test that joining relations from different datasets raises an error.""" +def test_join_rejects_different_physical_destination(dataset_with_loads: TLoadsFixture) -> None: dataset, _, _ = dataset_with_loads with tempfile.TemporaryDirectory() as tmp: @@ -227,12 +226,11 @@ def other_data(): pipeline.run([other_data]) other_dataset = pipeline.dataset() - # Try to join with a relation from the other dataset rel = dataset.table("users") other_rel = other_dataset.table("other_data") - with pytest.raises(ValueError, match="different datasets"): - rel.join(other_rel) + with pytest.raises(ValueError, match="different physical destinations"): + rel.join(other_rel, on="users._dlt_id = other_data._dlt_id") @pytest.mark.parametrize( @@ -300,7 +298,7 @@ def test_resolve_reference_chain_rejects_unrelated_tables( pytest.param( lambda ds: ds.table("users"), "users", - "Cannot join a table to itself", + "Self-joins are not supported", id="self-join", ), pytest.param( @@ -920,3 +918,202 @@ def test_join_columns_schema_resolves_with_name_mutating_normalizer( for column_name in normalized_dataset.schema.tables[normalized_right]["columns"].keys() } assert expected_right_aliases.issubset(schema_cols) + + +def test_explicit_on_joins_relational_tables( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + joined = ds.table("customers").join("orders", on="customers.customer_id = orders.customer_id") + df = joined.df() + assert len(df) == 4 + assert "orders__amount" in df.columns + assert list(df["orders__amount"]) == [50.0, 75.0, 200.0, 30.0] + + # auto join should fail: no dlt reference between customers and orders + with pytest.raises(ValueError, match="Unable to resolve reference chain"): + ds.table("customers").join("orders") + + +def test_explicit_on_accepts_sqlglot_expression( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + on_expr = sge.EQ( + this=sge.Column( + table=sge.to_identifier("customers"), + this=sge.to_identifier("country_code"), + ), + expression=sge.Column( + table=sge.to_identifier("countries"), + this=sge.to_identifier("code"), + ), + ) + joined = ds.table("customers").join("countries", on=on_expr) + df = joined.df() + assert len(df) == 3 + assert list(df["countries__name"]) == ["Germany", "France", "Germany"] + + +def test_explicit_on_non_eq_predicate( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + joined = ds.table("customers").join( + "orders", + on="customers.customer_id = orders.customer_id AND orders.amount > 50", + ) + df = joined.df() + assert len(df) == 2 + assert list(df["orders__amount"]) == [75.0, 200.0] + + +def test_explicit_on_projection_prefix( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + joined = ds.table("customers").join( + "orders", on="customers.customer_id = orders.customer_id", alias="o" + ) + selects = joined.sqlglot_expression.selects + right_aliases = {expr.output_name for expr in selects if expr.output_name.startswith("o__")} + assert right_aliases + expected = {f"o__{col}" for col in ds.schema.tables["orders"]["columns"].keys()} + assert right_aliases == expected + + +def test_explicit_on_rejects_empty_alias( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + with pytest.raises(ValueError, match="must be a non-empty string"): + ds.table("customers").join( + "orders", on="customers.customer_id = orders.customer_id", alias="" + ) + + +def test_explicit_on_rejects_self_join( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + with pytest.raises(ValueError, match="Self-joins are not supported"): + ds.table("customers").join( + "customers", + on="customers.customer_id = customers.customer_id", + alias="c2", + ) + + +def test_explicit_on_with_filtered_rhs( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + expensive_orders = ds.table("orders").where("amount", "gt", 50.0) + joined = ds.table("customers").join( + expensive_orders, on="customers.customer_id = orders.customer_id" + ) + df = joined.df() + assert len(df) == 2 + assert list(df["name"]) == ["Alice", "Bob"] + assert list(df["orders__amount"]) == [75.0, 200.0] + + +def test_explicit_on_with_projected_rhs( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + narrow_orders = ds.table("orders").select("order_id", "customer_id") + joined = ds.table("customers").join( + narrow_orders, on="customers.customer_id = orders.customer_id" + ) + df = joined.df() + assert len(df) == 4 + rhs_cols = {c for c in df.columns if c.startswith("orders__")} + assert rhs_cols == {"orders__order_id", "orders__customer_id"} + assert "orders__amount" not in df.columns + + +def test_cross_dataset_join_registers_foreign_schemas( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + """Cross-dataset join registers the foreign dataset's schemas.""" + ds_a, ds_b = cross_dataset_duckdb + users = ds_a.table("users") + purchases = ds_b.table("purchases") + + assert ds_b.dataset_name not in ds_a._foreign_schemas + + users.join(purchases, on="users.id = purchases.user_id") + + assert ds_b.dataset_name in ds_a._foreign_schemas + foreign_schemas = ds_a._foreign_schemas[ds_b.dataset_name] + assert len(foreign_schemas) >= 1 + + +def test_cross_dataset_join_requires_on( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_a, ds_b = cross_dataset_duckdb + users = ds_a.table("users") + purchases = ds_b.table("purchases") + + with pytest.raises(ValueError, match="`on` is required"): + users.join(purchases) + + +def test_cross_dataset_join_e2e( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_a, ds_b = cross_dataset_duckdb + users = ds_a.table("users") + purchases = ds_b.table("purchases") + + joined = users.join(purchases, on="users.id = purchases.user_id") + df = joined.df() + assert len(df) == 3 + assert "purchases__sku" in df.columns + assert "purchases__quantity" in df.columns + assert sorted(df["purchases__sku"]) == ["G-001", "W-001", "W-001"] + + +_MATCHED = { + "purchases__purchase_id": [1, 2, 3], + "purchases__user_id": [1, 1, 2], + "purchases__sku": ["W-001", "G-001", "W-001"], + "purchases__quantity": [2, 1, 1], + "name": ["Alice", "Alice", "Bob"], +} +_MATCHED_PLUS_ORPHAN = { + "purchases__purchase_id": [1, 2, 3, 4], + "purchases__user_id": [1, 1, 2, 99], + "purchases__sku": ["W-001", "G-001", "W-001", "D-001"], + "purchases__quantity": [2, 1, 1, 5], + "name": ["Alice", "Alice", "Bob", None], # orphan's matched user name is NULL +} + + +@pytest.mark.parametrize( + "kind,expected", + [ + # inner + left: both users match, so LEFT adds no extra rows + pytest.param("inner", _MATCHED, id="inner"), + pytest.param("left", _MATCHED, id="left"), + # right + full: orphan purchase appears with NULL on the user side + pytest.param("right", _MATCHED_PLUS_ORPHAN, id="right"), + pytest.param("full", _MATCHED_PLUS_ORPHAN, id="full"), + ], +) +def test_cross_dataset_join_kind_parameter( + cross_dataset_duckdb: TCrossDsFixture, + kind: TJoinType, + expected: dict[str, list[Any]], +) -> None: + ds_a, ds_b = cross_dataset_duckdb + users = ds_a.table("users") + purchases = ds_b.table("purchases") + + joined = users.join(purchases, on="users.id = purchases.user_id", kind=kind) + df = joined.df() + + for col, expected_values in expected.items(): + assert list(df[col]) == expected_values, f"column `{col}` mismatch" diff --git a/tests/dataset/utils.py b/tests/dataset/utils.py index f866289a3d..46fdf242bb 100644 --- a/tests/dataset/utils.py +++ b/tests/dataset/utils.py @@ -51,8 +51,44 @@ class AccountMembershipRow(TypedDict): user_name: str +class WarehouseRow(TypedDict): + warehouse_id: int + city: str + + +class InventoryItemRow(TypedDict): + sku: str + warehouse_id: int + quantity: int + + +class PurchaseRow(TypedDict): + purchase_id: int + user_id: int + sku: str + quantity: int + + +class CustomerRow(TypedDict): + customer_id: int + name: str + country_code: str + + +class CustomerOrderRow(TypedDict): + order_id: int + customer_id: int + amount: float + + +class CountryRow(TypedDict): + code: str + name: str + + TLoadStats = dict[str, int] TLoadsFixture = tuple[dlt.Dataset, tuple[str, str], tuple[TLoadStats, TLoadStats]] +TCrossDsFixture = tuple[dlt.Dataset, dlt.Dataset] USERS_DATA_0: list[UserRow] = [ @@ -144,6 +180,88 @@ def products(batch_idx: int): return [users(i), products(i)] +WAREHOUSES: list[WarehouseRow] = [ + {"warehouse_id": 1, "city": "Berlin"}, + {"warehouse_id": 2, "city": "Paris"}, +] + +INVENTORY_ITEMS: list[InventoryItemRow] = [ + {"sku": "W-001", "warehouse_id": 1, "quantity": 50}, + {"sku": "G-001", "warehouse_id": 2, "quantity": 30}, + {"sku": "D-001", "warehouse_id": 1, "quantity": 10}, +] + +PURCHASES: list[PurchaseRow] = [ + {"purchase_id": 1, "user_id": 1, "sku": "W-001", "quantity": 2}, + {"purchase_id": 2, "user_id": 1, "sku": "G-001", "quantity": 1}, + {"purchase_id": 3, "user_id": 2, "sku": "W-001", "quantity": 1}, + {"purchase_id": 4, "user_id": 99, "sku": "D-001", "quantity": 5}, +] + + +@dlt.source +def inventory(): + @dlt.resource(name="warehouses") + def warehouses(): + yield WAREHOUSES + + @dlt.resource( + name="inventory_items", + references=[ + { + "referenced_table": "warehouses", + "columns": ["warehouse_id"], + "referenced_columns": ["warehouse_id"], + } + ], + ) + def inventory_items(): + yield INVENTORY_ITEMS + + @dlt.resource(name="purchases") + def purchases(): + yield PURCHASES + + return [warehouses(), inventory_items(), purchases()] + + +CUSTOMERS: list[CustomerRow] = [ + {"customer_id": 1, "name": "Alice", "country_code": "DE"}, + {"customer_id": 2, "name": "Bob", "country_code": "FR"}, + {"customer_id": 3, "name": "Charlie", "country_code": "DE"}, +] + +CUSTOMER_ORDERS: list[CustomerOrderRow] = [ + {"order_id": 100, "customer_id": 1, "amount": 50.0}, + {"order_id": 101, "customer_id": 1, "amount": 75.0}, + {"order_id": 102, "customer_id": 2, "amount": 200.0}, + {"order_id": 103, "customer_id": 3, "amount": 30.0}, +] + +COUNTRIES: list[CountryRow] = [ + {"code": "DE", "name": "Germany"}, + {"code": "FR", "name": "France"}, + {"code": "ES", "name": "Spain"}, +] + + +@dlt.source +def relational_tables(): + @dlt.resource(name="customers") + def customers(): + yield CUSTOMERS + + @dlt.resource(name="orders") + def orders(): + yield CUSTOMER_ORDERS + + @dlt.resource(name="countries") + def countries(): + yield COUNTRIES + + return [customers(), orders(), countries()] + + @dlt.source def annotated_references(): @dlt.resource(name="users") diff --git a/tests/destinations/test_queries.py b/tests/destinations/test_queries.py index a41c65e880..757e614701 100644 --- a/tests/destinations/test_queries.py +++ b/tests/destinations/test_queries.py @@ -1,5 +1,4 @@ -from functools import partial -from typing import cast +from typing import List, Optional, cast import duckdb import pytest @@ -130,12 +129,16 @@ def test_normalize_query(): ) with duckdb_destination_client.sql_client as sql_client: + + def _expand(table_name: str, db: Optional[str] = None) -> List[str]: + return sql_client.make_qualified_table_name_path( + table_name, quote=False, casefold=False, dataset_name=db + ) + normalized_query_expr = bind_query( qualified_query=cast(sge.Query, qualified_query_expr), sqlglot_schema=sqlglot_schema, - expand_table_name=partial( - sql_client.make_qualified_table_name_path, quote=False, casefold=False - ), + expand_table_name=_expand, casefold_identifier=sql_client.capabilities.casefold_identifier, ) normalized_query = normalized_query_expr.sql()