Compare commits

...

3 Commits

Author SHA1 Message Date
Douglas Creager
42945dc9dc formatter 2025-01-21 14:24:58 -05:00
Douglas Creager
20563db9c0 semantic 2025-01-21 14:24:58 -05:00
Douglas Creager
a3b5df8a64 ast 2025-01-21 14:24:36 -05:00
25 changed files with 5936 additions and 1819 deletions

1
Cargo.lock generated
View File

@@ -2917,6 +2917,7 @@ dependencies = [
"itertools 0.14.0",
"memchr",
"ruff_cache",
"ruff_index",
"ruff_macros",
"ruff_python_trivia",
"ruff_source_file",

View File

@@ -2,7 +2,7 @@ use std::fmt::Formatter;
use std::ops::Deref;
use std::sync::Arc;
use ruff_python_ast::{ModModule, PySourceType};
use ruff_python_ast::{ModModuleId, PySourceType};
use ruff_python_parser::{parse_unchecked_source, Parsed};
use crate::files::{File, FilePath};
@@ -43,24 +43,24 @@ pub fn parsed_module(db: &dyn Db, file: File) -> ParsedModule {
/// Cheap cloneable wrapper around the parsed module.
#[derive(Clone)]
pub struct ParsedModule {
inner: Arc<Parsed<ModModule>>,
inner: Arc<Parsed<ModModuleId>>,
}
impl ParsedModule {
pub fn new(parsed: Parsed<ModModule>) -> Self {
pub fn new(parsed: Parsed<ModModuleId>) -> Self {
Self {
inner: Arc::new(parsed),
}
}
/// Consumes `self` and returns the Arc storing the parsed module.
pub fn into_arc(self) -> Arc<Parsed<ModModule>> {
pub fn into_arc(self) -> Arc<Parsed<ModModuleId>> {
self.inner
}
}
impl Deref for ParsedModule {
type Target = Parsed<ModModule>;
type Target = Parsed<ModModuleId>;
fn deref(&self) -> &Self::Target {
&self.inner

View File

@@ -0,0 +1,321 @@
use heck::ToSnakeCase;
use proc_macro2::TokenStream;
use quote::quote;
use syn::spanned::Spanned;
use syn::{Attribute, Error, Fields, Ident, ItemEnum, Result, Type, Variant};
pub(crate) fn generate_ast_enum(input: ItemEnum) -> Result<TokenStream> {
let ast_enum = AstEnum::new(input)?;
let id_enum = generate_id_enum(&ast_enum);
let node_enum = generate_node_enum(&ast_enum);
let node_enum_node_method = generate_node_enum_node_method(&ast_enum);
let node_enum_ranged_impl = generate_node_enum_ranged_impl(&ast_enum);
let variant_ids = generate_variant_ids(&ast_enum);
let storage = generate_storage(&ast_enum);
Ok(quote! {
#id_enum
#node_enum
#node_enum_node_method
#node_enum_ranged_impl
#variant_ids
#storage
})
}
fn snake_case(node_ident: &Ident) -> Ident {
let node_string = node_ident.to_string().to_snake_case();
Ident::new(&node_string, node_ident.span())
}
fn concat(prefix: &str, id: &Ident, suffix: &str) -> Ident {
let mut id_string = id.to_string();
id_string.insert_str(0, prefix);
id_string.push_str(suffix);
Ident::new(&id_string, id.span())
}
/// Describes one of the enums that holds syntax nodes (e.g. Mod, Stmt)
struct AstEnum {
/// The base name of the enums (e.g. Mod, Stmt)
base_enum_name: Ident,
/// The syntax node variants for this enum
variants: Vec<AstVariant>,
}
/// Describes one specific syntax node (e.g. ModExpression, StmtIf)
struct AstVariant {
/// The name of the variant within its containing enum (e.g. Expression, If)
variant_name: Ident,
/// The struct type defining the contents of this syntax node (e.g. ModExpression, StmtIf)
node_ty: Ident,
/// All of the attributes attached to this variant
attrs: Vec<Attribute>,
}
impl AstEnum {
fn new(input: ItemEnum) -> Result<AstEnum> {
let base_enum_name = input.ident;
let variants: Result<Vec<_>> = input.variants.into_iter().map(AstVariant::new).collect();
let variants = variants?;
Ok(AstEnum {
base_enum_name,
variants,
})
}
fn map_variants<'a, B, F>(&'a self, f: F) -> impl Iterator<Item = B> + 'a
where
F: FnMut(&AstVariant) -> B + 'a,
{
self.variants.iter().map(f)
}
/// The name of the enum containing syntax node IDs (e.g. ModId, StmtId)
fn id_enum_ty(&self) -> Ident {
concat("", &self.base_enum_name, "Id")
}
/// The name of the enum containing references to syntax nodes (e.g. ModRef, StmtRef)
fn ref_enum_ty(&self) -> Ident {
concat("", &self.base_enum_name, "Ref")
}
/// The name of the storage type for this enum (e.g. ModStorage)
fn enum_storage_ty(&self) -> Ident {
concat("", &self.base_enum_name, "Storage")
}
/// The name of the storage field in Ast (e.g. mod_storage)
fn enum_storage_field(&self) -> Ident {
snake_case(&self.enum_storage_ty())
}
}
impl AstVariant {
fn new(variant: Variant) -> Result<AstVariant> {
let Fields::Unnamed(fields) = &variant.fields else {
return Err(Error::new(
variant.fields.span(),
"Each AstNode variant must have a single unnamed field",
));
};
let mut fields = fields.unnamed.iter();
let field = fields.next().ok_or_else(|| {
Error::new(
variant.fields.span(),
"Each AstNode variant must have a single unnamed field",
)
})?;
if fields.next().is_some() {
return Err(Error::new(
variant.fields.span(),
"Each AstNode variant must have a single unnamed field",
));
}
let Type::Path(field_ty) = &field.ty else {
return Err(Error::new(
field.ty.span(),
"Each AstNode variant must wrap a simple Id type",
));
};
let node_ty = field_ty.path.require_ident()?.clone();
Ok(AstVariant {
variant_name: variant.ident,
node_ty,
attrs: variant.attrs,
})
}
/// The name of the ID type for this variant's syntax node (e.g. ModExpressionId, StmtIfId)
fn id_ty(&self) -> Ident {
concat("", &self.node_ty, "Id")
}
/// The name of the storage field in the containing enum storage type (e.g.
/// mod_expression_storage)
fn variant_storage_field(&self) -> Ident {
concat("", &snake_case(&self.node_ty), "_storage")
}
/// The name of the method that adds a new syntax node to an [Ast] (e.g. `add_mod_expression`)
}
/// Generates the enum containing syntax node IDs (e.g. ModId, StmtId)
fn generate_id_enum(ast_enum: &AstEnum) -> TokenStream {
let id_enum_ty = ast_enum.id_enum_ty();
let id_enum_variants = ast_enum.map_variants(|v| {
let AstVariant {
variant_name,
attrs,
..
} = v;
let id_ty = v.id_ty();
quote! {
#( #attrs )*
#variant_name(#id_ty)
}
});
quote! {
#[automatically_derived]
#[derive(Copy, Clone, Debug, PartialEq, is_macro::Is)]
pub enum #id_enum_ty {
#( #id_enum_variants ),*
}
}
}
fn generate_ref_enum(ast_enum: &AstEnum) -> TokenStream {
let ref_enum_ty = ast_enum.ref_enum_ty();
let variants = ast_enum.map_variants(|v| {
let AstVariant {
attrs,
variant_name,
node_ty,
..
} = v;
quote! {
#( #attrs )*
#variant_name(crate::Node<'a, &'a #node_ty>)
}
});
quote! {
#[automatically_derived]
#[derive(Copy, Clone, Debug, PartialEq, is_macro::Is)]
pub enum #node_ident<'a> {
#( #variants ),*
}
}
}
fn generate_node_enum_node_method(ast_enum: &AstEnum) -> TokenStream {
let id_enum_ty = ast_enum.id_enum_ty();
let ref_enum_ty = ast_enum.ref_enum_ty();
let variants = ast_enum.map_variants(|v| {
let AstVariant { variant_name, .. } = v;
quote! { #id_enum_ty::#variant_name(id) => #ref_enum_ty::#variant_name(self.ast.wrap(&self.ast[id])) }
});
quote! {
#[automatically_derived]
impl<'a> crate::Node<'a, #id_enum_ty> {
#[inline]
pub fn node(&self) -> #ref_enum_ty<'a> {
match self.node {
#( #variants ),*
}
}
}
}
}
fn generate_node_enum_ranged_impl(ast_enum: &AstEnum) -> TokenStream {
let ref_enum_ty = ast_enum.ref_enum_ty();
let variants = ast_enum.map_variants(|v| {
let AstVariant { variant_name, .. } = v;
quote! { #ref_enum_ty::#variant_name(node) => node.range() }
});
quote! {
#[automatically_derived]
impl ruff_text_size::Ranged for #ref_enum_ty<'_> {
fn range(&self) -> ruff_text_size::TextRange {
match self {
#( #variants ),*
}
}
}
}
}
/// Generates the ID type for each syntax node struct in this enum.
///
/// We also define:
/// - [Index] and [IndexMut] impls so that you can index into an [Ast] using the ID type
/// - a `node` method on e.g. `Node<StmtIfId>` that returns a `Node<&StmtIf>`
/// - [Ranged] impls for the `StmtIf` and `Node<&StmtIf>`
fn generate_ids(ast_enum: &AstEnum) -> TokenStream {
let id_enum_ty = ast_enum.id_enum_ty();
let enum_storage_field = ast_enum.enum_storage_field();
let variants = ast_enum.map_variants(|v| {
let AstVariant { node_ty, .. } = v;
let id_ty = v.id_ty();
let variant_storage_field = v.variant_storage_field();
quote! {
#[automatically_derived]
#[ruff_index::newtype_index]
pub struct #id_ty;
#[automatically_derived]
impl std::ops::Index<#id_ty> for crate::Ast {
type Output = #node_ty;
#[inline]
fn index(&self, id: #id_ty) -> &#node_ty {
&self.#enum_storage_field.#variant_storage_field[id]
}
}
#[automatically_derived]
impl std::ops::IndexMut<#id_ty> for crate::Ast {
#[inline]
fn index_mut(&mut self, id: #id_ty) -> &mut #node_ty {
&mut self.#enum_storage_field.#variant_storage_field[id]
}
}
#[automatically_derived]
impl<'a> crate::Node<'a, #id_ty> {
#[inline]
pub fn node(&self) -> crate::Node<'a, &'a #node_ty> {
self.ast.wrap(&self.ast[self.node])
}
}
#[automatically_derived]
impl<'a> ruff_text_size::Ranged for #node_ty {
fn range(&self) -> TextRange {
self.range
}
}
#[automatically_derived]
impl<'a> ruff_text_size::Ranged for crate::Node<'a, &'a #node_ty> {
fn range(&self) -> TextRange {
self.as_ref().range()
}
}
}
});
quote! { #( #variants )* }
}
fn generate_storage(ast_enum: &AstEnum) -> TokenStream {
let id_enum_ty = ast_enum.id_enum_ty();
let enum_storage_ty = ast_enum.enum_storage_ty();
let enum_storage_field = ast_enum.enum_storage_field();
let storage_fields = ast_enum.map_variants(|v| {
let AstVariant { id_ty, node_ty, .. } = v;
let variant_storage_field = v.variant_storage_field();
quote! { #variant_storage_field: ruff_index::IndexVec<#id_ty, #node_ty> }
});
let add_methods = ast_enum.map_variants(|v| {
let AstVariant { node_ty, .. } = v;
let variant_storage_field = v.variant_storage_field();
let method_name = concat("add_", vec_name, "");
quote! {
#[automatically_derived]
impl crate::Ast {
pub fn #method_name(&mut self, payload: #node_ty) -> #id_ident {
#id_ident::#variant_name(self.#storage_field.#vec_name.push(payload))
}
}
}
});
quote! {
#[automatically_derived]
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, Default, PartialEq)]
pub(crate) struct #storage_ty {
#( #storage_fields ),*
}
#( #add_methods )*
}
}

View File

@@ -14,6 +14,7 @@ license = { workspace = true }
[dependencies]
ruff_cache = { workspace = true, optional = true }
ruff_index = { workspace = true }
ruff_macros = { workspace = true, optional = true }
ruff_python_trivia = { workspace = true }
ruff_source_file = { workspace = true }

View File

@@ -67,6 +67,7 @@ class Ast:
class Group:
name: str
nodes: list[Node]
id_enum_ty: str
owned_enum_ty: str
add_suffix_to_is_methods: bool
@@ -75,6 +76,7 @@ class Group:
def __init__(self, group_name: str, group: dict[str, Any]) -> None:
self.name = group_name
self.id_enum_ty = group_name + "Id"
self.owned_enum_ty = group_name
self.ref_enum_ty = group_name + "Ref"
self.add_suffix_to_is_methods = group.get("add_suffix_to_is_methods", False)
@@ -89,12 +91,16 @@ class Group:
class Node:
name: str
variant: str
id_ty: str
ty: str
storage_field: str
def __init__(self, group: Group, node_name: str, node: dict[str, Any]) -> None:
self.name = node_name
self.variant = node.get("variant", node_name.removeprefix(group.name))
self.id_ty = node_name + "Id"
self.ty = f"crate::{node_name}"
self.storage_field = to_snake_case(node_name)
# ------------------------------------------------------------------------------
@@ -108,6 +114,75 @@ def write_preamble(out: list[str]) -> None:
""")
# ------------------------------------------------------------------------------
# ID enum
def write_ids(out: list[str], ast: Ast) -> None:
"""
Create an ID type for each syntax node, and a per-group enum that contains a
syntax node ID.
```rust
#[newindex_type]
pub struct TypeParamTypeVarId;
#[newindex_type]
pub struct TypeParamTypeVarTuple;
...
pub enum TypeParamId {
TypeVar(TypeParamTypeVarId),
TypeVarTuple(TypeParamTypeVarTupleId),
...
}
```
Also creates:
- `impl From<TypeParamTypeVarId> for TypeParamId`
- `impl Ranged for TypeParamTypeVar`
- `fn TypeParamId::is_type_var() -> bool`
If the `add_suffix_to_is_methods` group option is true, then the
`is_type_var` method will be named `is_type_var_type_param`.
"""
for node in ast.all_nodes:
out.append("")
out.append("#[ruff_index::newtype_index]")
out.append(f"pub struct {node.id_ty};")
out.append(f"""
impl ruff_text_size::Ranged for {node.ty} {{
fn range(&self) -> ruff_text_size::TextRange {{
self.range
}}
}}
""")
for group in ast.groups:
out.append("")
if group.rustdoc is not None:
out.append(group.rustdoc)
out.append("#[derive(Clone, Copy, Debug, PartialEq, is_macro::Is)]")
out.append(f"pub enum {group.id_enum_ty} {{")
for node in group.nodes:
if group.add_suffix_to_is_methods:
is_name = to_snake_case(node.variant + group.name)
out.append(f'#[is(name = "{is_name}")]')
out.append(f"{node.variant}({node.id_ty}),")
out.append("}")
for node in group.nodes:
out.append(f"""
impl From<{node.id_ty}> for {group.id_enum_ty} {{
fn from(id: {node.id_ty}) -> Self {{
Self::{node.variant}(id)
}}
}}
""")
# ------------------------------------------------------------------------------
# Owned enum
@@ -126,7 +201,6 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
Also creates:
- `impl Ranged for TypeParam`
- `TypeParam::visit_source_order`
- `impl From<TypeParamTypeVar> for TypeParam`
- `impl Ranged for TypeParamTypeVar`
- `fn TypeParam::is_type_var() -> bool`
@@ -170,15 +244,6 @@ def write_owned_enum(out: list[str], ast: Ast) -> None:
}
""")
for node in ast.all_nodes:
out.append(f"""
impl ruff_text_size::Ranged for {node.ty} {{
fn range(&self) -> ruff_text_size::TextRange {{
self.range
}}
}}
""")
for group in ast.groups:
out.append(f"""
impl {group.owned_enum_ty} {{
@@ -210,17 +275,18 @@ def write_ref_enum(out: list[str], ast: Ast) -> None:
```rust
pub enum TypeParamRef<'a> {
TypeVar(&'a TypeParamTypeVar),
TypeVarTuple(&'a TypeParamTypeVarTuple),
TypeVar(Node<'a, &'a TypeParamTypeVar>),
TypeVarTuple(Node<'a, &'a TypeParamTypeVarTuple>),
...
}
```
Also creates:
- `impl<'a> From<&'a TypeParam> for TypeParamRef<'a>`
- `impl<'a> From<&'a TypeParamTypeVar> for TypeParamRef<'a>`
- `impl<'a> From<Node<'a, &'a TypeParam>> for TypeParamRef<'a>`
- `impl<'a> From<Node<'a, &'a TypeParamTypeVar>> for TypeParamRef<'a>`
- `impl Ranged for TypeParamRef<'_>`
- `fn TypeParamRef::is_type_var() -> bool`
- `TypeParamRef::visit_source_order`
The name of each variant can be customized via the `variant` node option. If
the `add_suffix_to_is_methods` group option is true, then the `is_type_var`
@@ -237,17 +303,17 @@ def write_ref_enum(out: list[str], ast: Ast) -> None:
if group.add_suffix_to_is_methods:
is_name = to_snake_case(node.variant + group.name)
out.append(f'#[is(name = "{is_name}")]')
out.append(f"""{node.variant}(&'a {node.ty}),""")
out.append(f"""{node.variant}(crate::Node<'a, &'a {node.ty}>),""")
out.append("}")
out.append(f"""
impl<'a> From<&'a {group.owned_enum_ty}> for {group.ref_enum_ty}<'a> {{
fn from(node: &'a {group.owned_enum_ty}) -> Self {{
match node {{
impl<'a> From<crate::Node<'a, &'a {group.owned_enum_ty}>> for {group.ref_enum_ty}<'a> {{
fn from(node: crate::Node<'a, &'a {group.owned_enum_ty}>) -> Self {{
match node.node {{
""")
for node in group.nodes:
out.append(
f"{group.owned_enum_ty}::{node.variant}(node) => {group.ref_enum_ty}::{node.variant}(node),"
f"""{group.owned_enum_ty}::{node.variant}(n) => {group.ref_enum_ty}::{node.variant}(node.ast.wrap(n)),"""
)
out.append("""
}
@@ -257,8 +323,8 @@ def write_ref_enum(out: list[str], ast: Ast) -> None:
for node in group.nodes:
out.append(f"""
impl<'a> From<&'a {node.ty}> for {group.ref_enum_ty}<'a> {{
fn from(node: &'a {node.ty}) -> Self {{
impl<'a> From<crate::Node<'a, &'a {node.ty}>> for {group.ref_enum_ty}<'a> {{
fn from(node: crate::Node<'a, &'a {node.ty}>) -> Self {{
Self::{node.variant}(node)
}}
}}
@@ -277,6 +343,112 @@ def write_ref_enum(out: list[str], ast: Ast) -> None:
}
""")
for group in ast.groups:
out.append(f"""
impl<'a> {group.ref_enum_ty}<'a> {{
#[allow(unused)]
pub(crate) fn visit_source_order<V>(self, visitor: &mut V)
where
V: crate::visitor::source_order::SourceOrderVisitor<'a> + ?Sized,
{{
match self {{
""")
for node in group.nodes:
out.append(
f"""{group.ref_enum_ty}::{node.variant}(node) => node.visit_source_order(visitor),"""
)
out.append("""
}
}
}
""")
# ------------------------------------------------------------------------------
# AST storage
def write_storage(out: list[str], ast: Ast) -> None:
"""
Create the storage struct for all of the syntax nodes.
```rust
pub(crate) struct Storage {
...
pub(crate) type_param_type_var_id: IndexVec<TypeParamTypeVarId, TypeParamTypeVar>,
pub(crate) type_param_type_var_tuple_id: IndexVec<TypeParamTypeVarTupleId, TypeParamTypeVarTuple>,
...
}
```
Also creates:
- `impl AstId for TypeParamTypeVarId for Ast`
- `impl AstIdMut for TypeParamTypeVarId for Ast`
"""
out.append("")
out.append("#[derive(Clone, Default, PartialEq)]")
out.append("pub(crate) struct Storage {")
for node in ast.all_nodes:
out.append(f"""pub(crate) {node.storage_field}: ruff_index::IndexVec<{node.id_ty}, {node.ty}>,""")
out.append("}")
for node in ast.all_nodes:
out.append(f"""
impl crate::ast::AstId for {node.id_ty} {{
type Output<'a> = crate::Node<'a, &'a {node.ty}>;
#[inline]
fn node<'a>(self, ast: &'a crate::Ast) -> Self::Output<'a> {{
ast.wrap(&ast.storage.{node.storage_field}[self])
}}
}}
""")
out.append(f"""
impl crate::ast::AstIdMut for {node.id_ty} {{
type Output<'a> = crate::Node<'a, &'a mut {node.ty}>;
#[inline]
fn node_mut<'a>(self, ast: &'a mut crate::Ast) -> Self::Output<'a> {{
ast.wrap(&mut ast.storage.{node.storage_field}[self])
}}
}}
""")
out.append(f"""
impl<'a> crate::Node<'a, {node.id_ty}> {{
#[inline]
pub fn node(self) -> crate::Node<'a, &'a {node.ty}> {{
self.ast.node(self.node)
}}
}}
""")
for group in ast.groups:
out.append(f"""
impl crate::ast::AstId for {group.id_enum_ty} {{
type Output<'a> = {group.ref_enum_ty}<'a>;
#[inline]
fn node<'a>(self, ast: &'a crate::Ast) -> Self::Output<'a> {{
match self {{
""")
for node in group.nodes:
out.append(f"""{group.id_enum_ty}::{node.variant}(node) => {group.ref_enum_ty}::{node.variant}(ast.node(node)),""")
out.append(f"""
}}
}}
}}
""")
out.append(f"""
impl<'a> crate::Node<'a, {group.id_enum_ty}> {{
#[inline]
pub fn node(self) -> crate::Node<'a, &'a {node.ty}> {{
self.ast.node(self.node)
}}
}}
""")
# ------------------------------------------------------------------------------
# AnyNodeRef
@@ -289,16 +461,16 @@ def write_anynoderef(out: list[str], ast: Ast) -> None:
```rust
pub enum AnyNodeRef<'a> {
...
TypeParamTypeVar(&'a TypeParamTypeVar),
TypeParamTypeVarTuple(&'a TypeParamTypeVarTuple),
TypeParamTypeVar(Node<'a, &'a TypeParamTypeVar>),
TypeParamTypeVarTuple(Node<'a, &'a TypeParamTypeVarTuple>),
...
}
```
Also creates:
- `impl<'a> From<&'a TypeParam> for AnyNodeRef<'a>`
- `impl<'a> From<TypeParamRef<'a>> for AnyNodeRef<'a>`
- `impl<'a> From<&'a TypeParamTypeVarTuple> for AnyNodeRef<'a>`
- `impl<'a> From<Node<'a, &'a TypeParam>> for AnyNodeRef<'a>`
- `impl<'a> From<Node<'a, &'a TypeParamTypeVarTuple>> for AnyNodeRef<'a>`
- `impl Ranged for AnyNodeRef<'_>`
- `fn AnyNodeRef::as_ptr(&self) -> std::ptr::NonNull<()>`
- `fn AnyNodeRef::visit_preorder(self, visitor &mut impl SourceOrderVisitor)`
@@ -309,20 +481,20 @@ def write_anynoderef(out: list[str], ast: Ast) -> None:
pub enum AnyNodeRef<'a> {
""")
for node in ast.all_nodes:
out.append(f"""{node.name}(&'a {node.ty}),""")
out.append(f"""{node.name}(crate::Node<'a, &'a {node.ty}>),""")
out.append("""
}
""")
for group in ast.groups:
out.append(f"""
impl<'a> From<&'a {group.owned_enum_ty}> for AnyNodeRef<'a> {{
fn from(node: &'a {group.owned_enum_ty}) -> AnyNodeRef<'a> {{
match node {{
impl<'a> From<crate::Node<'a, &'a {group.owned_enum_ty}>> for AnyNodeRef<'a> {{
fn from(node: crate::Node<'a, &'a {group.owned_enum_ty}>) -> AnyNodeRef<'a> {{
match node.node {{
""")
for node in group.nodes:
out.append(
f"{group.owned_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
f"{group.owned_enum_ty}::{node.variant}(n) => AnyNodeRef::{node.name}(node.ast.wrap(n)),"
)
out.append("""
}
@@ -347,8 +519,8 @@ def write_anynoderef(out: list[str], ast: Ast) -> None:
for node in ast.all_nodes:
out.append(f"""
impl<'a> From<&'a {node.ty}> for AnyNodeRef<'a> {{
fn from(node: &'a {node.ty}) -> AnyNodeRef<'a> {{
impl<'a> From<crate::Node<'a, &'a {node.ty}>> for AnyNodeRef<'a> {{
fn from(node: crate::Node<'a, &'a {node.ty}>) -> AnyNodeRef<'a> {{
AnyNodeRef::{node.name}(node)
}}
}}
@@ -374,7 +546,7 @@ def write_anynoderef(out: list[str], ast: Ast) -> None:
""")
for node in ast.all_nodes:
out.append(
f"AnyNodeRef::{node.name}(node) => std::ptr::NonNull::from(*node).cast(),"
f"AnyNodeRef::{node.name}(node) => std::ptr::NonNull::from(node.as_ref()).cast(),"
)
out.append("""
}
@@ -470,8 +642,10 @@ def write_nodekind(out: list[str], ast: Ast) -> None:
def generate(ast: Ast) -> list[str]:
out = []
write_preamble(out)
write_ids(out, ast)
write_owned_enum(out, ast)
write_ref_enum(out, ast)
write_storage(out, ast)
write_anynoderef(out, ast)
write_nodekind(out, ast)
return out

View File

@@ -0,0 +1,92 @@
#![allow(clippy::derive_partial_eq_without_eq)]
use std::ops::{Deref, Index};
use crate as ast;
#[derive(Clone, Default, PartialEq)]
pub struct Ast {
pub(crate) storage: ast::Storage,
}
impl Ast {
#[inline]
pub fn wrap<T>(&self, node: T) -> Node<T> {
Node { ast: self, node }
}
#[inline]
pub fn node<'a, I>(&'a self, id: I) -> <I as AstId>::Output<'a>
where
I: AstId,
{
id.node(self)
}
}
impl std::fmt::Debug for Ast {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Ast").finish()
}
}
pub trait AstId {
type Output<'a>;
fn node<'a>(self, ast: &'a Ast) -> Self::Output<'a>;
}
pub trait AstIdMut {
type Output<'a>;
fn node_mut<'a>(self, ast: &'a mut Ast) -> Self::Output<'a>;
}
#[derive(Clone, Copy)]
pub struct Node<'ast, T> {
pub ast: &'ast Ast,
pub node: T,
}
impl<T> Node<'_, T> {
pub fn as_ref(&self) -> &T {
&self.node
}
}
impl<T> std::fmt::Debug for Node<'_, T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Node").field(&self.node).finish()
}
}
impl<T> Deref for Node<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.node
}
}
impl<T> Eq for Node<'_, T> where T: Eq {}
impl<T> std::hash::Hash for Node<'_, T>
where
T: std::hash::Hash,
{
fn hash<H>(&self, state: &mut H)
where
H: std::hash::Hasher,
{
self.node.hash(state);
}
}
impl<T> PartialEq for Node<'_, T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.node == other.node
}
}

View File

@@ -16,7 +16,7 @@
//! have the same shape in that they evaluate to the same value.
use crate as ast;
use crate::{Expr, Number};
use crate::{Expr, Node, Number};
use std::borrow::Cow;
use std::hash::Hash;
@@ -593,18 +593,14 @@ impl<'a> From<ast::LiteralExpressionRef<'a>> for ComparableLiteral<'a> {
match literal {
ast::LiteralExpressionRef::NoneLiteral(_) => Self::None,
ast::LiteralExpressionRef::EllipsisLiteral(_) => Self::Ellipsis,
ast::LiteralExpressionRef::BooleanLiteral(ast::ExprBooleanLiteral {
value, ..
}) => Self::Bool(value),
ast::LiteralExpressionRef::StringLiteral(ast::ExprStringLiteral { value, .. }) => {
Self::Str(value.iter().map(Into::into).collect())
ast::LiteralExpressionRef::BooleanLiteral(node) => Self::Bool(&node.value),
ast::LiteralExpressionRef::StringLiteral(node) => {
Self::Str(node.value.iter().map(Into::into).collect())
}
ast::LiteralExpressionRef::BytesLiteral(ast::ExprBytesLiteral { value, .. }) => {
Self::Bytes(value.iter().map(Into::into).collect())
}
ast::LiteralExpressionRef::NumberLiteral(ast::ExprNumberLiteral { value, .. }) => {
Self::Number(value.into())
ast::LiteralExpressionRef::BytesLiteral(node) => {
Self::Bytes(node.value.iter().map(Into::into).collect())
}
ast::LiteralExpressionRef::NumberLiteral(node) => Self::Number((&node.value).into()),
}
}
}
@@ -1437,9 +1433,9 @@ pub enum ComparableStmt<'a> {
Continue,
}
impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
fn from(stmt: &'a ast::Stmt) -> Self {
match stmt {
impl<'a> From<Node<'a, ast::StmtRef<'a>>> for ComparableStmt<'a> {
fn from(stmt: Node<'a, ast::StmtRef<'a>>) -> Self {
match stmt.node {
ast::Stmt::FunctionDef(ast::StmtFunctionDef {
is_async,
name,
@@ -1451,7 +1447,7 @@ impl<'a> From<&'a ast::Stmt> for ComparableStmt<'a> {
range: _,
}) => Self::FunctionDef(StmtFunctionDef {
is_async: *is_async,
name: name.as_str(),
name: stmt.name().as_str(),
parameters: parameters.into(),
body: body.iter().map(Into::into).collect(),
decorator_list: decorator_list.iter().map(Into::into).collect(),

View File

@@ -3,26 +3,20 @@ use std::iter::FusedIterator;
use ruff_text_size::{Ranged, TextRange};
use crate::{
self as ast, AnyNodeRef, AnyStringFlags, Expr, ExprBytesLiteral, ExprFString, ExprRef,
ExprStringLiteral, StringFlags,
self as ast, AnyNodeRef, AnyStringFlags, Ast, Expr, ExprBytesLiteral, ExprFString, ExprRef,
ExprStringLiteral, Node, StringFlags,
};
impl<'a> From<&'a Box<Expr>> for ExprRef<'a> {
fn from(value: &'a Box<Expr>) -> Self {
ExprRef::from(value.as_ref())
}
}
/// Unowned pendant to all the literal variants of [`ast::Expr`] that stores a
/// reference instead of an owned value.
#[derive(Copy, Clone, Debug, PartialEq, is_macro::Is)]
pub enum LiteralExpressionRef<'a> {
StringLiteral(&'a ast::ExprStringLiteral),
BytesLiteral(&'a ast::ExprBytesLiteral),
NumberLiteral(&'a ast::ExprNumberLiteral),
BooleanLiteral(&'a ast::ExprBooleanLiteral),
NoneLiteral(&'a ast::ExprNoneLiteral),
EllipsisLiteral(&'a ast::ExprEllipsisLiteral),
StringLiteral(Node<'a, &'a ast::ExprStringLiteral>),
BytesLiteral(Node<'a, &'a ast::ExprBytesLiteral>),
NumberLiteral(Node<'a, &'a ast::ExprNumberLiteral>),
BooleanLiteral(Node<'a, &'a ast::ExprBooleanLiteral>),
NoneLiteral(Node<'a, &'a ast::ExprNoneLiteral>),
EllipsisLiteral(Node<'a, &'a ast::ExprEllipsisLiteral>),
}
impl Ranged for LiteralExpressionRef<'_> {
@@ -83,9 +77,9 @@ impl LiteralExpressionRef<'_> {
/// literals, bytes literals, and f-strings.
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum StringLike<'a> {
String(&'a ast::ExprStringLiteral),
Bytes(&'a ast::ExprBytesLiteral),
FString(&'a ast::ExprFString),
String(Node<'a, &'a ast::ExprStringLiteral>),
Bytes(Node<'a, &'a ast::ExprBytesLiteral>),
FString(Node<'a, &'a ast::ExprFString>),
}
impl<'a> StringLike<'a> {
@@ -96,18 +90,18 @@ impl<'a> StringLike<'a> {
/// Returns an iterator over the [`StringLikePart`] contained in this string-like expression.
pub fn parts(&self) -> StringLikePartIter<'a> {
match self {
StringLike::String(expr) => StringLikePartIter::String(expr.value.iter()),
StringLike::Bytes(expr) => StringLikePartIter::Bytes(expr.value.iter()),
StringLike::FString(expr) => StringLikePartIter::FString(expr.value.iter()),
StringLike::String(expr) => StringLikePartIter::String(expr.ast, expr.value.iter()),
StringLike::Bytes(expr) => StringLikePartIter::Bytes(expr.ast, expr.value.iter()),
StringLike::FString(expr) => StringLikePartIter::FString(expr.ast, expr.value.iter()),
}
}
/// Returns `true` if the string is implicitly concatenated.
pub fn is_implicit_concatenated(self) -> bool {
match self {
Self::String(ExprStringLiteral { value, .. }) => value.is_implicit_concatenated(),
Self::Bytes(ExprBytesLiteral { value, .. }) => value.is_implicit_concatenated(),
Self::FString(ExprFString { value, .. }) => value.is_implicit_concatenated(),
Self::String(node) => node.value.is_implicit_concatenated(),
Self::Bytes(node) => node.value.is_implicit_concatenated(),
Self::FString(node) => node.value.is_implicit_concatenated(),
}
}
@@ -120,26 +114,26 @@ impl<'a> StringLike<'a> {
}
}
impl<'a> From<&'a ast::ExprStringLiteral> for StringLike<'a> {
fn from(value: &'a ast::ExprStringLiteral) -> Self {
impl<'a> From<Node<'a, &'a ast::ExprStringLiteral>> for StringLike<'a> {
fn from(value: Node<'a, &'a ast::ExprStringLiteral>) -> Self {
StringLike::String(value)
}
}
impl<'a> From<&'a ast::ExprBytesLiteral> for StringLike<'a> {
fn from(value: &'a ast::ExprBytesLiteral) -> Self {
impl<'a> From<Node<'a, &'a ast::ExprBytesLiteral>> for StringLike<'a> {
fn from(value: Node<'a, &'a ast::ExprBytesLiteral>) -> Self {
StringLike::Bytes(value)
}
}
impl<'a> From<&'a ast::ExprFString> for StringLike<'a> {
fn from(value: &'a ast::ExprFString) -> Self {
impl<'a> From<Node<'a, &'a ast::ExprFString>> for StringLike<'a> {
fn from(value: Node<'a, &'a ast::ExprFString>) -> Self {
StringLike::FString(value)
}
}
impl<'a> From<&StringLike<'a>> for ExprRef<'a> {
fn from(value: &StringLike<'a>) -> Self {
impl<'a> From<StringLike<'a>> for ExprRef<'a> {
fn from(value: StringLike<'a>) -> Self {
match value {
StringLike::String(expr) => ExprRef::StringLiteral(expr),
StringLike::Bytes(expr) => ExprRef::BytesLiteral(expr),
@@ -150,12 +144,6 @@ impl<'a> From<&StringLike<'a>> for ExprRef<'a> {
impl<'a> From<StringLike<'a>> for AnyNodeRef<'a> {
fn from(value: StringLike<'a>) -> Self {
AnyNodeRef::from(&value)
}
}
impl<'a> From<&StringLike<'a>> for AnyNodeRef<'a> {
fn from(value: &StringLike<'a>) -> Self {
match value {
StringLike::String(expr) => AnyNodeRef::ExprStringLiteral(expr),
StringLike::Bytes(expr) => AnyNodeRef::ExprBytesLiteral(expr),
@@ -164,14 +152,14 @@ impl<'a> From<&StringLike<'a>> for AnyNodeRef<'a> {
}
}
impl<'a> TryFrom<&'a Expr> for StringLike<'a> {
impl<'a> TryFrom<Node<'a, &'a Expr>> for StringLike<'a> {
type Error = ();
fn try_from(value: &'a Expr) -> Result<Self, Self::Error> {
match value {
Expr::StringLiteral(value) => Ok(Self::String(value)),
Expr::BytesLiteral(value) => Ok(Self::Bytes(value)),
Expr::FString(value) => Ok(Self::FString(value)),
fn try_from(value: Node<'a, &'a Expr>) -> Result<Self, Self::Error> {
match value.node {
Expr::StringLiteral(v) => Ok(Self::String(value.ast.wrap(v))),
Expr::BytesLiteral(v) => Ok(Self::Bytes(value.ast.wrap(v))),
Expr::FString(v) => Ok(Self::FString(value.ast.wrap(v))),
_ => Err(()),
}
}
@@ -203,9 +191,9 @@ impl Ranged for StringLike<'_> {
/// An enum that holds a reference to an individual part of a string-like expression.
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum StringLikePart<'a> {
String(&'a ast::StringLiteral),
Bytes(&'a ast::BytesLiteral),
FString(&'a ast::FString),
String(Node<'a, &'a ast::StringLiteral>),
Bytes(Node<'a, &'a ast::BytesLiteral>),
FString(Node<'a, &'a ast::FString>),
}
impl<'a> StringLikePart<'a> {
@@ -231,7 +219,7 @@ impl<'a> StringLikePart<'a> {
matches!(self, Self::String(_))
}
pub const fn as_string_literal(self) -> Option<&'a ast::StringLiteral> {
pub const fn as_string_literal(self) -> Option<Node<'a, &'a ast::StringLiteral>> {
match self {
StringLikePart::String(value) => Some(value),
_ => None,
@@ -243,30 +231,24 @@ impl<'a> StringLikePart<'a> {
}
}
impl<'a> From<&'a ast::StringLiteral> for StringLikePart<'a> {
fn from(value: &'a ast::StringLiteral) -> Self {
impl<'a> From<Node<'a, &'a ast::StringLiteral>> for StringLikePart<'a> {
fn from(value: Node<'a, &'a ast::StringLiteral>) -> Self {
StringLikePart::String(value)
}
}
impl<'a> From<&'a ast::BytesLiteral> for StringLikePart<'a> {
fn from(value: &'a ast::BytesLiteral) -> Self {
impl<'a> From<Node<'a, &'a ast::BytesLiteral>> for StringLikePart<'a> {
fn from(value: Node<'a, &'a ast::BytesLiteral>) -> Self {
StringLikePart::Bytes(value)
}
}
impl<'a> From<&'a ast::FString> for StringLikePart<'a> {
fn from(value: &'a ast::FString) -> Self {
impl<'a> From<Node<'a, &'a ast::FString>> for StringLikePart<'a> {
fn from(value: Node<'a, &'a ast::FString>) -> Self {
StringLikePart::FString(value)
}
}
impl<'a> From<&StringLikePart<'a>> for AnyNodeRef<'a> {
fn from(value: &StringLikePart<'a>) -> Self {
AnyNodeRef::from(*value)
}
}
impl<'a> From<StringLikePart<'a>> for AnyNodeRef<'a> {
fn from(value: StringLikePart<'a>) -> Self {
match value {
@@ -292,9 +274,9 @@ impl Ranged for StringLikePart<'_> {
/// This is created by the [`StringLike::parts`] method.
#[derive(Clone)]
pub enum StringLikePartIter<'a> {
String(std::slice::Iter<'a, ast::StringLiteral>),
Bytes(std::slice::Iter<'a, ast::BytesLiteral>),
FString(std::slice::Iter<'a, ast::FStringPart>),
String(&'a Ast, std::slice::Iter<'a, ast::StringLiteral>),
Bytes(&'a Ast, std::slice::Iter<'a, ast::BytesLiteral>),
FString(&'a Ast, std::slice::Iter<'a, ast::FStringPart>),
}
impl<'a> Iterator for StringLikePartIter<'a> {
@@ -302,15 +284,19 @@ impl<'a> Iterator for StringLikePartIter<'a> {
fn next(&mut self) -> Option<Self::Item> {
let part = match self {
StringLikePartIter::String(inner) => StringLikePart::String(inner.next()?),
StringLikePartIter::Bytes(inner) => StringLikePart::Bytes(inner.next()?),
StringLikePartIter::FString(inner) => {
StringLikePartIter::String(ast, inner) => {
StringLikePart::String(ast.wrap(inner.next()?))
}
StringLikePartIter::Bytes(ast, inner) => StringLikePart::Bytes(ast.wrap(inner.next()?)),
StringLikePartIter::FString(ast, inner) => {
let part = inner.next()?;
match part {
ast::FStringPart::Literal(string_literal) => {
StringLikePart::String(string_literal)
StringLikePart::String(ast.wrap(string_literal))
}
ast::FStringPart::FString(f_string) => {
StringLikePart::FString(ast.wrap(f_string))
}
ast::FStringPart::FString(f_string) => StringLikePart::FString(f_string),
}
}
};
@@ -320,9 +306,9 @@ impl<'a> Iterator for StringLikePartIter<'a> {
fn size_hint(&self) -> (usize, Option<usize>) {
match self {
StringLikePartIter::String(inner) => inner.size_hint(),
StringLikePartIter::Bytes(inner) => inner.size_hint(),
StringLikePartIter::FString(inner) => inner.size_hint(),
StringLikePartIter::String(_, inner) => inner.size_hint(),
StringLikePartIter::Bytes(_, inner) => inner.size_hint(),
StringLikePartIter::FString(_, inner) => inner.size_hint(),
}
}
}
@@ -330,15 +316,21 @@ impl<'a> Iterator for StringLikePartIter<'a> {
impl DoubleEndedIterator for StringLikePartIter<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
let part = match self {
StringLikePartIter::String(inner) => StringLikePart::String(inner.next_back()?),
StringLikePartIter::Bytes(inner) => StringLikePart::Bytes(inner.next_back()?),
StringLikePartIter::FString(inner) => {
StringLikePartIter::String(ast, inner) => {
StringLikePart::String(ast.wrap(inner.next_back()?))
}
StringLikePartIter::Bytes(ast, inner) => {
StringLikePart::Bytes(ast.wrap(inner.next_back()?))
}
StringLikePartIter::FString(ast, inner) => {
let part = inner.next_back()?;
match part {
ast::FStringPart::Literal(string_literal) => {
StringLikePart::String(string_literal)
StringLikePart::String(ast.wrap(string_literal))
}
ast::FStringPart::FString(f_string) => {
StringLikePart::FString(ast.wrap(f_string))
}
ast::FStringPart::FString(f_string) => StringLikePart::FString(f_string),
}
}
};

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,13 @@
use std::ffi::OsStr;
use std::path::Path;
pub use ast::{Ast, Node};
pub use expression::*;
pub use generated::*;
pub use int::*;
pub use nodes::*;
pub mod ast;
pub mod comparable;
pub mod docstrings;
mod expression;

File diff suppressed because it is too large Load Diff

View File

@@ -13,28 +13,90 @@ use itertools::Itertools;
use ruff_text_size::{Ranged, TextLen, TextRange, TextSize};
use crate::ast::AstId;
use crate::name::Name;
use crate::{
int,
str::Quote,
str_prefix::{AnyStringPrefix, ByteStringPrefix, FStringPrefix, StringLiteralPrefix},
ExceptHandler, Expr, FStringElement, LiteralExpressionRef, Pattern, Stmt, TypeParam,
Ast, DecoratorId, ExceptHandler, Expr, ExprId, FStringElement, IdentifierId,
LiteralExpressionRef, Node, ParametersId, Pattern, Stmt, StmtId, TypeParam, TypeParamsId,
};
macro_rules! accessor {
($ty:ty, $field:ident, $ref_ty:ty) => {
impl $ty {
pub fn $field<'a>(&'a self, ast: &'a Ast) -> Node<'a, <$ref_ty as AstId>::Output<'a>> {
ast.wrap(ast.node(self.$field))
}
}
impl<'a> Node<'a, &'a $ty> {
pub fn $field(self) -> Node<'a, <$ref_ty as AstId>::Output<'a>> {
self.node.$field(self.ast)
}
}
};
}
macro_rules! option_accessor {
($ty:ty, $field:ident, $ref_ty:ty) => {
impl $ty {
pub fn $field<'a>(
&'a self,
ast: &'a Ast,
) -> Option<Node<'a, <$ref_ty as AstId>::Output<'a>>> {
self.$field.map(|id| ast.wrap(ast.node(*id)))
}
}
impl<'a> Node<'a, &'a $ty> {
pub fn $field(self) -> Option<Node<'a, <$ref_ty as AstId>::Output<'a>>> {
self.node.$field(self.ast)
}
}
};
}
macro_rules! vec_accessor {
($ty:ty, $field:ident, $ref_ty:ty) => {
impl $ty {
pub fn $field<'a>(
&'a self,
ast: &'a Ast,
) -> impl Iterator<Item = Node<'a, <$ref_ty as AstId>::Output<'a>>> + 'a {
self.$field.iter().map(|id| ast.wrap(ast.node(*id)))
}
}
impl<'a> Node<'a, &'a $ty> {
pub fn $field(
self,
) -> impl Iterator<Item = Node<'a, <$ref_ty as AstId>::Output<'a>>> + 'a {
self.node.$field(self.ast)
}
}
};
}
/// See also [Module](https://docs.python.org/3/library/ast.html#ast.Module)
#[derive(Clone, Debug, PartialEq)]
pub struct ModModule {
pub range: TextRange,
pub body: Vec<Stmt>,
pub body: Vec<StmtId>,
}
vec_accessor!(ModModule, body, StmtId);
/// See also [Expression](https://docs.python.org/3/library/ast.html#ast.Expression)
#[derive(Clone, Debug, PartialEq)]
pub struct ModExpression {
pub range: TextRange,
pub body: Box<Expr>,
pub body: ExprId,
}
accessor!(ModExpression, body, ExprId);
/// An AST node used to represent a IPython escape command at the statement level.
///
/// For example,
@@ -104,20 +166,27 @@ pub struct StmtIpyEscapeCommand {
pub struct StmtFunctionDef {
pub range: TextRange,
pub is_async: bool,
pub decorator_list: Vec<Decorator>,
pub name: Identifier,
pub type_params: Option<Box<TypeParams>>,
pub parameters: Box<Parameters>,
pub returns: Option<Box<Expr>>,
pub body: Vec<Stmt>,
pub decorator_list: Vec<DecoratorId>,
pub name: IdentifierId,
pub type_params: Option<TypeParamsId>,
pub parameters: ParametersId,
pub returns: Option<ExprId>,
pub body: Vec<StmtId>,
}
vec_accessor!(StmtFunctionDef, decorator_list, DecoratorId);
accessor!(StmtFunctionDef, name, IdentifierId);
option_accessor!(StmtFunctionDef, type_params, TypeParamsId);
accessor!(StmtFunctionDef, parameters, ParametersId);
option_accessor!(StmtFunctionDef, returns, ExprId);
vec_accessor!(StmtFunctionDef, body, StmtId);
/// See also [ClassDef](https://docs.python.org/3/library/ast.html#ast.ClassDef)
#[derive(Clone, Debug, PartialEq)]
pub struct StmtClassDef {
pub range: TextRange,
pub decorator_list: Vec<Decorator>,
pub name: Identifier,
pub name: IdentifierId,
pub type_params: Option<Box<TypeParams>>,
pub arguments: Option<Box<Arguments>>,
pub body: Vec<Stmt>,

View File

@@ -1,9 +1,9 @@
use crate::AnyNodeRef;
use crate::{
Alias, Arguments, BoolOp, BytesLiteral, CmpOp, Comprehension, Decorator, ElifElseClause,
ExceptHandler, Expr, FString, FStringElement, Keyword, MatchCase, Mod, Operator, Parameter,
ParameterWithDefault, Parameters, Pattern, PatternArguments, PatternKeyword, Singleton, Stmt,
StringLiteral, TypeParam, TypeParams, UnaryOp, WithItem,
ExceptHandler, Expr, FString, FStringElement, Keyword, MatchCase, Mod, ModId, Node, Operator,
Parameter, ParameterWithDefault, Parameters, Pattern, PatternArguments, PatternKeyword,
Singleton, Stmt, StringLiteral, TypeParam, TypeParams, UnaryOp, WithItem,
};
/// Visitor that traverses all nodes recursively in the order they appear in the source.
@@ -20,8 +20,8 @@ pub trait SourceOrderVisitor<'a> {
fn leave_node(&mut self, _node: AnyNodeRef<'a>) {}
#[inline]
fn visit_mod(&mut self, module: &'a Mod) {
walk_module(self, module);
fn visit_mod(&mut self, module: Node<'a, ModId>) {
walk_module(self, module.node());
}
#[inline]
@@ -172,7 +172,7 @@ pub trait SourceOrderVisitor<'a> {
}
}
pub fn walk_module<'a, V>(visitor: &mut V, module: &'a Mod)
pub fn walk_module<'a, V>(visitor: &mut V, module: Node<'a, &'a Mod>)
where
V: SourceOrderVisitor<'a> + ?Sized,
{

View File

@@ -15,6 +15,13 @@ def rustfmt(code: str) -> str:
return check_output(["rustfmt", "--emit=stdout"], input=code, text=True)
# Which nodes have been migrated over to the new IndexVec representation?
indexed_nodes = {
"ModModule",
"ModExpression",
}
# %%
# Read nodes
@@ -130,14 +137,18 @@ use ruff_python_ast as ast;
"""
for node in nodes:
if node in indexed_nodes:
node_type = f"ast::Node<&ast::{node}>"
else:
node_type = f"ast::{node}"
text = f"""
impl FormatRule<ast::{node}, PyFormatContext<'_>>
impl FormatRule<{node_type}, PyFormatContext<'_>>
for crate::{groups[group_for_node(node)]}::{to_camel_case(node)}::Format{node}
{{
#[inline]
fn fmt(
&self,
node: &ast::{node},
node: &{node_type},
f: &mut PyFormatter,
) -> FormatResult<()> {{
FormatNodeRule::<ast::{node}>::fmt(self, node, f)

View File

@@ -71,7 +71,8 @@ pub fn format_and_debug_print(source: &str, cli: &Cli, source_path: &Path) -> Re
}
if cli.print_comments {
// Print preceding, following and enclosing nodes
let decorated_comments = collect_comments(parsed.syntax(), source_code, &comment_ranges);
let decorated_comments =
collect_comments(parsed.syntax().node(), source_code, &comment_ranges);
if !decorated_comments.is_empty() {
println!("# Comment decoration: Range, Preceding, Following, Enclosing, Comment");
}

View File

@@ -1,7 +1,8 @@
use ast::helpers::comment_indentation_after;
use ruff_python_ast::whitespace::indentation;
use ruff_python_ast::{
self as ast, AnyNodeRef, Comprehension, Expr, ModModule, Parameter, Parameters, StringLike,
self as ast, AnyNodeRef, Comprehension, Expr, ModModule, Node, Parameter, Parameters,
StringLike,
};
use ruff_python_trivia::{
find_only_token_in_range, first_non_trivia_token, indentation_at_offset, BackwardsTokenizer,
@@ -966,7 +967,7 @@ fn handle_trailing_binary_like_comment<'a>(
///
/// Comments of an all empty module are leading module comments
fn handle_trailing_module_comment<'a>(
module: &'a ModModule,
module: Node<'a, &'a ModModule>,
comment: DecoratedComment<'a>,
) -> CommentPlacement<'a> {
if comment.preceding_node().is_none() && comment.following_node().is_none() {

View File

@@ -18,7 +18,7 @@ use crate::comments::{CommentsMap, SourceComment};
/// Collect the preceding, following and enclosing node for each comment without applying
/// [`place_comment`] for debugging.
pub(crate) fn collect_comments<'a>(
root: &'a Mod,
root: Mod,
source_code: SourceCode<'a>,
comment_ranges: &'a CommentRanges,
) -> Vec<DecoratedComment<'a>> {

View File

@@ -4,7 +4,7 @@ use tracing::Level;
pub use range::format_range;
use ruff_formatter::prelude::*;
use ruff_formatter::{format, write, FormatError, Formatted, PrintError, Printed, SourceCode};
use ruff_python_ast::{AnyNodeRef, Mod};
use ruff_python_ast::{AnyNodeRef, Mod, ModId};
use ruff_python_parser::{parse, AsMode, ParseError, Parsed};
use ruff_python_trivia::CommentRanges;
use ruff_text_size::Ranged;
@@ -119,7 +119,7 @@ pub fn format_module_source(
}
pub fn format_module_ast<'a>(
parsed: &'a Parsed<Mod>,
parsed: &'a Parsed<ModId>,
comment_ranges: &'a CommentRanges,
source: &'a str,
options: PyFormatOptions,

View File

@@ -1,4 +1,4 @@
use ruff_formatter::{FormatOwnedWithRule, FormatRefWithRule};
use ruff_formatter::FormatOwnedWithRule;
use ruff_python_ast::Mod;
use crate::prelude::*;
@@ -9,8 +9,8 @@ pub(crate) mod mod_module;
#[derive(Default)]
pub struct FormatMod;
impl FormatRule<Mod, PyFormatContext<'_>> for FormatMod {
fn fmt(&self, item: &Mod, f: &mut PyFormatter) -> FormatResult<()> {
impl FormatRule<Mod<'_>, PyFormatContext<'_>> for FormatMod {
fn fmt(&self, item: &Mod<'_>, f: &mut PyFormatter) -> FormatResult<()> {
match item {
Mod::Module(x) => x.format().fmt(f),
Mod::Expression(x) => x.format().fmt(f),
@@ -18,16 +18,8 @@ impl FormatRule<Mod, PyFormatContext<'_>> for FormatMod {
}
}
impl<'ast> AsFormat<PyFormatContext<'ast>> for Mod {
type Format<'a> = FormatRefWithRule<'a, Mod, FormatMod, PyFormatContext<'ast>>;
fn format(&self) -> Self::Format<'_> {
FormatRefWithRule::new(self, FormatMod)
}
}
impl<'ast> IntoFormat<PyFormatContext<'ast>> for Mod {
type Format = FormatOwnedWithRule<Mod, FormatMod, PyFormatContext<'ast>>;
impl<'ast> IntoFormat<PyFormatContext<'ast>> for Mod<'ast> {
type Format = FormatOwnedWithRule<Mod<'ast>, FormatMod, PyFormatContext<'ast>>;
fn into_format(self) -> Self::Format {
FormatOwnedWithRule::new(self, FormatMod)

View File

@@ -1,4 +1,4 @@
use ruff_python_ast::ModExpression;
use ruff_python_ast::{ModExpression, Node};
use crate::prelude::*;

View File

@@ -1,5 +1,5 @@
use ruff_formatter::write;
use ruff_python_ast::ModModule;
use ruff_python_ast::{ModModule, Node};
use ruff_python_trivia::lines_after;
use crate::prelude::*;

View File

@@ -73,7 +73,8 @@ pub use crate::token::{Token, TokenKind};
use crate::parser::Parser;
use ruff_python_ast::{
Expr, Mod, ModExpression, ModModule, PySourceType, StringFlags, StringLiteral, Suite,
Ast, Expr, Mod, ModExpression, ModId, ModModule, ModModuleId, PySourceType, StringFlags,
StringLiteral, Suite,
};
use ruff_python_trivia::CommentRanges;
use ruff_text_size::{Ranged, TextRange, TextSize};
@@ -109,7 +110,7 @@ pub mod typing;
/// let module = parse_module(source);
/// assert!(module.is_ok());
/// ```
pub fn parse_module(source: &str) -> Result<Parsed<ModModule>, ParseError> {
pub fn parse_module(source: &str) -> Result<Parsed<ModModuleId>, ParseError> {
Parser::new(source, Mode::Module)
.parse()
.try_into_module()
@@ -132,7 +133,7 @@ pub fn parse_module(source: &str) -> Result<Parsed<ModModule>, ParseError> {
/// let expr = parse_expression("1 + 2");
/// assert!(expr.is_ok());
/// ```
pub fn parse_expression(source: &str) -> Result<Parsed<ModExpression>, ParseError> {
pub fn parse_expression(source: &str) -> Result<Parsed<ModExpressionId>, ParseError> {
Parser::new(source, Mode::Expression)
.parse()
.try_into_expression()
@@ -159,7 +160,7 @@ pub fn parse_expression(source: &str) -> Result<Parsed<ModExpression>, ParseErro
pub fn parse_expression_range(
source: &str,
range: TextRange,
) -> Result<Parsed<ModExpression>, ParseError> {
) -> Result<Parsed<ModExpressionId>, ParseError> {
let source = &source[..range.end().to_usize()];
Parser::new_starts_at(source, Mode::Expression, range.start())
.parse()
@@ -273,7 +274,7 @@ pub fn parse_string_annotation(
/// let parsed = parse(source, Mode::Ipython);
/// assert!(parsed.is_ok());
/// ```
pub fn parse(source: &str, mode: Mode) -> Result<Parsed<Mod>, ParseError> {
pub fn parse(source: &str, mode: Mode) -> Result<Parsed<ModId>, ParseError> {
parse_unchecked(source, mode).into_result()
}
@@ -281,12 +282,12 @@ pub fn parse(source: &str, mode: Mode) -> Result<Parsed<Mod>, ParseError> {
///
/// This is same as the [`parse`] function except that it doesn't check for any [`ParseError`]
/// and returns the [`Parsed`] as is.
pub fn parse_unchecked(source: &str, mode: Mode) -> Parsed<Mod> {
pub fn parse_unchecked(source: &str, mode: Mode) -> Parsed<ModId> {
Parser::new(source, mode).parse()
}
/// Parse the given Python source code using the specified [`PySourceType`].
pub fn parse_unchecked_source(source: &str, source_type: PySourceType) -> Parsed<ModModule> {
pub fn parse_unchecked_source(source: &str, source_type: PySourceType) -> Parsed<ModModuleId> {
// SAFETY: Safe because `PySourceType` always parses to a `ModModule`
Parser::new(source, source_type.as_mode())
.parse()
@@ -297,15 +298,26 @@ pub fn parse_unchecked_source(source: &str, source_type: PySourceType) -> Parsed
/// Represents the parsed source code.
#[derive(Debug, PartialEq, Clone)]
pub struct Parsed<T> {
ast: Ast,
syntax: T,
tokens: Tokens,
errors: Vec<ParseError>,
}
impl<T> Parsed<T> {
impl<T> Parsed<T>
where
T: Clone,
{
/// Returns the syntax node represented by this parsed output.
pub fn syntax(&self) -> &T {
&self.syntax
pub fn syntax(&self) -> Node<T> {
self.ast.wrap(self.syntax.clone())
}
}
impl<T> Parsed<T> {
/// Returns the AST for the parsed output.
pub fn ast(&self) -> &Ast {
&self.ast
}
/// Returns all the tokens for the parsed output.
@@ -318,11 +330,6 @@ impl<T> Parsed<T> {
&self.errors
}
/// Consumes the [`Parsed`] output and returns the contained syntax node.
pub fn into_syntax(self) -> T {
self.syntax
}
/// Consumes the [`Parsed`] output and returns a list of syntax errors found during parsing.
pub fn into_errors(self) -> Vec<ParseError> {
self.errors
@@ -354,7 +361,7 @@ impl<T> Parsed<T> {
}
}
impl Parsed<Mod> {
impl Parsed<ModId> {
/// Attempts to convert the [`Parsed<Mod>`] into a [`Parsed<ModModule>`].
///
/// This method checks if the `syntax` field of the output is a [`Mod::Module`]. If it is, the
@@ -362,14 +369,15 @@ impl Parsed<Mod> {
/// returns [`None`].
///
/// [`Some(Parsed<ModModule>)`]: Some
pub fn try_into_module(self) -> Option<Parsed<ModModule>> {
pub fn try_into_module(self) -> Option<Parsed<ModModuleId>> {
match self.syntax {
Mod::Module(module) => Some(Parsed {
ModId::Module(module) => Some(Parsed {
ast: self.ast,
syntax: module,
tokens: self.tokens,
errors: self.errors,
}),
Mod::Expression(_) => None,
ModId::Expression(_) => None,
}
}
@@ -380,10 +388,11 @@ impl Parsed<Mod> {
/// Otherwise, it returns [`None`].
///
/// [`Some(Parsed<ModExpression>)`]: Some
pub fn try_into_expression(self) -> Option<Parsed<ModExpression>> {
pub fn try_into_expression(self) -> Option<Parsed<ModExpressionId>> {
match self.syntax {
Mod::Module(_) => None,
Mod::Expression(expression) => Some(Parsed {
ModId::Module(_) => None,
ModId::Expression(expression) => Some(Parsed {
ast: self.ast,
syntax: expression,
tokens: self.tokens,
errors: self.errors,
@@ -392,32 +401,22 @@ impl Parsed<Mod> {
}
}
impl Parsed<ModModule> {
impl Parsed<ModModuleId> {
/// Returns the module body contained in this parsed output as a [`Suite`].
pub fn suite(&self) -> &Suite {
&self.syntax.body
}
/// Consumes the [`Parsed`] output and returns the module body as a [`Suite`].
pub fn into_suite(self) -> Suite {
self.syntax.body
&self.ast[self.syntax].body
}
}
impl Parsed<ModExpression> {
impl Parsed<ModExpressionId> {
/// Returns the expression contained in this parsed output.
pub fn expr(&self) -> &Expr {
&self.syntax.body
&self.ast[self.syntax].body
}
/// Returns a mutable reference to the expression contained in this parsed output.
pub fn expr_mut(&mut self) -> &mut Expr {
&mut self.syntax.body
}
/// Consumes the [`Parsed`] output and returns the contained [`Expr`].
pub fn into_expr(self) -> Expr {
*self.syntax.body
&mut self.ast[self.syntax].body
}
}

View File

@@ -2,7 +2,7 @@ use std::cmp::Ordering;
use bitflags::bitflags;
use ruff_python_ast::{Mod, ModExpression, ModModule};
use ruff_python_ast::{Ast, ModExpression, ModId, ModModule};
use ruff_text_size::{Ranged, TextRange, TextSize};
use crate::parser::expression::ExpressionContext;
@@ -47,6 +47,9 @@ pub(crate) struct Parser<'src> {
/// The start offset in the source code from which to start parsing at.
start_offset: TextSize,
/// The AST that we are constructing.
ast: Ast,
}
impl<'src> Parser<'src> {
@@ -68,16 +71,21 @@ impl<'src> Parser<'src> {
prev_token_end: TextSize::new(0),
start_offset,
current_token_id: TokenId::default(),
ast: Ast::default(),
}
}
/// Consumes the [`Parser`] and returns the parsed [`Parsed`].
pub(crate) fn parse(mut self) -> Parsed<Mod> {
pub(crate) fn parse(mut self) -> Parsed<ModId> {
let syntax = match self.mode {
Mode::Expression | Mode::ParenthesizedExpression => {
Mod::Expression(self.parse_single_expression())
let expr = self.parse_single_expression();
self.ast.add_mod_expression(expr)
}
Mode::Module | Mode::Ipython => {
let module = self.parse_module();
self.ast.add_mod_module(module)
}
Mode::Module | Mode::Ipython => Mod::Module(self.parse_module()),
};
self.finish(syntax)
@@ -140,7 +148,7 @@ impl<'src> Parser<'src> {
}
}
fn finish(self, syntax: Mod) -> Parsed<Mod> {
fn finish(self, syntax: ModId) -> Parsed<ModId> {
assert_eq!(
self.current_token_kind(),
TokenKind::EndOfFile,
@@ -156,6 +164,7 @@ impl<'src> Parser<'src> {
// always results in a parse error.
if lex_errors.is_empty() {
return Parsed {
ast: self.ast,
syntax,
tokens: Tokens::new(tokens),
errors: parse_errors,
@@ -187,6 +196,7 @@ impl<'src> Parser<'src> {
merged.extend(lex_errors.map(ParseError::from));
Parsed {
ast: self.ast,
syntax,
tokens: Tokens::new(tokens),
errors: merged,

View File

@@ -2,7 +2,7 @@
use ruff_python_ast::relocate::relocate_expr;
use ruff_python_ast::str::raw_contents;
use ruff_python_ast::{Expr, ExprStringLiteral, ModExpression, StringLiteral};
use ruff_python_ast::{Expr, ExprStringLiteral, ModExpressionId, StringLiteral};
use ruff_text_size::Ranged;
use crate::{parse_expression, parse_string_annotation, ParseError, Parsed};
@@ -11,12 +11,12 @@ type AnnotationParseResult = Result<ParsedAnnotation, ParseError>;
#[derive(Debug)]
pub struct ParsedAnnotation {
parsed: Parsed<ModExpression>,
parsed: Parsed<ModExpressionId>,
kind: AnnotationKind,
}
impl ParsedAnnotation {
pub fn parsed(&self) -> &Parsed<ModExpression> {
pub fn parsed(&self) -> &Parsed<ModExpressionId> {
&self.parsed
}

View File

@@ -205,7 +205,6 @@ impl serde::Serialize for NameImport {
impl<'de> serde::de::Deserialize<'de> for NameImports {
fn deserialize<D: serde::de::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
use ruff_python_ast::{self as ast, Stmt};
use ruff_python_parser::Parsed;
struct AnyNameImportsVisitor;
@@ -217,10 +216,8 @@ impl<'de> serde::de::Deserialize<'de> for NameImports {
}
fn visit_str<E: serde::de::Error>(self, value: &str) -> Result<Self::Value, E> {
let body = ruff_python_parser::parse_module(value)
.map(Parsed::into_suite)
.map_err(E::custom)?;
let [stmt] = body.as_slice() else {
let body = ruff_python_parser::parse_module(value).map_err(E::custom)?;
let [stmt] = body.suite().as_slice() else {
return Err(E::custom("Expected a single statement"));
};