From 44ad2012622d73373eb7922978f391eb6c1e54fa Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Fri, 18 Apr 2025 09:57:40 +0530 Subject: [PATCH] [red-knot] Add support for overloaded functions (#17366) ## Summary Part of #15383, this PR adds support for overloaded callables. Typing spec: https://typing.python.org/en/latest/spec/overload.html Specifically, it does the following: 1. Update the `FunctionType::signature` method to return signatures from a possibly overloaded callable using a new `FunctionSignature` enum 2. Update `CallableType` to accommodate overloaded callable by updating the inner type to `Box<[Signature]>` 3. Update the relation methods on `CallableType` with logic specific to overloads 4. Update the display of callable type to display a list of signatures enclosed by parenthesis 5. Update `CallableTypeOf` special form to recognize overloaded callable 6. Update subtyping, assignability and fully static check to account for callables (equivalence is planned to be done as a follow-up) For (2), it is required to be done in this PR because otherwise I'd need to add some workaround for `into_callable_type` and I though it would be best to include it in here. For (2), another possible design would be convert `CallableType` in an enum with two variants `CallableType::Single` and `CallableType::Overload` but I decided to go with `Box<[Signature]>` for now to (a) mirror it to be equivalent to `overload` field on `CallableSignature` and (b) to avoid any refactor in this PR. This could be done in a follow-up to better split the two kind of callables. ### Design There were two main candidates on how to represent the overloaded definition: 1. To include it in the existing infrastructure which is what this PR is doing by recognizing all the signatures within the `FunctionType::signature` method 2. To create a new `Overload` type variant
For context, this is what I had in mind with the new type variant:

```rs pub enum Type { FunctionLiteral(FunctionType), Overload(OverloadType), BoundMethod(BoundMethodType), ... } pub struct OverloadType { // FunctionLiteral or BoundMethod overloads: Box<[Type]>, // FunctionLiteral or BoundMethod implementation: Option } pub struct BoundMethodType { kind: BoundMethodKind, self_instance: Type, } pub enum BoundMethodKind { Function(FunctionType), Overload(OverloadType), } ```

The main reasons to choose (1) are the simplicity in the implementation, reusing the existing infrastructure, avoiding any complications that the new type variant has specifically around the different variants between function and methods which would require the overload type to use `Type` instead. ### Implementation The core logic is how to collect all the overloaded functions. The way this is done in this PR is by recording a **use** on the `Identifier` node that represents the function name in the use-def map. This is then used to fetch the previous symbol using the same name. This way the signatures are going to be propagated from top to bottom (from first overload to the final overload or the implementation) with each function / method. For example: ```py from typing import overload @overload def foo(x: int) -> int: ... @overload def foo(x: str) -> str: ... def foo(x: int | str) -> int | str: return x ``` Here, each definition of `foo` knows about all the signatures that comes before itself. So, the first overload would only see itself, the second would see the first and itself and so on until the implementation or the final overload. This approach required some updates specifically recognizing `Identifier` node to record the function use because it doesn't use `ExprName`. ## Test Plan Update existing test cases which were limited by the overload support and add test cases for the following cases: * Valid overloads as functions, methods, generics, version specific * Invalid overloads as stated in https://typing.python.org/en/latest/spec/overload.html#invalid-overload-definitions (implementation will be done in a follow-up) * Various relation: fully static, subtyping, and assignability (others in a follow-up) ## Ecosystem changes _WIP_ After going through the ecosystem changes (there are a lot!), here's what I've found: We need assignability check between a callable type and a class literal because a lot of builtins are defined as classes in typeshed whose constructor method is overloaded e.g., `map`, `sorted`, `list.sort`, `max`, `min` with the `key` parameter, `collections.abc.defaultdict`, etc. (https://github.com/astral-sh/ruff/issues/17343). This makes up most of the ecosystem diff **roughly 70 diagnostics**. For example: ```py from collections import defaultdict # red-knot: No overload of bound method `__init__` matches arguments [lint:no-matching-overload] defaultdict(int) # red-knot: No overload of bound method `__init__` matches arguments [lint:no-matching-overload] defaultdict(list) class Foo: def __init__(self, x: int): self.x = x # red-knot: No overload of function `__new__` matches arguments [lint:no-matching-overload] map(Foo, ["a", "b", "c"]) ``` Duplicate diagnostics in unpacking (https://github.com/astral-sh/ruff/issues/16514) has **~16 diagnostics**. Support for the `callable` builtin which requires `TypeIs` support. This is **5 diagnostics**. For example: ```py from typing import Any def _(x: Any | None) -> None: if callable(x): # red-knot: `Any | None` # Pyright: `(...) -> object` # mypy: `Any` # pyrefly: `(...) -> object` reveal_type(x) ``` Narrowing on `assert` which has **11 diagnostics**. This is being worked on in https://github.com/astral-sh/ruff/pull/17345. For example: ```py import re match = re.search("", "") assert match match.group() # error: [possibly-unbound-attribute] ``` Others: * `Self`: 2 * Type aliases: 6 * Generics: 3 * Protocols: 13 * Unpacking in comprehension: 1 (https://github.com/astral-sh/ruff/pull/17396) ## Performance Refer to https://github.com/astral-sh/ruff/pull/17366#issuecomment-2814053046. --- crates/red_knot/tests/cli.rs | 24 +- .../mdtest/annotations/literal_string.md | 6 +- .../resources/mdtest/attributes.md | 4 +- .../resources/mdtest/binary/instances.md | 10 +- .../resources/mdtest/binary/integers.md | 8 +- .../resources/mdtest/dataclasses.md | 13 +- .../resources/mdtest/descriptor_protocol.md | 6 +- .../resources/mdtest/narrow/truthiness.md | 2 +- .../resources/mdtest/overloads.md | 638 ++++++++++++++++++ .../resources/mdtest/subscript/bytes.md | 9 +- .../resources/mdtest/subscript/lists.md | 14 +- .../resources/mdtest/subscript/string.md | 9 +- .../resources/mdtest/subscript/tuple.md | 4 +- .../type_properties/is_assignable_to.md | 31 + .../type_properties/is_equivalent_to.md | 4 + .../mdtest/type_properties/is_fully_static.md | 31 + .../is_gradual_equivalent_to.md | 2 + .../mdtest/type_properties/is_subtype_of.md | 182 +++++ .../src/semantic_index/ast_ids.rs | 15 +- .../src/semantic_index/builder.rs | 11 + crates/red_knot_python_semantic/src/types.rs | 366 ++++++++-- .../src/types/call/bind.rs | 8 +- .../src/types/class.rs | 4 +- .../src/types/display.rs | 100 ++- .../src/types/infer.rs | 33 +- .../types/property_tests/type_generation.rs | 2 +- .../src/types/signatures.rs | 97 ++- 27 files changed, 1435 insertions(+), 198 deletions(-) create mode 100644 crates/red_knot_python_semantic/resources/mdtest/overloads.md diff --git a/crates/red_knot/tests/cli.rs b/crates/red_knot/tests/cli.rs index 3f2dd4a98a..a37468e824 100644 --- a/crates/red_knot/tests/cli.rs +++ b/crates/red_knot/tests/cli.rs @@ -252,7 +252,7 @@ fn configuration_rule_severity() -> anyhow::Result<()> { r#" y = 4 / 0 - for a in range(0, y): + for a in range(0, int(y)): x = a print(x) # possibly-unresolved-reference @@ -271,7 +271,7 @@ fn configuration_rule_severity() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | warning: lint:possibly-unresolved-reference @@ -307,7 +307,7 @@ fn configuration_rule_severity() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | Found 1 diagnostic @@ -328,7 +328,7 @@ fn cli_rule_severity() -> anyhow::Result<()> { y = 4 / 0 - for a in range(0, y): + for a in range(0, int(y)): x = a print(x) # possibly-unresolved-reference @@ -358,7 +358,7 @@ fn cli_rule_severity() -> anyhow::Result<()> { 4 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 5 | - 6 | for a in range(0, y): + 6 | for a in range(0, int(y)): | warning: lint:possibly-unresolved-reference @@ -405,7 +405,7 @@ fn cli_rule_severity() -> anyhow::Result<()> { 4 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 5 | - 6 | for a in range(0, y): + 6 | for a in range(0, int(y)): | Found 2 diagnostics @@ -426,7 +426,7 @@ fn cli_rule_severity_precedence() -> anyhow::Result<()> { r#" y = 4 / 0 - for a in range(0, y): + for a in range(0, int(y)): x = a print(x) # possibly-unresolved-reference @@ -445,7 +445,7 @@ fn cli_rule_severity_precedence() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | warning: lint:possibly-unresolved-reference @@ -482,7 +482,7 @@ fn cli_rule_severity_precedence() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | Found 1 diagnostic @@ -814,7 +814,7 @@ fn user_configuration() -> anyhow::Result<()> { r#" y = 4 / 0 - for a in range(0, y): + for a in range(0, int(y)): x = a print(x) @@ -841,7 +841,7 @@ fn user_configuration() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | warning: lint:possibly-unresolved-reference @@ -883,7 +883,7 @@ fn user_configuration() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | error: lint:possibly-unresolved-reference diff --git a/crates/red_knot_python_semantic/resources/mdtest/annotations/literal_string.md b/crates/red_knot_python_semantic/resources/mdtest/annotations/literal_string.md index 45dff62b80..445c7a05da 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/annotations/literal_string.md +++ b/crates/red_knot_python_semantic/resources/mdtest/annotations/literal_string.md @@ -72,13 +72,11 @@ reveal_type(baz) # revealed: Literal["bazfoo"] qux = (foo, bar) reveal_type(qux) # revealed: tuple[Literal["foo"], Literal["bar"]] -# TODO: Infer "LiteralString" -reveal_type(foo.join(qux)) # revealed: @Todo(return type of overloaded function) +reveal_type(foo.join(qux)) # revealed: LiteralString template: LiteralString = "{}, {}" reveal_type(template) # revealed: Literal["{}, {}"] -# TODO: Infer `LiteralString` -reveal_type(template.format(foo, bar)) # revealed: @Todo(return type of overloaded function) +reveal_type(template.format(foo, bar)) # revealed: LiteralString ``` ### Assignability diff --git a/crates/red_knot_python_semantic/resources/mdtest/attributes.md b/crates/red_knot_python_semantic/resources/mdtest/attributes.md index 15882300c1..3f6c5321c0 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/attributes.md +++ b/crates/red_knot_python_semantic/resources/mdtest/attributes.md @@ -1698,9 +1698,9 @@ Most attribute accesses on bool-literal types are delegated to `builtins.bool`, bools are instances of that class: ```py -# revealed: bound method Literal[True].__and__(**kwargs: @Todo(todo signature **kwargs)) -> @Todo(return type of overloaded function) +# revealed: Overload[(value: bool, /) -> bool, (value: int, /) -> int] reveal_type(True.__and__) -# revealed: bound method Literal[False].__or__(**kwargs: @Todo(todo signature **kwargs)) -> @Todo(return type of overloaded function) +# revealed: Overload[(value: bool, /) -> bool, (value: int, /) -> int] reveal_type(False.__or__) ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md b/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md index b435a34d76..650e7a636c 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md +++ b/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md @@ -310,9 +310,7 @@ reveal_type(A() + 1) # revealed: A reveal_type(1 + A()) # revealed: A reveal_type(A() + "foo") # revealed: A -# TODO should be `A` since `str.__add__` doesn't support `A` instances -# TODO overloads -reveal_type("foo" + A()) # revealed: @Todo(return type of overloaded function) +reveal_type("foo" + A()) # revealed: A reveal_type(A() + b"foo") # revealed: A # TODO should be `A` since `bytes.__add__` doesn't support `A` instances @@ -320,16 +318,14 @@ reveal_type(b"foo" + A()) # revealed: bytes reveal_type(A() + ()) # revealed: A # TODO this should be `A`, since `tuple.__add__` doesn't support `A` instances -reveal_type(() + A()) # revealed: @Todo(return type of overloaded function) +reveal_type(() + A()) # revealed: @Todo(full tuple[...] support) literal_string_instance = "foo" * 1_000_000_000 # the test is not testing what it's meant to be testing if this isn't a `LiteralString`: reveal_type(literal_string_instance) # revealed: LiteralString reveal_type(A() + literal_string_instance) # revealed: A -# TODO should be `A` since `str.__add__` doesn't support `A` instances -# TODO overloads -reveal_type(literal_string_instance + A()) # revealed: @Todo(return type of overloaded function) +reveal_type(literal_string_instance + A()) # revealed: A ``` ## Operations involving instances of classes inheriting from `Any` diff --git a/crates/red_knot_python_semantic/resources/mdtest/binary/integers.md b/crates/red_knot_python_semantic/resources/mdtest/binary/integers.md index a4172c821d..dcbc180317 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/binary/integers.md +++ b/crates/red_knot_python_semantic/resources/mdtest/binary/integers.md @@ -50,9 +50,11 @@ reveal_type(1 ** (largest_u32 + 1)) # revealed: int reveal_type(2**largest_u32) # revealed: int def variable(x: int): - reveal_type(x**2) # revealed: @Todo(return type of overloaded function) - reveal_type(2**x) # revealed: @Todo(return type of overloaded function) - reveal_type(x**x) # revealed: @Todo(return type of overloaded function) + reveal_type(x**2) # revealed: int + # TODO: should be `Any` (overload 5 on `__pow__`), requires correct overload matching + reveal_type(2**x) # revealed: int + # TODO: should be `Any` (overload 5 on `__pow__`), requires correct overload matching + reveal_type(x**x) # revealed: int ``` If the second argument is \<0, a `float` is returned at runtime. If the first argument is \<0 but diff --git a/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md b/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md index 28baa67ba4..94958c883b 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md +++ b/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md @@ -547,14 +547,14 @@ the descriptor's `__get__` method as if it had been called on the class itself, for the `instance` argument. ```py -from typing import overload +from typing import Literal, overload from dataclasses import dataclass class ConvertToLength: _len: int = 0 @overload - def __get__(self, instance: None, owner: type) -> str: ... + def __get__(self, instance: None, owner: type) -> Literal[""]: ... @overload def __get__(self, instance: object, owner: type | None) -> int: ... def __get__(self, instance: object | None, owner: type | None) -> str | int: @@ -570,12 +570,10 @@ class ConvertToLength: class C: converter: ConvertToLength = ConvertToLength() -# TODO: Should be `(converter: str = Literal[""]) -> None` once we understand overloads -reveal_type(C.__init__) # revealed: (converter: str = str | int) -> None +reveal_type(C.__init__) # revealed: (converter: str = Literal[""]) -> None c = C("abc") -# TODO: Should be `int` once we understand overloads -reveal_type(c.converter) # revealed: str | int +reveal_type(c.converter) # revealed: int # This is also okay: C() @@ -611,8 +609,7 @@ class AcceptsStrAndInt: class C: field: AcceptsStrAndInt = AcceptsStrAndInt() -# TODO: Should be `field: str | int = int` once we understand overloads -reveal_type(C.__init__) # revealed: (field: Unknown = int) -> None +reveal_type(C.__init__) # revealed: (field: str | int = int) -> None ``` ## `dataclasses.field` diff --git a/crates/red_knot_python_semantic/resources/mdtest/descriptor_protocol.md b/crates/red_knot_python_semantic/resources/mdtest/descriptor_protocol.md index aeabb336ec..8f8d17ae76 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/descriptor_protocol.md +++ b/crates/red_knot_python_semantic/resources/mdtest/descriptor_protocol.md @@ -459,11 +459,9 @@ class Descriptor: class C: d: Descriptor = Descriptor() -# TODO: should be `Literal["called on class object"] -reveal_type(C.d) # revealed: LiteralString +reveal_type(C.d) # revealed: Literal["called on class object"] -# TODO: should be `Literal["called on instance"] -reveal_type(C().d) # revealed: LiteralString +reveal_type(C().d) # revealed: Literal["called on instance"] ``` ## Descriptor protocol for dunder methods diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md index 84a14f93f5..f8f5ab8c5e 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md @@ -246,7 +246,7 @@ class MetaTruthy(type): class MetaDeferred(type): def __bool__(self) -> MetaAmbiguous: - return MetaAmbiguous() + raise NotImplementedError class AmbiguousClass(metaclass=MetaAmbiguous): ... class FalsyClass(metaclass=MetaFalsy): ... diff --git a/crates/red_knot_python_semantic/resources/mdtest/overloads.md b/crates/red_knot_python_semantic/resources/mdtest/overloads.md new file mode 100644 index 0000000000..764b6a4412 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/overloads.md @@ -0,0 +1,638 @@ +# Overloads + +Reference: + +## `typing.overload` + +The definition of `typing.overload` in typeshed is an identity function. + +```py +from typing import overload + +def foo(x: int) -> int: + return x + +reveal_type(foo) # revealed: def foo(x: int) -> int +bar = overload(foo) +reveal_type(bar) # revealed: def foo(x: int) -> int +``` + +## Functions + +```py +from typing import overload + +@overload +def add() -> None: ... +@overload +def add(x: int) -> int: ... +@overload +def add(x: int, y: int) -> int: ... +def add(x: int | None = None, y: int | None = None) -> int | None: + return (x or 0) + (y or 0) + +reveal_type(add) # revealed: Overload[() -> None, (x: int) -> int, (x: int, y: int) -> int] +reveal_type(add()) # revealed: None +reveal_type(add(1)) # revealed: int +reveal_type(add(1, 2)) # revealed: int +``` + +## Overriding + +These scenarios are to verify that the overloaded and non-overloaded definitions are correctly +overridden by each other. + +An overloaded function is overriding another overloaded function: + +```py +from typing import overload + +@overload +def foo() -> None: ... +@overload +def foo(x: int) -> int: ... +def foo(x: int | None = None) -> int | None: + return x + +reveal_type(foo) # revealed: Overload[() -> None, (x: int) -> int] +reveal_type(foo()) # revealed: None +reveal_type(foo(1)) # revealed: int + +@overload +def foo() -> None: ... +@overload +def foo(x: str) -> str: ... +def foo(x: str | None = None) -> str | None: + return x + +reveal_type(foo) # revealed: Overload[() -> None, (x: str) -> str] +reveal_type(foo()) # revealed: None +reveal_type(foo("")) # revealed: str +``` + +A non-overloaded function is overriding an overloaded function: + +```py +def foo(x: int) -> int: + return x + +reveal_type(foo) # revealed: def foo(x: int) -> int +``` + +An overloaded function is overriding a non-overloaded function: + +```py +reveal_type(foo) # revealed: def foo(x: int) -> int + +@overload +def foo() -> None: ... +@overload +def foo(x: bytes) -> bytes: ... +def foo(x: bytes | None = None) -> bytes | None: + return x + +reveal_type(foo) # revealed: Overload[() -> None, (x: bytes) -> bytes] +reveal_type(foo()) # revealed: None +reveal_type(foo(b"")) # revealed: bytes +``` + +## Methods + +```py +from typing import overload + +class Foo1: + @overload + def method(self) -> None: ... + @overload + def method(self, x: int) -> int: ... + def method(self, x: int | None = None) -> int | None: + return x + +foo1 = Foo1() +reveal_type(foo1.method) # revealed: Overload[() -> None, (x: int) -> int] +reveal_type(foo1.method()) # revealed: None +reveal_type(foo1.method(1)) # revealed: int + +class Foo2: + @overload + def method(self) -> None: ... + @overload + def method(self, x: str) -> str: ... + def method(self, x: str | None = None) -> str | None: + return x + +foo2 = Foo2() +reveal_type(foo2.method) # revealed: Overload[() -> None, (x: str) -> str] +reveal_type(foo2.method()) # revealed: None +reveal_type(foo2.method("")) # revealed: str +``` + +## Constructor + +```py +from typing import overload + +class Foo: + @overload + def __init__(self) -> None: ... + @overload + def __init__(self, x: int) -> None: ... + def __init__(self, x: int | None = None) -> None: + self.x = x + +foo = Foo() +reveal_type(foo) # revealed: Foo +reveal_type(foo.x) # revealed: Unknown | int | None + +foo1 = Foo(1) +reveal_type(foo1) # revealed: Foo +reveal_type(foo1.x) # revealed: Unknown | int | None +``` + +## Version specific + +Function definitions can vary between multiple Python versions. + +### Overload and non-overload (3.9) + +Here, the same function is overloaded in one version and not in another. + +```toml +[environment] +python-version = "3.9" +``` + +```py +import sys +from typing import overload + +if sys.version_info < (3, 10): + def func(x: int) -> int: + return x + +elif sys.version_info <= (3, 12): + @overload + def func() -> None: ... + @overload + def func(x: int) -> int: ... + def func(x: int | None = None) -> int | None: + return x + +reveal_type(func) # revealed: def func(x: int) -> int +func() # error: [missing-argument] +``` + +### Overload and non-overload (3.10) + +```toml +[environment] +python-version = "3.10" +``` + +```py +import sys +from typing import overload + +if sys.version_info < (3, 10): + def func(x: int) -> int: + return x + +elif sys.version_info <= (3, 12): + @overload + def func() -> None: ... + @overload + def func(x: int) -> int: ... + def func(x: int | None = None) -> int | None: + return x + +reveal_type(func) # revealed: Overload[() -> None, (x: int) -> int] +reveal_type(func()) # revealed: None +reveal_type(func(1)) # revealed: int +``` + +### Some overloads are version specific (3.9) + +```toml +[environment] +python-version = "3.9" +``` + +`overloaded.pyi`: + +```pyi +import sys +from typing import overload + +if sys.version_info >= (3, 10): + @overload + def func() -> None: ... + +@overload +def func(x: int) -> int: ... +@overload +def func(x: str) -> str: ... +``` + +`main.py`: + +```py +from overloaded import func + +reveal_type(func) # revealed: Overload[(x: int) -> int, (x: str) -> str] +func() # error: [no-matching-overload] +reveal_type(func(1)) # revealed: int +reveal_type(func("")) # revealed: str +``` + +### Some overloads are version specific (3.10) + +```toml +[environment] +python-version = "3.10" +``` + +`overloaded.pyi`: + +```pyi +import sys +from typing import overload + +@overload +def func() -> None: ... + +if sys.version_info >= (3, 10): + @overload + def func(x: int) -> int: ... + +@overload +def func(x: str) -> str: ... +``` + +`main.py`: + +```py +from overloaded import func + +reveal_type(func) # revealed: Overload[() -> None, (x: int) -> int, (x: str) -> str] +reveal_type(func()) # revealed: None +reveal_type(func(1)) # revealed: int +reveal_type(func("")) # revealed: str +``` + +## Generic + +```toml +[environment] +python-version = "3.12" +``` + +For an overloaded generic function, it's not necessary for all overloads to be generic. + +```py +from typing import overload + +@overload +def func() -> None: ... +@overload +def func[T](x: T) -> T: ... +def func[T](x: T | None = None) -> T | None: + return x + +reveal_type(func) # revealed: Overload[() -> None, (x: T) -> T] +reveal_type(func()) # revealed: None +reveal_type(func(1)) # revealed: Literal[1] +reveal_type(func("")) # revealed: Literal[""] +``` + +## Invalid + +### At least two overloads + +At least two `@overload`-decorated definitions must be present. + +```py +from typing import overload + +# TODO: error +@overload +def func(x: int) -> int: ... +def func(x: int | str) -> int | str: + return x +``` + +### Overload without an implementation + +#### Regular modules + +In regular modules, a series of `@overload`-decorated definitions must be followed by exactly one +non-`@overload`-decorated definition (for the same function/method). + +```py +from typing import overload + +# TODO: error because implementation does not exists +@overload +def func(x: int) -> int: ... +@overload +def func(x: str) -> str: ... + +class Foo: + # TODO: error because implementation does not exists + @overload + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... +``` + +#### Stub files + +Overload definitions within stub files are exempt from this check. + +```pyi +from typing import overload + +@overload +def func(x: int) -> int: ... +@overload +def func(x: str) -> str: ... +``` + +#### Protocols + +Overload definitions within protocols are exempt from this check. + +```py +from typing import Protocol, overload + +class Foo(Protocol): + @overload + def f(self, x: int) -> int: ... + @overload + def f(self, x: str) -> str: ... +``` + +#### Abstract methods + +Overload definitions within abstract base classes are exempt from this check. + +```py +from abc import ABC, abstractmethod +from typing import overload + +class AbstractFoo(ABC): + @overload + @abstractmethod + def f(self, x: int) -> int: ... + @overload + @abstractmethod + def f(self, x: str) -> str: ... +``` + +Using the `@abstractmethod` decorator requires that the class's metaclass is `ABCMeta` or is derived +from it. + +```py +class Foo: + # TODO: Error because implementation does not exists + @overload + @abstractmethod + def f(self, x: int) -> int: ... + @overload + @abstractmethod + def f(self, x: str) -> str: ... +``` + +And, the `@abstractmethod` decorator must be present on all the `@overload`-ed methods. + +```py +class PartialFoo1(ABC): + @overload + @abstractmethod + def f(self, x: int) -> int: ... + @overload + def f(self, x: str) -> str: ... + +class PartialFoo(ABC): + @overload + def f(self, x: int) -> int: ... + @overload + @abstractmethod + def f(self, x: str) -> str: ... +``` + +### Inconsistent decorators + +#### `@staticmethod` / `@classmethod` + +If one overload signature is decorated with `@staticmethod` or `@classmethod`, all overload +signatures must be similarly decorated. The implementation, if present, must also have a consistent +decorator. + +```py +from __future__ import annotations + +from typing import overload + +class CheckStaticMethod: + # TODO: error because `@staticmethod` does not exist on all overloads + @overload + def method1(x: int) -> int: ... + @overload + def method1(x: str) -> str: ... + @staticmethod + def method1(x: int | str) -> int | str: + return x + # TODO: error because `@staticmethod` does not exist on all overloads + @overload + def method2(x: int) -> int: ... + @overload + @staticmethod + def method2(x: str) -> str: ... + @staticmethod + def method2(x: int | str) -> int | str: + return x + # TODO: error because `@staticmethod` does not exist on the implementation + @overload + @staticmethod + def method3(x: int) -> int: ... + @overload + @staticmethod + def method3(x: str) -> str: ... + def method3(x: int | str) -> int | str: + return x + + @overload + @staticmethod + def method4(x: int) -> int: ... + @overload + @staticmethod + def method4(x: str) -> str: ... + @staticmethod + def method4(x: int | str) -> int | str: + return x + +class CheckClassMethod: + def __init__(self, x: int) -> None: + self.x = x + # TODO: error because `@classmethod` does not exist on all overloads + @overload + @classmethod + def try_from1(cls, x: int) -> CheckClassMethod: ... + @overload + def try_from1(cls, x: str) -> None: ... + @classmethod + def try_from1(cls, x: int | str) -> CheckClassMethod | None: + if isinstance(x, int): + return cls(x) + return None + # TODO: error because `@classmethod` does not exist on all overloads + @overload + def try_from2(cls, x: int) -> CheckClassMethod: ... + @overload + @classmethod + def try_from2(cls, x: str) -> None: ... + @classmethod + def try_from2(cls, x: int | str) -> CheckClassMethod | None: + if isinstance(x, int): + return cls(x) + return None + # TODO: error because `@classmethod` does not exist on the implementation + @overload + @classmethod + def try_from3(cls, x: int) -> CheckClassMethod: ... + @overload + @classmethod + def try_from3(cls, x: str) -> None: ... + def try_from3(cls, x: int | str) -> CheckClassMethod | None: + if isinstance(x, int): + return cls(x) + return None + + @overload + @classmethod + def try_from4(cls, x: int) -> CheckClassMethod: ... + @overload + @classmethod + def try_from4(cls, x: str) -> None: ... + @classmethod + def try_from4(cls, x: int | str) -> CheckClassMethod | None: + if isinstance(x, int): + return cls(x) + return None +``` + +#### `@final` / `@override` + +If a `@final` or `@override` decorator is supplied for a function with overloads, the decorator +should be applied only to the overload implementation if it is present. + +```py +from typing_extensions import final, overload, override + +class Foo: + @overload + def method1(self, x: int) -> int: ... + @overload + def method1(self, x: str) -> str: ... + @final + def method1(self, x: int | str) -> int | str: + return x + # TODO: error because `@final` is not on the implementation + @overload + @final + def method2(self, x: int) -> int: ... + @overload + def method2(self, x: str) -> str: ... + def method2(self, x: int | str) -> int | str: + return x + # TODO: error because `@final` is not on the implementation + @overload + def method3(self, x: int) -> int: ... + @overload + @final + def method3(self, x: str) -> str: ... + def method3(self, x: int | str) -> int | str: + return x + +class Base: + @overload + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + def method(self, x: int | str) -> int | str: + return x + +class Sub1(Base): + @overload + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + @override + def method(self, x: int | str) -> int | str: + return x + +class Sub2(Base): + # TODO: error because `@override` is not on the implementation + @overload + def method(self, x: int) -> int: ... + @overload + @override + def method(self, x: str) -> str: ... + def method(self, x: int | str) -> int | str: + return x + +class Sub3(Base): + # TODO: error because `@override` is not on the implementation + @overload + @override + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + def method(self, x: int | str) -> int | str: + return x +``` + +#### `@final` / `@override` in stub files + +If an overload implementation isn’t present (for example, in a stub file), the `@final` or +`@override` decorator should be applied only to the first overload. + +```pyi +from typing_extensions import final, overload, override + +class Foo: + @overload + @final + def method1(self, x: int) -> int: ... + @overload + def method1(self, x: str) -> str: ... + + # TODO: error because `@final` is not on the first overload + @overload + def method2(self, x: int) -> int: ... + @final + @overload + def method2(self, x: str) -> str: ... + +class Base: + @overload + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + +class Sub1(Base): + @overload + @override + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + +class Sub2(Base): + # TODO: error because `@override` is not on the first overload + @overload + def method(self, x: int) -> int: ... + @overload + @override + def method(self, x: str) -> str: ... +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md index 255cb6bc9d..7127263dd4 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md @@ -24,8 +24,7 @@ reveal_type(y) # revealed: Unknown def _(n: int): a = b"abcde"[n] - # TODO: Support overloads... Should be `bytes` - reveal_type(a) # revealed: @Todo(return type of overloaded function) + reveal_type(a) # revealed: int ``` ## Slices @@ -43,11 +42,9 @@ b[::0] # error: [zero-stepsize-in-slice] def _(m: int, n: int): byte_slice1 = b[m:n] - # TODO: Support overloads... Should be `bytes` - reveal_type(byte_slice1) # revealed: @Todo(return type of overloaded function) + reveal_type(byte_slice1) # revealed: bytes def _(s: bytes) -> bytes: byte_slice2 = s[0:5] - # TODO: Support overloads... Should be `bytes` - return reveal_type(byte_slice2) # revealed: @Todo(return type of overloaded function) + return reveal_type(byte_slice2) # revealed: bytes ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/lists.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/lists.md index 408a4f43b6..5f9264f6fa 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/lists.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/lists.md @@ -12,13 +12,13 @@ x = [1, 2, 3] reveal_type(x) # revealed: list # TODO reveal int -reveal_type(x[0]) # revealed: @Todo(return type of overloaded function) +reveal_type(x[0]) # revealed: Unknown | @Todo(Support for `typing.TypeVar` instances in type expressions) # TODO reveal list -reveal_type(x[0:1]) # revealed: @Todo(return type of overloaded function) +reveal_type(x[0:1]) # revealed: @Todo(generics) -# TODO error -reveal_type(x["a"]) # revealed: @Todo(return type of overloaded function) +# error: [call-non-callable] +reveal_type(x["a"]) # revealed: Unknown ``` ## Assignments within list assignment @@ -29,9 +29,11 @@ In assignment, we might also have a named assignment. This should also get type x = [1, 2, 3] x[0 if (y := 2) else 1] = 5 -# TODO error? (indeterminite index type) +# TODO: better error than "method `__getitem__` not callable on type `list`" +# error: [call-non-callable] x["a" if (y := 2) else 1] = 6 -# TODO error (can't index via string) +# TODO: better error than "method `__getitem__` not callable on type `list`" +# error: [call-non-callable] x["a" if (y := 2) else "b"] = 6 ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md index 9a875dc323..469300c0b4 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md @@ -21,8 +21,7 @@ reveal_type(b) # revealed: Unknown def _(n: int): a = "abcde"[n] - # TODO: Support overloads... Should be `str` - reveal_type(a) # revealed: @Todo(return type of overloaded function) + reveal_type(a) # revealed: LiteralString ``` ## Slices @@ -75,12 +74,10 @@ def _(m: int, n: int, s2: str): s[::0] # error: [zero-stepsize-in-slice] substring1 = s[m:n] - # TODO: Support overloads... Should be `LiteralString` - reveal_type(substring1) # revealed: @Todo(return type of overloaded function) + reveal_type(substring1) # revealed: LiteralString substring2 = s2[0:5] - # TODO: Support overloads... Should be `str` - reveal_type(substring2) # revealed: @Todo(return type of overloaded function) + reveal_type(substring2) # revealed: str ``` ## Unsupported slice types diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md index 8ac7ff2534..579d7451f6 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md @@ -69,8 +69,8 @@ def _(m: int, n: int): t[::0] # error: [zero-stepsize-in-slice] tuple_slice = t[m:n] - # TODO: Support overloads... Should be `tuple[Literal[1, 'a', b"b"] | None, ...]` - reveal_type(tuple_slice) # revealed: @Todo(return type of overloaded function) + # TODO: Should be `tuple[Literal[1, 'a', b"b"] | None, ...]` + reveal_type(tuple_slice) # revealed: @Todo(full tuple[...] support) ``` ## Inheritance diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md index 5ca7c9c628..fdd7ddbe7c 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md @@ -522,4 +522,35 @@ c: Callable[[Any], str] = A().f c: Callable[[Any], str] = A().g ``` +### Overloads + +`overloaded.pyi`: + +```pyi +from typing import Any, overload + +@overload +def overloaded() -> None: ... +@overload +def overloaded(a: str) -> str: ... +@overload +def overloaded(a: str, b: Any) -> str: ... +``` + +```py +from overloaded import overloaded +from typing import Any, Callable + +c: Callable[[], None] = overloaded +c: Callable[[str], str] = overloaded +c: Callable[[str, Any], Any] = overloaded +c: Callable[..., str] = overloaded + +# error: [invalid-assignment] +c: Callable[..., int] = overloaded + +# error: [invalid-assignment] +c: Callable[[int], str] = overloaded +``` + [typing documentation]: https://typing.python.org/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md index 39aa18fc43..3a234bb8a6 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md @@ -254,4 +254,8 @@ from knot_extensions import is_equivalent_to, static_assert static_assert(is_equivalent_to(int | Callable[[int | str], None], Callable[[str | int], None] | int)) ``` +### Overloads + +TODO + [the equivalence relation]: https://typing.python.org/en/latest/spec/glossary.html#term-equivalent diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_fully_static.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_fully_static.md index 8af5af4154..9a44ce090b 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_fully_static.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_fully_static.md @@ -99,3 +99,34 @@ static_assert(not is_fully_static(CallableTypeOf[f13])) static_assert(not is_fully_static(CallableTypeOf[f14])) static_assert(not is_fully_static(CallableTypeOf[f15])) ``` + +## Overloads + +`overloaded.pyi`: + +```pyi +from typing import Any, overload + +@overload +def gradual() -> None: ... +@overload +def gradual(a: Any) -> None: ... + +@overload +def static() -> None: ... +@overload +def static(x: int) -> None: ... +@overload +def static(x: str) -> str: ... +``` + +```py +from knot_extensions import CallableTypeOf, TypeOf, is_fully_static, static_assert +from overloaded import gradual, static + +static_assert(is_fully_static(TypeOf[gradual])) +static_assert(is_fully_static(TypeOf[static])) + +static_assert(not is_fully_static(CallableTypeOf[gradual])) +static_assert(is_fully_static(CallableTypeOf[static])) +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md index 43062b1802..e3e46a96e6 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md @@ -157,4 +157,6 @@ def f6(a, /): ... static_assert(not is_gradual_equivalent_to(CallableTypeOf[f1], CallableTypeOf[f6])) ``` +TODO: Overloads + [materializations]: https://typing.python.org/en/latest/spec/glossary.html#term-materialize diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md index aa5f025e13..c61f748598 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md @@ -1153,5 +1153,187 @@ static_assert(not is_subtype_of(TypeOf[A.g], Callable[[], int])) static_assert(is_subtype_of(TypeOf[A.f], Callable[[A, int], int])) ``` +### Overloads + +#### Subtype overloaded + +For `B <: A`, if a callable `B` is overloaded with two or more signatures, it is a subtype of +callable `A` if _at least one_ of the overloaded signatures in `B` is a subtype of `A`. + +`overloaded.pyi`: + +```pyi +from typing import overload + +class A: ... +class B: ... +class C: ... + +@overload +def overloaded(x: A) -> None: ... +@overload +def overloaded(x: B) -> None: ... +``` + +```py +from knot_extensions import CallableTypeOf, is_subtype_of, static_assert +from overloaded import A, B, C, overloaded + +def accepts_a(x: A) -> None: ... +def accepts_b(x: B) -> None: ... +def accepts_c(x: C) -> None: ... + +static_assert(is_subtype_of(CallableTypeOf[overloaded], CallableTypeOf[accepts_a])) +static_assert(is_subtype_of(CallableTypeOf[overloaded], CallableTypeOf[accepts_b])) +static_assert(not is_subtype_of(CallableTypeOf[overloaded], CallableTypeOf[accepts_c])) +``` + +#### Supertype overloaded + +For `B <: A`, if a callable `A` is overloaded with two or more signatures, callable `B` is a subtype +of `A` if `B` is a subtype of _all_ of the signatures in `A`. + +`overloaded.pyi`: + +```pyi +from typing import overload + +class Grandparent: ... +class Parent(Grandparent): ... +class Child(Parent): ... + +@overload +def overloaded(a: Child) -> None: ... +@overload +def overloaded(a: Parent) -> None: ... +@overload +def overloaded(a: Grandparent) -> None: ... +``` + +```py +from knot_extensions import CallableTypeOf, is_subtype_of, static_assert +from overloaded import Grandparent, Parent, Child, overloaded + +# This is a subtype of only the first overload +def child(a: Child) -> None: ... + +# This is a subtype of the first and second overload +def parent(a: Parent) -> None: ... + +# This is the only function that's a subtype of all overloads +def grandparent(a: Grandparent) -> None: ... + +static_assert(not is_subtype_of(CallableTypeOf[child], CallableTypeOf[overloaded])) +static_assert(not is_subtype_of(CallableTypeOf[parent], CallableTypeOf[overloaded])) +static_assert(is_subtype_of(CallableTypeOf[grandparent], CallableTypeOf[overloaded])) +``` + +#### Both overloads + +For `B <: A`, if both `A` and `B` is a callable that's overloaded with two or more signatures, then +`B` is a subtype of `A` if for _every_ signature in `A`, there is _at least one_ signature in `B` +that is a subtype of it. + +`overloaded.pyi`: + +```pyi +from typing import overload + +class Grandparent: ... +class Parent(Grandparent): ... +class Child(Parent): ... +class Other: ... + +@overload +def pg(a: Parent) -> None: ... +@overload +def pg(a: Grandparent) -> None: ... + +@overload +def po(a: Parent) -> None: ... +@overload +def po(a: Other) -> None: ... + +@overload +def go(a: Grandparent) -> None: ... +@overload +def go(a: Other) -> None: ... + +@overload +def cpg(a: Child) -> None: ... +@overload +def cpg(a: Parent) -> None: ... +@overload +def cpg(a: Grandparent) -> None: ... + +@overload +def empty_go() -> Child: ... +@overload +def empty_go(a: Grandparent) -> None: ... +@overload +def empty_go(a: Other) -> Other: ... + +@overload +def empty_cp() -> Parent: ... +@overload +def empty_cp(a: Child) -> None: ... +@overload +def empty_cp(a: Parent) -> None: ... +``` + +```py +from knot_extensions import CallableTypeOf, is_subtype_of, static_assert +from overloaded import pg, po, go, cpg, empty_go, empty_cp + +static_assert(is_subtype_of(CallableTypeOf[pg], CallableTypeOf[cpg])) +static_assert(is_subtype_of(CallableTypeOf[cpg], CallableTypeOf[pg])) + +static_assert(not is_subtype_of(CallableTypeOf[po], CallableTypeOf[pg])) +static_assert(not is_subtype_of(CallableTypeOf[pg], CallableTypeOf[po])) + +static_assert(is_subtype_of(CallableTypeOf[go], CallableTypeOf[pg])) +static_assert(not is_subtype_of(CallableTypeOf[pg], CallableTypeOf[go])) + +# Overload 1 in `empty_go` is a subtype of overload 1 in `empty_cp` +# Overload 2 in `empty_go` is a subtype of overload 2 in `empty_cp` +# Overload 2 in `empty_go` is a subtype of overload 3 in `empty_cp` +# +# All overloads in `empty_cp` has a subtype in `empty_go` +static_assert(is_subtype_of(CallableTypeOf[empty_go], CallableTypeOf[empty_cp])) + +static_assert(not is_subtype_of(CallableTypeOf[empty_cp], CallableTypeOf[empty_go])) +``` + +#### Order of overloads + +Order of overloads is irrelevant for subtyping. + +`overloaded.pyi`: + +```pyi +from typing import overload + +class A: ... +class B: ... + +@overload +def overload_ab(x: A) -> None: ... +@overload +def overload_ab(x: B) -> None: ... + +@overload +def overload_ba(x: B) -> None: ... +@overload +def overload_ba(x: A) -> None: ... +``` + +```py +from overloaded import overload_ab, overload_ba +from knot_extensions import CallableTypeOf, is_subtype_of, static_assert + +static_assert(is_subtype_of(CallableTypeOf[overload_ab], CallableTypeOf[overload_ba])) +static_assert(is_subtype_of(CallableTypeOf[overload_ba], CallableTypeOf[overload_ab])) +``` + [special case for float and complex]: https://typing.python.org/en/latest/spec/special-types.html#special-cases-for-float-and-complex [typing documentation]: https://typing.python.org/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence diff --git a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs index 160f3af7b9..bd5e93ada6 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs @@ -58,6 +58,13 @@ pub trait HasScopedUseId { fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId; } +impl HasScopedUseId for ast::Identifier { + fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId { + let ast_ids = ast_ids(db, scope); + ast_ids.use_id(self) + } +} + impl HasScopedUseId for ast::ExprName { fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId { let expression_ref = ExprRef::from(self); @@ -157,7 +164,7 @@ impl AstIdsBuilder { } /// Adds `expr` to the use ids map and returns its id. - pub(super) fn record_use(&mut self, expr: &ast::Expr) -> ScopedUseId { + pub(super) fn record_use(&mut self, expr: impl Into) -> ScopedUseId { let use_id = self.uses_map.len().into(); self.uses_map.insert(expr.into(), use_id); @@ -196,4 +203,10 @@ pub(crate) mod node_key { Self(NodeKey::from_node(value)) } } + + impl From<&ast::Identifier> for ExpressionNodeKey { + fn from(value: &ast::Identifier) -> Self { + Self(NodeKey::from_node(value)) + } + } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index e4c25f4840..f9d312ab53 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -1115,6 +1115,17 @@ where // at the end to match the runtime evaluation of parameter defaults // and return-type annotations. let (symbol, _) = self.add_symbol(name.id.clone()); + + // Record a use of the function name in the scope that it is defined in, so that it + // can be used to find previously defined functions with the same name. This is + // used to collect all the overloaded definitions of a function. This needs to be + // done on the `Identifier` node as opposed to `ExprName` because that's what the + // AST uses. + self.mark_symbol_used(symbol); + let use_id = self.current_ast_ids().record_use(name); + self.current_use_def_map_mut() + .record_use(symbol, use_id, NodeKey::from_node(name)); + self.add_definition(symbol, function_def); } ast::Stmt::ClassDef(class) => { diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 4a2b446d2e..6782de5f69 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -29,12 +29,14 @@ pub(crate) use self::signatures::{CallableSignature, Signature, Signatures}; pub(crate) use self::subclass_of::SubclassOfType; use crate::module_name::ModuleName; use crate::module_resolver::{file_to_module, resolve_module, KnownModule}; -use crate::semantic_index::ast_ids::HasScopedExpressionId; +use crate::semantic_index::ast_ids::{HasScopedExpressionId, HasScopedUseId}; use crate::semantic_index::definition::Definition; use crate::semantic_index::symbol::ScopeId; use crate::semantic_index::{imported_modules, semantic_index}; use crate::suppression::check_suppressions; -use crate::symbol::{imported_symbol, Boundness, Symbol, SymbolAndQualifiers}; +use crate::symbol::{ + imported_symbol, symbol_from_bindings, Boundness, Symbol, SymbolAndQualifiers, +}; use crate::types::call::{Bindings, CallArgumentTypes, CallableBinding}; pub(crate) use crate::types::class_base::ClassBase; use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION}; @@ -507,12 +509,14 @@ impl<'db> Type<'db> { .any(|ty| ty.contains_todo(db)), Self::Callable(callable) => { - let signature = callable.signature(db); - signature.parameters().iter().any(|param| { - param - .annotated_type() - .is_some_and(|ty| ty.contains_todo(db)) - }) || signature.return_ty.is_some_and(|ty| ty.contains_todo(db)) + let signatures = callable.signatures(db); + signatures.iter().any(|signature| { + signature.parameters().iter().any(|param| { + param + .annotated_type() + .is_some_and(|ty| ty.contains_todo(db)) + }) || signature.return_ty.is_some_and(|ty| ty.contains_todo(db)) + }) } Self::SubclassOf(subclass_of) => match subclass_of.subclass_of() { @@ -1029,9 +1033,9 @@ impl<'db> Type<'db> { .to_instance(db) .is_subtype_of(db, target), - (Type::Callable(self_callable), Type::Callable(other_callable)) => self_callable - .signature(db) - .is_subtype_of(db, other_callable.signature(db)), + (Type::Callable(self_callable), Type::Callable(other_callable)) => { + self_callable.is_subtype_of(db, other_callable) + } (Type::DataclassDecorator(_), _) => { // TODO: Implement subtyping using an equivalent `Callable` type. @@ -1358,9 +1362,9 @@ impl<'db> Type<'db> { ) } - (Type::Callable(self_callable), Type::Callable(target_callable)) => self_callable - .signature(db) - .is_assignable_to(db, target_callable.signature(db)), + (Type::Callable(self_callable), Type::Callable(target_callable)) => { + self_callable.is_assignable_to(db, target_callable) + } (Type::FunctionLiteral(self_function_literal), Type::Callable(_)) => { self_function_literal @@ -1391,9 +1395,7 @@ impl<'db> Type<'db> { left.is_equivalent_to(db, right) } (Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right), - (Type::Callable(left), Type::Callable(right)) => { - left.signature(db).is_equivalent_to(db, right.signature(db)) - } + (Type::Callable(left), Type::Callable(right)) => left.is_equivalent_to(db, right), _ => self == other && self.is_fully_static(db) && other.is_fully_static(db), } } @@ -1454,9 +1456,9 @@ impl<'db> Type<'db> { first.is_gradual_equivalent_to(db, second) } - (Type::Callable(first), Type::Callable(second)) => first - .signature(db) - .is_gradual_equivalent_to(db, second.signature(db)), + (Type::Callable(first), Type::Callable(second)) => { + first.is_gradual_equivalent_to(db, second) + } _ => false, } @@ -1904,7 +1906,7 @@ impl<'db> Type<'db> { .elements(db) .iter() .all(|elem| elem.is_fully_static(db)), - Type::Callable(callable) => callable.signature(db).is_fully_static(db), + Type::Callable(callable) => callable.is_fully_static(db), } } @@ -3141,16 +3143,27 @@ impl<'db> Type<'db> { /// [`CallErrorKind::NotCallable`]. fn signatures(self, db: &'db dyn Db) -> Signatures<'db> { match self { - Type::Callable(callable) => Signatures::single(CallableSignature::single( - self, - callable.signature(db).clone(), - )), + Type::Callable(callable) => { + Signatures::single(match callable.signatures(db).as_ref() { + [signature] => CallableSignature::single(self, signature.clone()), + signatures => { + CallableSignature::from_overloads(self, signatures.iter().cloned()) + } + }) + } Type::BoundMethod(bound_method) => { let signature = bound_method.function(db).signature(db); - let signature = CallableSignature::single(self, signature.clone()) - .with_bound_type(bound_method.self_instance(db)); - Signatures::single(signature) + Signatures::single(match signature { + FunctionSignature::Single(signature) => { + CallableSignature::single(self, signature.clone()) + .with_bound_type(bound_method.self_instance(db)) + } + FunctionSignature::Overloaded(signatures, _) => { + CallableSignature::from_overloads(self, signatures.iter().cloned()) + .with_bound_type(bound_method.self_instance(db)) + } + }) } Type::MethodWrapper( @@ -3497,10 +3510,14 @@ impl<'db> Type<'db> { Signatures::single(signature) } - _ => Signatures::single(CallableSignature::single( - self, - function_type.signature(db).clone(), - )), + _ => Signatures::single(match function_type.signature(db) { + FunctionSignature::Single(signature) => { + CallableSignature::single(self, signature.clone()) + } + FunctionSignature::Overloaded(signatures, _) => { + CallableSignature::from_overloads(self, signatures.iter().cloned()) + } + }), }, Type::ClassLiteral(class) => match class.known(db) { @@ -3692,7 +3709,10 @@ impl<'db> Type<'db> { .with_annotated_type(UnionType::from_elements( db, [ - Type::Callable(CallableType::new(db, getter_signature)), + Type::Callable(CallableType::single( + db, + getter_signature, + )), Type::none(db), ], )) @@ -3701,7 +3721,10 @@ impl<'db> Type<'db> { .with_annotated_type(UnionType::from_elements( db, [ - Type::Callable(CallableType::new(db, setter_signature)), + Type::Callable(CallableType::single( + db, + setter_signature, + )), Type::none(db), ], )) @@ -3710,7 +3733,7 @@ impl<'db> Type<'db> { .with_annotated_type(UnionType::from_elements( db, [ - Type::Callable(CallableType::new( + Type::Callable(CallableType::single( db, deleter_signature, )), @@ -5738,15 +5761,54 @@ bitflags! { } } +/// A function signature, which can be either a single signature or an overloaded signature. +#[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] +pub(crate) enum FunctionSignature<'db> { + /// A single function signature. + Single(Signature<'db>), + + /// An overloaded function signature containing the `@overload`-ed signatures and an optional + /// implementation signature. + Overloaded(Vec>, Option>), +} + +impl<'db> FunctionSignature<'db> { + /// Returns a slice of all signatures. + /// + /// For an overloaded function, this only includes the `@overload`-ed signatures and not the + /// implementation signature. + pub(crate) fn as_slice(&self) -> &[Signature<'db>] { + match self { + Self::Single(signature) => std::slice::from_ref(signature), + Self::Overloaded(signatures, _) => signatures, + } + } + + /// Returns an iterator over the signatures. + pub(crate) fn iter(&self) -> Iter> { + self.as_slice().iter() + } +} + +impl<'db> IntoIterator for &'db FunctionSignature<'db> { + type Item = &'db Signature<'db>; + type IntoIter = Iter<'db, Signature<'db>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + #[salsa::interned(debug)] pub struct FunctionType<'db> { - /// name of the function at definition + /// Name of the function at definition. #[return_ref] pub name: ast::name::Name, /// Is this a function that we special-case somehow? If so, which one? known: Option, + /// The scope that's created by the function, in which the function body is evaluated. body_scope: ScopeId<'db>, /// A set of special decorators that were applied to this function @@ -5768,10 +5830,11 @@ impl<'db> FunctionType<'db> { } /// Convert the `FunctionType` into a [`Type::Callable`]. - /// - /// This powers the `CallableTypeOf` special form from the `knot_extensions` module. pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> { - Type::Callable(CallableType::new(db, self.signature(db).clone())) + Type::Callable(CallableType::from_overloads( + db, + self.signature(db).iter().cloned(), + )) } /// Returns the [`FileRange`] of the function's name. @@ -5808,18 +5871,55 @@ impl<'db> FunctionType<'db> { /// Were this not a salsa query, then the calling query /// would depend on the function's AST and rerun for every change in that file. #[salsa::tracked(return_ref)] - pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> { + pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> { let mut internal_signature = self.internal_signature(db); - if self.has_known_decorator(db, FunctionDecorators::OVERLOAD) { - return Signature::todo("return type of overloaded function"); - } - if let Some(specialization) = self.specialization(db) { - internal_signature.apply_specialization(db, specialization); + internal_signature = internal_signature.apply_specialization(db, specialization); } - internal_signature + // The semantic model records a use for each function on the name node. This is used here + // to get the previous function definition with the same name. + let scope = self.definition(db).scope(db); + let use_def = semantic_index(db, scope.file(db)).use_def_map(scope.file_scope_id(db)); + let use_id = self + .body_scope(db) + .node(db) + .expect_function() + .name + .scoped_use_id(db, scope); + + if let Symbol::Type(Type::FunctionLiteral(function_literal), Boundness::Bound) = + symbol_from_bindings(db, use_def.bindings_at_use(use_id)) + { + match function_literal.signature(db) { + FunctionSignature::Single(_) => { + debug_assert!( + !function_literal.has_known_decorator(db, FunctionDecorators::OVERLOAD), + "Expected `FunctionSignature::Overloaded` if the previous function was an overload" + ); + } + FunctionSignature::Overloaded(_, Some(_)) => { + // If the previous overloaded function already has an implementation, then this + // new signature completely replaces it. + } + FunctionSignature::Overloaded(signatures, None) => { + return if self.has_known_decorator(db, FunctionDecorators::OVERLOAD) { + let mut signatures = signatures.clone(); + signatures.push(internal_signature); + FunctionSignature::Overloaded(signatures, None) + } else { + FunctionSignature::Overloaded(signatures.clone(), Some(internal_signature)) + }; + } + } + } + + if self.has_known_decorator(db, FunctionDecorators::OVERLOAD) { + FunctionSignature::Overloaded(vec![internal_signature], None) + } else { + FunctionSignature::Single(internal_signature) + } } /// Typed internally-visible signature for this function. @@ -6013,27 +6113,54 @@ pub struct BoundMethodType<'db> { impl<'db> BoundMethodType<'db> { pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> { - Type::Callable(CallableType::new( + Type::Callable(CallableType::from_overloads( db, - self.function(db).signature(db).bind_self(), + self.function(db) + .signature(db) + .iter() + .map(signatures::Signature::bind_self), )) } } -/// This type represents the set of all callable objects with a certain signature. -/// It can be written in type expressions using `typing.Callable`. -/// `lambda` expressions are inferred directly as `CallableType`s; all function-literal types -/// are subtypes of a `CallableType`. +/// This type represents the set of all callable objects with a certain, possibly overloaded, +/// signature. +/// +/// It can be written in type expressions using `typing.Callable`. `lambda` expressions are +/// inferred directly as `CallableType`s; all function-literal types are subtypes of a +/// `CallableType`. #[salsa::interned(debug)] pub struct CallableType<'db> { #[return_ref] - signature: Signature<'db>, + signatures: Box<[Signature<'db>]>, } impl<'db> CallableType<'db> { + /// Create a non-overloaded callable type with a single signature. + pub(crate) fn single(db: &'db dyn Db, signature: Signature<'db>) -> Self { + CallableType::new(db, vec![signature].into_boxed_slice()) + } + + /// Create an overloaded callable type with multiple signatures. + /// + /// # Panics + /// + /// Panics if `overloads` is empty. + pub(crate) fn from_overloads(db: &'db dyn Db, overloads: I) -> Self + where + I: IntoIterator>, + { + let overloads = overloads.into_iter().collect::>().into_boxed_slice(); + assert!( + !overloads.is_empty(), + "CallableType must have at least one signature" + ); + CallableType::new(db, overloads) + } + /// Create a callable type which accepts any parameters and returns an `Unknown` type. pub(crate) fn unknown(db: &'db dyn Db) -> Self { - CallableType::new( + CallableType::single( db, Signature::new(Parameters::unknown(), Some(Type::unknown())), ) @@ -6043,22 +6170,127 @@ impl<'db> CallableType<'db> { /// /// See [`Type::normalized`] for more details. fn normalized(self, db: &'db dyn Db) -> Self { - let signature = self.signature(db); - let parameters = signature - .parameters() - .iter() - .map(|param| param.normalized(db)) - .collect(); - let return_ty = signature - .return_ty - .map(|return_ty| return_ty.normalized(db)); - CallableType::new(db, Signature::new(parameters, return_ty)) + CallableType::from_overloads( + db, + self.signatures(db) + .iter() + .map(|signature| signature.normalized(db)), + ) } + /// Apply a specialization to this callable type. + /// + /// See [`Type::apply_specialization`] for more details. fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { - let mut signature = self.signature(db).clone(); - signature.apply_specialization(db, specialization); - Self::new(db, signature) + CallableType::from_overloads( + db, + self.signatures(db) + .iter() + .map(|signature| signature.apply_specialization(db, specialization)), + ) + } + + /// Check whether this callable type is fully static. + /// + /// See [`Type::is_fully_static`] for more details. + fn is_fully_static(self, db: &'db dyn Db) -> bool { + self.signatures(db) + .iter() + .all(|signature| signature.is_fully_static(db)) + } + + /// Check whether this callable type is a subtype of another callable type. + /// + /// See [`Type::is_subtype_of`] for more details. + fn is_subtype_of(self, db: &'db dyn Db, other: Self) -> bool { + self.is_assignable_to_impl(db, other, &|self_signature, other_signature| { + self_signature.is_subtype_of(db, other_signature) + }) + } + + /// Check whether this callable type is assignable to another callable type. + /// + /// See [`Type::is_assignable_to`] for more details. + fn is_assignable_to(self, db: &'db dyn Db, other: Self) -> bool { + self.is_assignable_to_impl(db, other, &|self_signature, other_signature| { + self_signature.is_assignable_to(db, other_signature) + }) + } + + /// Implementation for the various relation checks between two, possible overloaded, callable + /// types. + /// + /// The `check_signature` closure is used to check the relation between two [`Signature`]s. + fn is_assignable_to_impl(self, db: &'db dyn Db, other: Self, check_signature: &F) -> bool + where + F: Fn(&Signature<'db>, &Signature<'db>) -> bool, + { + match (&**self.signatures(db), &**other.signatures(db)) { + ([self_signature], [other_signature]) => { + // Base case: both callable types contain a single signature. + check_signature(self_signature, other_signature) + } + + // `self` is possibly overloaded while `other` is definitely not overloaded. + (self_signatures, [other_signature]) => { + let other_callable = CallableType::single(db, other_signature.clone()); + self_signatures + .iter() + .map(|self_signature| CallableType::single(db, self_signature.clone())) + .any(|self_callable| { + self_callable.is_assignable_to_impl(db, other_callable, check_signature) + }) + } + + // `self` is definitely not overloaded while `other` is possibly overloaded. + ([self_signature], other_signatures) => { + let self_callable = CallableType::single(db, self_signature.clone()); + other_signatures + .iter() + .map(|other_signature| CallableType::single(db, other_signature.clone())) + .all(|other_callable| { + self_callable.is_assignable_to_impl(db, other_callable, check_signature) + }) + } + + // `self` is definitely overloaded while `other` is possibly overloaded. + (_, other_signatures) => other_signatures + .iter() + .map(|other_signature| CallableType::single(db, other_signature.clone())) + .all(|other_callable| { + self.is_assignable_to_impl(db, other_callable, check_signature) + }), + } + } + + /// Check whether this callable type is equivalent to another callable type. + /// + /// See [`Type::is_equivalent_to`] for more details. + fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool { + match (&**self.signatures(db), &**other.signatures(db)) { + ([self_signature], [other_signature]) => { + self_signature.is_equivalent_to(db, other_signature) + } + _ => { + // TODO: overloads + false + } + } + } + + /// Check whether this callable type is gradual equivalent to another callable type. + /// + /// See [`Type::is_gradual_equivalent_to`] for more details. + fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool { + match (&**self.signatures(db), &**other.signatures(db)) { + ([self_signature], [other_signature]) => { + self_signature.is_gradual_equivalent_to(db, other_signature) + } + _ => { + // TODO: overloads + false + } + } } } diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index 16f00029e5..7232a08f8d 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -19,7 +19,7 @@ use crate::types::diagnostic::{ use crate::types::generics::{Specialization, SpecializationBuilder}; use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::{ - todo_type, BoundMethodType, DataclassMetadata, FunctionDecorators, KnownClass, KnownFunction, + BoundMethodType, DataclassMetadata, FunctionDecorators, KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType, WrapperDescriptorKind, }; use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic}; @@ -536,7 +536,11 @@ impl<'db> Bindings<'db> { } Some(KnownFunction::Overload) => { - overload.set_return_type(todo_type!("overload[..] return type")); + // TODO: This can be removed once we understand legacy generics because the + // typeshed definition for `typing.overload` is an identity function. + if let [Some(ty)] = overload.parameter_types() { + overload.set_return_type(*ty); + } } Some(KnownFunction::GetattrStatic) => { diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index d59fa8e006..f22e8a668e 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -931,7 +931,7 @@ impl<'db> ClassLiteralType<'db> { let init_signature = Signature::new(Parameters::new(parameters), Some(Type::none(db))); - return Some(Type::Callable(CallableType::new(db, init_signature))); + return Some(Type::Callable(CallableType::single(db, init_signature))); } else if matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") { if metadata.contains(DataclassMetadata::ORDER) { let signature = Signature::new( @@ -943,7 +943,7 @@ impl<'db> ClassLiteralType<'db> { Some(KnownClass::Bool.to_instance(db)), ); - return Some(Type::Callable(CallableType::new(db, signature))); + return Some(Type::Callable(CallableType::single(db, signature))); } } diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 5a6da80132..96c3c5c535 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -11,12 +11,15 @@ use crate::types::class_base::ClassBase; use crate::types::generics::{GenericContext, Specialization}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::{ - InstanceType, IntersectionType, KnownClass, MethodWrapperKind, StringLiteralType, Type, - TypeVarBoundOrConstraints, TypeVarInstance, UnionType, WrapperDescriptorKind, + FunctionSignature, InstanceType, IntersectionType, KnownClass, MethodWrapperKind, + StringLiteralType, Type, TypeVarBoundOrConstraints, TypeVarInstance, UnionType, + WrapperDescriptorKind, }; use crate::Db; use rustc_hash::FxHashMap; +use super::CallableType; + impl<'db> Type<'db> { pub fn display(&self, db: &'db dyn Db) -> DisplayType { DisplayType { ty: self, db } @@ -95,32 +98,59 @@ impl Display for DisplayRepresentation<'_> { Type::KnownInstance(known_instance) => f.write_str(known_instance.repr(self.db)), Type::FunctionLiteral(function) => { let signature = function.signature(self.db); + // TODO: when generic function types are supported, we should add // the generic type parameters to the signature, i.e. // show `def foo[T](x: T) -> T`. - write!( - f, - // "def {name}{specialization}{signature}", - "def {name}{signature}", - name = function.name(self.db), - signature = signature.display(self.db) - ) + match signature { + FunctionSignature::Single(signature) => { + write!( + f, + // "def {name}{specialization}{signature}", + "def {name}{signature}", + name = function.name(self.db), + signature = signature.display(self.db) + ) + } + FunctionSignature::Overloaded(signatures, _) => { + // TODO: How to display overloads? + f.write_str("Overload[")?; + let mut join = f.join(", "); + for signature in signatures { + join.entry(&signature.display(self.db)); + } + f.write_str("]") + } + } } - Type::Callable(callable) => callable.signature(self.db).display(self.db).fmt(f), + Type::Callable(callable) => callable.display(self.db).fmt(f), Type::BoundMethod(bound_method) => { let function = bound_method.function(self.db); // TODO: use the specialization from the method. Similar to the comment above // about the function specialization, - write!( - f, - "bound method {instance}.{method}{signature}", - method = function.name(self.db), - instance = bound_method.self_instance(self.db).display(self.db), - signature = function.signature(self.db).bind_self().display(self.db) - ) + match function.signature(self.db) { + FunctionSignature::Single(signature) => { + write!( + f, + "bound method {instance}.{method}{signature}", + method = function.name(self.db), + instance = bound_method.self_instance(self.db).display(self.db), + signature = signature.bind_self().display(self.db) + ) + } + FunctionSignature::Overloaded(signatures, _) => { + // TODO: How to display overloads? + f.write_str("Overload[")?; + let mut join = f.join(", "); + for signature in signatures { + join.entry(&signature.bind_self().display(self.db)); + } + f.write_str("]") + } + } } Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => { write!( @@ -355,8 +385,40 @@ impl Display for DisplaySpecialization<'_> { } } +impl<'db> CallableType<'db> { + pub(crate) fn display(&'db self, db: &'db dyn Db) -> DisplayCallableType<'db> { + DisplayCallableType { + signatures: self.signatures(db), + db, + } + } +} + +pub(crate) struct DisplayCallableType<'db> { + signatures: &'db [Signature<'db>], + db: &'db dyn Db, +} + +impl Display for DisplayCallableType<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self.signatures { + [signature] => write!(f, "{}", signature.display(self.db)), + signatures => { + // TODO: How to display overloads? + f.write_str("Overload[")?; + let mut join = f.join(", "); + for signature in signatures { + join.entry(&signature.display(self.db)); + } + join.finish()?; + f.write_char(']') + } + } + } +} + impl<'db> Signature<'db> { - fn display(&'db self, db: &'db dyn Db) -> DisplaySignature<'db> { + pub(crate) fn display(&'db self, db: &'db dyn Db) -> DisplaySignature<'db> { DisplaySignature { parameters: self.parameters(), return_ty: self.return_ty, @@ -365,7 +427,7 @@ impl<'db> Signature<'db> { } } -struct DisplaySignature<'db> { +pub(crate) struct DisplaySignature<'db> { parameters: &'db Parameters<'db>, return_ty: Option>, db: &'db dyn Db, diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 59248f9f7b..29a92c058d 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -4133,7 +4133,7 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO: Useful inference of a lambda's return type will require a different approach, // which does the inference of the body expression based on arguments at each call site, // rather than eagerly computing a return type without knowing the argument types. - Type::Callable(CallableType::new( + Type::Callable(CallableType::single( self.db(), Signature::new(parameters, Some(Type::unknown())), )) @@ -7305,7 +7305,7 @@ impl<'db> TypeInferenceBuilder<'db> { let callable_type = if let (Some(parameters), Some(return_type), true) = (parameters, return_type, correct_argument_number) { - CallableType::new(db, Signature::new(parameters, Some(return_type))) + CallableType::single(db, Signature::new(parameters, Some(return_type))) } else { CallableType::unknown(db) }; @@ -7386,8 +7386,22 @@ impl<'db> TypeInferenceBuilder<'db> { let argument_type = self.infer_expression(arguments_slice); let signatures = argument_type.signatures(db); - // TODO overloads - let Some(signature) = signatures.iter().flatten().next() else { + // SAFETY: This is enforced by the constructor methods on `Signatures` even in + // the case of a non-callable union. + let callable_signature = signatures + .iter() + .next() + .expect("`Signatures` should have at least one `CallableSignature`"); + + let mut signature_iter = callable_signature.iter().map(|signature| { + if argument_type.is_bound_method() { + signature.bind_self() + } else { + signature.clone() + } + }); + + let Some(signature) = signature_iter.next() else { self.context.report_lint_old( &INVALID_TYPE_FORM, arguments_slice, @@ -7400,13 +7414,10 @@ impl<'db> TypeInferenceBuilder<'db> { return Type::unknown(); }; - let revealed_signature = if argument_type.is_bound_method() { - signature.bind_self() - } else { - signature.clone() - }; - - Type::Callable(CallableType::new(db, revealed_signature)) + Type::Callable(CallableType::from_overloads( + db, + std::iter::once(signature).chain(signature_iter), + )) } }, diff --git a/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs b/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs index 8072190159..470f3e0778 100644 --- a/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs +++ b/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs @@ -188,7 +188,7 @@ impl Ty { create_bound_method(db, function, builtins_class) } - Ty::Callable { params, returns } => Type::Callable(CallableType::new( + Ty::Callable { params, returns } => Type::Callable(CallableType::single( db, Signature::new( params.into_parameters(db), diff --git a/crates/red_knot_python_semantic/src/types/signatures.rs b/crates/red_knot_python_semantic/src/types/signatures.rs index 50c022e8d6..52b1ab827f 100644 --- a/crates/red_knot_python_semantic/src/types/signatures.rs +++ b/crates/red_knot_python_semantic/src/types/signatures.rs @@ -250,16 +250,6 @@ impl<'db> Signature<'db> { } } - /// Return a todo signature: (*args: Todo, **kwargs: Todo) -> Todo - #[allow(unused_variables)] // 'reason' only unused in debug builds - pub(crate) fn todo(reason: &'static str) -> Self { - Signature { - generic_context: None, - parameters: Parameters::todo(), - return_ty: Some(todo_type!(reason)), - } - } - /// Return a typed signature from a function definition. pub(super) fn from_function( db: &'db dyn Db, @@ -286,15 +276,30 @@ impl<'db> Signature<'db> { } } + pub(crate) fn normalized(&self, db: &'db dyn Db) -> Self { + Self { + generic_context: self.generic_context, + parameters: self + .parameters + .iter() + .map(|param| param.normalized(db)) + .collect(), + return_ty: self.return_ty.map(|return_ty| return_ty.normalized(db)), + } + } + pub(crate) fn apply_specialization( - &mut self, + &self, db: &'db dyn Db, specialization: Specialization<'db>, - ) { - self.parameters.apply_specialization(db, specialization); - self.return_ty = self - .return_ty - .map(|ty| ty.apply_specialization(db, specialization)); + ) -> Self { + Self { + generic_context: self.generic_context, + parameters: self.parameters.apply_specialization(db, specialization), + return_ty: self + .return_ty + .map(|ty| ty.apply_specialization(db, specialization)), + } } /// Return the parameters in this signature. @@ -995,10 +1000,15 @@ impl<'db> Parameters<'db> { ) } - fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) { - self.value - .iter_mut() - .for_each(|param| param.apply_specialization(db, specialization)); + fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { + Self { + value: self + .value + .iter() + .map(|param| param.apply_specialization(db, specialization)) + .collect(), + is_gradual: self.is_gradual, + } } pub(crate) fn len(&self) -> usize { @@ -1162,11 +1172,14 @@ impl<'db> Parameter<'db> { self } - fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) { - self.annotated_type = self - .annotated_type - .map(|ty| ty.apply_specialization(db, specialization)); - self.kind.apply_specialization(db, specialization); + fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { + Self { + annotated_type: self + .annotated_type + .map(|ty| ty.apply_specialization(db, specialization)), + kind: self.kind.apply_specialization(db, specialization), + form: self.form, + } } /// Strip information from the parameter so that two equivalent parameters compare equal. @@ -1356,14 +1369,27 @@ pub(crate) enum ParameterKind<'db> { } impl<'db> ParameterKind<'db> { - fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) { + fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { match self { - Self::PositionalOnly { default_type, .. } - | Self::PositionalOrKeyword { default_type, .. } - | Self::KeywordOnly { default_type, .. } => { - *default_type = default_type.map(|ty| ty.apply_specialization(db, specialization)); - } - Self::Variadic { .. } | Self::KeywordVariadic { .. } => {} + Self::PositionalOnly { default_type, name } => Self::PositionalOnly { + default_type: default_type + .as_ref() + .map(|ty| ty.apply_specialization(db, specialization)), + name: name.clone(), + }, + Self::PositionalOrKeyword { default_type, name } => Self::PositionalOrKeyword { + default_type: default_type + .as_ref() + .map(|ty| ty.apply_specialization(db, specialization)), + name: name.clone(), + }, + Self::KeywordOnly { default_type, name } => Self::KeywordOnly { + default_type: default_type + .as_ref() + .map(|ty| ty.apply_specialization(db, specialization)), + name: name.clone(), + }, + Self::Variadic { .. } | Self::KeywordVariadic { .. } => self.clone(), } } } @@ -1380,7 +1406,7 @@ mod tests { use super::*; use crate::db::tests::{setup_db, TestDb}; use crate::symbol::global_symbol; - use crate::types::{FunctionType, KnownClass}; + use crate::types::{FunctionSignature, FunctionType, KnownClass}; use ruff_db::system::DbWithWritableSystem as _; #[track_caller] @@ -1627,6 +1653,9 @@ mod tests { let expected_sig = func.internal_signature(&db); // With no decorators, internal and external signature are the same - assert_eq!(func.signature(&db), &expected_sig); + assert_eq!( + func.signature(&db), + &FunctionSignature::Single(expected_sig) + ); } }