Compare commits
2 Commits
zb/cache-n
...
alex/less-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e08e60c64b | ||
|
|
1463ce5820 |
@@ -2124,26 +2124,20 @@ shows up in a subset of the union members) is present, but that isn't generally
|
||||
field, it could be *assigned to* with another `TypedDict` that does:
|
||||
|
||||
```py
|
||||
from typing_extensions import Literal
|
||||
|
||||
class Foo(TypedDict):
|
||||
foo: int
|
||||
|
||||
class Bar(TypedDict):
|
||||
bar: int
|
||||
|
||||
def disappointment(u: Foo | Bar, v: Literal["foo"]):
|
||||
def disappointment(u: Foo | Bar):
|
||||
if "foo" in u:
|
||||
# We can't narrow the union here...
|
||||
reveal_type(u) # revealed: Foo | Bar
|
||||
else:
|
||||
# ...(even though we *can* narrow it here)...
|
||||
reveal_type(u) # revealed: Bar
|
||||
|
||||
if v in u:
|
||||
# TODO: This should narrow to `Bar`, because "foo" is required in `Foo`.
|
||||
reveal_type(u) # revealed: Foo | Bar
|
||||
else:
|
||||
reveal_type(u) # revealed: Bar
|
||||
|
||||
# ...because `u` could turn out to be one of these.
|
||||
class FooBar(TypedDict):
|
||||
@@ -2154,39 +2148,6 @@ static_assert(is_assignable_to(FooBar, Foo))
|
||||
static_assert(is_assignable_to(FooBar, Bar))
|
||||
```
|
||||
|
||||
`not in` works in the opposite way to `in`: we can narrow in the positive case, but we cannot narrow
|
||||
in the negative case. The following snippet also tests our narrowing behaviour for intersections
|
||||
that contain `TypedDict`s, and unions that contain intersections that contain `TypedDict`s:
|
||||
|
||||
```py
|
||||
from typing_extensions import Literal, Any
|
||||
from ty_extensions import Intersection, is_assignable_to, static_assert
|
||||
|
||||
def _(t: Bar, u: Foo | Intersection[Bar, Any], v: Intersection[Bar, Any], w: Literal["bar"]):
|
||||
reveal_type(u) # revealed: Foo | (Bar & Any)
|
||||
reveal_type(v) # revealed: Bar & Any
|
||||
|
||||
if "bar" not in t:
|
||||
reveal_type(t) # revealed: Never
|
||||
else:
|
||||
reveal_type(t) # revealed: Bar
|
||||
|
||||
if "bar" not in u:
|
||||
reveal_type(u) # revealed: Foo
|
||||
else:
|
||||
reveal_type(u) # revealed: Foo | (Bar & Any)
|
||||
|
||||
if "bar" not in v:
|
||||
reveal_type(v) # revealed: Never
|
||||
else:
|
||||
reveal_type(v) # revealed: Bar & Any
|
||||
|
||||
if w not in u:
|
||||
reveal_type(u) # revealed: Foo
|
||||
else:
|
||||
reveal_type(u) # revealed: Foo | (Bar & Any)
|
||||
```
|
||||
|
||||
TODO: The narrowing that we didn't do above will become possible when we add support for
|
||||
`closed=True`. This is [one of the main use cases][closed] that motivated the `closed` feature.
|
||||
|
||||
|
||||
@@ -14107,19 +14107,21 @@ impl<'db> UnionType<'db> {
|
||||
self.try_map(db, |element| element.to_instance(db))
|
||||
}
|
||||
|
||||
pub(crate) fn filter(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
mut f: impl FnMut(&Type<'db>) -> bool,
|
||||
) -> Type<'db> {
|
||||
self.elements(db)
|
||||
.iter()
|
||||
.filter(|ty| f(ty))
|
||||
.fold(UnionBuilder::new(db), |builder, element| {
|
||||
builder.add(*element)
|
||||
})
|
||||
.recursively_defined(self.recursively_defined(db))
|
||||
.build()
|
||||
pub(crate) fn filter(self, db: &'db dyn Db, f: impl FnMut(&Type<'db>) -> bool) -> Type<'db> {
|
||||
let current = self.elements(db);
|
||||
let new: Vec<Type<'db>> = current.iter().copied().filter(f).collect();
|
||||
match new.len() {
|
||||
0 => Type::Never,
|
||||
1 => new[0],
|
||||
len if len == current.len() => Type::Union(self),
|
||||
_ => new
|
||||
.iter()
|
||||
.fold(UnionBuilder::new(db), |builder, element| {
|
||||
builder.add(*element)
|
||||
})
|
||||
.recursively_defined(self.recursively_defined(db))
|
||||
.build(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn map_with_boundness(
|
||||
|
||||
@@ -715,7 +715,6 @@ impl<'db> ClassType<'db> {
|
||||
/// Return the [`DisjointBase`] that appears first in the MRO of this class.
|
||||
///
|
||||
/// Returns `None` if this class does not have any disjoint bases in its MRO.
|
||||
#[salsa::tracked(heap_size=ruff_memory_usage::heap_size)]
|
||||
pub(super) fn nearest_disjoint_base(self, db: &'db dyn Db) -> Option<DisjointBase<'db>> {
|
||||
self.iter_mro(db)
|
||||
.filter_map(ClassBase::into_class)
|
||||
@@ -4234,7 +4233,7 @@ impl InheritanceCycle {
|
||||
/// `TypeError`s resulting from class definitions.
|
||||
///
|
||||
/// [PEP 800]: https://peps.python.org/pep-0800/
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, get_size2::GetSize, salsa::Update)]
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
|
||||
pub(super) struct DisjointBase<'db> {
|
||||
pub(super) class: ClassLiteral<'db>,
|
||||
pub(super) kind: DisjointBaseKind,
|
||||
@@ -4271,7 +4270,7 @@ impl<'db> DisjointBase<'db> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, get_size2::GetSize, salsa::Update)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub(super) enum DisjointBaseKind {
|
||||
/// We know the class is a disjoint base because it's either hardcoded in ty
|
||||
/// or has the `@disjoint_base` decorator.
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::types::enums::{enum_member_literals, enum_metadata};
|
||||
use crate::types::function::KnownFunction;
|
||||
use crate::types::infer::{ExpressionInference, infer_same_file_expression_type};
|
||||
use crate::types::typed_dict::{
|
||||
SynthesizedTypedDictType, TypedDictField, TypedDictFieldBuilder, TypedDictSchema, TypedDictType,
|
||||
SynthesizedTypedDictType, TypedDictFieldBuilder, TypedDictSchema, TypedDictType,
|
||||
};
|
||||
use crate::types::{
|
||||
CallableType, ClassLiteral, ClassType, IntersectionBuilder, IntersectionType, KnownClass,
|
||||
@@ -1027,31 +1027,23 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
&& rhs_ty.is_singleton(self.db)
|
||||
{
|
||||
let is_positive_check = is_positive == (ops[0] == ast::CmpOp::Is);
|
||||
let filtered: Vec<_> = union
|
||||
.elements(self.db)
|
||||
.iter()
|
||||
.filter(|elem| {
|
||||
elem.as_nominal_instance()
|
||||
.and_then(|inst| inst.tuple_spec(self.db))
|
||||
.and_then(|spec| spec.py_index(self.db, index).ok())
|
||||
.is_none_or(|el_ty| {
|
||||
if is_positive_check {
|
||||
// `is X` context: keep tuples where element could be X
|
||||
!el_ty.is_disjoint_from(self.db, rhs_ty)
|
||||
} else {
|
||||
// `is not X` context: keep tuples where element is not always X
|
||||
!el_ty.is_subtype_of(self.db, rhs_ty)
|
||||
}
|
||||
})
|
||||
})
|
||||
.copied()
|
||||
.collect();
|
||||
if filtered.len() < union.elements(self.db).len() {
|
||||
let filtered = union.filter(self.db, |elem| {
|
||||
elem.as_nominal_instance()
|
||||
.and_then(|inst| inst.tuple_spec(self.db))
|
||||
.and_then(|spec| spec.py_index(self.db, index).ok())
|
||||
.is_none_or(|el_ty| {
|
||||
if is_positive_check {
|
||||
// `is X` context: keep tuples where element could be X
|
||||
!el_ty.is_disjoint_from(self.db, rhs_ty)
|
||||
} else {
|
||||
// `is not X` context: keep tuples where element is not always X
|
||||
!el_ty.is_subtype_of(self.db, rhs_ty)
|
||||
}
|
||||
})
|
||||
});
|
||||
if filtered != Type::Union(union) {
|
||||
let place = self.expect_place(&subscript_place_expr);
|
||||
constraints.insert(
|
||||
place,
|
||||
NarrowingConstraint::regular(UnionType::from_elements(self.db, filtered)),
|
||||
);
|
||||
constraints.insert(place, NarrowingConstraint::typeguard(filtered));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1099,75 +1091,6 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
}
|
||||
}
|
||||
|
||||
// Narrow unions and intersections of `TypedDict` in cases where required keys are
|
||||
// excluded:
|
||||
//
|
||||
// class Foo(TypedDict):
|
||||
// foo: int
|
||||
// class Bar(TypedDict):
|
||||
// bar: int
|
||||
//
|
||||
// def _(u: Foo | Bar):
|
||||
// if "foo" not in u:
|
||||
// reveal_type(u) # revealed: Bar
|
||||
if matches!(&**ops, [ast::CmpOp::In | ast::CmpOp::NotIn])
|
||||
&& let Type::StringLiteral(key) = inference.expression_type(&**left)
|
||||
&& let Some(rhs_place_expr) = place_expr(&comparators[0])
|
||||
&& let rhs_type = inference.expression_type(&comparators[0])
|
||||
&& is_typeddict_or_union_with_typeddicts(self.db, rhs_type)
|
||||
{
|
||||
let is_negative_check = is_positive == (ops[0] == ast::CmpOp::NotIn);
|
||||
if is_negative_check {
|
||||
let requires_key = |td: TypedDictType<'db>| -> bool {
|
||||
td.items(self.db)
|
||||
.get(key.value(self.db))
|
||||
.is_some_and(TypedDictField::is_required)
|
||||
};
|
||||
|
||||
let narrowed = match rhs_type {
|
||||
Type::TypedDict(td) => {
|
||||
if requires_key(td) {
|
||||
Type::Never
|
||||
} else {
|
||||
rhs_type
|
||||
}
|
||||
}
|
||||
Type::Intersection(intersection) => {
|
||||
if intersection
|
||||
.positive(self.db)
|
||||
.iter()
|
||||
.copied()
|
||||
.filter_map(Type::as_typed_dict)
|
||||
.any(requires_key)
|
||||
{
|
||||
Type::Never
|
||||
} else {
|
||||
rhs_type
|
||||
}
|
||||
}
|
||||
Type::Union(union) => {
|
||||
// remove all members of the union that would require the key
|
||||
union.filter(self.db, |ty| match ty {
|
||||
Type::TypedDict(td) => !requires_key(*td),
|
||||
Type::Intersection(intersection) => !intersection
|
||||
.positive(self.db)
|
||||
.iter()
|
||||
.copied()
|
||||
.filter_map(Type::as_typed_dict)
|
||||
.any(requires_key),
|
||||
_ => true,
|
||||
})
|
||||
}
|
||||
_ => rhs_type,
|
||||
};
|
||||
|
||||
if narrowed != rhs_type {
|
||||
let place = self.expect_place(&rhs_place_expr);
|
||||
constraints.insert(place, NarrowingConstraint::typeguard(narrowed));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_rhs_ty: Option<Type> = None;
|
||||
|
||||
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
|
||||
@@ -1708,33 +1631,25 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
}
|
||||
|
||||
// Filter the union based on whether each tuple element at the index could match the rhs.
|
||||
let filtered: Vec<_> = union
|
||||
.elements(self.db)
|
||||
.iter()
|
||||
.filter(|elem| {
|
||||
elem.as_nominal_instance()
|
||||
.and_then(|inst| inst.tuple_spec(self.db))
|
||||
.and_then(|spec| spec.py_index(self.db, index).ok())
|
||||
.is_none_or(|el_ty| {
|
||||
if constrain_with_equality {
|
||||
// Keep tuples where element could be equal to rhs.
|
||||
!el_ty.is_disjoint_from(self.db, rhs_type)
|
||||
} else {
|
||||
// Keep tuples where element is not always equal to rhs.
|
||||
!el_ty.is_subtype_of(self.db, rhs_type)
|
||||
}
|
||||
})
|
||||
})
|
||||
.copied()
|
||||
.collect();
|
||||
let filtered = union.filter(self.db, |elem| {
|
||||
elem.as_nominal_instance()
|
||||
.and_then(|inst| inst.tuple_spec(self.db))
|
||||
.and_then(|spec| spec.py_index(self.db, index).ok())
|
||||
.is_none_or(|el_ty| {
|
||||
if constrain_with_equality {
|
||||
// Keep tuples where element could be equal to rhs.
|
||||
!el_ty.is_disjoint_from(self.db, rhs_type)
|
||||
} else {
|
||||
// Keep tuples where element is not always equal to rhs.
|
||||
!el_ty.is_subtype_of(self.db, rhs_type)
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
// Only create a constraint if we actually narrowed something.
|
||||
if filtered.len() < union.elements(self.db).len() {
|
||||
if filtered != rhs_type {
|
||||
let place = self.expect_place(&subscript_place_expr);
|
||||
Some((
|
||||
place,
|
||||
NarrowingConstraint::regular(UnionType::from_elements(self.db, filtered)),
|
||||
))
|
||||
Some((place, NarrowingConstraint::typeguard(filtered)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -1746,13 +1661,18 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
||||
fn is_typeddict_or_union_with_typeddicts<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
|
||||
match ty {
|
||||
Type::TypedDict(_) => true,
|
||||
Type::Intersection(intersection) => {
|
||||
intersection.positive(db).iter().any(Type::is_typed_dict)
|
||||
Type::Union(union) => {
|
||||
union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.any(|union_member_ty| match union_member_ty {
|
||||
Type::TypedDict(_) => true,
|
||||
Type::Intersection(intersection) => {
|
||||
intersection.positive(db).iter().any(Type::is_typed_dict)
|
||||
}
|
||||
_ => false,
|
||||
})
|
||||
}
|
||||
Type::Union(union) => union
|
||||
.elements(db)
|
||||
.iter()
|
||||
.any(|union_member_ty| is_typeddict_or_union_with_typeddicts(db, *union_member_ty)),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user