Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate grouped and ungrouped nodes more clearly in AST generator #15646

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 80 additions & 74 deletions crates/ruff_python_ast/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,41 @@ def to_snake_case(node: str) -> str:
# Read AST description


def load_ast(root: Path) -> list[Group]:
def load_ast(root: Path) -> Ast:
ast_path = root.joinpath("crates", "ruff_python_ast", "ast.toml")
with ast_path.open("rb") as ast_file:
ast = tomllib.load(ast_file)
return [Group(group_name, group) for group_name, group in ast.items()]
return Ast(ast)


# ------------------------------------------------------------------------------
# Preprocess


@dataclass
class Ast:
"""
The parsed representation of the `ast.toml` file. Defines all of the Python
AST syntax nodes, and which groups (`Stmt`, `Expr`, etc.) they belong to.
"""

groups: list[Group]
ungrouped_nodes: list[Node]
all_nodes: list[Node]

def __init__(self, ast: dict[str, Any]) -> None:
self.groups = []
self.ungrouped_nodes = []
self.all_nodes = []
for group_name, group in ast.items():
group = Group(group_name, group)
self.all_nodes.extend(group.nodes)
if group_name == "ungrouped":
self.ungrouped_nodes = group.nodes
else:
self.groups.append(group)


@dataclass
class Group:
name: str
Expand Down Expand Up @@ -89,7 +113,7 @@ def write_preamble(out: list[str]) -> None:
# Owned enum


def write_owned_enum(out: list[str], groups: list[Group]) -> None:
def write_owned_enum(out: list[str], ast: Ast) -> None:
"""
Create an enum for each group that contains an owned copy of a syntax node.

Expand All @@ -112,10 +136,7 @@ def write_owned_enum(out: list[str], groups: list[Group]) -> None:
`is_type_var` method will be named `is_type_var_type_param`.
"""

for group in groups:
if group.name == "ungrouped":
continue

for group in ast.groups:
out.append("")
if group.rustdoc is not None:
out.append(group.rustdoc)
Expand Down Expand Up @@ -150,19 +171,16 @@ def write_owned_enum(out: list[str], groups: list[Group]) -> None:
}
""")

for group in groups:
for node in group.nodes:
out.append(f"""
for node in ast.all_nodes:
out.append(f"""
impl ruff_text_size::Ranged for {node.ty} {{
fn range(&self) -> ruff_text_size::TextRange {{
self.range
}}
}}
""")
""")

for group in groups:
if group.name == "ungrouped":
continue
for group in ast.groups:
out.append(f"""
impl {group.owned_enum_ty} {{
#[allow(unused)]
Expand All @@ -187,7 +205,7 @@ def write_owned_enum(out: list[str], groups: list[Group]) -> None:
# Ref enum


def write_ref_enum(out: list[str], groups: list[Group]) -> None:
def write_ref_enum(out: list[str], ast: Ast) -> None:
"""
Create an enum for each group that contains a reference to a syntax node.

Expand All @@ -211,10 +229,7 @@ def write_ref_enum(out: list[str], groups: list[Group]) -> None:
method will be named `is_type_var_type_param`.
"""

for group in groups:
if group.name == "ungrouped":
continue

for group in ast.groups:
out.append("")
if group.rustdoc is not None:
out.append(group.rustdoc)
Expand Down Expand Up @@ -269,7 +284,7 @@ def write_ref_enum(out: list[str], groups: list[Group]) -> None:
# AnyNodeRef


def write_anynoderef(out: list[str], groups: list[Group]) -> None:
def write_anynoderef(out: list[str], ast: Ast) -> None:
"""
Create the AnyNodeRef type.

Expand All @@ -295,62 +310,59 @@ def write_anynoderef(out: list[str], groups: list[Group]) -> None:
#[derive(Copy, Clone, Debug, is_macro::Is, PartialEq)]
pub enum AnyNodeRef<'a> {
""")
for group in groups:
for node in group.nodes:
out.append(f"""{node.name}(&'a {node.ty}),""")
for node in ast.all_nodes:
out.append(f"""{node.name}(&'a {node.ty}),""")
out.append("""
}
""")

for group in groups:
if group.name != "ungrouped":
out.append(f"""
for group in ast.groups:
out.append(f"""
impl<'a> From<&'a {group.owned_enum_ty}> for AnyNodeRef<'a> {{
fn from(node: &'a {group.owned_enum_ty}) -> AnyNodeRef<'a> {{
match node {{
""")
for node in group.nodes:
out.append(
f"{group.owned_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
)
out.append("""
""")
for node in group.nodes:
out.append(
f"{group.owned_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
)
out.append("""
}
}
}
""")
""")

out.append(f"""
out.append(f"""
impl<'a> From<{group.ref_enum_ty}<'a>> for AnyNodeRef<'a> {{
fn from(node: {group.ref_enum_ty}<'a>) -> AnyNodeRef<'a> {{
match node {{
""")
for node in group.nodes:
out.append(
f"{group.ref_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
)
out.append("""
""")
for node in group.nodes:
out.append(
f"{group.ref_enum_ty}::{node.variant}(node) => AnyNodeRef::{node.name}(node),"
)
out.append("""
}
}
}
""")
""")

for node in group.nodes:
out.append(f"""
for node in ast.all_nodes:
out.append(f"""
impl<'a> From<&'a {node.ty}> for AnyNodeRef<'a> {{
fn from(node: &'a {node.ty}) -> AnyNodeRef<'a> {{
AnyNodeRef::{node.name}(node)
}}
}}
""")
""")

out.append("""
impl ruff_text_size::Ranged for AnyNodeRef<'_> {
fn range(&self) -> ruff_text_size::TextRange {
match self {
""")
for group in groups:
for node in group.nodes:
out.append(f"""AnyNodeRef::{node.name}(node) => node.range(),""")
for node in ast.all_nodes:
out.append(f"""AnyNodeRef::{node.name}(node) => node.range(),""")
out.append("""
}
}
Expand All @@ -362,11 +374,10 @@ def write_anynoderef(out: list[str], groups: list[Group]) -> None:
pub fn as_ptr(&self) -> std::ptr::NonNull<()> {
match self {
""")
for group in groups:
for node in group.nodes:
out.append(
f"AnyNodeRef::{node.name}(node) => std::ptr::NonNull::from(*node).cast(),"
)
for node in ast.all_nodes:
out.append(
f"AnyNodeRef::{node.name}(node) => std::ptr::NonNull::from(*node).cast(),"
)
out.append("""
}
}
Expand All @@ -382,20 +393,17 @@ def write_anynoderef(out: list[str], groups: list[Group]) -> None:
{
match self {
""")
for group in groups:
for node in group.nodes:
out.append(
f"AnyNodeRef::{node.name}(node) => node.visit_source_order(visitor),"
)
for node in ast.all_nodes:
out.append(
f"AnyNodeRef::{node.name}(node) => node.visit_source_order(visitor),"
)
out.append("""
}
}
}
""")

for group in groups:
if group.name == "ungrouped":
continue
for group in ast.groups:
out.append(f"""
impl AnyNodeRef<'_> {{
pub const fn is_{group.anynode_is_label}(self) -> bool {{
Expand All @@ -416,7 +424,7 @@ def write_anynoderef(out: list[str], groups: list[Group]) -> None:
# NodeKind


def write_nodekind(out: list[str], groups: list[Group]) -> None:
def write_nodekind(out: list[str], ast: Ast) -> None:
"""
Create the NodeKind type.

Expand All @@ -437,9 +445,8 @@ def write_nodekind(out: list[str], groups: list[Group]) -> None:
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum NodeKind {
""")
for group in groups:
for node in group.nodes:
out.append(f"""{node.name},""")
for node in ast.all_nodes:
out.append(f"""{node.name},""")
out.append("""
}
""")
Expand All @@ -449,9 +456,8 @@ def write_nodekind(out: list[str], groups: list[Group]) -> None:
pub const fn kind(self) -> NodeKind {
match self {
""")
for group in groups:
for node in group.nodes:
out.append(f"""AnyNodeRef::{node.name}(_) => NodeKind::{node.name},""")
for node in ast.all_nodes:
out.append(f"""AnyNodeRef::{node.name}(_) => NodeKind::{node.name},""")
out.append("""
}
}
Expand All @@ -463,13 +469,13 @@ def write_nodekind(out: list[str], groups: list[Group]) -> None:
# Format and write output


def generate(groups: list[Group]) -> list[str]:
def generate(ast: Ast) -> list[str]:
out = []
write_preamble(out)
write_owned_enum(out, groups)
write_ref_enum(out, groups)
write_anynoderef(out, groups)
write_nodekind(out, groups)
write_owned_enum(out, ast)
write_ref_enum(out, ast)
write_anynoderef(out, ast)
write_nodekind(out, ast)
return out


Expand All @@ -486,8 +492,8 @@ def main() -> None:
root = Path(
check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip()
)
groups = load_ast(root)
out = generate(groups)
ast = load_ast(root)
out = generate(ast)
write_output(root, out)


Expand Down
Loading
Loading