Compare commits

...

1 Commits

Author SHA1 Message Date
Charlie Marsh
1d4dc4423a [ty] Allow enum narrowing for classes that don't override __eq__ 2026-01-06 08:27:38 -05:00
2 changed files with 120 additions and 11 deletions

View File

@@ -2050,6 +2050,59 @@ def _(u: Foo | Bar | Baz | Bing):
reveal_type(u) # revealed: Bing
```
Enum literals are also supported as tags, as long as the enum class does not override `__eq__`:
```py
from enum import Enum
class Tag(Enum):
A = 1
B = 2
C = 3
class WithEnumTagA(TypedDict):
tag: Literal[Tag.A]
class WithEnumTagB(TypedDict):
tag: Literal[Tag.B]
class WithEnumTagC(TypedDict):
tag: Literal[Tag.C]
def _(u: WithEnumTagA | WithEnumTagB | WithEnumTagC):
if u["tag"] == Tag.A:
reveal_type(u) # revealed: WithEnumTagA
elif u["tag"] == Tag.B:
reveal_type(u) # revealed: WithEnumTagB
else:
reveal_type(u) # revealed: WithEnumTagC
```
However, if an enum overrides `__eq__`, we cannot safely narrow based on its values:
```py
class WackyTag(Enum):
X = 1
Y = 2
def __eq__(self, other):
return True
class WithWackyTagX(TypedDict):
tag: Literal[WackyTag.X]
class WithWackyTagY(TypedDict):
tag: Literal[WackyTag.Y]
def _(u: WithWackyTagX | WithWackyTagY):
if u["tag"] == WackyTag.X:
# We cannot narrow here because WackyTag overrides `__eq__`.
reveal_type(u) # revealed: WithWackyTagX | WithWackyTagY
else:
# We cannot narrow here either.
reveal_type(u) # revealed: WithWackyTagX | WithWackyTagY
```
We can descend into intersections to discover `TypedDict` types that need narrowing:
```py
@@ -2226,6 +2279,61 @@ def match_statements(u: Foo | Bar | Baz | Bing):
reveal_type(u) # revealed: Bing
```
Enum literals are supported as tags in match statements:
```py
from enum import Enum
class Tag(Enum):
A = 1
B = 2
C = 3
class WithEnumTagA(TypedDict):
tag: Literal[Tag.A]
class WithEnumTagB(TypedDict):
tag: Literal[Tag.B]
class WithEnumTagC(TypedDict):
tag: Literal[Tag.C]
def match_enum_tags(u: WithEnumTagA | WithEnumTagB | WithEnumTagC):
match u["tag"]:
case Tag.A:
reveal_type(u) # revealed: WithEnumTagA
case Tag.B:
reveal_type(u) # revealed: WithEnumTagB
case _:
reveal_type(u) # revealed: WithEnumTagC
```
If an enum overrides `__eq__`, we cannot safely narrow based on its values in match statements:
```py
class WackyTag(Enum):
X = 1
Y = 2
def __eq__(self, other):
return True
class WithWackyTagX(TypedDict):
tag: Literal[WackyTag.X]
class WithWackyTagY(TypedDict):
tag: Literal[WackyTag.Y]
def match_wacky_enum_tags(u: WithWackyTagX | WithWackyTagY):
match u["tag"]:
case WackyTag.X:
# We cannot narrow here because WackyTag overrides `__eq__`.
reveal_type(u) # revealed: WithWackyTagX | WithWackyTagY
case _:
# We cannot narrow here either.
reveal_type(u) # revealed: WithWackyTagX | WithWackyTagY
```
We can also narrow a single `TypedDict` type to `Never`:
```py

View File

@@ -1606,7 +1606,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let Type::StringLiteral(key_literal) = subscript_key_type else {
return None;
};
if !is_supported_tag_literal(rhs_type) {
if !is_supported_tag_literal(self.db, rhs_type) {
return None;
}
@@ -1684,7 +1684,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let index = i32::try_from(index).ok()?;
// The comparison value must be a supported literal type.
if !is_supported_tag_literal(rhs_type) {
if !is_supported_tag_literal(self.db, rhs_type) {
return None;
}
@@ -1754,13 +1754,14 @@ fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) ->
}
}
fn is_supported_tag_literal(ty: Type) -> bool {
matches!(
ty,
// TODO: We'd like to support `EnumLiteral` also, but we have to be careful with types like
// `IntEnum` and `StrEnum` that have custom `__eq__` methods.
Type::StringLiteral(_) | Type::BytesLiteral(_) | Type::IntLiteral(_)
)
fn is_supported_tag_literal<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
match ty {
Type::StringLiteral(_) | Type::BytesLiteral(_) | Type::IntLiteral(_) => true,
// Enum literals are supported as long as the enum class does not override
// `__eq__` or `__ne__` with a custom implementation.
Type::EnumLiteral(_) => !ty.overrides_equality(db),
_ => false,
}
}
// See the comment above the call to this function.
@@ -1774,7 +1775,7 @@ fn all_matching_typeddict_fields_have_literal_types<'db>(
typeddict
.items(db)
.get(field_name)
.is_none_or(|field| is_supported_tag_literal(field.declared_ty))
.is_none_or(|field| is_supported_tag_literal(db, field.declared_ty))
};
match ty {
@@ -1832,6 +1833,6 @@ fn all_matching_tuple_elements_have_literal_types<'db>(
elem.as_nominal_instance()
.and_then(|inst| inst.tuple_spec(db))
.and_then(|spec| spec.py_index(db, index).ok())
.is_none_or(is_supported_tag_literal)
.is_none_or(|ty| is_supported_tag_literal(db, ty))
})
}