From 70bdde40855e70f856cda9e065d276008e7c23e3 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Fri, 1 Nov 2024 15:49:18 -0400 Subject: [PATCH] Handle unions in augmented assignments (#14045) ## Summary Removing more TODOs from the augmented assignment test suite. Now, if the _target_ is a union, we correctly infer the union of results: ```python if flag: f = Foo() else: f = 42.0 f += 12 ``` --- .../resources/mdtest/assignment/augmented.md | 37 +++- crates/red_knot_python_semantic/src/types.rs | 2 +- .../src/types/infer.rs | 162 ++++++++++-------- 3 files changed, 126 insertions(+), 75 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/assignment/augmented.md b/crates/red_knot_python_semantic/resources/mdtest/assignment/augmented.md index 061c965080..dd96fdc819 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/assignment/augmented.md +++ b/crates/red_knot_python_semantic/resources/mdtest/assignment/augmented.md @@ -126,8 +126,7 @@ else: f = 42.0 f += 12 -# TODO(charlie): This should be `str | int | float` -reveal_type(f) # revealed: @Todo +reveal_type(f) # revealed: int | str | float ``` ## Target union @@ -148,6 +147,36 @@ else: f = 42.0 f += 12 -# TODO(charlie): This should be `str | float`. -reveal_type(f) # revealed: @Todo +reveal_type(f) # revealed: str | float +``` + +## Partially bound target union with `__add__` + +```py +def bool_instance() -> bool: + return True + +flag = bool_instance() + +class Foo: + def __add__(self, other: int) -> str: + return "Hello, world!" + if bool_instance(): + def __iadd__(self, other: int) -> int: + return 42 + +class Bar: + def __add__(self, other: int) -> bytes: + return b"Hello, world!" + + def __iadd__(self, other: int) -> float: + return 42.0 + +if flag: + f = Foo() +else: + f = Bar() +f += 12 + +reveal_type(f) # revealed: int | str | float ``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 481531e460..f064f22c60 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1902,7 +1902,7 @@ impl<'db> UnionType<'db> { pub fn map( &self, db: &'db dyn Db, - transform_fn: impl Fn(&Type<'db>) -> Type<'db>, + transform_fn: impl FnMut(&Type<'db>) -> Type<'db>, ) -> Type<'db> { Self::from_elements(db, self.elements(db).iter().map(transform_fn)) } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 51d4a4d975..c16e1c1a11 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1516,6 +1516,96 @@ impl<'db> TypeInferenceBuilder<'db> { } } + fn infer_augmented_op( + &mut self, + assignment: &ast::StmtAugAssign, + target_type: Type<'db>, + value_type: Type<'db>, + ) -> Type<'db> { + // If the target defines, e.g., `__iadd__`, infer the augmented assignment as a call to that + // dunder. + let op = assignment.op; + match target_type { + Type::Union(union) => { + return union.map(self.db, |&target_type| { + self.infer_augmented_op(assignment, target_type, value_type) + }) + } + Type::Instance(class) => { + if let Symbol::Type(class_member, boundness) = + class.class_member(self.db, op.in_place_dunder()) + { + let call = class_member.call(self.db, &[target_type, value_type]); + let augmented_return_ty = match call.return_ty_result( + self.db, + AnyNodeRef::StmtAugAssign(assignment), + &mut self.diagnostics, + ) { + Ok(t) => t, + Err(e) => { + self.diagnostics.add( + assignment.into(), + "unsupported-operator", + format_args!( + "Operator `{op}=` is unsupported between objects of type `{}` and `{}`", + target_type.display(self.db), + value_type.display(self.db) + ), + ); + e.return_ty() + } + }; + + return match boundness { + Boundness::Bound => augmented_return_ty, + Boundness::MayBeUnbound => { + let left_ty = target_type; + let right_ty = value_type; + + let binary_return_ty = self.infer_binary_expression_type(left_ty, right_ty, op) + .unwrap_or_else(|| { + self.diagnostics.add( + assignment.into(), + "unsupported-operator", + format_args!( + "Operator `{op}=` is unsupported between objects of type `{}` and `{}`", + left_ty.display(self.db), + right_ty.display(self.db) + ), + ); + Type::Unknown + }); + + UnionType::from_elements( + self.db, + [augmented_return_ty, binary_return_ty], + ) + } + }; + } + } + _ => {} + } + + // By default, fall back to non-augmented binary operator inference. + let left_ty = target_type; + let right_ty = value_type; + + self.infer_binary_expression_type(left_ty, right_ty, op) + .unwrap_or_else(|| { + self.diagnostics.add( + assignment.into(), + "unsupported-operator", + format_args!( + "Operator `{op}=` is unsupported between objects of type `{}` and `{}`", + left_ty.display(self.db), + right_ty.display(self.db) + ), + ); + Type::Unknown + }) + } + fn infer_augment_assignment_definition( &mut self, assignment: &ast::StmtAugAssign, @@ -1529,7 +1619,7 @@ impl<'db> TypeInferenceBuilder<'db> { let ast::StmtAugAssign { range: _, target, - op, + op: _, value, } = assignment; @@ -1547,75 +1637,7 @@ impl<'db> TypeInferenceBuilder<'db> { }; let value_type = self.infer_expression(value); - // If the target defines, e.g., `__iadd__`, infer the augmented assignment as a call to that - // dunder. - if let Type::Instance(class) = target_type { - if let Symbol::Type(class_member, boundness) = - class.class_member(self.db, op.in_place_dunder()) - { - let call = class_member.call(self.db, &[target_type, value_type]); - let augmented_return_ty = match call.return_ty_result( - self.db, - AnyNodeRef::StmtAugAssign(assignment), - &mut self.diagnostics, - ) { - Ok(t) => t, - Err(e) => { - self.diagnostics.add( - assignment.into(), - "unsupported-operator", - format_args!( - "Operator `{op}=` is unsupported between objects of type `{}` and `{}`", - target_type.display(self.db), - value_type.display(self.db) - ), - ); - e.return_ty() - } - }; - - return match boundness { - Boundness::Bound => augmented_return_ty, - Boundness::MayBeUnbound => { - let left_ty = target_type; - let right_ty = value_type; - - let binary_return_ty = self.infer_binary_expression_type(left_ty, right_ty, *op) - .unwrap_or_else(|| { - self.diagnostics.add( - assignment.into(), - "unsupported-operator", - format_args!( - "Operator `{op}=` is unsupported between objects of type `{}` and `{}`", - left_ty.display(self.db), - right_ty.display(self.db) - ), - ); - Type::Unknown - }); - - UnionType::from_elements(self.db, [augmented_return_ty, binary_return_ty]) - } - }; - } - } - - let left_ty = target_type; - let right_ty = value_type; - - self.infer_binary_expression_type(left_ty, right_ty, *op) - .unwrap_or_else(|| { - self.diagnostics.add( - assignment.into(), - "unsupported-operator", - format_args!( - "Operator `{op}=` is unsupported between objects of type `{}` and `{}`", - left_ty.display(self.db), - right_ty.display(self.db) - ), - ); - Type::Unknown - }) + self.infer_augmented_op(assignment, target_type, value_type) } fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) {