From 33b942c7ad4e3c71b832dbb234adbb8b3ee9b0be Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Mon, 10 Nov 2025 19:46:49 -0500 Subject: [PATCH] [ty] Handle annotated `self` parameter in constructor of non-invariant generic classes (#21325) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This manifested as an error when inferring the type of a PEP-695 generic class via its constructor parameters: ```py class D[T, U]: @overload def __init__(self: "D[str, U]", u: U) -> None: ... @overload def __init__(self, t: T, u: U) -> None: ... def __init__(self, *args) -> None: ... # revealed: D[Unknown, str] # SHOULD BE: D[str, str] reveal_type(D("string")) ``` This manifested because `D` is inferred to be bivariant in both `T` and `U`. We weren't seeing this in the equivalent example for legacy typevars, since those default to invariant. (This issue also showed up for _covariant_ typevars, so this issue was not limited to bivariance.) The underlying cause was because of a heuristic that we have in our current constraint solver, which attempts to handle situations like this: ```py def f[T](t: T | None): ... f(None) ``` Here, the `None` argument matches the non-typevar union element, so this argument should not add any constraints on what `T` can specialize to. Our previous heuristic would check for this by seeing if the argument type is a subtype of the parameter annotation as a whole — even if it isn't a union! That would cause us to erroneously ignore the `self` parameter in our constructor call, since bivariant classes are equivalent to each other, regardless of their specializations. The quick fix is to move this heuristic "down a level", so that we only apply it when the parameter annotation is a union. This heuristic should go away completely :crossed_fingers: with the new constraint solver. --- .../mdtest/assignment/annotations.md | 3 +- .../mdtest/generics/legacy/classes.md | 16 ++++-- .../mdtest/generics/pep695/classes.md | 15 ++++-- .../ty_python_semantic/src/types/generics.rs | 53 +++++++++---------- 4 files changed, 50 insertions(+), 37 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index 3865572726..043380338b 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -674,8 +674,7 @@ x6: Covariant[Any] = covariant(1) x7: Contravariant[Any] = contravariant(1) x8: Invariant[Any] = invariant(1) -# TODO: This could reveal `Bivariant[Any]`. -reveal_type(x5) # revealed: Bivariant[Literal[1]] +reveal_type(x5) # revealed: Bivariant[Any] reveal_type(x6) # revealed: Covariant[Any] reveal_type(x7) # revealed: Contravariant[Any] reveal_type(x8) # revealed: Invariant[Any] diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md index a1f47c3b11..7ba6803dda 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md @@ -436,9 +436,7 @@ def test_seq(x: Sequence[T]) -> Sequence[T]: def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]): reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]] reveal_type(test_seq(t2)) # revealed: Sequence[int | str] - - # TODO: this should be `Sequence[Never]` - reveal_type(test_seq(t3)) # revealed: Sequence[Unknown] + reveal_type(test_seq(t3)) # revealed: Sequence[Never] ``` ### `__init__` is itself generic @@ -466,6 +464,7 @@ wrong_innards: C[int] = C("five", 1) from typing_extensions import overload, Generic, TypeVar T = TypeVar("T") +U = TypeVar("U") class C(Generic[T]): @overload @@ -497,6 +496,17 @@ C[int](12) C[None]("string") # error: [no-matching-overload] C[None](b"bytes") # error: [no-matching-overload] C[None](12) + +class D(Generic[T, U]): + @overload + def __init__(self: "D[str, U]", u: U) -> None: ... + @overload + def __init__(self, t: T, u: U) -> None: ... + def __init__(self, *args) -> None: ... + +reveal_type(D("string")) # revealed: D[str, str] +reveal_type(D(1)) # revealed: D[str, int] +reveal_type(D(1, "string")) # revealed: D[int, str] ``` ### Synthesized methods with dataclasses diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md index 30a9ee88ae..a01b468ad0 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md @@ -375,9 +375,7 @@ def test_seq[T](x: Sequence[T]) -> Sequence[T]: def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]): reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]] reveal_type(test_seq(t2)) # revealed: Sequence[int | str] - - # TODO: this should be `Sequence[Never]` - reveal_type(test_seq(t3)) # revealed: Sequence[Unknown] + reveal_type(test_seq(t3)) # revealed: Sequence[Never] ``` ### `__init__` is itself generic @@ -436,6 +434,17 @@ C[int](12) C[None]("string") # error: [no-matching-overload] C[None](b"bytes") # error: [no-matching-overload] C[None](12) + +class D[T, U]: + @overload + def __init__(self: "D[str, U]", u: U) -> None: ... + @overload + def __init__(self, t: T, u: U) -> None: ... + def __init__(self, *args) -> None: ... + +reveal_type(D("string")) # revealed: D[str, str] +reveal_type(D(1)) # revealed: D[str, int] +reveal_type(D(1, "string")) # revealed: D[int, str] ``` ### Synthesized methods with dataclasses diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 11b708cf69..8be33c95fc 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1393,31 +1393,6 @@ impl<'db> SpecializationBuilder<'db> { return Ok(()); } - // If the actual type is a subtype of the formal type, then return without adding any new - // type mappings. (Note that if the formal type contains any typevars, this check will - // fail, since no non-typevar types are assignable to a typevar. Also note that we are - // checking _subtyping_, not _assignability_, so that we do specialize typevars to dynamic - // argument types; and we have a special case for `Never`, which is a subtype of all types, - // but which we also do want as a specialization candidate.) - // - // In particular, this handles a case like - // - // ```py - // def f[T](t: T | None): ... - // - // f(None) - // ``` - // - // without specializing `T` to `None`. - if !matches!(formal, Type::ProtocolInstance(_)) - && !actual.is_never() - && actual - .when_subtype_of(self.db, formal, self.inferable) - .is_always_satisfied(self.db) - { - return Ok(()); - } - // Remove the union elements from `actual` that are not related to `formal`, and vice // versa. // @@ -1473,10 +1448,30 @@ impl<'db> SpecializationBuilder<'db> { self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter); } (Type::Union(formal), _) => { - // Second, if the formal is a union, and precisely one union element _is_ a typevar (not - // _contains_ a typevar), then we add a mapping between that typevar and the actual - // type. (Note that we've already handled above the case where the actual is - // assignable to any _non-typevar_ union element.) + // Second, if the formal is a union, and precisely one union element is assignable + // from the actual type, then we don't add any type mapping. This handles a case like + // + // ```py + // def f[T](t: T | None): ... + // + // f(None) + // ``` + // + // without specializing `T` to `None`. + // + // Otherwise, if precisely one union element _is_ a typevar (not _contains_ a + // typevar), then we add a mapping between that typevar and the actual type. + if !actual.is_never() { + let assignable_elements = (formal.elements(self.db).iter()).filter(|ty| { + actual + .when_subtype_of(self.db, **ty, self.inferable) + .is_always_satisfied(self.db) + }); + if assignable_elements.exactly_one().is_ok() { + return Ok(()); + } + } + let bound_typevars = (formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar()); if let Ok(bound_typevar) = bound_typevars.exactly_one() {