diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md index 00967654f4..f21ea7257e 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md @@ -58,7 +58,9 @@ reveal_type(c >= d) # revealed: Literal[True] #### Results with Ambiguity ```py -def bool_instance() -> bool: ... +def bool_instance() -> bool: + return True + def int_instance() -> int: return 42 @@ -134,23 +136,158 @@ reveal_type(c >= c) # revealed: Literal[True] #### Non Boolean Rich Comparisons +Rich comparison methods defined in a class affect tuple comparisons as well. Proper type inference +should be possible even in cases where these methods return non-boolean types. + +Note: Tuples use lexicographic comparisons. If the `==` result for all paired elements in the tuple +is True, the comparison then considers the tuple’s length. Regardless of the return type of the +dunder methods, the final result can still be a boolean value. + +(+cpython: For tuples, `==` and `!=` always produce boolean results, regardless of the return type +of the dunder methods.) + ```py +from __future__ import annotations + class A: - def __eq__(self, o) -> str: ... - def __ne__(self, o) -> int: ... - def __lt__(self, o) -> float: ... - def __le__(self, o) -> object: ... - def __gt__(self, o) -> tuple: ... - def __ge__(self, o) -> list: ... + def __eq__(self, o: object) -> str: + return "hello" + + def __ne__(self, o: object) -> bytes: + return b"world" + + def __lt__(self, o: A) -> float: + return 3.14 + + def __le__(self, o: A) -> complex: + return complex(0.5, -0.5) + + def __gt__(self, o: A) -> tuple: + return (1, 2, 3) + + def __ge__(self, o: A) -> list: + return [1, 2, 3] a = (A(), A()) +reveal_type(a == a) # revealed: bool +reveal_type(a != a) # revealed: bool +reveal_type(a < a) # revealed: float | Literal[False] +reveal_type(a <= a) # revealed: complex | Literal[True] +reveal_type(a > a) # revealed: tuple | Literal[False] +reveal_type(a >= a) # revealed: list | Literal[True] + +# If lexicographic comparison is finished before comparing A() +b = ("1_foo", A()) +c = ("2_bar", A()) + +reveal_type(b == c) # revealed: Literal[False] +reveal_type(b != c) # revealed: Literal[True] +reveal_type(b < c) # revealed: Literal[True] +reveal_type(b <= c) # revealed: Literal[True] +reveal_type(b > c) # revealed: Literal[False] +reveal_type(b >= c) # revealed: Literal[False] + +class B: + def __lt__(self, o: B) -> set: + return set() + +reveal_type((A(), B()) < (A(), B())) # revealed: float | set | Literal[False] +``` + +#### Special Handling of Eq and NotEq in Lexicographic Comparisons + +> Example: `(int_instance(), "foo") == (int_instance(), "bar")` + +`Eq` and `NotEq` have unique behavior compared to other operators in lexicographic comparisons. +Specifically, for `Eq`, if any non-equal pair exists within the tuples being compared, we can +immediately conclude that the tuples are not equal. Conversely, for `NotEq`, if any non-equal pair +exists, we can determine that the tuples are unequal. + +In contrast, with operators like `<` and `>`, the comparison must consider each pair of elements +sequentially, and the final outcome might remain ambiguous until all pairs are compared. + +```py +def str_instance() -> str: + return "hello" + +def int_instance() -> int: + return 42 + +reveal_type("foo" == "bar") # revealed: Literal[False] +reveal_type(("foo",) == ("bar",)) # revealed: Literal[False] +reveal_type((4, "foo") == (4, "bar")) # revealed: Literal[False] +reveal_type((int_instance(), "foo") == (int_instance(), "bar")) # revealed: Literal[False] + +a = (str_instance(), int_instance(), "foo") + reveal_type(a == a) # revealed: bool reveal_type(a != a) # revealed: bool reveal_type(a < a) # revealed: bool reveal_type(a <= a) # revealed: bool reveal_type(a > a) # revealed: bool reveal_type(a >= a) # revealed: bool + +b = (str_instance(), int_instance(), "bar") + +reveal_type(a == b) # revealed: Literal[False] +reveal_type(a != b) # revealed: Literal[True] +reveal_type(a < b) # revealed: bool +reveal_type(a <= b) # revealed: bool +reveal_type(a > b) # revealed: bool +reveal_type(a >= b) # revealed: bool + +c = (str_instance(), int_instance(), "foo", "different_length") +reveal_type(a == c) # revealed: Literal[False] +reveal_type(a != c) # revealed: Literal[True] +reveal_type(a < c) # revealed: bool +reveal_type(a <= c) # revealed: bool +reveal_type(a > c) # revealed: bool +reveal_type(a >= c) # revealed: bool +``` + +#### Error Propagation + +Errors occurring within a tuple comparison should propagate outward. However, if the tuple +comparison can clearly conclude before encountering an error, the error should not be raised. + +```py +def int_instance() -> int: + return 42 + +def str_instance() -> str: + return "hello" + +class A: ... + +# error: [unsupported-operator] "Operator `<` is not supported for types `A` and `A`" +A() < A() +# error: [unsupported-operator] "Operator `<=` is not supported for types `A` and `A`" +A() <= A() +# error: [unsupported-operator] "Operator `>` is not supported for types `A` and `A`" +A() > A() +# error: [unsupported-operator] "Operator `>=` is not supported for types `A` and `A`" +A() >= A() + +a = (0, int_instance(), A()) + +# error: [unsupported-operator] "Operator `<` is not supported for types `A` and `A`, in comparing `tuple[Literal[0], int, A]` with `tuple[Literal[0], int, A]`" +reveal_type(a < a) # revealed: Unknown +# error: [unsupported-operator] "Operator `<=` is not supported for types `A` and `A`, in comparing `tuple[Literal[0], int, A]` with `tuple[Literal[0], int, A]`" +reveal_type(a <= a) # revealed: Unknown +# error: [unsupported-operator] "Operator `>` is not supported for types `A` and `A`, in comparing `tuple[Literal[0], int, A]` with `tuple[Literal[0], int, A]`" +reveal_type(a > a) # revealed: Unknown +# error: [unsupported-operator] "Operator `>=` is not supported for types `A` and `A`, in comparing `tuple[Literal[0], int, A]` with `tuple[Literal[0], int, A]`" +reveal_type(a >= a) # revealed: Unknown + +# Comparison between `a` and `b` should only involve the first elements, `Literal[0]` and `Literal[99999]`, +# and should terminate immediately. +b = (99999, int_instance(), A()) + +reveal_type(a < b) # revealed: Literal[True] +reveal_type(a <= b) # revealed: Literal[True] +reveal_type(a > b) # revealed: Literal[False] +reveal_type(a >= b) # revealed: Literal[False] ``` ### Membership Test Comparisons diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md index c1f23b1d48..472cac5073 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md @@ -4,6 +4,8 @@ def bool_instance() -> bool: return True +class A: ... + a = 1 in 7 # error: "Operator `in` is not supported for types `Literal[1]` and `Literal[7]`" reveal_type(a) # revealed: bool @@ -33,4 +35,8 @@ reveal_type(e) # revealed: bool f = (1, 2) < (1, "hello") # TODO: should be Unknown, once operand type check is implemented reveal_type(f) # revealed: bool + +# error: [unsupported-operator] "Operator `<` is not supported for types `A` and `A`, in comparing `tuple[bool, A]` with `tuple[bool, A]`" +g = (bool_instance(), A()) < (bool_instance(), A()) +reveal_type(g) # revealed: Unknown ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/sys_version_info.md b/crates/red_knot_python_semantic/resources/mdtest/sys_version_info.md index afac3a5766..a9c175a0e7 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/sys_version_info.md +++ b/crates/red_knot_python_semantic/resources/mdtest/sys_version_info.md @@ -58,8 +58,7 @@ reveal_type(sys.version_info >= (3, 9, 1, "final", 0)) # revealed: bool # emitting a lint diagnostic of some kind warning them about the probable error? reveal_type(sys.version_info >= (3, 9, 1, "final", 0, 5)) # revealed: bool -# TODO: this should be `Literal[False]`; see #14279 -reveal_type(sys.version_info == (3, 9, 1, "finallllll", 0)) # revealed: bool +reveal_type(sys.version_info == (3, 8, 1, "finallllll", 0)) # revealed: Literal[False] ``` ## Imports and aliases diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index b6e8a76a07..ecc358d0f5 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -3643,16 +3643,16 @@ impl<'db> TypeInferenceBuilder<'db> { let lhs_elements = lhs.elements(self.db); let rhs_elements = rhs.elements(self.db); - let mut lexicographic_type_comparison = - |op| self.infer_lexicographic_type_comparison(lhs_elements, op, rhs_elements); + let mut tuple_rich_comparison = + |op| self.infer_tuple_rich_comparison(lhs_elements, op, rhs_elements); match op { - ast::CmpOp::Eq => lexicographic_type_comparison(RichCompareOperator::Eq), - ast::CmpOp::NotEq => lexicographic_type_comparison(RichCompareOperator::Ne), - ast::CmpOp::Lt => lexicographic_type_comparison(RichCompareOperator::Lt), - ast::CmpOp::LtE => lexicographic_type_comparison(RichCompareOperator::Le), - ast::CmpOp::Gt => lexicographic_type_comparison(RichCompareOperator::Gt), - ast::CmpOp::GtE => lexicographic_type_comparison(RichCompareOperator::Ge), + ast::CmpOp::Eq => tuple_rich_comparison(RichCompareOperator::Eq), + ast::CmpOp::NotEq => tuple_rich_comparison(RichCompareOperator::Ne), + ast::CmpOp::Lt => tuple_rich_comparison(RichCompareOperator::Lt), + ast::CmpOp::LtE => tuple_rich_comparison(RichCompareOperator::Le), + ast::CmpOp::Gt => tuple_rich_comparison(RichCompareOperator::Gt), + ast::CmpOp::GtE => tuple_rich_comparison(RichCompareOperator::Ge), ast::CmpOp::In | ast::CmpOp::NotIn => { let mut eq_count = 0usize; let mut not_eq_count = 0usize; @@ -3685,8 +3685,7 @@ impl<'db> TypeInferenceBuilder<'db> { ast::CmpOp::Is | ast::CmpOp::IsNot => { // - `[ast::CmpOp::Is]`: returns `false` if the elements are definitely unequal, otherwise `bool` // - `[ast::CmpOp::IsNot]`: returns `true` if the elements are definitely unequal, otherwise `bool` - let eq_result = lexicographic_type_comparison(RichCompareOperator::Eq) - .expect( + let eq_result = tuple_rich_comparison(RichCompareOperator::Eq).expect( "infer_binary_type_comparison should never return None for `CmpOp::Eq`", ); @@ -3751,53 +3750,80 @@ impl<'db> TypeInferenceBuilder<'db> { } } - /// Performs lexicographic comparison between two slices of types. + /// Simulates rich comparison between tuples and returns the inferred result. + /// This performs a lexicographic comparison, returning a union of all possible return types that could result from the comparison. /// - /// For lexicographic comparison, elements from both slices are compared pairwise using - /// `infer_binary_type_comparison`. If a conclusive result cannot be determined as a `BooleanLiteral`, - /// it returns `bool`. Returns `None` if the comparison is not supported. - fn infer_lexicographic_type_comparison( + /// basically it's based on cpython's `tuple_richcompare` + /// see `` + fn infer_tuple_rich_comparison( &mut self, left: &[Type<'db>], op: RichCompareOperator, right: &[Type<'db>], ) -> Result, CompareUnsupportedError<'db>> { - // Compare paired elements from left and right slices - for (l_ty, r_ty) in left.iter().copied().zip(right.iter().copied()) { - let eq_result = self + let left_iter = left.iter().copied(); + let right_iter = right.iter().copied(); + + let mut builder = UnionBuilder::new(self.db); + + for (l_ty, r_ty) in left_iter.zip(right_iter) { + let pairwise_eq_result = self .infer_binary_type_comparison(l_ty, ast::CmpOp::Eq, r_ty) .expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); - match eq_result { + match pairwise_eq_result { // If propagation is required, return the result as is Type::Todo => return Ok(Type::Todo), ty => match ty.bool(self.db) { - // Types are equal, continue to the next pair + // - AlwaysTrue : Continue to the next pair for lexicographic comparison Truthiness::AlwaysTrue => continue, - // Types are not equal, perform the specified comparison and return the result - Truthiness::AlwaysFalse => { - return self.infer_binary_type_comparison(l_ty, op.into(), r_ty) + // - AlwaysFalse: + // Lexicographic comparisons will always terminate with this pair. + // Complete the comparison and return the result. + // - Ambiguous: + // Lexicographic comparisons might continue to the next pair (if eq_result is true), + // or terminate here (if eq_result is false). + // To account for cases where the comparison terminates here, add the pairwise comparison result to the union builder. + eq_truthiness @ (Truthiness::AlwaysFalse | Truthiness::Ambiguous) => { + let pairwise_compare_result = match op { + RichCompareOperator::Lt + | RichCompareOperator::Le + | RichCompareOperator::Gt + | RichCompareOperator::Ge => { + self.infer_binary_type_comparison(l_ty, op.into(), r_ty)? + } + // For `==` and `!=`, we already figure out the result from `pairwise_eq_result` + // NOTE: The CPython implementation does not account for non-boolean return types + // or cases where `!=` is not the negation of `==`, we also do not consider these cases. + RichCompareOperator::Eq => Type::BooleanLiteral(false), + RichCompareOperator::Ne => Type::BooleanLiteral(true), + }; + + builder = builder.add(pairwise_compare_result); + + if eq_truthiness.is_ambiguous() { + continue; + } + + return Ok(builder.build()); } - // If the intermediate result is ambiguous, we cannot determine the final result as BooleanLiteral. - // In this case, we simply return a bool instance. - Truthiness::Ambiguous => return Ok(KnownClass::Bool.to_instance(self.db)), }, } } - // At this point, the lengths of the two slices may be different, but the prefix of - // left and right slices is entirely identical. - // We return a comparison of the slice lengths based on the operator. + // if no more items to compare, we just compare sizes let (left_len, right_len) = (left.len(), right.len()); - Ok(Type::BooleanLiteral(match op { + builder = builder.add(Type::BooleanLiteral(match op { RichCompareOperator::Eq => left_len == right_len, RichCompareOperator::Ne => left_len != right_len, RichCompareOperator::Lt => left_len < right_len, RichCompareOperator::Le => left_len <= right_len, RichCompareOperator::Gt => left_len > right_len, RichCompareOperator::Ge => left_len >= right_len, - })) + })); + + Ok(builder.build()) } fn infer_subscript_expression(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> {