Files
ruff/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md
David Peter dc66019fbc [ty] Expansion of enums into unions of literals (#19382)
## Summary

Implement expansion of enums into unions of enum literals (and the
reverse operation). For the enum below, this allows us to understand
that `Color = Literal[Color.RED, Color.GREEN, Color.BLUE]`, or that
`Color & ~Literal[Color.RED] = Literal[Color.GREEN, Color.BLUE]`. This
helps in exhaustiveness checking, which is why we see some removed
`assert_never` false positives. And since exhaustiveness checking also
helps with understanding terminal control flow, we also see a few
removed `invalid-return-type` and `possibly-unresolved-reference` false
positives. This PR also adds expansion of enums in overload resolution
and type narrowing constructs.

```py
from enum import Enum
from typing_extensions import Literal, assert_never
from ty_extensions import Intersection, Not, static_assert, is_equivalent_to

class Color(Enum):
    RED = 1
    GREEN = 2
    BLUE = 3

type Red = Literal[Color.RED]
type Green = Literal[Color.GREEN]
type Blue = Literal[Color.BLUE]

static_assert(is_equivalent_to(Red | Green | Blue, Color))
static_assert(is_equivalent_to(Intersection[Color, Not[Red]], Green | Blue))


def color_name(color: Color) -> str:  # no error here (we detect that this can not implicitly return None)
    if color is Color.RED:
        return "Red"
    elif color is Color.GREEN:
        return "Green"
    elif color is Color.BLUE:
        return "Blue"
    else:
        assert_never(color)  # no error here
```

## Performance

I avoided an initial regression here for large enums, but the
`UnionBuilder` and `IntersectionBuilder` parts can certainly still be
optimized. We might want to use the same technique that we also use for
unions of other literals. I didn't see any problems in our benchmarks so
far, so this is not included yet.

## Test Plan

Many new Markdown tests
2025-07-21 19:37:55 +02:00

4.9 KiB

Narrowing for != and == conditionals

x != None

def _(flag: bool):
    x = None if flag else 1

    if x != None:
        reveal_type(x)  # revealed: Literal[1]
    else:
        reveal_type(x)  # revealed: None

!= for other singleton types

Bool

def _(x: bool):
    if x != False:
        reveal_type(x)  # revealed: Literal[True]
    else:
        reveal_type(x)  # revealed: Literal[False]

def _(x: bool):
    if x == False:
        reveal_type(x)  # revealed: Literal[False]
    else:
        reveal_type(x)  # revealed: Literal[True]

Enums

from enum import Enum

class Answer(Enum):
    NO = 0
    YES = 1

def _(answer: Answer):
    if answer != Answer.NO:
        reveal_type(answer)  # revealed: Literal[Answer.YES]
    else:
        reveal_type(answer)  # revealed: Literal[Answer.NO]

def _(answer: Answer):
    if answer == Answer.NO:
        reveal_type(answer)  # revealed: Literal[Answer.NO]
    else:
        reveal_type(answer)  # revealed: Literal[Answer.YES]

class Single(Enum):
    VALUE = 1

def _(x: Single | int):
    if x != Single.VALUE:
        reveal_type(x)  # revealed: int
    else:
        # `int` is not eliminated here because there could be subclasses of `int` with custom `__eq__`/`__ne__` methods
        reveal_type(x)  # revealed: Single | int

def _(x: Single | int):
    if x == Single.VALUE:
        reveal_type(x)  # revealed: Single | int
    else:
        reveal_type(x)  # revealed: int

This narrowing behavior is only safe if the enum has no custom __eq__/__ne__ method:

from enum import Enum

class AmbiguousEnum(Enum):
    NO = 0
    YES = 1

    def __ne__(self, other) -> bool:
        return True

def _(answer: AmbiguousEnum):
    if answer != AmbiguousEnum.NO:
        reveal_type(answer)  # revealed: AmbiguousEnum
    else:
        reveal_type(answer)  # revealed: AmbiguousEnum

Similar if that method is inherited from a base class:

from enum import Enum

class Mixin:
    def __eq__(self, other) -> bool:
        return True

class AmbiguousEnum(Mixin, Enum):
    NO = 0
    YES = 1

def _(answer: AmbiguousEnum):
    if answer == AmbiguousEnum.NO:
        reveal_type(answer)  # revealed: AmbiguousEnum
    else:
        reveal_type(answer)  # revealed: AmbiguousEnum

x != y where y is of literal type

def _(flag: bool):
    x = 1 if flag else 2

    if x != 1:
        reveal_type(x)  # revealed: Literal[2]

x != y where y is a single-valued type

def _(flag: bool):
    class A: ...
    class B: ...
    C = A if flag else B

    if C != A:
        reveal_type(C)  # revealed: <class 'B'>
    else:
        reveal_type(C)  # revealed: <class 'A'>

x != y where y has multiple single-valued options

def _(flag1: bool, flag2: bool):
    x = 1 if flag1 else 2
    y = 2 if flag2 else 3

    if x != y:
        reveal_type(x)  # revealed: Literal[1, 2]
    else:
        reveal_type(x)  # revealed: Literal[2]

!= for non-single-valued types

Only single-valued types should narrow the type:

def _(flag: bool, a: int, y: int):
    x = a if flag else None

    if x != y:
        reveal_type(x)  # revealed: int | None

Mix of single-valued and non-single-valued types

def _(flag1: bool, flag2: bool, a: int):
    x = 1 if flag1 else 2
    y = 2 if flag2 else a

    if x != y:
        reveal_type(x)  # revealed: Literal[1, 2]
    else:
        reveal_type(x)  # revealed: Literal[1, 2]

Assignment expressions

from typing import Literal

def f() -> Literal[1, 2, 3]:
    return 1

if (x := f()) != 1:
    reveal_type(x)  # revealed: Literal[2, 3]
else:
    reveal_type(x)  # revealed: Literal[1]

Union with Any

from typing import Any

def _(x: Any | None, y: Any | None):
    if x != 1:
        reveal_type(x)  # revealed: (Any & ~Literal[1]) | None
    if y == 1:
        reveal_type(y)  # revealed: Any & ~None

Booleans and integers

from typing import Literal

def _(b: bool, i: Literal[1, 2]):
    if b == 1:
        reveal_type(b)  # revealed: Literal[True]
    else:
        reveal_type(b)  # revealed: Literal[False]

    if b == 6:
        reveal_type(b)  # revealed: Never
    else:
        reveal_type(b)  # revealed: bool

    if b == 0:
        reveal_type(b)  # revealed: Literal[False]
    else:
        reveal_type(b)  # revealed: Literal[True]

    if i == True:
        reveal_type(i)  # revealed: Literal[1]
    else:
        reveal_type(i)  # revealed: Literal[2]

Narrowing LiteralString in union

from typing_extensions import Literal, LiteralString, Any

def _(s: LiteralString | None, t: LiteralString | Any):
    if s == "foo":
        reveal_type(s)  # revealed: Literal["foo"]

    if s == 1:
        reveal_type(s)  # revealed: Never

    if t == "foo":
        # TODO could be `Literal["foo"] | Any`
        reveal_type(t)  # revealed: LiteralString | Any