Skip to content

Commit 748a1dd

Browse files
authored
Merge pull request #69 from WebAssembly/multi-return
Switch (back) to multi-return, remove unit, s/expected/result/
2 parents ca2851c + e409eb6 commit 748a1dd

File tree

7 files changed

+276
-201
lines changed

7 files changed

+276
-201
lines changed

design/mvp/Binary.md

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -157,38 +157,39 @@ deftype ::= dvt:<defvaltype> => dvt
157157
| ft:<functype> => ft
158158
| ct:<componenttype> => ct
159159
| it:<instancetype> => it
160-
primvaltype ::= 0x7f => unit
161-
| 0x7e => bool
162-
| 0x7d => s8
163-
| 0x7c => u8
164-
| 0x7b => s16
165-
| 0x7a => u16
166-
| 0x79 => s32
167-
| 0x78 => u32
168-
| 0x77 => s64
169-
| 0x76 => u64
170-
| 0x75 => float32
171-
| 0x74 => float64
172-
| 0x73 => char
173-
| 0x72 => string
160+
primvaltype ::= 0x7f => bool
161+
| 0x7e => s8
162+
| 0x7d => u8
163+
| 0x7c => s16
164+
| 0x7b => u16
165+
| 0x7a => s32
166+
| 0x79 => u32
167+
| 0x78 => s64
168+
| 0x77 => u64
169+
| 0x76 => float32
170+
| 0x75 => float64
171+
| 0x74 => char
172+
| 0x73 => string
174173
defvaltype ::= pvt:<primvaltype> => pvt
175-
| 0x71 field*:vec(<field>) => (record field*)
176-
| 0x70 case*:vec(<case>) => (variant case*)
177-
| 0x6f t:<valtype> => (list t)
178-
| 0x6e t*:vec(<valtype>) => (tuple t*)
179-
| 0x6d n*:vec(<name>) => (flags n*)
180-
| 0x6c n*:vec(<name>) => (enum n*)
181-
| 0x6b t*:vec(<valtype>) => (union t*)
182-
| 0x6a t:<valtype> => (option t)
183-
| 0x69 t:<valtype> u:<valtype> => (expected t u)
184-
field ::= n:<name> t:<valtype> => (field n t)
185-
case ::= n:<name> t:<valtype> 0x0 => (case n t)
186-
| n:<name> t:<valtype> 0x1 i:<u32> => (case n t (refines case-label[i]))
174+
| 0x72 nt*:vec(<namedvaltype>) => (record (field nt)*)
175+
| 0x71 case*:vec(<case>) => (variant case*)
176+
| 0x70 t:<valtype> => (list t)
177+
| 0x6f t*:vec(<valtype>) => (tuple t*)
178+
| 0x6e n*:vec(<name>) => (flags n*)
179+
| 0x6d n*:vec(<name>) => (enum n*)
180+
| 0x6c t*:vec(<valtype>) => (union t*)
181+
| 0x6b t:<valtype> => (option t)
182+
| 0x6a t?:<casetype> u?:<casetype> => (result t? (error u)?)
183+
namedvaltype ::= n:<name> t:<valtype> => n t
184+
case ::= n:<name> t?:<casetype> 0x0 => (case n t?)
185+
| n:<name> t?:<casetype> 0x1 i:<u32> => (case n t? (refines case-label[i]))
186+
casetype ::= 0x00 =>
187+
| 0x01 t:<valtype> => t
187188
valtype ::= i:<typeidx> => i
188189
| pvt:<primvaltype> => pvt
189-
functype ::= 0x40 param*:vec(<param>) t:<valtype> => (func param* (result t))
190-
param ::= 0x00 t:<valtype> => (param t)
191-
| 0x01 n:<name> t:<valtype> => (param n t)
190+
functype ::= 0x40 p*:<funcvec> r*:<funcvec> => (func (param p)* (result r)*)
191+
funcvec ::= 0x00 t:<valtype> => [t]
192+
| 0x01 nt*:vec(<namedvaltype>) => nt*
192193
componenttype ::= 0x41 cd*:vec(<componentdecl>) => (component cd*)
193194
instancetype ::= 0x42 id*:vec(<instancedecl>) => (instance id*)
194195
componentdecl ::= 0x03 id:<importdecl> => id
@@ -219,9 +220,9 @@ Notes:
219220
in type definitions from containing components.
220221
* Validation of `externdesc` requires the various `typeidx` type constructors
221222
to match the preceding `sort`.
222-
* Validation of record field names, variant case names, flag names, and enum case
223-
names requires that the name be unique for the record, variant, flags, or enum
224-
type definition.
223+
* Validation of function parameter and result names, record field names,
224+
variant case names, flag names, and enum case names requires that the name be
225+
unique for the func, record, variant, flags, or enum type definition.
225226
* Validation of the optional `refines` clause of a variant case requires that
226227
the case index is less than the current case's index (and therefore
227228
cases are acyclic).

design/mvp/CanonicalABI.md

Lines changed: 83 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,12 @@ function to replace specialized value types with their expansion:
6868
```python
6969
def despecialize(t):
7070
match t:
71-
case Tuple(ts) : return Record([ Field(str(i), t) for i,t in enumerate(ts) ])
72-
case Unit() : return Record([])
73-
case Union(ts) : return Variant([ Case(str(i), t) for i,t in enumerate(ts) ])
74-
case Enum(labels) : return Variant([ Case(l, Unit()) for l in labels ])
75-
case Option(t) : return Variant([ Case("none", Unit()), Case("some", t) ])
76-
case Expected(ok, error) : return Variant([ Case("ok", ok), Case("error", error) ])
77-
case _ : return t
71+
case Tuple(ts) : return Record([ Field(str(i), t) for i,t in enumerate(ts) ])
72+
case Union(ts) : return Variant([ Case(str(i), t) for i,t in enumerate(ts) ])
73+
case Enum(labels) : return Variant([ Case(l, None) for l in labels ])
74+
case Option(t) : return Variant([ Case("none", None), Case("some", t) ])
75+
case Result(ok, error) : return Variant([ Case("ok", ok), Case("error", error) ])
76+
case _ : return t
7877
```
7978
The specialized value types `string` and `flags` are missing from this list
8079
because they are given specialized canonical ABI representations distinct from
@@ -98,17 +97,17 @@ def alignment(t):
9897
case Float64() : return 8
9998
case Char() : return 4
10099
case String() | List(_) : return 4
101-
case Record(fields) : return max_alignment(types_of(fields))
102-
case Variant(cases) : return max_alignment(types_of(cases) + [discriminant_type(cases)])
100+
case Record(fields) : return alignment_record(fields)
101+
case Variant(cases) : return alignment_variant(cases)
103102
case Flags(labels) : return alignment_flags(labels)
103+
```
104104

105-
def types_of(fields_or_cases):
106-
return [x.t for x in fields_or_cases]
107-
108-
def max_alignment(ts):
105+
Record alignment is tuple alignment, with the definitions split for reuse below:
106+
```python
107+
def alignment_record(fields):
109108
a = 1
110-
for t in ts:
111-
a = max(a, alignment(t))
109+
for f in fields:
110+
a = max(a, alignment(f.t))
112111
return a
113112
```
114113

@@ -117,6 +116,9 @@ covering the number of cases in the variant. Depending on the payload type,
117116
this can allow more compact representations of variants in memory. This smallest
118117
integer type is selected by the following function, used above and below:
119118
```python
119+
def alignment_variant(cases):
120+
return max(alignment(discriminant_type(cases)), max_case_alignment(cases))
121+
120122
def discriminant_type(cases):
121123
n = len(cases)
122124
assert(0 < n < (1 << 32))
@@ -125,6 +127,13 @@ def discriminant_type(cases):
125127
case 1: return U8()
126128
case 2: return U16()
127129
case 3: return U32()
130+
131+
def max_case_alignment(cases):
132+
a = 1
133+
for c in cases:
134+
if c.t is not None:
135+
a = max(a, alignment(c.t))
136+
return a
128137
```
129138

130139
As an optimization, `flags` are represented as packed bit-vectors. Like variant
@@ -164,19 +173,20 @@ def size_record(fields):
164173
for f in fields:
165174
s = align_to(s, alignment(f.t))
166175
s += size(f.t)
167-
return align_to(s, alignment(Record(fields)))
176+
return align_to(s, alignment_record(fields))
168177

169178
def align_to(ptr, alignment):
170179
return math.ceil(ptr / alignment) * alignment
171180

172181
def size_variant(cases):
173182
s = size(discriminant_type(cases))
174-
s = align_to(s, max_alignment(types_of(cases)))
183+
s = align_to(s, max_case_alignment(cases))
175184
cs = 0
176185
for c in cases:
177-
cs = max(cs, size(c.t))
186+
if c.t is not None:
187+
cs = max(cs, size(c.t))
178188
s += cs
179-
return align_to(s, alignment(Variant(cases)))
189+
return align_to(s, alignment_variant(cases))
180190

181191
def size_flags(labels):
182192
n = len(labels)
@@ -362,18 +372,21 @@ string operations.
362372
```python
363373
def load_variant(opts, ptr, cases):
364374
disc_size = size(discriminant_type(cases))
365-
disc = load_int(opts, ptr, disc_size)
375+
case_index = load_int(opts, ptr, disc_size)
366376
ptr += disc_size
367-
trap_if(disc >= len(cases))
368-
case = cases[disc]
369-
ptr = align_to(ptr, max_alignment(types_of(cases)))
370-
return { case_label_with_refinements(case, cases): load(opts, ptr, case.t) }
371-
372-
def case_label_with_refinements(case, cases):
373-
label = case.label
374-
while case.refines is not None:
375-
case = cases[find_case(case.refines, cases)]
376-
label += '|' + case.label
377+
trap_if(case_index >= len(cases))
378+
c = cases[case_index]
379+
ptr = align_to(ptr, max_case_alignment(cases))
380+
case_label = case_label_with_refinements(c, cases)
381+
if c.t is None:
382+
return { case_label: None }
383+
return { case_label: load(opts, ptr, c.t) }
384+
385+
def case_label_with_refinements(c, cases):
386+
label = c.label
387+
while c.refines is not None:
388+
c = cases[find_case(c.refines, cases)]
389+
label += '|' + c.label
377390
return label
378391

379392
def find_case(label, cases):
@@ -702,8 +715,10 @@ def store_variant(opts, v, ptr, cases):
702715
disc_size = size(discriminant_type(cases))
703716
store_int(opts, case_index, ptr, disc_size)
704717
ptr += disc_size
705-
ptr = align_to(ptr, max_alignment(types_of(cases)))
706-
store(opts, case_value, cases[case_index].t, ptr)
718+
ptr = align_to(ptr, max_case_alignment(cases))
719+
c = cases[case_index]
720+
if c.t is not None:
721+
store(opts, case_value, c.t, ptr)
707722

708723
def match_case(v, cases):
709724
assert(len(v.keys()) == 1)
@@ -771,7 +786,7 @@ def flatten(functype, context):
771786
if len(flat_params) > MAX_FLAT_PARAMS:
772787
flat_params = ['i32']
773788

774-
flat_results = flatten_type(functype.result)
789+
flat_results = flatten_types(functype.results)
775790
if len(flat_results) > MAX_FLAT_RESULTS:
776791
match context:
777792
case 'lift':
@@ -799,11 +814,20 @@ def flatten_type(t):
799814
case Float64() : return ['f64']
800815
case Char() : return ['i32']
801816
case String() | List(_) : return ['i32', 'i32']
802-
case Record(fields) : return flatten_types(types_of(fields))
817+
case Record(fields) : return flatten_record(fields)
803818
case Variant(cases) : return flatten_variant(cases)
804819
case Flags(labels) : return ['i32'] * num_i32_flags(labels)
805820
```
806821

822+
Record flattening simply flattens each field in sequence.
823+
```python
824+
def flatten_record(fields):
825+
flat = []
826+
for f in fields:
827+
flat += flatten_type(f.t)
828+
return flat
829+
```
830+
807831
Variant flattening is more involved due to the fact that each case payload can
808832
have a totally different flattening. Rather than giving up when there is a type
809833
mismatch, the Canonical ABI relies on the fact that the 4 core value types can
@@ -816,11 +840,12 @@ an `i32` into an `i64`.
816840
def flatten_variant(cases):
817841
flat = []
818842
for c in cases:
819-
for i,ft in enumerate(flatten_type(c.t)):
820-
if i < len(flat):
821-
flat[i] = join(flat[i], ft)
822-
else:
823-
flat.append(ft)
843+
if c.t is not None:
844+
for i,ft in enumerate(flatten_type(c.t)):
845+
if i < len(flat):
846+
flat[i] = join(flat[i], ft)
847+
else:
848+
flat.append(ft)
824849
return flatten_type(discriminant_type(cases)) + flat
825850

826851
def join(a, b):
@@ -929,9 +954,8 @@ high bits of an `i64` are set for a 32-bit type:
929954
def lift_flat_variant(opts, vi, cases):
930955
flat_types = flatten_variant(cases)
931956
assert(flat_types.pop(0) == 'i32')
932-
disc = vi.next('i32')
933-
trap_if(disc >= len(cases))
934-
case = cases[disc]
957+
case_index = vi.next('i32')
958+
trap_if(case_index >= len(cases))
935959
class CoerceValueIter:
936960
def next(self, want):
937961
have = flat_types.pop(0)
@@ -942,10 +966,14 @@ def lift_flat_variant(opts, vi, cases):
942966
case ('i64', 'f32') : return reinterpret_i32_as_float(wrap_i64_to_i32(x))
943967
case ('i64', 'f64') : return reinterpret_i64_as_float(x)
944968
case _ : return x
945-
v = lift_flat(opts, CoerceValueIter(), case.t)
969+
c = cases[case_index]
970+
if c.t is None:
971+
v = None
972+
else:
973+
v = lift_flat(opts, CoerceValueIter(), c.t)
946974
for have in flat_types:
947975
_ = vi.next(have)
948-
return { case_label_with_refinements(case, cases): v }
976+
return { case_label_with_refinements(c, cases): v }
949977

950978
def wrap_i64_to_i32(i):
951979
assert(0 <= i < (1 << 64))
@@ -1034,7 +1062,11 @@ def lower_flat_variant(opts, v, cases):
10341062
case_index, case_value = match_case(v, cases)
10351063
flat_types = flatten_variant(cases)
10361064
assert(flat_types.pop(0) == 'i32')
1037-
payload = lower_flat(opts, case_value, cases[case_index].t)
1065+
c = cases[case_index]
1066+
if c.t is None:
1067+
payload = []
1068+
else:
1069+
payload = lower_flat(opts, case_value, c.t)
10381070
for i,have in enumerate(payload):
10391071
want = flat_types.pop(0)
10401072
match (have.t, want):
@@ -1177,21 +1209,21 @@ def canon_lift(callee_opts, callee_instance, callee, functype, args, called_as_e
11771209
except CoreWebAssemblyException:
11781210
trap()
11791211

1180-
[result] = lift(callee_opts, MAX_FLAT_RESULTS, ValueIter(flat_results), [functype.result])
1212+
results = lift(callee_opts, MAX_FLAT_RESULTS, ValueIter(flat_results), functype.results)
11811213
def post_return():
11821214
if callee_opts.post_return is not None:
11831215
callee_opts.post_return(flat_results)
11841216
if called_as_export:
11851217
callee_instance.may_enter = True
11861218

1187-
return (result, post_return)
1219+
return (results, post_return)
11881220
```
11891221
There are a number of things to note about this definition:
11901222

11911223
Uncaught Core WebAssembly [exceptions] result in a trap at component
11921224
boundaries. Thus, if a component wishes to signal an error, it must use some
1193-
sort of explicit type such as `expected` (whose `error` case particular
1194-
language bindings may choose to map to and from exceptions).
1225+
sort of explicit type such as `result` (whose `error` case particular language
1226+
bindings may choose to map to and from exceptions).
11951227

11961228
The `called_as_export` parameter indicates whether `canon_lift` is being called
11971229
as part of a component export or whether this `canon_lift` is being called
@@ -1234,10 +1266,10 @@ def canon_lower(caller_opts, caller_instance, callee, functype, flat_args):
12341266
flat_args = ValueIter(flat_args)
12351267
args = lift(caller_opts, MAX_FLAT_PARAMS, flat_args, functype.params)
12361268

1237-
result, post_return = callee(args)
1269+
results, post_return = callee(args)
12381270

12391271
caller_instance.may_leave = False
1240-
flat_results = lower(caller_opts, MAX_FLAT_RESULTS, [result], [functype.result], flat_args)
1272+
flat_results = lower(caller_opts, MAX_FLAT_RESULTS, results, functype.results, flat_args)
12411273
caller_instance.may_leave = True
12421274

12431275
post_return()

0 commit comments

Comments
 (0)