Compare commits

...

3 Commits

Author SHA1 Message Date
Zanie Blue
fd525e8cc4 review 2025-12-19 15:37:11 -06:00
Zanie Blue
3cf876e27a propagate the narrowed element instead of skipping narrowing entirely 2025-12-19 11:27:16 -06:00
Zanie Blue
2fd90afc7a [ty] Track narrowing unions to prevent exponential blowup
When inferring a call expression with a union type context, we try narrowing
to each element of the union. If nested calls have the same union as their
parameter type, this would lead to exponential blowup. By tracking which
unions we're already narrowing against, we skip redundant nested narrowing.

On a synthetic benchmark with nested `list_schema()` calls:
- 4-level nesting: >120s → 0.13s
2025-12-18 20:36:37 -06:00
2 changed files with 109 additions and 14 deletions

View File

@@ -362,3 +362,48 @@ def _(x: X, flag: bool):
# error: [possibly-unresolved-reference] "Name `y` used when possibly not defined"
x.td = {"y": y}
```
## Nested calls with union parameter types
When inferring nested function calls where the parameter type is a union, the narrowed type context
should be propagated through nested calls rather than re-narrowing at each level. This avoids
exponential blowup while preserving type precision.
```toml
[environment]
python-version = "3.12"
```
`nested_union_inference.py`:
```py
def f(x: list[int | None]): ...
def g[T](x: T) -> list[T]:
return [x]
# Bidirectional inference uses the type context to solve generics.
# The type context list[int | None] guides the inference of g(1).
f(reveal_type(g(1))) # revealed: list[int | None]
```
`nested_same_union.py`:
When a function's parameter type is the same union as the outer type context, the narrowed element
should be propagated instead of re-narrowing (which would cause exponential blowup).
```py
type CoreSchema = A | B | C | D
class A: ...
class B: ...
class C: ...
class D: ...
def outer(x: CoreSchema) -> None: ...
def inner(x: CoreSchema) -> D:
return D()
# The outer call narrows CoreSchema to D. The inner call's parameter is also CoreSchema,
# so the narrowed type D is propagated instead of trying all 4 elements again.
outer(reveal_type(inner(inner(inner(D()))))) # revealed: D
```

View File

@@ -12,6 +12,8 @@ use ruff_python_ast::{
use ruff_python_stdlib::builtins::version_builtin_was_added;
use ruff_text_size::{Ranged, TextRange};
use rustc_hash::{FxHashMap, FxHashSet};
use crate::FxOrderMap;
use smallvec::SmallVec;
use super::{
@@ -310,6 +312,16 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
/// A list of `dataclass_transform` field specifiers that are "active" (when inferring
/// the right hand side of an annotated assignment in a class that is a dataclass).
dataclass_field_specifiers: SmallVec<[Type<'db>; NUM_FIELD_SPECIFIERS_INLINE]>,
/// Unions we're currently narrowing against in ancestor calls, mapped to the element
/// we're currently trying.
///
/// When inferring a call expression with a union type context, we try narrowing to each
/// element of the union. If nested calls have the same union as their parameter type,
/// this would lead to exponential blowup. By tracking which unions we're already narrowing
/// against and what element we're narrowing to, nested calls can use the narrowed element
/// directly instead of re-narrowing.
narrowing_unions: FxOrderMap<UnionType<'db>, Type<'db>>,
}
impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
@@ -348,6 +360,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
cycle_recovery: None,
all_definitely_bound: true,
dataclass_field_specifiers: SmallVec::new(),
narrowing_unions: FxOrderMap::default(),
}
}
@@ -7007,15 +7020,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
) -> Result<(), CallErrorKind> {
let db = self.db();
// If the type context is a union, attempt to narrow to a specific element.
let narrow_targets: &[_] = match call_expression_tcx.annotation {
// TODO: We could theoretically attempt to narrow to every element of
// the power set of this union. However, this leads to an exponential
// explosion of inference attempts, and is rarely needed in practice.
Some(Type::Union(union)) => union.elements(db),
_ => &[],
// If we're already narrowing against this union in an ancestor call, propagate
// the narrowed element instead of re-narrowing (which would cause exponential blowup).
let call_expression_tcx = if let Some(Type::Union(union)) = call_expression_tcx.annotation
&& let Some(&narrowed_ty) = self.narrowing_unions.get(&union)
{
TypeContext::new(Some(narrowed_ty))
} else {
call_expression_tcx
};
// If the type context is a union we haven't seen, attempt to narrow to a specific element.
let (narrow_union, narrow_targets): (Option<UnionType<'db>>, &[_]) =
match call_expression_tcx.annotation {
Some(Type::Union(union)) if !self.narrowing_unions.contains_key(&union) => {
// TODO: We could theoretically attempt to narrow to every element of
// the power set of this union. However, this leads to an exponential
// explosion of inference attempts, and is rarely needed in practice.
(Some(union), union.elements(db))
}
_ => (None, &[]),
};
// We silence diagnostics until we successfully narrow to a specific type.
let mut speculated_bindings = bindings.clone();
let was_in_multi_inference = self.context.set_multi_inference(true);
@@ -7023,6 +7049,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let mut try_narrow = |narrowed_ty| {
let narrowed_tcx = TypeContext::new(Some(narrowed_ty));
// Track that we're narrowing this union to this specific element,
// so nested calls can propagate the narrowed type.
if let Some(union) = narrow_union {
self.narrowing_unions.insert(union, narrowed_ty);
}
// Attempt to infer the argument types using the narrowed type context.
self.infer_all_argument_types(
ast_arguments,
@@ -7082,12 +7114,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
};
// Prefer the declared type of generic classes.
let mut narrowing_result = None;
for narrowed_ty in narrow_targets
.iter()
.filter(|ty| ty.class_specialization(db).is_some())
{
if let Some(result) = try_narrow(*narrowed_ty) {
return result;
narrowing_result = Some(result);
break;
}
}
@@ -7095,15 +7129,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
//
// TODO: We could also attempt an inference without type context, but this
// leads to similar performance issues.
for narrowed_ty in narrow_targets
.iter()
.filter(|ty| ty.class_specialization(db).is_none())
{
if let Some(result) = try_narrow(*narrowed_ty) {
return result;
if narrowing_result.is_none() {
for narrowed_ty in narrow_targets
.iter()
.filter(|ty| ty.class_specialization(db).is_none())
{
if let Some(result) = try_narrow(*narrowed_ty) {
narrowing_result = Some(result);
break;
}
}
}
// Clean up: pop the union from the tracking map.
if narrow_union.is_some() {
self.narrowing_unions.pop();
}
// If narrowing succeeded, return the result.
if let Some(result) = narrowing_result {
return result;
}
// Re-enable diagnostics, and infer against the entire union as a fallback.
self.context.set_multi_inference(was_in_multi_inference);
@@ -12592,6 +12639,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
index: _,
region: _,
return_types_and_ranges: _,
narrowing_unions: _,
} = self;
let diagnostics = context.finish();
@@ -12659,6 +12707,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
index: _,
region: _,
return_types_and_ranges: _,
narrowing_unions: _,
} = self;
let _ = scope;
@@ -12736,6 +12785,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
index: _,
region: _,
return_types_and_ranges: _,
narrowing_unions: _,
} = self;
let _ = scope;