[red-knot] Allow explicit specialization of generic classes (#17023)

This PR lets you explicitly specialize a generic class using a subscript
expression. It introduces three new Rust types for representing classes:

- `NonGenericClass`
- `GenericClass` (not specialized)
- `GenericAlias` (specialized)

and two enum wrappers:

- `ClassType` (a non-generic class or generic alias, represents a class
_type_ at runtime)
- `ClassLiteralType` (a non-generic class or generic class, represents a
class body in the AST)

We also add internal support for specializing callables, in particular
function literals. (That is, the internal `Type` representation now
attaches an optional specialization to a function literal.) This is used
in this PR for the methods of a generic class, but should also give us
most of what we need for specializing generic _functions_ (which this PR
does not yet tackle).

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Douglas Creager
2025-04-09 11:18:46 -04:00
committed by GitHub
parent 7c81408c54
commit ff376fc262
21 changed files with 1559 additions and 435 deletions

View File

@@ -425,6 +425,17 @@ impl<'db> SymbolAndQualifiers<'db> {
self.qualifiers.contains(TypeQualifiers::CLASS_VAR)
}
#[must_use]
pub(crate) fn map_type(
self,
f: impl FnOnce(Type<'db>) -> Type<'db>,
) -> SymbolAndQualifiers<'db> {
SymbolAndQualifiers {
symbol: self.symbol.map_type(f),
qualifiers: self.qualifiers,
}
}
/// Transform symbol and qualifiers into a [`LookupResult`],
/// a [`Result`] type in which the `Ok` variant represents a definitely bound symbol
/// and the `Err` variant represents a symbol that is either definitely or possibly unbound.

View File

@@ -35,12 +35,16 @@ use crate::symbol::{imported_symbol, Boundness, Symbol, SymbolAndQualifiers};
use crate::types::call::{Bindings, CallArgumentTypes};
pub(crate) use crate::types::class_base::ClassBase;
use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION};
use crate::types::generics::Specialization;
use crate::types::infer::infer_unpack_types;
use crate::types::mro::{Mro, MroError, MroIterator};
pub(crate) use crate::types::narrow::infer_narrowing_constraint;
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
use crate::{Db, FxOrderSet, Module, Program};
pub(crate) use class::{Class, ClassLiteralType, InstanceType, KnownClass, KnownInstanceType};
pub(crate) use class::{
Class, ClassLiteralType, ClassType, GenericAlias, GenericClass, InstanceType, KnownClass,
KnownInstanceType, NonGenericClass,
};
mod builder;
mod call;
@@ -49,6 +53,7 @@ mod class_base;
mod context;
mod diagnostic;
mod display;
mod generics;
mod infer;
mod mro;
mod narrow;
@@ -276,6 +281,18 @@ pub struct PropertyInstanceType<'db> {
setter: Option<Type<'db>>,
}
impl<'db> PropertyInstanceType<'db> {
fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
let getter = self
.getter(db)
.map(|ty| ty.apply_specialization(db, specialization));
let setter = self
.setter(db)
.map(|ty| ty.apply_specialization(db, specialization));
Self::new(db, getter, setter)
}
}
/// Representation of a type: a set of possible values at runtime.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update)]
pub enum Type<'db> {
@@ -319,7 +336,9 @@ pub enum Type<'db> {
ModuleLiteral(ModuleLiteralType<'db>),
/// A specific class object
ClassLiteral(ClassLiteralType<'db>),
// The set of all class objects that are subclasses of the given class (C), spelled `type[C]`.
/// A specialization of a generic class
GenericAlias(GenericAlias<'db>),
/// The set of all class objects that are subclasses of the given class (C), spelled `type[C]`.
SubclassOf(SubclassOfType<'db>),
/// The set of Python objects with the given class in their __class__'s method resolution order
Instance(InstanceType<'db>),
@@ -382,20 +401,17 @@ impl<'db> Type<'db> {
fn is_none(&self, db: &'db dyn Db) -> bool {
self.into_instance()
.is_some_and(|instance| instance.class().is_known(db, KnownClass::NoneType))
.is_some_and(|instance| instance.class.is_known(db, KnownClass::NoneType))
}
pub fn is_notimplemented(&self, db: &'db dyn Db) -> bool {
self.into_instance().is_some_and(|instance| {
instance
.class()
.is_known(db, KnownClass::NotImplementedType)
})
self.into_instance()
.is_some_and(|instance| instance.class.is_known(db, KnownClass::NotImplementedType))
}
pub fn is_object(&self, db: &'db dyn Db) -> bool {
self.into_instance()
.is_some_and(|instance| instance.class().is_object(db))
.is_some_and(|instance| instance.class.is_object(db))
}
pub const fn is_todo(&self) -> bool {
@@ -426,6 +442,12 @@ impl<'db> Type<'db> {
| Self::WrapperDescriptor(_)
| Self::MethodWrapper(_) => false,
Self::GenericAlias(generic) => generic
.specialization(db)
.types(db)
.iter()
.any(|ty| ty.contains_todo(db)),
Self::Callable(callable) => {
let signature = callable.signature(db);
signature.parameters().iter().any(|param| {
@@ -467,10 +489,6 @@ impl<'db> Type<'db> {
}
}
pub const fn class_literal(class: Class<'db>) -> Self {
Self::ClassLiteral(ClassLiteralType { class })
}
pub const fn into_class_literal(self) -> Option<ClassLiteralType<'db>> {
match self {
Type::ClassLiteral(class_type) => Some(class_type),
@@ -492,6 +510,29 @@ impl<'db> Type<'db> {
matches!(self, Type::ClassLiteral(..))
}
pub const fn into_class_type(self) -> Option<ClassType<'db>> {
match self {
Type::ClassLiteral(ClassLiteralType::NonGeneric(non_generic)) => {
Some(ClassType::NonGeneric(non_generic))
}
Type::GenericAlias(alias) => Some(ClassType::Generic(alias)),
_ => None,
}
}
#[track_caller]
pub fn expect_class_type(self) -> ClassType<'db> {
self.into_class_type()
.expect("Expected a Type::GenericAlias or non-generic Type::ClassLiteral variant")
}
pub const fn is_class_type(&self) -> bool {
matches!(
self,
Type::ClassLiteral(ClassLiteralType::NonGeneric(_)) | Type::GenericAlias(_)
)
}
pub const fn is_instance(&self) -> bool {
matches!(self, Type::Instance(..))
}
@@ -631,7 +672,7 @@ impl<'db> Type<'db> {
matches!(self, Type::LiteralString)
}
pub const fn instance(class: Class<'db>) -> Self {
pub const fn instance(class: ClassType<'db>) -> Self {
Self::Instance(InstanceType { class })
}
@@ -693,6 +734,10 @@ impl<'db> Type<'db> {
| Type::KnownInstance(_)
| Type::IntLiteral(_)
| Type::SubclassOf(_) => self,
Type::GenericAlias(generic) => {
let specialization = generic.specialization(db).normalized(db);
Type::GenericAlias(GenericAlias::new(db, generic.origin(db), specialization))
}
Type::TypeVar(typevar) => match typevar.bound_or_constraints(db) {
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
Type::TypeVar(TypeVarInstance::new(
@@ -932,13 +977,16 @@ impl<'db> Type<'db> {
// `Literal[<class 'C'>]` is a subtype of `type[B]` if `C` is a subclass of `B`,
// since `type[B]` describes all possible runtime subclasses of the class object `B`.
(
Type::ClassLiteral(ClassLiteralType { class }),
Type::SubclassOf(target_subclass_ty),
) => target_subclass_ty
(Type::ClassLiteral(class), Type::SubclassOf(target_subclass_ty)) => target_subclass_ty
.subclass_of()
.into_class()
.is_some_and(|target_class| class.is_subclass_of(db, target_class)),
.is_some_and(|target_class| class.is_subclass_of(db, None, target_class)),
(Type::GenericAlias(alias), Type::SubclassOf(target_subclass_ty)) => target_subclass_ty
.subclass_of()
.into_class()
.is_some_and(|target_class| {
ClassType::from(alias).is_subclass_of(db, target_class)
}),
// This branch asks: given two types `type[T]` and `type[S]`, is `type[T]` a subtype of `type[S]`?
(Type::SubclassOf(self_subclass_ty), Type::SubclassOf(target_subclass_ty)) => {
@@ -948,9 +996,12 @@ impl<'db> Type<'db> {
// `Literal[str]` is a subtype of `type` because the `str` class object is an instance of its metaclass `type`.
// `Literal[abc.ABC]` is a subtype of `abc.ABCMeta` because the `abc.ABC` class object
// is an instance of its metaclass `abc.ABCMeta`.
(Type::ClassLiteral(ClassLiteralType { class }), _) => {
(Type::ClassLiteral(class), _) => {
class.metaclass_instance_type(db).is_subtype_of(db, target)
}
(Type::GenericAlias(alias), _) => ClassType::from(alias)
.metaclass_instance_type(db)
.is_subtype_of(db, target),
// `type[str]` (== `SubclassOf("str")` in red-knot) describes all possible runtime subclasses
// of the class object `str`. It is a subtype of `type` (== `Instance("type")`) because `str`
@@ -1141,11 +1192,10 @@ impl<'db> Type<'db> {
// Every class literal type is also assignable to `type[Any]`, because the class
// literal type for a class `C` is a subtype of `type[C]`, and `type[C]` is assignable
// to `type[Any]`.
(Type::ClassLiteral(_) | Type::SubclassOf(_), Type::SubclassOf(target_subclass_of))
if target_subclass_of.is_dynamic() =>
{
true
}
(
Type::ClassLiteral(_) | Type::GenericAlias(_) | Type::SubclassOf(_),
Type::SubclassOf(target_subclass_of),
) if target_subclass_of.is_dynamic() => true,
// `type[Any]` is assignable to any type that `type[object]` is assignable to, because
// `type[Any]` can materialize to `type[object]`.
@@ -1386,6 +1436,7 @@ impl<'db> Type<'db> {
| Type::WrapperDescriptor(..)
| Type::ModuleLiteral(..)
| Type::ClassLiteral(..)
| Type::GenericAlias(..)
| Type::KnownInstance(..)),
right @ (Type::BooleanLiteral(..)
| Type::IntLiteral(..)
@@ -1398,6 +1449,7 @@ impl<'db> Type<'db> {
| Type::WrapperDescriptor(..)
| Type::ModuleLiteral(..)
| Type::ClassLiteral(..)
| Type::GenericAlias(..)
| Type::KnownInstance(..)),
) => left != right,
@@ -1406,6 +1458,7 @@ impl<'db> Type<'db> {
(
Type::Tuple(..),
Type::ClassLiteral(..)
| Type::GenericAlias(..)
| Type::ModuleLiteral(..)
| Type::BooleanLiteral(..)
| Type::BytesLiteral(..)
@@ -1420,6 +1473,7 @@ impl<'db> Type<'db> {
)
| (
Type::ClassLiteral(..)
| Type::GenericAlias(..)
| Type::ModuleLiteral(..)
| Type::BooleanLiteral(..)
| Type::BytesLiteral(..)
@@ -1434,17 +1488,23 @@ impl<'db> Type<'db> {
Type::Tuple(..),
) => true,
(
Type::SubclassOf(subclass_of_ty),
Type::ClassLiteral(ClassLiteralType { class: class_b }),
)
| (
Type::ClassLiteral(ClassLiteralType { class: class_b }),
Type::SubclassOf(subclass_of_ty),
) => match subclass_of_ty.subclass_of() {
ClassBase::Dynamic(_) => false,
ClassBase::Class(class_a) => !class_b.is_subclass_of(db, class_a),
},
(Type::SubclassOf(subclass_of_ty), Type::ClassLiteral(class_b))
| (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => {
match subclass_of_ty.subclass_of() {
ClassBase::Dynamic(_) => false,
ClassBase::Class(class_a) => !class_b.is_subclass_of(db, None, class_a),
}
}
(Type::SubclassOf(subclass_of_ty), Type::GenericAlias(alias_b))
| (Type::GenericAlias(alias_b), Type::SubclassOf(subclass_of_ty)) => {
match subclass_of_ty.subclass_of() {
ClassBase::Dynamic(_) => false,
ClassBase::Class(class_a) => {
!ClassType::from(alias_b).is_subclass_of(db, class_a)
}
}
}
(
Type::SubclassOf(_),
@@ -1561,12 +1621,14 @@ impl<'db> Type<'db> {
// A class-literal type `X` is always disjoint from an instance type `Y`,
// unless the type expressing "all instances of `Z`" is a subtype of of `Y`,
// where `Z` is `X`'s metaclass.
(Type::ClassLiteral(ClassLiteralType { class }), instance @ Type::Instance(_))
| (instance @ Type::Instance(_), Type::ClassLiteral(ClassLiteralType { class })) => {
!class
.metaclass_instance_type(db)
.is_subtype_of(db, instance)
}
(Type::ClassLiteral(class), instance @ Type::Instance(_))
| (instance @ Type::Instance(_), Type::ClassLiteral(class)) => !class
.metaclass_instance_type(db)
.is_subtype_of(db, instance),
(Type::GenericAlias(alias), instance @ Type::Instance(_))
| (instance @ Type::Instance(_), Type::GenericAlias(alias)) => !ClassType::from(alias)
.metaclass_instance_type(db)
.is_subtype_of(db, instance),
(Type::FunctionLiteral(..), Type::Instance(InstanceType { class }))
| (Type::Instance(InstanceType { class }), Type::FunctionLiteral(..)) => {
@@ -1692,7 +1754,7 @@ impl<'db> Type<'db> {
},
Type::SubclassOf(subclass_of_ty) => subclass_of_ty.is_fully_static(),
Type::ClassLiteral(_) | Type::Instance(_) => {
Type::ClassLiteral(_) | Type::GenericAlias(_) | Type::Instance(_) => {
// TODO: Ideally, we would iterate over the MRO of the class, check if all
// bases are fully static, and only return `true` if that is the case.
//
@@ -1763,6 +1825,7 @@ impl<'db> Type<'db> {
| Type::FunctionLiteral(..)
| Type::WrapperDescriptor(..)
| Type::ClassLiteral(..)
| Type::GenericAlias(..)
| Type::ModuleLiteral(..)
| Type::KnownInstance(..) => true,
Type::Callable(_) => {
@@ -1828,6 +1891,7 @@ impl<'db> Type<'db> {
| Type::MethodWrapper(_)
| Type::ModuleLiteral(..)
| Type::ClassLiteral(..)
| Type::GenericAlias(..)
| Type::IntLiteral(..)
| Type::BooleanLiteral(..)
| Type::StringLiteral(..)
@@ -1912,7 +1976,7 @@ impl<'db> Type<'db> {
Type::Dynamic(_) | Type::Never => Some(Symbol::bound(self).into()),
Type::ClassLiteral(class_literal @ ClassLiteralType { class }) => {
Type::ClassLiteral(class) => {
match (class.known(db), name) {
(Some(KnownClass::FunctionType), "__get__") => Some(
Symbol::bound(Type::WrapperDescriptor(
@@ -1956,10 +2020,14 @@ impl<'db> Type<'db> {
"__get__" | "__set__" | "__delete__",
) => Some(Symbol::Unbound.into()),
_ => Some(class_literal.class_member(db, name, policy)),
_ => Some(class.class_member(db, None, name, policy)),
}
}
Type::GenericAlias(alias) => {
Some(ClassType::from(*alias).class_member(db, name, policy))
}
Type::SubclassOf(subclass_of)
if name == "__get__"
&& matches!(
@@ -2126,7 +2194,9 @@ impl<'db> Type<'db> {
// a `__dict__` that is filled with class level attributes. Modeling this is currently not
// required, as `instance_member` is only called for instance-like types through `member`,
// but we might want to add this in the future.
Type::ClassLiteral(_) | Type::SubclassOf(_) => Symbol::Unbound.into(),
Type::ClassLiteral(_) | Type::GenericAlias(_) | Type::SubclassOf(_) => {
Symbol::Unbound.into()
}
}
}
@@ -2439,7 +2509,7 @@ impl<'db> Type<'db> {
)
.into(),
Type::ClassLiteral(ClassLiteralType { class })
Type::ClassLiteral(class)
if name == "__get__" && class.is_known(db, KnownClass::FunctionType) =>
{
Symbol::bound(Type::WrapperDescriptor(
@@ -2447,7 +2517,7 @@ impl<'db> Type<'db> {
))
.into()
}
Type::ClassLiteral(ClassLiteralType { class })
Type::ClassLiteral(class)
if name == "__get__" && class.is_known(db, KnownClass::Property) =>
{
Symbol::bound(Type::WrapperDescriptor(
@@ -2455,7 +2525,7 @@ impl<'db> Type<'db> {
))
.into()
}
Type::ClassLiteral(ClassLiteralType { class })
Type::ClassLiteral(class)
if name == "__set__" && class.is_known(db, KnownClass::Property) =>
{
Symbol::bound(Type::WrapperDescriptor(
@@ -2599,8 +2669,8 @@ impl<'db> Type<'db> {
}
}
Type::ClassLiteral(..) | Type::SubclassOf(..) => {
let class_attr_plain = self.find_name_in_mro_with_policy(db, name_str, policy).expect(
Type::ClassLiteral(..) | Type::GenericAlias(..) | Type::SubclassOf(..) => {
let class_attr_plain = self.find_name_in_mro_with_policy(db, name_str,policy).expect(
"Calling `find_name_in_mro` on class literals and subclass-of types should always return `Some`",
);
@@ -2785,14 +2855,17 @@ impl<'db> Type<'db> {
Type::AlwaysFalsy => Truthiness::AlwaysFalse,
Type::ClassLiteral(ClassLiteralType { class }) => class
Type::ClassLiteral(class) => class
.metaclass_instance_type(db)
.try_bool_impl(db, allow_short_circuit)?,
Type::GenericAlias(alias) => ClassType::from(*alias)
.metaclass_instance_type(db)
.try_bool_impl(db, allow_short_circuit)?,
Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() {
ClassBase::Dynamic(_) => Truthiness::Ambiguous,
ClassBase::Class(class) => {
Type::class_literal(class).try_bool_impl(db, allow_short_circuit)?
Type::from(class).try_bool_impl(db, allow_short_circuit)?
}
},
@@ -3141,7 +3214,7 @@ impl<'db> Type<'db> {
)),
},
Type::ClassLiteral(ClassLiteralType { class }) => match class.known(db) {
Type::ClassLiteral(class) => match class.known(db) {
// TODO: Ideally we'd use `try_call_constructor` for all constructor calls.
// Currently we don't for a few special known types, either because their
// constructors are defined with overloads, or because we want to special case
@@ -3345,13 +3418,23 @@ impl<'db> Type<'db> {
}
},
Type::GenericAlias(_) => {
// TODO annotated return type on `__new__` or metaclass `__call__`
// TODO check call vs signatures of `__new__` and/or `__init__`
let signature = CallableSignature::single(
self,
Signature::new(Parameters::gradual_form(), self.to_instance(db)),
);
Signatures::single(signature)
}
Type::SubclassOf(subclass_of_type) => match subclass_of_type.subclass_of() {
ClassBase::Dynamic(dynamic_type) => Type::Dynamic(dynamic_type).signatures(db),
// Most type[] constructor calls are handled by `try_call_constructor` and not via
// getting the signature here. This signature can still be used in some cases (e.g.
// evaluating callable subtyping). TODO improve this definition (intersection of
// `__new__` and `__init__` signatures? and respect metaclass `__call__`).
ClassBase::Class(class) => Type::class_literal(class).signatures(db),
ClassBase::Class(class) => Type::from(class).signatures(db),
},
Type::Instance(_) => {
@@ -3729,7 +3812,8 @@ impl<'db> Type<'db> {
pub fn to_instance(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Type::Dynamic(_) | Type::Never => Some(*self),
Type::ClassLiteral(ClassLiteralType { class }) => Some(Type::instance(*class)),
Type::ClassLiteral(class) => Some(Type::instance(class.default_specialization(db))),
Type::GenericAlias(alias) => Some(Type::instance(ClassType::from(*alias))),
Type::SubclassOf(subclass_of_ty) => Some(subclass_of_ty.to_instance()),
Type::Union(union) => {
let mut builder = UnionBuilder::new(db);
@@ -3774,7 +3858,7 @@ impl<'db> Type<'db> {
match self {
// Special cases for `float` and `complex`
// https://typing.readthedocs.io/en/latest/spec/special-types.html#special-cases-for-float-and-complex
Type::ClassLiteral(ClassLiteralType { class }) => {
Type::ClassLiteral(class) => {
let ty = match class.known(db) {
Some(KnownClass::Any) => Type::any(),
Some(KnownClass::Complex) => UnionType::from_elements(
@@ -3792,10 +3876,11 @@ impl<'db> Type<'db> {
KnownClass::Float.to_instance(db),
],
),
_ => Type::instance(*class),
_ => Type::instance(class.default_specialization(db)),
};
Ok(ty)
}
Type::GenericAlias(alias) => Ok(Type::instance(ClassType::from(*alias))),
Type::SubclassOf(_)
| Type::BooleanLiteral(_)
@@ -4033,7 +4118,8 @@ impl<'db> Type<'db> {
}
},
Type::ClassLiteral(ClassLiteralType { class }) => class.metaclass(db),
Type::ClassLiteral(class) => class.metaclass(db),
Type::GenericAlias(alias) => ClassType::from(*alias).metaclass(db),
Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() {
ClassBase::Dynamic(_) => *self,
ClassBase::Class(class) => SubclassOfType::from(
@@ -4055,6 +4141,121 @@ impl<'db> Type<'db> {
}
}
/// Applies a specialization to this type, replacing any typevars with the types that they are
/// specialized to.
///
/// Note that this does not specialize generic classes, functions, or type aliases! That is a
/// different operation that is performed explicitly (via a subscript operation), or implicitly
/// via a call to the generic object.
#[must_use]
#[salsa::tracked]
pub fn apply_specialization(
self,
db: &'db dyn Db,
specialization: Specialization<'db>,
) -> Type<'db> {
match self {
Type::TypeVar(typevar) => specialization.get(db, typevar).unwrap_or(self),
Type::FunctionLiteral(function) => {
Type::FunctionLiteral(function.apply_specialization(db, specialization))
}
// Note that we don't need to apply the specialization to `self_instance`, since it
// must either be a non-generic class literal (which cannot have any typevars to
// specialize) or a generic alias (which has already been fully specialized). For a
// generic alias, the specialization being applied here must be for some _other_
// generic context nested within the generic alias's class literal, which the generic
// alias's context cannot refer to. (The _method_ does need to be specialized, since it
// might be a nested generic method, whose generic context is what is now being
// specialized.)
Type::BoundMethod(method) => Type::BoundMethod(BoundMethodType::new(
db,
method.function(db).apply_specialization(db, specialization),
method.self_instance(db),
)),
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => {
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(
function.apply_specialization(db, specialization),
))
}
Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(property)) => {
Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(
property.apply_specialization(db, specialization),
))
}
Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet(property)) => {
Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet(
property.apply_specialization(db, specialization),
))
}
Type::Callable(callable) => {
Type::Callable(callable.apply_specialization(db, specialization))
}
Type::GenericAlias(generic) => {
let specialization = generic
.specialization(db)
.apply_specialization(db, specialization);
Type::GenericAlias(GenericAlias::new(db, generic.origin(db), specialization))
}
Type::PropertyInstance(property) => {
Type::PropertyInstance(property.apply_specialization(db, specialization))
}
Type::Union(union) => union.map(db, |element| {
element.apply_specialization(db, specialization)
}),
Type::Intersection(intersection) => {
let mut builder = IntersectionBuilder::new(db);
for positive in intersection.positive(db) {
builder =
builder.add_positive(positive.apply_specialization(db, specialization));
}
for negative in intersection.negative(db) {
builder =
builder.add_negative(negative.apply_specialization(db, specialization));
}
builder.build()
}
Type::Tuple(tuple) => TupleType::from_elements(
db,
tuple
.iter(db)
.map(|ty| ty.apply_specialization(db, specialization)),
),
Type::Dynamic(_)
| Type::Never
| Type::AlwaysTruthy
| Type::AlwaysFalsy
| Type::WrapperDescriptor(_)
| Type::ModuleLiteral(_)
// A non-generic class never needs to be specialized. A generic class is specialized
// explicitly (via a subscript expression) or implicitly (via a call), and not because
// some other generic context's specialization is applied to it.
| Type::ClassLiteral(_)
// SubclassOf contains a ClassType, which has already been specialized if needed, like
// above with BoundMethod's self_instance.
| Type::SubclassOf(_)
| Type::IntLiteral(_)
| Type::BooleanLiteral(_)
| Type::LiteralString
| Type::StringLiteral(_)
| Type::BytesLiteral(_)
| Type::SliceLiteral(_)
// Instance contains a ClassType, which has already been specialized if needed, like
// above with BoundMethod's self_instance.
| Type::Instance(_)
| Type::KnownInstance(_) => self,
}
}
/// Return the string representation of this type when converted to string as it would be
/// provided by the `__str__` method.
///
@@ -4111,11 +4312,10 @@ impl<'db> Type<'db> {
}
Self::ModuleLiteral(module) => Some(TypeDefinition::Module(module.module(db))),
Self::ClassLiteral(class_literal) => {
Some(TypeDefinition::Class(class_literal.class().definition(db)))
}
Self::Instance(instance) => {
Some(TypeDefinition::Class(instance.class().definition(db)))
Some(TypeDefinition::Class(class_literal.definition(db)))
}
Self::GenericAlias(alias) => Some(TypeDefinition::Class(alias.definition(db))),
Self::Instance(instance) => Some(TypeDefinition::Class(instance.class.definition(db))),
Self::KnownInstance(instance) => match instance {
KnownInstanceType::TypeVar(var) => {
Some(TypeDefinition::TypeVar(var.definition(db)))
@@ -5133,6 +5333,11 @@ pub struct FunctionType<'db> {
/// A set of special decorators that were applied to this function
decorators: FunctionDecorators,
/// A specialization that should be applied to the function's parameter and return types,
/// either because the function is itself generic, or because it appears in the body of a
/// generic class.
specialization: Option<Specialization<'db>>,
}
#[salsa::tracked]
@@ -5183,13 +5388,17 @@ impl<'db> FunctionType<'db> {
/// would depend on the function's AST and rerun for every change in that file.
#[salsa::tracked(return_ref)]
pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> {
let internal_signature = self.internal_signature(db);
let mut internal_signature = self.internal_signature(db);
if self.has_known_decorator(db, FunctionDecorators::OVERLOAD) {
Signature::todo("return type of overloaded function")
} else {
internal_signature
return Signature::todo("return type of overloaded function");
}
if let Some(specialization) = self.specialization(db) {
internal_signature.apply_specialization(db, specialization);
}
internal_signature
}
/// Typed internally-visible signature for this function.
@@ -5212,6 +5421,21 @@ impl<'db> FunctionType<'db> {
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
self.known(db) == Some(known_function)
}
fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
let specialization = match self.specialization(db) {
Some(existing) => existing.apply_specialization(db, specialization),
None => specialization,
};
Self::new(
db,
self.name(db).clone(),
self.known(db),
self.body_scope(db),
self.decorators(db),
Some(specialization),
)
}
}
/// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might
@@ -5382,6 +5606,12 @@ impl<'db> CallableType<'db> {
CallableType::new(db, Signature::new(parameters, return_ty))
}
fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
let mut signature = self.signature(db).clone();
signature.apply_specialization(db, specialization);
Self::new(db, signature)
}
/// Returns `true` if this is a fully static callable type.
///
/// A callable type is fully static if all of its parameters and return type are fully static
@@ -6006,8 +6236,8 @@ impl<'db> TypeAliasType<'db> {
/// Either the explicit `metaclass=` keyword of the class, or the inferred metaclass of one of its base classes.
#[derive(Debug, Clone, PartialEq, Eq, salsa::Update)]
pub(super) struct MetaclassCandidate<'db> {
metaclass: Class<'db>,
explicit_metaclass_of: Class<'db>,
metaclass: ClassType<'db>,
explicit_metaclass_of: ClassLiteralType<'db>,
}
#[salsa::interned(debug)]
@@ -6579,6 +6809,10 @@ impl<'db> TupleType<'db> {
pub fn len(&self, db: &'db dyn Db) -> usize {
self.elements(db).len()
}
pub fn iter(&self, db: &'db dyn Db) -> impl Iterator<Item = Type<'db>> + 'db + '_ {
self.elements(db).iter().copied()
}
}
// Make sure that the `Type` enum does not grow unexpectedly.

View File

@@ -292,7 +292,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
_ => {
let known_instance = new_positive
.into_instance()
.and_then(|instance| instance.class().known(db));
.and_then(|instance| instance.class.known(db));
if known_instance == Some(KnownClass::Object) {
// `object & T` -> `T`; it is always redundant to add `object` to an intersection
@@ -312,7 +312,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
new_positive = Type::BooleanLiteral(false);
}
Type::Instance(instance)
if instance.class().is_known(db, KnownClass::Bool) =>
if instance.class.is_known(db, KnownClass::Bool) =>
{
match new_positive {
// `bool & AlwaysTruthy` -> `Literal[True]`
@@ -406,7 +406,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
self.positive
.iter()
.filter_map(|ty| ty.into_instance())
.filter_map(|instance| instance.class().known(db))
.filter_map(|instance| instance.class.known(db))
.any(KnownClass::is_bool)
};
@@ -422,7 +422,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
Type::Never => {
// Adding ~Never to an intersection is a no-op.
}
Type::Instance(instance) if instance.class().is_object(db) => {
Type::Instance(instance) if instance.class.is_object(db) => {
// Adding ~object to an intersection results in Never.
*self = Self::default();
self.positive.insert(Type::Never);

View File

@@ -18,8 +18,8 @@ use crate::types::diagnostic::{
};
use crate::types::signatures::{Parameter, ParameterForm};
use crate::types::{
todo_type, BoundMethodType, ClassLiteralType, FunctionDecorators, KnownClass, KnownFunction,
KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType, WrapperDescriptorKind,
todo_type, BoundMethodType, FunctionDecorators, KnownClass, KnownFunction, KnownInstanceType,
MethodWrapperKind, PropertyInstanceType, UnionType, WrapperDescriptorKind,
};
use ruff_db::diagnostic::{OldSecondaryDiagnosticMessage, Span};
use ruff_python_ast as ast;
@@ -566,7 +566,7 @@ impl<'db> Bindings<'db> {
_ => {}
},
Type::ClassLiteral(ClassLiteralType { class }) => match class.known(db) {
Type::ClassLiteral(class) => match class.known(db) {
Some(KnownClass::Bool) => match overload.parameter_types() {
[Some(arg)] => overload.set_return_type(arg.bool(db).into_type(db)),
[None] => overload.set_return_type(Type::BooleanLiteral(false)),
@@ -1064,7 +1064,7 @@ impl<'db> CallableDescription<'db> {
}),
Type::ClassLiteral(class_type) => Some(CallableDescription {
kind: "class",
name: class_type.class().name(db),
name: class_type.name(db),
}),
Type::BoundMethod(bound_method) => Some(CallableDescription {
kind: "bound method",

View File

@@ -6,6 +6,7 @@ use super::{
Type, TypeAliasType, TypeQualifiers, TypeVarInstance,
};
use crate::semantic_index::definition::Definition;
use crate::types::generics::{GenericContext, Specialization};
use crate::{
module_resolver::file_to_module,
semantic_index::{
@@ -24,36 +25,199 @@ use crate::{
};
use indexmap::IndexSet;
use itertools::Itertools as _;
use ruff_db::files::{File, FileRange};
use ruff_db::files::File;
use ruff_python_ast::{self as ast, PythonVersion};
use rustc_hash::FxHashSet;
/// Representation of a runtime class object.
///
/// Does not in itself represent a type,
/// but is used as the inner data for several structs that *do* represent types.
#[salsa::interned(debug)]
fn explicit_bases_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &[Type<'db>],
_count: u32,
_self: ClassLiteralType<'db>,
) -> salsa::CycleRecoveryAction<Box<[Type<'db>]>> {
salsa::CycleRecoveryAction::Iterate
}
fn explicit_bases_cycle_initial<'db>(
_db: &'db dyn Db,
_self: ClassLiteralType<'db>,
) -> Box<[Type<'db>]> {
Box::default()
}
fn try_mro_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &Result<Mro<'db>, MroError<'db>>,
_count: u32,
_self: ClassLiteralType<'db>,
_specialization: Option<Specialization<'db>>,
) -> salsa::CycleRecoveryAction<Result<Mro<'db>, MroError<'db>>> {
salsa::CycleRecoveryAction::Iterate
}
#[allow(clippy::unnecessary_wraps)]
fn try_mro_cycle_initial<'db>(
db: &'db dyn Db,
self_: ClassLiteralType<'db>,
specialization: Option<Specialization<'db>>,
) -> Result<Mro<'db>, MroError<'db>> {
Ok(Mro::from_error(
db,
self_.apply_optional_specialization(db, specialization),
))
}
#[allow(clippy::ref_option, clippy::trivially_copy_pass_by_ref)]
fn inheritance_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &Option<InheritanceCycle>,
_count: u32,
_self: ClassLiteralType<'db>,
) -> salsa::CycleRecoveryAction<Option<InheritanceCycle>> {
salsa::CycleRecoveryAction::Iterate
}
fn inheritance_cycle_initial<'db>(
_db: &'db dyn Db,
_self: ClassLiteralType<'db>,
) -> Option<InheritanceCycle> {
None
}
/// Representation of a class definition statement in the AST. This does not in itself represent a
/// type, but is used as the inner data for several structs that *do* represent types.
#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
pub struct Class<'db> {
/// Name of the class at definition
#[return_ref]
pub(crate) name: ast::name::Name,
body_scope: ScopeId<'db>,
pub(crate) body_scope: ScopeId<'db>,
pub(crate) known: Option<KnownClass>,
}
#[salsa::tracked]
impl<'db> Class<'db> {
fn file(&self, db: &dyn Db) -> File {
self.body_scope.file(db)
}
/// Return the original [`ast::StmtClassDef`] node associated with this class
///
/// ## Note
/// Only call this function from queries in the same file or your
/// query depends on the AST of another file (bad!).
fn node(&self, db: &'db dyn Db) -> &'db ast::StmtClassDef {
self.body_scope.node(db).expect_class()
}
fn definition(&self, db: &'db dyn Db) -> Definition<'db> {
let index = semantic_index(db, self.body_scope.file(db));
index.expect_single_definition(self.body_scope.node(db).expect_class())
}
}
/// A [`Class`] that is not generic.
#[salsa::interned(debug)]
pub struct NonGenericClass<'db> {
#[return_ref]
pub(crate) class: Class<'db>,
}
impl<'db> From<NonGenericClass<'db>> for Type<'db> {
fn from(class: NonGenericClass<'db>) -> Type<'db> {
Type::ClassLiteral(ClassLiteralType::NonGeneric(class))
}
}
/// A [`Class`] that is generic.
#[salsa::interned(debug)]
pub struct GenericClass<'db> {
#[return_ref]
pub(crate) class: Class<'db>,
pub(crate) generic_context: GenericContext<'db>,
}
impl<'db> From<GenericClass<'db>> for Type<'db> {
fn from(class: GenericClass<'db>) -> Type<'db> {
Type::ClassLiteral(ClassLiteralType::Generic(class))
}
}
/// A specialization of a generic class with a particular assignment of types to typevars.
#[salsa::interned(debug)]
pub struct GenericAlias<'db> {
pub(crate) origin: GenericClass<'db>,
pub(crate) specialization: Specialization<'db>,
}
impl<'db> GenericAlias<'db> {
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
let scope = self.body_scope(db);
let index = semantic_index(db, scope.file(db));
index.expect_single_definition(scope.node(db).expect_class())
self.origin(db).class(db).definition(db)
}
}
impl<'db> From<GenericAlias<'db>> for Type<'db> {
fn from(alias: GenericAlias<'db>) -> Type<'db> {
Type::GenericAlias(alias)
}
}
/// Represents a class type, which might be a non-generic class, or a specialization of a generic
/// class.
#[derive(
Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, salsa::Supertype, salsa::Update,
)]
pub enum ClassType<'db> {
NonGeneric(NonGenericClass<'db>),
Generic(GenericAlias<'db>),
}
#[salsa::tracked]
impl<'db> ClassType<'db> {
fn class(self, db: &'db dyn Db) -> &'db Class<'db> {
match self {
Self::NonGeneric(non_generic) => non_generic.class(db),
Self::Generic(generic) => generic.origin(db).class(db),
}
}
/// Returns the class literal and specialization for this class. For a non-generic class, this
/// is the class itself. For a generic alias, this is the alias's origin.
pub(crate) fn class_literal(
self,
db: &'db dyn Db,
) -> (ClassLiteralType<'db>, Option<Specialization<'db>>) {
match self {
Self::NonGeneric(non_generic) => (ClassLiteralType::NonGeneric(non_generic), None),
Self::Generic(generic) => (
ClassLiteralType::Generic(generic.origin(db)),
Some(generic.specialization(db)),
),
}
}
pub(crate) fn name(self, db: &'db dyn Db) -> &'db ast::name::Name {
&self.class(db).name
}
pub(crate) fn known(self, db: &'db dyn Db) -> Option<KnownClass> {
self.class(db).known
}
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
self.class(db).definition(db)
}
fn specialize_type(self, db: &'db dyn Db, ty: Type<'db>) -> Type<'db> {
match self {
Self::NonGeneric(_) => ty,
Self::Generic(generic) => ty.apply_specialization(db, generic.specialization(db)),
}
}
/// Return `true` if this class represents `known_class`
pub(crate) fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool {
self.known(db) == Some(known_class)
self.class(db).known == Some(known_class)
}
/// Return `true` if this class represents the builtin class `object`
@@ -61,6 +225,202 @@ impl<'db> Class<'db> {
self.is_known(db, KnownClass::Object)
}
/// Iterate over the [method resolution order] ("MRO") of the class.
///
/// If the MRO could not be accurately resolved, this method falls back to iterating
/// over an MRO that has the class directly inheriting from `Unknown`. Use
/// [`ClassLiteralType::try_mro`] if you need to distinguish between the success and failure
/// cases rather than simply iterating over the inferred resolution order for the class.
///
/// [method resolution order]: https://docs.python.org/3/glossary.html#term-method-resolution-order
pub(super) fn iter_mro(self, db: &'db dyn Db) -> impl Iterator<Item = ClassBase<'db>> {
let (class_literal, specialization) = self.class_literal(db);
class_literal.iter_mro(db, specialization)
}
/// Is this class final?
pub(super) fn is_final(self, db: &'db dyn Db) -> bool {
let (class_literal, _) = self.class_literal(db);
class_literal.is_final(db)
}
/// Return `true` if `other` is present in this class's MRO.
pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: ClassType<'db>) -> bool {
// `is_subclass_of` is checking the subtype relation, in which gradual types do not
// participate, so we should not return `True` if we find `Any/Unknown` in the MRO.
self.iter_mro(db).contains(&ClassBase::Class(other))
}
/// Return the metaclass of this class, or `type[Unknown]` if the metaclass cannot be inferred.
pub(super) fn metaclass(self, db: &'db dyn Db) -> Type<'db> {
let (class_literal, _) = self.class_literal(db);
self.specialize_type(db, class_literal.metaclass(db))
}
/// Return a type representing "the set of all instances of the metaclass of this class".
pub(super) fn metaclass_instance_type(self, db: &'db dyn Db) -> Type<'db> {
self
.metaclass(db)
.to_instance(db)
.expect("`Type::to_instance()` should always return `Some()` when called on the type of a metaclass")
}
/// Returns the class member of this class named `name`.
///
/// The member resolves to a member on the class itself or any of its proper superclasses.
///
/// TODO: Should this be made private...?
pub(super) fn class_member(
self,
db: &'db dyn Db,
name: &str,
policy: MemberLookupPolicy,
) -> SymbolAndQualifiers<'db> {
let (class_literal, specialization) = self.class_literal(db);
class_literal
.class_member(db, specialization, name, policy)
.map_type(|ty| self.specialize_type(db, ty))
}
/// Returns the inferred type of the class member named `name`. Only bound members
/// or those marked as ClassVars are considered.
///
/// Returns [`Symbol::Unbound`] if `name` cannot be found in this class's scope
/// directly. Use [`ClassType::class_member`] if you require a method that will
/// traverse through the MRO until it finds the member.
pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> {
let (class_literal, _) = self.class_literal(db);
class_literal
.own_class_member(db, name)
.map_type(|ty| self.specialize_type(db, ty))
}
/// Returns the `name` attribute of an instance of this class.
///
/// The attribute could be defined in the class body, but it could also be an implicitly
/// defined attribute that is only present in a method (typically `__init__`).
///
/// The attribute might also be defined in a superclass of this class.
pub(super) fn instance_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> {
let (class_literal, specialization) = self.class_literal(db);
class_literal
.instance_member(db, specialization, name)
.map_type(|ty| self.specialize_type(db, ty))
}
/// A helper function for `instance_member` that looks up the `name` attribute only on
/// this class, not on its superclasses.
fn own_instance_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> {
let (class_literal, _) = self.class_literal(db);
class_literal
.own_instance_member(db, name)
.map_type(|ty| self.specialize_type(db, ty))
}
}
impl<'db> From<GenericAlias<'db>> for ClassType<'db> {
fn from(generic: GenericAlias<'db>) -> ClassType<'db> {
ClassType::Generic(generic)
}
}
impl<'db> From<ClassType<'db>> for Type<'db> {
fn from(class: ClassType<'db>) -> Type<'db> {
match class {
ClassType::NonGeneric(non_generic) => non_generic.into(),
ClassType::Generic(generic) => generic.into(),
}
}
}
/// Represents a single class object at runtime, which might be a non-generic class, or a generic
/// class that has not been specialized.
#[derive(
Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, salsa::Supertype, salsa::Update,
)]
pub enum ClassLiteralType<'db> {
NonGeneric(NonGenericClass<'db>),
Generic(GenericClass<'db>),
}
#[salsa::tracked]
impl<'db> ClassLiteralType<'db> {
fn class(self, db: &'db dyn Db) -> &'db Class<'db> {
match self {
Self::NonGeneric(non_generic) => non_generic.class(db),
Self::Generic(generic) => generic.class(db),
}
}
pub(crate) fn name(self, db: &'db dyn Db) -> &'db ast::name::Name {
&self.class(db).name
}
pub(crate) fn known(self, db: &'db dyn Db) -> Option<KnownClass> {
self.class(db).known
}
/// Return `true` if this class represents `known_class`
pub(crate) fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool {
self.class(db).known == Some(known_class)
}
/// Return `true` if this class represents the builtin class `object`
pub(crate) fn is_object(self, db: &'db dyn Db) -> bool {
self.is_known(db, KnownClass::Object)
}
pub(crate) fn body_scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.class(db).body_scope
}
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
self.class(db).definition(db)
}
pub(crate) fn apply_optional_specialization(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
) -> ClassType<'db> {
match (self, specialization) {
(Self::NonGeneric(non_generic), _) => ClassType::NonGeneric(non_generic),
(Self::Generic(generic), None) => {
let specialization = generic.generic_context(db).default_specialization(db);
ClassType::Generic(GenericAlias::new(db, generic, specialization))
}
(Self::Generic(generic), Some(specialization)) => {
ClassType::Generic(GenericAlias::new(db, generic, specialization))
}
}
}
/// Returns the default specialization of this class. For non-generic classes, the class is
/// returned unchanged. For a non-specialized generic class, we return a generic alias that
/// applies the default specialization to the class's typevars.
pub(crate) fn default_specialization(self, db: &'db dyn Db) -> ClassType<'db> {
match self {
Self::NonGeneric(non_generic) => ClassType::NonGeneric(non_generic),
Self::Generic(generic) => {
let specialization = generic.generic_context(db).default_specialization(db);
ClassType::Generic(GenericAlias::new(db, generic, specialization))
}
}
}
/// Returns the unknown specialization of this class. For non-generic classes, the class is
/// returned unchanged. For a non-specialized generic class, we return a generic alias that
/// maps each of the class's typevars to `Unknown`.
pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> ClassType<'db> {
match self {
Self::NonGeneric(non_generic) => ClassType::NonGeneric(non_generic),
Self::Generic(generic) => {
let specialization = generic.generic_context(db).unknown_specialization(db);
ClassType::Generic(GenericAlias::new(db, generic, specialization))
}
}
}
/// Return an iterator over the inferred types of this class's *explicit* bases.
///
/// Note that any class (except for `object`) that has no explicit
@@ -79,21 +439,21 @@ impl<'db> Class<'db> {
}
/// Iterate over this class's explicit bases, filtering out any bases that are not class objects.
fn fully_static_explicit_bases(self, db: &'db dyn Db) -> impl Iterator<Item = Class<'db>> {
fn fully_static_explicit_bases(self, db: &'db dyn Db) -> impl Iterator<Item = ClassType<'db>> {
self.explicit_bases(db)
.iter()
.copied()
.filter_map(Type::into_class_literal)
.map(|ClassLiteralType { class }| class)
.filter_map(Type::into_class_type)
}
#[salsa::tracked(return_ref, cycle_fn=explicit_bases_cycle_recover, cycle_initial=explicit_bases_cycle_initial)]
fn explicit_bases_query(self, db: &'db dyn Db) -> Box<[Type<'db>]> {
tracing::trace!("Class::explicit_bases_query: {}", self.name(db));
let class = self.class(db);
tracing::trace!("ClassLiteralType::explicit_bases_query: {}", class.name);
let class_stmt = self.node(db);
let class_stmt = class.node(db);
let class_definition =
semantic_index(db, self.file(db)).expect_single_definition(class_stmt);
semantic_index(db, class.file(db)).expect_single_definition(class_stmt);
class_stmt
.bases()
@@ -102,40 +462,19 @@ impl<'db> Class<'db> {
.collect()
}
fn file(self, db: &dyn Db) -> File {
self.body_scope(db).file(db)
}
/// Return the original [`ast::StmtClassDef`] node associated with this class
///
/// ## Note
/// Only call this function from queries in the same file or your
/// query depends on the AST of another file (bad!).
fn node(self, db: &'db dyn Db) -> &'db ast::StmtClassDef {
self.body_scope(db).node(db).expect_class()
}
/// Returns the file range of the class's name.
pub fn focus_range(self, db: &dyn Db) -> FileRange {
FileRange::new(self.file(db), self.node(db).name.range)
}
pub fn full_range(self, db: &dyn Db) -> FileRange {
FileRange::new(self.file(db), self.node(db).range)
}
/// Return the types of the decorators on this class
#[salsa::tracked(return_ref)]
fn decorators(self, db: &'db dyn Db) -> Box<[Type<'db>]> {
tracing::trace!("Class::decorators: {}", self.name(db));
let class = self.class(db);
tracing::trace!("ClassLiteralType::decorators: {}", class.name);
let class_stmt = self.node(db);
let class_stmt = class.node(db);
if class_stmt.decorator_list.is_empty() {
return Box::new([]);
}
let class_definition =
semantic_index(db, self.file(db)).expect_single_definition(class_stmt);
semantic_index(db, class.file(db)).expect_single_definition(class_stmt);
class_stmt
.decorator_list
@@ -164,28 +503,43 @@ impl<'db> Class<'db> {
///
/// [method resolution order]: https://docs.python.org/3/glossary.html#term-method-resolution-order
#[salsa::tracked(return_ref, cycle_fn=try_mro_cycle_recover, cycle_initial=try_mro_cycle_initial)]
pub(super) fn try_mro(self, db: &'db dyn Db) -> Result<Mro<'db>, MroError<'db>> {
tracing::trace!("Class::try_mro: {}", self.name(db));
Mro::of_class(db, self)
pub(super) fn try_mro(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
) -> Result<Mro<'db>, MroError<'db>> {
let class = self.class(db);
tracing::trace!("ClassLiteralType::try_mro: {}", class.name);
Mro::of_class(db, self, specialization)
}
/// Iterate over the [method resolution order] ("MRO") of the class.
///
/// If the MRO could not be accurately resolved, this method falls back to iterating
/// over an MRO that has the class directly inheriting from `Unknown`. Use
/// [`Class::try_mro`] if you need to distinguish between the success and failure
/// [`ClassLiteralType::try_mro`] if you need to distinguish between the success and failure
/// cases rather than simply iterating over the inferred resolution order for the class.
///
/// [method resolution order]: https://docs.python.org/3/glossary.html#term-method-resolution-order
pub(super) fn iter_mro(self, db: &'db dyn Db) -> impl Iterator<Item = ClassBase<'db>> {
MroIterator::new(db, self)
pub(super) fn iter_mro(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
) -> impl Iterator<Item = ClassBase<'db>> {
MroIterator::new(db, self, specialization)
}
/// Return `true` if `other` is present in this class's MRO.
pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: Class) -> bool {
pub(super) fn is_subclass_of(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
other: ClassType<'db>,
) -> bool {
// `is_subclass_of` is checking the subtype relation, in which gradual types do not
// participate, so we should not return `True` if we find `Any/Unknown` in the MRO.
self.iter_mro(db).contains(&ClassBase::Class(other))
self.iter_mro(db, specialization)
.contains(&ClassBase::Class(other))
}
/// Return the explicit `metaclass` of this class, if one is defined.
@@ -194,14 +548,15 @@ impl<'db> Class<'db> {
/// Only call this function from queries in the same file or your
/// query depends on the AST of another file (bad!).
fn explicit_metaclass(self, db: &'db dyn Db) -> Option<Type<'db>> {
let class_stmt = self.node(db);
let class = self.class(db);
let class_stmt = class.node(db);
let metaclass_node = &class_stmt
.arguments
.as_ref()?
.find_keyword("metaclass")?
.value;
let class_definition = self.definition(db);
let class_definition = class.definition(db);
Some(definition_expression_type(
db,
@@ -227,7 +582,8 @@ impl<'db> Class<'db> {
/// Return the metaclass of this class, or an error if the metaclass cannot be inferred.
#[salsa::tracked]
pub(super) fn try_metaclass(self, db: &'db dyn Db) -> Result<Type<'db>, MetaclassError<'db>> {
tracing::trace!("Class::try_metaclass: {}", self.name(db));
let class = self.class(db);
tracing::trace!("ClassLiteralType::try_metaclass: {}", class.name);
// Identify the class's own metaclass (or take the first base class's metaclass).
let mut base_classes = self.fully_static_explicit_bases(db).peekable();
@@ -243,18 +599,19 @@ impl<'db> Class<'db> {
let (metaclass, class_metaclass_was_from) = if let Some(metaclass) = explicit_metaclass {
(metaclass, self)
} else if let Some(base_class) = base_classes.next() {
(base_class.metaclass(db), base_class)
let (base_class_literal, _) = base_class.class_literal(db);
(base_class.metaclass(db), base_class_literal)
} else {
(KnownClass::Type.to_class_literal(db), self)
};
let mut candidate = if let Type::ClassLiteral(metaclass_ty) = metaclass {
let mut candidate = if let Some(metaclass_ty) = metaclass.into_class_type() {
MetaclassCandidate {
metaclass: metaclass_ty.class,
metaclass: metaclass_ty,
explicit_metaclass_of: class_metaclass_was_from,
}
} else {
let name = Type::string_literal(db, self.name(db));
let name = Type::string_literal(db, &class.name);
let bases = TupleType::from_elements(db, self.explicit_bases(db));
// TODO: Should be `dict[str, Any]`
let namespace = KnownClass::Dict.to_instance(db);
@@ -290,32 +647,34 @@ impl<'db> Class<'db> {
// - https://github.com/python/cpython/blob/83ba8c2bba834c0b92de669cac16fcda17485e0e/Objects/typeobject.c#L3629-L3663
for base_class in base_classes {
let metaclass = base_class.metaclass(db);
let Type::ClassLiteral(metaclass) = metaclass else {
let Some(metaclass) = metaclass.into_class_type() else {
continue;
};
if metaclass.class.is_subclass_of(db, candidate.metaclass) {
if metaclass.is_subclass_of(db, candidate.metaclass) {
let (base_class_literal, _) = base_class.class_literal(db);
candidate = MetaclassCandidate {
metaclass: metaclass.class,
explicit_metaclass_of: base_class,
metaclass,
explicit_metaclass_of: base_class_literal,
};
continue;
}
if candidate.metaclass.is_subclass_of(db, metaclass.class) {
if candidate.metaclass.is_subclass_of(db, metaclass) {
continue;
}
let (base_class_literal, _) = base_class.class_literal(db);
return Err(MetaclassError {
kind: MetaclassErrorKind::Conflict {
candidate1: candidate,
candidate2: MetaclassCandidate {
metaclass: metaclass.class,
explicit_metaclass_of: base_class,
metaclass,
explicit_metaclass_of: base_class_literal,
},
candidate1_is_base_class: explicit_metaclass.is_none(),
},
});
}
Ok(Type::class_literal(candidate.metaclass))
Ok(candidate.metaclass.into())
}
/// Returns the class member of this class named `name`.
@@ -326,11 +685,12 @@ impl<'db> Class<'db> {
pub(super) fn class_member(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
name: &str,
policy: MemberLookupPolicy,
) -> SymbolAndQualifiers<'db> {
if name == "__mro__" {
let tuple_elements = self.iter_mro(db).map(Type::from);
let tuple_elements = self.iter_mro(db, specialization).map(Type::from);
return Symbol::bound(TupleType::from_elements(db, tuple_elements)).into();
}
@@ -345,7 +705,7 @@ impl<'db> Class<'db> {
let mut lookup_result: LookupResult<'db> =
Err(LookupError::Unbound(TypeQualifiers::empty()));
for superclass in self.iter_mro(db) {
for superclass in self.iter_mro(db, specialization) {
match superclass {
ClassBase::Dynamic(DynamicType::TodoProtocol) => {
// TODO: We currently skip `Protocol` when looking up class members, in order to
@@ -415,7 +775,7 @@ impl<'db> Class<'db> {
/// or those marked as ClassVars are considered.
///
/// Returns [`Symbol::Unbound`] if `name` cannot be found in this class's scope
/// directly. Use [`Class::class_member`] if you require a method that will
/// directly. Use [`ClassLiteralType::class_member`] if you require a method that will
/// traverse through the MRO until it finds the member.
pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> {
let body_scope = self.body_scope(db);
@@ -428,11 +788,16 @@ impl<'db> Class<'db> {
/// defined attribute that is only present in a method (typically `__init__`).
///
/// The attribute might also be defined in a superclass of this class.
pub(super) fn instance_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> {
pub(super) fn instance_member(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
name: &str,
) -> SymbolAndQualifiers<'db> {
let mut union = UnionBuilder::new(db);
let mut union_qualifiers = TypeQualifiers::empty();
for superclass in self.iter_mro(db) {
for superclass in self.iter_mro(db, specialization) {
match superclass {
ClassBase::Dynamic(DynamicType::TodoProtocol) => {
// TODO: We currently skip `Protocol` when looking up instance members, in order to
@@ -680,21 +1045,22 @@ impl<'db> Class<'db> {
/// Also, populates `visited_classes` with all base classes of `self`.
fn is_cyclically_defined_recursive<'db>(
db: &'db dyn Db,
class: Class<'db>,
classes_on_stack: &mut IndexSet<Class<'db>>,
visited_classes: &mut IndexSet<Class<'db>>,
class: ClassLiteralType<'db>,
classes_on_stack: &mut IndexSet<ClassLiteralType<'db>>,
visited_classes: &mut IndexSet<ClassLiteralType<'db>>,
) -> bool {
let mut result = false;
for explicit_base_class in class.fully_static_explicit_bases(db) {
if !classes_on_stack.insert(explicit_base_class) {
let (explicit_base_class_literal, _) = explicit_base_class.class_literal(db);
if !classes_on_stack.insert(explicit_base_class_literal) {
return true;
}
if visited_classes.insert(explicit_base_class) {
if visited_classes.insert(explicit_base_class_literal) {
// If we find a cycle, keep searching to check if we can reach the starting class.
result |= is_cyclically_defined_recursive(
db,
explicit_base_class,
explicit_base_class_literal,
classes_on_stack,
visited_classes,
);
@@ -718,48 +1084,13 @@ impl<'db> Class<'db> {
}
}
fn explicit_bases_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &[Type<'db>],
_count: u32,
_self: Class<'db>,
) -> salsa::CycleRecoveryAction<Box<[Type<'db>]>> {
salsa::CycleRecoveryAction::Iterate
}
fn explicit_bases_cycle_initial<'db>(_db: &'db dyn Db, _self: Class<'db>) -> Box<[Type<'db>]> {
Box::default()
}
fn try_mro_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &Result<Mro<'db>, MroError<'db>>,
_count: u32,
_self: Class<'db>,
) -> salsa::CycleRecoveryAction<Result<Mro<'db>, MroError<'db>>> {
salsa::CycleRecoveryAction::Iterate
}
#[allow(clippy::unnecessary_wraps)]
fn try_mro_cycle_initial<'db>(
db: &'db dyn Db,
self_: Class<'db>,
) -> Result<Mro<'db>, MroError<'db>> {
Ok(Mro::from_error(db, self_))
}
#[allow(clippy::ref_option, clippy::trivially_copy_pass_by_ref)]
fn inheritance_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &Option<InheritanceCycle>,
_count: u32,
_self: Class<'db>,
) -> salsa::CycleRecoveryAction<Option<InheritanceCycle>> {
salsa::CycleRecoveryAction::Iterate
}
fn inheritance_cycle_initial<'db>(_db: &'db dyn Db, _self: Class<'db>) -> Option<InheritanceCycle> {
None
impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
fn from(class: ClassLiteralType<'db>) -> Type<'db> {
match class {
ClassLiteralType::NonGeneric(non_generic) => non_generic.into(),
ClassLiteralType::Generic(generic) => generic.into(),
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
@@ -778,48 +1109,13 @@ impl InheritanceCycle {
}
}
/// A singleton type representing a single class object at runtime.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update)]
pub struct ClassLiteralType<'db> {
pub(super) class: Class<'db>,
}
impl<'db> ClassLiteralType<'db> {
pub(super) fn class(self) -> Class<'db> {
self.class
}
pub(crate) fn body_scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.class.body_scope(db)
}
pub(super) fn class_member(
self,
db: &'db dyn Db,
name: &str,
policy: MemberLookupPolicy,
) -> SymbolAndQualifiers<'db> {
self.class.class_member(db, name, policy)
}
}
impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
fn from(value: ClassLiteralType<'db>) -> Self {
Self::ClassLiteral(value)
}
}
/// A type representing the set of runtime objects which are instances of a certain class.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update)]
pub struct InstanceType<'db> {
pub(super) class: Class<'db>,
pub class: ClassType<'db>,
}
impl<'db> InstanceType<'db> {
pub(super) fn class(self) -> Class<'db> {
self.class
}
pub(super) fn is_subtype_of(self, db: &'db dyn Db, other: InstanceType<'db>) -> bool {
// N.B. The subclass relation is fully static
self.class.is_subclass_of(db, other.class)
@@ -839,7 +1135,7 @@ impl<'db> From<InstanceType<'db>> for Type<'db> {
/// Note: good candidates are any classes in `[crate::module_resolver::module::KnownModule]`
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(test, derive(strum_macros::EnumIter))]
pub enum KnownClass {
pub(crate) enum KnownClass {
// To figure out where an stdlib symbol is defined, you can go into `crates/red_knot_vendored`
// and grep for the symbol name in any `.pyi` file.
@@ -1081,8 +1377,8 @@ impl<'db> KnownClass {
/// If the class cannot be found in typeshed, a debug-level log message will be emitted stating this.
pub(crate) fn to_instance(self, db: &'db dyn Db) -> Type<'db> {
self.to_class_literal(db)
.into_class_literal()
.map(|ClassLiteralType { class }| Type::instance(class))
.into_class_type()
.map(Type::instance)
.unwrap_or_else(Type::unknown)
}
@@ -1096,9 +1392,9 @@ impl<'db> KnownClass {
) -> Result<ClassLiteralType<'db>, KnownClassLookupError<'db>> {
let symbol = known_module_symbol(db, self.canonical_module(db), self.name(db)).symbol;
match symbol {
Symbol::Type(Type::ClassLiteral(class_type), Boundness::Bound) => Ok(class_type),
Symbol::Type(Type::ClassLiteral(class_type), Boundness::PossiblyUnbound) => {
Err(KnownClassLookupError::ClassPossiblyUnbound { class_type })
Symbol::Type(Type::ClassLiteral(class_literal), Boundness::Bound) => Ok(class_literal),
Symbol::Type(Type::ClassLiteral(class_literal), Boundness::PossiblyUnbound) => {
Err(KnownClassLookupError::ClassPossiblyUnbound { class_literal })
}
Symbol::Type(found_type, _) => {
Err(KnownClassLookupError::SymbolNotAClass { found_type })
@@ -1133,8 +1429,8 @@ impl<'db> KnownClass {
}
match lookup_error {
KnownClassLookupError::ClassPossiblyUnbound { class_type, .. } => {
Type::class_literal(class_type.class)
KnownClassLookupError::ClassPossiblyUnbound { class_literal, .. } => {
class_literal.into()
}
KnownClassLookupError::ClassNotFound { .. }
| KnownClassLookupError::SymbolNotAClass { .. } => Type::unknown(),
@@ -1148,16 +1444,16 @@ impl<'db> KnownClass {
/// If the class cannot be found in typeshed, a debug-level log message will be emitted stating this.
pub(crate) fn to_subclass_of(self, db: &'db dyn Db) -> Type<'db> {
self.to_class_literal(db)
.into_class_literal()
.map(|ClassLiteralType { class }| SubclassOfType::from(db, class))
.into_class_type()
.map(|class| SubclassOfType::from(db, class))
.unwrap_or_else(SubclassOfType::subclass_of_unknown)
}
/// Return `true` if this symbol can be resolved to a class definition `class` in typeshed,
/// *and* `class` is a subclass of `other`.
pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: Class<'db>) -> bool {
pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: ClassType<'db>) -> bool {
self.try_to_class_literal(db)
.is_ok_and(|ClassLiteralType { class }| class.is_subclass_of(db, other))
.is_ok_and(|class| class.is_subclass_of(db, None, other))
}
/// Return the module in which we should look up the definition for this class
@@ -1489,7 +1785,9 @@ pub(crate) enum KnownClassLookupError<'db> {
SymbolNotAClass { found_type: Type<'db> },
/// There is a symbol by that name in the expected typeshed module,
/// and it's a class definition, but it's possibly unbound.
ClassPossiblyUnbound { class_type: ClassLiteralType<'db> },
ClassPossiblyUnbound {
class_literal: ClassLiteralType<'db>,
},
}
impl<'db> KnownClassLookupError<'db> {
@@ -1769,7 +2067,7 @@ impl<'db> KnownInstanceType<'db> {
}
/// Return `true` if this symbol is an instance of `class`.
pub(super) fn is_instance_of(self, db: &'db dyn Db, class: Class<'db>) -> bool {
pub(super) fn is_instance_of(self, db: &'db dyn Db, class: ClassType<'db>) -> bool {
self.class().is_subclass_of(db, class)
}

View File

@@ -1,16 +1,19 @@
use crate::types::{todo_type, Class, DynamicType, KnownClass, KnownInstanceType, Type};
use crate::types::{todo_type, ClassType, DynamicType, KnownClass, KnownInstanceType, Type};
use crate::Db;
use itertools::Either;
/// Enumeration of the possible kinds of types we allow in class bases.
///
/// This is much more limited than the [`Type`] enum:
/// all types that would be invalid to have as a class base are
/// transformed into [`ClassBase::unknown`]
/// This is much more limited than the [`Type`] enum: all types that would be invalid to have as a
/// class base are transformed into [`ClassBase::unknown`]
///
/// Note that a non-specialized generic class _cannot_ be a class base. When we see a
/// non-specialized generic class in any type expression (including the list of base classes), we
/// automatically construct the default specialization for that class.
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, salsa::Update)]
pub(crate) enum ClassBase<'db> {
Dynamic(DynamicType),
Class(Class<'db>),
Class(ClassType<'db>),
}
impl<'db> ClassBase<'db> {
@@ -39,7 +42,12 @@ impl<'db> ClassBase<'db> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.base {
ClassBase::Dynamic(dynamic) => dynamic.fmt(f),
ClassBase::Class(class) => write!(f, "<class '{}'>", class.name(self.db)),
ClassBase::Class(class @ ClassType::NonGeneric(_)) => {
write!(f, "<class '{}'>", class.name(self.db))
}
ClassBase::Class(ClassType::Generic(alias)) => {
write!(f, "<class '{}'>", alias.display(self.db))
}
}
}
}
@@ -51,8 +59,8 @@ impl<'db> ClassBase<'db> {
pub(super) fn object(db: &'db dyn Db) -> Self {
KnownClass::Object
.to_class_literal(db)
.into_class_literal()
.map_or(Self::unknown(), |literal| Self::Class(literal.class()))
.into_class_type()
.map_or(Self::unknown(), Self::Class)
}
/// Attempt to resolve `ty` into a `ClassBase`.
@@ -61,11 +69,12 @@ impl<'db> ClassBase<'db> {
pub(super) fn try_from_type(db: &'db dyn Db, ty: Type<'db>) -> Option<Self> {
match ty {
Type::Dynamic(dynamic) => Some(Self::Dynamic(dynamic)),
Type::ClassLiteral(literal) => Some(if literal.class().is_known(db, KnownClass::Any) {
Type::ClassLiteral(literal) => Some(if literal.is_known(db, KnownClass::Any) {
Self::Dynamic(DynamicType::Any)
} else {
Self::Class(literal.class())
Self::Class(literal.default_specialization(db))
}),
Type::GenericAlias(generic) => Some(Self::Class(ClassType::Generic(generic))),
Type::Union(_) => None, // TODO -- forces consideration of multiple possible MROs?
Type::Intersection(_) => None, // TODO -- probably incorrect?
Type::Instance(_) => None, // TODO -- handle `__mro_entries__`?
@@ -159,7 +168,7 @@ impl<'db> ClassBase<'db> {
}
}
pub(super) fn into_class(self) -> Option<Class<'db>> {
pub(super) fn into_class(self) -> Option<ClassType<'db>> {
match self {
Self::Class(class) => Some(class),
Self::Dynamic(_) => None,
@@ -178,8 +187,8 @@ impl<'db> ClassBase<'db> {
}
}
impl<'db> From<Class<'db>> for ClassBase<'db> {
fn from(value: Class<'db>) -> Self {
impl<'db> From<ClassType<'db>> for ClassBase<'db> {
fn from(value: ClassType<'db>) -> Self {
ClassBase::Class(value)
}
}
@@ -188,7 +197,7 @@ impl<'db> From<ClassBase<'db>> for Type<'db> {
fn from(value: ClassBase<'db>) -> Self {
match value {
ClassBase::Dynamic(dynamic) => Type::Dynamic(dynamic),
ClassBase::Class(class) => Type::class_literal(class),
ClassBase::Class(class) => class.into(),
}
}
}

View File

@@ -7,7 +7,7 @@ use crate::types::string_annotation::{
IMPLICIT_CONCATENATED_STRING_TYPE_ANNOTATION, INVALID_SYNTAX_IN_FORWARD_ANNOTATION,
RAW_STRING_TYPE_ANNOTATION,
};
use crate::types::{ClassLiteralType, KnownInstanceType, Type};
use crate::types::{KnownInstanceType, Type};
use ruff_db::diagnostic::{Diagnostic, OldSecondaryDiagnosticMessage, Span};
use ruff_python_ast::{self as ast, AnyNodeRef};
use ruff_text_size::Ranged;
@@ -1025,7 +1025,7 @@ fn report_invalid_assignment_with_message(
message: std::fmt::Arguments,
) {
match target_ty {
Type::ClassLiteral(ClassLiteralType { class }) => {
Type::ClassLiteral(class) => {
context.report_lint(&INVALID_ASSIGNMENT, node, format_args!(
"Implicit shadowing of class `{}`; annotate to make it explicit if this is intentional",
class.name(context.db())));

View File

@@ -6,11 +6,13 @@ use ruff_db::display::FormatterJoinExtension;
use ruff_python_ast::str::{Quote, TripleQuotes};
use ruff_python_literal::escape::AsciiEscape;
use crate::types::class::{ClassType, GenericAlias, GenericClass};
use crate::types::class_base::ClassBase;
use crate::types::generics::Specialization;
use crate::types::signatures::{Parameter, Parameters, Signature};
use crate::types::{
ClassLiteralType, InstanceType, IntersectionType, KnownClass, MethodWrapperKind,
StringLiteralType, Type, UnionType, WrapperDescriptorKind,
InstanceType, IntersectionType, KnownClass, MethodWrapperKind, StringLiteralType, Type,
TypeVarInstance, UnionType, WrapperDescriptorKind,
};
use crate::Db;
use rustc_hash::FxHashMap;
@@ -34,7 +36,7 @@ impl Display for DisplayType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let representation = self.ty.representation(self.db);
match self.ty {
Type::ClassLiteral(literal) if literal.class().is_known(self.db, KnownClass::Any) => {
Type::ClassLiteral(literal) if literal.is_known(self.db, KnownClass::Any) => {
write!(f, "typing.Any")
}
Type::IntLiteral(_)
@@ -42,6 +44,7 @@ impl Display for DisplayType<'_> {
| Type::StringLiteral(_)
| Type::BytesLiteral(_)
| Type::ClassLiteral(_)
| Type::GenericAlias(_)
| Type::FunctionLiteral(_) => {
write!(f, "Literal[{representation}]")
}
@@ -69,20 +72,21 @@ impl Display for DisplayRepresentation<'_> {
match self.ty {
Type::Dynamic(dynamic) => dynamic.fmt(f),
Type::Never => f.write_str("Never"),
Type::Instance(InstanceType { class }) => {
let representation = match class.known(self.db) {
Some(KnownClass::NoneType) => "None",
Some(KnownClass::NoDefaultType) => "NoDefault",
_ => class.name(self.db),
};
f.write_str(representation)
}
Type::Instance(InstanceType { class }) => match (class, class.known(self.db)) {
(_, Some(KnownClass::NoneType)) => f.write_str("None"),
(_, Some(KnownClass::NoDefaultType)) => f.write_str("NoDefault"),
(ClassType::NonGeneric(class), _) => f.write_str(&class.class(self.db).name),
(ClassType::Generic(alias), _) => write!(f, "{}", alias.display(self.db)),
},
Type::PropertyInstance(_) => f.write_str("property"),
Type::ModuleLiteral(module) => {
write!(f, "<module '{}'>", module.module(self.db).name())
}
// TODO functions and classes should display using a fully qualified name
Type::ClassLiteral(ClassLiteralType { class }) => f.write_str(class.name(self.db)),
Type::ClassLiteral(class) => f.write_str(class.name(self.db)),
Type::GenericAlias(generic) => {
write!(f, "{}", generic.display(self.db))
}
Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() {
// Only show the bare class name here; ClassBase::display would render this as
// type[<class 'Foo'>] instead of type[Foo].
@@ -90,21 +94,49 @@ impl Display for DisplayRepresentation<'_> {
ClassBase::Dynamic(dynamic) => write!(f, "type[{dynamic}]"),
},
Type::KnownInstance(known_instance) => f.write_str(known_instance.repr(self.db)),
Type::FunctionLiteral(function) => f.write_str(function.name(self.db)),
Type::FunctionLiteral(function) => {
f.write_str(function.name(self.db))?;
if let Some(specialization) = function.specialization(self.db) {
specialization.display_short(self.db).fmt(f)?;
}
Ok(())
}
Type::Callable(callable) => callable.signature(self.db).display(self.db).fmt(f),
Type::BoundMethod(bound_method) => {
let function = bound_method.function(self.db);
let self_instance = bound_method.self_instance(self.db);
let self_instance_specialization = match self_instance {
Type::Instance(InstanceType {
class: ClassType::Generic(alias),
}) => Some(alias.specialization(self.db)),
_ => None,
};
let specialization = match function.specialization(self.db) {
Some(specialization)
if self_instance_specialization.is_none_or(|sis| specialization == sis) =>
{
specialization.display_short(self.db).to_string()
}
_ => String::new(),
};
write!(
f,
"<bound method `{method}` of `{instance}`>",
method = bound_method.function(self.db).name(self.db),
instance = bound_method.self_instance(self.db).display(self.db)
"<bound method `{method}{specialization}` of `{instance}`>",
method = function.name(self.db),
instance = bound_method.self_instance(self.db).display(self.db),
)
}
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => {
write!(
f,
"<method-wrapper `__get__` of `{function}`>",
function = function.name(self.db)
"<method-wrapper `__get__` of `{function}{specialization}`>",
function = function.name(self.db),
specialization = if let Some(specialization) = function.specialization(self.db)
{
specialization.display_short(self.db).to_string()
} else {
String::new()
},
)
}
Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(_)) => {
@@ -174,6 +206,86 @@ impl Display for DisplayRepresentation<'_> {
}
}
impl<'db> GenericAlias<'db> {
pub(crate) fn display(&'db self, db: &'db dyn Db) -> DisplayGenericAlias<'db> {
DisplayGenericAlias {
origin: self.origin(db),
specialization: self.specialization(db),
db,
}
}
}
pub(crate) struct DisplayGenericAlias<'db> {
origin: GenericClass<'db>,
specialization: Specialization<'db>,
db: &'db dyn Db,
}
impl Display for DisplayGenericAlias<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"{origin}{specialization}",
origin = self.origin.class(self.db).name,
specialization = self.specialization.display_short(self.db),
)
}
}
impl<'db> Specialization<'db> {
/// Renders the specialization in full, e.g. `{T = int, U = str}`.
pub fn display(&'db self, db: &'db dyn Db) -> DisplaySpecialization<'db> {
DisplaySpecialization {
typevars: self.generic_context(db).variables(db),
types: self.types(db),
db,
full: true,
}
}
/// Renders the specialization as it would appear in a subscript expression, e.g. `[int, str]`.
pub fn display_short(&'db self, db: &'db dyn Db) -> DisplaySpecialization<'db> {
DisplaySpecialization {
typevars: self.generic_context(db).variables(db),
types: self.types(db),
db,
full: false,
}
}
}
pub struct DisplaySpecialization<'db> {
typevars: &'db [TypeVarInstance<'db>],
types: &'db [Type<'db>],
db: &'db dyn Db,
full: bool,
}
impl Display for DisplaySpecialization<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if self.full {
f.write_char('{')?;
for (idx, (var, ty)) in self.typevars.iter().zip(self.types).enumerate() {
if idx > 0 {
f.write_str(", ")?;
}
write!(f, "{} = {}", var.name(self.db), ty.display(self.db))?;
}
f.write_char('}')
} else {
f.write_char('[')?;
for (idx, (_, ty)) in self.typevars.iter().zip(self.types).enumerate() {
if idx > 0 {
f.write_str(", ")?;
}
write!(f, "{}", ty.display(self.db))?;
}
f.write_char(']')
}
}
}
impl<'db> Signature<'db> {
fn display(&'db self, db: &'db dyn Db) -> DisplaySignature<'db> {
DisplaySignature {

View File

@@ -0,0 +1,148 @@
use ruff_python_ast as ast;
use crate::semantic_index::SemanticIndex;
use crate::types::signatures::{Parameter, Parameters, Signature};
use crate::types::{
declaration_type, KnownInstanceType, Type, TypeVarBoundOrConstraints, TypeVarInstance,
UnionType,
};
use crate::Db;
/// A list of formal type variables for a generic function, class, or type alias.
#[salsa::tracked(debug)]
pub struct GenericContext<'db> {
#[return_ref]
pub(crate) variables: Box<[TypeVarInstance<'db>]>,
}
impl<'db> GenericContext<'db> {
pub(crate) fn from_type_params(
db: &'db dyn Db,
index: &'db SemanticIndex<'db>,
type_params_node: &ast::TypeParams,
) -> Self {
let variables = type_params_node
.iter()
.filter_map(|type_param| Self::variable_from_type_param(db, index, type_param))
.collect();
Self::new(db, variables)
}
fn variable_from_type_param(
db: &'db dyn Db,
index: &'db SemanticIndex<'db>,
type_param_node: &ast::TypeParam,
) -> Option<TypeVarInstance<'db>> {
match type_param_node {
ast::TypeParam::TypeVar(node) => {
let definition = index.expect_single_definition(node);
let Type::KnownInstance(KnownInstanceType::TypeVar(typevar)) =
declaration_type(db, definition).inner_type()
else {
panic!("typevar should be inferred as a TypeVarInstance");
};
Some(typevar)
}
// TODO: Support these!
ast::TypeParam::ParamSpec(_) => None,
ast::TypeParam::TypeVarTuple(_) => None,
}
}
pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> {
let parameters = Parameters::new(
self.variables(db)
.iter()
.map(|typevar| Self::parameter_from_typevar(db, *typevar)),
);
Signature::new(parameters, None)
}
fn parameter_from_typevar(db: &'db dyn Db, typevar: TypeVarInstance<'db>) -> Parameter<'db> {
let mut parameter = Parameter::positional_only(Some(typevar.name(db).clone()));
match typevar.bound_or_constraints(db) {
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
// TODO: This should be a type form.
parameter = parameter.with_annotated_type(bound);
}
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
// TODO: This should be a new type variant where only these exact types are
// assignable, and not subclasses of them, nor a union of them.
parameter = parameter
.with_annotated_type(UnionType::from_elements(db, constraints.iter(db)));
}
None => {}
}
parameter
}
pub(crate) fn default_specialization(self, db: &'db dyn Db) -> Specialization<'db> {
let types = self
.variables(db)
.iter()
.map(|typevar| typevar.default_ty(db).unwrap_or(Type::unknown()))
.collect();
self.specialize(db, types)
}
pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> Specialization<'db> {
let types = vec![Type::unknown(); self.variables(db).len()];
self.specialize(db, types.into())
}
pub(crate) fn specialize(
self,
db: &'db dyn Db,
types: Box<[Type<'db>]>,
) -> Specialization<'db> {
Specialization::new(db, self, types)
}
}
/// An assignment of a specific type to each type variable in a generic scope.
#[salsa::tracked(debug)]
pub struct Specialization<'db> {
pub(crate) generic_context: GenericContext<'db>,
#[return_ref]
pub(crate) types: Box<[Type<'db>]>,
}
impl<'db> Specialization<'db> {
/// Applies a specialization to this specialization. This is used, for instance, when a generic
/// class inherits from a generic alias:
///
/// ```py
/// class A[T]: ...
/// class B[U](A[U]): ...
/// ```
///
/// `B` is a generic class, whose MRO includes the generic alias `A[U]`, which specializes `A`
/// with the specialization `{T: U}`. If `B` is specialized to `B[int]`, with specialization
/// `{U: int}`, we can apply the second specialization to the first, resulting in `T: int`.
/// That lets us produce the generic alias `A[int]`, which is the corresponding entry in the
/// MRO of `B[int]`.
pub(crate) fn apply_specialization(self, db: &'db dyn Db, other: Specialization<'db>) -> Self {
let types = self
.types(db)
.into_iter()
.map(|ty| ty.apply_specialization(db, other))
.collect();
Specialization::new(db, self.generic_context(db), types)
}
pub(crate) fn normalized(self, db: &'db dyn Db) -> Self {
let types = self.types(db).iter().map(|ty| ty.normalized(db)).collect();
Self::new(db, self.generic_context(db), types)
}
/// Returns the type that a typevar is specialized to, or None if the typevar isn't part of
/// this specialization.
pub(crate) fn get(self, db: &'db dyn Db, typevar: TypeVarInstance<'db>) -> Option<Type<'db>> {
self.generic_context(db)
.variables(db)
.into_iter()
.zip(self.types(db))
.find(|(var, _)| **var == typevar)
.map(|(_, ty)| *ty)
}
}

View File

@@ -37,6 +37,7 @@ use itertools::{Either, Itertools};
use ruff_db::diagnostic::{DiagnosticId, Severity};
use ruff_db::files::File;
use ruff_db::parsed::parsed_module;
use ruff_python_ast::visitor::{walk_expr, Visitor};
use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext};
use ruff_text_size::{Ranged, TextRange};
use rustc_hash::{FxHashMap, FxHashSet};
@@ -61,6 +62,7 @@ use crate::symbol::{
typing_extensions_symbol, Boundness, LookupError,
};
use crate::types::call::{Argument, Bindings, CallArgumentTypes, CallArguments, CallError};
use crate::types::class::MetaclassErrorKind;
use crate::types::diagnostic::{
report_implicit_return_type, report_invalid_arguments_to_annotated,
report_invalid_arguments_to_callable, report_invalid_assignment,
@@ -73,18 +75,18 @@ use crate::types::diagnostic::{
INVALID_TYPE_VARIABLE_CONSTRAINTS, POSSIBLY_UNBOUND_IMPORT, UNDEFINED_REVEAL,
UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT, UNSUPPORTED_OPERATOR,
};
use crate::types::generics::GenericContext;
use crate::types::mro::MroErrorKind;
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{
class::MetaclassErrorKind, todo_type, Class, DynamicType, FunctionType, IntersectionBuilder,
IntersectionType, KnownClass, KnownFunction, KnownInstanceType, MetaclassCandidate, Parameter,
ParameterForm, Parameters, SliceLiteralType, SubclassOfType, Symbol, SymbolAndQualifiers,
todo_type, CallDunderError, CallableSignature, CallableType, Class, ClassLiteralType,
DynamicType, FunctionDecorators, FunctionType, GenericAlias, GenericClass, IntersectionBuilder,
IntersectionType, KnownClass, KnownFunction, KnownInstanceType, MemberLookupPolicy,
MetaclassCandidate, NonGenericClass, Parameter, ParameterForm, Parameters, Signature,
Signatures, SliceLiteralType, StringLiteralType, SubclassOfType, Symbol, SymbolAndQualifiers,
Truthiness, TupleType, Type, TypeAliasType, TypeAndQualifiers, TypeArrayDisplay,
TypeQualifiers, TypeVarBoundOrConstraints, TypeVarInstance, UnionBuilder, UnionType,
};
use crate::types::{
CallableType, FunctionDecorators, MemberLookupPolicy, Signature, StringLiteralType,
};
use crate::unpack::{Unpack, UnpackPosition};
use crate::util::subscript::{PyIndex, PySlice};
use crate::Db;
@@ -102,7 +104,6 @@ use super::slots::check_class_slots;
use super::string_annotation::{
parse_string_annotation, BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION,
};
use super::{CallDunderError, ClassLiteralType};
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the
@@ -708,7 +709,7 @@ impl<'db> TypeInferenceBuilder<'db> {
if let DefinitionKind::Class(class) = definition.kind(self.db()) {
ty.inner_type()
.into_class_literal()
.map(|ty| (ty.class(), class.node()))
.map(|ty| (ty, class.node()))
} else {
None
}
@@ -736,10 +737,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// (2) Check for classes that inherit from `@final` classes
for (i, base_class) in class.explicit_bases(self.db()).iter().enumerate() {
// dynamic/unknown bases are never `@final`
let Some(base_class) = base_class
.into_class_literal()
.map(super::class::ClassLiteralType::class)
else {
let Some(base_class) = base_class.into_class_literal() else {
continue;
};
if !base_class.is_final(self.db()) {
@@ -757,7 +755,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// (3) Check that the class's MRO is resolvable
match class.try_mro(self.db()).as_ref() {
match class.try_mro(self.db(), None).as_ref() {
Err(mro_error) => {
match mro_error.reason() {
MroErrorKind::DuplicateBases(duplicates) => {
@@ -983,7 +981,7 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::BooleanLiteral(_) | Type::IntLiteral(_) => {}
Type::Instance(instance)
if matches!(
instance.class().known(self.db()),
instance.class.known(self.db()),
Some(KnownClass::Float | KnownClass::Int | KnownClass::Bool)
) => {}
_ => return false,
@@ -1415,7 +1413,7 @@ impl<'db> TypeInferenceBuilder<'db> {
continue;
}
} else if let Type::ClassLiteral(class) = decorator_ty {
if class.class.is_known(self.db(), KnownClass::Classmethod) {
if class.is_known(self.db(), KnownClass::Classmethod) {
function_decorators |= FunctionDecorators::CLASSMETHOD;
continue;
}
@@ -1453,12 +1451,15 @@ impl<'db> TypeInferenceBuilder<'db> {
.node_scope(NodeWithScopeRef::Function(function))
.to_scope_id(self.db(), self.file());
let specialization = None;
let mut inferred_ty = Type::FunctionLiteral(FunctionType::new(
self.db(),
&name.id,
function_kind,
body_scope,
function_decorators,
specialization,
));
for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() {
@@ -1695,6 +1696,10 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_decorator(decorator);
}
let generic_context = type_params.as_ref().map(|type_params| {
GenericContext::from_type_params(self.db(), self.index, type_params)
});
let body_scope = self
.index
.node_scope(NodeWithScopeRef::Class(class_node))
@@ -1702,8 +1707,18 @@ impl<'db> TypeInferenceBuilder<'db> {
let maybe_known_class = KnownClass::try_from_file_and_name(self.db(), self.file(), name);
let class = Class::new(self.db(), &name.id, body_scope, maybe_known_class);
let class_ty = Type::class_literal(class);
let class = Class {
name: name.id.clone(),
body_scope,
known: maybe_known_class,
};
let class_literal = match generic_context {
Some(generic_context) => {
ClassLiteralType::Generic(GenericClass::new(self.db(), class, generic_context))
}
None => ClassLiteralType::NonGeneric(NonGenericClass::new(self.db(), class)),
};
let class_ty = Type::from(class_literal);
self.add_declaration_with_binding(
class_node.into(),
@@ -1719,8 +1734,12 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// Inference of bases deferred in stubs
// TODO also defer stringified generic type parameters
if self.are_all_types_deferred() {
// TODO: Only defer the references that are actually string literals, instead of
// deferring the entire class definition if a string literal occurs anywhere in the
// base class list.
if self.are_all_types_deferred()
|| class_node.bases().iter().any(contains_string_literal)
{
self.types.deferred.insert(definition);
} else {
for base in class_node.bases() {
@@ -2532,7 +2551,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
Type::ClassLiteral(..) | Type::SubclassOf(..) => {
Type::ClassLiteral(..) | Type::GenericAlias(..) | Type::SubclassOf(..) => {
match object_ty.class_member(db, attribute.into()) {
SymbolAndQualifiers {
symbol: Symbol::Type(meta_attr_ty, meta_attr_boundness),
@@ -2856,10 +2875,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// Handle various singletons.
if let Type::Instance(instance) = declared_ty.inner_type() {
if instance
.class()
.is_known(self.db(), KnownClass::SpecialForm)
{
if instance.class.is_known(self.db(), KnownClass::SpecialForm) {
if let Some(name_expr) = target.as_name_expr() {
if let Some(known_instance) = KnownInstanceType::try_from_file_and_name(
self.db(),
@@ -4018,9 +4034,12 @@ impl<'db> TypeInferenceBuilder<'db> {
let class = match callable_type {
Type::SubclassOf(subclass_of_type) => match subclass_of_type.subclass_of() {
ClassBase::Dynamic(_) => None,
ClassBase::Class(class) => Some(class),
ClassBase::Class(class) => {
let (class_literal, _) = class.class_literal(self.db());
Some(class_literal)
}
},
Type::ClassLiteral(ClassLiteralType { class }) => Some(class),
Type::ClassLiteral(class) => Some(class),
_ => None,
};
@@ -4475,7 +4494,7 @@ impl<'db> TypeInferenceBuilder<'db> {
LookupError::Unbound(_) => {
let bound_on_instance = match value_type {
Type::ClassLiteral(class) => {
!class.class().instance_member(db, attr).symbol.is_unbound()
!class.instance_member(db, None, attr).symbol.is_unbound()
}
Type::SubclassOf(subclass_of @ SubclassOfType { .. }) => {
match subclass_of.subclass_of() {
@@ -4588,6 +4607,7 @@ impl<'db> TypeInferenceBuilder<'db> {
| Type::BoundMethod(_)
| Type::ModuleLiteral(_)
| Type::ClassLiteral(_)
| Type::GenericAlias(_)
| Type::SubclassOf(_)
| Type::Instance(_)
| Type::KnownInstance(_)
@@ -4864,6 +4884,7 @@ impl<'db> TypeInferenceBuilder<'db> {
| Type::MethodWrapper(_)
| Type::ModuleLiteral(_)
| Type::ClassLiteral(_)
| Type::GenericAlias(_)
| Type::SubclassOf(_)
| Type::Instance(_)
| Type::KnownInstance(_)
@@ -4885,6 +4906,7 @@ impl<'db> TypeInferenceBuilder<'db> {
| Type::MethodWrapper(_)
| Type::ModuleLiteral(_)
| Type::ClassLiteral(_)
| Type::GenericAlias(_)
| Type::SubclassOf(_)
| Type::Instance(_)
| Type::KnownInstance(_)
@@ -5452,9 +5474,7 @@ impl<'db> TypeInferenceBuilder<'db> {
range,
),
(Type::Tuple(_), Type::Instance(instance))
if instance
.class()
.is_known(self.db(), KnownClass::VersionInfo) =>
if instance.class.is_known(self.db(), KnownClass::VersionInfo) =>
{
self.infer_binary_type_comparison(
left,
@@ -5464,9 +5484,7 @@ impl<'db> TypeInferenceBuilder<'db> {
)
}
(Type::Instance(instance), Type::Tuple(_))
if instance
.class()
.is_known(self.db(), KnownClass::VersionInfo) =>
if instance.class.is_known(self.db(), KnownClass::VersionInfo) =>
{
self.infer_binary_type_comparison(
Type::version_info_tuple(self.db()),
@@ -5774,11 +5792,71 @@ impl<'db> TypeInferenceBuilder<'db> {
ctx: _,
} = subscript;
// HACK ALERT: If we are subscripting a generic class, short-circuit the rest of the
// subscript inference logic and treat this as an explicit specialization.
// TODO: Move this logic into a custom callable, and update `find_name_in_mro` to return
// this callable as the `__class_getitem__` method on `type`. That probably requires
// updating all of the subscript logic below to use custom callables for all of the _other_
// special cases, too.
let value_ty = self.infer_expression(value);
if let Type::ClassLiteral(ClassLiteralType::Generic(generic_class)) = value_ty {
return self.infer_explicit_class_specialization(
subscript,
value_ty,
generic_class,
slice,
);
}
let slice_ty = self.infer_expression(slice);
self.infer_subscript_expression_types(value, value_ty, slice_ty)
}
fn infer_explicit_class_specialization(
&mut self,
subscript: &ast::ExprSubscript,
value_ty: Type<'db>,
generic_class: GenericClass<'db>,
slice_node: &ast::Expr,
) -> Type<'db> {
let mut call_argument_types = match slice_node {
ast::Expr::Tuple(tuple) => CallArgumentTypes::positional(
tuple.elts.iter().map(|elt| self.infer_type_expression(elt)),
),
_ => CallArgumentTypes::positional([self.infer_type_expression(slice_node)]),
};
let generic_context = generic_class.generic_context(self.db());
let signatures = Signatures::single(CallableSignature::single(
value_ty,
generic_context.signature(self.db()),
));
let bindings = match Bindings::match_parameters(signatures, &mut call_argument_types)
.check_types(self.db(), &mut call_argument_types)
{
Ok(bindings) => bindings,
Err(CallError(_, bindings)) => {
bindings.report_diagnostics(&self.context, subscript.into());
return Type::unknown();
}
};
let callable = bindings
.into_iter()
.next()
.expect("valid bindings should have one callable");
let (_, overload) = callable
.matching_overload()
.expect("valid bindings should have matching overload");
let specialization = generic_context.specialize(
self.db(),
overload
.parameter_types()
.iter()
.map(|ty| ty.unwrap_or(Type::unknown()))
.collect(),
);
Type::from(GenericAlias::new(self.db(), generic_class, specialization))
}
fn infer_subscript_expression_types(
&mut self,
value_node: &ast::Expr,
@@ -5789,16 +5867,12 @@ impl<'db> TypeInferenceBuilder<'db> {
(
Type::Instance(instance),
Type::IntLiteral(_) | Type::BooleanLiteral(_) | Type::SliceLiteral(_),
) if instance
.class()
.is_known(self.db(), KnownClass::VersionInfo) =>
{
self.infer_subscript_expression_types(
) if instance.class.is_known(self.db(), KnownClass::VersionInfo) => self
.infer_subscript_expression_types(
value_node,
Type::version_info_tuple(self.db()),
slice_ty,
)
}
),
// Ex) Given `("a", "b", "c", "d")[1]`, return `"b"`
(Type::Tuple(tuple_ty), Type::IntLiteral(int)) if i32::try_from(int).is_ok() => {
@@ -6006,9 +6080,16 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
if matches!(value_ty, Type::ClassLiteral(class_literal) if class_literal.class().is_known(self.db(), KnownClass::Type))
{
return KnownClass::GenericAlias.to_instance(self.db());
if let Type::ClassLiteral(class) = value_ty {
if class.is_known(self.db(), KnownClass::Type) {
return KnownClass::GenericAlias.to_instance(self.db());
}
if let ClassLiteralType::Generic(_) = class {
// TODO: specialize the generic class using these explicit type
// variable assignments
return value_ty;
}
}
report_non_subscriptable(
@@ -6031,6 +6112,10 @@ impl<'db> TypeInferenceBuilder<'db> {
// TODO: proper support for generic classes
// For now, just infer `Sequence`, if we see something like `Sequence[str]`. This allows us
// to look up attributes on generic base classes, even if we don't understand generics yet.
// Note that this isn't handled by the clause up above for generic classes
// that use legacy type variables and an explicit `Generic` base class.
// Once we handle legacy typevars, this special case will be removed in
// favor of the specialization logic above.
value_ty
}
_ => Type::unknown(),
@@ -6063,7 +6148,7 @@ impl<'db> TypeInferenceBuilder<'db> {
},
Some(Type::BooleanLiteral(b)) => SliceArg::Arg(Some(i32::from(b))),
Some(Type::Instance(instance))
if instance.class().is_known(self.db(), KnownClass::NoneType) =>
if instance.class.is_known(self.db(), KnownClass::NoneType) =>
{
SliceArg::Arg(None)
}
@@ -6629,8 +6714,7 @@ impl<'db> TypeInferenceBuilder<'db> {
value_ty: Type<'db>,
) -> Type<'db> {
match value_ty {
Type::ClassLiteral(class_literal_ty) => match class_literal_ty.class().known(self.db())
{
Type::ClassLiteral(class_literal) => match class_literal.known(self.db()) {
Some(KnownClass::Tuple) => self.infer_tuple_type_expression(slice),
Some(KnownClass::Type) => self.infer_subclass_of_type_expression(slice),
_ => self.infer_subscript_type_expression(subscript, value_ty),
@@ -6732,14 +6816,14 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::Expr::Name(_) | ast::Expr::Attribute(_) => {
let name_ty = self.infer_expression(slice);
match name_ty {
Type::ClassLiteral(class_literal_ty) => {
if class_literal_ty
.class()
.is_known(self.db(), KnownClass::Any)
{
Type::ClassLiteral(class_literal) => {
if class_literal.is_known(self.db(), KnownClass::Any) {
SubclassOfType::subclass_of_any()
} else {
SubclassOfType::from(self.db(), class_literal_ty.class())
SubclassOfType::from(
self.db(),
class_literal.default_specialization(self.db()),
)
}
}
Type::KnownInstance(KnownInstanceType::Any) => {
@@ -6820,7 +6904,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} = subscript;
match value_ty {
Type::ClassLiteral(literal) if literal.class().is_known(self.db(), KnownClass::Any) => {
Type::ClassLiteral(literal) if literal.is_known(self.db(), KnownClass::Any) => {
self.context.report_lint(
&INVALID_TYPE_FORM,
subscript,
@@ -7516,6 +7600,21 @@ impl StringPartsCollector {
}
}
fn contains_string_literal(expr: &ast::Expr) -> bool {
struct ContainsStringLiteral(bool);
impl<'a> Visitor<'a> for ContainsStringLiteral {
fn visit_expr(&mut self, expr: &'a ast::Expr) {
self.0 |= matches!(expr, ast::Expr::StringLiteral(_));
walk_expr(self, expr);
}
}
let mut visitor = ContainsStringLiteral(false);
visitor.visit_expr(expr);
visitor.0
}
#[cfg(test)]
mod tests {
use crate::db::tests::{setup_db, TestDb};

View File

@@ -4,34 +4,58 @@ use std::ops::Deref;
use rustc_hash::FxHashSet;
use crate::types::class_base::ClassBase;
use crate::types::{Class, Type};
use crate::types::generics::Specialization;
use crate::types::{ClassLiteralType, ClassType, Type};
use crate::Db;
/// The inferred method resolution order of a given class.
///
/// See [`Class::iter_mro`] for more details.
/// An MRO cannot contain non-specialized generic classes. (This is why [`ClassBase`] contains a
/// [`ClassType`], not a [`ClassLiteralType`].) Any generic classes in a base class list are always
/// specialized — either because the class is explicitly specialized if there is a subscript
/// expression, or because we create the default specialization if there isn't.
///
/// The MRO of a non-specialized generic class can contain generic classes that are specialized
/// with a typevar from the inheriting class. When the inheriting class is specialized, the MRO of
/// the resulting generic alias will substitute those type variables accordingly. For instance, in
/// the following example, the MRO of `D[int]` includes `C[int]`, and the MRO of `D[U]` includes
/// `C[U]` (which is a generic alias, not a non-specialized generic class):
///
/// ```py
/// class C[T]: ...
/// class D[U](C[U]): ...
/// ```
///
/// See [`ClassType::iter_mro`] for more details.
#[derive(PartialEq, Eq, Clone, Debug, salsa::Update)]
pub(super) struct Mro<'db>(Box<[ClassBase<'db>]>);
impl<'db> Mro<'db> {
/// Attempt to resolve the MRO of a given class
/// Attempt to resolve the MRO of a given class. Because we derive the MRO from the list of
/// base classes in the class definition, this operation is performed on a [class
/// literal][ClassLiteralType], not a [class type][ClassType]. (You can _also_ get the MRO of a
/// class type, but this is done by first getting the MRO of the underlying class literal, and
/// specializing each base class as needed if the class type is a generic alias.)
///
/// In the event that a possible list of bases would (or could) lead to a
/// `TypeError` being raised at runtime due to an unresolvable MRO, we infer
/// the MRO of the class as being `[<the class in question>, Unknown, object]`.
/// This seems most likely to reduce the possibility of cascading errors
/// elsewhere.
/// In the event that a possible list of bases would (or could) lead to a `TypeError` being
/// raised at runtime due to an unresolvable MRO, we infer the MRO of the class as being `[<the
/// class in question>, Unknown, object]`. This seems most likely to reduce the possibility of
/// cascading errors elsewhere. (For a generic class, the first entry in this fallback MRO uses
/// the default specialization of the class's type variables.)
///
/// (We emit a diagnostic warning about the runtime `TypeError` in
/// [`super::infer::TypeInferenceBuilder::infer_region_scope`].)
pub(super) fn of_class(db: &'db dyn Db, class: Class<'db>) -> Result<Self, MroError<'db>> {
Self::of_class_impl(db, class).map_err(|error_kind| MroError {
kind: error_kind,
fallback_mro: Self::from_error(db, class),
pub(super) fn of_class(
db: &'db dyn Db,
class: ClassLiteralType<'db>,
specialization: Option<Specialization<'db>>,
) -> Result<Self, MroError<'db>> {
Self::of_class_impl(db, class, specialization).map_err(|err| {
err.into_mro_error(db, class.apply_optional_specialization(db, specialization))
})
}
pub(super) fn from_error(db: &'db dyn Db, class: Class<'db>) -> Self {
pub(super) fn from_error(db: &'db dyn Db, class: ClassType<'db>) -> Self {
Self::from([
ClassBase::Class(class),
ClassBase::unknown(),
@@ -39,20 +63,30 @@ impl<'db> Mro<'db> {
])
}
fn of_class_impl(db: &'db dyn Db, class: Class<'db>) -> Result<Self, MroErrorKind<'db>> {
fn of_class_impl(
db: &'db dyn Db,
class: ClassLiteralType<'db>,
specialization: Option<Specialization<'db>>,
) -> Result<Self, MroErrorKind<'db>> {
let class_bases = class.explicit_bases(db);
if !class_bases.is_empty() && class.inheritance_cycle(db).is_some() {
// We emit errors for cyclically defined classes elsewhere.
// It's important that we don't even try to infer the MRO for a cyclically defined class,
// or we'll end up in an infinite loop.
return Ok(Mro::from_error(db, class));
return Ok(Mro::from_error(
db,
class.apply_optional_specialization(db, specialization),
));
}
match class_bases {
// `builtins.object` is the special case:
// the only class in Python that has an MRO with length <2
[] if class.is_object(db) => Ok(Self::from([ClassBase::Class(class)])),
[] if class.is_object(db) => Ok(Self::from([
// object is not generic, so the default specialization should be a no-op
ClassBase::Class(class.apply_optional_specialization(db, specialization)),
])),
// All other classes in Python have an MRO with length >=2.
// Even if a class has no explicit base classes,
@@ -67,7 +101,10 @@ impl<'db> Mro<'db> {
// >>> Foo.__mro__
// (<class '__main__.Foo'>, <class 'object'>)
// ```
[] => Ok(Self::from([ClassBase::Class(class), ClassBase::object(db)])),
[] => Ok(Self::from([
ClassBase::Class(class.apply_optional_specialization(db, specialization)),
ClassBase::object(db),
])),
// Fast path for a class that has only a single explicit base.
//
@@ -77,9 +114,11 @@ impl<'db> Mro<'db> {
[single_base] => ClassBase::try_from_type(db, *single_base).map_or_else(
|| Err(MroErrorKind::InvalidBases(Box::from([(0, *single_base)]))),
|single_base| {
Ok(std::iter::once(ClassBase::Class(class))
.chain(single_base.mro(db))
.collect())
Ok(std::iter::once(ClassBase::Class(
class.apply_optional_specialization(db, specialization),
))
.chain(single_base.mro(db))
.collect())
},
),
@@ -103,7 +142,9 @@ impl<'db> Mro<'db> {
return Err(MroErrorKind::InvalidBases(invalid_bases.into_boxed_slice()));
}
let mut seqs = vec![VecDeque::from([ClassBase::Class(class)])];
let mut seqs = vec![VecDeque::from([ClassBase::Class(
class.apply_optional_specialization(db, specialization),
)])];
for base in &valid_bases {
seqs.push(base.mro(db).collect());
}
@@ -118,7 +159,8 @@ impl<'db> Mro<'db> {
.filter_map(|(index, base)| Some((index, base.into_class()?)))
{
if !seen_bases.insert(base) {
duplicate_bases.push((index, base));
let (base_class_literal, _) = base.class_literal(db);
duplicate_bases.push((index, base_class_literal));
}
}
@@ -178,12 +220,15 @@ impl<'db> FromIterator<ClassBase<'db>> for Mro<'db> {
///
/// Even for first-party code, where we will have to resolve the MRO for every class we encounter,
/// loading the cached MRO comes with a certain amount of overhead, so it's best to avoid calling the
/// Salsa-tracked [`Class::try_mro`] method unless it's absolutely necessary.
/// Salsa-tracked [`ClassLiteralType::try_mro`] method unless it's absolutely necessary.
pub(super) struct MroIterator<'db> {
db: &'db dyn Db,
/// The class whose MRO we're iterating over
class: Class<'db>,
class: ClassLiteralType<'db>,
/// The specialization to apply to each MRO element, if any
specialization: Option<Specialization<'db>>,
/// Whether or not we've already yielded the first element of the MRO
first_element_yielded: bool,
@@ -197,10 +242,15 @@ pub(super) struct MroIterator<'db> {
}
impl<'db> MroIterator<'db> {
pub(super) fn new(db: &'db dyn Db, class: Class<'db>) -> Self {
pub(super) fn new(
db: &'db dyn Db,
class: ClassLiteralType<'db>,
specialization: Option<Specialization<'db>>,
) -> Self {
Self {
db,
class,
specialization,
first_element_yielded: false,
subsequent_elements: None,
}
@@ -211,7 +261,7 @@ impl<'db> MroIterator<'db> {
fn full_mro_except_first_element(&mut self) -> impl Iterator<Item = ClassBase<'db>> + '_ {
self.subsequent_elements
.get_or_insert_with(|| {
let mut full_mro_iter = match self.class.try_mro(self.db) {
let mut full_mro_iter = match self.class.try_mro(self.db, self.specialization) {
Ok(mro) => mro.iter(),
Err(error) => error.fallback_mro().iter(),
};
@@ -228,7 +278,10 @@ impl<'db> Iterator for MroIterator<'db> {
fn next(&mut self) -> Option<Self::Item> {
if !self.first_element_yielded {
self.first_element_yielded = true;
return Some(ClassBase::Class(self.class));
return Some(ClassBase::Class(
self.class
.apply_optional_specialization(self.db, self.specialization),
));
}
self.full_mro_except_first_element().next()
}
@@ -273,11 +326,11 @@ pub(super) enum MroErrorKind<'db> {
/// The class has one or more duplicate bases.
///
/// This variant records the indices and [`Class`]es
/// This variant records the indices and [`ClassLiteralType`]s
/// of the duplicate bases. The indices are the indices of nodes
/// in the bases list of the class's [`StmtClassDef`](ruff_python_ast::StmtClassDef) node.
/// Each index is the index of a node representing a duplicate base.
DuplicateBases(Box<[(usize, Class<'db>)]>),
DuplicateBases(Box<[(usize, ClassLiteralType<'db>)]>),
/// The MRO is otherwise unresolvable through the C3-merge algorithm.
///
@@ -285,6 +338,15 @@ pub(super) enum MroErrorKind<'db> {
UnresolvableMro { bases_list: Box<[ClassBase<'db>]> },
}
impl<'db> MroErrorKind<'db> {
pub(super) fn into_mro_error(self, db: &'db dyn Db, class: ClassType<'db>) -> MroError<'db> {
MroError {
kind: self,
fallback_mro: Mro::from_error(db, class),
}
}
}
/// Implementation of the [C3-merge algorithm] for calculating a Python class's
/// [method resolution order].
///

View File

@@ -159,10 +159,10 @@ impl KnownConstraintFunction {
Type::ClassLiteral(class_literal) => {
// At runtime (on Python 3.11+), this will return `True` for classes that actually
// do inherit `typing.Any` and `False` otherwise. We could accurately model that?
if class_literal.class().is_known(db, KnownClass::Any) {
if class_literal.is_known(db, KnownClass::Any) {
None
} else {
Some(constraint_fn(class_literal.class()))
Some(constraint_fn(class_literal.default_specialization(db)))
}
}
Type::SubclassOf(subclass_of_ty) => {
@@ -473,8 +473,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
range: _,
},
}) if keywords.is_empty() => {
let Type::ClassLiteral(ClassLiteralType { class: rhs_class }) = rhs_ty else {
continue;
let rhs_class = match rhs_ty {
Type::ClassLiteral(class) => class,
Type::GenericAlias(alias) => {
ClassLiteralType::Generic(alias.origin(self.db))
}
_ => {
continue;
}
};
let [ast::Expr::Name(ast::ExprName { id, .. })] = &**args else {
@@ -496,10 +502,13 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
if callable_type
.into_class_literal()
.is_some_and(|c| c.class().is_known(self.db, KnownClass::Type))
.is_some_and(|c| c.is_known(self.db, KnownClass::Type))
{
let symbol = self.expect_expr_name_symbol(id);
constraints.insert(symbol, Type::instance(rhs_class));
constraints.insert(
symbol,
Type::instance(rhs_class.unknown_specialization(self.db)),
);
}
}
_ => {}
@@ -550,7 +559,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
Type::ClassLiteral(class_type)
if expr_call.arguments.args.len() == 1
&& expr_call.arguments.keywords.is_empty()
&& class_type.class().is_known(self.db, KnownClass::Bool) =>
&& class_type.is_known(self.db, KnownClass::Bool) =>
{
self.evaluate_expression_node_predicate(
&expr_call.arguments.args[0],

View File

@@ -11,6 +11,8 @@ use ruff_python_ast::name::Name;
/// A test representation of a type that can be transformed unambiguously into a real Type,
/// given a db.
///
/// TODO: We should add some variants that exercise generic classes and specializations thereof.
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum Ty {
Never,
@@ -167,7 +169,7 @@ impl Ty {
.symbol
.expect_type()
.expect_class_literal()
.class,
.default_specialization(db),
),
Ty::SubclassOfAbcClass(s) => SubclassOfType::from(
db,
@@ -175,7 +177,7 @@ impl Ty {
.symbol
.expect_type()
.expect_class_literal()
.class,
.default_specialization(db),
),
Ty::AlwaysTruthy => Type::AlwaysTruthy,
Ty::AlwaysFalsy => Type::AlwaysFalsy,

View File

@@ -14,6 +14,7 @@ use smallvec::{smallvec, SmallVec};
use super::{definition_expression_type, DynamicType, Type};
use crate::semantic_index::definition::Definition;
use crate::types::generics::Specialization;
use crate::types::todo_type;
use crate::Db;
use ruff_python_ast::{self as ast, name::Name};
@@ -261,6 +262,17 @@ impl<'db> Signature<'db> {
}
}
pub(crate) fn apply_specialization(
&mut self,
db: &'db dyn Db,
specialization: Specialization<'db>,
) {
self.parameters.apply_specialization(db, specialization);
self.return_ty = self
.return_ty
.map(|ty| ty.apply_specialization(db, specialization));
}
/// Return the parameters in this signature.
pub(crate) fn parameters(&self) -> &Parameters<'db> {
&self.parameters
@@ -445,6 +457,12 @@ impl<'db> Parameters<'db> {
)
}
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
self.value
.iter_mut()
.for_each(|param| param.apply_specialization(db, specialization));
}
pub(crate) fn len(&self) -> usize {
self.value.len()
}
@@ -606,6 +624,13 @@ impl<'db> Parameter<'db> {
self
}
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
self.annotated_type = self
.annotated_type
.map(|ty| ty.apply_specialization(db, specialization));
self.kind.apply_specialization(db, specialization);
}
/// Strip information from the parameter so that two equivalent parameters compare equal.
/// Normalize nested unions and intersections in the annotated type, if any.
///
@@ -792,6 +817,19 @@ pub(crate) enum ParameterKind<'db> {
},
}
impl<'db> ParameterKind<'db> {
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
match self {
Self::PositionalOnly { default_type, .. }
| Self::PositionalOrKeyword { default_type, .. }
| Self::KeywordOnly { default_type, .. } => {
*default_type = default_type.map(|ty| ty.apply_specialization(db, specialization));
}
Self::Variadic { .. } | Self::KeywordVariadic { .. } => {}
}
}
}
/// Whether a parameter is used as a value or a type form.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub(crate) enum ParameterForm {

View File

@@ -4,7 +4,7 @@ use crate::db::Db;
use crate::symbol::{Boundness, Symbol};
use crate::types::class_base::ClassBase;
use crate::types::diagnostic::report_base_with_incompatible_slots;
use crate::types::{Class, ClassLiteralType, Type};
use crate::types::{ClassLiteralType, Type};
use super::InferContext;
@@ -23,7 +23,7 @@ enum SlotsKind {
}
impl SlotsKind {
fn from(db: &dyn Db, base: Class) -> Self {
fn from(db: &dyn Db, base: ClassLiteralType) -> Self {
let Symbol::Type(slots_ty, bound) = base.own_class_member(db, "__slots__").symbol else {
return Self::NotSpecified;
};
@@ -50,7 +50,11 @@ impl SlotsKind {
}
}
pub(super) fn check_class_slots(context: &InferContext, class: Class, node: &ast::StmtClassDef) {
pub(super) fn check_class_slots(
context: &InferContext,
class: ClassLiteralType,
node: &ast::StmtClassDef,
) {
let db = context.db();
let mut first_with_solid_base = None;
@@ -58,16 +62,17 @@ pub(super) fn check_class_slots(context: &InferContext, class: Class, node: &ast
let mut found_second = false;
for (index, base) in class.explicit_bases(db).iter().enumerate() {
let Type::ClassLiteral(ClassLiteralType { class: base }) = base else {
let Type::ClassLiteral(base) = base else {
continue;
};
let solid_base = base.iter_mro(db).find_map(|current| {
let solid_base = base.iter_mro(db, None).find_map(|current| {
let ClassBase::Class(current) = current else {
return None;
};
match SlotsKind::from(db, current) {
let (class_literal, _) = current.class_literal(db);
match SlotsKind::from(db, class_literal) {
SlotsKind::NotEmpty => Some(current),
SlotsKind::NotSpecified | SlotsKind::Empty => None,
SlotsKind::Dynamic => None,

View File

@@ -1,6 +1,6 @@
use crate::symbol::SymbolAndQualifiers;
use super::{ClassBase, ClassLiteralType, Db, KnownClass, MemberLookupPolicy, Type};
use super::{ClassBase, Db, KnownClass, MemberLookupPolicy, Type};
/// A type that represents `type[C]`, i.e. the class object `C` and class objects that are subclasses of `C`.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update)]
@@ -27,7 +27,7 @@ impl<'db> SubclassOfType<'db> {
ClassBase::Dynamic(_) => Type::SubclassOf(Self { subclass_of }),
ClassBase::Class(class) => {
if class.is_final(db) {
Type::ClassLiteral(ClassLiteralType { class })
Type::from(class)
} else if class.is_object(db) {
KnownClass::Type.to_instance(db)
} else {

View File

@@ -2,10 +2,7 @@ use std::cmp::Ordering;
use crate::db::Db;
use super::{
class_base::ClassBase, ClassLiteralType, DynamicType, InstanceType, KnownInstanceType,
TodoType, Type,
};
use super::{class_base::ClassBase, DynamicType, InstanceType, KnownInstanceType, TodoType, Type};
/// Return an [`Ordering`] that describes the canonical order in which two types should appear
/// in an [`crate::types::IntersectionType`] or a [`crate::types::UnionType`] in order for them
@@ -93,13 +90,14 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
(Type::ModuleLiteral(_), _) => Ordering::Less,
(_, Type::ModuleLiteral(_)) => Ordering::Greater,
(
Type::ClassLiteral(ClassLiteralType { class: left }),
Type::ClassLiteral(ClassLiteralType { class: right }),
) => left.cmp(right),
(Type::ClassLiteral(left), Type::ClassLiteral(right)) => left.cmp(right),
(Type::ClassLiteral(_), _) => Ordering::Less,
(_, Type::ClassLiteral(_)) => Ordering::Greater,
(Type::GenericAlias(left), Type::GenericAlias(right)) => left.cmp(right),
(Type::GenericAlias(_), _) => Ordering::Less,
(_, Type::GenericAlias(_)) => Ordering::Greater,
(Type::SubclassOf(left), Type::SubclassOf(right)) => {
match (left.subclass_of(), right.subclass_of()) {
(ClassBase::Class(left), ClassBase::Class(right)) => left.cmp(&right),