Compare commits

...

1 Commits

Author SHA1 Message Date
Dhruv Manilawala
ec2861bc4f [ty] Pass the generic context through the decorator 2026-01-13 16:06:43 +05:30
3 changed files with 103 additions and 17 deletions

View File

@@ -886,8 +886,6 @@ class GenericClass[T]:
raise NotImplementedError
def _(x: list[str]):
# TODO: This fails because we are not propagating GenericClass's generic context into the
# Callable that we create for it.
# revealed: (x: list[T@GenericClass], y: list[T@GenericClass]) -> GenericClass[T@GenericClass]
reveal_type(into_callable(GenericClass))
# revealed: ty_extensions.GenericContext[T@GenericClass]
@@ -895,15 +893,10 @@ def _(x: list[str]):
# revealed: (x: list[T@GenericClass], y: list[T@GenericClass]) -> GenericClass[T@GenericClass]
reveal_type(accepts_callable(GenericClass))
# TODO: revealed: ty_extensions.GenericContext[T@GenericClass]
# revealed: None
# revealed: ty_extensions.GenericContext[T@GenericClass]
reveal_type(generic_context(accepts_callable(GenericClass)))
# TODO: revealed: GenericClass[str]
# TODO: no errors
# revealed: GenericClass[T@GenericClass]
# error: [invalid-argument-type]
# error: [invalid-argument-type]
# revealed: GenericClass[str]
reveal_type(accepts_callable(GenericClass)(x, x))
```

View File

@@ -800,3 +800,78 @@ def f(x: int, y: str):
reveal_type(infer_paramspec(f)) # revealed: (x: int, y: str) -> None
```
## Generic context preservation through `ParamSpec` decorators
When a generic function is decorated with a `ParamSpec`-based decorator, the generic context of the
decorated function should be preserved. This allows type inference to work correctly when calling the
decorated function.
Regression test for <https://github.com/astral-sh/ty/issues/2336>
### Basic
```py
from typing import Callable
from ty_extensions import generic_context
def decorator[**P, T](func: Callable[P, T]) -> Callable[P, T]:
return func
@decorator
def identity[T](value: T) -> T:
return value
@decorator
def pair[T, U](first: T, second: U) -> tuple[T, U]:
return (first, second)
# revealed: ty_extensions.GenericContext[T@identity]
reveal_type(generic_context(identity))
# revealed: ty_extensions.GenericContext[T@pair, U@pair]
reveal_type(generic_context(pair))
reveal_type(identity(1)) # revealed: Literal[1]
reveal_type(identity("hello")) # revealed: Literal["hello"]
reveal_type(pair(1, "a")) # revealed: tuple[Literal[1], Literal["a"]]
reveal_type(pair("x", 2.5)) # revealed: tuple[Literal["x"], float]
```
### Chained decorators with generic functions
```py
from typing import Callable
def decorator1[**P, R](func: Callable[P, R]) -> Callable[P, R]:
return func
def decorator2[**P, R](func: Callable[P, R]) -> Callable[P, R]:
return func
@decorator1
@decorator2
def chained_generic[T](value: T) -> T:
return value
reveal_type(chained_generic(42)) # revealed: Literal[42]
reveal_type(chained_generic("test")) # revealed: Literal["test"]
```
### Generic method decoration
```py
from typing import Callable
def method_decorator[**P, R](func: Callable[P, R]) -> Callable[P, R]:
return func
class Container:
@method_decorator
def generic_method[T](self, value: T) -> T:
return value
c = Container()
reveal_type(c.generic_method(100)) # revealed: Literal[100]
reveal_type(c.generic_method([1, 2, 3])) # revealed: list[Unknown | int]
```

View File

@@ -188,9 +188,13 @@ impl<'db> CallableSignature<'db> {
{
Some(CallableSignature::from_overloads(
callable.signatures(db).iter().map(|signature| Signature {
generic_context: self_signature.generic_context.map(|context| {
type_mapping.update_signature_generic_context(db, context)
}),
generic_context: GenericContext::merge_optional(
db,
signature.generic_context,
self_signature.generic_context.map(|context| {
type_mapping.update_signature_generic_context(db, context)
}),
),
definition: signature.definition,
parameters: if signature.parameters().is_top() {
signature.parameters().clone()
@@ -414,7 +418,11 @@ impl<'db> CallableSignature<'db> {
db,
CallableSignature::from_overloads(other_signatures.iter().map(
|signature| {
Signature::new(signature.parameters().clone(), Type::unknown())
Signature::new_generic(
signature.generic_context,
signature.parameters().clone(),
Type::unknown(),
)
},
)),
CallableTypeKind::ParamSpecValue,
@@ -446,7 +454,11 @@ impl<'db> CallableSignature<'db> {
db,
CallableSignature::from_overloads(self_signatures.iter().map(
|signature| {
Signature::new(signature.parameters().clone(), Type::unknown())
Signature::new_generic(
signature.generic_context,
signature.parameters().clone(),
Type::unknown(),
)
},
)),
CallableTypeKind::ParamSpecValue,
@@ -1083,7 +1095,11 @@ impl<'db> Signature<'db> {
let upper = Type::Callable(CallableType::new(
db,
CallableSignature::from_overloads(other.overloads.iter().map(|signature| {
Signature::new(signature.parameters().clone(), Type::unknown())
Signature::new_generic(
signature.generic_context,
signature.parameters().clone(),
Type::unknown(),
)
})),
CallableTypeKind::ParamSpecValue,
));
@@ -1339,7 +1355,8 @@ impl<'db> Signature<'db> {
(Some(self_bound_typevar), None) => {
let upper = Type::Callable(CallableType::new(
db,
CallableSignature::single(Signature::new(
CallableSignature::single(Signature::new_generic(
other.generic_context,
other.parameters.clone(),
Type::unknown(),
)),
@@ -1358,7 +1375,8 @@ impl<'db> Signature<'db> {
(None, Some(other_bound_typevar)) => {
let lower = Type::Callable(CallableType::new(
db,
CallableSignature::single(Signature::new(
CallableSignature::single(Signature::new_generic(
self.generic_context,
self.parameters.clone(),
Type::unknown(),
)),