Skip to content

Refactor ast to hold data as seperated type #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
parser/src/python.rs linguist-generated
**/*.snap linguist-generated -merge
**/*.lalrpop text eol=LF
**/*.py text working-tree-encoding=UTF-8 eol=LF
**/*.rs text working-tree-encoding=UTF-8 eol=LF
129 changes: 99 additions & 30 deletions ast/asdl_rs.py
Original file line number Diff line number Diff line change
@@ -131,6 +131,9 @@ def visitSum(self, sum, name):
if is_simple(sum):
info.has_userdata = False
else:
for t in sum.types:
self.typeinfo[t.name] = TypeInfo(t.name)
self.add_children(t.name, t.fields)
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
@@ -205,16 +208,49 @@ def simple_sum(self, sum, name, depth):

def sum_with_constructors(self, sum, name, depth):
typeinfo = self.typeinfo[name]
generics, generics_applied = self.get_generics(name, "U = ()", "U")
enumname = rustname = get_rust_type(name)
# all the attributes right now are for location, so if it has attrs we
# can just wrap it in Located<>
if sum.attributes:
enumname = rustname + "Kind"

for t in sum.types:
if not t.fields:
continue
self.emit_attrs(depth)
self.typeinfo[t] = TypeInfo(t)
t_generics, t_generics_applied = self.get_generics(t.name, "U = ()", "U")
payload_name = f"{rustname}{t.name}"
self.emit(f"pub struct {payload_name}{t_generics} {{", depth)
for f in t.fields:
self.visit(f, typeinfo, "pub ", depth + 1, t.name)
self.emit("}", depth)
self.emit(
textwrap.dedent(
f"""
impl{t_generics_applied} From<{payload_name}{t_generics_applied}> for {enumname}{t_generics_applied} {{
fn from(payload: {payload_name}{t_generics_applied}) -> Self {{
{enumname}::{t.name}(payload)
}}
}}
"""
),
depth,
)

generics, generics_applied = self.get_generics(name, "U = ()", "U")
self.emit_attrs(depth)
self.emit(f"pub enum {enumname}{generics} {{", depth)
for t in sum.types:
self.visit(t, typeinfo, depth + 1)
if t.fields:
t_generics, t_generics_applied = self.get_generics(
t.name, "U = ()", "U"
)
self.emit(
f"{t.name}({rustname}{t.name}{t_generics_applied}),", depth + 1
)
else:
self.emit(f"{t.name},", depth + 1)
self.emit("}", depth)
if sum.attributes:
self.emit(
@@ -238,13 +274,18 @@ def visitField(self, field, parent, vis, depth, constructor=None):
if fieldtype and fieldtype.has_userdata:
typ = f"{typ}<U>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if fieldtype and fieldtype.boxed and (not (parent.product or field.seq) or field.opt):
if (
fieldtype
and fieldtype.boxed
and (not (parent.product or field.seq) or field.opt)
):
typ = f"Box<{typ}>"
if field.opt or (
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
# the expression to be unpacked goes in `values` with a `None` at the corresponding
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
constructor == "Dict" and field.name == "keys"
constructor == "Dict"
and field.name == "keys"
):
typ = f"Option<{typ}>"
if field.seq:
@@ -344,14 +385,21 @@ def visitSum(self, sum, name, depth):
)
if is_located:
self.emit("fold_located(folder, node, |folder, node| {", depth)
enumname += "Kind"
rustname = enumname + "Kind"
else:
rustname = enumname
self.emit("match node {", depth + 1)
for cons in sum.types:
fields_pattern = self.make_pattern(cons.fields)
fields_pattern = self.make_pattern(
enumname, rustname, cons.name, cons.fields
)
self.emit(
f"{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth + 2
f"{fields_pattern[0]} {{ {fields_pattern[1]} }} {fields_pattern[2]} => {{",
depth + 2,
)
self.gen_construction(
fields_pattern[0], cons.fields, fields_pattern[2], depth + 3
)
self.gen_construction(f"{enumname}::{cons.name}", cons.fields, depth + 3)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
if is_located:
@@ -381,23 +429,33 @@ def visitProduct(self, product, name, depth):
)
if is_located:
self.emit("fold_located(folder, node, |folder, node| {", depth)
structname += "Data"
fields_pattern = self.make_pattern(product.fields)
self.emit(f"let {structname} {{ {fields_pattern} }} = node;", depth + 1)
self.gen_construction(structname, product.fields, depth + 1)
rustname = structname + "Data"
else:
rustname = structname
fields_pattern = self.make_pattern(rustname, structname, None, product.fields)
self.emit(f"let {rustname} {{ {fields_pattern[1]} }} = node;", depth + 1)
self.gen_construction(rustname, product.fields, "", depth + 1)
if is_located:
self.emit("})", depth)
self.emit("}", depth)

def make_pattern(self, fields):
return ",".join(rust_field(f.name) for f in fields)
def make_pattern(self, rustname, pyname, fieldname, fields):
if fields:
header = f"{pyname}::{fieldname}({rustname}{fieldname}"
footer = ")"
else:
header = f"{pyname}::{fieldname}"
footer = ""

def gen_construction(self, cons_path, fields, depth):
self.emit(f"Ok({cons_path} {{", depth)
body = ",".join(rust_field(f.name) for f in fields)
return header, body, footer

def gen_construction(self, header, fields, footer, depth):
self.emit(f"Ok({header} {{", depth)
for field in fields:
name = rust_field(field.name)
self.emit(f"{name}: Foldable::fold({name}, folder)?,", depth + 1)
self.emit("})", depth)
self.emit(f"}}{footer})", depth)


class FoldModuleVisitor(TypeInfoEmitVisitor):
@@ -514,33 +572,36 @@ def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)

def visitSum(self, sum, name, depth):
enumname = get_rust_type(name)
rustname = enumname = get_rust_type(name)
if sum.attributes:
enumname += "Kind"
rustname = enumname + "Kind"

self.emit(f"impl NamedNode for ast::{enumname} {{", depth)
self.emit(f"impl NamedNode for ast::{rustname} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::{enumname} {{", depth)
self.emit(f"impl Node for ast::{rustname} {{", depth)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
self.emit("match self {", depth + 2)
for variant in sum.types:
self.constructor_to_object(variant, enumname, depth + 3)
self.constructor_to_object(variant, enumname, rustname, depth + 3)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_sum_fromobj(sum, name, enumname, depth + 2)
self.gen_sum_fromobj(sum, name, enumname, rustname, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)

def constructor_to_object(self, cons, enumname, depth):
fields_pattern = self.make_pattern(cons.fields)
self.emit(f"ast::{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth)
def constructor_to_object(self, cons, enumname, rustname, depth):
self.emit(f"ast::{rustname}::{cons.name}", depth)
if cons.fields:
fields_pattern = self.make_pattern(cons.fields)
self.emit(f"( ast::{enumname}{cons.name} {{ {fields_pattern} }} )", depth)
self.emit(" => {", depth)
self.make_node(cons.name, cons.fields, depth + 1)
self.emit("}", depth)

@@ -586,15 +647,20 @@ def make_node(self, variant, fields, depth):
def make_pattern(self, fields):
return ",".join(rust_field(f.name) for f in fields)

def gen_sum_fromobj(self, sum, sumname, enumname, depth):
def gen_sum_fromobj(self, sum, sumname, enumname, rustname, depth):
if sum.attributes:
self.extract_location(sumname, depth)

self.emit("let _cls = _object.class();", depth)
self.emit("Ok(", depth)
for cons in sum.types:
self.emit(f"if _cls.is(Node{cons.name}::static_type()) {{", depth)
self.gen_construction(f"{enumname}::{cons.name}", cons, sumname, depth + 1)
if cons.fields:
self.emit(f"ast::{rustname}::{cons.name} (ast::{enumname}{cons.name} {{", depth + 1)
self.gen_construction_fields(cons, sumname, depth + 1)
self.emit("})", depth + 1)
else:
self.emit(f"ast::{rustname}::{cons.name}", depth + 1)
self.emit("} else", depth)

self.emit("{", depth)
@@ -610,13 +676,16 @@ def gen_product_fromobj(self, product, prodname, structname, depth):
self.gen_construction(structname, product, prodname, depth + 1)
self.emit(")", depth)

def gen_construction(self, cons_path, cons, name, depth):
self.emit(f"ast::{cons_path} {{", depth)
def gen_construction_fields(self, cons, name, depth):
for field in cons.fields:
self.emit(
f"{rust_field(field.name)}: {self.decode_field(field, name)},",
depth + 1,
)

def gen_construction(self, cons_path, cons, name, depth):
self.emit(f"ast::{cons_path} {{", depth)
self.gen_construction_fields(cons, name, depth + 1)
self.emit("}", depth)

def extract_location(self, typename, depth):
Loading