Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions dlt/common/libs/sqlglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def bind_query(
qualified_query: sge.Query,
sqlglot_schema: Any, # SQLGlotSchema
*,
expand_table_name: Callable[[str], List[str]],
expand_table_name: Callable[[str, Optional[str]], List[str]],
casefold_identifier: Callable[[str], str],
) -> sge.Query:
"""Binds a logical query (compliant with dlt schema) to physical tables in the destination dataset.
Expand All @@ -971,7 +971,9 @@ def bind_query(
Args:
qualified_query: SQLGlot query expression with qualified table/column references
sqlglot_schema: Schema mapping for name validation and column resolution
expand_table_name: Function that expands table name to fully qualified path [catalog, schema, table]
expand_table_name: Function ``(table_name, dataset_name | None) -> [catalog, schema, table]``
that expands a table name to a fully qualified path. The second argument is the
dataset qualifier from the query (``node.db``), or `None` for the default dataset.
casefold_identifier: Case transformation function (`str`, `str.upper`, or `str.lower`)

Returns:
Expand All @@ -993,7 +995,7 @@ def bind_query(
# expand named of known tables. this is currently clickhouse things where
# we use dataset.table in queries but render those as dataset___table
if sqlglot_schema.column_names(node):
expanded_path = expand_table_name(node.name)
expanded_path = expand_table_name(node.name, node.db or None)
# set the table name
if node.name != expanded_path[-1]:
node.this.set("this", expanded_path[-1])
Expand Down
118 changes: 108 additions & 10 deletions dlt/dataset/_join.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

from functools import reduce
from typing import TYPE_CHECKING, Optional, Sequence
from typing import TYPE_CHECKING, Optional, Sequence, Union

import sqlglot
import sqlglot.expressions as sge

from dlt.common.typing import TypedDict
from dlt.common.schema import Schema, utils as schema_utils
from dlt.common.schema.typing import TTableReference
from dlt.common.schema.typing import TTableReference, TTableSchemaColumns
from dlt.common.libs.sqlglot import TSqlGlotDialect

if TYPE_CHECKING:
from dlt.dataset.relation import TJoinType
from dlt.dataset.relation import Relation, TJoinType

_INTERMEDIATE_JOIN_ALIAS_PREFIX = "_dlt_int_t"

Expand Down Expand Up @@ -268,17 +271,16 @@ def _discover_join_params(
def _apply_join_projection(
query: sge.Select,
*,
schema: Schema,
left_table: str,
target_table: str,
target_columns: TTableSchemaColumns,
target_qualifier: str,
projection_prefix: str,
allow_existing_target_projection: bool,
) -> None:
"""Apply join projection contract onto ``query``.

Preserves the left-side projection and appends only columns from the explicitly
joined ``target_table`` as ``{projection_prefix}__{column}`` aliases.
Preserves the left-side projection and appends only columns from the
joined target as ``{projection_prefix}__{column}`` aliases.

``allow_existing_target_projection`` is used for idempotent re-joins: when a
join call contributes no new join edges, all target-prefixed columns may already
Expand All @@ -304,7 +306,6 @@ def _apply_join_projection(
if expr.output_name not in {"", "*"}
}

target_columns = schema.tables[target_table]["columns"]
target_output_names = {
f"{projection_prefix}__{column_name}" for column_name in target_columns.keys()
}
Expand Down Expand Up @@ -376,11 +377,108 @@ def _apply_join(

_apply_join_projection(
query,
schema=schema,
left_table=left_table,
target_table=right_table,
target_columns=schema.tables[right_table]["columns"],
target_qualifier=target_qualifier,
projection_prefix=projection_prefix,
allow_existing_target_projection=not join_params,
)
return query


def _rewrite_on_qualifiers(
on_expr: sge.Expression,
target_table: str,
internal_alias: str,
) -> sge.Expression:
"""Rewrite column qualifiers in the ON expression that reference the target table.

The user writes ``on="users.id = orders.user_id"`` using logical table names.
Once the target is aliased internally, those references must point to the alias
so the SQL engine can resolve them.
"""
on_expr = on_expr.copy()
for col in on_expr.find_all(sge.Column):
table_node = col.args.get("table")
if isinstance(table_node, sge.Identifier) and table_node.name == target_table:
table_node.set("this", internal_alias)
return on_expr


def _apply_explicit_join(
expression: sge.Query,
*,
target: Optional["Relation"] = None,
target_table: str,
target_dataset_name: Optional[str],
target_columns: TTableSchemaColumns,
on: Union[str, sge.Expression],
projection_prefix: str,
kind: "TJoinType",
destination_dialect: TSqlGlotDialect,
) -> sge.Select:
"""Apply an explicit-ON join to ``expression`` and return the new query.

Args:
expression: Left-side query to join onto.
target: Right-hand Relation object (if transformed/subquery), or None for
string / base-table targets.
target_table: Bare table name for schema lookups and projection.
target_dataset_name: Foreign dataset qualifier, or None for local.
target_columns: Columns from the right-hand side for projection.
on: Join condition as a SQL string or sqlglot expression.
projection_prefix: Prefix for appended column aliases.
kind: SQL join type.
destination_dialect: Dialect for parsing string ON expressions.
"""
query = expression.copy()
if not isinstance(query, sge.Select):
raise ValueError(f"Join query `{query}` must be an SQL SELECT statement.")

internal_alias = f"_dlt_jt_{projection_prefix}"

# build target expression
target_expr: sge.Expression
if target is not None:
# transformed Relation -> subquery (preserves WHERE, SELECT, etc.)
target_expr = sge.Subquery(
this=target.sqlglot_expression,
alias=sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)),
)
else:
# base-table target (Relation with _table_name, or str)
table_node_args: dict[str, sge.Expression] = {
"this": sge.to_identifier(target_table, quoted=True),
"alias": sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)),
}
if target_dataset_name:
table_node_args["db"] = sge.to_identifier(target_dataset_name, quoted=False)
target_expr = sge.Table(**table_node_args)

if isinstance(on, str):
on_expr = sqlglot.parse_one(on, dialect=destination_dialect)
else:
on_expr = on

on_expr = _rewrite_on_qualifiers(on_expr, target_table, internal_alias)

join_expr = sge.Join(this=target_expr, kind=kind.upper()).on(on_expr)
query = query.join(join_expr)

from_expr = query.args.get("from_") or query.args.get("from")
if not isinstance(from_expr, sge.From) or not isinstance(from_expr.this, sge.Table):
raise ValueError(
"Cannot apply explicit join: left-side query must have a base table "
"in its FROM clause (not a subquery or derived table)."
)
left_table = from_expr.this.this.name

_apply_join_projection(
query,
left_table=left_table,
target_columns=target_columns,
target_qualifier=internal_alias,
projection_prefix=projection_prefix,
allow_existing_target_projection=False,
)
return query
16 changes: 13 additions & 3 deletions dlt/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(
self._default_schema_name: Optional[str] = None
self._resolved: bool = False

self._foreign_schemas: Dict[str, List[dlt.Schema]] = {}

self._sql_client: SqlClientBase[Any] = None
self._opened_sql_client: SqlClientBase[Any] = None
self._table_client: SupportsOpenTables = None
Expand Down Expand Up @@ -158,14 +160,22 @@ def _ipython_key_completions_(self) -> list[str]:
"""Provide table names as completion suggestion in interactive environments."""
return self.tables

def _add_foreign_schemas(self, dataset_name: str, schemas: Sequence[dlt.Schema]) -> None:
"""Register schemas from a foreign dataset for cross-dataset joins."""
if dataset_name == self.dataset_name:
return
self._foreign_schemas[dataset_name] = list(schemas)

@property
def sqlglot_schema(self) -> SQLGlotSchema:
"""SQLGlot schema of the dataset derived from all dlt schemas."""
# NOTE: no cache for now, it is probably more expensive to compute the current schema hash
# to see wether this is stale than to compute a new sqlglot schema
return lineage.create_sqlglot_schema(
{self.dataset_name: list(self.schemas)}, dialect=self.destination_dialect
)
schema_map: Dict[str, Sequence[dlt.Schema]] = {
self.dataset_name: list(self.schemas),
**self._foreign_schemas,
}
return lineage.create_sqlglot_schema(schema_map, dialect=self.destination_dialect)

@property
def destination_dialect(self) -> TSqlGlotDialect:
Expand Down
Loading
Loading