Skip to content

Commit b2d797f

Browse files
authored
[TypedFunctionReferences] Add Typed Function References feature and use the types (#3388)
This adds the new feature and starts to use the new types where relevant. We use them even without the feature being enabled, as we don't know the features during wasm loading - but the hope is that given the type is a subtype, it should all work out. In practice, if you print out the internal type you may see a typed function reference-specific type for a ref.func for example, instead of a generic funcref, but it should not affect anything else. This PR does not support non-nullable types, that is, everything is nullable for now. As suggested by @tlively this is simpler for now and leaves nullability for later work (which will apparently require let or something else, and many passes may need to be changed). To allow this PR to work, we need to provide a type on creating a RefFunc. The wasm-builder.h internal API is updated for this, as are the C and JS APIs, which are breaking changes. cc @dcodeIO We must also write and read function types properly. This PR improves collectSignatures to find all the types, and also to sort them by the dependencies between them (as we can't emit X in the binary if it depends on Y, and Y has not been emitted - we need to give Y's index). This sorting ends up changing a few test outputs. InstrumentLocals support for printing function types that are not funcref is disabled for now, until we figure out how to make that work and/or decide if it's important enough to work on. The fuzzer has various fixes to emit valid types for things (mostly whitespace there). Also two drive-by fixes to call makeTrivial where it should be (when we fail to create a specific node, we can't just try to make another node, in theory it could infinitely recurse). Binary writing changes here to replace calls to a standalone function to write out a type with one that is called on the binary writer object itself, which maintains a mapping of type indexes (getFunctionSignatureByIndex).
1 parent 6829433 commit b2d797f

31 files changed

+1262
-1098
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ full changeset diff at the end of each section.
1515
Current Trunk
1616
-------------
1717

18+
- `RefFunc` C and JS API constructors (`BinaryenRefFunc` and `ref.func`
19+
respectively) now take an extra `type` parameter, similar to `RefNull`. This
20+
is necessary for typed function references support.
1821
- JS API functions for atomic notify/wait instructions are renamed.
1922
- `module.atomic.notify` -> `module.memory.atomic.notify`
2023
- `module.i32.atomic.wait` -> `module.memory.atomic.wait32`

src/binaryen-c.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,9 +1187,11 @@ BinaryenExpressionRef BinaryenRefIsNull(BinaryenModuleRef module,
11871187
Builder(*(Module*)module).makeRefIsNull((Expression*)value));
11881188
}
11891189

1190-
BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module,
1191-
const char* func) {
1192-
return static_cast<Expression*>(Builder(*(Module*)module).makeRefFunc(func));
1190+
BinaryenExpressionRef
1191+
BinaryenRefFunc(BinaryenModuleRef module, const char* func, BinaryenType type) {
1192+
Type type_(type);
1193+
return static_cast<Expression*>(
1194+
Builder(*(Module*)module).makeRefFunc(func, type_));
11931195
}
11941196

11951197
BinaryenExpressionRef BinaryenRefEq(BinaryenModuleRef module,

src/binaryen-c.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,8 @@ BINARYEN_API BinaryenExpressionRef BinaryenRefNull(BinaryenModuleRef module,
792792
BINARYEN_API BinaryenExpressionRef
793793
BinaryenRefIsNull(BinaryenModuleRef module, BinaryenExpressionRef value);
794794
BINARYEN_API BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module,
795-
const char* func);
795+
const char* func,
796+
BinaryenType type);
796797
BINARYEN_API BinaryenExpressionRef BinaryenRefEq(BinaryenModuleRef module,
797798
BinaryenExpressionRef left,
798799
BinaryenExpressionRef right);

src/ir/ReFinalize.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,11 @@ void ReFinalize::visitMemorySize(MemorySize* curr) { curr->finalize(); }
126126
void ReFinalize::visitMemoryGrow(MemoryGrow* curr) { curr->finalize(); }
127127
void ReFinalize::visitRefNull(RefNull* curr) { curr->finalize(); }
128128
void ReFinalize::visitRefIsNull(RefIsNull* curr) { curr->finalize(); }
129-
void ReFinalize::visitRefFunc(RefFunc* curr) { curr->finalize(); }
129+
void ReFinalize::visitRefFunc(RefFunc* curr) {
130+
// TODO: should we look up the function and update the type from there? This
131+
// could handle a change to the function's type, but is also not really what
132+
// this class has been meant to do.
133+
}
130134
void ReFinalize::visitRefEq(RefEq* curr) { curr->finalize(); }
131135
void ReFinalize::visitTry(Try* curr) { curr->finalize(); }
132136
void ReFinalize::visitThrow(Throw* curr) { curr->finalize(); }

src/ir/module-utils.h

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,16 +414,29 @@ collectSignatures(Module& wasm,
414414
Counts& counts;
415415

416416
TypeCounter(Counts& counts) : counts(counts) {}
417+
417418
void visitExpression(Expression* curr) {
418-
if (auto* call = curr->dynCast<CallIndirect>()) {
419+
if (curr->is<RefNull>()) {
420+
maybeNote(curr->type);
421+
} else if (auto* call = curr->dynCast<CallIndirect>()) {
419422
counts[call->sig]++;
420423
} else if (Properties::isControlFlowStructure(curr)) {
421-
// TODO: Allow control flow to have input types as well
424+
maybeNote(curr->type);
422425
if (curr->type.isTuple()) {
426+
// TODO: Allow control flow to have input types as well
423427
counts[Signature(Type::none, curr->type)]++;
424428
}
425429
}
426430
}
431+
432+
void maybeNote(Type type) {
433+
if (type.isRef()) {
434+
auto heapType = type.getHeapType();
435+
if (heapType.isSignature()) {
436+
counts[heapType.getSignature()]++;
437+
}
438+
}
439+
}
427440
};
428441
TypeCounter(counts).walk(func->body);
429442
};
@@ -434,6 +447,14 @@ collectSignatures(Module& wasm,
434447
Counts counts;
435448
for (auto& curr : wasm.functions) {
436449
counts[curr->sig]++;
450+
for (auto type : curr->vars) {
451+
if (type.isRef()) {
452+
auto heapType = type.getHeapType();
453+
if (heapType.isSignature()) {
454+
counts[heapType.getSignature()]++;
455+
}
456+
}
457+
}
437458
}
438459
for (auto& curr : wasm.events) {
439460
counts[curr->sig]++;
@@ -444,10 +465,61 @@ collectSignatures(Module& wasm,
444465
counts[innerPair.first] += innerPair.second;
445466
}
446467
}
468+
469+
// TODO: recursively traverse each reference type, which may have a child type
470+
// this is itself a reference type.
471+
472+
// We must sort all the dependencies of a signature before it. For example,
473+
// (func (param (ref (func)))) must appear after (func). To do that, find the
474+
// depth of dependencies of each signature. For example, if A depends on B
475+
// which depends on C, then A's depth is 2, B's is 1, and C's is 0 (assuming
476+
// no other dependencies).
477+
Counts depthOfDependencies;
478+
std::unordered_map<Signature, std::unordered_set<Signature>> isDependencyOf;
479+
// To calculate the depth of dependencies, we'll do a flow analysis, visiting
480+
// each signature as we find out new things about it.
481+
std::set<Signature> toVisit;
482+
for (auto& pair : counts) {
483+
auto sig = pair.first;
484+
depthOfDependencies[sig] = 0;
485+
toVisit.insert(sig);
486+
for (Type type : {sig.params, sig.results}) {
487+
for (auto element : type) {
488+
if (element.isRef()) {
489+
auto heapType = element.getHeapType();
490+
if (heapType.isSignature()) {
491+
isDependencyOf[heapType.getSignature()].insert(sig);
492+
}
493+
}
494+
}
495+
}
496+
}
497+
while (!toVisit.empty()) {
498+
auto iter = toVisit.begin();
499+
auto sig = *iter;
500+
toVisit.erase(iter);
501+
// Anything that depends on this has a depth of dependencies equal to this
502+
// signature's, plus this signature itself.
503+
auto newDepth = depthOfDependencies[sig] + 1;
504+
if (newDepth > counts.size()) {
505+
Fatal() << "Cyclic signatures detected, cannot sort them.";
506+
}
507+
for (auto& other : isDependencyOf[sig]) {
508+
if (depthOfDependencies[other] < newDepth) {
509+
// We found something new to propagate.
510+
depthOfDependencies[other] = newDepth;
511+
toVisit.insert(other);
512+
}
513+
}
514+
}
515+
// Sort by frequency and then simplicity, and also keeping every signature
516+
// before things that depend on it.
447517
std::vector<std::pair<Signature, size_t>> sorted(counts.begin(),
448518
counts.end());
449519
std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {
450-
// order by frequency then simplicity
520+
if (depthOfDependencies[a.first] != depthOfDependencies[b.first]) {
521+
return depthOfDependencies[a.first] < depthOfDependencies[b.first];
522+
}
451523
if (a.second != b.second) {
452524
return a.second > b.second;
453525
}

src/js/binaryen.js-post.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2112,8 +2112,8 @@ function wrapModule(module, self = {}) {
21122112
'is_null'(value) {
21132113
return Module['_BinaryenRefIsNull'](module, value);
21142114
},
2115-
'func'(func) {
2116-
return preserveStack(() => Module['_BinaryenRefFunc'](module, strToStack(func)));
2115+
'func'(func, type) {
2116+
return preserveStack(() => Module['_BinaryenRefFunc'](module, strToStack(func), type));
21172117
},
21182118
'eq'(left, right) {
21192119
return Module['_BinaryenRefEq'](module, left, right);

src/passes/InstrumentLocals.cpp

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -135,45 +135,48 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> {
135135
Builder builder(*getModule());
136136
Name import;
137137
auto type = curr->value->type;
138-
if (type.isFunction()) {
139-
import = set_funcref;
140-
} else {
141-
TODO_SINGLE_COMPOUND(curr->value->type);
142-
switch (type.getBasic()) {
143-
case Type::i32:
144-
import = set_i32;
145-
break;
146-
case Type::i64:
147-
return; // TODO
148-
case Type::f32:
149-
import = set_f32;
150-
break;
151-
case Type::f64:
152-
import = set_f64;
153-
break;
154-
case Type::v128:
155-
import = set_v128;
156-
break;
157-
case Type::externref:
158-
import = set_externref;
159-
break;
160-
case Type::exnref:
161-
import = set_exnref;
162-
break;
163-
case Type::anyref:
164-
import = set_anyref;
165-
break;
166-
case Type::eqref:
167-
import = set_eqref;
168-
break;
169-
case Type::i31ref:
170-
import = set_i31ref;
171-
break;
172-
case Type::unreachable:
173-
return; // nothing to do here
174-
default:
175-
WASM_UNREACHABLE("unexpected type");
176-
}
138+
if (type.isFunction() && type != Type::funcref) {
139+
// FIXME: support typed function references
140+
return;
141+
}
142+
TODO_SINGLE_COMPOUND(curr->value->type);
143+
switch (type.getBasic()) {
144+
case Type::i32:
145+
import = set_i32;
146+
break;
147+
case Type::i64:
148+
return; // TODO
149+
case Type::f32:
150+
import = set_f32;
151+
break;
152+
case Type::f64:
153+
import = set_f64;
154+
break;
155+
case Type::v128:
156+
import = set_v128;
157+
break;
158+
case Type::funcref:
159+
import = set_funcref;
160+
break;
161+
case Type::externref:
162+
import = set_externref;
163+
break;
164+
case Type::exnref:
165+
import = set_exnref;
166+
break;
167+
case Type::anyref:
168+
import = set_anyref;
169+
break;
170+
case Type::eqref:
171+
import = set_eqref;
172+
break;
173+
case Type::i31ref:
174+
import = set_i31ref;
175+
break;
176+
case Type::unreachable:
177+
return; // nothing to do here
178+
default:
179+
WASM_UNREACHABLE("unexpected type");
177180
}
178181
curr->value = builder.makeCall(import,
179182
{builder.makeConst(int32_t(id++)),

src/tools/fuzzing.h

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ class TranslateToFuzzReader {
321321
}
322322
return Type(types);
323323
}
324+
if (type.isFunction() && type != Type::funcref) {
325+
// TODO: specific typed function references types.
326+
return type;
327+
}
324328
SmallVector<Type, 2> options;
325329
options.push_back(type); // includes itself
326330
TODO_SINGLE_COMPOUND(type);
@@ -653,6 +657,10 @@ class TranslateToFuzzReader {
653657
Index numVars = upToSquared(MAX_VARS);
654658
for (Index i = 0; i < numVars; i++) {
655659
auto type = getConcreteType();
660+
if (type.isRef() && !type.isNullable()) {
661+
// We can't use a nullable type as a var, which is null-initialized.
662+
continue;
663+
}
656664
funcContext->typeLocals[type].push_back(params.size() +
657665
func->vars.size());
658666
func->vars.push_back(type);
@@ -1371,7 +1379,6 @@ class TranslateToFuzzReader {
13711379
}
13721380

13731381
Expression* makeCall(Type type) {
1374-
// seems ok, go on
13751382
int tries = TRIES;
13761383
bool isReturn;
13771384
while (tries-- > 0) {
@@ -1392,7 +1399,7 @@ class TranslateToFuzzReader {
13921399
return builder.makeCall(target->name, args, type, isReturn);
13931400
}
13941401
// we failed to find something
1395-
return make(type);
1402+
return makeTrivial(type);
13961403
}
13971404

13981405
Expression* makeCallIndirect(Type type) {
@@ -1418,7 +1425,7 @@ class TranslateToFuzzReader {
14181425
i = 0;
14191426
}
14201427
if (i == start) {
1421-
return make(type);
1428+
return makeTrivial(type);
14221429
}
14231430
}
14241431
// with high probability, make sure the type is valid otherwise, most are
@@ -2018,12 +2025,28 @@ class TranslateToFuzzReader {
20182025
if (!wasm.functions.empty() && !oneIn(wasm.functions.size())) {
20192026
target = pick(wasm.functions).get();
20202027
}
2021-
return builder.makeRefFunc(target->name);
2028+
auto type = Type(HeapType(target->sig), /* nullable = */ true);
2029+
return builder.makeRefFunc(target->name, type);
20222030
}
20232031
if (type == Type::i31ref) {
20242032
return builder.makeI31New(makeConst(Type::i32));
20252033
}
2026-
return builder.makeRefNull(type);
2034+
if (oneIn(2) && type.isNullable()) {
2035+
return builder.makeRefNull(type);
2036+
}
2037+
// TODO: randomize the order
2038+
for (auto& func : wasm.functions) {
2039+
// FIXME: RefFunc type should be non-nullable, but we emit nullable
2040+
// types for now.
2041+
if (type == Type(HeapType(func->sig), /* nullable = */ true)) {
2042+
return builder.makeRefFunc(func->name, type);
2043+
}
2044+
}
2045+
// We failed to find a function, so create a null reference if we can.
2046+
if (type.isNullable()) {
2047+
return builder.makeRefNull(type);
2048+
}
2049+
WASM_UNREACHABLE("un-handleable non-nullable type");
20272050
}
20282051
if (type.isTuple()) {
20292052
std::vector<Expression*> operands;
@@ -2972,6 +2995,7 @@ class TranslateToFuzzReader {
29722995
Type::anyref,
29732996
Type::eqref,
29742997
Type::i31ref));
2998+
// TODO: emit typed function references types
29752999
}
29763000

29773001
Type getSingleConcreteType() { return pick(getSingleConcreteTypes()); }
@@ -2997,12 +3021,24 @@ class TranslateToFuzzReader {
29973021

29983022
Type getEqReferenceType() { return pick(getEqReferenceTypes()); }
29993023

3024+
Type getMVPType() {
3025+
return pick(items(FeatureOptions<Type>().add(
3026+
FeatureSet::MVP, Type::i32, Type::i64, Type::f32, Type::f64)));
3027+
}
3028+
30003029
Type getTupleType() {
30013030
std::vector<Type> elements;
3002-
size_t numElements = 2 + upTo(MAX_TUPLE_SIZE - 1);
3003-
elements.resize(numElements);
3004-
for (size_t i = 0; i < numElements; ++i) {
3005-
elements[i] = getSingleConcreteType();
3031+
size_t maxElements = 2 + upTo(MAX_TUPLE_SIZE - 1);
3032+
for (size_t i = 0; i < maxElements; ++i) {
3033+
auto type = getSingleConcreteType();
3034+
// Don't add a non-nullable type into a tuple, as currently we can't spill
3035+
// them into locals (that would require a "let").
3036+
if (!type.isNullable()) {
3037+
elements.push_back(type);
3038+
}
3039+
}
3040+
while (elements.size() < 2) {
3041+
elements.push_back(getMVPType());
30063042
}
30073043
return Type(elements);
30083044
}

src/tools/tool-options.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ struct ToolOptions : public Options {
8989
.addFeature(FeatureSet::Multivalue, "multivalue functions")
9090
.addFeature(FeatureSet::GC, "garbage collection")
9191
.addFeature(FeatureSet::Memory64, "memory64")
92+
.addFeature(FeatureSet::TypedFunctionReferences,
93+
"typed function references")
9294
.add("--no-validation",
9395
"-n",
9496
"Disables validation, assumes inputs are correct",

0 commit comments

Comments
 (0)