From 1ef2f73820271c04b60f9f764e1a7ce0facc0501 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 28 Jan 2025 18:32:53 +0000 Subject: [PATCH] any better --- crates/red_knot_python_semantic/src/types.rs | 42 +----- .../src/types/builder.rs | 140 ++++++++++++------ 2 files changed, 100 insertions(+), 82 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 2b87e14b8c..e6e2f9eea7 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -811,29 +811,6 @@ impl<'db> Type<'db> { } } - /// Normalize the type `bool` -> `Literal[True, False]`. - /// - /// Using this method in various type-relational methods - /// ensures that the following invariants hold true: - /// - /// - bool ≡ Literal[True, False] - /// - bool | T ≡ Literal[True, False] | T - /// - bool <: Literal[True, False] - /// - bool | T <: Literal[True, False] | T - /// - Literal[True, False] <: bool - /// - Literal[True, False] | T <: bool | T - #[must_use] - pub fn with_normalized_bools(self, db: &'db dyn Db) -> Self { - match self { - Type::Instance(InstanceType { class }) if class.is_known(db, KnownClass::Bool) => { - Type::normalized_bool(db) - } - // TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`? - // We'd need to rename this method... --Alex - _ => self, - } - } - /// Return a normalized version of `self` in which all unions and intersections are sorted /// according to a canonical order, no matter how "deeply" a union/intersection may be nested. #[must_use] @@ -905,10 +882,12 @@ impl<'db> Type<'db> { (_, Type::Never) => false, (Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => { - Type::normalized_bool(db).is_subtype_of(db, target) + Type::BooleanLiteral(true).is_subtype_of(db, target) + && Type::BooleanLiteral(false).is_subtype_of(db, target) } (_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => { - self.is_boolean_literal() + self.is_subtype_of(db, Type::BooleanLiteral(true)) + || self.is_subtype_of(db, Type::BooleanLiteral(false)) } (Type::Union(union), _) => union @@ -1125,10 +1104,12 @@ impl<'db> Type<'db> { } (Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => { - Type::normalized_bool(db).is_assignable_to(db, target) + Type::BooleanLiteral(true).is_assignable_to(db, target) + && Type::BooleanLiteral(false).is_assignable_to(db, target) } (_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => { - self.is_assignable_to(db, Type::normalized_bool(db)) + self.is_assignable_to(db, Type::BooleanLiteral(false)) + || self.is_assignable_to(db, Type::BooleanLiteral(true)) } // A union is assignable to a type T iff every element of the union is assignable to T. @@ -2409,13 +2390,6 @@ impl<'db> Type<'db> { KnownClass::NoneType.to_instance(db) } - /// The type `Literal[True, False]`, which is exactly equivalent to `bool` - /// (and which `bool` is eagerly normalized to in several situations) - pub fn normalized_bool(db: &'db dyn Db) -> Type<'db> { - const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)]; - Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS))) - } - /// Return the type of `tuple(sys.version_info)`. /// /// This is not exactly the type that `sys.version_info` has at runtime, diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 49ed137a96..a9430f37a1 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -26,14 +26,14 @@ //! eliminate the supertype from the intersection). //! * An intersection containing two non-overlapping types should simplify to [`Type::Never`]. -use crate::types::{IntersectionType, KnownClass, Type, UnionType}; +use crate::types::{InstanceType, IntersectionType, KnownClass, Type, UnionType}; use crate::{Db, FxOrderSet}; use smallvec::SmallVec; pub(crate) struct UnionBuilder<'db> { elements: Vec>, db: &'db dyn Db, - contains_bool_literals: bool, + bool_literals_present: BoolLiteralsPresent, } impl<'db> UnionBuilder<'db> { @@ -41,13 +41,12 @@ impl<'db> UnionBuilder<'db> { Self { db, elements: vec![], - contains_bool_literals: false, + bool_literals_present: BoolLiteralsPresent::Zero, } } /// Adds a type to this union. pub(crate) fn add(mut self, ty: Type<'db>) -> Self { - let ty = ty.with_normalized_bools(self.db); match ty { Type::Union(union) => { let new_elements = union.elements(self.db); @@ -57,6 +56,11 @@ impl<'db> UnionBuilder<'db> { } } Type::Never => {} + Type::Instance(InstanceType { class }) if class.is_known(self.db, KnownClass::Bool) => { + self = self + .add(Type::BooleanLiteral(false)) + .add(Type::BooleanLiteral(true)); + } _ => { let mut to_remove = SmallVec::<[usize; 2]>::new(); let ty_negated = ty.negate(self.db); @@ -79,7 +83,11 @@ impl<'db> UnionBuilder<'db> { return self; } } - self.contains_bool_literals |= ty.is_boolean_literal(); + + if ty.is_boolean_literal() { + self.bool_literals_present.increment(); + } + match to_remove[..] { [] => self.elements.push(ty), [index] => self.elements[index] = ty, @@ -109,14 +117,14 @@ impl<'db> UnionBuilder<'db> { let UnionBuilder { mut elements, db, - contains_bool_literals, + bool_literals_present, } = self; match elements.len() { 0 => Type::Never, 1 => elements[0], _ => { - if contains_bool_literals { + if bool_literals_present.is_two() { let mut element_iter = elements.iter(); if let Some(first_pos) = element_iter.position(Type::is_boolean_literal) { if let Some(second_pos) = element_iter.position(Type::is_boolean_literal) { @@ -135,6 +143,27 @@ impl<'db> UnionBuilder<'db> { } } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum BoolLiteralsPresent { + Zero, + One, + Two, +} + +impl BoolLiteralsPresent { + fn increment(&mut self) { + *self = match self { + BoolLiteralsPresent::Zero => BoolLiteralsPresent::One, + BoolLiteralsPresent::One => BoolLiteralsPresent::Two, + BoolLiteralsPresent::Two => BoolLiteralsPresent::Two, + }; + } + + const fn is_two(self) -> bool { + matches!(self, BoolLiteralsPresent::Two) + } +} + #[derive(Clone)] pub(crate) struct IntersectionBuilder<'db> { // Really this builds a union-of-intersections, because we always keep our set-theoretic types @@ -162,8 +191,18 @@ impl<'db> IntersectionBuilder<'db> { } pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self { - let ty = ty.with_normalized_bools(self.db); - if let Type::Union(union) = ty { + const BOOL_LITERALS: &[Type] = &[Type::BooleanLiteral(false), Type::BooleanLiteral(true)]; + + // Treat `bool` as `Literal[True] | Literal[False]` + let union_elements = match ty { + Type::Union(union) => Some(union.elements(self.db)), + Type::Instance(InstanceType { class }) if class.is_known(self.db, KnownClass::Bool) => { + Some(BOOL_LITERALS) + } + _ => None, + }; + + if let Some(elements) = union_elements { // Distribute ourself over this union: for each union element, clone ourself and // intersect with that union element, then create a new union-of-intersections with all // of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2` @@ -172,8 +211,7 @@ impl<'db> IntersectionBuilder<'db> { // (T2 & T4)`. If `self` is already a union-of-intersections `(T1 & T2) | (T3 & T4)` // and we add `T5 | T6` to it, that flattens all the way out to `(T1 & T2 & T5) | (T1 & // T2 & T6) | (T3 & T4 & T5) ...` -- you get the idea. - union - .elements(self.db) + elements .iter() .map(|elem| self.clone().add_positive(*elem)) .fold(IntersectionBuilder::empty(self.db), |mut builder, sub| { @@ -193,45 +231,51 @@ impl<'db> IntersectionBuilder<'db> { pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self { // See comments above in `add_positive`; this is just the negated version. - let ty = ty.with_normalized_bools(self.db); - - if let Type::Union(union) = ty { - for elem in union.elements(self.db) { - self = self.add_negative(*elem); + match ty { + Type::Union(union) => { + for elem in union.elements(self.db) { + self = self.add_negative(*elem); + } + self } - self - } else if let Type::Intersection(intersection) = ty { - // (A | B) & ~(C & ~D) - // -> (A | B) & (~C | D) - // -> ((A | B) & ~C) | ((A | B) & D) - // i.e. if we have an intersection of positive constraints C - // and negative constraints D, then our new intersection - // is (existing & ~C) | (existing & D) - - let positive_side = intersection - .positive(self.db) - .iter() - // we negate all the positive constraints while distributing - .map(|elem| self.clone().add_negative(*elem)); - - let negative_side = intersection - .negative(self.db) - .iter() - // all negative constraints end up becoming positive constraints - .map(|elem| self.clone().add_positive(*elem)); - - positive_side.chain(negative_side).fold( - IntersectionBuilder::empty(self.db), - |mut builder, sub| { - builder.intersections.extend(sub.intersections); - builder - }, - ) - } else { - for inner in &mut self.intersections { - inner.add_negative(self.db, ty); + Type::Instance(InstanceType { class }) if class.is_known(self.db, KnownClass::Bool) => { + self.add_negative(Type::BooleanLiteral(false)) + .add_negative(Type::BooleanLiteral(true)) + } + Type::Intersection(intersection) => { + // (A | B) & ~(C & ~D) + // -> (A | B) & (~C | D) + // -> ((A | B) & ~C) | ((A | B) & D) + // i.e. if we have an intersection of positive constraints C + // and negative constraints D, then our new intersection + // is (existing & ~C) | (existing & D) + + let positive_side = intersection + .positive(self.db) + .iter() + // we negate all the positive constraints while distributing + .map(|elem| self.clone().add_negative(*elem)); + + let negative_side = intersection + .negative(self.db) + .iter() + // all negative constraints end up becoming positive constraints + .map(|elem| self.clone().add_positive(*elem)); + + positive_side.chain(negative_side).fold( + IntersectionBuilder::empty(self.db), + |mut builder, sub| { + builder.intersections.extend(sub.intersections); + builder + }, + ) + } + _ => { + for inner in &mut self.intersections { + inner.add_negative(self.db, ty); + } + self } - self } }