From e50fc049abaa563edb4a05551222d7329051f7b7 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Thu, 3 Apr 2025 07:15:33 -0400 Subject: [PATCH] [red-knot] visibility_constraint analysis for match cases (#17077) ## Summary Add visibility constraint analysis for pattern predicate kinds `Singleton`, `Or`, and `Class`. ## Test Plan update conditional/match.md --- .../resources/mdtest/conditional/match.md | 234 ++++++++++++++++++ .../semantic_index/visibility_constraints.rs | 120 +++++++-- 2 files changed, 327 insertions(+), 27 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/conditional/match.md b/crates/red_knot_python_semantic/resources/mdtest/conditional/match.md index 2f0ad24fe1..5037484212 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/conditional/match.md +++ b/crates/red_knot_python_semantic/resources/mdtest/conditional/match.md @@ -44,6 +44,240 @@ def _(target: int): reveal_type(y) # revealed: Literal[2, 3, 4] ``` +## Value match + +A value pattern matches based on equality: the first `case` branch here will be taken if `subject` +is equal to `2`, even if `subject` is not an instance of `int`. We can't know whether `C` here has a +custom `__eq__` implementation that might cause it to compare equal to `2`, so we have to consider +the possibility that the `case` branch might be taken even though the type `C` is disjoint from the +type `Literal[2]`. + +This leads us to infer `Literal[1, 3]` as the type of `y` after the `match` statement, rather than +`Literal[1]`: + +```py +from typing import final + +@final +class C: + pass + +def _(subject: C): + y = 1 + match subject: + case 2: + y = 3 + reveal_type(y) # revealed: Literal[1, 3] +``` + +## Class match + +A `case` branch with a class pattern is taken if the subject is an instance of the given class, and +all subpatterns in the class pattern match. + +```py +from typing import final + +class Foo: + pass + +class FooSub(Foo): + pass + +class Bar: + pass + +@final +class Baz: + pass + +def _(target: FooSub): + y = 1 + + match target: + case Baz(): + y = 2 + case Foo(): + y = 3 + case Bar(): + y = 4 + + reveal_type(y) # revealed: Literal[3] + +def _(target: FooSub): + y = 1 + + match target: + case Baz(): + y = 2 + case Bar(): + y = 3 + case Foo(): + y = 4 + + reveal_type(y) # revealed: Literal[3, 4] + +def _(target: FooSub | str): + y = 1 + + match target: + case Baz(): + y = 2 + case Foo(): + y = 3 + case Bar(): + y = 4 + + reveal_type(y) # revealed: Literal[1, 3, 4] +``` + +## Singleton match + +Singleton patterns are matched based on identity, not equality comparisons or `isinstance()` checks. + +```py +from typing import Literal + +def _(target: Literal[True, False]): + y = 1 + + match target: + case True: + y = 2 + case False: + y = 3 + case None: + y = 4 + + # TODO: with exhaustiveness checking, this should be Literal[2, 3] + reveal_type(y) # revealed: Literal[1, 2, 3] + +def _(target: bool): + y = 1 + + match target: + case True: + y = 2 + case False: + y = 3 + case None: + y = 4 + + # TODO: with exhaustiveness checking, this should be Literal[2, 3] + reveal_type(y) # revealed: Literal[1, 2, 3] + +def _(target: None): + y = 1 + + match target: + case True: + y = 2 + case False: + y = 3 + case None: + y = 4 + + reveal_type(y) # revealed: Literal[4] + +def _(target: None | Literal[True]): + y = 1 + + match target: + case True: + y = 2 + case False: + y = 3 + case None: + y = 4 + + # TODO: with exhaustiveness checking, this should be Literal[2, 4] + reveal_type(y) # revealed: Literal[1, 2, 4] + +# bool is an int subclass +def _(target: int): + y = 1 + + match target: + case True: + y = 2 + case False: + y = 3 + case None: + y = 4 + + reveal_type(y) # revealed: Literal[1, 2, 3] + +def _(target: str): + y = 1 + + match target: + case True: + y = 2 + case False: + y = 3 + case None: + y = 4 + + reveal_type(y) # revealed: Literal[1] +``` + +## Or match + +A `|` pattern matches if any of the subpatterns match. + +```py +from typing import Literal, final + +def _(target: Literal["foo", "baz"]): + y = 1 + + match target: + case "foo" | "bar": + y = 2 + case "baz": + y = 3 + + # TODO: with exhaustiveness, this should be Literal[2, 3] + reveal_type(y) # revealed: Literal[1, 2, 3] + +def _(target: None): + y = 1 + + match target: + case None | 3: + y = 2 + case "foo" | 4 | True: + y = 3 + + reveal_type(y) # revealed: Literal[2] + +@final +class Baz: + pass + +def _(target: int | None | float): + y = 1 + + match target: + case None | 3: + y = 2 + case Baz(): + y = 3 + + reveal_type(y) # revealed: Literal[1, 2] + +def _(target: None | str): + y = 1 + + match target: + case Baz() | True | False: + y = 2 + case int(): + y = 3 + + reveal_type(y) # revealed: Literal[1, 3] +``` + ## Guard with object that implements `__bool__` incorrectly ```py diff --git a/crates/red_knot_python_semantic/src/semantic_index/visibility_constraints.rs b/crates/red_knot_python_semantic/src/semantic_index/visibility_constraints.rs index 80970756d5..2a428d4cdf 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/visibility_constraints.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/visibility_constraints.rs @@ -178,10 +178,11 @@ use std::cmp::Ordering; use ruff_index::{Idx, IndexVec}; use rustc_hash::FxHashMap; +use crate::semantic_index::expression::Expression; use crate::semantic_index::predicate::{ - PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId, + PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId, }; -use crate::types::{infer_expression_type, Truthiness}; +use crate::types::{infer_expression_type, Truthiness, Type}; use crate::Db; /// A ternary formula that defines under what conditions a binding is visible. (A ternary formula @@ -553,37 +554,102 @@ impl VisibilityConstraints { } } + fn analyze_single_pattern_predicate_kind<'db>( + db: &'db dyn Db, + predicate_kind: &PatternPredicateKind<'db>, + subject: Expression<'db>, + ) -> Truthiness { + match predicate_kind { + PatternPredicateKind::Value(value) => { + let subject_ty = infer_expression_type(db, subject); + let value_ty = infer_expression_type(db, *value); + + if subject_ty.is_single_valued(db) { + Truthiness::from(subject_ty.is_equivalent_to(db, value_ty)) + } else { + Truthiness::Ambiguous + } + } + PatternPredicateKind::Singleton(singleton) => { + let subject_ty = infer_expression_type(db, subject); + + let singleton_ty = match singleton { + ruff_python_ast::Singleton::None => Type::none(db), + ruff_python_ast::Singleton::True => Type::BooleanLiteral(true), + ruff_python_ast::Singleton::False => Type::BooleanLiteral(false), + }; + + debug_assert!(singleton_ty.is_singleton(db)); + + if subject_ty.is_equivalent_to(db, singleton_ty) { + Truthiness::AlwaysTrue + } else if subject_ty.is_disjoint_from(db, singleton_ty) { + Truthiness::AlwaysFalse + } else { + Truthiness::Ambiguous + } + } + PatternPredicateKind::Or(predicates) => { + use std::ops::ControlFlow; + let (ControlFlow::Break(truthiness) | ControlFlow::Continue(truthiness)) = + predicates + .iter() + .map(|p| Self::analyze_single_pattern_predicate_kind(db, p, subject)) + // this is just a "max", but with a slight optimization: `AlwaysTrue` is the "greatest" possible element, so we short-circuit if we get there + .try_fold(Truthiness::AlwaysFalse, |acc, next| match (acc, next) { + (Truthiness::AlwaysTrue, _) | (_, Truthiness::AlwaysTrue) => { + ControlFlow::Break(Truthiness::AlwaysTrue) + } + (Truthiness::Ambiguous, _) | (_, Truthiness::Ambiguous) => { + ControlFlow::Continue(Truthiness::Ambiguous) + } + (Truthiness::AlwaysFalse, Truthiness::AlwaysFalse) => { + ControlFlow::Continue(Truthiness::AlwaysFalse) + } + }); + truthiness + } + PatternPredicateKind::Class(class_expr) => { + let subject_ty = infer_expression_type(db, subject); + let class_ty = infer_expression_type(db, *class_expr).to_instance(db); + + class_ty.map_or(Truthiness::Ambiguous, |class_ty| { + if subject_ty.is_subtype_of(db, class_ty) { + Truthiness::AlwaysTrue + } else if subject_ty.is_disjoint_from(db, class_ty) { + Truthiness::AlwaysFalse + } else { + Truthiness::Ambiguous + } + }) + } + PatternPredicateKind::Unsupported => Truthiness::Ambiguous, + } + } + + fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness { + let truthiness = Self::analyze_single_pattern_predicate_kind( + db, + predicate.kind(db), + predicate.subject(db), + ); + + if truthiness == Truthiness::AlwaysTrue && predicate.guard(db).is_some() { + // Fall back to ambiguous, the guard might change the result. + // TODO: actually analyze guard truthiness + Truthiness::Ambiguous + } else { + truthiness + } + } + fn analyze_single(db: &dyn Db, predicate: &Predicate) -> Truthiness { match predicate.node { PredicateNode::Expression(test_expr) => { let ty = infer_expression_type(db, test_expr); ty.bool(db).negate_if(!predicate.is_positive) } - PredicateNode::Pattern(inner) => match inner.kind(db) { - PatternPredicateKind::Value(value) => { - let subject_expression = inner.subject(db); - let subject_ty = infer_expression_type(db, subject_expression); - let value_ty = infer_expression_type(db, *value); - - if subject_ty.is_single_valued(db) { - let truthiness = - Truthiness::from(subject_ty.is_equivalent_to(db, value_ty)); - - if truthiness.is_always_true() && inner.guard(db).is_some() { - // Fall back to ambiguous, the guard might change the result. - Truthiness::Ambiguous - } else { - truthiness - } - } else { - Truthiness::Ambiguous - } - } - PatternPredicateKind::Singleton(..) - | PatternPredicateKind::Class(..) - | PatternPredicateKind::Or(..) - | PatternPredicateKind::Unsupported => Truthiness::Ambiguous, - }, + PredicateNode::Pattern(inner) => Self::analyze_single_pattern_predicate(db, inner), } } }