diff --git a/crates/ruff_python_formatter/src/statement/stmt_async_for.rs b/crates/ruff_python_formatter/src/statement/stmt_async_for.rs index 2718289fb8..906f19c5e0 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_async_for.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_async_for.rs @@ -1,5 +1,7 @@ -use crate::{not_yet_implemented, FormatNodeRule, PyFormatter}; -use ruff_formatter::{write, Buffer, FormatResult}; +use crate::prelude::*; +use crate::statement::stmt_for::AnyStatementFor; +use crate::{FormatNodeRule, PyFormatter}; +use ruff_formatter::FormatResult; use rustpython_parser::ast::StmtAsyncFor; #[derive(Default)] @@ -7,6 +9,15 @@ pub struct FormatStmtAsyncFor; impl FormatNodeRule for FormatStmtAsyncFor { fn fmt_fields(&self, item: &StmtAsyncFor, f: &mut PyFormatter) -> FormatResult<()> { - write!(f, [not_yet_implemented(item)]) + AnyStatementFor::from(item).fmt(f) + } + + fn fmt_dangling_comments( + &self, + _node: &StmtAsyncFor, + _f: &mut PyFormatter, + ) -> FormatResult<()> { + // Handled in `fmt_fields` + Ok(()) } } diff --git a/crates/ruff_python_formatter/src/statement/stmt_for.rs b/crates/ruff_python_formatter/src/statement/stmt_for.rs index b03283cabc..4d009ca818 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_for.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_for.rs @@ -4,9 +4,10 @@ use crate::expression::maybe_parenthesize_expression; use crate::expression::parentheses::Parenthesize; use crate::prelude::*; use crate::{FormatNodeRule, PyFormatter}; -use ruff_formatter::{write, Buffer, FormatResult}; -use ruff_python_ast::node::AstNode; -use rustpython_parser::ast::{Expr, Ranged, Stmt, StmtFor}; +use ruff_formatter::{format_args, write, Buffer, FormatResult}; +use ruff_python_ast::node::AnyNodeRef; +use ruff_text_size::TextRange; +use rustpython_parser::ast::{Expr, Ranged, Stmt, StmtAsyncFor, StmtFor, Suite}; #[derive(Debug)] struct ExprTupleWithoutParentheses<'a>(&'a Expr); @@ -26,19 +27,85 @@ impl Format> for ExprTupleWithoutParentheses<'_> { #[derive(Default)] pub struct FormatStmtFor; -impl FormatNodeRule for FormatStmtFor { - fn fmt_fields(&self, item: &StmtFor, f: &mut PyFormatter) -> FormatResult<()> { - let StmtFor { - range: _, - target, - iter, - body, - orelse, - type_comment: _, - } = item; +pub(super) enum AnyStatementFor<'a> { + For(&'a StmtFor), + AsyncFor(&'a StmtAsyncFor), +} + +impl<'a> AnyStatementFor<'a> { + const fn is_async(&self) -> bool { + matches!(self, AnyStatementFor::AsyncFor(_)) + } + + fn target(&self) -> &Expr { + match self { + AnyStatementFor::For(stmt) => &stmt.target, + AnyStatementFor::AsyncFor(stmt) => &stmt.target, + } + } + + #[allow(clippy::iter_not_returning_iterator)] + fn iter(&self) -> &Expr { + match self { + AnyStatementFor::For(stmt) => &stmt.iter, + AnyStatementFor::AsyncFor(stmt) => &stmt.iter, + } + } + + fn body(&self) -> &Suite { + match self { + AnyStatementFor::For(stmt) => &stmt.body, + AnyStatementFor::AsyncFor(stmt) => &stmt.body, + } + } + + fn orelse(&self) -> &Suite { + match self { + AnyStatementFor::For(stmt) => &stmt.orelse, + AnyStatementFor::AsyncFor(stmt) => &stmt.orelse, + } + } +} + +impl Ranged for AnyStatementFor<'_> { + fn range(&self) -> TextRange { + match self { + AnyStatementFor::For(stmt) => stmt.range(), + AnyStatementFor::AsyncFor(stmt) => stmt.range(), + } + } +} + +impl<'a> From<&'a StmtFor> for AnyStatementFor<'a> { + fn from(value: &'a StmtFor) -> Self { + AnyStatementFor::For(value) + } +} + +impl<'a> From<&'a StmtAsyncFor> for AnyStatementFor<'a> { + fn from(value: &'a StmtAsyncFor) -> Self { + AnyStatementFor::AsyncFor(value) + } +} + +impl<'a> From<&AnyStatementFor<'a>> for AnyNodeRef<'a> { + fn from(value: &AnyStatementFor<'a>) -> Self { + match value { + AnyStatementFor::For(stmt) => AnyNodeRef::StmtFor(stmt), + AnyStatementFor::AsyncFor(stmt) => AnyNodeRef::StmtAsyncFor(stmt), + } + } +} + +impl Format> for AnyStatementFor<'_> { + fn fmt(&self, f: &mut Formatter>) -> FormatResult<()> { + let target = self.target(); + let iter = self.iter(); + let body = self.body(); + let orelse = self.orelse(); let comments = f.context().comments().clone(); - let dangling_comments = comments.dangling_comments(item.as_any_node_ref()); + let dangling_comments = comments.dangling_comments(self); let body_start = body.first().map_or(iter.end(), Stmt::start); let or_else_comments_start = dangling_comments.partition_point(|comment| comment.slice().end() < body_start); @@ -49,13 +116,15 @@ impl FormatNodeRule for FormatStmtFor { write!( f, [ + self.is_async() + .then_some(format_args![text("async"), space()]), text("for"), space(), - ExprTupleWithoutParentheses(target.as_ref()), + ExprTupleWithoutParentheses(target), space(), text("in"), space(), - maybe_parenthesize_expression(iter, item, Parenthesize::IfBreaks), + maybe_parenthesize_expression(iter, self, Parenthesize::IfBreaks), text(":"), trailing_comments(trailing_condition_comments), block_indent(&body.format()) @@ -84,6 +153,12 @@ impl FormatNodeRule for FormatStmtFor { Ok(()) } +} + +impl FormatNodeRule for FormatStmtFor { + fn fmt_fields(&self, item: &StmtFor, f: &mut PyFormatter) -> FormatResult<()> { + AnyStatementFor::from(item).fmt(f) + } fn fmt_dangling_comments(&self, _node: &StmtFor, _f: &mut PyFormatter) -> FormatResult<()> { // Handled in `fmt_fields` diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__starred_for_target.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__starred_for_target.py.snap deleted file mode 100644 index dad8c213dc..0000000000 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_310__starred_for_target.py.snap +++ /dev/null @@ -1,123 +0,0 @@ ---- -source: crates/ruff_python_formatter/tests/fixtures.rs -input_file: crates/ruff_python_formatter/resources/test/fixtures/black/py_310/starred_for_target.py ---- -## Input - -```py -for x in *a, *b: - print(x) - -for x in a, b, *c: - print(x) - -for x in *a, b, c: - print(x) - -for x in *a, b, *c: - print(x) - -async for x in *a, *b: - print(x) - -async for x in *a, b, *c: - print(x) - -async for x in a, b, *c: - print(x) - -async for x in ( - *loooooooooooooooooooooong, - very, - *loooooooooooooooooooooooooooooooooooooooooooooooong, -): - print(x) -``` - -## Black Differences - -```diff ---- Black -+++ Ruff -@@ -10,18 +10,10 @@ - for x in *a, b, *c: - print(x) - --async for x in *a, *b: -- print(x) -+NOT_YET_IMPLEMENTED_StmtAsyncFor - --async for x in *a, b, *c: -- print(x) -+NOT_YET_IMPLEMENTED_StmtAsyncFor - --async for x in a, b, *c: -- print(x) -+NOT_YET_IMPLEMENTED_StmtAsyncFor - --async for x in ( -- *loooooooooooooooooooooong, -- very, -- *loooooooooooooooooooooooooooooooooooooooooooooooong, --): -- print(x) -+NOT_YET_IMPLEMENTED_StmtAsyncFor -``` - -## Ruff Output - -```py -for x in *a, *b: - print(x) - -for x in a, b, *c: - print(x) - -for x in *a, b, c: - print(x) - -for x in *a, b, *c: - print(x) - -NOT_YET_IMPLEMENTED_StmtAsyncFor - -NOT_YET_IMPLEMENTED_StmtAsyncFor - -NOT_YET_IMPLEMENTED_StmtAsyncFor - -NOT_YET_IMPLEMENTED_StmtAsyncFor -``` - -## Black Output - -```py -for x in *a, *b: - print(x) - -for x in a, b, *c: - print(x) - -for x in *a, b, c: - print(x) - -for x in *a, b, *c: - print(x) - -async for x in *a, *b: - print(x) - -async for x in *a, b, *c: - print(x) - -async for x in a, b, *c: - print(x) - -async for x in ( - *loooooooooooooooooooooong, - very, - *loooooooooooooooooooooooooooooooooooooooooooooooong, -): - print(x) -``` - - diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__fmtskip8.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__fmtskip8.py.snap index 4d65173c25..2a56e66b7b 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__fmtskip8.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__fmtskip8.py.snap @@ -74,7 +74,7 @@ async def test_async_with(): ```diff --- Black +++ Ruff -@@ -1,62 +1,62 @@ +@@ -1,62 +1,63 @@ # Make sure a leading comment is not removed. -def some_func( unformatted, args ): # fmt: skip +def some_func(unformatted, args): # fmt: skip @@ -129,8 +129,8 @@ async def test_async_with(): async def test_async_for(): - async for i in some_async_iter( unformatted, args ): # fmt: skip -- print("Do something") -+ NOT_YET_IMPLEMENTED_StmtAsyncFor # fmt: skip ++ async for i in some_async_iter(unformatted, args): # fmt: skip + print("Do something") -try : # fmt: skip @@ -203,7 +203,8 @@ for i in some_iter(unformatted, args): # fmt: skip async def test_async_for(): - NOT_YET_IMPLEMENTED_StmtAsyncFor # fmt: skip + async for i in some_async_iter(unformatted, args): # fmt: skip + print("Do something") try: