Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 48 additions & 26 deletions dlt/common/libs/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
raise ImportError(f"Found pydantic {PYDANTIC_VERSION} but dlt requires pydantic>=2.0")
from pydantic import (
BaseModel,
Field,
ValidationError,
Json,
create_model,
Expand Down Expand Up @@ -127,29 +128,19 @@ def _build_discriminator_map(
return None

ann = root_field.annotation
discriminator: Optional[str] = None
discriminator = _get_field_discriminator(root_field)
union_args: Tuple[Any, ...] = ()

if is_annotated(ann):
if root_field.discriminator is not None:
# on pydantic >= 2.13 the discriminator is stored on the field object, and the
# annotation may be either the bare union or Annotated[Union[A, B], <other metadata>]
union_ann = get_args(ann)[0] if is_annotated(ann) else ann
union_args = get_args(union_ann)
elif is_annotated(ann):
# on pydantic < 2.13 `ann` is an annotated union type, e.g. Annotated[Union[A, B], Field(discriminator="kind")]
args = get_args(ann)
# metadata may be FieldInfo directly or wrapped in a tuple by _process_annotation
for a in args[1:]:
items = a if isinstance(a, (list, tuple)) else (a,)
for item in items:
if isinstance(item, FieldInfo) and isinstance(item.discriminator, str):
discriminator = item.discriminator
break
if discriminator:
break
union_args = get_args(args[0])

# pydantic 2.13+: annotation is plain Union (not Annotated), discriminator
# is on root_field directly
if not discriminator and isinstance(root_field.discriminator, str):
discriminator = root_field.discriminator
if not union_args:
union_args = get_args(ann)

if not discriminator or not union_args:
return None

Expand All @@ -162,6 +153,25 @@ def _build_discriminator_map(
return discriminator, mapping


def _get_field_discriminator(field: FieldInfo) -> Optional[str]:
# pydantic >= 2.13 stores the discriminator field name directly on the field object
if isinstance(field.discriminator, str):
return field.discriminator

if not is_annotated(field.annotation):
return None

# on pydantic < 2.13 `ann` is an annotated union type, e.g. Annotated[Union[A, B], Field(discriminator="kind")]
args = get_args(field.annotation)
for a in args[1:]:
items = a if isinstance(a, (list, tuple)) else (a,)
for item in items:
if isinstance(item, FieldInfo) and isinstance(item.discriminator, str):
return item.discriminator

return None


def resolve_variant_model(
model: Type[BaseModel],
item: Any,
Expand Down Expand Up @@ -344,7 +354,9 @@ def apply_schema_contract_to_model(
{"__module__": model.__module__},
)
else:
model = create_model(model.__name__ + "Any", **{n: (Any, None) for n in model.model_fields}) # type: ignore
model = create_model(
model.__name__ + "Any", **{n: (Any, None) for n in model.model_fields}
) # type: ignore
elif data_mode == "discard_value":
raise NotImplementedError(
"`data_mode='discard_value'`. Cannot discard defined fields with validation errors"
Expand All @@ -369,7 +381,7 @@ def _process_annotation(t_: Type[Any]) -> Type[Any]:
"""Recursively recreates models with applied schema contract"""
if is_annotated(t_):
a_t, *a_m = get_args(t_)
return Annotated[_process_annotation(a_t), tuple(a_m)] # type: ignore[return-value]
return Annotated[(_process_annotation(a_t), *a_m)] # type: ignore[return-value]
Copy link
Copy Markdown
Contributor Author

@Travior Travior Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it's unclear why we previously wrapped other annotation metadata into tuples, losing their semantics:
Annotated[str, FieldInfo(...), AfterValidator(...)]
would be turned into:
Annotated[str, (FieldInfo(...), AfterValidator(...)].

Pydantic then ignores this tuple and we effectively lose the validator (or other annotation metadata)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is certainly a bug. I probably overlooked that in claude-generated code.

origin = get_origin(t_)
# tuple must be checked before is_list_generic_type (tuple is a Sequence)
if origin is tuple:
Expand Down Expand Up @@ -417,21 +429,31 @@ def _rebuild_annotated(f: Any) -> Type[Any]:
if getattr(model, "__pydantic_root_model__", False):
root_field = model.model_fields.get("root")
if root_field:
processed_ann = _process_annotation(_rebuild_annotated(root_field))
# preserve discriminator that pydantic 2.13+ strips from Annotated metadata
if isinstance(root_field.discriminator, str) and not is_annotated(processed_ann):
processed_ann = Annotated[processed_ann, FieldInfo(discriminator=root_field.discriminator)] # type: ignore[assignment]
# on pydantic >= 2.13 `_rebuild_annotated` might not return an annotated type,
# so processed_ann could be the bare Union type, e.g. Union[A, B]
processed_ann: Any = _process_annotation(_rebuild_annotated(root_field))
discriminator = _get_field_discriminator(root_field)
if discriminator is not None and root_field.discriminator is not None:
if is_annotated(processed_ann):
# we have other metadata (besides the discriminator) that we should preserve
args = get_args(processed_ann)
processed_ann = Annotated[
(args[0], Field(discriminator=discriminator), *args[1:])
]
else:
# no other metadata, so we can rebuild the annotated type with the discriminator
processed_ann = Annotated[processed_ann, Field(discriminator=discriminator)]
new_rm = type(
model.__name__ + "Extra" + extra.title(),
(PydanticRootModel[processed_ann],), # type: ignore[valid-type]
(PydanticRootModel[processed_ann],),
{"__module__": model.__module__},
)
if original_dlt_config:
new_rm.dlt_config = original_dlt_config # type: ignore[attr-defined]
return new_rm

processed_fields = {
n: (_process_annotation(_rebuild_annotated(f)), f) for n, f in model.model_fields.items()
n: (_process_annotation(f.annotation), f) for n, f in model.model_fields.items()
}

# use __base__ to inherit validators (@field_validator, @model_validator)
Expand Down
145 changes: 133 additions & 12 deletions tests/libs/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
)
from dlt.common.warnings import Dlt100DeprecationWarning
from pydantic import (
AfterValidator,
VERSION as PYDANTIC_VERSION,
UUID4,
BaseModel,
Field,
Expand Down Expand Up @@ -1126,6 +1128,125 @@ def test_apply_contract_root_model_discriminator_preserved() -> None:
mutated.model_validate({"kind": "unknown", "id": 3})


def test_apply_contract_root_model_preserves_root_after_validator() -> None:
"""Root-level validator metadata must survive RootModel reconstruction."""

def validate_click(value: Any) -> Any:
if isinstance(value, Click) and value.element_id == "forbidden":
raise ValueError("forbidden element")
return value

class Click(BaseModel):
kind: Literal["click"]
element_id: str

class Purchase(BaseModel):
kind: Literal["purchase"]
amount: float

U = Annotated[
Union[Click, Purchase],
Field(discriminator="kind"),
AfterValidator(validate_click),
]

class Event(RootModel[U]):
pass

mutated: Any = apply_schema_contract_to_model(Event, "freeze", "freeze")
validated = mutated.model_validate({"kind": "click", "element_id": "btn_1"})

assert isinstance(validated.root, Click)
assert validated.root.kind == "click"
with pytest.raises(ValidationError, match="forbidden element"):
mutated.model_validate({"kind": "click", "element_id": "forbidden"})


def test_apply_contract_preserves_non_root_after_validator() -> None:
"""Non-root Annotated validator metadata must survive model reconstruction."""

def reject_negative(value: int) -> int:
if value < 0:
raise ValueError("negative value")
return value

class ModelWithAnnotatedValidator(BaseModel):
value: Annotated[int, AfterValidator(reject_negative)]

mutated: Any = apply_schema_contract_to_model(ModelWithAnnotatedValidator, "freeze", "freeze")
validated = mutated.model_validate({"value": 1})

assert validated.value == 1
with pytest.raises(ValidationError, match="negative value"):
mutated.model_validate({"value": -1})


def test_apply_contract_preserves_multiple_annotated_metadata_entries() -> None:
"""Annotated metadata entries stay separate after model reconstruction."""

class Child(BaseModel):
x: int

class ModelWithAnnotatedMetadata(BaseModel):
field: Annotated[Child, "meta1", "meta2"]

mutated: Any = apply_schema_contract_to_model(ModelWithAnnotatedMetadata, "freeze", "freeze")
rebuilt = mutated.model_fields["field"].rebuild_annotation()

assert get_origin(rebuilt) is Annotated
assert get_args(rebuilt)[1:] == ("meta1", "meta2")


@pytest.mark.skipif(
tuple(int(part) for part in PYDANTIC_VERSION.split(".")[:2]) < (2, 13),
reason="Requires pydantic >= 2.13 field.discriminator behavior",
)
def test_build_discriminator_map_preserves_new_root_discriminator_with_extra_metadata() -> None:
"""Discriminator extraction works when pydantic stores it on the root field."""

def validate_event(value: Any) -> Any:
return value

class Click(BaseModel):
kind: Literal["click"]
element_id: str

class Purchase(BaseModel):
kind: Literal["purchase"]
amount: float

U = Annotated[
Union[Click, Purchase],
AfterValidator(validate_event),
Field(discriminator="kind"),
]

class Event(RootModel[U]):
pass

root_field = Event.model_fields["root"]
assert root_field.discriminator == "kind"

result = _build_discriminator_map(Event)
assert result is not None
disc_field, mapping = result
assert disc_field == "kind"
assert set(mapping.keys()) == {"click", "purchase"}

mutated: Any = apply_schema_contract_to_model(Event, "freeze", "freeze")
root_field = mutated.model_fields["root"]
assert root_field.discriminator == "kind"

result = _build_discriminator_map(mutated)
assert result is not None
disc_field, mapping = result
assert disc_field == "kind"
assert set(mapping.keys()) == {"click", "purchase"}

validated = mutated.model_validate({"kind": "click", "element_id": "btn_1"})
assert validated.model_dump() == {"kind": "click", "element_id": "btn_1"}


def test_extra_schema_contract_conflict_warning() -> None:
"""Warns when model extra contradicts schema_contract columns setting."""

Expand Down Expand Up @@ -1304,20 +1425,20 @@ def get_inner_model_names(annotation: Any) -> List[str]:
# user_labels: List[UserLabel] — inner model should be transformed
user_labels_ann = model.model_fields["user_labels"].annotation
inner_names = get_inner_model_names(user_labels_ann)
assert any(
"ExtraAllow" in n for n in inner_names
), f"UserLabel in List not transformed: {inner_names}"
assert any("ExtraAllow" in n for n in inner_names), (
f"UserLabel in List not transformed: {inner_names}"
)

# unity: Union[UserAddress, UserLabel, Dict[str, UserAddress]]
unity_ann = model.model_fields["unity"].annotation
inner_names = get_inner_model_names(unity_ann)
# should have transformed versions of UserAddress and UserLabel
assert any(
"UserAddress" in n and "ExtraAllow" in n for n in inner_names
), f"UserAddress in Union not transformed: {inner_names}"
assert any(
"UserLabel" in n and "ExtraAllow" in n for n in inner_names
), f"UserLabel in Union not transformed: {inner_names}"
assert any("UserAddress" in n and "ExtraAllow" in n for n in inner_names), (
f"UserAddress in Union not transformed: {inner_names}"
)
assert any("UserLabel" in n and "ExtraAllow" in n for n in inner_names), (
f"UserLabel in Union not transformed: {inner_names}"
)

# address field itself is Annotated[UserAddress, ...] — check the inner type
address_ann = model.model_fields["address"].annotation
Expand All @@ -1327,9 +1448,9 @@ def get_inner_model_names(annotation: Any) -> List[str]:
addr_model = address_ann
ro_labels_ann = addr_model.model_fields["ro_labels"].annotation
inner_names = get_inner_model_names(ro_labels_ann)
assert any(
"ExtraAllow" in n for n in inner_names
), f"UserLabel in Mapping not transformed: {inner_names}"
assert any("ExtraAllow" in n for n in inner_names), (
f"UserLabel in Mapping not transformed: {inner_names}"
)


def test_child_model_cache_shared_across_nesting_levels() -> None:
Expand Down
Loading