From 3cf876e27a406822311cdb941e0112a0b61ba8cb Mon Sep 17 00:00:00 2001 From: Zanie Blue Date: Fri, 19 Dec 2025 11:20:15 -0600 Subject: [PATCH] propagate the narrowed element instead of skipping narrowing entirely --- .../resources/mdtest/bidirectional.md | 45 +++++++++++++++++++ .../src/types/infer/builder.rs | 41 +++++++++++------ 2 files changed, 72 insertions(+), 14 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index bae848c4a7..3208368812 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -362,3 +362,48 @@ def _(x: X, flag: bool): # error: [possibly-unresolved-reference] "Name `y` used when possibly not defined" x.td = {"y": y} ``` + +## Nested calls with union parameter types + +When inferring nested function calls where the parameter type is a union, the narrowed type context +should be propagated through nested calls rather than re-narrowing at each level. This avoids +exponential blowup while preserving type precision. + +```toml +[environment] +python-version = "3.12" +``` + +`nested_union_inference.py`: + +```py +def f(x: list[int | None]): ... +def g[T](x: T) -> list[T]: + return [x] + +# Bidirectional inference uses the type context to solve generics. +# The type context list[int | None] guides the inference of g(1). +f(reveal_type(g(1))) # revealed: list[int | None] +``` + +`nested_same_union.py`: + +When a function's parameter type is the same union as the outer type context, the narrowed element +should be propagated instead of re-narrowing (which would cause exponential blowup). + +```py +type CoreSchema = A | B | C | D + +class A: ... +class B: ... +class C: ... +class D: ... + +def outer(x: CoreSchema) -> None: ... +def inner(x: CoreSchema) -> D: + return D() + +# The outer call narrows CoreSchema to D. The inner call's parameter is also CoreSchema, +# so the narrowed type D is propagated instead of trying all 4 elements again. +outer(reveal_type(inner(inner(inner(D()))))) # revealed: D +``` diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index b3cd5708d5..8cb0001cad 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -311,13 +311,15 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { /// the right hand side of an annotated assignment in a class that is a dataclass). dataclass_field_specifiers: SmallVec<[Type<'db>; NUM_FIELD_SPECIFIERS_INLINE]>, - /// Unions we're currently narrowing against in ancestor calls. + /// Unions we're currently narrowing against in ancestor calls, mapped to the element + /// we're currently trying. /// /// When inferring a call expression with a union type context, we try narrowing to each /// element of the union. If nested calls have the same union as their parameter type, /// this would lead to exponential blowup. By tracking which unions we're already narrowing - /// against, we skip redundant nested narrowing. - narrowing_unions: FxHashSet>, + /// against and what element we're narrowing to, nested calls can use the narrowed element + /// directly instead of re-narrowing. + narrowing_unions: FxHashMap, Type<'db>>, } impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { @@ -356,7 +358,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { cycle_recovery: None, all_definitely_bound: true, dataclass_field_specifiers: SmallVec::new(), - narrowing_unions: FxHashSet::default(), + narrowing_unions: FxHashMap::default(), } } @@ -7016,12 +7018,23 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ) -> Result<(), CallErrorKind> { let db = self.db(); - // If the type context is a union, attempt to narrow to a specific element. - // However, skip narrowing if we're already narrowing against the same union - // in an ancestor call to avoid exponential blowup with deeply nested calls. + // If we're already narrowing against this union in an ancestor call, propagate + // the narrowed element instead of re-narrowing (which would cause exponential blowup). + let call_expression_tcx = match call_expression_tcx.annotation { + Some(Type::Union(union)) => { + if let Some(&narrowed_ty) = self.narrowing_unions.get(&union) { + TypeContext::new(Some(narrowed_ty)) + } else { + call_expression_tcx + } + } + _ => call_expression_tcx, + }; + + // If the type context is a union we haven't seen, attempt to narrow to a specific element. let (narrow_union, narrow_targets): (Option>, &[_]) = match call_expression_tcx.annotation { - Some(Type::Union(union)) if !self.narrowing_unions.contains(&union) => { + Some(Type::Union(union)) if !self.narrowing_unions.contains_key(&union) => { // TODO: We could theoretically attempt to narrow to every element of // the power set of this union. However, this leads to an exponential // explosion of inference attempts, and is rarely needed in practice. @@ -7030,12 +7043,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { _ => (None, &[]), }; - // Track that we're narrowing against this union to prevent nested calls - // from redundantly narrowing against the same union. - if let Some(union) = narrow_union { - self.narrowing_unions.insert(union); - } - // We silence diagnostics until we successfully narrow to a specific type. let mut speculated_bindings = bindings.clone(); let was_in_multi_inference = self.context.set_multi_inference(true); @@ -7043,6 +7050,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut try_narrow = |narrowed_ty| { let narrowed_tcx = TypeContext::new(Some(narrowed_ty)); + // Track that we're narrowing this union to this specific element, + // so nested calls can propagate the narrowed type. + if let Some(union) = narrow_union { + self.narrowing_unions.insert(union, narrowed_ty); + } + // Attempt to infer the argument types using the narrowed type context. self.infer_all_argument_types( ast_arguments,