Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 48 additions & 26 deletions dlt/common/libs/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
raise ImportError(f"Found pydantic {PYDANTIC_VERSION} but dlt requires pydantic>=2.0")
from pydantic import (
BaseModel,
Field,
ValidationError,
Json,
create_model,
Expand Down Expand Up @@ -127,29 +128,19 @@ def _build_discriminator_map(
return None

ann = root_field.annotation
discriminator: Optional[str] = None
discriminator = _get_field_discriminator(root_field)
union_args: Tuple[Any, ...] = ()

if is_annotated(ann):
if root_field.discriminator is not None:
# on pydantic >= 2.13 the discriminator is stored on the field object, and the
# annotation may be either the bare union or Annotated[Union[A, B], <other metadata>]
union_ann = get_args(ann)[0] if is_annotated(ann) else ann
union_args = get_args(union_ann)
elif is_annotated(ann):
# on pydantic < 2.13 `ann` is an annotated union type, e.g. Annotated[Union[A, B], Field(discriminator="kind")]
args = get_args(ann)
# metadata may be FieldInfo directly or wrapped in a tuple by _process_annotation
for a in args[1:]:
items = a if isinstance(a, (list, tuple)) else (a,)
for item in items:
if isinstance(item, FieldInfo) and isinstance(item.discriminator, str):
discriminator = item.discriminator
break
if discriminator:
break
union_args = get_args(args[0])

# pydantic 2.13+: annotation is plain Union (not Annotated), discriminator
# is on root_field directly
if not discriminator and isinstance(root_field.discriminator, str):
discriminator = root_field.discriminator
if not union_args:
union_args = get_args(ann)

if not discriminator or not union_args:
return None

Expand All @@ -162,6 +153,25 @@ def _build_discriminator_map(
return discriminator, mapping


def _get_field_discriminator(field: FieldInfo) -> Optional[str]:
# pydantic >= 2.13 stores the discriminator field name directly on the field object
if isinstance(field.discriminator, str):
return field.discriminator

if not is_annotated(field.annotation):
return None

# on pydantic < 2.13 `ann` is an annotated union type, e.g. Annotated[Union[A, B], Field(discriminator="kind")]
args = get_args(field.annotation)
for a in args[1:]:
items = a if isinstance(a, (list, tuple)) else (a,)
for item in items:
if isinstance(item, FieldInfo) and isinstance(item.discriminator, str):
return item.discriminator

return None


def resolve_variant_model(
model: Type[BaseModel],
item: Any,
Expand Down Expand Up @@ -344,7 +354,9 @@ def apply_schema_contract_to_model(
{"__module__": model.__module__},
)
else:
model = create_model(model.__name__ + "Any", **{n: (Any, None) for n in model.model_fields}) # type: ignore
model = create_model(
model.__name__ + "Any", **{n: (Any, None) for n in model.model_fields}
) # type: ignore
elif data_mode == "discard_value":
raise NotImplementedError(
"`data_mode='discard_value'`. Cannot discard defined fields with validation errors"
Expand All @@ -369,7 +381,7 @@ def _process_annotation(t_: Type[Any]) -> Type[Any]:
"""Recursively recreates models with applied schema contract"""
if is_annotated(t_):
a_t, *a_m = get_args(t_)
return Annotated[_process_annotation(a_t), tuple(a_m)] # type: ignore[return-value]
return Annotated[(_process_annotation(a_t), *a_m)] # type: ignore[return-value]
Copy link
Copy Markdown
Contributor Author

@Travior Travior Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it's unclear why we previously wrapped other annotation metadata into tuples, losing their semantics:
Annotated[str, FieldInfo(...), AfterValidator(...)]
would be turned into:
Annotated[str, (FieldInfo(...), AfterValidator(...)].

Pydantic then ignores this tuple and we effectively lose the validator (or other annotation metadata)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is certainly a bug. I probably overlooked that in claude-generated code.

origin = get_origin(t_)
# tuple must be checked before is_list_generic_type (tuple is a Sequence)
if origin is tuple:
Expand Down Expand Up @@ -417,21 +429,31 @@ def _rebuild_annotated(f: Any) -> Type[Any]:
if getattr(model, "__pydantic_root_model__", False):
root_field = model.model_fields.get("root")
if root_field:
processed_ann = _process_annotation(_rebuild_annotated(root_field))
# preserve discriminator that pydantic 2.13+ strips from Annotated metadata
if isinstance(root_field.discriminator, str) and not is_annotated(processed_ann):
processed_ann = Annotated[processed_ann, FieldInfo(discriminator=root_field.discriminator)] # type: ignore[assignment]
# on pydantic >= 2.13 `_rebuild_annotated` might not return an annotated type,
# so processed_ann could be the bare Union type, e.g. Union[A, B]
processed_ann: Any = _process_annotation(_rebuild_annotated(root_field))
discriminator = _get_field_discriminator(root_field)
if discriminator is not None and root_field.discriminator is not None:
if is_annotated(processed_ann):
# we have other metadata (besides the discriminator) that we should preserve
args = get_args(processed_ann)
processed_ann = Annotated[
(args[0], Field(discriminator=discriminator), *args[1:])
]
else:
# no other metadata, so we can rebuild the annotated type with the discriminator
processed_ann = Annotated[processed_ann, Field(discriminator=discriminator)]
new_rm = type(
model.__name__ + "Extra" + extra.title(),
(PydanticRootModel[processed_ann],), # type: ignore[valid-type]
(PydanticRootModel[processed_ann],),
{"__module__": model.__module__},
)
if original_dlt_config:
new_rm.dlt_config = original_dlt_config # type: ignore[attr-defined]
return new_rm

processed_fields = {
n: (_process_annotation(_rebuild_annotated(f)), f) for n, f in model.model_fields.items()
n: (_process_annotation(f.annotation), f) for n, f in model.model_fields.items()
}

# use __base__ to inherit validators (@field_validator, @model_validator)
Expand Down
174 changes: 174 additions & 0 deletions tests/libs/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
)
from dlt.common.warnings import Dlt100DeprecationWarning
from pydantic import (
AfterValidator,
VERSION as PYDANTIC_VERSION,
UUID4,
BaseModel,
Field,
Expand Down Expand Up @@ -1126,6 +1128,178 @@ def test_apply_contract_root_model_discriminator_preserved() -> None:
mutated.model_validate({"kind": "unknown", "id": 3})


def test_apply_contract_root_model_preserves_root_after_validator() -> None:
"""Root-level validator metadata must survive RootModel reconstruction."""

def validate_click(value: Any) -> Any:
if isinstance(value, Click) and value.element_id == "forbidden":
raise ValueError("forbidden element")
return value

class Click(BaseModel):
kind: Literal["click"]
element_id: str

class Purchase(BaseModel):
kind: Literal["purchase"]
amount: float

U = Annotated[
Union[Click, Purchase],
Field(discriminator="kind"),
AfterValidator(validate_click),
]

class Event(RootModel[U]):
pass

mutated: Any = apply_schema_contract_to_model(Event, "freeze", "freeze")
validated = mutated.model_validate({"kind": "click", "element_id": "btn_1"})

assert isinstance(validated.root, Click)
assert validated.root.kind == "click"
with pytest.raises(ValidationError, match="forbidden element"):
mutated.model_validate({"kind": "click", "element_id": "forbidden"})


def test_apply_contract_preserves_non_root_after_validator() -> None:
"""Non-root Annotated validator metadata must survive model reconstruction."""

def reject_negative(value: int) -> int:
if value < 0:
raise ValueError("negative value")
return value

class ModelWithAnnotatedValidator(BaseModel):
value: Annotated[int, AfterValidator(reject_negative)]

mutated: Any = apply_schema_contract_to_model(ModelWithAnnotatedValidator, "freeze", "freeze")
validated = mutated.model_validate({"value": 1})

assert validated.value == 1
with pytest.raises(ValidationError, match="negative value"):
mutated.model_validate({"value": -1})


def test_apply_contract_preserves_multiple_annotated_metadata_entries() -> None:
"""Annotated metadata entries stay separate after model reconstruction."""

class Child(BaseModel):
x: int

class ModelWithAnnotatedMetadata(BaseModel):
field: Annotated[Child, "meta1", "meta2"]

mutated: Any = apply_schema_contract_to_model(ModelWithAnnotatedMetadata, "freeze", "freeze")
rebuilt = mutated.model_fields["field"].rebuild_annotation()

assert get_origin(rebuilt) is Annotated
assert get_args(rebuilt)[1:] == ("meta1", "meta2")


def test_apply_contract_preserves_nested_annotated_metadata_entries() -> None:
"""Nested Annotated metadata entries must remain separate slots, not be packed
into a single tuple, when _process_annotation recurses through containers."""

class Child(BaseModel):
x: int

class M(BaseModel):
items: List[Annotated[Child, "meta1", "meta2"]]

mutated: Any = apply_schema_contract_to_model(M, "freeze", "freeze")
inner = get_args(mutated.model_fields["items"].annotation)[0]
inner_args = get_args(inner)

assert inner_args[0].__name__.endswith("ExtraForbid")
assert inner_args[1:] == ("meta1", "meta2")


def test_apply_contract_non_root_discriminated_union_with_validator() -> None:
"""Non-root field with a discriminated union plus extra metadata keeps both
the discriminator and the validator after model reconstruction."""

class Click(BaseModel):
kind: Literal["click"]
element_id: str

class Purchase(BaseModel):
kind: Literal["purchase"]
amount: float

def reject_forbidden(value: Any) -> Any:
if isinstance(value, Click) and value.element_id == "forbidden":
raise ValueError("forbidden element")
return value

class Container(BaseModel):
event: Annotated[
Union[Click, Purchase],
Field(discriminator="kind"),
AfterValidator(reject_forbidden),
]

mutated: Any = apply_schema_contract_to_model(Container, "freeze", "freeze")

ok = mutated.model_validate({"event": {"kind": "click", "element_id": "btn_1"}})
assert ok.event.kind == "click"
with pytest.raises(ValidationError):
mutated.model_validate({"event": {"kind": "unknown", "element_id": "btn_1"}})

with pytest.raises(ValidationError, match="forbidden element"):
mutated.model_validate({"event": {"kind": "click", "element_id": "forbidden"}})


@pytest.mark.skipif(
tuple(int(part) for part in PYDANTIC_VERSION.split(".")[:2]) < (2, 13),
reason="Requires pydantic >= 2.13 field.discriminator behavior",
)
def test_build_discriminator_map_preserves_new_root_discriminator_with_extra_metadata() -> None:
"""Discriminator extraction works when pydantic stores it on the root field."""

def validate_event(value: Any) -> Any:
return value

class Click(BaseModel):
kind: Literal["click"]
element_id: str

class Purchase(BaseModel):
kind: Literal["purchase"]
amount: float

U = Annotated[
Union[Click, Purchase],
AfterValidator(validate_event),
Field(discriminator="kind"),
]

class Event(RootModel[U]):
pass

root_field = Event.model_fields["root"]
assert root_field.discriminator == "kind"

result = _build_discriminator_map(Event)
assert result is not None
disc_field, mapping = result
assert disc_field == "kind"
assert set(mapping.keys()) == {"click", "purchase"}

mutated: Any = apply_schema_contract_to_model(Event, "freeze", "freeze")
root_field = mutated.model_fields["root"]
assert root_field.discriminator == "kind"

result = _build_discriminator_map(mutated)
assert result is not None
disc_field, mapping = result
assert disc_field == "kind"
assert set(mapping.keys()) == {"click", "purchase"}

validated = mutated.model_validate({"kind": "click", "element_id": "btn_1"})
assert validated.model_dump() == {"kind": "click", "element_id": "btn_1"}


def test_extra_schema_contract_conflict_warning() -> None:
"""Warns when model extra contradicts schema_contract columns setting."""

Expand Down
Loading