Compare commits

...

1 Commits

Author SHA1 Message Date
Ibraheem Ahmed
c781969973 propagate type context through comprehensions 2026-01-13 20:54:40 -05:00
10 changed files with 319 additions and 214 deletions

View File

@@ -650,6 +650,27 @@ reveal_type(x5) # revealed: list[Iterable[Any]]
x6: Iterable[list[Any]] = [[1, 2, 3]]
reveal_type(x6) # revealed: list[list[Any]]
x7: Sequence[Any] = [i for i in [1, 2, 3]]
# TODO: This should infer `list[int]`.
reveal_type(x7) # revealed: list[Unknown | int]
x8: MutableSequence[Any] = [i for i in [1, 2, 3]]
reveal_type(x8) # revealed: list[Any]
x9: Iterable[Any] = [i for i in [1, 2, 3]]
# TODO: This should infer `list[int]`.
reveal_type(x9) # revealed: list[Unknown | int]
x10: Iterable[Iterable[Any]] = [[i] for i in [1, 2, 3]]
# TODO: This should infer `list[list[int]]`.
reveal_type(x10) # revealed: list[list[Unknown | int]]
x11: list[Iterable[Any]] = [[i] for i in [1, 2, 3]]
reveal_type(x11) # revealed: list[Iterable[Any]]
x12: Iterable[list[Any]] = [[i] for i in [1, 2, 3]]
reveal_type(x12) # revealed: list[list[Any]]
class X[T]:
value: T
@@ -660,29 +681,29 @@ class A[T](X[T]): ...
def a[T](value: T) -> A[T]:
return A(value)
x7: A[object] = A(1)
reveal_type(x7) # revealed: A[object]
x13: A[object] = A(1)
reveal_type(x13) # revealed: A[object]
x8: X[object] = A(1)
reveal_type(x8) # revealed: A[object]
x14: X[object] = A(1)
reveal_type(x14) # revealed: A[object]
x9: X[object] | None = A(1)
reveal_type(x9) # revealed: A[object]
x15: X[object] | None = A(1)
reveal_type(x15) # revealed: A[object]
x10: X[object] | None = a(1)
reveal_type(x10) # revealed: A[object]
x16: X[object] | None = a(1)
reveal_type(x16) # revealed: A[object]
def f[T](x: T) -> list[list[T]]:
return [[x]]
x11: Sequence[Sequence[Any]] = f(1)
reveal_type(x11) # revealed: list[list[int]]
x17: Sequence[Sequence[Any]] = f(1)
reveal_type(x17) # revealed: list[list[int]]
x12: Sequence[list[Any]] = f(1)
reveal_type(x12) # revealed: list[list[Any]]
x18: Sequence[list[Any]] = f(1)
reveal_type(x18) # revealed: list[list[Any]]
x13: dict[int, dict[str, int]] = defaultdict(dict)
reveal_type(x13) # revealed: defaultdict[int, dict[str, int]]
x19: dict[int, dict[str, int]] = defaultdict(dict)
reveal_type(x19) # revealed: defaultdict[int, dict[str, int]]
```
## Narrow generic unions

View File

@@ -122,6 +122,17 @@ async def _():
[reveal_type(x) async for x in range(3)]
```
## Comprehension value type
The type of the expression being iterated over is immutable, and so should not be widened with
`Unknown` or through literal promotion:
```py
# TODO: This should reveal `Literal["a", "b"]`
# revealed: Unknown | str
x = [reveal_type(string) for string in ["a", "b"]]
```
## Comprehension expression types
The type of the comprehension expression itself should reflect the inferred element type:
@@ -129,16 +140,16 @@ The type of the comprehension expression itself should reflect the inferred elem
```py
from typing import TypedDict, Literal
# revealed: list[int | Unknown]
# revealed: list[Unknown | int]
reveal_type([x for x in range(10)])
# revealed: set[int | Unknown]
# revealed: set[Unknown | int]
reveal_type({x for x in range(10)})
# revealed: dict[int | Unknown, str | Unknown]
# revealed: dict[Unknown | int, Unknown | str]
reveal_type({x: str(x) for x in range(10)})
# revealed: list[tuple[int, Unknown | str] | Unknown]
# revealed: list[Unknown | tuple[int, Unknown | str]]
reveal_type([(x, y) for x in range(5) for y in ["a", "b", "c"]])
squares: list[int | None] = [x**2 for x in range(10)]
@@ -148,27 +159,34 @@ reveal_type(squares) # revealed: list[int | None]
Inference for comprehensions takes the type context into account:
```py
from typing import Sequence
# Without type context:
reveal_type([x for x in [1, 2, 3]]) # revealed: list[Unknown | int]
reveal_type({x: "a" for x in [1, 2, 3]}) # revealed: dict[Unknown | int, str | Unknown]
reveal_type({str(x): x for x in [1, 2, 3]}) # revealed: dict[str | Unknown, Unknown | int]
reveal_type({x: "a" for x in [1, 2, 3]}) # revealed: dict[Unknown | int, Unknown | str]
reveal_type({str(x): x for x in [1, 2, 3]}) # revealed: dict[Unknown | str, Unknown | int]
reveal_type({x for x in [1, 2, 3]}) # revealed: set[Unknown | int]
# With type context:
xs: list[int] = [x for x in [1, 2, 3]]
reveal_type(xs) # revealed: list[int]
x1: list[int] = [x for x in [1, 2, 3]]
reveal_type(x1) # revealed: list[int]
ys: dict[int, str] = {x: str(x) for x in [1, 2, 3]}
reveal_type(ys) # revealed: dict[int, str]
x2: Sequence[int] = [x for x in [1, 2, 3]]
# TODO: This should reveal `list[int]`.
reveal_type(x2) # revealed: list[Unknown | int]
zs: set[int] = {x for x in [1, 2, 3]}
x3: dict[int, str] = {x: str(x) for x in [1, 2, 3]}
reveal_type(x3) # revealed: dict[int, str]
x4: set[int] = {x for x in [1, 2, 3]}
reveal_type(x4) # revealed: set[int]
```
This also works for nested comprehensions:
```py
table = [[(x, y) for x in range(3)] for y in range(3)]
reveal_type(table) # revealed: list[list[tuple[int, int] | Unknown] | Unknown]
reveal_type(table) # revealed: list[Unknown | list[Unknown | tuple[int, int]]]
table_with_content: list[list[tuple[int, int, str | None]]] = [[(x, y, None) for x in range(3)] for y in range(3)]
reveal_type(table_with_content) # revealed: list[list[tuple[int, int, str | None]]]
@@ -177,25 +195,29 @@ reveal_type(table_with_content) # revealed: list[list[tuple[int, int, str | Non
The type context is propagated down into the comprehension:
```py
y1: list[list[int]] = [[n] for n in [1, 2, 3]]
reveal_type(y1) # revealed: list[list[int]]
y2: list[Sequence[int]] = [[i] for i in [1, 2, 3]]
reveal_type(y2) # revealed: list[Sequence[int]]
class Person(TypedDict):
name: str
# TODO: This should not error.
# error: [invalid-assignment]
persons: list[Person] = [{"name": n} for n in ["Alice", "Bob"]]
reveal_type(persons) # revealed: list[Person]
y3: list[Person] = [{"name": n} for n in ["Alice", "Bob"]]
reveal_type(y3) # revealed: list[Person]
# TODO: This should be an invalid-key error.
# error: [invalid-assignment]
invalid: list[Person] = [{"misspelled": n} for n in ["Alice", "Bob"]]
y4: list[Person] = [{"misspelled": n} for n in ["Alice", "Bob"]]
```
We promote literals to avoid overly-precise types in invariant positions:
```py
reveal_type([x for x in ("a", "b", "c")]) # revealed: list[str | Unknown]
reveal_type({x for x in (1, 2, 3)}) # revealed: set[int | Unknown]
reveal_type({k: 0 for k in ("a", "b", "c")}) # revealed: dict[str | Unknown, int | Unknown]
reveal_type([x for x in ("a", "b", "c")]) # revealed: list[Unknown | str]
reveal_type({x for x in (1, 2, 3)}) # revealed: set[Unknown | int]
reveal_type({k: 0 for k in ("a", "b", "c")}) # revealed: dict[Unknown | str, Unknown | int]
```
Type context can prevent this promotion from happening:

View File

@@ -51,6 +51,6 @@ reveal_type({"a": 1, "b": (1, 2), "c": (1, 2, 3)})
## Dict comprehensions
```py
# revealed: dict[int | Unknown, int | Unknown]
# revealed: dict[Unknown | int, Unknown | int]
reveal_type({x: y for x, y in enumerate(range(42))})
```

View File

@@ -41,5 +41,5 @@ reveal_type([1, (1, 2), (1, 2, 3)])
## List comprehensions
```py
reveal_type([x for x in range(42)]) # revealed: list[int | Unknown]
reveal_type([x for x in range(42)]) # revealed: list[Unknown | int]
```

View File

@@ -35,5 +35,5 @@ reveal_type({1, (1, 2), (1, 2, 3)})
## Set comprehensions
```py
reveal_type({x for x in range(42)}) # revealed: set[int | Unknown]
reveal_type({x for x in range(42)}) # revealed: set[Unknown | int]
```

View File

@@ -15,7 +15,7 @@ use crate::semantic_index::definition::Definition;
use crate::semantic_index::scope::FileScopeId;
use crate::semantic_index::semantic_index;
use crate::types::list_members::{Member, all_members, all_reachable_members};
use crate::types::{Type, binding_type, infer_scope_types};
use crate::types::{Type, TypeContext, binding_type, infer_scope_types};
/// The primary interface the LSP should use for querying semantic information about a [`File`].
///
@@ -358,7 +358,7 @@ impl<'db> SemanticModel<'db> {
let index = semantic_index(self.db, self.file);
let file_scope = index.expression_scope_id(&expr);
let scope = file_scope.to_scope_id(self.db, self.file);
if !infer_scope_types(self.db, scope).is_string_annotation(expr) {
if !infer_scope_types(self.db, scope, TypeContext::default()).is_string_annotation(expr) {
return None;
}
@@ -467,7 +467,7 @@ impl HasType for ast::ExprRef<'_> {
let file_scope = index.try_expression_scope_id(&model.expr_ref_in_ast(*self))?;
let scope = file_scope.to_scope_id(model.db, model.file);
infer_scope_types(model.db, scope).try_expression_type(*self)
infer_scope_types(model.db, scope, TypeContext::default()).try_expression_type(*self)
}
}

View File

@@ -131,7 +131,12 @@ pub fn check_types(db: &dyn Db, file: File) -> Vec<Diagnostic> {
let mut diagnostics = TypeCheckDiagnostics::default();
for scope_id in index.scope_ids() {
let result = infer_scope_types(db, scope_id);
// TODO: Scopes that need type context, e.g., list comprehensions, are inferred during
// the inference of their outer scope, and so end up getting inferred a second time here
// without type context. This does not affect the type inferred in its outer scope, but
// is necessary as some IDE operations may be performed without type context, e.g.,
// hovering in the inner scope.
let result = infer_scope_types(db, scope_id, TypeContext::default());
if let Some(scope_diagnostics) = result.diagnostics() {
diagnostics.extend(scope_diagnostics);
@@ -198,7 +203,7 @@ fn definition_expression_type<'db>(
}
} else {
// expression is in a type-params sub-scope
infer_scope_types(db, scope).expression_type(expression)
infer_scope_types(db, scope, TypeContext::default()).expression_type(expression)
}
}

View File

@@ -64,41 +64,6 @@ mod builder;
#[cfg(test)]
mod tests;
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the
/// scope.
#[salsa::tracked(returns(ref), cycle_fn=scope_cycle_recover, cycle_initial=scope_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> {
let file = scope.file(db);
let _span = tracing::trace_span!("infer_scope_types", scope=?scope.as_id(), ?file).entered();
let module = parsed_module(db, file).load(db);
// Using the index here is fine because the code below depends on the AST anyway.
// The isolation of the query is by the return inferred types.
let index = semantic_index(db, file);
TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index, &module).finish_scope()
}
fn scope_cycle_recover<'db>(
db: &'db dyn Db,
cycle: &salsa::Cycle,
previous_inference: &ScopeInference<'db>,
inference: ScopeInference<'db>,
_scope: ScopeId<'db>,
) -> ScopeInference<'db> {
inference.cycle_normalized(db, previous_inference, cycle)
}
fn scope_cycle_initial<'db>(
_db: &'db dyn Db,
id: salsa::Id,
_scope: ScopeId<'db>,
) -> ScopeInference<'db> {
ScopeInference::cycle_initial(Type::divergent(id))
}
/// Infer all types for a [`Definition`] (including sub-expressions).
/// Use when resolving a place use or public type of a place.
#[salsa::tracked(returns(ref), cycle_fn=definition_cycle_recover, cycle_initial=definition_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
@@ -182,6 +147,53 @@ fn deferred_cycle_initial<'db>(
DefinitionInference::cycle_initial(definition.scope(db), Type::divergent(id))
}
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the
/// scope.
pub(crate) fn infer_scope_types<'db>(
db: &'db dyn Db,
scope: ScopeId<'db>,
tcx: TypeContext<'db>,
) -> &'db ScopeInference<'db> {
infer_scope_types_impl(db, InferScope::new(db, scope, tcx))
}
#[salsa::tracked(returns(ref), cycle_fn=scope_cycle_recover, cycle_initial=scope_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
pub(crate) fn infer_scope_types_impl<'db>(
db: &'db dyn Db,
input: InferScope<'db>,
) -> ScopeInference<'db> {
let (scope, tcx) = input.into_inner(db);
let file = scope.file(db);
let _span = tracing::trace_span!("infer_scope_types", scope=?scope.as_id(), ?file).entered();
let module = parsed_module(db, file).load(db);
// Using the index here is fine because the code below depends on the AST anyway.
// The isolation of the query is by the return inferred types.
let index = semantic_index(db, file);
TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope, tcx), index, &module).finish_scope()
}
fn scope_cycle_recover<'db>(
db: &'db dyn Db,
cycle: &salsa::Cycle,
previous_inference: &ScopeInference<'db>,
inference: ScopeInference<'db>,
_input: InferScope<'db>,
) -> ScopeInference<'db> {
inference.cycle_normalized(db, previous_inference, cycle)
}
fn scope_cycle_initial<'db>(
_db: &'db dyn Db,
id: salsa::Id,
_input: InferScope<'db>,
) -> ScopeInference<'db> {
ScopeInference::cycle_initial(Type::divergent(id))
}
/// Infer all types for an [`Expression`] (including sub-expressions).
/// Use rarely; only for cases where we'd otherwise risk double-inferring an expression: RHS of an
/// assignment, which might be unpacking/multi-target and thus part of multiple definitions, or a
@@ -199,7 +211,7 @@ pub(super) fn infer_expression_types_impl<'db>(
db: &'db dyn Db,
input: InferExpression<'db>,
) -> ExpressionInference<'db> {
let (expression, tcx) = (input.expression(db), input.tcx(db));
let (expression, tcx) = input.into_inner(db);
let file = expression.file(db);
let module = parsed_module(db, file).load(db);
@@ -237,8 +249,9 @@ fn expression_cycle_initial<'db>(
id: salsa::Id,
input: InferExpression<'db>,
) -> ExpressionInference<'db> {
let (expression, _) = input.into_inner(db);
let cycle_recovery = Type::divergent(id);
ExpressionInference::cycle_initial(input.expression(db).scope(db), cycle_recovery)
ExpressionInference::cycle_initial(expression.scope(db), cycle_recovery)
}
/// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query.
@@ -273,13 +286,15 @@ pub(crate) fn infer_expression_type<'db>(
#[salsa::tracked(cycle_fn=single_expression_cycle_recover, cycle_initial=single_expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn infer_expression_type_impl<'db>(db: &'db dyn Db, input: InferExpression<'db>) -> Type<'db> {
let file = input.expression(db).file(db);
let (expression, _) = input.into_inner(db);
let file = expression.file(db);
let module = parsed_module(db, file).load(db);
// It's okay to call the "same file" version here because we're inside a salsa query.
let inference = infer_expression_types_impl(db, input);
inference.expression_type(input.expression(db).node_ref(db, &module))
inference.expression_type(expression.node_ref(db, &module))
}
fn single_expression_cycle_recover<'db>(
@@ -310,6 +325,12 @@ pub(super) enum InferExpression<'db> {
WithContext(ExpressionWithContext<'db>),
}
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
pub(super) struct ExpressionWithContext<'db> {
expression: Expression<'db>,
tcx: TypeContext<'db>,
}
impl<'db> InferExpression<'db> {
pub(super) fn new(
db: &'db dyn Db,
@@ -319,37 +340,57 @@ impl<'db> InferExpression<'db> {
if tcx.annotation.is_some() {
InferExpression::WithContext(ExpressionWithContext::new(db, expression, tcx))
} else {
// Drop the empty `TypeContext` to avoid the interning cost.
InferExpression::Bare(expression)
}
}
fn expression(self, db: &'db dyn Db) -> Expression<'db> {
fn into_inner(self, db: &'db dyn Db) -> (Expression<'db>, TypeContext<'db>) {
match self {
InferExpression::Bare(expression) => expression,
InferExpression::WithContext(expression_with_context) => {
expression_with_context.expression(db)
}
}
}
fn tcx(self, db: &'db dyn Db) -> TypeContext<'db> {
match self {
InferExpression::Bare(_) => TypeContext::default(),
InferExpression::WithContext(expression_with_context) => {
expression_with_context.tcx(db)
}
InferExpression::Bare(expression) => (expression, TypeContext::default()),
InferExpression::WithContext(expression_with_context) => (
expression_with_context.expression(db),
expression_with_context.tcx(db),
),
}
}
}
/// An `Expression` with a `TypeContext`.
/// A `ScopeId` with an optional `TypeContext`.
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, salsa::Supertype, salsa::Update)]
pub(super) enum InferScope<'db> {
Bare(ScopeId<'db>),
WithContext(ScopeWithContext<'db>),
}
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
pub(super) struct ExpressionWithContext<'db> {
expression: Expression<'db>,
pub(super) struct ScopeWithContext<'db> {
scope: ScopeId<'db>,
tcx: TypeContext<'db>,
}
impl<'db> InferScope<'db> {
pub(super) fn new(
db: &'db dyn Db,
scope: ScopeId<'db>,
tcx: TypeContext<'db>,
) -> InferScope<'db> {
if tcx.annotation.is_some() {
InferScope::WithContext(ScopeWithContext::new(db, scope, tcx))
} else {
InferScope::Bare(scope)
}
}
fn into_inner(self, db: &'db dyn Db) -> (ScopeId<'db>, TypeContext<'db>) {
match self {
InferScope::Bare(scope) => (scope, TypeContext::default()),
InferScope::WithContext(scope_with_context) => {
(scope_with_context.scope(db), scope_with_context.tcx(db))
}
}
}
}
/// The type context for a given expression, namely the type annotation
/// in an annotated assignment.
///
@@ -513,7 +554,7 @@ pub(crate) enum InferenceRegion<'db> {
/// infer deferred types for a [`Definition`]
Deferred(Definition<'db>),
/// infer types for an entire [`ScopeId`]
Scope(ScopeId<'db>),
Scope(ScopeId<'db>, TypeContext<'db>),
}
impl<'db> InferenceRegion<'db> {
@@ -523,7 +564,7 @@ impl<'db> InferenceRegion<'db> {
InferenceRegion::Definition(definition) | InferenceRegion::Deferred(definition) => {
definition.scope(db)
}
InferenceRegion::Scope(scope) => scope,
InferenceRegion::Scope(scope, _) => scope,
}
}
}

View File

@@ -517,17 +517,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let file_scope = self.index.expression_scope_id(expression);
let expr_scope = file_scope.to_scope_id(self.db(), self.file());
match self.region {
InferenceRegion::Scope(scope) if scope == expr_scope => {
InferenceRegion::Scope(scope, _) if scope == expr_scope => {
self.expression_type(expression)
}
_ => infer_scope_types(self.db(), expr_scope).expression_type(expression),
_ => infer_scope_types(self.db(), expr_scope, TypeContext::default())
.expression_type(expression),
}
}
/// Infers types in the given [`InferenceRegion`].
fn infer_region(&mut self) {
match self.region {
InferenceRegion::Scope(scope) => self.infer_region_scope(scope),
InferenceRegion::Scope(scope, tcx) => self.infer_region_scope(scope, tcx),
InferenceRegion::Definition(definition) => self.infer_region_definition(definition),
InferenceRegion::Deferred(definition) => self.infer_region_deferred(definition),
InferenceRegion::Expression(expression, tcx) => {
@@ -536,7 +537,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
fn infer_region_scope(&mut self, scope: ScopeId<'db>) {
fn infer_region_scope(&mut self, scope: ScopeId<'db>, tcx: TypeContext<'db>) {
let node = scope.node(self.db());
match node {
NodeWithScopeKind::Module => {
@@ -560,13 +561,22 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.infer_type_alias(type_alias.node(self.module()));
}
NodeWithScopeKind::ListComprehension(comprehension) => {
self.infer_list_comprehension_expression_scope(comprehension.node(self.module()));
self.infer_list_comprehension_expression_scope(
comprehension.node(self.module()),
tcx,
);
}
NodeWithScopeKind::SetComprehension(comprehension) => {
self.infer_set_comprehension_expression_scope(comprehension.node(self.module()));
self.infer_set_comprehension_expression_scope(
comprehension.node(self.module()),
tcx,
);
}
NodeWithScopeKind::DictComprehension(comprehension) => {
self.infer_dict_comprehension_expression_scope(comprehension.node(self.module()));
self.infer_dict_comprehension_expression_scope(
comprehension.node(self.module()),
tcx,
);
}
NodeWithScopeKind::GeneratorExpression(generator) => {
self.infer_generator_expression_scope(generator.node(self.module()));
@@ -8448,9 +8458,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = list;
let mut elts = elts.iter().map(|elt| [Some(elt)]);
let infer_elt_ty =
&mut |builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
self.infer_collection_literal(&mut elts, tcx, infer_elt_ty, KnownClass::List)
let mut infer_elt_ty =
|builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
self.infer_collection_literal(KnownClass::List, &mut elts, &mut infer_elt_ty, tcx)
.unwrap_or_else(|| {
KnownClass::List.to_specialized_instance(self.db(), &[Type::unknown()])
})
@@ -8464,9 +8475,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = set;
let mut elts = elts.iter().map(|elt| [Some(elt)]);
let infer_elt_ty =
&mut |builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
self.infer_collection_literal(&mut elts, tcx, infer_elt_ty, KnownClass::Set)
let mut infer_elt_ty =
|builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
self.infer_collection_literal(KnownClass::Set, &mut elts, &mut infer_elt_ty, tcx)
.unwrap_or_else(|| {
KnownClass::Set.to_specialized_instance(self.db(), &[Type::unknown()])
})
@@ -8556,14 +8568,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Avoid inferring the items multiple times if we already attempted to infer the
// dictionary literal as a `TypedDict`. This also allows us to infer using the
// type context of the expected `TypedDict` field.
let infer_elt_ty = &mut |builder: &mut Self, (_, elt, tcx): ArgExpr<'db, '_>| {
let mut infer_elt_ty = |builder: &mut Self, (_, elt, tcx): ArgExpr<'db, '_>| {
item_types
.get(&elt.node_index().load())
.copied()
.unwrap_or_else(|| builder.infer_expression(elt, tcx))
};
self.infer_collection_literal(&mut items, tcx, infer_elt_ty, KnownClass::Dict)
self.infer_collection_literal(KnownClass::Dict, &mut items, &mut infer_elt_ty, tcx)
.unwrap_or_else(|| {
KnownClass::Dict
.to_specialized_instance(self.db(), &[Type::unknown(), Type::unknown()])
@@ -8614,10 +8626,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Infer the type of a collection literal expression.
fn infer_collection_literal<'expr, const N: usize>(
&mut self,
elts: &mut dyn Iterator<Item = [Option<&'expr ast::Expr>; N]>,
tcx: TypeContext<'db>,
infer_elt_expression: &mut dyn FnMut(&mut Self, ArgExpr<'db, 'expr>) -> Type<'db>,
collection_class: KnownClass,
elts: &mut dyn Iterator<Item = [Option<&'expr ast::Expr>; N]>,
infer_elt_expression: &mut dyn FnMut(&mut Self, ArgExpr<'db, 'expr>) -> Type<'db>,
tcx: TypeContext<'db>,
) -> Option<Type<'db>> {
// Extract the type variable `T` from `list[T]` in typeshed.
let elt_tys = |collection_class: KnownClass| {
@@ -8638,7 +8650,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
};
let Some((collection_alias, generic_context, elt_tys)) = elt_tys(collection_class) else {
// Infer the element types without type context, and fallback to unknown for
// Infer the element types without type context, and fallback to `Unknown` for
// custom typesheds.
for (i, elt) in elts.flatten().flatten().enumerate() {
infer_elt_expression(self, (i, elt, TypeContext::default()));
@@ -8660,10 +8672,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Collect type constraints from the declared element types.
let (elt_tcx_constraints, elt_tcx_variance) = {
let mut builder = SpecializationBuilder::new(
self.db(),
generic_context.inferable_typevars(self.db()),
);
let mut builder = SpecializationBuilder::new(self.db(), inferable);
// For a given type variable, we keep track of the variance of any assignments to
// that type variable in the type context.
@@ -8829,7 +8838,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
return Type::unknown();
};
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let inference = infer_scope_types(self.db(), scope, TypeContext::default());
let yield_type = inference.expression_type(elt.as_ref());
if evaluation_mode.is_async() {
@@ -8845,47 +8854,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
/// Return a specialization of the collection class (list, dict, set) based on the type context and the inferred
/// element / key-value types from the comprehension expression.
fn infer_comprehension_specialization(
&self,
fn infer_comprehension_specialization<const N: usize>(
&mut self,
collection_class: KnownClass,
inferred_element_types: &[Type<'db>],
elements: &[Option<&ast::Expr>; N],
inference: &ScopeInference<'db>,
tcx: TypeContext<'db>,
) -> Type<'db> {
// Remove any union elements of that are unrelated to the collection type.
let tcx = tcx.map(|annotation| {
annotation.filter_disjoint_elements(
self.db(),
collection_class.to_instance(self.db()),
InferableTypeVars::None,
)
});
) -> Option<Type<'db>> {
let mut elements = [elements].into_iter().copied();
let mut infer_element_ty =
|_builder: &mut Self, (_, elt, _)| inference.expression_type(elt);
if let Some(annotated_element_types) = tcx
.known_specialization(self.db(), collection_class)
.map(|specialization| specialization.types(self.db()))
&& annotated_element_types
.iter()
.zip(inferred_element_types.iter())
.all(|(annotated, inferred)| inferred.is_assignable_to(self.db(), *annotated))
{
collection_class.to_specialized_instance(self.db(), annotated_element_types)
} else {
collection_class.to_specialized_instance(
self.db(),
inferred_element_types
.iter()
.map(|ty| {
UnionType::from_elements(
self.db(),
[
ty.promote_literals(self.db(), TypeContext::default()),
Type::unknown(),
],
)
})
.collect::<Vec<_>>(),
)
}
self.infer_collection_literal(collection_class, &mut elements, &mut infer_element_ty, tcx)
}
fn infer_list_comprehension_expression(
@@ -8909,10 +8889,41 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
return Type::unknown();
};
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let element_type = inference.expression_type(elt.as_ref());
let inference = infer_scope_types(self.db(), scope, tcx);
self.infer_comprehension_specialization(KnownClass::List, &[element_type], tcx)
self.infer_comprehension_specialization(KnownClass::List, &[Some(elt)], inference, tcx)
.unwrap_or_else(|| {
KnownClass::List.to_specialized_instance(self.db(), &[Type::unknown()])
})
}
fn infer_set_comprehension_expression(
&mut self,
setcomp: &ast::ExprSetComp,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprSetComp {
range: _,
node_index: _,
elt,
generators,
} = setcomp;
self.infer_first_comprehension_iter(generators);
let Some(scope_id) = self
.index
.try_node_scope(NodeWithScopeRef::SetComprehension(setcomp))
else {
return Type::unknown();
};
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope, tcx);
self.infer_comprehension_specialization(KnownClass::Set, &[Some(elt)], inference, tcx)
.unwrap_or_else(|| {
KnownClass::Set.to_specialized_instance(self.db(), &[Type::unknown()])
})
}
fn infer_dict_comprehension_expression(
@@ -8937,38 +8948,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
return Type::unknown();
};
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let key_type = inference.expression_type(key.as_ref());
let value_type = inference.expression_type(value.as_ref());
let inference = infer_scope_types(self.db(), scope, tcx);
self.infer_comprehension_specialization(KnownClass::Dict, &[key_type, value_type], tcx)
}
fn infer_set_comprehension_expression(
&mut self,
setcomp: &ast::ExprSetComp,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprSetComp {
range: _,
node_index: _,
elt,
generators,
} = setcomp;
self.infer_first_comprehension_iter(generators);
let Some(scope_id) = self
.index
.try_node_scope(NodeWithScopeRef::SetComprehension(setcomp))
else {
return Type::unknown();
};
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let element_type = inference.expression_type(elt.as_ref());
self.infer_comprehension_specialization(KnownClass::Set, &[element_type], tcx)
self.infer_comprehension_specialization(
KnownClass::Dict,
&[Some(key), Some(value)],
inference,
tcx,
)
.unwrap_or_else(|| {
KnownClass::Dict.to_specialized_instance(self.db(), &[Type::unknown(), Type::unknown()])
})
}
fn infer_generator_expression_scope(&mut self, generator: &ast::ExprGenerator) {
@@ -8984,7 +8974,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.infer_comprehensions(generators);
}
fn infer_list_comprehension_expression_scope(&mut self, listcomp: &ast::ExprListComp) {
fn infer_list_comprehension_expression_scope(
&mut self,
listcomp: &ast::ExprListComp,
tcx: TypeContext<'db>,
) {
let ast::ExprListComp {
range: _,
node_index: _,
@@ -8992,11 +8986,41 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
generators,
} = listcomp;
self.infer_expression(elt, TypeContext::default());
// Infer the element type using the outer type context.
let mut elts = [[Some(elt.as_ref())]].into_iter();
let mut infer_elt_ty =
|builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
self.infer_collection_literal(KnownClass::List, &mut elts, &mut infer_elt_ty, tcx);
self.infer_comprehensions(generators);
}
fn infer_dict_comprehension_expression_scope(&mut self, dictcomp: &ast::ExprDictComp) {
fn infer_set_comprehension_expression_scope(
&mut self,
setcomp: &ast::ExprSetComp,
tcx: TypeContext<'db>,
) {
let ast::ExprSetComp {
range: _,
node_index: _,
elt,
generators,
} = setcomp;
// Infer the element type using the outer type context.
let mut elts = [[Some(elt.as_ref())]].into_iter();
let mut infer_elt_ty =
|builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
self.infer_collection_literal(KnownClass::Set, &mut elts, &mut infer_elt_ty, tcx);
self.infer_comprehensions(generators);
}
fn infer_dict_comprehension_expression_scope(
&mut self,
dictcomp: &ast::ExprDictComp,
tcx: TypeContext<'db>,
) {
let ast::ExprDictComp {
range: _,
node_index: _,
@@ -9005,20 +9029,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
generators,
} = dictcomp;
self.infer_expression(key, TypeContext::default());
self.infer_expression(value, TypeContext::default());
self.infer_comprehensions(generators);
}
// Infer the key and value types using the outer type context.
let mut elts = [[Some(key.as_ref()), Some(value.as_ref())]].into_iter();
let mut infer_elt_ty =
|builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
self.infer_collection_literal(KnownClass::Dict, &mut elts, &mut infer_elt_ty, tcx);
fn infer_set_comprehension_expression_scope(&mut self, setcomp: &ast::ExprSetComp) {
let ast::ExprSetComp {
range: _,
node_index: _,
elt,
generators,
} = setcomp;
self.infer_expression(elt, TypeContext::default());
self.infer_comprehensions(generators);
}

View File

@@ -54,7 +54,7 @@ fn function_signature_expression_type<'db>(
infer_deferred_types(db, definition).expression_type(expression)
} else {
// expression is in the PEP-695 type params sub-scope
infer_scope_types(db, scope).expression_type(expression)
infer_scope_types(db, scope, TypeContext::default()).expression_type(expression)
}
}