diff --git a/dlt/common/libs/pydantic.py b/dlt/common/libs/pydantic.py index cdd843bec8..dbb36d93ed 100644 --- a/dlt/common/libs/pydantic.py +++ b/dlt/common/libs/pydantic.py @@ -46,6 +46,7 @@ raise ImportError(f"Found pydantic {PYDANTIC_VERSION} but dlt requires pydantic>=2.0") from pydantic import ( BaseModel, + Field, ValidationError, Json, create_model, @@ -127,29 +128,19 @@ def _build_discriminator_map( return None ann = root_field.annotation - discriminator: Optional[str] = None + discriminator = _get_field_discriminator(root_field) union_args: Tuple[Any, ...] = () - if is_annotated(ann): + if root_field.discriminator is not None: + # on pydantic >= 2.13 the discriminator is stored on the field object, and the + # annotation may be either the bare union or Annotated[Union[A, B], ] + union_ann = get_args(ann)[0] if is_annotated(ann) else ann + union_args = get_args(union_ann) + elif is_annotated(ann): + # on pydantic < 2.13 `ann` is an annotated union type, e.g. Annotated[Union[A, B], Field(discriminator="kind")] args = get_args(ann) - # metadata may be FieldInfo directly or wrapped in a tuple by _process_annotation - for a in args[1:]: - items = a if isinstance(a, (list, tuple)) else (a,) - for item in items: - if isinstance(item, FieldInfo) and isinstance(item.discriminator, str): - discriminator = item.discriminator - break - if discriminator: - break union_args = get_args(args[0]) - # pydantic 2.13+: annotation is plain Union (not Annotated), discriminator - # is on root_field directly - if not discriminator and isinstance(root_field.discriminator, str): - discriminator = root_field.discriminator - if not union_args: - union_args = get_args(ann) - if not discriminator or not union_args: return None @@ -162,6 +153,25 @@ def _build_discriminator_map( return discriminator, mapping +def _get_field_discriminator(field: FieldInfo) -> Optional[str]: + # pydantic >= 2.13 stores the discriminator field name directly on the field object + if isinstance(field.discriminator, str): + return field.discriminator + + if not is_annotated(field.annotation): + return None + + # on pydantic < 2.13 `ann` is an annotated union type, e.g. Annotated[Union[A, B], Field(discriminator="kind")] + args = get_args(field.annotation) + for a in args[1:]: + items = a if isinstance(a, (list, tuple)) else (a,) + for item in items: + if isinstance(item, FieldInfo) and isinstance(item.discriminator, str): + return item.discriminator + + return None + + def resolve_variant_model( model: Type[BaseModel], item: Any, @@ -344,7 +354,9 @@ def apply_schema_contract_to_model( {"__module__": model.__module__}, ) else: - model = create_model(model.__name__ + "Any", **{n: (Any, None) for n in model.model_fields}) # type: ignore + model = create_model( + model.__name__ + "Any", **{n: (Any, None) for n in model.model_fields} + ) # type: ignore elif data_mode == "discard_value": raise NotImplementedError( "`data_mode='discard_value'`. Cannot discard defined fields with validation errors" @@ -369,7 +381,7 @@ def _process_annotation(t_: Type[Any]) -> Type[Any]: """Recursively recreates models with applied schema contract""" if is_annotated(t_): a_t, *a_m = get_args(t_) - return Annotated[_process_annotation(a_t), tuple(a_m)] # type: ignore[return-value] + return Annotated[(_process_annotation(a_t), *a_m)] # type: ignore[return-value] origin = get_origin(t_) # tuple must be checked before is_list_generic_type (tuple is a Sequence) if origin is tuple: @@ -417,13 +429,23 @@ def _rebuild_annotated(f: Any) -> Type[Any]: if getattr(model, "__pydantic_root_model__", False): root_field = model.model_fields.get("root") if root_field: - processed_ann = _process_annotation(_rebuild_annotated(root_field)) - # preserve discriminator that pydantic 2.13+ strips from Annotated metadata - if isinstance(root_field.discriminator, str) and not is_annotated(processed_ann): - processed_ann = Annotated[processed_ann, FieldInfo(discriminator=root_field.discriminator)] # type: ignore[assignment] + # on pydantic >= 2.13 `_rebuild_annotated` might not return an annotated type, + # so processed_ann could be the bare Union type, e.g. Union[A, B] + processed_ann: Any = _process_annotation(_rebuild_annotated(root_field)) + discriminator = _get_field_discriminator(root_field) + if discriminator is not None and root_field.discriminator is not None: + if is_annotated(processed_ann): + # we have other metadata (besides the discriminator) that we should preserve + args = get_args(processed_ann) + processed_ann = Annotated[ + (args[0], Field(discriminator=discriminator), *args[1:]) + ] + else: + # no other metadata, so we can rebuild the annotated type with the discriminator + processed_ann = Annotated[processed_ann, Field(discriminator=discriminator)] new_rm = type( model.__name__ + "Extra" + extra.title(), - (PydanticRootModel[processed_ann],), # type: ignore[valid-type] + (PydanticRootModel[processed_ann],), {"__module__": model.__module__}, ) if original_dlt_config: @@ -431,7 +453,7 @@ def _rebuild_annotated(f: Any) -> Type[Any]: return new_rm processed_fields = { - n: (_process_annotation(_rebuild_annotated(f)), f) for n, f in model.model_fields.items() + n: (_process_annotation(f.annotation), f) for n, f in model.model_fields.items() } # use __base__ to inherit validators (@field_validator, @model_validator) diff --git a/tests/libs/test_pydantic.py b/tests/libs/test_pydantic.py index f2e1caec0c..0e55685f34 100644 --- a/tests/libs/test_pydantic.py +++ b/tests/libs/test_pydantic.py @@ -47,6 +47,8 @@ ) from dlt.common.warnings import Dlt100DeprecationWarning from pydantic import ( + AfterValidator, + VERSION as PYDANTIC_VERSION, UUID4, BaseModel, Field, @@ -1126,6 +1128,178 @@ def test_apply_contract_root_model_discriminator_preserved() -> None: mutated.model_validate({"kind": "unknown", "id": 3}) +def test_apply_contract_root_model_preserves_root_after_validator() -> None: + """Root-level validator metadata must survive RootModel reconstruction.""" + + def validate_click(value: Any) -> Any: + if isinstance(value, Click) and value.element_id == "forbidden": + raise ValueError("forbidden element") + return value + + class Click(BaseModel): + kind: Literal["click"] + element_id: str + + class Purchase(BaseModel): + kind: Literal["purchase"] + amount: float + + U = Annotated[ + Union[Click, Purchase], + Field(discriminator="kind"), + AfterValidator(validate_click), + ] + + class Event(RootModel[U]): + pass + + mutated: Any = apply_schema_contract_to_model(Event, "freeze", "freeze") + validated = mutated.model_validate({"kind": "click", "element_id": "btn_1"}) + + assert isinstance(validated.root, Click) + assert validated.root.kind == "click" + with pytest.raises(ValidationError, match="forbidden element"): + mutated.model_validate({"kind": "click", "element_id": "forbidden"}) + + +def test_apply_contract_preserves_non_root_after_validator() -> None: + """Non-root Annotated validator metadata must survive model reconstruction.""" + + def reject_negative(value: int) -> int: + if value < 0: + raise ValueError("negative value") + return value + + class ModelWithAnnotatedValidator(BaseModel): + value: Annotated[int, AfterValidator(reject_negative)] + + mutated: Any = apply_schema_contract_to_model(ModelWithAnnotatedValidator, "freeze", "freeze") + validated = mutated.model_validate({"value": 1}) + + assert validated.value == 1 + with pytest.raises(ValidationError, match="negative value"): + mutated.model_validate({"value": -1}) + + +def test_apply_contract_preserves_multiple_annotated_metadata_entries() -> None: + """Annotated metadata entries stay separate after model reconstruction.""" + + class Child(BaseModel): + x: int + + class ModelWithAnnotatedMetadata(BaseModel): + field: Annotated[Child, "meta1", "meta2"] + + mutated: Any = apply_schema_contract_to_model(ModelWithAnnotatedMetadata, "freeze", "freeze") + rebuilt = mutated.model_fields["field"].rebuild_annotation() + + assert get_origin(rebuilt) is Annotated + assert get_args(rebuilt)[1:] == ("meta1", "meta2") + + +def test_apply_contract_preserves_nested_annotated_metadata_entries() -> None: + """Nested Annotated metadata entries must remain separate slots, not be packed + into a single tuple, when _process_annotation recurses through containers.""" + + class Child(BaseModel): + x: int + + class M(BaseModel): + items: List[Annotated[Child, "meta1", "meta2"]] + + mutated: Any = apply_schema_contract_to_model(M, "freeze", "freeze") + inner = get_args(mutated.model_fields["items"].annotation)[0] + inner_args = get_args(inner) + + assert inner_args[0].__name__.endswith("ExtraForbid") + assert inner_args[1:] == ("meta1", "meta2") + + +def test_apply_contract_non_root_discriminated_union_with_validator() -> None: + """Non-root field with a discriminated union plus extra metadata keeps both + the discriminator and the validator after model reconstruction.""" + + class Click(BaseModel): + kind: Literal["click"] + element_id: str + + class Purchase(BaseModel): + kind: Literal["purchase"] + amount: float + + def reject_forbidden(value: Any) -> Any: + if isinstance(value, Click) and value.element_id == "forbidden": + raise ValueError("forbidden element") + return value + + class Container(BaseModel): + event: Annotated[ + Union[Click, Purchase], + Field(discriminator="kind"), + AfterValidator(reject_forbidden), + ] + + mutated: Any = apply_schema_contract_to_model(Container, "freeze", "freeze") + + ok = mutated.model_validate({"event": {"kind": "click", "element_id": "btn_1"}}) + assert ok.event.kind == "click" + with pytest.raises(ValidationError): + mutated.model_validate({"event": {"kind": "unknown", "element_id": "btn_1"}}) + + with pytest.raises(ValidationError, match="forbidden element"): + mutated.model_validate({"event": {"kind": "click", "element_id": "forbidden"}}) + + +@pytest.mark.skipif( + tuple(int(part) for part in PYDANTIC_VERSION.split(".")[:2]) < (2, 13), + reason="Requires pydantic >= 2.13 field.discriminator behavior", +) +def test_build_discriminator_map_preserves_new_root_discriminator_with_extra_metadata() -> None: + """Discriminator extraction works when pydantic stores it on the root field.""" + + def validate_event(value: Any) -> Any: + return value + + class Click(BaseModel): + kind: Literal["click"] + element_id: str + + class Purchase(BaseModel): + kind: Literal["purchase"] + amount: float + + U = Annotated[ + Union[Click, Purchase], + AfterValidator(validate_event), + Field(discriminator="kind"), + ] + + class Event(RootModel[U]): + pass + + root_field = Event.model_fields["root"] + assert root_field.discriminator == "kind" + + result = _build_discriminator_map(Event) + assert result is not None + disc_field, mapping = result + assert disc_field == "kind" + assert set(mapping.keys()) == {"click", "purchase"} + + mutated: Any = apply_schema_contract_to_model(Event, "freeze", "freeze") + root_field = mutated.model_fields["root"] + assert root_field.discriminator == "kind" + + result = _build_discriminator_map(mutated) + assert result is not None + disc_field, mapping = result + assert disc_field == "kind" + assert set(mapping.keys()) == {"click", "purchase"} + + validated = mutated.model_validate({"kind": "click", "element_id": "btn_1"}) + assert validated.model_dump() == {"kind": "click", "element_id": "btn_1"} + + def test_extra_schema_contract_conflict_warning() -> None: """Warns when model extra contradicts schema_contract columns setting."""