Compare commits

...

4 Commits

Author SHA1 Message Date
Jack O'Connor
8effe4e84c got a salsa cycle 2026-01-13 23:39:41 -08:00
Jack O'Connor
4bcf624e9e getting the first new loop test working 2026-01-13 16:54:01 -08:00
Jack O'Connor
f13926ea01 Claude getting started, no inference yet 2026-01-13 15:48:04 -08:00
Jack O'Connor
6b79f415f9 Jack's basic loop tests 2026-01-13 11:27:12 -08:00
5 changed files with 433 additions and 11 deletions

View File

@@ -127,3 +127,53 @@ class NotBoolable:
while NotBoolable():
...
```
## Cyclic control flow
```py
def random() -> bool:
return False
i = 0
reveal_type(i) # revealed: Literal[0]
while random():
i += 1
reveal_type(i) # revealed: int
```
A more complex example, where the loop condition narrows both the loop-back value and the
end-of-loop value.
```py
x = "A"
while x != "C":
reveal_type(x) # revealed: Literal["A", "B"]
if random():
x = "B"
else:
x = "C"
reveal_type(x) # revealed: Literal["B", "C"]
reveal_type(x) # revealed: Literal["C"]
```
The same thing, but nested loops.
```py
x = "A"
while x != "E":
reveal_type(x) # revealed: Literal["A", "B", "D"]
while x != "C":
reveal_type(x) # revealed: Literal["A", "B", "D"]
if random():
x = "B"
else:
x = "C"
reveal_type(x) # revealed: Literal["B", "C", "D"]
reveal_type(x) # revealed: Literal["C"]
if random():
x = "D"
else:
x = "E"
reveal_type(x) # revealed: Literal["D", "E"]
reveal_type(x) # revealed: Literal["E"]
```

View File

@@ -26,7 +26,8 @@ use crate::semantic_index::definition::{
ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionNodeKey,
DefinitionNodeRef, Definitions, ExceptHandlerDefinitionNodeRef, ForStmtDefinitionNodeRef,
ImportDefinitionNodeRef, ImportFromDefinitionNodeRef, ImportFromSubmoduleDefinitionNodeRef,
MatchPatternDefinitionNodeRef, StarImportDefinitionNodeRef, WithItemDefinitionNodeRef,
LoopHeaderDefinitionNodeRef, MatchPatternDefinitionNodeRef, StarImportDefinitionNodeRef,
WithItemDefinitionNodeRef,
};
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::place::{PlaceExpr, PlaceTableBuilder, ScopedPlaceId};
@@ -65,6 +66,237 @@ impl Loop {
}
}
/// Visitor that collects all places (symbols and members) that are assigned within a loop body.
///
/// This is used for cyclic control flow analysis. Any place assigned in a loop may have
/// a different value on subsequent iterations, so we create loop header definitions
/// to represent the possible values at loop entry.
///
/// The visitor also tracks which places have augmented assignments (like `i += 1`), since
/// these create unbounded value growth that requires type widening.
///
/// The visitor does NOT recurse into nested function/class definitions since those are different scopes.
#[derive(Debug, Default)]
struct LoopBindingCollector {
/// All places that are assigned within the loop body.
bound_places: Vec<PlaceExpr>,
/// Places that have augmented assignments (subset of `bound_places`).
augmented_places: Vec<PlaceExpr>,
}
impl LoopBindingCollector {
/// Collect all places assigned in the given statements.
/// Returns `(all_bound_places, augmented_places)`.
fn collect(body: &[ast::Stmt]) -> (Vec<PlaceExpr>, Vec<PlaceExpr>) {
let mut collector = Self::default();
collector.visit_body(body);
(collector.bound_places, collector.augmented_places)
}
/// Add a place from an expression target (assignment LHS).
fn add_place_from_target(&mut self, target: &ast::Expr) {
match target {
// Simple name assignment: x = ...
ast::Expr::Name(name) => {
self.bound_places.push(PlaceExpr::from_expr_name(name));
}
// Attribute/subscript assignment: x.y = ..., x[0] = ...
ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => {
if let Some(place) = PlaceExpr::try_from_expr(target) {
self.bound_places.push(place);
}
}
// Unpacking: x, y = ... or (x, y) = ... or [x, y] = ...
ast::Expr::Tuple(tuple) => {
for elt in &tuple.elts {
self.add_place_from_target(elt);
}
}
ast::Expr::List(list) => {
for elt in &list.elts {
self.add_place_from_target(elt);
}
}
// Starred in unpacking: *x = ...
ast::Expr::Starred(starred) => {
self.add_place_from_target(&starred.value);
}
_ => {}
}
}
/// Add a place from an augmented assignment target.
fn add_augmented_place(&mut self, target: &ast::Expr) {
// First add to bound_places
self.add_place_from_target(target);
// Then also track as augmented
match target {
ast::Expr::Name(name) => {
self.augmented_places.push(PlaceExpr::from_expr_name(name));
}
ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => {
if let Some(place) = PlaceExpr::try_from_expr(target) {
self.augmented_places.push(place);
}
}
_ => {}
}
}
}
impl<'ast> Visitor<'ast> for LoopBindingCollector {
fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) {
match stmt {
// Assignment statements
ast::Stmt::Assign(node) => {
for target in &node.targets {
self.add_place_from_target(target);
}
}
ast::Stmt::AugAssign(node) => {
self.add_augmented_place(&node.target);
}
ast::Stmt::AnnAssign(node) => {
// Only a binding if there's a value
if node.value.is_some() {
self.add_place_from_target(&node.target);
}
}
// For loop target is bound
ast::Stmt::For(node) => {
self.add_place_from_target(&node.target);
self.visit_body(&node.body);
self.visit_body(&node.orelse);
}
// With statement binds the `as` target
ast::Stmt::With(node) => {
for item in &node.items {
if let Some(vars) = &item.optional_vars {
self.add_place_from_target(vars);
}
}
self.visit_body(&node.body);
}
// Exception handler binds the `as` name
ast::Stmt::Try(node) => {
self.visit_body(&node.body);
for handler in &node.handlers {
let ast::ExceptHandler::ExceptHandler(h) = handler;
if let Some(name) = &h.name {
self.bound_places
.push(PlaceExpr::Symbol(Symbol::new(name.id.clone())));
}
self.visit_body(&h.body);
}
self.visit_body(&node.orelse);
self.visit_body(&node.finalbody);
}
// Import statements bind names
ast::Stmt::Import(node) => {
for alias in &node.names {
let name = alias.asname.as_ref().unwrap_or(&alias.name);
self.bound_places
.push(PlaceExpr::Symbol(Symbol::new(name.id.clone())));
}
}
ast::Stmt::ImportFrom(node) => {
for alias in &node.names {
if &*alias.name != "*" {
let name = alias.asname.as_ref().unwrap_or(&alias.name);
self.bound_places
.push(PlaceExpr::Symbol(Symbol::new(name.id.clone())));
}
}
}
// Function/class definitions bind their name (but we don't recurse into their body)
ast::Stmt::FunctionDef(node) => {
self.bound_places
.push(PlaceExpr::Symbol(Symbol::new(node.name.id.clone())));
}
ast::Stmt::ClassDef(node) => {
self.bound_places
.push(PlaceExpr::Symbol(Symbol::new(node.name.id.clone())));
}
// Match statement can bind names in patterns
ast::Stmt::Match(node) => {
for case in &node.cases {
self.collect_pattern_bindings(&case.pattern);
self.visit_body(&case.body);
}
}
// Other statements: recurse using default behavior
_ => walk_stmt(self, stmt),
}
}
fn visit_expr(&mut self, expr: &'ast ast::Expr) {
// Named expressions (walrus operator): x := ...
if let ast::Expr::Named(node) = expr {
self.add_place_from_target(&node.target);
}
walk_expr(self, expr);
}
}
impl LoopBindingCollector {
/// Collect bindings from match patterns.
fn collect_pattern_bindings(&mut self, pattern: &ast::Pattern) {
match pattern {
ast::Pattern::MatchAs(p) => {
if let Some(name) = &p.name {
self.bound_places
.push(PlaceExpr::Symbol(Symbol::new(name.id.clone())));
}
if let Some(pat) = &p.pattern {
self.collect_pattern_bindings(pat);
}
}
ast::Pattern::MatchStar(p) => {
if let Some(name) = &p.name {
self.bound_places
.push(PlaceExpr::Symbol(Symbol::new(name.id.clone())));
}
}
ast::Pattern::MatchMapping(p) => {
for pat in &p.patterns {
self.collect_pattern_bindings(pat);
}
if let Some(rest) = &p.rest {
self.bound_places
.push(PlaceExpr::Symbol(Symbol::new(rest.id.clone())));
}
}
ast::Pattern::MatchSequence(p) => {
for pat in &p.patterns {
self.collect_pattern_bindings(pat);
}
}
ast::Pattern::MatchClass(p) => {
for pat in &p.arguments.patterns {
self.collect_pattern_bindings(pat);
}
for kw in &p.arguments.keywords {
self.collect_pattern_bindings(&kw.pattern);
}
}
ast::Pattern::MatchOr(p) => {
for pat in &p.patterns {
self.collect_pattern_bindings(pat);
}
}
ast::Pattern::MatchValue(_) | ast::Pattern::MatchSingleton(_) => {}
}
}
}
struct ScopeInfo {
file_scope_id: FileScopeId,
/// Current loop state; None if we are not currently visiting a loop
@@ -1924,19 +2156,38 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
self.in_type_checking_block = is_outer_block_in_type_checking;
}
ast::Stmt::While(ast::StmtWhile {
test,
body,
orelse,
range: _,
node_index: _,
}) => {
ast::Stmt::While(
while_stmt @ ast::StmtWhile {
test,
body,
orelse,
range: _,
node_index: _,
},
) => {
self.visit_expr(test);
let pre_loop = self.flow_snapshot();
let predicate = self.record_expression_narrowing_constraint(test);
self.record_reachability_constraint(predicate);
// Collect all places assigned in the loop body.
// Loop headers are needed for all bindings to track values that could
// flow back from previous iterations.
let (bound_places, _augmented_places) = LoopBindingCollector::collect(body);
// Create loop header definitions for each bound place.
// The loop header type will be the union of the seed type (before loop)
// and all types assigned in the loop body.
for place_expr in bound_places {
let place_id = self.add_place(place_expr);
let loop_header_ref = LoopHeaderDefinitionNodeRef {
while_stmt,
place: place_id,
};
self.push_additional_definition(place_id, loop_header_ref);
}
let outer_loop = self.push_loop();
self.visit_body(body);
let this_loop = self.pop_loop(outer_loop);

View File

@@ -285,6 +285,14 @@ pub(crate) enum DefinitionNodeRef<'ast, 'db> {
TypeVar(&'ast ast::TypeParamTypeVar),
ParamSpec(&'ast ast::TypeParamParamSpec),
TypeVarTuple(&'ast ast::TypeParamTypeVarTuple),
/// A loop header definition for cyclic control flow analysis.
LoopHeader(LoopHeaderDefinitionNodeRef<'ast>),
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct LoopHeaderDefinitionNodeRef<'ast> {
pub(crate) while_stmt: &'ast ast::StmtWhile,
pub(crate) place: ScopedPlaceId,
}
impl<'ast> From<&'ast ast::StmtFunctionDef> for DefinitionNodeRef<'ast, '_> {
@@ -335,6 +343,12 @@ impl<'ast> From<&'ast ast::TypeParamTypeVarTuple> for DefinitionNodeRef<'ast, '_
}
}
impl<'ast> From<LoopHeaderDefinitionNodeRef<'ast>> for DefinitionNodeRef<'ast, '_> {
fn from(value: LoopHeaderDefinitionNodeRef<'ast>) -> Self {
Self::LoopHeader(value)
}
}
impl<'ast> From<ImportDefinitionNodeRef<'ast>> for DefinitionNodeRef<'ast, '_> {
fn from(node_ref: ImportDefinitionNodeRef<'ast>) -> Self {
Self::Import(node_ref)
@@ -619,6 +633,12 @@ impl<'db> DefinitionNodeRef<'_, 'db> {
DefinitionNodeRef::TypeVarTuple(node) => {
DefinitionKind::TypeVarTuple(AstNodeRef::new(parsed, node))
}
DefinitionNodeRef::LoopHeader(LoopHeaderDefinitionNodeRef { while_stmt, place }) => {
DefinitionKind::LoopHeader(LoopHeaderDefinitionKind {
while_stmt: AstNodeRef::new(parsed, while_stmt),
place,
})
}
}
}
@@ -683,6 +703,7 @@ impl<'db> DefinitionNodeRef<'_, 'db> {
Self::TypeVar(node) => node.into(),
Self::ParamSpec(node) => node.into(),
Self::TypeVarTuple(node) => node.into(),
Self::LoopHeader(LoopHeaderDefinitionNodeRef { while_stmt, .. }) => while_stmt.into(),
}
}
}
@@ -753,6 +774,27 @@ pub enum DefinitionKind<'db> {
TypeVar(AstNodeRef<ast::TypeParamTypeVar>),
ParamSpec(AstNodeRef<ast::TypeParamParamSpec>),
TypeVarTuple(AstNodeRef<ast::TypeParamTypeVarTuple>),
/// A loop header definition representing the fixed-point type of a place at loop entry.
///
/// This is used to handle cyclic control flow in loops, where a variable's type at the
/// start of an iteration depends on assignments from previous iterations.
LoopHeader(LoopHeaderDefinitionKind),
}
/// Definition kind for a loop header entry.
#[derive(Clone, Debug, get_size2::GetSize)]
pub struct LoopHeaderDefinitionKind {
/// The while statement that defines this loop.
while_stmt: AstNodeRef<ast::StmtWhile>,
/// The place being defined at the loop header.
/// Currently unused but stored for future use in more precise loop analysis.
place: ScopedPlaceId,
}
impl LoopHeaderDefinitionKind {
pub(crate) fn while_stmt<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtWhile {
self.while_stmt.node(module)
}
}
impl DefinitionKind<'_> {
@@ -835,6 +877,8 @@ impl DefinitionKind<'_> {
DefinitionKind::TypeVarTuple(type_var_tuple) => {
type_var_tuple.node(module).name.range()
}
// For loop headers, use the while statement's test expression range
DefinitionKind::LoopHeader(loop_header) => loop_header.while_stmt(module).test.range(),
}
}
@@ -880,6 +924,8 @@ impl DefinitionKind<'_> {
DefinitionKind::TypeVar(type_var) => type_var.node(module).range(),
DefinitionKind::ParamSpec(param_spec) => param_spec.node(module).range(),
DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.node(module).range(),
// For loop headers, use the entire while statement range
DefinitionKind::LoopHeader(loop_header) => loop_header.while_stmt(module).range(),
}
}
@@ -935,7 +981,8 @@ impl DefinitionKind<'_> {
| DefinitionKind::WithItem(_)
| DefinitionKind::MatchPattern(_)
| DefinitionKind::ImportFromSubmodule(_)
| DefinitionKind::ExceptHandler(_) => DefinitionCategory::Binding,
| DefinitionKind::ExceptHandler(_)
| DefinitionKind::LoopHeader(_) => DefinitionCategory::Binding,
}
}
}
@@ -1280,6 +1327,12 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey {
}
}
impl From<&ast::StmtWhile> for DefinitionNodeKey {
fn from(node: &ast::StmtWhile) -> Self {
Self(NodeKey::from_node(node))
}
}
impl From<&ast::Parameter> for DefinitionNodeKey {
fn from(node: &ast::Parameter) -> Self {
Self(NodeKey::from_node(node))

View File

@@ -1377,7 +1377,8 @@ mod resolve_definition {
| DefinitionKind::ExceptHandler(_)
| DefinitionKind::TypeVar(_)
| DefinitionKind::ParamSpec(_)
| DefinitionKind::TypeVarTuple(_) => {
| DefinitionKind::TypeVarTuple(_)
| DefinitionKind::LoopHeader(_) => {
// Not yet implemented
return Err(());
}

View File

@@ -39,7 +39,7 @@ use crate::semantic_index::ast_ids::{HasScopedUseId, ScopedUseId};
use crate::semantic_index::definition::{
AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, ComprehensionDefinitionKind,
Definition, DefinitionKind, DefinitionNodeKey, DefinitionState, ExceptHandlerDefinitionKind,
ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind,
ForStmtDefinitionKind, LoopHeaderDefinitionKind, TargetKind, WithItemDefinitionKind,
};
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::narrowing_constraints::ConstraintKey;
@@ -1541,6 +1541,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
DefinitionKind::TypeVarTuple(node) => {
self.infer_typevartuple_definition(node.node(self.module()), definition);
}
DefinitionKind::LoopHeader(loop_header) => {
self.infer_loop_header_definition(loop_header, definition);
}
}
}
@@ -3809,6 +3812,70 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
);
}
/// Infer the type for a loop header definition.
///
/// Loop headers represent the fixed-point type of a place at loop entry.
/// The type is the union of:
/// 1. The seed type (visible before the loop)
/// 2. Types from all bindings within the loop body
fn infer_loop_header_definition(
&mut self,
loop_header: &LoopHeaderDefinitionKind,
definition: Definition<'db>,
) {
let db = self.db();
let module = self.module();
// Get seed type (visible before loop)
let file_scope = self.scope().file_scope_id(db);
let use_def = self.index.use_def_map(file_scope);
let seed_bindings = use_def.bindings_at_definition(definition);
let seed_place = place_from_bindings(db, seed_bindings);
let seed_ty = match seed_place.place {
Place::Defined(defined) => defined.ty,
Place::Undefined => Type::unknown(),
};
// Get the while loop body range
let while_stmt = loop_header.while_stmt(&module);
let body_range = while_stmt.body.first().map(|first| {
let last = while_stmt.body.last().unwrap();
ruff_text_size::TextRange::new(first.range().start(), last.range().end())
});
// Find all bindings for this place within the loop body and collect their types
let place = definition.place(db);
let mut all_types = vec![seed_ty];
if let Some(body_range) = body_range {
let all_bindings = use_def.reachable_bindings(place);
for binding_with_constraints in all_bindings {
let DefinitionState::Defined(binding) = binding_with_constraints.binding else {
continue;
};
// Skip the loop header itself
if binding == definition {
continue;
}
// Check if this binding is within the loop body
let binding_range = binding.kind(db).full_range(&module);
if body_range.contains_range(binding_range) {
let binding_ty = binding_type(db, binding);
all_types.push(binding_ty);
}
}
}
// Union all types together
let final_ty = UnionType::from_elements(db, all_types);
self.bindings
.insert(definition, final_ty, self.multi_inference_state);
}
fn infer_match_statement(&mut self, match_statement: &ast::StmtMatch) {
let ast::StmtMatch {
range: _,