From 7568eeb7a5fcb77af740c8ed0d77eef2f946d479 Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Wed, 30 Apr 2025 20:34:21 +0530 Subject: [PATCH] [red-knot] Check decorator consistency on overloads (#17684) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Part of #15383. As per the spec (https://typing.python.org/en/latest/spec/overload.html#invalid-overload-definitions): For `@staticmethod` and `@classmethod`: > If one overload signature is decorated with `@staticmethod` or `@classmethod`, all overload signatures must be similarly decorated. The implementation, if present, must also have a consistent decorator. Type checkers should report an error if these conditions are not met. For `@final` and `@override`: > If a `@final` or `@override` decorator is supplied for a function with overloads, the decorator should be applied only to the overload implementation if it is present. If an overload implementation isn’t present (for example, in a stub file), the `@final` or `@override` decorator should be applied only to the first overload. Type checkers should enforce these rules and generate an error when they are violated. If a `@final` or `@override` decorator follows these rules, a type checker should treat the decorator as if it is present on all overloads. ## Test Plan Update existing tests; add snapshots. --- .../resources/mdtest/overloads.md | 146 +++++++++++------- ...onsistent_decorators_-_`@classmethod`.snap | 128 +++++++++++++++ ..._-_Inconsistent_decorators_-_`@final`.snap | 111 +++++++++++++ ...Inconsistent_decorators_-_`@override`.snap | 129 ++++++++++++++++ crates/red_knot_python_semantic/src/types.rs | 7 + .../src/types/infer.rs | 97 ++++++++++++ 6 files changed, 561 insertions(+), 57 deletions(-) create mode 100644 crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@classmethod`.snap create mode 100644 crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@final`.snap create mode 100644 crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@override`.snap diff --git a/crates/red_knot_python_semantic/resources/mdtest/overloads.md b/crates/red_knot_python_semantic/resources/mdtest/overloads.md index 4089040d9b..c26b9cf42b 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/overloads.md +++ b/crates/red_knot_python_semantic/resources/mdtest/overloads.md @@ -438,11 +438,10 @@ class PartialFoo(ABC): ### Inconsistent decorators -#### `@staticmethod` / `@classmethod` +#### `@staticmethod` -If one overload signature is decorated with `@staticmethod` or `@classmethod`, all overload -signatures must be similarly decorated. The implementation, if present, must also have a consistent -decorator. +If one overload signature is decorated with `@staticmethod`, all overload signatures must be +similarly decorated. The implementation, if present, must also have a consistent decorator. ```py from __future__ import annotations @@ -486,39 +485,54 @@ class CheckStaticMethod: @staticmethod def method4(x: int | str) -> int | str: return x +``` + +#### `@classmethod` + + + +The same rules apply for `@classmethod` as for [`@staticmethod`](#staticmethod). + +```py +from __future__ import annotations + +from typing import overload class CheckClassMethod: def __init__(self, x: int) -> None: self.x = x - # TODO: error because `@classmethod` does not exist on all overloads + @overload @classmethod def try_from1(cls, x: int) -> CheckClassMethod: ... @overload def try_from1(cls, x: str) -> None: ... @classmethod + # error: [invalid-overload] "Overloaded function `try_from1` does not use the `@classmethod` decorator consistently" def try_from1(cls, x: int | str) -> CheckClassMethod | None: if isinstance(x, int): return cls(x) return None - # TODO: error because `@classmethod` does not exist on all overloads + @overload def try_from2(cls, x: int) -> CheckClassMethod: ... @overload @classmethod def try_from2(cls, x: str) -> None: ... @classmethod + # error: [invalid-overload] def try_from2(cls, x: int | str) -> CheckClassMethod | None: if isinstance(x, int): return cls(x) return None - # TODO: error because `@classmethod` does not exist on the implementation + @overload @classmethod def try_from3(cls, x: int) -> CheckClassMethod: ... @overload @classmethod def try_from3(cls, x: str) -> None: ... + # error: [invalid-overload] def try_from3(cls, x: int | str) -> CheckClassMethod | None: if isinstance(x, int): return cls(x) @@ -537,13 +551,15 @@ class CheckClassMethod: return None ``` -#### `@final` / `@override` +#### `@final` -If a `@final` or `@override` decorator is supplied for a function with overloads, the decorator -should be applied only to the overload implementation if it is present. + + +If a `@final` decorator is supplied for a function with overloads, the decorator should be applied +only to the overload implementation if it is present. ```py -from typing_extensions import final, overload, override +from typing_extensions import final, overload class Foo: @overload @@ -553,68 +569,31 @@ class Foo: @final def method1(self, x: int | str) -> int | str: return x - # TODO: error because `@final` is not on the implementation + @overload @final def method2(self, x: int) -> int: ... @overload def method2(self, x: str) -> str: ... + # error: [invalid-overload] def method2(self, x: int | str) -> int | str: return x - # TODO: error because `@final` is not on the implementation + @overload def method3(self, x: int) -> int: ... @overload @final def method3(self, x: str) -> str: ... + # error: [invalid-overload] def method3(self, x: int | str) -> int | str: return x - -class Base: - @overload - def method(self, x: int) -> int: ... - @overload - def method(self, x: str) -> str: ... - def method(self, x: int | str) -> int | str: - return x - -class Sub1(Base): - @overload - def method(self, x: int) -> int: ... - @overload - def method(self, x: str) -> str: ... - @override - def method(self, x: int | str) -> int | str: - return x - -class Sub2(Base): - # TODO: error because `@override` is not on the implementation - @overload - def method(self, x: int) -> int: ... - @overload - @override - def method(self, x: str) -> str: ... - def method(self, x: int | str) -> int | str: - return x - -class Sub3(Base): - # TODO: error because `@override` is not on the implementation - @overload - @override - def method(self, x: int) -> int: ... - @overload - def method(self, x: str) -> str: ... - def method(self, x: int | str) -> int | str: - return x ``` -#### `@final` / `@override` in stub files - -If an overload implementation isn’t present (for example, in a stub file), the `@final` or -`@override` decorator should be applied only to the first overload. +If an overload implementation isn't present (for example, in a stub file), the `@final` decorator +should be applied only to the first overload. ```pyi -from typing_extensions import final, overload, override +from typing_extensions import final, overload class Foo: @overload @@ -623,12 +602,65 @@ class Foo: @overload def method1(self, x: str) -> str: ... - # TODO: error because `@final` is not on the first overload @overload def method2(self, x: int) -> int: ... @final @overload + # error: [invalid-overload] def method2(self, x: str) -> str: ... +``` + +#### `@override` + + + +The same rules apply for `@override` as for [`@final`](#final). + +```py +from typing_extensions import overload, override + +class Base: + @overload + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + def method(self, x: int | str) -> int | str: + return x + +class Sub1(Base): + @overload + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + @override + def method(self, x: int | str) -> int | str: + return x + +class Sub2(Base): + @overload + def method(self, x: int) -> int: ... + @overload + @override + def method(self, x: str) -> str: ... + # error: [invalid-overload] + def method(self, x: int | str) -> int | str: + return x + +class Sub3(Base): + @overload + @override + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + # error: [invalid-overload] + def method(self, x: int | str) -> int | str: + return x +``` + +And, similarly, in stub files: + +```pyi +from typing_extensions import overload, override class Base: @overload @@ -644,10 +676,10 @@ class Sub1(Base): def method(self, x: str) -> str: ... class Sub2(Base): - # TODO: error because `@override` is not on the first overload @overload def method(self, x: int) -> int: ... @overload @override + # error: [invalid-overload] def method(self, x: str) -> str: ... ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@classmethod`.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@classmethod`.snap new file mode 100644 index 0000000000..aef6a58ff1 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@classmethod`.snap @@ -0,0 +1,128 @@ +--- +source: crates/red_knot_test/src/lib.rs +expression: snapshot +--- +--- +mdtest name: overloads.md - Overloads - Invalid - Inconsistent decorators - `@classmethod` +mdtest path: crates/red_knot_python_semantic/resources/mdtest/overloads.md +--- + +# Python source files + +## mdtest_snippet.py + +``` + 1 | from __future__ import annotations + 2 | + 3 | from typing import overload + 4 | + 5 | class CheckClassMethod: + 6 | def __init__(self, x: int) -> None: + 7 | self.x = x + 8 | + 9 | @overload +10 | @classmethod +11 | def try_from1(cls, x: int) -> CheckClassMethod: ... +12 | @overload +13 | def try_from1(cls, x: str) -> None: ... +14 | @classmethod +15 | # error: [invalid-overload] "Overloaded function `try_from1` does not use the `@classmethod` decorator consistently" +16 | def try_from1(cls, x: int | str) -> CheckClassMethod | None: +17 | if isinstance(x, int): +18 | return cls(x) +19 | return None +20 | +21 | @overload +22 | def try_from2(cls, x: int) -> CheckClassMethod: ... +23 | @overload +24 | @classmethod +25 | def try_from2(cls, x: str) -> None: ... +26 | @classmethod +27 | # error: [invalid-overload] +28 | def try_from2(cls, x: int | str) -> CheckClassMethod | None: +29 | if isinstance(x, int): +30 | return cls(x) +31 | return None +32 | +33 | @overload +34 | @classmethod +35 | def try_from3(cls, x: int) -> CheckClassMethod: ... +36 | @overload +37 | @classmethod +38 | def try_from3(cls, x: str) -> None: ... +39 | # error: [invalid-overload] +40 | def try_from3(cls, x: int | str) -> CheckClassMethod | None: +41 | if isinstance(x, int): +42 | return cls(x) +43 | return None +44 | +45 | @overload +46 | @classmethod +47 | def try_from4(cls, x: int) -> CheckClassMethod: ... +48 | @overload +49 | @classmethod +50 | def try_from4(cls, x: str) -> None: ... +51 | @classmethod +52 | def try_from4(cls, x: int | str) -> CheckClassMethod | None: +53 | if isinstance(x, int): +54 | return cls(x) +55 | return None +``` + +# Diagnostics + +``` +error: lint:invalid-overload: Overloaded function `try_from3` does not use the `@classmethod` decorator consistently + --> src/mdtest_snippet.py:40:9 + | +38 | def try_from3(cls, x: str) -> None: ... +39 | # error: [invalid-overload] +40 | def try_from3(cls, x: int | str) -> CheckClassMethod | None: + | --------- + | | + | Missing here +41 | if isinstance(x, int): +42 | return cls(x) + | + +``` + +``` +error: lint:invalid-overload: Overloaded function `try_from1` does not use the `@classmethod` decorator consistently + --> src/mdtest_snippet.py:13:9 + | +11 | def try_from1(cls, x: int) -> CheckClassMethod: ... +12 | @overload +13 | def try_from1(cls, x: str) -> None: ... + | --------- Missing here +14 | @classmethod +15 | # error: [invalid-overload] "Overloaded function `try_from1` does not use the `@classmethod` decorator consistently" +16 | def try_from1(cls, x: int | str) -> CheckClassMethod | None: + | ^^^^^^^^^ +17 | if isinstance(x, int): +18 | return cls(x) + | + +``` + +``` +error: lint:invalid-overload: Overloaded function `try_from2` does not use the `@classmethod` decorator consistently + --> src/mdtest_snippet.py:28:9 + | +26 | @classmethod +27 | # error: [invalid-overload] +28 | def try_from2(cls, x: int | str) -> CheckClassMethod | None: + | ^^^^^^^^^ +29 | if isinstance(x, int): +30 | return cls(x) + | + ::: src/mdtest_snippet.py:22:9 + | +21 | @overload +22 | def try_from2(cls, x: int) -> CheckClassMethod: ... + | --------- Missing here +23 | @overload +24 | @classmethod + | + +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@final`.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@final`.snap new file mode 100644 index 0000000000..97692ad5b8 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@final`.snap @@ -0,0 +1,111 @@ +--- +source: crates/red_knot_test/src/lib.rs +expression: snapshot +--- +--- +mdtest name: overloads.md - Overloads - Invalid - Inconsistent decorators - `@final` +mdtest path: crates/red_knot_python_semantic/resources/mdtest/overloads.md +--- + +# Python source files + +## mdtest_snippet.py + +``` + 1 | from typing_extensions import final, overload + 2 | + 3 | class Foo: + 4 | @overload + 5 | def method1(self, x: int) -> int: ... + 6 | @overload + 7 | def method1(self, x: str) -> str: ... + 8 | @final + 9 | def method1(self, x: int | str) -> int | str: +10 | return x +11 | +12 | @overload +13 | @final +14 | def method2(self, x: int) -> int: ... +15 | @overload +16 | def method2(self, x: str) -> str: ... +17 | # error: [invalid-overload] +18 | def method2(self, x: int | str) -> int | str: +19 | return x +20 | +21 | @overload +22 | def method3(self, x: int) -> int: ... +23 | @overload +24 | @final +25 | def method3(self, x: str) -> str: ... +26 | # error: [invalid-overload] +27 | def method3(self, x: int | str) -> int | str: +28 | return x +``` + +## mdtest_snippet.pyi + +``` + 1 | from typing_extensions import final, overload + 2 | + 3 | class Foo: + 4 | @overload + 5 | @final + 6 | def method1(self, x: int) -> int: ... + 7 | @overload + 8 | def method1(self, x: str) -> str: ... + 9 | +10 | @overload +11 | def method2(self, x: int) -> int: ... +12 | @final +13 | @overload +14 | # error: [invalid-overload] +15 | def method2(self, x: str) -> str: ... +``` + +# Diagnostics + +``` +error: lint:invalid-overload: `@final` decorator should be applied only to the overload implementation + --> src/mdtest_snippet.py:27:9 + | +25 | def method3(self, x: str) -> str: ... +26 | # error: [invalid-overload] +27 | def method3(self, x: int | str) -> int | str: + | ------- + | | + | Implementation defined here +28 | return x + | + +``` + +``` +error: lint:invalid-overload: `@final` decorator should be applied only to the overload implementation + --> src/mdtest_snippet.py:18:9 + | +16 | def method2(self, x: str) -> str: ... +17 | # error: [invalid-overload] +18 | def method2(self, x: int | str) -> int | str: + | ------- + | | + | Implementation defined here +19 | return x + | + +``` + +``` +error: lint:invalid-overload: `@final` decorator should be applied only to the first overload + --> src/mdtest_snippet.pyi:11:9 + | +10 | @overload +11 | def method2(self, x: int) -> int: ... + | ------- First overload defined here +12 | @final +13 | @overload +14 | # error: [invalid-overload] +15 | def method2(self, x: str) -> str: ... + | ^^^^^^^ + | + +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@override`.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@override`.snap new file mode 100644 index 0000000000..bae8cd1bc7 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_Inconsistent_decorators_-_`@override`.snap @@ -0,0 +1,129 @@ +--- +source: crates/red_knot_test/src/lib.rs +expression: snapshot +--- +--- +mdtest name: overloads.md - Overloads - Invalid - Inconsistent decorators - `@override` +mdtest path: crates/red_knot_python_semantic/resources/mdtest/overloads.md +--- + +# Python source files + +## mdtest_snippet.py + +``` + 1 | from typing_extensions import overload, override + 2 | + 3 | class Base: + 4 | @overload + 5 | def method(self, x: int) -> int: ... + 6 | @overload + 7 | def method(self, x: str) -> str: ... + 8 | def method(self, x: int | str) -> int | str: + 9 | return x +10 | +11 | class Sub1(Base): +12 | @overload +13 | def method(self, x: int) -> int: ... +14 | @overload +15 | def method(self, x: str) -> str: ... +16 | @override +17 | def method(self, x: int | str) -> int | str: +18 | return x +19 | +20 | class Sub2(Base): +21 | @overload +22 | def method(self, x: int) -> int: ... +23 | @overload +24 | @override +25 | def method(self, x: str) -> str: ... +26 | # error: [invalid-overload] +27 | def method(self, x: int | str) -> int | str: +28 | return x +29 | +30 | class Sub3(Base): +31 | @overload +32 | @override +33 | def method(self, x: int) -> int: ... +34 | @overload +35 | def method(self, x: str) -> str: ... +36 | # error: [invalid-overload] +37 | def method(self, x: int | str) -> int | str: +38 | return x +``` + +## mdtest_snippet.pyi + +``` + 1 | from typing_extensions import overload, override + 2 | + 3 | class Base: + 4 | @overload + 5 | def method(self, x: int) -> int: ... + 6 | @overload + 7 | def method(self, x: str) -> str: ... + 8 | + 9 | class Sub1(Base): +10 | @overload +11 | @override +12 | def method(self, x: int) -> int: ... +13 | @overload +14 | def method(self, x: str) -> str: ... +15 | +16 | class Sub2(Base): +17 | @overload +18 | def method(self, x: int) -> int: ... +19 | @overload +20 | @override +21 | # error: [invalid-overload] +22 | def method(self, x: str) -> str: ... +``` + +# Diagnostics + +``` +error: lint:invalid-overload: `@override` decorator should be applied only to the overload implementation + --> src/mdtest_snippet.py:27:9 + | +25 | def method(self, x: str) -> str: ... +26 | # error: [invalid-overload] +27 | def method(self, x: int | str) -> int | str: + | ------ + | | + | Implementation defined here +28 | return x + | + +``` + +``` +error: lint:invalid-overload: `@override` decorator should be applied only to the overload implementation + --> src/mdtest_snippet.py:37:9 + | +35 | def method(self, x: str) -> str: ... +36 | # error: [invalid-overload] +37 | def method(self, x: int | str) -> int | str: + | ------ + | | + | Implementation defined here +38 | return x + | + +``` + +``` +error: lint:invalid-overload: `@override` decorator should be applied only to the first overload + --> src/mdtest_snippet.pyi:18:9 + | +16 | class Sub2(Base): +17 | @overload +18 | def method(self, x: int) -> int: ... + | ------ First overload defined here +19 | @overload +20 | @override +21 | # error: [invalid-overload] +22 | def method(self, x: str) -> str: ... + | ^^^^^^ + | + +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 5d2c430fa8..007904f177 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -6493,6 +6493,13 @@ struct OverloadedFunction<'db> { implementation: Option>, } +impl<'db> OverloadedFunction<'db> { + /// Returns an iterator over all overloads and the implementation, in that order. + fn all(&self) -> impl Iterator> + '_ { + self.overloads.iter().copied().chain(self.implementation) + } +} + #[salsa::interned(debug)] pub struct FunctionType<'db> { /// Name of the function at definition. diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index b9c56a02af..9b5846d83c 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1096,6 +1096,103 @@ impl<'db> TypeInferenceBuilder<'db> { } } } + + // TODO: Add `@staticmethod` + for (decorator, name) in [(FunctionDecorators::CLASSMETHOD, "classmethod")] { + let mut decorator_present = false; + let mut decorator_missing = vec![]; + + for function in overloaded.all() { + if function.has_known_decorator(self.db(), decorator) { + decorator_present = true; + } else { + decorator_missing.push(function); + } + } + + if !decorator_present { + // Both overloads and implementation does not have the decorator + continue; + } + if decorator_missing.is_empty() { + // All overloads and implementation have the decorator + continue; + } + + let function_node = function.node(self.db(), self.file()); + if let Some(builder) = self + .context + .report_lint(&INVALID_OVERLOAD, &function_node.name) + { + let mut diagnostic = builder.into_diagnostic(format_args!( + "Overloaded function `{}` does not use the `@{name}` decorator \ + consistently", + &function_node.name + )); + for function in decorator_missing { + diagnostic.annotate( + self.context + .secondary(function.focus_range(self.db())) + .message(format_args!("Missing here")), + ); + } + } + } + + for (decorator, name) in [ + (FunctionDecorators::FINAL, "final"), + (FunctionDecorators::OVERRIDE, "override"), + ] { + if let Some(implementation) = overloaded.implementation.as_ref() { + for overload in &overloaded.overloads { + if !overload.has_known_decorator(self.db(), decorator) { + continue; + } + let function_node = function.node(self.db(), self.file()); + let Some(builder) = self + .context + .report_lint(&INVALID_OVERLOAD, &function_node.name) + else { + continue; + }; + let mut diagnostic = builder.into_diagnostic(format_args!( + "`@{name}` decorator should be applied only to the \ + overload implementation" + )); + diagnostic.annotate( + self.context + .secondary(implementation.focus_range(self.db())) + .message(format_args!("Implementation defined here")), + ); + } + } else { + let mut overloads = overloaded.overloads.iter(); + let Some(first_overload) = overloads.next() else { + continue; + }; + for overload in overloads { + if !overload.has_known_decorator(self.db(), decorator) { + continue; + } + let function_node = function.node(self.db(), self.file()); + let Some(builder) = self + .context + .report_lint(&INVALID_OVERLOAD, &function_node.name) + else { + continue; + }; + let mut diagnostic = builder.into_diagnostic(format_args!( + "`@{name}` decorator should be applied only to the \ + first overload" + )); + diagnostic.annotate( + self.context + .secondary(first_overload.focus_range(self.db())) + .message(format_args!("First overload defined here")), + ); + } + } + } } }