Add box to stmt

This commit is contained in:
Charlie Marsh
2025-12-11 11:24:31 -05:00
parent c9155d5e72
commit cc850ec348
123 changed files with 2523 additions and 2786 deletions

View File

@@ -77,6 +77,10 @@ fields = [{ name = "body", type = "Box<Expr>" }]
add_suffix_to_is_methods = true
anynode_is_label = "statement"
doc = "See also [stmt](https://docs.python.org/3/library/ast.html#ast.stmt)"
# Box all variants except the smallest ones (Pass, Break, Continue, Return, Expr)
# to reduce enum size from 128 bytes to ~24 bytes
box_variants = true
unboxed_variants = ["StmtPass", "StmtBreak", "StmtContinue", "StmtReturn", "StmtExpr"]
[Stmt.nodes.StmtFunctionDef]
doc = """See also [FunctionDef](https://docs.python.org/3/library/ast.html#ast.FunctionDef)

View File

@@ -122,6 +122,8 @@ class Group:
add_suffix_to_is_methods: bool
anynode_is_label: str
doc: str | None
box_variants: bool
unboxed_variants: set[str]
def __init__(self, group_name: str, group: dict[str, Any]) -> None:
self.name = group_name
@@ -130,10 +132,16 @@ class Group:
self.add_suffix_to_is_methods = group.get("add_suffix_to_is_methods", False)
self.anynode_is_label = group.get("anynode_is_label", to_snake_case(group_name))
self.doc = group.get("doc")
self.box_variants = group.get("box_variants", False)
self.unboxed_variants = set(group.get("unboxed_variants", []))
self.nodes = [
Node(self, node_name, node) for node_name, node in group["nodes"].items()
]
def is_boxed(self, node_name: str) -> bool:
"""Returns True if this node should be boxed in the owned enum."""
return self.box_variants and node_name not in self.unboxed_variants
@dataclass
class Node:
@@ -321,17 +329,29 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
out.append('#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]')
out.append(f"pub enum {group.owned_enum_ty} {{")
for node in group.nodes:
out.append(f"{node.variant}({node.ty}),")
if group.is_boxed(node.name):
out.append(f"{node.variant}(Box<{node.ty}>),")
else:
out.append(f"{node.variant}({node.ty}),")
out.append("}")
for node in group.nodes:
out.append(f"""
impl From<{node.ty}> for {group.owned_enum_ty} {{
fn from(node: {node.ty}) -> Self {{
Self::{node.variant}(node)
if group.is_boxed(node.name):
out.append(f"""
impl From<{node.ty}> for {group.owned_enum_ty} {{
fn from(node: {node.ty}) -> Self {{
Self::{node.variant}(Box::new(node))
}}
}}
}}
""")
""")
else:
out.append(f"""
impl From<{node.ty}> for {group.owned_enum_ty} {{
fn from(node: {node.ty}) -> Self {{
Self::{node.variant}(node)
}}
}}
""")
out.append(f"""
impl ruff_text_size::Ranged for {group.owned_enum_ty} {{
@@ -369,6 +389,9 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
match_arm = f"Self::{variant_name}"
if group.add_suffix_to_is_methods:
is_name = to_snake_case(node.variant + group.name)
is_boxed = group.is_boxed(node.name)
# For boxed variants, we need to dereference the box
unbox = "*" if is_boxed else ""
if len(group.nodes) > 1:
out.append(f"""
#[inline]
@@ -379,7 +402,7 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
#[inline]
pub fn {is_name}(self) -> Option<{node.ty}> {{
match self {{
{match_arm}(val) => Some(val),
{match_arm}(val) => Some({unbox}val),
_ => None,
}}
}}
@@ -387,7 +410,7 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
#[inline]
pub fn expect_{is_name}(self) -> {node.ty} {{
match self {{
{match_arm}(val) => val,
{match_arm}(val) => {unbox}val,
_ => panic!("called expect on {{self:?}}"),
}}
}}
@@ -418,14 +441,14 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
#[inline]
pub fn {is_name}(self) -> Option<{node.ty}> {{
match self {{
{match_arm}(val) => Some(val),
{match_arm}(val) => Some({unbox}val),
}}
}}
#[inline]
pub fn expect_{is_name}(self) -> {node.ty} {{
match self {{
{match_arm}(val) => val,
{match_arm}(val) => {unbox}val,
}}
}}

View File

@@ -1594,233 +1594,297 @@ pub enum ComparableStmt<'a> {
impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
fn from(stmt: &'a ast::Stmt) -> Self {
match stmt {
ast::Stmt::FunctionDef(ast::StmtFunctionDef {
is_async,
name,
parameters,
body,
decorator_list,
returns,
type_params,
range: _,
node_index: _,
}) => Self::FunctionDef(StmtFunctionDef {
is_async: *is_async,
name: name.as_str(),
parameters: parameters.into(),
body: body.iter().map(Into::into).collect(),
decorator_list: decorator_list.iter().map(Into::into).collect(),
returns: returns.as_ref().map(Into::into),
type_params: type_params.as_ref().map(Into::into),
}),
ast::Stmt::ClassDef(ast::StmtClassDef {
name,
arguments,
body,
decorator_list,
type_params,
range: _,
node_index: _,
}) => Self::ClassDef(StmtClassDef {
name: name.as_str(),
arguments: arguments.as_ref().map(Into::into).unwrap_or_default(),
body: body.iter().map(Into::into).collect(),
decorator_list: decorator_list.iter().map(Into::into).collect(),
type_params: type_params.as_ref().map(Into::into),
}),
ast::Stmt::FunctionDef(node) => {
let ast::StmtFunctionDef {
is_async,
name,
parameters,
body,
decorator_list,
returns,
type_params,
range: _,
node_index: _,
} = &**node;
Self::FunctionDef(StmtFunctionDef {
is_async: *is_async,
name: name.as_str(),
parameters: parameters.into(),
body: body.iter().map(Into::into).collect(),
decorator_list: decorator_list.iter().map(Into::into).collect(),
returns: returns.as_ref().map(Into::into),
type_params: type_params.as_ref().map(Into::into),
})
}
ast::Stmt::ClassDef(node) => {
let ast::StmtClassDef {
name,
arguments,
body,
decorator_list,
type_params,
range: _,
node_index: _,
} = &**node;
Self::ClassDef(StmtClassDef {
name: name.as_str(),
arguments: arguments.as_ref().map(Into::into).unwrap_or_default(),
body: body.iter().map(Into::into).collect(),
decorator_list: decorator_list.iter().map(Into::into).collect(),
type_params: type_params.as_ref().map(Into::into),
})
}
ast::Stmt::Return(ast::StmtReturn {
value,
range: _,
node_index: _,
}) => Self::Return(StmtReturn {
value: value.as_ref().map(Into::into),
}),
ast::Stmt::Delete(ast::StmtDelete {
targets,
range: _,
node_index: _,
}) => Self::Delete(StmtDelete {
targets: targets.iter().map(Into::into).collect(),
}),
ast::Stmt::TypeAlias(ast::StmtTypeAlias {
range: _,
node_index: _,
name,
type_params,
value,
}) => Self::TypeAlias(StmtTypeAlias {
name: name.into(),
type_params: type_params.as_ref().map(Into::into),
value: value.into(),
}),
ast::Stmt::Assign(ast::StmtAssign {
targets,
value,
range: _,
node_index: _,
}) => Self::Assign(StmtAssign {
targets: targets.iter().map(Into::into).collect(),
value: value.into(),
}),
ast::Stmt::AugAssign(ast::StmtAugAssign {
target,
op,
value,
range: _,
node_index: _,
}) => Self::AugAssign(StmtAugAssign {
target: target.into(),
op: (*op).into(),
value: value.into(),
}),
ast::Stmt::AnnAssign(ast::StmtAnnAssign {
target,
annotation,
value,
simple,
range: _,
node_index: _,
}) => Self::AnnAssign(StmtAnnAssign {
target: target.into(),
annotation: annotation.into(),
value: value.as_ref().map(Into::into),
simple: *simple,
}),
ast::Stmt::For(ast::StmtFor {
is_async,
target,
iter,
body,
orelse,
range: _,
node_index: _,
}) => Self::For(StmtFor {
is_async: *is_async,
target: target.into(),
iter: iter.into(),
body: body.iter().map(Into::into).collect(),
orelse: orelse.iter().map(Into::into).collect(),
}),
ast::Stmt::While(ast::StmtWhile {
test,
body,
orelse,
range: _,
node_index: _,
}) => Self::While(StmtWhile {
test: test.into(),
body: body.iter().map(Into::into).collect(),
orelse: orelse.iter().map(Into::into).collect(),
}),
ast::Stmt::If(ast::StmtIf {
test,
body,
elif_else_clauses,
range: _,
node_index: _,
}) => Self::If(StmtIf {
test: test.into(),
body: body.iter().map(Into::into).collect(),
elif_else_clauses: elif_else_clauses.iter().map(Into::into).collect(),
}),
ast::Stmt::With(ast::StmtWith {
is_async,
items,
body,
range: _,
node_index: _,
}) => Self::With(StmtWith {
is_async: *is_async,
items: items.iter().map(Into::into).collect(),
body: body.iter().map(Into::into).collect(),
}),
ast::Stmt::Match(ast::StmtMatch {
subject,
cases,
range: _,
node_index: _,
}) => Self::Match(StmtMatch {
subject: subject.into(),
cases: cases.iter().map(Into::into).collect(),
}),
ast::Stmt::Raise(ast::StmtRaise {
exc,
cause,
range: _,
node_index: _,
}) => Self::Raise(StmtRaise {
exc: exc.as_ref().map(Into::into),
cause: cause.as_ref().map(Into::into),
}),
ast::Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
is_star,
range: _,
node_index: _,
}) => Self::Try(StmtTry {
body: body.iter().map(Into::into).collect(),
handlers: handlers.iter().map(Into::into).collect(),
orelse: orelse.iter().map(Into::into).collect(),
finalbody: finalbody.iter().map(Into::into).collect(),
is_star: *is_star,
}),
ast::Stmt::Assert(ast::StmtAssert {
test,
msg,
range: _,
node_index: _,
}) => Self::Assert(StmtAssert {
test: test.into(),
msg: msg.as_ref().map(Into::into),
}),
ast::Stmt::Import(ast::StmtImport {
names,
range: _,
node_index: _,
}) => Self::Import(StmtImport {
names: names.iter().map(Into::into).collect(),
}),
ast::Stmt::ImportFrom(ast::StmtImportFrom {
module,
names,
level,
range: _,
node_index: _,
}) => Self::ImportFrom(StmtImportFrom {
module: module.as_deref(),
names: names.iter().map(Into::into).collect(),
level: *level,
}),
ast::Stmt::Global(ast::StmtGlobal {
names,
range: _,
node_index: _,
}) => Self::Global(StmtGlobal {
names: names.iter().map(ast::Identifier::as_str).collect(),
}),
ast::Stmt::Nonlocal(ast::StmtNonlocal {
names,
range: _,
node_index: _,
}) => Self::Nonlocal(StmtNonlocal {
names: names.iter().map(ast::Identifier::as_str).collect(),
}),
ast::Stmt::IpyEscapeCommand(ast::StmtIpyEscapeCommand {
kind,
value,
range: _,
node_index: _,
}) => Self::IpyEscapeCommand(StmtIpyEscapeCommand { kind: *kind, value }),
}) => {
Self::Return(StmtReturn {
value: value.as_ref().map(Into::into),
})
}
ast::Stmt::Delete(node) => {
let ast::StmtDelete {
targets,
range: _,
node_index: _,
} = &**node;
Self::Delete(StmtDelete {
targets: targets.iter().map(Into::into).collect(),
})
}
ast::Stmt::TypeAlias(node) => {
let ast::StmtTypeAlias {
range: _,
node_index: _,
name,
type_params,
value,
} = &**node;
Self::TypeAlias(StmtTypeAlias {
name: name.into(),
type_params: type_params.as_ref().map(Into::into),
value: value.into(),
})
}
ast::Stmt::Assign(node) => {
let ast::StmtAssign {
targets,
value,
range: _,
node_index: _,
} = &**node;
Self::Assign(StmtAssign {
targets: targets.iter().map(Into::into).collect(),
value: value.into(),
})
}
ast::Stmt::AugAssign(node) => {
let ast::StmtAugAssign {
target,
op,
value,
range: _,
node_index: _,
} = &**node;
Self::AugAssign(StmtAugAssign {
target: target.into(),
op: (*op).into(),
value: value.into(),
})
}
ast::Stmt::AnnAssign(node) => {
let ast::StmtAnnAssign {
target,
annotation,
value,
simple,
range: _,
node_index: _,
} = &**node;
Self::AnnAssign(StmtAnnAssign {
target: target.into(),
annotation: annotation.into(),
value: value.as_ref().map(Into::into),
simple: *simple,
})
}
ast::Stmt::For(node) => {
let ast::StmtFor {
is_async,
target,
iter,
body,
orelse,
range: _,
node_index: _,
} = &**node;
Self::For(StmtFor {
is_async: *is_async,
target: target.into(),
iter: iter.into(),
body: body.iter().map(Into::into).collect(),
orelse: orelse.iter().map(Into::into).collect(),
})
}
ast::Stmt::While(node) => {
let ast::StmtWhile {
test,
body,
orelse,
range: _,
node_index: _,
} = &**node;
Self::While(StmtWhile {
test: test.into(),
body: body.iter().map(Into::into).collect(),
orelse: orelse.iter().map(Into::into).collect(),
})
}
ast::Stmt::If(node) => {
let ast::StmtIf {
test,
body,
elif_else_clauses,
range: _,
node_index: _,
} = &**node;
Self::If(StmtIf {
test: test.into(),
body: body.iter().map(Into::into).collect(),
elif_else_clauses: elif_else_clauses.iter().map(Into::into).collect(),
})
}
ast::Stmt::With(node) => {
let ast::StmtWith {
is_async,
items,
body,
range: _,
node_index: _,
} = &**node;
Self::With(StmtWith {
is_async: *is_async,
items: items.iter().map(Into::into).collect(),
body: body.iter().map(Into::into).collect(),
})
}
ast::Stmt::Match(node) => {
let ast::StmtMatch {
subject,
cases,
range: _,
node_index: _,
} = &**node;
Self::Match(StmtMatch {
subject: subject.into(),
cases: cases.iter().map(Into::into).collect(),
})
}
ast::Stmt::Raise(node) => {
let ast::StmtRaise {
exc,
cause,
range: _,
node_index: _,
} = &**node;
Self::Raise(StmtRaise {
exc: exc.as_ref().map(Into::into),
cause: cause.as_ref().map(Into::into),
})
}
ast::Stmt::Try(node) => {
let ast::StmtTry {
body,
handlers,
orelse,
finalbody,
is_star,
range: _,
node_index: _,
} = &**node;
Self::Try(StmtTry {
body: body.iter().map(Into::into).collect(),
handlers: handlers.iter().map(Into::into).collect(),
orelse: orelse.iter().map(Into::into).collect(),
finalbody: finalbody.iter().map(Into::into).collect(),
is_star: *is_star,
})
}
ast::Stmt::Assert(node) => {
let ast::StmtAssert {
test,
msg,
range: _,
node_index: _,
} = &**node;
Self::Assert(StmtAssert {
test: test.into(),
msg: msg.as_ref().map(Into::into),
})
}
ast::Stmt::Import(node) => {
let ast::StmtImport {
names,
range: _,
node_index: _,
} = &**node;
Self::Import(StmtImport {
names: names.iter().map(Into::into).collect(),
})
}
ast::Stmt::ImportFrom(node) => {
let ast::StmtImportFrom {
module,
names,
level,
range: _,
node_index: _,
} = &**node;
Self::ImportFrom(StmtImportFrom {
module: module.as_deref(),
names: names.iter().map(Into::into).collect(),
level: *level,
})
}
ast::Stmt::Global(node) => {
let ast::StmtGlobal {
names,
range: _,
node_index: _,
} = &**node;
Self::Global(StmtGlobal {
names: names.iter().map(ast::Identifier::as_str).collect(),
})
}
ast::Stmt::Nonlocal(node) => {
let ast::StmtNonlocal {
names,
range: _,
node_index: _,
} = &**node;
Self::Nonlocal(StmtNonlocal {
names: names.iter().map(ast::Identifier::as_str).collect(),
})
}
ast::Stmt::IpyEscapeCommand(node) => {
let ast::StmtIpyEscapeCommand {
kind,
value,
range: _,
node_index: _,
} = &**node;
Self::IpyEscapeCommand(StmtIpyEscapeCommand { kind: *kind, value })
}
ast::Stmt::Expr(ast::StmtExpr {
value,
range: _,
node_index: _,
}) => Self::Expr(StmtExpr {
value: value.into(),
}),
}) => {
Self::Expr(StmtExpr {
value: value.into(),
})
}
ast::Stmt::Pass(_) => Self::Pass,
ast::Stmt::Break(_) => Self::Break,
ast::Stmt::Continue(_) => Self::Continue,

View File

@@ -123,42 +123,42 @@ impl Mod {
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]
pub enum Stmt {
FunctionDef(crate::StmtFunctionDef),
ClassDef(crate::StmtClassDef),
FunctionDef(Box<crate::StmtFunctionDef>),
ClassDef(Box<crate::StmtClassDef>),
Return(crate::StmtReturn),
Delete(crate::StmtDelete),
TypeAlias(crate::StmtTypeAlias),
Assign(crate::StmtAssign),
AugAssign(crate::StmtAugAssign),
AnnAssign(crate::StmtAnnAssign),
For(crate::StmtFor),
While(crate::StmtWhile),
If(crate::StmtIf),
With(crate::StmtWith),
Match(crate::StmtMatch),
Raise(crate::StmtRaise),
Try(crate::StmtTry),
Assert(crate::StmtAssert),
Import(crate::StmtImport),
ImportFrom(crate::StmtImportFrom),
Global(crate::StmtGlobal),
Nonlocal(crate::StmtNonlocal),
Delete(Box<crate::StmtDelete>),
TypeAlias(Box<crate::StmtTypeAlias>),
Assign(Box<crate::StmtAssign>),
AugAssign(Box<crate::StmtAugAssign>),
AnnAssign(Box<crate::StmtAnnAssign>),
For(Box<crate::StmtFor>),
While(Box<crate::StmtWhile>),
If(Box<crate::StmtIf>),
With(Box<crate::StmtWith>),
Match(Box<crate::StmtMatch>),
Raise(Box<crate::StmtRaise>),
Try(Box<crate::StmtTry>),
Assert(Box<crate::StmtAssert>),
Import(Box<crate::StmtImport>),
ImportFrom(Box<crate::StmtImportFrom>),
Global(Box<crate::StmtGlobal>),
Nonlocal(Box<crate::StmtNonlocal>),
Expr(crate::StmtExpr),
Pass(crate::StmtPass),
Break(crate::StmtBreak),
Continue(crate::StmtContinue),
IpyEscapeCommand(crate::StmtIpyEscapeCommand),
IpyEscapeCommand(Box<crate::StmtIpyEscapeCommand>),
}
impl From<crate::StmtFunctionDef> for Stmt {
fn from(node: crate::StmtFunctionDef) -> Self {
Self::FunctionDef(node)
Self::FunctionDef(Box::new(node))
}
}
impl From<crate::StmtClassDef> for Stmt {
fn from(node: crate::StmtClassDef) -> Self {
Self::ClassDef(node)
Self::ClassDef(Box::new(node))
}
}
@@ -170,103 +170,103 @@ impl From<crate::StmtReturn> for Stmt {
impl From<crate::StmtDelete> for Stmt {
fn from(node: crate::StmtDelete) -> Self {
Self::Delete(node)
Self::Delete(Box::new(node))
}
}
impl From<crate::StmtTypeAlias> for Stmt {
fn from(node: crate::StmtTypeAlias) -> Self {
Self::TypeAlias(node)
Self::TypeAlias(Box::new(node))
}
}
impl From<crate::StmtAssign> for Stmt {
fn from(node: crate::StmtAssign) -> Self {
Self::Assign(node)
Self::Assign(Box::new(node))
}
}
impl From<crate::StmtAugAssign> for Stmt {
fn from(node: crate::StmtAugAssign) -> Self {
Self::AugAssign(node)
Self::AugAssign(Box::new(node))
}
}
impl From<crate::StmtAnnAssign> for Stmt {
fn from(node: crate::StmtAnnAssign) -> Self {
Self::AnnAssign(node)
Self::AnnAssign(Box::new(node))
}
}
impl From<crate::StmtFor> for Stmt {
fn from(node: crate::StmtFor) -> Self {
Self::For(node)
Self::For(Box::new(node))
}
}
impl From<crate::StmtWhile> for Stmt {
fn from(node: crate::StmtWhile) -> Self {
Self::While(node)
Self::While(Box::new(node))
}
}
impl From<crate::StmtIf> for Stmt {
fn from(node: crate::StmtIf) -> Self {
Self::If(node)
Self::If(Box::new(node))
}
}
impl From<crate::StmtWith> for Stmt {
fn from(node: crate::StmtWith) -> Self {
Self::With(node)
Self::With(Box::new(node))
}
}
impl From<crate::StmtMatch> for Stmt {
fn from(node: crate::StmtMatch) -> Self {
Self::Match(node)
Self::Match(Box::new(node))
}
}
impl From<crate::StmtRaise> for Stmt {
fn from(node: crate::StmtRaise) -> Self {
Self::Raise(node)
Self::Raise(Box::new(node))
}
}
impl From<crate::StmtTry> for Stmt {
fn from(node: crate::StmtTry) -> Self {
Self::Try(node)
Self::Try(Box::new(node))
}
}
impl From<crate::StmtAssert> for Stmt {
fn from(node: crate::StmtAssert) -> Self {
Self::Assert(node)
Self::Assert(Box::new(node))
}
}
impl From<crate::StmtImport> for Stmt {
fn from(node: crate::StmtImport) -> Self {
Self::Import(node)
Self::Import(Box::new(node))
}
}
impl From<crate::StmtImportFrom> for Stmt {
fn from(node: crate::StmtImportFrom) -> Self {
Self::ImportFrom(node)
Self::ImportFrom(Box::new(node))
}
}
impl From<crate::StmtGlobal> for Stmt {
fn from(node: crate::StmtGlobal) -> Self {
Self::Global(node)
Self::Global(Box::new(node))
}
}
impl From<crate::StmtNonlocal> for Stmt {
fn from(node: crate::StmtNonlocal) -> Self {
Self::Nonlocal(node)
Self::Nonlocal(Box::new(node))
}
}
@@ -296,7 +296,7 @@ impl From<crate::StmtContinue> for Stmt {
impl From<crate::StmtIpyEscapeCommand> for Stmt {
fn from(node: crate::StmtIpyEscapeCommand) -> Self {
Self::IpyEscapeCommand(node)
Self::IpyEscapeCommand(Box::new(node))
}
}
@@ -374,7 +374,7 @@ impl Stmt {
#[inline]
pub fn function_def_stmt(self) -> Option<crate::StmtFunctionDef> {
match self {
Self::FunctionDef(val) => Some(val),
Self::FunctionDef(val) => Some(*val),
_ => None,
}
}
@@ -382,7 +382,7 @@ impl Stmt {
#[inline]
pub fn expect_function_def_stmt(self) -> crate::StmtFunctionDef {
match self {
Self::FunctionDef(val) => val,
Self::FunctionDef(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -411,7 +411,7 @@ impl Stmt {
#[inline]
pub fn class_def_stmt(self) -> Option<crate::StmtClassDef> {
match self {
Self::ClassDef(val) => Some(val),
Self::ClassDef(val) => Some(*val),
_ => None,
}
}
@@ -419,7 +419,7 @@ impl Stmt {
#[inline]
pub fn expect_class_def_stmt(self) -> crate::StmtClassDef {
match self {
Self::ClassDef(val) => val,
Self::ClassDef(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -485,7 +485,7 @@ impl Stmt {
#[inline]
pub fn delete_stmt(self) -> Option<crate::StmtDelete> {
match self {
Self::Delete(val) => Some(val),
Self::Delete(val) => Some(*val),
_ => None,
}
}
@@ -493,7 +493,7 @@ impl Stmt {
#[inline]
pub fn expect_delete_stmt(self) -> crate::StmtDelete {
match self {
Self::Delete(val) => val,
Self::Delete(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -522,7 +522,7 @@ impl Stmt {
#[inline]
pub fn type_alias_stmt(self) -> Option<crate::StmtTypeAlias> {
match self {
Self::TypeAlias(val) => Some(val),
Self::TypeAlias(val) => Some(*val),
_ => None,
}
}
@@ -530,7 +530,7 @@ impl Stmt {
#[inline]
pub fn expect_type_alias_stmt(self) -> crate::StmtTypeAlias {
match self {
Self::TypeAlias(val) => val,
Self::TypeAlias(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -559,7 +559,7 @@ impl Stmt {
#[inline]
pub fn assign_stmt(self) -> Option<crate::StmtAssign> {
match self {
Self::Assign(val) => Some(val),
Self::Assign(val) => Some(*val),
_ => None,
}
}
@@ -567,7 +567,7 @@ impl Stmt {
#[inline]
pub fn expect_assign_stmt(self) -> crate::StmtAssign {
match self {
Self::Assign(val) => val,
Self::Assign(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -596,7 +596,7 @@ impl Stmt {
#[inline]
pub fn aug_assign_stmt(self) -> Option<crate::StmtAugAssign> {
match self {
Self::AugAssign(val) => Some(val),
Self::AugAssign(val) => Some(*val),
_ => None,
}
}
@@ -604,7 +604,7 @@ impl Stmt {
#[inline]
pub fn expect_aug_assign_stmt(self) -> crate::StmtAugAssign {
match self {
Self::AugAssign(val) => val,
Self::AugAssign(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -633,7 +633,7 @@ impl Stmt {
#[inline]
pub fn ann_assign_stmt(self) -> Option<crate::StmtAnnAssign> {
match self {
Self::AnnAssign(val) => Some(val),
Self::AnnAssign(val) => Some(*val),
_ => None,
}
}
@@ -641,7 +641,7 @@ impl Stmt {
#[inline]
pub fn expect_ann_assign_stmt(self) -> crate::StmtAnnAssign {
match self {
Self::AnnAssign(val) => val,
Self::AnnAssign(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -670,7 +670,7 @@ impl Stmt {
#[inline]
pub fn for_stmt(self) -> Option<crate::StmtFor> {
match self {
Self::For(val) => Some(val),
Self::For(val) => Some(*val),
_ => None,
}
}
@@ -678,7 +678,7 @@ impl Stmt {
#[inline]
pub fn expect_for_stmt(self) -> crate::StmtFor {
match self {
Self::For(val) => val,
Self::For(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -707,7 +707,7 @@ impl Stmt {
#[inline]
pub fn while_stmt(self) -> Option<crate::StmtWhile> {
match self {
Self::While(val) => Some(val),
Self::While(val) => Some(*val),
_ => None,
}
}
@@ -715,7 +715,7 @@ impl Stmt {
#[inline]
pub fn expect_while_stmt(self) -> crate::StmtWhile {
match self {
Self::While(val) => val,
Self::While(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -744,7 +744,7 @@ impl Stmt {
#[inline]
pub fn if_stmt(self) -> Option<crate::StmtIf> {
match self {
Self::If(val) => Some(val),
Self::If(val) => Some(*val),
_ => None,
}
}
@@ -752,7 +752,7 @@ impl Stmt {
#[inline]
pub fn expect_if_stmt(self) -> crate::StmtIf {
match self {
Self::If(val) => val,
Self::If(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -781,7 +781,7 @@ impl Stmt {
#[inline]
pub fn with_stmt(self) -> Option<crate::StmtWith> {
match self {
Self::With(val) => Some(val),
Self::With(val) => Some(*val),
_ => None,
}
}
@@ -789,7 +789,7 @@ impl Stmt {
#[inline]
pub fn expect_with_stmt(self) -> crate::StmtWith {
match self {
Self::With(val) => val,
Self::With(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -818,7 +818,7 @@ impl Stmt {
#[inline]
pub fn match_stmt(self) -> Option<crate::StmtMatch> {
match self {
Self::Match(val) => Some(val),
Self::Match(val) => Some(*val),
_ => None,
}
}
@@ -826,7 +826,7 @@ impl Stmt {
#[inline]
pub fn expect_match_stmt(self) -> crate::StmtMatch {
match self {
Self::Match(val) => val,
Self::Match(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -855,7 +855,7 @@ impl Stmt {
#[inline]
pub fn raise_stmt(self) -> Option<crate::StmtRaise> {
match self {
Self::Raise(val) => Some(val),
Self::Raise(val) => Some(*val),
_ => None,
}
}
@@ -863,7 +863,7 @@ impl Stmt {
#[inline]
pub fn expect_raise_stmt(self) -> crate::StmtRaise {
match self {
Self::Raise(val) => val,
Self::Raise(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -892,7 +892,7 @@ impl Stmt {
#[inline]
pub fn try_stmt(self) -> Option<crate::StmtTry> {
match self {
Self::Try(val) => Some(val),
Self::Try(val) => Some(*val),
_ => None,
}
}
@@ -900,7 +900,7 @@ impl Stmt {
#[inline]
pub fn expect_try_stmt(self) -> crate::StmtTry {
match self {
Self::Try(val) => val,
Self::Try(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -929,7 +929,7 @@ impl Stmt {
#[inline]
pub fn assert_stmt(self) -> Option<crate::StmtAssert> {
match self {
Self::Assert(val) => Some(val),
Self::Assert(val) => Some(*val),
_ => None,
}
}
@@ -937,7 +937,7 @@ impl Stmt {
#[inline]
pub fn expect_assert_stmt(self) -> crate::StmtAssert {
match self {
Self::Assert(val) => val,
Self::Assert(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -966,7 +966,7 @@ impl Stmt {
#[inline]
pub fn import_stmt(self) -> Option<crate::StmtImport> {
match self {
Self::Import(val) => Some(val),
Self::Import(val) => Some(*val),
_ => None,
}
}
@@ -974,7 +974,7 @@ impl Stmt {
#[inline]
pub fn expect_import_stmt(self) -> crate::StmtImport {
match self {
Self::Import(val) => val,
Self::Import(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -1003,7 +1003,7 @@ impl Stmt {
#[inline]
pub fn import_from_stmt(self) -> Option<crate::StmtImportFrom> {
match self {
Self::ImportFrom(val) => Some(val),
Self::ImportFrom(val) => Some(*val),
_ => None,
}
}
@@ -1011,7 +1011,7 @@ impl Stmt {
#[inline]
pub fn expect_import_from_stmt(self) -> crate::StmtImportFrom {
match self {
Self::ImportFrom(val) => val,
Self::ImportFrom(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -1040,7 +1040,7 @@ impl Stmt {
#[inline]
pub fn global_stmt(self) -> Option<crate::StmtGlobal> {
match self {
Self::Global(val) => Some(val),
Self::Global(val) => Some(*val),
_ => None,
}
}
@@ -1048,7 +1048,7 @@ impl Stmt {
#[inline]
pub fn expect_global_stmt(self) -> crate::StmtGlobal {
match self {
Self::Global(val) => val,
Self::Global(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -1077,7 +1077,7 @@ impl Stmt {
#[inline]
pub fn nonlocal_stmt(self) -> Option<crate::StmtNonlocal> {
match self {
Self::Nonlocal(val) => Some(val),
Self::Nonlocal(val) => Some(*val),
_ => None,
}
}
@@ -1085,7 +1085,7 @@ impl Stmt {
#[inline]
pub fn expect_nonlocal_stmt(self) -> crate::StmtNonlocal {
match self {
Self::Nonlocal(val) => val,
Self::Nonlocal(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}
@@ -1262,7 +1262,7 @@ impl Stmt {
#[inline]
pub fn ipy_escape_command_stmt(self) -> Option<crate::StmtIpyEscapeCommand> {
match self {
Self::IpyEscapeCommand(val) => Some(val),
Self::IpyEscapeCommand(val) => Some(*val),
_ => None,
}
}
@@ -1270,7 +1270,7 @@ impl Stmt {
#[inline]
pub fn expect_ipy_escape_command_stmt(self) -> crate::StmtIpyEscapeCommand {
match self {
Self::IpyEscapeCommand(val) => val,
Self::IpyEscapeCommand(val) => *val,
_ => panic!("called expect on {self:?}"),
}
}

View File

@@ -392,43 +392,32 @@ pub fn any_over_interpolated_string_element(
pub fn any_over_stmt(stmt: &Stmt, func: &dyn Fn(&Expr) -> bool) -> bool {
match stmt {
Stmt::FunctionDef(ast::StmtFunctionDef {
parameters,
type_params,
body,
decorator_list,
returns,
..
}) => {
parameters.iter().any(|param| {
Stmt::FunctionDef(node) => {
node.parameters.iter().any(|param| {
param
.default()
.is_some_and(|default| any_over_expr(default, func))
|| param
.annotation()
.is_some_and(|annotation| any_over_expr(annotation, func))
}) || type_params.as_ref().is_some_and(|type_params| {
}) || node.type_params.as_ref().is_some_and(|type_params| {
type_params
.iter()
.any(|type_param| any_over_type_param(type_param, func))
}) || body.iter().any(|stmt| any_over_stmt(stmt, func))
|| decorator_list
}) || node.body.iter().any(|stmt| any_over_stmt(stmt, func))
|| node
.decorator_list
.iter()
.any(|decorator| any_over_expr(&decorator.expression, func))
|| returns
|| node
.returns
.as_ref()
.is_some_and(|value| any_over_expr(value, func))
}
Stmt::ClassDef(ast::StmtClassDef {
arguments,
type_params,
body,
decorator_list,
..
}) => {
Stmt::ClassDef(node) => {
// Note that e.g. `class A(*args, a=2, *args2, **kwargs): pass` is a valid class
// definition
arguments
node.arguments
.as_deref()
.is_some_and(|Arguments { args, keywords, .. }| {
args.iter().any(|expr| any_over_expr(expr, func))
@@ -436,89 +425,61 @@ pub fn any_over_stmt(stmt: &Stmt, func: &dyn Fn(&Expr) -> bool) -> bool {
.iter()
.any(|keyword| any_over_expr(&keyword.value, func))
})
|| type_params.as_ref().is_some_and(|type_params| {
|| node.type_params.as_ref().is_some_and(|type_params| {
type_params
.iter()
.any(|type_param| any_over_type_param(type_param, func))
})
|| body.iter().any(|stmt| any_over_stmt(stmt, func))
|| decorator_list
|| node.body.iter().any(|stmt| any_over_stmt(stmt, func))
|| node
.decorator_list
.iter()
.any(|decorator| any_over_expr(&decorator.expression, func))
}
Stmt::Return(ast::StmtReturn {
value,
range: _,
node_index: _,
}) => value
Stmt::Return(node) => node
.value
.as_ref()
.is_some_and(|value| any_over_expr(value, func)),
Stmt::Delete(ast::StmtDelete {
targets,
range: _,
node_index: _,
}) => targets.iter().any(|expr| any_over_expr(expr, func)),
Stmt::TypeAlias(ast::StmtTypeAlias {
name,
type_params,
value,
..
}) => {
any_over_expr(name, func)
|| type_params.as_ref().is_some_and(|type_params| {
Stmt::Delete(node) => node.targets.iter().any(|expr| any_over_expr(expr, func)),
Stmt::TypeAlias(node) => {
any_over_expr(&node.name, func)
|| node.type_params.as_ref().is_some_and(|type_params| {
type_params
.iter()
.any(|type_param| any_over_type_param(type_param, func))
})
|| any_over_expr(value, func)
|| any_over_expr(&node.value, func)
}
Stmt::Assign(ast::StmtAssign { targets, value, .. }) => {
targets.iter().any(|expr| any_over_expr(expr, func)) || any_over_expr(value, func)
Stmt::Assign(node) => {
node.targets.iter().any(|expr| any_over_expr(expr, func))
|| any_over_expr(&node.value, func)
}
Stmt::AugAssign(ast::StmtAugAssign { target, value, .. }) => {
any_over_expr(target, func) || any_over_expr(value, func)
Stmt::AugAssign(node) => {
any_over_expr(&node.target, func) || any_over_expr(&node.value, func)
}
Stmt::AnnAssign(ast::StmtAnnAssign {
target,
annotation,
value,
..
}) => {
any_over_expr(target, func)
|| any_over_expr(annotation, func)
|| value
Stmt::AnnAssign(node) => {
any_over_expr(&node.target, func)
|| any_over_expr(&node.annotation, func)
|| node
.value
.as_ref()
.is_some_and(|value| any_over_expr(value, func))
}
Stmt::For(ast::StmtFor {
target,
iter,
body,
orelse,
..
}) => {
any_over_expr(target, func)
|| any_over_expr(iter, func)
|| any_over_body(body, func)
|| any_over_body(orelse, func)
Stmt::For(node) => {
any_over_expr(&node.target, func)
|| any_over_expr(&node.iter, func)
|| any_over_body(&node.body, func)
|| any_over_body(&node.orelse, func)
}
Stmt::While(ast::StmtWhile {
test,
body,
orelse,
range: _,
node_index: _,
}) => any_over_expr(test, func) || any_over_body(body, func) || any_over_body(orelse, func),
Stmt::If(ast::StmtIf {
test,
body,
elif_else_clauses,
range: _,
node_index: _,
}) => {
any_over_expr(test, func)
|| any_over_body(body, func)
|| elif_else_clauses.iter().any(|clause| {
Stmt::While(node) => {
any_over_expr(&node.test, func)
|| any_over_body(&node.body, func)
|| any_over_body(&node.orelse, func)
}
Stmt::If(node) => {
any_over_expr(&node.test, func)
|| any_over_body(&node.body, func)
|| node.elif_else_clauses.iter().any(|clause| {
clause
.test
.as_ref()
@@ -526,37 +487,27 @@ pub fn any_over_stmt(stmt: &Stmt, func: &dyn Fn(&Expr) -> bool) -> bool {
|| any_over_body(&clause.body, func)
})
}
Stmt::With(ast::StmtWith { items, body, .. }) => {
items.iter().any(|with_item| {
Stmt::With(node) => {
node.items.iter().any(|with_item| {
any_over_expr(&with_item.context_expr, func)
|| with_item
.optional_vars
.as_ref()
.is_some_and(|expr| any_over_expr(expr, func))
}) || any_over_body(body, func)
}) || any_over_body(&node.body, func)
}
Stmt::Raise(ast::StmtRaise {
exc,
cause,
range: _,
node_index: _,
}) => {
exc.as_ref().is_some_and(|value| any_over_expr(value, func))
|| cause
Stmt::Raise(node) => {
node.exc
.as_ref()
.is_some_and(|value| any_over_expr(value, func))
|| node
.cause
.as_ref()
.is_some_and(|value| any_over_expr(value, func))
}
Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
is_star: _,
range: _,
node_index: _,
}) => {
any_over_body(body, func)
|| handlers.iter().any(|handler| {
Stmt::Try(node) => {
any_over_body(&node.body, func)
|| node.handlers.iter().any(|handler| {
let ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler {
type_,
body,
@@ -565,26 +516,19 @@ pub fn any_over_stmt(stmt: &Stmt, func: &dyn Fn(&Expr) -> bool) -> bool {
type_.as_ref().is_some_and(|expr| any_over_expr(expr, func))
|| any_over_body(body, func)
})
|| any_over_body(orelse, func)
|| any_over_body(finalbody, func)
|| any_over_body(&node.orelse, func)
|| any_over_body(&node.finalbody, func)
}
Stmt::Assert(ast::StmtAssert {
test,
msg,
range: _,
node_index: _,
}) => {
any_over_expr(test, func)
|| msg.as_ref().is_some_and(|value| any_over_expr(value, func))
Stmt::Assert(node) => {
any_over_expr(&node.test, func)
|| node
.msg
.as_ref()
.is_some_and(|value| any_over_expr(value, func))
}
Stmt::Match(ast::StmtMatch {
subject,
cases,
range: _,
node_index: _,
}) => {
any_over_expr(subject, func)
|| cases.iter().any(|case| {
Stmt::Match(node) => {
any_over_expr(&node.subject, func)
|| node.cases.iter().any(|case| {
let MatchCase {
pattern,
guard,
@@ -601,11 +545,7 @@ pub fn any_over_stmt(stmt: &Stmt, func: &dyn Fn(&Expr) -> bool) -> bool {
Stmt::ImportFrom(_) => false,
Stmt::Global(_) => false,
Stmt::Nonlocal(_) => false,
Stmt::Expr(ast::StmtExpr {
value,
range: _,
node_index: _,
}) => any_over_expr(value, func),
Stmt::Expr(node) => any_over_expr(&node.value, func),
Stmt::Pass(_) | Stmt::Break(_) | Stmt::Continue(_) => false,
Stmt::IpyEscapeCommand(_) => false,
}
@@ -631,15 +571,15 @@ pub fn is_assignment_to_a_dunder(stmt: &Stmt) -> bool {
// Check whether it's an assignment to a dunder, with or without a type
// annotation. This is what pycodestyle (as of 2.9.1) does.
match stmt {
Stmt::Assign(ast::StmtAssign { targets, .. }) => {
if let [Expr::Name(ast::ExprName { id, .. })] = targets.as_slice() {
Stmt::Assign(node) => {
if let [Expr::Name(ast::ExprName { id, .. })] = node.targets.as_slice() {
is_dunder(id)
} else {
false
}
}
Stmt::AnnAssign(ast::StmtAnnAssign { target, .. }) => {
if let Expr::Name(ast::ExprName { id, .. }) = target.as_ref() {
Stmt::AnnAssign(node) => {
if let Expr::Name(ast::ExprName { id, .. }) = node.target.as_ref() {
is_dunder(id)
} else {
false
@@ -1021,33 +961,31 @@ pub struct RaiseStatementVisitor<'a> {
impl<'a> StatementVisitor<'a> for RaiseStatementVisitor<'a> {
fn visit_stmt(&mut self, stmt: &'a Stmt) {
match stmt {
Stmt::Raise(ast::StmtRaise {
exc,
cause,
range: _,
node_index: _,
}) => {
self.raises
.push((stmt.range(), exc.as_deref(), cause.as_deref()));
Stmt::Raise(node) => {
self.raises.push((
stmt.range(),
node.exc.as_deref(),
node.cause.as_deref(),
));
}
Stmt::ClassDef(_) | Stmt::FunctionDef(_) | Stmt::Try(_) => {}
Stmt::If(ast::StmtIf {
body,
elif_else_clauses,
..
}) => {
crate::statement_visitor::walk_body(self, body);
for clause in elif_else_clauses {
Stmt::If(node) => {
crate::statement_visitor::walk_body(self, &node.body);
for clause in &node.elif_else_clauses {
self.visit_elif_else_clause(clause);
}
}
Stmt::While(ast::StmtWhile { body, .. })
| Stmt::With(ast::StmtWith { body, .. })
| Stmt::For(ast::StmtFor { body, .. }) => {
crate::statement_visitor::walk_body(self, body);
Stmt::While(node) => {
crate::statement_visitor::walk_body(self, &node.body);
}
Stmt::Match(ast::StmtMatch { cases, .. }) => {
for case in cases {
Stmt::With(node) => {
crate::statement_visitor::walk_body(self, &node.body);
}
Stmt::For(node) => {
crate::statement_visitor::walk_body(self, &node.body);
}
Stmt::Match(node) => {
for case in &node.cases {
crate::statement_visitor::walk_body(self, &case.body);
}
}
@@ -1066,10 +1004,10 @@ impl Visitor<'_> for AwaitVisitor {
fn visit_stmt(&mut self, stmt: &Stmt) {
match stmt {
Stmt::FunctionDef(_) | Stmt::ClassDef(_) => (),
Stmt::With(ast::StmtWith { is_async: true, .. }) => {
Stmt::With(node) if node.is_async => {
self.seen_await = true;
}
Stmt::For(ast::StmtFor { is_async: true, .. }) => {
Stmt::For(node) if node.is_async => {
self.seen_await = true;
}
_ => crate::visitor::walk_stmt(self, stmt),
@@ -1095,13 +1033,8 @@ impl Visitor<'_> for AwaitVisitor {
/// Return `true` if a `Stmt` is a docstring.
pub fn is_docstring_stmt(stmt: &Stmt) -> bool {
if let Stmt::Expr(ast::StmtExpr {
value,
range: _,
node_index: _,
}) = stmt
{
value.is_string_literal_expr()
if let Stmt::Expr(node) = stmt {
node.value.is_string_literal_expr()
} else {
false
}
@@ -1113,13 +1046,8 @@ pub fn on_conditional_branch<'a>(parents: &mut impl Iterator<Item = &'a Stmt>) -
if matches!(parent, Stmt::If(_) | Stmt::While(_) | Stmt::Match(_)) {
return true;
}
if let Stmt::Expr(ast::StmtExpr {
value,
range: _,
node_index: _,
}) = parent
{
if value.is_if_expr() {
if let Stmt::Expr(node) = parent {
if node.value.is_if_expr() {
return true;
}
}
@@ -1140,7 +1068,7 @@ pub fn in_nested_block<'a>(mut parents: impl Iterator<Item = &'a Stmt>) -> bool
/// Check if a node represents an unpacking assignment.
pub fn is_unpacking_assignment(parent: &Stmt, child: &Expr) -> bool {
match parent {
Stmt::With(ast::StmtWith { items, .. }) => items.iter().any(|item| {
Stmt::With(node) => node.items.iter().any(|item| {
if let Some(optional_vars) = &item.optional_vars {
if optional_vars.is_tuple_expr() {
if any_over_expr(optional_vars, &|expr| expr == child) {
@@ -1150,22 +1078,23 @@ pub fn is_unpacking_assignment(parent: &Stmt, child: &Expr) -> bool {
}
false
}),
Stmt::Assign(ast::StmtAssign { targets, value, .. }) => {
Stmt::Assign(node) => {
// In `(a, b) = (1, 2)`, `(1, 2)` is the target, and it is a tuple.
let value_is_tuple = matches!(
value.as_ref(),
node.value.as_ref(),
Expr::Set(_) | Expr::List(_) | Expr::Tuple(_)
);
// In `(a, b) = coords = (1, 2)`, `(a, b)` and `coords` are the targets, and
// `(a, b)` is a tuple. (We use "tuple" as a placeholder for any
// unpackable expression.)
let targets_are_tuples = targets
let targets_are_tuples = node
.targets
.iter()
.all(|item| matches!(item, Expr::Set(_) | Expr::List(_) | Expr::Tuple(_)));
// If we're looking at `a` in `(a, b) = coords = (1, 2)`, then we should
// identify that the current expression is in a tuple.
let child_in_tuple = targets_are_tuples
|| targets.iter().any(|item| {
|| node.targets.iter().any(|item| {
matches!(item, Expr::Set(_) | Expr::List(_) | Expr::Tuple(_))
&& any_over_expr(item, &|expr| expr == child)
});
@@ -1748,7 +1677,7 @@ mod tests {
default: Some(Box::new(constant_two.clone())),
name: Identifier::new("x", TextRange::default()),
});
let type_alias = Stmt::TypeAlias(StmtTypeAlias {
let type_alias = Stmt::TypeAlias(Box::new(StmtTypeAlias {
name: Box::new(name.clone()),
type_params: Some(Box::new(TypeParams {
type_params: vec![type_var_one, type_var_two],
@@ -1758,7 +1687,7 @@ mod tests {
value: Box::new(constant_three.clone()),
range: TextRange::default(),
node_index: AtomicNodeIndex::NONE,
});
}));
assert!(!any_over_stmt(&type_alias, &|expr| {
seen.borrow_mut().push(expr.clone());
false

View File

@@ -112,10 +112,10 @@ pub fn except(handler: &ExceptHandler, source: &str) -> TextRange {
/// Return the [`TextRange`] of the `else` token in a `For` or `While` statement.
pub fn else_(stmt: &Stmt, source: &str) -> Option<TextRange> {
let (Stmt::For(ast::StmtFor { body, orelse, .. })
| Stmt::While(ast::StmtWhile { body, orelse, .. })) = stmt
else {
return None;
let (body, orelse) = match stmt {
Stmt::For(node) => (&node.body, &node.orelse),
Stmt::While(node) => (&node.body, &node.orelse),
_ => return None,
};
if orelse.is_empty() {

View File

@@ -3626,7 +3626,9 @@ mod tests {
#[test]
#[cfg(target_pointer_width = "64")]
fn size() {
assert_eq!(std::mem::size_of::<Stmt>(), 128);
// Stmt variants are boxed to reduce the enum size from 128 to 32 bytes.
// Only Return, Expr, Pass, Break, Continue remain unboxed.
assert_eq!(std::mem::size_of::<Stmt>(), 32);
assert_eq!(std::mem::size_of::<StmtFunctionDef>(), 128);
assert_eq!(std::mem::size_of::<StmtClassDef>(), 120);
assert_eq!(std::mem::size_of::<StmtTry>(), 112);

View File

@@ -29,51 +29,41 @@ pub fn walk_body<'a, V: StatementVisitor<'a> + ?Sized>(visitor: &mut V, body: &'
pub fn walk_stmt<'a, V: StatementVisitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) {
match stmt {
Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => {
visitor.visit_body(body);
Stmt::FunctionDef(node) => {
visitor.visit_body(&node.body);
}
Stmt::For(ast::StmtFor { body, orelse, .. }) => {
visitor.visit_body(body);
visitor.visit_body(orelse);
Stmt::For(node) => {
visitor.visit_body(&node.body);
visitor.visit_body(&node.orelse);
}
Stmt::ClassDef(ast::StmtClassDef { body, .. }) => {
visitor.visit_body(body);
Stmt::ClassDef(node) => {
visitor.visit_body(&node.body);
}
Stmt::While(ast::StmtWhile { body, orelse, .. }) => {
visitor.visit_body(body);
visitor.visit_body(orelse);
Stmt::While(node) => {
visitor.visit_body(&node.body);
visitor.visit_body(&node.orelse);
}
Stmt::If(ast::StmtIf {
body,
elif_else_clauses,
..
}) => {
visitor.visit_body(body);
for clause in elif_else_clauses {
Stmt::If(node) => {
visitor.visit_body(&node.body);
for clause in &node.elif_else_clauses {
visitor.visit_elif_else_clause(clause);
}
}
Stmt::With(ast::StmtWith { body, .. }) => {
visitor.visit_body(body);
Stmt::With(node) => {
visitor.visit_body(&node.body);
}
Stmt::Match(ast::StmtMatch { cases, .. }) => {
for match_case in cases {
Stmt::Match(node) => {
for match_case in &node.cases {
visitor.visit_match_case(match_case);
}
}
Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
..
}) => {
visitor.visit_body(body);
for except_handler in handlers {
Stmt::Try(node) => {
visitor.visit_body(&node.body);
for except_handler in &node.handlers {
visitor.visit_except_handler(except_handler);
}
visitor.visit_body(orelse);
visitor.visit_body(finalbody);
visitor.visit_body(&node.orelse);
visitor.visit_body(&node.finalbody);
}
_ => {}
}

View File

@@ -1,41 +1,32 @@
//! Utilities for manually traversing a Python AST.
use crate::{self as ast, AnyNodeRef, ExceptHandler, Stmt};
use crate::{AnyNodeRef, ExceptHandler, Stmt};
/// Given a [`Stmt`] and its parent, return the [`ast::Suite`] that contains the [`Stmt`].
pub fn suite<'a>(stmt: &'a Stmt, parent: &'a Stmt) -> Option<EnclosingSuite<'a>> {
// TODO: refactor this to work without a parent, ie when `stmt` is at the top level
match parent {
Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => EnclosingSuite::new(body, stmt),
Stmt::ClassDef(ast::StmtClassDef { body, .. }) => EnclosingSuite::new(body, stmt),
Stmt::For(ast::StmtFor { body, orelse, .. }) => [body, orelse]
Stmt::FunctionDef(node) => EnclosingSuite::new(&node.body, stmt),
Stmt::ClassDef(node) => EnclosingSuite::new(&node.body, stmt),
Stmt::For(node) => [&node.body, &node.orelse]
.iter()
.find_map(|suite| EnclosingSuite::new(suite, stmt)),
Stmt::While(ast::StmtWhile { body, orelse, .. }) => [body, orelse]
Stmt::While(node) => [&node.body, &node.orelse]
.iter()
.find_map(|suite| EnclosingSuite::new(suite, stmt)),
Stmt::If(ast::StmtIf {
body,
elif_else_clauses,
..
}) => [body]
Stmt::If(node) => [&node.body]
.into_iter()
.chain(elif_else_clauses.iter().map(|clause| &clause.body))
.chain(node.elif_else_clauses.iter().map(|clause| &clause.body))
.find_map(|suite| EnclosingSuite::new(suite, stmt)),
Stmt::With(ast::StmtWith { body, .. }) => EnclosingSuite::new(body, stmt),
Stmt::Match(ast::StmtMatch { cases, .. }) => cases
Stmt::With(node) => EnclosingSuite::new(&node.body, stmt),
Stmt::Match(node) => node
.cases
.iter()
.map(|case| &case.body)
.find_map(|body| EnclosingSuite::new(body, stmt)),
Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
..
}) => [body, orelse, finalbody]
Stmt::Try(node) => [&node.body, &node.orelse, &node.finalbody]
.into_iter()
.chain(
handlers
node.handlers
.iter()
.filter_map(ExceptHandler::as_except_handler)
.map(|handler| &handler.body),

View File

@@ -134,43 +134,30 @@ pub fn walk_elif_else_clause<'a, V: Visitor<'a> + ?Sized>(
pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) {
match stmt {
Stmt::FunctionDef(ast::StmtFunctionDef {
parameters,
body,
decorator_list,
returns,
type_params,
..
}) => {
for decorator in decorator_list {
Stmt::FunctionDef(node) => {
for decorator in &node.decorator_list {
visitor.visit_decorator(decorator);
}
if let Some(type_params) = type_params {
if let Some(type_params) = &node.type_params {
visitor.visit_type_params(type_params);
}
visitor.visit_parameters(parameters);
if let Some(expr) = returns {
visitor.visit_parameters(&node.parameters);
if let Some(expr) = &node.returns {
visitor.visit_annotation(expr);
}
visitor.visit_body(body);
visitor.visit_body(&node.body);
}
Stmt::ClassDef(ast::StmtClassDef {
arguments,
body,
decorator_list,
type_params,
..
}) => {
for decorator in decorator_list {
Stmt::ClassDef(node) => {
for decorator in &node.decorator_list {
visitor.visit_decorator(decorator);
}
if let Some(type_params) = type_params {
if let Some(type_params) = &node.type_params {
visitor.visit_type_params(type_params);
}
if let Some(arguments) = arguments {
if let Some(arguments) = &node.arguments {
visitor.visit_arguments(arguments);
}
visitor.visit_body(body);
visitor.visit_body(&node.body);
}
Stmt::Return(ast::StmtReturn {
value,
@@ -181,87 +168,99 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) {
visitor.visit_expr(expr);
}
}
Stmt::Delete(ast::StmtDelete {
targets,
range: _,
node_index: _,
}) => {
Stmt::Delete(node) => {
let ast::StmtDelete {
targets,
range: _,
node_index: _,
} = &**node;
for expr in targets {
visitor.visit_expr(expr);
}
}
Stmt::TypeAlias(ast::StmtTypeAlias {
range: _,
node_index: _,
name,
type_params,
value,
}) => {
Stmt::TypeAlias(node) => {
let ast::StmtTypeAlias {
range: _,
node_index: _,
name,
type_params,
value,
} = &**node;
visitor.visit_expr(value);
if let Some(type_params) = type_params {
visitor.visit_type_params(type_params);
}
visitor.visit_expr(name);
}
Stmt::Assign(ast::StmtAssign { targets, value, .. }) => {
Stmt::Assign(node) => {
let ast::StmtAssign { targets, value, range: _, node_index: _ } = &**node;
visitor.visit_expr(value);
for expr in targets {
visitor.visit_expr(expr);
}
}
Stmt::AugAssign(ast::StmtAugAssign {
target,
op,
value,
range: _,
node_index: _,
}) => {
Stmt::AugAssign(node) => {
let ast::StmtAugAssign {
target,
op,
value,
range: _,
node_index: _,
} = &**node;
visitor.visit_expr(value);
visitor.visit_operator(op);
visitor.visit_expr(target);
}
Stmt::AnnAssign(ast::StmtAnnAssign {
target,
annotation,
value,
..
}) => {
Stmt::AnnAssign(node) => {
let ast::StmtAnnAssign {
target,
annotation,
value,
simple: _,
range: _,
node_index: _,
} = &**node;
if let Some(expr) = value {
visitor.visit_expr(expr);
}
visitor.visit_annotation(annotation);
visitor.visit_expr(target);
}
Stmt::For(ast::StmtFor {
target,
iter,
body,
orelse,
..
}) => {
Stmt::For(node) => {
let ast::StmtFor {
target,
iter,
body,
orelse,
is_async: _,
range: _,
node_index: _,
} = &**node;
visitor.visit_expr(iter);
visitor.visit_expr(target);
visitor.visit_body(body);
visitor.visit_body(orelse);
}
Stmt::While(ast::StmtWhile {
test,
body,
orelse,
range: _,
node_index: _,
}) => {
Stmt::While(node) => {
let ast::StmtWhile {
test,
body,
orelse,
range: _,
node_index: _,
} = &**node;
visitor.visit_expr(test);
visitor.visit_body(body);
visitor.visit_body(orelse);
}
Stmt::If(ast::StmtIf {
test,
body,
elif_else_clauses,
range: _,
node_index: _,
}) => {
Stmt::If(node) => {
let ast::StmtIf {
test,
body,
elif_else_clauses,
range: _,
node_index: _,
} = &**node;
visitor.visit_expr(test);
visitor.visit_body(body);
for clause in elif_else_clauses {
@@ -271,29 +270,32 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) {
walk_elif_else_clause(visitor, clause);
}
}
Stmt::With(ast::StmtWith { items, body, .. }) => {
Stmt::With(node) => {
let ast::StmtWith { items, body, is_async: _, range: _, node_index: _ } = &**node;
for with_item in items {
visitor.visit_with_item(with_item);
}
visitor.visit_body(body);
}
Stmt::Match(ast::StmtMatch {
subject,
cases,
range: _,
node_index: _,
}) => {
Stmt::Match(node) => {
let ast::StmtMatch {
subject,
cases,
range: _,
node_index: _,
} = &**node;
visitor.visit_expr(subject);
for match_case in cases {
visitor.visit_match_case(match_case);
}
}
Stmt::Raise(ast::StmtRaise {
exc,
cause,
range: _,
node_index: _,
}) => {
Stmt::Raise(node) => {
let ast::StmtRaise {
exc,
cause,
range: _,
node_index: _,
} = &**node;
if let Some(expr) = exc {
visitor.visit_expr(expr);
}
@@ -301,15 +303,16 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) {
visitor.visit_expr(expr);
}
}
Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
is_star: _,
range: _,
node_index: _,
}) => {
Stmt::Try(node) => {
let ast::StmtTry {
body,
handlers,
orelse,
finalbody,
is_star: _,
range: _,
node_index: _,
} = &**node;
visitor.visit_body(body);
for except_handler in handlers {
visitor.visit_except_handler(except_handler);
@@ -317,27 +320,30 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) {
visitor.visit_body(orelse);
visitor.visit_body(finalbody);
}
Stmt::Assert(ast::StmtAssert {
test,
msg,
range: _,
node_index: _,
}) => {
Stmt::Assert(node) => {
let ast::StmtAssert {
test,
msg,
range: _,
node_index: _,
} = &**node;
visitor.visit_expr(test);
if let Some(expr) = msg {
visitor.visit_expr(expr);
}
}
Stmt::Import(ast::StmtImport {
names,
range: _,
node_index: _,
}) => {
Stmt::Import(node) => {
let ast::StmtImport {
names,
range: _,
node_index: _,
} = &**node;
for alias in names {
visitor.visit_alias(alias);
}
}
Stmt::ImportFrom(ast::StmtImportFrom { names, .. }) => {
Stmt::ImportFrom(node) => {
let ast::StmtImportFrom { names, module: _, level: _, range: _, node_index: _ } = &**node;
for alias in names {
visitor.visit_alias(alias);
}
@@ -348,7 +354,9 @@ pub fn walk_stmt<'a, V: Visitor<'a> + ?Sized>(visitor: &mut V, stmt: &'a Stmt) {
value,
range: _,
node_index: _,
}) => visitor.visit_expr(value),
}) => {
visitor.visit_expr(value);
}
Stmt::Pass(_) | Stmt::Break(_) | Stmt::Continue(_) | Stmt::IpyEscapeCommand(_) => {}
}
}

View File

@@ -121,43 +121,30 @@ pub fn walk_elif_else_clause<V: Transformer + ?Sized>(
pub fn walk_stmt<V: Transformer + ?Sized>(visitor: &V, stmt: &mut Stmt) {
match stmt {
Stmt::FunctionDef(ast::StmtFunctionDef {
parameters,
body,
decorator_list,
returns,
type_params,
..
}) => {
for decorator in decorator_list {
Stmt::FunctionDef(node) => {
for decorator in &mut node.decorator_list {
visitor.visit_decorator(decorator);
}
if let Some(type_params) = type_params {
if let Some(type_params) = &mut node.type_params {
visitor.visit_type_params(type_params);
}
visitor.visit_parameters(parameters);
if let Some(expr) = returns {
visitor.visit_parameters(&mut node.parameters);
if let Some(expr) = &mut node.returns {
visitor.visit_annotation(expr);
}
visitor.visit_body(body);
visitor.visit_body(&mut node.body);
}
Stmt::ClassDef(ast::StmtClassDef {
arguments,
body,
decorator_list,
type_params,
..
}) => {
for decorator in decorator_list {
Stmt::ClassDef(node) => {
for decorator in &mut node.decorator_list {
visitor.visit_decorator(decorator);
}
if let Some(type_params) = type_params {
if let Some(type_params) = &mut node.type_params {
visitor.visit_type_params(type_params);
}
if let Some(arguments) = arguments {
if let Some(arguments) = &mut node.arguments {
visitor.visit_arguments(arguments);
}
visitor.visit_body(body);
visitor.visit_body(&mut node.body);
}
Stmt::Return(ast::StmtReturn {
value,
@@ -168,116 +155,131 @@ pub fn walk_stmt<V: Transformer + ?Sized>(visitor: &V, stmt: &mut Stmt) {
visitor.visit_expr(expr);
}
}
Stmt::Delete(ast::StmtDelete {
targets,
range: _,
node_index: _,
}) => {
Stmt::Delete(node) => {
let ast::StmtDelete {
targets,
range: _,
node_index: _,
} = &mut **node;
for expr in targets {
visitor.visit_expr(expr);
}
}
Stmt::TypeAlias(ast::StmtTypeAlias {
range: _,
node_index: _,
name,
type_params,
value,
}) => {
Stmt::TypeAlias(node) => {
let ast::StmtTypeAlias {
range: _,
node_index: _,
name,
type_params,
value,
} = &mut **node;
visitor.visit_expr(value);
if let Some(type_params) = type_params {
visitor.visit_type_params(type_params);
}
visitor.visit_expr(name);
}
Stmt::Assign(ast::StmtAssign { targets, value, .. }) => {
Stmt::Assign(node) => {
let ast::StmtAssign { targets, value, range: _, node_index: _ } = &mut **node;
visitor.visit_expr(value);
for expr in targets {
visitor.visit_expr(expr);
}
}
Stmt::AugAssign(ast::StmtAugAssign {
target,
op,
value,
range: _,
node_index: _,
}) => {
Stmt::AugAssign(node) => {
let ast::StmtAugAssign {
target,
op,
value,
range: _,
node_index: _,
} = &mut **node;
visitor.visit_expr(value);
visitor.visit_operator(op);
visitor.visit_expr(target);
}
Stmt::AnnAssign(ast::StmtAnnAssign {
target,
annotation,
value,
..
}) => {
Stmt::AnnAssign(node) => {
let ast::StmtAnnAssign {
target,
annotation,
value,
simple: _,
range: _,
node_index: _,
} = &mut **node;
if let Some(expr) = value {
visitor.visit_expr(expr);
}
visitor.visit_annotation(annotation);
visitor.visit_expr(target);
}
Stmt::For(ast::StmtFor {
target,
iter,
body,
orelse,
..
}) => {
Stmt::For(node) => {
let ast::StmtFor {
target,
iter,
body,
orelse,
is_async: _,
range: _,
node_index: _,
} = &mut **node;
visitor.visit_expr(iter);
visitor.visit_expr(target);
visitor.visit_body(body);
visitor.visit_body(orelse);
}
Stmt::While(ast::StmtWhile {
test,
body,
orelse,
range: _,
node_index: _,
}) => {
Stmt::While(node) => {
let ast::StmtWhile {
test,
body,
orelse,
range: _,
node_index: _,
} = &mut **node;
visitor.visit_expr(test);
visitor.visit_body(body);
visitor.visit_body(orelse);
}
Stmt::If(ast::StmtIf {
test,
body,
elif_else_clauses,
range: _,
node_index: _,
}) => {
Stmt::If(node) => {
let ast::StmtIf {
test,
body,
elif_else_clauses,
range: _,
node_index: _,
} = &mut **node;
visitor.visit_expr(test);
visitor.visit_body(body);
for clause in elif_else_clauses {
walk_elif_else_clause(visitor, clause);
}
}
Stmt::With(ast::StmtWith { items, body, .. }) => {
Stmt::With(node) => {
let ast::StmtWith { items, body, is_async: _, range: _, node_index: _ } = &mut **node;
for with_item in items {
visitor.visit_with_item(with_item);
}
visitor.visit_body(body);
}
Stmt::Match(ast::StmtMatch {
subject,
cases,
range: _,
node_index: _,
}) => {
Stmt::Match(node) => {
let ast::StmtMatch {
subject,
cases,
range: _,
node_index: _,
} = &mut **node;
visitor.visit_expr(subject);
for match_case in cases {
visitor.visit_match_case(match_case);
}
}
Stmt::Raise(ast::StmtRaise {
exc,
cause,
range: _,
node_index: _,
}) => {
Stmt::Raise(node) => {
let ast::StmtRaise {
exc,
cause,
range: _,
node_index: _,
} = &mut **node;
if let Some(expr) = exc {
visitor.visit_expr(expr);
}
@@ -285,15 +287,16 @@ pub fn walk_stmt<V: Transformer + ?Sized>(visitor: &V, stmt: &mut Stmt) {
visitor.visit_expr(expr);
}
}
Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
is_star: _,
range: _,
node_index: _,
}) => {
Stmt::Try(node) => {
let ast::StmtTry {
body,
handlers,
orelse,
finalbody,
is_star: _,
range: _,
node_index: _,
} = &mut **node;
visitor.visit_body(body);
for except_handler in handlers {
visitor.visit_except_handler(except_handler);
@@ -301,27 +304,30 @@ pub fn walk_stmt<V: Transformer + ?Sized>(visitor: &V, stmt: &mut Stmt) {
visitor.visit_body(orelse);
visitor.visit_body(finalbody);
}
Stmt::Assert(ast::StmtAssert {
test,
msg,
range: _,
node_index: _,
}) => {
Stmt::Assert(node) => {
let ast::StmtAssert {
test,
msg,
range: _,
node_index: _,
} = &mut **node;
visitor.visit_expr(test);
if let Some(expr) = msg {
visitor.visit_expr(expr);
}
}
Stmt::Import(ast::StmtImport {
names,
range: _,
node_index: _,
}) => {
Stmt::Import(node) => {
let ast::StmtImport {
names,
range: _,
node_index: _,
} = &mut **node;
for alias in names {
visitor.visit_alias(alias);
}
}
Stmt::ImportFrom(ast::StmtImportFrom { names, .. }) => {
Stmt::ImportFrom(node) => {
let ast::StmtImportFrom { names, module: _, level: _, range: _, node_index: _ } = &mut **node;
for alias in names {
visitor.visit_alias(alias);
}
@@ -332,7 +338,9 @@ pub fn walk_stmt<V: Transformer + ?Sized>(visitor: &V, stmt: &mut Stmt) {
value,
range: _,
node_index: _,
}) => visitor.visit_expr(value),
}) => {
visitor.visit_expr(value);
}
Stmt::Pass(_) | Stmt::Break(_) | Stmt::Continue(_) | Stmt::IpyEscapeCommand(_) => {}
}
}