From 34a5d7cb7fc6abe0239757d35be306a97c3068e7 Mon Sep 17 00:00:00 2001 From: TomerBin Date: Fri, 1 Nov 2024 21:23:18 +0200 Subject: [PATCH] [red-knot] Infer type of if-expression if test has statically known truthiness (#14048) ## Summary Detecting statically known truthy or falsy test in if expressions (ternary). ## Test Plan new mdtest --- .../resources/mdtest/expression/if.md | 24 +++++++++++++++++++ .../resources/mdtest/subscript/class.md | 5 ++-- .../src/types/infer.rs | 10 ++++---- 3 files changed, 33 insertions(+), 6 deletions(-) create mode 100644 crates/red_knot_python_semantic/resources/mdtest/expression/if.md diff --git a/crates/red_knot_python_semantic/resources/mdtest/expression/if.md b/crates/red_knot_python_semantic/resources/mdtest/expression/if.md new file mode 100644 index 0000000000..ec687f798f --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/expression/if.md @@ -0,0 +1,24 @@ +# If expression + +## Union + +```py +def bool_instance() -> bool: + return True + +reveal_type(1 if bool_instance() else 2) # revealed: Literal[1, 2] +``` + +## Statically known branches + +```py +reveal_type(1 if True else 2) # revealed: Literal[1] +reveal_type(1 if "not empty" else 2) # revealed: Literal[1] +reveal_type(1 if (1,) else 2) # revealed: Literal[1] +reveal_type(1 if 1 else 2) # revealed: Literal[1] + +reveal_type(1 if False else 2) # revealed: Literal[2] +reveal_type(1 if None else 2) # revealed: Literal[2] +reveal_type(1 if "" else 2) # revealed: Literal[2] +reveal_type(1 if 0 else 2) # revealed: Literal[2] +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/class.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/class.md index c3327c5bdc..56870fefc6 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/class.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/class.md @@ -39,7 +39,8 @@ reveal_type(UnionClassGetItem[0]) # revealed: str | int ## Class getitem with class union ```py -flag = True +def bool_instance() -> bool: + return True class A: def __class_getitem__(cls, item: int) -> str: @@ -49,7 +50,7 @@ class B: def __class_getitem__(cls, item: int) -> int: return item -x = A if flag else B +x = A if bool_instance() else B reveal_type(x) # revealed: Literal[A, B] reveal_type(x[0]) # revealed: str | int diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index fd3f807122..51d4a4d975 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2362,13 +2362,15 @@ impl<'db> TypeInferenceBuilder<'db> { orelse, } = if_expression; - self.infer_expression(test); - - // TODO detect statically known truthy or falsy test + let test_ty = self.infer_expression(test); let body_ty = self.infer_expression(body); let orelse_ty = self.infer_expression(orelse); - UnionType::from_elements(self.db, [body_ty, orelse_ty]) + match test_ty.bool(self.db) { + Truthiness::AlwaysTrue => body_ty, + Truthiness::AlwaysFalse => orelse_ty, + Truthiness::Ambiguous => UnionType::from_elements(self.db, [body_ty, orelse_ty]), + } } fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) {