[ty] narrow tagged unions of TypedDict (#22104)
Identify and narrow cases like this:
```py
class Foo(TypedDict):
tag: Literal["foo"]
class Bar(TypedDict):
tag: Literal["bar"]
def _(union: Foo | Bar):
if union["tag"] == "foo":
reveal_type(union) # Foo
```
Fixes part of https://github.com/astral-sh/ty/issues/1479.
---------
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
@@ -10,6 +10,9 @@ use crate::semantic_index::scope::ScopeId;
|
||||
use crate::types::enums::{enum_member_literals, enum_metadata};
|
||||
use crate::types::function::KnownFunction;
|
||||
use crate::types::infer::infer_same_file_expression_type;
|
||||
use crate::types::typed_dict::{
|
||||
SynthesizedTypedDictType, TypedDictFieldBuilder, TypedDictSchema, TypedDictType,
|
||||
};
|
||||
use crate::types::{
|
||||
CallableType, ClassLiteral, ClassType, IntersectionBuilder, KnownClass, KnownInstanceType,
|
||||
SpecialFormType, SubclassOfInner, SubclassOfType, Truthiness, Type, TypeContext,
|
||||
@@ -17,6 +20,7 @@ use crate::types::{
|
||||
};
|
||||
|
||||
use ruff_db::parsed::{ParsedModuleRef, parsed_module};
|
||||
use ruff_python_ast::name::Name;
|
||||
use ruff_python_stdlib::identifiers::is_identifier;
|
||||
|
||||
use itertools::Itertools;
|
||||
@@ -877,6 +881,72 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
|
||||
let mut constraints = NarrowingConstraints::default();
|
||||
|
||||
// Narrow tagged unions of `TypedDict`s with `Literal` keys, for example:
|
||||
//
|
||||
// class Foo(TypedDict):
|
||||
// tag: Literal["foo"]
|
||||
// class Bar(TypedDict):
|
||||
// tag: Literal["bar"]
|
||||
// def _(union: Foo | Bar):
|
||||
// if union["tag"] == "foo":
|
||||
// reveal_type(union) # Foo
|
||||
//
|
||||
// Importantly, `my_typeddict_union["tag"]` isn't the place we're going to constraint.
|
||||
// Instead, we're going to constrain `my_typeddict_union` itself.
|
||||
if matches!(&**ops, [ast::CmpOp::Eq | ast::CmpOp::NotEq])
|
||||
&& let ast::Expr::Subscript(subscript) = &**left
|
||||
&& let lhs_value_type = inference.expression_type(&*subscript.value)
|
||||
// Checking for `TypedDict`s up front isn't strictly necessary, since the intersection
|
||||
// we're going to build is compatible with non-`TypedDict` types, but we don't want to
|
||||
// do the work to build it and intersect it (or for that matter, let the user see it)
|
||||
// in the common case where there are no `TypedDict`s.
|
||||
&& is_typeddict_or_union_with_typeddicts(self.db, lhs_value_type)
|
||||
&& let Some(subscript_place_expr) = place_expr(&subscript.value)
|
||||
&& let Type::StringLiteral(key_literal) = inference.expression_type(&*subscript.slice)
|
||||
&& let rhs_type = inference.expression_type(&comparators[0])
|
||||
&& is_supported_typeddict_tag_literal(rhs_type)
|
||||
{
|
||||
// If we have an equality constraint (either `==` on the `if` side, or `!=` on the
|
||||
// `else` side), we have to be careful. If all the matching fields in all the
|
||||
// `TypedDict`s here have literal types, then yes, equality is as good as a type check.
|
||||
// However, if any of them are e.g. `int` or `str` or some random class, then we can't
|
||||
// narrow their type at all, because subclasses of those types can implement `__eq__`
|
||||
// in any perverse way they like. On the other hand, if this is an *inequality*
|
||||
// constraint, then we can go ahead and assert "you can't be this exact literal type"
|
||||
// without worrying about what other types might be present.
|
||||
let constrain_with_equality = is_positive == (ops[0] == ast::CmpOp::Eq);
|
||||
if !constrain_with_equality
|
||||
|| all_matching_typeddict_fields_have_literal_types(
|
||||
self.db,
|
||||
lhs_value_type,
|
||||
key_literal.value(self.db),
|
||||
)
|
||||
{
|
||||
let field_name = Name::from(key_literal.value(self.db));
|
||||
let rhs_type = inference.expression_type(&comparators[0]);
|
||||
// To avoid excluding non-`TypedDict` types, our constraints are always expressed
|
||||
// as a negative intersection (i.e. "you're *not* this kind of `TypedDict`"). If
|
||||
// `constrain_with_equality` is true, the whole constraint is going to be a double
|
||||
// negative, i.e. "you're *not* a `TypedDict` *without* this literal field". As the
|
||||
// first step of building that, we negate the right hand side.
|
||||
let field_type = rhs_type.negate_if(self.db, constrain_with_equality);
|
||||
// Create the synthesized `TypedDict` with that (possibly negated) field. We don't
|
||||
// want to constrain the mutability or required-ness of the field, so the most
|
||||
// compatible form is not-required and read-only.
|
||||
let field = TypedDictFieldBuilder::new(field_type)
|
||||
.required(false)
|
||||
.read_only(true)
|
||||
.build();
|
||||
let schema = TypedDictSchema::from_iter([(field_name, field)]);
|
||||
let synthesized_typeddict =
|
||||
TypedDictType::Synthesized(SynthesizedTypedDictType::new(self.db, schema));
|
||||
// As mentioned above, the synthesized `TypedDict` is always negated.
|
||||
let intersection = Type::TypedDict(synthesized_typeddict).negate(self.db);
|
||||
let place = self.expect_place(&subscript_place_expr);
|
||||
constraints.insert(place, intersection);
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_rhs_ty: Option<Type> = None;
|
||||
|
||||
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
|
||||
@@ -1212,3 +1282,71 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return true if the given type is a `TypedDict`, or if it's a union that includes at least one
|
||||
// `TypedDict` (even if other types are present).
|
||||
fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
|
||||
match ty {
|
||||
Type::TypedDict(_) => true,
|
||||
Type::Union(union) => {
|
||||
union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.any(|union_member_ty| match union_member_ty {
|
||||
Type::TypedDict(_) => true,
|
||||
Type::Intersection(intersection) => {
|
||||
intersection.positive(db).iter().any(Type::is_typed_dict)
|
||||
}
|
||||
_ => false,
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_supported_typeddict_tag_literal(ty: Type) -> bool {
|
||||
matches!(
|
||||
ty,
|
||||
// TODO: We'd like to support `EnumLiteral` also, but we have to be careful with types like
|
||||
// `IntEnum` and `StrEnum` that have custom `__eq__` methods.
|
||||
Type::StringLiteral(_) | Type::BytesLiteral(_) | Type::IntLiteral(_)
|
||||
)
|
||||
}
|
||||
|
||||
// See the comment above the call to this function.
|
||||
fn all_matching_typeddict_fields_have_literal_types<'db>(
|
||||
db: &'db dyn Db,
|
||||
ty: Type<'db>,
|
||||
field_name: &str,
|
||||
) -> bool {
|
||||
let matching_field_is_literal = |typeddict: &TypedDictType<'db>| {
|
||||
// There's no matching field to check if `.get()` returns `None`.
|
||||
typeddict
|
||||
.items(db)
|
||||
.get(field_name)
|
||||
.is_none_or(|field| is_supported_typeddict_tag_literal(field.declared_ty))
|
||||
};
|
||||
|
||||
match ty {
|
||||
Type::TypedDict(td) => matching_field_is_literal(&td),
|
||||
Type::Union(union) => {
|
||||
union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.all(|union_member_ty| match union_member_ty {
|
||||
Type::TypedDict(td) => matching_field_is_literal(td),
|
||||
Type::Intersection(intersection) => {
|
||||
intersection
|
||||
.positive(db)
|
||||
.iter()
|
||||
.all(|intersection_member_ty| match intersection_member_ty {
|
||||
Type::TypedDict(td) => matching_field_is_literal(td),
|
||||
_ => true,
|
||||
})
|
||||
}
|
||||
_ => true,
|
||||
})
|
||||
}
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user