From daefa74e9ad4014ed79f59b0a741a5bac2e5f0b7 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Mon, 7 Aug 2023 12:36:02 -0400 Subject: [PATCH] Remove async AST node variants for `with`, `for`, and `def` (#6369) ## Summary Per the suggestion in https://github.com/astral-sh/ruff/discussions/6183, this PR removes `AsyncWith`, `AsyncFor`, and `AsyncFunctionDef`, replacing them with an `is_async` field on the non-async variants of those structs. Unlike an interpreter, we _generally_ have identical handling for these nodes, so separating them into distinct variants adds complexity from which we don't really benefit. This can be seen below, where we get to remove a _ton_ of code related to adding generic `Any*` wrappers, and a ton of duplicate branches for these cases. ## Test Plan `cargo test` is unchanged, apart from parser snapshots. --- crates/ruff/src/autofix/edits.rs | 5 +- .../ast/analyze/deferred_for_loops.rs | 25 ++- .../checkers/ast/analyze/deferred_scopes.rs | 10 +- .../src/checkers/ast/analyze/expression.rs | 2 +- .../src/checkers/ast/analyze/statement.rs | 38 +--- crates/ruff/src/checkers/ast/mod.rs | 37 ++-- crates/ruff/src/docstrings/extraction.rs | 3 +- .../src/rules/flake8_annotations/helpers.rs | 8 - .../rules/abstract_base_class.rs | 10 +- .../rules/function_uses_loop_variable.rs | 8 +- .../rules/jump_statement_in_finally.rs | 7 +- .../rules/iter_method_return_iterable.rs | 12 +- .../flake8_pyi/rules/non_self_return_type.rs | 4 +- .../flake8_pytest_style/rules/fixture.rs | 2 +- .../rules/flake8_pytest_style/rules/raises.rs | 7 +- .../src/rules/flake8_return/rules/function.rs | 7 +- .../ruff/src/rules/flake8_return/visitor.rs | 9 +- .../flake8_simplify/rules/ast_unary_op.rs | 6 +- .../rules/flake8_simplify/rules/ast_with.rs | 24 ++- .../rules/flake8_simplify/rules/fix_with.rs | 10 +- .../rules/reimplemented_builtin.rs | 2 +- .../rules/unused_arguments.rs | 7 - crates/ruff/src/rules/isort/block.rs | 25 +-- .../mccabe/rules/function_is_too_complex.rs | 9 +- .../pycodestyle/rules/lambda_assignment.rs | 2 + .../src/rules/pydocstyle/rules/sections.rs | 4 +- .../pyflakes/rules/break_outside_loop.rs | 6 +- .../pyflakes/rules/continue_outside_loop.rs | 6 +- .../rules/pyflakes/rules/undefined_local.rs | 2 +- .../rules/pyflakes/rules/unused_variable.rs | 2 +- crates/ruff/src/rules/pylint/helpers.rs | 9 +- .../rules/pylint/rules/continue_in_finally.rs | 9 +- .../rules/pylint/rules/redefined_loop_name.rs | 19 +- .../rules/pylint/rules/too_many_branches.rs | 1 - .../rules/pylint/rules/too_many_statements.rs | 4 +- .../pylint/rules/useless_else_on_loop.rs | 9 +- .../rules/yield_from_in_async_function.rs | 12 +- .../rules/super_call_with_parameters.rs | 2 +- .../pyupgrade/rules/yield_in_for_loop.rs | 71 ++++--- .../ruff/src/rules/ruff/rules/unreachable.rs | 20 +- .../rules/type_check_without_type_error.rs | 2 +- crates/ruff_python_ast/src/cast.rs | 20 +- crates/ruff_python_ast/src/comparable.rs | 72 +------ crates/ruff_python_ast/src/function.rs | 134 ------------- crates/ruff_python_ast/src/helpers.rs | 33 +--- crates/ruff_python_ast/src/identifier.rs | 6 +- crates/ruff_python_ast/src/lib.rs | 1 - crates/ruff_python_ast/src/node.rs | 183 ------------------ crates/ruff_python_ast/src/nodes.rs | 99 ++-------- .../ruff_python_ast/src/statement_visitor.rs | 10 - crates/ruff_python_ast/src/traversal.rs | 11 -- crates/ruff_python_ast/src/visitor.rs | 38 ---- .../ruff_python_ast/src/visitor/preorder.rs | 21 +- crates/ruff_python_codegen/src/generator.rs | 89 ++------- .../src/comments/placement.rs | 14 +- .../src/expression/expr_named_expr.rs | 1 - crates/ruff_python_formatter/src/generated.rs | 108 ----------- .../src/statement/mod.rs | 6 - .../src/statement/stmt_async_for.rs | 23 --- .../src/statement/stmt_async_function_def.rs | 24 --- .../src/statement/stmt_async_with.rs | 22 --- .../src/statement/stmt_for.rs | 106 ++-------- .../src/statement/stmt_function_def.rs | 71 ++----- .../src/statement/stmt_with.rs | 106 ++-------- .../src/statement/suite.rs | 5 +- crates/ruff_python_parser/src/python.lalrpop | 18 +- crates/ruff_python_parser/src/python.rs | 20 +- ...on_parser__context__tests__assign_for.snap | 1 + ...n_parser__context__tests__assign_with.snap | 1 + ...unction__tests__function_kw_only_args.snap | 1 + ...__function_kw_only_args_with_defaults.snap | 1 + ...er__function__tests__function_no_args.snap | 1 + ...__tests__function_no_args_with_ranges.snap | 1 + ..._tests__function_pos_and_kw_only_args.snap | 1 + ...on_pos_and_kw_only_args_with_defaults.snap | 1 + ...w_only_args_with_defaults_and_varargs.snap | 1 + ..._with_defaults_and_varargs_and_kwargs.snap | 1 + ...r__function__tests__function_pos_args.snap | 1 + ...ests__function_pos_args_with_defaults.snap | 1 + ..._tests__function_pos_args_with_ranges.snap | 1 + ...rser__parser__tests__decorator_ranges.snap | 1 + ..._parser__parser__tests__jupyter_magic.snap | 2 + ...on_parser__parser__tests__parse_class.snap | 2 + ...ser__tests__parse_function_definition.snap | 7 + ...ser__parser__tests__variadic_generics.snap | 1 + ...parser__parser__tests__with_statement.snap | 28 ++- .../src/analyze/visibility.rs | 80 ++++---- crates/ruff_python_semantic/src/definition.rs | 1 - crates/ruff_python_semantic/src/globals.rs | 2 +- crates/ruff_python_semantic/src/model.rs | 11 +- crates/ruff_python_semantic/src/scope.rs | 7 - 91 files changed, 375 insertions(+), 1478 deletions(-) delete mode 100644 crates/ruff_python_ast/src/function.rs delete mode 100644 crates/ruff_python_formatter/src/statement/stmt_async_for.rs delete mode 100644 crates/ruff_python_formatter/src/statement/stmt_async_function_def.rs delete mode 100644 crates/ruff_python_formatter/src/statement/stmt_async_with.rs diff --git a/crates/ruff/src/autofix/edits.rs b/crates/ruff/src/autofix/edits.rs index 17736349a5..6bdc112ffb 100644 --- a/crates/ruff/src/autofix/edits.rs +++ b/crates/ruff/src/autofix/edits.rs @@ -179,16 +179,13 @@ fn is_only(vec: &[T], value: &T) -> bool { fn is_lone_child(child: &Stmt, parent: &Stmt) -> bool { match parent { Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { body, .. }) | Stmt::ClassDef(ast::StmtClassDef { body, .. }) - | Stmt::With(ast::StmtWith { body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => { + | Stmt::With(ast::StmtWith { body, .. }) => { if is_only(body, child) { return true; } } Stmt::For(ast::StmtFor { body, orelse, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { body, orelse, .. }) | Stmt::While(ast::StmtWhile { body, orelse, .. }) => { if is_only(body, child) || is_only(orelse, child) { return true; diff --git a/crates/ruff/src/checkers/ast/analyze/deferred_for_loops.rs b/crates/ruff/src/checkers/ast/analyze/deferred_for_loops.rs index c9ed5be35d..8898e1df69 100644 --- a/crates/ruff/src/checkers/ast/analyze/deferred_for_loops.rs +++ b/crates/ruff/src/checkers/ast/analyze/deferred_for_loops.rs @@ -11,21 +11,18 @@ pub(crate) fn deferred_for_loops(checker: &mut Checker) { for snapshot in for_loops { checker.semantic.restore(snapshot); - if let Stmt::For(ast::StmtFor { + let Stmt::For(ast::StmtFor { target, iter, body, .. - }) - | Stmt::AsyncFor(ast::StmtAsyncFor { - target, iter, body, .. - }) = &checker.semantic.current_statement() - { - if checker.enabled(Rule::UnusedLoopControlVariable) { - flake8_bugbear::rules::unused_loop_control_variable(checker, target, body); - } - if checker.enabled(Rule::IncorrectDictIterator) { - perflint::rules::incorrect_dict_iterator(checker, target, iter); - } - } else { - unreachable!("Expected Expr::For | Expr::AsyncFor"); + }) = checker.semantic.current_statement() + else { + unreachable!("Expected Stmt::For"); + }; + + if checker.enabled(Rule::UnusedLoopControlVariable) { + flake8_bugbear::rules::unused_loop_control_variable(checker, target, body); + } + if checker.enabled(Rule::IncorrectDictIterator) { + perflint::rules::incorrect_dict_iterator(checker, target, iter); } } } diff --git a/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs b/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs index bcb655a16b..6ab4f74d30 100644 --- a/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs +++ b/crates/ruff/src/checkers/ast/analyze/deferred_scopes.rs @@ -241,10 +241,7 @@ pub(crate) fn deferred_scopes(checker: &mut Checker) { flake8_pyi::rules::unused_private_typed_dict(checker, scope, &mut diagnostics); } - if matches!( - scope.kind, - ScopeKind::Function(_) | ScopeKind::AsyncFunction(_) | ScopeKind::Lambda(_) - ) { + if matches!(scope.kind, ScopeKind::Function(_) | ScopeKind::Lambda(_)) { if checker.enabled(Rule::UnusedVariable) { pyflakes::rules::unused_variable(checker, scope, &mut diagnostics); } @@ -270,10 +267,7 @@ pub(crate) fn deferred_scopes(checker: &mut Checker) { } } - if matches!( - scope.kind, - ScopeKind::Function(_) | ScopeKind::AsyncFunction(_) | ScopeKind::Module - ) { + if matches!(scope.kind, ScopeKind::Function(_) | ScopeKind::Module) { if enforce_typing_imports { let runtime_imports: Vec<&Binding> = checker .semantic diff --git a/crates/ruff/src/checkers/ast/analyze/expression.rs b/crates/ruff/src/checkers/ast/analyze/expression.rs index 6398e14e87..038ded69c7 100644 --- a/crates/ruff/src/checkers/ast/analyze/expression.rs +++ b/crates/ruff/src/checkers/ast/analyze/expression.rs @@ -206,7 +206,7 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) { } ExprContext::Store => { if checker.enabled(Rule::NonLowercaseVariableInFunction) { - if checker.semantic.current_scope().kind.is_any_function() { + if checker.semantic.current_scope().kind.is_function() { // Ignore globals. if !checker .semantic diff --git a/crates/ruff/src/checkers/ast/analyze/statement.rs b/crates/ruff/src/checkers/ast/analyze/statement.rs index eee23db997..db0fa91f56 100644 --- a/crates/ruff/src/checkers/ast/analyze/statement.rs +++ b/crates/ruff/src/checkers/ast/analyze/statement.rs @@ -70,22 +70,14 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { } } Stmt::FunctionDef(ast::StmtFunctionDef { + is_async, name, decorator_list, returns, parameters, body, type_params, - .. - }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - name, - decorator_list, - returns, - parameters, - body, - type_params, - .. + range: _, }) => { if checker.enabled(Rule::DjangoNonLeadingReceiverDecorator) { flake8_django::rules::non_leading_receiver_decorator(checker, decorator_list); @@ -151,11 +143,11 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { flake8_pyi::rules::non_self_return_type( checker, stmt, + *is_async, name, decorator_list, returns.as_ref().map(AsRef::as_ref), parameters, - stmt.is_async_function_def_stmt(), ); } if checker.enabled(Rule::CustomTypeVarReturnType) { @@ -181,12 +173,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { } } if checker.enabled(Rule::BadExitAnnotation) { - flake8_pyi::rules::bad_exit_annotation( - checker, - stmt.is_async_function_def_stmt(), - name, - parameters, - ); + flake8_pyi::rules::bad_exit_annotation(checker, *is_async, name, parameters); } if checker.enabled(Rule::RedundantNumericUnion) { flake8_pyi::rules::redundant_numeric_union(checker, parameters); @@ -1097,8 +1084,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { pygrep_hooks::rules::non_existent_mock_method(checker, test); } } - Stmt::With(ast::StmtWith { items, body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { items, body, .. }) => { + Stmt::With(with_ @ ast::StmtWith { items, body, .. }) => { if checker.enabled(Rule::AssertRaisesException) { flake8_bugbear::rules::assert_raises_exception(checker, items); } @@ -1108,8 +1094,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { if checker.enabled(Rule::MultipleWithStatements) { flake8_simplify::rules::multiple_with_statements( checker, - stmt, - body, + with_, checker.semantic.current_statement_parent(), ); } @@ -1134,13 +1119,6 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { iter, orelse, .. - }) - | Stmt::AsyncFor(ast::StmtAsyncFor { - target, - body, - iter, - orelse, - .. }) => { if checker.any_enabled(&[Rule::UnusedLoopControlVariable, Rule::IncorrectDictIterator]) { @@ -1339,7 +1317,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { if !checker .semantic .current_scopes() - .any(|scope| scope.kind.is_any_function()) + .any(|scope| scope.kind.is_function()) { if checker.enabled(Rule::UnprefixedTypeParam) { flake8_pyi::rules::prefix_type_params(checker, value, targets); @@ -1404,7 +1382,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { if !checker .semantic .current_scopes() - .any(|scope| scope.kind.is_any_function()) + .any(|scope| scope.kind.is_function()) { flake8_pyi::rules::annotated_assignment_default_in_stub( checker, target, value, annotation, diff --git a/crates/ruff/src/checkers/ast/mod.rs b/crates/ruff/src/checkers/ast/mod.rs index 4b6972ead1..030ea33af9 100644 --- a/crates/ruff/src/checkers/ast/mod.rs +++ b/crates/ruff/src/checkers/ast/mod.rs @@ -462,14 +462,6 @@ where returns, type_params, .. - }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - body, - parameters, - decorator_list, - type_params, - returns, - .. }) => { // Visit the decorators and arguments, but avoid the body, which will be // deferred. @@ -540,8 +532,7 @@ where self.semantic.push_scope(match &stmt { Stmt::FunctionDef(stmt) => ScopeKind::Function(stmt), - Stmt::AsyncFunctionDef(stmt) => ScopeKind::AsyncFunction(stmt), - _ => unreachable!("Expected Stmt::FunctionDef | Stmt::AsyncFunctionDef"), + _ => unreachable!("Expected Stmt::FunctionDef"), }); self.deferred.functions.push(self.semantic.snapshot()); @@ -743,8 +734,7 @@ where // Step 3: Clean-up match stmt { - Stmt::FunctionDef(ast::StmtFunctionDef { name, .. }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { name, .. }) => { + Stmt::FunctionDef(ast::StmtFunctionDef { name, .. }) => { let scope_id = self.semantic.scope_id; self.deferred.scopes.push(scope_id); self.semantic.pop_scope(); // Function scope @@ -1626,7 +1616,7 @@ impl<'a> Checker<'a> { return; } - if matches!(parent, Stmt::For(_) | Stmt::AsyncFor(_)) { + if parent.is_for_stmt() { self.add_binding( id, expr.range(), @@ -1825,19 +1815,14 @@ impl<'a> Checker<'a> { for snapshot in deferred_functions { self.semantic.restore(snapshot); - match &self.semantic.current_statement() { - Stmt::FunctionDef(ast::StmtFunctionDef { - body, parameters, .. - }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - body, parameters, .. - }) => { - self.visit_parameters(parameters); - self.visit_body(body); - } - _ => { - unreachable!("Expected Stmt::FunctionDef | Stmt::AsyncFunctionDef") - } + if let Stmt::FunctionDef(ast::StmtFunctionDef { + body, parameters, .. + }) = self.semantic.current_statement() + { + self.visit_parameters(parameters); + self.visit_body(body); + } else { + unreachable!("Expected Stmt::FunctionDef") } } } diff --git a/crates/ruff/src/docstrings/extraction.rs b/crates/ruff/src/docstrings/extraction.rs index 3a2fdb7660..84cdba07e2 100644 --- a/crates/ruff/src/docstrings/extraction.rs +++ b/crates/ruff/src/docstrings/extraction.rs @@ -30,8 +30,7 @@ pub(crate) fn extract_docstring<'a>(definition: &'a Definition<'a>) -> Option<&' Definition::Module(module) => docstring_from(module.python_ast), Definition::Member(member) => { if let Stmt::ClassDef(ast::StmtClassDef { body, .. }) - | Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { body, .. }) = &member.stmt + | Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) = &member.stmt { docstring_from(body) } else { diff --git a/crates/ruff/src/rules/flake8_annotations/helpers.rs b/crates/ruff/src/rules/flake8_annotations/helpers.rs index b1572efd82..04ecc41290 100644 --- a/crates/ruff/src/rules/flake8_annotations/helpers.rs +++ b/crates/ruff/src/rules/flake8_annotations/helpers.rs @@ -15,14 +15,6 @@ pub(super) fn match_function_def( body, decorator_list, .. - }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - name, - parameters, - returns, - body, - decorator_list, - .. }) => ( name, parameters, diff --git a/crates/ruff/src/rules/flake8_bugbear/rules/abstract_base_class.rs b/crates/ruff/src/rules/flake8_bugbear/rules/abstract_base_class.rs index 8b69b2cf61..43307aa5c1 100644 --- a/crates/ruff/src/rules/flake8_bugbear/rules/abstract_base_class.rs +++ b/crates/ruff/src/rules/flake8_bugbear/rules/abstract_base_class.rs @@ -159,18 +159,12 @@ pub(crate) fn abstract_base_class( continue; } - let (Stmt::FunctionDef(ast::StmtFunctionDef { + let Stmt::FunctionDef(ast::StmtFunctionDef { decorator_list, body, name: method_name, .. - }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - decorator_list, - body, - name: method_name, - .. - })) = stmt + }) = stmt else { continue; }; diff --git a/crates/ruff/src/rules/flake8_bugbear/rules/function_uses_loop_variable.rs b/crates/ruff/src/rules/flake8_bugbear/rules/function_uses_loop_variable.rs index 500acd9015..b0291c21b6 100644 --- a/crates/ruff/src/rules/flake8_bugbear/rules/function_uses_loop_variable.rs +++ b/crates/ruff/src/rules/flake8_bugbear/rules/function_uses_loop_variable.rs @@ -86,9 +86,6 @@ impl<'a> Visitor<'a> for SuspiciousVariablesVisitor<'a> { match stmt { Stmt::FunctionDef(ast::StmtFunctionDef { parameters, body, .. - }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - parameters, body, .. }) => { // Collect all loaded variable names. let mut visitor = LoadedNamesVisitor::default(); @@ -236,7 +233,7 @@ struct AssignedNamesVisitor<'a> { /// `Visitor` to collect all used identifiers in a statement. impl<'a> Visitor<'a> for AssignedNamesVisitor<'a> { fn visit_stmt(&mut self, stmt: &'a Stmt) { - if matches!(stmt, Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_)) { + if stmt.is_function_def_stmt() { // Don't recurse. return; } @@ -251,8 +248,7 @@ impl<'a> Visitor<'a> for AssignedNamesVisitor<'a> { } Stmt::AugAssign(ast::StmtAugAssign { target, .. }) | Stmt::AnnAssign(ast::StmtAnnAssign { target, .. }) - | Stmt::For(ast::StmtFor { target, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { target, .. }) => { + | Stmt::For(ast::StmtFor { target, .. }) => { let mut visitor = NamesFromAssignmentsVisitor::default(); visitor.visit_expr(target); self.names.extend(visitor.names); diff --git a/crates/ruff/src/rules/flake8_bugbear/rules/jump_statement_in_finally.rs b/crates/ruff/src/rules/flake8_bugbear/rules/jump_statement_in_finally.rs index df83409f2e..2d796f4706 100644 --- a/crates/ruff/src/rules/flake8_bugbear/rules/jump_statement_in_finally.rs +++ b/crates/ruff/src/rules/flake8_bugbear/rules/jump_statement_in_finally.rs @@ -69,16 +69,13 @@ fn walk_stmt(checker: &mut Checker, body: &[Stmt], f: fn(&Stmt) -> bool) { )); } match stmt { - Stmt::While(ast::StmtWhile { body, .. }) - | Stmt::For(ast::StmtFor { body, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { body, .. }) => { + Stmt::While(ast::StmtWhile { body, .. }) | Stmt::For(ast::StmtFor { body, .. }) => { walk_stmt(checker, body, Stmt::is_return_stmt); } Stmt::If(ast::StmtIf { body, .. }) | Stmt::Try(ast::StmtTry { body, .. }) | Stmt::TryStar(ast::StmtTryStar { body, .. }) - | Stmt::With(ast::StmtWith { body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => { + | Stmt::With(ast::StmtWith { body, .. }) => { walk_stmt(checker, body, f); } Stmt::Match(ast::StmtMatch { cases, .. }) => { diff --git a/crates/ruff/src/rules/flake8_pyi/rules/iter_method_return_iterable.rs b/crates/ruff/src/rules/flake8_pyi/rules/iter_method_return_iterable.rs index a3f9c1b2c1..1be0eec92f 100644 --- a/crates/ruff/src/rules/flake8_pyi/rules/iter_method_return_iterable.rs +++ b/crates/ruff/src/rules/flake8_pyi/rules/iter_method_return_iterable.rs @@ -49,14 +49,14 @@ use crate::checkers::ast::Checker; /// ``` #[violation] pub struct IterMethodReturnIterable { - async_: bool, + is_async: bool, } impl Violation for IterMethodReturnIterable { #[derive_message_formats] fn message(&self) -> String { - let IterMethodReturnIterable { async_ } = self; - if *async_ { + let IterMethodReturnIterable { is_async } = self; + if *is_async { format!("`__aiter__` methods should return an `AsyncIterator`, not an `AsyncIterable`") } else { format!("`__iter__` methods should return an `Iterator`, not an `Iterable`") @@ -91,7 +91,7 @@ pub(crate) fn iter_method_return_iterable(checker: &mut Checker, definition: &De returns }; - let async_ = match name.as_str() { + let is_async = match name.as_str() { "__iter__" => false, "__aiter__" => true, _ => return, @@ -101,7 +101,7 @@ pub(crate) fn iter_method_return_iterable(checker: &mut Checker, definition: &De .semantic() .resolve_call_path(annotation) .is_some_and(|call_path| { - if async_ { + if is_async { matches!( call_path.as_slice(), ["typing", "AsyncIterable"] | ["collections", "abc", "AsyncIterable"] @@ -115,7 +115,7 @@ pub(crate) fn iter_method_return_iterable(checker: &mut Checker, definition: &De }) { checker.diagnostics.push(Diagnostic::new( - IterMethodReturnIterable { async_ }, + IterMethodReturnIterable { is_async }, returns.range(), )); } diff --git a/crates/ruff/src/rules/flake8_pyi/rules/non_self_return_type.rs b/crates/ruff/src/rules/flake8_pyi/rules/non_self_return_type.rs index f4b9e32ea2..0f251297cb 100644 --- a/crates/ruff/src/rules/flake8_pyi/rules/non_self_return_type.rs +++ b/crates/ruff/src/rules/flake8_pyi/rules/non_self_return_type.rs @@ -113,11 +113,11 @@ impl Violation for NonSelfReturnType { pub(crate) fn non_self_return_type( checker: &mut Checker, stmt: &Stmt, + is_async: bool, name: &str, decorator_list: &[Decorator], returns: Option<&Expr>, parameters: &Parameters, - async_: bool, ) { let ScopeKind::Class(class_def) = checker.semantic().current_scope().kind else { return; @@ -138,7 +138,7 @@ pub(crate) fn non_self_return_type( return; } - if async_ { + if is_async { if name == "__aenter__" && is_name(returns, &class_def.name) && !is_final(&class_def.decorator_list, checker.semantic()) diff --git a/crates/ruff/src/rules/flake8_pytest_style/rules/fixture.rs b/crates/ruff/src/rules/flake8_pytest_style/rules/fixture.rs index 63352d2ffa..0a95d0cb60 100644 --- a/crates/ruff/src/rules/flake8_pytest_style/rules/fixture.rs +++ b/crates/ruff/src/rules/flake8_pytest_style/rules/fixture.rs @@ -448,7 +448,7 @@ where self.has_return_with_value = true; } } - Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) => {} + Stmt::FunctionDef(_) => {} _ => visitor::walk_stmt(self, stmt), } } diff --git a/crates/ruff/src/rules/flake8_pytest_style/rules/raises.rs b/crates/ruff/src/rules/flake8_pytest_style/rules/raises.rs index c8c824648b..f44ad46824 100644 --- a/crates/ruff/src/rules/flake8_pytest_style/rules/raises.rs +++ b/crates/ruff/src/rules/flake8_pytest_style/rules/raises.rs @@ -195,12 +195,9 @@ pub(crate) fn complex_raises( if raises_called { let is_too_complex = if let [stmt] = body { match stmt { - Stmt::With(ast::StmtWith { body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => { - is_non_trivial_with_body(body) - } + Stmt::With(ast::StmtWith { body, .. }) => is_non_trivial_with_body(body), // Allow function and class definitions to test decorators - Stmt::ClassDef(_) | Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) => false, + Stmt::ClassDef(_) | Stmt::FunctionDef(_) => false, stmt => is_compound_statement(stmt), } } else { diff --git a/crates/ruff/src/rules/flake8_return/rules/function.rs b/crates/ruff/src/rules/flake8_return/rules/function.rs index 00675beb65..90dfb25173 100644 --- a/crates/ruff/src/rules/flake8_return/rules/function.rs +++ b/crates/ruff/src/rules/flake8_return/rules/function.rs @@ -425,9 +425,7 @@ fn implicit_return(checker: &mut Checker, stmt: &Stmt) { } Stmt::Assert(ast::StmtAssert { test, .. }) if is_const_false(test) => {} Stmt::While(ast::StmtWhile { test, .. }) if is_const_true(test) => {} - Stmt::For(ast::StmtFor { orelse, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { orelse, .. }) - | Stmt::While(ast::StmtWhile { orelse, .. }) => { + Stmt::For(ast::StmtFor { orelse, .. }) | Stmt::While(ast::StmtWhile { orelse, .. }) => { if let Some(last_stmt) = orelse.last() { implicit_return(checker, last_stmt); } else { @@ -454,8 +452,7 @@ fn implicit_return(checker: &mut Checker, stmt: &Stmt) { } } } - Stmt::With(ast::StmtWith { body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => { + Stmt::With(ast::StmtWith { body, .. }) => { if let Some(last_stmt) = body.last() { implicit_return(checker, last_stmt); } diff --git a/crates/ruff/src/rules/flake8_return/visitor.rs b/crates/ruff/src/rules/flake8_return/visitor.rs index 80bc2f8d7d..775c3356e5 100644 --- a/crates/ruff/src/rules/flake8_return/visitor.rs +++ b/crates/ruff/src/rules/flake8_return/visitor.rs @@ -50,12 +50,6 @@ impl<'a> Visitor<'a> for ReturnVisitor<'a> { decorator_list, returns, .. - }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - parameters, - decorator_list, - returns, - .. }) => { // Visit the decorators, etc. self.sibling = Some(stmt); @@ -101,8 +95,7 @@ impl<'a> Visitor<'a> for ReturnVisitor<'a> { // x = f.read() // return x // ``` - Stmt::With(ast::StmtWith { body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => { + Stmt::With(ast::StmtWith { body, .. }) => { if let Some(stmt_assign) = body.last().and_then(Stmt::as_assign_stmt) { self.stack .assignment_return diff --git a/crates/ruff/src/rules/flake8_simplify/rules/ast_unary_op.rs b/crates/ruff/src/rules/flake8_simplify/rules/ast_unary_op.rs index c2aa396b17..f5da2c443b 100644 --- a/crates/ruff/src/rules/flake8_simplify/rules/ast_unary_op.rs +++ b/crates/ruff/src/rules/flake8_simplify/rules/ast_unary_op.rs @@ -160,8 +160,7 @@ pub(crate) fn negation_with_equal_op( } // Avoid flagging issues in dunder implementations. - if let ScopeKind::Function(ast::StmtFunctionDef { name, .. }) - | ScopeKind::AsyncFunction(ast::StmtAsyncFunctionDef { name, .. }) = + if let ScopeKind::Function(ast::StmtFunctionDef { name, .. }) = &checker.semantic().current_scope().kind { if is_dunder_method(name) { @@ -218,8 +217,7 @@ pub(crate) fn negation_with_not_equal_op( } // Avoid flagging issues in dunder implementations. - if let ScopeKind::Function(ast::StmtFunctionDef { name, .. }) - | ScopeKind::AsyncFunction(ast::StmtAsyncFunctionDef { name, .. }) = + if let ScopeKind::Function(ast::StmtFunctionDef { name, .. }) = &checker.semantic().current_scope().kind { if is_dunder_method(name) { diff --git a/crates/ruff/src/rules/flake8_simplify/rules/ast_with.rs b/crates/ruff/src/rules/flake8_simplify/rules/ast_with.rs index 55c5d0d781..528b82e050 100644 --- a/crates/ruff/src/rules/flake8_simplify/rules/ast_with.rs +++ b/crates/ruff/src/rules/flake8_simplify/rules/ast_with.rs @@ -63,18 +63,22 @@ impl Violation for MultipleWithStatements { /// Returns a boolean indicating whether it's an async with statement, the items /// and body. fn next_with(body: &[Stmt]) -> Option<(bool, &[WithItem], &[Stmt])> { - match body { - [Stmt::With(ast::StmtWith { items, body, .. })] => Some((false, items, body)), - [Stmt::AsyncWith(ast::StmtAsyncWith { items, body, .. })] => Some((true, items, body)), - _ => None, - } + let [Stmt::With(ast::StmtWith { + is_async, + items, + body, + .. + })] = body + else { + return None; + }; + Some((*is_async, items, body)) } /// SIM117 pub(crate) fn multiple_with_statements( checker: &mut Checker, - with_stmt: &Stmt, - with_body: &[Stmt], + with_stmt: &ast::StmtWith, with_parent: Option<&Stmt>, ) { // Make sure we fix from top to bottom for nested with statements, e.g. for @@ -102,8 +106,8 @@ pub(crate) fn multiple_with_statements( } } - if let Some((is_async, items, body)) = next_with(with_body) { - if is_async != with_stmt.is_async_with_stmt() { + if let Some((is_async, items, body)) = next_with(&with_stmt.body) { + if is_async != with_stmt.is_async { // One of the statements is an async with, while the other is not, // we can't merge those statements. return; @@ -133,7 +137,7 @@ pub(crate) fn multiple_with_statements( if !checker .indexer() .comment_ranges() - .intersects(TextRange::new(with_stmt.start(), with_body[0].start())) + .intersects(TextRange::new(with_stmt.start(), with_stmt.body[0].start())) { match fix_with::fix_multiple_with_statements( checker.locator(), diff --git a/crates/ruff/src/rules/flake8_simplify/rules/fix_with.rs b/crates/ruff/src/rules/flake8_simplify/rules/fix_with.rs index d5dc5ca9d8..36d1f6eb96 100644 --- a/crates/ruff/src/rules/flake8_simplify/rules/fix_with.rs +++ b/crates/ruff/src/rules/flake8_simplify/rules/fix_with.rs @@ -1,6 +1,6 @@ use anyhow::{bail, Result}; use libcst_native::{CompoundStatement, Statement, Suite, With}; -use ruff_python_ast::Ranged; +use ruff_python_ast::{self as ast, Ranged}; use crate::autofix::codemods::CodegenStylist; use ruff_diagnostics::Edit; @@ -14,15 +14,15 @@ use crate::cst::matchers::{match_function_def, match_indented_block, match_state pub(crate) fn fix_multiple_with_statements( locator: &Locator, stylist: &Stylist, - stmt: &ruff_python_ast::Stmt, + with_stmt: &ast::StmtWith, ) -> Result { // Infer the indentation of the outer block. - let Some(outer_indent) = whitespace::indentation(locator, stmt) else { + let Some(outer_indent) = whitespace::indentation(locator, with_stmt) else { bail!("Unable to fix multiline statement"); }; // Extract the module text. - let contents = locator.lines(stmt.range()); + let contents = locator.lines(with_stmt.range()); // If the block is indented, "embed" it in a function definition, to preserve // indentation while retaining valid source code. (We'll strip the prefix later @@ -82,7 +82,7 @@ pub(crate) fn fix_multiple_with_statements( .to_string() }; - let range = locator.lines_range(stmt.range()); + let range = locator.lines_range(with_stmt.range()); Ok(Edit::range_replacement(contents, range)) } diff --git a/crates/ruff/src/rules/flake8_simplify/rules/reimplemented_builtin.rs b/crates/ruff/src/rules/flake8_simplify/rules/reimplemented_builtin.rs index 989036b0e4..257c080349 100644 --- a/crates/ruff/src/rules/flake8_simplify/rules/reimplemented_builtin.rs +++ b/crates/ruff/src/rules/flake8_simplify/rules/reimplemented_builtin.rs @@ -60,7 +60,7 @@ impl Violation for ReimplementedBuiltin { /// SIM110, SIM111 pub(crate) fn convert_for_loop_to_any_all(checker: &mut Checker, stmt: &Stmt) { - if !checker.semantic().current_scope().kind.is_any_function() { + if !checker.semantic().current_scope().kind.is_function() { return; } diff --git a/crates/ruff/src/rules/flake8_unused_arguments/rules/unused_arguments.rs b/crates/ruff/src/rules/flake8_unused_arguments/rules/unused_arguments.rs index b6bc1e23b7..5699bf8e01 100644 --- a/crates/ruff/src/rules/flake8_unused_arguments/rules/unused_arguments.rs +++ b/crates/ruff/src/rules/flake8_unused_arguments/rules/unused_arguments.rs @@ -328,13 +328,6 @@ pub(crate) fn unused_arguments( body, decorator_list, .. - }) - | ScopeKind::AsyncFunction(ast::StmtAsyncFunctionDef { - name, - parameters, - body, - decorator_list, - .. }) => { match function_type::classify( name, diff --git a/crates/ruff/src/rules/isort/block.rs b/crates/ruff/src/rules/isort/block.rs index f5a2607d8b..3832444b4c 100644 --- a/crates/ruff/src/rules/isort/block.rs +++ b/crates/ruff/src/rules/isort/block.rs @@ -89,7 +89,7 @@ impl<'a> BlockBuilder<'a> { // sibling (i.e., as if the comment is the next statement, as // opposed to the class or function). match stmt { - Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) => { + Stmt::FunctionDef(_) => { if helpers::has_comment_break(stmt, self.locator) { Trailer::Sibling } else { @@ -196,12 +196,6 @@ where } self.finalize(None); } - Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { body, .. }) => { - for stmt in body { - self.visit_stmt(stmt); - } - self.finalize(None); - } Stmt::ClassDef(ast::StmtClassDef { body, .. }) => { for stmt in body { self.visit_stmt(stmt); @@ -219,17 +213,6 @@ where } self.finalize(None); } - Stmt::AsyncFor(ast::StmtAsyncFor { body, orelse, .. }) => { - for stmt in body { - self.visit_stmt(stmt); - } - self.finalize(None); - - for stmt in orelse { - self.visit_stmt(stmt); - } - self.finalize(None); - } Stmt::While(ast::StmtWhile { body, orelse, .. }) => { for stmt in body { self.visit_stmt(stmt); @@ -261,12 +244,6 @@ where } self.finalize(None); } - Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => { - for stmt in body { - self.visit_stmt(stmt); - } - self.finalize(None); - } Stmt::Match(ast::StmtMatch { cases, .. }) => { for match_case in cases { self.visit_match_case(match_case); diff --git a/crates/ruff/src/rules/mccabe/rules/function_is_too_complex.rs b/crates/ruff/src/rules/mccabe/rules/function_is_too_complex.rs index 927ae37b5c..60938d334c 100644 --- a/crates/ruff/src/rules/mccabe/rules/function_is_too_complex.rs +++ b/crates/ruff/src/rules/mccabe/rules/function_is_too_complex.rs @@ -82,14 +82,12 @@ fn get_complexity_number(stmts: &[Stmt]) -> usize { complexity += get_complexity_number(&clause.body); } } - Stmt::For(ast::StmtFor { body, orelse, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { body, orelse, .. }) => { + Stmt::For(ast::StmtFor { body, orelse, .. }) => { complexity += 1; complexity += get_complexity_number(body); complexity += get_complexity_number(orelse); } - Stmt::With(ast::StmtWith { body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => { + Stmt::With(ast::StmtWith { body, .. }) => { complexity += get_complexity_number(body); } Stmt::While(ast::StmtWhile { body, orelse, .. }) => { @@ -131,8 +129,7 @@ fn get_complexity_number(stmts: &[Stmt]) -> usize { complexity += get_complexity_number(body); } } - Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { body, .. }) => { + Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => { complexity += 1; complexity += get_complexity_number(body); } diff --git a/crates/ruff/src/rules/pycodestyle/rules/lambda_assignment.rs b/crates/ruff/src/rules/pycodestyle/rules/lambda_assignment.rs index 7ae7764574..d06d032d4e 100644 --- a/crates/ruff/src/rules/pycodestyle/rules/lambda_assignment.rs +++ b/crates/ruff/src/rules/pycodestyle/rules/lambda_assignment.rs @@ -220,6 +220,7 @@ fn function( }) .collect::>(); let func = Stmt::FunctionDef(ast::StmtFunctionDef { + is_async: false, name: Identifier::new(name.to_string(), TextRange::default()), parameters: Box::new(Parameters { posonlyargs: new_posonlyargs, @@ -236,6 +237,7 @@ fn function( } } let func = Stmt::FunctionDef(ast::StmtFunctionDef { + is_async: false, name: Identifier::new(name.to_string(), TextRange::default()), parameters: Box::new(parameters.clone()), body: vec![body], diff --git a/crates/ruff/src/rules/pydocstyle/rules/sections.rs b/crates/ruff/src/rules/pydocstyle/rules/sections.rs index 001e5830ff..38f49dbdfe 100644 --- a/crates/ruff/src/rules/pydocstyle/rules/sections.rs +++ b/crates/ruff/src/rules/pydocstyle/rules/sections.rs @@ -1726,9 +1726,7 @@ fn missing_args(checker: &mut Checker, docstring: &Docstring, docstrings_args: & return; }; - let (Stmt::FunctionDef(ast::StmtFunctionDef { parameters, .. }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { parameters, .. })) = stmt - else { + let Stmt::FunctionDef(ast::StmtFunctionDef { parameters, .. }) = stmt else { return; }; diff --git a/crates/ruff/src/rules/pyflakes/rules/break_outside_loop.rs b/crates/ruff/src/rules/pyflakes/rules/break_outside_loop.rs index cfe9323c43..0fb29fcc04 100644 --- a/crates/ruff/src/rules/pyflakes/rules/break_outside_loop.rs +++ b/crates/ruff/src/rules/pyflakes/rules/break_outside_loop.rs @@ -36,14 +36,12 @@ pub(crate) fn break_outside_loop<'a>( let mut child = stmt; for parent in parents { match parent { - Stmt::For(ast::StmtFor { orelse, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { orelse, .. }) - | Stmt::While(ast::StmtWhile { orelse, .. }) => { + Stmt::For(ast::StmtFor { orelse, .. }) | Stmt::While(ast::StmtWhile { orelse, .. }) => { if !orelse.contains(child) { return None; } } - Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) | Stmt::ClassDef(_) => { + Stmt::FunctionDef(_) | Stmt::ClassDef(_) => { break; } _ => {} diff --git a/crates/ruff/src/rules/pyflakes/rules/continue_outside_loop.rs b/crates/ruff/src/rules/pyflakes/rules/continue_outside_loop.rs index 5e773e4095..15a0accc9c 100644 --- a/crates/ruff/src/rules/pyflakes/rules/continue_outside_loop.rs +++ b/crates/ruff/src/rules/pyflakes/rules/continue_outside_loop.rs @@ -36,14 +36,12 @@ pub(crate) fn continue_outside_loop<'a>( let mut child = stmt; for parent in parents { match parent { - Stmt::For(ast::StmtFor { orelse, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { orelse, .. }) - | Stmt::While(ast::StmtWhile { orelse, .. }) => { + Stmt::For(ast::StmtFor { orelse, .. }) | Stmt::While(ast::StmtWhile { orelse, .. }) => { if !orelse.contains(child) { return None; } } - Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) | Stmt::ClassDef(_) => { + Stmt::FunctionDef(_) | Stmt::ClassDef(_) => { break; } _ => {} diff --git a/crates/ruff/src/rules/pyflakes/rules/undefined_local.rs b/crates/ruff/src/rules/pyflakes/rules/undefined_local.rs index 43c4eda595..e4f931a7bc 100644 --- a/crates/ruff/src/rules/pyflakes/rules/undefined_local.rs +++ b/crates/ruff/src/rules/pyflakes/rules/undefined_local.rs @@ -51,7 +51,7 @@ pub(crate) fn undefined_local( scope: &Scope, diagnostics: &mut Vec, ) { - if scope.kind.is_any_function() { + if scope.kind.is_function() { for (name, binding_id) in scope.bindings() { // If the variable shadows a binding in a parent scope... if let Some(shadowed_id) = checker.semantic().shadowed_binding(binding_id) { diff --git a/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs b/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs index e91407a7ca..dd3ff4c7c7 100644 --- a/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs +++ b/crates/ruff/src/rules/pyflakes/rules/unused_variable.rs @@ -299,7 +299,7 @@ fn remove_unused_variable( /// F841 pub(crate) fn unused_variable(checker: &Checker, scope: &Scope, diagnostics: &mut Vec) { - if scope.uses_locals() && scope.kind.is_any_function() { + if scope.uses_locals() && scope.kind.is_function() { return; } diff --git a/crates/ruff/src/rules/pylint/helpers.rs b/crates/ruff/src/rules/pylint/helpers.rs index 25ca39c544..60bb0dd223 100644 --- a/crates/ruff/src/rules/pylint/helpers.rs +++ b/crates/ruff/src/rules/pylint/helpers.rs @@ -24,16 +24,11 @@ pub(super) fn type_param_name(arguments: &Arguments) -> Option<&str> { pub(super) fn in_dunder_init(semantic: &SemanticModel, settings: &Settings) -> bool { let scope = semantic.current_scope(); - let (ScopeKind::Function(ast::StmtFunctionDef { + let ScopeKind::Function(ast::StmtFunctionDef { name, decorator_list, .. - }) - | ScopeKind::AsyncFunction(ast::StmtAsyncFunctionDef { - name, - decorator_list, - .. - })) = scope.kind + }) = scope.kind else { return false; }; diff --git a/crates/ruff/src/rules/pylint/rules/continue_in_finally.rs b/crates/ruff/src/rules/pylint/rules/continue_in_finally.rs index 90f75b83e4..8624f65304 100644 --- a/crates/ruff/src/rules/pylint/rules/continue_in_finally.rs +++ b/crates/ruff/src/rules/pylint/rules/continue_in_finally.rs @@ -69,11 +69,10 @@ fn traverse_body(checker: &mut Checker, body: &[Stmt]) { traverse_body(checker, body); traverse_body(checker, orelse); } - Stmt::For(ast::StmtFor { orelse, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { orelse, .. }) - | Stmt::While(ast::StmtWhile { orelse, .. }) => traverse_body(checker, orelse), - Stmt::With(ast::StmtWith { body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => { + Stmt::For(ast::StmtFor { orelse, .. }) | Stmt::While(ast::StmtWhile { orelse, .. }) => { + traverse_body(checker, orelse); + } + Stmt::With(ast::StmtWith { body, .. }) => { traverse_body(checker, body); } Stmt::Match(ast::StmtMatch { cases, .. }) => { diff --git a/crates/ruff/src/rules/pylint/rules/redefined_loop_name.rs b/crates/ruff/src/rules/pylint/rules/redefined_loop_name.rs index 870da016f6..23ec06ba31 100644 --- a/crates/ruff/src/rules/pylint/rules/redefined_loop_name.rs +++ b/crates/ruff/src/rules/pylint/rules/redefined_loop_name.rs @@ -147,9 +147,7 @@ impl<'a, 'b> StatementVisitor<'b> for InnerForWithAssignTargetsVisitor<'a, 'b> { fn visit_stmt(&mut self, stmt: &'b Stmt) { // Collect target expressions. match stmt { - // For and async for. - Stmt::For(ast::StmtFor { target, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { target, .. }) => { + Stmt::For(ast::StmtFor { target, .. }) => { self.assignment_targets.extend( assignment_targets_from_expr(target, self.dummy_variable_rgx).map(|expr| { ExprWithInnerBindingKind { @@ -159,7 +157,6 @@ impl<'a, 'b> StatementVisitor<'b> for InnerForWithAssignTargetsVisitor<'a, 'b> { }), ); } - // With. Stmt::With(ast::StmtWith { items, .. }) => { self.assignment_targets.extend( assignment_targets_from_with_items(items, self.dummy_variable_rgx).map( @@ -170,7 +167,6 @@ impl<'a, 'b> StatementVisitor<'b> for InnerForWithAssignTargetsVisitor<'a, 'b> { ), ); } - // Assignment, augmented assignment, and annotated assignment. Stmt::Assign(ast::StmtAssign { targets, value, .. }) => { // Check for single-target assignments which are of the // form `x = cast(..., x)`. @@ -217,8 +213,7 @@ impl<'a, 'b> StatementVisitor<'b> for InnerForWithAssignTargetsVisitor<'a, 'b> { // Decide whether to recurse. match stmt { // Don't recurse into blocks that create a new scope. - Stmt::ClassDef(_) => {} - Stmt::FunctionDef(_) => {} + Stmt::ClassDef(_) | Stmt::FunctionDef(_) => {} // Otherwise, do recurse. _ => { walk_stmt(self, stmt); @@ -339,8 +334,7 @@ fn assignment_targets_from_assign_targets<'a>( /// PLW2901 pub(crate) fn redefined_loop_name(checker: &mut Checker, stmt: &Stmt) { let (outer_assignment_targets, inner_assignment_targets) = match stmt { - Stmt::With(ast::StmtWith { items, body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { items, body, .. }) => { + Stmt::With(ast::StmtWith { items, body, .. }) => { let outer_assignment_targets: Vec = assignment_targets_from_with_items(items, &checker.settings.dummy_variable_rgx) .map(|expr| ExprWithOuterBindingKind { @@ -358,8 +352,7 @@ pub(crate) fn redefined_loop_name(checker: &mut Checker, stmt: &Stmt) { } (outer_assignment_targets, visitor.assignment_targets) } - Stmt::For(ast::StmtFor { target, body, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { target, body, .. }) => { + Stmt::For(ast::StmtFor { target, body, .. }) => { let outer_assignment_targets: Vec = assignment_targets_from_expr(target, &checker.settings.dummy_variable_rgx) .map(|expr| ExprWithOuterBindingKind { @@ -377,9 +370,7 @@ pub(crate) fn redefined_loop_name(checker: &mut Checker, stmt: &Stmt) { } (outer_assignment_targets, visitor.assignment_targets) } - _ => panic!( - "redefined_loop_name called on Statement that is not a With, For, AsyncWith, or AsyncFor" - ) + _ => panic!("redefined_loop_name called on Statement that is not a `With` or `For`"), }; let mut diagnostics = Vec::new(); diff --git a/crates/ruff/src/rules/pylint/rules/too_many_branches.rs b/crates/ruff/src/rules/pylint/rules/too_many_branches.rs index bc5e459e9f..2e6e220ec8 100644 --- a/crates/ruff/src/rules/pylint/rules/too_many_branches.rs +++ b/crates/ruff/src/rules/pylint/rules/too_many_branches.rs @@ -108,7 +108,6 @@ fn num_branches(stmts: &[Stmt]) -> usize { .sum::() } Stmt::For(ast::StmtFor { body, orelse, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { body, orelse, .. }) | Stmt::While(ast::StmtWhile { body, orelse, .. }) => { 1 + num_branches(body) + (if orelse.is_empty() { diff --git a/crates/ruff/src/rules/pylint/rules/too_many_statements.rs b/crates/ruff/src/rules/pylint/rules/too_many_statements.rs index 0df12604ac..4088a9e473 100644 --- a/crates/ruff/src/rules/pylint/rules/too_many_statements.rs +++ b/crates/ruff/src/rules/pylint/rules/too_many_statements.rs @@ -78,8 +78,7 @@ fn num_statements(stmts: &[Stmt]) -> usize { count += num_statements(&clause.body); } } - Stmt::For(ast::StmtFor { body, orelse, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { body, orelse, .. }) => { + Stmt::For(ast::StmtFor { body, orelse, .. }) => { count += num_statements(body); count += num_statements(orelse); } @@ -129,7 +128,6 @@ fn num_statements(stmts: &[Stmt]) -> usize { } } Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { body, .. }) | Stmt::With(ast::StmtWith { body, .. }) => { count += 1; count += num_statements(body); diff --git a/crates/ruff/src/rules/pylint/rules/useless_else_on_loop.rs b/crates/ruff/src/rules/pylint/rules/useless_else_on_loop.rs index 00b0caad33..bf290b59c5 100644 --- a/crates/ruff/src/rules/pylint/rules/useless_else_on_loop.rs +++ b/crates/ruff/src/rules/pylint/rules/useless_else_on_loop.rs @@ -63,8 +63,7 @@ fn loop_exits_early(body: &[Stmt]) -> bool { .iter() .any(|clause| loop_exits_early(&clause.body)) } - Stmt::With(ast::StmtWith { body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => loop_exits_early(body), + Stmt::With(ast::StmtWith { body, .. }) => loop_exits_early(body), Stmt::Match(ast::StmtMatch { cases, .. }) => cases .iter() .any(|MatchCase { body, .. }| loop_exits_early(body)), @@ -91,9 +90,9 @@ fn loop_exits_early(body: &[Stmt]) -> bool { }) => loop_exits_early(body), }) } - Stmt::For(ast::StmtFor { orelse, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { orelse, .. }) - | Stmt::While(ast::StmtWhile { orelse, .. }) => loop_exits_early(orelse), + Stmt::For(ast::StmtFor { orelse, .. }) | Stmt::While(ast::StmtWhile { orelse, .. }) => { + loop_exits_early(orelse) + } Stmt::Break(_) => true, _ => false, }) diff --git a/crates/ruff/src/rules/pylint/rules/yield_from_in_async_function.rs b/crates/ruff/src/rules/pylint/rules/yield_from_in_async_function.rs index 723a927ed4..a94f388a03 100644 --- a/crates/ruff/src/rules/pylint/rules/yield_from_in_async_function.rs +++ b/crates/ruff/src/rules/pylint/rules/yield_from_in_async_function.rs @@ -1,7 +1,7 @@ -use ruff_python_ast::{ExprYieldFrom, Ranged}; - use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::{self as ast, Ranged}; +use ruff_python_semantic::ScopeKind; use crate::checkers::ast::Checker; @@ -37,9 +37,11 @@ impl Violation for YieldFromInAsyncFunction { } /// PLE1700 -pub(crate) fn yield_from_in_async_function(checker: &mut Checker, expr: &ExprYieldFrom) { - let scope = checker.semantic().current_scope(); - if scope.kind.is_async_function() { +pub(crate) fn yield_from_in_async_function(checker: &mut Checker, expr: &ast::ExprYieldFrom) { + if matches!( + checker.semantic().current_scope().kind, + ScopeKind::Function(ast::StmtFunctionDef { is_async: true, .. }) + ) { checker .diagnostics .push(Diagnostic::new(YieldFromInAsyncFunction, expr.range())); diff --git a/crates/ruff/src/rules/pyupgrade/rules/super_call_with_parameters.rs b/crates/ruff/src/rules/pyupgrade/rules/super_call_with_parameters.rs index 457de15f0e..d8c6981aa0 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/super_call_with_parameters.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/super_call_with_parameters.rs @@ -83,7 +83,7 @@ pub(crate) fn super_call_with_parameters( let scope = checker.semantic().current_scope(); // Check: are we in a Function scope? - if !scope.kind.is_any_function() { + if !scope.kind.is_function() { return; } diff --git a/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs b/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs index 12e686facb..acf96c59be 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/yield_in_for_loop.rs @@ -1,10 +1,10 @@ -use ruff_python_ast::{self as ast, Expr, ExprContext, Ranged, Stmt}; use rustc_hash::FxHashMap; use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::statement_visitor::StatementVisitor; use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::{self as ast, Expr, ExprContext, Ranged, Stmt}; use ruff_python_ast::{statement_visitor, visitor}; use ruff_python_semantic::StatementKey; @@ -120,7 +120,7 @@ impl<'a> StatementVisitor<'a> for YieldFromVisitor<'a> { } } } - Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) | Stmt::ClassDef(_) => { + Stmt::FunctionDef(_) | Stmt::ClassDef(_) => { // Don't recurse into anything that defines a new scope. } _ => statement_visitor::walk_stmt(self, stmt), @@ -162,39 +162,46 @@ impl<'a> Visitor<'a> for ReferenceVisitor<'a> { /// UP028 pub(crate) fn yield_in_for_loop(checker: &mut Checker, stmt: &Stmt) { // Intentionally omit async functions. - if let Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) = stmt { - let yields = { - let mut visitor = YieldFromVisitor::default(); - visitor.visit_body(body); - visitor.yields - }; + let Stmt::FunctionDef(ast::StmtFunctionDef { + is_async: false, + body, + .. + }) = stmt + else { + return; + }; - let references = { - let mut visitor = ReferenceVisitor::default(); - visitor.visit_body(body); - visitor.references - }; + let yields = { + let mut visitor = YieldFromVisitor::default(); + visitor.visit_body(body); + visitor.yields + }; - for item in yields { - // If any of the bound names are used outside of the loop, don't rewrite. - if references.iter().any(|(statement, names)| { - *statement != StatementKey::from(item.stmt) - && *statement != StatementKey::from(item.body) - && item.names.iter().any(|name| names.contains(name)) - }) { - continue; - } + let references = { + let mut visitor = ReferenceVisitor::default(); + visitor.visit_body(body); + visitor.references + }; - let mut diagnostic = Diagnostic::new(YieldInForLoop, item.stmt.range()); - if checker.patch(diagnostic.kind.rule()) { - let contents = checker.locator().slice(item.iter.range()); - let contents = format!("yield from {contents}"); - diagnostic.set_fix(Fix::suggested(Edit::range_replacement( - contents, - item.stmt.range(), - ))); - } - checker.diagnostics.push(diagnostic); + for item in yields { + // If any of the bound names are used outside of the loop, don't rewrite. + if references.iter().any(|(statement, names)| { + *statement != StatementKey::from(item.stmt) + && *statement != StatementKey::from(item.body) + && item.names.iter().any(|name| names.contains(name)) + }) { + continue; } + + let mut diagnostic = Diagnostic::new(YieldInForLoop, item.stmt.range()); + if checker.patch(diagnostic.kind.rule()) { + let contents = checker.locator().slice(item.iter.range()); + let contents = format!("yield from {contents}"); + diagnostic.set_fix(Fix::suggested(Edit::range_replacement( + contents, + item.stmt.range(), + ))); + } + checker.diagnostics.push(diagnostic); } } diff --git a/crates/ruff/src/rules/ruff/rules/unreachable.rs b/crates/ruff/src/rules/ruff/rules/unreachable.rs index 5e2c94a226..b803fce617 100644 --- a/crates/ruff/src/rules/ruff/rules/unreachable.rs +++ b/crates/ruff/src/rules/ruff/rules/unreachable.rs @@ -2,8 +2,8 @@ use std::{fmt, iter, usize}; use log::error; use ruff_python_ast::{ - Expr, Identifier, MatchCase, Pattern, PatternMatchAs, Ranged, Stmt, StmtAsyncFor, - StmtAsyncWith, StmtFor, StmtMatch, StmtReturn, StmtTry, StmtTryStar, StmtWhile, StmtWith, + Expr, Identifier, MatchCase, Pattern, PatternMatchAs, Ranged, Stmt, StmtFor, StmtMatch, + StmtReturn, StmtTry, StmtTryStar, StmtWhile, StmtWith, }; use ruff_text_size::{TextRange, TextSize}; @@ -467,7 +467,6 @@ impl<'stmt> BasicBlocksBuilder<'stmt> { let next = match stmt { // Statements that continue to the next statement after execution. Stmt::FunctionDef(_) - | Stmt::AsyncFunctionDef(_) | Stmt::Import(_) | Stmt::ImportFrom(_) | Stmt::ClassDef(_) @@ -535,12 +534,6 @@ impl<'stmt> BasicBlocksBuilder<'stmt> { body, orelse, .. - }) - | Stmt::AsyncFor(StmtAsyncFor { - iter: condition, - body, - orelse, - .. }) => loop_block(self, Condition::Iterator(condition), body, orelse, after), Stmt::Try(StmtTry { body, @@ -566,8 +559,7 @@ impl<'stmt> BasicBlocksBuilder<'stmt> { let _ = (body, handlers, orelse, finalbody); // Silence unused code warnings. self.unconditional_next_block(after) } - Stmt::With(StmtWith { items, body, .. }) - | Stmt::AsyncWith(StmtAsyncWith { items, body, .. }) => { + Stmt::With(StmtWith { items, body, .. }) => { // TODO: handle `with` statements, see // . // I recommend to `try` statements first as `with` can desugar @@ -889,7 +881,6 @@ fn needs_next_block(stmts: &[Stmt]) -> bool { Stmt::Return(_) | Stmt::Raise(_) => false, Stmt::If(stmt) => needs_next_block(&stmt.body) || stmt.elif_else_clauses.last().map_or(true, |clause| needs_next_block(&clause.body)), Stmt::FunctionDef(_) - | Stmt::AsyncFunctionDef(_) | Stmt::Import(_) | Stmt::ImportFrom(_) | Stmt::ClassDef(_) @@ -905,10 +896,8 @@ fn needs_next_block(stmts: &[Stmt]) -> bool { | Stmt::Break(_) | Stmt::Continue(_) | Stmt::For(_) - | Stmt::AsyncFor(_) | Stmt::While(_) | Stmt::With(_) - | Stmt::AsyncWith(_) | Stmt::Match(_) | Stmt::Try(_) | Stmt::TryStar(_) @@ -923,7 +912,6 @@ fn needs_next_block(stmts: &[Stmt]) -> bool { fn is_control_flow_stmt(stmt: &Stmt) -> bool { match stmt { Stmt::FunctionDef(_) - | Stmt::AsyncFunctionDef(_) | Stmt::Import(_) | Stmt::ImportFrom(_) | Stmt::ClassDef(_) @@ -937,11 +925,9 @@ fn is_control_flow_stmt(stmt: &Stmt) -> bool { | Stmt::Pass(_) => false, Stmt::Return(_) | Stmt::For(_) - | Stmt::AsyncFor(_) | Stmt::While(_) | Stmt::If(_) | Stmt::With(_) - | Stmt::AsyncWith(_) | Stmt::Match(_) | Stmt::Raise(_) | Stmt::Try(_) diff --git a/crates/ruff/src/rules/tryceratops/rules/type_check_without_type_error.rs b/crates/ruff/src/rules/tryceratops/rules/type_check_without_type_error.rs index 208f9d3ff9..91e1d25b4c 100644 --- a/crates/ruff/src/rules/tryceratops/rules/type_check_without_type_error.rs +++ b/crates/ruff/src/rules/tryceratops/rules/type_check_without_type_error.rs @@ -55,7 +55,7 @@ where { fn visit_stmt(&mut self, stmt: &'b Stmt) { match stmt { - Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) | Stmt::ClassDef(_) => { + Stmt::FunctionDef(_) | Stmt::ClassDef(_) => { // Don't recurse. } Stmt::Return(_) => self.returns.push(stmt), diff --git a/crates/ruff_python_ast/src/cast.rs b/crates/ruff_python_ast/src/cast.rs index 1b59f69b20..6246442ed7 100644 --- a/crates/ruff_python_ast/src/cast.rs +++ b/crates/ruff_python_ast/src/cast.rs @@ -1,19 +1,15 @@ use crate::{nodes, Decorator, Stmt}; pub fn name(stmt: &Stmt) -> &str { - match stmt { - Stmt::FunctionDef(nodes::StmtFunctionDef { name, .. }) - | Stmt::AsyncFunctionDef(nodes::StmtAsyncFunctionDef { name, .. }) => name.as_str(), - _ => panic!("Expected Stmt::FunctionDef | Stmt::AsyncFunctionDef"), - } + let Stmt::FunctionDef(nodes::StmtFunctionDef { name, .. }) = stmt else { + panic!("Expected Stmt::FunctionDef") + }; + name.as_str() } pub fn decorator_list(stmt: &Stmt) -> &[Decorator] { - match stmt { - Stmt::FunctionDef(nodes::StmtFunctionDef { decorator_list, .. }) - | Stmt::AsyncFunctionDef(nodes::StmtAsyncFunctionDef { decorator_list, .. }) => { - decorator_list - } - _ => panic!("Expected Stmt::FunctionDef | Stmt::AsyncFunctionDef"), - } + let Stmt::FunctionDef(nodes::StmtFunctionDef { decorator_list, .. }) = stmt else { + panic!("Expected Stmt::FunctionDef") + }; + decorator_list } diff --git a/crates/ruff_python_ast/src/comparable.rs b/crates/ruff_python_ast/src/comparable.rs index 107e079128..71b546a342 100644 --- a/crates/ruff_python_ast/src/comparable.rs +++ b/crates/ruff_python_ast/src/comparable.rs @@ -950,16 +950,7 @@ impl<'a> From<&'a ast::Expr> for ComparableExpr<'a> { #[derive(Debug, PartialEq, Eq, Hash)] pub struct StmtFunctionDef<'a> { - decorator_list: Vec>, - name: &'a str, - type_params: Option>, - parameters: ComparableParameters<'a>, - returns: Option>, - body: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -pub struct StmtAsyncFunctionDef<'a> { + is_async: bool, decorator_list: Vec>, name: &'a str, type_params: Option>, @@ -1084,14 +1075,7 @@ pub struct StmtAnnAssign<'a> { #[derive(Debug, PartialEq, Eq, Hash)] pub struct StmtFor<'a> { - target: ComparableExpr<'a>, - iter: ComparableExpr<'a>, - body: Vec>, - orelse: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -pub struct StmtAsyncFor<'a> { + is_async: bool, target: ComparableExpr<'a>, iter: ComparableExpr<'a>, body: Vec>, @@ -1114,12 +1098,7 @@ pub struct StmtIf<'a> { #[derive(Debug, PartialEq, Eq, Hash)] pub struct StmtWith<'a> { - items: Vec>, - body: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -pub struct StmtAsyncWith<'a> { + is_async: bool, items: Vec>, body: Vec>, } @@ -1194,7 +1173,6 @@ pub struct StmtLineMagic<'a> { #[derive(Debug, PartialEq, Eq, Hash)] pub enum ComparableStmt<'a> { FunctionDef(StmtFunctionDef<'a>), - AsyncFunctionDef(StmtAsyncFunctionDef<'a>), ClassDef(StmtClassDef<'a>), Return(StmtReturn<'a>), Delete(StmtDelete<'a>), @@ -1202,11 +1180,9 @@ pub enum ComparableStmt<'a> { AugAssign(StmtAugAssign<'a>), AnnAssign(StmtAnnAssign<'a>), For(StmtFor<'a>), - AsyncFor(StmtAsyncFor<'a>), While(StmtWhile<'a>), If(StmtIf<'a>), With(StmtWith<'a>), - AsyncWith(StmtAsyncWith<'a>), Match(StmtMatch<'a>), Raise(StmtRaise<'a>), Try(StmtTry<'a>), @@ -1228,6 +1204,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> { fn from(stmt: &'a ast::Stmt) -> Self { match stmt { ast::Stmt::FunctionDef(ast::StmtFunctionDef { + is_async, name, parameters, body, @@ -1236,22 +1213,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> { type_params, range: _, }) => Self::FunctionDef(StmtFunctionDef { - name: name.as_str(), - parameters: parameters.into(), - body: body.iter().map(Into::into).collect(), - decorator_list: decorator_list.iter().map(Into::into).collect(), - returns: returns.as_ref().map(Into::into), - type_params: type_params.as_ref().map(Into::into), - }), - ast::Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - name, - parameters, - body, - decorator_list, - returns, - type_params, - range: _, - }) => Self::AsyncFunctionDef(StmtAsyncFunctionDef { + is_async: *is_async, name: name.as_str(), parameters: parameters.into(), body: body.iter().map(Into::into).collect(), @@ -1320,24 +1282,14 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> { simple: *simple, }), ast::Stmt::For(ast::StmtFor { + is_async, target, iter, body, orelse, range: _, }) => Self::For(StmtFor { - target: target.into(), - iter: iter.into(), - body: body.iter().map(Into::into).collect(), - orelse: orelse.iter().map(Into::into).collect(), - }), - ast::Stmt::AsyncFor(ast::StmtAsyncFor { - target, - iter, - body, - orelse, - range: _, - }) => Self::AsyncFor(StmtAsyncFor { + is_async: *is_async, target: target.into(), iter: iter.into(), body: body.iter().map(Into::into).collect(), @@ -1364,18 +1316,12 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> { elif_else_clauses: elif_else_clauses.iter().map(Into::into).collect(), }), ast::Stmt::With(ast::StmtWith { + is_async, items, body, range: _, }) => Self::With(StmtWith { - items: items.iter().map(Into::into).collect(), - body: body.iter().map(Into::into).collect(), - }), - ast::Stmt::AsyncWith(ast::StmtAsyncWith { - items, - body, - range: _, - }) => Self::AsyncWith(StmtAsyncWith { + is_async: *is_async, items: items.iter().map(Into::into).collect(), body: body.iter().map(Into::into).collect(), }), diff --git a/crates/ruff_python_ast/src/function.rs b/crates/ruff_python_ast/src/function.rs deleted file mode 100644 index 96737bcf86..0000000000 --- a/crates/ruff_python_ast/src/function.rs +++ /dev/null @@ -1,134 +0,0 @@ -use crate::node::AnyNodeRef; -use crate::{ - Decorator, Expr, Identifier, Parameters, Ranged, StmtAsyncFunctionDef, StmtFunctionDef, Suite, - TypeParams, -}; -use ruff_text_size::TextRange; - -/// Enum that represents any python function definition. -#[derive(Copy, Clone, PartialEq, Debug)] -pub enum AnyFunctionDefinition<'a> { - FunctionDefinition(&'a StmtFunctionDef), - AsyncFunctionDefinition(&'a StmtAsyncFunctionDef), -} - -impl<'a> AnyFunctionDefinition<'a> { - pub const fn cast_ref(reference: AnyNodeRef<'a>) -> Option { - match reference { - AnyNodeRef::StmtAsyncFunctionDef(definition) => { - Some(Self::AsyncFunctionDefinition(definition)) - } - AnyNodeRef::StmtFunctionDef(definition) => Some(Self::FunctionDefinition(definition)), - _ => None, - } - } - - /// Returns `Some` if this is a [`StmtFunctionDef`] and `None` otherwise. - pub const fn as_function_definition(self) -> Option<&'a StmtFunctionDef> { - if let Self::FunctionDefinition(definition) = self { - Some(definition) - } else { - None - } - } - - /// Returns `Some` if this is a [`StmtAsyncFunctionDef`] and `None` otherwise. - pub const fn as_async_function_definition(self) -> Option<&'a StmtAsyncFunctionDef> { - if let Self::AsyncFunctionDefinition(definition) = self { - Some(definition) - } else { - None - } - } - - /// Returns the function's name - pub const fn name(self) -> &'a Identifier { - match self { - Self::FunctionDefinition(definition) => &definition.name, - Self::AsyncFunctionDefinition(definition) => &definition.name, - } - } - - /// Returns the function arguments (parameters). - pub fn arguments(self) -> &'a Parameters { - match self { - Self::FunctionDefinition(definition) => definition.parameters.as_ref(), - Self::AsyncFunctionDefinition(definition) => definition.parameters.as_ref(), - } - } - - /// Returns the function's body - pub const fn body(self) -> &'a Suite { - match self { - Self::FunctionDefinition(definition) => &definition.body, - Self::AsyncFunctionDefinition(definition) => &definition.body, - } - } - - /// Returns the decorators attributing the function. - pub fn decorators(self) -> &'a [Decorator] { - match self { - Self::FunctionDefinition(definition) => &definition.decorator_list, - Self::AsyncFunctionDefinition(definition) => &definition.decorator_list, - } - } - - pub fn returns(self) -> Option<&'a Expr> { - match self { - Self::FunctionDefinition(definition) => definition.returns.as_deref(), - Self::AsyncFunctionDefinition(definition) => definition.returns.as_deref(), - } - } - - pub fn type_params(self) -> Option<&'a TypeParams> { - match self { - Self::FunctionDefinition(definition) => definition.type_params.as_ref(), - Self::AsyncFunctionDefinition(definition) => definition.type_params.as_ref(), - } - } - - /// Returns `true` if this is [`Self::AsyncFunctionDefinition`] - pub const fn is_async(self) -> bool { - matches!(self, Self::AsyncFunctionDefinition(_)) - } -} - -impl Ranged for AnyFunctionDefinition<'_> { - fn range(&self) -> TextRange { - match self { - AnyFunctionDefinition::FunctionDefinition(definition) => definition.range(), - AnyFunctionDefinition::AsyncFunctionDefinition(definition) => definition.range(), - } - } -} - -impl<'a> From<&'a StmtFunctionDef> for AnyFunctionDefinition<'a> { - fn from(value: &'a StmtFunctionDef) -> Self { - Self::FunctionDefinition(value) - } -} - -impl<'a> From<&'a StmtAsyncFunctionDef> for AnyFunctionDefinition<'a> { - fn from(value: &'a StmtAsyncFunctionDef) -> Self { - Self::AsyncFunctionDefinition(value) - } -} - -impl<'a> From> for AnyNodeRef<'a> { - fn from(value: AnyFunctionDefinition<'a>) -> Self { - match value { - AnyFunctionDefinition::FunctionDefinition(function_def) => { - AnyNodeRef::StmtFunctionDef(function_def) - } - AnyFunctionDefinition::AsyncFunctionDefinition(async_def) => { - AnyNodeRef::StmtAsyncFunctionDef(async_def) - } - } - } -} - -impl<'a> From<&'a AnyFunctionDefinition<'a>> for AnyNodeRef<'a> { - fn from(value: &'a AnyFunctionDefinition<'a>) -> Self { - (*value).into() - } -} diff --git a/crates/ruff_python_ast/src/helpers.rs b/crates/ruff_python_ast/src/helpers.rs index d29a68b167..a7994a3b49 100644 --- a/crates/ruff_python_ast/src/helpers.rs +++ b/crates/ruff_python_ast/src/helpers.rs @@ -18,14 +18,11 @@ pub const fn is_compound_statement(stmt: &Stmt) -> bool { matches!( stmt, Stmt::FunctionDef(_) - | Stmt::AsyncFunctionDef(_) | Stmt::ClassDef(_) | Stmt::While(_) | Stmt::For(_) - | Stmt::AsyncFor(_) | Stmt::Match(_) | Stmt::With(_) - | Stmt::AsyncWith(_) | Stmt::If(_) | Stmt::Try(_) | Stmt::TryStar(_) @@ -321,14 +318,6 @@ where decorator_list, returns, .. - }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - parameters, - type_params, - body, - decorator_list, - returns, - .. }) => { parameters .posonlyargs @@ -439,13 +428,6 @@ where body, orelse, .. - }) - | Stmt::AsyncFor(ast::StmtAsyncFor { - target, - iter, - body, - orelse, - .. }) => { any_over_expr(target, func) || any_over_expr(iter, func) @@ -474,8 +456,7 @@ where || any_over_body(&clause.body, func) }) } - Stmt::With(ast::StmtWith { items, body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { items, body, .. }) => { + Stmt::With(ast::StmtWith { items, body, .. }) => { items.iter().any(|with_item| { any_over_expr(&with_item.context_expr, func) || with_item @@ -912,7 +893,7 @@ where { fn visit_stmt(&mut self, stmt: &'b Stmt) { match stmt { - Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) | Stmt::ClassDef(_) => { + Stmt::FunctionDef(_) | Stmt::ClassDef(_) => { // Don't recurse. } Stmt::Return(stmt) => self.returns.push(stmt), @@ -941,11 +922,7 @@ where self.raises .push((stmt.range(), exc.as_deref(), cause.as_deref())); } - Stmt::ClassDef(_) - | Stmt::FunctionDef(_) - | Stmt::AsyncFunctionDef(_) - | Stmt::Try(_) - | Stmt::TryStar(_) => {} + Stmt::ClassDef(_) | Stmt::FunctionDef(_) | Stmt::Try(_) | Stmt::TryStar(_) => {} Stmt::If(ast::StmtIf { body, elif_else_clauses, @@ -958,9 +935,7 @@ where } Stmt::While(ast::StmtWhile { body, .. }) | Stmt::With(ast::StmtWith { body, .. }) - | Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) - | Stmt::For(ast::StmtFor { body, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { body, .. }) => { + | Stmt::For(ast::StmtFor { body, .. }) => { walk_body(self, body); } Stmt::Match(ast::StmtMatch { cases, .. }) => { diff --git a/crates/ruff_python_ast/src/identifier.rs b/crates/ruff_python_ast/src/identifier.rs index d5d18a9641..38a014ef14 100644 --- a/crates/ruff_python_ast/src/identifier.rs +++ b/crates/ruff_python_ast/src/identifier.rs @@ -31,8 +31,7 @@ impl Identifier for Stmt { fn identifier(&self) -> TextRange { match self { Stmt::ClassDef(ast::StmtClassDef { name, .. }) - | Stmt::FunctionDef(ast::StmtFunctionDef { name, .. }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { name, .. }) => name.range(), + | Stmt::FunctionDef(ast::StmtFunctionDef { name, .. }) => name.range(), _ => self.range(), } } @@ -85,10 +84,9 @@ pub fn except(handler: &ExceptHandler, source: &str) -> TextRange { .expect("Failed to find `except` token in `ExceptHandler`") } -/// Return the [`TextRange`] of the `else` token in a `For`, `AsyncFor`, or `While` statement. +/// Return the [`TextRange`] of the `else` token in a `For` or `While` statement. pub fn else_(stmt: &Stmt, source: &str) -> Option { let (Stmt::For(ast::StmtFor { body, orelse, .. }) - | Stmt::AsyncFor(ast::StmtAsyncFor { body, orelse, .. }) | Stmt::While(ast::StmtWhile { body, orelse, .. })) = stmt else { return None; diff --git a/crates/ruff_python_ast/src/lib.rs b/crates/ruff_python_ast/src/lib.rs index a7489fb7a1..25ebbc0329 100644 --- a/crates/ruff_python_ast/src/lib.rs +++ b/crates/ruff_python_ast/src/lib.rs @@ -6,7 +6,6 @@ pub mod call_path; pub mod cast; pub mod comparable; pub mod docstrings; -pub mod function; pub mod hashable; pub mod helpers; pub mod identifier; diff --git a/crates/ruff_python_ast/src/node.rs b/crates/ruff_python_ast/src/node.rs index 6f7f2c911f..ec2c6075c0 100644 --- a/crates/ruff_python_ast/src/node.rs +++ b/crates/ruff_python_ast/src/node.rs @@ -24,7 +24,6 @@ pub enum AnyNode { ModModule(ast::ModModule), ModExpression(ast::ModExpression), StmtFunctionDef(ast::StmtFunctionDef), - StmtAsyncFunctionDef(ast::StmtAsyncFunctionDef), StmtClassDef(ast::StmtClassDef), StmtReturn(ast::StmtReturn), StmtDelete(ast::StmtDelete), @@ -33,11 +32,9 @@ pub enum AnyNode { StmtAugAssign(ast::StmtAugAssign), StmtAnnAssign(ast::StmtAnnAssign), StmtFor(ast::StmtFor), - StmtAsyncFor(ast::StmtAsyncFor), StmtWhile(ast::StmtWhile), StmtIf(ast::StmtIf), StmtWith(ast::StmtWith), - StmtAsyncWith(ast::StmtAsyncWith), StmtMatch(ast::StmtMatch), StmtRaise(ast::StmtRaise), StmtTry(ast::StmtTry), @@ -110,7 +107,6 @@ impl AnyNode { pub fn statement(self) -> Option { match self { AnyNode::StmtFunctionDef(node) => Some(Stmt::FunctionDef(node)), - AnyNode::StmtAsyncFunctionDef(node) => Some(Stmt::AsyncFunctionDef(node)), AnyNode::StmtClassDef(node) => Some(Stmt::ClassDef(node)), AnyNode::StmtReturn(node) => Some(Stmt::Return(node)), AnyNode::StmtDelete(node) => Some(Stmt::Delete(node)), @@ -119,11 +115,9 @@ impl AnyNode { AnyNode::StmtAugAssign(node) => Some(Stmt::AugAssign(node)), AnyNode::StmtAnnAssign(node) => Some(Stmt::AnnAssign(node)), AnyNode::StmtFor(node) => Some(Stmt::For(node)), - AnyNode::StmtAsyncFor(node) => Some(Stmt::AsyncFor(node)), AnyNode::StmtWhile(node) => Some(Stmt::While(node)), AnyNode::StmtIf(node) => Some(Stmt::If(node)), AnyNode::StmtWith(node) => Some(Stmt::With(node)), - AnyNode::StmtAsyncWith(node) => Some(Stmt::AsyncWith(node)), AnyNode::StmtMatch(node) => Some(Stmt::Match(node)), AnyNode::StmtRaise(node) => Some(Stmt::Raise(node)), AnyNode::StmtTry(node) => Some(Stmt::Try(node)), @@ -230,7 +224,6 @@ impl AnyNode { AnyNode::ModModule(_) | AnyNode::ModExpression(_) | AnyNode::StmtFunctionDef(_) - | AnyNode::StmtAsyncFunctionDef(_) | AnyNode::StmtClassDef(_) | AnyNode::StmtReturn(_) | AnyNode::StmtDelete(_) @@ -239,11 +232,9 @@ impl AnyNode { | AnyNode::StmtAugAssign(_) | AnyNode::StmtAnnAssign(_) | AnyNode::StmtFor(_) - | AnyNode::StmtAsyncFor(_) | AnyNode::StmtWhile(_) | AnyNode::StmtIf(_) | AnyNode::StmtWith(_) - | AnyNode::StmtAsyncWith(_) | AnyNode::StmtMatch(_) | AnyNode::StmtRaise(_) | AnyNode::StmtTry(_) @@ -291,7 +282,6 @@ impl AnyNode { AnyNode::ModExpression(node) => Some(Mod::Expression(node)), AnyNode::StmtFunctionDef(_) - | AnyNode::StmtAsyncFunctionDef(_) | AnyNode::StmtClassDef(_) | AnyNode::StmtReturn(_) | AnyNode::StmtDelete(_) @@ -300,11 +290,9 @@ impl AnyNode { | AnyNode::StmtAugAssign(_) | AnyNode::StmtAnnAssign(_) | AnyNode::StmtFor(_) - | AnyNode::StmtAsyncFor(_) | AnyNode::StmtWhile(_) | AnyNode::StmtIf(_) | AnyNode::StmtWith(_) - | AnyNode::StmtAsyncWith(_) | AnyNode::StmtMatch(_) | AnyNode::StmtRaise(_) | AnyNode::StmtTry(_) @@ -388,7 +376,6 @@ impl AnyNode { AnyNode::ModModule(_) | AnyNode::ModExpression(_) | AnyNode::StmtFunctionDef(_) - | AnyNode::StmtAsyncFunctionDef(_) | AnyNode::StmtClassDef(_) | AnyNode::StmtReturn(_) | AnyNode::StmtDelete(_) @@ -397,11 +384,9 @@ impl AnyNode { | AnyNode::StmtAugAssign(_) | AnyNode::StmtAnnAssign(_) | AnyNode::StmtFor(_) - | AnyNode::StmtAsyncFor(_) | AnyNode::StmtWhile(_) | AnyNode::StmtIf(_) | AnyNode::StmtWith(_) - | AnyNode::StmtAsyncWith(_) | AnyNode::StmtMatch(_) | AnyNode::StmtRaise(_) | AnyNode::StmtTry(_) @@ -470,7 +455,6 @@ impl AnyNode { AnyNode::ModModule(_) | AnyNode::ModExpression(_) | AnyNode::StmtFunctionDef(_) - | AnyNode::StmtAsyncFunctionDef(_) | AnyNode::StmtClassDef(_) | AnyNode::StmtReturn(_) | AnyNode::StmtDelete(_) @@ -479,11 +463,9 @@ impl AnyNode { | AnyNode::StmtAugAssign(_) | AnyNode::StmtAnnAssign(_) | AnyNode::StmtFor(_) - | AnyNode::StmtAsyncFor(_) | AnyNode::StmtWhile(_) | AnyNode::StmtIf(_) | AnyNode::StmtWith(_) - | AnyNode::StmtAsyncWith(_) | AnyNode::StmtMatch(_) | AnyNode::StmtRaise(_) | AnyNode::StmtTry(_) @@ -577,7 +559,6 @@ impl AnyNode { Self::ModModule(node) => AnyNodeRef::ModModule(node), Self::ModExpression(node) => AnyNodeRef::ModExpression(node), Self::StmtFunctionDef(node) => AnyNodeRef::StmtFunctionDef(node), - Self::StmtAsyncFunctionDef(node) => AnyNodeRef::StmtAsyncFunctionDef(node), Self::StmtClassDef(node) => AnyNodeRef::StmtClassDef(node), Self::StmtReturn(node) => AnyNodeRef::StmtReturn(node), Self::StmtDelete(node) => AnyNodeRef::StmtDelete(node), @@ -586,11 +567,9 @@ impl AnyNode { Self::StmtAugAssign(node) => AnyNodeRef::StmtAugAssign(node), Self::StmtAnnAssign(node) => AnyNodeRef::StmtAnnAssign(node), Self::StmtFor(node) => AnyNodeRef::StmtFor(node), - Self::StmtAsyncFor(node) => AnyNodeRef::StmtAsyncFor(node), Self::StmtWhile(node) => AnyNodeRef::StmtWhile(node), Self::StmtIf(node) => AnyNodeRef::StmtIf(node), Self::StmtWith(node) => AnyNodeRef::StmtWith(node), - Self::StmtAsyncWith(node) => AnyNodeRef::StmtAsyncWith(node), Self::StmtMatch(node) => AnyNodeRef::StmtMatch(node), Self::StmtRaise(node) => AnyNodeRef::StmtRaise(node), Self::StmtTry(node) => AnyNodeRef::StmtTry(node), @@ -750,34 +729,6 @@ impl AstNode for ast::StmtFunctionDef { AnyNode::from(self) } } -impl AstNode for ast::StmtAsyncFunctionDef { - fn cast(kind: AnyNode) -> Option - where - Self: Sized, - { - if let AnyNode::StmtAsyncFunctionDef(node) = kind { - Some(node) - } else { - None - } - } - - fn cast_ref(kind: AnyNodeRef) -> Option<&Self> { - if let AnyNodeRef::StmtAsyncFunctionDef(node) = kind { - Some(node) - } else { - None - } - } - - fn as_any_node_ref(&self) -> AnyNodeRef { - AnyNodeRef::from(self) - } - - fn into_any_node(self) -> AnyNode { - AnyNode::from(self) - } -} impl AstNode for ast::StmtClassDef { fn cast(kind: AnyNode) -> Option where @@ -1002,34 +953,6 @@ impl AstNode for ast::StmtFor { AnyNode::from(self) } } -impl AstNode for ast::StmtAsyncFor { - fn cast(kind: AnyNode) -> Option - where - Self: Sized, - { - if let AnyNode::StmtAsyncFor(node) = kind { - Some(node) - } else { - None - } - } - - fn cast_ref(kind: AnyNodeRef) -> Option<&Self> { - if let AnyNodeRef::StmtAsyncFor(node) = kind { - Some(node) - } else { - None - } - } - - fn as_any_node_ref(&self) -> AnyNodeRef { - AnyNodeRef::from(self) - } - - fn into_any_node(self) -> AnyNode { - AnyNode::from(self) - } -} impl AstNode for ast::StmtWhile { fn cast(kind: AnyNode) -> Option where @@ -1142,34 +1065,6 @@ impl AstNode for ast::StmtWith { AnyNode::from(self) } } -impl AstNode for ast::StmtAsyncWith { - fn cast(kind: AnyNode) -> Option - where - Self: Sized, - { - if let AnyNode::StmtAsyncWith(node) = kind { - Some(node) - } else { - None - } - } - - fn cast_ref(kind: AnyNodeRef) -> Option<&Self> { - if let AnyNodeRef::StmtAsyncWith(node) = kind { - Some(node) - } else { - None - } - } - - fn as_any_node_ref(&self) -> AnyNodeRef { - AnyNodeRef::from(self) - } - - fn into_any_node(self) -> AnyNode { - AnyNode::from(self) - } -} impl AstNode for ast::StmtMatch { fn cast(kind: AnyNode) -> Option where @@ -2996,7 +2891,6 @@ impl From for AnyNode { fn from(stmt: Stmt) -> Self { match stmt { Stmt::FunctionDef(node) => AnyNode::StmtFunctionDef(node), - Stmt::AsyncFunctionDef(node) => AnyNode::StmtAsyncFunctionDef(node), Stmt::ClassDef(node) => AnyNode::StmtClassDef(node), Stmt::Return(node) => AnyNode::StmtReturn(node), Stmt::Delete(node) => AnyNode::StmtDelete(node), @@ -3005,11 +2899,9 @@ impl From for AnyNode { Stmt::AugAssign(node) => AnyNode::StmtAugAssign(node), Stmt::AnnAssign(node) => AnyNode::StmtAnnAssign(node), Stmt::For(node) => AnyNode::StmtFor(node), - Stmt::AsyncFor(node) => AnyNode::StmtAsyncFor(node), Stmt::While(node) => AnyNode::StmtWhile(node), Stmt::If(node) => AnyNode::StmtIf(node), Stmt::With(node) => AnyNode::StmtWith(node), - Stmt::AsyncWith(node) => AnyNode::StmtAsyncWith(node), Stmt::Match(node) => AnyNode::StmtMatch(node), Stmt::Raise(node) => AnyNode::StmtRaise(node), Stmt::Try(node) => AnyNode::StmtTry(node), @@ -3113,12 +3005,6 @@ impl From for AnyNode { } } -impl From for AnyNode { - fn from(node: ast::StmtAsyncFunctionDef) -> Self { - AnyNode::StmtAsyncFunctionDef(node) - } -} - impl From for AnyNode { fn from(node: ast::StmtClassDef) -> Self { AnyNode::StmtClassDef(node) @@ -3167,12 +3053,6 @@ impl From for AnyNode { } } -impl From for AnyNode { - fn from(node: ast::StmtAsyncFor) -> Self { - AnyNode::StmtAsyncFor(node) - } -} - impl From for AnyNode { fn from(node: ast::StmtWhile) -> Self { AnyNode::StmtWhile(node) @@ -3197,12 +3077,6 @@ impl From for AnyNode { } } -impl From for AnyNode { - fn from(node: ast::StmtAsyncWith) -> Self { - AnyNode::StmtAsyncWith(node) - } -} - impl From for AnyNode { fn from(node: ast::StmtMatch) -> Self { AnyNode::StmtMatch(node) @@ -3588,7 +3462,6 @@ impl Ranged for AnyNode { AnyNode::ModModule(node) => node.range(), AnyNode::ModExpression(node) => node.range(), AnyNode::StmtFunctionDef(node) => node.range(), - AnyNode::StmtAsyncFunctionDef(node) => node.range(), AnyNode::StmtClassDef(node) => node.range(), AnyNode::StmtReturn(node) => node.range(), AnyNode::StmtDelete(node) => node.range(), @@ -3597,11 +3470,9 @@ impl Ranged for AnyNode { AnyNode::StmtAugAssign(node) => node.range(), AnyNode::StmtAnnAssign(node) => node.range(), AnyNode::StmtFor(node) => node.range(), - AnyNode::StmtAsyncFor(node) => node.range(), AnyNode::StmtWhile(node) => node.range(), AnyNode::StmtIf(node) => node.range(), AnyNode::StmtWith(node) => node.range(), - AnyNode::StmtAsyncWith(node) => node.range(), AnyNode::StmtMatch(node) => node.range(), AnyNode::StmtRaise(node) => node.range(), AnyNode::StmtTry(node) => node.range(), @@ -3677,7 +3548,6 @@ pub enum AnyNodeRef<'a> { ModModule(&'a ast::ModModule), ModExpression(&'a ast::ModExpression), StmtFunctionDef(&'a ast::StmtFunctionDef), - StmtAsyncFunctionDef(&'a ast::StmtAsyncFunctionDef), StmtClassDef(&'a ast::StmtClassDef), StmtReturn(&'a ast::StmtReturn), StmtDelete(&'a ast::StmtDelete), @@ -3686,11 +3556,9 @@ pub enum AnyNodeRef<'a> { StmtAugAssign(&'a ast::StmtAugAssign), StmtAnnAssign(&'a ast::StmtAnnAssign), StmtFor(&'a ast::StmtFor), - StmtAsyncFor(&'a ast::StmtAsyncFor), StmtWhile(&'a ast::StmtWhile), StmtIf(&'a ast::StmtIf), StmtWith(&'a ast::StmtWith), - StmtAsyncWith(&'a ast::StmtAsyncWith), StmtMatch(&'a ast::StmtMatch), StmtRaise(&'a ast::StmtRaise), StmtTry(&'a ast::StmtTry), @@ -3765,7 +3633,6 @@ impl AnyNodeRef<'_> { AnyNodeRef::ModModule(node) => NonNull::from(*node).cast(), AnyNodeRef::ModExpression(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtFunctionDef(node) => NonNull::from(*node).cast(), - AnyNodeRef::StmtAsyncFunctionDef(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtClassDef(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtReturn(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtDelete(node) => NonNull::from(*node).cast(), @@ -3774,11 +3641,9 @@ impl AnyNodeRef<'_> { AnyNodeRef::StmtAugAssign(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtAnnAssign(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtFor(node) => NonNull::from(*node).cast(), - AnyNodeRef::StmtAsyncFor(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtWhile(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtIf(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtWith(node) => NonNull::from(*node).cast(), - AnyNodeRef::StmtAsyncWith(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtMatch(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtRaise(node) => NonNull::from(*node).cast(), AnyNodeRef::StmtTry(node) => NonNull::from(*node).cast(), @@ -3859,7 +3724,6 @@ impl AnyNodeRef<'_> { AnyNodeRef::ModModule(_) => NodeKind::ModModule, AnyNodeRef::ModExpression(_) => NodeKind::ModExpression, AnyNodeRef::StmtFunctionDef(_) => NodeKind::StmtFunctionDef, - AnyNodeRef::StmtAsyncFunctionDef(_) => NodeKind::StmtAsyncFunctionDef, AnyNodeRef::StmtClassDef(_) => NodeKind::StmtClassDef, AnyNodeRef::StmtReturn(_) => NodeKind::StmtReturn, AnyNodeRef::StmtDelete(_) => NodeKind::StmtDelete, @@ -3868,11 +3732,9 @@ impl AnyNodeRef<'_> { AnyNodeRef::StmtAugAssign(_) => NodeKind::StmtAugAssign, AnyNodeRef::StmtAnnAssign(_) => NodeKind::StmtAnnAssign, AnyNodeRef::StmtFor(_) => NodeKind::StmtFor, - AnyNodeRef::StmtAsyncFor(_) => NodeKind::StmtAsyncFor, AnyNodeRef::StmtWhile(_) => NodeKind::StmtWhile, AnyNodeRef::StmtIf(_) => NodeKind::StmtIf, AnyNodeRef::StmtWith(_) => NodeKind::StmtWith, - AnyNodeRef::StmtAsyncWith(_) => NodeKind::StmtAsyncWith, AnyNodeRef::StmtMatch(_) => NodeKind::StmtMatch, AnyNodeRef::StmtRaise(_) => NodeKind::StmtRaise, AnyNodeRef::StmtTry(_) => NodeKind::StmtTry, @@ -3945,7 +3807,6 @@ impl AnyNodeRef<'_> { pub const fn is_statement(self) -> bool { match self { AnyNodeRef::StmtFunctionDef(_) - | AnyNodeRef::StmtAsyncFunctionDef(_) | AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtReturn(_) | AnyNodeRef::StmtDelete(_) @@ -3954,11 +3815,9 @@ impl AnyNodeRef<'_> { | AnyNodeRef::StmtAugAssign(_) | AnyNodeRef::StmtAnnAssign(_) | AnyNodeRef::StmtFor(_) - | AnyNodeRef::StmtAsyncFor(_) | AnyNodeRef::StmtWhile(_) | AnyNodeRef::StmtIf(_) | AnyNodeRef::StmtWith(_) - | AnyNodeRef::StmtAsyncWith(_) | AnyNodeRef::StmtMatch(_) | AnyNodeRef::StmtRaise(_) | AnyNodeRef::StmtTry(_) @@ -4065,7 +3924,6 @@ impl AnyNodeRef<'_> { AnyNodeRef::ModModule(_) | AnyNodeRef::ModExpression(_) | AnyNodeRef::StmtFunctionDef(_) - | AnyNodeRef::StmtAsyncFunctionDef(_) | AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtReturn(_) | AnyNodeRef::StmtDelete(_) @@ -4074,11 +3932,9 @@ impl AnyNodeRef<'_> { | AnyNodeRef::StmtAugAssign(_) | AnyNodeRef::StmtAnnAssign(_) | AnyNodeRef::StmtFor(_) - | AnyNodeRef::StmtAsyncFor(_) | AnyNodeRef::StmtWhile(_) | AnyNodeRef::StmtIf(_) | AnyNodeRef::StmtWith(_) - | AnyNodeRef::StmtAsyncWith(_) | AnyNodeRef::StmtMatch(_) | AnyNodeRef::StmtRaise(_) | AnyNodeRef::StmtTry(_) @@ -4125,7 +3981,6 @@ impl AnyNodeRef<'_> { AnyNodeRef::ModModule(_) | AnyNodeRef::ModExpression(_) => true, AnyNodeRef::StmtFunctionDef(_) - | AnyNodeRef::StmtAsyncFunctionDef(_) | AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtReturn(_) | AnyNodeRef::StmtDelete(_) @@ -4134,11 +3989,9 @@ impl AnyNodeRef<'_> { | AnyNodeRef::StmtAugAssign(_) | AnyNodeRef::StmtAnnAssign(_) | AnyNodeRef::StmtFor(_) - | AnyNodeRef::StmtAsyncFor(_) | AnyNodeRef::StmtWhile(_) | AnyNodeRef::StmtIf(_) | AnyNodeRef::StmtWith(_) - | AnyNodeRef::StmtAsyncWith(_) | AnyNodeRef::StmtMatch(_) | AnyNodeRef::StmtRaise(_) | AnyNodeRef::StmtTry(_) @@ -4222,7 +4075,6 @@ impl AnyNodeRef<'_> { AnyNodeRef::ModModule(_) | AnyNodeRef::ModExpression(_) | AnyNodeRef::StmtFunctionDef(_) - | AnyNodeRef::StmtAsyncFunctionDef(_) | AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtReturn(_) | AnyNodeRef::StmtDelete(_) @@ -4231,11 +4083,9 @@ impl AnyNodeRef<'_> { | AnyNodeRef::StmtAugAssign(_) | AnyNodeRef::StmtAnnAssign(_) | AnyNodeRef::StmtFor(_) - | AnyNodeRef::StmtAsyncFor(_) | AnyNodeRef::StmtWhile(_) | AnyNodeRef::StmtIf(_) | AnyNodeRef::StmtWith(_) - | AnyNodeRef::StmtAsyncWith(_) | AnyNodeRef::StmtMatch(_) | AnyNodeRef::StmtRaise(_) | AnyNodeRef::StmtTry(_) @@ -4304,7 +4154,6 @@ impl AnyNodeRef<'_> { AnyNodeRef::ModModule(_) | AnyNodeRef::ModExpression(_) | AnyNodeRef::StmtFunctionDef(_) - | AnyNodeRef::StmtAsyncFunctionDef(_) | AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtReturn(_) | AnyNodeRef::StmtDelete(_) @@ -4313,11 +4162,9 @@ impl AnyNodeRef<'_> { | AnyNodeRef::StmtAugAssign(_) | AnyNodeRef::StmtAnnAssign(_) | AnyNodeRef::StmtFor(_) - | AnyNodeRef::StmtAsyncFor(_) | AnyNodeRef::StmtWhile(_) | AnyNodeRef::StmtIf(_) | AnyNodeRef::StmtWith(_) - | AnyNodeRef::StmtAsyncWith(_) | AnyNodeRef::StmtMatch(_) | AnyNodeRef::StmtRaise(_) | AnyNodeRef::StmtTry(_) @@ -4391,13 +4238,10 @@ impl AnyNodeRef<'_> { self, AnyNodeRef::StmtIf(_) | AnyNodeRef::StmtFor(_) - | AnyNodeRef::StmtAsyncFor(_) | AnyNodeRef::StmtWhile(_) | AnyNodeRef::StmtWith(_) - | AnyNodeRef::StmtAsyncWith(_) | AnyNodeRef::StmtMatch(_) | AnyNodeRef::StmtFunctionDef(_) - | AnyNodeRef::StmtAsyncFunctionDef(_) | AnyNodeRef::StmtClassDef(_) | AnyNodeRef::StmtTry(_) | AnyNodeRef::StmtTryStar(_) @@ -4435,12 +4279,6 @@ impl<'a> From<&'a ast::StmtFunctionDef> for AnyNodeRef<'a> { } } -impl<'a> From<&'a ast::StmtAsyncFunctionDef> for AnyNodeRef<'a> { - fn from(node: &'a ast::StmtAsyncFunctionDef) -> Self { - AnyNodeRef::StmtAsyncFunctionDef(node) - } -} - impl<'a> From<&'a ast::StmtClassDef> for AnyNodeRef<'a> { fn from(node: &'a ast::StmtClassDef) -> Self { AnyNodeRef::StmtClassDef(node) @@ -4489,12 +4327,6 @@ impl<'a> From<&'a ast::StmtFor> for AnyNodeRef<'a> { } } -impl<'a> From<&'a ast::StmtAsyncFor> for AnyNodeRef<'a> { - fn from(node: &'a ast::StmtAsyncFor) -> Self { - AnyNodeRef::StmtAsyncFor(node) - } -} - impl<'a> From<&'a ast::StmtWhile> for AnyNodeRef<'a> { fn from(node: &'a ast::StmtWhile) -> Self { AnyNodeRef::StmtWhile(node) @@ -4519,12 +4351,6 @@ impl<'a> From<&'a ast::StmtWith> for AnyNodeRef<'a> { } } -impl<'a> From<&'a ast::StmtAsyncWith> for AnyNodeRef<'a> { - fn from(node: &'a ast::StmtAsyncWith) -> Self { - AnyNodeRef::StmtAsyncWith(node) - } -} - impl<'a> From<&'a ast::StmtMatch> for AnyNodeRef<'a> { fn from(node: &'a ast::StmtMatch) -> Self { AnyNodeRef::StmtMatch(node) @@ -4864,7 +4690,6 @@ impl<'a> From<&'a Stmt> for AnyNodeRef<'a> { fn from(stmt: &'a Stmt) -> Self { match stmt { Stmt::FunctionDef(node) => AnyNodeRef::StmtFunctionDef(node), - Stmt::AsyncFunctionDef(node) => AnyNodeRef::StmtAsyncFunctionDef(node), Stmt::ClassDef(node) => AnyNodeRef::StmtClassDef(node), Stmt::Return(node) => AnyNodeRef::StmtReturn(node), Stmt::Delete(node) => AnyNodeRef::StmtDelete(node), @@ -4873,11 +4698,9 @@ impl<'a> From<&'a Stmt> for AnyNodeRef<'a> { Stmt::AugAssign(node) => AnyNodeRef::StmtAugAssign(node), Stmt::AnnAssign(node) => AnyNodeRef::StmtAnnAssign(node), Stmt::For(node) => AnyNodeRef::StmtFor(node), - Stmt::AsyncFor(node) => AnyNodeRef::StmtAsyncFor(node), Stmt::While(node) => AnyNodeRef::StmtWhile(node), Stmt::If(node) => AnyNodeRef::StmtIf(node), Stmt::With(node) => AnyNodeRef::StmtWith(node), - Stmt::AsyncWith(node) => AnyNodeRef::StmtAsyncWith(node), Stmt::Match(node) => AnyNodeRef::StmtMatch(node), Stmt::Raise(node) => AnyNodeRef::StmtRaise(node), Stmt::Try(node) => AnyNodeRef::StmtTry(node), @@ -5027,7 +4850,6 @@ impl Ranged for AnyNodeRef<'_> { AnyNodeRef::ModModule(node) => node.range(), AnyNodeRef::ModExpression(node) => node.range(), AnyNodeRef::StmtFunctionDef(node) => node.range(), - AnyNodeRef::StmtAsyncFunctionDef(node) => node.range(), AnyNodeRef::StmtClassDef(node) => node.range(), AnyNodeRef::StmtReturn(node) => node.range(), AnyNodeRef::StmtDelete(node) => node.range(), @@ -5036,11 +4858,9 @@ impl Ranged for AnyNodeRef<'_> { AnyNodeRef::StmtAugAssign(node) => node.range(), AnyNodeRef::StmtAnnAssign(node) => node.range(), AnyNodeRef::StmtFor(node) => node.range(), - AnyNodeRef::StmtAsyncFor(node) => node.range(), AnyNodeRef::StmtWhile(node) => node.range(), AnyNodeRef::StmtIf(node) => node.range(), AnyNodeRef::StmtWith(node) => node.range(), - AnyNodeRef::StmtAsyncWith(node) => node.range(), AnyNodeRef::StmtMatch(node) => node.range(), AnyNodeRef::StmtRaise(node) => node.range(), AnyNodeRef::StmtTry(node) => node.range(), @@ -5118,7 +4938,6 @@ pub enum NodeKind { ModExpression, ModFunctionType, StmtFunctionDef, - StmtAsyncFunctionDef, StmtClassDef, StmtReturn, StmtDelete, @@ -5127,11 +4946,9 @@ pub enum NodeKind { StmtAugAssign, StmtAnnAssign, StmtFor, - StmtAsyncFor, StmtWhile, StmtIf, StmtWith, - StmtAsyncWith, StmtMatch, StmtRaise, StmtTry, diff --git a/crates/ruff_python_ast/src/nodes.rs b/crates/ruff_python_ast/src/nodes.rs index d536ce8ca6..5e5239872a 100644 --- a/crates/ruff_python_ast/src/nodes.rs +++ b/crates/ruff_python_ast/src/nodes.rs @@ -45,8 +45,6 @@ impl From for Mod { pub enum Stmt { #[is(name = "function_def_stmt")] FunctionDef(StmtFunctionDef), - #[is(name = "async_function_def_stmt")] - AsyncFunctionDef(StmtAsyncFunctionDef), #[is(name = "class_def_stmt")] ClassDef(StmtClassDef), #[is(name = "return_stmt")] @@ -63,16 +61,12 @@ pub enum Stmt { TypeAlias(StmtTypeAlias), #[is(name = "for_stmt")] For(StmtFor), - #[is(name = "async_for_stmt")] - AsyncFor(StmtAsyncFor), #[is(name = "while_stmt")] While(StmtWhile), #[is(name = "if_stmt")] If(StmtIf), #[is(name = "with_stmt")] With(StmtWith), - #[is(name = "async_with_stmt")] - AsyncWith(StmtAsyncWith), #[is(name = "match_stmt")] Match(StmtMatch), #[is(name = "raise_stmt")] @@ -118,10 +112,15 @@ impl From for Stmt { } } -/// See also [FunctionDef](https://docs.python.org/3/library/ast.html#ast.FunctionDef) +/// See also [FunctionDef](https://docs.python.org/3/library/ast.html#ast.FunctionDef) and +/// [AsyncFunctionDef](https://docs.python.org/3/library/ast.html#ast.AsyncFunctionDef). +/// +/// This type differs from the original Python AST, as it collapses the +/// synchronous and asynchronous variants into a single type. #[derive(Clone, Debug, PartialEq)] pub struct StmtFunctionDef { pub range: TextRange, + pub is_async: bool, pub decorator_list: Vec, pub name: Identifier, pub type_params: Option, @@ -136,24 +135,6 @@ impl From for Stmt { } } -/// See also [AsyncFunctionDef](https://docs.python.org/3/library/ast.html#ast.AsyncFunctionDef) -#[derive(Clone, Debug, PartialEq)] -pub struct StmtAsyncFunctionDef { - pub range: TextRange, - pub decorator_list: Vec, - pub name: Identifier, - pub type_params: Option, - pub parameters: Box, - pub returns: Option>, - pub body: Vec, -} - -impl From for Stmt { - fn from(payload: StmtAsyncFunctionDef) -> Self { - Stmt::AsyncFunctionDef(payload) - } -} - /// See also [ClassDef](https://docs.python.org/3/library/ast.html#ast.ClassDef) #[derive(Clone, Debug, PartialEq)] pub struct StmtClassDef { @@ -275,10 +256,15 @@ impl From for Stmt { } } -/// See also [For](https://docs.python.org/3/library/ast.html#ast.For) +/// See also [For](https://docs.python.org/3/library/ast.html#ast.For) and +/// [AsyncFor](https://docs.python.org/3/library/ast.html#ast.AsyncFor). +/// +/// This type differs from the original Python AST, as it collapses the +/// synchronous and asynchronous variants into a single type. #[derive(Clone, Debug, PartialEq)] pub struct StmtFor { pub range: TextRange, + pub is_async: bool, pub target: Box, pub iter: Box, pub body: Vec, @@ -291,23 +277,8 @@ impl From for Stmt { } } -/// See also [AsyncFor](https://docs.python.org/3/library/ast.html#ast.AsyncFor) -#[derive(Clone, Debug, PartialEq)] -pub struct StmtAsyncFor { - pub range: TextRange, - pub target: Box, - pub iter: Box, - pub body: Vec, - pub orelse: Vec, -} - -impl From for Stmt { - fn from(payload: StmtAsyncFor) -> Self { - Stmt::AsyncFor(payload) - } -} - -/// See also [While](https://docs.python.org/3/library/ast.html#ast.While) +/// See also [While](https://docs.python.org/3/library/ast.html#ast.While) and +/// [AsyncWhile](https://docs.python.org/3/library/ast.html#ast.AsyncWhile). #[derive(Clone, Debug, PartialEq)] pub struct StmtWhile { pub range: TextRange, @@ -344,10 +315,15 @@ pub struct ElifElseClause { pub body: Vec, } -/// See also [With](https://docs.python.org/3/library/ast.html#ast.With) +/// See also [With](https://docs.python.org/3/library/ast.html#ast.With) and +/// [AsyncWith](https://docs.python.org/3/library/ast.html#ast.AsyncWith). +/// +/// This type differs from the original Python AST, as it collapses the +/// synchronous and asynchronous variants into a single type. #[derive(Clone, Debug, PartialEq)] pub struct StmtWith { pub range: TextRange, + pub is_async: bool, pub items: Vec, pub body: Vec, } @@ -358,20 +334,6 @@ impl From for Stmt { } } -/// See also [AsyncWith](https://docs.python.org/3/library/ast.html#ast.AsyncWith) -#[derive(Clone, Debug, PartialEq)] -pub struct StmtAsyncWith { - pub range: TextRange, - pub items: Vec, - pub body: Vec, -} - -impl From for Stmt { - fn from(payload: StmtAsyncWith) -> Self { - Stmt::AsyncWith(payload) - } -} - /// See also [Match](https://docs.python.org/3/library/ast.html#ast.Match) #[derive(Clone, Debug, PartialEq)] pub struct StmtMatch { @@ -2599,11 +2561,6 @@ impl Ranged for crate::nodes::StmtFunctionDef { self.range } } -impl Ranged for crate::nodes::StmtAsyncFunctionDef { - fn range(&self) -> TextRange { - self.range - } -} impl Ranged for crate::nodes::StmtClassDef { fn range(&self) -> TextRange { self.range @@ -2644,11 +2601,6 @@ impl Ranged for crate::nodes::StmtFor { self.range } } -impl Ranged for crate::nodes::StmtAsyncFor { - fn range(&self) -> TextRange { - self.range - } -} impl Ranged for crate::nodes::StmtWhile { fn range(&self) -> TextRange { self.range @@ -2669,11 +2621,6 @@ impl Ranged for crate::nodes::StmtWith { self.range } } -impl Ranged for crate::nodes::StmtAsyncWith { - fn range(&self) -> TextRange { - self.range - } -} impl Ranged for crate::nodes::StmtMatch { fn range(&self) -> TextRange { self.range @@ -2748,7 +2695,6 @@ impl Ranged for crate::Stmt { fn range(&self) -> TextRange { match self { Self::FunctionDef(node) => node.range(), - Self::AsyncFunctionDef(node) => node.range(), Self::ClassDef(node) => node.range(), Self::Return(node) => node.range(), Self::Delete(node) => node.range(), @@ -2757,11 +2703,9 @@ impl Ranged for crate::Stmt { Self::AugAssign(node) => node.range(), Self::AnnAssign(node) => node.range(), Self::For(node) => node.range(), - Self::AsyncFor(node) => node.range(), Self::While(node) => node.range(), Self::If(node) => node.range(), Self::With(node) => node.range(), - Self::AsyncWith(node) => node.range(), Self::Match(node) => node.range(), Self::Raise(node) => node.range(), Self::Try(node) => node.range(), @@ -3110,8 +3054,7 @@ mod size_assertions { use static_assertions::assert_eq_size; assert_eq_size!(Stmt, [u8; 144]); - assert_eq_size!(StmtFunctionDef, [u8; 136]); - assert_eq_size!(StmtAsyncFunctionDef, [u8; 136]); + assert_eq_size!(StmtFunctionDef, [u8; 144]); assert_eq_size!(StmtClassDef, [u8; 104]); assert_eq_size!(StmtTry, [u8; 104]); assert_eq_size!(Expr, [u8; 80]); diff --git a/crates/ruff_python_ast/src/statement_visitor.rs b/crates/ruff_python_ast/src/statement_visitor.rs index c061e0beb8..57f458d8fb 100644 --- a/crates/ruff_python_ast/src/statement_visitor.rs +++ b/crates/ruff_python_ast/src/statement_visitor.rs @@ -32,9 +32,6 @@ pub fn walk_stmt<'a, V: StatementVisitor<'a> + ?Sized>(visitor: &mut V, stmt: &' Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => { visitor.visit_body(body); } - Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { body, .. }) => { - visitor.visit_body(body); - } Stmt::For(ast::StmtFor { body, orelse, .. }) => { visitor.visit_body(body); visitor.visit_body(orelse); @@ -42,10 +39,6 @@ pub fn walk_stmt<'a, V: StatementVisitor<'a> + ?Sized>(visitor: &mut V, stmt: &' Stmt::ClassDef(ast::StmtClassDef { body, .. }) => { visitor.visit_body(body); } - Stmt::AsyncFor(ast::StmtAsyncFor { body, orelse, .. }) => { - visitor.visit_body(body); - visitor.visit_body(orelse); - } Stmt::While(ast::StmtWhile { body, orelse, .. }) => { visitor.visit_body(body); visitor.visit_body(orelse); @@ -63,9 +56,6 @@ pub fn walk_stmt<'a, V: StatementVisitor<'a> + ?Sized>(visitor: &mut V, stmt: &' Stmt::With(ast::StmtWith { body, .. }) => { visitor.visit_body(body); } - Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => { - visitor.visit_body(body); - } Stmt::Match(ast::StmtMatch { cases, .. }) => { for match_case in cases { visitor.visit_match_case(match_case); diff --git a/crates/ruff_python_ast/src/traversal.rs b/crates/ruff_python_ast/src/traversal.rs index 95f88d13c9..6a732aa890 100644 --- a/crates/ruff_python_ast/src/traversal.rs +++ b/crates/ruff_python_ast/src/traversal.rs @@ -5,7 +5,6 @@ use crate::{self as ast, ExceptHandler, Stmt, Suite}; pub fn suite<'a>(stmt: &'a Stmt, parent: &'a Stmt) -> Option<&'a Suite> { match parent { Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => Some(body), - Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { body, .. }) => Some(body), Stmt::ClassDef(ast::StmtClassDef { body, .. }) => Some(body), Stmt::For(ast::StmtFor { body, orelse, .. }) => { if body.contains(stmt) { @@ -16,15 +15,6 @@ pub fn suite<'a>(stmt: &'a Stmt, parent: &'a Stmt) -> Option<&'a Suite> { None } } - Stmt::AsyncFor(ast::StmtAsyncFor { body, orelse, .. }) => { - if body.contains(stmt) { - Some(body) - } else if orelse.contains(stmt) { - Some(orelse) - } else { - None - } - } Stmt::While(ast::StmtWhile { body, orelse, .. }) => { if body.contains(stmt) { Some(body) @@ -49,7 +39,6 @@ pub fn suite<'a>(stmt: &'a Stmt, parent: &'a Stmt) -> Option<&'a Suite> { } } Stmt::With(ast::StmtWith { body, .. }) => Some(body), - Stmt::AsyncWith(ast::StmtAsyncWith { body, .. }) => Some(body), Stmt::Match(ast::StmtMatch { cases, .. }) => cases .iter() .map(|case| &case.body) diff --git a/crates/ruff_python_ast/src/visitor.rs b/crates/ruff_python_ast/src/visitor.rs index b3aac478bb..be2fc8fc6d 100644 --- a/crates/ruff_python_ast/src/visitor.rs +++ b/crates/ruff_python_ast/src/visitor.rs @@ -128,26 +128,6 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) { } visitor.visit_body(body); } - Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - parameters, - body, - decorator_list, - returns, - type_params, - .. - }) => { - for decorator in decorator_list { - visitor.visit_decorator(decorator); - } - if let Some(type_params) = type_params { - visitor.visit_type_params(type_params); - } - visitor.visit_parameters(parameters); - for expr in returns { - visitor.visit_annotation(expr); - } - visitor.visit_body(body); - } Stmt::ClassDef(ast::StmtClassDef { arguments, body, @@ -228,18 +208,6 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) { visitor.visit_body(body); visitor.visit_body(orelse); } - Stmt::AsyncFor(ast::StmtAsyncFor { - target, - iter, - body, - orelse, - .. - }) => { - visitor.visit_expr(iter); - visitor.visit_expr(target); - visitor.visit_body(body); - visitor.visit_body(orelse); - } Stmt::While(ast::StmtWhile { test, body, @@ -271,12 +239,6 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) { } visitor.visit_body(body); } - Stmt::AsyncWith(ast::StmtAsyncWith { items, body, .. }) => { - for with_item in items { - visitor.visit_with_item(with_item); - } - visitor.visit_body(body); - } Stmt::Match(ast::StmtMatch { subject, cases, diff --git a/crates/ruff_python_ast/src/visitor/preorder.rs b/crates/ruff_python_ast/src/visitor/preorder.rs index 84a9e0535a..90ae53db61 100644 --- a/crates/ruff_python_ast/src/visitor/preorder.rs +++ b/crates/ruff_python_ast/src/visitor/preorder.rs @@ -145,14 +145,6 @@ where returns, type_params, .. - }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - parameters, - body, - decorator_list, - returns, - type_params, - .. }) => { for decorator in decorator_list { visitor.visit_decorator(decorator); @@ -261,13 +253,6 @@ where body, orelse, .. - }) - | Stmt::AsyncFor(ast::StmtAsyncFor { - target, - iter, - body, - orelse, - .. }) => { visitor.visit_expr(target); visitor.visit_expr(iter); @@ -302,11 +287,7 @@ where Stmt::With(ast::StmtWith { items, body, - range: _, - }) - | Stmt::AsyncWith(ast::StmtAsyncWith { - items, - body, + is_async: _, range: _, }) => { for with_item in items { diff --git a/crates/ruff_python_codegen/src/generator.rs b/crates/ruff_python_codegen/src/generator.rs index b502abad65..162f75736d 100644 --- a/crates/ruff_python_codegen/src/generator.rs +++ b/crates/ruff_python_codegen/src/generator.rs @@ -23,7 +23,6 @@ mod precedence { pub(crate) const YIELD_FROM: u8 = 7; pub(crate) const IF: u8 = 9; pub(crate) const FOR: u8 = 9; - pub(crate) const ASYNC_FOR: u8 = 9; pub(crate) const WHILE: u8 = 9; pub(crate) const RETURN: u8 = 11; pub(crate) const SLICE: u8 = 13; @@ -204,6 +203,7 @@ impl<'a> Generator<'a> { match ast { Stmt::FunctionDef(ast::StmtFunctionDef { + is_async, name, parameters, body, @@ -220,6 +220,9 @@ impl<'a> Generator<'a> { }); } statement!({ + if *is_async { + self.p("async "); + } self.p("def "); self.p_id(name); if let Some(type_params) = type_params { @@ -239,42 +242,6 @@ impl<'a> Generator<'a> { self.newlines(2); } } - Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - name, - parameters, - body, - returns, - decorator_list, - type_params, - .. - }) => { - self.newlines(if self.indent_depth == 0 { 2 } else { 1 }); - for decorator in decorator_list { - statement!({ - self.p("@"); - self.unparse_expr(&decorator.expression, precedence::MAX); - }); - } - statement!({ - self.p("async def "); - self.p_id(name); - if let Some(type_params) = type_params { - self.unparse_type_params(type_params); - } - self.p("("); - self.unparse_parameters(parameters); - self.p(")"); - if let Some(returns) = returns { - self.p(" -> "); - self.unparse_expr(returns, precedence::MAX); - } - self.p(":"); - }); - self.body(body); - if self.indent_depth == 0 { - self.newlines(2); - } - } Stmt::ClassDef(ast::StmtClassDef { name, arguments, @@ -400,6 +367,7 @@ impl<'a> Generator<'a> { }); } Stmt::For(ast::StmtFor { + is_async, target, iter, body, @@ -407,6 +375,9 @@ impl<'a> Generator<'a> { .. }) => { statement!({ + if *is_async { + self.p("async "); + } self.p("for "); self.unparse_expr(target, precedence::FOR); self.p(" in "); @@ -421,28 +392,6 @@ impl<'a> Generator<'a> { self.body(orelse); } } - Stmt::AsyncFor(ast::StmtAsyncFor { - target, - iter, - body, - orelse, - .. - }) => { - statement!({ - self.p("async for "); - self.unparse_expr(target, precedence::ASYNC_FOR); - self.p(" in "); - self.unparse_expr(iter, precedence::MAX); - self.p(":"); - }); - self.body(body); - if !orelse.is_empty() { - statement!({ - self.p("else:"); - }); - self.body(orelse); - } - } Stmt::While(ast::StmtWhile { test, body, @@ -490,21 +439,17 @@ impl<'a> Generator<'a> { self.body(&clause.body); } } - Stmt::With(ast::StmtWith { items, body, .. }) => { + Stmt::With(ast::StmtWith { + is_async, + items, + body, + .. + }) => { statement!({ - self.p("with "); - let mut first = true; - for item in items { - self.p_delim(&mut first, ", "); - self.unparse_with_item(item); + if *is_async { + self.p("async "); } - self.p(":"); - }); - self.body(body); - } - Stmt::AsyncWith(ast::StmtAsyncWith { items, body, .. }) => { - statement!({ - self.p("async with "); + self.p("with "); let mut first = true; for item in items { self.p_delim(&mut first, ", "); diff --git a/crates/ruff_python_formatter/src/comments/placement.rs b/crates/ruff_python_formatter/src/comments/placement.rs index 865c770b19..728ea3e7ca 100644 --- a/crates/ruff_python_formatter/src/comments/placement.rs +++ b/crates/ruff_python_formatter/src/comments/placement.rs @@ -67,9 +67,7 @@ pub(super) fn place_comment<'a>( handle_module_level_own_line_comment_before_class_or_function_comment(comment, locator) } AnyNodeRef::WithItem(_) => handle_with_item_comment(comment, locator), - AnyNodeRef::StmtFunctionDef(_) | AnyNodeRef::StmtAsyncFunctionDef(_) => { - handle_leading_function_with_decorators_comment(comment) - } + AnyNodeRef::StmtFunctionDef(_) => handle_leading_function_with_decorators_comment(comment), AnyNodeRef::StmtClassDef(class_def) => { handle_leading_class_with_decorators_comment(comment, class_def) } @@ -178,7 +176,6 @@ fn handle_end_of_line_comment_around_body<'a>( fn is_first_statement_in_body(statement: AnyNodeRef, has_body: AnyNodeRef) -> bool { match has_body { AnyNodeRef::StmtFor(ast::StmtFor { body, orelse, .. }) - | AnyNodeRef::StmtAsyncFor(ast::StmtAsyncFor { body, orelse, .. }) | AnyNodeRef::StmtWhile(ast::StmtWhile { body, orelse, .. }) => { are_same_optional(statement, body.first()) || are_same_optional(statement, orelse.first()) @@ -208,7 +205,6 @@ fn is_first_statement_in_body(statement: AnyNodeRef, has_body: AnyNodeRef) -> bo body, .. }) | AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. }) - | AnyNodeRef::StmtAsyncFunctionDef(ast::StmtAsyncFunctionDef { body, .. }) | AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. }) => { are_same_optional(statement, body.first()) } @@ -773,9 +769,7 @@ fn handle_module_level_own_line_comment_before_class_or_function_comment<'a>( // ... where the following is a function or class statement. if !matches!( following, - AnyNodeRef::StmtAsyncFunctionDef(_) - | AnyNodeRef::StmtFunctionDef(_) - | AnyNodeRef::StmtClassDef(_) + AnyNodeRef::StmtFunctionDef(_) | AnyNodeRef::StmtClassDef(_) ) { return CommentPlacement::Default(comment); } @@ -1408,10 +1402,8 @@ where fn last_child_in_body(node: AnyNodeRef) -> Option { let body = match node { AnyNodeRef::StmtFunctionDef(ast::StmtFunctionDef { body, .. }) - | AnyNodeRef::StmtAsyncFunctionDef(ast::StmtAsyncFunctionDef { body, .. }) | AnyNodeRef::StmtClassDef(ast::StmtClassDef { body, .. }) | AnyNodeRef::StmtWith(ast::StmtWith { body, .. }) - | AnyNodeRef::StmtAsyncWith(ast::StmtAsyncWith { body, .. }) | AnyNodeRef::MatchCase(MatchCase { body, .. }) | AnyNodeRef::ExceptHandlerExceptHandler(ast::ExceptHandlerExceptHandler { body, .. @@ -1424,7 +1416,6 @@ fn last_child_in_body(node: AnyNodeRef) -> Option { }) => elif_else_clauses.last().map_or(body, |clause| &clause.body), AnyNodeRef::StmtFor(ast::StmtFor { body, orelse, .. }) - | AnyNodeRef::StmtAsyncFor(ast::StmtAsyncFor { body, orelse, .. }) | AnyNodeRef::StmtWhile(ast::StmtWhile { body, orelse, .. }) => { if orelse.is_empty() { body @@ -1477,7 +1468,6 @@ fn last_child_in_body(node: AnyNodeRef) -> Option { fn is_first_statement_in_alternate_body(statement: AnyNodeRef, has_body: AnyNodeRef) -> bool { match has_body { AnyNodeRef::StmtFor(ast::StmtFor { orelse, .. }) - | AnyNodeRef::StmtAsyncFor(ast::StmtAsyncFor { orelse, .. }) | AnyNodeRef::StmtWhile(ast::StmtWhile { orelse, .. }) => { are_same_optional(statement, orelse.first()) } diff --git a/crates/ruff_python_formatter/src/expression/expr_named_expr.rs b/crates/ruff_python_formatter/src/expression/expr_named_expr.rs index 88bf4b2052..e0b1f674b9 100644 --- a/crates/ruff_python_formatter/src/expression/expr_named_expr.rs +++ b/crates/ruff_python_formatter/src/expression/expr_named_expr.rs @@ -46,7 +46,6 @@ impl NeedsParentheses for ExprNamedExpr { || parent.is_with_item() || parent.is_stmt_delete() || parent.is_stmt_for() - || parent.is_stmt_async_for() { OptionalParentheses::Always } else { diff --git a/crates/ruff_python_formatter/src/generated.rs b/crates/ruff_python_formatter/src/generated.rs index 1871e09bd5..5a71b2d6e9 100644 --- a/crates/ruff_python_formatter/src/generated.rs +++ b/crates/ruff_python_formatter/src/generated.rs @@ -108,42 +108,6 @@ impl<'ast> IntoFormat> for ast::StmtFunctionDef { } } -impl FormatRule> - for crate::statement::stmt_async_function_def::FormatStmtAsyncFunctionDef -{ - #[inline] - fn fmt(&self, node: &ast::StmtAsyncFunctionDef, f: &mut PyFormatter) -> FormatResult<()> { - FormatNodeRule::::fmt(self, node, f) - } -} -impl<'ast> AsFormat> for ast::StmtAsyncFunctionDef { - type Format<'a> = FormatRefWithRule< - 'a, - ast::StmtAsyncFunctionDef, - crate::statement::stmt_async_function_def::FormatStmtAsyncFunctionDef, - PyFormatContext<'ast>, - >; - fn format(&self) -> Self::Format<'_> { - FormatRefWithRule::new( - self, - crate::statement::stmt_async_function_def::FormatStmtAsyncFunctionDef::default(), - ) - } -} -impl<'ast> IntoFormat> for ast::StmtAsyncFunctionDef { - type Format = FormatOwnedWithRule< - ast::StmtAsyncFunctionDef, - crate::statement::stmt_async_function_def::FormatStmtAsyncFunctionDef, - PyFormatContext<'ast>, - >; - fn into_format(self) -> Self::Format { - FormatOwnedWithRule::new( - self, - crate::statement::stmt_async_function_def::FormatStmtAsyncFunctionDef::default(), - ) - } -} - impl FormatRule> for crate::statement::stmt_class_def::FormatStmtClassDef { @@ -424,42 +388,6 @@ impl<'ast> IntoFormat> for ast::StmtFor { } } -impl FormatRule> - for crate::statement::stmt_async_for::FormatStmtAsyncFor -{ - #[inline] - fn fmt(&self, node: &ast::StmtAsyncFor, f: &mut PyFormatter) -> FormatResult<()> { - FormatNodeRule::::fmt(self, node, f) - } -} -impl<'ast> AsFormat> for ast::StmtAsyncFor { - type Format<'a> = FormatRefWithRule< - 'a, - ast::StmtAsyncFor, - crate::statement::stmt_async_for::FormatStmtAsyncFor, - PyFormatContext<'ast>, - >; - fn format(&self) -> Self::Format<'_> { - FormatRefWithRule::new( - self, - crate::statement::stmt_async_for::FormatStmtAsyncFor::default(), - ) - } -} -impl<'ast> IntoFormat> for ast::StmtAsyncFor { - type Format = FormatOwnedWithRule< - ast::StmtAsyncFor, - crate::statement::stmt_async_for::FormatStmtAsyncFor, - PyFormatContext<'ast>, - >; - fn into_format(self) -> Self::Format { - FormatOwnedWithRule::new( - self, - crate::statement::stmt_async_for::FormatStmtAsyncFor::default(), - ) - } -} - impl FormatRule> for crate::statement::stmt_while::FormatStmtWhile { @@ -554,42 +482,6 @@ impl<'ast> IntoFormat> for ast::StmtWith { } } -impl FormatRule> - for crate::statement::stmt_async_with::FormatStmtAsyncWith -{ - #[inline] - fn fmt(&self, node: &ast::StmtAsyncWith, f: &mut PyFormatter) -> FormatResult<()> { - FormatNodeRule::::fmt(self, node, f) - } -} -impl<'ast> AsFormat> for ast::StmtAsyncWith { - type Format<'a> = FormatRefWithRule< - 'a, - ast::StmtAsyncWith, - crate::statement::stmt_async_with::FormatStmtAsyncWith, - PyFormatContext<'ast>, - >; - fn format(&self) -> Self::Format<'_> { - FormatRefWithRule::new( - self, - crate::statement::stmt_async_with::FormatStmtAsyncWith::default(), - ) - } -} -impl<'ast> IntoFormat> for ast::StmtAsyncWith { - type Format = FormatOwnedWithRule< - ast::StmtAsyncWith, - crate::statement::stmt_async_with::FormatStmtAsyncWith, - PyFormatContext<'ast>, - >; - fn into_format(self) -> Self::Format { - FormatOwnedWithRule::new( - self, - crate::statement::stmt_async_with::FormatStmtAsyncWith::default(), - ) - } -} - impl FormatRule> for crate::statement::stmt_match::FormatStmtMatch { diff --git a/crates/ruff_python_formatter/src/statement/mod.rs b/crates/ruff_python_formatter/src/statement/mod.rs index 114d771528..1a92f03e39 100644 --- a/crates/ruff_python_formatter/src/statement/mod.rs +++ b/crates/ruff_python_formatter/src/statement/mod.rs @@ -5,9 +5,6 @@ use ruff_python_ast::Stmt; pub(crate) mod stmt_ann_assign; pub(crate) mod stmt_assert; pub(crate) mod stmt_assign; -pub(crate) mod stmt_async_for; -pub(crate) mod stmt_async_function_def; -pub(crate) mod stmt_async_with; pub(crate) mod stmt_aug_assign; pub(crate) mod stmt_break; pub(crate) mod stmt_class_def; @@ -40,7 +37,6 @@ impl FormatRule> for FormatStmt { fn fmt(&self, item: &Stmt, f: &mut PyFormatter) -> FormatResult<()> { match item { Stmt::FunctionDef(x) => x.format().fmt(f), - Stmt::AsyncFunctionDef(x) => x.format().fmt(f), Stmt::ClassDef(x) => x.format().fmt(f), Stmt::Return(x) => x.format().fmt(f), Stmt::Delete(x) => x.format().fmt(f), @@ -48,11 +44,9 @@ impl FormatRule> for FormatStmt { Stmt::AugAssign(x) => x.format().fmt(f), Stmt::AnnAssign(x) => x.format().fmt(f), Stmt::For(x) => x.format().fmt(f), - Stmt::AsyncFor(x) => x.format().fmt(f), Stmt::While(x) => x.format().fmt(f), Stmt::If(x) => x.format().fmt(f), Stmt::With(x) => x.format().fmt(f), - Stmt::AsyncWith(x) => x.format().fmt(f), Stmt::Match(x) => x.format().fmt(f), Stmt::Raise(x) => x.format().fmt(f), Stmt::Try(x) => x.format().fmt(f), diff --git a/crates/ruff_python_formatter/src/statement/stmt_async_for.rs b/crates/ruff_python_formatter/src/statement/stmt_async_for.rs deleted file mode 100644 index f47b3452e2..0000000000 --- a/crates/ruff_python_formatter/src/statement/stmt_async_for.rs +++ /dev/null @@ -1,23 +0,0 @@ -use crate::prelude::*; -use crate::statement::stmt_for::AnyStatementFor; -use crate::{FormatNodeRule, PyFormatter}; -use ruff_formatter::FormatResult; -use ruff_python_ast::StmtAsyncFor; - -#[derive(Default)] -pub struct FormatStmtAsyncFor; - -impl FormatNodeRule for FormatStmtAsyncFor { - fn fmt_fields(&self, item: &StmtAsyncFor, f: &mut PyFormatter) -> FormatResult<()> { - AnyStatementFor::from(item).fmt(f) - } - - fn fmt_dangling_comments( - &self, - _node: &StmtAsyncFor, - _f: &mut PyFormatter, - ) -> FormatResult<()> { - // Handled in `fmt_fields` - Ok(()) - } -} diff --git a/crates/ruff_python_formatter/src/statement/stmt_async_function_def.rs b/crates/ruff_python_formatter/src/statement/stmt_async_function_def.rs deleted file mode 100644 index 1ba23722e5..0000000000 --- a/crates/ruff_python_formatter/src/statement/stmt_async_function_def.rs +++ /dev/null @@ -1,24 +0,0 @@ -use ruff_python_ast::StmtAsyncFunctionDef; - -use ruff_python_ast::function::AnyFunctionDefinition; - -use crate::prelude::*; -use crate::FormatNodeRule; - -#[derive(Default)] -pub struct FormatStmtAsyncFunctionDef; - -impl FormatNodeRule for FormatStmtAsyncFunctionDef { - fn fmt_fields(&self, item: &StmtAsyncFunctionDef, f: &mut PyFormatter) -> FormatResult<()> { - AnyFunctionDefinition::from(item).format().fmt(f) - } - - fn fmt_dangling_comments( - &self, - _node: &StmtAsyncFunctionDef, - _f: &mut PyFormatter, - ) -> FormatResult<()> { - // Handled by `AnyFunctionDef` - Ok(()) - } -} diff --git a/crates/ruff_python_formatter/src/statement/stmt_async_with.rs b/crates/ruff_python_formatter/src/statement/stmt_async_with.rs deleted file mode 100644 index fdd89fdb2f..0000000000 --- a/crates/ruff_python_formatter/src/statement/stmt_async_with.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::prelude::*; -use crate::statement::stmt_with::AnyStatementWith; -use crate::FormatNodeRule; -use ruff_python_ast::StmtAsyncWith; - -#[derive(Default)] -pub struct FormatStmtAsyncWith; - -impl FormatNodeRule for FormatStmtAsyncWith { - fn fmt_fields(&self, item: &StmtAsyncWith, f: &mut PyFormatter) -> FormatResult<()> { - AnyStatementWith::from(item).fmt(f) - } - - fn fmt_dangling_comments( - &self, - _node: &StmtAsyncWith, - _f: &mut PyFormatter, - ) -> FormatResult<()> { - // Handled in `fmt_fields` - Ok(()) - } -} diff --git a/crates/ruff_python_formatter/src/statement/stmt_for.rs b/crates/ruff_python_formatter/src/statement/stmt_for.rs index 957fd563c7..e9cac17d23 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_for.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_for.rs @@ -1,13 +1,12 @@ +use ruff_formatter::{format_args, write, Buffer, FormatResult}; +use ruff_python_ast::{Expr, Ranged, Stmt, StmtFor}; + use crate::comments::{leading_alternate_branch_comments, trailing_comments}; use crate::expression::expr_tuple::TupleParentheses; use crate::expression::maybe_parenthesize_expression; use crate::expression::parentheses::Parenthesize; use crate::prelude::*; use crate::{FormatNodeRule, PyFormatter}; -use ruff_formatter::{format_args, write, Buffer, FormatResult}; -use ruff_python_ast::node::AnyNodeRef; -use ruff_python_ast::{Expr, Ranged, Stmt, StmtAsyncFor, StmtFor, Suite}; -use ruff_text_size::TextRange; #[derive(Debug)] struct ExprTupleWithoutParentheses<'a>(&'a Expr); @@ -27,85 +26,19 @@ impl Format> for ExprTupleWithoutParentheses<'_> { #[derive(Default)] pub struct FormatStmtFor; -pub(super) enum AnyStatementFor<'a> { - For(&'a StmtFor), - AsyncFor(&'a StmtAsyncFor), -} - -impl<'a> AnyStatementFor<'a> { - const fn is_async(&self) -> bool { - matches!(self, AnyStatementFor::AsyncFor(_)) - } - - fn target(&self) -> &Expr { - match self { - AnyStatementFor::For(stmt) => &stmt.target, - AnyStatementFor::AsyncFor(stmt) => &stmt.target, - } - } - - #[allow(clippy::iter_not_returning_iterator)] - fn iter(&self) -> &Expr { - match self { - AnyStatementFor::For(stmt) => &stmt.iter, - AnyStatementFor::AsyncFor(stmt) => &stmt.iter, - } - } - - fn body(&self) -> &Suite { - match self { - AnyStatementFor::For(stmt) => &stmt.body, - AnyStatementFor::AsyncFor(stmt) => &stmt.body, - } - } - - fn orelse(&self) -> &Suite { - match self { - AnyStatementFor::For(stmt) => &stmt.orelse, - AnyStatementFor::AsyncFor(stmt) => &stmt.orelse, - } - } -} - -impl Ranged for AnyStatementFor<'_> { - fn range(&self) -> TextRange { - match self { - AnyStatementFor::For(stmt) => stmt.range(), - AnyStatementFor::AsyncFor(stmt) => stmt.range(), - } - } -} - -impl<'a> From<&'a StmtFor> for AnyStatementFor<'a> { - fn from(value: &'a StmtFor) -> Self { - AnyStatementFor::For(value) - } -} - -impl<'a> From<&'a StmtAsyncFor> for AnyStatementFor<'a> { - fn from(value: &'a StmtAsyncFor) -> Self { - AnyStatementFor::AsyncFor(value) - } -} - -impl<'a> From<&AnyStatementFor<'a>> for AnyNodeRef<'a> { - fn from(value: &AnyStatementFor<'a>) -> Self { - match value { - AnyStatementFor::For(stmt) => AnyNodeRef::StmtFor(stmt), - AnyStatementFor::AsyncFor(stmt) => AnyNodeRef::StmtAsyncFor(stmt), - } - } -} - -impl Format> for AnyStatementFor<'_> { - fn fmt(&self, f: &mut PyFormatter) -> FormatResult<()> { - let target = self.target(); - let iter = self.iter(); - let body = self.body(); - let orelse = self.orelse(); +impl FormatNodeRule for FormatStmtFor { + fn fmt_fields(&self, item: &StmtFor, f: &mut PyFormatter) -> FormatResult<()> { + let StmtFor { + is_async, + target, + iter, + body, + orelse, + range: _, + } = item; let comments = f.context().comments().clone(); - let dangling_comments = comments.dangling_comments(self); + let dangling_comments = comments.dangling_comments(item); let body_start = body.first().map_or(iter.end(), Stmt::start); let or_else_comments_start = dangling_comments.partition_point(|comment| comment.slice().end() < body_start); @@ -116,15 +49,14 @@ impl Format> for AnyStatementFor<'_> { write!( f, [ - self.is_async() - .then_some(format_args![text("async"), space()]), + is_async.then_some(format_args![text("async"), space()]), text("for"), space(), ExprTupleWithoutParentheses(target), space(), text("in"), space(), - maybe_parenthesize_expression(iter, self, Parenthesize::IfBreaks), + maybe_parenthesize_expression(iter, item, Parenthesize::IfBreaks), text(":"), trailing_comments(trailing_condition_comments), block_indent(&body.format()) @@ -153,12 +85,6 @@ impl Format> for AnyStatementFor<'_> { Ok(()) } -} - -impl FormatNodeRule for FormatStmtFor { - fn fmt_fields(&self, item: &StmtFor, f: &mut PyFormatter) -> FormatResult<()> { - AnyStatementFor::from(item).fmt(f) - } fn fmt_dangling_comments(&self, _node: &StmtFor, _f: &mut PyFormatter) -> FormatResult<()> { // Handled in `fmt_fields` diff --git a/crates/ruff_python_formatter/src/statement/stmt_function_def.rs b/crates/ruff_python_formatter/src/statement/stmt_function_def.rs index 37100ad626..80a8da56ca 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_function_def.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_function_def.rs @@ -1,11 +1,8 @@ +use ruff_formatter::write; use ruff_python_ast::{Ranged, StmtFunctionDef}; - -use ruff_formatter::{write, FormatOwnedWithRule, FormatRefWithRule}; -use ruff_python_ast::function::AnyFunctionDefinition; use ruff_python_trivia::lines_after_ignoring_trivia; use crate::comments::{leading_comments, trailing_comments}; - use crate::expression::parentheses::{optional_parentheses, Parentheses}; use crate::prelude::*; use crate::statement::suite::SuiteKind; @@ -16,24 +13,6 @@ pub struct FormatStmtFunctionDef; impl FormatNodeRule for FormatStmtFunctionDef { fn fmt_fields(&self, item: &StmtFunctionDef, f: &mut PyFormatter) -> FormatResult<()> { - AnyFunctionDefinition::from(item).format().fmt(f) - } - - fn fmt_dangling_comments( - &self, - _node: &StmtFunctionDef, - _f: &mut PyFormatter, - ) -> FormatResult<()> { - // Handled by `AnyFunctionDef` - Ok(()) - } -} - -#[derive(Default)] -pub struct FormatAnyFunctionDef; - -impl FormatRule, PyFormatContext<'_>> for FormatAnyFunctionDef { - fn fmt(&self, item: &AnyFunctionDefinition<'_>, f: &mut PyFormatter) -> FormatResult<()> { let comments = f.context().comments().clone(); let dangling_comments = comments.dangling_comments(item); @@ -43,9 +22,9 @@ impl FormatRule, PyFormatContext<'_>> for FormatAnyFun let (leading_definition_comments, trailing_definition_comments) = dangling_comments.split_at(trailing_definition_comments_start); - if let Some(last_decorator) = item.decorators().last() { + if let Some(last_decorator) = item.decorator_list.last() { f.join_with(hard_line_break()) - .entries(item.decorators().iter().formatted()) + .entries(item.decorator_list.iter().formatted()) .finish()?; if leading_definition_comments.is_empty() { @@ -69,21 +48,19 @@ impl FormatRule, PyFormatContext<'_>> for FormatAnyFun } } - if item.is_async() { + if item.is_async { write!(f, [text("async"), space()])?; } - let name = item.name(); + write!(f, [text("def"), space(), item.name.format()])?; - write!(f, [text("def"), space(), name.format()])?; - - if let Some(type_params) = item.type_params() { + if let Some(type_params) = item.type_params.as_ref() { write!(f, [type_params.format()])?; } - write!(f, [item.arguments().format()])?; + write!(f, [item.parameters.format()])?; - if let Some(return_annotation) = item.returns() { + if let Some(return_annotation) = item.returns.as_ref() { write!( f, [ @@ -102,33 +79,17 @@ impl FormatRule, PyFormatContext<'_>> for FormatAnyFun [ text(":"), trailing_comments(trailing_definition_comments), - block_indent(&item.body().format().with_options(SuiteKind::Function)) + block_indent(&item.body.format().with_options(SuiteKind::Function)) ] ) } -} -impl<'def, 'ast> AsFormat> for AnyFunctionDefinition<'def> { - type Format<'a> = FormatRefWithRule< - 'a, - AnyFunctionDefinition<'def>, - FormatAnyFunctionDef, - PyFormatContext<'ast>, - > where Self: 'a; - - fn format(&self) -> Self::Format<'_> { - FormatRefWithRule::new(self, FormatAnyFunctionDef) - } -} - -impl<'def, 'ast> IntoFormat> for AnyFunctionDefinition<'def> { - type Format = FormatOwnedWithRule< - AnyFunctionDefinition<'def>, - FormatAnyFunctionDef, - PyFormatContext<'ast>, - >; - - fn into_format(self) -> Self::Format { - FormatOwnedWithRule::new(self, FormatAnyFunctionDef) + fn fmt_dangling_comments( + &self, + _node: &StmtFunctionDef, + _f: &mut PyFormatter, + ) -> FormatResult<()> { + // Handled in `fmt_fields` + Ok(()) } } diff --git a/crates/ruff_python_formatter/src/statement/stmt_with.rs b/crates/ruff_python_formatter/src/statement/stmt_with.rs index 828d0e3536..b8afb2042d 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_with.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_with.rs @@ -1,9 +1,7 @@ -use ruff_python_ast::{Ranged, StmtAsyncWith, StmtWith, Suite, WithItem}; -use ruff_text_size::TextRange; - use ruff_formatter::{format_args, write, FormatError}; -use ruff_python_ast::node::AnyNodeRef; +use ruff_python_ast::{Ranged, StmtWith}; use ruff_python_trivia::{SimpleTokenKind, SimpleTokenizer}; +use ruff_text_size::TextRange; use crate::comments::trailing_comments; use crate::expression::parentheses::{ @@ -12,81 +10,29 @@ use crate::expression::parentheses::{ use crate::prelude::*; use crate::FormatNodeRule; -pub(super) enum AnyStatementWith<'a> { - With(&'a StmtWith), - AsyncWith(&'a StmtAsyncWith), -} +#[derive(Default)] +pub struct FormatStmtWith; -impl<'a> AnyStatementWith<'a> { - const fn is_async(&self) -> bool { - matches!(self, AnyStatementWith::AsyncWith(_)) - } - - fn items(&self) -> &[WithItem] { - match self { - AnyStatementWith::With(with) => with.items.as_slice(), - AnyStatementWith::AsyncWith(with) => with.items.as_slice(), - } - } - - fn body(&self) -> &Suite { - match self { - AnyStatementWith::With(with) => &with.body, - AnyStatementWith::AsyncWith(with) => &with.body, - } - } -} - -impl Ranged for AnyStatementWith<'_> { - fn range(&self) -> TextRange { - match self { - AnyStatementWith::With(with) => with.range(), - AnyStatementWith::AsyncWith(with) => with.range(), - } - } -} - -impl<'a> From<&'a StmtWith> for AnyStatementWith<'a> { - fn from(value: &'a StmtWith) -> Self { - AnyStatementWith::With(value) - } -} - -impl<'a> From<&'a StmtAsyncWith> for AnyStatementWith<'a> { - fn from(value: &'a StmtAsyncWith) -> Self { - AnyStatementWith::AsyncWith(value) - } -} - -impl<'a> From<&AnyStatementWith<'a>> for AnyNodeRef<'a> { - fn from(value: &AnyStatementWith<'a>) -> Self { - match value { - AnyStatementWith::With(with) => AnyNodeRef::StmtWith(with), - AnyStatementWith::AsyncWith(with) => AnyNodeRef::StmtAsyncWith(with), - } - } -} - -impl Format> for AnyStatementWith<'_> { - fn fmt(&self, f: &mut PyFormatter) -> FormatResult<()> { +impl FormatNodeRule for FormatStmtWith { + fn fmt_fields(&self, item: &StmtWith, f: &mut PyFormatter) -> FormatResult<()> { let comments = f.context().comments().clone(); - let dangling_comments = comments.dangling_comments(self); + let dangling_comments = comments.dangling_comments(item); write!( f, [ - self.is_async() + item.is_async .then_some(format_args![text("async"), space()]), text("with"), space() ] )?; - if are_with_items_parenthesized(self, f.context())? { + if are_with_items_parenthesized(item, f.context())? { optional_parentheses(&format_with(|f| { - let mut joiner = f.join_comma_separated(self.body().first().unwrap().start()); + let mut joiner = f.join_comma_separated(item.body.first().unwrap().start()); - for item in self.items() { + for item in &item.items { joiner.entry_with_line_separator( item, &item.format(), @@ -98,7 +44,7 @@ impl Format> for AnyStatementWith<'_> { .fmt(f)?; } else { f.join_with(format_args![text(","), space()]) - .entries(self.items().iter().formatted()) + .entries(item.items.iter().formatted()) .finish()?; } @@ -107,18 +53,20 @@ impl Format> for AnyStatementWith<'_> { [ text(":"), trailing_comments(dangling_comments), - block_indent(&self.body().format()) + block_indent(&item.body.format()) ] ) } + + fn fmt_dangling_comments(&self, _node: &StmtWith, _f: &mut PyFormatter) -> FormatResult<()> { + // Handled in `fmt_fields` + Ok(()) + } } -fn are_with_items_parenthesized( - with: &AnyStatementWith, - context: &PyFormatContext, -) -> FormatResult { +fn are_with_items_parenthesized(with: &StmtWith, context: &PyFormatContext) -> FormatResult { let first_with_item = with - .items() + .items .first() .ok_or(FormatError::syntax_error("Expected at least one with item"))?; let before_first_with_item = TextRange::new(with.start(), first_with_item.start()); @@ -145,17 +93,3 @@ fn are_with_items_parenthesized( None => Ok(false), } } - -#[derive(Default)] -pub struct FormatStmtWith; - -impl FormatNodeRule for FormatStmtWith { - fn fmt_fields(&self, item: &StmtWith, f: &mut PyFormatter) -> FormatResult<()> { - AnyStatementWith::from(item).fmt(f) - } - - fn fmt_dangling_comments(&self, _node: &StmtWith, _f: &mut PyFormatter) -> FormatResult<()> { - // Handled in `fmt_fields` - Ok(()) - } -} diff --git a/crates/ruff_python_formatter/src/statement/suite.rs b/crates/ruff_python_formatter/src/statement/suite.rs index 771ad26036..ad3da3c927 100644 --- a/crates/ruff_python_formatter/src/statement/suite.rs +++ b/crates/ruff_python_formatter/src/statement/suite.rs @@ -210,10 +210,7 @@ impl FormatRule> for FormatSuite { /// Returns `true` if a [`Stmt`] is a class or function definition. const fn is_class_or_function_definition(stmt: &Stmt) -> bool { - matches!( - stmt, - Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) | Stmt::ClassDef(_) - ) + matches!(stmt, Stmt::FunctionDef(_) | Stmt::ClassDef(_)) } /// Returns `true` if a [`Stmt`] is an import. diff --git a/crates/ruff_python_parser/src/python.lalrpop b/crates/ruff_python_parser/src/python.lalrpop index fb75124ea0..3416ab6c09 100644 --- a/crates/ruff_python_parser/src/python.lalrpop +++ b/crates/ruff_python_parser/src/python.lalrpop @@ -858,11 +858,7 @@ ForStatement: ast::Stmt = { .end(); let target = Box::new(set_context(target, ast::ExprContext::Store)); let iter = Box::new(iter); - if is_async.is_some() { - ast::Stmt::AsyncFor(ast::StmtAsyncFor { target, iter, body, orelse, range: (location..end_location).into() }) - } else { - ast::Stmt::For(ast::StmtFor { target, iter, body, orelse, range: (location..end_location).into() }) - } + ast::Stmt::For(ast::StmtFor { target, iter, body, orelse, is_async: is_async.is_some(), range: (location..end_location).into() }) }, }; @@ -975,11 +971,7 @@ ExceptClause: ast::ExceptHandler = { WithStatement: ast::Stmt = { "with" ":" => { let end_location = body.last().unwrap().end(); - if is_async.is_some() { - ast::StmtAsyncWith { items, body, range: (location..end_location).into() }.into() - } else { - ast::StmtWith { items, body, range: (location..end_location).into() }.into() - } + ast::StmtWith { items, body, is_async: is_async.is_some(), range: (location..end_location).into() }.into() }, }; @@ -1014,11 +1006,7 @@ FuncDef: ast::Stmt = { let args = Box::new(args); let returns = r.map(Box::new); let end_location = body.last().unwrap().end(); - if is_async.is_some() { - ast::StmtAsyncFunctionDef { name, parameters:args, body, decorator_list, returns, type_params, range: (location..end_location).into() }.into() - } else { - ast::StmtFunctionDef { name, parameters:args, body, decorator_list, returns, type_params, range: (location..end_location).into() }.into() - } + ast::StmtFunctionDef { name, parameters:args, body, decorator_list, returns, type_params, is_async: is_async.is_some(), range: (location..end_location).into() }.into() }, }; diff --git a/crates/ruff_python_parser/src/python.rs b/crates/ruff_python_parser/src/python.rs index 8577cb525b..4ca7af8dfe 100644 --- a/crates/ruff_python_parser/src/python.rs +++ b/crates/ruff_python_parser/src/python.rs @@ -1,5 +1,5 @@ // auto-generated: "lalrpop 0.20.0" -// sha3: f99d8cb29227bfbe1fa07719f655304a9a93fd4715726687ef40c091adbdbad5 +// sha3: d713a7771107f8c20353ce5e890fba004b3c5491f513d28e9348a49cd510c59b use num_bigint::BigInt; use ruff_text_size::TextSize; use ruff_python_ast::{self as ast, Ranged, MagicKind}; @@ -33081,11 +33081,7 @@ fn __action147< .end(); let target = Box::new(set_context(target, ast::ExprContext::Store)); let iter = Box::new(iter); - if is_async.is_some() { - ast::Stmt::AsyncFor(ast::StmtAsyncFor { target, iter, body, orelse, range: (location..end_location).into() }) - } else { - ast::Stmt::For(ast::StmtFor { target, iter, body, orelse, range: (location..end_location).into() }) - } + ast::Stmt::For(ast::StmtFor { target, iter, body, orelse, is_async: is_async.is_some(), range: (location..end_location).into() }) } } @@ -33306,11 +33302,7 @@ fn __action155< { { let end_location = body.last().unwrap().end(); - if is_async.is_some() { - ast::StmtAsyncWith { items, body, range: (location..end_location).into() }.into() - } else { - ast::StmtWith { items, body, range: (location..end_location).into() }.into() - } + ast::StmtWith { items, body, is_async: is_async.is_some(), range: (location..end_location).into() }.into() } } @@ -33407,11 +33399,7 @@ fn __action161< let args = Box::new(args); let returns = r.map(Box::new); let end_location = body.last().unwrap().end(); - if is_async.is_some() { - ast::StmtAsyncFunctionDef { name, parameters:args, body, decorator_list, returns, type_params, range: (location..end_location).into() }.into() - } else { - ast::StmtFunctionDef { name, parameters:args, body, decorator_list, returns, type_params, range: (location..end_location).into() }.into() - } + ast::StmtFunctionDef { name, parameters:args, body, decorator_list, returns, type_params, is_async: is_async.is_some(), range: (location..end_location).into() }.into() } } diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__context__tests__assign_for.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__context__tests__assign_for.snap index fb74fe3e48..2d4bf6ba9b 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__context__tests__assign_for.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__context__tests__assign_for.snap @@ -6,6 +6,7 @@ expression: parse_ast For( StmtFor { range: 0..24, + is_async: false, target: Name( ExprName { range: 4..5, diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__context__tests__assign_with.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__context__tests__assign_with.snap index c77adbe5f3..f977b37549 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__context__tests__assign_with.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__context__tests__assign_with.snap @@ -6,6 +6,7 @@ expression: parse_ast With( StmtWith { range: 0..17, + is_async: false, items: [ WithItem { range: 5..11, diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_kw_only_args.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_kw_only_args.snap index 4e6cabe7e1..ad7dbadf9a 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_kw_only_args.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_kw_only_args.snap @@ -7,6 +7,7 @@ Ok( FunctionDef( StmtFunctionDef { range: 0..23, + is_async: false, decorator_list: [], name: Identifier { id: "f", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_kw_only_args_with_defaults.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_kw_only_args_with_defaults.snap index 2008962778..33a05a02ec 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_kw_only_args_with_defaults.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_kw_only_args_with_defaults.snap @@ -7,6 +7,7 @@ Ok( FunctionDef( StmtFunctionDef { range: 0..29, + is_async: false, decorator_list: [], name: Identifier { id: "f", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_no_args.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_no_args.snap index 2c5ceaf090..4e2121628d 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_no_args.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_no_args.snap @@ -7,6 +7,7 @@ Ok( FunctionDef( StmtFunctionDef { range: 0..13, + is_async: false, decorator_list: [], name: Identifier { id: "f", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_no_args_with_ranges.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_no_args_with_ranges.snap index 2c5ceaf090..4e2121628d 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_no_args_with_ranges.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_no_args_with_ranges.snap @@ -7,6 +7,7 @@ Ok( FunctionDef( StmtFunctionDef { range: 0..13, + is_async: false, decorator_list: [], name: Identifier { id: "f", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args.snap index 5e32c2421e..cd12db1311 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args.snap @@ -7,6 +7,7 @@ Ok( FunctionDef( StmtFunctionDef { range: 0..32, + is_async: false, decorator_list: [], name: Identifier { id: "f", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults.snap index fe9ac3afcf..4276b8915b 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults.snap @@ -7,6 +7,7 @@ Ok( FunctionDef( StmtFunctionDef { range: 0..38, + is_async: false, decorator_list: [], name: Identifier { id: "f", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults_and_varargs.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults_and_varargs.snap index 8944370563..ba1ece7a57 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults_and_varargs.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults_and_varargs.snap @@ -7,6 +7,7 @@ Ok( FunctionDef( StmtFunctionDef { range: 0..42, + is_async: false, decorator_list: [], name: Identifier { id: "f", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults_and_varargs_and_kwargs.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults_and_varargs_and_kwargs.snap index e5ef63795f..dc2888c7bd 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults_and_varargs_and_kwargs.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_and_kw_only_args_with_defaults_and_varargs_and_kwargs.snap @@ -7,6 +7,7 @@ Ok( FunctionDef( StmtFunctionDef { range: 0..52, + is_async: false, decorator_list: [], name: Identifier { id: "f", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args.snap index 8a969a82e6..788b1b7a49 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args.snap @@ -7,6 +7,7 @@ Ok( FunctionDef( StmtFunctionDef { range: 0..20, + is_async: false, decorator_list: [], name: Identifier { id: "f", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args_with_defaults.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args_with_defaults.snap index 4f369f862b..7c0459ab6c 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args_with_defaults.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args_with_defaults.snap @@ -7,6 +7,7 @@ Ok( FunctionDef( StmtFunctionDef { range: 0..26, + is_async: false, decorator_list: [], name: Identifier { id: "f", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args_with_ranges.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args_with_ranges.snap index 8a969a82e6..788b1b7a49 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args_with_ranges.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__function__tests__function_pos_args_with_ranges.snap @@ -7,6 +7,7 @@ Ok( FunctionDef( StmtFunctionDef { range: 0..20, + is_async: false, decorator_list: [], name: Identifier { id: "f", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__decorator_ranges.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__decorator_ranges.snap index b976725751..160699606e 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__decorator_ranges.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__decorator_ranges.snap @@ -6,6 +6,7 @@ expression: parse_ast FunctionDef( StmtFunctionDef { range: 0..34, + is_async: false, decorator_list: [ Decorator { range: 0..13, diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__jupyter_magic.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__jupyter_magic.snap index 895867df6e..ba9adb4a07 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__jupyter_magic.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__jupyter_magic.snap @@ -125,6 +125,7 @@ Module( FunctionDef( StmtFunctionDef { range: 566..626, + is_async: false, decorator_list: [], name: Identifier { id: "foo", @@ -199,6 +200,7 @@ Module( For( StmtFor { range: 701..727, + is_async: false, target: Name( ExprName { range: 705..706, diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__parse_class.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__parse_class.snap index 9ac6469514..601df87ff8 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__parse_class.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__parse_class.snap @@ -38,6 +38,7 @@ expression: "parse_suite(source, \"\").unwrap()" FunctionDef( StmtFunctionDef { range: 18..44, + is_async: false, decorator_list: [], name: Identifier { id: "__init__", @@ -78,6 +79,7 @@ expression: "parse_suite(source, \"\").unwrap()" FunctionDef( StmtFunctionDef { range: 46..98, + is_async: false, decorator_list: [], name: Identifier { id: "method_with_default", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__parse_function_definition.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__parse_function_definition.snap index 1c03ac6610..cf0f7df886 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__parse_function_definition.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__parse_function_definition.snap @@ -6,6 +6,7 @@ expression: "parse_suite(source, \"\").unwrap()" FunctionDef( StmtFunctionDef { range: 0..20, + is_async: false, decorator_list: [], name: Identifier { id: "func", @@ -53,6 +54,7 @@ expression: "parse_suite(source, \"\").unwrap()" FunctionDef( StmtFunctionDef { range: 22..53, + is_async: false, decorator_list: [], name: Identifier { id: "func", @@ -132,6 +134,7 @@ expression: "parse_suite(source, \"\").unwrap()" FunctionDef( StmtFunctionDef { range: 55..91, + is_async: false, decorator_list: [], name: Identifier { id: "func", @@ -219,6 +222,7 @@ expression: "parse_suite(source, \"\").unwrap()" FunctionDef( StmtFunctionDef { range: 93..138, + is_async: false, decorator_list: [], name: Identifier { id: "func", @@ -321,6 +325,7 @@ expression: "parse_suite(source, \"\").unwrap()" FunctionDef( StmtFunctionDef { range: 140..171, + is_async: false, decorator_list: [], name: Identifier { id: "func", @@ -393,6 +398,7 @@ expression: "parse_suite(source, \"\").unwrap()" FunctionDef( StmtFunctionDef { range: 173..230, + is_async: false, decorator_list: [], name: Identifier { id: "func", @@ -496,6 +502,7 @@ expression: "parse_suite(source, \"\").unwrap()" FunctionDef( StmtFunctionDef { range: 232..273, + is_async: false, decorator_list: [], name: Identifier { id: "func", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__variadic_generics.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__variadic_generics.snap index ec38d5c16b..891e02d914 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__variadic_generics.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__variadic_generics.snap @@ -6,6 +6,7 @@ expression: parse_ast FunctionDef( StmtFunctionDef { range: 1..49, + is_async: false, decorator_list: [], name: Identifier { id: "args_to_tuple", diff --git a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__with_statement.snap b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__with_statement.snap index 668b5381fe..abae0fb4bc 100644 --- a/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__with_statement.snap +++ b/crates/ruff_python_parser/src/snapshots/ruff_python_parser__parser__tests__with_statement.snap @@ -1,11 +1,12 @@ --- source: crates/ruff_python_parser/src/parser.rs -expression: "ast::Suite::parse(source, \"\").unwrap()" +expression: "parse_suite(source, \"\").unwrap()" --- [ With( StmtWith { range: 0..12, + is_async: false, items: [ WithItem { range: 5..6, @@ -33,6 +34,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 13..30, + is_async: false, items: [ WithItem { range: 18..24, @@ -68,6 +70,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 31..46, + is_async: false, items: [ WithItem { range: 36..37, @@ -108,6 +111,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 47..72, + is_async: false, items: [ WithItem { range: 52..58, @@ -164,6 +168,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 73..97, + is_async: false, items: [ WithItem { range: 78..91, @@ -214,6 +219,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 98..127, + is_async: false, items: [ WithItem { range: 103..121, @@ -272,6 +278,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 128..141, + is_async: false, items: [ WithItem { range: 133..135, @@ -297,6 +304,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 142..160, + is_async: false, items: [ WithItem { range: 147..154, @@ -330,6 +338,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 161..175, + is_async: false, items: [ WithItem { range: 167..168, @@ -357,6 +366,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 176..195, + is_async: false, items: [ WithItem { range: 181..189, @@ -392,6 +402,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 196..211, + is_async: false, items: [ WithItem { range: 202..203, @@ -419,6 +430,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 212..232, + is_async: false, items: [ WithItem { range: 217..226, @@ -462,6 +474,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 233..250, + is_async: false, items: [ WithItem { range: 239..243, @@ -502,6 +515,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 251..273, + is_async: false, items: [ WithItem { range: 256..267, @@ -554,6 +568,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 274..290, + is_async: false, items: [ WithItem { range: 279..284, @@ -593,6 +608,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 291..312, + is_async: false, items: [ WithItem { range: 296..306, @@ -640,6 +656,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 313..331, + is_async: false, items: [ WithItem { range: 318..325, @@ -688,6 +705,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 332..355, + is_async: false, items: [ WithItem { range: 337..349, @@ -744,6 +762,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 356..375, + is_async: false, items: [ WithItem { range: 361..369, @@ -783,6 +802,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 376..400, + is_async: false, items: [ WithItem { range: 381..394, @@ -830,6 +850,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 401..428, + is_async: false, items: [ WithItem { range: 406..422, @@ -898,6 +919,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 429..461, + is_async: false, items: [ WithItem { range: 434..455, @@ -974,6 +996,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 462..481, + is_async: false, items: [ WithItem { range: 468..474, @@ -1009,6 +1032,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 482..502, + is_async: false, items: [ WithItem { range: 488..494, @@ -1044,6 +1068,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 503..530, + is_async: false, items: [ WithItem { range: 509..515, @@ -1100,6 +1125,7 @@ expression: "ast::Suite::parse(source, \"\").unwrap()" With( StmtWith { range: 531..559, + is_async: false, items: [ WithItem { range: 537..543, diff --git a/crates/ruff_python_semantic/src/analyze/visibility.rs b/crates/ruff_python_semantic/src/analyze/visibility.rs index bc2d536938..846e41fa67 100644 --- a/crates/ruff_python_semantic/src/analyze/visibility.rs +++ b/crates/ruff_python_semantic/src/analyze/visibility.rs @@ -178,8 +178,7 @@ impl ModuleSource<'_> { pub(crate) fn function_visibility(stmt: &Stmt) -> Visibility { match stmt { - Stmt::FunctionDef(ast::StmtFunctionDef { name, .. }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { name, .. }) => { + Stmt::FunctionDef(ast::StmtFunctionDef { name, .. }) => { if name.starts_with('_') { Visibility::Private } else { @@ -191,52 +190,45 @@ pub(crate) fn function_visibility(stmt: &Stmt) -> Visibility { } pub(crate) fn method_visibility(stmt: &Stmt) -> Visibility { - match stmt { - Stmt::FunctionDef(ast::StmtFunctionDef { - name, - decorator_list, - .. + let Stmt::FunctionDef(ast::StmtFunctionDef { + name, + decorator_list, + .. + }) = stmt + else { + panic!("Found non-FunctionDef in method_visibility") + }; + + // Is this a setter or deleter? + if decorator_list.iter().any(|decorator| { + collect_call_path(&decorator.expression).is_some_and(|call_path| { + call_path.as_slice() == [name, "setter"] || call_path.as_slice() == [name, "deleter"] }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { - name, - decorator_list, - .. - }) => { - // Is this a setter or deleter? - if decorator_list.iter().any(|decorator| { - collect_call_path(&decorator.expression).is_some_and(|call_path| { - call_path.as_slice() == [name, "setter"] - || call_path.as_slice() == [name, "deleter"] - }) - }) { - return Visibility::Private; - } - - // Is the method non-private? - if !name.starts_with('_') { - return Visibility::Public; - } - - // Is this a magic method? - if name.starts_with("__") && name.ends_with("__") { - return Visibility::Public; - } - - Visibility::Private - } - _ => panic!("Found non-FunctionDef in method_visibility"), + }) { + return Visibility::Private; } + + // Is the method non-private? + if !name.starts_with('_') { + return Visibility::Public; + } + + // Is this a magic method? + if name.starts_with("__") && name.ends_with("__") { + return Visibility::Public; + } + + Visibility::Private } pub(crate) fn class_visibility(stmt: &Stmt) -> Visibility { - match stmt { - Stmt::ClassDef(ast::StmtClassDef { name, .. }) => { - if name.starts_with('_') { - Visibility::Private - } else { - Visibility::Public - } - } - _ => panic!("Found non-ClassDef in function_visibility"), + let Stmt::ClassDef(ast::StmtClassDef { name, .. }) = stmt else { + panic!("Found non-ClassDef in class_visibility") + }; + + if name.starts_with('_') { + Visibility::Private + } else { + Visibility::Public } } diff --git a/crates/ruff_python_semantic/src/definition.rs b/crates/ruff_python_semantic/src/definition.rs index aec18d62f5..e3a5bfd649 100644 --- a/crates/ruff_python_semantic/src/definition.rs +++ b/crates/ruff_python_semantic/src/definition.rs @@ -84,7 +84,6 @@ impl<'a> Member<'a> { pub fn name(&self) -> Option<&'a str> { match &self.stmt { Stmt::FunctionDef(ast::StmtFunctionDef { name, .. }) - | Stmt::AsyncFunctionDef(ast::StmtAsyncFunctionDef { name, .. }) | Stmt::ClassDef(ast::StmtClassDef { name, .. }) => Some(name), _ => None, } diff --git a/crates/ruff_python_semantic/src/globals.rs b/crates/ruff_python_semantic/src/globals.rs index a17b79b6f8..ffaf1b16f9 100644 --- a/crates/ruff_python_semantic/src/globals.rs +++ b/crates/ruff_python_semantic/src/globals.rs @@ -80,7 +80,7 @@ impl<'a> StatementVisitor<'a> for GlobalsVisitor<'a> { self.0.insert(name.as_str(), *range); } } - Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) | Stmt::ClassDef(_) => { + Stmt::FunctionDef(_) | Stmt::ClassDef(_) => { // Don't recurse. } _ => walk_stmt(self, stmt), diff --git a/crates/ruff_python_semantic/src/model.rs b/crates/ruff_python_semantic/src/model.rs index 2dc7e6e69b..3ae1b06bb2 100644 --- a/crates/ruff_python_semantic/src/model.rs +++ b/crates/ruff_python_semantic/src/model.rs @@ -478,7 +478,7 @@ impl<'a> SemanticModel<'a> { } } - seen_function |= scope.kind.is_any_function(); + seen_function |= scope.kind.is_function(); import_starred = import_starred || scope.uses_star_imports(); } @@ -540,7 +540,7 @@ impl<'a> SemanticModel<'a> { } } - seen_function |= scope.kind.is_any_function(); + seen_function |= scope.kind.is_function(); } None @@ -1015,11 +1015,8 @@ impl<'a> SemanticModel<'a> { /// Return `true` if the model is in an async context. pub fn in_async_context(&self) -> bool { for scope in self.current_scopes() { - if scope.kind.is_async_function() { - return true; - } - if scope.kind.is_function() { - return false; + if let ScopeKind::Function(ast::StmtFunctionDef { is_async, .. }) = scope.kind { + return *is_async; } } false diff --git a/crates/ruff_python_semantic/src/scope.rs b/crates/ruff_python_semantic/src/scope.rs index b80d612784..3de3b058f9 100644 --- a/crates/ruff_python_semantic/src/scope.rs +++ b/crates/ruff_python_semantic/src/scope.rs @@ -178,19 +178,12 @@ bitflags! { pub enum ScopeKind<'a> { Class(&'a ast::StmtClassDef), Function(&'a ast::StmtFunctionDef), - AsyncFunction(&'a ast::StmtAsyncFunctionDef), Generator, Module, Type, Lambda(&'a ast::ExprLambda), } -impl ScopeKind<'_> { - pub const fn is_any_function(&self) -> bool { - matches!(self, ScopeKind::Function(_) | ScopeKind::AsyncFunction(_)) - } -} - /// Id uniquely identifying a scope in a program. /// /// Using a `u32` is sufficient because Ruff only supports parsing documents with a size of max `u32::max`