diff --git a/dlt/dataset/_incremental.py b/dlt/dataset/_incremental.py new file mode 100644 index 0000000000..f0c283bb9c --- /dev/null +++ b/dlt/dataset/_incremental.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Type, TYPE_CHECKING + +import sqlglot.expressions as sge +from jsonpath_ng.exceptions import JSONPathError + +from dlt.common.jsonpath import extract_simple_field_name +from dlt.common.libs.sqlglot import build_typed_literal, to_sqlglot_type +from dlt.common.schema.typing import TTableSchemaColumns + +if TYPE_CHECKING: + from dlt.extract.incremental import Incremental + + +_AGG_CURSOR_ALIAS = "__dlt_inc_cursor" + + +@dataclass(frozen=True) +class _RelationIncrementalContext: + """Private per-relation marker tying a `Relation` back to its `Incremental`. + + Set by `Relation.incremental()` and consumed by downstream lifecycle + code (e.g. `dlthub` transformations) that needs to advance the cursor + state after the relation executes. + """ + + incremental: Incremental[Any] + cursor_column: sge.Column + + +def _build_incremental_aggregate( + base_query: sge.Query, + ctx: _RelationIncrementalContext, +) -> sge.Select: + """Build `SELECT (alias) FROM (SELECT cursor AS alias FROM )`. + + Bare cursor: wraps the base query as a subquery so projections, GROUP BY, + HAVING and aliased computed cursors are preserved. Qualified cursor + (`table.column`, from an auto-join): replaces the base query's projection + list inline so the join qualifier resolves. + """ + if ctx.incremental.end_value is None and ( + base_query.args.get("limit") is not None or base_query.args.get("order") is not None + ): + raise ValueError( + "LIMIT and ORDER BY aren't supported on stateful `.incremental()` as " + "state would advance past only the returned rows, silently skipping " + "the rest on the next run. Remove them, or set `end_value=` for a " + "bounded read." + ) + + cursor_alias = sge.to_identifier(_AGG_CURSOR_ALIAS, quoted=True) + if ctx.cursor_column.table: + inner = base_query.copy() + inner.set( + "expressions", + [sge.Alias(this=ctx.cursor_column.copy(), alias=cursor_alias)], + ) + else: + bare_cursor = sge.Column(this=ctx.cursor_column.this.copy()) + inner = sge.Select( + expressions=[sge.Alias(this=bare_cursor, alias=cursor_alias)] + ).from_(base_query.copy().subquery()) + + agg_cls: Type[sge.AggFunc] + if ctx.incremental.last_value_func is max: + agg_cls = sge.Max + elif ctx.incremental.last_value_func is min: + agg_cls = sge.Min + else: + raise ValueError( + "Incremental aggregate can only be built for `min` or `max` " + f"`last_value_func`, got {ctx.incremental.last_value_func!r}." + ) + + outer_ref = sge.Column(this=cursor_alias.copy()) + return sge.Select(expressions=[agg_cls(this=outer_ref)]).from_(inner.subquery()) + + +def _parse_incremental_cursor_path(cursor_path: str) -> Tuple[Optional[str], str]: + """Split `table.column` into parts, or return `(None, column)` for a bare field. + + Rejects JSONPath expressions (wildcards, array indices, `$` root markers) that + cannot be pushed down to SQL. + """ + if not cursor_path: + raise ValueError("Incremental `cursor_path` must be a non-empty string.") + + if any(ch in cursor_path for ch in ("$", "[", "*")): + raise ValueError( + f"Incremental `cursor_path={cursor_path!r}` is a JSONPath expression. " + "`Relation.incremental()` only supports plain `column` or `table.column` cursors." + ) + + invalid_msg = ( + f"Incremental `cursor_path={cursor_path!r}` is not a plain column identifier. " + "Use `column` or `table.column`." + ) + + if "." in cursor_path: + table_part, column_part = cursor_path.rsplit(".", 1) + if not table_part: + raise ValueError(invalid_msg) + else: + table_part, column_part = None, cursor_path + + try: + column_name = extract_simple_field_name(column_part) + except JSONPathError as e: + raise ValueError(invalid_msg) from e + if column_name is None: + raise ValueError(invalid_msg) + return table_part, column_name + + +def _build_incremental_condition( + incremental: Incremental[Any], + column_ref: sge.Column, + sqlglot_type: Optional[sge.DataType], +) -> Optional[sge.Expression]: + """Build the WHERE condition for an Incremental cursor on `column_ref`. + + Operator matrix (closed/open bounds): + + - `max` + closed start -> `>=`, open start -> `>` + - `max` + open end -> `<`, closed end -> `<=` + - `min` + closed start -> `<=`, open start -> `<` + - `min` + open end -> `>`, closed end -> `>=` + + Args: + incremental (Incremental): The incremental carrying cursor bounds, range, and + `on_cursor_value_missing` policy. + column_ref (sge.Column): Reference to the cursor column in the target query. + sqlglot_type (Optional[sge.DataType]): SQLGlot data type used to CAST the + bound literals; pass `None` to skip casting. + + Returns: + Optional[sge.Expression]: A boolean expression ready to be attached via + `.where(...)`, or `None`. + + Raises: + ValueError: If `incremental.last_value_func` is not `min` or `max`, or if + `on_cursor_value_missing` is not one of `"include"`, `"exclude"`, `"raise"`. + """ + last_value_func = incremental.last_value_func + start_op_cls: Type[sge.Binary] + end_op_cls: Type[sge.Binary] + if last_value_func is max: + start_op_cls = sge.GTE if incremental.range_start == "closed" else sge.GT + end_op_cls = sge.LT if incremental.range_end == "open" else sge.LTE + elif last_value_func is min: + start_op_cls = sge.LTE if incremental.range_start == "closed" else sge.LT + end_op_cls = sge.GT if incremental.range_end == "open" else sge.GTE + else: + raise ValueError( + f"Incremental `last_value_func={last_value_func!r}` cannot be pushed " + "down to SQL. Only `min` and `max` are supported by `Relation.incremental()`." + ) + + on_missing = incremental.on_cursor_value_missing + if on_missing not in ("include", "exclude", "raise"): + raise ValueError( + "Incremental `on_cursor_value_missing=" + f"{on_missing!r}` is not supported by " + "`Relation.incremental()`. Expected one of: 'include', 'exclude', 'raise'." + ) + + start_value = incremental.last_value + end_value = incremental.end_value + + bounds: Optional[sge.Expression] = None + if start_value is not None: + start_literal = build_typed_literal(start_value, sqlglot_type) + bounds = start_op_cls(this=column_ref.copy(), expression=start_literal) + + if end_value is not None: + end_literal = build_typed_literal(end_value, sqlglot_type) + end_condition: sge.Expression = end_op_cls(this=column_ref.copy(), expression=end_literal) + bounds = end_condition if bounds is None else sge.And(this=bounds, expression=end_condition) + + if on_missing == "include": + if bounds is None: + return None + is_null = sge.Is(this=column_ref.copy(), expression=sge.Null()) + return sge.Or(this=bounds, expression=is_null) + + # "exclude" or "raise" both pin nulls out via IS NOT NULL. + # "raise" can't raise mid-query in SQL pushdown; so we warn users + is_not_null = sge.Not(this=sge.Is(this=column_ref.copy(), expression=sge.Null())) + if bounds is None: + return is_not_null + return sge.And(this=bounds, expression=is_not_null) + + +def _maybe_warn_on_cursor_missing_raise( + incremental: Incremental[Any], + columns_schema: TTableSchemaColumns, + column_name: str, +) -> None: + """Warn when `on_cursor_value_missing="raise"` is bound against a nullable cursor.""" + if incremental.on_cursor_value_missing != "raise": + return + column_schema = columns_schema.get(column_name) or {} + if column_schema.get("nullable") is False: + return + warnings.warn( + "Can't raise on NULL cursor values; rows with NULL " + "cursors will be excluded. Set on_cursor_value_missing explicitly " + "to silence.", + UserWarning, + stacklevel=3, + ) + + +def _sqlglot_type_for_column( + columns: TTableSchemaColumns, column_name: str +) -> Optional[sge.DataType]: + """Resolve the SQLGlot data type for `column_name` from a dlt columns schema.""" + column_schema = columns.get(column_name) + if not column_schema: + return None + data_type = column_schema.get("data_type") + if data_type is None: + return None + return to_sqlglot_type( + dlt_type=data_type, + precision=column_schema.get("precision"), + timezone=column_schema.get("timezone"), + nullable=column_schema.get("nullable"), + ) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index 240116fec9..fa60bdc33f 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -265,6 +265,26 @@ def _discover_join_params( return joins, target_qualifier +def _normalize_left_projection(query: sge.Select, left_table: str) -> list[sge.Expression]: + """Qualify the left-side projection so an added JOIN can't leak right-side columns. + + Bare `Star` becomes `.*`; unqualified `Column`s get their + `table` set to ``. + """ + origin_identifier = sge.to_identifier(left_table, quoted=False) + normalized: list[sge.Expression] = [] + for expr in query.selects: + if isinstance(expr, sge.Star): + normalized.append(sge.Column(table=origin_identifier, this=sge.Star())) + elif isinstance(expr, sge.Column) and expr.args.get("table") is None: + expr_copy = expr.copy() + expr_copy.set("table", origin_identifier) + normalized.append(expr_copy) + else: + normalized.append(expr) + return normalized + + def _apply_join_projection( query: sge.Select, *, @@ -275,28 +295,17 @@ def _apply_join_projection( projection_prefix: str, allow_existing_target_projection: bool, ) -> None: - """Apply join projection contract onto ``query``. + """Apply join projection contract onto `query`. Preserves the left-side projection and appends only columns from the explicitly - joined ``target_table`` as ``{projection_prefix}__{column}`` aliases. + joined `target_table` as `{projection_prefix}__{column}` aliases. - ``allow_existing_target_projection`` is used for idempotent re-joins: when a + `allow_existing_target_projection` is used for idempotent re-joins: when a join call contributes no new join edges, all target-prefixed columns may already exist in the left projection and should be accepted as a no-op instead of raising a collision error. """ - # Unbound columns must refer to the origin table so bind them to it - origin_identifier = sge.to_identifier(left_table, quoted=False) - normalized_left_expressions: list[sge.Expression] = [] - for expr in query.selects: - if isinstance(expr, sge.Star): - normalized_left_expressions.append(sge.Column(table=origin_identifier, this=sge.Star())) - elif isinstance(expr, sge.Column) and expr.args.get("table") is None: - expr_copy = expr.copy() - expr_copy.set("table", origin_identifier) - normalized_left_expressions.append(expr_copy) - else: - normalized_left_expressions.append(expr) + normalized_left_expressions = _normalize_left_projection(query, left_table) existing_projection_column_names = { expr.output_name @@ -343,8 +352,14 @@ def _apply_join( right_table: str, projection_prefix: str, kind: TJoinType = "inner", + project: bool = True, ) -> sge.Select: - """Apply schema-driven join(s) to ``expression`` and return the new query.""" + """Apply schema-driven join(s) to `expression` and return the new query. + + When `project` is `False` the JOIN is added to the query but the SELECT + list is left untouched (filter-only join). Use this for join targets whose + columns must be referenced in WHERE/ON predicates without being projected. + """ if left_table not in schema.tables: raise ValueError(f"Table `{left_table}` not found in dataset schema") if right_table not in schema.tables: @@ -374,13 +389,18 @@ def _apply_join( ) query = query.join(join_expr) - _apply_join_projection( - query, - schema=schema, - left_table=left_table, - target_table=right_table, - target_qualifier=target_qualifier, - projection_prefix=projection_prefix, - allow_existing_target_projection=not join_params, - ) + if project: + _apply_join_projection( + query, + schema=schema, + left_table=left_table, + target_table=right_table, + target_qualifier=target_qualifier, + projection_prefix=projection_prefix, + allow_existing_target_projection=not join_params, + ) + else: + # Filter-only join: qualify the left projection so a bare `*` does not + # expand across the joined table and leak right-side columns at runtime. + query.set("expressions", _normalize_left_projection(query, left_table)) return query diff --git a/dlt/dataset/dataset.py b/dlt/dataset/dataset.py index f336a4e128..087d91e4d6 100644 --- a/dlt/dataset/dataset.py +++ b/dlt/dataset/dataset.py @@ -47,6 +47,7 @@ if TYPE_CHECKING: from dlt.common.libs.ibis import ir from dlt.common.libs.ibis import BaseBackend as IbisBackend + from dlt.extract.incremental import Incremental class Dataset: @@ -262,9 +263,22 @@ def __call__( return self.query(query, query_dialect, _execute_raw_query=_execute_raw_query) def table( - self, table_name: str, *, load_ids: Optional[Collection[str]] = None, **kwargs: Any + self, + table_name: str, + *, + load_ids: Optional[Collection[str]] = None, + incremental: Optional[Incremental[Any]] = None, + **kwargs: Any, ) -> dlt.Relation: - """Get a `dlt.Relation` associated with a table from the dataset.""" + """Get a `dlt.Relation` associated with a table from the dataset. + + Args: + table_name (str): Name of the table in the dataset schema. + load_ids (Optional[Collection[str]]): If provided, restrict rows to the + given load ids via `Relation.from_loads()`. + incremental (Optional[Incremental[Any]]): If provided, apply the cursor + range as a `WHERE` clause via `Relation.incremental()`. + """ if table_name not in self.tables: # TODO: raise TableNotFound raise ValueError(f"Table `{table_name}` not found. Available table(s): {self.tables}") @@ -277,10 +291,12 @@ def table( " Ibis Table." ) + relation = dlt.Relation(dataset=self, table_name=table_name) if load_ids: - return dlt.Relation(dataset=self, table_name=table_name).from_loads(load_ids) - else: - return dlt.Relation(dataset=self, table_name=table_name) + relation = relation.from_loads(load_ids) + if incremental is not None: + relation = relation.incremental(incremental) + return relation def loads_table(self) -> dlt.Relation: """Get `_dlt_loads` table from the dataset.""" diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 13988f8644..f1232182c6 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -37,12 +37,21 @@ from dlt.destinations.sql_client import SqlClientBase, WithSqlClient from dlt.destinations.queries import bind_query, build_select_expr from dlt.common.destination.dataset import SupportsDataAccess -from dlt.dataset._join import _apply_join +from dlt.dataset._incremental import ( + _build_incremental_aggregate, + _build_incremental_condition, + _maybe_warn_on_cursor_missing_raise, + _parse_incremental_cursor_path, + _RelationIncrementalContext, + _sqlglot_type_for_column, +) +from dlt.dataset._join import _apply_join, _extract_joined_table_aliases if TYPE_CHECKING: from dlt.common.libs.ibis import ir from dlt.common.libs.pandas import pandas as pd + from dlt.extract.incremental import Incremental from dlt.common.libs.pyarrow import pyarrow as pa from dlt.helpers.ibis import Expr as IbisExpr @@ -106,6 +115,7 @@ def __init__( self._opened_sql_client: SqlClientBase[Any] = None self._sqlglot_expression: sge.Query = None self._schema: Optional[TTableSchemaColumns] = None + self._incremental_ctx: Optional[_RelationIncrementalContext] = None def df(self, *args: Any, **kwargs: Any) -> pd.DataFrame | None: with self._cursor() as cursor: @@ -438,6 +448,134 @@ def join( rel._sqlglot_expression = query return rel + def incremental(self, incremental: Incremental[Any]) -> Self: + """Filter this relation to a cursor range using an Incremental. + + Translates the `Incremental` bounds (`initial_value`/`end_value`, `range_start`/ + `range_end`, `last_value_func`) into a SQL `WHERE` clause. When the cursor + path is `table.column`, joins the referenced table via the dataset schema + without adding its columns to the projection, then filters on the joined + column. If the target is already joined, the existing JOIN is reused. + + Args: + incremental (Incremental[Any]): The incremental whose cursor path and + range define the filter. `last_value_func` must be `min` or `max`. + + Returns: + Self: A new relation with the incremental filter applied. + + Raises: + ValueError: If the cursor path is a JSONPath expression, if a dotted + cursor is used on a relation whose FROM is not a bare base table + (e.g. after `.from_loads()` or on a `.query(...)` relation), if + the referenced table is not in the dataset schema, or if + `last_value_func` is not `min` or `max`. + + Notes: + Aggregate (GROUP BY) cursors with `range_start="open"`: late + rows for already-emitted buckets are silently dropped. Set + `lag=N` to widen the lower bound and let `merge` overwrite + stale totals. + + Scheduler mode (`end_value` set): align window bounds to + bucket boundaries (misaligned bounds shift the result by a + bucket), use non-overlapping windows only, and re-run the + affected window manually to repair late data — `lag=` is + suppressed in this mode. + """ + if self._incremental_ctx is not None: + raise ValueError( + "`.incremental()` has already been applied to this relation with " + f"cursor `{self._incremental_ctx.incremental.cursor_path}`." + ) + + table_name, column_name = _parse_incremental_cursor_path(incremental.cursor_path) + + if table_name is None: + column_ref = sge.Column(this=sge.to_identifier(column_name, quoted=True)) + sqlglot_type = _sqlglot_type_for_column(self.columns_schema, column_name) + _maybe_warn_on_cursor_missing_raise(incremental, self.columns_schema, column_name) + condition = _build_incremental_condition(incremental, column_ref, sqlglot_type) + rel = self.__copy__() + if condition is not None: + rel._sqlglot_expression = rel.sqlglot_expression.where(condition) + rel._incremental_ctx = _RelationIncrementalContext( + incremental=incremental, + cursor_column=column_ref.copy(), + ) + return rel + + if not self._table_name: + raise ValueError( + f"Incremental cursor `{incremental.cursor_path}` references table " + f"`{table_name}` but the relation has no base table to resolve joins. " + "Call `.incremental()` on `dataset.table(...)`, not on a `.query(...)`." + ) + if table_name not in self._dataset.schema.tables: + raise ValueError( + f"Incremental cursor target table `{table_name}` not found in dataset schema." + ) + if self._table_name not in _extract_joined_table_aliases(self.sqlglot_expression): + raise ValueError( + f"Incremental cursor `{incremental.cursor_path}` requires a " + f"base-table relation to resolve the join to `{table_name}`. " + f"This relation is derived from base table `{self._table_name}` " + "(e.g. via `.from_loads()`, `dataset.table(load_ids=...)`, or " + "`.select()`), so a dotted cursor cannot be applied. Use a " + f"cursor on a column of `{self._table_name}` instead, or drop " + "the derivation." + ) + + query = _apply_join( + self.sqlglot_expression, + schema=self._dataset.schema, + left_table=self._table_name, + right_table=table_name, + projection_prefix=table_name, + kind="inner", + project=False, + ) + qualifier_map = _extract_joined_table_aliases(query) + target_qualifier = qualifier_map[table_name] + + column_ref = sge.Column( + this=sge.to_identifier(column_name, quoted=True), + table=sge.to_identifier(target_qualifier, quoted=False), + ) + target_columns = self._dataset.schema.tables[table_name].get("columns", {}) + sqlglot_type = _sqlglot_type_for_column(target_columns, column_name) + _maybe_warn_on_cursor_missing_raise(incremental, target_columns, column_name) + + condition = _build_incremental_condition(incremental, column_ref, sqlglot_type) + rel = self.__copy__() + rel._sqlglot_expression = query.where(condition) if condition is not None else query + rel._incremental_ctx = _RelationIncrementalContext( + incremental=incremental, + cursor_column=column_ref.copy(), + ) + return rel + + @property + def is_incremental(self) -> bool: + """True if any clause on this relation was produced by `.incremental()`.""" + return self._incremental_ctx is not None + + def _incremental_aggregate_relation(self) -> Optional[Self]: + """Return a relation computing `(cursor)` over this relation. + + Used by downstream lifecycle code (e.g. dlthub transformations) to advance + incremental state after the relation executes. Returns `None` when this + relation was not produced by `.incremental()`. + """ + if self._incremental_ctx is None: + return None + agg_query = _build_incremental_aggregate(self.sqlglot_expression, self._incremental_ctx) + rel = self.__copy__() + rel._sqlglot_expression = agg_query + # Derived relation — do not re-advance state from the aggregate itself. + rel._incremental_ctx = None + return rel + # NOTE we currently force to have one column selected; we could be more flexible # and rewrite the query to compute the AGG of all selected columns # `SELECT AGG(col1), AGG(col2), ... FROM table`` @@ -613,7 +751,9 @@ def with_load_id_col(self) -> dlt.Relation: self._dataset.schema.tables, self._table_name )["name"] if root_table_name == self._table_name: - raise ValueError(f"{root_table_name} is a root table, but load id column is not present.") + raise ValueError( + f"{root_table_name} is a root table, but load id column is not present." + ) join_alias = "_dlt_root" joined = self.join(root_table_name, alias=join_alias) @@ -721,6 +861,7 @@ def __repr__(self) -> str: def __copy__(self) -> Self: rel = self.__class__(dataset=self._dataset, query=self.sqlglot_expression) rel._table_name = self._table_name + rel._incremental_ctx = self._incremental_ctx return rel diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index 56ff85dcae..f8cd0a44d4 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -63,6 +63,7 @@ from dlt.extract.incremental.transform import ( JsonIncremental, ArrowIncremental, + ModelIncremental, IncrementalTransform, ) from dlt.extract.incremental.lag import apply_lag_with_suppression @@ -638,14 +639,25 @@ def _make_or_get_transformer(self, cls: Type[IncrementalTransform]) -> Increment self.range_start, self.range_end, ) + # ModelIncremental needs a back-reference so it can auto-apply + # `relation.incremental(self)` when the user yields a bare relation. + if isinstance(transformer, ModelIncremental): + transformer._incremental = self return transformer def _get_transform(self, items: TDataItems) -> IncrementalTransform: """Gets transform implementation that handles particular data item type""" + # Lazy import to avoid failure with a partially-initialised + # `dlt.extract` during dlt startup. + # TODO: we should consider creating a registry for transforms + from dlt.dataset.relation import Relation + # assume list is all of the same type for item in items if isinstance(items, list) else [items]: if is_arrow_object(item) or is_pandas_frame(item) or is_polars_frame(item): return self._make_or_get_transformer(ArrowIncremental) + elif isinstance(item, Relation): + return self._make_or_get_transformer(ModelIncremental) return self._make_or_get_transformer(JsonIncremental) return self._make_or_get_transformer(JsonIncremental) diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index c6bd9de59c..e0b0f40bfe 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -610,3 +610,52 @@ def _process_null_at_cursor_path(self, tbl: "pa.Table") -> Tuple["pa.Table", "pa if rows_with_null.num_rows > 0: raise IncrementalCursorPathHasValueNone(self.resource_name, self.cursor_path) return rows_without_null, rows_with_null + + +class ModelIncremental(IncrementalTransform): + """Incremental transform for `Relation` items. + + Filtering happens via SQL pushdown when `Relation.incremental(cursor)` is + applied. + + Modes: + - `end_value` is set: external scheduler/ephemeral: no aggregate, state is not + advanced from observed data. + - `range_start="open"`, no `end_value`: stateful open-range: aggregate runs + and `last_value` advances. Open range on the next run excludes the + boundary, so no deduplication is required. + - Otherwise (closed-range stateful): rejected as boundary deduplication via + `unique_hashes` cannot be reproduced from a single aggregate. + """ + + # Parent `Incremental` so we can auto-apply below + _incremental: Optional["Incremental[Any]"] = None # type: ignore[name-defined] # noqa: F821 + + def __call__(self, relation: TDataItem) -> Tuple[Optional[TDataItem], bool, bool]: + ctx = getattr(relation, "_incremental_ctx", None) + if ctx is None: + # Bare relation, no `.incremental()`. Auto-apply using the parent `Incremental` + relation = relation.incremental(self._incremental) + + if self.end_value is not None: + # External scheduler/ephemeral mode: state not advanced from observed data. + self.seen_data = True + return relation, False, False + + if self.range_start != "open": + raise ValueError( + f"Stateful incremental on resource '{self.resource_name}' over a " + "Relation requires `range_start='open'`. Closed-range semantics " + "rely on boundary deduplication via `unique_hashes`, which a " + "SQL aggregate cannot reproduce. Either set `range_start='open'` " + "on the Incremental, or provide `end_value=` for scheduler mode." + ) + + agg_rel = relation._incremental_aggregate_relation() + if agg_rel is not None: + new_value = agg_rel.fetchscalar() + if new_value is not None: + self.last_value = new_value + + self.seen_data = True + return relation, False, False diff --git a/tests/dataset/test_relation_incremental.py b/tests/dataset/test_relation_incremental.py new file mode 100644 index 0000000000..2ab6d11412 --- /dev/null +++ b/tests/dataset/test_relation_incremental.py @@ -0,0 +1,1009 @@ +from __future__ import annotations + +import pathlib +import warnings +from typing import Any, Iterator, Literal + +import pytest +from sqlglot import expressions as sge + +import dlt +from dlt.common.pendulum import pendulum +from dlt.extract.incremental.transform import ModelIncremental + + +EVENTS_LOAD_0 = [ + {"id": 1, "created_at": "2026-01-01T00:00:00+00:00", "value": 1.0}, + {"id": 2, "created_at": "2026-01-05T00:00:00+00:00", "value": 2.0}, + {"id": 3, "created_at": "2026-01-10T00:00:00+00:00", "value": 3.0}, +] +EVENTS_LOAD_1 = [ + {"id": 4, "created_at": "2026-01-15T00:00:00+00:00", "value": 4.0}, + {"id": 5, "created_at": "2026-01-20T00:00:00+00:00", "value": 5.0}, +] + +END_VALUE_DT = pendulum.datetime(2999, 1, 1, tz="UTC") +END_VALUE_ID = 10**12 + + +@pytest.fixture(scope="module") +def incremental_pipeline(module_tmp_path: pathlib.Path) -> dlt.Pipeline: + pipeline = dlt.pipeline( + pipeline_name="relation_incremental", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(str(module_tmp_path / "incremental.db")), + dev_mode=True, + ) + + @dlt.resource(name="events", primary_key="id", write_disposition="append") + def events(batch: int) -> Iterator[Any]: + if batch == 0: + yield EVENTS_LOAD_0 + else: + yield EVENTS_LOAD_1 + + pipeline.run(events(batch=0)) + pipeline.run(events(batch=1)) + return pipeline + + +@pytest.fixture(scope="module") +def incremental_dataset(incremental_pipeline: dlt.Pipeline) -> dlt.Dataset: + return incremental_pipeline.dataset() + + +@pytest.fixture(scope="module") +def dataset_with_incomplete_join_target(module_tmp_path: pathlib.Path) -> dlt.Dataset: + """Two sibling tables joined by an explicit reference, where the join target + declares an incomplete column hint via `columns=`. + + `phantom_field` is declared on `categories` with no `data_type`, so it never + materializes at the destination. `Schema.get_table_columns()` filters it out + via `is_complete_column`; raw `schema.tables[...]["columns"]` does not. + """ + pipeline = dlt.pipeline( + pipeline_name="relation_incremental_incomplete", + pipelines_dir=str(module_tmp_path / "pipelines_dir_incomplete"), + destination=dlt.destinations.duckdb(str(module_tmp_path / "incomplete.db")), + dev_mode=True, + ) + + @dlt.resource( + name="categories", + primary_key="id", + columns=[{"name": "phantom_field", "nullable": True}], + ) + def categories() -> Iterator[Any]: + yield [{"id": 1, "name": "alpha"}, {"id": 2, "name": "beta"}] + + @dlt.resource( + name="products", + primary_key="id", + columns=[{"name": "category_id", "data_type": "bigint"}], + references=[{ + "referenced_table": "categories", + "columns": ["category_id"], + "referenced_columns": ["id"], + }], + ) + def products() -> Iterator[Any]: + yield [ + {"id": 10, "category_id": 1}, + {"id": 11, "category_id": 2}, + {"id": 12, "category_id": 1}, + ] + + pipeline.run([categories(), products()]) + return pipeline.dataset() + + +def _where(relation: dlt.Relation) -> sge.Expression: + where_node = relation.sqlglot_expression.args.get("where") + assert isinstance(where_node, sge.Where), f"Expected WHERE clause, got {where_node!r}" + return where_node.this + + +def _column_name(expr: sge.Expression) -> str: + assert isinstance(expr, sge.Column), f"Expected Column, got {expr!r}" + return expr.args["this"].name + + +def _column_table(expr: sge.Expression) -> str | None: + assert isinstance(expr, sge.Column), f"Expected Column, got {expr!r}" + table = expr.args.get("table") + return table.name if table is not None else None + + +def _join_target_names(relation: dlt.Relation) -> list[str]: + joins = relation.sqlglot_expression.args.get("joins") or [] + names: list[str] = [] + for join in joins: + target = join.this + assert isinstance(target, sge.Table) + names.append(target.this.name) + return names + + +def test_incremental_emits_where_on_simple_cursor(incremental_dataset: dlt.Dataset) -> None: + incremental = dlt.sources.incremental("id", initial_value=2, end_value=END_VALUE_ID) + relation = incremental_dataset.table("events").incremental(incremental) + + condition = _where(relation) + assert isinstance(condition, sge.And) + bound_pair = condition.this + assert isinstance(bound_pair, sge.And) + assert isinstance(bound_pair.this, sge.GTE) + assert _column_name(bound_pair.this.this) == "id" + # no join is added for a simple cursor path + assert (relation.sqlglot_expression.args.get("joins") or []) == [] + + +def test_incremental_sets_is_incremental_flag(incremental_dataset: dlt.Dataset) -> None: + base = incremental_dataset.table("events") + assert base.is_incremental is False + + incremental = dlt.sources.incremental("id", initial_value=1, end_value=END_VALUE_ID) + flagged = base.incremental(incremental) + assert flagged.is_incremental is True + + # flag survives further chaining, context propagates through copies + chained = flagged.select("id", "value").where("value", "gt", 0) + assert chained.is_incremental is True + + # a plain where() never sets the flag + assert base.where("id", "gt", 1).is_incremental is False + + +def test_incremental_kwarg_on_table_equivalent_to_method( + incremental_dataset: dlt.Dataset, +) -> None: + incremental = dlt.sources.incremental("id", initial_value=2, end_value=END_VALUE_ID) + + via_kwarg = incremental_dataset.table( + "events", incremental=incremental + ).sqlglot_expression.sql() + via_method = ( + incremental_dataset.table("events").incremental(incremental).sqlglot_expression.sql() + ) + + assert via_kwarg == via_method + + +def test_incremental_returns_new_relation(incremental_dataset: dlt.Dataset) -> None: + base = incremental_dataset.table("events") + sql_before = base.sqlglot_expression.sql() + + incremental = dlt.sources.incremental("id", initial_value=2, end_value=END_VALUE_ID) + filtered = base.incremental(incremental) + + assert filtered is not base + assert base.sqlglot_expression.sql() == sql_before + assert filtered.sqlglot_expression.sql() != sql_before + + +@pytest.mark.parametrize( + "last_value_func,range_start,range_end,expected_start_cls,expected_end_cls", + [ + pytest.param("max", "closed", "open", sge.GTE, sge.LT, id="max-closed-open-default"), + pytest.param("max", "open", "closed", sge.GT, sge.LTE, id="max-open-closed"), + pytest.param("min", "closed", "open", sge.LTE, sge.GT, id="min-closed-open"), + pytest.param("min", "open", "closed", sge.LT, sge.GTE, id="min-open-closed"), + ], +) +def test_incremental_operators_matrix( + incremental_dataset: dlt.Dataset, + last_value_func: Literal["min", "max"], + range_start: Literal["open", "closed"], + range_end: Literal["open", "closed"], + expected_start_cls: type, + expected_end_cls: type, +) -> None: + incremental = dlt.sources.incremental( + "id", + initial_value=2, + end_value=4, + last_value_func=last_value_func, + range_start=range_start, + range_end=range_end, + ) + relation = incremental_dataset.table("events").incremental(incremental) + + condition = _where(relation) + assert isinstance(condition, sge.And) + bound_pair = condition.this + assert isinstance(bound_pair, sge.And) + start_op = bound_pair.this + end_op = bound_pair.expression + assert isinstance(start_op, expected_start_cls) + assert isinstance(end_op, expected_end_cls) + assert isinstance(start_op, sge.Binary) and isinstance(end_op, sge.Binary) + assert _column_name(start_op.this) == "id" + assert _column_name(end_op.this) == "id" + + +def test_incremental_datetime_cursor_renders_as_sql_literal( + incremental_dataset: dlt.Dataset, +) -> None: + ts = pendulum.datetime(2026, 1, 5, tz="UTC") + incremental = dlt.sources.incremental("created_at", initial_value=ts, end_value=END_VALUE_DT) + # `created_at` is nullable, below silence "raise" warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + relation = incremental_dataset.table("events").incremental(incremental) + + sql = relation.sqlglot_expression.sql(dialect=incremental_dataset.destination_dialect) + assert "2026-01-05" in sql + assert "DateTime(" not in sql + assert "datetime.datetime" not in sql + + +def test_incremental_dotted_cursor_auto_joins_target( + incremental_dataset: dlt.Dataset, +) -> None: + incremental = dlt.sources.incremental( + "_dlt_loads.inserted_at", + initial_value=pendulum.datetime(2026, 1, 1, tz="UTC"), + end_value=END_VALUE_DT, + ) + # _dlt_loads.inserted_at is `nullable=False` in the system schema, so the + # default "raise" policy stays silent here — no warnings.catch_warnings needed + relation = incremental_dataset.table("events").incremental(incremental) + + # exactly one JOIN added, targeting _dlt_loads + assert _join_target_names(relation) == ["_dlt_loads"] + + # bound pair is wrapped with AND IS NOT NULL by the default "raise" policy + condition = _where(relation) + assert isinstance(condition, sge.And) + bound_pair = condition.this + assert isinstance(bound_pair, sge.And) + start_op = bound_pair.this + assert isinstance(start_op, sge.Binary) + # WHERE column is qualified to the joined table + assert _column_name(start_op.this) == "inserted_at" + assert _column_table(start_op.this) == "_dlt_loads" + + +def test_incremental_dotted_cursor_does_not_pollute_projection( + incremental_dataset: dlt.Dataset, +) -> None: + # end-only: valid unbound mode, last_value is None -> single LT condition, + # enough to trigger the auto-join without needing a start bound. + incremental: dlt.sources.incremental[Any] = dlt.sources.incremental( + "_dlt_loads.inserted_at", end_value=END_VALUE_DT + ) + relation = incremental_dataset.table("events").incremental(incremental) + + # no column from _dlt_loads appears in the SELECT list — the auto-join + # is filter-only (project=False path). + selects = relation.sqlglot_expression.selects + output_names = [expr.output_name for expr in selects] + assert not any(name.startswith("_dlt_loads__") for name in output_names) + + +def test_incremental_dotted_cursor_runtime_columns_base_only( + incremental_dataset: dlt.Dataset, +) -> None: + incremental: dlt.sources.incremental[Any] = dlt.sources.incremental( + "_dlt_loads.inserted_at", + initial_value=pendulum.datetime(2026, 1, 1, tz="UTC"), + end_value=END_VALUE_DT, + ) + relation = incremental_dataset.table("events").incremental(incremental) + + expected_columns = set(incremental_dataset.table("events").columns) + assert set(relation.columns) == expected_columns + assert not any(c.startswith("_dlt_loads__") for c in relation.columns) + + row = relation.fetchone() + assert row is not None + assert len(row) == len(relation.columns) + + +@pytest.mark.xfail( + strict=True, + reason=( + "Bug: `Relation.incremental` dotted-cursor branch reads target columns via raw" + " `schema.tables[name]['columns']`, accepting incomplete column hints (no" + " `data_type`) as cursors. The resulting WHERE references a column that" + " doesn't exist on the destination — materialization fails at lineage." + " Fix: source target columns via `Schema.get_table_columns(table_name)` and" + " reject `.incremental()` on cursors that aren't materialized." + ), +) +def test_incremental_dotted_cursor_rejects_incomplete_target_column( + dataset_with_incomplete_join_target: dlt.Dataset, +) -> None: + """An incomplete (declared but unmaterialized) cursor column must not produce + a relation that emits SQL referencing a non-existent column. Materializing + the relation against duckdb is the source of truth. + """ + incremental = dlt.sources.incremental( + "categories.phantom_field", + initial_value=0, + end_value=10**12, + on_cursor_value_missing="exclude", + ) + relation = dataset_with_incomplete_join_target.table("products").incremental(incremental) + # Hard failure today: lineage rejects "Unknown column: phantom_field" because + # the SQLGlot schema filters incomplete columns but the WHERE built by + # `.incremental()` does not — they disagree, lineage raises. + relation.fetchall() + + +@pytest.mark.xfail( + strict=True, + reason=( + "Bug: `_apply_join_projection` reads `schema.tables[target]['columns']` raw" + " and aliases every column into the SELECT — including incomplete columns" + " (no `data_type`) that don't exist on the destination. Fix: source columns" + " via `Schema.get_table_columns(target)` so incomplete hints are filtered" + " out of the projection." + ), +) +def test_join_does_not_project_incomplete_target_columns( + dataset_with_incomplete_join_target: dlt.Dataset, +) -> None: + """`relation.join(other)` must not emit projection aliases for columns that + are declared as hints but were never materialized. Materializing the join is + the source of truth: today it raises `LineageFailedException` because the + projected `categories__phantom_field` has no underlying column. + """ + relation = dataset_with_incomplete_join_target.table("products").join("categories") + rows = relation.fetchall() + assert rows is not None + # 3 products inner-joined to 2 categories on category_id → 3 rows + assert len(rows) == 3 + + +def test_incremental_dotted_cursor_reuses_existing_join( + incremental_dataset: dlt.Dataset, +) -> None: + """An explicit .join() before .incremental() on the same target should + not be duplicated — the WHERE latches onto the existing qualifier. + """ + pre_joined = incremental_dataset.table("events").join("_dlt_loads") + existing_targets = _join_target_names(pre_joined) + assert existing_targets.count("_dlt_loads") == 1 + + incremental: dlt.sources.incremental[Any] = dlt.sources.incremental( + "_dlt_loads.inserted_at", end_value=END_VALUE_DT + ) + relation = pre_joined.incremental(incremental) + + assert _join_target_names(relation).count("_dlt_loads") == 1 + + +def test_incremental_aggregate_on_simple_cursor(incremental_dataset: dlt.Dataset) -> None: + """`_incremental_aggregate_relation` returns the MAX cursor over the filter.""" + incremental = dlt.sources.incremental("id", initial_value=2, end_value=END_VALUE_ID) + relation = incremental_dataset.table("events").incremental(incremental) + # max id across EVENTS_LOAD_0 + EVENTS_LOAD_1 with id >= 2 is 5 + assert relation._incremental_aggregate_relation().fetchscalar() == 5 + + +def test_incremental_aggregate_on_dotted_cursor(incremental_dataset: dlt.Dataset) -> None: + incremental: dlt.sources.incremental[Any] = dlt.sources.incremental( + "_dlt_loads.inserted_at", + initial_value=pendulum.datetime(2026, 1, 1, tz="UTC"), + end_value=END_VALUE_DT, + ) + relation = incremental_dataset.table("events").incremental(incremental) + # exact value depends on load timing, but a MAX of inserted_at should be non-null + agg_value = relation._incremental_aggregate_relation().fetchscalar() + assert agg_value is not None + + +def test_incremental_aggregate_returns_none_when_not_incremental( + incremental_dataset: dlt.Dataset, +) -> None: + not_incremental = incremental_dataset.table("events") + assert not_incremental._incremental_aggregate_relation() is None + + +def test_incremental_aggregate_honors_min(incremental_dataset: dlt.Dataset) -> None: + """`last_value_func=min` flips the aggregate to SQL `MIN`.""" + # for min: closed start -> `<=`, closed end -> `>=`. Window [0, 5] contains ids 1-5. + incremental = dlt.sources.incremental( + "id", + initial_value=5, + end_value=0, + last_value_func="min", + range_end="closed", + ) + relation = incremental_dataset.table("events").incremental(incremental) + assert relation._incremental_aggregate_relation().fetchscalar() == 1 + + +def test_incremental_aggregate_on_query_with_group_by(incremental_dataset: dlt.Dataset) -> None: + incremental = dlt.sources.incremental( + "day", + initial_value=pendulum.datetime(2000, 1, 1, tz="UTC"), + end_value=END_VALUE_DT, + ) + sql = ( + "SELECT CAST(date_trunc('day', created_at) AS TIMESTAMP WITH TIME ZONE) AS day," + " COUNT(*) AS total FROM events GROUP BY day" + ) + relation = incremental_dataset(sql).incremental(incremental) + assert relation._incremental_aggregate_relation().fetchscalar() == pendulum.datetime( + 2026, 1, 20, tz="UTC" + ) + + +def test_incremental_aggregate_on_query_relation_bare_cursor( + incremental_dataset: dlt.Dataset, +) -> None: + incremental = dlt.sources.incremental("id", initial_value=2, end_value=END_VALUE_ID) + relation = incremental_dataset("SELECT id, value FROM events WHERE value > 0").incremental( + incremental + ) + assert relation._incremental_aggregate_relation().fetchscalar() == 5 + + +def test_incremental_aggregate_preserves_distinct(incremental_dataset: dlt.Dataset) -> None: + incremental = dlt.sources.incremental("id", initial_value=2, end_value=END_VALUE_ID) + relation = incremental_dataset("SELECT DISTINCT id FROM events").incremental(incremental) + assert relation._incremental_aggregate_relation().fetchscalar() == 5 + + +def test_incremental_aggregate_branches_on_cursor_qualifier( + incremental_dataset: dlt.Dataset, +) -> None: + bare = dlt.sources.incremental("id", initial_value=0, end_value=END_VALUE_ID) + bare_rel = incremental_dataset.table("events").incremental(bare) + bare_agg = bare_rel._incremental_aggregate_relation().sqlglot_expression + bare_inner_subq = bare_agg.args["from_"].this + assert isinstance(bare_inner_subq, sge.Subquery) + bare_inner_select = bare_inner_subq.this + bare_inner_from = bare_inner_select.args["from_"].this + assert isinstance( + bare_inner_from, sge.Subquery + ), "Bare cursor: base query must be wrapped as a subquery" + + dotted = dlt.sources.incremental( + "_dlt_loads.inserted_at", + initial_value=pendulum.datetime(2026, 1, 1, tz="UTC"), + end_value=END_VALUE_DT, + ) + dotted_rel = incremental_dataset.table("events").incremental(dotted) + dotted_agg = dotted_rel._incremental_aggregate_relation().sqlglot_expression + dotted_inner_subq = dotted_agg.args["from_"].this + assert isinstance(dotted_inner_subq, sge.Subquery) + dotted_inner_select = dotted_inner_subq.this + dotted_inner_from = dotted_inner_select.args["from_"].this + assert isinstance( + dotted_inner_from, sge.Table + ), "Qualified cursor: inline-projection path must keep the base table in FROM" + assert dotted_inner_select.args.get( + "joins" + ), "Qualified cursor: JOIN must be preserved so the qualifier still resolves" + + +@pytest.mark.parametrize( + "shape", + [ + pytest.param(lambda r: r.limit(2), id="limit-only"), + pytest.param(lambda r: r.order_by("id", "desc"), id="order-by-only"), + pytest.param(lambda r: r.order_by("id").limit(2), id="order-by-limit"), + ], +) +def test_incremental_aggregate_rejects_limit_or_order_by_in_stateful_mode( + incremental_pipeline: dlt.Pipeline, shape: Any +) -> None: + # In stateful mode (no end_value), LIMIT/ORDER BY would advance state past + # only the returned rows. Rejected so callers can't silently skip rows. + # Empty yield -> no rows pass the pipe step -> state never advances, so a + # fixed resource name is safe to reuse across params. + dataset = incremental_pipeline.dataset() + captured: dlt.Relation | None = None + + @dlt.resource(name="probe_reject") + def probe( + cursor: dlt.sources.incremental[int] = dlt.sources.incremental( + "id", initial_value=0, range_start="open" + ), + ) -> Iterator[Any]: + nonlocal captured + captured = shape(dataset.table("events").incremental(cursor)) + yield from [] + + incremental_pipeline.extract(probe()) + assert captured is not None + with pytest.raises(ValueError, match="LIMIT and ORDER BY aren't supported"): + captured._incremental_aggregate_relation() + + +def test_incremental_inside_resource_captures_bound_sql( + incremental_pipeline: dlt.Pipeline, +) -> None: + dataset = incremental_pipeline.dataset() + captured: dlt.Relation | None = None + + @dlt.resource(name="probe_simple_cursor") + def probe( + cursor: dlt.sources.incremental[int] = dlt.sources.incremental("id", initial_value=2), + ) -> Iterator[Any]: + nonlocal captured + captured = dataset.table("events").incremental(cursor) + yield from [] + + incremental_pipeline.extract(probe()) + assert captured is not None + condition = _where(captured) + assert isinstance(condition, sge.And) + start_op = condition.this + assert isinstance(start_op, sge.GTE) + assert _column_name(start_op.this) == "id" + + +def test_incremental_custom_last_value_func_raises( + incremental_dataset: dlt.Dataset, +) -> None: + """Only `min` and `max` can be pushed down to SQL; custom callables can't.""" + incremental = dlt.sources.incremental("id", initial_value=1, last_value_func=lambda xs: max(xs)) + with pytest.raises(ValueError, match="last_value_func"): + incremental_dataset.table("events").incremental(incremental) + + +def test_incremental_unknown_dotted_target_raises( + incremental_dataset: dlt.Dataset, +) -> None: + incremental = dlt.sources.incremental("not_a_table.ts", initial_value=1) + with pytest.raises(ValueError, match="not found in dataset schema"): + incremental_dataset.table("events").incremental(incremental) + + +def test_incremental_dotted_cursor_on_query_relation_raises( + incremental_dataset: dlt.Dataset, +) -> None: + """Dotted cursors need a base-table relation to resolve the join chain.""" + query_relation = incremental_dataset.query("SELECT * FROM events") + incremental = dlt.sources.incremental( + "_dlt_loads.inserted_at", + initial_value=pendulum.datetime(2026, 1, 1, tz="UTC"), + end_value=END_VALUE_DT, + ) + with pytest.raises(ValueError, match="no base table"): + query_relation.incremental(incremental) + + +def test_incremental_chained_call_raises(incremental_dataset: dlt.Dataset) -> None: + incremental_a = dlt.sources.incremental("id", initial_value=1, end_value=END_VALUE_ID) + incremental_b = dlt.sources.incremental("value", initial_value=0.0, end_value=10.0) + + relation = incremental_dataset.table("events").incremental(incremental_a) + with pytest.raises(ValueError, match="already been applied"): + relation.incremental(incremental_b) + + +@pytest.mark.parametrize( + "build_relation", + [ + pytest.param( + lambda ds, load_ids, incremental: ds.table( + "events", load_ids=load_ids, incremental=incremental + ), + id="kwargs", + ), + pytest.param( + lambda ds, load_ids, incremental: ds.table("events") + .from_loads(load_ids) + .incremental(incremental), + id="chained", + ), + ], +) +def test_incremental_dotted_cursor_after_from_loads_raises( + incremental_pipeline: dlt.Pipeline, build_relation: Any +) -> None: + """`.from_loads()` wraps FROM in a subquery, so a subsequent dotted-cursor + `.incremental()` cannot resolve the join. Both the kwargs combo on + `dataset.table()` and the chained form must fail with a clear, user-facing + message rather than the internal `_discover_join_params` error. + """ + dataset = incremental_pipeline.dataset() + load_ids = dataset.load_ids() + assert load_ids, "fixture must produce at least one load" + + incremental = dlt.sources.incremental( + "_dlt_loads.inserted_at", + initial_value=pendulum.datetime(2026, 1, 1, tz="UTC"), + end_value=END_VALUE_DT, + ) + with pytest.raises(ValueError, match="dotted cursor cannot be applied"): + build_relation(dataset, load_ids, incremental) + + +@pytest.mark.parametrize( + "cursor_path", + [ + pytest.param("$.items[*].name", id="jsonpath-wildcard"), + pytest.param("$.name", id="jsonpath-root"), + pytest.param("items[0]", id="array-index"), + ], +) +def test_incremental_rejects_jsonpath_cursor( + incremental_dataset: dlt.Dataset, cursor_path: str +) -> None: + incremental = dlt.sources.incremental(cursor_path, initial_value=1) + with pytest.raises(ValueError, match="JSONPath|plain column"): + incremental_dataset.table("events").incremental(incremental) + + +@pytest.mark.parametrize( + "cursor_path,match", + [ + pytest.param("", "non-empty string", id="empty"), + pytest.param("col.", "not a plain column identifier", id="trailing-dot"), + pytest.param(".col", "not a plain column identifier", id="leading-dot"), + pytest.param('"col with.dot"', "not a plain column identifier", id="quoted-with-dot"), + pytest.param("$.name", "JSONPath expression", id="jsonpath-root"), + pytest.param("items[0]", "JSONPath expression", id="array-index"), + ], +) +def test_parse_incremental_cursor_path_rejects_malformed(cursor_path: str, match: str) -> None: + from dlt.dataset._incremental import _parse_incremental_cursor_path + + with pytest.raises(ValueError, match=match): + _parse_incremental_cursor_path(cursor_path) + + +def test_incremental_rejects_quoted_cursor_with_inner_dot( + incremental_dataset: dlt.Dataset, +) -> None: + incremental = dlt.sources.incremental('"col with.dot"', initial_value=1) + with pytest.raises(ValueError, match="not a plain column identifier"): + incremental_dataset.table("events").incremental(incremental) + + +@pytest.mark.parametrize( + "bounds_kwargs,bind_via_resource", + [ + pytest.param({"initial_value": 2}, True, id="start-only"), + pytest.param({"end_value": END_VALUE_ID}, False, id="end-only"), + pytest.param({"initial_value": 2, "end_value": END_VALUE_ID}, False, id="start-and-end"), + ], +) +@pytest.mark.parametrize( + "policy,expected_root_cls", + [ + pytest.param("include", sge.Or, id="include-or-is-null"), + pytest.param("exclude", sge.And, id="exclude-and-is-not-null"), + ], +) +def test_incremental_on_cursor_value_missing( + incremental_pipeline: dlt.Pipeline, + bounds_kwargs: dict[str, Any], + bind_via_resource: bool, + policy: Literal["include", "exclude"], + expected_root_cls: type, +) -> None: + dataset = incremental_pipeline.dataset() + + if bind_via_resource: + bounds_id = "_".join(sorted(bounds_kwargs)) + resource_name = f"probe_null_guard_{policy}_{bounds_id}" + captured: dlt.Relation | None = None + + @dlt.resource(name=resource_name) + def probe( + cursor: dlt.sources.incremental[int] = dlt.sources.incremental( + "id", on_cursor_value_missing=policy, **bounds_kwargs + ), + ) -> Iterator[Any]: + nonlocal captured + captured = dataset.table("events").incremental(cursor) + yield from [] + + incremental_pipeline.extract(probe()) + assert captured is not None + relation = captured + else: + incremental: dlt.sources.incremental[Any] = dlt.sources.incremental( + "id", on_cursor_value_missing=policy, **bounds_kwargs + ) + relation = dataset.table("events").incremental(incremental) + + condition = _where(relation) + assert isinstance(condition, expected_root_cls), ( + f"Expected `{expected_root_cls.__name__}` root for policy={policy} " + f"bounds={bounds_kwargs}, got {type(condition).__name__}: " + f"{condition.sql()}" + ) + # right-hand side of the wrapper is the null-guard on the cursor column: + # `Is(col, Null)` for include, `Not(Is(col, Null))` for exclude + null_guard = condition.expression + if isinstance(null_guard, sge.Not): + null_guard = null_guard.this + assert isinstance(null_guard, sge.Is) + assert isinstance(null_guard.expression, sge.Null) + assert _column_name(null_guard.this) == "id" + + +def test_incremental_raise_emits_is_not_null_pushdown( + incremental_dataset: dlt.Dataset, +) -> None: + # We can't raise on NULL cursor values, so `"raise"` (the default) + # falls back to `... AND col IS NOT NULL`, same shape as `"exclude"` + incremental = dlt.sources.incremental( + "id", + initial_value=2, + end_value=END_VALUE_ID, + on_cursor_value_missing="raise", + ) + relation = incremental_dataset.table("events").incremental(incremental) + + condition = _where(relation) + assert isinstance(condition, sge.And), ( + "raise pushdown must wrap with `AND IS NOT NULL`, got " + f"{type(condition).__name__}: {condition.sql()}" + ) + null_guard = condition.expression + assert isinstance(null_guard, sge.Not) + inner = null_guard.this + assert isinstance(inner, sge.Is) + assert isinstance(inner.expression, sge.Null) + assert _column_name(inner.this) == "id" + + +def test_incremental_raise_warns_on_nullable_cursor( + incremental_dataset: dlt.Dataset, +) -> None: + incremental = dlt.sources.incremental( + "created_at", + initial_value=pendulum.datetime(2026, 1, 1, tz="UTC"), + end_value=END_VALUE_DT, + on_cursor_value_missing="raise", + ) + with pytest.warns(UserWarning, match="Can't raise on NULL cursor"): + incremental_dataset.table("events").incremental(incremental) + + +def test_incremental_raise_no_warn_on_non_nullable_cursor( + incremental_dataset: dlt.Dataset, +) -> None: + incremental = dlt.sources.incremental( + "_dlt_loads.inserted_at", + initial_value=pendulum.datetime(2026, 1, 1, tz="UTC"), + end_value=END_VALUE_DT, + on_cursor_value_missing="raise", + ) + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always", UserWarning) + incremental_dataset.table("events").incremental(incremental) + pushdown_warnings = [w for w in captured if "Can't raise on NULL cursor" in str(w.message)] + assert pushdown_warnings == [], ( + "unexpected pushdown warning on a non-nullable cursor: " + f"{[str(w.message) for w in pushdown_warnings]}" + ) + + +def test_incremental_no_bounds_include_emits_no_where( + incremental_pipeline: dlt.Pipeline, +) -> None: + dataset = incremental_pipeline.dataset() + captured: dlt.Relation | None = None + + @dlt.resource(name="probe_no_bounds_include") + def probe( + cursor: dlt.sources.incremental[int] = dlt.sources.incremental( + "id", on_cursor_value_missing="include" + ), + ) -> Iterator[Any]: + nonlocal captured + captured = dataset.table("events").incremental(cursor) + yield from [] + + incremental_pipeline.extract(probe()) + assert captured is not None + relation = captured + + assert relation.sqlglot_expression.args.get("where") is None + assert relation.is_incremental is True + # the aggregate over the unfiltered base should still observe the full max id (5) + assert relation._incremental_aggregate_relation().fetchscalar() == 5 + + +@pytest.mark.parametrize("policy", ["exclude", "raise"]) +def test_incremental_no_bounds_exclude_or_raise_emits_only_is_not_null( + incremental_pipeline: dlt.Pipeline, policy: Literal["exclude", "raise"] +) -> None: + dataset = incremental_pipeline.dataset() + captured: dlt.Relation | None = None + + @dlt.resource(name=f"probe_no_bounds_{policy}") + def probe( + cursor: dlt.sources.incremental[int] = dlt.sources.incremental( + "id", on_cursor_value_missing=policy + ), + ) -> Iterator[Any]: + nonlocal captured + captured = dataset.table("events").incremental(cursor) + yield from [] + + incremental_pipeline.extract(probe()) + assert captured is not None + relation = captured + + condition = _where(relation) + assert isinstance(condition, sge.Not), ( + f"expected bare `IS NOT NULL` for no-bounds policy={policy!r}, " + f"got {type(condition).__name__}: {condition.sql()}" + ) + inner = condition.this + assert isinstance(inner, sge.Is) + assert isinstance(inner.expression, sge.Null) + assert _column_name(inner.this) == "id" + assert relation.is_incremental is True + + +@pytest.mark.parametrize("policy", ["include", "exclude"]) +def test_incremental_no_warn_when_policy_explicit( + incremental_dataset: dlt.Dataset, policy: Literal["include", "exclude"] +) -> None: + incremental: dlt.sources.incremental[Any] = dlt.sources.incremental( + "created_at", + initial_value=pendulum.datetime(2026, 1, 1, tz="UTC"), + end_value=END_VALUE_DT, + on_cursor_value_missing=policy, + ) + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always", UserWarning) + incremental_dataset.table("events").incremental(incremental) + assert ( + captured == [] + ), f"unexpected warning for policy={policy!r}: {[str(w.message) for w in captured]}" + + +def _model_transformer( + *, + cursor_path: str = "id", + start_value: Any = 0, + end_value: Any = None, + last_value_func: Any = max, + range_start: Literal["open", "closed"] = "open", + range_end: Literal["open", "closed"] = "open", +) -> ModelIncremental: + return ModelIncremental( + resource_name="test", + cursor_path=cursor_path, + initial_value=start_value, + start_value=start_value, + end_value=end_value, + last_value_func=last_value_func, + primary_key=None, + unique_hashes=set(), + range_start=range_start, + range_end=range_end, + ) + + +def _capture_stateful_relation( + pipeline: dlt.Pipeline, + *, + resource_name: str, + initial_value: int, + range_start: Literal["open", "closed"] = "open", +) -> dlt.Relation: + """Build an `.incremental()`-applied Relation against a bound stateful cursor. + + Stateful incrementals need an active pipeline to resolve + `get_state()`, so we wrap the build in a no-op resource and `extract()` it + just to bind. + """ + dataset = pipeline.dataset() + captured: dlt.Relation | None = None + + @dlt.resource(name=resource_name) + def probe( + cursor: dlt.sources.incremental[int] = dlt.sources.incremental( + "id", initial_value=initial_value, range_start=range_start + ), + ) -> Iterator[Any]: + nonlocal captured + captured = dataset.table("events").incremental(cursor) + yield from [] + + pipeline.extract(probe()) + assert captured is not None + return captured + + +def test_get_transform_dispatches_modelincremental_for_relation( + incremental_dataset: dlt.Dataset, +) -> None: + incremental: dlt.sources.incremental[int] = dlt.sources.incremental( + "id", initial_value=0, end_value=END_VALUE_ID + ) + incremental._cached_state = { + "unique_hashes": [], + "initial_value": 0, + "last_value": 0, + "start_value": 0, + } + relation = incremental_dataset.table("events") + incremental_transform = incremental._get_transform(relation) + assert isinstance(incremental_transform, ModelIncremental) + assert incremental_transform.cursor_path == "id" + + +def test_model_incremental_advances_last_value_for_open_range( + incremental_pipeline: dlt.Pipeline, +) -> None: + relation = _capture_stateful_relation( + incremental_pipeline, resource_name="probe_advance", initial_value=2 + ) + transformer = _model_transformer(start_value=2) + out, start_out_of_range, end_out_of_range = transformer(relation) + + assert out is relation + assert (start_out_of_range, end_out_of_range) == (False, False) + assert transformer.last_value == 5 + + +def test_model_incremental_no_advance_in_scheduler_mode( + incremental_dataset: dlt.Dataset, +) -> None: + incremental = dlt.sources.incremental("id", initial_value=0, end_value=END_VALUE_ID) + relation = incremental_dataset.table("events").incremental(incremental) + + transformer = _model_transformer(start_value=0, end_value=END_VALUE_ID, range_start="closed") + transformer(relation) + + assert transformer.last_value == 0 + + +def test_model_incremental_rejects_closed_range_stateful( + incremental_pipeline: dlt.Pipeline, +) -> None: + relation = _capture_stateful_relation( + incremental_pipeline, + resource_name="probe_reject_closed", + initial_value=0, + range_start="closed", + ) + + transformer = _model_transformer(start_value=0, range_start="closed") + with pytest.raises(ValueError, match="range_start='open'"): + transformer(relation) + + +def test_model_incremental_auto_applies_on_bare_relation( + incremental_pipeline: dlt.Pipeline, +) -> None: + dataset = incremental_pipeline.dataset() + yielded: dlt.Relation | None = None + + @dlt.resource(name="probe_auto_apply") + def probe( + cursor: dlt.sources.incremental[int] = dlt.sources.incremental( + "id", initial_value=2, range_start="open" + ), + ) -> Iterator[Any]: + nonlocal yielded + yielded = dataset.table("events") + yield yielded + + resource = probe() + incremental_pipeline.extract(resource) + + assert yielded is not None + assert yielded.is_incremental is False + + # max(id) over the events table is 5, so `last_value` becomes 5. + assert resource.state["incremental"]["id"]["last_value"] == 5 + + +def test_model_incremental_does_not_clobber_last_value_on_empty_filter( + incremental_pipeline: dlt.Pipeline, +) -> None: + # initial_value above all data (max is 5) so the WHERE excludes everything. + relation = _capture_stateful_relation( + incremental_pipeline, resource_name="probe_empty_filter", initial_value=10**9 + ) + + transformer = _model_transformer(start_value=10**9) + transformer(relation) + + assert transformer.last_value == 10**9 diff --git a/tests/load/test_model_item_format.py b/tests/load/test_model_item_format.py index 7dfe7ab97b..caa472572a 100644 --- a/tests/load/test_model_item_format.py +++ b/tests/load/test_model_item_format.py @@ -6,8 +6,6 @@ from dlt.extract.hints import make_hints -from dlt.pipeline.exceptions import PipelineStepFailed - from dlt.common.schema.typing import TWriteDisposition from dlt.common.utils import uniq_id @@ -52,17 +50,25 @@ def test_simple_incremental(destination_config: DestinationTestConfiguration) -> example_table_columns = dataset.schema.tables["example_table"]["columns"] - # TODO: incremental is not supported for models yet @dlt.resource() - def copied_table(incremental_field=dlt.sources.incremental("a")) -> Any: - rel = dataset["example_table"].limit(8) + def copied_table( + cursor: dlt.sources.incremental[int] = dlt.sources.incremental( + "a", initial_value=2, end_value=100, on_cursor_value_missing="exclude" + ), + ) -> Any: + rel = dataset["example_table"].incremental(cursor) yield dlt.mark.with_hints( rel, hints=make_hints(columns=example_table_columns), ) - with pytest.raises(PipelineStepFailed): - pipeline.run([copied_table()]) + info = pipeline.run( + [copied_table()], + loader_file_format="model", + table_format=destination_config.run_kwargs["table_format"], + ) + assert_load_info(info) + assert load_table_counts(pipeline, "copied_table") == {"copied_table": 8} @pytest.mark.parametrize(