diff --git a/crates/red_knot/tests/cli.rs b/crates/red_knot/tests/cli.rs index 3f2dd4a98a..a37468e824 100644 --- a/crates/red_knot/tests/cli.rs +++ b/crates/red_knot/tests/cli.rs @@ -252,7 +252,7 @@ fn configuration_rule_severity() -> anyhow::Result<()> { r#" y = 4 / 0 - for a in range(0, y): + for a in range(0, int(y)): x = a print(x) # possibly-unresolved-reference @@ -271,7 +271,7 @@ fn configuration_rule_severity() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | warning: lint:possibly-unresolved-reference @@ -307,7 +307,7 @@ fn configuration_rule_severity() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | Found 1 diagnostic @@ -328,7 +328,7 @@ fn cli_rule_severity() -> anyhow::Result<()> { y = 4 / 0 - for a in range(0, y): + for a in range(0, int(y)): x = a print(x) # possibly-unresolved-reference @@ -358,7 +358,7 @@ fn cli_rule_severity() -> anyhow::Result<()> { 4 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 5 | - 6 | for a in range(0, y): + 6 | for a in range(0, int(y)): | warning: lint:possibly-unresolved-reference @@ -405,7 +405,7 @@ fn cli_rule_severity() -> anyhow::Result<()> { 4 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 5 | - 6 | for a in range(0, y): + 6 | for a in range(0, int(y)): | Found 2 diagnostics @@ -426,7 +426,7 @@ fn cli_rule_severity_precedence() -> anyhow::Result<()> { r#" y = 4 / 0 - for a in range(0, y): + for a in range(0, int(y)): x = a print(x) # possibly-unresolved-reference @@ -445,7 +445,7 @@ fn cli_rule_severity_precedence() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | warning: lint:possibly-unresolved-reference @@ -482,7 +482,7 @@ fn cli_rule_severity_precedence() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | Found 1 diagnostic @@ -814,7 +814,7 @@ fn user_configuration() -> anyhow::Result<()> { r#" y = 4 / 0 - for a in range(0, y): + for a in range(0, int(y)): x = a print(x) @@ -841,7 +841,7 @@ fn user_configuration() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | warning: lint:possibly-unresolved-reference @@ -883,7 +883,7 @@ fn user_configuration() -> anyhow::Result<()> { 2 | y = 4 / 0 | ^^^^^ Cannot divide object of type `Literal[4]` by zero 3 | - 4 | for a in range(0, y): + 4 | for a in range(0, int(y)): | error: lint:possibly-unresolved-reference diff --git a/crates/red_knot_python_semantic/resources/mdtest/annotations/literal_string.md b/crates/red_knot_python_semantic/resources/mdtest/annotations/literal_string.md index 45dff62b80..445c7a05da 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/annotations/literal_string.md +++ b/crates/red_knot_python_semantic/resources/mdtest/annotations/literal_string.md @@ -72,13 +72,11 @@ reveal_type(baz) # revealed: Literal["bazfoo"] qux = (foo, bar) reveal_type(qux) # revealed: tuple[Literal["foo"], Literal["bar"]] -# TODO: Infer "LiteralString" -reveal_type(foo.join(qux)) # revealed: @Todo(return type of overloaded function) +reveal_type(foo.join(qux)) # revealed: LiteralString template: LiteralString = "{}, {}" reveal_type(template) # revealed: Literal["{}, {}"] -# TODO: Infer `LiteralString` -reveal_type(template.format(foo, bar)) # revealed: @Todo(return type of overloaded function) +reveal_type(template.format(foo, bar)) # revealed: LiteralString ``` ### Assignability diff --git a/crates/red_knot_python_semantic/resources/mdtest/attributes.md b/crates/red_knot_python_semantic/resources/mdtest/attributes.md index 15882300c1..3f6c5321c0 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/attributes.md +++ b/crates/red_knot_python_semantic/resources/mdtest/attributes.md @@ -1698,9 +1698,9 @@ Most attribute accesses on bool-literal types are delegated to `builtins.bool`, bools are instances of that class: ```py -# revealed: bound method Literal[True].__and__(**kwargs: @Todo(todo signature **kwargs)) -> @Todo(return type of overloaded function) +# revealed: Overload[(value: bool, /) -> bool, (value: int, /) -> int] reveal_type(True.__and__) -# revealed: bound method Literal[False].__or__(**kwargs: @Todo(todo signature **kwargs)) -> @Todo(return type of overloaded function) +# revealed: Overload[(value: bool, /) -> bool, (value: int, /) -> int] reveal_type(False.__or__) ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md b/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md index b435a34d76..650e7a636c 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md +++ b/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md @@ -310,9 +310,7 @@ reveal_type(A() + 1) # revealed: A reveal_type(1 + A()) # revealed: A reveal_type(A() + "foo") # revealed: A -# TODO should be `A` since `str.__add__` doesn't support `A` instances -# TODO overloads -reveal_type("foo" + A()) # revealed: @Todo(return type of overloaded function) +reveal_type("foo" + A()) # revealed: A reveal_type(A() + b"foo") # revealed: A # TODO should be `A` since `bytes.__add__` doesn't support `A` instances @@ -320,16 +318,14 @@ reveal_type(b"foo" + A()) # revealed: bytes reveal_type(A() + ()) # revealed: A # TODO this should be `A`, since `tuple.__add__` doesn't support `A` instances -reveal_type(() + A()) # revealed: @Todo(return type of overloaded function) +reveal_type(() + A()) # revealed: @Todo(full tuple[...] support) literal_string_instance = "foo" * 1_000_000_000 # the test is not testing what it's meant to be testing if this isn't a `LiteralString`: reveal_type(literal_string_instance) # revealed: LiteralString reveal_type(A() + literal_string_instance) # revealed: A -# TODO should be `A` since `str.__add__` doesn't support `A` instances -# TODO overloads -reveal_type(literal_string_instance + A()) # revealed: @Todo(return type of overloaded function) +reveal_type(literal_string_instance + A()) # revealed: A ``` ## Operations involving instances of classes inheriting from `Any` diff --git a/crates/red_knot_python_semantic/resources/mdtest/binary/integers.md b/crates/red_knot_python_semantic/resources/mdtest/binary/integers.md index a4172c821d..dcbc180317 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/binary/integers.md +++ b/crates/red_knot_python_semantic/resources/mdtest/binary/integers.md @@ -50,9 +50,11 @@ reveal_type(1 ** (largest_u32 + 1)) # revealed: int reveal_type(2**largest_u32) # revealed: int def variable(x: int): - reveal_type(x**2) # revealed: @Todo(return type of overloaded function) - reveal_type(2**x) # revealed: @Todo(return type of overloaded function) - reveal_type(x**x) # revealed: @Todo(return type of overloaded function) + reveal_type(x**2) # revealed: int + # TODO: should be `Any` (overload 5 on `__pow__`), requires correct overload matching + reveal_type(2**x) # revealed: int + # TODO: should be `Any` (overload 5 on `__pow__`), requires correct overload matching + reveal_type(x**x) # revealed: int ``` If the second argument is \<0, a `float` is returned at runtime. If the first argument is \<0 but diff --git a/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md b/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md index 28baa67ba4..94958c883b 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md +++ b/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md @@ -547,14 +547,14 @@ the descriptor's `__get__` method as if it had been called on the class itself, for the `instance` argument. ```py -from typing import overload +from typing import Literal, overload from dataclasses import dataclass class ConvertToLength: _len: int = 0 @overload - def __get__(self, instance: None, owner: type) -> str: ... + def __get__(self, instance: None, owner: type) -> Literal[""]: ... @overload def __get__(self, instance: object, owner: type | None) -> int: ... def __get__(self, instance: object | None, owner: type | None) -> str | int: @@ -570,12 +570,10 @@ class ConvertToLength: class C: converter: ConvertToLength = ConvertToLength() -# TODO: Should be `(converter: str = Literal[""]) -> None` once we understand overloads -reveal_type(C.__init__) # revealed: (converter: str = str | int) -> None +reveal_type(C.__init__) # revealed: (converter: str = Literal[""]) -> None c = C("abc") -# TODO: Should be `int` once we understand overloads -reveal_type(c.converter) # revealed: str | int +reveal_type(c.converter) # revealed: int # This is also okay: C() @@ -611,8 +609,7 @@ class AcceptsStrAndInt: class C: field: AcceptsStrAndInt = AcceptsStrAndInt() -# TODO: Should be `field: str | int = int` once we understand overloads -reveal_type(C.__init__) # revealed: (field: Unknown = int) -> None +reveal_type(C.__init__) # revealed: (field: str | int = int) -> None ``` ## `dataclasses.field` diff --git a/crates/red_knot_python_semantic/resources/mdtest/descriptor_protocol.md b/crates/red_knot_python_semantic/resources/mdtest/descriptor_protocol.md index aeabb336ec..8f8d17ae76 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/descriptor_protocol.md +++ b/crates/red_knot_python_semantic/resources/mdtest/descriptor_protocol.md @@ -459,11 +459,9 @@ class Descriptor: class C: d: Descriptor = Descriptor() -# TODO: should be `Literal["called on class object"] -reveal_type(C.d) # revealed: LiteralString +reveal_type(C.d) # revealed: Literal["called on class object"] -# TODO: should be `Literal["called on instance"] -reveal_type(C().d) # revealed: LiteralString +reveal_type(C().d) # revealed: Literal["called on instance"] ``` ## Descriptor protocol for dunder methods diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md index 84a14f93f5..f8f5ab8c5e 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md @@ -246,7 +246,7 @@ class MetaTruthy(type): class MetaDeferred(type): def __bool__(self) -> MetaAmbiguous: - return MetaAmbiguous() + raise NotImplementedError class AmbiguousClass(metaclass=MetaAmbiguous): ... class FalsyClass(metaclass=MetaFalsy): ... diff --git a/crates/red_knot_python_semantic/resources/mdtest/overloads.md b/crates/red_knot_python_semantic/resources/mdtest/overloads.md new file mode 100644 index 0000000000..764b6a4412 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/overloads.md @@ -0,0 +1,638 @@ +# Overloads + +Reference: + +## `typing.overload` + +The definition of `typing.overload` in typeshed is an identity function. + +```py +from typing import overload + +def foo(x: int) -> int: + return x + +reveal_type(foo) # revealed: def foo(x: int) -> int +bar = overload(foo) +reveal_type(bar) # revealed: def foo(x: int) -> int +``` + +## Functions + +```py +from typing import overload + +@overload +def add() -> None: ... +@overload +def add(x: int) -> int: ... +@overload +def add(x: int, y: int) -> int: ... +def add(x: int | None = None, y: int | None = None) -> int | None: + return (x or 0) + (y or 0) + +reveal_type(add) # revealed: Overload[() -> None, (x: int) -> int, (x: int, y: int) -> int] +reveal_type(add()) # revealed: None +reveal_type(add(1)) # revealed: int +reveal_type(add(1, 2)) # revealed: int +``` + +## Overriding + +These scenarios are to verify that the overloaded and non-overloaded definitions are correctly +overridden by each other. + +An overloaded function is overriding another overloaded function: + +```py +from typing import overload + +@overload +def foo() -> None: ... +@overload +def foo(x: int) -> int: ... +def foo(x: int | None = None) -> int | None: + return x + +reveal_type(foo) # revealed: Overload[() -> None, (x: int) -> int] +reveal_type(foo()) # revealed: None +reveal_type(foo(1)) # revealed: int + +@overload +def foo() -> None: ... +@overload +def foo(x: str) -> str: ... +def foo(x: str | None = None) -> str | None: + return x + +reveal_type(foo) # revealed: Overload[() -> None, (x: str) -> str] +reveal_type(foo()) # revealed: None +reveal_type(foo("")) # revealed: str +``` + +A non-overloaded function is overriding an overloaded function: + +```py +def foo(x: int) -> int: + return x + +reveal_type(foo) # revealed: def foo(x: int) -> int +``` + +An overloaded function is overriding a non-overloaded function: + +```py +reveal_type(foo) # revealed: def foo(x: int) -> int + +@overload +def foo() -> None: ... +@overload +def foo(x: bytes) -> bytes: ... +def foo(x: bytes | None = None) -> bytes | None: + return x + +reveal_type(foo) # revealed: Overload[() -> None, (x: bytes) -> bytes] +reveal_type(foo()) # revealed: None +reveal_type(foo(b"")) # revealed: bytes +``` + +## Methods + +```py +from typing import overload + +class Foo1: + @overload + def method(self) -> None: ... + @overload + def method(self, x: int) -> int: ... + def method(self, x: int | None = None) -> int | None: + return x + +foo1 = Foo1() +reveal_type(foo1.method) # revealed: Overload[() -> None, (x: int) -> int] +reveal_type(foo1.method()) # revealed: None +reveal_type(foo1.method(1)) # revealed: int + +class Foo2: + @overload + def method(self) -> None: ... + @overload + def method(self, x: str) -> str: ... + def method(self, x: str | None = None) -> str | None: + return x + +foo2 = Foo2() +reveal_type(foo2.method) # revealed: Overload[() -> None, (x: str) -> str] +reveal_type(foo2.method()) # revealed: None +reveal_type(foo2.method("")) # revealed: str +``` + +## Constructor + +```py +from typing import overload + +class Foo: + @overload + def __init__(self) -> None: ... + @overload + def __init__(self, x: int) -> None: ... + def __init__(self, x: int | None = None) -> None: + self.x = x + +foo = Foo() +reveal_type(foo) # revealed: Foo +reveal_type(foo.x) # revealed: Unknown | int | None + +foo1 = Foo(1) +reveal_type(foo1) # revealed: Foo +reveal_type(foo1.x) # revealed: Unknown | int | None +``` + +## Version specific + +Function definitions can vary between multiple Python versions. + +### Overload and non-overload (3.9) + +Here, the same function is overloaded in one version and not in another. + +```toml +[environment] +python-version = "3.9" +``` + +```py +import sys +from typing import overload + +if sys.version_info < (3, 10): + def func(x: int) -> int: + return x + +elif sys.version_info <= (3, 12): + @overload + def func() -> None: ... + @overload + def func(x: int) -> int: ... + def func(x: int | None = None) -> int | None: + return x + +reveal_type(func) # revealed: def func(x: int) -> int +func() # error: [missing-argument] +``` + +### Overload and non-overload (3.10) + +```toml +[environment] +python-version = "3.10" +``` + +```py +import sys +from typing import overload + +if sys.version_info < (3, 10): + def func(x: int) -> int: + return x + +elif sys.version_info <= (3, 12): + @overload + def func() -> None: ... + @overload + def func(x: int) -> int: ... + def func(x: int | None = None) -> int | None: + return x + +reveal_type(func) # revealed: Overload[() -> None, (x: int) -> int] +reveal_type(func()) # revealed: None +reveal_type(func(1)) # revealed: int +``` + +### Some overloads are version specific (3.9) + +```toml +[environment] +python-version = "3.9" +``` + +`overloaded.pyi`: + +```pyi +import sys +from typing import overload + +if sys.version_info >= (3, 10): + @overload + def func() -> None: ... + +@overload +def func(x: int) -> int: ... +@overload +def func(x: str) -> str: ... +``` + +`main.py`: + +```py +from overloaded import func + +reveal_type(func) # revealed: Overload[(x: int) -> int, (x: str) -> str] +func() # error: [no-matching-overload] +reveal_type(func(1)) # revealed: int +reveal_type(func("")) # revealed: str +``` + +### Some overloads are version specific (3.10) + +```toml +[environment] +python-version = "3.10" +``` + +`overloaded.pyi`: + +```pyi +import sys +from typing import overload + +@overload +def func() -> None: ... + +if sys.version_info >= (3, 10): + @overload + def func(x: int) -> int: ... + +@overload +def func(x: str) -> str: ... +``` + +`main.py`: + +```py +from overloaded import func + +reveal_type(func) # revealed: Overload[() -> None, (x: int) -> int, (x: str) -> str] +reveal_type(func()) # revealed: None +reveal_type(func(1)) # revealed: int +reveal_type(func("")) # revealed: str +``` + +## Generic + +```toml +[environment] +python-version = "3.12" +``` + +For an overloaded generic function, it's not necessary for all overloads to be generic. + +```py +from typing import overload + +@overload +def func() -> None: ... +@overload +def func[T](x: T) -> T: ... +def func[T](x: T | None = None) -> T | None: + return x + +reveal_type(func) # revealed: Overload[() -> None, (x: T) -> T] +reveal_type(func()) # revealed: None +reveal_type(func(1)) # revealed: Literal[1] +reveal_type(func("")) # revealed: Literal[""] +``` + +## Invalid + +### At least two overloads + +At least two `@overload`-decorated definitions must be present. + +```py +from typing import overload + +# TODO: error +@overload +def func(x: int) -> int: ... +def func(x: int | str) -> int | str: + return x +``` + +### Overload without an implementation + +#### Regular modules + +In regular modules, a series of `@overload`-decorated definitions must be followed by exactly one +non-`@overload`-decorated definition (for the same function/method). + +```py +from typing import overload + +# TODO: error because implementation does not exists +@overload +def func(x: int) -> int: ... +@overload +def func(x: str) -> str: ... + +class Foo: + # TODO: error because implementation does not exists + @overload + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... +``` + +#### Stub files + +Overload definitions within stub files are exempt from this check. + +```pyi +from typing import overload + +@overload +def func(x: int) -> int: ... +@overload +def func(x: str) -> str: ... +``` + +#### Protocols + +Overload definitions within protocols are exempt from this check. + +```py +from typing import Protocol, overload + +class Foo(Protocol): + @overload + def f(self, x: int) -> int: ... + @overload + def f(self, x: str) -> str: ... +``` + +#### Abstract methods + +Overload definitions within abstract base classes are exempt from this check. + +```py +from abc import ABC, abstractmethod +from typing import overload + +class AbstractFoo(ABC): + @overload + @abstractmethod + def f(self, x: int) -> int: ... + @overload + @abstractmethod + def f(self, x: str) -> str: ... +``` + +Using the `@abstractmethod` decorator requires that the class's metaclass is `ABCMeta` or is derived +from it. + +```py +class Foo: + # TODO: Error because implementation does not exists + @overload + @abstractmethod + def f(self, x: int) -> int: ... + @overload + @abstractmethod + def f(self, x: str) -> str: ... +``` + +And, the `@abstractmethod` decorator must be present on all the `@overload`-ed methods. + +```py +class PartialFoo1(ABC): + @overload + @abstractmethod + def f(self, x: int) -> int: ... + @overload + def f(self, x: str) -> str: ... + +class PartialFoo(ABC): + @overload + def f(self, x: int) -> int: ... + @overload + @abstractmethod + def f(self, x: str) -> str: ... +``` + +### Inconsistent decorators + +#### `@staticmethod` / `@classmethod` + +If one overload signature is decorated with `@staticmethod` or `@classmethod`, all overload +signatures must be similarly decorated. The implementation, if present, must also have a consistent +decorator. + +```py +from __future__ import annotations + +from typing import overload + +class CheckStaticMethod: + # TODO: error because `@staticmethod` does not exist on all overloads + @overload + def method1(x: int) -> int: ... + @overload + def method1(x: str) -> str: ... + @staticmethod + def method1(x: int | str) -> int | str: + return x + # TODO: error because `@staticmethod` does not exist on all overloads + @overload + def method2(x: int) -> int: ... + @overload + @staticmethod + def method2(x: str) -> str: ... + @staticmethod + def method2(x: int | str) -> int | str: + return x + # TODO: error because `@staticmethod` does not exist on the implementation + @overload + @staticmethod + def method3(x: int) -> int: ... + @overload + @staticmethod + def method3(x: str) -> str: ... + def method3(x: int | str) -> int | str: + return x + + @overload + @staticmethod + def method4(x: int) -> int: ... + @overload + @staticmethod + def method4(x: str) -> str: ... + @staticmethod + def method4(x: int | str) -> int | str: + return x + +class CheckClassMethod: + def __init__(self, x: int) -> None: + self.x = x + # TODO: error because `@classmethod` does not exist on all overloads + @overload + @classmethod + def try_from1(cls, x: int) -> CheckClassMethod: ... + @overload + def try_from1(cls, x: str) -> None: ... + @classmethod + def try_from1(cls, x: int | str) -> CheckClassMethod | None: + if isinstance(x, int): + return cls(x) + return None + # TODO: error because `@classmethod` does not exist on all overloads + @overload + def try_from2(cls, x: int) -> CheckClassMethod: ... + @overload + @classmethod + def try_from2(cls, x: str) -> None: ... + @classmethod + def try_from2(cls, x: int | str) -> CheckClassMethod | None: + if isinstance(x, int): + return cls(x) + return None + # TODO: error because `@classmethod` does not exist on the implementation + @overload + @classmethod + def try_from3(cls, x: int) -> CheckClassMethod: ... + @overload + @classmethod + def try_from3(cls, x: str) -> None: ... + def try_from3(cls, x: int | str) -> CheckClassMethod | None: + if isinstance(x, int): + return cls(x) + return None + + @overload + @classmethod + def try_from4(cls, x: int) -> CheckClassMethod: ... + @overload + @classmethod + def try_from4(cls, x: str) -> None: ... + @classmethod + def try_from4(cls, x: int | str) -> CheckClassMethod | None: + if isinstance(x, int): + return cls(x) + return None +``` + +#### `@final` / `@override` + +If a `@final` or `@override` decorator is supplied for a function with overloads, the decorator +should be applied only to the overload implementation if it is present. + +```py +from typing_extensions import final, overload, override + +class Foo: + @overload + def method1(self, x: int) -> int: ... + @overload + def method1(self, x: str) -> str: ... + @final + def method1(self, x: int | str) -> int | str: + return x + # TODO: error because `@final` is not on the implementation + @overload + @final + def method2(self, x: int) -> int: ... + @overload + def method2(self, x: str) -> str: ... + def method2(self, x: int | str) -> int | str: + return x + # TODO: error because `@final` is not on the implementation + @overload + def method3(self, x: int) -> int: ... + @overload + @final + def method3(self, x: str) -> str: ... + def method3(self, x: int | str) -> int | str: + return x + +class Base: + @overload + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + def method(self, x: int | str) -> int | str: + return x + +class Sub1(Base): + @overload + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + @override + def method(self, x: int | str) -> int | str: + return x + +class Sub2(Base): + # TODO: error because `@override` is not on the implementation + @overload + def method(self, x: int) -> int: ... + @overload + @override + def method(self, x: str) -> str: ... + def method(self, x: int | str) -> int | str: + return x + +class Sub3(Base): + # TODO: error because `@override` is not on the implementation + @overload + @override + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + def method(self, x: int | str) -> int | str: + return x +``` + +#### `@final` / `@override` in stub files + +If an overload implementation isn’t present (for example, in a stub file), the `@final` or +`@override` decorator should be applied only to the first overload. + +```pyi +from typing_extensions import final, overload, override + +class Foo: + @overload + @final + def method1(self, x: int) -> int: ... + @overload + def method1(self, x: str) -> str: ... + + # TODO: error because `@final` is not on the first overload + @overload + def method2(self, x: int) -> int: ... + @final + @overload + def method2(self, x: str) -> str: ... + +class Base: + @overload + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + +class Sub1(Base): + @overload + @override + def method(self, x: int) -> int: ... + @overload + def method(self, x: str) -> str: ... + +class Sub2(Base): + # TODO: error because `@override` is not on the first overload + @overload + def method(self, x: int) -> int: ... + @overload + @override + def method(self, x: str) -> str: ... +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md index 255cb6bc9d..7127263dd4 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/bytes.md @@ -24,8 +24,7 @@ reveal_type(y) # revealed: Unknown def _(n: int): a = b"abcde"[n] - # TODO: Support overloads... Should be `bytes` - reveal_type(a) # revealed: @Todo(return type of overloaded function) + reveal_type(a) # revealed: int ``` ## Slices @@ -43,11 +42,9 @@ b[::0] # error: [zero-stepsize-in-slice] def _(m: int, n: int): byte_slice1 = b[m:n] - # TODO: Support overloads... Should be `bytes` - reveal_type(byte_slice1) # revealed: @Todo(return type of overloaded function) + reveal_type(byte_slice1) # revealed: bytes def _(s: bytes) -> bytes: byte_slice2 = s[0:5] - # TODO: Support overloads... Should be `bytes` - return reveal_type(byte_slice2) # revealed: @Todo(return type of overloaded function) + return reveal_type(byte_slice2) # revealed: bytes ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/lists.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/lists.md index 408a4f43b6..5f9264f6fa 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/lists.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/lists.md @@ -12,13 +12,13 @@ x = [1, 2, 3] reveal_type(x) # revealed: list # TODO reveal int -reveal_type(x[0]) # revealed: @Todo(return type of overloaded function) +reveal_type(x[0]) # revealed: Unknown | @Todo(Support for `typing.TypeVar` instances in type expressions) # TODO reveal list -reveal_type(x[0:1]) # revealed: @Todo(return type of overloaded function) +reveal_type(x[0:1]) # revealed: @Todo(generics) -# TODO error -reveal_type(x["a"]) # revealed: @Todo(return type of overloaded function) +# error: [call-non-callable] +reveal_type(x["a"]) # revealed: Unknown ``` ## Assignments within list assignment @@ -29,9 +29,11 @@ In assignment, we might also have a named assignment. This should also get type x = [1, 2, 3] x[0 if (y := 2) else 1] = 5 -# TODO error? (indeterminite index type) +# TODO: better error than "method `__getitem__` not callable on type `list`" +# error: [call-non-callable] x["a" if (y := 2) else 1] = 6 -# TODO error (can't index via string) +# TODO: better error than "method `__getitem__` not callable on type `list`" +# error: [call-non-callable] x["a" if (y := 2) else "b"] = 6 ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md index 9a875dc323..469300c0b4 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/string.md @@ -21,8 +21,7 @@ reveal_type(b) # revealed: Unknown def _(n: int): a = "abcde"[n] - # TODO: Support overloads... Should be `str` - reveal_type(a) # revealed: @Todo(return type of overloaded function) + reveal_type(a) # revealed: LiteralString ``` ## Slices @@ -75,12 +74,10 @@ def _(m: int, n: int, s2: str): s[::0] # error: [zero-stepsize-in-slice] substring1 = s[m:n] - # TODO: Support overloads... Should be `LiteralString` - reveal_type(substring1) # revealed: @Todo(return type of overloaded function) + reveal_type(substring1) # revealed: LiteralString substring2 = s2[0:5] - # TODO: Support overloads... Should be `str` - reveal_type(substring2) # revealed: @Todo(return type of overloaded function) + reveal_type(substring2) # revealed: str ``` ## Unsupported slice types diff --git a/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md b/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md index 8ac7ff2534..579d7451f6 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md +++ b/crates/red_knot_python_semantic/resources/mdtest/subscript/tuple.md @@ -69,8 +69,8 @@ def _(m: int, n: int): t[::0] # error: [zero-stepsize-in-slice] tuple_slice = t[m:n] - # TODO: Support overloads... Should be `tuple[Literal[1, 'a', b"b"] | None, ...]` - reveal_type(tuple_slice) # revealed: @Todo(return type of overloaded function) + # TODO: Should be `tuple[Literal[1, 'a', b"b"] | None, ...]` + reveal_type(tuple_slice) # revealed: @Todo(full tuple[...] support) ``` ## Inheritance diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md index 5ca7c9c628..fdd7ddbe7c 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md @@ -522,4 +522,35 @@ c: Callable[[Any], str] = A().f c: Callable[[Any], str] = A().g ``` +### Overloads + +`overloaded.pyi`: + +```pyi +from typing import Any, overload + +@overload +def overloaded() -> None: ... +@overload +def overloaded(a: str) -> str: ... +@overload +def overloaded(a: str, b: Any) -> str: ... +``` + +```py +from overloaded import overloaded +from typing import Any, Callable + +c: Callable[[], None] = overloaded +c: Callable[[str], str] = overloaded +c: Callable[[str, Any], Any] = overloaded +c: Callable[..., str] = overloaded + +# error: [invalid-assignment] +c: Callable[..., int] = overloaded + +# error: [invalid-assignment] +c: Callable[[int], str] = overloaded +``` + [typing documentation]: https://typing.python.org/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md index 39aa18fc43..3a234bb8a6 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md @@ -254,4 +254,8 @@ from knot_extensions import is_equivalent_to, static_assert static_assert(is_equivalent_to(int | Callable[[int | str], None], Callable[[str | int], None] | int)) ``` +### Overloads + +TODO + [the equivalence relation]: https://typing.python.org/en/latest/spec/glossary.html#term-equivalent diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_fully_static.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_fully_static.md index 8af5af4154..9a44ce090b 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_fully_static.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_fully_static.md @@ -99,3 +99,34 @@ static_assert(not is_fully_static(CallableTypeOf[f13])) static_assert(not is_fully_static(CallableTypeOf[f14])) static_assert(not is_fully_static(CallableTypeOf[f15])) ``` + +## Overloads + +`overloaded.pyi`: + +```pyi +from typing import Any, overload + +@overload +def gradual() -> None: ... +@overload +def gradual(a: Any) -> None: ... + +@overload +def static() -> None: ... +@overload +def static(x: int) -> None: ... +@overload +def static(x: str) -> str: ... +``` + +```py +from knot_extensions import CallableTypeOf, TypeOf, is_fully_static, static_assert +from overloaded import gradual, static + +static_assert(is_fully_static(TypeOf[gradual])) +static_assert(is_fully_static(TypeOf[static])) + +static_assert(not is_fully_static(CallableTypeOf[gradual])) +static_assert(is_fully_static(CallableTypeOf[static])) +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md index 43062b1802..e3e46a96e6 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md @@ -157,4 +157,6 @@ def f6(a, /): ... static_assert(not is_gradual_equivalent_to(CallableTypeOf[f1], CallableTypeOf[f6])) ``` +TODO: Overloads + [materializations]: https://typing.python.org/en/latest/spec/glossary.html#term-materialize diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md index aa5f025e13..c61f748598 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md @@ -1153,5 +1153,187 @@ static_assert(not is_subtype_of(TypeOf[A.g], Callable[[], int])) static_assert(is_subtype_of(TypeOf[A.f], Callable[[A, int], int])) ``` +### Overloads + +#### Subtype overloaded + +For `B <: A`, if a callable `B` is overloaded with two or more signatures, it is a subtype of +callable `A` if _at least one_ of the overloaded signatures in `B` is a subtype of `A`. + +`overloaded.pyi`: + +```pyi +from typing import overload + +class A: ... +class B: ... +class C: ... + +@overload +def overloaded(x: A) -> None: ... +@overload +def overloaded(x: B) -> None: ... +``` + +```py +from knot_extensions import CallableTypeOf, is_subtype_of, static_assert +from overloaded import A, B, C, overloaded + +def accepts_a(x: A) -> None: ... +def accepts_b(x: B) -> None: ... +def accepts_c(x: C) -> None: ... + +static_assert(is_subtype_of(CallableTypeOf[overloaded], CallableTypeOf[accepts_a])) +static_assert(is_subtype_of(CallableTypeOf[overloaded], CallableTypeOf[accepts_b])) +static_assert(not is_subtype_of(CallableTypeOf[overloaded], CallableTypeOf[accepts_c])) +``` + +#### Supertype overloaded + +For `B <: A`, if a callable `A` is overloaded with two or more signatures, callable `B` is a subtype +of `A` if `B` is a subtype of _all_ of the signatures in `A`. + +`overloaded.pyi`: + +```pyi +from typing import overload + +class Grandparent: ... +class Parent(Grandparent): ... +class Child(Parent): ... + +@overload +def overloaded(a: Child) -> None: ... +@overload +def overloaded(a: Parent) -> None: ... +@overload +def overloaded(a: Grandparent) -> None: ... +``` + +```py +from knot_extensions import CallableTypeOf, is_subtype_of, static_assert +from overloaded import Grandparent, Parent, Child, overloaded + +# This is a subtype of only the first overload +def child(a: Child) -> None: ... + +# This is a subtype of the first and second overload +def parent(a: Parent) -> None: ... + +# This is the only function that's a subtype of all overloads +def grandparent(a: Grandparent) -> None: ... + +static_assert(not is_subtype_of(CallableTypeOf[child], CallableTypeOf[overloaded])) +static_assert(not is_subtype_of(CallableTypeOf[parent], CallableTypeOf[overloaded])) +static_assert(is_subtype_of(CallableTypeOf[grandparent], CallableTypeOf[overloaded])) +``` + +#### Both overloads + +For `B <: A`, if both `A` and `B` is a callable that's overloaded with two or more signatures, then +`B` is a subtype of `A` if for _every_ signature in `A`, there is _at least one_ signature in `B` +that is a subtype of it. + +`overloaded.pyi`: + +```pyi +from typing import overload + +class Grandparent: ... +class Parent(Grandparent): ... +class Child(Parent): ... +class Other: ... + +@overload +def pg(a: Parent) -> None: ... +@overload +def pg(a: Grandparent) -> None: ... + +@overload +def po(a: Parent) -> None: ... +@overload +def po(a: Other) -> None: ... + +@overload +def go(a: Grandparent) -> None: ... +@overload +def go(a: Other) -> None: ... + +@overload +def cpg(a: Child) -> None: ... +@overload +def cpg(a: Parent) -> None: ... +@overload +def cpg(a: Grandparent) -> None: ... + +@overload +def empty_go() -> Child: ... +@overload +def empty_go(a: Grandparent) -> None: ... +@overload +def empty_go(a: Other) -> Other: ... + +@overload +def empty_cp() -> Parent: ... +@overload +def empty_cp(a: Child) -> None: ... +@overload +def empty_cp(a: Parent) -> None: ... +``` + +```py +from knot_extensions import CallableTypeOf, is_subtype_of, static_assert +from overloaded import pg, po, go, cpg, empty_go, empty_cp + +static_assert(is_subtype_of(CallableTypeOf[pg], CallableTypeOf[cpg])) +static_assert(is_subtype_of(CallableTypeOf[cpg], CallableTypeOf[pg])) + +static_assert(not is_subtype_of(CallableTypeOf[po], CallableTypeOf[pg])) +static_assert(not is_subtype_of(CallableTypeOf[pg], CallableTypeOf[po])) + +static_assert(is_subtype_of(CallableTypeOf[go], CallableTypeOf[pg])) +static_assert(not is_subtype_of(CallableTypeOf[pg], CallableTypeOf[go])) + +# Overload 1 in `empty_go` is a subtype of overload 1 in `empty_cp` +# Overload 2 in `empty_go` is a subtype of overload 2 in `empty_cp` +# Overload 2 in `empty_go` is a subtype of overload 3 in `empty_cp` +# +# All overloads in `empty_cp` has a subtype in `empty_go` +static_assert(is_subtype_of(CallableTypeOf[empty_go], CallableTypeOf[empty_cp])) + +static_assert(not is_subtype_of(CallableTypeOf[empty_cp], CallableTypeOf[empty_go])) +``` + +#### Order of overloads + +Order of overloads is irrelevant for subtyping. + +`overloaded.pyi`: + +```pyi +from typing import overload + +class A: ... +class B: ... + +@overload +def overload_ab(x: A) -> None: ... +@overload +def overload_ab(x: B) -> None: ... + +@overload +def overload_ba(x: B) -> None: ... +@overload +def overload_ba(x: A) -> None: ... +``` + +```py +from overloaded import overload_ab, overload_ba +from knot_extensions import CallableTypeOf, is_subtype_of, static_assert + +static_assert(is_subtype_of(CallableTypeOf[overload_ab], CallableTypeOf[overload_ba])) +static_assert(is_subtype_of(CallableTypeOf[overload_ba], CallableTypeOf[overload_ab])) +``` + [special case for float and complex]: https://typing.python.org/en/latest/spec/special-types.html#special-cases-for-float-and-complex [typing documentation]: https://typing.python.org/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence diff --git a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs index 160f3af7b9..bd5e93ada6 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs @@ -58,6 +58,13 @@ pub trait HasScopedUseId { fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId; } +impl HasScopedUseId for ast::Identifier { + fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId { + let ast_ids = ast_ids(db, scope); + ast_ids.use_id(self) + } +} + impl HasScopedUseId for ast::ExprName { fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId { let expression_ref = ExprRef::from(self); @@ -157,7 +164,7 @@ impl AstIdsBuilder { } /// Adds `expr` to the use ids map and returns its id. - pub(super) fn record_use(&mut self, expr: &ast::Expr) -> ScopedUseId { + pub(super) fn record_use(&mut self, expr: impl Into) -> ScopedUseId { let use_id = self.uses_map.len().into(); self.uses_map.insert(expr.into(), use_id); @@ -196,4 +203,10 @@ pub(crate) mod node_key { Self(NodeKey::from_node(value)) } } + + impl From<&ast::Identifier> for ExpressionNodeKey { + fn from(value: &ast::Identifier) -> Self { + Self(NodeKey::from_node(value)) + } + } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index e4c25f4840..f9d312ab53 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -1115,6 +1115,17 @@ where // at the end to match the runtime evaluation of parameter defaults // and return-type annotations. let (symbol, _) = self.add_symbol(name.id.clone()); + + // Record a use of the function name in the scope that it is defined in, so that it + // can be used to find previously defined functions with the same name. This is + // used to collect all the overloaded definitions of a function. This needs to be + // done on the `Identifier` node as opposed to `ExprName` because that's what the + // AST uses. + self.mark_symbol_used(symbol); + let use_id = self.current_ast_ids().record_use(name); + self.current_use_def_map_mut() + .record_use(symbol, use_id, NodeKey::from_node(name)); + self.add_definition(symbol, function_def); } ast::Stmt::ClassDef(class) => { diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 4a2b446d2e..6782de5f69 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -29,12 +29,14 @@ pub(crate) use self::signatures::{CallableSignature, Signature, Signatures}; pub(crate) use self::subclass_of::SubclassOfType; use crate::module_name::ModuleName; use crate::module_resolver::{file_to_module, resolve_module, KnownModule}; -use crate::semantic_index::ast_ids::HasScopedExpressionId; +use crate::semantic_index::ast_ids::{HasScopedExpressionId, HasScopedUseId}; use crate::semantic_index::definition::Definition; use crate::semantic_index::symbol::ScopeId; use crate::semantic_index::{imported_modules, semantic_index}; use crate::suppression::check_suppressions; -use crate::symbol::{imported_symbol, Boundness, Symbol, SymbolAndQualifiers}; +use crate::symbol::{ + imported_symbol, symbol_from_bindings, Boundness, Symbol, SymbolAndQualifiers, +}; use crate::types::call::{Bindings, CallArgumentTypes, CallableBinding}; pub(crate) use crate::types::class_base::ClassBase; use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION}; @@ -507,12 +509,14 @@ impl<'db> Type<'db> { .any(|ty| ty.contains_todo(db)), Self::Callable(callable) => { - let signature = callable.signature(db); - signature.parameters().iter().any(|param| { - param - .annotated_type() - .is_some_and(|ty| ty.contains_todo(db)) - }) || signature.return_ty.is_some_and(|ty| ty.contains_todo(db)) + let signatures = callable.signatures(db); + signatures.iter().any(|signature| { + signature.parameters().iter().any(|param| { + param + .annotated_type() + .is_some_and(|ty| ty.contains_todo(db)) + }) || signature.return_ty.is_some_and(|ty| ty.contains_todo(db)) + }) } Self::SubclassOf(subclass_of) => match subclass_of.subclass_of() { @@ -1029,9 +1033,9 @@ impl<'db> Type<'db> { .to_instance(db) .is_subtype_of(db, target), - (Type::Callable(self_callable), Type::Callable(other_callable)) => self_callable - .signature(db) - .is_subtype_of(db, other_callable.signature(db)), + (Type::Callable(self_callable), Type::Callable(other_callable)) => { + self_callable.is_subtype_of(db, other_callable) + } (Type::DataclassDecorator(_), _) => { // TODO: Implement subtyping using an equivalent `Callable` type. @@ -1358,9 +1362,9 @@ impl<'db> Type<'db> { ) } - (Type::Callable(self_callable), Type::Callable(target_callable)) => self_callable - .signature(db) - .is_assignable_to(db, target_callable.signature(db)), + (Type::Callable(self_callable), Type::Callable(target_callable)) => { + self_callable.is_assignable_to(db, target_callable) + } (Type::FunctionLiteral(self_function_literal), Type::Callable(_)) => { self_function_literal @@ -1391,9 +1395,7 @@ impl<'db> Type<'db> { left.is_equivalent_to(db, right) } (Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right), - (Type::Callable(left), Type::Callable(right)) => { - left.signature(db).is_equivalent_to(db, right.signature(db)) - } + (Type::Callable(left), Type::Callable(right)) => left.is_equivalent_to(db, right), _ => self == other && self.is_fully_static(db) && other.is_fully_static(db), } } @@ -1454,9 +1456,9 @@ impl<'db> Type<'db> { first.is_gradual_equivalent_to(db, second) } - (Type::Callable(first), Type::Callable(second)) => first - .signature(db) - .is_gradual_equivalent_to(db, second.signature(db)), + (Type::Callable(first), Type::Callable(second)) => { + first.is_gradual_equivalent_to(db, second) + } _ => false, } @@ -1904,7 +1906,7 @@ impl<'db> Type<'db> { .elements(db) .iter() .all(|elem| elem.is_fully_static(db)), - Type::Callable(callable) => callable.signature(db).is_fully_static(db), + Type::Callable(callable) => callable.is_fully_static(db), } } @@ -3141,16 +3143,27 @@ impl<'db> Type<'db> { /// [`CallErrorKind::NotCallable`]. fn signatures(self, db: &'db dyn Db) -> Signatures<'db> { match self { - Type::Callable(callable) => Signatures::single(CallableSignature::single( - self, - callable.signature(db).clone(), - )), + Type::Callable(callable) => { + Signatures::single(match callable.signatures(db).as_ref() { + [signature] => CallableSignature::single(self, signature.clone()), + signatures => { + CallableSignature::from_overloads(self, signatures.iter().cloned()) + } + }) + } Type::BoundMethod(bound_method) => { let signature = bound_method.function(db).signature(db); - let signature = CallableSignature::single(self, signature.clone()) - .with_bound_type(bound_method.self_instance(db)); - Signatures::single(signature) + Signatures::single(match signature { + FunctionSignature::Single(signature) => { + CallableSignature::single(self, signature.clone()) + .with_bound_type(bound_method.self_instance(db)) + } + FunctionSignature::Overloaded(signatures, _) => { + CallableSignature::from_overloads(self, signatures.iter().cloned()) + .with_bound_type(bound_method.self_instance(db)) + } + }) } Type::MethodWrapper( @@ -3497,10 +3510,14 @@ impl<'db> Type<'db> { Signatures::single(signature) } - _ => Signatures::single(CallableSignature::single( - self, - function_type.signature(db).clone(), - )), + _ => Signatures::single(match function_type.signature(db) { + FunctionSignature::Single(signature) => { + CallableSignature::single(self, signature.clone()) + } + FunctionSignature::Overloaded(signatures, _) => { + CallableSignature::from_overloads(self, signatures.iter().cloned()) + } + }), }, Type::ClassLiteral(class) => match class.known(db) { @@ -3692,7 +3709,10 @@ impl<'db> Type<'db> { .with_annotated_type(UnionType::from_elements( db, [ - Type::Callable(CallableType::new(db, getter_signature)), + Type::Callable(CallableType::single( + db, + getter_signature, + )), Type::none(db), ], )) @@ -3701,7 +3721,10 @@ impl<'db> Type<'db> { .with_annotated_type(UnionType::from_elements( db, [ - Type::Callable(CallableType::new(db, setter_signature)), + Type::Callable(CallableType::single( + db, + setter_signature, + )), Type::none(db), ], )) @@ -3710,7 +3733,7 @@ impl<'db> Type<'db> { .with_annotated_type(UnionType::from_elements( db, [ - Type::Callable(CallableType::new( + Type::Callable(CallableType::single( db, deleter_signature, )), @@ -5738,15 +5761,54 @@ bitflags! { } } +/// A function signature, which can be either a single signature or an overloaded signature. +#[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] +pub(crate) enum FunctionSignature<'db> { + /// A single function signature. + Single(Signature<'db>), + + /// An overloaded function signature containing the `@overload`-ed signatures and an optional + /// implementation signature. + Overloaded(Vec>, Option>), +} + +impl<'db> FunctionSignature<'db> { + /// Returns a slice of all signatures. + /// + /// For an overloaded function, this only includes the `@overload`-ed signatures and not the + /// implementation signature. + pub(crate) fn as_slice(&self) -> &[Signature<'db>] { + match self { + Self::Single(signature) => std::slice::from_ref(signature), + Self::Overloaded(signatures, _) => signatures, + } + } + + /// Returns an iterator over the signatures. + pub(crate) fn iter(&self) -> Iter> { + self.as_slice().iter() + } +} + +impl<'db> IntoIterator for &'db FunctionSignature<'db> { + type Item = &'db Signature<'db>; + type IntoIter = Iter<'db, Signature<'db>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + #[salsa::interned(debug)] pub struct FunctionType<'db> { - /// name of the function at definition + /// Name of the function at definition. #[return_ref] pub name: ast::name::Name, /// Is this a function that we special-case somehow? If so, which one? known: Option, + /// The scope that's created by the function, in which the function body is evaluated. body_scope: ScopeId<'db>, /// A set of special decorators that were applied to this function @@ -5768,10 +5830,11 @@ impl<'db> FunctionType<'db> { } /// Convert the `FunctionType` into a [`Type::Callable`]. - /// - /// This powers the `CallableTypeOf` special form from the `knot_extensions` module. pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> { - Type::Callable(CallableType::new(db, self.signature(db).clone())) + Type::Callable(CallableType::from_overloads( + db, + self.signature(db).iter().cloned(), + )) } /// Returns the [`FileRange`] of the function's name. @@ -5808,18 +5871,55 @@ impl<'db> FunctionType<'db> { /// Were this not a salsa query, then the calling query /// would depend on the function's AST and rerun for every change in that file. #[salsa::tracked(return_ref)] - pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> { + pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> { let mut internal_signature = self.internal_signature(db); - if self.has_known_decorator(db, FunctionDecorators::OVERLOAD) { - return Signature::todo("return type of overloaded function"); - } - if let Some(specialization) = self.specialization(db) { - internal_signature.apply_specialization(db, specialization); + internal_signature = internal_signature.apply_specialization(db, specialization); } - internal_signature + // The semantic model records a use for each function on the name node. This is used here + // to get the previous function definition with the same name. + let scope = self.definition(db).scope(db); + let use_def = semantic_index(db, scope.file(db)).use_def_map(scope.file_scope_id(db)); + let use_id = self + .body_scope(db) + .node(db) + .expect_function() + .name + .scoped_use_id(db, scope); + + if let Symbol::Type(Type::FunctionLiteral(function_literal), Boundness::Bound) = + symbol_from_bindings(db, use_def.bindings_at_use(use_id)) + { + match function_literal.signature(db) { + FunctionSignature::Single(_) => { + debug_assert!( + !function_literal.has_known_decorator(db, FunctionDecorators::OVERLOAD), + "Expected `FunctionSignature::Overloaded` if the previous function was an overload" + ); + } + FunctionSignature::Overloaded(_, Some(_)) => { + // If the previous overloaded function already has an implementation, then this + // new signature completely replaces it. + } + FunctionSignature::Overloaded(signatures, None) => { + return if self.has_known_decorator(db, FunctionDecorators::OVERLOAD) { + let mut signatures = signatures.clone(); + signatures.push(internal_signature); + FunctionSignature::Overloaded(signatures, None) + } else { + FunctionSignature::Overloaded(signatures.clone(), Some(internal_signature)) + }; + } + } + } + + if self.has_known_decorator(db, FunctionDecorators::OVERLOAD) { + FunctionSignature::Overloaded(vec![internal_signature], None) + } else { + FunctionSignature::Single(internal_signature) + } } /// Typed internally-visible signature for this function. @@ -6013,27 +6113,54 @@ pub struct BoundMethodType<'db> { impl<'db> BoundMethodType<'db> { pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> { - Type::Callable(CallableType::new( + Type::Callable(CallableType::from_overloads( db, - self.function(db).signature(db).bind_self(), + self.function(db) + .signature(db) + .iter() + .map(signatures::Signature::bind_self), )) } } -/// This type represents the set of all callable objects with a certain signature. -/// It can be written in type expressions using `typing.Callable`. -/// `lambda` expressions are inferred directly as `CallableType`s; all function-literal types -/// are subtypes of a `CallableType`. +/// This type represents the set of all callable objects with a certain, possibly overloaded, +/// signature. +/// +/// It can be written in type expressions using `typing.Callable`. `lambda` expressions are +/// inferred directly as `CallableType`s; all function-literal types are subtypes of a +/// `CallableType`. #[salsa::interned(debug)] pub struct CallableType<'db> { #[return_ref] - signature: Signature<'db>, + signatures: Box<[Signature<'db>]>, } impl<'db> CallableType<'db> { + /// Create a non-overloaded callable type with a single signature. + pub(crate) fn single(db: &'db dyn Db, signature: Signature<'db>) -> Self { + CallableType::new(db, vec![signature].into_boxed_slice()) + } + + /// Create an overloaded callable type with multiple signatures. + /// + /// # Panics + /// + /// Panics if `overloads` is empty. + pub(crate) fn from_overloads(db: &'db dyn Db, overloads: I) -> Self + where + I: IntoIterator>, + { + let overloads = overloads.into_iter().collect::>().into_boxed_slice(); + assert!( + !overloads.is_empty(), + "CallableType must have at least one signature" + ); + CallableType::new(db, overloads) + } + /// Create a callable type which accepts any parameters and returns an `Unknown` type. pub(crate) fn unknown(db: &'db dyn Db) -> Self { - CallableType::new( + CallableType::single( db, Signature::new(Parameters::unknown(), Some(Type::unknown())), ) @@ -6043,22 +6170,127 @@ impl<'db> CallableType<'db> { /// /// See [`Type::normalized`] for more details. fn normalized(self, db: &'db dyn Db) -> Self { - let signature = self.signature(db); - let parameters = signature - .parameters() - .iter() - .map(|param| param.normalized(db)) - .collect(); - let return_ty = signature - .return_ty - .map(|return_ty| return_ty.normalized(db)); - CallableType::new(db, Signature::new(parameters, return_ty)) + CallableType::from_overloads( + db, + self.signatures(db) + .iter() + .map(|signature| signature.normalized(db)), + ) } + /// Apply a specialization to this callable type. + /// + /// See [`Type::apply_specialization`] for more details. fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { - let mut signature = self.signature(db).clone(); - signature.apply_specialization(db, specialization); - Self::new(db, signature) + CallableType::from_overloads( + db, + self.signatures(db) + .iter() + .map(|signature| signature.apply_specialization(db, specialization)), + ) + } + + /// Check whether this callable type is fully static. + /// + /// See [`Type::is_fully_static`] for more details. + fn is_fully_static(self, db: &'db dyn Db) -> bool { + self.signatures(db) + .iter() + .all(|signature| signature.is_fully_static(db)) + } + + /// Check whether this callable type is a subtype of another callable type. + /// + /// See [`Type::is_subtype_of`] for more details. + fn is_subtype_of(self, db: &'db dyn Db, other: Self) -> bool { + self.is_assignable_to_impl(db, other, &|self_signature, other_signature| { + self_signature.is_subtype_of(db, other_signature) + }) + } + + /// Check whether this callable type is assignable to another callable type. + /// + /// See [`Type::is_assignable_to`] for more details. + fn is_assignable_to(self, db: &'db dyn Db, other: Self) -> bool { + self.is_assignable_to_impl(db, other, &|self_signature, other_signature| { + self_signature.is_assignable_to(db, other_signature) + }) + } + + /// Implementation for the various relation checks between two, possible overloaded, callable + /// types. + /// + /// The `check_signature` closure is used to check the relation between two [`Signature`]s. + fn is_assignable_to_impl(self, db: &'db dyn Db, other: Self, check_signature: &F) -> bool + where + F: Fn(&Signature<'db>, &Signature<'db>) -> bool, + { + match (&**self.signatures(db), &**other.signatures(db)) { + ([self_signature], [other_signature]) => { + // Base case: both callable types contain a single signature. + check_signature(self_signature, other_signature) + } + + // `self` is possibly overloaded while `other` is definitely not overloaded. + (self_signatures, [other_signature]) => { + let other_callable = CallableType::single(db, other_signature.clone()); + self_signatures + .iter() + .map(|self_signature| CallableType::single(db, self_signature.clone())) + .any(|self_callable| { + self_callable.is_assignable_to_impl(db, other_callable, check_signature) + }) + } + + // `self` is definitely not overloaded while `other` is possibly overloaded. + ([self_signature], other_signatures) => { + let self_callable = CallableType::single(db, self_signature.clone()); + other_signatures + .iter() + .map(|other_signature| CallableType::single(db, other_signature.clone())) + .all(|other_callable| { + self_callable.is_assignable_to_impl(db, other_callable, check_signature) + }) + } + + // `self` is definitely overloaded while `other` is possibly overloaded. + (_, other_signatures) => other_signatures + .iter() + .map(|other_signature| CallableType::single(db, other_signature.clone())) + .all(|other_callable| { + self.is_assignable_to_impl(db, other_callable, check_signature) + }), + } + } + + /// Check whether this callable type is equivalent to another callable type. + /// + /// See [`Type::is_equivalent_to`] for more details. + fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool { + match (&**self.signatures(db), &**other.signatures(db)) { + ([self_signature], [other_signature]) => { + self_signature.is_equivalent_to(db, other_signature) + } + _ => { + // TODO: overloads + false + } + } + } + + /// Check whether this callable type is gradual equivalent to another callable type. + /// + /// See [`Type::is_gradual_equivalent_to`] for more details. + fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool { + match (&**self.signatures(db), &**other.signatures(db)) { + ([self_signature], [other_signature]) => { + self_signature.is_gradual_equivalent_to(db, other_signature) + } + _ => { + // TODO: overloads + false + } + } } } diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index 16f00029e5..7232a08f8d 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -19,7 +19,7 @@ use crate::types::diagnostic::{ use crate::types::generics::{Specialization, SpecializationBuilder}; use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::{ - todo_type, BoundMethodType, DataclassMetadata, FunctionDecorators, KnownClass, KnownFunction, + BoundMethodType, DataclassMetadata, FunctionDecorators, KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType, WrapperDescriptorKind, }; use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic}; @@ -536,7 +536,11 @@ impl<'db> Bindings<'db> { } Some(KnownFunction::Overload) => { - overload.set_return_type(todo_type!("overload[..] return type")); + // TODO: This can be removed once we understand legacy generics because the + // typeshed definition for `typing.overload` is an identity function. + if let [Some(ty)] = overload.parameter_types() { + overload.set_return_type(*ty); + } } Some(KnownFunction::GetattrStatic) => { diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index d59fa8e006..f22e8a668e 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -931,7 +931,7 @@ impl<'db> ClassLiteralType<'db> { let init_signature = Signature::new(Parameters::new(parameters), Some(Type::none(db))); - return Some(Type::Callable(CallableType::new(db, init_signature))); + return Some(Type::Callable(CallableType::single(db, init_signature))); } else if matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") { if metadata.contains(DataclassMetadata::ORDER) { let signature = Signature::new( @@ -943,7 +943,7 @@ impl<'db> ClassLiteralType<'db> { Some(KnownClass::Bool.to_instance(db)), ); - return Some(Type::Callable(CallableType::new(db, signature))); + return Some(Type::Callable(CallableType::single(db, signature))); } } diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 5a6da80132..96c3c5c535 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -11,12 +11,15 @@ use crate::types::class_base::ClassBase; use crate::types::generics::{GenericContext, Specialization}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::{ - InstanceType, IntersectionType, KnownClass, MethodWrapperKind, StringLiteralType, Type, - TypeVarBoundOrConstraints, TypeVarInstance, UnionType, WrapperDescriptorKind, + FunctionSignature, InstanceType, IntersectionType, KnownClass, MethodWrapperKind, + StringLiteralType, Type, TypeVarBoundOrConstraints, TypeVarInstance, UnionType, + WrapperDescriptorKind, }; use crate::Db; use rustc_hash::FxHashMap; +use super::CallableType; + impl<'db> Type<'db> { pub fn display(&self, db: &'db dyn Db) -> DisplayType { DisplayType { ty: self, db } @@ -95,32 +98,59 @@ impl Display for DisplayRepresentation<'_> { Type::KnownInstance(known_instance) => f.write_str(known_instance.repr(self.db)), Type::FunctionLiteral(function) => { let signature = function.signature(self.db); + // TODO: when generic function types are supported, we should add // the generic type parameters to the signature, i.e. // show `def foo[T](x: T) -> T`. - write!( - f, - // "def {name}{specialization}{signature}", - "def {name}{signature}", - name = function.name(self.db), - signature = signature.display(self.db) - ) + match signature { + FunctionSignature::Single(signature) => { + write!( + f, + // "def {name}{specialization}{signature}", + "def {name}{signature}", + name = function.name(self.db), + signature = signature.display(self.db) + ) + } + FunctionSignature::Overloaded(signatures, _) => { + // TODO: How to display overloads? + f.write_str("Overload[")?; + let mut join = f.join(", "); + for signature in signatures { + join.entry(&signature.display(self.db)); + } + f.write_str("]") + } + } } - Type::Callable(callable) => callable.signature(self.db).display(self.db).fmt(f), + Type::Callable(callable) => callable.display(self.db).fmt(f), Type::BoundMethod(bound_method) => { let function = bound_method.function(self.db); // TODO: use the specialization from the method. Similar to the comment above // about the function specialization, - write!( - f, - "bound method {instance}.{method}{signature}", - method = function.name(self.db), - instance = bound_method.self_instance(self.db).display(self.db), - signature = function.signature(self.db).bind_self().display(self.db) - ) + match function.signature(self.db) { + FunctionSignature::Single(signature) => { + write!( + f, + "bound method {instance}.{method}{signature}", + method = function.name(self.db), + instance = bound_method.self_instance(self.db).display(self.db), + signature = signature.bind_self().display(self.db) + ) + } + FunctionSignature::Overloaded(signatures, _) => { + // TODO: How to display overloads? + f.write_str("Overload[")?; + let mut join = f.join(", "); + for signature in signatures { + join.entry(&signature.bind_self().display(self.db)); + } + f.write_str("]") + } + } } Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => { write!( @@ -355,8 +385,40 @@ impl Display for DisplaySpecialization<'_> { } } +impl<'db> CallableType<'db> { + pub(crate) fn display(&'db self, db: &'db dyn Db) -> DisplayCallableType<'db> { + DisplayCallableType { + signatures: self.signatures(db), + db, + } + } +} + +pub(crate) struct DisplayCallableType<'db> { + signatures: &'db [Signature<'db>], + db: &'db dyn Db, +} + +impl Display for DisplayCallableType<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self.signatures { + [signature] => write!(f, "{}", signature.display(self.db)), + signatures => { + // TODO: How to display overloads? + f.write_str("Overload[")?; + let mut join = f.join(", "); + for signature in signatures { + join.entry(&signature.display(self.db)); + } + join.finish()?; + f.write_char(']') + } + } + } +} + impl<'db> Signature<'db> { - fn display(&'db self, db: &'db dyn Db) -> DisplaySignature<'db> { + pub(crate) fn display(&'db self, db: &'db dyn Db) -> DisplaySignature<'db> { DisplaySignature { parameters: self.parameters(), return_ty: self.return_ty, @@ -365,7 +427,7 @@ impl<'db> Signature<'db> { } } -struct DisplaySignature<'db> { +pub(crate) struct DisplaySignature<'db> { parameters: &'db Parameters<'db>, return_ty: Option>, db: &'db dyn Db, diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 59248f9f7b..29a92c058d 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -4133,7 +4133,7 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO: Useful inference of a lambda's return type will require a different approach, // which does the inference of the body expression based on arguments at each call site, // rather than eagerly computing a return type without knowing the argument types. - Type::Callable(CallableType::new( + Type::Callable(CallableType::single( self.db(), Signature::new(parameters, Some(Type::unknown())), )) @@ -7305,7 +7305,7 @@ impl<'db> TypeInferenceBuilder<'db> { let callable_type = if let (Some(parameters), Some(return_type), true) = (parameters, return_type, correct_argument_number) { - CallableType::new(db, Signature::new(parameters, Some(return_type))) + CallableType::single(db, Signature::new(parameters, Some(return_type))) } else { CallableType::unknown(db) }; @@ -7386,8 +7386,22 @@ impl<'db> TypeInferenceBuilder<'db> { let argument_type = self.infer_expression(arguments_slice); let signatures = argument_type.signatures(db); - // TODO overloads - let Some(signature) = signatures.iter().flatten().next() else { + // SAFETY: This is enforced by the constructor methods on `Signatures` even in + // the case of a non-callable union. + let callable_signature = signatures + .iter() + .next() + .expect("`Signatures` should have at least one `CallableSignature`"); + + let mut signature_iter = callable_signature.iter().map(|signature| { + if argument_type.is_bound_method() { + signature.bind_self() + } else { + signature.clone() + } + }); + + let Some(signature) = signature_iter.next() else { self.context.report_lint_old( &INVALID_TYPE_FORM, arguments_slice, @@ -7400,13 +7414,10 @@ impl<'db> TypeInferenceBuilder<'db> { return Type::unknown(); }; - let revealed_signature = if argument_type.is_bound_method() { - signature.bind_self() - } else { - signature.clone() - }; - - Type::Callable(CallableType::new(db, revealed_signature)) + Type::Callable(CallableType::from_overloads( + db, + std::iter::once(signature).chain(signature_iter), + )) } }, diff --git a/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs b/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs index 8072190159..470f3e0778 100644 --- a/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs +++ b/crates/red_knot_python_semantic/src/types/property_tests/type_generation.rs @@ -188,7 +188,7 @@ impl Ty { create_bound_method(db, function, builtins_class) } - Ty::Callable { params, returns } => Type::Callable(CallableType::new( + Ty::Callable { params, returns } => Type::Callable(CallableType::single( db, Signature::new( params.into_parameters(db), diff --git a/crates/red_knot_python_semantic/src/types/signatures.rs b/crates/red_knot_python_semantic/src/types/signatures.rs index 50c022e8d6..52b1ab827f 100644 --- a/crates/red_knot_python_semantic/src/types/signatures.rs +++ b/crates/red_knot_python_semantic/src/types/signatures.rs @@ -250,16 +250,6 @@ impl<'db> Signature<'db> { } } - /// Return a todo signature: (*args: Todo, **kwargs: Todo) -> Todo - #[allow(unused_variables)] // 'reason' only unused in debug builds - pub(crate) fn todo(reason: &'static str) -> Self { - Signature { - generic_context: None, - parameters: Parameters::todo(), - return_ty: Some(todo_type!(reason)), - } - } - /// Return a typed signature from a function definition. pub(super) fn from_function( db: &'db dyn Db, @@ -286,15 +276,30 @@ impl<'db> Signature<'db> { } } + pub(crate) fn normalized(&self, db: &'db dyn Db) -> Self { + Self { + generic_context: self.generic_context, + parameters: self + .parameters + .iter() + .map(|param| param.normalized(db)) + .collect(), + return_ty: self.return_ty.map(|return_ty| return_ty.normalized(db)), + } + } + pub(crate) fn apply_specialization( - &mut self, + &self, db: &'db dyn Db, specialization: Specialization<'db>, - ) { - self.parameters.apply_specialization(db, specialization); - self.return_ty = self - .return_ty - .map(|ty| ty.apply_specialization(db, specialization)); + ) -> Self { + Self { + generic_context: self.generic_context, + parameters: self.parameters.apply_specialization(db, specialization), + return_ty: self + .return_ty + .map(|ty| ty.apply_specialization(db, specialization)), + } } /// Return the parameters in this signature. @@ -995,10 +1000,15 @@ impl<'db> Parameters<'db> { ) } - fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) { - self.value - .iter_mut() - .for_each(|param| param.apply_specialization(db, specialization)); + fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { + Self { + value: self + .value + .iter() + .map(|param| param.apply_specialization(db, specialization)) + .collect(), + is_gradual: self.is_gradual, + } } pub(crate) fn len(&self) -> usize { @@ -1162,11 +1172,14 @@ impl<'db> Parameter<'db> { self } - fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) { - self.annotated_type = self - .annotated_type - .map(|ty| ty.apply_specialization(db, specialization)); - self.kind.apply_specialization(db, specialization); + fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { + Self { + annotated_type: self + .annotated_type + .map(|ty| ty.apply_specialization(db, specialization)), + kind: self.kind.apply_specialization(db, specialization), + form: self.form, + } } /// Strip information from the parameter so that two equivalent parameters compare equal. @@ -1356,14 +1369,27 @@ pub(crate) enum ParameterKind<'db> { } impl<'db> ParameterKind<'db> { - fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) { + fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self { match self { - Self::PositionalOnly { default_type, .. } - | Self::PositionalOrKeyword { default_type, .. } - | Self::KeywordOnly { default_type, .. } => { - *default_type = default_type.map(|ty| ty.apply_specialization(db, specialization)); - } - Self::Variadic { .. } | Self::KeywordVariadic { .. } => {} + Self::PositionalOnly { default_type, name } => Self::PositionalOnly { + default_type: default_type + .as_ref() + .map(|ty| ty.apply_specialization(db, specialization)), + name: name.clone(), + }, + Self::PositionalOrKeyword { default_type, name } => Self::PositionalOrKeyword { + default_type: default_type + .as_ref() + .map(|ty| ty.apply_specialization(db, specialization)), + name: name.clone(), + }, + Self::KeywordOnly { default_type, name } => Self::KeywordOnly { + default_type: default_type + .as_ref() + .map(|ty| ty.apply_specialization(db, specialization)), + name: name.clone(), + }, + Self::Variadic { .. } | Self::KeywordVariadic { .. } => self.clone(), } } } @@ -1380,7 +1406,7 @@ mod tests { use super::*; use crate::db::tests::{setup_db, TestDb}; use crate::symbol::global_symbol; - use crate::types::{FunctionType, KnownClass}; + use crate::types::{FunctionSignature, FunctionType, KnownClass}; use ruff_db::system::DbWithWritableSystem as _; #[track_caller] @@ -1627,6 +1653,9 @@ mod tests { let expected_sig = func.internal_signature(&db); // With no decorators, internal and external signature are the same - assert_eq!(func.signature(&db), &expected_sig); + assert_eq!( + func.signature(&db), + &FunctionSignature::Single(expected_sig) + ); } }