Compare commits

...

11 Commits

Author SHA1 Message Date
Zanie Blue
c3b4fab764 [ty] Track narrowing unions to prevent exponential blowup
When inferring a call expression with a union type context, we try narrowing
to each element of the union. If nested calls have the same union as their
parameter type, this would lead to exponential blowup. By tracking which
unions we're already narrowing against, we skip redundant nested narrowing.

On a synthetic benchmark with nested `list_schema()` calls:
- 4-level nesting: >120s → 0.13s
2025-12-18 20:24:33 -06:00
Zanie Blue
abf17c6ef4 [ty] Add short-circuits for union type relation checks
- Short-circuit union-vs-union comparisons when both sides are identical
- Add fast membership check before full relation checking when comparing
  a type against a union
2025-12-18 20:24:18 -06:00
Zanie Blue
b9f65213d0 [ty] Cache Type::is_assignable_to
Add salsa caching to `Type::is_assignable_to` to memoize repeated
assignability checks.
2025-12-18 20:24:03 -06:00
Zanie Blue
ff2553665c [ty] Cache ClassType::is_subclass_of
Add salsa caching to `ClassType::is_subclass_of` to avoid repeated MRO
traversals when checking class relationships.
2025-12-18 20:23:28 -06:00
Zanie Blue
8ebbe6b0f6 Invert case 2025-12-18 18:35:38 -06:00
Zanie Blue
6bc88c90b2 [ty] Cache Type::is_disjoint_from 2025-12-18 18:35:38 -06:00
Zanie Blue
3c694c7d86 [ty] Cache Type::is_subtype_of 2025-12-18 18:35:38 -06:00
Zanie Blue
1603948aae [ty] Cache ClassType::nearest_disjoint_base 2025-12-18 18:35:38 -06:00
Micha Reiser
cb6ba23b0a More lazy negated computations 2025-12-18 18:32:39 -06:00
Micha Reiser
a918833d19 Defer insertion 2025-12-18 18:32:39 -06:00
Micha Reiser
bcf9295973 [ty] Small union builder nits 2025-12-18 18:32:39 -06:00
5 changed files with 227 additions and 61 deletions

View File

@@ -1952,8 +1952,22 @@ impl<'db> Type<'db> {
///
/// See [`TypeRelation::Subtyping`] for more details.
pub(crate) fn is_subtype_of(self, db: &'db dyn Db, target: Type<'db>) -> bool {
self.when_subtype_of(db, target, InferableTypeVars::None)
.is_always_satisfied(db)
#[salsa::tracked(cycle_initial=is_subtype_of_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn is_subtype_of_impl<'db>(
db: &'db dyn Db,
self_ty: Type<'db>,
target: Type<'db>,
) -> bool {
self_ty
.when_subtype_of(db, target, InferableTypeVars::None)
.is_always_satisfied(db)
}
if self == target {
return true;
}
is_subtype_of_impl(db, self, target)
}
fn when_subtype_of(
@@ -1988,8 +2002,22 @@ impl<'db> Type<'db> {
///
/// See `TypeRelation::Assignability` for more details.
pub fn is_assignable_to(self, db: &'db dyn Db, target: Type<'db>) -> bool {
self.when_assignable_to(db, target, InferableTypeVars::None)
.is_always_satisfied(db)
#[salsa::tracked(cycle_initial=is_assignable_to_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn is_assignable_to_impl<'db>(
db: &'db dyn Db,
self_ty: Type<'db>,
target: Type<'db>,
) -> bool {
self_ty
.when_assignable_to(db, target, InferableTypeVars::None)
.is_always_satisfied(db)
}
if self == target {
return true;
}
is_assignable_to_impl(db, self, target)
}
/// Return true if this type is assignable to type `target` using constraint-set assignability.
@@ -2403,6 +2431,9 @@ impl<'db> Type<'db> {
// `Never` is the bottom type, the empty set.
(_, Type::Never) => ConstraintSet::from(false),
// Short-circuit: if both sides are the same union, they trivially satisfy the relation.
(Type::Union(left), Type::Union(right)) if left == right => ConstraintSet::from(true),
(Type::Union(union), _) => union.elements(db).iter().when_all(db, |&elem_ty| {
elem_ty.has_relation_to_impl(
db,
@@ -2414,16 +2445,22 @@ impl<'db> Type<'db> {
)
}),
(_, Type::Union(union)) => union.elements(db).iter().when_any(db, |&elem_ty| {
self.has_relation_to_impl(
db,
elem_ty,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
}),
(_, Type::Union(union)) => {
// Fast path: if self is directly a member of the union, no need to check relations
if union.elements(db).contains(&self) {
return ConstraintSet::from(true);
}
union.elements(db).iter().when_any(db, |&elem_ty| {
self.has_relation_to_impl(
db,
elem_ty,
inferable,
relation,
relation_visitor,
disjointness_visitor,
)
})
}
// If both sides are intersections we need to handle the right side first
// (A & B & C) is a subtype of (A & B) because the left is a subtype of both A and B,
@@ -3224,8 +3261,22 @@ impl<'db> Type<'db> {
/// This function aims to have no false positives, but might return wrong
/// `false` answers in some cases.
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool {
self.when_disjoint_from(db, other, InferableTypeVars::None)
.is_always_satisfied(db)
#[salsa::tracked(cycle_initial=is_disjoint_from_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn is_disjoint_from_cached<'db>(
db: &'db dyn Db,
self_ty: Type<'db>,
other: Type<'db>,
) -> bool {
self_ty
.when_disjoint_from(db, other, InferableTypeVars::None)
.is_always_satisfied(db)
}
if self == other {
return false;
}
is_disjoint_from_cached(db, self, other)
}
fn when_disjoint_from(
@@ -8671,6 +8722,37 @@ impl<'db> VarianceInferable<'db> for Type<'db> {
}
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_subtype_of_cycle_initial<'db>(
_db: &'db dyn Db,
_id: salsa::Id,
_self_ty: Type<'db>,
_target: Type<'db>,
) -> bool {
false
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_assignable_to_cycle_initial<'db>(
_db: &'db dyn Db,
_id: salsa::Id,
_self_ty: Type<'db>,
_target: Type<'db>,
) -> bool {
// In case of a cycle, conservatively assume assignable to avoid false positives
true
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_disjoint_from_cycle_initial<'db>(
_db: &'db dyn Db,
_id: salsa::Id,
_self_ty: Type<'db>,
_other: Type<'db>,
) -> bool {
false
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_redundant_with_cycle_initial<'db>(
_db: &'db dyn Db,

View File

@@ -365,7 +365,7 @@ impl<'db> UnionBuilder<'db> {
Type::StringLiteral(literal) => {
let mut found = None;
let mut to_remove = None;
let ty_negated = ty.negate(self.db);
let mut ty_negated = None;
for (index, element) in self.elements.iter_mut().enumerate() {
match element {
UnionElement::StringLiterals(literals) => {
@@ -383,8 +383,10 @@ impl<'db> UnionBuilder<'db> {
}
if existing.is_subtype_of(self.db, ty) {
to_remove = Some(index);
continue;
}
if ty_negated.is_subtype_of(self.db, *existing) {
let negated = ty_negated.get_or_insert_with(|| ty.negate(self.db));
if negated.is_subtype_of(self.db, *existing) {
// The type that includes both this new element, and its negation
// (or a supertype of its negation), must be simply `object`.
self.collapse_to_object();
@@ -410,7 +412,7 @@ impl<'db> UnionBuilder<'db> {
Type::BytesLiteral(literal) => {
let mut found = None;
let mut to_remove = None;
let ty_negated = ty.negate(self.db);
let mut ty_negated = None;
for (index, element) in self.elements.iter_mut().enumerate() {
match element {
UnionElement::BytesLiterals(literals) => {
@@ -428,8 +430,11 @@ impl<'db> UnionBuilder<'db> {
}
if existing.is_subtype_of(self.db, ty) {
to_remove = Some(index);
continue;
}
if ty_negated.is_subtype_of(self.db, *existing) {
let negated = ty_negated.get_or_insert_with(|| ty.negate(self.db));
if negated.is_subtype_of(self.db, *existing) {
// The type that includes both this new element, and its negation
// (or a supertype of its negation), must be simply `object`.
self.collapse_to_object();
@@ -455,7 +460,7 @@ impl<'db> UnionBuilder<'db> {
Type::IntLiteral(literal) => {
let mut found = None;
let mut to_remove = None;
let ty_negated = ty.negate(self.db);
let mut ty_negated = None;
for (index, element) in self.elements.iter_mut().enumerate() {
match element {
UnionElement::IntLiterals(literals) => {
@@ -473,8 +478,11 @@ impl<'db> UnionBuilder<'db> {
}
if existing.is_subtype_of(self.db, ty) {
to_remove = Some(index);
continue;
}
if ty_negated.is_subtype_of(self.db, *existing) {
let negated = ty_negated.get_or_insert_with(|| ty.negate(self.db));
if negated.is_subtype_of(self.db, *existing) {
// The type that includes both this new element, and its negation
// (or a supertype of its negation), must be simply `object`.
self.collapse_to_object();
@@ -549,19 +557,28 @@ impl<'db> UnionBuilder<'db> {
// unpacking them.
let should_simplify_full = !matches!(ty, Type::TypeAlias(_)) && !self.cycle_recovery;
let mut to_remove = SmallVec::<[usize; 2]>::new();
let ty_negated = if should_simplify_full {
ty.negate(self.db)
} else {
Type::Never // won't be used
let mut ty_negated: Option<Type> = None;
let mut i = 0;
let mut insertion_point: Option<usize> = None;
let mut remove_or_replace = |i: usize, elements: &mut Vec<UnionElement<'db>>| {
if insertion_point.is_none() {
insertion_point = Some(i);
} else {
elements.swap_remove(i);
}
};
for (index, element) in self.elements.iter_mut().enumerate() {
while i < self.elements.len() {
let element = &mut self.elements[i];
let element_type = match element.try_reduce(self.db, ty) {
ReduceResult::KeepIf(keep) => {
if !keep {
to_remove.push(index);
remove_or_replace(i, &mut self.elements);
}
i += 1;
continue;
}
ReduceResult::Type(ty) => ty,
@@ -587,19 +604,24 @@ impl<'db> UnionBuilder<'db> {
// problematic if some of those fields point to recursive `Union`s. To avoid cycles,
// compare `TypedDict`s by name/identity instead of using the `has_relation_to`
// machinery.
if let (Type::TypedDict(element_td), Type::TypedDict(ty_td)) = (element_type, ty) {
if element_td == ty_td {
return;
}
if element_type.is_typed_dict() && ty.is_typed_dict() {
i += 1;
continue;
}
if should_simplify_full && !matches!(element_type, Type::TypeAlias(_)) {
if ty.is_redundant_with(self.db, element_type) {
return;
} else if element_type.is_redundant_with(self.db, ty) {
to_remove.push(index);
} else if ty_negated.is_subtype_of(self.db, element_type) {
}
if element_type.is_redundant_with(self.db, ty) {
remove_or_replace(i, &mut self.elements);
i += 1;
continue;
}
let negated = ty_negated.get_or_insert_with(|| ty.negate(self.db));
if negated.is_subtype_of(self.db, element_type) {
// We add `ty` to the union. We just checked that `~ty` is a subtype of an
// existing `element`. This also means that `~ty | ty` is a subtype of
// `element | ty`, because both elements in the first union are subtypes of
@@ -613,13 +635,12 @@ impl<'db> UnionBuilder<'db> {
return;
}
}
i += 1;
}
if let Some((&first, rest)) = to_remove.split_first() {
self.elements[first] = UnionElement::Type(ty);
// We iterate in descending order to keep remaining indices valid after `swap_remove`.
for &index in rest.iter().rev() {
self.elements.swap_remove(index);
}
if let Some(insertion_point) = insertion_point {
self.elements[insertion_point] = UnionElement::Type(ty);
} else {
self.elements.push(UnionElement::Type(ty));
}

View File

@@ -600,8 +600,22 @@ impl<'db> ClassType<'db> {
/// Return `true` if `other` is present in this class's MRO.
pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: ClassType<'db>) -> bool {
self.when_subclass_of(db, other, InferableTypeVars::None)
.is_always_satisfied(db)
#[salsa::tracked(cycle_initial=is_subclass_of_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn is_subclass_of_impl<'db>(
db: &'db dyn Db,
self_ty: ClassType<'db>,
other: ClassType<'db>,
) -> bool {
self_ty
.when_subclass_of(db, other, InferableTypeVars::None)
.is_always_satisfied(db)
}
if self == other {
return true;
}
is_subclass_of_impl(db, self, other)
}
pub(super) fn when_subclass_of(
@@ -714,6 +728,7 @@ impl<'db> ClassType<'db> {
/// Return the [`DisjointBase`] that appears first in the MRO of this class.
///
/// Returns `None` if this class does not have any disjoint bases in its MRO.
#[salsa::tracked(heap_size=ruff_memory_usage::heap_size)]
pub(super) fn nearest_disjoint_base(self, db: &'db dyn Db) -> Option<DisjointBase<'db>> {
self.iter_mro(db)
.filter_map(ClassBase::into_class)
@@ -1360,6 +1375,16 @@ fn into_callable_cycle_initial<'db>(
CallableTypes::one(CallableType::bottom(db))
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_subclass_of_cycle_initial<'db>(
_db: &'db dyn Db,
_id: salsa::Id,
_self_ty: ClassType<'db>,
_other: ClassType<'db>,
) -> bool {
false
}
impl<'db> From<GenericAlias<'db>> for ClassType<'db> {
fn from(generic: GenericAlias<'db>) -> ClassType<'db> {
ClassType::Generic(generic)
@@ -4110,7 +4135,7 @@ impl InheritanceCycle {
/// `TypeError`s resulting from class definitions.
///
/// [PEP 800]: https://peps.python.org/pep-0800/
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, get_size2::GetSize, salsa::Update)]
pub(super) struct DisjointBase<'db> {
pub(super) class: ClassLiteral<'db>,
pub(super) kind: DisjointBaseKind,
@@ -4147,7 +4172,7 @@ impl<'db> DisjointBase<'db> {
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, get_size2::GetSize, salsa::Update)]
pub(super) enum DisjointBaseKind {
/// We know the class is a disjoint base because it's either hardcoded in ty
/// or has the `@disjoint_base` decorator.

View File

@@ -310,6 +310,14 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
/// A list of `dataclass_transform` field specifiers that are "active" (when inferring
/// the right hand side of an annotated assignment in a class that is a dataclass).
dataclass_field_specifiers: SmallVec<[Type<'db>; NUM_FIELD_SPECIFIERS_INLINE]>,
/// Unions we're currently narrowing against in ancestor calls.
///
/// When inferring a call expression with a union type context, we try narrowing to each
/// element of the union. If nested calls have the same union as their parameter type,
/// this would lead to exponential blowup. By tracking which unions we're already narrowing
/// against, we skip redundant nested narrowing.
narrowing_unions: FxHashSet<UnionType<'db>>,
}
impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
@@ -348,6 +356,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
cycle_recovery: None,
all_definitely_bound: true,
dataclass_field_specifiers: SmallVec::new(),
narrowing_unions: FxHashSet::default(),
}
}
@@ -7008,13 +7017,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let db = self.db();
// If the type context is a union, attempt to narrow to a specific element.
let narrow_targets: &[_] = match call_expression_tcx.annotation {
// TODO: We could theoretically attempt to narrow to every element of
// the power set of this union. However, this leads to an exponential
// explosion of inference attempts, and is rarely needed in practice.
Some(Type::Union(union)) => union.elements(db),
_ => &[],
};
// However, skip narrowing if we're already narrowing against the same union
// in an ancestor call to avoid exponential blowup with deeply nested calls.
let (narrow_union, narrow_targets): (Option<UnionType<'db>>, &[_]) =
match call_expression_tcx.annotation {
Some(Type::Union(union)) if !self.narrowing_unions.contains(&union) => {
// TODO: We could theoretically attempt to narrow to every element of
// the power set of this union. However, this leads to an exponential
// explosion of inference attempts, and is rarely needed in practice.
(Some(union), union.elements(db))
}
_ => (None, &[]),
};
// Track that we're narrowing against this union to prevent nested calls
// from redundantly narrowing against the same union.
if let Some(union) = narrow_union {
self.narrowing_unions.insert(union);
}
// We silence diagnostics until we successfully narrow to a specific type.
let mut speculated_bindings = bindings.clone();
@@ -7082,12 +7102,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
};
// Prefer the declared type of generic classes.
let mut narrowing_result = None;
for narrowed_ty in narrow_targets
.iter()
.filter(|ty| ty.class_specialization(db).is_some())
{
if let Some(result) = try_narrow(*narrowed_ty) {
return result;
narrowing_result = Some(result);
break;
}
}
@@ -7095,15 +7117,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
//
// TODO: We could also attempt an inference without type context, but this
// leads to similar performance issues.
for narrowed_ty in narrow_targets
.iter()
.filter(|ty| ty.class_specialization(db).is_none())
{
if let Some(result) = try_narrow(*narrowed_ty) {
return result;
if narrowing_result.is_none() {
for narrowed_ty in narrow_targets
.iter()
.filter(|ty| ty.class_specialization(db).is_none())
{
if let Some(result) = try_narrow(*narrowed_ty) {
narrowing_result = Some(result);
break;
}
}
}
// Clean up: remove the union from the tracking set.
if let Some(union) = narrow_union {
self.narrowing_unions.remove(&union);
}
// If narrowing succeeded, return the result.
if let Some(result) = narrowing_result {
return result;
}
// Re-enable diagnostics, and infer against the entire union as a fallback.
self.context.set_multi_inference(was_in_multi_inference);
@@ -12592,6 +12627,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
index: _,
region: _,
return_types_and_ranges: _,
narrowing_unions: _,
} = self;
let diagnostics = context.finish();
@@ -12659,6 +12695,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
index: _,
region: _,
return_types_and_ranges: _,
narrowing_unions: _,
} = self;
let _ = scope;
@@ -12736,6 +12773,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
index: _,
region: _,
return_types_and_ranges: _,
narrowing_unions: _,
} = self;
let _ = scope;

View File

@@ -720,7 +720,7 @@ impl<'db> ProtocolInstanceType<'db> {
_value: ProtocolInstanceType<'db>,
_: (),
) -> bool {
true
false
}
is_equivalent_to_object_inner(db, self, ())