Add box to stmt

This commit is contained in:
Charlie Marsh
2025-12-11 11:24:31 -05:00
parent c9155d5e72
commit cc850ec348
123 changed files with 2523 additions and 2786 deletions

View File

@@ -122,6 +122,8 @@ class Group:
add_suffix_to_is_methods: bool
anynode_is_label: str
doc: str | None
box_variants: bool
unboxed_variants: set[str]
def __init__(self, group_name: str, group: dict[str, Any]) -> None:
self.name = group_name
@@ -130,10 +132,16 @@ class Group:
self.add_suffix_to_is_methods = group.get("add_suffix_to_is_methods", False)
self.anynode_is_label = group.get("anynode_is_label", to_snake_case(group_name))
self.doc = group.get("doc")
self.box_variants = group.get("box_variants", False)
self.unboxed_variants = set(group.get("unboxed_variants", []))
self.nodes = [
Node(self, node_name, node) for node_name, node in group["nodes"].items()
]
def is_boxed(self, node_name: str) -> bool:
"""Returns True if this node should be boxed in the owned enum."""
return self.box_variants and node_name not in self.unboxed_variants
@dataclass
class Node:
@@ -321,17 +329,29 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
out.append('#[cfg_attr(feature = "get-size", derive(get_size2::GetSize))]')
out.append(f"pub enum {group.owned_enum_ty} {{")
for node in group.nodes:
out.append(f"{node.variant}({node.ty}),")
if group.is_boxed(node.name):
out.append(f"{node.variant}(Box<{node.ty}>),")
else:
out.append(f"{node.variant}({node.ty}),")
out.append("}")
for node in group.nodes:
out.append(f"""
impl From<{node.ty}> for {group.owned_enum_ty} {{
fn from(node: {node.ty}) -> Self {{
Self::{node.variant}(node)
if group.is_boxed(node.name):
out.append(f"""
impl From<{node.ty}> for {group.owned_enum_ty} {{
fn from(node: {node.ty}) -> Self {{
Self::{node.variant}(Box::new(node))
}}
}}
}}
""")
""")
else:
out.append(f"""
impl From<{node.ty}> for {group.owned_enum_ty} {{
fn from(node: {node.ty}) -> Self {{
Self::{node.variant}(node)
}}
}}
""")
out.append(f"""
impl ruff_text_size::Ranged for {group.owned_enum_ty} {{
@@ -369,6 +389,9 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
match_arm = f"Self::{variant_name}"
if group.add_suffix_to_is_methods:
is_name = to_snake_case(node.variant + group.name)
is_boxed = group.is_boxed(node.name)
# For boxed variants, we need to dereference the box
unbox = "*" if is_boxed else ""
if len(group.nodes) > 1:
out.append(f"""
#[inline]
@@ -379,7 +402,7 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
#[inline]
pub fn {is_name}(self) -> Option<{node.ty}> {{
match self {{
{match_arm}(val) => Some(val),
{match_arm}(val) => Some({unbox}val),
_ => None,
}}
}}
@@ -387,7 +410,7 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
#[inline]
pub fn expect_{is_name}(self) -> {node.ty} {{
match self {{
{match_arm}(val) => val,
{match_arm}(val) => {unbox}val,
_ => panic!("called expect on {{self:?}}"),
}}
}}
@@ -418,14 +441,14 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
#[inline]
pub fn {is_name}(self) -> Option<{node.ty}> {{
match self {{
{match_arm}(val) => Some(val),
{match_arm}(val) => Some({unbox}val),
}}
}}
#[inline]
pub fn expect_{is_name}(self) -> {node.ty} {{
match self {{
{match_arm}(val) => val,
{match_arm}(val) => {unbox}val,
}}
}}