diff --git a/mypy/checker.py b/mypy/checker.py index 59571954e0f7..224f3e63287c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -379,6 +379,17 @@ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> Literal return False +class _AssignmentExprSeeker(TraverserVisitor): + """Check if an expression tree contains a walrus operator (:=).""" + + def __init__(self) -> None: + super().__init__() + self.found = False + + def visit_assignment_expr(self, o: AssignmentExpr) -> None: + self.found = True + + class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi, SplittingVisitor): """Mypy type checker. @@ -4756,8 +4767,12 @@ def infer_rvalue_with_fallback_context( # binder.accumulate_type_assignments() and assign the types inferred for type # context that is ultimately used. This is however tricky with redefinitions. # For now we simply disable second accept in cases known to cause problems, - # see e.g. testAssignToOptionalTupleWalrus. - binder_version = self.binder.version + # see e.g. testAssignToOptionalTupleWalrus. We only need to scan for walrus + # when union fallback is otherwise applicable. + union_fallback_possible = preferred_context is not None and isinstance( + get_proper_type(lvalue_type), UnionType + ) + has_walrus = union_fallback_possible and self._has_assignment_expr(rvalue) fallback_context_used = False with ( @@ -4784,11 +4799,7 @@ def infer_rvalue_with_fallback_context( # Try re-inferring r.h.s. in empty context for union with explicit annotation, # and use it results in a narrower type. This helps with various practical # examples, see e.g. testOptionalTypeNarrowedByGenericCall. - union_fallback = ( - preferred_context is not None - and isinstance(get_proper_type(lvalue_type), UnionType) - and binder_version == self.binder.version - ) + union_fallback = union_fallback_possible and not has_walrus # Skip literal types, as they have special logic (for better errors). try_fallback = redefinition_fallback or union_fallback or argument_redefinition_fallback @@ -5091,6 +5102,13 @@ def visit_return_stmt(self, s: ReturnStmt) -> None: self.check_return_stmt(s) self.binder.unreachable() + @staticmethod + def _has_assignment_expr(expr: Expression) -> bool: + """Check if an expression tree contains a walrus operator (:=).""" + seeker = _AssignmentExprSeeker() + expr.accept(seeker) + return seeker.found + def infer_context_dependent( self, expr: Expression, type_ctx: Type, allow_none_func_call: bool ) -> ProperType: diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index eda17c820d42..49b6612fb77b 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -1514,6 +1514,19 @@ i = i if isinstance(i, int) else b reveal_type(i) # N: Revealed type is "Any | builtins.int" [builtins fixtures/isinstance.pyi] +[case testTypeNarrowingByReassignmentGeneratorTernary] +from typing import Iterable, Union + +def foo(args: Union[Iterable[Union[str, int]], str, int]) -> Iterable[str]: + if isinstance(args, (str, int)): + args = (args,) + args = ( + arg if isinstance(arg, str) else str(arg) + for arg in args + ) + return args +[builtins fixtures/isinstance.pyi] + [case testLambdaInferenceUsesNarrowedTypes] from typing import Optional, Callable