From 531ca7e47a71fe018215bddbad7d737cd2b84ee5 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Thu, 11 Dec 2025 10:15:03 -0500 Subject: [PATCH] improve generic call inference performance --- crates/ty_python_semantic/src/types.rs | 23 +++++----- .../ty_python_semantic/src/types/call/bind.rs | 45 ++++++++++--------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 40f8038cce..9fad0accbe 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1003,7 +1003,7 @@ impl<'db> Type<'db> { pub(crate) fn class_specialization( self, db: &'db dyn Db, - ) -> Option<(ClassType<'db>, Specialization<'db>)> { + ) -> Option<(ClassLiteral<'db>, Specialization<'db>)> { self.specialization_of_optional(db, None) } @@ -1021,7 +1021,7 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, expected_class: Option>, - ) -> Option<(ClassType<'db>, Specialization<'db>)> { + ) -> Option<(ClassLiteral<'db>, Specialization<'db>)> { let class_type = match self { Type::NominalInstance(instance) => instance, Type::ProtocolInstance(instance) => instance.to_nominal_instance()?, @@ -1035,7 +1035,7 @@ impl<'db> Type<'db> { return None; } - Some((class_type, specialization?)) + Some((class_literal, specialization?)) } /// Given a type variable `T` from the generic context of a class `C`: @@ -3905,17 +3905,20 @@ impl<'db> Type<'db> { .zip(specialization.types(db)) { let variance = type_var.variance_with_polarity(db, polarity); - let tcx = tcx.and_then(|tcx| { - tcx.filter_union(db, |ty| { - ty.find_type_var_from(db, type_var, class_literal).is_some() - }) - .find_type_var_from(db, type_var, class_literal) + let narrowed_tcx = tcx.and_then(|annotation| match annotation { + Type::Union(union) => union + .elements(db) + .iter() + .filter_map(|ty| ty.find_type_var_from(db, type_var, class_literal)) + .exactly_one() + .ok(), + _ => annotation.find_type_var_from(db, type_var, class_literal), }); - f(type_var, *ty, variance, tcx); + f(type_var, *ty, variance, narrowed_tcx); visitor.visit(*ty, || { - ty.visit_specialization_impl(db, tcx, variance, f, visitor); + ty.visit_specialization_impl(db, narrowed_tcx, variance, f, visitor); }); } } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index c05186a13e..00fb354aa5 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -2818,35 +2818,38 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { // Prefer the declared type of generic classes. let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| { - let tcx = tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()); - tcx.class_specialization(self.db)?; - - let return_specialization = return_ty + let (tcx_class, tcx_specialization) = tcx .filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()) - .class_specialization(self.db); + .class_specialization(self.db)?; + + let Some((return_class, return_specialization)) = return_ty + .filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()) + .class_specialization(self.db) + else { + builder.infer(return_ty, tcx).ok()?; + return Some(builder.type_mappings().clone()); + }; // TODO: We should use the constraint solver here to determine the type mappings for more // complex subtyping relationships, e.g., callables, protocols, or unions containing multiple // generic elements. - if let Some((class_literal, _)) = return_specialization - && let Some(generic_alias) = class_literal.into_generic_alias() - { - let specialization = generic_alias.specialization(self.db); - for (class_type_var, return_ty) in specialization - .generic_context(self.db) - .variables(self.db) - .zip(specialization.types(self.db)) - { - if let Some(ty) = tcx.find_type_var_from( - self.db, - class_type_var, - generic_alias.origin(self.db), + for base in return_class.iter_mro(self.db, Some(return_specialization)) { + let Some((base_class, Some(base_specialization))) = + base.into_class().map(|class| class.class_literal(self.db)) + else { + continue; + }; + + if base_class == tcx_class { + for (base_ty, tcx_ty) in std::iter::zip( + base_specialization.types(self.db), + tcx_specialization.types(self.db), ) { - builder.infer(*return_ty, ty).ok()?; + builder.infer(*base_ty, *tcx_ty).ok()?; } + + break; } - } else { - builder.infer(return_ty, tcx).ok()?; } Some(builder.type_mappings().clone())