From a9fc648faf587eea54fefa41311ac16d7ba5106d Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Sat, 6 May 2023 12:20:08 -0400 Subject: [PATCH] Use `NodeId` for `Binding` source (#4234) --- crates/ruff/src/checkers/ast/mod.rs | 96 ++++++++++--------- .../rules/unused_loop_control_variable.rs | 14 +-- .../rules/pyflakes/rules/unused_variable.rs | 3 +- .../rules/pylint/rules/global_statement.rs | 6 +- .../src/analyze/branch_detection.rs | 58 ++++++----- crates/ruff_python_semantic/src/binding.rs | 4 +- crates/ruff_python_semantic/src/context.rs | 12 ++- 7 files changed, 97 insertions(+), 96 deletions(-) diff --git a/crates/ruff/src/checkers/ast/mod.rs b/crates/ruff/src/checkers/ast/mod.rs index 2a03890690..f7d06457ea 100644 --- a/crates/ruff/src/checkers/ast/mod.rs +++ b/crates/ruff/src/checkers/ast/mod.rs @@ -26,6 +26,7 @@ use ruff_python_semantic::binding::{ Importation, StarImportation, SubmoduleImportation, }; use ruff_python_semantic::context::Context; +use ruff_python_semantic::node::NodeId; use ruff_python_semantic::scope::{ClassDef, FunctionDef, Lambda, Scope, ScopeId, ScopeKind}; use ruff_python_stdlib::builtins::{BUILTINS, MAGIC_GLOBALS}; use ruff_python_stdlib::path::is_python_stub_file; @@ -231,7 +232,7 @@ where synthetic_usage: usage, typing_usage: None, range: *range, - source: Some(stmt), + source: self.ctx.stmt_id, context, exceptions, }); @@ -261,7 +262,7 @@ where synthetic_usage: usage, typing_usage: None, range: *range, - source: Some(stmt), + source: self.ctx.stmt_id, context, exceptions, }); @@ -688,7 +689,7 @@ where synthetic_usage: None, typing_usage: None, range: stmt.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -904,7 +905,7 @@ where synthetic_usage: Some((self.ctx.scope_id, alias.range())), typing_usage: None, range: alias.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -934,7 +935,7 @@ where synthetic_usage: None, typing_usage: None, range: alias.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -962,7 +963,7 @@ where }, typing_usage: None, range: alias.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -1222,7 +1223,7 @@ where synthetic_usage: Some((self.ctx.scope_id, alias.range())), typing_usage: None, range: alias.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -1318,7 +1319,7 @@ where }, typing_usage: None, range: alias.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -2022,7 +2023,7 @@ where synthetic_usage: None, typing_usage: None, range: stmt.range(), - source: Some(stmt), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }); @@ -2085,7 +2086,7 @@ where synthetic_usage: None, typing_usage: None, range: stmt.range(), - source: Some(stmt), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }); @@ -2246,7 +2247,7 @@ where synthetic_usage: None, typing_usage: None, range: stmt.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4121,7 +4122,7 @@ where synthetic_usage: None, typing_usage: None, range: arg.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4165,7 +4166,7 @@ where synthetic_usage: None, typing_usage: None, range: pattern.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4243,7 +4244,7 @@ impl<'a> Checker<'a> { && !(existing.kind.is_function_definition() && analyze::visibility::is_overload( &self.ctx, - cast::decorator_list(existing.source.as_ref().unwrap()), + cast::decorator_list(self.ctx.stmts[existing.source.unwrap()]), )) { if self.settings.rules.enabled(Rule::RedefinedWhileUnused) { @@ -4260,13 +4261,17 @@ impl<'a> Checker<'a> { BindingKind::ClassDefinition | BindingKind::FunctionDefinition ) .then(|| { - binding.source.as_ref().map_or(binding.range, |source| { - helpers::identifier_range(source, self.locator) + binding.source.map_or(binding.range, |source| { + helpers::identifier_range( + self.ctx.stmts[source], + self.locator, + ) }) }) .unwrap_or(binding.range), ); - if let Some(parent) = binding.source.as_ref() { + if let Some(parent) = binding.source { + let parent = self.ctx.stmts[parent]; if matches!(parent.node, StmtKind::ImportFrom { .. }) && parent.range().contains_range(binding.range) { @@ -4571,7 +4576,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4591,7 +4596,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4608,7 +4613,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4691,7 +4696,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4713,7 +4718,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4729,7 +4734,7 @@ impl<'a> Checker<'a> { synthetic_usage: None, typing_usage: None, range: expr.range(), - source: Some(self.ctx.current_stmt()), + source: self.ctx.stmt_id, context: self.ctx.execution_context(), exceptions: self.ctx.exceptions(), }, @@ -4937,9 +4942,7 @@ impl<'a> Checker<'a> { | StmtKind::AsyncFor { target, body, .. } = &stmt.node { if self.settings.rules.enabled(Rule::UnusedLoopControlVariable) { - flake8_bugbear::rules::unused_loop_control_variable( - self, stmt, target, body, - ); + flake8_bugbear::rules::unused_loop_control_variable(self, target, body); } } else { unreachable!("Expected ExprKind::For | ExprKind::AsyncFor"); @@ -5089,7 +5092,8 @@ impl<'a> Checker<'a> { for (name, index) in scope.bindings() { let binding = &self.ctx.bindings[*index]; if binding.kind.is_global() { - if let Some(stmt) = &binding.source { + if let Some(source) = binding.source { + let stmt = &self.ctx.stmts[source]; if matches!(stmt.node, StmtKind::Global { .. }) { diagnostics.push(Diagnostic::new( pylint::rules::GlobalVariableNotAssigned { @@ -5143,13 +5147,17 @@ impl<'a> Checker<'a> { | BindingKind::FunctionDefinition ) .then(|| { - rebound.source.as_ref().map_or(rebound.range, |source| { - helpers::identifier_range(source, self.locator) + rebound.source.map_or(rebound.range, |source| { + helpers::identifier_range( + self.ctx.stmts[source], + self.locator, + ) }) }) .unwrap_or(rebound.range), ); - if let Some(parent) = &rebound.source { + if let Some(source) = rebound.source { + let parent = &self.ctx.stmts[source]; if matches!(parent.node, StmtKind::ImportFrom { .. }) && parent.range().contains_range(rebound.range) { @@ -5203,11 +5211,7 @@ impl<'a> Checker<'a> { // Collect all unused imports by location. (Multiple unused imports at the same // location indicates an `import from`.) type UnusedImport<'a> = (&'a str, &'a TextRange); - type BindingContext<'a> = ( - RefEquality<'a, Stmt>, - Option>, - Exceptions, - ); + type BindingContext<'a> = (NodeId, Option, Exceptions); let mut unused: FxHashMap> = FxHashMap::default(); let mut ignored: FxHashMap> = @@ -5232,11 +5236,12 @@ impl<'a> Checker<'a> { continue; } - let child = binding.source.unwrap(); - let parent = self.ctx.stmts.parent(child); - let exceptions = binding.exceptions; + let child_id = binding.source.unwrap(); + let parent_id = self.ctx.stmts.parent_id(child_id); + let exceptions = binding.exceptions; let diagnostic_offset = binding.range.start(); + let child = &self.ctx.stmts[child_id]; let parent_offset = if matches!(child.node, StmtKind::ImportFrom { .. }) { Some(child.start()) } else { @@ -5249,12 +5254,12 @@ impl<'a> Checker<'a> { }) { ignored - .entry((RefEquality(child), parent.map(RefEquality), exceptions)) + .entry((child_id, parent_id, exceptions)) .or_default() .push((full_name, &binding.range)); } else { unused - .entry((RefEquality(child), parent.map(RefEquality), exceptions)) + .entry((child_id, parent_id, exceptions)) .or_default() .push((full_name, &binding.range)); } @@ -5264,10 +5269,10 @@ impl<'a> Checker<'a> { self.settings.ignore_init_module_imports && self.path.ends_with("__init__.py"); for ((defined_by, defined_in, exceptions), unused_imports) in unused .into_iter() - .sorted_by_key(|((defined_by, ..), ..)| defined_by.start()) + .sorted_by_key(|((defined_by, ..), ..)| *defined_by) { - let child: &Stmt = defined_by.into(); - let parent: Option<&Stmt> = defined_in.map(Into::into); + let child = self.ctx.stmts[defined_by]; + let parent = defined_in.map(|defined_in| self.ctx.stmts[defined_in]); let multiple = unused_imports.len() > 1; let in_except_handler = exceptions .intersects(Exceptions::MODULE_NOT_FOUND_ERROR | Exceptions::IMPORT_ERROR); @@ -5285,7 +5290,7 @@ impl<'a> Checker<'a> { ) { Ok(fix) => { if fix.is_deletion() || fix.content() == Some("pass") { - self.deletions.insert(defined_by); + self.deletions.insert(RefEquality(child)); } Some(fix) } @@ -5324,8 +5329,9 @@ impl<'a> Checker<'a> { } for ((child, .., exceptions), unused_imports) in ignored .into_iter() - .sorted_by_key(|((defined_by, ..), ..)| defined_by.start()) + .sorted_by_key(|((defined_by, ..), ..)| *defined_by) { + let child = self.ctx.stmts[child]; let multiple = unused_imports.len() > 1; let in_except_handler = exceptions .intersects(Exceptions::MODULE_NOT_FOUND_ERROR | Exceptions::IMPORT_ERROR); diff --git a/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs b/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs index ad58504ecd..6d4752f22a 100644 --- a/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs +++ b/crates/ruff/src/rules/flake8_bugbear/rules/unused_loop_control_variable.rs @@ -24,7 +24,6 @@ use serde::{Deserialize, Serialize}; use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_ast::types::RefEquality; use ruff_python_ast::visitor::Visitor; use ruff_python_ast::{helpers, visitor}; @@ -108,12 +107,7 @@ where } /// B007 -pub fn unused_loop_control_variable( - checker: &mut Checker, - stmt: &Stmt, - target: &Expr, - body: &[Stmt], -) { +pub fn unused_loop_control_variable(checker: &mut Checker, target: &Expr, body: &[Stmt]) { let control_names = { let mut finder = NameFinder::new(); finder.visit_expr(target); @@ -168,9 +162,9 @@ pub fn unused_loop_control_variable( let scope = checker.ctx.scope(); let binding = scope.bindings_for_name(name).find_map(|index| { let binding = &checker.ctx.bindings[*index]; - binding.source.and_then(|source| { - (RefEquality(source) == RefEquality(stmt)).then_some(binding) - }) + binding + .source + .and_then(|source| (Some(source) == checker.ctx.stmt_id).then_some(binding)) }); if let Some(binding) = binding { if binding.kind.is_loop_var() { diff --git a/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs b/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs index e2ae12fb8d..be23789f23 100644 --- a/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs +++ b/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs @@ -328,7 +328,8 @@ pub fn unused_variable(checker: &mut Checker, scope: ScopeId) { binding.range, ); if checker.patch(diagnostic.kind.rule()) { - if let Some(stmt) = binding.source { + if let Some(source) = binding.source { + let stmt = checker.ctx.stmts[source]; if let Some((kind, fix)) = remove_unused_variable(stmt, binding.range, checker) { if matches!(kind, DeletionKind::Whole) { diff --git a/crates/ruff/src/rules/pylint/rules/global_statement.rs b/crates/ruff/src/rules/pylint/rules/global_statement.rs index e1efbf26cb..a05b776530 100644 --- a/crates/ruff/src/rules/pylint/rules/global_statement.rs +++ b/crates/ruff/src/rules/pylint/rules/global_statement.rs @@ -1,5 +1,3 @@ -use rustpython_parser::ast::Stmt; - use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; @@ -55,9 +53,9 @@ pub fn global_statement(checker: &mut Checker, name: &str) { if let Some(index) = scope.get(name) { let binding = &checker.ctx.bindings[*index]; if binding.kind.is_global() { - let source: &Stmt = binding + let source = checker.ctx.stmts[binding .source - .expect("`global` bindings should always have a `source`"); + .expect("`global` bindings should always have a `source`")]; let diagnostic = Diagnostic::new( GlobalStatement { name: name.to_string(), diff --git a/crates/ruff_python_semantic/src/analyze/branch_detection.rs b/crates/ruff_python_semantic/src/analyze/branch_detection.rs index 3a1499bea1..efc60fbdad 100644 --- a/crates/ruff_python_semantic/src/analyze/branch_detection.rs +++ b/crates/ruff_python_semantic/src/analyze/branch_detection.rs @@ -1,47 +1,41 @@ use std::cmp::Ordering; -use ruff_python_ast::types::RefEquality; use rustpython_parser::ast::ExcepthandlerKind::ExceptHandler; use rustpython_parser::ast::{Stmt, StmtKind}; -use crate::node::Nodes; +use crate::node::{NodeId, Nodes}; /// Return the common ancestor of `left` and `right` below `stop`, or `None`. -fn common_ancestor<'a>( - left: &'a Stmt, - right: &'a Stmt, - stop: Option<&'a Stmt>, - node_tree: &Nodes<'a>, -) -> Option<&'a Stmt> { - if stop.map_or(false, |stop| { - RefEquality(left) == RefEquality(stop) || RefEquality(right) == RefEquality(stop) - }) { +fn common_ancestor( + left: NodeId, + right: NodeId, + stop: Option, + node_tree: &Nodes, +) -> Option { + if stop.map_or(false, |stop| left == stop || right == stop) { return None; } - if RefEquality(left) == RefEquality(right) { + if left == right { return Some(left); } - let left_id = node_tree.node_id(left)?; - let right_id = node_tree.node_id(right)?; - - let left_depth = node_tree.depth(left_id); - let right_depth = node_tree.depth(right_id); + let left_depth = node_tree.depth(left); + let right_depth = node_tree.depth(right); match left_depth.cmp(&right_depth) { Ordering::Less => { - let right_id = node_tree.parent_id(right_id)?; - common_ancestor(left, node_tree[right_id], stop, node_tree) + let right = node_tree.parent_id(right)?; + common_ancestor(left, right, stop, node_tree) } Ordering::Equal => { - let left_id = node_tree.parent_id(left_id)?; - let right_id = node_tree.parent_id(right_id)?; - common_ancestor(node_tree[left_id], node_tree[right_id], stop, node_tree) + let left = node_tree.parent_id(left)?; + let right = node_tree.parent_id(right)?; + common_ancestor(left, right, stop, node_tree) } Ordering::Greater => { - let left_id = node_tree.parent_id(left_id)?; - common_ancestor(node_tree[left_id], right, stop, node_tree) + let left = node_tree.parent_id(left)?; + common_ancestor(left, right, stop, node_tree) } } } @@ -78,21 +72,23 @@ fn alternatives(stmt: &Stmt) -> Vec> { /// Return `true` if `stmt` is a descendent of any of the nodes in `ancestors`. fn descendant_of<'a>( - stmt: &'a Stmt, + stmt: NodeId, ancestors: &[&'a Stmt], - stop: &'a Stmt, + stop: NodeId, node_tree: &Nodes<'a>, ) -> bool { - ancestors - .iter() - .any(|ancestor| common_ancestor(stmt, ancestor, Some(stop), node_tree).is_some()) + ancestors.iter().any(|ancestor| { + node_tree.node_id(ancestor).map_or(false, |ancestor| { + common_ancestor(stmt, ancestor, Some(stop), node_tree).is_some() + }) + }) } /// Return `true` if `left` and `right` are on different branches of an `if` or /// `try` statement. -pub fn different_forks<'a>(left: &'a Stmt, right: &'a Stmt, node_tree: &Nodes<'a>) -> bool { +pub fn different_forks(left: NodeId, right: NodeId, node_tree: &Nodes) -> bool { if let Some(ancestor) = common_ancestor(left, right, None, node_tree) { - for items in alternatives(ancestor) { + for items in alternatives(node_tree[ancestor]) { let l = descendant_of(left, &items, ancestor, node_tree); let r = descendant_of(right, &items, ancestor, node_tree); if l ^ r { diff --git a/crates/ruff_python_semantic/src/binding.rs b/crates/ruff_python_semantic/src/binding.rs index 6628fa235b..fe0d2ceadd 100644 --- a/crates/ruff_python_semantic/src/binding.rs +++ b/crates/ruff_python_semantic/src/binding.rs @@ -3,8 +3,8 @@ use std::ops::{Deref, Index, IndexMut}; use bitflags::bitflags; use ruff_text_size::TextRange; -use rustpython_parser::ast::Stmt; +use crate::node::NodeId; use crate::scope::ScopeId; #[derive(Debug, Clone)] @@ -14,7 +14,7 @@ pub struct Binding<'a> { /// The context in which the binding was created. pub context: ExecutionContext, /// The statement in which the [`Binding`] was defined. - pub source: Option<&'a Stmt>, + pub source: Option, /// Tuple of (scope index, range) indicating the scope and range at which /// the binding was last used in a runtime context. pub runtime_usage: Option<(ScopeId, TextRange)>, diff --git a/crates/ruff_python_semantic/src/context.rs b/crates/ruff_python_semantic/src/context.rs index bdd10a5146..912de79e43 100644 --- a/crates/ruff_python_semantic/src/context.rs +++ b/crates/ruff_python_semantic/src/context.rs @@ -252,7 +252,9 @@ impl<'a> Context<'a> { .take(scope_index) .all(|scope| scope.get(name).is_none()) { - return Some((binding.source.unwrap(), format!("{name}.{member}"))); + if let Some(source) = binding.source { + return Some((self.stmts[source], format!("{name}.{member}"))); + } } } } @@ -268,7 +270,9 @@ impl<'a> Context<'a> { .take(scope_index) .all(|scope| scope.get(name).is_none()) { - return Some((binding.source.unwrap(), (*name).to_string())); + if let Some(source) = binding.source { + return Some((self.stmts[source], (*name).to_string())); + } } } } @@ -283,7 +287,9 @@ impl<'a> Context<'a> { .take(scope_index) .all(|scope| scope.get(name).is_none()) { - return Some((binding.source.unwrap(), format!("{name}.{member}"))); + if let Some(source) = binding.source { + return Some((self.stmts[source], format!("{name}.{member}"))); + } } } }