@@ -131,6 +131,9 @@ def visitSum(self, sum, name):
131
131
if is_simple (sum ):
132
132
info .has_userdata = False
133
133
else :
134
+ for t in sum .types :
135
+ self .typeinfo [t .name ] = TypeInfo (t .name )
136
+ self .add_children (t .name , t .fields )
134
137
if len (sum .types ) > 1 :
135
138
info .boxed = True
136
139
if sum .attributes :
@@ -205,16 +208,49 @@ def simple_sum(self, sum, name, depth):
205
208
206
209
def sum_with_constructors (self , sum , name , depth ):
207
210
typeinfo = self .typeinfo [name ]
208
- generics , generics_applied = self .get_generics (name , "U = ()" , "U" )
209
211
enumname = rustname = get_rust_type (name )
210
212
# all the attributes right now are for location, so if it has attrs we
211
213
# can just wrap it in Located<>
212
214
if sum .attributes :
213
215
enumname = rustname + "Kind"
216
+
217
+ for t in sum .types :
218
+ if not t .fields :
219
+ continue
220
+ self .emit_attrs (depth )
221
+ self .typeinfo [t ] = TypeInfo (t )
222
+ t_generics , t_generics_applied = self .get_generics (t .name , "U = ()" , "U" )
223
+ payload_name = f"{ rustname } { t .name } "
224
+ self .emit (f"pub struct { payload_name } { t_generics } {{" , depth )
225
+ for f in t .fields :
226
+ self .visit (f , typeinfo , "pub " , depth + 1 , t .name )
227
+ self .emit ("}" , depth )
228
+ self .emit (
229
+ textwrap .dedent (
230
+ f"""
231
+ impl{ t_generics_applied } From<{ payload_name } { t_generics_applied } > for { enumname } { t_generics_applied } {{
232
+ fn from(payload: { payload_name } { t_generics_applied } ) -> Self {{
233
+ { enumname } ::{ t .name } (payload)
234
+ }}
235
+ }}
236
+ """
237
+ ),
238
+ depth ,
239
+ )
240
+
241
+ generics , generics_applied = self .get_generics (name , "U = ()" , "U" )
214
242
self .emit_attrs (depth )
215
243
self .emit (f"pub enum { enumname } { generics } {{" , depth )
216
244
for t in sum .types :
217
- self .visit (t , typeinfo , depth + 1 )
245
+ if t .fields :
246
+ t_generics , t_generics_applied = self .get_generics (
247
+ t .name , "U = ()" , "U"
248
+ )
249
+ self .emit (
250
+ f"{ t .name } ({ rustname } { t .name } { t_generics_applied } )," , depth + 1
251
+ )
252
+ else :
253
+ self .emit (f"{ t .name } ," , depth + 1 )
218
254
self .emit ("}" , depth )
219
255
if sum .attributes :
220
256
self .emit (
@@ -238,13 +274,18 @@ def visitField(self, field, parent, vis, depth, constructor=None):
238
274
if fieldtype and fieldtype .has_userdata :
239
275
typ = f"{ typ } <U>"
240
276
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
241
- if fieldtype and fieldtype .boxed and (not (parent .product or field .seq ) or field .opt ):
277
+ if (
278
+ fieldtype
279
+ and fieldtype .boxed
280
+ and (not (parent .product or field .seq ) or field .opt )
281
+ ):
242
282
typ = f"Box<{ typ } >"
243
283
if field .opt or (
244
284
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
245
285
# the expression to be unpacked goes in `values` with a `None` at the corresponding
246
286
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
247
- constructor == "Dict" and field .name == "keys"
287
+ constructor == "Dict"
288
+ and field .name == "keys"
248
289
):
249
290
typ = f"Option<{ typ } >"
250
291
if field .seq :
@@ -344,14 +385,21 @@ def visitSum(self, sum, name, depth):
344
385
)
345
386
if is_located :
346
387
self .emit ("fold_located(folder, node, |folder, node| {" , depth )
347
- enumname += "Kind"
388
+ rustname = enumname + "Kind"
389
+ else :
390
+ rustname = enumname
348
391
self .emit ("match node {" , depth + 1 )
349
392
for cons in sum .types :
350
- fields_pattern = self .make_pattern (cons .fields )
393
+ fields_pattern = self .make_pattern (
394
+ enumname , rustname , cons .name , cons .fields
395
+ )
351
396
self .emit (
352
- f"{ enumname } ::{ cons .name } {{ { fields_pattern } }} => {{" , depth + 2
397
+ f"{ fields_pattern [0 ]} {{ { fields_pattern [1 ]} }} { fields_pattern [2 ]} => {{" ,
398
+ depth + 2 ,
399
+ )
400
+ self .gen_construction (
401
+ fields_pattern [0 ], cons .fields , fields_pattern [2 ], depth + 3
353
402
)
354
- self .gen_construction (f"{ enumname } ::{ cons .name } " , cons .fields , depth + 3 )
355
403
self .emit ("}" , depth + 2 )
356
404
self .emit ("}" , depth + 1 )
357
405
if is_located :
@@ -381,23 +429,33 @@ def visitProduct(self, product, name, depth):
381
429
)
382
430
if is_located :
383
431
self .emit ("fold_located(folder, node, |folder, node| {" , depth )
384
- structname += "Data"
385
- fields_pattern = self .make_pattern (product .fields )
386
- self .emit (f"let { structname } {{ { fields_pattern } }} = node;" , depth + 1 )
387
- self .gen_construction (structname , product .fields , depth + 1 )
432
+ rustname = structname + "Data"
433
+ else :
434
+ rustname = structname
435
+ fields_pattern = self .make_pattern (rustname , structname , None , product .fields )
436
+ self .emit (f"let { rustname } {{ { fields_pattern [1 ]} }} = node;" , depth + 1 )
437
+ self .gen_construction (rustname , product .fields , "" , depth + 1 )
388
438
if is_located :
389
439
self .emit ("})" , depth )
390
440
self .emit ("}" , depth )
391
441
392
- def make_pattern (self , fields ):
393
- return "," .join (rust_field (f .name ) for f in fields )
442
+ def make_pattern (self , rustname , pyname , fieldname , fields ):
443
+ if fields :
444
+ header = f"{ pyname } ::{ fieldname } ({ rustname } { fieldname } "
445
+ footer = ")"
446
+ else :
447
+ header = f"{ pyname } ::{ fieldname } "
448
+ footer = ""
449
+
450
+ body = "," .join (rust_field (f .name ) for f in fields )
451
+ return header , body , footer
394
452
395
- def gen_construction (self , cons_path , fields , depth ):
396
- self .emit (f"Ok({ cons_path } {{" , depth )
453
+ def gen_construction (self , header , fields , footer , depth ):
454
+ self .emit (f"Ok({ header } {{" , depth )
397
455
for field in fields :
398
456
name = rust_field (field .name )
399
457
self .emit (f"{ name } : Foldable::fold({ name } , folder)?," , depth + 1 )
400
- self .emit (" })" , depth )
458
+ self .emit (f"}} { footer } )" , depth )
401
459
402
460
403
461
class FoldModuleVisitor (TypeInfoEmitVisitor ):
0 commit comments