Compare commits

...

4 Commits

Author SHA1 Message Date
Charlie Marsh
e14cf84e33 Add variable-length tuples 2025-12-30 17:54:59 -05:00
Charlie Marsh
8e01d73d90 Minor tweaks 2025-12-30 17:54:59 -05:00
Charlie Marsh
1b799e77f5 Add tuple intersection 2025-12-30 17:54:59 -05:00
Charlie Marsh
e7e87e86cb [ty] Add support for narrowing on tuple match cases 2025-12-30 17:54:59 -05:00
9 changed files with 582 additions and 6 deletions

View File

@@ -442,3 +442,123 @@ def _(x: tuple[Literal["tag1"], A] | tuple[str, B]):
# But we *can* narrow with inequality
reveal_type(x) # revealed: tuple[str, B]
```
## Sequence patterns
Sequence patterns narrow tuple element types based on the patterns matched against each element.
```py
def _(subj: tuple[int | str, int | str]):
match subj:
case (x, str()):
reveal_type(subj) # revealed: tuple[int | str, str]
case (int(), y):
# After first case, subject is tuple[int | str, int] (second element can't be str)
reveal_type(subj) # revealed: tuple[int, int]
def _(subj: tuple[int | str, int | str]):
match subj:
case (int(), str()):
reveal_type(subj) # revealed: tuple[int, str]
def _(subj: tuple[int | str | None, int | str | None]):
match subj:
case (None, _):
reveal_type(subj) # revealed: tuple[None, int | str | None]
case (_, None):
# After first case, subject is tuple[int | str, int | str | None] (first element can't be None)
reveal_type(subj) # revealed: tuple[int | str, None]
```
## Sequence patterns with nested tuples
```py
def _(subj: tuple[tuple[int | str, int], int | str]):
match subj:
case ((str(), _), _):
# The inner tuple is narrowed by intersecting element-wise with the pattern's constraint
# tuple[int | str, int] & tuple[str, object] -> tuple[str, int]
reveal_type(subj) # revealed: tuple[tuple[str, int], int | str]
```
## Sequence patterns with or patterns
```py
def _(subj: tuple[int | str | bytes, int | str]):
match subj:
case (int() | str(), _):
reveal_type(subj) # revealed: tuple[int | str, int | str]
```
## Sequence patterns with wildcards
Wildcards (`_`) and name patterns don't narrow the element type.
```py
def _(subj: tuple[int | str, int | str]):
match subj:
case (_, _):
reveal_type(subj) # revealed: tuple[int | str, int | str]
def _(subj: tuple[int | str, int | str]):
match subj:
case (x, y):
reveal_type(subj) # revealed: tuple[int | str, int | str]
```
## Sequence pattern negative narrowing
When a sequence pattern doesn't match, the type is narrowed by subtracting the pattern type. The
type system simplifies `tuple[A, B] & ~tuple[C, D]` to `tuple[A & ~C, B] | tuple[A, B & ~D]`.
```py
def _(subj: tuple[int | str, int | str]):
match subj:
case (int(), int()):
reveal_type(subj) # revealed: tuple[int, int]
case _:
# tuple[int | str, int | str] & ~tuple[int, int]
# = tuple[str, int | str] | tuple[int | str, str]
reveal_type(subj) # revealed: tuple[str, int | str] | tuple[int | str, str]
def _(subj: tuple[int | str, int | str]):
match subj:
case (x, str()):
reveal_type(subj) # revealed: tuple[int | str, str]
case y:
# tuple[int | str, int | str] & ~tuple[object, str]
# First element: (int | str) & ~object = Never (can't differ here)
# Second element: (int | str) & ~str = int
# Result: tuple[int | str, int]
reveal_type(subj) # revealed: tuple[int | str, int]
```
## Sequence pattern exhaustiveness
When a sequence pattern exhaustively matches all possible tuple values, subsequent cases should be
unreachable (`Never`).
```py
def _(subj: tuple[int, str]):
match subj:
case (int(), str()):
reveal_type(subj) # revealed: tuple[int, str]
case _:
reveal_type(subj) # revealed: Never
```
## Sequence patterns with homogeneous tuples
Sequence patterns on homogeneous tuples narrow to a fixed-length tuple with the specified length.
```py
def _(subj: tuple[int | str, ...]):
match subj:
case (x, str()):
reveal_type(subj) # revealed: tuple[int | str, str]
def _(subj: tuple[int | str, ...]):
match subj:
case (int(), int(), y):
reveal_type(subj) # revealed: tuple[int, int, int | str]
```

View File

@@ -429,6 +429,10 @@ class Foo: ...
class Bar: ...
def test4(val: Intersection[tuple[Foo], tuple[Bar]]):
# TODO: should be `Foo & Bar`
reveal_type(val[0]) # revealed: @Todo(Subscript expressions on intersections)
# Intersection of tuples is simplified element-wise
reveal_type(val[0]) # revealed: Foo & Bar
def test5(val: Intersection[tuple[Foo, ...], tuple[Bar, ...]]):
# Intersection of homogeneous variable tuples is simplified element-wise
reveal_type(val[0]) # revealed: Foo & Bar
```

View File

@@ -948,6 +948,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
.map(|p| Box::new(self.predicate_kind(p))),
pattern.name.as_ref().map(|name| name.id.clone()),
),
ast::Pattern::MatchSequence(pattern) => {
let predicates = pattern
.patterns
.iter()
.map(|pattern| self.predicate_kind(pattern))
.collect();
PatternPredicateKind::Sequence(predicates)
}
_ => PatternPredicateKind::Unsupported,
}
}

View File

@@ -137,6 +137,7 @@ pub(crate) enum PatternPredicateKind<'db> {
Or(Vec<PatternPredicateKind<'db>>),
Class(Expression<'db>, ClassPatternKind),
As(Option<Box<PatternPredicateKind<'db>>>, Option<Name>),
Sequence(Vec<PatternPredicateKind<'db>>),
Unsupported,
}

View File

@@ -208,8 +208,8 @@ use crate::semantic_index::predicate::{
Predicates, ScopedPredicateId,
};
use crate::types::{
CallableTypes, IntersectionBuilder, Truthiness, Type, TypeContext, UnionBuilder, UnionType,
infer_expression_type, static_expression_truthiness,
CallableTypes, IntersectionBuilder, Truthiness, TupleSpec, Type, TypeContext, UnionBuilder,
UnionType, infer_expression_type, static_expression_truthiness,
};
/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
@@ -348,6 +348,13 @@ fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>)
.as_deref()
.map(|p| pattern_kind_to_type(db, p))
.unwrap_or_else(Type::object),
PatternPredicateKind::Sequence(patterns) => {
let elements: Vec<_> = patterns
.iter()
.map(|p| pattern_kind_to_type(db, p))
.collect();
Type::heterogeneous_tuple(db, elements)
}
PatternPredicateKind::Unsupported => Type::Never,
}
}
@@ -852,6 +859,52 @@ impl ReachabilityConstraints {
.as_deref()
.map(|p| Self::analyze_single_pattern_predicate_kind(db, p, subject_ty))
.unwrap_or(Truthiness::AlwaysTrue),
PatternPredicateKind::Sequence(patterns) => {
// Check if the subject is a tuple with matching length.
let tuple_spec = match subject_ty {
Type::NominalInstance(instance) => instance.tuple_spec(db),
_ => None,
};
let Some(tuple_spec) = tuple_spec else {
// Subject is not a tuple type; can't determine if it matches.
return Truthiness::Ambiguous;
};
match tuple_spec.as_ref() {
TupleSpec::Fixed(fixed) => {
if fixed.len() != patterns.len() {
// Length mismatch; pattern definitely can't match.
return Truthiness::AlwaysFalse;
}
// Check each element pattern against its corresponding element type.
let mut result = Truthiness::AlwaysTrue;
for (element_ty, pattern) in
fixed.all_elements().iter().zip(patterns.iter())
{
let element_result = Self::analyze_single_pattern_predicate_kind(
db,
pattern,
*element_ty,
);
match element_result {
Truthiness::AlwaysFalse => return Truthiness::AlwaysFalse,
Truthiness::Ambiguous => result = Truthiness::Ambiguous,
Truthiness::AlwaysTrue => {}
}
}
result
}
TupleSpec::Variable(_) => {
// Variable-length tuples could match patterns of various lengths.
Truthiness::Ambiguous
}
}
}
PatternPredicateKind::Unsupported => Truthiness::Ambiguous,
}
}

View File

@@ -35,6 +35,7 @@ pub(crate) use self::infer::{
pub use self::signatures::ParameterKind;
pub(crate) use self::signatures::{CallableSignature, Signature};
pub(crate) use self::subclass_of::{SubclassOfInner, SubclassOfType};
pub(crate) use self::tuple::TupleSpec;
pub use crate::diagnostic::add_inferred_python_version_hint_to_diagnostic;
use crate::place::{
Definedness, Place, PlaceAndQualifiers, TypeOrigin, Widening, builtins_module_scope,
@@ -70,7 +71,7 @@ pub(crate) use crate::types::narrow::{NarrowingConstraint, infer_narrowing_const
use crate::types::newtype::NewType;
pub(crate) use crate::types::signatures::{Parameter, Parameters};
use crate::types::signatures::{ParameterForm, walk_signature};
use crate::types::tuple::{Tuple, TupleSpec, TupleSpecBuilder};
use crate::types::tuple::{Tuple, TupleSpecBuilder};
pub(crate) use crate::types::typed_dict::{TypedDictParams, TypedDictType, walk_typed_dict_type};
pub use crate::types::variance::TypeVarVariance;
use crate::types::variance::VarianceInferable;

View File

@@ -957,6 +957,23 @@ impl<'db> IntersectionBuilder<'db> {
}
}
/// Result of attempting to intersect two fixed-length tuples element-wise.
enum TupleIntersectionResult<'db> {
/// The new type is not a fixed-length tuple, or no existing tuple matches.
NotApplicable,
/// One tuple is a subtype of the other; let normal subsumption handle it.
SubtypeRelationship,
/// Element-wise intersection resulted in `Never` (disjoint element types).
Never,
/// Successfully computed the element-wise intersection.
Intersected {
/// The resulting tuple type.
tuple: Type<'db>,
/// Index of the existing positive type to replace.
replace_index: usize,
},
}
#[derive(Debug, Clone, Default)]
struct InnerIntersectionBuilder<'db> {
positive: FxOrderSet<Type<'db>>,
@@ -1004,6 +1021,27 @@ impl<'db> InnerIntersectionBuilder<'db> {
return;
}
// Handle tuple intersection: tuple[A, B] & tuple[C, D] -> tuple[A & C, B & D]
match self.try_intersect_tuples(db, new_positive) {
TupleIntersectionResult::Never => {
self.positive.clear();
self.negative.clear();
self.positive.insert(Type::Never);
return;
}
TupleIntersectionResult::Intersected {
tuple,
replace_index,
} => {
self.positive.swap_remove_index(replace_index);
self.positive.insert(tuple);
return;
}
// Not applicable or subtype relationship: continue with normal processing
TupleIntersectionResult::NotApplicable
| TupleIntersectionResult::SubtypeRelationship => {}
}
let addition_is_bool_instance = known_instance == Some(KnownClass::Bool);
for (index, existing_positive) in self.positive.iter().enumerate() {
@@ -1103,6 +1141,106 @@ impl<'db> InnerIntersectionBuilder<'db> {
}
}
/// Try to intersect two tuples element-wise.
///
/// For fixed-length tuples: `tuple[A, B] & tuple[C, D]` -> `tuple[A & C, B & D]`
/// For variable-length tuples: `tuple[A, ...] & tuple[B, ...]` -> `tuple[A & B, ...]`
fn try_intersect_tuples(
&self,
db: &'db dyn Db,
new_positive: Type<'db>,
) -> TupleIntersectionResult<'db> {
use crate::types::tuple::TupleSpec;
let Some(new_instance) = new_positive.as_nominal_instance() else {
return TupleIntersectionResult::NotApplicable;
};
let Some(new_tuple_spec) = new_instance.own_tuple_spec(db) else {
return TupleIntersectionResult::NotApplicable;
};
// Find an existing positive tuple to intersect with.
for (index, existing_positive) in self.positive.iter().enumerate() {
let Some(existing_instance) = existing_positive.as_nominal_instance() else {
continue;
};
let Some(existing_spec) = existing_instance.own_tuple_spec(db) else {
continue;
};
// Optimization: if one tuple is a subtype of the other, the intersection
// is just the subtype, so we can skip element-wise intersection and let
// the normal subsumption logic handle it. This also prevents potential
// infinite recursion when element-wise intersection would recursively
// create the same intersection we're already building.
if new_positive.is_subtype_of(db, *existing_positive)
|| existing_positive.is_subtype_of(db, new_positive)
{
return TupleIntersectionResult::SubtypeRelationship;
}
match (existing_spec.as_ref(), new_tuple_spec.as_ref()) {
(TupleSpec::Fixed(existing_fixed), TupleSpec::Fixed(new_fixed)) => {
if existing_fixed.len() != new_fixed.len() {
continue;
}
let intersected_elements: Vec<_> = existing_fixed
.all_elements()
.iter()
.zip(new_fixed.all_elements())
.map(|(a, b)| {
IntersectionBuilder::new(db)
.add_positive(*a)
.add_positive(*b)
.build()
})
.collect();
if intersected_elements.iter().any(Type::is_never) {
return TupleIntersectionResult::Never;
}
return TupleIntersectionResult::Intersected {
tuple: Type::heterogeneous_tuple(db, intersected_elements),
replace_index: index,
};
}
(TupleSpec::Variable(existing_var), TupleSpec::Variable(new_var)) => {
// Only handle simple homogeneous tuples (no prefix/suffix) for now
if !existing_var.prefix_elements().is_empty()
|| !existing_var.suffix_elements().is_empty()
|| !new_var.prefix_elements().is_empty()
|| !new_var.suffix_elements().is_empty()
{
continue;
}
let intersected_variable = IntersectionBuilder::new(db)
.add_positive(*existing_var.variable_element())
.add_positive(*new_var.variable_element())
.build();
let intersected_tuple = if intersected_variable.is_never() {
Type::empty_tuple(db)
} else {
Type::homogeneous_tuple(db, intersected_variable)
};
return TupleIntersectionResult::Intersected {
tuple: intersected_tuple,
replace_index: index,
};
}
_ => continue,
}
}
TupleIntersectionResult::NotApplicable
}
/// Adds a negative type to this intersection.
fn add_negative(&mut self, db: &'db dyn Db, new_negative: Type<'db>) {
let contains_bool = || {

View File

@@ -260,7 +260,7 @@ impl<'db> NominalInstanceType<'db> {
///
/// I.e., for the type `tuple[int, str]`, this will return the tuple spec `[int, str]`.
/// For a subclass of `tuple[int, str]`, it will return the same tuple spec.
pub(super) fn tuple_spec(&self, db: &'db dyn Db) -> Option<Cow<'db, TupleSpec<'db>>> {
pub(crate) fn tuple_spec(&self, db: &'db dyn Db) -> Option<Cow<'db, TupleSpec<'db>>> {
match self.0 {
NominalInstanceInner::ExactTuple(tuple) => Some(Cow::Borrowed(tuple.tuple(db))),
NominalInstanceInner::NonTuple(class) => {

View File

@@ -32,6 +32,8 @@ use rustc_hash::FxHashMap;
use smallvec::{SmallVec, smallvec};
use std::collections::hash_map::Entry;
use super::tuple::TupleSpec;
/// Return the type constraint that `test` (if true) would place on `symbol`, if any.
///
/// For example, if we have this code:
@@ -595,6 +597,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
PatternPredicateKind::As(pattern, _) => pattern
.as_deref()
.and_then(|p| self.evaluate_pattern_predicate_kind(p, subject, is_positive)),
PatternPredicateKind::Sequence(element_patterns) => {
self.evaluate_match_pattern_sequence(subject, element_patterns, is_positive)
}
PatternPredicateKind::Unsupported => None,
}
}
@@ -1472,6 +1477,252 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
})
}
/// Evaluate a sequence pattern like `case (x, y, z):` or `case [a, b]:`.
///
/// For each element pattern, we narrow the corresponding element of the tuple subject.
fn evaluate_match_pattern_sequence(
&mut self,
subject: Expression<'db>,
element_patterns: &[PatternPredicateKind<'db>],
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
// Get the subject expression's place.
let place_expr = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&place_expr);
// Get the subject's type.
let subject_ty =
infer_same_file_expression_type(self.db, subject, TypeContext::default(), self.module);
// Get the tuple spec, if it's a tuple type.
let tuple_spec = match subject_ty {
Type::NominalInstance(instance) => instance.tuple_spec(self.db)?,
_ => return None,
};
// Check if any element pattern provides narrowing constraints.
let has_any_constraint = element_patterns
.iter()
.any(|pattern| self.pattern_to_type_constraint(pattern).is_some());
// If no element pattern provides constraints (e.g., all wildcards), don't narrow.
if !has_any_constraint {
return None;
}
// For negative narrowing, we compute tuple[A, B] & ~tuple[C, D] directly.
// A value is NOT in tuple[C, D] if either:
// - position 0 is not in C, OR
// - position 1 is not in D
// So: tuple[A, B] & ~tuple[C, D] = tuple[A & ~C, B] | tuple[A, B & ~D]
if !is_positive {
return self.evaluate_match_pattern_sequence_negative(
place,
tuple_spec.as_ref(),
element_patterns,
);
}
// Positive narrowing: narrow each element based on its pattern.
let narrowed_elements: Vec<Type<'db>> = match tuple_spec.as_ref() {
TupleSpec::Fixed(fixed) => {
// Require exact length match for fixed-length tuples.
if fixed.len() != element_patterns.len() {
return None;
}
fixed
.all_elements()
.iter()
.zip(element_patterns.iter())
.map(|(element_ty, pattern)| {
if let Some(constraint_ty) = self.pattern_to_type_constraint(pattern) {
IntersectionBuilder::new(self.db)
.add_positive(*element_ty)
.add_positive(constraint_ty)
.build()
} else {
*element_ty
}
})
.collect()
}
TupleSpec::Variable(variable) => {
// For variable-length tuples, narrow to a fixed-length tuple with the pattern's length.
let pattern_len = element_patterns.len();
let prefix = variable.prefix_elements();
let suffix = variable.suffix_elements();
let prefix_len = prefix.len();
let suffix_len = suffix.len();
if pattern_len < prefix_len + suffix_len {
return None;
}
element_patterns
.iter()
.enumerate()
.map(|(i, pattern)| {
let element_ty = if i < prefix_len {
prefix[i]
} else if i >= pattern_len - suffix_len {
suffix[i - (pattern_len - suffix_len)]
} else {
*variable.variable_element()
};
if let Some(constraint_ty) = self.pattern_to_type_constraint(pattern) {
IntersectionBuilder::new(self.db)
.add_positive(element_ty)
.add_positive(constraint_ty)
.build()
} else {
element_ty
}
})
.collect()
}
};
let narrowed_tuple = Type::heterogeneous_tuple(self.db, narrowed_elements);
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(narrowed_tuple),
)]))
}
/// Handle negative narrowing for sequence patterns.
///
/// For `tuple[A, B] & ~tuple[C, D]`, a value is NOT in `tuple[C, D]` if either:
/// - position 0 is not in C, OR
/// - position 1 is not in D
///
/// So: `tuple[A, B] & ~tuple[C, D]` = `tuple[A & ~C, B] | tuple[A, B & ~D]`
///
/// We compute this directly to avoid stack overflow issues with the intersection builder.
fn evaluate_match_pattern_sequence_negative(
&mut self,
place: ScopedPlaceId,
tuple_spec: &TupleSpec<'db>,
element_patterns: &[PatternPredicateKind<'db>],
) -> Option<NarrowingConstraints<'db>> {
let db = self.db;
// Only handle fixed-length tuples for now
let subject_elements: Vec<Type<'db>> = match tuple_spec {
TupleSpec::Fixed(fixed) => {
if fixed.len() != element_patterns.len() {
return None;
}
fixed.all_elements().to_vec()
}
TupleSpec::Variable(_) => return None,
};
// For unconstrained patterns (wildcards), use `object` so `A & ~object = Never`
// means those positions won't contribute to the union.
let pattern_elements: Vec<Type<'db>> = element_patterns
.iter()
.map(|pattern| {
self.pattern_to_type_constraint(pattern)
.unwrap_or(Type::object())
})
.collect();
// For each position, compute tuple where that position doesn't match the pattern.
// Result: tuple[A & ~C, B] | tuple[A, B & ~D] for tuple[A, B] & ~tuple[C, D]
let mut union_elements: Vec<Type<'db>> = Vec::new();
for i in 0..subject_elements.len() {
let narrowed_element = IntersectionBuilder::new(db)
.add_positive(subject_elements[i])
.add_negative(pattern_elements[i])
.build();
if narrowed_element.is_never() {
continue;
}
let mut new_elements = subject_elements.clone();
new_elements[i] = narrowed_element;
union_elements.push(Type::heterogeneous_tuple(db, new_elements));
}
if union_elements.is_empty() {
return Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(Type::Never),
)]));
}
let narrowed = UnionType::from_elements(db, union_elements);
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::regular(narrowed),
)]))
}
/// Convert a pattern kind to the type it constrains to.
///
/// Returns `None` for patterns that don't constrain the type (like wildcards or name patterns).
fn pattern_to_type_constraint(&self, pattern: &PatternPredicateKind<'db>) -> Option<Type<'db>> {
match pattern {
PatternPredicateKind::Singleton(singleton) => Some(match singleton {
ast::Singleton::None => Type::none(self.db),
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
}),
PatternPredicateKind::Class(cls, _) => {
let class_ty = infer_same_file_expression_type(
self.db,
*cls,
TypeContext::default(),
self.module,
);
match class_ty {
Type::ClassLiteral(class) => {
Some(Type::instance(self.db, class.top_materialization(self.db)))
}
dynamic @ Type::Dynamic(_) => Some(dynamic),
Type::SpecialForm(SpecialFormType::Any) => Some(Type::any()),
_ => None,
}
}
PatternPredicateKind::Value(expr) => Some(infer_same_file_expression_type(
self.db,
*expr,
TypeContext::default(),
self.module,
)),
PatternPredicateKind::Or(patterns) => {
// Union of all pattern constraints.
let elements: Vec<_> = patterns
.iter()
.filter_map(|p| self.pattern_to_type_constraint(p))
.collect();
if elements.is_empty() {
None
} else {
Some(UnionType::from_elements(self.db, elements))
}
}
PatternPredicateKind::As(inner, _) => inner
.as_deref()
.and_then(|p| self.pattern_to_type_constraint(p)),
PatternPredicateKind::Sequence(patterns) => {
// For nested sequences, create a tuple type.
let elements: Vec<_> = patterns
.iter()
.map(|p| self.pattern_to_type_constraint(p).unwrap_or(Type::object()))
.collect();
Some(Type::heterogeneous_tuple(self.db, elements))
}
PatternPredicateKind::Unsupported => None,
}
}
fn evaluate_bool_op(
&mut self,
expr_bool_op: &ExprBoolOp,