[red-knot] infer attribute assignments bound in comprehensions (#17396)

## Summary

This PR is a follow-up to #16852.

Instance variables bound in comprehensions are recorded, allowing type
inference to work correctly.

This required adding support for unpacking in comprehension which
resolves https://github.com/astral-sh/ruff/issues/15369.

## Test Plan

One TODO in `mdtest/attributes.md` is now resolved, and some new test
cases are added.

---------

Co-authored-by: Dhruv Manilawala <dhruvmanila@gmail.com>
This commit is contained in:
Shunsuke Shibayama
2025-04-19 10:12:48 +09:00
committed by GitHub
parent 2a478ce1b2
commit da6b68cb58
10 changed files with 349 additions and 108 deletions

View File

@@ -940,7 +940,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
panic!("expected generator definition")
};
let target = comprehension.target();
let name = target.id().as_str();
let name = target.as_name_expr().unwrap().id().as_str();
assert_eq!(name, "x");
assert_eq!(target.range(), TextRange::new(23.into(), 24.into()));

View File

@@ -18,11 +18,12 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::ast_ids::AstIdsBuilder;
use crate::semantic_index::definition::{
AnnotatedAssignmentDefinitionKind, AnnotatedAssignmentDefinitionNodeRef,
AssignmentDefinitionKind, AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef,
Definition, DefinitionCategory, DefinitionKind, DefinitionNodeKey, DefinitionNodeRef,
Definitions, ExceptHandlerDefinitionNodeRef, ForStmtDefinitionKind, ForStmtDefinitionNodeRef,
ImportDefinitionNodeRef, ImportFromDefinitionNodeRef, MatchPatternDefinitionNodeRef,
StarImportDefinitionNodeRef, TargetKind, WithItemDefinitionKind, WithItemDefinitionNodeRef,
AssignmentDefinitionKind, AssignmentDefinitionNodeRef, ComprehensionDefinitionKind,
ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionKind,
DefinitionNodeKey, DefinitionNodeRef, Definitions, ExceptHandlerDefinitionNodeRef,
ForStmtDefinitionKind, ForStmtDefinitionNodeRef, ImportDefinitionNodeRef,
ImportFromDefinitionNodeRef, MatchPatternDefinitionNodeRef, StarImportDefinitionNodeRef,
TargetKind, WithItemDefinitionKind, WithItemDefinitionNodeRef,
};
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::predicate::{
@@ -850,31 +851,35 @@ impl<'db> SemanticIndexBuilder<'db> {
// The `iter` of the first generator is evaluated in the outer scope, while all subsequent
// nodes are evaluated in the inner scope.
self.add_standalone_expression(&generator.iter);
let value = self.add_standalone_expression(&generator.iter);
self.visit_expr(&generator.iter);
self.push_scope(scope);
self.push_assignment(CurrentAssignment::Comprehension {
node: generator,
first: true,
});
self.visit_expr(&generator.target);
self.pop_assignment();
self.add_unpackable_assignment(
&Unpackable::Comprehension {
node: generator,
first: true,
},
&generator.target,
value,
);
for expr in &generator.ifs {
self.visit_expr(expr);
}
for generator in generators_iter {
self.add_standalone_expression(&generator.iter);
let value = self.add_standalone_expression(&generator.iter);
self.visit_expr(&generator.iter);
self.push_assignment(CurrentAssignment::Comprehension {
node: generator,
first: false,
});
self.visit_expr(&generator.target);
self.pop_assignment();
self.add_unpackable_assignment(
&Unpackable::Comprehension {
node: generator,
first: false,
},
&generator.target,
value,
);
for expr in &generator.ifs {
self.visit_expr(expr);
@@ -933,9 +938,30 @@ impl<'db> SemanticIndexBuilder<'db> {
let current_assignment = match target {
ast::Expr::List(_) | ast::Expr::Tuple(_) => {
if matches!(unpackable, Unpackable::Comprehension { .. }) {
debug_assert_eq!(
self.scopes[self.current_scope()].node().scope_kind(),
ScopeKind::Comprehension
);
}
// The first iterator of the comprehension is evaluated in the outer scope, while all subsequent
// nodes are evaluated in the inner scope.
// SAFETY: The current scope is the comprehension, and the comprehension scope must have a parent scope.
let value_file_scope =
if let Unpackable::Comprehension { first: true, .. } = unpackable {
self.scope_stack
.iter()
.rev()
.nth(1)
.expect("The comprehension scope must have a parent scope")
.file_scope_id
} else {
self.current_scope()
};
let unpack = Some(Unpack::new(
self.db,
self.file,
value_file_scope,
self.current_scope(),
// SAFETY: `target` belongs to the `self.module` tree
#[allow(unsafe_code)]
@@ -1804,7 +1830,7 @@ where
let node_key = NodeKey::from_node(expr);
match expr {
ast::Expr::Name(name_node @ ast::ExprName { id, ctx, .. }) => {
ast::Expr::Name(ast::ExprName { id, ctx, .. }) => {
let (is_use, is_definition) = match (ctx, self.current_assignment()) {
(ast::ExprContext::Store, Some(CurrentAssignment::AugAssign(_))) => {
// For augmented assignment, the target expression is also used.
@@ -1867,12 +1893,17 @@ where
// implemented.
self.add_definition(symbol, named);
}
Some(CurrentAssignment::Comprehension { node, first }) => {
Some(CurrentAssignment::Comprehension {
unpack,
node,
first,
}) => {
self.add_definition(
symbol,
ComprehensionDefinitionNodeRef {
unpack,
iterable: &node.iter,
target: name_node,
target: expr,
first,
is_async: node.is_async,
},
@@ -2143,14 +2174,37 @@ where
DefinitionKind::WithItem(assignment),
);
}
Some(CurrentAssignment::Comprehension { .. }) => {
// TODO:
Some(CurrentAssignment::Comprehension {
unpack,
node,
first,
}) => {
// SAFETY: `iter` and `expr` belong to the `self.module` tree
#[allow(unsafe_code)]
let assignment = ComprehensionDefinitionKind {
target_kind: TargetKind::from(unpack),
iterable: unsafe {
AstNodeRef::new(self.module.clone(), &node.iter)
},
target: unsafe { AstNodeRef::new(self.module.clone(), expr) },
first,
is_async: node.is_async,
};
// Temporarily move to the scope of the method to which the instance attribute is defined.
// SAFETY: `self.scope_stack` is not empty because the targets in comprehensions should always introduce a new scope.
let scope = self.scope_stack.pop().expect("The popped scope must be a comprehension, which must have a parent scope");
self.register_attribute_assignment(
object,
attr,
DefinitionKind::Comprehension(assignment),
);
self.scope_stack.push(scope);
}
Some(CurrentAssignment::AugAssign(_)) => {
// TODO:
}
Some(CurrentAssignment::Named(_)) => {
// TODO:
// A named expression whose target is an attribute is syntactically prohibited
}
None => {}
}
@@ -2244,6 +2298,7 @@ enum CurrentAssignment<'a> {
Comprehension {
node: &'a ast::Comprehension,
first: bool,
unpack: Option<(UnpackPosition, Unpack<'a>)>,
},
WithItem {
item: &'a ast::WithItem,
@@ -2257,11 +2312,9 @@ impl CurrentAssignment<'_> {
match self {
Self::Assign { unpack, .. }
| Self::For { unpack, .. }
| Self::WithItem { unpack, .. } => unpack.as_mut().map(|(position, _)| position),
Self::AnnAssign(_)
| Self::AugAssign(_)
| Self::Named(_)
| Self::Comprehension { .. } => None,
| Self::WithItem { unpack, .. }
| Self::Comprehension { unpack, .. } => unpack.as_mut().map(|(position, _)| position),
Self::AnnAssign(_) | Self::AugAssign(_) | Self::Named(_) => None,
}
}
}
@@ -2316,13 +2369,17 @@ enum Unpackable<'a> {
item: &'a ast::WithItem,
is_async: bool,
},
Comprehension {
first: bool,
node: &'a ast::Comprehension,
},
}
impl<'a> Unpackable<'a> {
const fn kind(&self) -> UnpackKind {
match self {
Unpackable::Assign(_) => UnpackKind::Assign,
Unpackable::For(_) => UnpackKind::Iterable,
Unpackable::For(_) | Unpackable::Comprehension { .. } => UnpackKind::Iterable,
Unpackable::WithItem { .. } => UnpackKind::ContextManager,
}
}
@@ -2337,6 +2394,11 @@ impl<'a> Unpackable<'a> {
is_async: *is_async,
unpack,
},
Unpackable::Comprehension { node, first } => CurrentAssignment::Comprehension {
node,
first: *first,
unpack,
},
}
}
}

View File

@@ -281,8 +281,9 @@ pub(crate) struct ExceptHandlerDefinitionNodeRef<'a> {
#[derive(Copy, Clone, Debug)]
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) unpack: Option<(UnpackPosition, Unpack<'a>)>,
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::ExprName,
pub(crate) target: &'a ast::Expr,
pub(crate) first: bool,
pub(crate) is_async: bool,
}
@@ -374,11 +375,13 @@ impl<'db> DefinitionNodeRef<'db> {
is_async,
}),
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef {
unpack,
iterable,
target,
first,
is_async,
}) => DefinitionKind::Comprehension(ComprehensionDefinitionKind {
target_kind: TargetKind::from(unpack),
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
first,
@@ -474,7 +477,9 @@ impl<'db> DefinitionNodeRef<'db> {
unpack: _,
is_async: _,
}) => DefinitionNodeKey(NodeKey::from_node(target)),
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => {
DefinitionNodeKey(NodeKey::from_node(target))
}
Self::VariadicPositionalParameter(node) => node.into(),
Self::VariadicKeywordParameter(node) => node.into(),
Self::Parameter(node) => node.into(),
@@ -550,7 +555,7 @@ pub enum DefinitionKind<'db> {
AnnotatedAssignment(AnnotatedAssignmentDefinitionKind),
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
For(ForStmtDefinitionKind<'db>),
Comprehension(ComprehensionDefinitionKind),
Comprehension(ComprehensionDefinitionKind<'db>),
VariadicPositionalParameter(AstNodeRef<ast::Parameter>),
VariadicKeywordParameter(AstNodeRef<ast::Parameter>),
Parameter(AstNodeRef<ast::ParameterWithDefault>),
@@ -749,19 +754,24 @@ impl MatchPatternDefinitionKind {
}
#[derive(Clone, Debug)]
pub struct ComprehensionDefinitionKind {
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
first: bool,
is_async: bool,
pub struct ComprehensionDefinitionKind<'db> {
pub(super) target_kind: TargetKind<'db>,
pub(super) iterable: AstNodeRef<ast::Expr>,
pub(super) target: AstNodeRef<ast::Expr>,
pub(super) first: bool,
pub(super) is_async: bool,
}
impl ComprehensionDefinitionKind {
impl<'db> ComprehensionDefinitionKind<'db> {
pub(crate) fn iterable(&self) -> &ast::Expr {
self.iterable.node()
}
pub(crate) fn target(&self) -> &ast::ExprName {
pub(crate) fn target_kind(&self) -> TargetKind<'db> {
self.target_kind
}
pub(crate) fn target(&self) -> &ast::Expr {
self.target.node()
}

View File

@@ -1416,14 +1416,42 @@ impl<'db> ClassLiteralType<'db> {
}
}
}
DefinitionKind::Comprehension(_) => {
// TODO:
DefinitionKind::Comprehension(comprehension) => {
match comprehension.target_kind() {
TargetKind::Sequence(_, unpack) => {
// We found an unpacking assignment like:
//
// [... for .., self.name, .. in <iterable>]
let unpacked = infer_unpack_types(db, unpack);
let target_ast_id = comprehension
.target()
.scoped_expression_id(db, unpack.target_scope(db));
let inferred_ty = unpacked.expression_type(target_ast_id);
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
}
TargetKind::NameOrAttribute => {
// We found an attribute assignment like:
//
// [... for self.name in <iterable>]
let iterable_ty = infer_expression_type(
db,
index.expression(comprehension.iterable()),
);
// TODO: Potential diagnostics resulting from the iterable are currently not reported.
let inferred_ty = iterable_ty.iterate(db);
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
}
}
}
DefinitionKind::AugmentedAssignment(_) => {
// TODO:
}
DefinitionKind::NamedExpression(_) => {
// TODO:
// A named expression whose target is an attribute is syntactically prohibited
}
_ => {}
}

View File

@@ -49,9 +49,9 @@ use crate::module_resolver::resolve_module;
use crate::node_key::NodeKey;
use crate::semantic_index::ast_ids::{HasScopedExpressionId, HasScopedUseId, ScopedExpressionId};
use crate::semantic_index::definition::{
AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, Definition, DefinitionKind,
DefinitionNodeKey, ExceptHandlerDefinitionKind, ForStmtDefinitionKind, TargetKind,
WithItemDefinitionKind,
AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, ComprehensionDefinitionKind,
Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind,
ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind,
};
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::symbol::{
@@ -306,7 +306,7 @@ pub(super) fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> U
let _span =
tracing::trace_span!("infer_unpack_types", range=?unpack.range(db), ?file).entered();
let mut unpacker = Unpacker::new(db, unpack.scope(db));
let mut unpacker = Unpacker::new(db, unpack.target_scope(db), unpack.value_scope(db));
unpacker.unpack(unpack.target(db), unpack.value(db));
unpacker.finish()
}
@@ -946,13 +946,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_named_expression_definition(named_expression.node(), definition);
}
DefinitionKind::Comprehension(comprehension) => {
self.infer_comprehension_definition(
comprehension.iterable(),
comprehension.target(),
comprehension.is_first(),
comprehension.is_async(),
definition,
);
self.infer_comprehension_definition(comprehension, definition);
}
DefinitionKind::VariadicPositionalParameter(parameter) => {
self.infer_variadic_positional_parameter_definition(parameter, definition);
@@ -1937,11 +1931,13 @@ impl<'db> TypeInferenceBuilder<'db> {
for item in items {
let target = item.optional_vars.as_deref();
if let Some(target) = target {
self.infer_target(target, &item.context_expr, |db, ctx_manager_ty| {
self.infer_target(target, &item.context_expr, |builder, context_expr| {
// TODO: `infer_with_statement_definition` reports a diagnostic if `ctx_manager_ty` isn't a context manager
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `with not_context_manager as a.x: ...
ctx_manager_ty.enter(db)
builder
.infer_standalone_expression(context_expr)
.enter(builder.db())
});
} else {
// Call into the context expression inference to validate that it evaluates
@@ -2347,7 +2343,9 @@ impl<'db> TypeInferenceBuilder<'db> {
} = assignment;
for target in targets {
self.infer_target(target, value, |_, ty| ty);
self.infer_target(target, value, |builder, value_expr| {
builder.infer_standalone_expression(value_expr)
});
}
}
@@ -2357,23 +2355,16 @@ impl<'db> TypeInferenceBuilder<'db> {
/// targets (unpacking). If `target` is an attribute expression, we check that the assignment
/// is valid. For 'target's that are definitions, this check happens elsewhere.
///
/// The `to_assigned_ty` function is used to convert the inferred type of the `value` expression
/// to the type that is eventually assigned to the `target`.
///
/// # Panics
///
/// If the `value` is not a standalone expression.
fn infer_target<F>(&mut self, target: &ast::Expr, value: &ast::Expr, to_assigned_ty: F)
/// The `infer_value_expr` function is used to infer the type of the `value` expression which
/// are not `Name` expressions. The returned type is the one that is eventually assigned to the
/// `target`.
fn infer_target<F>(&mut self, target: &ast::Expr, value: &ast::Expr, infer_value_expr: F)
where
F: Fn(&'db dyn Db, Type<'db>) -> Type<'db>,
F: Fn(&mut TypeInferenceBuilder<'db>, &ast::Expr) -> Type<'db>,
{
let assigned_ty = match target {
ast::Expr::Name(_) => None,
_ => {
let value_ty = self.infer_standalone_expression(value);
Some(to_assigned_ty(self.db(), value_ty))
}
_ => Some(infer_value_expr(self, value)),
};
self.infer_target_impl(target, assigned_ty);
}
@@ -3126,11 +3117,13 @@ impl<'db> TypeInferenceBuilder<'db> {
is_async: _,
} = for_statement;
self.infer_target(target, iter, |db, iter_ty| {
self.infer_target(target, iter, |builder, iter_expr| {
// TODO: `infer_for_statement_definition` reports a diagnostic if `iter_ty` isn't iterable
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `for a.x in not_iterable: ...
iter_ty.iterate(db)
builder
.infer_standalone_expression(iter_expr)
.iterate(builder.db())
});
self.infer_body(body);
@@ -3959,15 +3952,17 @@ impl<'db> TypeInferenceBuilder<'db> {
is_async: _,
} = comprehension;
if !is_first {
self.infer_standalone_expression(iter);
}
// TODO more complex assignment targets
if let ast::Expr::Name(name) = target {
self.infer_definition(name);
} else {
self.infer_expression(target);
}
self.infer_target(target, iter, |builder, iter_expr| {
// TODO: `infer_comprehension_definition` reports a diagnostic if `iter_ty` isn't iterable
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `[... for a.x in not_iterable]
if is_first {
infer_same_file_expression_type(builder.db(), builder.index.expression(iter_expr))
} else {
builder.infer_standalone_expression(iter_expr)
}
.iterate(builder.db())
});
for expr in ifs {
self.infer_expression(expr);
}
@@ -3975,12 +3970,12 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_comprehension_definition(
&mut self,
iterable: &ast::Expr,
target: &ast::ExprName,
is_first: bool,
is_async: bool,
comprehension: &ComprehensionDefinitionKind<'db>,
definition: Definition<'db>,
) {
let iterable = comprehension.iterable();
let target = comprehension.target();
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db(), expression);
@@ -3990,7 +3985,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// (2) We must *not* call `self.extend()` on the result of the type inference,
// because `ScopedExpressionId`s are only meaningful within their own scope, so
// we'd add types for random wrong expressions in the current scope
let iterable_type = if is_first {
let iterable_type = if comprehension.is_first() {
let lookup_scope = self
.index
.parent_scope_id(self.scope().file_scope_id(self.db()))
@@ -4002,14 +3997,26 @@ impl<'db> TypeInferenceBuilder<'db> {
result.expression_type(iterable.scoped_expression_id(self.db(), self.scope()))
};
let target_type = if is_async {
let target_type = if comprehension.is_async() {
// TODO: async iterables/iterators! -- Alex
todo_type!("async iterables/iterators")
} else {
iterable_type.try_iterate(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, iterable_type, iterable.into());
err.fallback_element_type(self.db())
})
match comprehension.target_kind() {
TargetKind::Sequence(unpack_position, unpack) => {
let unpacked = infer_unpack_types(self.db(), unpack);
if unpack_position == UnpackPosition::First {
self.context.extend(unpacked.diagnostics());
}
let target_ast_id = target.scoped_expression_id(self.db(), self.scope());
unpacked.expression_type(target_ast_id)
}
TargetKind::NameOrAttribute => {
iterable_type.try_iterate(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, iterable_type, iterable.into());
err.fallback_element_type(self.db())
})
}
}
};
self.types.expressions.insert(

View File

@@ -18,16 +18,22 @@ use super::{TupleType, UnionType};
/// Unpacks the value expression type to their respective targets.
pub(crate) struct Unpacker<'db> {
context: InferContext<'db>,
scope: ScopeId<'db>,
target_scope: ScopeId<'db>,
value_scope: ScopeId<'db>,
targets: FxHashMap<ScopedExpressionId, Type<'db>>,
}
impl<'db> Unpacker<'db> {
pub(crate) fn new(db: &'db dyn Db, scope: ScopeId<'db>) -> Self {
pub(crate) fn new(
db: &'db dyn Db,
target_scope: ScopeId<'db>,
value_scope: ScopeId<'db>,
) -> Self {
Self {
context: InferContext::new(db, scope),
context: InferContext::new(db, target_scope),
targets: FxHashMap::default(),
scope,
target_scope,
value_scope,
}
}
@@ -43,7 +49,7 @@ impl<'db> Unpacker<'db> {
);
let value_type = infer_expression_types(self.db(), value.expression())
.expression_type(value.scoped_expression_id(self.db(), self.scope));
.expression_type(value.scoped_expression_id(self.db(), self.value_scope));
let value_type = match value.kind() {
UnpackKind::Assign => {
@@ -79,8 +85,10 @@ impl<'db> Unpacker<'db> {
) {
match target {
ast::Expr::Name(_) | ast::Expr::Attribute(_) => {
self.targets
.insert(target.scoped_expression_id(self.db(), self.scope), value_ty);
self.targets.insert(
target.scoped_expression_id(self.db(), self.target_scope),
value_ty,
);
}
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
self.unpack_inner(value, value_expr, value_ty);

View File

@@ -30,7 +30,9 @@ use crate::Db;
pub(crate) struct Unpack<'db> {
pub(crate) file: File,
pub(crate) file_scope: FileScopeId,
pub(crate) value_file_scope: FileScopeId,
pub(crate) target_file_scope: FileScopeId,
/// The target expression that is being unpacked. For example, in `(a, b) = (1, 2)`, the target
/// expression is `(a, b)`.
@@ -47,9 +49,19 @@ pub(crate) struct Unpack<'db> {
}
impl<'db> Unpack<'db> {
/// Returns the scope where the unpacking is happening.
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
/// Returns the scope in which the unpack value expression belongs.
///
/// The scope in which the target and value expression belongs to are usually the same
/// except in generator expressions and comprehensions (list/dict/set), where the value
/// expression of the first generator is evaluated in the outer scope, while the ones in the subsequent
/// generators are evaluated in the comprehension scope.
pub(crate) fn value_scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.value_file_scope(db).to_scope_id(db, self.file(db))
}
/// Returns the scope where the unpack target expression belongs to.
pub(crate) fn target_scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.target_file_scope(db).to_scope_id(db, self.file(db))
}
/// Returns the range of the unpack target expression.