Skip to content

Commit e7df318

Browse files
committed
Add support for multiple tables and update CFI
This commits makes a number of changes to the WebAssembly format, some of which exceed the feature set desired for the MVP. (1) It adds support for updated table definitions, including the default, elementType, initial, and max attributes, plus a name. Currently, the initial and max attributes must be equal to the number of elements. The elementType attribute is interpreted as a FunctionType index, and type homogeneity is enforced on table elements, unless the specified FunctionType has name "anyfunc", which corresponds to a FunctionType with a none parameter and return type none. Format: (table <name> [default] <type> <entries>) (2) It adds support for multiple tables. If tables are used, currently the first table must be default, and the remainder must not. Example: (table "foo" default (type $FUNCSIG$i) $a) (table "bla" (type $anyfunc) $b $c $d) (3) Indirect calls have an immediate argument that specifies the index of the function call table. Example: (call_indirect "foo" $FUNCSIG$i (get_local $1)) (4) Corresponding upstream LLVM changes are required to use multiple tables, but the updated format is backwards compatible. Example: i32.call_indirect $0=, $pop0 i32.call_indirect.1 $0=, $pop0, $1, $2, $3 (5) Generating WebAssembly from code built with Clang/LLVM CFI now utilizes multiple tables. This is the only enabled use case for multiple tables; all others will default to a single table, if tables are used. The value passed in the .indidx assembler directive is now interpreted as the index of the indirect call table to assign.
1 parent 96e226d commit e7df318

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+528
-319
lines changed

src/asm2wasm.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -660,11 +660,11 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
660660
// TODO: when not using aliasing function pointers, we could merge them by noticing that
661661
// index 0 in each table is the null func, and each other index should only have one
662662
// non-null func. However, that breaks down when function pointer casts are emulated.
663-
functionTableStarts[name] = wasm.table.names.size(); // this table starts here
663+
functionTableStarts[name] = wasm.getDefaultTable()->values.size(); // this table starts here
664664
Ref contents = value[1];
665665
for (unsigned k = 0; k < contents->size(); k++) {
666666
IString curr = contents[k][1]->getIString();
667-
wasm.table.names.push_back(curr);
667+
wasm.getDefaultTable()->values.push_back(curr);
668668
}
669669
} else {
670670
abort_on("invalid var element", pair);
@@ -1404,6 +1404,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
14041404
}
14051405
// function pointers
14061406
auto ret = allocator.alloc<CallIndirect>();
1407+
ret->table = wasm.getDefaultTable()->name;
14071408
Ref target = ast[1];
14081409
assert(target[0] == SUB && target[1][0] == NAME && target[2][0] == BINARY && target[2][1] == AND && target[2][3][0] == NUM); // FUNCTION_TABLE[(expr) & mask]
14091410
ret->target = process(target[2]); // TODO: as an optimization, we could look through the mask

src/ast_utils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ struct ExpressionManipulator {
265265
return ret;
266266
}
267267
Expression* visitCallIndirect(CallIndirect *curr) {
268-
auto* ret = builder.makeCallIndirect(curr->fullType, curr->target, {}, curr->type);
268+
auto* ret = builder.makeCallIndirect(curr->table, curr->fullType, curr->target, {}, curr->type);
269269
for (Index i = 0; i < curr->operands.size(); i++) {
270270
ret->operands.push_back(copy(curr->operands[i]));
271271
}
@@ -459,6 +459,7 @@ struct ExpressionAnalyzer {
459459
break;
460460
}
461461
case Expression::Id::CallIndirectId: {
462+
CHECK(CallIndirect, table);
462463
PUSH(CallIndirect, target);
463464
CHECK(CallIndirect, fullType);
464465
CHECK(CallIndirect, operands.size());
@@ -661,6 +662,7 @@ struct ExpressionAnalyzer {
661662
break;
662663
}
663664
case Expression::Id::CallIndirectId: {
665+
HASH_NAME(CallIndirect, table);
664666
PUSH(CallIndirect, target);
665667
HASH_NAME(CallIndirect, fullType);
666668
HASH(CallIndirect, operands.size());

src/binaryen-c.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressio
586586
}
587587
BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value) {
588588
auto* ret = Builder(*((Module*)module)).makeReturn((Expression*)value);
589-
589+
590590
if (tracing) {
591591
auto id = noteExpression(ret);
592592
std::cout << " expressions[" << id << "] = BinaryenReturn(the_module, expressions[" << expressions[value] << "]);\n";
@@ -730,7 +730,7 @@ void BinaryenSetFunctionTable(BinaryenModuleRef module, BinaryenFunctionRef* fun
730730

731731
auto* wasm = (Module*)module;
732732
for (BinaryenIndex i = 0; i < numFuncs; i++) {
733-
wasm->table.names.push_back(((Function*)funcs[i])->name);
733+
wasm->getDefaultTable()->values.push_back(((Function*)funcs[i])->name);
734734
}
735735
}
736736

src/passes/DuplicateFunctionElimination.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,12 @@ struct DuplicateFunctionElimination : public Pass {
123123
replacerRunner.add<FunctionReplacer>(&replacements);
124124
replacerRunner.run();
125125
// replace in table
126-
for (auto& name : module->table.names) {
127-
auto iter = replacements.find(name);
128-
if (iter != replacements.end()) {
129-
name = iter->second;
126+
for (auto& curr : module->tables) {
127+
for (auto& name : curr->values) {
128+
auto iter = replacements.find(name);
129+
if (iter != replacements.end()) {
130+
name = iter->second;
131+
}
130132
}
131133
}
132134
// replace in start

src/passes/Print.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
222222
printCallBody(curr);
223223
}
224224
void visitCallIndirect(CallIndirect *curr) {
225-
printOpening(o, "call_indirect ") << curr->fullType;
225+
printOpening(o, "call_indirect ");
226+
printText(o, curr->table.str);
227+
o << ' ' << curr->fullType;
226228
incIndent();
227229
printFullLine(curr->target);
228230
for (auto operand : curr->operands) {
@@ -555,8 +557,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
555557
decIndent();
556558
}
557559
void visitTable(Table *curr) {
558-
printOpening(o, "table");
559-
for (auto name : curr->names) {
560+
printOpening(o, "table ");
561+
printText(o, curr->name.str) << ' ';
562+
if (curr->isDefault)
563+
o << "default" << ' ';
564+
visitFunctionType(curr->elementType, true);
565+
for (auto name : curr->values) {
560566
o << ' ';
561567
printName(name);
562568
}
@@ -621,9 +627,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
621627
visitExport(child.get());
622628
o << maybeNewLine;
623629
}
624-
if (curr->table.names.size() > 0) {
630+
for (auto& child : curr->tables) {
625631
doIndent(o, indent);
626-
visitTable(&curr->table);
632+
visitTable(child.get());
627633
o << maybeNewLine;
628634
}
629635
for (auto& child : curr->functions) {

src/passes/RemoveUnusedFunctions.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ struct RemoveUnusedFunctions : public Pass {
3939
root.push_back(module->getFunction(curr->value));
4040
}
4141
// For now, all functions that can be called indirectly are marked as roots.
42-
for (auto& curr : module->table.names) {
43-
root.push_back(module->getFunction(curr));
42+
for (auto& child : module->tables) {
43+
for (auto& curr : child->values) {
44+
root.push_back(module->getFunction(curr));
45+
}
4446
}
4547
// Compute function reachability starting from the root set.
4648
DirectCallGraphAnalyzer analyzer(module, root);

src/passes/ReorderFunctions.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ struct ReorderFunctions : public WalkerPass<PostWalker<ReorderFunctions, Visitor
3838
for (auto& curr : module->exports) {
3939
counts[curr->value]++;
4040
}
41-
for (auto& curr : module->table.names) {
42-
counts[curr]++;
41+
for (auto& child : module->tables) {
42+
for (auto& curr : child->values) {
43+
counts[curr]++;
44+
}
4345
}
4446
std::sort(module->functions.begin(), module->functions.end(), [this](
4547
const std::unique_ptr<Function>& a,

src/s2wasm.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,11 @@ class S2WasmBuilder {
318318
return cashew::IString(str.c_str(), false);
319319
}
320320

321+
uint32_t getTable() {
322+
if (!match(".")) return 0;
323+
return getInt();
324+
}
325+
321326
std::vector<char> getQuoted() {
322327
assert(*s == '"');
323328
s++;
@@ -622,7 +627,7 @@ class S2WasmBuilder {
622627
};
623628
wasm::Builder builder(*wasm);
624629
std::vector<NameType> params;
625-
int64_t indirectIndex = -1;
630+
uint64_t indirectIndex = 0;
626631
WasmType resultType = none;
627632
std::vector<NameType> vars;
628633

@@ -643,9 +648,6 @@ class S2WasmBuilder {
643648
} else if (match(".indidx")) {
644649
indirectIndex = getInt64();
645650
skipWhitespace();
646-
if (indirectIndex < 0) {
647-
abort_on("indidx");
648-
}
649651
} else if (match(".local")) {
650652
while (1) {
651653
Name name = getNextId();
@@ -859,6 +861,7 @@ class S2WasmBuilder {
859861
auto makeCall = [&](WasmType type) {
860862
if (match("_indirect")) {
861863
// indirect call
864+
uint32_t table = getTable();
862865
Name assign = getAssign();
863866
int num = getNumInputs();
864867
auto inputs = getInputs(num);
@@ -867,7 +870,7 @@ class S2WasmBuilder {
867870
std::vector<Expression*> operands(++input, inputs.end());
868871
auto* funcType = ensureFunctionType(getSig(type, operands), wasm);
869872
assert(type == funcType->result);
870-
auto* indirect = builder.makeCallIndirect(funcType, target, std::move(operands));
873+
auto* indirect = builder.makeCallIndirect(linkerObj->getIndirectTable(table, funcType)->name, funcType, target, std::move(operands));
871874
setOutput(indirect, assign);
872875
} else {
873876
// non-indirect call

src/wasm-binary.h

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
484484
writeSignatures();
485485
writeImports();
486486
writeFunctionSignatures();
487-
writeFunctionTable();
487+
writeFunctionTables();
488488
writeMemory();
489489
writeExports();
490490
writeStart();
@@ -559,14 +559,22 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
559559
finishSection(start);
560560
}
561561

562-
int32_t getFunctionTypeIndex(Name type) {
562+
Index getFunctionTypeIndex(Name type) {
563563
// TODO: optimize
564564
for (size_t i = 0; i < wasm->functionTypes.size(); i++) {
565565
if (wasm->functionTypes[i]->name == type) return i;
566566
}
567567
abort();
568568
}
569569

570+
Index getFunctionTableIndex(Name type) {
571+
// TODO: optimize
572+
for (size_t i = 0; i < wasm->tables.size(); i++) {
573+
if (wasm->tables[i]->name == type) return i;
574+
}
575+
abort();
576+
}
577+
570578
void writeImports() {
571579
if (wasm->imports.size() == 0) return;
572580
if (debug) std::cerr << "== writeImports" << std::endl;
@@ -670,7 +678,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
670678

671679
void writeExports() {
672680
if (wasm->exports.size() == 0) return;
673-
if (debug) std::cerr << "== writeexports" << std::endl;
681+
if (debug) std::cerr << "== writeExports" << std::endl;
674682
auto start = startSection(BinaryConsts::Section::ExportTable);
675683
o << U32LEB(wasm->exports.size());
676684
for (auto& curr : wasm->exports) {
@@ -709,8 +717,8 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
709717
assert(mappedImports.count(name));
710718
return mappedImports[name];
711719
}
712-
713-
std::map<Name, uint32_t> mappedFunctions; // name of the Function => index
720+
721+
std::map<Name, uint32_t> mappedFunctions; // name of the Function => entry index
714722
uint32_t getFunctionIndex(Name name) {
715723
if (!mappedFunctions.size()) {
716724
// Create name => index mapping.
@@ -723,13 +731,21 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
723731
return mappedFunctions[name];
724732
}
725733

726-
void writeFunctionTable() {
727-
if (wasm->table.names.size() == 0) return;
728-
if (debug) std::cerr << "== writeFunctionTable" << std::endl;
734+
void writeFunctionTables() {
735+
if (wasm->tables.size() == 0) return;
736+
if (debug) std::cerr << "== writeFunctionTables" << std::endl;
729737
auto start = startSection(BinaryConsts::Section::FunctionTable);
730-
o << U32LEB(wasm->table.names.size());
731-
for (auto name : wasm->table.names) {
732-
o << U32LEB(getFunctionIndex(name));
738+
o << U32LEB(wasm->tables.size());
739+
for (auto& curr : wasm->tables) {
740+
if (debug) std::cerr << "write one" << std::endl;
741+
o << int8_t(curr->isDefault);
742+
o << U32LEB(getFunctionTypeIndex(curr->elementType->name));
743+
assert(curr->initial == curr->values.size() && curr->initial == curr->max);
744+
o << U32LEB(curr->initial);
745+
o << U32LEB(curr->max);
746+
for (auto name : curr->values) {
747+
o << U32LEB(getFunctionIndex(name));
748+
}
733749
}
734750
finishSection(start);
735751
}
@@ -909,7 +925,7 @@ class WasmBinaryWriter : public Visitor<WasmBinaryWriter, void> {
909925
for (auto* operand : curr->operands) {
910926
recurse(operand);
911927
}
912-
o << int8_t(BinaryConsts::CallIndirect) << U32LEB(curr->operands.size()) << U32LEB(getFunctionTypeIndex(curr->fullType));
928+
o << int8_t(BinaryConsts::CallIndirect) << U32LEB(curr->operands.size()) << U32LEB(getFunctionTypeIndex(curr->fullType)) << U32LEB(getFunctionTableIndex(curr->table));
913929
}
914930
void visitGetLocal(GetLocal *curr) {
915931
if (debug) std::cerr << "zz node: GetLocal " << (o.size() + 1) << std::endl;
@@ -1231,7 +1247,7 @@ class WasmBinaryBuilder {
12311247
else if (match(BinaryConsts::Section::Functions)) readFunctions();
12321248
else if (match(BinaryConsts::Section::ExportTable)) readExports();
12331249
else if (match(BinaryConsts::Section::DataSegments)) readDataSegments();
1234-
else if (match(BinaryConsts::Section::FunctionTable)) readFunctionTable();
1250+
else if (match(BinaryConsts::Section::FunctionTable)) readFunctionTables();
12351251
else if (match(BinaryConsts::Section::Names)) readNames();
12361252
else {
12371253
std::cerr << "unfamiliar section: ";
@@ -1427,6 +1443,11 @@ class WasmBinaryBuilder {
14271443
assert(numResults == 1);
14281444
curr->result = getWasmType();
14291445
}
1446+
// TODO: Handle "anyfunc" properly. This sets the name to "anyfunc" if
1447+
// it does not already exist, and matches the expected type signature.
1448+
if (!wasm.checkFunctionType(FunctionType::kAnyFunc) && FunctionType::isAnyFuncType(curr)) {
1449+
curr->name = FunctionType::kAnyFunc;
1450+
}
14301451
wasm.addFunctionType(curr);
14311452
}
14321453
}
@@ -1441,7 +1462,7 @@ class WasmBinaryBuilder {
14411462
curr->name = Name(std::string("import$") + std::to_string(i));
14421463
auto index = getU32LEB();
14431464
assert(index < wasm.functionTypes.size());
1444-
curr->type = wasm.getFunctionType(index);
1465+
curr->type = wasm.functionTypes[index].get();
14451466
assert(curr->type->name.is());
14461467
curr->module = getInlineString();
14471468
curr->base = getInlineString();
@@ -1596,9 +1617,12 @@ class WasmBinaryBuilder {
15961617
}
15971618
}
15981619

1599-
for (size_t index : functionTable) {
1600-
assert(index < wasm.functions.size());
1601-
wasm.table.names.push_back(wasm.functions[index]->name);
1620+
for (auto& pair : functionTable) {
1621+
assert(pair.first < wasm.tables.size());
1622+
assert(pair.second < wasm.functions.size());
1623+
assert(wasm.tables[pair.first]->values.size() <= wasm.tables[pair.first]->max);
1624+
assert(wasm.tables[pair.first]->elementType->name == FunctionType::kAnyFunc || wasm.tables[pair.first]->elementType == wasm.getFunctionType(wasm.functions[pair.second]->type));
1625+
wasm.tables[pair.first]->values.push_back(wasm.functions[pair.second]->name);
16021626
}
16031627
}
16041628

@@ -1618,14 +1642,28 @@ class WasmBinaryBuilder {
16181642
}
16191643
}
16201644

1621-
std::vector<size_t> functionTable;
1645+
std::vector<std::pair<size_t, size_t>> functionTable;
16221646

1623-
void readFunctionTable() {
1624-
if (debug) std::cerr << "== readFunctionTable" << std::endl;
1625-
auto num = getU32LEB();
1626-
for (size_t i = 0; i < num; i++) {
1647+
void readFunctionTables() {
1648+
if (debug) std::cerr << "== readFunctionTables" << std::endl;
1649+
size_t numTables = getU32LEB();
1650+
for (size_t i = 0; i < numTables; i++) {
1651+
if (debug) std::cerr << "read one" << std::endl;
1652+
auto curr = new Table;
1653+
auto flag = getInt8();
1654+
assert((!i && flag) || (i && !flag));
1655+
curr->isDefault = flag;
16271656
auto index = getU32LEB();
1628-
functionTable.push_back(index);
1657+
assert(index < functionTypes.size());
1658+
curr->elementType = wasm.getFunctionType(index);
1659+
curr->initial = getU32LEB();
1660+
curr->max = getU32LEB();
1661+
assert(curr->initial == curr->max);
1662+
for (size_t j = 0; j < curr->initial; j++) {
1663+
auto index = getU32LEB();
1664+
functionTable.push_back(std::make_pair<>(i, index));
1665+
}
1666+
wasm.addTable(curr);
16291667
}
16301668
}
16311669

@@ -1843,6 +1881,7 @@ class WasmBinaryBuilder {
18431881
curr->fullType = fullType->name;
18441882
auto num = fullType->params.size();
18451883
assert(num == arity);
1884+
curr->table = wasm.getTable(getU32LEB())->name;
18461885
curr->operands.resize(num);
18471886
for (size_t i = 0; i < num; i++) {
18481887
curr->operands[num - i - 1] = popExpression();

0 commit comments

Comments
 (0)