Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,14 @@ r2: RecursiveList[int] = [1, [1, 2, 3]]
r3: RecursiveList[int] = "a"
# error: [invalid-assignment]
r4: RecursiveList[int] = ["a"]
# TODO: this should be an error
# error: [invalid-assignment] "Object of type `list[int | list[RecursiveList[int]] | list[int | list[RecursiveList[int]] | str]]` is not assignable to `RecursiveList[int]`"
r5: RecursiveList[int] = [1, ["a"]]

def _(x: RecursiveList[int]):
if isinstance(x, list):
# TODO: should be `list[RecursiveList[int]]
reveal_type(x[0]) # revealed: int | list[Any]
reveal_type(x[0]) # revealed: int | list[RecursiveList[int]]
if isinstance(x, list) and isinstance(x[0], list):
# TODO: should be `list[RecursiveList[int]]`
reveal_type(x[0]) # revealed: list[Any]
reveal_type(x[0]) # revealed: list[RecursiveList[int]]
```

Assignment checks respect structural subtyping, i.e. type aliases with the same structure are
Expand Down Expand Up @@ -413,7 +411,7 @@ d1: DivergentList[int] = []
d2: DivergentList[int] = [1]
# error: [invalid-assignment]
d3: DivergentList[int] = ["a"]
# TODO: this should be an error
# error: [invalid-assignment] "Object of type `list[list[DivergentList[int]] | list[list[DivergentList[int]] | int]]` is not assignable to `DivergentList[int]`"
d4: DivergentList[int] = [[1]]

def _(x: DivergentList[int]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,11 @@ def f(x: OptNestedInt) -> None:
reveal_type(x) # revealed: int | tuple[OptNestedInt, ...] | None
if x is not None:
reveal_type(x) # revealed: int | tuple[OptNestedInt, ...]

type RecursiveList = list[RecursiveList]

def g(x: RecursiveList):
reveal_type(x[0]) # revealed: list[RecursiveList]
```

### Invalid self-referential
Expand Down Expand Up @@ -431,8 +436,7 @@ type Foo[T] = list[T] | Bar[T]
type Bar[T] = int | Foo[T]

def _(x: Bar[int]):
# TODO: should be `int | list[int]`
reveal_type(x) # revealed: int | list[int] | Any
reveal_type(x) # revealed: int | list[int]
```

### With legacy generic
Expand Down Expand Up @@ -579,7 +583,7 @@ type A = list[Union["A", str]]
def f(x: A):
reveal_type(x) # revealed: list[A | str]
for item in x:
reveal_type(item) # revealed: list[Any | str] | str
reveal_type(item) # revealed: list[A | str] | str
```

#### With new-style union
Expand All @@ -590,7 +594,7 @@ type A = list[A | str]
def f(x: A):
reveal_type(x) # revealed: list[A | str]
for item in x:
reveal_type(item) # revealed: list[Any | str] | str
reveal_type(item) # revealed: list[A | str] | str
```

#### With Optional
Expand All @@ -603,7 +607,7 @@ type A = list[Optional[Union["A", str]]]
def f(x: A):
reveal_type(x) # revealed: list[A | str | None]
for item in x:
reveal_type(item) # revealed: list[Any | str | None] | str | None
reveal_type(item) # revealed: list[A | str | None] | str | None
```

### Tuple comparison
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ class Bar:
def g(self, x: Scalar | ArrayNd) -> None:
pass

# TODO: should be `bound method Bar.g(x: Scalar | ArrayNd) -> None`
reveal_type(Bar().g) # revealed: bound method Bar.g(x: Scalar | list[Any] | tuple[Any]) -> None
reveal_type(Bar().g) # revealed: bound method Bar.g(x: Scalar | ArrayNd) -> None

type GenericArray1d[T] = list[T] | tuple[T]

Expand Down
89 changes: 49 additions & 40 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ impl<'db> ApplyTypeMappingVisitor<'db> {

pub(crate) fn visit(
&self,
db: &'db dyn Db,
ty: Type<'db>,
type_mapping: &TypeMapping<'_, 'db>,
func: impl FnOnce() -> Type<'db>,
Expand All @@ -272,15 +273,15 @@ impl<'db> ApplyTypeMappingVisitor<'db> {
TypeMapping::Materialize(MaterializationKind::Top) => self
.top_materialization
.get_or_init(TypeTransformer::default)
.visit(ty, func),
.visit_type(db, ty, func),
TypeMapping::Materialize(MaterializationKind::Bottom) => self
.bottom_materialization
.get_or_init(TypeTransformer::default)
.visit(ty, func),
.visit_type(db, ty, func),
_ => self
.default
.get_or_init(TypeTransformer::default)
.visit(ty, func),
.visit_type(db, ty, func),
}
}

Expand Down Expand Up @@ -5584,7 +5585,7 @@ impl<'db> Type<'db> {
Type::TypeVar(bound_typevar) => bound_typevar.apply_type_mapping_impl(db, type_mapping, visitor),
Type::KnownInstance(known_instance) => known_instance.apply_type_mapping_impl(db, type_mapping, tcx, visitor),

Type::FunctionLiteral(function) => visitor.visit(self, type_mapping, || {
Type::FunctionLiteral(function) => visitor.visit(db, self, type_mapping, || {
match type_mapping {
// Promote the types within the signature before promoting the signature to its
// callable form.
Expand Down Expand Up @@ -5632,7 +5633,7 @@ impl<'db> Type<'db> {
instance.apply_type_mapping_impl(db, type_mapping, tcx, visitor)
},

Type::NewTypeInstance(newtype) => visitor.visit(self, type_mapping, || {
Type::NewTypeInstance(newtype) => visitor.visit(db, self, type_mapping, || {
Type::NewTypeInstance(newtype.map_base_class_type(db, |class_type| {
class_type.apply_type_mapping_impl(db, type_mapping, tcx, visitor)
}))
Expand Down Expand Up @@ -5718,7 +5719,7 @@ impl<'db> Type<'db> {
}

// TODO(jelle): Materialize should be handled differently, since TypeIs is invariant
Type::TypeIs(type_is) => visitor.visit(self, type_mapping, || {
Type::TypeIs(type_is) => visitor.visit(db, self, type_mapping, || {
type_is.with_type(
db,
type_is
Expand All @@ -5727,7 +5728,7 @@ impl<'db> Type<'db> {
)
}),

Type::TypeGuard(type_guard) => visitor.visit(self, type_mapping, || {
Type::TypeGuard(type_guard) => visitor.visit(db, self, type_mapping, || {
type_guard.with_type(
db,
type_guard
Expand All @@ -5737,42 +5738,50 @@ impl<'db> Type<'db> {
}),

Type::TypeAlias(alias) => {
// For EagerExpansion, expand the raw value type. This path relies on Salsa's cycle
// detection rather than the visitor's cycle detection, because the visitor tracks
// Type values and `RecursiveList` is different from `RecursiveList[T]`.
if TypeMapping::EagerExpansion == *type_mapping {
return alias.raw_value_type(db).expand_eagerly(db);
}

// Do not call `value_type` here. `value_type` does the specialization internally, so `apply_type_mapping` is
// performed without `visitor` inheritance. In the case of recursive type aliases, this leads to infinite recursion.
// Instead, call `raw_value_type` and perform the specialization after the `visitor` cache has been created.
//
// IMPORTANT: All processing must happen inside a single visitor.visit() call so that if we encounter
// this same TypeAlias again (e.g., in `type RecursiveT = int | tuple[RecursiveT, ...]`), the visitor
// will detect the cycle and return the fallback value.
let mapped = visitor.visit(self, type_mapping, || {
match type_mapping {
TypeMapping::EagerExpansion => unreachable!("handled above"),

_ => {
match type_mapping {
// For EagerExpansion, expand the raw value type. This path relies on Salsa's cycle
// detection rather than the visitor's cycle detection, because the visitor tracks
// Type values and `RecursiveList` is different from `RecursiveList[T]`.
TypeMapping::EagerExpansion => {
alias.raw_value_type(db).expand_eagerly(db)
},
// When specializing a generic type alias, instead of specializing the expanded type, the type alias itself is specialized.
// Without this special handling, recursive type aliases would result in cycles, returning an unspecialized fallback type.
TypeMapping::ApplySpecialization(specialization)
| TypeMapping::ApplySpecializationWithMaterialization { specialization, .. }
if matches!(specialization, ApplySpecialization::Specialization(_) | ApplySpecialization::Partial { .. }) => {
let current_specialization = specialization.as_specialization(db).unwrap();
Type::TypeAlias(alias.apply_specialization(
db,
|generic_context| {
alias
.specialization(db)
.unwrap_or_else(|| generic_context.identity_specialization(db))
.apply_specialization(db, current_specialization)
},
))
}
_ => {
// Do not call `value_type` here. `value_type` does the specialization internally, so `apply_type_mapping` is
// performed without `visitor` inheritance. In the case of recursive type aliases, this leads to infinite recursion.
// Instead, call `raw_value_type` and perform the specialization after the `visitor` cache has been created.
//
// IMPORTANT: All processing must happen inside a single visitor.visit() call so that if we encounter
// this same TypeAlias again (e.g., in `type RecursiveT = int | tuple[RecursiveT, ...]`), the visitor
// will detect the cycle and return the fallback value.
let mapped = visitor.visit(db, self, type_mapping, || {
let value_type = alias.raw_value_type(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor);
alias.apply_function_specialization(db, value_type).apply_type_mapping_impl(db, type_mapping, tcx, visitor)
});

// If the type mapping does not result in any change to this type alias, keep the
// alias node instead of eagerly expanding it.
if alias.value_type(db) == mapped {
self
} else {
mapped
}
}
});

let is_recursive = any_over_type(db, alias.raw_value_type(db).expand_eagerly(db), false, |ty| ty.is_divergent());

// If the type mapping does not result in any change to this (non-recursive) type alias, do not expand it.
//
// TODO: The rule that recursive type aliases must be expanded could potentially be removed, but doing so would
// currently cause a stack overflow, as the current recursive type alias specialization/expansion mechanism is
// incomplete.
if !is_recursive && alias.value_type(db) == mapped {
self
} else {
mapped
}
}

Expand Down Expand Up @@ -6802,7 +6811,7 @@ impl<'db> TypeMapping<'_, 'db> {
specialization,
materialization_kind,
} => TypeMapping::ApplySpecializationWithMaterialization {
specialization: specialization.clone(),
specialization: *specialization,
materialization_kind: materialization_kind.flip(),
},
TypeMapping::Promote(mode, kind) => TypeMapping::Promote(mode.flip(), *kind),
Expand Down
29 changes: 16 additions & 13 deletions crates/ty_python_semantic/src/types/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -885,15 +885,18 @@ enum NestedSubstitutionSide {
/// Identifies one nested-typevar substitution that has been applied while saturating a single
/// BDD path.
///
/// We intentionally key this by the constraint that we substitute _into_ and the typevar that we
/// substitute _for_, but not by the replacement type. For the pathological cases that matter for
/// performance, the same nested substitution shape can keep producing ever-deeper replacement
/// types (for instance, repeated `Iterable[...]` wrapping). Recording only the substitution site
/// lets [`PathAssignments`] apply that substitution at most once per path, which preserves the
/// initial cross-typevar relationship without repeatedly unfolding the same pattern.
/// We key this by the typevar of the constrained constraint (which stays the same across an
/// entire chain of derivations against a single root constraint), the typevar that we substitute
/// _for_, and the side. We deliberately do _not_ key by the constraint id we substitute into,
/// because each nested substitution produces a new derived constraint, and if we keyed by that
/// id the next derivation step would have a different id and the repeat-guard would never fire.
/// The pathological cases that matter for performance involve repeated wrapping (e.g.
/// `Iterable[...]` layers) that keeps producing ever-deeper replacement types while targeting
/// the same constrained typevar; keying by the constrained typevar plus the substituted typevar
/// lets [`PathAssignments`] apply each substitution shape at most once per path.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)]
struct NestedSubstitution {
substituted_into: ConstraintId,
constrained_typevar: TypeVarId,
substituted_typevar: TypeVarId,
side: NestedSubstitutionSide,
}
Expand All @@ -909,12 +912,12 @@ struct DerivedConstraint {
fn nested_substitution<'db>(
db: &'db dyn Db,
builder: &ConstraintSetBuilder<'db>,
substituted_into: ConstraintId,
constrained_typevar: BoundTypeVarInstance<'db>,
substituted_typevar: BoundTypeVarInstance<'db>,
side: NestedSubstitutionSide,
) -> NestedSubstitution {
NestedSubstitution {
substituted_into,
constrained_typevar: builder.typevar_id(db, constrained_typevar),
substituted_typevar: builder.typevar_id(db, substituted_typevar),
side,
}
Expand Down Expand Up @@ -4870,7 +4873,7 @@ impl SequentMap {
Some(nested_substitution(
db,
builder,
constrained_constraint,
constrained_typevar,
bound_typevar,
NestedSubstitutionSide::Upper,
)),
Expand Down Expand Up @@ -4934,7 +4937,7 @@ impl SequentMap {
Some(nested_substitution(
db,
builder,
constrained_constraint,
constrained_typevar,
bound_typevar,
NestedSubstitutionSide::Lower,
)),
Expand Down Expand Up @@ -5030,7 +5033,7 @@ impl SequentMap {
Some(nested_substitution(
db,
builder,
constrained_constraint,
constrained_typevar,
nested_typevar,
NestedSubstitutionSide::Upper,
)),
Expand Down Expand Up @@ -5074,7 +5077,7 @@ impl SequentMap {
Some(nested_substitution(
db,
builder,
constrained_constraint,
constrained_typevar,
nested_typevar,
NestedSubstitutionSide::Lower,
)),
Expand Down
Loading
Loading