Compare commits

...

3 Commits

Author SHA1 Message Date
David Peter
a3f8f60e1e Same trick for tuples 2025-10-21 15:30:53 +02:00
David Peter
be94796e6d Same trick for tuples 2025-10-21 15:01:03 +02:00
David Peter
d2f73a404a [ty] Fall back to Divergent for deeply nested specializations 2025-10-21 14:52:33 +02:00
11 changed files with 255 additions and 28 deletions

View File

@@ -2457,6 +2457,31 @@ class Counter:
reveal_type(Counter().count) # revealed: Unknown | int
```
We also handle infinitely nested generics:
```py
class NestedLists:
def __init__(self: "NestedLists"):
self.x = 1
def f(self: "NestedLists"):
self.x = [self.x]
reveal_type(NestedLists().x) # revealed: Unknown | Literal[1] | list[Divergent]
class NestedMixed:
def f(self: "NestedMixed"):
self.x = [self.x]
def g(self: "NestedMixed"):
self.x = {self.x}
def h(self: "NestedMixed"):
self.x = {"a": self.x}
reveal_type(NestedMixed().x) # revealed: Unknown | list[Divergent] | set[Divergent] | dict[Unknown | str, Divergent]
```
### Builtin types attributes
This test can probably be removed eventually, but we currently include it because we do not yet
@@ -2551,13 +2576,28 @@ reveal_type(Answer.__members__) # revealed: MappingProxyType[str, Unknown]
## Divergent inferred implicit instance attribute types
```py
# TODO: This test currently panics, see https://github.com/astral-sh/ty/issues/837
class C:
def f(self, other: "C"):
self.x = (other.x, 1)
# class C:
# def f(self, other: "C"):
# self.x = (other.x, 1)
reveal_type(C().x) # revealed: Unknown | tuple[Divergent, Literal[1]]
```
That also works if the tuple is not constructed directly:
```py
# from typing import TypeVar, Literal
#
# reveal_type(C().x) # revealed: Unknown | tuple[Divergent, Literal[1]]
# T = TypeVar("T")
#
# def make_tuple(a: T) -> tuple[T, Literal[1]]:
# return (a, 1)
#
# class D:
# def f(self, other: "D"):
# self.x = make_tuple(other.x)
#
# reveal_type(D().x) # revealed: Unknown | tuple[Divergent, Literal[1]]
```
## Attributes of standard library modules that aren't yet defined

View File

@@ -69,7 +69,7 @@ use crate::types::tuple::{TupleSpec, TupleSpecBuilder};
pub(crate) use crate::types::typed_dict::{TypedDictParams, TypedDictType, walk_typed_dict_type};
pub use crate::types::variance::TypeVarVariance;
use crate::types::variance::VarianceInferable;
use crate::types::visitor::any_over_type;
use crate::types::visitor::{any_over_type, specialization_depth};
use crate::unpack::EvaluationMode;
use crate::{Db, FxOrderSet, Module, Program};
pub(crate) use class::{ClassLiteral, ClassType, GenericAlias, KnownClass};
@@ -827,10 +827,14 @@ impl<'db> Type<'db> {
Self::Dynamic(DynamicType::Unknown)
}
pub(crate) fn divergent(scope: ScopeId<'db>) -> Self {
pub(crate) fn divergent(scope: Option<ScopeId<'db>>) -> Self {
Self::Dynamic(DynamicType::Divergent(DivergentType { scope }))
}
pub(crate) const fn is_divergent(&self) -> bool {
matches!(self, Type::Dynamic(DynamicType::Divergent(_)))
}
pub const fn is_unknown(&self) -> bool {
matches!(self, Type::Dynamic(DynamicType::Unknown))
}
@@ -7642,7 +7646,7 @@ impl<'db> KnownInstanceType<'db> {
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)]
pub struct DivergentType<'db> {
/// The scope where this divergence was detected.
scope: ScopeId<'db>,
scope: Option<ScopeId<'db>>,
}
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)]
@@ -11747,7 +11751,7 @@ pub(crate) mod tests {
let file_scope_id = FileScopeId::global();
let scope = file_scope_id.to_scope_id(&db, file);
let div = Type::Dynamic(DynamicType::Divergent(DivergentType { scope }));
let div = Type::Dynamic(DynamicType::Divergent(DivergentType { scope: Some(scope) }));
// The `Divergent` type must not be eliminated in union with other dynamic types,
// as this would prevent detection of divergent type inference using `Divergent`.

View File

@@ -301,7 +301,7 @@ fn expand_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Vec<Type<'db>>> {
}
})
.multi_cartesian_product()
.map(|types| Type::tuple(TupleType::heterogeneous(db, types)))
.map(|types| Type::heterogeneous_tuple(db, types))
.collect::<Vec<_>>();
if expanded.len() == 1 {

View File

@@ -30,14 +30,16 @@ use crate::types::member::{Member, class_member};
use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature};
use crate::types::tuple::{TupleSpec, TupleType};
use crate::types::typed_dict::typed_dict_params_from_class_def;
use crate::types::visitor::{NonAtomicType, TypeKind, TypeVisitor, walk_non_atomic_type};
use crate::types::visitor::{
MAX_SPECIALIZATION_DEPTH, NonAtomicType, TypeKind, TypeVisitor, walk_non_atomic_type,
};
use crate::types::{
ApplyTypeMappingVisitor, Binding, BoundSuperType, CallableType, DataclassFlags,
DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor,
IsDisjointVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType,
MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType,
TypeContext, TypeMapping, TypeRelation, TypedDictParams, UnionBuilder, VarianceInferable,
declaration_type, determine_upper_bound, infer_definition_types,
declaration_type, determine_upper_bound, infer_definition_types, specialization_depth,
};
use crate::{
Db, FxIndexMap, FxIndexSet, FxOrderSet, Program,
@@ -1612,7 +1614,15 @@ impl<'db> ClassLiteral<'db> {
match self.generic_context(db) {
None => ClassType::NonGeneric(self),
Some(generic_context) => {
let specialization = f(generic_context);
let mut specialization = f(generic_context);
for (idx, ty) in specialization.types(db).iter().enumerate() {
if specialization_depth(db, *ty) > MAX_SPECIALIZATION_DEPTH {
specialization =
specialization.with_replaced_type(db, idx, Type::divergent(None));
}
}
ClassType::Generic(GenericAlias::new(db, self, specialization))
}
}

View File

@@ -1264,6 +1264,27 @@ impl<'db> Specialization<'db> {
// A tuple's specialization will include all of its element types, so we don't need to also
// look in `self.tuple`.
}
/// Returns a copy of this specialization with the type at a given index replaced.
pub(crate) fn with_replaced_type(
self,
db: &'db dyn Db,
index: usize,
new_type: Type<'db>,
) -> Self {
debug_assert!(index < self.types(db).len());
let mut new_types: Box<[_]> = self.types(db).to_vec().into_boxed_slice();
new_types[index] = new_type;
Self::new(
db,
self.generic_context(db),
new_types,
self.materialization_kind(db),
self.tuple_inner(db),
)
}
}
/// A mapping between type variables and types.

View File

@@ -62,7 +62,7 @@ mod builder;
mod tests;
/// How many fixpoint iterations to allow before falling back to Divergent type.
const ITERATIONS_BEFORE_FALLBACK: u32 = 10;
const ITERATIONS_BEFORE_FALLBACK: u32 = 20;
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the
@@ -567,7 +567,7 @@ impl<'db> CycleRecovery<'db> {
fn fallback_type(self) -> Type<'db> {
match self {
Self::Initial => Type::Never,
Self::Divergent(scope) => Type::divergent(scope),
Self::Divergent(scope) => Type::divergent(Some(scope)),
}
}
}

View File

@@ -5968,16 +5968,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let mut annotated_elt_tys = annotated_tuple.as_ref().map(Tuple::all_elements);
let db = self.db();
let divergent = Type::divergent(self.scope());
let element_types = elts.iter().map(|element| {
let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied();
let element_type = self.infer_expression(element, TypeContext::new(annotated_elt_ty));
if element_type.has_divergent_type(self.db(), divergent) {
divergent
} else {
element_type
}
self.infer_expression(element, TypeContext::new(annotated_elt_ty))
});
Type::heterogeneous_tuple(db, element_types)

View File

@@ -22,7 +22,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
/// Infer the type of a type expression.
pub(super) fn infer_type_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
let mut ty = self.infer_type_expression_no_store(expression);
let divergent = Type::divergent(self.scope());
let divergent = Type::divergent(Some(self.scope()));
if ty.has_divergent_type(self.db(), divergent) {
ty = divergent;
}

View File

@@ -12,10 +12,11 @@ use crate::types::enums::is_single_member_enum;
use crate::types::generics::{InferableTypeVars, walk_specialization};
use crate::types::protocol_class::walk_protocol_interface;
use crate::types::tuple::{TupleSpec, TupleType};
use crate::types::visitor::MAX_SPECIALIZATION_DEPTH;
use crate::types::{
ApplyTypeMappingVisitor, ClassBase, ClassLiteral, FindLegacyTypeVarsVisitor,
HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, TypeContext,
TypeMapping, TypeRelation, VarianceInferable,
TypeMapping, TypeRelation, VarianceInferable, specialization_depth,
};
use crate::{Db, FxOrderSet};

View File

@@ -26,10 +26,11 @@ use crate::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroErro
use crate::types::class::{ClassType, KnownClass};
use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
use crate::types::generics::InferableTypeVars;
use crate::types::visitor::MAX_SPECIALIZATION_DEPTH;
use crate::types::{
ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor,
IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, Type, TypeMapping, TypeRelation,
UnionBuilder, UnionType,
UnionBuilder, UnionType, specialization_depth,
};
use crate::types::{Truthiness, TypeContext};
use crate::{Db, FxOrderSet, Program};
@@ -178,7 +179,16 @@ impl<'db> TupleType<'db> {
db: &'db dyn Db,
types: impl IntoIterator<Item = Type<'db>>,
) -> Option<Self> {
TupleType::new(db, &TupleSpec::heterogeneous(types))
TupleType::new(
db,
&TupleSpec::heterogeneous(types.into_iter().map(|ty| {
if specialization_depth(db, ty) > MAX_SPECIALIZATION_DEPTH {
Type::divergent(None)
} else {
ty
}
})),
)
}
#[cfg(test)]

View File

@@ -1,3 +1,5 @@
use rustc_hash::FxHashMap;
use crate::{
Db, FxIndexSet,
types::{
@@ -16,7 +18,10 @@ use crate::{
walk_typed_dict_type, walk_typeis_type, walk_union,
},
};
use std::cell::{Cell, RefCell};
use std::{
cell::{Cell, RefCell},
collections::hash_map::Entry,
};
/// A visitor trait that recurses into nested types.
///
@@ -295,3 +300,146 @@ pub(super) fn any_over_type<'db>(
visitor.visit_type(db, ty);
visitor.found_matching_type.get()
}
// To prevent infinite recursion during type inference for infinite types, we fall back to
// `C[Divergent]` once a certain amount of levels of specialization have occurred. For
// example:
//
// ```py
// x = 1
// while random_bool():
// x = [x]
//
// reveal_type(x) # Unknown | Literal[1] | list[Divergent]
// ```
pub(super) const MAX_SPECIALIZATION_DEPTH: usize = 10;
/// Returns the maximum number of layers of generic specializations for a given type.
///
/// For example, `int` has a depth of `0`, `list[int]` has a depth of `1`, and `list[set[int]]`
/// has a depth of `2`. A set-theoretic type like `list[int] | list[list[int]]` has a maximum
/// depth of `2`.
pub(super) fn specialization_depth(db: &dyn Db, ty: Type<'_>) -> usize {
struct SpecializationDepthVisitor<'db> {
seen_types: RefCell<FxHashMap<NonAtomicType<'db>, Option<usize>>>,
max_depth: Cell<usize>,
}
impl<'db> TypeVisitor<'db> for SpecializationDepthVisitor<'db> {
fn should_visit_lazy_type_attributes(&self) -> bool {
false
}
fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) {
match TypeKind::from(ty) {
TypeKind::Atomic => {
if ty.is_divergent() {
self.max_depth.set(usize::MAX);
}
}
TypeKind::NonAtomic(non_atomic_type) => {
match self.seen_types.borrow_mut().entry(non_atomic_type) {
Entry::Occupied(cached_depth) => {
self.max_depth
.update(|current| current.max(cached_depth.get().unwrap_or(0)));
return;
}
Entry::Vacant(entry) => {
entry.insert(None);
}
}
let self_depth: usize =
matches!(non_atomic_type, NonAtomicType::GenericAlias(_)).into();
let previous_max_depth = self.max_depth.replace(0);
walk_non_atomic_type(db, non_atomic_type, self);
self.max_depth.update(|max_child_depth| {
previous_max_depth.max(max_child_depth.saturating_add(self_depth))
});
self.seen_types
.borrow_mut()
.insert(non_atomic_type, Some(self.max_depth.get()));
}
}
}
}
let visitor = SpecializationDepthVisitor {
seen_types: RefCell::new(FxHashMap::default()),
max_depth: Cell::new(0),
};
visitor.visit_type(db, ty);
visitor.max_depth.get()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{db::tests::setup_db, types::KnownClass};
#[test]
fn test_generics_layering_depth() {
let db = setup_db();
let list_of_int =
KnownClass::List.to_specialized_instance(&db, [KnownClass::Int.to_instance(&db)]);
assert_eq!(specialization_depth(&db, list_of_int), 1);
let list_of_list_of_int = KnownClass::List.to_specialized_instance(&db, [list_of_int]);
assert_eq!(specialization_depth(&db, list_of_list_of_int), 2);
let list_of_list_of_list_of_int =
KnownClass::List.to_specialized_instance(&db, [list_of_list_of_int]);
assert_eq!(specialization_depth(&db, list_of_list_of_list_of_int), 3);
let set_of_dict_of_str_and_list_of_int = KnownClass::Set.to_specialized_instance(
&db,
[KnownClass::Dict
.to_specialized_instance(&db, [KnownClass::Str.to_instance(&db), list_of_int])],
);
assert_eq!(
specialization_depth(&db, set_of_dict_of_str_and_list_of_int),
3
);
let union_type_1 =
UnionType::from_elements(&db, [list_of_list_of_list_of_int, list_of_list_of_int]);
assert_eq!(specialization_depth(&db, union_type_1), 3);
let union_type_2 =
UnionType::from_elements(&db, [list_of_list_of_int, list_of_list_of_list_of_int]);
assert_eq!(specialization_depth(&db, union_type_2), 3);
let tuple_of_tuple_of_int = Type::heterogeneous_tuple(
&db,
[Type::heterogeneous_tuple(
&db,
[KnownClass::Int.to_instance(&db)],
)],
);
assert_eq!(specialization_depth(&db, tuple_of_tuple_of_int), 2);
let tuple_of_list_of_int_and_str = KnownClass::Tuple
.to_specialized_instance(&db, [list_of_int, KnownClass::Str.to_instance(&db)]);
assert_eq!(specialization_depth(&db, tuple_of_list_of_int_and_str), 1);
let list_of_union_of_lists = KnownClass::List.to_specialized_instance(
&db,
[UnionType::from_elements(
&db,
[
KnownClass::List
.to_specialized_instance(&db, [KnownClass::Int.to_instance(&db)]),
KnownClass::List
.to_specialized_instance(&db, [KnownClass::Str.to_instance(&db)]),
KnownClass::List
.to_specialized_instance(&db, [KnownClass::Bytes.to_instance(&db)]),
],
)],
);
assert_eq!(specialization_depth(&db, list_of_union_of_lists), 2);
}
}