Compare commits
1 Commits
alex/subsc
...
dhruv/para
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec2861bc4f |
@@ -886,8 +886,6 @@ class GenericClass[T]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _(x: list[str]):
|
||||
# TODO: This fails because we are not propagating GenericClass's generic context into the
|
||||
# Callable that we create for it.
|
||||
# revealed: (x: list[T@GenericClass], y: list[T@GenericClass]) -> GenericClass[T@GenericClass]
|
||||
reveal_type(into_callable(GenericClass))
|
||||
# revealed: ty_extensions.GenericContext[T@GenericClass]
|
||||
@@ -895,15 +893,10 @@ def _(x: list[str]):
|
||||
|
||||
# revealed: (x: list[T@GenericClass], y: list[T@GenericClass]) -> GenericClass[T@GenericClass]
|
||||
reveal_type(accepts_callable(GenericClass))
|
||||
# TODO: revealed: ty_extensions.GenericContext[T@GenericClass]
|
||||
# revealed: None
|
||||
# revealed: ty_extensions.GenericContext[T@GenericClass]
|
||||
reveal_type(generic_context(accepts_callable(GenericClass)))
|
||||
|
||||
# TODO: revealed: GenericClass[str]
|
||||
# TODO: no errors
|
||||
# revealed: GenericClass[T@GenericClass]
|
||||
# error: [invalid-argument-type]
|
||||
# error: [invalid-argument-type]
|
||||
# revealed: GenericClass[str]
|
||||
reveal_type(accepts_callable(GenericClass)(x, x))
|
||||
```
|
||||
|
||||
|
||||
@@ -800,3 +800,78 @@ def f(x: int, y: str):
|
||||
|
||||
reveal_type(infer_paramspec(f)) # revealed: (x: int, y: str) -> None
|
||||
```
|
||||
|
||||
## Generic context preservation through `ParamSpec` decorators
|
||||
|
||||
When a generic function is decorated with a `ParamSpec`-based decorator, the generic context of the
|
||||
decorated function should be preserved. This allows type inference to work correctly when calling the
|
||||
decorated function.
|
||||
|
||||
Regression test for <https://github.com/astral-sh/ty/issues/2336>
|
||||
|
||||
### Basic
|
||||
|
||||
```py
|
||||
from typing import Callable
|
||||
from ty_extensions import generic_context
|
||||
|
||||
def decorator[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
||||
return func
|
||||
|
||||
@decorator
|
||||
def identity[T](value: T) -> T:
|
||||
return value
|
||||
|
||||
@decorator
|
||||
def pair[T, U](first: T, second: U) -> tuple[T, U]:
|
||||
return (first, second)
|
||||
|
||||
# revealed: ty_extensions.GenericContext[T@identity]
|
||||
reveal_type(generic_context(identity))
|
||||
# revealed: ty_extensions.GenericContext[T@pair, U@pair]
|
||||
reveal_type(generic_context(pair))
|
||||
|
||||
reveal_type(identity(1)) # revealed: Literal[1]
|
||||
reveal_type(identity("hello")) # revealed: Literal["hello"]
|
||||
|
||||
reveal_type(pair(1, "a")) # revealed: tuple[Literal[1], Literal["a"]]
|
||||
reveal_type(pair("x", 2.5)) # revealed: tuple[Literal["x"], float]
|
||||
```
|
||||
|
||||
### Chained decorators with generic functions
|
||||
|
||||
```py
|
||||
from typing import Callable
|
||||
|
||||
def decorator1[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
||||
return func
|
||||
|
||||
def decorator2[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
||||
return func
|
||||
|
||||
@decorator1
|
||||
@decorator2
|
||||
def chained_generic[T](value: T) -> T:
|
||||
return value
|
||||
|
||||
reveal_type(chained_generic(42)) # revealed: Literal[42]
|
||||
reveal_type(chained_generic("test")) # revealed: Literal["test"]
|
||||
```
|
||||
|
||||
### Generic method decoration
|
||||
|
||||
```py
|
||||
from typing import Callable
|
||||
|
||||
def method_decorator[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
||||
return func
|
||||
|
||||
class Container:
|
||||
@method_decorator
|
||||
def generic_method[T](self, value: T) -> T:
|
||||
return value
|
||||
|
||||
c = Container()
|
||||
reveal_type(c.generic_method(100)) # revealed: Literal[100]
|
||||
reveal_type(c.generic_method([1, 2, 3])) # revealed: list[Unknown | int]
|
||||
```
|
||||
|
||||
@@ -188,9 +188,13 @@ impl<'db> CallableSignature<'db> {
|
||||
{
|
||||
Some(CallableSignature::from_overloads(
|
||||
callable.signatures(db).iter().map(|signature| Signature {
|
||||
generic_context: self_signature.generic_context.map(|context| {
|
||||
type_mapping.update_signature_generic_context(db, context)
|
||||
}),
|
||||
generic_context: GenericContext::merge_optional(
|
||||
db,
|
||||
signature.generic_context,
|
||||
self_signature.generic_context.map(|context| {
|
||||
type_mapping.update_signature_generic_context(db, context)
|
||||
}),
|
||||
),
|
||||
definition: signature.definition,
|
||||
parameters: if signature.parameters().is_top() {
|
||||
signature.parameters().clone()
|
||||
@@ -414,7 +418,11 @@ impl<'db> CallableSignature<'db> {
|
||||
db,
|
||||
CallableSignature::from_overloads(other_signatures.iter().map(
|
||||
|signature| {
|
||||
Signature::new(signature.parameters().clone(), Type::unknown())
|
||||
Signature::new_generic(
|
||||
signature.generic_context,
|
||||
signature.parameters().clone(),
|
||||
Type::unknown(),
|
||||
)
|
||||
},
|
||||
)),
|
||||
CallableTypeKind::ParamSpecValue,
|
||||
@@ -446,7 +454,11 @@ impl<'db> CallableSignature<'db> {
|
||||
db,
|
||||
CallableSignature::from_overloads(self_signatures.iter().map(
|
||||
|signature| {
|
||||
Signature::new(signature.parameters().clone(), Type::unknown())
|
||||
Signature::new_generic(
|
||||
signature.generic_context,
|
||||
signature.parameters().clone(),
|
||||
Type::unknown(),
|
||||
)
|
||||
},
|
||||
)),
|
||||
CallableTypeKind::ParamSpecValue,
|
||||
@@ -1083,7 +1095,11 @@ impl<'db> Signature<'db> {
|
||||
let upper = Type::Callable(CallableType::new(
|
||||
db,
|
||||
CallableSignature::from_overloads(other.overloads.iter().map(|signature| {
|
||||
Signature::new(signature.parameters().clone(), Type::unknown())
|
||||
Signature::new_generic(
|
||||
signature.generic_context,
|
||||
signature.parameters().clone(),
|
||||
Type::unknown(),
|
||||
)
|
||||
})),
|
||||
CallableTypeKind::ParamSpecValue,
|
||||
));
|
||||
@@ -1339,7 +1355,8 @@ impl<'db> Signature<'db> {
|
||||
(Some(self_bound_typevar), None) => {
|
||||
let upper = Type::Callable(CallableType::new(
|
||||
db,
|
||||
CallableSignature::single(Signature::new(
|
||||
CallableSignature::single(Signature::new_generic(
|
||||
other.generic_context,
|
||||
other.parameters.clone(),
|
||||
Type::unknown(),
|
||||
)),
|
||||
@@ -1358,7 +1375,8 @@ impl<'db> Signature<'db> {
|
||||
(None, Some(other_bound_typevar)) => {
|
||||
let lower = Type::Callable(CallableType::new(
|
||||
db,
|
||||
CallableSignature::single(Signature::new(
|
||||
CallableSignature::single(Signature::new_generic(
|
||||
self.generic_context,
|
||||
self.parameters.clone(),
|
||||
Type::unknown(),
|
||||
)),
|
||||
|
||||
Reference in New Issue
Block a user