Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions dlt/common/libs/sqlglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def bind_query(
qualified_query: sge.Query,
sqlglot_schema: Any, # SQLGlotSchema
*,
expand_table_name: Callable[[str], List[str]],
expand_table_name: Callable[[str, Optional[str]], List[str]],
casefold_identifier: Callable[[str], str],
) -> sge.Query:
"""Binds a logical query (compliant with dlt schema) to physical tables in the destination dataset.
Expand All @@ -971,7 +971,9 @@ def bind_query(
Args:
qualified_query: SQLGlot query expression with qualified table/column references
sqlglot_schema: Schema mapping for name validation and column resolution
expand_table_name: Function that expands table name to fully qualified path [catalog, schema, table]
expand_table_name: Function ``(table_name, dataset_name | None) -> [catalog, schema, table]``
that expands a table name to a fully qualified path. The second argument is the
dataset qualifier from the query (``node.db``), or `None` for the default dataset.
casefold_identifier: Case transformation function (`str`, `str.upper`, or `str.lower`)

Returns:
Expand All @@ -993,7 +995,7 @@ def bind_query(
# expand named of known tables. this is currently clickhouse things where
# we use dataset.table in queries but render those as dataset___table
if sqlglot_schema.column_names(node):
expanded_path = expand_table_name(node.name)
expanded_path = expand_table_name(node.name, node.db or None)
# set the table name
if node.name != expanded_path[-1]:
node.this.set("this", expanded_path[-1])
Expand Down
118 changes: 108 additions & 10 deletions dlt/dataset/_join.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

from functools import reduce
from typing import TYPE_CHECKING, Optional, Sequence
from typing import TYPE_CHECKING, Optional, Sequence, Union

import sqlglot
import sqlglot.expressions as sge

from dlt.common.typing import TypedDict
from dlt.common.schema import Schema, utils as schema_utils
from dlt.common.schema.typing import TTableReference
from dlt.common.schema.typing import TTableReference, TTableSchemaColumns
from dlt.common.libs.sqlglot import TSqlGlotDialect

if TYPE_CHECKING:
from dlt.dataset.relation import TJoinType
from dlt.dataset.relation import Relation, TJoinType

_INTERMEDIATE_JOIN_ALIAS_PREFIX = "_dlt_int_t"

Expand Down Expand Up @@ -268,17 +271,16 @@ def _discover_join_params(
def _apply_join_projection(
query: sge.Select,
*,
schema: Schema,
left_table: str,
target_table: str,
target_columns: TTableSchemaColumns,
target_qualifier: str,
projection_prefix: str,
allow_existing_target_projection: bool,
) -> None:
"""Apply join projection contract onto ``query``.

Preserves the left-side projection and appends only columns from the explicitly
joined ``target_table`` as ``{projection_prefix}__{column}`` aliases.
Preserves the left-side projection and appends only columns from the
joined target as ``{projection_prefix}__{column}`` aliases.

``allow_existing_target_projection`` is used for idempotent re-joins: when a
join call contributes no new join edges, all target-prefixed columns may already
Expand All @@ -304,7 +306,6 @@ def _apply_join_projection(
if expr.output_name not in {"", "*"}
}

target_columns = schema.tables[target_table]["columns"]
target_output_names = {
f"{projection_prefix}__{column_name}" for column_name in target_columns.keys()
}
Expand Down Expand Up @@ -376,11 +377,108 @@ def _apply_join(

_apply_join_projection(
query,
schema=schema,
left_table=left_table,
target_table=right_table,
target_columns=schema.tables[right_table]["columns"],
target_qualifier=target_qualifier,
projection_prefix=projection_prefix,
allow_existing_target_projection=not join_params,
)
return query


def _rewrite_on_qualifiers(
on_expr: sge.Expression,
target_table: str,
internal_alias: str,
) -> sge.Expression:
"""Rewrite column qualifiers in the ON expression that reference the target table.

The user writes ``on="users.id = orders.user_id"`` using logical table names.
Once the target is aliased internally, those references must point to the alias
so the SQL engine can resolve them.
"""
on_expr = on_expr.copy()
for col in on_expr.find_all(sge.Column):
table_node = col.args.get("table")
if isinstance(table_node, sge.Identifier) and table_node.name == target_table:
table_node.set("this", internal_alias)
return on_expr


def _apply_explicit_join(
expression: sge.Query,
*,
target: Optional["Relation"] = None,
target_table: str,
target_dataset_name: Optional[str],
target_columns: TTableSchemaColumns,
on: Union[str, sge.Expression],
projection_prefix: str,
kind: "TJoinType",
destination_dialect: TSqlGlotDialect,
) -> sge.Select:
"""Apply an explicit-ON join to ``expression`` and return the new query.

Args:
expression: Left-side query to join onto.
target: Right-hand Relation object (if transformed/subquery), or None for
string / base-table targets.
target_table: Bare table name for schema lookups and projection.
target_dataset_name: Foreign dataset qualifier, or None for local.
target_columns: Columns from the right-hand side for projection.
on: Join condition as a SQL string or sqlglot expression.
projection_prefix: Prefix for appended column aliases.
kind: SQL join type.
destination_dialect: Dialect for parsing string ON expressions.
"""
query = expression.copy()
if not isinstance(query, sge.Select):
raise ValueError(f"Join query `{query}` must be an SQL SELECT statement.")

internal_alias = f"_dlt_jt_{projection_prefix}"

# build target expression
target_expr: sge.Expression
if target is not None:
# transformed Relation -> subquery (preserves WHERE, SELECT, etc.)
target_expr = sge.Subquery(
this=target.sqlglot_expression,
alias=sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)),
)
else:
# base-table target (Relation with _table_name, or str)
table_node_args: dict[str, sge.Expression] = {
"this": sge.to_identifier(target_table, quoted=True),
"alias": sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)),
}
if target_dataset_name:
table_node_args["db"] = sge.to_identifier(target_dataset_name, quoted=False)
target_expr = sge.Table(**table_node_args)

if isinstance(on, str):
on_expr = sqlglot.parse_one(on, dialect=destination_dialect)
else:
on_expr = on

on_expr = _rewrite_on_qualifiers(on_expr, target_table, internal_alias)

join_expr = sge.Join(this=target_expr, kind=kind.upper()).on(on_expr)
query = query.join(join_expr)

from_expr = query.args.get("from_") or query.args.get("from")
if not isinstance(from_expr, sge.From) or not isinstance(from_expr.this, sge.Table):
raise ValueError(
"Cannot apply explicit join: left-side query must have a base table "
"in its FROM clause (not a subquery or derived table)."
)
left_table = from_expr.this.this.name

_apply_join_projection(
query,
left_table=left_table,
target_columns=target_columns,
target_qualifier=internal_alias,
projection_prefix=projection_prefix,
allow_existing_target_projection=False,
)
return query
22 changes: 19 additions & 3 deletions dlt/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(
self._default_schema_name: Optional[str] = None
self._resolved: bool = False

self._foreign_schemas: Dict[str, List[dlt.Schema]] = {}

self._sql_client: SqlClientBase[Any] = None
self._opened_sql_client: SqlClientBase[Any] = None
self._table_client: SupportsOpenTables = None
Expand Down Expand Up @@ -158,14 +160,28 @@ def _ipython_key_completions_(self) -> list[str]:
"""Provide table names as completion suggestion in interactive environments."""
return self.tables

def _is_same_dataset(self, other: dlt.Dataset) -> bool:
"""Whether `other` represents the same logical dataset."""
# TODO currently only compares dataset name,
# once harderned, conside implementing __eq__ based on this method
return self.dataset_name == other.dataset_name

def _add_foreign_schemas(self, dataset_name: str, schemas: Sequence[dlt.Schema]) -> None:
"""Register schemas from a foreign dataset for cross-dataset joins."""
if dataset_name == self.dataset_name:
return
self._foreign_schemas[dataset_name] = list(schemas)

@property
def sqlglot_schema(self) -> SQLGlotSchema:
"""SQLGlot schema of the dataset derived from all dlt schemas."""
# NOTE: no cache for now, it is probably more expensive to compute the current schema hash
# to see wether this is stale than to compute a new sqlglot schema
return lineage.create_sqlglot_schema(
{self.dataset_name: list(self.schemas)}, dialect=self.destination_dialect
)
schema_map: Dict[str, Sequence[dlt.Schema]] = {
self.dataset_name: list(self.schemas),
**self._foreign_schemas,
}
return lineage.create_sqlglot_schema(schema_map, dialect=self.destination_dialect)

@property
def destination_dialect(self) -> TSqlGlotDialect:
Expand Down
Loading
Loading