From 9dddd73c292b477fc96fbb6ff233cd46c7caa3ca Mon Sep 17 00:00:00 2001 From: Shaygan Hooshyari Date: Tue, 5 Nov 2024 02:45:46 +0100 Subject: [PATCH] [red-knot] Literal special form (#13874) Handling `Literal` type in annotations. Resolves: #13672 ## Implementation Since Literals are not a fully defined type in typeshed. I used a trick to figure out when a special form is a literal. When we are inferring assignment types I am checking if the type of that assignment was resolved to typing.SpecialForm and the name of the target is `Literal` if that is the case then I am re creating a new instance type and set the known instance field to `KnownInstance:Literal`. **Why not defining a new type?** From this [issue](https://github.com/python/typeshed/issues/6219) I learned that we want to resolve members to SpecialMethod class. So if we create a new instance here we can rely on the member resolving in that already exists. ## Tests https://typing.readthedocs.io/en/latest/spec/literal.html#equivalence-of-two-literals Since the type of the value inside Literal is evaluated as a Literal(LiteralString, LiteralInt, ...) then the equality is only true when types and value are equal. https://typing.readthedocs.io/en/latest/spec/literal.html#legal-and-illegal-parameterizations The illegal parameterizations are mostly implemented I'm currently checking the slice expression and the slice type to make sure it's valid. https://typing.readthedocs.io/en/latest/spec/literal.html#shortening-unions-of-literals --------- Co-authored-by: Carl Meyer Co-authored-by: Alex Waygood --- .../comparison/instances/membership_test.md | 10 +- .../resources/mdtest/literal/literal.md | 91 +++++++ .../resources/mdtest/unary/instance.md | 4 +- crates/red_knot_python_semantic/src/stdlib.rs | 4 +- crates/red_knot_python_semantic/src/types.rs | 249 +++++++++++++----- .../src/types/builder.rs | 7 +- .../src/types/display.rs | 11 +- .../src/types/infer.rs | 204 +++++++++++--- .../src/types/narrow.rs | 2 +- 9 files changed, 465 insertions(+), 117 deletions(-) create mode 100644 crates/red_knot_python_semantic/resources/mdtest/literal/literal.md diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/membership_test.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/membership_test.md index a8dcba2dcf..9f9b5bce10 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/membership_test.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/membership_test.md @@ -93,13 +93,11 @@ class AlwaysFalse: def __contains__(self, item: int) -> Literal[""]: return "" -# TODO: it should be Literal[True] and Literal[False] -reveal_type(42 in AlwaysTrue()) # revealed: @Todo -reveal_type(42 not in AlwaysTrue()) # revealed: @Todo +reveal_type(42 in AlwaysTrue()) # revealed: Literal[True] +reveal_type(42 not in AlwaysTrue()) # revealed: Literal[False] -# TODO: it should be Literal[False] and Literal[True] -reveal_type(42 in AlwaysFalse()) # revealed: @Todo -reveal_type(42 not in AlwaysFalse()) # revealed: @Todo +reveal_type(42 in AlwaysFalse()) # revealed: Literal[False] +reveal_type(42 not in AlwaysFalse()) # revealed: Literal[True] ``` ## No Fallback for `__contains__` diff --git a/crates/red_knot_python_semantic/resources/mdtest/literal/literal.md b/crates/red_knot_python_semantic/resources/mdtest/literal/literal.md new file mode 100644 index 0000000000..9aaf9ce7fb --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/literal/literal.md @@ -0,0 +1,91 @@ +# Literal + + + +## Parameterization + +```py +from typing import Literal +from enum import Enum + +mode: Literal["w", "r"] +mode2: Literal["w"] | Literal["r"] +union_var: Literal[Literal[Literal[1, 2, 3], "foo"], 5, None] +a1: Literal[26] +a2: Literal[0x1A] +a3: Literal[-4] +a4: Literal["hello world"] +a5: Literal[b"hello world"] +a6: Literal[True] +a7: Literal[None] +a8: Literal[Literal[1]] +a9: Literal[Literal["w"], Literal["r"], Literal[Literal["w+"]]] + +class Color(Enum): + RED = 0 + GREEN = 1 + BLUE = 2 + +b1: Literal[Color.RED] + +def f(): + reveal_type(mode) # revealed: Literal["w", "r"] + reveal_type(mode2) # revealed: Literal["w", "r"] + # TODO: should be revealed: Literal[1, 2, 3, "foo", 5] | None + reveal_type(union_var) # revealed: Literal[1, 2, 3, 5] | Literal["foo"] | None + reveal_type(a1) # revealed: Literal[26] + reveal_type(a2) # revealed: Literal[26] + reveal_type(a3) # revealed: Literal[-4] + reveal_type(a4) # revealed: Literal["hello world"] + reveal_type(a5) # revealed: Literal[b"hello world"] + reveal_type(a6) # revealed: Literal[True] + reveal_type(a7) # revealed: None + reveal_type(a8) # revealed: Literal[1] + reveal_type(a9) # revealed: Literal["w", "r", "w+"] + # TODO: This should be Color.RED + reveal_type(b1) # revealed: Literal[0] + +# error: [invalid-literal-parameter] +invalid1: Literal[3 + 4] +# error: [invalid-literal-parameter] +invalid2: Literal[4 + 3j] +# error: [invalid-literal-parameter] +invalid3: Literal[(3, 4)] +invalid4: Literal[ + 1 + 2, # error: [invalid-literal-parameter] + "foo", + hello, # error: [invalid-literal-parameter] + (1, 2, 3), # error: [invalid-literal-parameter] +] +``` + +## Detecting Literal outside typing and typing_extensions + +Only Literal that is defined in typing and typing_extension modules is detected as the special +Literal. + +```pyi path=other.pyi +from typing import _SpecialForm + +Literal: _SpecialForm +``` + +```py +from other import Literal + +a1: Literal[26] + +def f(): + reveal_type(a1) # revealed: @Todo +``` + +## Detecting typing_extensions.Literal + +```py +from typing_extensions import Literal + +a1: Literal[26] + +def f(): + reveal_type(a1) # revealed: Literal[26] +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/unary/instance.md b/crates/red_knot_python_semantic/resources/mdtest/unary/instance.md index c07bf7b56c..85f80c3b8c 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/unary/instance.md +++ b/crates/red_knot_python_semantic/resources/mdtest/unary/instance.md @@ -1,6 +1,8 @@ # Unary Operations ```py +from typing import Literal + class Number: def __init__(self, value: int): self.value = 1 @@ -18,7 +20,7 @@ a = Number() reveal_type(+a) # revealed: int reveal_type(-a) # revealed: int -reveal_type(~a) # revealed: @Todo +reveal_type(~a) # revealed: Literal[True] class NoDunder: ... diff --git a/crates/red_knot_python_semantic/src/stdlib.rs b/crates/red_knot_python_semantic/src/stdlib.rs index 90be044517..a212c17b34 100644 --- a/crates/red_knot_python_semantic/src/stdlib.rs +++ b/crates/red_knot_python_semantic/src/stdlib.rs @@ -11,11 +11,9 @@ use crate::Db; enum CoreStdlibModule { Builtins, Types, - // the Typing enum is currently only used in tests - #[allow(dead_code)] - Typing, Typeshed, TypingExtensions, + Typing, } impl CoreStdlibModule { diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 8fd3617ee0..1a4edae35a 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -12,7 +12,9 @@ use crate::semantic_index::{ global_scope, semantic_index, symbol_table, use_def_map, BindingWithConstraints, BindingWithConstraintsIterator, DeclarationsIterator, }; -use crate::stdlib::{builtins_symbol, types_symbol, typeshed_symbol, typing_extensions_symbol}; +use crate::stdlib::{ + builtins_symbol, types_symbol, typeshed_symbol, typing_extensions_symbol, typing_symbol, +}; use crate::symbol::{Boundness, Symbol}; use crate::types::diagnostic::TypeCheckDiagnosticsBuilder; use crate::types::narrow::narrowing_constraint; @@ -328,7 +330,7 @@ pub enum Type<'db> { /// A specific class object ClassLiteral(ClassType<'db>), /// The set of Python objects with the given class in their __class__'s method resolution order - Instance(ClassType<'db>), + Instance(InstanceType<'db>), /// The set of objects in any of the types in the union Union(UnionType<'db>), /// The set of objects in all of the types in the intersection @@ -484,23 +486,32 @@ impl<'db> Type<'db> { (_, Type::Unknown | Type::Any | Type::Todo) => false, (Type::Never, _) => true, (_, Type::Never) => false, - (_, Type::Instance(class)) if class.is_known(db, KnownClass::Object) => true, - (Type::Instance(class), _) if class.is_known(db, KnownClass::Object) => false, - (Type::BooleanLiteral(_), Type::Instance(class)) + (_, Type::Instance(InstanceType { class, .. })) + if class.is_known(db, KnownClass::Object) => + { + true + } + (Type::Instance(InstanceType { class, .. }), _) + if class.is_known(db, KnownClass::Object) => + { + false + } + (Type::BooleanLiteral(_), Type::Instance(InstanceType { class, .. })) if class.is_known(db, KnownClass::Bool) => { true } - (Type::IntLiteral(_), Type::Instance(class)) if class.is_known(db, KnownClass::Int) => { - true - } - (Type::StringLiteral(_), Type::LiteralString) => true, - (Type::StringLiteral(_) | Type::LiteralString, Type::Instance(class)) - if class.is_known(db, KnownClass::Str) => + (Type::IntLiteral(_), Type::Instance(InstanceType { class, .. })) + if class.is_known(db, KnownClass::Int) => { true } - (Type::BytesLiteral(_), Type::Instance(class)) + (Type::StringLiteral(_), Type::LiteralString) => true, + ( + Type::StringLiteral(_) | Type::LiteralString, + Type::Instance(InstanceType { class, .. }), + ) if class.is_known(db, KnownClass::Str) => true, + (Type::BytesLiteral(_), Type::Instance(InstanceType { class, .. })) if class.is_known(db, KnownClass::Bytes) => { true @@ -515,7 +526,7 @@ impl<'db> Type<'db> { }, ) } - (Type::ClassLiteral(..), Type::Instance(class)) + (Type::ClassLiteral(..), Type::Instance(InstanceType { class, .. })) if class.is_known(db, KnownClass::Type) => { true @@ -568,9 +579,15 @@ impl<'db> Type<'db> { .iter() .all(|&neg_ty| neg_ty.is_disjoint_from(db, ty)) } - (Type::Instance(self_class), Type::Instance(target_class)) => { - self_class.is_subclass_of(db, target_class) - } + ( + Type::Instance(InstanceType { + class: self_class, .. + }), + Type::Instance(InstanceType { + class: target_class, + .. + }), + ) => self_class.is_subclass_of(db, target_class), // TODO _ => false, } @@ -615,8 +632,12 @@ impl<'db> Type<'db> { // understand `sys.version_info` branches. self == other || matches!((self, other), - (Type::Instance(self_class), Type::Instance(target_class)) - if self_class.is_known(db, KnownClass::NoneType) && target_class.is_known(db, KnownClass::NoneType)) + ( + Type::Instance(InstanceType { class: self_class, .. }), + Type::Instance(InstanceType { class: target_class, .. }) + ) + if self_class.is_known(db, KnownClass::NoneType) && + target_class.is_known(db, KnownClass::NoneType)) } /// Return true if this type and `other` have no common elements. @@ -670,74 +691,88 @@ impl<'db> Type<'db> { | Type::ClassLiteral(..)), ) => left != right, - (Type::Instance(class_none), Type::Instance(class_other)) - | (Type::Instance(class_other), Type::Instance(class_none)) - if class_none.is_known(db, KnownClass::NoneType) => - { - !matches!( - class_other.known(db), - Some(KnownClass::NoneType | KnownClass::Object) - ) - } - (Type::Instance(class_none), _) | (_, Type::Instance(class_none)) - if class_none.is_known(db, KnownClass::NoneType) => - { - true - } + ( + Type::Instance(InstanceType { + class: class_none, .. + }), + Type::Instance(InstanceType { + class: class_other, .. + }), + ) + | ( + Type::Instance(InstanceType { + class: class_other, .. + }), + Type::Instance(InstanceType { + class: class_none, .. + }), + ) if class_none.is_known(db, KnownClass::NoneType) => !matches!( + class_other.known(db), + Some(KnownClass::NoneType | KnownClass::Object) + ), + ( + Type::Instance(InstanceType { + class: class_none, .. + }), + _, + ) + | ( + _, + Type::Instance(InstanceType { + class: class_none, .. + }), + ) if class_none.is_known(db, KnownClass::NoneType) => true, - (Type::BooleanLiteral(..), Type::Instance(class_type)) - | (Type::Instance(class_type), Type::BooleanLiteral(..)) => !matches!( - class_type.known(db), + (Type::BooleanLiteral(..), Type::Instance(InstanceType { class, .. })) + | (Type::Instance(InstanceType { class, .. }), Type::BooleanLiteral(..)) => !matches!( + class.known(db), Some(KnownClass::Bool | KnownClass::Int | KnownClass::Object) ), (Type::BooleanLiteral(..), _) | (_, Type::BooleanLiteral(..)) => true, - (Type::IntLiteral(..), Type::Instance(class_type)) - | (Type::Instance(class_type), Type::IntLiteral(..)) => !matches!( - class_type.known(db), - Some(KnownClass::Int | KnownClass::Object) - ), + (Type::IntLiteral(..), Type::Instance(InstanceType { class, .. })) + | (Type::Instance(InstanceType { class, .. }), Type::IntLiteral(..)) => { + !matches!(class.known(db), Some(KnownClass::Int | KnownClass::Object)) + } (Type::IntLiteral(..), _) | (_, Type::IntLiteral(..)) => true, (Type::StringLiteral(..), Type::LiteralString) | (Type::LiteralString, Type::StringLiteral(..)) => false, - (Type::StringLiteral(..), Type::Instance(class_type)) - | (Type::Instance(class_type), Type::StringLiteral(..)) => !matches!( - class_type.known(db), - Some(KnownClass::Str | KnownClass::Object) - ), + (Type::StringLiteral(..), Type::Instance(InstanceType { class, .. })) + | (Type::Instance(InstanceType { class, .. }), Type::StringLiteral(..)) => { + !matches!(class.known(db), Some(KnownClass::Str | KnownClass::Object)) + } (Type::StringLiteral(..), _) | (_, Type::StringLiteral(..)) => true, (Type::LiteralString, Type::LiteralString) => false, - (Type::LiteralString, Type::Instance(class_type)) - | (Type::Instance(class_type), Type::LiteralString) => !matches!( - class_type.known(db), - Some(KnownClass::Str | KnownClass::Object) - ), + (Type::LiteralString, Type::Instance(InstanceType { class, .. })) + | (Type::Instance(InstanceType { class, .. }), Type::LiteralString) => { + !matches!(class.known(db), Some(KnownClass::Str | KnownClass::Object)) + } (Type::LiteralString, _) | (_, Type::LiteralString) => true, - (Type::BytesLiteral(..), Type::Instance(class_type)) - | (Type::Instance(class_type), Type::BytesLiteral(..)) => !matches!( - class_type.known(db), + (Type::BytesLiteral(..), Type::Instance(InstanceType { class, .. })) + | (Type::Instance(InstanceType { class, .. }), Type::BytesLiteral(..)) => !matches!( + class.known(db), Some(KnownClass::Bytes | KnownClass::Object) ), (Type::BytesLiteral(..), _) | (_, Type::BytesLiteral(..)) => true, - (Type::SliceLiteral(..), Type::Instance(class_type)) - | (Type::Instance(class_type), Type::SliceLiteral(..)) => !matches!( - class_type.known(db), + (Type::SliceLiteral(..), Type::Instance(InstanceType { class, .. })) + | (Type::Instance(InstanceType { class, .. }), Type::SliceLiteral(..)) => !matches!( + class.known(db), Some(KnownClass::Slice | KnownClass::Object) ), (Type::SliceLiteral(..), _) | (_, Type::SliceLiteral(..)) => true, ( Type::FunctionLiteral(..) | Type::ModuleLiteral(..) | Type::ClassLiteral(..), - Type::Instance(class_type), + Type::Instance(InstanceType { class, .. }), ) | ( - Type::Instance(class_type), + Type::Instance(InstanceType { class, .. }), Type::FunctionLiteral(..) | Type::ModuleLiteral(..) | Type::ClassLiteral(..), - ) => !class_type.is_known(db, KnownClass::Object), + ) => !class.is_known(db, KnownClass::Object), (Type::Instance(..), Type::Instance(..)) => { // TODO: once we have support for `final`, there might be some cases where @@ -801,7 +836,7 @@ impl<'db> Type<'db> { | Type::FunctionLiteral(..) | Type::ClassLiteral(..) | Type::ModuleLiteral(..) => true, - Type::Instance(class) => { + Type::Instance(InstanceType { class, .. }) => { // TODO some more instance types can be singleton types (EllipsisType, NotImplementedType) matches!(class.known(db), Some(KnownClass::NoneType)) } @@ -849,7 +884,7 @@ impl<'db> Type<'db> { .iter() .all(|elem| elem.is_single_valued(db)), - Type::Instance(class_type) => match class_type.known(db) { + Type::Instance(InstanceType { class, .. }) => match class.known(db) { Some(KnownClass::NoneType) => true, Some( KnownClass::Bool @@ -866,7 +901,8 @@ impl<'db> Type<'db> { | KnownClass::Slice | KnownClass::GenericAlias | KnownClass::ModuleType - | KnownClass::FunctionType, + | KnownClass::FunctionType + | KnownClass::SpecialForm, ) => false, None => false, }, @@ -1026,7 +1062,7 @@ impl<'db> Type<'db> { // More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing Truthiness::Ambiguous } - Type::Instance(class) => { + Type::Instance(InstanceType { class, .. }) => { // TODO: lookup `__bool__` and `__len__` methods on the instance's class // More info in https://docs.python.org/3/library/stdtypes.html#truth-value-testing // For now, we only special-case some builtin classes @@ -1091,11 +1127,11 @@ impl<'db> Type<'db> { .first() .map(|arg| arg.bool(db).into_type(db)) .unwrap_or(Type::BooleanLiteral(false)), - _ => Type::Instance(class), + _ => class.to_instance(), }) } - Type::Instance(class) => { + Type::Instance(InstanceType { class, .. }) => { // Since `__call__` is a dunder, we need to access it as an attribute on the class // rather than the instance (matching runtime semantics). match class.class_member(db, "__call__") { @@ -1209,7 +1245,7 @@ impl<'db> Type<'db> { Type::Todo => Type::Todo, Type::Unknown => Type::Unknown, Type::Never => Type::Never, - Type::ClassLiteral(class) => Type::Instance(*class), + Type::ClassLiteral(class) => Type::Instance(InstanceType::anonymous(*class)), Type::Union(union) => union.map(db, |element| element.to_instance(db)), // TODO: we can probably do better here: --Alex Type::Intersection(_) => Type::Todo, @@ -1239,7 +1275,7 @@ impl<'db> Type<'db> { pub fn to_meta_type(&self, db: &'db dyn Db) -> Type<'db> { match self { Type::Never => Type::Never, - Type::Instance(class) => Type::ClassLiteral(*class), + Type::Instance(InstanceType { class, .. }) => Type::ClassLiteral(*class), Type::Union(union) => union.map(db, |ty| ty.to_meta_type(db)), Type::BooleanLiteral(_) => KnownClass::Bool.to_class(db), Type::BytesLiteral(_) => KnownClass::Bytes.to_class(db), @@ -1339,6 +1375,7 @@ pub enum KnownClass { FunctionType, // Typeshed NoneType, // Part of `types` for Python >= 3.10 + SpecialForm, } impl<'db> KnownClass { @@ -1360,6 +1397,7 @@ impl<'db> KnownClass { Self::ModuleType => "ModuleType", Self::FunctionType => "FunctionType", Self::NoneType => "NoneType", + Self::SpecialForm => "_SpecialForm", } } @@ -1384,7 +1422,7 @@ impl<'db> KnownClass { Self::GenericAlias | Self::ModuleType | Self::FunctionType => { types_symbol(db, self.as_str()).unwrap_or_unknown() } - + Self::SpecialForm => typing_symbol(db, self.as_str()).unwrap_or_unknown(), Self::NoneType => typeshed_symbol(db, self.as_str()).unwrap_or_unknown(), } } @@ -1419,6 +1457,7 @@ impl<'db> KnownClass { "NoneType" => Some(Self::NoneType), "ModuleType" => Some(Self::ModuleType), "FunctionType" => Some(Self::FunctionType), + "_SpecialForm" => Some(Self::SpecialForm), _ => None, } } @@ -1443,6 +1482,46 @@ impl<'db> KnownClass { | Self::Slice => module.name() == "builtins", Self::GenericAlias | Self::ModuleType | Self::FunctionType => module.name() == "types", Self::NoneType => matches!(module.name().as_str(), "_typeshed" | "types"), + Self::SpecialForm => { + matches!(module.name().as_str(), "typing" | "typing_extensions") + } + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum KnownInstance { + Literal, + // TODO: fill this enum out with more special forms, etc. +} + +impl KnownInstance { + pub const fn as_str(&self) -> &'static str { + match self { + KnownInstance::Literal => "Literal", + } + } + + pub fn maybe_from_module(module: &Module, instance_name: &str) -> Option { + let candidate = Self::from_name(instance_name)?; + candidate.check_module(module).then_some(candidate) + } + + fn from_name(name: &str) -> Option { + match name { + "Literal" => Some(Self::Literal), + _ => None, + } + } + + fn check_module(self, module: &Module) -> bool { + if !module.search_path().is_standard_library() { + return false; + } + match self { + Self::Literal => { + matches!(module.name().as_str(), "typing" | "typing_extensions") + } } } } @@ -1854,6 +1933,10 @@ pub struct ClassType<'db> { #[salsa::tracked] impl<'db> ClassType<'db> { + pub fn to_instance(self) -> Type<'db> { + Type::Instance(InstanceType::anonymous(self)) + } + /// Return `true` if this class represents `known_class` pub fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool { self.known(db) == Some(known_class) @@ -1986,6 +2069,29 @@ fn infer_class_base_type<'db>( } } +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct InstanceType<'db> { + class: ClassType<'db>, + known: Option, +} + +impl<'db> InstanceType<'db> { + pub fn anonymous(class: ClassType<'db>) -> Self { + Self { class, known: None } + } + + pub fn known(class: ClassType<'db>, known: KnownInstance) -> Self { + Self { + class, + known: Some(known), + } + } + + pub fn is_known(&self, known_instance: KnownInstance) -> bool { + self.known == Some(known_instance) + } +} + #[salsa::interned] pub struct UnionType<'db> { /// The union type includes values in any of these types. @@ -2095,6 +2201,13 @@ mod tests { use ruff_python_ast as ast; use test_case::test_case; + #[cfg(target_pointer_width = "64")] + #[test] + fn no_bloat_enum_sizes() { + use std::mem::size_of; + assert_eq!(size_of::(), 16); + } + fn setup_db() -> TestDb { let db = TestDb::new(); diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 5853c44e1d..2ff1f9fad1 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -26,8 +26,7 @@ //! eliminate the supertype from the intersection). //! * An intersection containing two non-overlapping types should simplify to [`Type::Never`]. -use super::KnownClass; -use crate::types::{IntersectionType, Type, UnionType}; +use crate::types::{InstanceType, IntersectionType, KnownClass, Type, UnionType}; use crate::{Db, FxOrderSet}; use smallvec::SmallVec; @@ -247,8 +246,8 @@ impl<'db> InnerIntersectionBuilder<'db> { } } else { // ~Literal[True] & bool = Literal[False] - if let Type::Instance(class_type) = new_positive { - if class_type.is_known(db, KnownClass::Bool) { + if let Type::Instance(InstanceType { class, .. }) = new_positive { + if class.is_known(db, KnownClass::Bool) { if let Some(&Type::BooleanLiteral(value)) = self .negative .iter() diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index bea1f1646f..e69b6921f5 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -6,7 +6,7 @@ use ruff_db::display::FormatterJoinExtension; use ruff_python_ast::str::Quote; use ruff_python_literal::escape::AsciiEscape; -use crate::types::{IntersectionType, KnownClass, Type, UnionType}; +use crate::types::{InstanceType, IntersectionType, KnownClass, Type, UnionType}; use crate::Db; use rustc_hash::FxHashMap; @@ -64,7 +64,9 @@ impl Display for DisplayRepresentation<'_> { Type::Any => f.write_str("Any"), Type::Never => f.write_str("Never"), Type::Unknown => f.write_str("Unknown"), - Type::Instance(class) if class.is_known(self.db, KnownClass::NoneType) => { + Type::Instance(InstanceType { class, .. }) + if class.is_known(self.db, KnownClass::NoneType) => + { f.write_str("None") } // `[Type::Todo]`'s display should be explicit that is not a valid display of @@ -75,7 +77,10 @@ impl Display for DisplayRepresentation<'_> { } // TODO functions and classes should display using a fully qualified name Type::ClassLiteral(class) => f.write_str(class.name(self.db)), - Type::Instance(class) => f.write_str(class.name(self.db)), + Type::Instance(InstanceType { class, known }) => f.write_str(match known { + Some(super::KnownInstance::Literal) => "Literal", + _ => class.name(self.db), + }), Type::FunctionLiteral(function) => f.write_str(function.name(self.db)), Type::Union(union) => union.display(self.db).fmt(f), Type::Intersection(intersection) => intersection.display(self.db).fmt(f), diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 1216a3cc43..c0eac9563b 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -54,9 +54,9 @@ use crate::types::diagnostic::{ use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ bindings_ty, builtins_symbol, declarations_ty, global_symbol, symbol, typing_extensions_symbol, - Boundness, BytesLiteralType, ClassType, FunctionType, IterationOutcome, KnownClass, - KnownFunction, SliceLiteralType, StringLiteralType, Symbol, Truthiness, TupleType, Type, - TypeArrayDisplay, UnionBuilder, UnionType, + Boundness, BytesLiteralType, ClassType, FunctionType, InstanceType, IterationOutcome, + KnownClass, KnownFunction, KnownInstance, SliceLiteralType, StringLiteralType, Symbol, + Truthiness, TupleType, Type, TypeArrayDisplay, UnionBuilder, UnionType, }; use crate::unpack::Unpack; use crate::util::subscript::{PyIndex, PySlice}; @@ -611,10 +611,10 @@ impl<'db> TypeInferenceBuilder<'db> { fn check_division_by_zero(&mut self, expr: &ast::ExprBinOp, left: Type<'db>) { match left { Type::BooleanLiteral(_) | Type::IntLiteral(_) => {} - Type::Instance(cls) + Type::Instance(InstanceType { class, .. }) if [KnownClass::Float, KnownClass::Int, KnownClass::Bool] .iter() - .any(|&k| cls.is_known(self.db, k)) => {} + .any(|&k| class.is_known(self.db, k)) => {} _ => return, }; @@ -959,7 +959,7 @@ impl<'db> TypeInferenceBuilder<'db> { .node_scope(NodeWithScopeRef::Class(class)) .to_scope_id(self.db, self.file); - let maybe_known_class = file_to_module(self.db, body_scope.file(self.db)) + let maybe_known_class = file_to_module(self.db, self.file) .as_ref() .and_then(|module| KnownClass::maybe_from_module(module, name.as_str())); @@ -1273,12 +1273,12 @@ impl<'db> TypeInferenceBuilder<'db> { // anything else is invalid and should lead to a diagnostic being reported --Alex match node_ty { Type::Any | Type::Unknown => node_ty, - Type::ClassLiteral(class_ty) => Type::Instance(class_ty), + Type::ClassLiteral(class_ty) => class_ty.to_instance(), Type::Tuple(tuple) => UnionType::from_elements( self.db, tuple.elements(self.db).iter().map(|ty| { ty.into_class_literal_type() - .map_or(Type::Todo, Type::Instance) + .map_or(Type::Todo, ClassType::to_instance) }), ), _ => Type::Todo, @@ -1472,7 +1472,23 @@ impl<'db> TypeInferenceBuilder<'db> { simple: _, } = assignment; - let annotation_ty = self.infer_annotation_expression(annotation); + let mut annotation_ty = self.infer_annotation_expression(annotation); + + // If the declared variable is annotated with _SpecialForm class then we treat it differently + // by assigning the known field to the instance. + if let Type::Instance(InstanceType { class, .. }) = annotation_ty { + if class.is_known(self.db, KnownClass::SpecialForm) { + if let Some(name_expr) = target.as_name_expr() { + let maybe_known_instance = file_to_module(self.db, self.file) + .as_ref() + .and_then(|module| KnownInstance::maybe_from_module(module, &name_expr.id)); + if let Some(known_instance) = maybe_known_instance { + annotation_ty = Type::Instance(InstanceType::known(class, known_instance)); + } + } + } + } + if let Some(value) = value { let value_ty = self.infer_expression(value); self.add_declaration_with_binding( @@ -1512,7 +1528,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_augmented_op(assignment, target_type, value_type) }) } - Type::Instance(class) => { + Type::Instance(InstanceType { class, .. }) => { if let Symbol::Type(class_member, boundness) = class.class_member(self.db, op.in_place_dunder()) { @@ -2655,7 +2671,10 @@ impl<'db> TypeInferenceBuilder<'db> { (UnaryOp::Not, ty) => ty.bool(self.db).negate().into_type(self.db), (_, Type::Any) => Type::Any, (_, Type::Unknown) => Type::Unknown, - (op @ (UnaryOp::UAdd | UnaryOp::USub | UnaryOp::Invert), Type::Instance(class)) => { + ( + op @ (UnaryOp::UAdd | UnaryOp::USub | UnaryOp::Invert), + Type::Instance(InstanceType { class, .. }), + ) => { let unary_dunder_method = match op { UnaryOp::Invert => "__invert__", UnaryOp::UAdd => "__pos__", @@ -2901,7 +2920,15 @@ impl<'db> TypeInferenceBuilder<'db> { op, ), - (Type::Instance(left_class), Type::Instance(right_class), op) => { + ( + Type::Instance(InstanceType { + class: left_class, .. + }), + Type::Instance(InstanceType { + class: right_class, .. + }), + op, + ) => { if left_class != right_class && right_class.is_subclass_of(self.db, left_class) { let reflected_dunder = op.reflected_dunder(); let rhs_reflected = right_class.class_member(self.db, reflected_dunder); @@ -2949,11 +2976,9 @@ impl<'db> TypeInferenceBuilder<'db> { }) } - ( - Type::BooleanLiteral(b1), - Type::BooleanLiteral(b2), - ruff_python_ast::Operator::BitOr, - ) => Some(Type::BooleanLiteral(b1 | b2)), + (Type::BooleanLiteral(b1), Type::BooleanLiteral(b2), ast::Operator::BitOr) => { + Some(Type::BooleanLiteral(b1 | b2)) + } (Type::BooleanLiteral(bool_value), right, op) => self.infer_binary_expression_type( Type::IntLiteral(i64::from(bool_value)), @@ -3312,7 +3337,14 @@ impl<'db> TypeInferenceBuilder<'db> { } // Lookup the rich comparison `__dunder__` methods on instances - (Type::Instance(left_class), Type::Instance(right_class)) => { + ( + Type::Instance(InstanceType { + class: left_class, .. + }), + Type::Instance(InstanceType { + class: right_class, .. + }), + ) => { let rich_comparison = |op| perform_rich_comparison(self.db, left_class, right_class, op); let membership_test_comparison = @@ -3653,7 +3685,9 @@ impl<'db> TypeInferenceBuilder<'db> { Err(_) => SliceArg::Unsupported, }, Some(Type::BooleanLiteral(b)) => SliceArg::Arg(Some(i32::from(b))), - Some(Type::Instance(class)) if class.is_known(self.db, KnownClass::NoneType) => { + Some(Type::Instance(InstanceType { class, .. })) + if class.is_known(self.db, KnownClass::NoneType) => + { SliceArg::Arg(None) } None => SliceArg::Arg(None), @@ -3744,8 +3778,6 @@ impl<'db> TypeInferenceBuilder<'db> { impl<'db> TypeInferenceBuilder<'db> { fn infer_type_expression(&mut self, expression: &ast::Expr) -> Type<'db> { // https://typing.readthedocs.io/en/latest/spec/annotations.html#grammar-token-expression-grammar-type_expression - // TODO: this does not include any of the special forms, and is only a - // stub of the forms other than a standalone name in scope. let ty = match expression { ast::Expr::Name(name) => { @@ -3792,9 +3824,7 @@ impl<'db> TypeInferenceBuilder<'db> { { self.infer_tuple_type_expression(slice) } else { - self.infer_type_expression(slice); - // TODO: many other kinds of subscripts - Type::Todo + self.infer_subscript_type_expression(subscript, value_ty) } } @@ -3966,6 +3996,121 @@ impl<'db> TypeInferenceBuilder<'db> { } } } + + fn infer_subscript_type_expression( + &mut self, + subscript: &ast::ExprSubscript, + value_ty: Type<'db>, + ) -> Type<'db> { + let ast::ExprSubscript { + range: _, + value: _, + slice, + ctx: _, + } = subscript; + + match value_ty { + Type::Instance(InstanceType { + class: _, + known: Some(known_instance), + }) => self.infer_parameterized_known_instance_type_expression(known_instance, slice), + _ => { + self.infer_type_expression(slice); + Type::Todo // TODO: generics + } + } + } + + fn infer_parameterized_known_instance_type_expression( + &mut self, + known_instance: KnownInstance, + parameters: &ast::Expr, + ) -> Type<'db> { + match known_instance { + KnownInstance::Literal => match self.infer_literal_parameter_type(parameters) { + Ok(ty) => ty, + Err(nodes) => { + for node in nodes { + self.diagnostics.add( + node.into(), + "invalid-literal-parameter", + format_args!( + "Type arguments for `Literal` must be `None`, \ + a literal value (int, bool, str, or bytes), or an enum value" + ), + ); + } + Type::Unknown + } + }, + } + } + + fn infer_literal_parameter_type<'ast>( + &mut self, + parameters: &'ast ast::Expr, + ) -> Result, Vec<&'ast ast::Expr>> { + Ok(match parameters { + // TODO handle type aliases + ast::Expr::Subscript(ast::ExprSubscript { value, slice, .. }) => { + let value_ty = self.infer_expression(value); + if matches!( + value_ty, + Type::Instance(InstanceType { + known: Some(KnownInstance::Literal), + .. + }) + ) { + self.infer_literal_parameter_type(slice)? + } else { + return Err(vec![parameters]); + } + } + ast::Expr::Tuple(tuple) if !tuple.parenthesized => { + let mut errors = vec![]; + let mut builder = UnionBuilder::new(self.db); + for elt in tuple { + match self.infer_literal_parameter_type(elt) { + Ok(ty) => { + builder = builder.add(ty); + } + Err(nodes) => { + errors.extend(nodes); + } + } + } + if errors.is_empty() { + builder.build() + } else { + return Err(errors); + } + } + + ast::Expr::StringLiteral(literal) => self.infer_string_literal_expression(literal), + ast::Expr::BytesLiteral(literal) => self.infer_bytes_literal_expression(literal), + ast::Expr::BooleanLiteral(literal) => self.infer_boolean_literal_expression(literal), + // For enum values + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { + let value_ty = self.infer_expression(value); + // TODO: Check that value type is enum otherwise return None + value_ty.member(self.db, &attr.id).unwrap_or_unknown() + } + ast::Expr::NoneLiteral(_) => Type::none(self.db), + // for negative and positive numbers + ast::Expr::UnaryOp(ref u) + if matches!(u.op, UnaryOp::USub | UnaryOp::UAdd) + && u.operand.is_number_literal_expr() => + { + self.infer_unary_expression(u) + } + ast::Expr::NumberLiteral(ref number) if number.value.is_int() => { + self.infer_number_literal_expression(number) + } + _ => { + return Err(vec![parameters]); + } + }) + } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -4131,10 +4276,7 @@ fn perform_rich_comparison<'db>( |op: RichCompareOperator, left_class: ClassType<'db>, right_class: ClassType<'db>| { match left_class.class_member(db, op.dunder()) { Symbol::Type(class_member_dunder, Boundness::Bound) => class_member_dunder - .call( - db, - &[Type::Instance(left_class), Type::Instance(right_class)], - ) + .call(db, &[left_class.to_instance(), right_class.to_instance()]) .return_ty(db), _ => None, } @@ -4160,8 +4302,8 @@ fn perform_rich_comparison<'db>( }) .ok_or_else(|| CompareUnsupportedError { op: op.into(), - left_ty: Type::Instance(left_class), - right_ty: Type::Instance(right_class), + left_ty: left_class.to_instance(), + right_ty: right_class.to_instance(), }) } @@ -4175,7 +4317,7 @@ fn perform_membership_test_comparison<'db>( right_class: ClassType<'db>, op: MembershipTestCompareOperator, ) -> Result, CompareUnsupportedError<'db>> { - let (left_instance, right_instance) = (Type::Instance(left_class), Type::Instance(right_class)); + let (left_instance, right_instance) = (left_class.to_instance(), right_class.to_instance()); let contains_dunder = right_class.class_member(db, "__contains__"); let compare_result_opt = match contains_dunder { diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index e088ab6d28..235ed1e3c3 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -88,7 +88,7 @@ fn generate_isinstance_constraint<'db>( classinfo: &Type<'db>, ) -> Option> { match classinfo { - Type::ClassLiteral(class) => Some(Type::Instance(*class)), + Type::ClassLiteral(class) => Some(class.to_instance()), Type::Tuple(tuple) => { let mut builder = UnionBuilder::new(db); for element in tuple.elements(db) {