diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md index da5678f310..c13b7dd4f5 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/match.md @@ -16,3 +16,48 @@ def _(flag: bool): reveal_type(y) # revealed: Literal[0] | None ``` + +## Class patterns + +```py +def get_object() -> object: ... + +class A: ... +class B: ... + +x = get_object() + +reveal_type(x) # revealed: object + +match x: + case A(): + reveal_type(x) # revealed: A + case B(): + # TODO could be `B & ~A` + reveal_type(x) # revealed: B + +reveal_type(x) # revealed: object +``` + +## Class pattern with guard + +```py +def get_object() -> object: ... + +class A: + def y() -> int: ... + +class B: ... + +x = get_object() + +reveal_type(x) # revealed: object + +match x: + case A() if reveal_type(x): # revealed: A + pass + case B() if reveal_type(x): # revealed: B + pass + +reveal_type(x) # revealed: object +``` diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 91ce783d69..3e7140d56a 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -404,6 +404,17 @@ impl<'db> SemanticIndexBuilder<'db> { pattern: &ast::Pattern, guard: Option<&ast::Expr>, ) -> Constraint<'db> { + // This is called for the top-level pattern of each match arm. We need to create a + // standalone expression for each arm of a match statement, since they can introduce + // constraints on the match subject. (Or more accurately, for the match arm's pattern, + // since its the pattern that introduces any constraints, not the body.) Ideally, that + // standalone expression would wrap the match arm's pattern as a whole. But a standalone + // expression can currently only wrap an ast::Expr, which patterns are not. So, we need to + // choose an Expr that can “stand in” for the pattern, which we can wrap in a standalone + // expression. + // + // See the comment in TypeInferenceBuilder::infer_match_pattern for more details. + let guard = guard.map(|guard| self.add_standalone_expression(guard)); let kind = match pattern { @@ -414,6 +425,10 @@ impl<'db> SemanticIndexBuilder<'db> { ast::Pattern::MatchSingleton(singleton) => { PatternConstraintKind::Singleton(singleton.value, guard) } + ast::Pattern::MatchClass(pattern) => { + let cls = self.add_standalone_expression(&pattern.cls); + PatternConstraintKind::Class(cls, guard) + } _ => PatternConstraintKind::Unsupported, }; @@ -1089,37 +1104,35 @@ where cases, range: _, }) => { + debug_assert_eq!(self.current_match_case, None); + let subject_expr = self.add_standalone_expression(subject); self.visit_expr(subject); - - let after_subject = self.flow_snapshot(); - let Some((first, remaining)) = cases.split_first() else { + if cases.is_empty() { return; }; - let first_constraint_id = self.add_pattern_constraint( - subject_expr, - &first.pattern, - first.guard.as_deref(), - ); - - self.visit_match_case(first); - - let first_vis_constraint_id = - self.record_visibility_constraint(first_constraint_id); - let mut vis_constraints = vec![first_vis_constraint_id]; - + let after_subject = self.flow_snapshot(); + let mut vis_constraints = vec![]; let mut post_case_snapshots = vec![]; - for case in remaining { - post_case_snapshots.push(self.flow_snapshot()); - self.flow_restore(after_subject.clone()); + for (i, case) in cases.iter().enumerate() { + if i != 0 { + post_case_snapshots.push(self.flow_snapshot()); + self.flow_restore(after_subject.clone()); + } + + self.current_match_case = Some(CurrentMatchCase::new(&case.pattern)); + self.visit_pattern(&case.pattern); + self.current_match_case = None; let constraint_id = self.add_pattern_constraint( subject_expr, &case.pattern, case.guard.as_deref(), ); - self.visit_match_case(case); - + if let Some(expr) = &case.guard { + self.visit_expr(expr); + } + self.visit_body(&case.body); for id in &vis_constraints { self.record_negated_visibility_constraint(*id); } @@ -1538,18 +1551,6 @@ where } } - fn visit_match_case(&mut self, match_case: &'ast ast::MatchCase) { - debug_assert!(self.current_match_case.is_none()); - self.current_match_case = Some(CurrentMatchCase::new(&match_case.pattern)); - self.visit_pattern(&match_case.pattern); - self.current_match_case = None; - - if let Some(expr) = &match_case.guard { - self.visit_expr(expr); - } - self.visit_body(&match_case.body); - } - fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) { if let ast::Pattern::MatchStar(ast::PatternMatchStar { name: Some(name), @@ -1636,6 +1637,7 @@ impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> { } } +#[derive(Debug, PartialEq)] struct CurrentMatchCase<'a> { /// The pattern that's part of the current match case. pattern: &'a ast::Pattern, diff --git a/crates/red_knot_python_semantic/src/semantic_index/constraint.rs b/crates/red_knot_python_semantic/src/semantic_index/constraint.rs index 3e7a8fc668..d99eb44b85 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/constraint.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/constraint.rs @@ -22,6 +22,7 @@ pub(crate) enum ConstraintNode<'db> { pub(crate) enum PatternConstraintKind<'db> { Singleton(Singleton, Option>), Value(Expression<'db>, Option>), + Class(Expression<'db>, Option>), Unsupported, } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 48d5e8cd3f..3eb4f72c1d 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1780,26 +1780,62 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_match_pattern(&mut self, pattern: &ast::Pattern) { + // We need to create a standalone expression for each arm of a match statement, since they + // can introduce constraints on the match subject. (Or more accurately, for the match arm's + // pattern, since its the pattern that introduces any constraints, not the body.) Ideally, + // that standalone expression would wrap the match arm's pattern as a whole. But a + // standalone expression can currently only wrap an ast::Expr, which patterns are not. So, + // we need to choose an Expr that can “stand in” for the pattern, which we can wrap in a + // standalone expression. + // + // That said, when inferring the type of a standalone expression, we don't have access to + // its parent or sibling nodes. That means, for instance, that in a class pattern, where + // we are currently using the class name as the standalone expression, we do not have + // access to the class pattern's arguments in the standalone expression inference scope. + // At the moment, we aren't trying to do anything with those arguments when creating a + // narrowing constraint for the pattern. But in the future, if we do, we will have to + // either wrap those arguments in their own standalone expressions, or update Expression to + // be able to wrap other AST node types besides just ast::Expr. + // + // This function is only called for the top-level pattern of a match arm, and is + // responsible for inferring the standalone expression for each supported pattern type. It + // then hands off to `infer_nested_match_pattern` for any subexpressions and subpatterns, + // where we do NOT have any additional standalone expressions to infer through. + // // TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against // the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510 match pattern { ast::Pattern::MatchValue(match_value) => { self.infer_standalone_expression(&match_value.value); } + ast::Pattern::MatchClass(match_class) => { + let ast::PatternMatchClass { + range: _, + cls, + arguments, + } = match_class; + for pattern in &arguments.patterns { + self.infer_nested_match_pattern(pattern); + } + for keyword in &arguments.keywords { + self.infer_nested_match_pattern(&keyword.pattern); + } + self.infer_standalone_expression(cls); + } _ => { - self.infer_match_pattern_impl(pattern); + self.infer_nested_match_pattern(pattern); } } } - fn infer_match_pattern_impl(&mut self, pattern: &ast::Pattern) { + fn infer_nested_match_pattern(&mut self, pattern: &ast::Pattern) { match pattern { ast::Pattern::MatchValue(match_value) => { self.infer_expression(&match_value.value); } ast::Pattern::MatchSequence(match_sequence) => { for pattern in &match_sequence.patterns { - self.infer_match_pattern_impl(pattern); + self.infer_nested_match_pattern(pattern); } } ast::Pattern::MatchMapping(match_mapping) => { @@ -1813,7 +1849,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expression(key); } for pattern in patterns { - self.infer_match_pattern_impl(pattern); + self.infer_nested_match_pattern(pattern); } } ast::Pattern::MatchClass(match_class) => { @@ -1823,21 +1859,21 @@ impl<'db> TypeInferenceBuilder<'db> { arguments, } = match_class; for pattern in &arguments.patterns { - self.infer_match_pattern_impl(pattern); + self.infer_nested_match_pattern(pattern); } for keyword in &arguments.keywords { - self.infer_match_pattern_impl(&keyword.pattern); + self.infer_nested_match_pattern(&keyword.pattern); } self.infer_expression(cls); } ast::Pattern::MatchAs(match_as) => { if let Some(pattern) = &match_as.pattern { - self.infer_match_pattern_impl(pattern); + self.infer_nested_match_pattern(pattern); } } ast::Pattern::MatchOr(match_or) => { for pattern in &match_or.patterns { - self.infer_match_pattern_impl(pattern); + self.infer_nested_match_pattern(pattern); } } ast::Pattern::MatchStar(_) | ast::Pattern::MatchSingleton(_) => {} diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 7db4f4e5b3..ad81107c54 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -233,6 +233,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> { PatternConstraintKind::Singleton(singleton, _guard) => { self.evaluate_match_pattern_singleton(*subject, *singleton) } + PatternConstraintKind::Class(cls, _guard) => { + self.evaluate_match_pattern_class(*subject, *cls) + } // TODO: support more pattern kinds PatternConstraintKind::Value(..) | PatternConstraintKind::Unsupported => None, } @@ -486,6 +489,27 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } } + fn evaluate_match_pattern_class( + &mut self, + subject: Expression<'db>, + cls: Expression<'db>, + ) -> Option> { + if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() { + // SAFETY: we should always have a symbol for every Name node. + let symbol = self.symbols().symbol_id_by_name(id).unwrap(); + let scope = self.scope(); + let inference = infer_expression_types(self.db, cls); + let ty = inference + .expression_ty(cls.node_ref(self.db).scoped_expression_id(self.db, scope)) + .to_instance(self.db); + let mut constraints = NarrowingConstraints::default(); + constraints.insert(symbol, ty); + Some(constraints) + } else { + None + } + } + fn evaluate_bool_op( &mut self, expr_bool_op: &ExprBoolOp, diff --git a/crates/red_knot_python_semantic/src/visibility_constraints.rs b/crates/red_knot_python_semantic/src/visibility_constraints.rs index 3c2a8a1e18..2bdc2d3484 100644 --- a/crates/red_knot_python_semantic/src/visibility_constraints.rs +++ b/crates/red_knot_python_semantic/src/visibility_constraints.rs @@ -329,9 +329,9 @@ impl<'db> VisibilityConstraints<'db> { Truthiness::Ambiguous } } - PatternConstraintKind::Singleton(..) | PatternConstraintKind::Unsupported => { - Truthiness::Ambiguous - } + PatternConstraintKind::Singleton(..) + | PatternConstraintKind::Class(..) + | PatternConstraintKind::Unsupported => Truthiness::Ambiguous, }, } }