Skip to content

Commit ee4b6e2

Browse files
committed
Fix TypedDict
1 parent 2b246a1 commit ee4b6e2

3 files changed

Lines changed: 200 additions & 11 deletions

File tree

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,165 @@ takes_formatter({"format": "%(message)s"})
7171
takes_formatter({"factory": object(), "facility": "local0"})
7272
```
7373

74+
Large `dict[str, TypedDict]` literals should still preserve inner `TypedDict` keys after crossing
75+
the large-collection promotion threshold:
76+
77+
```py
78+
from typing import TypedDict
79+
80+
class Entry(TypedDict):
81+
a: str
82+
b: bool
83+
84+
entries: dict[str, Entry] = {
85+
"k0": {"a": "v", "b": False},
86+
"k1": {"a": "v", "b": False},
87+
"k2": {"a": "v", "b": False},
88+
"k3": {"a": "v", "b": False},
89+
"k4": {"a": "v", "b": False},
90+
"k5": {"a": "v", "b": False},
91+
"k6": {"a": "v", "b": False},
92+
"k7": {"a": "v", "b": False},
93+
"k8": {"a": "v", "b": False},
94+
"k9": {"a": "v", "b": False},
95+
"k10": {"a": "v", "b": False},
96+
"k11": {"a": "v", "b": False},
97+
"k12": {"a": "v", "b": False},
98+
"k13": {"a": "v", "b": False},
99+
"k14": {"a": "v", "b": False},
100+
"k15": {"a": "v", "b": False},
101+
"k16": {"a": "v", "b": False},
102+
"k17": {"a": "v", "b": False},
103+
"k18": {"a": "v", "b": False},
104+
"k19": {"a": "v", "b": False},
105+
"k20": {"a": "v", "b": False},
106+
"k21": {"a": "v", "b": False},
107+
"k22": {"a": "v", "b": False},
108+
"k23": {"a": "v", "b": False},
109+
"k24": {"a": "v", "b": False},
110+
"k25": {"a": "v", "b": False},
111+
"k26": {"a": "v", "b": False},
112+
"k27": {"a": "v", "b": False},
113+
"k28": {"a": "v", "b": False},
114+
"k29": {"a": "v", "b": False},
115+
"k30": {"a": "v", "b": False},
116+
"k31": {"a": "v", "b": False},
117+
"k32": {"a": "v", "b": False},
118+
"k33": {"a": "v", "b": False},
119+
"k34": {"a": "v", "b": False},
120+
"k35": {"a": "v", "b": False},
121+
"k36": {"a": "v", "b": False},
122+
"k37": {"a": "v", "b": False},
123+
"k38": {"a": "v", "b": False},
124+
"k39": {"a": "v", "b": False},
125+
"k40": {"a": "v", "b": False},
126+
"k41": {"a": "v", "b": False},
127+
"k42": {"a": "v", "b": False},
128+
"k43": {"a": "v", "b": False},
129+
"k44": {"a": "v", "b": False},
130+
"k45": {"a": "v", "b": False},
131+
"k46": {"a": "v", "b": False},
132+
"k47": {"a": "v", "b": False},
133+
"k48": {"a": "v", "b": False},
134+
"k49": {"a": "v", "b": False},
135+
"k50": {"a": "v", "b": False},
136+
"k51": {"a": "v", "b": False},
137+
"k52": {"a": "v", "b": False},
138+
"k53": {"a": "v", "b": False},
139+
"k54": {"a": "v", "b": False},
140+
"k55": {"a": "v", "b": False},
141+
"k56": {"a": "v", "b": False},
142+
"k57": {"a": "v", "b": False},
143+
"k58": {"a": "v", "b": False},
144+
"k59": {"a": "v", "b": False},
145+
"k60": {"a": "v", "b": False},
146+
"k61": {"a": "v", "b": False},
147+
"k62": {"a": "v", "b": False},
148+
"k63": {"a": "v", "b": False},
149+
"k64": {"a": "v", "b": False},
150+
}
151+
152+
reveal_type(entries["k0"]) # revealed: Entry
153+
```
154+
155+
Large dict literals should still preserve `TypedDict` subscript keys in nested value expressions:
156+
157+
```py
158+
class BacktestContent(TypedDict):
159+
final_balance: float
160+
backtest_start_time: int
161+
162+
def backtest_stats(content: BacktestContent) -> dict[str, object]:
163+
return {
164+
"k0": None,
165+
"k1": None,
166+
"k2": None,
167+
"k3": None,
168+
"k4": None,
169+
"k5": None,
170+
"k6": None,
171+
"k7": None,
172+
"k8": None,
173+
"k9": None,
174+
"k10": None,
175+
"k11": None,
176+
"k12": None,
177+
"k13": None,
178+
"k14": None,
179+
"k15": None,
180+
"k16": None,
181+
"k17": None,
182+
"k18": None,
183+
"k19": None,
184+
"k20": None,
185+
"k21": None,
186+
"k22": None,
187+
"k23": None,
188+
"k24": None,
189+
"k25": None,
190+
"k26": None,
191+
"k27": None,
192+
"k28": None,
193+
"k29": None,
194+
"k30": None,
195+
"k31": None,
196+
"k32": None,
197+
"k33": None,
198+
"k34": None,
199+
"k35": None,
200+
"k36": None,
201+
"k37": None,
202+
"k38": None,
203+
"k39": None,
204+
"k40": None,
205+
"k41": None,
206+
"k42": None,
207+
"k43": None,
208+
"k44": None,
209+
"k45": None,
210+
"k46": None,
211+
"k47": None,
212+
"k48": None,
213+
"k49": None,
214+
"k50": None,
215+
"k51": None,
216+
"k52": None,
217+
"k53": None,
218+
"k54": None,
219+
"k55": None,
220+
"k56": None,
221+
"k57": None,
222+
"k58": None,
223+
"k59": None,
224+
"k60": None,
225+
"k61": None,
226+
"k62": None,
227+
"k63": None,
228+
"final_balance": content["final_balance"],
229+
"backtest_run_start_ts": content["backtest_start_time"],
230+
}
231+
```
232+
74233
Methods that are available on `dict`s are also available on `TypedDict`s:
75234

76235
```py

crates/ty_python_semantic/src/types/infer/builder/subscript.rs

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ use ty_python_core::place::{PlaceExpr, PlaceExprRef};
3838
use ty_python_core::scope::FileScopeId;
3939

4040
impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
41+
fn infer_subscript_slice(&mut self, slice: &ast::Expr, tcx: TypeContext<'db>) -> Type<'db> {
42+
self.with_inference_flag(InferenceFlags::PROMOTE_LITERALS, false, |builder| {
43+
builder.infer_expression(slice, tcx)
44+
})
45+
}
46+
4147
pub(super) fn infer_subscript_expression(
4248
&mut self,
4349
subscript: &ast::ExprSubscript,
@@ -54,19 +60,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
5460
ExprContext::Load => self.infer_subscript_load(subscript),
5561
ExprContext::Store => {
5662
let value_ty = self.infer_expression(value, TypeContext::default());
57-
let slice_ty = self.infer_expression(slice, TypeContext::default());
63+
let slice_ty = self.infer_subscript_slice(slice, TypeContext::default());
5864
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
5965
Type::Never
6066
}
6167
ExprContext::Del => {
6268
let value_ty = self.infer_expression(value, TypeContext::default());
63-
let slice_ty = self.infer_expression(slice, TypeContext::default());
69+
let slice_ty = self.infer_subscript_slice(slice, TypeContext::default());
6470
self.validate_subscript_deletion(subscript, value_ty, slice_ty);
6571
Type::Never
6672
}
6773
ExprContext::Invalid => {
6874
let value_ty = self.infer_expression(value, TypeContext::default());
69-
let slice_ty = self.infer_expression(slice, TypeContext::default());
75+
let slice_ty = self.infer_subscript_slice(slice, TypeContext::default());
7076
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
7177
Type::unknown()
7278
}
@@ -119,7 +125,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
119125
{
120126
// Even if we can obtain the subscript type based on the assignments, we still perform default type inference
121127
// (to store the expression type and to report errors).
122-
let slice_ty = self.infer_expression(slice, TypeContext::default());
128+
let slice_ty = self.infer_subscript_slice(slice, TypeContext::default());
123129
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
124130
return ty;
125131
}
@@ -162,7 +168,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
162168
Type::KnownInstance(KnownInstanceType::TypeAliasType(TypeAliasType::ManualPEP695(
163169
_,
164170
))) => {
165-
let slice_ty = self.infer_expression(slice, TypeContext::default());
171+
let slice_ty = self.infer_subscript_slice(slice, TypeContext::default());
166172
let mut variables = FxOrderSet::default();
167173
slice_ty.bind_and_find_all_legacy_typevars(
168174
db,
@@ -263,7 +269,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
263269
return union_type;
264270
}
265271
_ => {
266-
return self.infer_expression(slice, TypeContext::default());
272+
return self.infer_subscript_slice(slice, TypeContext::default());
267273
}
268274
},
269275
SpecialFormType::Type => {
@@ -332,7 +338,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
332338
return self.infer_explicit_type_alias_specialization(subscript, value_ty, false);
333339
}
334340
Type::Dynamic(DynamicType::Unknown) => {
335-
let slice_ty = self.infer_expression(slice, TypeContext::default());
341+
let slice_ty = self.infer_subscript_slice(slice, TypeContext::default());
336342
let mut variables = FxOrderSet::default();
337343
slice_ty.bind_and_find_all_legacy_typevars(
338344
db,
@@ -345,7 +351,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
345351
_ => {}
346352
}
347353

348-
let slice_ty = self.infer_expression(slice, TypeContext::default());
354+
let slice_ty = self.infer_subscript_slice(slice, TypeContext::default());
349355
let result_ty = self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
350356
self.narrow_expr_with_applicable_constraints(subscript, result_ty, &constraint_keys)
351357
}
@@ -1070,7 +1076,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
10701076
} = target;
10711077

10721078
let object_ty = self.infer_expression(object, TypeContext::default());
1073-
let mut infer_slice_ty = |builder: &mut Self, tcx| builder.infer_expression(slice, tcx);
1079+
let mut infer_slice_ty =
1080+
|builder: &mut Self, tcx| builder.infer_subscript_slice(slice, tcx);
10741081

10751082
self.validate_subscript_assignment_impl(
10761083
target,

crates/ty_python_semantic/src/types/infer/builder/typed_dict.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::types::diagnostic::{
1111
INVALID_ARGUMENT_TYPE, INVALID_TYPE_FORM, MISSING_ARGUMENT, TOO_MANY_POSITIONAL_ARGUMENTS,
1212
UNKNOWN_ARGUMENT, report_mismatched_type_name,
1313
};
14+
use crate::types::infer::InferenceFlags;
1415
use crate::types::infer::builder::DeferredExpressionState;
1516
use crate::types::special_form::TypeQualifier;
1617
use crate::types::typed_dict::{
@@ -61,6 +62,25 @@ impl<'expr> TypedDictConstructorForm<'expr> {
6162
}
6263

6364
impl<'db> TypeInferenceBuilder<'db, '_> {
65+
/// Preserve string literal keys while inferring `TypedDict` fields, even when the enclosing
66+
/// collection enables large-literal promotion.
67+
fn infer_typed_dict_key_expression(&mut self, key: &ast::Expr) -> Type<'db> {
68+
if let Some(key_ty) = self.try_expression_type(key) {
69+
return key_ty.as_string_literal().map_or_else(
70+
|| {
71+
key.as_string_literal_expr().map_or(key_ty, |literal| {
72+
Type::string_literal(self.db(), literal.value.to_str())
73+
})
74+
},
75+
|_| key_ty,
76+
);
77+
}
78+
79+
self.with_inference_flag(InferenceFlags::PROMOTE_LITERALS, false, |builder| {
80+
builder.infer_expression(key, TypeContext::default())
81+
})
82+
}
83+
6484
/// Infer a `TypedDict(name, fields)` call expression.
6585
///
6686
/// This method *does not* call `infer_expression` on the object being called;
@@ -309,7 +329,10 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
309329
let typed_dict_items = typed_dict.items(self.db());
310330

311331
for item in items {
312-
let key_ty = self.infer_optional_expression(item.key.as_ref(), TypeContext::default());
332+
let key_ty = item
333+
.key
334+
.as_ref()
335+
.map(|key| self.infer_typed_dict_key_expression(key));
313336
if let Some((key, key_ty)) = item.key.as_ref().zip(key_ty) {
314337
item_types.insert(key.node_index().load(), key_ty);
315338
}
@@ -436,7 +459,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
436459
let value_tcx = item
437460
.key
438461
.as_ref()
439-
.map(|key| self.get_or_infer_expression(key, TypeContext::default()))
462+
.map(|key| self.infer_typed_dict_key_expression(key))
440463
.and_then(Type::as_string_literal)
441464
.and_then(|key| items.get(key.value(self.db())))
442465
.map(|field| TypeContext::new(Some(field.declared_ty)))

0 commit comments

Comments
 (0)