Compare commits

...

5 Commits

Author SHA1 Message Date
Ibraheem Ahmed
fba4ac1dc6 improve literal promotion tests 2026-01-09 16:36:14 -05:00
Ibraheem Ahmed
d6d0b44fdb preserve bounds and constraints on synthetic type variables 2026-01-09 16:36:12 -05:00
Ibraheem Ahmed
5ed810a5d9 avoid creating invalid type mappings during synthetic inference 2026-01-09 16:35:49 -05:00
Ibraheem Ahmed
4812707fd7 avoid marking synthesized type variables as inferable unnecessarily 2026-01-09 16:35:48 -05:00
Ibraheem Ahmed
540593587d inform literal promotion with variance of inferred argument types 2026-01-09 16:34:53 -05:00
8 changed files with 534 additions and 130 deletions

View File

@@ -71,6 +71,8 @@ Literals are promoted if they are in non-covariant position in the return type o
function, or constructor of a generic class:
```py
from typing import Callable, Literal
class Bivariant[T]:
def __init__(self, value: T): ...
@@ -89,6 +91,12 @@ class Invariant[T]:
def __init__(self, value: T): ...
def takes_literal[T](x: Literal[1]):
raise NotImplementedError
def returns_literal[T]() -> Literal[1]:
raise NotImplementedError
def f1[T](x: T) -> Bivariant[T] | None: ...
def f2[T](x: T) -> Covariant[T] | None: ...
def f3[T](x: T) -> Covariant[T] | Bivariant[T] | None: ...
@@ -100,6 +108,8 @@ def f8[T](x: T) -> Invariant[T] | Covariant[T] | None: ...
def f9[T](x: T) -> tuple[Invariant[T], Invariant[T]] | None: ...
def f10[T, U](x: T, y: U) -> tuple[Invariant[T], Covariant[U]] | None: ...
def f11[T, U](x: T, y: U) -> tuple[Invariant[Covariant[T] | None], Covariant[U]] | None: ...
def f12[T](x: Callable[[T], None]) -> Invariant[T] | None: ...
def f13[T](x: Callable[[], T]) -> Invariant[T] | None: ...
reveal_type(Bivariant(1)) # revealed: Bivariant[Literal[1]]
reveal_type(Covariant(1)) # revealed: Covariant[Literal[1]]
@@ -120,6 +130,11 @@ reveal_type(f9(1)) # revealed: tuple[Invariant[int], Invariant[int]] | None
reveal_type(f10(1, 1)) # revealed: tuple[Invariant[int], Covariant[Literal[1]]] | None
reveal_type(f11(1, 1)) # revealed: tuple[Invariant[Covariant[int] | None], Covariant[Literal[1]]] | None
# TODO: we should avoid literal promotion here.
# error: [invalid-argument-type] "Argument to function `f12` is incorrect: Expected `(int, /) -> None`, found `def takes_literal[T](x: Literal[1]) -> Unknown`"
reveal_type(f12(takes_literal)) # revealed: Invariant[int] | None
reveal_type(f13(returns_literal)) # revealed: Invariant[int] | None
```
## Invariant and contravariant literal arguments are respected
@@ -179,10 +194,14 @@ promotion:
from typing import Iterable
class X[T]:
x: T
def __init__(self, x: Iterable[T]): ...
def _(x: list[Literal[1]]):
reveal_type(X(x)) # revealed: X[Literal[1]]
reveal_type(list(x)) # revealed: list[Literal[1]]
reveal_type(set(x)) # revealed: set[Literal[1]]
```
## Literals are promoted recursively

View File

@@ -2,6 +2,7 @@ use compact_str::CompactString;
use infer::nearest_enclosing_class;
use itertools::{Either, Itertools};
use ruff_diagnostics::{Edit, Fix};
use rustc_hash::FxHashMap;
use std::borrow::Cow;
use std::cell::RefCell;
@@ -2066,13 +2067,13 @@ impl<'db> Type<'db> {
builder.into_type_mappings()
};
for (type_var, ty) in generic_context.variables(db).zip(specialization.types(db)) {
for (type_var, ty) in specialization.types_and_variables(db) {
let variance = type_var.variance_with_polarity(db, polarity);
let narrowed_tcx = TypeContext::new(tcx_mappings.get(&type_var.identity(db)).copied());
f(type_var, *ty, variance, narrowed_tcx);
f(type_var, ty, variance, narrowed_tcx);
visitor.visit(*ty, || {
visitor.visit(ty, || {
ty.visit_specialization_impl(db, narrowed_tcx, variance, f, visitor);
});
}
@@ -6095,20 +6096,21 @@ impl<'db> Type<'db> {
return alias.raw_value_type(db).expand_eagerly(db);
}
if let TypeMapping::UniqueSpecialization { .. } = *type_mapping {
return visitor.visit(self, || {
alias.value_type(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor)
});
}
// Do not call `value_type` here. `value_type` does the specialization internally, so `apply_type_mapping` is
// performed without `visitor` inheritance. In the case of recursive type aliases, this leads to infinite recursion.
// Instead, call `raw_value_type` and perform the specialization after the `visitor` cache has been created.
let value_type = visitor.visit(self, || {
match type_mapping {
// We only want to perform the unique specialization onto the specialization of the type alias below,
// not the raw value type.
TypeMapping::UniqueSpecialization { .. } => alias.raw_value_type(db),
_ => alias.raw_value_type(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor),
}
alias.raw_value_type(db).apply_type_mapping_impl(db, type_mapping, tcx, visitor)
});
let mapped = alias.apply_function_specialization(db, value_type).apply_type_mapping_impl(db, type_mapping, tcx, visitor);
let is_recursive = any_over_type(db, alias.raw_value_type(db).expand_eagerly(db), &|ty| ty.is_divergent(), false);
// If the type mapping does not result in any change to this (non-recursive) type alias, do not expand it.
@@ -6853,29 +6855,48 @@ impl PromoteLiteralsMode {
pub enum TypeMapping<'a, 'db> {
/// Applies a specialization to the type
ApplySpecialization(ApplySpecialization<'a, 'db>),
/// Resets any specializations to contain unique synthetic type variables.
UniqueSpecialization {
// The number at which to begin counting synthetic type variables, which
// are identified numerically.
skip: usize,
// A list of synthetic type variables, and the types they replaced.
specialization: RefCell<Vec<(BoundTypeVarInstance<'db>, Type<'db>)>>,
specialization: RefCell<FxHashMap<BoundTypeVarInstance<'db>, Type<'db>>>,
// Whether or not to constrain a given synthetic type variable to the type
// it replaces.
constrain: bool,
// The inferable type variables contained in the type being mapped.
inferable: InferableTypeVars<'a, 'db>,
},
/// Replaces any literal types with their corresponding promoted type form (e.g. `Literal["string"]`
/// to `str`, or `def _() -> int` to `Callable[[], int]`).
PromoteLiterals(PromoteLiteralsMode),
/// Binds a legacy typevar with the generic context (class, function, type alias) that it is
/// being used in.
BindLegacyTypevars(BindingContext<'db>),
/// Binds any `typing.Self` typevar with a particular `self` class.
BindSelf {
self_type: Type<'db>,
binding_context: Option<BindingContext<'db>>,
},
/// Replaces occurrences of `typing.Self` with a new `Self` type variable with the given upper bound.
ReplaceSelf { new_upper_bound: Type<'db> },
/// Create the top or bottom materialization of a type.
Materialize(MaterializationKind),
/// Replace default types in parameters of callables with `Unknown`. This is used to avoid infinite
/// recursion when the type of the default value of a parameter depends on the callable itself.
ReplaceParameterDefaults,
/// Apply eager expansion to the type.
/// In the case of recursive type aliases, this will diverge, so that part will be replaced with `Divergent`.
EagerExpansion,
@@ -6917,19 +6938,24 @@ impl<'db> TypeMapping<'_, 'db> {
}
/// Returns a new `TypeMapping` that should be applied in contravariant positions.
pub(crate) fn flip(&self) -> Self {
pub(crate) fn flip(&self) -> Cow<'_, Self> {
match self {
TypeMapping::Materialize(materialization_kind) => {
TypeMapping::Materialize(materialization_kind.flip())
Cow::Owned(TypeMapping::Materialize(materialization_kind.flip()))
}
TypeMapping::PromoteLiterals(mode) => {
Cow::Owned(TypeMapping::PromoteLiterals(mode.flip()))
}
TypeMapping::PromoteLiterals(mode) => TypeMapping::PromoteLiterals(mode.flip()),
TypeMapping::ApplySpecialization(_)
| TypeMapping::UniqueSpecialization { .. }
| TypeMapping::BindLegacyTypevars(_)
| TypeMapping::BindSelf { .. }
| TypeMapping::ReplaceSelf { .. }
| TypeMapping::ReplaceParameterDefaults
| TypeMapping::EagerExpansion => self.clone(),
| TypeMapping::EagerExpansion => {
// Avoid cloning here, as type mappings may contain mutable state.
Cow::Borrowed(self)
}
}
}
}
@@ -7034,7 +7060,7 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
match known_instance {
KnownInstanceType::SubscriptedProtocol(context)
| KnownInstanceType::SubscriptedGeneric(context) => {
walk_generic_context(db, context, visitor);
visitor.visit_generic_context(db, context);
}
KnownInstanceType::TypeVar(typevar) => {
visitor.visit_type_var_type(db, typevar);
@@ -8419,8 +8445,7 @@ impl<'db> BoundTypeVarInstance<'db> {
self.identity(db) == other.identity(db)
}
/// Create a new PEP 695 type variable that can be used in signatures
/// of synthetic generic functions.
/// Create a new PEP 695 type variable that can be used in signatures of synthetic generic functions.
pub(crate) fn synthetic(db: &'db dyn Db, name: Name, variance: TypeVarVariance) -> Self {
let identity = TypeVarIdentity::new(
db,
@@ -8438,6 +8463,30 @@ impl<'db> BoundTypeVarInstance<'db> {
Self::new(db, typevar, BindingContext::Synthetic, None)
}
/// Create a new PEP 695 type variable that can be used in signatures of synthetic generic functions,
/// with the given upper bound.
pub(crate) fn synthetic_with_bounds_or_constraints(
db: &'db dyn Db,
name: Name,
variance: TypeVarVariance,
bounds_or_constraints: Option<TypeVarBoundOrConstraints<'db>>,
) -> Self {
let identity = TypeVarIdentity::new(
db,
name,
None, // definition
TypeVarKind::Pep695,
);
let typevar = TypeVarInstance::new(
db,
identity,
bounds_or_constraints.map(TypeVarBoundOrConstraintsEvaluation::from),
Some(variance),
None, // _default
);
Self::new(db, typevar, BindingContext::Synthetic, None)
}
/// Create a new synthetic `Self` type variable with the given upper bound.
pub(crate) fn synthetic_self(
db: &'db dyn Db,

View File

@@ -3020,7 +3020,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
.class_specialization(self.db)?;
builder
.infer_reverse_map(tcx, return_ty, |(_, variance, inferred_ty)| {
.infer_reverse_map(tcx, return_ty, |(_, inferred_ty), variance| {
// Avoid unnecessarily widening the return type based on a covariant
// type parameter from the type context, as it can lead to argument
// assignability errors if the type variable is constrained by a narrower
@@ -3048,15 +3048,15 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
for (parameter_index, variadic_argument_type) in
self.argument_matches[argument_index].iter()
{
let specialization_result = builder.infer_map(
let specialization_result = builder.infer_map_with_variance(
parameters[parameter_index].annotated_type(),
variadic_argument_type.unwrap_or(argument_type),
|(identity, variance, inferred_ty)| {
|(type_var, inferred_ty), variance| {
// Avoid widening the inferred type if it is already assignable to the
// preferred declared type.
if preferred_type_mappings
.as_ref()
.and_then(|types| types.get(&identity))
.and_then(|types| types.get(&type_var.identity(self.db)))
.is_some_and(|preferred_ty| {
inferred_ty.is_assignable_to(self.db, *preferred_ty)
})
@@ -3065,7 +3065,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}
variance_in_arguments
.entry(identity)
.entry(type_var.identity(self.db))
.and_modify(|current| *current = current.join(variance))
.or_insert(variance);

View File

@@ -27,7 +27,7 @@ use crate::types::function::{
DataclassTransformerFlags, DataclassTransformerParams, KnownFunction,
};
use crate::types::generics::{
GenericContext, InferableTypeVars, Specialization, walk_generic_context, walk_specialization,
GenericContext, InferableTypeVars, Specialization, walk_specialization,
};
use crate::types::infer::{infer_expression_type, infer_unpack_types, nearest_enclosing_class};
use crate::types::member::{Member, class_member};
@@ -369,9 +369,7 @@ impl<'db> VarianceInferable<'db> for GenericAlias<'db> {
std::iter::once(origin.variance_of(db, typevar))
.chain(
specialization
.generic_context(db)
.variables(db)
.zip(specialization.types(db))
.types_and_variables(db)
.map(|(generic_typevar, ty)| {
if let Some(explicit_variance) =
generic_typevar.typevar(db).explicit_variance(db)
@@ -1657,7 +1655,7 @@ impl<'db> ClassLiteral<'db> {
let visitor = CollectTypeVars::default();
if let Some(generic_context) = self.generic_context(db) {
walk_generic_context(db, generic_context, &visitor);
visitor.visit_generic_context(db, generic_context);
}
for base in self.explicit_bases(db) {
visitor.visit_type(db, *base);

View File

@@ -1,7 +1,8 @@
use std::borrow::Cow;
use std::cell::RefCell;
use std::cell::{Cell, RefCell};
use std::collections::hash_map::Entry;
use std::fmt::Display;
use std::mem;
use itertools::{Either, Itertools};
use ruff_python_ast as ast;
@@ -26,8 +27,9 @@ use crate::types::{
ApplyTypeMappingVisitor, BindingContext, BoundTypeVarIdentity, BoundTypeVarInstance,
ClassLiteral, FindLegacyTypeVarsVisitor, IntersectionType, KnownClass, KnownInstanceType,
MaterializationKind, NormalizedVisitor, Type, TypeContext, TypeMapping,
TypeVarBoundOrConstraints, TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance,
UnionType, declaration_type, walk_type_var_bounds,
TypeVarBoundOrConstraints, TypeVarConstraints, TypeVarIdentity, TypeVarInstance, TypeVarKind,
TypeVarVariance, UnionBuilder, UnionType, declaration_type, walk_bound_type_var_type,
walk_type_var_bounds,
};
use crate::{Db, FxOrderMap, FxOrderSet};
@@ -123,8 +125,8 @@ pub(crate) fn typing_self<'db>(
)
}
#[derive(Clone, Copy, Debug)]
pub(crate) enum InferableTypeVars<'a, 'db> {
#[derive(Clone, Copy, Debug, PartialEq, Eq, get_size2::GetSize)]
pub enum InferableTypeVars<'a, 'db> {
None,
One(&'a FxHashSet<BoundTypeVarIdentity<'db>>),
Two(
@@ -763,7 +765,8 @@ pub(super) fn walk_specialization<'db, V: TypeVisitor<'db> + ?Sized>(
specialization: Specialization<'db>,
visitor: &V,
) {
walk_generic_context(db, specialization.generic_context(db), visitor);
visitor.visit_generic_context(db, specialization.generic_context(db));
for ty in specialization.types(db) {
visitor.visit_type(db, *ty);
}
@@ -1059,30 +1062,113 @@ impl<'db> Specialization<'db> {
return self.materialize_impl(db, *materialization_kind, visitor);
}
let types: Box<[_]> = if let TypeMapping::UniqueSpecialization { specialization } =
type_mapping
let types: Box<[_]> = if let TypeMapping::UniqueSpecialization {
specialization,
constrain,
skip,
inferable,
} = type_mapping
{
let mut specialization = specialization.borrow_mut();
self.types(db)
.iter()
.zip(self.generic_context(db).variables(db))
.map(|(ty, typevar)| {
// Create a unique synthetic type variable.
let name = format!("_T{}", specialization.len());
let synthetic =
BoundTypeVarInstance::synthetic(db, Name::new(name), typevar.variance(db));
let synthetic_bounds_or_constraints = |ty: Type<'db>| {
if !constrain {
return None;
}
specialization.push((synthetic, *ty));
let inferable_type_vars: FxOrderSet<_> =
collect_inferable_type_vars(ty, *inferable, db);
if inferable_type_vars.is_empty() {
// If we are creating a non-inferable synthetic type variable, constrain it
// such that it satisfies the same constraints as the type it replaced.
return Some(TypeVarBoundOrConstraints::Constraints(
TypeVarConstraints::new(db, vec![ty].into_boxed_slice()),
));
}
// Otherwise, we are creating an inferable synthetic type variable, and so must
// preserve the bounds and constraints of the inferable type variables we are replacing.
let mut specializations = vec![Vec::new()];
for bound_typevar in &inferable_type_vars {
match bound_typevar.typevar(db).bound_or_constraints(db) {
// If the type variable being replaced has multiple constraints, the
// synthetic type variable will have an upper bound of the union
// of those constraints.
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
for specialization in mem::take(&mut specializations) {
// Expand the constraints for each element of the union.
for constraint in constraints.elements(db) {
let mut specialization = specialization.clone();
// Note we use a gradual type here to emulate an inferable type
// variable. We cannot specialize to the constraint directly,
// as it may be in non-covariant position in the outer type
// that forms the upper bound.
specialization.push(IntersectionType::from_elements(
db,
[Type::any(), *constraint],
));
specializations.push(specialization);
}
}
}
// Otherwise, we can replace the type variable with its singular upper
// bound directly.
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
for specialization in &mut specializations {
specialization.push(IntersectionType::from_elements(
db,
[Type::any(), bound],
));
}
}
// An unbounded type variable has an implicit upper bound of `object`.
None => {
for specialization in &mut specializations {
specialization.push(Type::any());
}
}
}
}
let generic_context =
GenericContext::from_typevar_instances(db, inferable_type_vars);
// Form the upper bound of the synthetic type variable from the union
// of those constraints.
let mut upper_bounds = UnionBuilder::new(db);
for specialization in specializations {
let specialization = generic_context.specialize(db, &specialization);
upper_bounds = upper_bounds.add(ty.apply_specialization(db, specialization));
}
Some(TypeVarBoundOrConstraints::UpperBound(upper_bounds.build()))
};
self.types_and_variables(db)
.map(|(typevar, ty)| {
// Create a unique synthetic type variable.
let name = format!("_T{}", specialization.len() + skip);
let synthetic = BoundTypeVarInstance::synthetic_with_bounds_or_constraints(
db,
Name::new(name),
typevar.variance(db),
synthetic_bounds_or_constraints(ty),
);
specialization.insert(synthetic, ty);
Type::TypeVar(synthetic)
})
.collect()
} else {
self.types(db)
.iter()
.zip(self.generic_context(db).variables(db))
self.types_and_variables(db)
.enumerate()
.map(|(i, (ty, typevar))| {
.map(|(i, (typevar, ty))| {
let tcx = TypeContext::new(tcx.get(i).copied());
if typevar.variance(db).is_covariant() {
ty.apply_type_mapping_impl(db, type_mapping, tcx, visitor)
@@ -1211,9 +1297,7 @@ impl<'db> Specialization<'db> {
}
let mut has_dynamic_invariant_typevar = false;
let types: Box<[_]> = self
.generic_context(db)
.variables(db)
.zip(self.types(db))
.types_and_variables(db)
.map(|(bound_typevar, vartype)| {
match bound_typevar.variance(db) {
TypeVarVariance::Bivariant => {
@@ -1233,7 +1317,7 @@ impl<'db> Specialization<'db> {
if !vartype.is_equivalent_to(db, top_materialization) {
has_dynamic_invariant_typevar = true;
}
*vartype
vartype
}
}
})
@@ -1478,6 +1562,15 @@ impl<'db> Specialization<'db> {
// A tuple's specialization will include all of its element types, so we don't need to also
// look in `self.tuple`.
}
pub(crate) fn types_and_variables(
self,
db: &'db dyn Db,
) -> impl Iterator<Item = (BoundTypeVarInstance<'db>, Type<'db>)> {
self.generic_context(db)
.variables(db)
.zip(self.types(db).iter().copied())
}
}
/// A mapping between type variables and types.
@@ -1527,18 +1620,18 @@ impl<'db> ApplySpecialization<'_, 'db> {
/// Performs type inference between parameter annotations and argument types, producing a
/// specialization of a generic function.
pub(crate) struct SpecializationBuilder<'db> {
pub(crate) struct SpecializationBuilder<'a, 'db> {
db: &'db dyn Db,
inferable: InferableTypeVars<'db, 'db>,
inferable: InferableTypeVars<'a, 'db>,
types: FxHashMap<BoundTypeVarIdentity<'db>, Type<'db>>,
}
/// An assignment from a bound type variable to a given type, along with the variance of the outermost
/// type with respect to the type variable.
pub(crate) type TypeVarAssignment<'db> = (BoundTypeVarIdentity<'db>, TypeVarVariance, Type<'db>);
pub(crate) type TypeVarAssignment<'db> = (BoundTypeVarInstance<'db>, Type<'db>);
impl<'db> SpecializationBuilder<'db> {
pub(crate) fn new(db: &'db dyn Db, inferable: InferableTypeVars<'db, 'db>) -> Self {
impl<'a, 'db> SpecializationBuilder<'a, 'db> {
pub(crate) fn new(db: &'db dyn Db, inferable: InferableTypeVars<'a, 'db>) -> Self {
Self {
db,
inferable,
@@ -1570,9 +1663,9 @@ impl<'db> SpecializationBuilder<'db> {
}
Self {
types,
db: self.db,
inferable: self.inferable,
types,
}
}
@@ -1590,15 +1683,13 @@ impl<'db> SpecializationBuilder<'db> {
&mut self,
bound_typevar: BoundTypeVarInstance<'db>,
ty: Type<'db>,
variance: TypeVarVariance,
mut f: impl FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
) {
let identity = bound_typevar.identity(self.db);
let Some(ty) = f((identity, variance, ty)) else {
let Some(ty) = f((bound_typevar, ty)) else {
return;
};
match self.types.entry(identity) {
match self.types.entry(bound_typevar.identity(self.db)) {
Entry::Occupied(mut entry) => {
// TODO: The spec says that when a ParamSpec is used multiple times in a signature,
// the type checker can solve it to a common behavioral supertype. We don't
@@ -1609,6 +1700,7 @@ impl<'db> SpecializationBuilder<'db> {
if bound_typevar.is_paramspec(self.db) {
return;
}
*entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]);
}
Entry::Vacant(entry) => {
@@ -1631,7 +1723,6 @@ impl<'db> SpecializationBuilder<'db> {
/// type comparisons.
fn add_type_mappings_from_constraint_set(
&mut self,
formal: Type<'db>,
constraints: ConstraintSet<'db>,
mut f: impl FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
) {
@@ -1682,18 +1773,17 @@ impl<'db> SpecializationBuilder<'db> {
}
for (bound_typevar, bounds) in mappings.drain() {
let variance = formal.variance_of(self.db, bound_typevar);
// Prefer the lower bound (often the concrete actual type seen) over the
// upper bound (which may include TypeVar bounds/constraints). The upper bound
// should only be used as a fallback when no concrete type was inferred.
let lower = UnionType::from_elements(self.db, bounds.lower);
if !lower.is_never() {
self.add_type_mapping(bound_typevar, lower, variance, &mut f);
self.add_type_mapping(bound_typevar, lower, &mut f);
continue;
}
let upper = IntersectionType::from_elements(self.db, bounds.upper);
if !upper.is_object() {
self.add_type_mapping(bound_typevar, upper, variance, &mut f);
self.add_type_mapping(bound_typevar, upper, &mut f);
}
}
}
@@ -1705,7 +1795,7 @@ impl<'db> SpecializationBuilder<'db> {
formal: Type<'db>,
actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> {
self.infer_map(formal, actual, |(_, _, ty)| Some(ty))
self.infer_map(formal, actual, |(_, ty)| Some(ty))
}
/// Infer type mappings for the specialization based on a given type and its declared type.
@@ -1718,14 +1808,13 @@ impl<'db> SpecializationBuilder<'db> {
actual: Type<'db>,
mut f: impl FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
) -> Result<(), SpecializationError<'db>> {
self.infer_map_impl(formal, actual, TypeVarVariance::Covariant, &mut f)
self.infer_map_impl(formal, actual, &mut f)
}
fn infer_map_impl(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
polarity: TypeVarVariance,
mut f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
) -> Result<(), SpecializationError<'db>> {
// TODO: Eventually, the builder will maintain a constraint set, instead of a hash-map of
@@ -1797,7 +1886,7 @@ impl<'db> SpecializationBuilder<'db> {
if remaining_actual.is_never() {
return Ok(());
}
self.add_type_mapping(*formal_bound_typevar, remaining_actual, polarity, f);
self.add_type_mapping(*formal_bound_typevar, remaining_actual, f);
}
(Type::Union(union_formal), _) => {
// Second, if the formal is a union, and the actual type is assignable to precisely
@@ -1843,7 +1932,7 @@ impl<'db> SpecializationBuilder<'db> {
let mut first_error = None;
let mut found_matching_element = false;
for formal_element in union_formal.elements(self.db) {
let result = self.infer_map_impl(*formal_element, actual, polarity, &mut f);
let result = self.infer_map_impl(*formal_element, actual, &mut f);
if let Err(err) = result {
first_error.get_or_insert(err);
} else {
@@ -1869,7 +1958,7 @@ impl<'db> SpecializationBuilder<'db> {
// actual type must also be disjoint from every negative element of the
// intersection, but that doesn't help us infer any type mappings.)
for positive in formal.iter_positive(self.db) {
self.infer_map_impl(positive, actual, polarity, f)?;
self.infer_map_impl(positive, actual, f)?;
}
}
@@ -1887,13 +1976,13 @@ impl<'db> SpecializationBuilder<'db> {
argument: ty,
});
}
self.add_type_mapping(bound_typevar, ty, polarity, f);
self.add_type_mapping(bound_typevar, ty, f);
}
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
// Prefer an exact match first.
for constraint in constraints.elements(self.db) {
if ty == *constraint {
self.add_type_mapping(bound_typevar, ty, polarity, f);
self.add_type_mapping(bound_typevar, ty, f);
return Ok(());
}
}
@@ -1903,7 +1992,14 @@ impl<'db> SpecializationBuilder<'db> {
.when_assignable_to(self.db, *constraint, self.inferable)
.is_always_satisfied(self.db)
{
self.add_type_mapping(bound_typevar, *constraint, polarity, f);
if ty.is_type_var() {
// If the actual type is also a type variable, creating a type mapping between
// the type variables instead of the constraint, which just happened to be satisfied.
self.add_type_mapping(bound_typevar, ty, f);
} else {
self.add_type_mapping(bound_typevar, *constraint, f);
}
return Ok(());
}
}
@@ -1912,7 +2008,7 @@ impl<'db> SpecializationBuilder<'db> {
argument: ty,
});
}
_ => self.add_type_mapping(bound_typevar, ty, polarity, f),
_ => self.add_type_mapping(bound_typevar, ty, f),
}
}
@@ -1921,7 +2017,7 @@ impl<'db> SpecializationBuilder<'db> {
{
let formal_instance = Type::TypeVar(subclass_of.into_type_var().unwrap());
if let Some(actual_instance) = ty.to_instance(self.db) {
return self.infer_map_impl(formal_instance, actual_instance, polarity, f);
return self.infer_map_impl(formal_instance, actual_instance, f);
}
}
@@ -1932,12 +2028,7 @@ impl<'db> SpecializationBuilder<'db> {
// will need to check the types of the protocol members to be able to
// infer the specialization of the protocol that the class implements.
if let Some(actual_nominal) = actual_protocol.to_nominal_instance() {
return self.infer_map_impl(
formal,
Type::NominalInstance(actual_nominal),
polarity,
f,
);
return self.infer_map_impl(formal, Type::NominalInstance(actual_nominal), f);
}
}
@@ -1963,8 +2054,7 @@ impl<'db> SpecializationBuilder<'db> {
.iter()
.zip(actual_tuple.all_elements())
{
let variance = TypeVarVariance::Covariant.compose(polarity);
self.infer_map_impl(*formal_element, *actual_element, variance, &mut f)?;
self.infer_map_impl(*formal_element, *actual_element, &mut f)?;
}
return Ok(());
}
@@ -1994,20 +2084,13 @@ impl<'db> SpecializationBuilder<'db> {
if formal_origin != base_alias.origin(self.db) {
continue;
}
let generic_context = formal_alias
.specialization(self.db)
.generic_context(self.db)
.variables(self.db);
let formal_specialization =
formal_alias.specialization(self.db).types(self.db);
let base_specialization = base_alias.specialization(self.db).types(self.db);
for (typevar, formal_ty, base_ty) in itertools::izip!(
generic_context,
formal_specialization,
base_specialization
) {
let variance = typevar.variance_with_polarity(self.db, polarity);
self.infer_map_impl(*formal_ty, *base_ty, variance, &mut f)?;
for (formal_ty, base_ty) in
itertools::izip!(formal_specialization, base_specialization)
{
self.infer_map_impl(*formal_ty, *base_ty, &mut f)?;
}
return Ok(());
}
@@ -2031,7 +2114,7 @@ impl<'db> SpecializationBuilder<'db> {
formal_callable,
self.inferable,
);
self.add_type_mappings_from_constraint_set(formal, when, &mut f);
self.add_type_mappings_from_constraint_set(when, &mut f);
} else {
for actual_signature in &actual_callable.signatures(self.db).overloads {
let when = actual_signature
@@ -2040,7 +2123,7 @@ impl<'db> SpecializationBuilder<'db> {
formal_callable,
self.inferable,
);
self.add_type_mappings_from_constraint_set(formal, when, &mut f);
self.add_type_mappings_from_constraint_set(when, &mut f);
}
}
}
@@ -2053,6 +2136,144 @@ impl<'db> SpecializationBuilder<'db> {
Ok(())
}
/// Infer type mappings for the specialization based on a given type and its declared type,
/// while keeping track of the variance of any inferred types.
///
/// The provided function will be called before any type mappings are created, and can
/// optionally modify the inferred type, or filter out the type mapping entirely.
///
/// Note that for a given type mapping, the provided variance is the variance of the inferred
/// type with respect to the _actual_ type.
pub(crate) fn infer_map_with_variance(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
mut f: impl FnMut(TypeVarAssignment<'db>, TypeVarVariance) -> Option<Type<'db>>,
) -> Result<(), SpecializationError<'db>> {
self.infer_map_with_variance_impl(formal, actual, TypeVarVariance::Covariant, &mut f)
}
fn infer_map_with_variance_impl(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
polarity: TypeVarVariance,
f: &mut dyn FnMut(TypeVarAssignment<'db>, TypeVarVariance) -> Option<Type<'db>>,
) -> Result<(), SpecializationError<'db>> {
// Synthesize the formal and actual types.
let (synthetic_formal, synthetic_formal_specialization) =
synthetic_specialization(formal, 0, true, self.inferable, self.db);
let skip = synthetic_formal_specialization.types(self.db).len();
let (synthetic_actual, synthetic_actual_specialization) =
synthetic_specialization(actual, skip, true, self.inferable, self.db);
// If we can't recurse into the actual type any further, just perform a regular inference
// with the current polarity.
if synthetic_actual_specialization.types(self.db).is_empty()
|| synthetic_formal_specialization.types(self.db).is_empty()
{
return self.infer_map_impl(formal, actual, &mut |type_assignment| {
f(type_assignment, polarity)
});
}
let mut synthetic_inferable = FxHashSet::default();
for (synthetic_typevar, ty) in synthetic_formal_specialization
.types_and_variables(self.db)
.chain(synthetic_actual_specialization.types_and_variables(self.db))
{
// Mark the synthetic type variable as inferable if it is representing an inferable
// type variable, or a type containing one.
if contains_inferable_type_var(ty, self.inferable, self.db) {
synthetic_inferable.insert(synthetic_typevar.identity(self.db));
}
}
let synthetic_inferable = InferableTypeVars::One(&synthetic_inferable);
// Collect the synthetic type variable to which each formal type variable is mapped.
let mut assigned_variables = FxHashMap::default();
let synthetic_type_mappings = {
let mut synthetic_builder =
SpecializationBuilder::new(self.db, self.inferable.merge(&synthetic_inferable));
let result =
synthetic_builder.infer_map(synthetic_formal, synthetic_actual, |(typevar, ty)| {
assigned_variables.insert(typevar.identity(self.db), typevar);
Some(ty)
});
// If the synthetic inference resulted in an error, we must perform inference on the
// original types to return the correct error message.
if result.is_err() {
return self.infer_map(formal, actual, |_| None);
}
synthetic_builder.into_type_mappings()
};
// We can't recurse any further, just perform a regular inference with the current polarity.
if synthetic_type_mappings.is_empty() {
return self.infer_map_impl(formal, actual, &mut |type_assignment| {
f(type_assignment, polarity)
});
}
for (identity, synthetic_type) in synthetic_type_mappings {
let bound_typevar = assigned_variables
.get(&identity)
.copied()
.expect("every type mapping stores its corresponding type variable");
let (formal_type, actual_type, synthetic_actual) = if let Some(formal_type) =
synthetic_actual_specialization.get(self.db, bound_typevar)
{
// The type variable may be a synthetic type variable from the actual type that was
// marked as inferable, in which case the inferred type is from the formal type.
let actual_type = synthetic_type.apply_type_mapping(
self.db,
&TypeMapping::ApplySpecialization(ApplySpecialization::Specialization(
synthetic_formal_specialization,
)),
TypeContext::default(),
);
(formal_type, actual_type, synthetic_formal)
} else {
let formal_type = synthetic_formal_specialization
.get(self.db, bound_typevar)
.unwrap_or(Type::TypeVar(bound_typevar));
let actual_type = synthetic_type.apply_type_mapping(
self.db,
&TypeMapping::ApplySpecialization(ApplySpecialization::Specialization(
synthetic_actual_specialization,
)),
TypeContext::default(),
);
(formal_type, actual_type, synthetic_actual)
};
if let Some(synthetic_type_var) = synthetic_type.as_typevar() {
// Created a type mapping to a synthetic type variable. Update the variance
// and recurse deeper.
let variance = synthetic_actual
.variance_of(self.db, synthetic_type_var)
.compose(polarity);
self.infer_map_with_variance_impl(formal_type, actual_type, variance, f)?;
} else {
// We can't recurse any further, just perform a regular inference with the current polarity.
self.infer_map_impl(formal_type, actual_type, &mut |type_assignment| {
f(type_assignment, polarity)
})?;
}
}
Ok(())
}
/// Infer type mappings for the specialization in the reverse direction, i.e., where the
/// actual type, not the formal type, contains inferable type variables.
pub(crate) fn infer_reverse(
@@ -2060,7 +2281,7 @@ impl<'db> SpecializationBuilder<'db> {
formal: Type<'db>,
actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> {
self.infer_reverse_map(formal, actual, |(_, _, ty)| Some(ty))
self.infer_reverse_map(formal, actual, |(_, ty), _| Some(ty))
}
/// Infer type mappings for the specialization in the reverse direction, i.e., where the
@@ -2068,11 +2289,14 @@ impl<'db> SpecializationBuilder<'db> {
///
/// The provided function will be called before any type mappings are created, and can
/// optionally modify the inferred type, or filter out the type mapping entirely.
///
/// Note that for a given type mapping, the provided variance is the variance of the inferred
/// type with respect to the _formal_ type.
pub(crate) fn infer_reverse_map(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
mut f: impl FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
mut f: impl FnMut(TypeVarAssignment<'db>, TypeVarVariance) -> Option<Type<'db>>,
) -> Result<(), SpecializationError<'db>> {
self.infer_reverse_map_impl(formal, actual, TypeVarVariance::Covariant, &mut f)
}
@@ -2082,30 +2306,19 @@ impl<'db> SpecializationBuilder<'db> {
formal: Type<'db>,
actual: Type<'db>,
polarity: TypeVarVariance,
f: &mut dyn FnMut(TypeVarAssignment<'db>) -> Option<Type<'db>>,
f: &mut dyn FnMut(TypeVarAssignment<'db>, TypeVarVariance) -> Option<Type<'db>>,
) -> Result<(), SpecializationError<'db>> {
// Assign each type variable on the formal type to a unique synthetic type variable.
let type_mapping = TypeMapping::UniqueSpecialization {
specialization: RefCell::new(Vec::new()),
};
let synthetic_formal =
formal.apply_type_mapping(self.db, &type_mapping, TypeContext::default());
let (synthetic_formal, synthetic_specialization) =
synthetic_specialization(formal, 0, false, self.inferable, self.db);
// Recover the synthetic type variables.
let synthetic_specialization = match type_mapping {
TypeMapping::UniqueSpecialization { specialization } => specialization.into_inner(),
_ => unreachable!(),
};
let inferable = GenericContext::from_typevar_instances(
self.db,
synthetic_specialization.iter().map(|(typevar, _)| *typevar),
)
.inferable_typevars(self.db);
let synthetic_inferable = synthetic_specialization
.generic_context(self.db)
.inferable_typevars(self.db);
// Collect the actual type to which each synthetic type variable is mapped.
let forward_type_mappings = {
let mut builder = SpecializationBuilder::new(self.db, inferable);
let mut builder = SpecializationBuilder::new(self.db, synthetic_inferable);
builder.infer(synthetic_formal, actual)?;
builder.into_type_mappings()
};
@@ -2114,14 +2327,16 @@ impl<'db> SpecializationBuilder<'db> {
//
// This is the base case for when `actual` is an inferable type variable.
if forward_type_mappings.is_empty() {
return self.infer_map_impl(actual, formal, polarity, f);
return self.infer_map_with_variance_impl(actual, formal, polarity, f);
}
// Consider the reverse inference of `Sequence[int]` given `list[T]`.
//
// Given a forward type mapping of `T@Sequence` -> `T@list`, and a synthetic type mapping of
// `T@Sequence` -> `int`, we want to infer the reverse type mapping `T@list` -> `int`.
for (synthetic_type_var, formal_type) in synthetic_specialization {
for (synthetic_type_var, formal_type) in
synthetic_specialization.types_and_variables(self.db)
{
if let Some(actual_type) =
forward_type_mappings.get(&synthetic_type_var.identity(self.db))
{
@@ -2137,6 +2352,125 @@ impl<'db> SpecializationBuilder<'db> {
}
}
// Returns `true` if the given type contains an inferable type variable.
fn contains_inferable_type_var<'db>(
ty: Type<'db>,
inferable: InferableTypeVars<'_, 'db>,
db: &'db dyn Db,
) -> bool {
struct ContainsInferableTypeVar<'a, 'db> {
result: Cell<bool>,
inferable: InferableTypeVars<'a, 'db>,
recursion_guard: TypeCollector<'db>,
}
impl<'db> TypeVisitor<'db> for ContainsInferableTypeVar<'_, 'db> {
fn should_visit_lazy_type_attributes(&self) -> bool {
true
}
fn visit_bound_type_var_type(
&self,
db: &'db dyn Db,
bound_typevar: BoundTypeVarInstance<'db>,
) {
if bound_typevar.is_inferable(db, self.inferable) {
self.result.set(true);
}
walk_bound_type_var_type(db, bound_typevar, self);
}
fn visit_generic_context(&self, _db: &'db dyn Db, _context: GenericContext<'db>) {}
fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) {
walk_type_with_recursion_guard(db, ty, self, &self.recursion_guard);
}
}
let visitor = ContainsInferableTypeVar {
inferable,
result: Cell::new(false),
recursion_guard: TypeCollector::default(),
};
visitor.visit_type(db, ty);
visitor.result.get()
}
// Returns the set of inferable type variables contained in the given type.
fn collect_inferable_type_vars<'db>(
ty: Type<'db>,
inferable: InferableTypeVars<'_, 'db>,
db: &'db dyn Db,
) -> FxOrderSet<BoundTypeVarInstance<'db>> {
struct CollectTypeVars<'a, 'db> {
typevars: RefCell<FxOrderSet<BoundTypeVarInstance<'db>>>,
inferable: InferableTypeVars<'a, 'db>,
recursion_guard: TypeCollector<'db>,
}
impl<'db> TypeVisitor<'db> for CollectTypeVars<'_, 'db> {
fn should_visit_lazy_type_attributes(&self) -> bool {
false
}
fn visit_bound_type_var_type(
&self,
db: &'db dyn Db,
bound_typevar: BoundTypeVarInstance<'db>,
) {
if bound_typevar.is_inferable(db, self.inferable) {
self.typevars.borrow_mut().insert(bound_typevar);
}
let typevar = bound_typevar.typevar(db);
if let Some(bound_or_constraints) = typevar.bound_or_constraints(db) {
walk_type_var_bounds(db, bound_or_constraints, self);
}
}
fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) {
walk_type_with_recursion_guard(db, ty, self, &self.recursion_guard);
}
}
let visitor = CollectTypeVars {
inferable,
typevars: RefCell::default(),
recursion_guard: TypeCollector::default(),
};
visitor.visit_type(db, ty);
visitor.typevars.into_inner()
}
// Assign each type variable on the given type to a unique synthetic type variable, returning the
// synthesized type and its specialization.
fn synthetic_specialization<'db>(
ty: Type<'db>,
skip: usize,
constrain: bool,
inferable: InferableTypeVars<'_, 'db>,
db: &'db dyn Db,
) -> (Type<'db>, Specialization<'db>) {
let type_mapping = TypeMapping::UniqueSpecialization {
skip,
constrain,
inferable,
specialization: RefCell::new(FxHashMap::default()),
};
let synthetic_ty = ty.apply_type_mapping(db, &type_mapping, TypeContext::default());
let synthetic_types = match type_mapping {
TypeMapping::UniqueSpecialization { specialization, .. } => specialization.into_inner(),
_ => unreachable!(),
};
let synthetic_specialization = {
let (type_vars, types): (Vec<_>, Vec<_>) = synthetic_types.iter().unzip();
GenericContext::from_typevar_instances(db, type_vars).specialize(db, &types)
};
(synthetic_ty, synthetic_specialization)
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) enum SpecializationError<'db> {
MismatchedBound {

View File

@@ -8191,9 +8191,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.infer_reverse_map(
tcx,
collection_instance,
|(typevar, variance, inferred_ty)| {
|(typevar, inferred_ty), variance| {
elt_tcx_variance
.entry(typevar)
.entry(typevar.identity(self.db()))
.and_modify(|current| *current = current.join(variance))
.or_insert(variance);

View File

@@ -19,7 +19,7 @@ use smallvec::{SmallVec, smallvec_inline};
use super::{DynamicType, Type, TypeVarVariance, definition_expression_type, semantic_index};
use crate::semantic_index::definition::Definition;
use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
use crate::types::generics::{GenericContext, InferableTypeVars, walk_generic_context};
use crate::types::generics::{GenericContext, InferableTypeVars};
use crate::types::infer::{infer_deferred_types, infer_scope_types};
use crate::types::relation::{
HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, TypeRelation,
@@ -607,7 +607,7 @@ pub(super) fn walk_signature<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
visitor: &V,
) {
if let Some(generic_context) = &signature.generic_context {
walk_generic_context(db, *generic_context, visitor);
visitor.visit_generic_context(db, *generic_context);
}
// By default we usually don't visit the type of the default value,
// as it isn't relevant to most things

View File

@@ -4,16 +4,16 @@ use crate::{
Db,
types::{
BoundMethodType, BoundSuperType, BoundTypeVarInstance, CallableType, GenericAlias,
IntersectionType, KnownBoundMethodType, KnownInstanceType, NominalInstanceType,
PropertyInstanceType, ProtocolInstanceType, SubclassOfType, Type, TypeAliasType,
TypeGuardType, TypeIsType, TypeVarInstance, TypedDictType, UnionType,
GenericContext, IntersectionType, KnownBoundMethodType, KnownInstanceType,
NominalInstanceType, PropertyInstanceType, ProtocolInstanceType, SubclassOfType, Type,
TypeAliasType, TypeGuardType, TypeIsType, TypeVarInstance, TypedDictType, UnionType,
bound_super::walk_bound_super_type,
class::walk_generic_alias,
function::{FunctionType, walk_function_type},
instance::{walk_nominal_instance_type, walk_protocol_instance_type},
newtype::{NewType, walk_newtype_instance_type},
subclass_of::walk_subclass_of_type,
walk_bound_method_type, walk_bound_type_var_type, walk_callable_type,
walk_bound_method_type, walk_bound_type_var_type, walk_callable_type, walk_generic_context,
walk_intersection_type, walk_known_instance_type, walk_method_wrapper_type,
walk_property_instance_type, walk_type_alias_type, walk_type_var_type,
walk_typed_dict_type, walk_typeguard_type, walk_typeis_type, walk_union,
@@ -60,6 +60,10 @@ pub(crate) trait TypeVisitor<'db> {
walk_subclass_of_type(db, subclass_of, self);
}
fn visit_generic_context(&self, db: &'db dyn Db, context: GenericContext<'db>) {
walk_generic_context(db, context, self);
}
fn visit_generic_alias_type(&self, db: &'db dyn Db, alias: GenericAlias<'db>) {
walk_generic_alias(db, alias, self);
}