Compare commits

...

1 Commits

Author SHA1 Message Date
Jack O'Connor
a38a18e2d3 WIP: Jack + Codex mucking around with loop control flow
(The `unsafe` code here is obviously not going to survive review.)
2026-01-12 19:22:02 -08:00
11 changed files with 758 additions and 48 deletions

View File

@@ -127,3 +127,64 @@ class NotBoolable:
while NotBoolable():
...
```
## Backwards control flow
```py
i = 0
reveal_type(i) # revealed: Literal[0]
while i < 1_000_000:
reveal_type(i) # revealed: int
i += 1
reveal_type(i) # revealed: int
reveal_type(i) # revealed: int
# TODO: None of these should need to be raised to `int`. Loop control flow analysis should take the
# loop condition into account.
i = 0
reveal_type(i) # revealed: Literal[0]
while i < 2:
# TODO: Should be Literal[0, 1].
reveal_type(i) # revealed: int
i += 1
# TODO: Should be Literal[1, 2].
reveal_type(i) # revealed: int
# TODO: Should be Literal[2].
reveal_type(i) # revealed: int
```
```py
def random() -> bool:
raise NotImplementedError
i = 0
while True:
reveal_type(i) # revealed: Literal[0, 1, 2]
if random():
i = 1
else:
i = "break"
break
# To get here we must take the `i = 1` branch above.
reveal_type(i) # revealed: Literal[1]
if random():
i = 2
reveal_type(i) # revealed: Literal[1, 2]
reveal_type(i) # revealed: Literal["break"]
i = 0
while random():
if random():
reveal_type(i) # revealed: Literal[0, 1, 2, 3]
i = 1
reveal_type(i) # revealed: Literal[1]
while random():
if random():
reveal_type(i) # revealed: Literal[1, 0, 2, 3]
i = 2
reveal_type(i) # revealed: Literal[2]
if random():
reveal_type(i) # revealed: Literal[1, 2, 0, 3]
i = 3
reveal_type(i) # revealed: Literal[3]
```

View File

@@ -229,6 +229,9 @@ pub(crate) struct SemanticIndex<'db> {
/// Map from a standalone expression to its [`Expression`] ingredient.
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
/// Map from loop-header definitions to their constituent definitions.
loop_header_definitions: FxHashMap<Definition<'db>, Vec<Definition<'db>>>,
/// Map from nodes that create a scope to the scope they create.
scopes_by_node: FxHashMap<NodeWithScopeKey, FileScopeId>,
@@ -319,6 +322,15 @@ impl<'db> SemanticIndex<'db> {
self.scope_ids_by_scope.iter().copied()
}
pub(crate) fn loop_header_definitions(
&self,
definition: Definition<'db>,
) -> Option<&[Definition<'db>]> {
self.loop_header_definitions
.get(&definition)
.map(|defs| defs.as_slice())
}
pub(crate) fn symbol_is_global_in_scope(
&self,
symbol: ScopedSymbolId,

View File

@@ -141,4 +141,10 @@ pub(crate) mod node_key {
Self(NodeKey::from_node(value))
}
}
impl From<ExpressionNodeKey> for NodeKey {
fn from(value: ExpressionNodeKey) -> Self {
value.0
}
}
}

View File

@@ -19,14 +19,15 @@ use ty_module_resolver::{ModuleName, resolve_module};
use crate::ast_node_ref::AstNodeRef;
use crate::node_key::NodeKey;
use crate::semantic_index::ast_ids::AstIdsBuilder;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::ast_ids::{AstIdsBuilder, ScopedUseId};
use crate::semantic_index::definition::{
AnnotatedAssignmentDefinitionNodeRef, AssignmentDefinitionNodeRef,
ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionNodeKey,
DefinitionNodeRef, Definitions, ExceptHandlerDefinitionNodeRef, ForStmtDefinitionNodeRef,
ImportDefinitionNodeRef, ImportFromDefinitionNodeRef, ImportFromSubmoduleDefinitionNodeRef,
MatchPatternDefinitionNodeRef, StarImportDefinitionNodeRef, WithItemDefinitionNodeRef,
ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionKind,
DefinitionNodeKey, DefinitionNodeRef, Definitions, ExceptHandlerDefinitionNodeRef,
ForStmtDefinitionNodeRef, ImportDefinitionNodeRef, ImportFromDefinitionNodeRef,
ImportFromSubmoduleDefinitionNodeRef, LoopHeaderDefinitionKind, MatchPatternDefinitionNodeRef,
StarImportDefinitionNodeRef, WithItemDefinitionNodeRef,
};
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::place::{PlaceExpr, PlaceTableBuilder, ScopedPlaceId};
@@ -43,6 +44,7 @@ use crate::semantic_index::scope::{
};
use crate::semantic_index::scope::{Scope, ScopeId, ScopeKind, ScopeLaziness};
use crate::semantic_index::symbol::{ScopedSymbolId, Symbol};
use crate::semantic_index::use_def::Bindings;
use crate::semantic_index::use_def::{
EnclosingSnapshotKey, FlowSnapshot, ScopedEnclosingSnapshotId, UseDefMapBuilder,
};
@@ -53,22 +55,35 @@ use crate::{Db, Program};
mod except_handlers;
#[derive(Clone, Debug)]
struct LoopUse {
place: ScopedPlaceId,
use_id: ScopedUseId,
}
#[derive(Clone, Debug, Default)]
struct Loop {
/// Flow states at each `break` in the current loop.
break_states: Vec<FlowSnapshot>,
uses: Vec<LoopUse>,
defined_places: FxHashSet<ScopedPlaceId>,
}
impl Loop {
fn push_break(&mut self, state: FlowSnapshot) {
self.break_states.push(state);
}
fn record_definition(&mut self, place: ScopedPlaceId) {
self.defined_places.insert(place);
}
}
struct ScopeInfo {
file_scope_id: FileScopeId,
/// Current loop state; None if we are not currently visiting a loop
current_loop: Option<Loop>,
condition_place_uses: Option<FxHashSet<ScopedPlaceId>>,
}
pub(super) struct SemanticIndexBuilder<'db, 'ast> {
@@ -109,6 +124,7 @@ pub(super) struct SemanticIndexBuilder<'db, 'ast> {
scopes_by_expression: ExpressionsScopeMapBuilder,
definitions_by_node: FxHashMap<DefinitionNodeKey, Definitions<'db>>,
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
loop_header_definitions: FxHashMap<Definition<'db>, Vec<Definition<'db>>>,
imported_modules: FxHashSet<ModuleName>,
seen_submodule_imports: FxHashSet<String>,
/// Hashset of all [`FileScopeId`]s that correspond to [generator functions].
@@ -147,6 +163,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
scopes_by_node: FxHashMap::default(),
definitions_by_node: FxHashMap::default(),
expressions_by_node: FxHashMap::default(),
loop_header_definitions: FxHashMap::default(),
seen_submodule_imports: FxHashSet::default(),
imported_modules: FxHashSet::default(),
@@ -256,8 +273,17 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
/// Pop a loop, replacing with the previous saved outer loop, if any.
fn pop_loop(&mut self, outer_loop: Option<Loop>) -> Loop {
std::mem::replace(&mut self.current_scope_info_mut().current_loop, outer_loop)
.expect("pop_loop() should not be called without a prior push_loop()")
let inner_loop = std::mem::take(&mut self.current_scope_info_mut().current_loop)
.expect("pop_loop() should not be called without a prior push_loop()");
let merged_outer = outer_loop.map(|mut outer| {
outer.uses.extend(inner_loop.uses.iter().cloned());
outer
.defined_places
.extend(inner_loop.defined_places.iter().copied());
outer
});
self.current_scope_info_mut().current_loop = merged_outer;
inner_loop
}
fn current_loop_mut(&mut self) -> Option<&mut Loop> {
@@ -308,6 +334,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
self.scope_stack.push(ScopeInfo {
file_scope_id,
current_loop: None,
condition_place_uses: None,
});
}
@@ -656,6 +683,9 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
place: ScopedPlaceId,
definition_node: impl Into<DefinitionNodeRef<'ast, 'db>> + std::fmt::Debug + Copy,
) -> Definition<'db> {
if let Some(current_loop) = self.current_loop_mut() {
current_loop.record_definition(place);
}
let (definition, num_definitions) = self.push_additional_definition(place, definition_node);
debug_assert_eq!(
num_definitions, 1,
@@ -755,6 +785,40 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
(definition, num_definitions)
}
fn add_loop_header_definition(
&mut self,
place: ScopedPlaceId,
loop_node: &'ast ast::StmtWhile,
definitions: Vec<Definition<'db>>,
seed_definitions: Vec<Definition<'db>>,
bindings: Bindings,
seed_bindings: Bindings,
) -> Definition<'db> {
let kind = DefinitionKind::LoopHeader(LoopHeaderDefinitionKind::new(
AstNodeRef::new(self.module, loop_node),
definitions,
seed_definitions,
bindings,
seed_bindings,
));
let is_reexported = kind.is_reexported();
let definition = Definition::new(
self.db,
self.file,
self.current_scope(),
place,
kind,
is_reexported,
);
self.add_entry_for_definition_key(DefinitionNodeKey::from_node_key(NodeKey::from_node(
loop_node,
)))
.push(definition);
definition
}
fn record_expression_narrowing_constraint(
&mut self,
predicate_node: &ast::Expr,
@@ -1318,6 +1382,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
scopes: self.scopes,
definitions_by_node: self.definitions_by_node,
expressions_by_node: self.expressions_by_node,
loop_header_definitions: self.loop_header_definitions,
scope_ids_by_scope: self.scope_ids_by_scope,
ast_ids,
scopes_by_expression: self.scopes_by_expression.build(),
@@ -1341,6 +1406,96 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
self.source_text
.get_or_init(|| source_text(self.db, self.file))
}
fn record_use(&mut self, place_id: ScopedPlaceId, expr_node_key: ExpressionNodeKey) {
if let ScopedPlaceId::Symbol(symbol_id) = place_id {
self.mark_symbol_used(symbol_id);
}
let use_id = self.current_ast_ids().record_use(expr_node_key);
self.current_use_def_map_mut()
.record_use(place_id, use_id, expr_node_key.into());
if let Some(condition_place_uses) = &mut self.current_scope_info_mut().condition_place_uses
{
condition_place_uses.insert(place_id);
}
if let Some(current_loop) = self.current_loop_mut() {
current_loop.uses.push(LoopUse {
place: place_id,
use_id,
});
}
}
fn create_loop_header_definitions(
&mut self,
loop_node: &'ast ast::StmtWhile,
loop_state: &Loop,
pre_loop: &FlowSnapshot,
post_body: &FlowSnapshot,
) {
let mut used_places = FxHashSet::default();
for loop_use in &loop_state.uses {
used_places.insert(loop_use.place);
}
let scope_id = self.current_scope();
for place in loop_state.defined_places.iter() {
if !used_places.contains(place) {
continue;
}
let pre_loop_binding_ids = pre_loop.binding_ids_for_place_excluding_unbound(*place);
let seed_bindings = pre_loop.bindings_for_place(*place);
let loop_bindings = self
.current_use_def_map_mut()
.merge_bindings(seed_bindings.clone(), post_body.bindings_for_place(*place));
let mut seed_definitions = self
.current_use_def_map()
.definitions_for_place_in_snapshot(pre_loop, *place);
let mut definitions = seed_definitions.clone();
definitions.extend(
self.current_use_def_map()
.definitions_for_place_in_snapshot(post_body, *place),
);
definitions.sort();
definitions.dedup();
seed_definitions.sort();
seed_definitions.dedup();
if definitions.is_empty() {
continue;
}
let header_definition = self.add_loop_header_definition(
*place,
loop_node,
definitions.clone(),
seed_definitions,
loop_bindings,
seed_bindings,
);
let header_definition_id = self.use_def_maps[scope_id]
.register_definition_with_bindings(
header_definition,
pre_loop.bindings_for_place(*place),
pre_loop.declarations_for_place(*place),
);
self.loop_header_definitions
.insert(header_definition, definitions);
if pre_loop_binding_ids.is_empty() {
continue;
}
for loop_use in loop_state.uses.iter().filter(|use_| use_.place == *place) {
self.current_use_def_map_mut().replace_use_bindings(
loop_use.use_id,
&pre_loop_binding_ids,
header_definition_id,
);
}
}
}
}
impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
@@ -1924,41 +2079,66 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
self.in_type_checking_block = is_outer_block_in_type_checking;
}
ast::Stmt::While(ast::StmtWhile {
test,
body,
orelse,
range: _,
node_index: _,
}) => {
ast::Stmt::While(stmt_while) => {
let ast::StmtWhile {
test,
body,
orelse,
range: _,
node_index: _,
} = stmt_while;
self.current_scope_info_mut()
.condition_place_uses
.replace(FxHashSet::default());
let outer_loop = self.push_loop();
let has_outer_loop = outer_loop.is_some();
self.visit_expr(test);
let condition_place_uses = self
.current_scope_info_mut()
.condition_place_uses
.take()
.unwrap();
let pre_loop = self.flow_snapshot();
let predicate = self.record_expression_narrowing_constraint(test);
self.record_reachability_constraint(predicate);
let predicate = self.build_predicate(test);
let predicate_id = self.add_predicate(predicate);
self.record_narrowing_constraint_id(predicate_id);
self.record_reachability_constraint_id(predicate_id);
let outer_loop = self.push_loop();
self.visit_body(body);
let this_loop = self.pop_loop(outer_loop);
let post_body = self.flow_snapshot();
if !has_outer_loop {
self.create_loop_header_definitions(
stmt_while, &this_loop, &pre_loop, &post_body,
);
}
// We execute the `else` branch once the condition evaluates to false. This could
// happen without ever executing the body, if the condition is false the first time
// it's tested. Or it could happen if a _later_ evaluation of the condition yields
// false. So we merge in the pre-loop state here into the post-body state:
self.flow_merge(pre_loop);
self.flow_merge(pre_loop.clone());
// The `else` branch can only be reached if the loop condition *can* be false. To
// model this correctly, we need a second copy of the while condition constraint,
// since the first and later evaluations might produce different results. We would
// otherwise simplify `predicate AND ~predicate` to `False`.
let later_predicate_id = self.current_use_def_map_mut().add_predicate(predicate);
let later_reachability_constraint = self
.current_reachability_constraints_mut()
.add_atom(later_predicate_id);
self.record_negated_reachability_constraint(later_reachability_constraint);
self.record_negated_narrowing_constraint(predicate);
let condition_depends_on_loop = condition_place_uses
.iter()
.any(|place| pre_loop.has_new_bindings_for_place(&post_body, *place));
if condition_depends_on_loop {
self.record_ambiguous_reachability();
} else {
let later_predicate_id =
self.current_use_def_map_mut().add_predicate(predicate);
let later_reachability_constraint = self
.current_reachability_constraints_mut()
.add_atom(later_predicate_id);
self.record_negated_reachability_constraint(later_reachability_constraint);
self.record_negated_narrowing_constraint(predicate);
}
self.visit_body(orelse);
@@ -2469,12 +2649,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
let place_id = self.add_place(place_expr);
if is_use {
if let ScopedPlaceId::Symbol(symbol_id) = place_id {
self.mark_symbol_used(symbol_id);
}
let use_id = self.current_ast_ids().record_use(expr);
self.current_use_def_map_mut()
.record_use(place_id, use_id, node_key);
self.record_use(place_id, expr.into());
}
if is_definition {

View File

@@ -13,6 +13,7 @@ use crate::node_key::NodeKey;
use crate::semantic_index::place::ScopedPlaceId;
use crate::semantic_index::scope::{FileScopeId, ScopeId};
use crate::semantic_index::symbol::ScopedSymbolId;
use crate::semantic_index::use_def::Bindings;
use crate::unpack::{Unpack, UnpackPosition};
/// A definition of a place.
@@ -753,6 +754,7 @@ pub enum DefinitionKind<'db> {
TypeVar(AstNodeRef<ast::TypeParamTypeVar>),
ParamSpec(AstNodeRef<ast::TypeParamParamSpec>),
TypeVarTuple(AstNodeRef<ast::TypeParamTypeVarTuple>),
LoopHeader(LoopHeaderDefinitionKind<'db>),
}
impl DefinitionKind<'_> {
@@ -835,6 +837,7 @@ impl DefinitionKind<'_> {
DefinitionKind::TypeVarTuple(type_var_tuple) => {
type_var_tuple.node(module).name.range()
}
DefinitionKind::LoopHeader(loop_header) => loop_header.node(module).range(),
}
}
@@ -880,6 +883,7 @@ impl DefinitionKind<'_> {
DefinitionKind::TypeVar(type_var) => type_var.node(module).range(),
DefinitionKind::ParamSpec(param_spec) => param_spec.node(module).range(),
DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.node(module).range(),
DefinitionKind::LoopHeader(loop_header) => loop_header.node(module).range(),
}
}
@@ -935,7 +939,8 @@ impl DefinitionKind<'_> {
| DefinitionKind::WithItem(_)
| DefinitionKind::MatchPattern(_)
| DefinitionKind::ImportFromSubmodule(_)
| DefinitionKind::ExceptHandler(_) => DefinitionCategory::Binding,
| DefinitionKind::ExceptHandler(_)
| DefinitionKind::LoopHeader(_) => DefinitionCategory::Binding,
}
}
}
@@ -1211,9 +1216,62 @@ impl ExceptHandlerDefinitionKind {
}
}
#[derive(Clone, Debug, get_size2::GetSize)]
pub struct LoopHeaderDefinitionKind<'db> {
node: AstNodeRef<ast::StmtWhile>,
definitions: Vec<Definition<'db>>,
seed_definitions: Vec<Definition<'db>>,
bindings: Bindings,
seed_bindings: Bindings,
}
impl<'db> LoopHeaderDefinitionKind<'db> {
pub(crate) fn new(
node: AstNodeRef<ast::StmtWhile>,
definitions: Vec<Definition<'db>>,
seed_definitions: Vec<Definition<'db>>,
bindings: Bindings,
seed_bindings: Bindings,
) -> Self {
Self {
node,
definitions,
seed_definitions,
bindings,
seed_bindings,
}
}
pub(crate) fn node<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtWhile {
self.node.node(module)
}
pub(crate) fn definitions(&self) -> &[Definition<'db>] {
&self.definitions
}
pub(crate) fn seed_definitions(&self) -> &[Definition<'db>] {
&self.seed_definitions
}
pub(crate) fn bindings(&self) -> &Bindings {
&self.bindings
}
pub(crate) fn seed_bindings(&self) -> &Bindings {
&self.seed_bindings
}
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, salsa::Update, get_size2::GetSize)]
pub(crate) struct DefinitionNodeKey(NodeKey);
impl DefinitionNodeKey {
pub(crate) fn from_node_key(node_key: NodeKey) -> Self {
Self(node_key)
}
}
impl From<&ast::Alias> for DefinitionNodeKey {
fn from(node: &ast::Alias) -> Self {
Self(NodeKey::from_node(node))
@@ -1280,6 +1338,12 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey {
}
}
impl From<&ast::StmtWhile> for DefinitionNodeKey {
fn from(node: &ast::StmtWhile) -> Self {
Self(NodeKey::from_node(node))
}
}
impl From<&ast::Parameter> for DefinitionNodeKey {
fn from(node: &ast::Parameter) -> Self {
Self(NodeKey::from_node(node))

View File

@@ -111,6 +111,13 @@ impl NarrowingConstraintsBuilder {
) -> ScopedNarrowingConstraint {
self.lists.intersect(a, b)
}
pub(crate) fn iter_predicates(
&self,
set: ScopedNarrowingConstraint,
) -> NarrowingConstraintsIterator<'_> {
self.lists.iter_set_reverse(set).copied()
}
}
// Iteration
@@ -142,12 +149,5 @@ mod tests {
}
}
impl NarrowingConstraintsBuilder {
pub(crate) fn iter_predicates(
&self,
set: ScopedNarrowingConstraint,
) -> NarrowingConstraintsIterator<'_> {
self.lists.iter_set_reverse(set).copied()
}
}
// Test-only impl removed; use the main impl above.
}

View File

@@ -261,8 +261,9 @@ use crate::semantic_index::reachability_constraints::{
};
use crate::semantic_index::scope::{FileScopeId, ScopeKind, ScopeLaziness};
use crate::semantic_index::symbol::ScopedSymbolId;
pub(crate) use crate::semantic_index::use_def::place_state::Bindings;
use crate::semantic_index::use_def::place_state::{
Bindings, Declarations, EnclosingSnapshot, LiveBindingsIterator, LiveDeclaration,
Declarations, EnclosingSnapshot, LiveBindingsIterator, LiveDeclaration,
LiveDeclarationsIterator, PlaceState, PreviousDefinitions, ScopedDefinitionId,
};
use crate::semantic_index::{EnclosingSnapshotResult, SemanticIndex};
@@ -364,6 +365,13 @@ impl<'db> UseDefMap<'db> {
)
}
pub(crate) fn bindings_from_snapshot<'a>(
&'a self,
bindings: &'a Bindings,
) -> BindingWithConstraintsIterator<'a, 'db> {
self.bindings_iterator(bindings, BoundnessAnalysis::BasedOnUnboundVisibility)
}
pub(crate) fn applicable_constraints(
&self,
constraint_key: ConstraintKey,
@@ -825,6 +833,52 @@ pub(super) struct FlowSnapshot {
reachability: ScopedReachabilityConstraintId,
}
impl FlowSnapshot {
pub(super) fn has_new_bindings_for_place(
&self,
other: &FlowSnapshot,
place: ScopedPlaceId,
) -> bool {
let (self_ids, other_ids) = match place {
ScopedPlaceId::Symbol(symbol) => (
self.symbol_states[symbol].bindings().binding_ids(),
other.symbol_states[symbol].bindings().binding_ids(),
),
ScopedPlaceId::Member(member) => (
self.member_states[member].bindings().binding_ids(),
other.member_states[member].bindings().binding_ids(),
),
};
other_ids.iter().any(|id| !self_ids.contains(id))
}
pub(super) fn bindings_for_place(&self, place: ScopedPlaceId) -> Bindings {
match place {
ScopedPlaceId::Symbol(symbol) => self.symbol_states[symbol].bindings().clone(),
ScopedPlaceId::Member(member) => self.member_states[member].bindings().clone(),
}
}
pub(super) fn binding_ids_for_place_excluding_unbound(
&self,
place: ScopedPlaceId,
) -> Vec<ScopedDefinitionId> {
let mut ids = match place {
ScopedPlaceId::Symbol(symbol) => self.symbol_states[symbol].bindings().binding_ids(),
ScopedPlaceId::Member(member) => self.member_states[member].bindings().binding_ids(),
};
ids.retain(|id| *id != ScopedDefinitionId::UNBOUND);
ids
}
pub(super) fn declarations_for_place(&self, place: ScopedPlaceId) -> Declarations {
match place {
ScopedPlaceId::Symbol(symbol) => self.symbol_states[symbol].declarations().clone(),
ScopedPlaceId::Member(member) => self.member_states[member].declarations().clone(),
}
}
}
/// A snapshot of the state of a single symbol (e.g. `obj`) and all of its associated members
/// (e.g. `obj.attr`, `obj["key"]`).
pub(super) struct SingleSymbolSnapshot {
@@ -1248,6 +1302,107 @@ impl<'db> UseDefMapBuilder<'db> {
self.record_node_reachability(node_key);
}
pub(super) fn merge_use_with_snapshot(
&mut self,
use_id: ScopedUseId,
snapshot: &FlowSnapshot,
place: ScopedPlaceId,
predicate: Option<ScopedPredicateId>,
) {
let mut bindings = self.bindings_by_use[use_id].clone();
let mut backedge_bindings = snapshot.bindings_for_place(place);
if let Some(predicate) = predicate {
if predicate != ScopedPredicateId::ALWAYS_TRUE
&& predicate != ScopedPredicateId::ALWAYS_FALSE
{
backedge_bindings
.record_narrowing_constraint(&mut self.narrowing_constraints, predicate.into());
}
}
bindings.merge(
backedge_bindings,
&mut self.narrowing_constraints,
&mut self.reachability_constraints,
);
self.bindings_by_use[use_id] = bindings;
}
pub(super) fn register_definition(
&mut self,
definition: Definition<'db>,
) -> ScopedDefinitionId {
self.all_definitions
.push(DefinitionState::Defined(definition))
}
pub(super) fn register_definition_with_bindings(
&mut self,
definition: Definition<'db>,
bindings: Bindings,
declarations: Declarations,
) -> ScopedDefinitionId {
let def_id = self
.all_definitions
.push(DefinitionState::Defined(definition));
self.bindings_by_definition.insert(definition, bindings);
self.declarations_by_binding
.insert(definition, declarations);
def_id
}
pub(super) fn replace_use_with_definition(
&mut self,
use_id: ScopedUseId,
definition_id: ScopedDefinitionId,
) {
self.bindings_by_use[use_id].replace_definition(definition_id);
}
pub(super) fn replace_use_bindings(
&mut self,
use_id: ScopedUseId,
from: &[ScopedDefinitionId],
definition_id: ScopedDefinitionId,
) {
self.bindings_by_use[use_id].replace_definitions(
from,
definition_id,
&mut self.narrowing_constraints,
&mut self.reachability_constraints,
);
}
pub(super) fn merge_bindings(&mut self, mut bindings: Bindings, other: Bindings) -> Bindings {
bindings.merge(
other,
&mut self.narrowing_constraints,
&mut self.reachability_constraints,
);
bindings
}
pub(super) fn definitions_for_place_in_snapshot(
&self,
snapshot: &FlowSnapshot,
place: ScopedPlaceId,
) -> Vec<Definition<'db>> {
let binding_ids = match place {
ScopedPlaceId::Symbol(symbol) => {
snapshot.symbol_states[symbol].bindings().binding_ids()
}
ScopedPlaceId::Member(member) => {
snapshot.member_states[member].bindings().binding_ids()
}
};
binding_ids
.into_iter()
.filter_map(|binding_id| match self.all_definitions[binding_id] {
DefinitionState::Defined(definition) => Some(definition),
DefinitionState::Undefined | DefinitionState::Deleted => None,
})
.collect()
}
pub(super) fn record_node_reachability(&mut self, node_key: NodeKey) {
self.node_reachability.insert(node_key, self.reachability);
}

View File

@@ -56,7 +56,7 @@ use crate::semantic_index::reachability_constraints::{
/// A newtype-index for a definition in a particular scope.
#[newtype_index]
#[derive(Ord, PartialOrd, get_size2::GetSize)]
pub(super) struct ScopedDefinitionId;
pub(crate) struct ScopedDefinitionId;
impl ScopedDefinitionId {
/// A special ID that is used to describe an implicit start-of-scope state. When
@@ -74,7 +74,7 @@ impl ScopedDefinitionId {
/// Live declarations for a single place at some point in control flow, with their
/// corresponding reachability constraints.
#[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub(super) struct Declarations {
pub(crate) struct Declarations {
/// A list of live declarations for this place, sorted by their `ScopedDefinitionId`
live_declarations: SmallVec<[LiveDeclaration; 2]>,
}
@@ -206,7 +206,7 @@ impl EnclosingSnapshot {
/// Live bindings for a single place at some point in control flow. Each live binding comes
/// with a set of narrowing constraints and a reachability constraint.
#[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub(super) struct Bindings {
pub(crate) struct Bindings {
/// The narrowing constraint applicable to the "unbound" binding, if we need access to it even
/// when it's not visible. This happens in class scopes, where local name bindings are not visible
/// to nested scopes, but we still need to know what narrowing constraints were applied to the
@@ -222,12 +222,109 @@ impl Bindings {
.unwrap_or(self.live_bindings[0].narrowing_constraint)
}
pub(super) fn has_defined_bindings(&self) -> bool {
self.live_bindings
.iter()
.any(|binding| !binding.binding.is_unbound())
}
pub(super) fn from_single(
binding: ScopedDefinitionId,
narrowing_constraint: ScopedNarrowingConstraint,
reachability_constraint: ScopedReachabilityConstraintId,
) -> Self {
Self {
unbound_narrowing_constraint: None,
live_bindings: smallvec![LiveBinding {
binding,
narrowing_constraint,
reachability_constraint,
}],
}
}
pub(super) fn representative_narrowing_constraint(&self) -> ScopedNarrowingConstraint {
self.live_bindings
.first()
.map(|binding| binding.narrowing_constraint)
.unwrap_or_else(ScopedNarrowingConstraint::empty)
}
pub(super) fn retain_bindings(
&mut self,
mut predicate: impl FnMut(ScopedDefinitionId) -> bool,
) {
self.live_bindings
.retain(|binding| predicate(binding.binding));
}
pub(super) fn binding_ids(&self) -> Vec<ScopedDefinitionId> {
self.live_bindings
.iter()
.map(|binding| binding.binding)
.collect()
}
pub(super) fn finish(&mut self, reachability_constraints: &mut ReachabilityConstraintsBuilder) {
self.live_bindings.shrink_to_fit();
for binding in &self.live_bindings {
reachability_constraints.mark_used(binding.reachability_constraint);
}
}
pub(super) fn replace_definition(&mut self, binding: ScopedDefinitionId) {
for live_binding in &mut self.live_bindings {
live_binding.binding = binding;
}
self.live_bindings
.sort_by(|left, right| left.binding.cmp(&right.binding));
}
pub(super) fn replace_definitions(
&mut self,
from: &[ScopedDefinitionId],
replacement: ScopedDefinitionId,
narrowing_constraints: &mut NarrowingConstraintsBuilder,
reachability_constraints: &mut ReachabilityConstraintsBuilder,
) {
if from.is_empty() {
return;
}
let mut changed = false;
for live_binding in &mut self.live_bindings {
if from.contains(&live_binding.binding) {
live_binding.binding = replacement;
changed = true;
}
}
if !changed {
return;
}
self.live_bindings
.sort_by(|left, right| left.binding.cmp(&right.binding));
let mut merged: SmallVec<[LiveBinding; 2]> = SmallVec::new();
for binding in std::mem::take(&mut self.live_bindings) {
match merged.last_mut() {
Some(last) if last.binding == binding.binding => {
last.narrowing_constraint = narrowing_constraints.intersect_constraints(
last.narrowing_constraint,
binding.narrowing_constraint,
);
last.reachability_constraint = reachability_constraints.add_or_constraint(
last.reachability_constraint,
binding.reachability_constraint,
);
}
_ => merged.push(binding),
}
}
self.live_bindings = merged;
}
}
/// One of the live bindings for a single place at some point in control flow.
@@ -291,6 +388,26 @@ impl Bindings {
}
}
pub(super) fn add_narrowing_constraint_from(
&mut self,
narrowing_constraints: &mut NarrowingConstraintsBuilder,
constraint: ScopedNarrowingConstraint,
) {
let predicates: Vec<_> = narrowing_constraints.iter_predicates(constraint).collect();
let mut unbound_constraint = self.unbound_narrowing_constraint;
for predicate in predicates {
if let Some(existing) = unbound_constraint {
unbound_constraint =
Some(narrowing_constraints.add_predicate_to_constraint(existing, predicate));
}
for binding in &mut self.live_bindings {
binding.narrowing_constraint = narrowing_constraints
.add_predicate_to_constraint(binding.narrowing_constraint, predicate);
}
}
self.unbound_narrowing_constraint = unbound_constraint;
}
/// Add given reachability constraint to all live bindings.
pub(super) fn record_reachability_constraint(
&mut self,

View File

@@ -18,6 +18,8 @@ use ruff_db::parsed::parsed_module;
use ruff_python_ast as ast;
use ruff_python_ast::name::Name;
use ruff_text_size::{Ranged, TextRange};
use rustc_hash::FxHashMap;
use salsa::plumbing::{AsId, Id};
use smallvec::{SmallVec, smallvec};
use ty_module_resolver::{KnownModule, Module, ModuleName, resolve_module};
@@ -158,8 +160,59 @@ pub fn check_types(db: &dyn Db, file: File) -> Vec<Diagnostic> {
diagnostics
}
thread_local! {
static LOOP_HEADER_OVERRIDE: RefCell<FxHashMap<Id, Type<'static>>> =
RefCell::new(FxHashMap::default());
}
fn loop_header_override<'db>(definition: Definition<'db>) -> Option<Type<'db>> {
let id = definition.as_id();
LOOP_HEADER_OVERRIDE.with(|cell| cell.borrow().get(&id).copied().map(restore_type_lifetime))
}
pub(crate) fn with_loop_header_override<'db, R>(
definition: Definition<'db>,
ty: Type<'db>,
f: impl FnOnce() -> R,
) -> R {
let id = definition.as_id();
LOOP_HEADER_OVERRIDE.with(|cell| {
let mut overrides = cell.borrow_mut();
let previous = overrides.insert(id, erase_type_lifetime(ty));
drop(overrides);
let result = f();
let mut overrides = cell.borrow_mut();
match previous {
Some(previous) => {
overrides.insert(id, previous);
}
None => {
overrides.remove(&id);
}
}
result
})
}
fn erase_type_lifetime<'db>(ty: Type<'db>) -> Type<'static> {
// SAFETY: `Type` is a copyable, db-backed handle; we only use this within
// a single thread to provide a temporary loop-header override.
unsafe { std::mem::transmute::<Type<'db>, Type<'static>>(ty) }
}
fn restore_type_lifetime<'db>(ty: Type<'static>) -> Type<'db> {
// SAFETY: This is the inverse of `erase_type_lifetime` and is only used
// within the dynamic scope of a loop-header override.
unsafe { std::mem::transmute::<Type<'static>, Type<'db>>(ty) }
}
/// Infer the type of a binding.
pub(crate) fn binding_type<'db>(db: &'db dyn Db, definition: Definition<'db>) -> Type<'db> {
if let Some(override_type) = loop_header_override(definition) {
return override_type;
}
let inference = infer_definition_types(db, definition);
inference.binding_type(definition)
}
@@ -11792,7 +11845,9 @@ impl<'db> UnionType<'db> {
elements
.into_iter()
.fold(
UnionBuilder::new(db).cycle_recovery(true),
UnionBuilder::new(db)
.cycle_recovery(true)
.recursively_defined(RecursivelyDefined::Yes),
|builder, element| builder.add(element.into()),
)
.build()

View File

@@ -1377,7 +1377,8 @@ mod resolve_definition {
| DefinitionKind::ExceptHandler(_)
| DefinitionKind::TypeVar(_)
| DefinitionKind::ParamSpec(_)
| DefinitionKind::TypeVarTuple(_) => {
| DefinitionKind::TypeVarTuple(_)
| DefinitionKind::LoopHeader(_) => {
// Not yet implemented
return Err(());
}

View File

@@ -39,7 +39,7 @@ use crate::semantic_index::ast_ids::{HasScopedUseId, ScopedUseId};
use crate::semantic_index::definition::{
AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, ComprehensionDefinitionKind,
Definition, DefinitionKind, DefinitionNodeKey, DefinitionState, ExceptHandlerDefinitionKind,
ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind,
ForStmtDefinitionKind, LoopHeaderDefinitionKind, TargetKind, WithItemDefinitionKind,
};
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::narrowing_constraints::ConstraintKey;
@@ -121,7 +121,7 @@ use crate::types::{
TypeQualifiers, TypeVarBoundOrConstraints, TypeVarBoundOrConstraintsEvaluation,
TypeVarDefaultEvaluation, TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance,
TypedDictType, UnionBuilder, UnionType, UnionTypeInstance, binding_type, infer_scope_types,
todo_type,
todo_type, with_loop_header_override,
};
use crate::types::{CallableTypes, overrides};
use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic};
@@ -1501,6 +1501,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
DefinitionKind::TypeVarTuple(node) => {
self.infer_typevartuple_definition(node.node(self.module()), definition);
}
DefinitionKind::LoopHeader(loop_header) => {
self.infer_loop_header_definition(&loop_header, definition);
}
}
}
@@ -1530,6 +1533,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
DefinitionKind::Assignment(assignment) => {
self.infer_assignment_deferred(assignment.value(self.module()));
}
DefinitionKind::LoopHeader(_) => {}
_ => {}
}
}
@@ -5384,6 +5388,66 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
add.insert(self, target_ty);
}
fn infer_loop_header_definition(
&mut self,
loop_header: &LoopHeaderDefinitionKind<'db>,
definition: Definition<'db>,
) {
const MAX_LOOP_HEADER_ITERATIONS: usize = 8;
let max_iterations = MAX_LOOP_HEADER_ITERATIONS;
let should_widen = true;
let add = self.add_binding(loop_header.node(self.module()).into(), definition);
let mut seed_union = UnionBuilder::new(self.db());
for seed_definition in loop_header.seed_definitions() {
seed_union = seed_union.add(binding_type(self.db(), *seed_definition));
}
let mut current = if seed_union.is_empty() {
Type::unknown()
} else {
seed_union.build()
};
let mut changed = false;
for _ in 0..max_iterations {
let previous_multi = self.multi_inference_state;
self.multi_inference_state = MultiInferenceState::Overwrite;
let next = with_loop_header_override(definition, current, || {
let mut union = UnionBuilder::new(self.db());
for definition in loop_header.definitions() {
let inferred = match definition.kind(self.db()) {
DefinitionKind::Assignment(assignment) => self
.infer_assignment_definition_impl(
&assignment,
*definition,
TypeContext::default(),
),
DefinitionKind::AugmentedAssignment(augmented_assignment) => {
self.infer_augment_assignment(augmented_assignment.node(self.module()))
}
_ => binding_type(self.db(), *definition),
};
union = union.add(inferred);
}
union.build()
});
self.multi_inference_state = previous_multi;
if next == current {
changed = false;
break;
}
changed = true;
current = next;
}
if should_widen && changed {
current = KnownClass::Int.to_instance(self.db());
}
add.insert(self, current);
}
fn infer_assignment_definition_impl(
&mut self,
assignment: &AssignmentDefinitionKind<'db>,