propagate the narrowed element instead of skipping narrowing entirely
This commit is contained in:
@@ -311,13 +311,15 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
|
||||
/// the right hand side of an annotated assignment in a class that is a dataclass).
|
||||
dataclass_field_specifiers: SmallVec<[Type<'db>; NUM_FIELD_SPECIFIERS_INLINE]>,
|
||||
|
||||
/// Unions we're currently narrowing against in ancestor calls.
|
||||
/// Unions we're currently narrowing against in ancestor calls, mapped to the element
|
||||
/// we're currently trying.
|
||||
///
|
||||
/// When inferring a call expression with a union type context, we try narrowing to each
|
||||
/// element of the union. If nested calls have the same union as their parameter type,
|
||||
/// this would lead to exponential blowup. By tracking which unions we're already narrowing
|
||||
/// against, we skip redundant nested narrowing.
|
||||
narrowing_unions: FxHashSet<UnionType<'db>>,
|
||||
/// against and what element we're narrowing to, nested calls can use the narrowed element
|
||||
/// directly instead of re-narrowing.
|
||||
narrowing_unions: FxHashMap<UnionType<'db>, Type<'db>>,
|
||||
}
|
||||
|
||||
impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
@@ -356,7 +358,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
cycle_recovery: None,
|
||||
all_definitely_bound: true,
|
||||
dataclass_field_specifiers: SmallVec::new(),
|
||||
narrowing_unions: FxHashSet::default(),
|
||||
narrowing_unions: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7016,12 +7018,23 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
) -> Result<(), CallErrorKind> {
|
||||
let db = self.db();
|
||||
|
||||
// If the type context is a union, attempt to narrow to a specific element.
|
||||
// However, skip narrowing if we're already narrowing against the same union
|
||||
// in an ancestor call to avoid exponential blowup with deeply nested calls.
|
||||
// If we're already narrowing against this union in an ancestor call, propagate
|
||||
// the narrowed element instead of re-narrowing (which would cause exponential blowup).
|
||||
let call_expression_tcx = match call_expression_tcx.annotation {
|
||||
Some(Type::Union(union)) => {
|
||||
if let Some(&narrowed_ty) = self.narrowing_unions.get(&union) {
|
||||
TypeContext::new(Some(narrowed_ty))
|
||||
} else {
|
||||
call_expression_tcx
|
||||
}
|
||||
}
|
||||
_ => call_expression_tcx,
|
||||
};
|
||||
|
||||
// If the type context is a union we haven't seen, attempt to narrow to a specific element.
|
||||
let (narrow_union, narrow_targets): (Option<UnionType<'db>>, &[_]) =
|
||||
match call_expression_tcx.annotation {
|
||||
Some(Type::Union(union)) if !self.narrowing_unions.contains(&union) => {
|
||||
Some(Type::Union(union)) if !self.narrowing_unions.contains_key(&union) => {
|
||||
// TODO: We could theoretically attempt to narrow to every element of
|
||||
// the power set of this union. However, this leads to an exponential
|
||||
// explosion of inference attempts, and is rarely needed in practice.
|
||||
@@ -7030,12 +7043,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
_ => (None, &[]),
|
||||
};
|
||||
|
||||
// Track that we're narrowing against this union to prevent nested calls
|
||||
// from redundantly narrowing against the same union.
|
||||
if let Some(union) = narrow_union {
|
||||
self.narrowing_unions.insert(union);
|
||||
}
|
||||
|
||||
// We silence diagnostics until we successfully narrow to a specific type.
|
||||
let mut speculated_bindings = bindings.clone();
|
||||
let was_in_multi_inference = self.context.set_multi_inference(true);
|
||||
@@ -7043,6 +7050,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
let mut try_narrow = |narrowed_ty| {
|
||||
let narrowed_tcx = TypeContext::new(Some(narrowed_ty));
|
||||
|
||||
// Track that we're narrowing this union to this specific element,
|
||||
// so nested calls can propagate the narrowed type.
|
||||
if let Some(union) = narrow_union {
|
||||
self.narrowing_unions.insert(union, narrowed_ty);
|
||||
}
|
||||
|
||||
// Attempt to infer the argument types using the narrowed type context.
|
||||
self.infer_all_argument_types(
|
||||
ast_arguments,
|
||||
|
||||
Reference in New Issue
Block a user