[ty] Handle annotated self parameter in constructor of non-invariant generic classes (#21325)
This manifested as an error when inferring the type of a PEP-695 generic
class via its constructor parameters:
```py
class D[T, U]:
@overload
def __init__(self: "D[str, U]", u: U) -> None: ...
@overload
def __init__(self, t: T, u: U) -> None: ...
def __init__(self, *args) -> None: ...
# revealed: D[Unknown, str]
# SHOULD BE: D[str, str]
reveal_type(D("string"))
```
This manifested because `D` is inferred to be bivariant in both `T` and
`U`. We weren't seeing this in the equivalent example for legacy
typevars, since those default to invariant. (This issue also showed up
for _covariant_ typevars, so this issue was not limited to bivariance.)
The underlying cause was because of a heuristic that we have in our
current constraint solver, which attempts to handle situations like
this:
```py
def f[T](t: T | None): ...
f(None)
```
Here, the `None` argument matches the non-typevar union element, so this
argument should not add any constraints on what `T` can specialize to.
Our previous heuristic would check for this by seeing if the argument
type is a subtype of the parameter annotation as a whole — even if it
isn't a union! That would cause us to erroneously ignore the `self`
parameter in our constructor call, since bivariant classes are
equivalent to each other, regardless of their specializations.
The quick fix is to move this heuristic "down a level", so that we only
apply it when the parameter annotation is a union. This heuristic should
go away completely 🤞 with the new constraint solver.
This commit is contained in:
@@ -674,8 +674,7 @@ x6: Covariant[Any] = covariant(1)
|
||||
x7: Contravariant[Any] = contravariant(1)
|
||||
x8: Invariant[Any] = invariant(1)
|
||||
|
||||
# TODO: This could reveal `Bivariant[Any]`.
|
||||
reveal_type(x5) # revealed: Bivariant[Literal[1]]
|
||||
reveal_type(x5) # revealed: Bivariant[Any]
|
||||
reveal_type(x6) # revealed: Covariant[Any]
|
||||
reveal_type(x7) # revealed: Contravariant[Any]
|
||||
reveal_type(x8) # revealed: Invariant[Any]
|
||||
|
||||
@@ -436,9 +436,7 @@ def test_seq(x: Sequence[T]) -> Sequence[T]:
|
||||
def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]):
|
||||
reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]]
|
||||
reveal_type(test_seq(t2)) # revealed: Sequence[int | str]
|
||||
|
||||
# TODO: this should be `Sequence[Never]`
|
||||
reveal_type(test_seq(t3)) # revealed: Sequence[Unknown]
|
||||
reveal_type(test_seq(t3)) # revealed: Sequence[Never]
|
||||
```
|
||||
|
||||
### `__init__` is itself generic
|
||||
@@ -466,6 +464,7 @@ wrong_innards: C[int] = C("five", 1)
|
||||
from typing_extensions import overload, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
class C(Generic[T]):
|
||||
@overload
|
||||
@@ -497,6 +496,17 @@ C[int](12)
|
||||
C[None]("string") # error: [no-matching-overload]
|
||||
C[None](b"bytes") # error: [no-matching-overload]
|
||||
C[None](12)
|
||||
|
||||
class D(Generic[T, U]):
|
||||
@overload
|
||||
def __init__(self: "D[str, U]", u: U) -> None: ...
|
||||
@overload
|
||||
def __init__(self, t: T, u: U) -> None: ...
|
||||
def __init__(self, *args) -> None: ...
|
||||
|
||||
reveal_type(D("string")) # revealed: D[str, str]
|
||||
reveal_type(D(1)) # revealed: D[str, int]
|
||||
reveal_type(D(1, "string")) # revealed: D[int, str]
|
||||
```
|
||||
|
||||
### Synthesized methods with dataclasses
|
||||
|
||||
@@ -375,9 +375,7 @@ def test_seq[T](x: Sequence[T]) -> Sequence[T]:
|
||||
def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]):
|
||||
reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]]
|
||||
reveal_type(test_seq(t2)) # revealed: Sequence[int | str]
|
||||
|
||||
# TODO: this should be `Sequence[Never]`
|
||||
reveal_type(test_seq(t3)) # revealed: Sequence[Unknown]
|
||||
reveal_type(test_seq(t3)) # revealed: Sequence[Never]
|
||||
```
|
||||
|
||||
### `__init__` is itself generic
|
||||
@@ -436,6 +434,17 @@ C[int](12)
|
||||
C[None]("string") # error: [no-matching-overload]
|
||||
C[None](b"bytes") # error: [no-matching-overload]
|
||||
C[None](12)
|
||||
|
||||
class D[T, U]:
|
||||
@overload
|
||||
def __init__(self: "D[str, U]", u: U) -> None: ...
|
||||
@overload
|
||||
def __init__(self, t: T, u: U) -> None: ...
|
||||
def __init__(self, *args) -> None: ...
|
||||
|
||||
reveal_type(D("string")) # revealed: D[str, str]
|
||||
reveal_type(D(1)) # revealed: D[str, int]
|
||||
reveal_type(D(1, "string")) # revealed: D[int, str]
|
||||
```
|
||||
|
||||
### Synthesized methods with dataclasses
|
||||
|
||||
Reference in New Issue
Block a user