## Summary Historically we have avoided narrowing on `==` tests because in many cases it's unsound, since subclasses of a type could compare equal to who-knows-what. But there are a lot of types (literals and unions of them, as well as some known instances like `None` -- single-valued types) whose `__eq__` behavior we know, and which we can safely narrow away based on equality comparisons. This PR implements equality narrowing in the cases where it is sound. The most elegant way to do this (and the way that is most in-line with our approach up until now) would be to introduce new Type variants `NeverEqualTo[...]` and `AlwaysEqualTo[...]`, and then implement all type relations for those variants, narrow by intersection, and let union and intersection simplification sort it all out. This is analogous to our existing handling for `AlwaysFalse` and `AlwaysTrue`. But I'm reluctant to add new `Type` variants for this, mostly because they could end up un-simplified in some types and make types even more complex. So let's try this approach, where we handle more of the narrowing logic as a special case. ## Test Plan Updated and added tests. --------- Co-authored-by: Carl Meyer <carl@astral.sh> Co-authored-by: Carl Meyer <carl@oddbird.net> Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
1.1 KiB
1.1 KiB
Narrowing with assert statements
assert a value is None or is not None
def _(x: str | None, y: str | None):
assert x is not None
reveal_type(x) # revealed: str
assert y is None
reveal_type(y) # revealed: None
assert a value is truthy or falsy
def _(x: bool, y: bool):
assert x
reveal_type(x) # revealed: Literal[True]
assert not y
reveal_type(y) # revealed: Literal[False]
assert with is and == for literals
from typing import Literal
def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
assert x is 2
reveal_type(x) # revealed: Literal[2]
assert y == 2
reveal_type(y) # revealed: Literal[2]
assert with isinstance
def _(x: int | str):
assert isinstance(x, int)
reveal_type(x) # revealed: int
assert a value in a tuple
from typing import Literal
def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
assert x in (1, 2)
reveal_type(x) # revealed: Literal[1, 2]
assert y not in (1, 2)
reveal_type(y) # revealed: Literal[3]