## Summary Part of #15383, this PR adds the core infrastructure to check for invalid overloads and adds a diagnostic to raise if there are < 2 overloads for a given definition. ### Design notes The requirements to check the overloads are: * Requires `FunctionType` which has the `to_overloaded` method * The `FunctionType` **should** be for the function that is either the implementation or the last overload if the implementation doesn't exists * Avoid checking any `FunctionType` that are part of an overload chain * Consider visibility constraints This required a couple of iteration to make sure all of the above requirements are fulfilled. #### 1. Use a set to deduplicate The logic would first collect all the `FunctionType` that are part of the overload chain except for the implementation or the last overload if the implementation doesn't exists. Then, when iterating over all the function declarations within the scope, we'd avoid checking these functions. But, this approach would fail to consider visibility constraints as certain overloads _can_ be behind a version check. Those aren't part of the overload chain but those aren't a separate overload chain either. <details><summary>Implementation:</summary> <p> ```rs fn check_overloaded_functions(&mut self) { let function_definitions = || { self.types .declarations .iter() .filter_map(|(definition, ty)| { // Filter out function literals that result from anything other than a function // definition e.g., imports. if let DefinitionKind::Function(function) = definition.kind(self.db()) { ty.inner_type() .into_function_literal() .map(|ty| (ty, definition.symbol(self.db()), function.node())) } else { None } }) }; // A set of all the functions that are part of an overloaded function definition except for // the implementation function and the last overload in case the implementation doesn't // exists. This allows us to collect all the function definitions that needs to be skipped // when checking for invalid overload usages. let mut overloads: HashSet<FunctionType<'db>> = HashSet::default(); for (function, _) in function_definitions() { let Some(overloaded) = function.to_overloaded(self.db()) else { continue; }; if overloaded.implementation.is_some() { overloads.extend(overloaded.overloads.iter().copied()); } else if let Some((_, previous_overloads)) = overloaded.overloads.split_last() { overloads.extend(previous_overloads.iter().copied()); } } for (function, function_node) in function_definitions() { let Some(overloaded) = function.to_overloaded(self.db()) else { continue; }; if overloads.contains(&function) { continue; } // At this point, the `function` variable is either the implementation function or the // last overloaded function if the implementation doesn't exists. if overloaded.overloads.len() < 2 { if let Some(builder) = self .context .report_lint(&INVALID_OVERLOAD, &function_node.name) { let mut diagnostic = builder.into_diagnostic(format_args!( "Function `{}` requires at least two overloads", &function_node.name )); if let Some(first_overload) = overloaded.overloads.first() { diagnostic.annotate( self.context .secondary(first_overload.focus_range(self.db())) .message(format_args!("Only one overload defined here")), ); } } } } } ``` </p> </details> #### 2. Define a `predecessor` query The `predecessor` query would return the previous `FunctionType` for the given `FunctionType` i.e., the current logic would be extracted to be a query instead. This could then be used to make sure that we're checking the entire overload chain once. The way this would've been implemented is to have a `to_overloaded` implementation which would take the root of the overload chain instead of the leaf. But, this would require updates to the use-def map to somehow be able to return the _following_ functions for a given definition. #### 3. Create a successor link This is what Pyrefly uses, we'd create a forward link between two functions that are involved in an overload chain. This means that for a given function, we can get the successor function. This could be used to find the _leaf_ of the overload chain which can then be used with the `to_overloaded` method to get the entire overload chain. But, this would also require updating the use-def map to be able to "see" the _following_ function. ### Implementation This leads us to the final implementation that this PR implements which is to consider the overloaded functions using: * Collect all the **function symbols** that are defined **and** called within the same file. This could potentially be an overloaded function * Use the public bindings to get the leaf of the overload chain and use that to get the entire overload chain via `to_overloaded` and perform the check This has a limitation that in case a function redefines an overload, then that overload will not be checked. For example: ```py from typing import overload @overload def f() -> None: ... @overload def f(x: int) -> int: ... # The above overload will not be checked as the below function with the same name # shadows it def f(*args: int) -> int: ... ``` ## Test Plan Update existing mdtest and add snapshot diagnostics.
14 KiB
Overloads
Reference: https://typing.python.org/en/latest/spec/overload.html
typing.overload
The definition of typing.overload in typeshed is an identity function.
from typing import overload
def foo(x: int) -> int:
return x
reveal_type(foo) # revealed: def foo(x: int) -> int
bar = overload(foo)
reveal_type(bar) # revealed: def foo(x: int) -> int
Functions
from typing import overload
@overload
def add() -> None: ...
@overload
def add(x: int) -> int: ...
@overload
def add(x: int, y: int) -> int: ...
def add(x: int | None = None, y: int | None = None) -> int | None:
return (x or 0) + (y or 0)
reveal_type(add) # revealed: Overload[() -> None, (x: int) -> int, (x: int, y: int) -> int]
reveal_type(add()) # revealed: None
reveal_type(add(1)) # revealed: int
reveal_type(add(1, 2)) # revealed: int
Overriding
These scenarios are to verify that the overloaded and non-overloaded definitions are correctly overridden by each other.
An overloaded function is overriding another overloaded function:
from typing import overload
@overload
def foo() -> None: ...
@overload
def foo(x: int) -> int: ...
def foo(x: int | None = None) -> int | None:
return x
reveal_type(foo) # revealed: Overload[() -> None, (x: int) -> int]
reveal_type(foo()) # revealed: None
reveal_type(foo(1)) # revealed: int
@overload
def foo() -> None: ...
@overload
def foo(x: str) -> str: ...
def foo(x: str | None = None) -> str | None:
return x
reveal_type(foo) # revealed: Overload[() -> None, (x: str) -> str]
reveal_type(foo()) # revealed: None
reveal_type(foo("")) # revealed: str
A non-overloaded function is overriding an overloaded function:
def foo(x: int) -> int:
return x
reveal_type(foo) # revealed: def foo(x: int) -> int
An overloaded function is overriding a non-overloaded function:
reveal_type(foo) # revealed: def foo(x: int) -> int
@overload
def foo() -> None: ...
@overload
def foo(x: bytes) -> bytes: ...
def foo(x: bytes | None = None) -> bytes | None:
return x
reveal_type(foo) # revealed: Overload[() -> None, (x: bytes) -> bytes]
reveal_type(foo()) # revealed: None
reveal_type(foo(b"")) # revealed: bytes
Methods
from typing import overload
class Foo1:
@overload
def method(self) -> None: ...
@overload
def method(self, x: int) -> int: ...
def method(self, x: int | None = None) -> int | None:
return x
foo1 = Foo1()
reveal_type(foo1.method) # revealed: Overload[() -> None, (x: int) -> int]
reveal_type(foo1.method()) # revealed: None
reveal_type(foo1.method(1)) # revealed: int
class Foo2:
@overload
def method(self) -> None: ...
@overload
def method(self, x: str) -> str: ...
def method(self, x: str | None = None) -> str | None:
return x
foo2 = Foo2()
reveal_type(foo2.method) # revealed: Overload[() -> None, (x: str) -> str]
reveal_type(foo2.method()) # revealed: None
reveal_type(foo2.method("")) # revealed: str
Constructor
from typing import overload
class Foo:
@overload
def __init__(self) -> None: ...
@overload
def __init__(self, x: int) -> None: ...
def __init__(self, x: int | None = None) -> None:
self.x = x
foo = Foo()
reveal_type(foo) # revealed: Foo
reveal_type(foo.x) # revealed: Unknown | int | None
foo1 = Foo(1)
reveal_type(foo1) # revealed: Foo
reveal_type(foo1.x) # revealed: Unknown | int | None
Version specific
Function definitions can vary between multiple Python versions.
Overload and non-overload (3.9)
Here, the same function is overloaded in one version and not in another.
[environment]
python-version = "3.9"
import sys
from typing import overload
if sys.version_info < (3, 10):
def func(x: int) -> int:
return x
elif sys.version_info <= (3, 12):
@overload
def func() -> None: ...
@overload
def func(x: int) -> int: ...
def func(x: int | None = None) -> int | None:
return x
reveal_type(func) # revealed: def func(x: int) -> int
func() # error: [missing-argument]
Overload and non-overload (3.10)
[environment]
python-version = "3.10"
import sys
from typing import overload
if sys.version_info < (3, 10):
def func(x: int) -> int:
return x
elif sys.version_info <= (3, 12):
@overload
def func() -> None: ...
@overload
def func(x: int) -> int: ...
def func(x: int | None = None) -> int | None:
return x
reveal_type(func) # revealed: Overload[() -> None, (x: int) -> int]
reveal_type(func()) # revealed: None
reveal_type(func(1)) # revealed: int
Some overloads are version specific (3.9)
[environment]
python-version = "3.9"
overloaded.pyi:
import sys
from typing import overload
if sys.version_info >= (3, 10):
@overload
def func() -> None: ...
@overload
def func(x: int) -> int: ...
@overload
def func(x: str) -> str: ...
main.py:
from overloaded import func
reveal_type(func) # revealed: Overload[(x: int) -> int, (x: str) -> str]
func() # error: [no-matching-overload]
reveal_type(func(1)) # revealed: int
reveal_type(func("")) # revealed: str
Some overloads are version specific (3.10)
[environment]
python-version = "3.10"
overloaded.pyi:
import sys
from typing import overload
@overload
def func() -> None: ...
if sys.version_info >= (3, 10):
@overload
def func(x: int) -> int: ...
@overload
def func(x: str) -> str: ...
main.py:
from overloaded import func
reveal_type(func) # revealed: Overload[() -> None, (x: int) -> int, (x: str) -> str]
reveal_type(func()) # revealed: None
reveal_type(func(1)) # revealed: int
reveal_type(func("")) # revealed: str
Generic
[environment]
python-version = "3.12"
For an overloaded generic function, it's not necessary for all overloads to be generic.
from typing import overload
@overload
def func() -> None: ...
@overload
def func[T](x: T) -> T: ...
def func[T](x: T | None = None) -> T | None:
return x
reveal_type(func) # revealed: Overload[() -> None, (x: T) -> T]
reveal_type(func()) # revealed: None
reveal_type(func(1)) # revealed: Literal[1]
reveal_type(func("")) # revealed: Literal[""]
Invalid
At least two overloads
At least two @overload-decorated definitions must be present.
from typing import overload
@overload
def func(x: int) -> int: ...
# error: [invalid-overload]
def func(x: int | str) -> int | str:
return x
from typing import overload
@overload
# error: [invalid-overload]
def func(x: int) -> int: ...
Overload without an implementation
Regular modules
In regular modules, a series of @overload-decorated definitions must be followed by exactly one
non-@overload-decorated definition (for the same function/method).
from typing import overload
# TODO: error because implementation does not exists
@overload
def func(x: int) -> int: ...
@overload
def func(x: str) -> str: ...
class Foo:
# TODO: error because implementation does not exists
@overload
def method(self, x: int) -> int: ...
@overload
def method(self, x: str) -> str: ...
Stub files
Overload definitions within stub files are exempt from this check.
from typing import overload
@overload
def func(x: int) -> int: ...
@overload
def func(x: str) -> str: ...
Protocols
Overload definitions within protocols are exempt from this check.
from typing import Protocol, overload
class Foo(Protocol):
@overload
def f(self, x: int) -> int: ...
@overload
def f(self, x: str) -> str: ...
Abstract methods
Overload definitions within abstract base classes are exempt from this check.
from abc import ABC, abstractmethod
from typing import overload
class AbstractFoo(ABC):
@overload
@abstractmethod
def f(self, x: int) -> int: ...
@overload
@abstractmethod
def f(self, x: str) -> str: ...
Using the @abstractmethod decorator requires that the class's metaclass is ABCMeta or is derived
from it.
class Foo:
# TODO: Error because implementation does not exists
@overload
@abstractmethod
def f(self, x: int) -> int: ...
@overload
@abstractmethod
def f(self, x: str) -> str: ...
And, the @abstractmethod decorator must be present on all the @overload-ed methods.
class PartialFoo1(ABC):
@overload
@abstractmethod
def f(self, x: int) -> int: ...
@overload
def f(self, x: str) -> str: ...
class PartialFoo(ABC):
@overload
def f(self, x: int) -> int: ...
@overload
@abstractmethod
def f(self, x: str) -> str: ...
Inconsistent decorators
@staticmethod / @classmethod
If one overload signature is decorated with @staticmethod or @classmethod, all overload
signatures must be similarly decorated. The implementation, if present, must also have a consistent
decorator.
from __future__ import annotations
from typing import overload
class CheckStaticMethod:
# TODO: error because `@staticmethod` does not exist on all overloads
@overload
def method1(x: int) -> int: ...
@overload
def method1(x: str) -> str: ...
@staticmethod
def method1(x: int | str) -> int | str:
return x
# TODO: error because `@staticmethod` does not exist on all overloads
@overload
def method2(x: int) -> int: ...
@overload
@staticmethod
def method2(x: str) -> str: ...
@staticmethod
def method2(x: int | str) -> int | str:
return x
# TODO: error because `@staticmethod` does not exist on the implementation
@overload
@staticmethod
def method3(x: int) -> int: ...
@overload
@staticmethod
def method3(x: str) -> str: ...
def method3(x: int | str) -> int | str:
return x
@overload
@staticmethod
def method4(x: int) -> int: ...
@overload
@staticmethod
def method4(x: str) -> str: ...
@staticmethod
def method4(x: int | str) -> int | str:
return x
class CheckClassMethod:
def __init__(self, x: int) -> None:
self.x = x
# TODO: error because `@classmethod` does not exist on all overloads
@overload
@classmethod
def try_from1(cls, x: int) -> CheckClassMethod: ...
@overload
def try_from1(cls, x: str) -> None: ...
@classmethod
def try_from1(cls, x: int | str) -> CheckClassMethod | None:
if isinstance(x, int):
return cls(x)
return None
# TODO: error because `@classmethod` does not exist on all overloads
@overload
def try_from2(cls, x: int) -> CheckClassMethod: ...
@overload
@classmethod
def try_from2(cls, x: str) -> None: ...
@classmethod
def try_from2(cls, x: int | str) -> CheckClassMethod | None:
if isinstance(x, int):
return cls(x)
return None
# TODO: error because `@classmethod` does not exist on the implementation
@overload
@classmethod
def try_from3(cls, x: int) -> CheckClassMethod: ...
@overload
@classmethod
def try_from3(cls, x: str) -> None: ...
def try_from3(cls, x: int | str) -> CheckClassMethod | None:
if isinstance(x, int):
return cls(x)
return None
@overload
@classmethod
def try_from4(cls, x: int) -> CheckClassMethod: ...
@overload
@classmethod
def try_from4(cls, x: str) -> None: ...
@classmethod
def try_from4(cls, x: int | str) -> CheckClassMethod | None:
if isinstance(x, int):
return cls(x)
return None
@final / @override
If a @final or @override decorator is supplied for a function with overloads, the decorator
should be applied only to the overload implementation if it is present.
from typing_extensions import final, overload, override
class Foo:
@overload
def method1(self, x: int) -> int: ...
@overload
def method1(self, x: str) -> str: ...
@final
def method1(self, x: int | str) -> int | str:
return x
# TODO: error because `@final` is not on the implementation
@overload
@final
def method2(self, x: int) -> int: ...
@overload
def method2(self, x: str) -> str: ...
def method2(self, x: int | str) -> int | str:
return x
# TODO: error because `@final` is not on the implementation
@overload
def method3(self, x: int) -> int: ...
@overload
@final
def method3(self, x: str) -> str: ...
def method3(self, x: int | str) -> int | str:
return x
class Base:
@overload
def method(self, x: int) -> int: ...
@overload
def method(self, x: str) -> str: ...
def method(self, x: int | str) -> int | str:
return x
class Sub1(Base):
@overload
def method(self, x: int) -> int: ...
@overload
def method(self, x: str) -> str: ...
@override
def method(self, x: int | str) -> int | str:
return x
class Sub2(Base):
# TODO: error because `@override` is not on the implementation
@overload
def method(self, x: int) -> int: ...
@overload
@override
def method(self, x: str) -> str: ...
def method(self, x: int | str) -> int | str:
return x
class Sub3(Base):
# TODO: error because `@override` is not on the implementation
@overload
@override
def method(self, x: int) -> int: ...
@overload
def method(self, x: str) -> str: ...
def method(self, x: int | str) -> int | str:
return x
@final / @override in stub files
If an overload implementation isn’t present (for example, in a stub file), the @final or
@override decorator should be applied only to the first overload.
from typing_extensions import final, overload, override
class Foo:
@overload
@final
def method1(self, x: int) -> int: ...
@overload
def method1(self, x: str) -> str: ...
# TODO: error because `@final` is not on the first overload
@overload
def method2(self, x: int) -> int: ...
@final
@overload
def method2(self, x: str) -> str: ...
class Base:
@overload
def method(self, x: int) -> int: ...
@overload
def method(self, x: str) -> str: ...
class Sub1(Base):
@overload
@override
def method(self, x: int) -> int: ...
@overload
def method(self, x: str) -> str: ...
class Sub2(Base):
# TODO: error because `@override` is not on the first overload
@overload
def method(self, x: int) -> int: ...
@overload
@override
def method(self, x: str) -> str: ...