Compare commits

..

2 Commits

4 changed files with 63 additions and 181 deletions

View File

@@ -2124,26 +2124,20 @@ shows up in a subset of the union members) is present, but that isn't generally
field, it could be *assigned to* with another `TypedDict` that does:
```py
from typing_extensions import Literal
class Foo(TypedDict):
foo: int
class Bar(TypedDict):
bar: int
def disappointment(u: Foo | Bar, v: Literal["foo"]):
def disappointment(u: Foo | Bar):
if "foo" in u:
# We can't narrow the union here...
reveal_type(u) # revealed: Foo | Bar
else:
# ...(even though we *can* narrow it here)...
reveal_type(u) # revealed: Bar
if v in u:
# TODO: This should narrow to `Bar`, because "foo" is required in `Foo`.
reveal_type(u) # revealed: Foo | Bar
else:
reveal_type(u) # revealed: Bar
# ...because `u` could turn out to be one of these.
class FooBar(TypedDict):
@@ -2154,39 +2148,6 @@ static_assert(is_assignable_to(FooBar, Foo))
static_assert(is_assignable_to(FooBar, Bar))
```
`not in` works in the opposite way to `in`: we can narrow in the positive case, but we cannot narrow
in the negative case. The following snippet also tests our narrowing behaviour for intersections
that contain `TypedDict`s, and unions that contain intersections that contain `TypedDict`s:
```py
from typing_extensions import Literal, Any
from ty_extensions import Intersection, is_assignable_to, static_assert
def _(t: Bar, u: Foo | Intersection[Bar, Any], v: Intersection[Bar, Any], w: Literal["bar"]):
reveal_type(u) # revealed: Foo | (Bar & Any)
reveal_type(v) # revealed: Bar & Any
if "bar" not in t:
reveal_type(t) # revealed: Never
else:
reveal_type(t) # revealed: Bar
if "bar" not in u:
reveal_type(u) # revealed: Foo
else:
reveal_type(u) # revealed: Foo | (Bar & Any)
if "bar" not in v:
reveal_type(v) # revealed: Never
else:
reveal_type(v) # revealed: Bar & Any
if w not in u:
reveal_type(u) # revealed: Foo
else:
reveal_type(u) # revealed: Foo | (Bar & Any)
```
TODO: The narrowing that we didn't do above will become possible when we add support for
`closed=True`. This is [one of the main use cases][closed] that motivated the `closed` feature.

View File

@@ -14107,19 +14107,21 @@ impl<'db> UnionType<'db> {
self.try_map(db, |element| element.to_instance(db))
}
pub(crate) fn filter(
self,
db: &'db dyn Db,
mut f: impl FnMut(&Type<'db>) -> bool,
) -> Type<'db> {
self.elements(db)
.iter()
.filter(|ty| f(ty))
.fold(UnionBuilder::new(db), |builder, element| {
builder.add(*element)
})
.recursively_defined(self.recursively_defined(db))
.build()
pub(crate) fn filter(self, db: &'db dyn Db, f: impl FnMut(&Type<'db>) -> bool) -> Type<'db> {
let current = self.elements(db);
let new: Vec<Type<'db>> = current.iter().copied().filter(f).collect();
match new.len() {
0 => Type::Never,
1 => new[0],
len if len == current.len() => Type::Union(self),
_ => new
.iter()
.fold(UnionBuilder::new(db), |builder, element| {
builder.add(*element)
})
.recursively_defined(self.recursively_defined(db))
.build(),
}
}
pub(crate) fn map_with_boundness(

View File

@@ -715,7 +715,6 @@ impl<'db> ClassType<'db> {
/// Return the [`DisjointBase`] that appears first in the MRO of this class.
///
/// Returns `None` if this class does not have any disjoint bases in its MRO.
#[salsa::tracked(heap_size=ruff_memory_usage::heap_size)]
pub(super) fn nearest_disjoint_base(self, db: &'db dyn Db) -> Option<DisjointBase<'db>> {
self.iter_mro(db)
.filter_map(ClassBase::into_class)
@@ -4234,7 +4233,7 @@ impl InheritanceCycle {
/// `TypeError`s resulting from class definitions.
///
/// [PEP 800]: https://peps.python.org/pep-0800/
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, get_size2::GetSize, salsa::Update)]
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
pub(super) struct DisjointBase<'db> {
pub(super) class: ClassLiteral<'db>,
pub(super) kind: DisjointBaseKind,
@@ -4271,7 +4270,7 @@ impl<'db> DisjointBase<'db> {
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, get_size2::GetSize, salsa::Update)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub(super) enum DisjointBaseKind {
/// We know the class is a disjoint base because it's either hardcoded in ty
/// or has the `@disjoint_base` decorator.

View File

@@ -12,7 +12,7 @@ use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::function::KnownFunction;
use crate::types::infer::{ExpressionInference, infer_same_file_expression_type};
use crate::types::typed_dict::{
SynthesizedTypedDictType, TypedDictField, TypedDictFieldBuilder, TypedDictSchema, TypedDictType,
SynthesizedTypedDictType, TypedDictFieldBuilder, TypedDictSchema, TypedDictType,
};
use crate::types::{
CallableType, ClassLiteral, ClassType, IntersectionBuilder, IntersectionType, KnownClass,
@@ -1027,31 +1027,23 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&& rhs_ty.is_singleton(self.db)
{
let is_positive_check = is_positive == (ops[0] == ast::CmpOp::Is);
let filtered: Vec<_> = union
.elements(self.db)
.iter()
.filter(|elem| {
elem.as_nominal_instance()
.and_then(|inst| inst.tuple_spec(self.db))
.and_then(|spec| spec.py_index(self.db, index).ok())
.is_none_or(|el_ty| {
if is_positive_check {
// `is X` context: keep tuples where element could be X
!el_ty.is_disjoint_from(self.db, rhs_ty)
} else {
// `is not X` context: keep tuples where element is not always X
!el_ty.is_subtype_of(self.db, rhs_ty)
}
})
})
.copied()
.collect();
if filtered.len() < union.elements(self.db).len() {
let filtered = union.filter(self.db, |elem| {
elem.as_nominal_instance()
.and_then(|inst| inst.tuple_spec(self.db))
.and_then(|spec| spec.py_index(self.db, index).ok())
.is_none_or(|el_ty| {
if is_positive_check {
// `is X` context: keep tuples where element could be X
!el_ty.is_disjoint_from(self.db, rhs_ty)
} else {
// `is not X` context: keep tuples where element is not always X
!el_ty.is_subtype_of(self.db, rhs_ty)
}
})
});
if filtered != Type::Union(union) {
let place = self.expect_place(&subscript_place_expr);
constraints.insert(
place,
NarrowingConstraint::regular(UnionType::from_elements(self.db, filtered)),
);
constraints.insert(place, NarrowingConstraint::typeguard(filtered));
}
}
@@ -1099,75 +1091,6 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
}
// Narrow unions and intersections of `TypedDict` in cases where required keys are
// excluded:
//
// class Foo(TypedDict):
// foo: int
// class Bar(TypedDict):
// bar: int
//
// def _(u: Foo | Bar):
// if "foo" not in u:
// reveal_type(u) # revealed: Bar
if matches!(&**ops, [ast::CmpOp::In | ast::CmpOp::NotIn])
&& let Type::StringLiteral(key) = inference.expression_type(&**left)
&& let Some(rhs_place_expr) = place_expr(&comparators[0])
&& let rhs_type = inference.expression_type(&comparators[0])
&& is_typeddict_or_union_with_typeddicts(self.db, rhs_type)
{
let is_negative_check = is_positive == (ops[0] == ast::CmpOp::NotIn);
if is_negative_check {
let requires_key = |td: TypedDictType<'db>| -> bool {
td.items(self.db)
.get(key.value(self.db))
.is_some_and(TypedDictField::is_required)
};
let narrowed = match rhs_type {
Type::TypedDict(td) => {
if requires_key(td) {
Type::Never
} else {
rhs_type
}
}
Type::Intersection(intersection) => {
if intersection
.positive(self.db)
.iter()
.copied()
.filter_map(Type::as_typed_dict)
.any(requires_key)
{
Type::Never
} else {
rhs_type
}
}
Type::Union(union) => {
// remove all members of the union that would require the key
union.filter(self.db, |ty| match ty {
Type::TypedDict(td) => !requires_key(*td),
Type::Intersection(intersection) => !intersection
.positive(self.db)
.iter()
.copied()
.filter_map(Type::as_typed_dict)
.any(requires_key),
_ => true,
})
}
_ => rhs_type,
};
if narrowed != rhs_type {
let place = self.expect_place(&rhs_place_expr);
constraints.insert(place, NarrowingConstraint::typeguard(narrowed));
}
}
}
let mut last_rhs_ty: Option<Type> = None;
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
@@ -1708,33 +1631,25 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
// Filter the union based on whether each tuple element at the index could match the rhs.
let filtered: Vec<_> = union
.elements(self.db)
.iter()
.filter(|elem| {
elem.as_nominal_instance()
.and_then(|inst| inst.tuple_spec(self.db))
.and_then(|spec| spec.py_index(self.db, index).ok())
.is_none_or(|el_ty| {
if constrain_with_equality {
// Keep tuples where element could be equal to rhs.
!el_ty.is_disjoint_from(self.db, rhs_type)
} else {
// Keep tuples where element is not always equal to rhs.
!el_ty.is_subtype_of(self.db, rhs_type)
}
})
})
.copied()
.collect();
let filtered = union.filter(self.db, |elem| {
elem.as_nominal_instance()
.and_then(|inst| inst.tuple_spec(self.db))
.and_then(|spec| spec.py_index(self.db, index).ok())
.is_none_or(|el_ty| {
if constrain_with_equality {
// Keep tuples where element could be equal to rhs.
!el_ty.is_disjoint_from(self.db, rhs_type)
} else {
// Keep tuples where element is not always equal to rhs.
!el_ty.is_subtype_of(self.db, rhs_type)
}
})
});
// Only create a constraint if we actually narrowed something.
if filtered.len() < union.elements(self.db).len() {
if filtered != rhs_type {
let place = self.expect_place(&subscript_place_expr);
Some((
place,
NarrowingConstraint::regular(UnionType::from_elements(self.db, filtered)),
))
Some((place, NarrowingConstraint::typeguard(filtered)))
} else {
None
}
@@ -1746,13 +1661,18 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
match ty {
Type::TypedDict(_) => true,
Type::Intersection(intersection) => {
intersection.positive(db).iter().any(Type::is_typed_dict)
Type::Union(union) => {
union
.elements(db)
.iter()
.any(|union_member_ty| match union_member_ty {
Type::TypedDict(_) => true,
Type::Intersection(intersection) => {
intersection.positive(db).iter().any(Type::is_typed_dict)
}
_ => false,
})
}
Type::Union(union) => union
.elements(db)
.iter()
.any(|union_member_ty| is_typeddict_or_union_with_typeddicts(db, *union_member_ty)),
_ => false,
}
}