[ty] increase the limit on the number of elements in a non-recursively defined literal union (#21683)

## Summary

Closes https://github.com/astral-sh/ty/issues/957

As explained in https://github.com/astral-sh/ty/issues/957, literal
union types for recursively defined values ​​can be widened early to
speed up the convergence of fixed-point iterations.
This PR achieves this by embedding a marker in `UnionType` that
distinguishes whether a value is recursively defined.

This also allows us to identify values ​​that are not recursively
defined, so I've increased the limit on the number of elements in a
literal union type for such values.

Edit: while this PR doesn't provide the significant performance
improvement initially hoped for, it does have the benefit of allowing
the number of elements in a literal union to be raised above the salsa
limit, and indeed mypy_primer results revealed that a literal union of
220 elements was actually being used.

## Test Plan

`call/union.md` has been updated
This commit is contained in:
Shunsuke Shibayama
2025-12-05 11:01:48 +09:00
committed by GitHub
parent a9de6b5c3e
commit f3e5713d90
6 changed files with 138 additions and 22 deletions

View File

@@ -44,6 +44,7 @@ use crate::semantic_index::scope::ScopeId;
use crate::semantic_index::{imported_modules, place_table, semantic_index};
use crate::suppression::check_suppressions;
use crate::types::bound_super::BoundSuperType;
use crate::types::builder::RecursivelyDefined;
use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding};
pub(crate) use crate::types::class_base::ClassBase;
use crate::types::constraints::{
@@ -9668,6 +9669,7 @@ impl<'db> TypeVarInstance<'db> {
.skip(1)
.map(|arg| definition_expression_type(db, definition, arg))
.collect::<Box<_>>(),
RecursivelyDefined::No,
)
}
_ => return None,
@@ -10120,6 +10122,7 @@ impl<'db> TypeVarBoundOrConstraints<'db> {
.iter()
.map(|ty| ty.normalized_impl(db, visitor))
.collect::<Box<_>>(),
constraints.recursively_defined(db),
))
}
}
@@ -10201,6 +10204,7 @@ impl<'db> TypeVarBoundOrConstraints<'db> {
.iter()
.map(|ty| ty.materialize(db, materialization_kind, visitor))
.collect::<Box<_>>(),
RecursivelyDefined::No,
))
}
}
@@ -13142,6 +13146,9 @@ pub struct UnionType<'db> {
/// The union type includes values in any of these types.
#[returns(deref)]
pub elements: Box<[Type<'db>]>,
/// Whether the value pointed to by this type is recursively defined.
/// If `Yes`, union literal widening is performed early.
recursively_defined: RecursivelyDefined,
}
pub(crate) fn walk_union<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
@@ -13226,7 +13233,14 @@ impl<'db> UnionType<'db> {
db: &'db dyn Db,
transform_fn: impl FnMut(&Type<'db>) -> Type<'db>,
) -> Type<'db> {
Self::from_elements(db, self.elements(db).iter().map(transform_fn))
self.elements(db)
.iter()
.map(transform_fn)
.fold(UnionBuilder::new(db), |builder, element| {
builder.add(element)
})
.recursively_defined(self.recursively_defined(db))
.build()
}
/// A fallible version of [`UnionType::map`].
@@ -13241,7 +13255,12 @@ impl<'db> UnionType<'db> {
db: &'db dyn Db,
transform_fn: impl FnMut(&Type<'db>) -> Option<Type<'db>>,
) -> Option<Type<'db>> {
Self::try_from_elements(db, self.elements(db).iter().map(transform_fn))
let mut builder = UnionBuilder::new(db);
for element in self.elements(db).iter().map(transform_fn) {
builder = builder.add(element?);
}
builder = builder.recursively_defined(self.recursively_defined(db));
Some(builder.build())
}
pub(crate) fn to_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {
@@ -13253,7 +13272,14 @@ impl<'db> UnionType<'db> {
db: &'db dyn Db,
mut f: impl FnMut(&Type<'db>) -> bool,
) -> Type<'db> {
Self::from_elements(db, self.elements(db).iter().filter(|ty| f(ty)))
self.elements(db)
.iter()
.filter(|ty| f(ty))
.fold(UnionBuilder::new(db), |builder, element| {
builder.add(*element)
})
.recursively_defined(self.recursively_defined(db))
.build()
}
pub(crate) fn map_with_boundness(
@@ -13288,7 +13314,9 @@ impl<'db> UnionType<'db> {
Place::Undefined
} else {
Place::Defined(
builder.build(),
builder
.recursively_defined(self.recursively_defined(db))
.build(),
origin,
if possibly_unbound {
Definedness::PossiblyUndefined
@@ -13336,7 +13364,9 @@ impl<'db> UnionType<'db> {
Place::Undefined
} else {
Place::Defined(
builder.build(),
builder
.recursively_defined(self.recursively_defined(db))
.build(),
origin,
if possibly_unbound {
Definedness::PossiblyUndefined
@@ -13371,6 +13401,7 @@ impl<'db> UnionType<'db> {
.unpack_aliases(true),
UnionBuilder::add,
)
.recursively_defined(self.recursively_defined(db))
.build()
}
@@ -13383,7 +13414,8 @@ impl<'db> UnionType<'db> {
let mut builder = UnionBuilder::new(db)
.order_elements(false)
.unpack_aliases(false)
.cycle_recovery(true);
.cycle_recovery(true)
.recursively_defined(self.recursively_defined(db));
let mut empty = true;
for ty in self.elements(db) {
if nested {
@@ -13398,6 +13430,7 @@ impl<'db> UnionType<'db> {
// `Divergent` in a union type does not mean true divergence, so we skip it if not nested.
// e.g. T | Divergent == T | (T | (T | (T | ...))) == T
if ty == &div {
builder = builder.recursively_defined(RecursivelyDefined::Yes);
continue;
}
builder = builder.add(

View File

@@ -202,12 +202,30 @@ enum ReduceResult<'db> {
Type(Type<'db>),
}
// TODO increase this once we extend `UnionElement` throughout all union/intersection
// representations, so that we can make large unions of literals fast in all operations.
//
// For now (until we solve https://github.com/astral-sh/ty/issues/957), keep this number
// below 200, which is the salsa fixpoint iteration limit.
const MAX_UNION_LITERALS: usize = 190;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, get_size2::GetSize)]
pub enum RecursivelyDefined {
Yes,
No,
}
impl RecursivelyDefined {
const fn is_yes(self) -> bool {
matches!(self, RecursivelyDefined::Yes)
}
const fn or(self, other: RecursivelyDefined) -> RecursivelyDefined {
match (self, other) {
(RecursivelyDefined::Yes, _) | (_, RecursivelyDefined::Yes) => RecursivelyDefined::Yes,
_ => RecursivelyDefined::No,
}
}
}
/// If the value is defined recursively, widening is performed from fewer literal elements, resulting in faster convergence of the fixed-point iteration.
const MAX_RECURSIVE_UNION_LITERALS: usize = 10;
/// If the value is defined non-recursively, the fixed-point iteration will converge in one go,
/// so in principle we can have as many literal elements as we want, but to avoid unintended huge computational loads, we limit it to 256.
const MAX_NON_RECURSIVE_UNION_LITERALS: usize = 256;
pub(crate) struct UnionBuilder<'db> {
elements: Vec<UnionElement<'db>>,
@@ -217,6 +235,7 @@ pub(crate) struct UnionBuilder<'db> {
// This is enabled when joining types in a `cycle_recovery` function.
// Since a cycle cannot be created within a `cycle_recovery` function, execution of `is_redundant_with` is skipped.
cycle_recovery: bool,
recursively_defined: RecursivelyDefined,
}
impl<'db> UnionBuilder<'db> {
@@ -227,6 +246,7 @@ impl<'db> UnionBuilder<'db> {
unpack_aliases: true,
order_elements: false,
cycle_recovery: false,
recursively_defined: RecursivelyDefined::No,
}
}
@@ -248,6 +268,11 @@ impl<'db> UnionBuilder<'db> {
self
}
pub(crate) fn recursively_defined(mut self, val: RecursivelyDefined) -> Self {
self.recursively_defined = val;
self
}
pub(crate) fn is_empty(&self) -> bool {
self.elements.is_empty()
}
@@ -258,6 +283,27 @@ impl<'db> UnionBuilder<'db> {
self.elements.push(UnionElement::Type(Type::object()));
}
fn widen_literal_types(&mut self, seen_aliases: &mut Vec<Type<'db>>) {
let mut replace_with = vec![];
for elem in &self.elements {
match elem {
UnionElement::IntLiterals(_) => {
replace_with.push(KnownClass::Int.to_instance(self.db));
}
UnionElement::StringLiterals(_) => {
replace_with.push(KnownClass::Str.to_instance(self.db));
}
UnionElement::BytesLiterals(_) => {
replace_with.push(KnownClass::Bytes.to_instance(self.db));
}
UnionElement::Type(_) => {}
}
}
for ty in replace_with {
self.add_in_place_impl(ty, seen_aliases);
}
}
/// Adds a type to this union.
pub(crate) fn add(mut self, ty: Type<'db>) -> Self {
self.add_in_place(ty);
@@ -270,6 +316,15 @@ impl<'db> UnionBuilder<'db> {
}
pub(crate) fn add_in_place_impl(&mut self, ty: Type<'db>, seen_aliases: &mut Vec<Type<'db>>) {
let cycle_recovery = self.cycle_recovery;
let should_widen = |literals, recursively_defined: RecursivelyDefined| {
if recursively_defined.is_yes() && cycle_recovery {
literals >= MAX_RECURSIVE_UNION_LITERALS
} else {
literals >= MAX_NON_RECURSIVE_UNION_LITERALS
}
};
match ty {
Type::Union(union) => {
let new_elements = union.elements(self.db);
@@ -277,6 +332,20 @@ impl<'db> UnionBuilder<'db> {
for element in new_elements {
self.add_in_place_impl(*element, seen_aliases);
}
self.recursively_defined = self
.recursively_defined
.or(union.recursively_defined(self.db));
if self.cycle_recovery && self.recursively_defined.is_yes() {
let literals = self.elements.iter().fold(0, |acc, elem| match elem {
UnionElement::IntLiterals(literals) => acc + literals.len(),
UnionElement::StringLiterals(literals) => acc + literals.len(),
UnionElement::BytesLiterals(literals) => acc + literals.len(),
UnionElement::Type(_) => acc,
});
if should_widen(literals, self.recursively_defined) {
self.widen_literal_types(seen_aliases);
}
}
}
// Adding `Never` to a union is a no-op.
Type::Never => {}
@@ -300,7 +369,7 @@ impl<'db> UnionBuilder<'db> {
for (index, element) in self.elements.iter_mut().enumerate() {
match element {
UnionElement::StringLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
if should_widen(literals.len(), self.recursively_defined) {
let replace_with = KnownClass::Str.to_instance(self.db);
self.add_in_place_impl(replace_with, seen_aliases);
return;
@@ -345,7 +414,7 @@ impl<'db> UnionBuilder<'db> {
for (index, element) in self.elements.iter_mut().enumerate() {
match element {
UnionElement::BytesLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
if should_widen(literals.len(), self.recursively_defined) {
let replace_with = KnownClass::Bytes.to_instance(self.db);
self.add_in_place_impl(replace_with, seen_aliases);
return;
@@ -390,7 +459,7 @@ impl<'db> UnionBuilder<'db> {
for (index, element) in self.elements.iter_mut().enumerate() {
match element {
UnionElement::IntLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
if should_widen(literals.len(), self.recursively_defined) {
let replace_with = KnownClass::Int.to_instance(self.db);
self.add_in_place_impl(replace_with, seen_aliases);
return;
@@ -585,6 +654,7 @@ impl<'db> UnionBuilder<'db> {
_ => Some(Type::Union(UnionType::new(
self.db,
types.into_boxed_slice(),
self.recursively_defined,
))),
}
}
@@ -696,6 +766,7 @@ impl<'db> IntersectionBuilder<'db> {
enum_member_literals(db, instance.class_literal(db), None)
.expect("Calling `enum_member_literals` on an enum class")
.collect::<Box<[_]>>(),
RecursivelyDefined::No,
)),
seen_aliases,
)

View File

@@ -50,6 +50,7 @@ use crate::semantic_index::{
ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table,
};
use crate::subscript::{PyIndex, PySlice};
use crate::types::builder::RecursivelyDefined;
use crate::types::call::bind::{CallableDescription, MatchingOverloadIndex};
use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind};
use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator};
@@ -3283,6 +3284,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
elts.iter()
.map(|expr| self.infer_type_expression(expr))
.collect::<Box<[_]>>(),
RecursivelyDefined::No,
));
self.store_expression_type(expr, ty);
}

View File

@@ -416,7 +416,7 @@ impl<'db> SubclassOfInner<'db> {
)
}
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
let constraints = constraints
let constraints_types = constraints
.elements(db)
.iter()
.map(|constraint| {
@@ -425,7 +425,11 @@ impl<'db> SubclassOfInner<'db> {
})
.collect::<Box<_>>();
TypeVarBoundOrConstraints::Constraints(UnionType::new(db, constraints))
TypeVarBoundOrConstraints::Constraints(UnionType::new(
db,
constraints_types,
constraints.recursively_defined(db),
))
}
})
});

View File

@@ -23,6 +23,7 @@ use itertools::{Either, EitherOrBoth, Itertools};
use crate::semantic_index::definition::Definition;
use crate::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError};
use crate::types::builder::RecursivelyDefined;
use crate::types::class::{ClassType, KnownClass};
use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
use crate::types::generics::InferableTypeVars;
@@ -1458,7 +1459,7 @@ impl<'db> Tuple<Type<'db>> {
// those techniques ensure that union elements are deduplicated and unions are eagerly simplified
// into other types where necessary. Here, however, we know that there are no duplicates
// in this union, so it's probably more efficient to use `UnionType::new()` directly.
Type::Union(UnionType::new(db, elements))
Type::Union(UnionType::new(db, elements, RecursivelyDefined::No))
};
TupleSpec::heterogeneous([