Compare commits

...

16 Commits

Author SHA1 Message Date
Douglas Creager
e1f9ba7f05 don't substitute recursively! 2025-12-19 13:35:33 -05:00
Douglas Creager
d6a98d8f2c mdlint 2025-12-18 20:59:34 -05:00
Douglas Creager
15c9d66ef6 document missing case 2025-12-18 20:56:20 -05:00
Douglas Creager
3d0dea21ad use or_else here, it's cleaner 2025-12-18 11:20:18 -05:00
Douglas Creager
a35c733b64 clean up callback types 2025-12-18 11:20:18 -05:00
Douglas Creager
a694ede51d clippy 2025-12-18 11:20:18 -05:00
Douglas Creager
3cc77a626d update higher-order callable tests 2025-12-18 11:20:18 -05:00
Douglas Creager
056414bfac typevar ordering strikes again 2025-12-18 11:20:18 -05:00
Douglas Creager
804842ac80 failing tests for propagating typevars 2025-12-18 09:12:28 -05:00
Douglas Creager
026b753735 handle the covariant case 2025-12-18 09:12:28 -05:00
Douglas Creager
6113935578 partial specialize single typevar 2025-12-18 09:12:27 -05:00
Douglas Creager
566bfc5ecb partial spec is an enum 2025-12-18 09:12:27 -05:00
Douglas Creager
7d38dc8685 restructure a bit 2025-12-18 09:12:27 -05:00
Douglas Creager
f22890bd88 add quantification test cases 2025-12-18 09:12:27 -05:00
Douglas Creager
b051943504 add exists too 2025-12-18 09:12:26 -05:00
Douglas Creager
6d4681e65d add retain_one extension method 2025-12-18 08:51:40 -05:00
9 changed files with 942 additions and 77 deletions

View File

@@ -481,22 +481,99 @@ reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None
## Passing generic functions to generic functions
```py
`functions.pyi`:
```pyi
from typing import Callable
def invoke[A, B](fn: Callable[[A], B], value: A) -> B:
return fn(value)
def invoke[A, B](fn: Callable[[A], B], value: A) -> B: ...
def identity[T](x: T) -> T:
return x
def identity[T](x: T) -> T: ...
def head[T](xs: list[T]) -> T:
return xs[0]
class Covariant[T]:
def get(self) -> T:
raise NotImplementedError
def head_covariant[T](xs: Covariant[T]) -> T: ...
def lift_covariant[T](xs: T) -> Covariant[T]: ...
class Contravariant[T]:
def receive(self, input: T): ...
def head_contravariant[T](xs: Contravariant[T]) -> T: ...
def lift_contravariant[T](xs: T) -> Contravariant[T]: ...
class Invariant[T]:
mutable_attribute: T
def head_invariant[T](xs: Invariant[T]) -> T: ...
def lift_invariant[T](xs: T) -> Invariant[T]: ...
```
A simple function that passes through its parameter type unchanged:
`simple.py`:
```py
from functions import invoke, identity
reveal_type(invoke(identity, 1)) # revealed: Literal[1]
```
# TODO: this should be `Unknown | int`
reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown
When the either the parameter or the return type is a generic alias referring to the typevar, we
should still be able to propagate the specializations through. This should work regardless of the
typevar's variance in the generic alias.
TODO: This currently only works for the `lift` functions (TODO: and only currently for the covariant
case). For the `lift` functions, the parameter type is a bare typevar, resulting in us inferring a
type mapping of `A = int, B = Class[A]`. When specializing, we can substitute the mapping of `A`
into the mapping of `B`, giving the correct return type.
For the `head` functions, the parameter type is a generic alias, resulting in us inferring a type
mapping of `A = Class[int], A = Class[B]`. At this point, the old solver is not able to unify the
two mappings for `A`, and we have no mapping for `B`. As a result, we infer `Unknown` for the return
type.
As part of migrating to the new solver, we will generate a single constraint set combining all of
the facts that we learn while checking the arguments. And the constraint set implementation should
be able to unify the two assignments to `A`.
`covariant.py`:
```py
from functions import invoke, Covariant, head_covariant, lift_covariant
# TODO: revealed: `int`
# revealed: Unknown
reveal_type(invoke(head_covariant, Covariant[int]()))
# revealed: Covariant[Literal[1]]
reveal_type(invoke(lift_covariant, 1))
```
`contravariant.py`:
```py
from functions import invoke, Contravariant, head_contravariant, lift_contravariant
# TODO: revealed: `int`
# revealed: Unknown
reveal_type(invoke(head_contravariant, Contravariant[int]()))
# TODO: revealed: Contravariant[int]
# revealed: Unknown
reveal_type(invoke(lift_contravariant, 1))
```
`invariant.py`:
```py
from functions import invoke, Invariant, head_invariant, lift_invariant
# TODO: revealed: `int`
# revealed: Unknown
reveal_type(invoke(head_invariant, Invariant[int]()))
# TODO: revealed: `Invariant[int]`
# revealed: Unknown
reveal_type(invoke(lift_invariant, 1))
```
## Protocols as TypeVar bounds

View File

@@ -416,19 +416,43 @@ def mutually_bound[T: Base, U]():
## Nested typevars
A typevar's constraint can _mention_ another typevar without _constraining_ it. In this example, `U`
must be specialized to `list[T]`, but it cannot affect what `T` is specialized to.
The specialization of one typevar can affect the specialization of another, even if it is not a
"top-level" type in the bounds. (That is, if it appears as inside the specialization of a generic
class.)
```py
from typing import Never
from ty_extensions import ConstraintSet, generic_context
def mentions[T, U]():
# (T@mentions ≤ int) ∧ (U@mentions = list[T@mentions])
constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(list[T], U, list[T])
# TODO: revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = list[int]]
# revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = Unknown]
reveal_type(generic_context(mentions).specialize_constrained(constraints))
class Covariant[T]:
def get(self) -> T:
raise NotImplementedError
class Contravariant[T]:
def receive(self, input: T): ...
class Invariant[T]:
mutable_attribute: T
def mentions_covariant[T, U]():
# (T@mentions_covariant ≤ int) ∧ (U@mentions_covariant ≤ Covariant[T@mentions_covariant])
constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(Never, U, Covariant[T])
# revealed: ty_extensions.Specialization[T@mentions_covariant = int, U@mentions_covariant = Covariant[int]]
reveal_type(generic_context(mentions_covariant).specialize_constrained(constraints))
def mentions_contravariant[T, U]():
# (T@mentions_contravariant ≤ int) ∧ (Contravariant[T@mentions_contravariant] ≤ U@mentions_contravariant)
constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(Contravariant[T], U, object)
# TODO: revealed: ty_extensions.Specialization[T@mentions_contravariant = int, U@mentions_contravariant = Contravariant[int]]
# revealed: ty_extensions.Specialization[T@mentions_contravariant = int, U@mentions_contravariant = Unknown]
reveal_type(generic_context(mentions_contravariant).specialize_constrained(constraints))
def mentions_invariant[T, U]():
# (T@mentions_invariant ≤ int) ∧ (U@mentions_invariant = Invariant[T@mentions_invariant])
constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(Invariant[T], U, Invariant[T])
# TODO: revealed: ty_extensions.Specialization[T@mentions_invariant = int, U@mentions_invariant = Invariant[int]]
# revealed: ty_extensions.Specialization[T@mentions_invariant = int, U@mentions_invariant = Unknown]
reveal_type(generic_context(mentions_invariant).specialize_constrained(constraints))
```
If the constraint set contains mutually recursive bounds, specialization inference will not
@@ -437,8 +461,8 @@ this case.
```py
def divergent[T, U]():
# (T@divergent = list[U@divergent]) ∧ (U@divergent = list[T@divergent]))
constraints = ConstraintSet.range(list[U], T, list[U]) & ConstraintSet.range(list[T], U, list[T])
# (T@divergent = Invariant[U@divergent]) ∧ (U@divergent = Invariant[T@divergent]))
constraints = ConstraintSet.range(Invariant[U], T, Invariant[U]) & ConstraintSet.range(Invariant[T], U, Invariant[T])
# revealed: None
reveal_type(generic_context(divergent).specialize_constrained(constraints))
```

View File

@@ -0,0 +1,407 @@
# Constraint set quantification
```toml
[environment]
python-version = "3.12"
```
We can _existentially quantify_ a constraint set over a type variable. The result is a copy of the
constraint set that only mentions the requested typevar. All constraints mentioning any other
typevars are removed. Importantly, they are removed "safely", with their constraints propagated
through to the remaining constraints as needed.
## Keeping a single typevar
If a constraint set only mentions a single typevar, and we keep that typevar when quantifying, the
result is unchanged.
```py
from ty_extensions import ConstraintSet, static_assert
class Base: ...
class Sub(Base): ...
def keep_single[T]():
constraints = ConstraintSet.always()
quantified = ConstraintSet.always()
static_assert(constraints.retain_one(T) == quantified)
constraints = ConstraintSet.never()
quantified = ConstraintSet.never()
static_assert(constraints.retain_one(T) == quantified)
constraints = ConstraintSet.range(Sub, T, Base)
quantified = ConstraintSet.range(Sub, T, Base)
static_assert(constraints.retain_one(T) == quantified)
```
## Removing a single typevar
If a constraint set only mentions a single typevar, and we remove that typevar when quantifying, the
result is usually "always". The only exception is if the original constraint set has no solution. In
that case, the result is also unsatisfiable.
```py
from ty_extensions import ConstraintSet, static_assert
class Base: ...
class Sub(Base): ...
def remove_single[T]():
constraints = ConstraintSet.always()
quantified = ConstraintSet.always()
static_assert(constraints.exists(T) == quantified)
constraints = ConstraintSet.never()
quantified = ConstraintSet.never()
static_assert(constraints.exists(T) == quantified)
constraints = ConstraintSet.range(Sub, T, Base)
quantified = ConstraintSet.always()
static_assert(constraints.exists(T) == quantified)
```
This also holds when the constraint set contains multiple typevars. In the cases below, we are
keeping `U`, and the constraints on `T` do not ever affect what `U` can specialize to — `U` can
specialize to anything (unless the original constraint set is unsatisfiable).
```py
from ty_extensions import ConstraintSet, static_assert
class Base: ...
class Sub(Base): ...
def remove_other[T, U]():
constraints = ConstraintSet.always()
quantified = ConstraintSet.always()
static_assert(constraints.retain_one(U) == quantified)
constraints = ConstraintSet.never()
quantified = ConstraintSet.never()
static_assert(constraints.retain_one(U) == quantified)
constraints = ConstraintSet.range(Sub, T, Base)
quantified = ConstraintSet.always()
static_assert(constraints.retain_one(U) == quantified)
```
## Transitivity
When a constraint set mentions two typevars, and compares them directly, then we can use
transitivity to propagate the other constraints when quantifying.
```py
from typing import Never
from ty_extensions import ConstraintSet, static_assert
class Super: ...
class Base(Super): ...
class Sub(Base): ...
def transitivity[T, U]():
# (Base ≤ T) ∧ (T ≤ U) → (Base ≤ U)
constraints = ConstraintSet.range(Base, T, object) & ConstraintSet.range(T, U, object)
quantified = ConstraintSet.range(Base, U, object)
static_assert(constraints.exists(T) == quantified)
# (Base ≤ T ≤ Super) ∧ (T ≤ U) → (Base ≤ U)
constraints = ConstraintSet.range(Base, T, Super) & ConstraintSet.range(T, U, object)
quantified = ConstraintSet.range(Base, U, object)
static_assert(constraints.exists(T) == quantified)
# (T ≤ Base) ∧ (U ≤ T) → (U ≤ Base)
constraints = ConstraintSet.range(Never, T, Base) & ConstraintSet.range(Never, U, T)
quantified = ConstraintSet.range(Never, U, Base)
static_assert(constraints.exists(T) == quantified)
# (Sub ≤ T ≤ Base) ∧ (U ≤ T) → (U ≤ Base)
constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Never, U, T)
quantified = ConstraintSet.range(Never, U, Base)
static_assert(constraints.exists(T) == quantified)
```
## Covariant transitivity
The same applies when one of the typevars is used covariantly in a bound of the other typevar.
```py
from typing import Never
from ty_extensions import ConstraintSet, static_assert
class Super: ...
class Base(Super): ...
class Sub(Base): ...
class Covariant[T]:
def get(self) -> T:
raise NotImplementedError
def covariant_transitivity[T, U]():
# (Base ≤ T) ∧ (Covariant[T] ≤ U) → (Covariant[Base] ≤ U)
constraints = ConstraintSet.range(Base, T, object) & ConstraintSet.range(Covariant[T], U, object)
quantified = ConstraintSet.range(Covariant[Base], U, object)
static_assert(constraints.exists(T) == quantified)
# (Base ≤ T ≤ Super) ∧ (Covariant[T] ≤ U) → (Covariant[Base] ≤ U)
constraints = ConstraintSet.range(Base, T, Super) & ConstraintSet.range(Covariant[T], U, object)
quantified = ConstraintSet.range(Covariant[Base], U, object)
static_assert(constraints.exists(T) == quantified)
# (T ≤ Base) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[Base])
constraints = ConstraintSet.range(Never, T, Base) & ConstraintSet.range(Never, U, Covariant[T])
quantified = ConstraintSet.range(Never, U, Covariant[Base])
static_assert(constraints.exists(T) == quantified)
# (Sub ≤ T ≤ Base) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[Base])
constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Never, U, Covariant[T])
quantified = ConstraintSet.range(Never, U, Covariant[Base])
static_assert(constraints.exists(T) == quantified)
```
Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test
with both variable orderings for the constraint that involves two typevars.
```py
def covariant_typevar_transitivity[B, T, U]():
# (B ≤ T) ∧ (Covariant[T] ≤ U) → (Covariant[B] ≤ U)
constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Covariant[T], U, object)
quantified = ConstraintSet.range(Covariant[B], U, object)
static_assert(constraints.exists(T) == quantified)
# (T ≤ B) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[B])
constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Covariant[T])
quantified = ConstraintSet.range(Never, U, Covariant[B])
static_assert(constraints.exists(T) == quantified)
def covariant_typevar_transitivity_reversed[T, B, U]():
# (B ≤ T) ∧ (Covariant[T] ≤ U) → (Covariant[B] ≤ U)
constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Covariant[T], U, object)
quantified = ConstraintSet.range(Covariant[B], U, object)
static_assert(constraints.exists(T) == quantified)
# (T ≤ B) ∧ (U ≤ Covariant[T]) → (U ≤ Covariant[B])
constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Covariant[T])
quantified = ConstraintSet.range(Never, U, Covariant[B])
static_assert(constraints.exists(T) == quantified)
```
## Contravariant transitivity
Similar rules apply, but in reverse, when one of the typevars is used contravariantly in a bound of
the other typevar.
```py
from typing import Never
from ty_extensions import ConstraintSet, static_assert
class Super: ...
class Base(Super): ...
class Sub(Base): ...
class Contravariant[T]:
def receive(self, input: T): ...
def contravariant_transitivity[T, U]():
# (Base ≤ T) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[Base])
constraints = ConstraintSet.range(Base, T, object) & ConstraintSet.range(Never, U, Contravariant[T])
quantified = ConstraintSet.range(Never, U, Contravariant[Base])
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
# (Base ≤ T ≤ Super) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[Base])
constraints = ConstraintSet.range(Base, T, Super) & ConstraintSet.range(Never, U, Contravariant[T])
quantified = ConstraintSet.range(Never, U, Contravariant[Base])
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
# (T ≤ Base) ∧ (Contravariant[T] ≤ U) → (Contravariant[Base] ≤ U)
constraints = ConstraintSet.range(Never, T, Base) & ConstraintSet.range(Contravariant[T], U, object)
quantified = ConstraintSet.range(Contravariant[Base], U, object)
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
# (Sub ≤ T ≤ Base) ∧ (Contravariant[T] ≤ U) → (Contravariant[Base] ≤ U)
constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Contravariant[T], U, object)
quantified = ConstraintSet.range(Contravariant[Base], U, object)
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
```
Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test
with both variable orderings for the constraint that involves two typevars.
```py
def contravariant_typevar_transitivity[B, T, U]():
# (B ≤ T) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[B])
constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Never, U, Contravariant[T])
quantified = ConstraintSet.range(Never, U, Contravariant[B])
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
# (T ≤ B) ∧ (Contravariant[T] ≤ U) → (Contravariant[B] ≤ U)
constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Contravariant[T], U, object)
quantified = ConstraintSet.range(Contravariant[B], U, object)
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
def contravariant_typevar_transitivity_reversed[T, B, U]():
# (B ≤ T) ∧ (U ≤ Contravariant[T]) → (U ≤ Contravariant[B])
constraints = ConstraintSet.range(B, T, object) & ConstraintSet.range(Never, U, Contravariant[T])
quantified = ConstraintSet.range(Never, U, Contravariant[B])
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
# (T ≤ B) ∧ (Contravariant[T] ≤ U) → (Contravariant[B] ≤ U)
constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Contravariant[T], U, object)
quantified = ConstraintSet.range(Contravariant[B], U, object)
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
```
## Invariant transitivity involving equality constraints
Invariant uses of a typevar are more subtle. The simplest case is when there is an _equality_
constraint on the invariant typevar. In that case, we know precisely which specialization is
required.
```py
from typing import Never
from ty_extensions import ConstraintSet, static_assert
class Base: ...
class Invariant[T]:
mutable_attribute: T
def invariant_equality_transitivity[T, U]():
# (T = Base) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Base])
constraints = ConstraintSet.range(Base, T, Base) & ConstraintSet.range(Never, U, Invariant[T])
quantified = ConstraintSet.range(Never, U, Invariant[Base])
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
# (T = Base) ∧ (Invariant[T] ≤ U) → (Invariant[Base] ≤ U)
constraints = ConstraintSet.range(Base, T, Base) & ConstraintSet.range(Invariant[T], U, object)
quantified = ConstraintSet.range(Invariant[Base], U, object)
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
```
Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test
with both variable orderings for the constraint that involves two typevars.
```py
def invariant_equality_typevar_transitivity[B, T, U]():
# (T = B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[B])
constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Never, U, Invariant[T])
quantified = ConstraintSet.range(Never, U, Invariant[B])
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
# (T = B) ∧ (Invariant[T] ≤ U) → (Invariant[B] ≤ U)
constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Invariant[T], U, object)
quantified = ConstraintSet.range(Invariant[B], U, object)
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
def invariant_equality_typevar_transitivity_reverse[T, B, U]():
# (T = B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[B])
constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Never, U, Invariant[T])
quantified = ConstraintSet.range(Never, U, Invariant[B])
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
# (T = B) ∧ (Invariant[T] ≤ U) → (Invariant[B] ≤ U)
constraints = ConstraintSet.range(B, T, B) & ConstraintSet.range(Invariant[T], U, object)
quantified = ConstraintSet.range(Invariant[B], U, object)
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
```
## Invariant transitivity involving range constraints
When there is a _range_ constraint on the invariant typevar, we still have to retain information
about which range of types the quantified-away typevar can specialize to, since this affects which
types the remaining typevar can specialize to, and invariant typevars are not monotonic like
covariant and contravariant typevars.
```py
from typing import Never
from ty_extensions import ConstraintSet, static_assert
class Base: ...
class Sub(Base): ...
class Invariant[T]:
mutable_attribute: T
def invariant_range_transitivity[T, U]():
# (Sub ≤ T ≤ Base) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Exists[Sub, Base]])
constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Never, U, Invariant[T])
# TODO: The existential that we need doesn't exist yet.
quantified = ConstraintSet.never()
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
# (Sub ≤ T ≤ Base) ∧ (Invariant[T] ≤ U) → (Invariant[Exists[Sub, Base]] ≤ U)
constraints = ConstraintSet.range(Sub, T, Base) & ConstraintSet.range(Invariant[T], U, object)
# TODO: The existential that we need doesn't exist yet.
quantified = ConstraintSet.never()
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
```
Same as above, but when propagating a third typevar instead of a concrete type. We make sure to test
with both variable orderings for the constraint that involves two typevars.
```py
def invariant_range_typevar_transitivity[B, T, U]():
# (T ≤ B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Exists[Never, B]])
constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Invariant[T])
# TODO: The existential that we need doesn't exist yet.
quantified = ConstraintSet.never()
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
# (T ≤ B) ∧ (Invariant[T] ≤ U) → (Invariant[Exists[Never, B]] ≤ U)
constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Invariant[T], U, object)
# TODO: The existential that we need doesn't exist yet.
quantified = ConstraintSet.never()
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
def invariant_range_typevar_transitivity_reverse[T, B, U]():
# (T ≤ B) ∧ (U ≤ Invariant[T]) → (U ≤ Invariant[Exists[Never, B]])
constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Never, U, Invariant[T])
# TODO: The existential that we need doesn't exist yet.
quantified = ConstraintSet.never()
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
# (T ≤ B) ∧ (Invariant[T] ≤ U) → (Invariant[Exists[Never, B]] ≤ U)
constraints = ConstraintSet.range(Never, T, B) & ConstraintSet.range(Invariant[T], U, object)
# TODO: The existential that we need doesn't exist yet.
quantified = ConstraintSet.never()
# TODO: no error
# error: [static-assert-error]
static_assert(constraints.exists(T) == quantified)
```

View File

@@ -5026,6 +5026,12 @@ impl<'db> Type<'db> {
))
.into()
}
Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked)) if name == "exists" => {
Place::bound(Type::KnownBoundMethod(
KnownBoundMethodType::ConstraintSetExists(tracked),
))
.into()
}
Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked))
if name == "implies_subtype_of" =>
{
@@ -5034,6 +5040,14 @@ impl<'db> Type<'db> {
))
.into()
}
Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked))
if name == "retain_one" =>
{
Place::bound(Type::KnownBoundMethod(
KnownBoundMethodType::ConstraintSetRetainOne(tracked),
))
.into()
}
Type::KnownInstance(KnownInstanceType::ConstraintSet(tracked))
if name == "satisfies" =>
{
@@ -8035,7 +8049,9 @@ impl<'db> Type<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetExists(_)
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetRetainOne(_)
| KnownBoundMethodType::ConstraintSetSatisfies(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_)
| KnownBoundMethodType::GenericContextSpecializeConstrained(_)
@@ -8254,7 +8270,9 @@ impl<'db> Type<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetExists(_)
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetRetainOne(_)
| KnownBoundMethodType::ConstraintSetSatisfies(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_)
| KnownBoundMethodType::GenericContextSpecializeConstrained(_),
@@ -12674,7 +12692,9 @@ pub enum KnownBoundMethodType<'db> {
ConstraintSetRange,
ConstraintSetAlways,
ConstraintSetNever,
ConstraintSetExists(TrackedConstraintSet<'db>),
ConstraintSetImpliesSubtypeOf(TrackedConstraintSet<'db>),
ConstraintSetRetainOne(TrackedConstraintSet<'db>),
ConstraintSetSatisfies(TrackedConstraintSet<'db>),
ConstraintSetSatisfiedByAllTypeVars(TrackedConstraintSet<'db>),
@@ -12706,7 +12726,9 @@ pub(super) fn walk_method_wrapper_type<'db, V: visitor::TypeVisitor<'db> + ?Size
KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetExists(_)
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetRetainOne(_)
| KnownBoundMethodType::ConstraintSetSatisfies(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_)
| KnownBoundMethodType::GenericContextSpecializeConstrained(_) => {}
@@ -12773,10 +12795,18 @@ impl<'db> KnownBoundMethodType<'db> {
KnownBoundMethodType::ConstraintSetNever,
KnownBoundMethodType::ConstraintSetNever,
)
| (
KnownBoundMethodType::ConstraintSetExists(_),
KnownBoundMethodType::ConstraintSetExists(_),
)
| (
KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_),
KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_),
)
| (
KnownBoundMethodType::ConstraintSetRetainOne(_),
KnownBoundMethodType::ConstraintSetRetainOne(_),
)
| (
KnownBoundMethodType::ConstraintSetSatisfies(_),
KnownBoundMethodType::ConstraintSetSatisfies(_),
@@ -12799,7 +12829,9 @@ impl<'db> KnownBoundMethodType<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetExists(_)
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetRetainOne(_)
| KnownBoundMethodType::ConstraintSetSatisfies(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_)
| KnownBoundMethodType::GenericContextSpecializeConstrained(_),
@@ -12811,7 +12843,9 @@ impl<'db> KnownBoundMethodType<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetExists(_)
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetRetainOne(_)
| KnownBoundMethodType::ConstraintSetSatisfies(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_)
| KnownBoundMethodType::GenericContextSpecializeConstrained(_),
@@ -12864,9 +12898,17 @@ impl<'db> KnownBoundMethodType<'db> {
) => ConstraintSet::from(true),
(
KnownBoundMethodType::ConstraintSetExists(left_constraints),
KnownBoundMethodType::ConstraintSetExists(right_constraints),
)
| (
KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(left_constraints),
KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(right_constraints),
)
| (
KnownBoundMethodType::ConstraintSetRetainOne(left_constraints),
KnownBoundMethodType::ConstraintSetRetainOne(right_constraints),
)
| (
KnownBoundMethodType::ConstraintSetSatisfies(left_constraints),
KnownBoundMethodType::ConstraintSetSatisfies(right_constraints),
@@ -12892,7 +12934,9 @@ impl<'db> KnownBoundMethodType<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetExists(_)
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetRetainOne(_)
| KnownBoundMethodType::ConstraintSetSatisfies(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_)
| KnownBoundMethodType::GenericContextSpecializeConstrained(_),
@@ -12904,7 +12948,9 @@ impl<'db> KnownBoundMethodType<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetExists(_)
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetRetainOne(_)
| KnownBoundMethodType::ConstraintSetSatisfies(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_)
| KnownBoundMethodType::GenericContextSpecializeConstrained(_),
@@ -12930,7 +12976,9 @@ impl<'db> KnownBoundMethodType<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetExists(_)
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetRetainOne(_)
| KnownBoundMethodType::ConstraintSetSatisfies(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_)
| KnownBoundMethodType::GenericContextSpecializeConstrained(_) => self,
@@ -12968,7 +13016,9 @@ impl<'db> KnownBoundMethodType<'db> {
| KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetExists(_)
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetRetainOne(_)
| KnownBoundMethodType::ConstraintSetSatisfies(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_)
| KnownBoundMethodType::GenericContextSpecializeConstrained(_) => Some(self),
@@ -12986,7 +13036,9 @@ impl<'db> KnownBoundMethodType<'db> {
KnownBoundMethodType::ConstraintSetRange
| KnownBoundMethodType::ConstraintSetAlways
| KnownBoundMethodType::ConstraintSetNever
| KnownBoundMethodType::ConstraintSetExists(_)
| KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_)
| KnownBoundMethodType::ConstraintSetRetainOne(_)
| KnownBoundMethodType::ConstraintSetSatisfies(_)
| KnownBoundMethodType::ConstraintSetSatisfiedByAllTypeVars(_)
| KnownBoundMethodType::GenericContextSpecializeConstrained(_) => {
@@ -13132,6 +13184,18 @@ impl<'db> KnownBoundMethodType<'db> {
)))
}
KnownBoundMethodType::ConstraintSetExists(_) => {
Either::Right(std::iter::once(Signature::new(
Parameters::new(
db,
[Parameter::variadic(Name::new_static("typevars"))
.type_form()
.with_annotated_type(Type::any())],
),
Some(KnownClass::ConstraintSet.to_instance(db)),
)))
}
KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => {
Either::Right(std::iter::once(Signature::new(
Parameters::new(
@@ -13149,6 +13213,20 @@ impl<'db> KnownBoundMethodType<'db> {
)))
}
KnownBoundMethodType::ConstraintSetRetainOne(_) => {
Either::Right(std::iter::once(Signature::new(
Parameters::new(
db,
[
Parameter::positional_only(Some(Name::new_static("typevar")))
.type_form()
.with_annotated_type(Type::any()),
],
),
Some(KnownClass::ConstraintSet.to_instance(db)),
)))
}
KnownBoundMethodType::ConstraintSetSatisfies(_) => {
Either::Right(std::iter::once(Signature::new(
Parameters::new(

View File

@@ -1255,6 +1255,25 @@ impl<'db> Bindings<'db> {
));
}
Type::KnownBoundMethod(KnownBoundMethodType::ConstraintSetExists(tracked)) => {
let typevars: Option<Vec<_>> = overload
.arguments_for_parameter(argument_types, 0)
.map(|(_, ty)| {
ty.as_typevar()
.map(|bound_typevar| bound_typevar.identity(db))
})
.collect();
let Some(typevars) = typevars else {
continue;
};
let result = tracked.constraints(db).exists(db, typevars);
let tracked = TrackedConstraintSet::new(db, result);
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::ConstraintSet(tracked),
));
}
Type::KnownBoundMethod(
KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(tracked),
) => {
@@ -1274,6 +1293,20 @@ impl<'db> Bindings<'db> {
));
}
Type::KnownBoundMethod(KnownBoundMethodType::ConstraintSetRetainOne(
tracked,
)) => {
let [Some(Type::TypeVar(typevar))] = overload.parameter_types() else {
continue;
};
let result = tracked.constraints(db).retain_one(db, typevar.identity(db));
let tracked = TrackedConstraintSet::new(db, result);
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::ConstraintSet(tracked),
));
}
Type::KnownBoundMethod(KnownBoundMethodType::ConstraintSetSatisfies(
tracked,
)) => {

View File

@@ -75,13 +75,16 @@ use itertools::Itertools;
use rustc_hash::{FxHashMap, FxHashSet};
use salsa::plumbing::AsId;
use crate::types::generics::{GenericContext, InferableTypeVars, Specialization};
use crate::types::generics::{
GenericContext, InferableTypeVars, PartialSpecialization, Specialization,
};
use crate::types::variance::VarianceInferable;
use crate::types::visitor::{
TypeCollector, TypeVisitor, any_over_type, walk_type_with_recursion_guard,
};
use crate::types::{
BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeVarBoundOrConstraints,
UnionType, walk_bound_type_var_type,
BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeContext, TypeMapping,
TypeVarBoundOrConstraints, TypeVarVariance, UnionType, walk_bound_type_var_type,
};
use crate::{Db, FxOrderMap};
@@ -432,6 +435,29 @@ impl<'db> ConstraintSet<'db> {
}
}
/// Returns a new constraint set that is the _existential abstraction_ of `self` for a set of
/// typevars. The result will return true whenever `self` returns true for _any_ assignment of
/// those typevars. The result will not contain any constraints that mention those typevars.
pub(crate) fn exists(
self,
db: &'db dyn Db,
bound_typevars: impl IntoIterator<Item = BoundTypeVarIdentity<'db>>,
) -> Self {
let node = self.node.exists(db, bound_typevars);
Self { node }
}
/// Quantifies over this constraint set so that it only contains constraints that mention the
/// given typevar. All other typevars are quantified away.
pub(crate) fn retain_one(
self,
db: &'db dyn Db,
bound_typevar: BoundTypeVarIdentity<'db>,
) -> Self {
let node = self.node.retain_one(db, bound_typevar);
Self { node }
}
/// Reduces the set of inferable typevars for this constraint set. You provide an iterator of
/// the typevars that were inferable when this constraint set was created, and which should be
/// abstracted away. Those typevars will be removed from the constraint set, and the constraint
@@ -3027,30 +3053,59 @@ impl<'db> SequentMap<'db> {
left_constraint: ConstrainedTypeVar<'db>,
right_constraint: ConstrainedTypeVar<'db>,
) {
// We've structured our constraints so that a typevar's upper/lower bound can only
// be another typevar if the bound is "later" in our arbitrary ordering. That means
// we only have to check this pair of constraints in one direction — though we do
// have to figure out which of the two typevars is constrained, and which one is
// the upper/lower bound.
// Add sequents when the two typevars are mutually constrained directly — that is, when the
// lower/upper bound _is_ the typevar, not _contains_ the typevar. We've structured our
// constraints so that a typevar's upper/lower bound can only be another typevar if the
// bound is "later" in our arbitrary ordering. That means we only have to check this pair
// of constraints in one direction — though we do have to figure out which of the two
// typevars is constrained, and which one is the upper/lower bound.
let left_typevar = left_constraint.typevar(db);
let right_typevar = right_constraint.typevar(db);
let (bound_typevar, bound_constraint, constrained_typevar, constrained_constraint) =
if left_typevar.can_be_bound_for(db, right_typevar) {
(
left_typevar,
left_constraint,
right_typevar,
right_constraint,
)
} else {
(
right_typevar,
right_constraint,
left_typevar,
left_constraint,
)
};
if left_typevar.can_be_bound_for(db, right_typevar) {
self.add_direct_mutual_sequents_for_different_typevars(
db,
left_constraint,
left_typevar,
right_constraint,
right_typevar,
);
} else {
self.add_direct_mutual_sequents_for_different_typevars(
db,
right_constraint,
right_typevar,
left_constraint,
left_typevar,
);
}
// Add sequents when one of the typevars is mentioned deeply inside the bounds of the
// other. Because the "bound" typevar appears deeply inside of the "constrained" typevar,
// our constraint ordering doesn't apply, and we have to try each direction.
self.add_deep_mutual_sequents(
db,
left_constraint,
left_typevar,
right_constraint,
right_typevar,
);
self.add_deep_mutual_sequents(
db,
right_constraint,
right_typevar,
left_constraint,
left_typevar,
);
}
fn add_direct_mutual_sequents_for_different_typevars(
&mut self,
db: &'db dyn Db,
bound_constraint: ConstrainedTypeVar<'db>,
bound_typevar: BoundTypeVarInstance<'db>,
constrained_constraint: ConstrainedTypeVar<'db>,
constrained_typevar: BoundTypeVarInstance<'db>,
) {
// We then look for cases where the "constrained" typevar's upper and/or lower bound
// matches the "bound" typevar. If so, we're going to add an implication sequent that
// replaces the upper/lower bound that matched with the bound constraint's corresponding
@@ -3104,10 +3159,157 @@ impl<'db> SequentMap<'db> {
let post_constraint =
ConstrainedTypeVar::new(db, constrained_typevar, new_lower, new_upper);
self.add_pair_implication(db, left_constraint, right_constraint, post_constraint);
self.add_pair_implication(
db,
bound_constraint,
constrained_constraint,
post_constraint,
);
self.enqueue_constraint(post_constraint);
}
fn add_deep_mutual_sequents(
&mut self,
db: &'db dyn Db,
bound_constraint: ConstrainedTypeVar<'db>,
bound_typevar: BoundTypeVarInstance<'db>,
constrained_constraint: ConstrainedTypeVar<'db>,
constrained_typevar: BoundTypeVarInstance<'db>,
) {
let bound_lower = bound_constraint.lower(db);
let bound_upper = bound_constraint.upper(db);
let constrained_lower = constrained_constraint.lower(db);
let constrained_upper = constrained_constraint.upper(db);
let mut add_constraint = |post_constraint| {
self.add_pair_implication(
db,
bound_constraint,
constrained_constraint,
post_constraint,
);
self.enqueue_constraint(post_constraint);
};
match constrained_lower.variance_of(db, bound_typevar) {
// B does not appear in CU, or if it does, it appears bivariantly. The constraints of B
// do not affect the valid specializations of C.
TypeVarVariance::Bivariant => {}
// (Covariant[B] ≤ C ≤ CU) ∧ (BL ≤ B ≤ BU) → (Covariant[BL] ≤ C ≤ CU)
TypeVarVariance::Covariant => {
// Only substitute to create a new sequent if the substitution is interesting, and
// doesn't recursively contain the typevar we are substituting for.
if !bound_lower.is_never()
&& !bound_lower.is_object()
&& bound_lower.variance_of(db, bound_typevar) == TypeVarVariance::Bivariant
{
let partial = PartialSpecialization::Single {
bound_typevar,
ty: bound_lower,
};
let new_lower = constrained_lower.apply_type_mapping(
db,
&TypeMapping::PartialSpecialization(partial),
TypeContext::default(),
);
add_constraint(ConstrainedTypeVar::new(
db,
constrained_typevar,
new_lower,
constrained_upper,
));
}
}
// TODO
TypeVarVariance::Contravariant | TypeVarVariance::Invariant => {}
}
// (Covariant[BU] ≤ C ≤ CU) ∧ (BL ≤ B ≤ BU) → (Covariant[B] ≤ C ≤ CU)
if let Type::TypeVar(bound_upper_typevar) = bound_upper
&& !bound_upper_typevar.is_same_typevar_as(db, constrained_typevar)
{
if constrained_lower.variance_of(db, bound_upper_typevar) == TypeVarVariance::Covariant
{
let partial = PartialSpecialization::Single {
bound_typevar: bound_upper_typevar,
ty: Type::TypeVar(bound_typevar),
};
let new_lower = constrained_lower.apply_type_mapping(
db,
&TypeMapping::PartialSpecialization(partial),
TypeContext::default(),
);
add_constraint(ConstrainedTypeVar::new(
db,
constrained_typevar,
new_lower,
constrained_upper,
));
}
}
match constrained_upper.variance_of(db, bound_typevar) {
// B does not appear in CU, or if it does, it appears bivariantly. The constraints of B
// do not affect the valid specializations of C.
TypeVarVariance::Bivariant => {}
// (CL ≤ C ≤ Covariant[B]) ∧ (BL ≤ B ≤ BU) → (CL ≤ C ≤ Covariant[BU])
TypeVarVariance::Covariant => {
// Only substitute to create a new sequent if the substitution is interesting, and
// doesn't recursively contain the typevar we are substituting for.
if !bound_upper.is_never()
&& !bound_upper.is_object()
&& bound_upper.variance_of(db, bound_typevar) == TypeVarVariance::Bivariant
{
let partial = PartialSpecialization::Single {
bound_typevar,
ty: bound_upper,
};
let new_upper = constrained_upper.apply_type_mapping(
db,
&TypeMapping::PartialSpecialization(partial),
TypeContext::default(),
);
add_constraint(ConstrainedTypeVar::new(
db,
constrained_typevar,
constrained_lower,
new_upper,
));
}
}
// TODO
TypeVarVariance::Contravariant | TypeVarVariance::Invariant => {}
}
// (CL ≤ C ≤ Covariant[BL]) ∧ (BL ≤ B ≤ BU) → (CL ≤ C ≤ Covariant[B])
if let Type::TypeVar(bound_lower_typevar) = bound_lower
&& !bound_lower_typevar.is_same_typevar_as(db, constrained_typevar)
{
if constrained_upper.variance_of(db, bound_lower_typevar) == TypeVarVariance::Covariant
{
let partial = PartialSpecialization::Single {
bound_typevar: bound_lower_typevar,
ty: Type::TypeVar(bound_typevar),
};
let new_upper = constrained_upper.apply_type_mapping(
db,
&TypeMapping::PartialSpecialization(partial),
TypeContext::default(),
);
add_constraint(ConstrainedTypeVar::new(
db,
constrained_typevar,
constrained_lower,
new_upper,
));
}
}
}
fn add_mutual_sequents_for_same_typevars(
&mut self,
db: &'db dyn Db,

View File

@@ -838,9 +838,15 @@ impl<'db> FmtDetailed<'db> for DisplayRepresentation<'db> {
KnownBoundMethodType::ConstraintSetNever => {
return f.write_str("bound method `ConstraintSet.never`");
}
KnownBoundMethodType::ConstraintSetExists(_) => {
return f.write_str("bound method `ConstraintSet.exists`");
}
KnownBoundMethodType::ConstraintSetImpliesSubtypeOf(_) => {
return f.write_str("bound method `ConstraintSet.implies_subtype_of`");
}
KnownBoundMethodType::ConstraintSetRetainOne(_) => {
return f.write_str("bound method `ConstraintSet.retain_one`");
}
KnownBoundMethodType::ConstraintSetSatisfies(_) => {
return f.write_str("bound method `ConstraintSet.satisfies`");
}

View File

@@ -568,7 +568,7 @@ impl<'db> GenericContext<'db> {
loop {
let mut any_changed = false;
for i in 0..len {
let partial = PartialSpecialization {
let partial = PartialSpecialization::FromGenericContext {
generic_context: self,
types: &types,
// Don't recursively substitute type[i] in itself. Ideally, we could instead
@@ -646,7 +646,7 @@ impl<'db> GenericContext<'db> {
// Typevars are only allowed to refer to _earlier_ typevars in their defaults. (This is
// statically enforced for PEP-695 contexts, and is explicitly called out as a
// requirement for legacy contexts.)
let partial = PartialSpecialization {
let partial = PartialSpecialization::FromGenericContext {
generic_context: self,
types: &expanded[0..idx],
skip: None,
@@ -1452,12 +1452,18 @@ impl<'db> Specialization<'db> {
/// You will usually use [`Specialization`] instead of this type. This type is used when we need to
/// substitute types for type variables before we have fully constructed a [`Specialization`].
#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)]
pub struct PartialSpecialization<'a, 'db> {
generic_context: GenericContext<'db>,
types: &'a [Type<'db>],
/// An optional typevar to _not_ substitute when applying the specialization. We use this to
/// avoid recursively substituting a type inside of itself.
skip: Option<usize>,
pub enum PartialSpecialization<'a, 'db> {
FromGenericContext {
generic_context: GenericContext<'db>,
types: &'a [Type<'db>],
/// An optional typevar to _not_ substitute when applying the specialization. We use this to
/// avoid recursively substituting a type inside of itself.
skip: Option<usize>,
},
Single {
bound_typevar: BoundTypeVarInstance<'db>,
ty: Type<'db>,
},
}
impl<'db> PartialSpecialization<'_, 'db> {
@@ -1466,16 +1472,30 @@ impl<'db> PartialSpecialization<'_, 'db> {
pub(crate) fn get(
&self,
db: &'db dyn Db,
bound_typevar: BoundTypeVarInstance<'db>,
needle_bound_typevar: BoundTypeVarInstance<'db>,
) -> Option<Type<'db>> {
let index = self
.generic_context
.variables_inner(db)
.get_index_of(&bound_typevar.identity(db))?;
if self.skip.is_some_and(|skip| skip == index) {
return Some(Type::Never);
match self {
PartialSpecialization::FromGenericContext {
generic_context,
types,
skip,
} => {
let index = generic_context
.variables_inner(db)
.get_index_of(&needle_bound_typevar.identity(db))?;
if skip.is_some_and(|skip| skip == index) {
return Some(Type::Never);
}
types.get(index).copied()
}
PartialSpecialization::Single { bound_typevar, ty } => {
if bound_typevar.is_same_typevar_as(db, needle_bound_typevar) {
Some(*ty)
} else {
None
}
}
}
self.types.get(index).copied()
}
}
@@ -1540,7 +1560,7 @@ impl<'db> SpecializationBuilder<'db> {
bound_typevar: BoundTypeVarInstance<'db>,
ty: Type<'db>,
variance: TypeVarVariance,
mut f: impl FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
) {
let identity = bound_typevar.identity(self.db);
let Some(ty) = f((identity, variance, ty)) else {
@@ -1582,7 +1602,7 @@ impl<'db> SpecializationBuilder<'db> {
&mut self,
formal: Type<'db>,
constraints: ConstraintSet<'db>,
mut f: impl FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
) {
#[derive(Default)]
struct Bounds<'db> {
@@ -1626,16 +1646,20 @@ impl<'db> SpecializationBuilder<'db> {
}
for (bound_typevar, bounds) in mappings.drain() {
let variance = formal.variance_of(self.db, bound_typevar);
let upper = IntersectionType::from_elements(self.db, bounds.upper);
if !upper.is_object() {
self.add_type_mapping(bound_typevar, upper, variance, &mut f);
let try_upper = || {
let upper = IntersectionType::from_elements(self.db, bounds.upper);
(!upper.is_object()).then_some(upper)
};
let try_lower = || {
let lower = UnionType::from_elements(self.db, bounds.lower);
(!lower.is_never()).then_some(lower)
};
let Some(mapped_type) = try_upper().or_else(try_lower) else {
continue;
}
let lower = UnionType::from_elements(self.db, bounds.lower);
if !lower.is_never() {
self.add_type_mapping(bound_typevar, lower, variance, &mut f);
}
};
let variance = formal.variance_of(self.db, bound_typevar);
self.add_type_mapping(bound_typevar, mapped_type, variance, f);
}
}
}
@@ -1667,7 +1691,7 @@ impl<'db> SpecializationBuilder<'db> {
formal: Type<'db>,
actual: Type<'db>,
polarity: TypeVarVariance,
mut f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
) -> Result<(), SpecializationError<'db>> {
// TODO: Eventually, the builder will maintain a constraint set, instead of a hash-map of
// type mappings, to represent the specialization that we are building up. At that point,
@@ -1779,7 +1803,7 @@ impl<'db> SpecializationBuilder<'db> {
let mut first_error = None;
let mut found_matching_element = false;
for formal_element in union_formal.elements(self.db) {
let result = self.infer_map_impl(*formal_element, actual, polarity, &mut f);
let result = self.infer_map_impl(*formal_element, actual, polarity, f);
if let Err(err) = result {
first_error.get_or_insert(err);
} else {
@@ -1882,7 +1906,7 @@ impl<'db> SpecializationBuilder<'db> {
formal_tuple.all_elements().zip(actual_tuple.all_elements())
{
let variance = TypeVarVariance::Covariant.compose(polarity);
self.infer_map_impl(*formal_element, *actual_element, variance, &mut f)?;
self.infer_map_impl(*formal_element, *actual_element, variance, f)?;
}
return Ok(());
}
@@ -1925,7 +1949,7 @@ impl<'db> SpecializationBuilder<'db> {
base_specialization
) {
let variance = typevar.variance_with_polarity(self.db, polarity);
self.infer_map_impl(*formal_ty, *base_ty, variance, &mut f)?;
self.infer_map_impl(*formal_ty, *base_ty, variance, f)?;
}
return Ok(());
}
@@ -1949,7 +1973,7 @@ impl<'db> SpecializationBuilder<'db> {
formal_callable,
self.inferable,
);
self.add_type_mappings_from_constraint_set(formal, when, &mut f);
self.add_type_mappings_from_constraint_set(formal, when, f);
} else {
for actual_signature in &actual_callable.signatures(self.db).overloads {
let when = actual_signature
@@ -1958,7 +1982,7 @@ impl<'db> SpecializationBuilder<'db> {
formal_callable,
self.inferable,
);
self.add_type_mappings_from_constraint_set(formal, when, &mut f);
self.add_type_mappings_from_constraint_set(formal, when, f);
}
}
}

View File

@@ -59,6 +59,14 @@ class ConstraintSet:
def never() -> Self:
"""Returns a constraint set that is never satisfied"""
def exists(self, *typevars: Any) -> Self:
"""
Returns a new constraint set that is the _existential abstraction_ of
`self` for a set of typevars. The result will return true whenever
`self` returns true for _any_ assignment of those typevars. The result
will not contain any constraints that mention those typevars.
"""
def implies_subtype_of(self, ty: Any, of: Any) -> Self:
"""
Returns a constraint set that is satisfied when `ty` is a `subtype`_ of
@@ -67,6 +75,12 @@ class ConstraintSet:
.. _subtype: https://typing.python.org/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence
"""
def retain_one(self, typevar: Any) -> Self:
"""
Quantifies over this constraint set so that it only contains constraints
that mention `typevar`. All other typevars are quantified away.
"""
def satisfies(self, other: Self) -> Self:
"""
Returns whether this constraint set satisfies another — that is, whether