diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/issubclass.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/issubclass.md new file mode 100644 index 0000000000..43a18db0b1 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/issubclass.md @@ -0,0 +1,244 @@ +# Narrowing for `issubclass` checks + +Narrowing for `issubclass(class, classinfo)` expressions. + +## `classinfo` is a single type + +### Basic example + +```py +def flag() -> bool: ... + +t = int if flag() else str + +if issubclass(t, bytes): + reveal_type(t) # revealed: Never + +if issubclass(t, object): + reveal_type(t) # revealed: Literal[int, str] + +if issubclass(t, int): + reveal_type(t) # revealed: Literal[int] +else: + reveal_type(t) # revealed: Literal[str] + +if issubclass(t, str): + reveal_type(t) # revealed: Literal[str] + if issubclass(t, int): + reveal_type(t) # revealed: Never +``` + +### Proper narrowing in `elif` and `else` branches + +```py +def flag() -> bool: ... + +t = int if flag() else str if flag() else bytes + +if issubclass(t, int): + reveal_type(t) # revealed: Literal[int] +else: + reveal_type(t) # revealed: Literal[str, bytes] + +if issubclass(t, int): + reveal_type(t) # revealed: Literal[int] +elif issubclass(t, str): + reveal_type(t) # revealed: Literal[str] +else: + reveal_type(t) # revealed: Literal[bytes] +``` + +### Multiple derived classes + +```py +class Base: ... +class Derived1(Base): ... +class Derived2(Base): ... +class Unrelated: ... + +def flag() -> bool: ... + +t1 = Derived1 if flag() else Derived2 + +if issubclass(t1, Base): + reveal_type(t1) # revealed: Literal[Derived1, Derived2] + +if issubclass(t1, Derived1): + reveal_type(t1) # revealed: Literal[Derived1] +else: + reveal_type(t1) # revealed: Literal[Derived2] + +t2 = Derived1 if flag() else Base + +if issubclass(t2, Base): + reveal_type(t2) # revealed: Literal[Derived1, Base] + +t3 = Derived1 if flag() else Unrelated + +if issubclass(t3, Base): + reveal_type(t3) # revealed: Literal[Derived1] +else: + reveal_type(t3) # revealed: Literal[Unrelated] +``` + +### Narrowing for non-literals + +```py +class A: ... +class B: ... + +def get_class() -> type[object]: ... + +t = get_class() + +if issubclass(t, A): + reveal_type(t) # revealed: type[A] + if issubclass(t, B): + reveal_type(t) # revealed: type[A] & type[B] +else: + reveal_type(t) # revealed: type[object] & ~type[A] +``` + +### Handling of `None` + +```py +from types import NoneType + +def flag() -> bool: ... + +t = int if flag() else NoneType + +if issubclass(t, NoneType): + reveal_type(t) # revealed: Literal[NoneType] + +if issubclass(t, type(None)): + # TODO: this should be just `Literal[NoneType]` + reveal_type(t) # revealed: Literal[int, NoneType] +``` + +## `classinfo` contains multiple types + +### (Nested) tuples of types + +```py +class Unrelated: ... + +def flag() -> bool: ... + +t = int if flag() else str if flag() else bytes + +if issubclass(t, (int, (Unrelated, (bytes,)))): + reveal_type(t) # revealed: Literal[int, bytes] +else: + reveal_type(t) # revealed: Literal[str] +``` + +## Special cases + +### Emit a diagnostic if the first argument is of wrong type + +#### Too wide + +`type[object]` is a subtype of `object`, but not every `object` can be passed as the first argument +to `issubclass`: + +```py +class A: ... + +def get_object() -> object: ... + +t = get_object() + +# TODO: we should emit a diagnostic here +if issubclass(t, A): + reveal_type(t) # revealed: type[A] +``` + +#### Wrong + +`Literal[1]` and `type` are entirely disjoint, so the inferred type of `Literal[1] & type[int]` is +eagerly simplified to `Never` as a result of the type narrowing in the `if issubclass(t, int)` +branch: + +```py +t = 1 + +# TODO: we should emit a diagnostic here +if issubclass(t, int): + reveal_type(t) # revealed: Never +``` + +### Do not use custom `issubclass` for narrowing + +```py +def issubclass(c, ci): + return True + +def flag() -> bool: ... + +t = int if flag() else str +if issubclass(t, int): + reveal_type(t) # revealed: Literal[int, str] +``` + +### Do support narrowing if `issubclass` is aliased + +```py +issubclass_alias = issubclass + +def flag() -> bool: ... + +t = int if flag() else str +if issubclass_alias(t, int): + reveal_type(t) # revealed: Literal[int] +``` + +### Do support narrowing if `issubclass` is imported + +```py +from builtins import issubclass as imported_issubclass + +def flag() -> bool: ... + +t = int if flag() else str +if imported_issubclass(t, int): + reveal_type(t) # revealed: Literal[int] +``` + +### Do not narrow if second argument is not a proper `classinfo` argument + +```py +from typing import Any + +def flag() -> bool: ... + +t = int if flag() else str + +# TODO: this should cause us to emit a diagnostic during +# type checking +if issubclass(t, "str"): + reveal_type(t) # revealed: Literal[int, str] + +# TODO: this should cause us to emit a diagnostic during +# type checking +if issubclass(t, (bytes, "str")): + reveal_type(t) # revealed: Literal[int, str] + +# TODO: this should cause us to emit a diagnostic during +# type checking +if issubclass(t, Any): + reveal_type(t) # revealed: Literal[int, str] +``` + +### Do not narrow if there are keyword arguments + +```py +def flag() -> bool: ... + +t = int if flag() else str + +# TODO: this should cause us to emit a diagnostic +# (`issubclass` has no `foo` parameter) +if issubclass(t, int, foo="bar"): + reveal_type(t) # revealed: Literal[int, str] +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 69efca7538..df5e069f93 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -331,6 +331,8 @@ pub enum Type<'db> { ModuleLiteral(File), /// A specific class object ClassLiteral(ClassLiteralType<'db>), + // The set of all class objects that are subclasses of the given class (C), spelled `type[C]`. + SubclassOf(SubclassOfType<'db>), /// The set of Python objects with the given class in their __class__'s method resolution order Instance(InstanceType<'db>), /// The set of objects in any of the types in the union @@ -400,6 +402,15 @@ impl<'db> Type<'db> { IntersectionBuilder::new(db).add_negative(*self).build() } + #[must_use] + pub fn negate_if(&self, db: &'db dyn Db, yes: bool) -> Type<'db> { + if yes { + self.negate(db) + } else { + *self + } + } + pub const fn into_union(self) -> Option> { match self { Type::Union(union_type) => Some(union_type), @@ -524,6 +535,26 @@ impl<'db> Type<'db> { { true } + (Type::ClassLiteral(self_class), Type::SubclassOf(target_class)) => { + self_class.class.is_subclass_of(db, target_class.class) + } + (Type::SubclassOf(self_class), Type::SubclassOf(target_class)) => { + self_class.class.is_subclass_of(db, target_class.class) + } + ( + Type::SubclassOf(SubclassOfType { class: self_class }), + Type::Instance(InstanceType { + class: target_class, + .. + }), + ) if self_class + .metaclass(db) + .into_class_literal() + .map(|meta| meta.class.is_subclass_of(db, target_class)) + .unwrap_or(false) => + { + true + } (Type::Union(union), ty) => union .elements(db) .iter() @@ -620,6 +651,9 @@ impl<'db> Type<'db> { // TODO equivalent but not identical structural types, differently-ordered unions and // intersections, other cases? + // TODO: Once we have support for final classes, we can establish that + // `Type::SubclassOf('FinalClass')` is equivalent to `Type::ClassLiteral('FinalClass')`. + // TODO: The following is a workaround that is required to unify the two different // versions of `NoneType` in typeshed. This should not be required anymore once we // understand `sys.version_info` branches. @@ -684,6 +718,41 @@ impl<'db> Type<'db> { | Type::ClassLiteral(..)), ) => left != right, + (Type::SubclassOf(type_class), Type::ClassLiteral(class_literal)) + | (Type::ClassLiteral(class_literal), Type::SubclassOf(type_class)) => { + !class_literal.class.is_subclass_of(db, type_class.class) + } + (Type::SubclassOf(_), Type::SubclassOf(_)) => false, + (Type::SubclassOf(_), Type::Instance(_)) | (Type::Instance(_), Type::SubclassOf(_)) => { + false + } + ( + Type::SubclassOf(_), + Type::BooleanLiteral(..) + | Type::IntLiteral(..) + | Type::StringLiteral(..) + | Type::BytesLiteral(..) + | Type::SliceLiteral(..) + | Type::FunctionLiteral(..) + | Type::ModuleLiteral(..), + ) + | ( + Type::BooleanLiteral(..) + | Type::IntLiteral(..) + | Type::StringLiteral(..) + | Type::BytesLiteral(..) + | Type::SliceLiteral(..) + | Type::FunctionLiteral(..) + | Type::ModuleLiteral(..), + Type::SubclassOf(_), + ) => true, + (Type::SubclassOf(_), _) | (_, Type::SubclassOf(_)) => { + // TODO: Once we have support for final classes, we can determine disjointness in some cases + // here. However, note that it might be better to turn `Type::SubclassOf('FinalClass')` into + // `Type::ClassLiteral('FinalClass')` during construction, instead of adding special cases for + // final classes inside `Type::SubclassOf` everywhere. + false + } ( Type::Instance(InstanceType { class: class_none, .. @@ -825,6 +894,11 @@ impl<'db> Type<'db> { // are both of type Literal[345], for example. false } + Type::SubclassOf(..) => { + // TODO once we have support for final classes, we can return `true` for some + // cases: type[C] is a singleton if C is final. + false + } Type::BooleanLiteral(_) | Type::FunctionLiteral(..) | Type::ClassLiteral(..) @@ -872,6 +946,11 @@ impl<'db> Type<'db> { | Type::BytesLiteral(..) | Type::SliceLiteral(..) => true, + Type::SubclassOf(..) => { + // TODO: Same comment as above for `is_singleton` + false + } + Type::Tuple(tuple) => tuple .elements(db) .iter() @@ -965,6 +1044,7 @@ impl<'db> Type<'db> { } } Type::ClassLiteral(class_ty) => class_ty.member(db, name), + Type::SubclassOf(subclass_of_ty) => subclass_of_ty.member(db, name), Type::Instance(_) => { // TODO MRO? get_own_instance_member, get_instance_member Type::Todo.into() @@ -1055,6 +1135,10 @@ impl<'db> Type<'db> { // More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing Truthiness::Ambiguous } + Type::SubclassOf(_) => { + // TODO: see above + Truthiness::Ambiguous + } Type::Instance(InstanceType { class, .. }) => { // TODO: lookup `__bool__` and `__len__` methods on the instance's class // More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing @@ -1239,6 +1323,7 @@ impl<'db> Type<'db> { Type::Unknown => Type::Unknown, Type::Never => Type::Never, Type::ClassLiteral(ClassLiteralType { class }) => Type::anonymous_instance(*class), + Type::SubclassOf(SubclassOfType { class }) => Type::anonymous_instance(*class), Type::Union(union) => union.map(db, |element| element.to_instance(db)), // TODO: we can probably do better here: --Alex Type::Intersection(_) => Type::Todo, @@ -1272,10 +1357,8 @@ impl<'db> Type<'db> { pub fn to_meta_type(&self, db: &'db dyn Db) -> Type<'db> { match self { Type::Never => Type::Never, - // TODO: not really correct -- the meta-type of an `InstanceType { class: T }` should be `type[T]` - // () Type::Instance(InstanceType { class, .. }) => { - Type::ClassLiteral(ClassLiteralType { class: *class }) + Type::SubclassOf(SubclassOfType { class: *class }) } Type::Union(union) => union.map(db, |ty| ty.to_meta_type(db)), Type::BooleanLiteral(_) => KnownClass::Bool.to_class(db), @@ -1286,7 +1369,14 @@ impl<'db> Type<'db> { Type::ModuleLiteral(_) => KnownClass::ModuleType.to_class(db), Type::Tuple(_) => KnownClass::Tuple.to_class(db), Type::ClassLiteral(ClassLiteralType { class }) => class.metaclass(db), - // TODO can we do better here? `type[LiteralString]`? + Type::SubclassOf(SubclassOfType { class }) => Type::SubclassOf( + class + .try_metaclass(db) + .ok() + .and_then(Type::into_class_literal) + .unwrap_or(KnownClass::Type.to_class(db).expect_class_literal()) + .to_subclass_of_type(), + ), Type::StringLiteral(_) | Type::LiteralString => KnownClass::Str.to_class(db), // TODO: `type[Any]`? Type::Any => Type::Any, @@ -1910,14 +2000,47 @@ impl<'db> FunctionType<'db> { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum KnownConstraintFunction { + /// `builtins.isinstance` + IsInstance, + /// `builtins.issubclass` + IsSubclass, +} + /// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might /// have special behavior. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum KnownFunction { + ConstraintFunction(KnownConstraintFunction), /// `builtins.reveal_type`, `typing.reveal_type` or `typing_extensions.reveal_type` RevealType, - /// `builtins.isinstance` - IsInstance, +} + +impl KnownFunction { + pub fn constraint_function(self) -> Option { + match self { + Self::ConstraintFunction(f) => Some(f), + Self::RevealType => None, + } + } + + fn from_definition<'db>( + db: &'db dyn Db, + definition: Definition<'db>, + name: &str, + ) -> Option { + match name { + "reveal_type" if definition.is_typing_definition(db) => Some(KnownFunction::RevealType), + "isinstance" if definition.is_builtin_definition(db) => Some( + KnownFunction::ConstraintFunction(KnownConstraintFunction::IsInstance), + ), + "issubclass" if definition.is_builtin_definition(db) => Some( + KnownFunction::ConstraintFunction(KnownConstraintFunction::IsSubclass), + ), + _ => None, + } + } } /// Representation of a runtime class object. @@ -2228,6 +2351,10 @@ impl<'db> ClassLiteralType<'db> { fn member(self, db: &'db dyn Db, name: &str) -> Symbol<'db> { self.class.class_member(db, name) } + + fn to_subclass_of_type(self) -> SubclassOfType<'db> { + SubclassOfType { class: self.class } + } } impl<'db> From> for Type<'db> { @@ -2236,6 +2363,18 @@ impl<'db> From> for Type<'db> { } } +/// A type that represents `type[C]`, i.e. the class literal `C` and class literals that are subclasses of `C`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SubclassOfType<'db> { + class: Class<'db>, +} + +impl<'db> SubclassOfType<'db> { + fn member(self, db: &'db dyn Db, name: &str) -> Symbol<'db> { + self.class.class_member(db, name) + } +} + /// A type representing the set of runtime objects which are instances of a certain class. /// /// Some specific instances of some types need to be treated specially by the type system: @@ -2463,7 +2602,10 @@ mod tests { StringLiteral(&'static str), LiteralString, BytesLiteral(&'static str), + // BuiltinInstance("str") corresponds to an instance of the builtin `str` class BuiltinInstance(&'static str), + // BuiltinClassLiteral("str") corresponds to the builtin `str` class object itself + BuiltinClassLiteral(&'static str), Union(Vec), Intersection { pos: Vec, neg: Vec }, Tuple(Vec), @@ -2483,6 +2625,7 @@ mod tests { Ty::LiteralString => Type::LiteralString, Ty::BytesLiteral(s) => Type::BytesLiteral(BytesLiteralType::new(db, s.as_bytes())), Ty::BuiltinInstance(s) => builtins_symbol(db, s).expect_type().to_instance(db), + Ty::BuiltinClassLiteral(s) => builtins_symbol(db, s).expect_type(), Ty::Union(tys) => { UnionType::from_elements(db, tys.into_iter().map(|ty| ty.into_type(db))) } @@ -2569,6 +2712,8 @@ mod tests { #[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})] #[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]})] #[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("str")], neg: vec![Ty::StringLiteral("foo")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})] + #[test_case(Ty::BuiltinClassLiteral("int"), Ty::BuiltinClassLiteral("int"))] + #[test_case(Ty::BuiltinClassLiteral("int"), Ty::BuiltinInstance("object"))] fn is_subtype_of(from: Ty, to: Ty) { let db = setup_db(); assert!(from.into_type(&db).is_subtype_of(&db, to.into_type(&db))); @@ -2594,6 +2739,8 @@ mod tests { #[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]})] #[test_case(Ty::BuiltinInstance("int"), Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(3)]})] #[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(1)]})] + #[test_case(Ty::BuiltinClassLiteral("int"), Ty::BuiltinClassLiteral("object"))] + #[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinClassLiteral("int"))] fn is_not_subtype_of(from: Ty, to: Ty) { let db = setup_db(); assert!(!from.into_type(&db).is_subtype_of(&db, to.into_type(&db))); @@ -2605,24 +2752,43 @@ mod tests { db.write_dedented( "/src/module.py", " - class A: ... - class B: ... - U = A if flag else B + class Base: ... + class Derived(Base): ... + class Unrelated: ... + U = Base if flag else Unrelated ", ) .unwrap(); let module = ruff_db::files::system_path_to_file(&db, "/src/module.py").unwrap(); - let type_a = super::global_symbol(&db, module, "A").expect_type(); - let type_u = super::global_symbol(&db, module, "U").expect_type(); + // `literal_base` represents `Literal[Base]`. + let literal_base = super::global_symbol(&db, module, "Base").expect_type(); + let literal_derived = super::global_symbol(&db, module, "Derived").expect_type(); + let u = super::global_symbol(&db, module, "U").expect_type(); - assert!(type_a.is_class_literal()); - assert!(type_a.is_subtype_of(&db, Ty::BuiltinInstance("type").into_type(&db))); - assert!(type_a.is_subtype_of(&db, Ty::BuiltinInstance("object").into_type(&db))); + assert!(literal_base.is_class_literal()); + assert!(literal_base.is_subtype_of(&db, Ty::BuiltinInstance("type").into_type(&db))); + assert!(literal_base.is_subtype_of(&db, Ty::BuiltinInstance("object").into_type(&db))); - assert!(type_u.is_union()); - assert!(type_u.is_subtype_of(&db, Ty::BuiltinInstance("type").into_type(&db))); - assert!(type_u.is_subtype_of(&db, Ty::BuiltinInstance("object").into_type(&db))); + assert!(literal_derived.is_class_literal()); + + // `subclass_of_base` represents `Type[Base]`. + let subclass_of_base = + Type::SubclassOf(literal_base.expect_class_literal().to_subclass_of_type()); + assert!(literal_base.is_subtype_of(&db, subclass_of_base)); + assert!(literal_derived.is_subtype_of(&db, subclass_of_base)); + + let subclass_of_derived = + Type::SubclassOf(literal_derived.expect_class_literal().to_subclass_of_type()); + assert!(literal_derived.is_subtype_of(&db, subclass_of_derived)); + assert!(!literal_base.is_subtype_of(&db, subclass_of_derived)); + + // Type[Derived] <: Type[Base] + assert!(subclass_of_derived.is_subtype_of(&db, subclass_of_base)); + + assert!(u.is_union()); + assert!(u.is_subtype_of(&db, Ty::BuiltinInstance("type").into_type(&db))); + assert!(u.is_subtype_of(&db, Ty::BuiltinInstance("object").into_type(&db))); } #[test] @@ -2750,6 +2916,42 @@ mod tests { assert!(!type_a.is_disjoint_from(&db, type_u)); } + #[test] + fn is_disjoint_type_type() { + let mut db = setup_db(); + db.write_dedented( + "/src/module.py", + " + class A: ... + class B: ... + ", + ) + .unwrap(); + let module = ruff_db::files::system_path_to_file(&db, "/src/module.py").unwrap(); + + let literal_a = super::global_symbol(&db, module, "A").expect_type(); + let literal_b = super::global_symbol(&db, module, "B").expect_type(); + + let subclass_of_a = + Type::SubclassOf(literal_a.expect_class_literal().to_subclass_of_type()); + let subclass_of_b = + Type::SubclassOf(literal_b.expect_class_literal().to_subclass_of_type()); + + // Class literals are always disjoint. They are singleton types + assert!(literal_a.is_disjoint_from(&db, literal_b)); + + // The class A is a subclass of A, so A is not disjoint from type[A] + assert!(!literal_a.is_disjoint_from(&db, subclass_of_a)); + + // The class A is disjoint from type[B] because it's not a subclass + // of B: + assert!(literal_a.is_disjoint_from(&db, subclass_of_b)); + + // However, type[A] is not disjoint from type[B], as there could be + // classes that inherit from both A and B: + assert!(!subclass_of_a.is_disjoint_from(&db, subclass_of_b)); + } + #[test_case(Ty::None)] #[test_case(Ty::BooleanLiteral(true))] #[test_case(Ty::BooleanLiteral(false))] diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 36e1727a6a..9e0b5dfeed 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -6,7 +6,9 @@ use ruff_db::display::FormatterJoinExtension; use ruff_python_ast::str::Quote; use ruff_python_literal::escape::AsciiEscape; -use crate::types::{ClassLiteralType, InstanceType, IntersectionType, KnownClass, Type, UnionType}; +use crate::types::{ + ClassLiteralType, InstanceType, IntersectionType, KnownClass, SubclassOfType, Type, UnionType, +}; use crate::Db; use rustc_hash::FxHashMap; @@ -77,6 +79,9 @@ impl Display for DisplayRepresentation<'_> { } // TODO functions and classes should display using a fully qualified name Type::ClassLiteral(ClassLiteralType { class }) => f.write_str(class.name(self.db)), + Type::SubclassOf(SubclassOfType { class }) => { + write!(f, "type[{}]", class.name(self.db)) + } Type::Instance(InstanceType { class, known }) => f.write_str(match known { Some(super::KnownInstance::Literal) => "Literal", _ => class.name(self.db), diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index f07ab2c3c5..7acfce8ff0 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -861,15 +861,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } - let function_kind = match &**name { - "reveal_type" if definition.is_typing_definition(self.db) => { - Some(KnownFunction::RevealType) - } - "isinstance" if definition.is_builtin_definition(self.db) => { - Some(KnownFunction::IsInstance) - } - _ => None, - }; + let function_kind = KnownFunction::from_definition(self.db, definition, name); let body_scope = self .index @@ -3882,15 +3874,15 @@ impl<'db> TypeInferenceBuilder<'db> { let value_ty = self.infer_expression(value); - if value_ty - .into_class_literal() - .is_some_and(|ClassLiteralType { class }| { - class.is_known(self.db, KnownClass::Tuple) - }) - { - self.infer_tuple_type_expression(slice) - } else { - self.infer_subscript_type_expression(subscript, value_ty) + match value_ty { + Type::ClassLiteral(class_literal_ty) => { + match class_literal_ty.class.known(self.db) { + Some(KnownClass::Tuple) => self.infer_tuple_type_expression(slice), + Some(KnownClass::Type) => self.infer_subclass_of_type_expression(slice), + _ => self.infer_subscript_type_expression(subscript, value_ty), + } + } + _ => self.infer_subscript_type_expression(subscript, value_ty), } } @@ -4063,6 +4055,25 @@ impl<'db> TypeInferenceBuilder<'db> { } } + /// Given the slice of a `type[]` annotation, return the type that the annotation represents + fn infer_subclass_of_type_expression(&mut self, slice: &ast::Expr) -> Type<'db> { + match slice { + ast::Expr::Name(name) => { + let name_ty = self.infer_name_expression(name); + if let Some(class_literal) = name_ty.into_class_literal() { + Type::SubclassOf(class_literal.to_subclass_of_type()) + } else { + Type::Todo + } + } + // TODO: attributes, unions, subscripts, etc. + _ => { + self.infer_type_expression(slice); + Type::Todo + } + } + } + fn infer_subscript_type_expression( &mut self, subscript: &ast::ExprSubscript, diff --git a/crates/red_knot_python_semantic/src/types/mro.rs b/crates/red_knot_python_semantic/src/types/mro.rs index e41ed031eb..74e44fe464 100644 --- a/crates/red_knot_python_semantic/src/types/mro.rs +++ b/crates/red_knot_python_semantic/src/types/mro.rs @@ -377,7 +377,8 @@ impl<'db> ClassBase<'db> { | Type::LiteralString | Type::Tuple(_) | Type::SliceLiteral(_) - | Type::ModuleLiteral(_) => None, + | Type::ModuleLiteral(_) + | Type::SubclassOf(_) => None, } } diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 07ff0803b9..d28006aded 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -5,8 +5,8 @@ use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; use crate::semantic_index::symbol_table; use crate::types::{ - infer_expression_types, ClassLiteralType, IntersectionBuilder, KnownClass, KnownFunction, - Truthiness, Type, UnionBuilder, + infer_expression_types, ClassLiteralType, IntersectionBuilder, KnownClass, + KnownConstraintFunction, KnownFunction, Truthiness, Type, UnionBuilder, }; use crate::Db; use itertools::Itertools; @@ -78,24 +78,27 @@ fn all_negative_narrowing_constraints_for_expression<'db>( NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), false).finish() } -/// Generate a constraint from the *type* of the second argument of an `isinstance` call. +/// Generate a constraint from the type of a `classinfo` argument to `isinstance` or `issubclass`. /// -/// Example: for `isinstance(…, str)`, we would infer `Type::ClassLiteral(str)` from the -/// second argument, but we need to generate a `Type::Instance(str)` constraint that can -/// be used to narrow down the type of the first argument. -fn generate_isinstance_constraint<'db>( +/// The `classinfo` argument can be a class literal, a tuple of (tuples of) class literals. PEP 604 +/// union types are not yet supported. Returns `None` if the `classinfo` argument has a wrong type. +fn generate_classinfo_constraint<'db, F>( db: &'db dyn Db, classinfo: &Type<'db>, -) -> Option> { + to_constraint: F, +) -> Option> +where + F: Fn(ClassLiteralType<'db>) -> Type<'db> + Copy, +{ match classinfo { - Type::ClassLiteral(ClassLiteralType { class }) => Some(Type::anonymous_instance(*class)), Type::Tuple(tuple) => { let mut builder = UnionBuilder::new(db); for element in tuple.elements(db) { - builder = builder.add(generate_isinstance_constraint(db, element)?); + builder = builder.add(generate_classinfo_constraint(db, element, to_constraint)?); } Some(builder.build()) } + Type::ClassLiteral(class_literal_type) => Some(to_constraint(*class_literal_type)), _ => None, } } @@ -330,34 +333,49 @@ impl<'db> NarrowingConstraintsBuilder<'db> { let scope = self.scope(); let inference = infer_expression_types(self.db, expression); - if let Some(func_type) = inference + // TODO: add support for PEP 604 union types on the right hand side of `isinstance` + // and `issubclass`, for example `isinstance(x, str | (int | float))`. + match inference .expression_ty(expr_call.func.scoped_ast_id(self.db, scope)) .into_function_literal() + .and_then(|f| f.known(self.db)) + .and_then(KnownFunction::constraint_function) { - if func_type.is_known(self.db, KnownFunction::IsInstance) - && expr_call.arguments.keywords.is_empty() - { - if let [ast::Expr::Name(ast::ExprName { id, .. }), rhs] = &*expr_call.arguments.args + Some(function) if expr_call.arguments.keywords.is_empty() => { + if let [ast::Expr::Name(ast::ExprName { id, .. }), class_info] = + &*expr_call.arguments.args { let symbol = self.symbols().symbol_id_by_name(id).unwrap(); - let rhs_type = inference.expression_ty(rhs.scoped_ast_id(self.db, scope)); + let class_info_ty = + inference.expression_ty(class_info.scoped_ast_id(self.db, scope)); - // TODO: add support for PEP 604 union types on the right hand side: - // isinstance(x, str | (int | float)) - if let Some(mut constraint) = generate_isinstance_constraint(self.db, &rhs_type) - { - if !is_positive { - constraint = constraint.negate(self.db); + let to_constraint = match function { + KnownConstraintFunction::IsInstance => { + |class_literal: ClassLiteralType<'db>| { + Type::anonymous_instance(class_literal.class) + } } - let mut constraints = NarrowingConstraints::default(); - constraints.insert(symbol, constraint); - return Some(constraints); - } + KnownConstraintFunction::IsSubclass => { + |class_literal: ClassLiteralType<'db>| { + Type::SubclassOf(class_literal.to_subclass_of_type()) + } + } + }; + + generate_classinfo_constraint(self.db, &class_info_ty, to_constraint).map( + |constraint| { + let mut constraints = NarrowingConstraints::default(); + constraints.insert(symbol, constraint.negate_if(self.db, !is_positive)); + constraints + }, + ) + } else { + None } } + _ => None, } - None } fn evaluate_match_pattern_singleton(