diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/in.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/in.md new file mode 100644 index 0000000000..dad0374702 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/in.md @@ -0,0 +1,80 @@ +# Narrowing for `in` conditionals + +## `in` for tuples + +```py +def _(x: int): + if x in (1, 2, 3): + reveal_type(x) # revealed: int + else: + reveal_type(x) # revealed: int +``` + +```py +def _(x: str): + if x in ("a", "b", "c"): + reveal_type(x) # revealed: str + else: + reveal_type(x) # revealed: str +``` + +```py +from typing import Literal + +def _(x: Literal[1, 2, "a", "b", False, b"abc"]): + if x in (1,): + reveal_type(x) # revealed: Literal[1] + elif x in (2, "a"): + reveal_type(x) # revealed: Literal[2, "a"] + elif x in (b"abc",): + reveal_type(x) # revealed: Literal[b"abc"] + elif x not in (3,): + reveal_type(x) # revealed: Literal["b", False] + else: + reveal_type(x) # revealed: Never +``` + +```py +def _(x: Literal["a", "b", "c", 1]): + if x in ("a", "b", "c", 2): + reveal_type(x) # revealed: Literal["a", "b", "c"] + else: + reveal_type(x) # revealed: Literal[1] +``` + +## `in` for `str` and literal strings + +```py +def _(x: str): + if x in "abc": + reveal_type(x) # revealed: str + else: + reveal_type(x) # revealed: str +``` + +```py +from typing import Literal + +def _(x: Literal["a", "b", "c", "d"]): + if x in "abc": + reveal_type(x) # revealed: Literal["a", "b", "c"] + else: + reveal_type(x) # revealed: Literal["d"] +``` + +```py +def _(x: Literal["a", "b", "c", "e"]): + if x in "abcd": + reveal_type(x) # revealed: Literal["a", "b", "c"] + else: + reveal_type(x) # revealed: Literal["e"] +``` + +```py +def _(x: Literal[1, "a", "b", "c", "d"]): + # error: [unsupported-operator] + if x in "abc": + reveal_type(x) # revealed: Literal["a", "b", "c"] + else: + reveal_type(x) # revealed: Literal[1, "d"] +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index d5ed2b6d77..48da634620 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -408,6 +408,11 @@ impl<'db> Type<'db> { matches!(self, Type::FunctionLiteral(..)) } + pub fn is_union_of_single_valued(&self, db: &'db dyn Db) -> bool { + self.into_union() + .is_some_and(|union| union.elements(db).iter().all(|ty| ty.is_single_valued(db))) + } + pub const fn into_int_literal(self) -> Option { match self { Type::IntLiteral(value) => Some(value), @@ -422,6 +427,10 @@ impl<'db> Type<'db> { } } + pub fn is_string_literal(&self) -> bool { + matches!(self, Type::StringLiteral(..)) + } + #[track_caller] pub fn expect_int_literal(self) -> i64 { self.into_int_literal() @@ -5403,6 +5412,14 @@ impl<'db> StringLiteralType<'db> { pub fn python_len(&self, db: &'db dyn Db) -> usize { self.value(db).chars().count() } + + /// Return an iterator over each character in the string literal. + /// as would be returned by Python's `iter()`. + pub fn iter_each_char(&self, db: &'db dyn Db) -> impl Iterator { + self.value(db) + .chars() + .map(|c| StringLiteralType::new(db, c.to_string().as_str())) + } } #[salsa::interned(debug)] diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index ccf1816db0..ceaec3d3c1 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -19,6 +19,8 @@ use rustc_hash::FxHashMap; use std::collections::hash_map::Entry; use std::sync::Arc; +use super::UnionType; + /// Return the type constraint that `test` (if true) would place on `definition`, if any. /// /// For example, if we have this code: @@ -288,6 +290,28 @@ impl<'db> NarrowingConstraintsBuilder<'db> { NarrowingConstraints::from_iter([(symbol, ty)]) } + fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option> { + if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) { + match rhs_ty { + Type::Tuple(rhs_tuple) => Some(UnionType::from_elements( + self.db, + rhs_tuple.elements(self.db), + )), + + Type::StringLiteral(string_literal) => Some(UnionType::from_elements( + self.db, + string_literal + .iter_each_char(self.db) + .map(Type::StringLiteral), + )), + + _ => None, + } + } else { + None + } + } + fn evaluate_expr_compare( &mut self, expr_compare: &ast::ExprCompare, @@ -371,6 +395,16 @@ impl<'db> NarrowingConstraintsBuilder<'db> { ast::CmpOp::Eq if lhs_ty.is_literal_string() => { constraints.insert(symbol, rhs_ty); } + ast::CmpOp::In => { + if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) { + constraints.insert(symbol, ty); + } + } + ast::CmpOp::NotIn => { + if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) { + constraints.insert(symbol, ty.negate(self.db)); + } + } _ => { // TODO other comparison types }