Compare commits

..

1 Commits

Author SHA1 Message Date
Zanie Blue
d83e0dc50b [ty] Cache ClassType::nearest_disjoint_base 2026-01-03 08:03:07 -06:00
5 changed files with 44 additions and 27 deletions

View File

@@ -7267,7 +7267,10 @@ impl<'db> Type<'db> {
}
(Some(Place::Defined(new_method, ..)), Place::Defined(init_method, ..)) => {
let callable = UnionType::from_elements(db, [new_method, init_method]);
let callable = UnionBuilder::new(db)
.add(*new_method)
.add(*init_method)
.build();
let new_method_bindings = new_method
.bindings(db)
@@ -10755,7 +10758,11 @@ fn walk_type_var_constraints<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
impl<'db> TypeVarConstraints<'db> {
fn as_type(self, db: &'db dyn Db) -> Type<'db> {
UnionType::from_elements(db, self.elements(db))
let mut builder = UnionBuilder::new(db);
for ty in self.elements(db) {
builder = builder.add(*ty);
}
builder.build()
}
fn to_instance(self, db: &'db dyn Db) -> Option<TypeVarConstraints<'db>> {

View File

@@ -715,6 +715,7 @@ impl<'db> ClassType<'db> {
/// Return the [`DisjointBase`] that appears first in the MRO of this class.
///
/// Returns `None` if this class does not have any disjoint bases in its MRO.
#[salsa::tracked(heap_size=ruff_memory_usage::heap_size)]
pub(super) fn nearest_disjoint_base(self, db: &'db dyn Db) -> Option<DisjointBase<'db>> {
self.iter_mro(db)
.filter_map(ClassBase::into_class)
@@ -4233,7 +4234,7 @@ impl InheritanceCycle {
/// `TypeError`s resulting from class definitions.
///
/// [PEP 800]: https://peps.python.org/pep-0800/
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, get_size2::GetSize, salsa::Update)]
pub(super) struct DisjointBase<'db> {
pub(super) class: ClassLiteral<'db>,
pub(super) kind: DisjointBaseKind,
@@ -4270,7 +4271,7 @@ impl<'db> DisjointBase<'db> {
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, get_size2::GetSize, salsa::Update)]
pub(super) enum DisjointBaseKind {
/// We know the class is a disjoint base because it's either hardcoded in ty
/// or has the `@disjoint_base` decorator.

View File

@@ -24,8 +24,9 @@ use std::cmp::Eq;
use std::hash::Hash;
use std::marker::PhantomData;
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_hash::FxHashMap;
use crate::FxIndexSet;
use crate::types::Type;
/// Maximum recursion depth for cycle detection.
@@ -63,7 +64,7 @@ pub struct CycleDetector<Tag, T, R> {
/// If the type we're visiting is present in `seen`, it indicates that we've hit a cycle (due
/// to a recursive type); we need to immediately short circuit the whole operation and return
/// the fallback value. That's why we pop items off the end of `seen` after we've visited them.
seen: RefCell<FxHashSet<T>>,
seen: RefCell<FxIndexSet<T>>,
/// Unlike `seen`, this field is a pure performance optimisation (and an essential one). If the
/// type we're trying to normalize is present in `cache`, it doesn't necessarily mean we've hit
@@ -85,7 +86,7 @@ pub struct CycleDetector<Tag, T, R> {
impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
pub fn new(fallback: R) -> Self {
CycleDetector {
seen: RefCell::new(FxHashSet::default()),
seen: RefCell::new(FxIndexSet::default()),
cache: RefCell::new(FxHashMap::default()),
depth: Cell::new(0),
fallback,
@@ -98,23 +99,24 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
return val.clone();
}
// Check depth limit to prevent stack overflow from recursive generic types
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
let current_depth = self.depth.get();
if current_depth >= MAX_RECURSION_DEPTH {
return self.fallback.clone();
}
// We hit a cycle
if !self.seen.borrow_mut().insert(item.clone()) {
return self.fallback.clone();
}
// Check depth limit to prevent stack overflow from recursive generic types
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
let current_depth = self.depth.get();
if current_depth >= MAX_RECURSION_DEPTH {
self.seen.borrow_mut().pop();
return self.fallback.clone();
}
self.depth.set(current_depth + 1);
let ret = func();
self.depth.set(current_depth);
self.seen.borrow_mut().remove(&item);
self.seen.borrow_mut().pop();
self.cache.borrow_mut().insert(item, ret.clone());
ret
@@ -125,24 +127,24 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
return Some(val.clone());
}
// Check depth limit to prevent stack overflow from recursive generic protocols
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
let current_depth = self.depth.get();
if current_depth >= MAX_RECURSION_DEPTH {
return Some(self.fallback.clone());
}
// We hit a cycle
if !self.seen.borrow_mut().insert(item.clone()) {
return Some(self.fallback.clone());
}
// Check depth limit to prevent stack overflow from recursive generic protocols
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
let current_depth = self.depth.get();
if current_depth >= MAX_RECURSION_DEPTH {
self.seen.borrow_mut().pop();
return Some(self.fallback.clone());
}
self.depth.set(current_depth + 1);
let ret = func()?;
self.depth.set(current_depth);
self.seen.borrow_mut().remove(&item);
self.seen.borrow_mut().pop();
self.cache.borrow_mut().insert(item, ret.clone());
Some(ret)

View File

@@ -1083,9 +1083,13 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
&mut self.inner_expression_inference_state,
InnerExpressionInferenceState::Get,
);
let union = union.map(self.db(), |element| {
self.infer_subscript_type_expression(subscript, *element)
});
let union = union
.elements(self.db())
.iter()
.fold(UnionBuilder::new(self.db()), |builder, elem| {
builder.add(self.infer_subscript_type_expression(subscript, *elem))
})
.build();
self.inner_expression_inference_state = previous_slice_inference_state;
union
}

View File

@@ -926,7 +926,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.build();
// Keep order: first literal complement, then broader arms.
let result = UnionType::from_elements(self.db, [narrowed_single, rest_union]);
let result = UnionBuilder::new(self.db)
.add(narrowed_single)
.add(rest_union)
.build();
Some(result)
} else {
None