Compare commits

...

1 Commits

Author SHA1 Message Date
Charlie Marsh
7e80f19223 [ty] Try eliminating ~AlwaysFalsy and ~AlwaysTruthy from intersections 2025-12-29 18:04:04 -05:00
19 changed files with 477 additions and 99 deletions

View File

@@ -194,7 +194,7 @@ static SYMPY: Benchmark = Benchmark::new(
max_dep_date: "2025-06-17",
python_version: PythonVersion::PY312,
},
13106,
13109,
);
static TANJUN: Benchmark = Benchmark::new(

View File

@@ -32,11 +32,11 @@ reveal_type(l3) # revealed: list[int | str]
def _(l: list[int] | None = None):
l1 = l or list()
reveal_type(l1) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]
reveal_type(l1) # revealed: list[int] | list[Unknown]
l2: list[int] = l or list()
# it would be better if this were `list[int]`? (https://github.com/astral-sh/ty/issues/136)
reveal_type(l2) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]
reveal_type(l2) # revealed: list[int] | list[Unknown]
def f[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]
@@ -91,7 +91,7 @@ def get_data() -> dict | None:
def wrap_data() -> list[dict]:
if not (res := get_data()):
return list1({})
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown]]
# `list[dict[Unknown, Unknown] & ~AlwaysFalsy]` and `list[dict[Unknown, Unknown]]` are incompatible,
# but the return type check passes here because the type of `list1(res)` is inferred
# by bidirectional type inference using the annotated return type, and the type of `res` is not used.
@@ -100,7 +100,7 @@ def wrap_data() -> list[dict]:
def wrap_data2() -> list[dict] | None:
if not (res := get_data()):
return None
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown]]
return list1(res)
def deco[T](func: Callable[[], T]) -> Callable[[], T]:
@@ -111,7 +111,7 @@ def outer() -> Callable[[], list[dict]]:
def inner() -> list[dict]:
if not (res := get_data()):
return list1({})
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown]]
return list1(res)
return inner

View File

@@ -189,7 +189,7 @@ def _(
h: Literal[42] | Not[Literal[42]],
i: Not[Literal[42]] | Literal[42],
):
reveal_type(a) # revealed: Literal[""] | ~AlwaysFalsy
reveal_type(a) # revealed: object
reveal_type(b) # revealed: object
reveal_type(c) # revealed: object
reveal_type(d) # revealed: object

View File

@@ -13,7 +13,7 @@ reveal_type(1 is not 1) # revealed: bool
reveal_type(1 is 2) # revealed: Literal[False]
reveal_type(1 is not 7) # revealed: Literal[True]
# error: [unsupported-operator] "Operator `<=` is not supported between objects of type `Literal[1]` and `Literal[""]`"
reveal_type(1 <= "" and 0 < 1) # revealed: (Unknown & ~AlwaysTruthy) | Literal[True]
reveal_type(1 <= "" and 0 < 1) # revealed: Unknown | Literal[True]
```
## Integer instance

View File

@@ -37,7 +37,7 @@ class C:
return self
x = A() < B() < C()
reveal_type(x) # revealed: (A & ~AlwaysTruthy) | B
reveal_type(x) # revealed: A | B
y = 0 < 1 < A() < 3
reveal_type(y) # revealed: Literal[False] | A

View File

@@ -10,8 +10,8 @@ def _(foo: str):
reveal_type(False or "z") # revealed: Literal["z"]
reveal_type(False or True) # revealed: Literal[True]
reveal_type(False or False) # revealed: Literal[False]
reveal_type(foo or False) # revealed: (str & ~AlwaysFalsy) | Literal[False]
reveal_type(foo or True) # revealed: (str & ~AlwaysFalsy) | Literal[True]
reveal_type(foo or False) # revealed: str | Literal[False]
reveal_type(foo or True) # revealed: str | Literal[True]
```
## AND
@@ -20,8 +20,8 @@ def _(foo: str):
def _(foo: str):
reveal_type(True and False) # revealed: Literal[False]
reveal_type(False and True) # revealed: Literal[False]
reveal_type(foo and False) # revealed: (str & ~AlwaysTruthy) | Literal[False]
reveal_type(foo and True) # revealed: (str & ~AlwaysTruthy) | Literal[True]
reveal_type(foo and False) # revealed: str | Literal[False]
reveal_type(foo and True) # revealed: str | Literal[True]
reveal_type("x" and "y" and "z") # revealed: Literal["z"]
reveal_type("x" and "y" and "") # revealed: Literal[""]
reveal_type("" and "y") # revealed: Literal[""]

View File

@@ -105,8 +105,8 @@ reveal_type(y) # revealed: Unknown
```py
def one(x: int | None):
assert (y := x), reveal_type(y) # revealed: (int & ~AlwaysTruthy) | None
reveal_type(y) # revealed: int & ~AlwaysFalsy
assert (y := x), reveal_type(y) # revealed: int | None
reveal_type(y) # revealed: int
def two(x: int | None):
assert isinstance((y := x), int), reveal_type(y) # revealed: None

View File

@@ -104,14 +104,14 @@ class C:
value: str | None
def foo(c: C):
# The truthiness check `c.value` narrows to `str & ~AlwaysFalsy`.
# The truthiness check `c.value` narrows to `str`.
# The subsequent `len(c.value)` doesn't narrow further since `str` is not narrowable by len().
if c.value and len(c.value):
reveal_type(c.value) # revealed: str & ~AlwaysFalsy
reveal_type(c.value) # revealed: str
# error: [invalid-argument-type] "Argument to function `len` is incorrect: Expected `Sized`, found `str | None`"
if len(c.value) and c.value:
reveal_type(c.value) # revealed: str & ~AlwaysFalsy
reveal_type(c.value) # revealed: str
if c.value is None or not len(c.value):
reveal_type(c.value) # revealed: str | None

View File

@@ -299,7 +299,7 @@ def f(l: list[str | None] | None):
def f(a: A):
if a:
def _():
reveal_type(a) # revealed: A & ~AlwaysFalsy
reveal_type(a) # revealed: A
a.x = None
```

View File

@@ -365,12 +365,12 @@ def f(
if isinstance(c, bool):
reveal_type(c) # revealed: Never
else:
reveal_type(c) # revealed: P & ~AlwaysTruthy
reveal_type(c) # revealed: P
if isinstance(d, bool):
reveal_type(d) # revealed: Never
else:
reveal_type(d) # revealed: P & ~AlwaysFalsy
reveal_type(d) # revealed: P
```
## Narrowing if an object of type `Any` or `Unknown` is used as the second argument

View File

@@ -58,9 +58,9 @@ and `tuple[()]` in the negative case (see <https://github.com/astral-sh/ty/issue
```py
def _(x: tuple[int, ...]):
if len(x):
reveal_type(x) # revealed: tuple[int, ...] & ~AlwaysFalsy
reveal_type(x) # revealed: tuple[int, ...]
else:
reveal_type(x) # revealed: tuple[int, ...] & ~AlwaysTruthy
reveal_type(x) # revealed: tuple[int, ...]
```
## Unions of narrowable types
@@ -70,9 +70,9 @@ from typing import Literal
def _(x: Literal["foo", ""] | tuple[int, ...]):
if len(x):
reveal_type(x) # revealed: Literal["foo"] | (tuple[int, ...] & ~AlwaysFalsy)
reveal_type(x) # revealed: Literal["foo"] | tuple[int, ...]
else:
reveal_type(x) # revealed: Literal[""] | (tuple[int, ...] & ~AlwaysTruthy)
reveal_type(x) # revealed: Literal[""] | tuple[int, ...]
```
## Types that are not narrowed
@@ -119,13 +119,13 @@ def _(lines: list[str]):
if not line:
continue
reveal_type(line) # revealed: str & ~AlwaysFalsy
reveal_type(line) # revealed: str
value = line if len(line) < 3 else ""
reveal_type(value) # revealed: (str & ~AlwaysFalsy) | Literal[""]
reveal_type(value) # revealed: str
if len(value):
# `Literal[""]` is removed, `str & ~AlwaysFalsy` is unchanged
reveal_type(value) # revealed: str & ~AlwaysFalsy
reveal_type(value) # revealed: str
# Accessing value[0] is safe here
_ = value[0]
```

View File

@@ -82,19 +82,19 @@ class B: ...
def f(x: A | B):
if x:
reveal_type(x) # revealed: (A & ~AlwaysFalsy) | (B & ~AlwaysFalsy)
reveal_type(x) # revealed: A | B
else:
reveal_type(x) # revealed: (A & ~AlwaysTruthy) | (B & ~AlwaysTruthy)
reveal_type(x) # revealed: A | B
if x and not x:
reveal_type(x) # revealed: (A & ~AlwaysFalsy & ~AlwaysTruthy) | (B & ~AlwaysFalsy & ~AlwaysTruthy)
reveal_type(x) # revealed: A | B
else:
reveal_type(x) # revealed: A | B
if x or not x:
reveal_type(x) # revealed: A | B
else:
reveal_type(x) # revealed: (A & ~AlwaysTruthy & ~AlwaysFalsy) | (B & ~AlwaysTruthy & ~AlwaysFalsy)
reveal_type(x) # revealed: A | B
```
### Truthiness of Types
@@ -111,9 +111,9 @@ x = int if flag() else str
reveal_type(x) # revealed: <class 'int'> | <class 'str'>
if x:
reveal_type(x) # revealed: (<class 'int'> & ~AlwaysFalsy) | (<class 'str'> & ~AlwaysFalsy)
reveal_type(x) # revealed: <class 'int'> | <class 'str'>
else:
reveal_type(x) # revealed: (<class 'int'> & ~AlwaysTruthy) | (<class 'str'> & ~AlwaysTruthy)
reveal_type(x) # revealed: <class 'int'> | <class 'str'>
```
## Determined Truthiness
@@ -179,9 +179,9 @@ if isinstance(x, str) and not isinstance(x, B):
reveal_type(z) # revealed: (A & str & ~B) | Literal[0, 42, "", "hello"]
if z:
reveal_type(z) # revealed: (A & str & ~B & ~AlwaysFalsy) | Literal[42, "hello"]
reveal_type(z) # revealed: (A & str & ~B) | Literal[42, "hello"]
else:
reveal_type(z) # revealed: (A & str & ~B & ~AlwaysTruthy) | Literal[0, ""]
reveal_type(z) # revealed: (A & str & ~B) | Literal[0, ""]
```
## Narrowing Multiple Variables
@@ -219,7 +219,7 @@ x = A()
if x and not x:
y = x
reveal_type(y) # revealed: A & ~AlwaysFalsy & ~AlwaysTruthy
reveal_type(y) # revealed: A
else:
y = x
reveal_type(y) # revealed: A
@@ -264,16 +264,16 @@ def _(
):
reveal_type(ta) # revealed: type[TruthyClass] | type[AmbiguousClass]
if ta:
reveal_type(ta) # revealed: type[TruthyClass] | (type[AmbiguousClass] & ~AlwaysFalsy)
reveal_type(ta) # revealed: type[TruthyClass] | type[AmbiguousClass]
reveal_type(af) # revealed: type[AmbiguousClass] | type[FalsyClass]
if af:
reveal_type(af) # revealed: type[AmbiguousClass] & ~AlwaysFalsy
reveal_type(af) # revealed: type[AmbiguousClass]
# error: [unsupported-bool-conversion] "Boolean conversion is not supported for type `MetaDeferred`"
if d:
# TODO: Should be `Unknown`
reveal_type(d) # revealed: type[DeferredClass] & ~AlwaysFalsy
reveal_type(d) # revealed: type[DeferredClass]
tf = TruthyClass if flag else FalsyClass
reveal_type(tf) # revealed: <class 'TruthyClass'> | <class 'FalsyClass'>
@@ -296,12 +296,12 @@ def _(x: Literal[0, 1]):
reveal_type(x and A()) # revealed: Literal[0] | A
def _(x: str):
reveal_type(x or A()) # revealed: (str & ~AlwaysFalsy) | A
reveal_type(x and A()) # revealed: (str & ~AlwaysTruthy) | A
reveal_type(x or A()) # revealed: str | A
reveal_type(x and A()) # revealed: str | A
def _(x: bool | str):
reveal_type(x or A()) # revealed: Literal[True] | (str & ~AlwaysFalsy) | A
reveal_type(x and A()) # revealed: Literal[False] | (str & ~AlwaysTruthy) | A
reveal_type(x or A()) # revealed: Literal[True] | str | A
reveal_type(x and A()) # revealed: Literal[False] | str | A
class Falsy:
def __bool__(self) -> Literal[False]:

View File

@@ -140,7 +140,7 @@ type IntOrStr = int | str
def f(x: IntOrStr, y: str | bytes):
z = x or y
reveal_type(z) # revealed: (int & ~AlwaysFalsy) | str | bytes
reveal_type(z) # revealed: int | str | bytes
```
## Multiple layers of union aliases

View File

@@ -266,7 +266,9 @@ use crate::semantic_index::use_def::place_state::{
LiveDeclarationsIterator, PlaceState, PreviousDefinitions, ScopedDefinitionId,
};
use crate::semantic_index::{EnclosingSnapshotResult, SemanticIndex};
use crate::types::{IntersectionBuilder, Truthiness, Type, infer_narrowing_constraint};
use crate::types::{
IntersectionBuilder, NarrowingConstraint, Truthiness, Type, infer_narrowing_constraint,
};
mod place_state;
@@ -757,22 +759,50 @@ impl<'db> ConstraintsIterator<'_, 'db> {
base_ty: Type<'db>,
place: ScopedPlaceId,
) -> Type<'db> {
let constraint_tys: Vec<_> = self
let constraints: Vec<_> = self
.filter_map(|constraint| infer_narrowing_constraint(db, constraint, place))
.collect();
if constraint_tys.is_empty() {
if constraints.is_empty() {
return base_ty;
}
// Separate truthiness constraints from intersection constraints.
// We apply truthiness filtering after intersecting with other constraints,
// so that pattern narrowing happens first.
let mut truthiness_constraints = Vec::new();
let mut intersection_types = Vec::new();
for constraint in constraints.into_iter().rev() {
match constraint {
NarrowingConstraint::Truthiness(is_truthy) => {
truthiness_constraints.push(is_truthy);
}
NarrowingConstraint::IntersectWith(ty) => {
intersection_types.push(ty);
}
}
}
// First intersect with type constraints.
let result_ty = if intersection_types.is_empty() {
base_ty
} else {
constraint_tys
intersection_types
.into_iter()
.rev()
.fold(
IntersectionBuilder::new(db).add_positive(base_ty),
IntersectionBuilder::add_positive,
)
.build()
}
};
// Then apply truthiness filtering.
truthiness_constraints
.into_iter()
.fold(result_ty, |ty, is_truthy| {
ty.filter_for_truthiness(db, is_truthy)
})
}
}

View File

@@ -65,7 +65,7 @@ use crate::types::generics::{
walk_generic_context,
};
use crate::types::mro::{Mro, MroError, MroIterator};
pub(crate) use crate::types::narrow::infer_narrowing_constraint;
pub(crate) use crate::types::narrow::{NarrowingConstraint, infer_narrowing_constraint};
use crate::types::newtype::NewType;
pub(crate) use crate::types::signatures::{Parameter, Parameters};
use crate::types::signatures::{ParameterForm, walk_signature};
@@ -1467,6 +1467,188 @@ impl<'db> Type<'db> {
if yes { self.negate(db) } else { *self }
}
/// Removes both `~AlwaysTruthy` and `~AlwaysFalsy` constraints from this type.
///
/// For intersection types, this removes `AlwaysTruthy` and `AlwaysFalsy` from the negative set.
/// For union types and generic aliases, this applies recursively to each element/type argument.
/// Other types are returned unchanged.
///
/// This is useful for simplifying types for display, as these truthiness constraints
/// are rarely useful to show and make types more confusing.
#[must_use]
pub(crate) fn remove_truthiness_constraints(self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::Union(union) => {
// Transform elements and check if any changed.
let transformed: Box<[Type<'db>]> = union
.elements(db)
.iter()
.map(|elem| elem.remove_truthiness_constraints(db))
.collect();
let changed = transformed
.iter()
.zip(union.elements(db).iter())
.any(|(new, old)| new != old);
if !changed {
self
} else {
UnionType::from_elements(db, transformed.iter().copied())
}
}
Type::Intersection(intersection) => {
// First, recursively transform positive elements.
let transformed_positives: Box<[Type<'db>]> = intersection
.positive(db)
.iter()
.map(|ty| ty.remove_truthiness_constraints(db))
.collect();
// Filter out `AlwaysTruthy` and ``AlwaysFalsy from negatives.
let filtered_negatives: Box<[Type<'db>]> = intersection
.negative(db)
.iter()
.filter(|ty| !matches!(ty, Type::AlwaysTruthy | Type::AlwaysFalsy))
.copied()
.collect();
let positives_changed = transformed_positives
.iter()
.zip(intersection.positive(db).iter())
.any(|(new, old)| new != old);
let negatives_changed = filtered_negatives.len() != intersection.negative(db).len();
if !positives_changed && !negatives_changed {
// Nothing changed, return unchanged.
self
} else {
IntersectionBuilder::new(db)
.positive_elements(transformed_positives.iter().copied())
.negative_elements(filtered_negatives.iter().copied())
.build()
}
}
Type::GenericAlias(alias) => {
// Transform type arguments recursively.
let specialization = alias.specialization(db);
let transformed_types: Box<[Type<'db>]> = specialization
.types(db)
.iter()
.map(|ty| ty.remove_truthiness_constraints(db))
.collect();
// Check if any types changed.
let changed = transformed_types
.iter()
.zip(specialization.types(db).iter())
.any(|(new, old)| new != old);
if !changed {
self
} else {
let new_specialization = Specialization::new(
db,
specialization.generic_context(db),
transformed_types,
specialization.materialization_kind(db),
specialization.tuple_inner(db),
);
Type::GenericAlias(GenericAlias::new(db, alias.origin(db), new_specialization))
}
}
Type::NominalInstance(instance) => {
// Transform the class type if it's generic.
let class = instance.class(db);
match class {
ClassType::Generic(alias) => {
// Transform the generic alias and create a new instance.
let transformed =
Type::GenericAlias(alias).remove_truthiness_constraints(db);
if let Type::GenericAlias(new_alias) = transformed {
if new_alias == alias {
self
} else {
Type::instance(db, ClassType::Generic(new_alias))
}
} else {
// If it simplified to something else (unlikely), return as-is.
self
}
}
ClassType::NonGeneric(_) => self,
}
}
_ => self,
}
}
/// Filters this type to keep only elements matching the desired truthiness.
///
/// When `is_truthy` is true, keeps truthy elements and removes falsy ones.
/// When `is_truthy` is false, keeps falsy elements and removes truthy ones.
///
/// Special handling is provided for:
/// - `bool` → narrows to `Literal[True]` or `Literal[False]`
/// - `LiteralString` → narrows to `LiteralString & ~Literal[""]` or `Literal[""]`
/// - Unions → recursively filters each element
///
/// For other types with ambiguous truthiness (like `int`), returns the type unchanged
/// since we can't statically determine which elements match.
///
/// This is used for truthiness narrowing (e.g., `if x:` narrows x to truthy elements).
#[must_use]
pub(crate) fn filter_for_truthiness(self, db: &'db dyn Db, is_truthy: bool) -> Type<'db> {
// Check if the type's truthiness already matches or contradicts what we want.
match self.bool(db).negate_if(!is_truthy) {
Truthiness::AlwaysTrue => return self,
Truthiness::AlwaysFalse => return Type::Never,
Truthiness::Ambiguous => {}
}
// Truthiness is ambiguous. Try to filter recursively for compound types.
match self {
Type::Union(union) => UnionType::from_elements(
db,
union
.elements(db)
.iter()
.map(|elem| elem.filter_for_truthiness(db, is_truthy)),
),
Type::Intersection(_) => {
// For intersections, we can't easily filter. But we can check if
// any element makes the intersection definitely wrong.
// For now, return the intersection unchanged for ambiguous cases.
// The subtype check in add_negative handles the definite cases.
self
}
Type::NominalInstance(instance) if instance.has_known_class(db, KnownClass::Bool) => {
// bool has ambiguous truthiness: True is truthy, False is falsy
if is_truthy {
Type::BooleanLiteral(true)
} else {
Type::BooleanLiteral(false)
}
}
Type::LiteralString => {
// LiteralString has ambiguous truthiness: "" is falsy, all others truthy.
if is_truthy {
IntersectionBuilder::new(db)
.add_positive(Type::LiteralString)
.add_negative(Type::string_literal(db, ""))
.build()
} else {
Type::string_literal(db, "")
}
}
_ => {
// For other types with ambiguous truthiness, we can't filter.
// Just return the type unchanged.
self
}
}
}
/// If the type is a union, filters union elements based on the provided predicate.
///
/// Otherwise, returns the type unchanged.

View File

@@ -908,6 +908,17 @@ impl<'db> IntersectionBuilder<'db> {
self
}
pub(crate) fn negative_elements<I, T>(mut self, elements: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<Type<'db>>,
{
for element in elements {
self = self.add_negative(element.into());
}
self
}
pub(crate) fn build(mut self) -> Type<'db> {
// Avoid allocating the UnionBuilder unnecessarily if we have just one intersection:
if self.intersections.len() == 1 {
@@ -1257,6 +1268,26 @@ impl<'db> InnerIntersectionBuilder<'db> {
fn build(mut self, db: &'db dyn Db, order_elements: bool) -> Type<'db> {
self.simplify_constrained_typevars(db);
// Strip out redundant ~AlwaysTruthy and ~AlwaysFalsy constraints from the negative set.
// These are only redundant when all positive elements already have a determined truthiness
// that makes the constraint unnecessary:
// - Strip ~AlwaysFalsy when ALL positive elements are always truthy
// - Strip ~AlwaysTruthy when ALL positive elements are always falsy
//
// We keep them when they're the only element (standalone negation like ~AlwaysFalsy)
// because that's how truthiness narrowing constraints are represented.
// See https://github.com/astral-sh/ty/issues/2233
if !self.positive.is_empty() && !self.negative.is_empty() {
let all_always_truthy = self.positive.iter().all(|ty| ty.bool(db).is_always_true());
let all_always_falsy = self.positive.iter().all(|ty| ty.bool(db).is_always_false());
self.negative.retain(|ty| match ty {
Type::AlwaysFalsy => !all_always_truthy,
Type::AlwaysTruthy => !all_always_falsy,
_ => true,
});
}
// If any typevars are in `self.positive`, speculatively solve all bounded type variables
// to their upper bound and all constrained type variables to the union of their constraints.
// If that speculative intersection simplifies to `Never`, this intersection must also simplify
@@ -1470,4 +1501,64 @@ mod tests {
assert_eq!(actual.display(&db).to_string(), "Literal[SafeUUID.unknown]");
}
}
#[test]
fn build_intersection_always_falsy_int_literal() {
let db = setup_db();
// IntLiteral(0) is always falsy, so IntLiteral(0) & ~AlwaysFalsy should be Never
let ty = IntersectionBuilder::new(&db)
.add_positive(Type::IntLiteral(0))
.add_negative(Type::AlwaysFalsy)
.build();
assert_eq!(ty, Type::Never);
// Same with reversed order
let ty = IntersectionBuilder::new(&db)
.add_negative(Type::AlwaysFalsy)
.add_positive(Type::IntLiteral(0))
.build();
assert_eq!(ty, Type::Never);
// IntLiteral(1) is always truthy, so IntLiteral(1) & ~AlwaysFalsy should be IntLiteral(1)
let ty = IntersectionBuilder::new(&db)
.add_positive(Type::IntLiteral(1))
.add_negative(Type::AlwaysFalsy)
.build();
assert_eq!(ty, Type::IntLiteral(1));
// BooleanLiteral(false) is always falsy
let ty = IntersectionBuilder::new(&db)
.add_positive(Type::BooleanLiteral(false))
.add_negative(Type::AlwaysFalsy)
.build();
assert_eq!(ty, Type::Never);
// BooleanLiteral(true) is always truthy
let ty = IntersectionBuilder::new(&db)
.add_positive(Type::BooleanLiteral(true))
.add_negative(Type::AlwaysFalsy)
.build();
assert_eq!(ty, Type::BooleanLiteral(true));
// Test a union: Literal[0, 1] & ~AlwaysFalsy should become Literal[1]
let union = UnionType::from_elements(&db, [Type::IntLiteral(0), Type::IntLiteral(1)]);
let ty = IntersectionBuilder::new(&db)
.add_positive(union)
.add_negative(Type::AlwaysFalsy)
.build();
assert_eq!(ty, Type::IntLiteral(1));
// Test with isinstance-like constraint: Literal[0, 1] & (int & ~AlwaysFalsy)
let int_ty = KnownClass::Int.to_instance(&db);
let constraint = IntersectionBuilder::new(&db)
.add_positive(int_ty)
.add_negative(Type::AlwaysFalsy)
.build();
let ty = IntersectionBuilder::new(&db)
.add_positive(union)
.add_positive(constraint)
.build();
assert_eq!(ty, Type::IntLiteral(1));
}
}

View File

@@ -1545,9 +1545,12 @@ impl KnownFunction {
{
let mut diag = builder.into_diagnostic("Revealed type");
let span = context.span(&call_expression.arguments.args[0]);
// Remove ~AlwaysTruthy and ~AlwaysFalsy constraints for cleaner display,
// as they make types more confusing without adding useful information.
let display_type = revealed_type.remove_truthiness_constraints(db);
diag.annotate(Annotation::primary(span).message(format_args!(
"`{}`",
revealed_type
display_type
.display_with(db, DisplaySettings::default().preserve_long_unions())
)));
}

View File

@@ -750,7 +750,7 @@ pub struct Specialization<'db> {
/// For specializations of `tuple`, we also store more detailed information about the tuple's
/// elements, above what the class's (single) typevar can represent.
tuple_inner: Option<TupleType<'db>>,
pub(super) tuple_inner: Option<TupleType<'db>>,
}
// The Salsa heap is tracked separately.

View File

@@ -31,6 +31,24 @@ use ruff_python_ast::{BoolOp, ExprBoolOp};
use rustc_hash::FxHashMap;
use std::collections::hash_map::Entry;
/// A narrowing constraint that can be applied to a type.
///
/// This represents either a type to intersect with (for most narrowing operations)
/// or a truthiness filter (for `if x:` style narrowing).
#[derive(Clone, Copy, Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub(crate) enum NarrowingConstraint<'db> {
/// Intersect with this type.
///
/// For example, `isinstance(x, int)` creates `IntersectWith(int)`.
IntersectWith(Type<'db>),
/// Filter for truthiness.
///
/// `Truthiness(true)` means keep truthy values (from `if x:`).
/// `Truthiness(false)` means keep falsy values (from `if not x:` or the else branch).
Truthiness(bool),
}
/// Return the type constraint that `test` (if true) would place on `symbol`, if any.
///
/// For example, if we have this code:
@@ -51,7 +69,7 @@ pub(crate) fn infer_narrowing_constraint<'db>(
db: &'db dyn Db,
predicate: Predicate<'db>,
place: ScopedPlaceId,
) -> Option<Type<'db>> {
) -> Option<NarrowingConstraint<'db>> {
let constraints = match predicate.node {
PredicateNode::Expression(expression) => {
if predicate.is_positive {
@@ -277,7 +295,20 @@ impl ClassInfoConstraintFunction {
}
}
type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, Type<'db>>;
type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, NarrowingConstraint<'db>>;
impl<'db> NarrowingConstraint<'db> {
/// Convert this constraint to a type for intersection operations.
///
/// `Truthiness` constraints are converted to `~AlwaysFalsy` or `~AlwaysTruthy`.
fn to_type(self, db: &'db dyn Db) -> Type<'db> {
match self {
NarrowingConstraint::IntersectWith(ty) => ty,
NarrowingConstraint::Truthiness(true) => Type::AlwaysFalsy.negate(db),
NarrowingConstraint::Truthiness(false) => Type::AlwaysTruthy.negate(db),
}
}
}
fn merge_constraints_and<'db>(
into: &mut NarrowingConstraints<'db>,
@@ -287,10 +318,26 @@ fn merge_constraints_and<'db>(
for (key, value) in from {
match into.entry(*key) {
Entry::Occupied(mut entry) => {
*entry.get_mut() = IntersectionBuilder::new(db)
.add_positive(*entry.get())
.add_positive(*value)
.build();
let merged = match (*entry.get(), *value) {
// Two truthiness constraints: if same, keep; if different, convert to types
// and let the intersection builder + filter_for_truthiness handle it.
// We can't just return Never because types with mutable truthiness
// (like class instances) can satisfy both truthy and falsy at different times.
(NarrowingConstraint::Truthiness(a), NarrowingConstraint::Truthiness(b))
if a == b =>
{
NarrowingConstraint::Truthiness(a)
}
// Mixed or two IntersectWith or different truthiness: convert to types and intersect
(existing, new) => {
let merged_ty = IntersectionBuilder::new(db)
.add_positive(existing.to_type(db))
.add_positive(new.to_type(db))
.build();
NarrowingConstraint::IntersectWith(merged_ty)
}
};
*entry.get_mut() = merged;
}
Entry::Vacant(entry) => {
entry.insert(*value);
@@ -307,16 +354,35 @@ fn merge_constraints_or<'db>(
for (key, value) in from {
match into.entry(*key) {
Entry::Occupied(mut entry) => {
*entry.get_mut() = UnionBuilder::new(db).add(*entry.get()).add(*value).build();
let merged = match (*entry.get(), *value) {
// Two truthiness constraints: if same, keep; if different, no constraint
(NarrowingConstraint::Truthiness(a), NarrowingConstraint::Truthiness(b)) => {
if a == b {
NarrowingConstraint::Truthiness(a)
} else {
// truthy OR falsy = any truthiness = object
NarrowingConstraint::IntersectWith(Type::object())
}
}
// Mixed or two IntersectWith: convert to types and union
(existing, new) => {
let merged_ty = UnionBuilder::new(db)
.add(existing.to_type(db))
.add(new.to_type(db))
.build();
NarrowingConstraint::IntersectWith(merged_ty)
}
};
*entry.get_mut() = merged;
}
Entry::Vacant(entry) => {
entry.insert(Type::object());
entry.insert(NarrowingConstraint::IntersectWith(Type::object()));
}
}
}
for (key, value) in into.iter_mut() {
if !from.contains_key(key) {
*value = Type::object();
*value = NarrowingConstraint::IntersectWith(Type::object());
}
}
}
@@ -510,8 +576,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
/// Narrow a type based on `len()`, only narrowing the parts that are safe to narrow.
///
/// For narrowable types (literals, tuples), we apply `~AlwaysFalsy` (positive) or
/// `~AlwaysTruthy` (negative). For non-narrowable types, we return them unchanged.
/// For narrowable types (literals, tuples), we filter by truthiness directly.
/// For non-narrowable types, we return them unchanged.
///
/// Returns `None` if no part of the type is narrowable.
fn narrow_type_by_len(db: &'db dyn Db, ty: Type<'db>, is_positive: bool) -> Option<Type<'db>> {
@@ -547,26 +613,15 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.any(|element| Self::is_base_type_narrowable_by_len(db, *element));
if has_narrowable {
// Apply the narrowing constraint to the whole intersection.
let mut builder = IntersectionBuilder::new(db).add_positive(ty);
if is_positive {
builder = builder.add_negative(Type::AlwaysFalsy);
} else {
builder = builder.add_negative(Type::AlwaysTruthy);
}
Some(builder.build())
// Apply truthiness filter directly to the intersection.
Some(ty.filter_for_truthiness(db, is_positive))
} else {
None
}
}
_ if Self::is_base_type_narrowable_by_len(db, ty) => {
let mut builder = IntersectionBuilder::new(db).add_positive(ty);
if is_positive {
builder = builder.add_negative(Type::AlwaysFalsy);
} else {
builder = builder.add_negative(Type::AlwaysTruthy);
}
Some(builder.build())
// Apply truthiness filter directly.
Some(ty.filter_for_truthiness(db, is_positive))
}
_ => None,
}
@@ -580,13 +635,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let target = place_expr(expr)?;
let place = self.expect_place(&target);
let ty = if is_positive {
Type::AlwaysFalsy.negate(self.db)
} else {
Type::AlwaysTruthy.negate(self.db)
};
Some(NarrowingConstraints::from_iter([(place, ty)]))
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::Truthiness(is_positive),
)]))
}
fn evaluate_expr_named(
@@ -917,7 +969,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.collect();
if filtered.len() < union.elements(self.db).len() {
let place = self.expect_place(&subscript_place_expr);
constraints.insert(place, UnionType::from_elements(self.db, filtered));
constraints.insert(
place,
NarrowingConstraint::IntersectWith(UnionType::from_elements(self.db, filtered)),
);
}
}
@@ -983,7 +1038,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
// As mentioned above, the synthesized `TypedDict` is always negated.
let intersection = Type::TypedDict(synthesized_typeddict).negate(self.db);
let place = self.expect_place(&subscript_place_expr);
constraints.insert(place, intersection);
constraints.insert(place, NarrowingConstraint::IntersectWith(intersection));
}
}
@@ -1004,7 +1059,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive)
{
let place = self.expect_place(&left);
constraints.insert(place, ty);
constraints.insert(place, NarrowingConstraint::IntersectWith(ty));
}
}
ast::Expr::Call(ast::ExprCall {
@@ -1052,8 +1107,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let place = self.expect_place(&target);
constraints.insert(
place,
Type::instance(self.db, rhs_class.unknown_specialization(self.db))
.negate_if(self.db, !is_positive),
NarrowingConstraint::IntersectWith(
Type::instance(self.db, rhs_class.unknown_specialization(self.db))
.negate_if(self.db, !is_positive),
),
);
}
}
@@ -1093,7 +1150,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
Some(NarrowingConstraints::from_iter([(
place,
guarded_ty.negate_if(self.db, !is_positive),
NarrowingConstraint::IntersectWith(guarded_ty.negate_if(self.db, !is_positive)),
)]))
}
// For the expression `len(E)`, we narrow the type based on whether len(E) is truthy
@@ -1112,7 +1169,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
if let Some(narrowed_ty) = Self::narrow_type_by_len(self.db, arg_ty, is_positive) {
let target = place_expr(arg)?;
let place = self.expect_place(&target);
Some(NarrowingConstraints::from_iter([(place, narrowed_ty)]))
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::IntersectWith(narrowed_ty),
)]))
} else {
None
}
@@ -1142,7 +1202,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
return Some(NarrowingConstraints::from_iter([(
place,
constraint.negate_if(self.db, !is_positive),
NarrowingConstraint::IntersectWith(
constraint.negate_if(self.db, !is_positive),
),
)]));
}
@@ -1155,7 +1217,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.map(|constraint| {
NarrowingConstraints::from_iter([(
place,
constraint.negate_if(self.db, !is_positive),
NarrowingConstraint::IntersectWith(
constraint.negate_if(self.db, !is_positive),
),
)])
})
}
@@ -1190,7 +1254,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
ast::Singleton::False => Type::BooleanLiteral(false),
};
let ty = ty.negate_if(self.db, !is_positive);
Some(NarrowingConstraints::from_iter([(place, ty)]))
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::IntersectWith(ty),
)]))
}
fn evaluate_match_pattern_class(
@@ -1223,7 +1290,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
_ => return None,
};
Some(NarrowingConstraints::from_iter([(place, narrowed_type)]))
Some(NarrowingConstraints::from_iter([(
place,
NarrowingConstraint::IntersectWith(narrowed_type),
)]))
}
fn evaluate_match_pattern_value(
@@ -1243,7 +1313,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module);
self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive)
.map(|ty| NarrowingConstraints::from_iter([(place, ty)]))
.map(|ty| {
NarrowingConstraints::from_iter([(place, NarrowingConstraint::IntersectWith(ty))])
})
}
fn evaluate_match_pattern_or(