propagate the narrowed element instead of skipping narrowing entirely

This commit is contained in:
Zanie Blue
2025-12-19 11:20:15 -06:00
parent 2fd90afc7a
commit 3cf876e27a
2 changed files with 72 additions and 14 deletions

View File

@@ -311,13 +311,15 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
/// the right hand side of an annotated assignment in a class that is a dataclass).
dataclass_field_specifiers: SmallVec<[Type<'db>; NUM_FIELD_SPECIFIERS_INLINE]>,
/// Unions we're currently narrowing against in ancestor calls.
/// Unions we're currently narrowing against in ancestor calls, mapped to the element
/// we're currently trying.
///
/// When inferring a call expression with a union type context, we try narrowing to each
/// element of the union. If nested calls have the same union as their parameter type,
/// this would lead to exponential blowup. By tracking which unions we're already narrowing
/// against, we skip redundant nested narrowing.
narrowing_unions: FxHashSet<UnionType<'db>>,
/// against and what element we're narrowing to, nested calls can use the narrowed element
/// directly instead of re-narrowing.
narrowing_unions: FxHashMap<UnionType<'db>, Type<'db>>,
}
impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
@@ -356,7 +358,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
cycle_recovery: None,
all_definitely_bound: true,
dataclass_field_specifiers: SmallVec::new(),
narrowing_unions: FxHashSet::default(),
narrowing_unions: FxHashMap::default(),
}
}
@@ -7016,12 +7018,23 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
) -> Result<(), CallErrorKind> {
let db = self.db();
// If the type context is a union, attempt to narrow to a specific element.
// However, skip narrowing if we're already narrowing against the same union
// in an ancestor call to avoid exponential blowup with deeply nested calls.
// If we're already narrowing against this union in an ancestor call, propagate
// the narrowed element instead of re-narrowing (which would cause exponential blowup).
let call_expression_tcx = match call_expression_tcx.annotation {
Some(Type::Union(union)) => {
if let Some(&narrowed_ty) = self.narrowing_unions.get(&union) {
TypeContext::new(Some(narrowed_ty))
} else {
call_expression_tcx
}
}
_ => call_expression_tcx,
};
// If the type context is a union we haven't seen, attempt to narrow to a specific element.
let (narrow_union, narrow_targets): (Option<UnionType<'db>>, &[_]) =
match call_expression_tcx.annotation {
Some(Type::Union(union)) if !self.narrowing_unions.contains(&union) => {
Some(Type::Union(union)) if !self.narrowing_unions.contains_key(&union) => {
// TODO: We could theoretically attempt to narrow to every element of
// the power set of this union. However, this leads to an exponential
// explosion of inference attempts, and is rarely needed in practice.
@@ -7030,12 +7043,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
_ => (None, &[]),
};
// Track that we're narrowing against this union to prevent nested calls
// from redundantly narrowing against the same union.
if let Some(union) = narrow_union {
self.narrowing_unions.insert(union);
}
// We silence diagnostics until we successfully narrow to a specific type.
let mut speculated_bindings = bindings.clone();
let was_in_multi_inference = self.context.set_multi_inference(true);
@@ -7043,6 +7050,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let mut try_narrow = |narrowed_ty| {
let narrowed_tcx = TypeContext::new(Some(narrowed_ty));
// Track that we're narrowing this union to this specific element,
// so nested calls can propagate the narrowed type.
if let Some(union) = narrow_union {
self.narrowing_unions.insert(union, narrowed_ty);
}
// Attempt to infer the argument types using the narrowed type context.
self.infer_all_argument_types(
ast_arguments,