diff --git a/crates/ty_python_semantic/resources/mdtest/comparison/tuples.md b/crates/ty_python_semantic/resources/mdtest/comparison/tuples.md index 03169ef851..e9bf3ccf38 100644 --- a/crates/ty_python_semantic/resources/mdtest/comparison/tuples.md +++ b/crates/ty_python_semantic/resources/mdtest/comparison/tuples.md @@ -392,3 +392,16 @@ class A: # error: [unsupported-bool-conversion] (A(),) == (A(),) ``` + +## Recursive NamedTuple + +```py +from __future__ import annotations +from typing import NamedTuple + +class Node(NamedTuple): + parent: Node | None + +def _(n: Node): + reveal_type(n.parent is n) # revealed: bool +``` diff --git a/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md b/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md index a17a46ce15..b3b1d15ef8 100644 --- a/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md +++ b/crates/ty_python_semantic/resources/mdtest/pep695_type_aliases.md @@ -351,3 +351,12 @@ def f(x: A): for item in x: reveal_type(item) # revealed: list[A | str | None] | str | None ``` + +### Tuple comparison + +```py +type X = tuple[X, int] + +def _(x: X): + reveal_type(x is x) # revealed: bool +``` diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 5e2b308433..34638b8bb3 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -45,6 +45,7 @@ use crate::semantic_index::{ use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind}; use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator}; use crate::types::context::{InNoTypeCheck, InferContext}; +use crate::types::cyclic::CycleDetector; use crate::types::diagnostic::{ CALL_NON_CALLABLE, CONFLICTING_DECLARATIONS, CONFLICTING_METACLASS, CYCLIC_CLASS_DEFINITION, DIVISION_BY_ZERO, DUPLICATE_KW_ONLY, INCONSISTENT_MRO, INVALID_ARGUMENT_TYPE, @@ -132,6 +133,13 @@ impl<'db> DeclaredAndInferredType<'db> { } } +/// A [`CycleDetector`] that is used in `infer_binary_type_comparison`. +type BinaryComparisonVisitor<'db> = CycleDetector< + ast::CmpOp, + (Type<'db>, ast::CmpOp, Type<'db>), + Result, CompareUnsupportedError<'db>>, +>; + /// Builder to infer all types in a region. /// /// A builder is used by creating it with [`new()`](TypeInferenceBuilder::new), and then calling @@ -7438,7 +7446,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let range = TextRange::new(left.start(), right.end()); let ty = builder - .infer_binary_type_comparison(left_ty, *op, right_ty, range) + .infer_binary_type_comparison( + left_ty, + *op, + right_ty, + range, + &BinaryComparisonVisitor::new(Ok(Type::BooleanLiteral(true))), + ) .unwrap_or_else(|error| { if let Some(diagnostic_builder) = builder.context.report_lint(&UNSUPPORTED_OPERATOR, range) @@ -7484,6 +7498,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { other: Type<'db>, intersection_on: IntersectionOn, range: TextRange, + visitor: &BinaryComparisonVisitor<'db>, ) -> Result, CompareUnsupportedError<'db>> { enum State<'db> { // We have not seen any positive elements (yet) @@ -7500,8 +7515,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // intersection type, which is even more specific. for pos in intersection.positive(self.db()) { let result = match intersection_on { - IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other, range), - IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos, range), + IntersectionOn::Left => { + self.infer_binary_type_comparison(*pos, op, other, range, visitor) + } + IntersectionOn::Right => { + self.infer_binary_type_comparison(other, op, *pos, range, visitor) + } }; if let Ok(Type::BooleanLiteral(_)) = result { @@ -7514,10 +7533,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for neg in intersection.negative(self.db()) { let result = match intersection_on { IntersectionOn::Left => self - .infer_binary_type_comparison(*neg, op, other, range) + .infer_binary_type_comparison(*neg, op, other, range, visitor) .ok(), IntersectionOn::Right => self - .infer_binary_type_comparison(other, op, *neg, range) + .infer_binary_type_comparison(other, op, *neg, range, visitor) .ok(), }; @@ -7578,8 +7597,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for pos in intersection.positive(self.db()) { let result = match intersection_on { - IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other, range), - IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos, range), + IntersectionOn::Left => { + self.infer_binary_type_comparison(*pos, op, other, range, visitor) + } + IntersectionOn::Right => { + self.infer_binary_type_comparison(other, op, *pos, range, visitor) + } }; match result { @@ -7614,10 +7637,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // We didn't see any positive elements, check if the operation is supported on `object`: match intersection_on { IntersectionOn::Left => { - self.infer_binary_type_comparison(Type::object(), op, other, range) + self.infer_binary_type_comparison(Type::object(), op, other, range, visitor) } IntersectionOn::Right => { - self.infer_binary_type_comparison(other, op, Type::object(), range) + self.infer_binary_type_comparison(other, op, Type::object(), range, visitor) } } } @@ -7637,6 +7660,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { op: ast::CmpOp, right: Type<'db>, range: TextRange, + visitor: &BinaryComparisonVisitor<'db>, ) -> Result, CompareUnsupportedError<'db>> { // Note: identity (is, is not) for equal builtin types is unreliable and not part of the // language spec. @@ -7689,7 +7713,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut builder = UnionBuilder::new(self.db()); for element in union.elements(self.db()) { builder = - builder.add(self.infer_binary_type_comparison(*element, op, other, range)?); + builder.add(self.infer_binary_type_comparison(*element, op, other, range, visitor)?); } Some(Ok(builder.build())) } @@ -7697,7 +7721,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut builder = UnionBuilder::new(self.db()); for element in union.elements(self.db()) { builder = - builder.add(self.infer_binary_type_comparison(other, op, *element, range)?); + builder.add(self.infer_binary_type_comparison(other, op, *element, range, visitor)?); } Some(Ok(builder.build())) } @@ -7709,6 +7733,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { right, IntersectionOn::Left, range, + visitor, )) } (left, Type::Intersection(intersection)) => { @@ -7718,22 +7743,29 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { left, IntersectionOn::Right, range, + visitor, )) } - (Type::TypeAlias(alias), right) => Some(self.infer_binary_type_comparison( - alias.value_type(self.db()), - op, - right, - range, - )), + (Type::TypeAlias(alias), right) => Some( + visitor.visit((left, op, right), || { self.infer_binary_type_comparison( + alias.value_type(self.db()), + op, + right, + range, + visitor, + ) + })), - (left, Type::TypeAlias(alias)) => Some(self.infer_binary_type_comparison( - left, - op, - alias.value_type(self.db()), - range, - )), + (left, Type::TypeAlias(alias)) => Some( + visitor.visit((left, op, right), || { self.infer_binary_type_comparison( + left, + op, + alias.value_type(self.db()), + range, + visitor, + ) + })), (Type::IntLiteral(n), Type::IntLiteral(m)) => Some(match op { ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)), @@ -7771,6 +7803,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { op, right, range, + visitor, )) } (Type::NominalInstance(_), Type::IntLiteral(_)) => { @@ -7779,6 +7812,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { op, KnownClass::Int.to_instance(self.db()), range, + visitor, )) } @@ -7789,6 +7823,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { op, Type::IntLiteral(i64::from(b)), range, + visitor, )) } (Type::BooleanLiteral(b), Type::IntLiteral(m)) => { @@ -7797,6 +7832,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { op, Type::IntLiteral(m), range, + visitor, )) } (Type::BooleanLiteral(a), Type::BooleanLiteral(b)) => { @@ -7805,6 +7841,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { op, Type::IntLiteral(i64::from(b)), range, + visitor, )) } @@ -7842,12 +7879,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { op, right, range, + visitor, )), (_, Type::StringLiteral(_)) => Some(self.infer_binary_type_comparison( left, op, KnownClass::Str.to_instance(self.db()), range, + visitor, )), (Type::LiteralString, _) => Some(self.infer_binary_type_comparison( @@ -7855,12 +7894,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { op, right, range, + visitor, )), (_, Type::LiteralString) => Some(self.infer_binary_type_comparison( left, op, KnownClass::Str.to_instance(self.db()), range, + visitor, )), (Type::BytesLiteral(salsa_b1), Type::BytesLiteral(salsa_b2)) => { @@ -7901,12 +7942,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { op, right, range, + visitor, )), (_, Type::BytesLiteral(_)) => Some(self.infer_binary_type_comparison( left, op, KnownClass::Bytes.to_instance(self.db()), range, + visitor, )), (Type::EnumLiteral(literal_1), Type::EnumLiteral(literal_2)) @@ -7933,7 +7976,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .and_then(|lhs_tuple| Some((lhs_tuple, nominal2.tuple_spec(self.db())?))) .map(|(lhs_tuple, rhs_tuple)| { let mut tuple_rich_comparison = - |op| self.infer_tuple_rich_comparison(&lhs_tuple, op, &rhs_tuple, range); + |rich_op| visitor.visit((left, op, right), || { + self.infer_tuple_rich_comparison(&lhs_tuple, rich_op, &rhs_tuple, range, visitor) + }); match op { ast::CmpOp::Eq => tuple_rich_comparison(RichCompareOperator::Eq), @@ -7952,6 +7997,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::CmpOp::Eq, ty, range, + visitor ).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); match eq_result { @@ -8125,6 +8171,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { op: RichCompareOperator, right: &TupleSpec<'db>, range: TextRange, + visitor: &BinaryComparisonVisitor<'db>, ) -> Result, CompareUnsupportedError<'db>> { // If either tuple is variable length, we can make no assumptions about the relative // lengths of the tuples, and therefore neither about how they compare lexicographically. @@ -8141,7 +8188,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for (l_ty, r_ty) in left_iter.zip(right_iter) { let pairwise_eq_result = self - .infer_binary_type_comparison(l_ty, ast::CmpOp::Eq, r_ty, range) + .infer_binary_type_comparison(l_ty, ast::CmpOp::Eq, r_ty, range, visitor) .expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); match pairwise_eq_result @@ -8166,9 +8213,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { RichCompareOperator::Lt | RichCompareOperator::Le | RichCompareOperator::Gt - | RichCompareOperator::Ge => { - self.infer_binary_type_comparison(l_ty, op.into(), r_ty, range)? - } + | RichCompareOperator::Ge => self.infer_binary_type_comparison( + l_ty, + op.into(), + r_ty, + range, + visitor, + )?, // For `==` and `!=`, we already figure out the result from `pairwise_eq_result` // NOTE: The CPython implementation does not account for non-boolean return types // or cases where `!=` is not the negation of `==`, we also do not consider these cases.