diff --git a/src/asm2wasm.h b/src/asm2wasm.h index 319116a4284..b30a1f6daa8 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -2008,7 +2008,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->offset = 0; ret->align = view.bytes; ret->ptr = processUnshifted(ast[2], view.bytes); - ret->type = getType(view.bytes, !view.integer); + ret->type = Type::get(view.bytes, !view.integer); return ret; } else if (what == UNARY_PREFIX) { if (ast[1] == PLUS) { diff --git a/src/ir/ExpressionAnalyzer.cpp b/src/ir/ExpressionAnalyzer.cpp index 5128f525bf4..7355d1856fc 100644 --- a/src/ir/ExpressionAnalyzer.cpp +++ b/src/ir/ExpressionAnalyzer.cpp @@ -151,7 +151,7 @@ template void visitImmediates(Expression* curr, T& visitor) { } void visitLoad(Load* curr) { visitor.visitInt(curr->bytes); - if (curr->type != unreachable && curr->bytes < getTypeSize(curr->type)) { + if (curr->type != unreachable && curr->bytes < curr->type.getByteSize()) { visitor.visitInt(curr->signed_); } visitor.visitAddress(curr->offset); diff --git a/src/ir/load-utils.h b/src/ir/load-utils.h index 037f5297b19..a3bf79c6020 100644 --- a/src/ir/load-utils.h +++ b/src/ir/load-utils.h @@ -31,7 +31,7 @@ inline bool isSignRelevant(Load* load) { if (load->type == unreachable) { return false; } - return !type.isFloat() && load->bytes < getTypeSize(type); + return !type.isFloat() && load->bytes < type.getByteSize(); } // check if a load can be signed (which some opts want to do) diff --git a/src/passes/Asyncify.cpp b/src/passes/Asyncify.cpp index 8e583863cda..8f929d9fe2f 100644 --- a/src/passes/Asyncify.cpp +++ b/src/passes/Asyncify.cpp @@ -1090,7 +1090,7 @@ struct AsyncifyLocals : public WalkerPass> { Index total = 0; for (Index i = 0; i < numPreservableLocals; i++) { auto type = func->getLocalType(i); - auto size = getTypeSize(type); + auto size = type.getByteSize(); total += size; } auto* block = builder->makeBlock(); @@ -1101,7 +1101,7 @@ struct AsyncifyLocals : public WalkerPass> { Index offset = 0; for (Index i = 0; i < numPreservableLocals; i++) { auto type = func->getLocalType(i); - auto size = getTypeSize(type); + auto size = type.getByteSize(); assert(size % STACK_ALIGN == 0); // TODO: higher alignment? block->list.push_back(builder->makeLocalSet( @@ -1130,7 +1130,7 @@ struct AsyncifyLocals : public WalkerPass> { Index offset = 0; for (Index i = 0; i < numPreservableLocals; i++) { auto type = func->getLocalType(i); - auto size = getTypeSize(type); + auto size = type.getByteSize(); assert(size % STACK_ALIGN == 0); // TODO: higher alignment? block->list.push_back( diff --git a/src/passes/AvoidReinterprets.cpp b/src/passes/AvoidReinterprets.cpp index 5aa1d338b8a..f1b3d96d8a9 100644 --- a/src/passes/AvoidReinterprets.cpp +++ b/src/passes/AvoidReinterprets.cpp @@ -32,7 +32,8 @@ static bool canReplaceWithReinterpret(Load* load) { // a reinterpret of the same address. A partial load would see // more bytes and possibly invalid data, and an unreachable // pointer is just not interesting to handle. - return load->type != unreachable && load->bytes == getTypeSize(load->type); + return load->type != Type::unreachable && + load->bytes == load->type.getByteSize(); } static Load* getSingleLoad(LocalGraph* localGraph, LocalGet* get) { @@ -116,7 +117,7 @@ struct AvoidReinterprets : public WalkerPass> { // We should use another load here, to avoid reinterprets. info.ptrLocal = Builder::addVar(func, i32); info.reinterpretedLocal = - Builder::addVar(func, reinterpretType(load->type)); + Builder::addVar(func, load->type.reinterpret()); } else { unoptimizables.insert(load); } @@ -150,8 +151,8 @@ struct AvoidReinterprets : public WalkerPass> { auto& info = iter->second; // A reinterpret of a get of a load - use the new local. Builder builder(*module); - replaceCurrent(builder.makeLocalGet( - info.reinterpretedLocal, reinterpretType(load->type))); + replaceCurrent(builder.makeLocalGet(info.reinterpretedLocal, + load->type.reinterpret())); } } } @@ -185,7 +186,7 @@ struct AvoidReinterprets : public WalkerPass> { load->offset, load->align, ptr, - reinterpretType(load->type)); + load->type.reinterpret()); } } finalOptimizer(infos, localGraph, getModule()); diff --git a/src/passes/ConstHoisting.cpp b/src/passes/ConstHoisting.cpp index 9eb8538781e..dbb3853d846 100644 --- a/src/passes/ConstHoisting.cpp +++ b/src/passes/ConstHoisting.cpp @@ -88,7 +88,7 @@ struct ConstHoisting : public WalkerPass> { } case f32: case f64: { - size = getTypeSize(value.type); + size = value.type.getByteSize(); break; } case v128: // v128 not implemented yet diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 3bad8f9604c..5efd1fd28a9 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -186,7 +186,7 @@ struct PrintExpressionContents o << ".atomic"; } o << ".load"; - if (curr->type != unreachable && curr->bytes < getTypeSize(curr->type)) { + if (curr->type != unreachable && curr->bytes < curr->type.getByteSize()) { if (curr->bytes == 1) { o << '8'; } else if (curr->bytes == 2) { @@ -233,7 +233,7 @@ struct PrintExpressionContents } static void printRMWSize(std::ostream& o, Type type, uint8_t bytes) { prepareColor(o) << forceConcrete(type) << ".atomic.rmw"; - if (type != unreachable && bytes != getTypeSize(type)) { + if (type != unreachable && bytes != type.getByteSize()) { if (bytes == 1) { o << '8'; } else if (bytes == 2) { @@ -269,7 +269,7 @@ struct PrintExpressionContents o << "xchg"; break; } - if (curr->type != unreachable && curr->bytes != getTypeSize(curr->type)) { + if (curr->type != unreachable && curr->bytes != curr->type.getByteSize()) { o << "_u"; } restoreNormalColor(o); @@ -281,7 +281,7 @@ struct PrintExpressionContents prepareColor(o); printRMWSize(o, curr->type, curr->bytes); o << "cmpxchg"; - if (curr->type != unreachable && curr->bytes != getTypeSize(curr->type)) { + if (curr->type != unreachable && curr->bytes != curr->type.getByteSize()) { o << "_u"; } restoreNormalColor(o); diff --git a/src/passes/SafeHeap.cpp b/src/passes/SafeHeap.cpp index 37610a90b7f..fc6706f3e54 100644 --- a/src/passes/SafeHeap.cpp +++ b/src/passes/SafeHeap.cpp @@ -176,7 +176,7 @@ struct SafeHeap : public Pass { load.type = type; for (Index bytes : {1, 2, 4, 8, 16}) { load.bytes = bytes; - if (bytes > getTypeSize(type) || (type == f32 && bytes != 4) || + if (bytes > type.getByteSize() || (type == f32 && bytes != 4) || (type == f64 && bytes != 8) || (type == v128 && bytes != 16)) { continue; } @@ -212,7 +212,7 @@ struct SafeHeap : public Pass { store.type = none; for (Index bytes : {1, 2, 4, 8, 16}) { store.bytes = bytes; - if (bytes > getTypeSize(valueType) || + if (bytes > valueType.getByteSize() || (valueType == f32 && bytes != 4) || (valueType == f64 && bytes != 8) || (valueType == v128 && bytes != 16)) { diff --git a/src/passes/SpillPointers.cpp b/src/passes/SpillPointers.cpp index 10fa588452e..758ca322336 100644 --- a/src/passes/SpillPointers.cpp +++ b/src/passes/SpillPointers.cpp @@ -78,7 +78,7 @@ struct SpillPointers PointerMap pointerMap; for (Index i = 0; i < func->getNumLocals(); i++) { if (func->getLocalType(i) == ABI::PointerType) { - auto offset = pointerMap.size() * getTypeSize(ABI::PointerType); + auto offset = pointerMap.size() * ABI::PointerType.getByteSize(); pointerMap[i] = offset; } } @@ -140,7 +140,7 @@ struct SpillPointers // get the stack space, and set the local to it ABI::getStackSpace(spillLocal, func, - getTypeSize(ABI::PointerType) * pointerMap.size(), + ABI::PointerType.getByteSize() * pointerMap.size(), *getModule()); } } @@ -184,9 +184,9 @@ struct SpillPointers // add the spills for (auto index : toSpill) { block->list.push_back( - builder.makeStore(getTypeSize(ABI::PointerType), + builder.makeStore(ABI::PointerType.getByteSize(), pointerMap[index], - getTypeSize(ABI::PointerType), + ABI::PointerType.getByteSize(), builder.makeLocalGet(spillLocal, ABI::PointerType), builder.makeLocalGet(index, ABI::PointerType), ABI::PointerType)); diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index bda58de9d97..571f0d1a538 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -1719,7 +1719,7 @@ template class ModuleInstanceBase { if (timeout.breaking()) { return timeout; } - auto bytes = getTypeSize(curr->expectedType); + auto bytes = curr->expectedType.getByteSize(); auto addr = instance.getFinalAddress(ptr.value, bytes); auto loaded = instance.doAtomicLoad(addr, bytes, curr->expectedType); NOTE_EVAL1(loaded); diff --git a/src/wasm-type.h b/src/wasm-type.h index c4d719a35d5..53ef39ef833 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -79,6 +79,21 @@ class Type { // Allows for using Types in switch statements constexpr operator uint32_t() const { return id; } + + // Returns the type size in bytes. Only single types are supported. + unsigned getByteSize() const; + + // Reinterpret an integer type to a float type with the same size and vice + // versa. Only single integer and float types are supported. + Type reinterpret() const; + + // Returns the feature set required to use this type. + FeatureSet getFeatures() const; + + // Returns a number type based on its size in bytes and whether it is a float + // type. + static Type get(unsigned byteSize, bool float_); + std::string toString() const; }; @@ -123,11 +138,6 @@ constexpr Type anyref = Type::anyref; constexpr Type exnref = Type::exnref; constexpr Type unreachable = Type::unreachable; -unsigned getTypeSize(Type type); -FeatureSet getFeatures(Type type); -Type getType(unsigned size, bool float_); -Type reinterpretType(Type type); - } // namespace wasm template<> class std::hash { diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index fa90532d994..82eb51d7e39 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -2974,7 +2974,7 @@ bool WasmBinaryBuilder::maybeVisitAtomicWait(Expression*& out, uint8_t code) { curr->ptr = popNonVoidExpression(); Address readAlign; readMemoryAccess(readAlign, curr->offset); - if (readAlign != getTypeSize(curr->expectedType)) { + if (readAlign != curr->expectedType.getByteSize()) { throwError("Align of AtomicWait must match size"); } curr->finalize(); @@ -2994,7 +2994,7 @@ bool WasmBinaryBuilder::maybeVisitAtomicNotify(Expression*& out, uint8_t code) { curr->ptr = popNonVoidExpression(); Address readAlign; readMemoryAccess(readAlign, curr->offset); - if (readAlign != getTypeSize(curr->type)) { + if (readAlign != curr->type.getByteSize()) { throwError("Align of AtomicNotify must match size"); } curr->finalize(); diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 1319d80fc3e..20aff209124 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -1268,7 +1268,7 @@ SExpressionWasmBuilder::makeLoad(Element& s, Type type, bool isAtomic) { auto* ret = allocator.alloc(); ret->isAtomic = isAtomic; ret->type = type; - ret->bytes = parseMemBytes(extra, getTypeSize(type)); + ret->bytes = parseMemBytes(extra, type.getByteSize()); ret->signed_ = extra[0] && extra[1] == 's'; size_t i = parseMemAttributes(s, &ret->offset, &ret->align, ret->bytes); ret->ptr = parseExpression(s[i]); @@ -1282,7 +1282,7 @@ SExpressionWasmBuilder::makeStore(Element& s, Type type, bool isAtomic) { auto ret = allocator.alloc(); ret->isAtomic = isAtomic; ret->valueType = type; - ret->bytes = parseMemBytes(extra, getTypeSize(type)); + ret->bytes = parseMemBytes(extra, type.getByteSize()); size_t i = parseMemAttributes(s, &ret->offset, &ret->align, ret->bytes); ret->ptr = parseExpression(s[i]); ret->value = parseExpression(s[i + 1]); @@ -1294,7 +1294,7 @@ Expression* SExpressionWasmBuilder::makeAtomicRMWOrCmpxchg(Element& s, Type type) { const char* extra = findMemExtra( *s[0], 11 /* after "type.atomic.rmw" */, /* isAtomic = */ false); - auto bytes = parseMemBytes(extra, getTypeSize(type)); + auto bytes = parseMemBytes(extra, type.getByteSize()); extra = strchr(extra, '.'); // after the optional '_u' and before the opcode if (!extra) { throw ParseException("malformed atomic rmw instruction"); diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index 7d29b0dcdaf..939ee8c9368 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -140,14 +140,85 @@ bool Type::operator<(const Type& other) const { [](const Type& a, const Type& b) { return uint32_t(a) < uint32_t(b); }); } -bool Signature::operator<(const Signature& other) const { - if (results < other.results) { - return true; - } else if (other.results < results) { - return false; - } else { - return params < other.params; +unsigned Type::getByteSize() const { + assert(isSingle() && "getByteSize does not works with single types"); + Type singleType = *expand().begin(); + switch (singleType) { + case Type::i32: + return 4; + case Type::i64: + return 8; + case Type::f32: + return 4; + case Type::f64: + return 8; + case Type::v128: + return 16; + case Type::anyref: // anyref type is opaque + case Type::exnref: // exnref type is opaque + case Type::none: + case Type::unreachable: + WASM_UNREACHABLE("invalid type"); } + WASM_UNREACHABLE("invalid type"); +} + +Type Type::reinterpret() const { + assert(isSingle() && "reinterpret only works with single types"); + Type singleType = *expand().begin(); + switch (singleType) { + case Type::i32: + return f32; + case Type::i64: + return f64; + case Type::f32: + return i32; + case Type::f64: + return i64; + case Type::v128: + case Type::anyref: + case Type::exnref: + case Type::none: + case Type::unreachable: + WASM_UNREACHABLE("invalid type"); + } + WASM_UNREACHABLE("invalid type"); +} + +FeatureSet Type::getFeatures() const { + FeatureSet feats = FeatureSet::MVP; + for (Type t : expand()) { + switch (t) { + case Type::v128: + feats |= FeatureSet::SIMD; + break; + case Type::anyref: + feats |= FeatureSet::ReferenceTypes; + break; + case Type::exnref: + feats |= FeatureSet::ExceptionHandling; + break; + default: + break; + } + } + return feats; +} + +Type Type::get(unsigned byteSize, bool float_) { + if (byteSize < 4) { + return Type::i32; + } + if (byteSize == 4) { + return float_ ? Type::f32 : Type::i32; + } + if (byteSize == 8) { + return float_ ? Type::f64 : Type::i64; + } + if (byteSize == 16) { + return Type::v128; + } + WASM_UNREACHABLE("invalid size"); } namespace { @@ -170,6 +241,22 @@ template std::string genericToString(const T& t) { } // anonymous namespace +std::string Type::toString() const { return genericToString(*this); } + +std::string ParamType::toString() const { return genericToString(*this); } + +std::string ResultType::toString() const { return genericToString(*this); } + +bool Signature::operator<(const Signature& other) const { + if (results < other.results) { + return true; + } else if (other.results < results) { + return false; + } else { + return params < other.params; + } +} + std::ostream& operator<<(std::ostream& os, Type type) { switch (type) { case Type::none: @@ -226,87 +313,4 @@ std::ostream& operator<<(std::ostream& os, Signature sig) { return os << "Signature(" << sig.params << " => " << sig.results << ")"; } -std::string Type::toString() const { return genericToString(*this); } - -std::string ParamType::toString() const { return genericToString(*this); } - -std::string ResultType::toString() const { return genericToString(*this); } - -unsigned getTypeSize(Type type) { - switch (type) { - case Type::i32: - return 4; - case Type::i64: - return 8; - case Type::f32: - return 4; - case Type::f64: - return 8; - case Type::v128: - return 16; - case Type::anyref: // anyref type is opaque - case Type::exnref: // exnref type is opaque - case Type::none: - case Type::unreachable: - WASM_UNREACHABLE("invalid type"); - } - WASM_UNREACHABLE("invalid type"); -} - -FeatureSet getFeatures(Type type) { - FeatureSet feats = FeatureSet::MVP; - for (Type t : type.expand()) { - switch (t) { - case v128: - feats |= FeatureSet::SIMD; - break; - case anyref: - feats |= FeatureSet::ReferenceTypes; - break; - case exnref: - feats |= FeatureSet::ExceptionHandling; - break; - default: - break; - } - } - return feats; -} - -Type getType(unsigned size, bool float_) { - if (size < 4) { - return Type::i32; - } - if (size == 4) { - return float_ ? Type::f32 : Type::i32; - } - if (size == 8) { - return float_ ? Type::f64 : Type::i64; - } - if (size == 16) { - return Type::v128; - } - WASM_UNREACHABLE("invalid size"); -} - -Type reinterpretType(Type type) { - switch (type) { - case Type::i32: - return f32; - case Type::i64: - return f64; - case Type::f32: - return i32; - case Type::f64: - return i64; - case Type::v128: - case Type::anyref: - case Type::exnref: - case Type::none: - case Type::unreachable: - WASM_UNREACHABLE("invalid type"); - } - WASM_UNREACHABLE("invalid type"); -} - } // namespace wasm diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index dcd56892c9f..55e115d958d 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -693,7 +693,7 @@ void FunctionValidator::visitCallIndirect(CallIndirect* curr) { } void FunctionValidator::visitConst(Const* curr) { - shouldBeTrue(getFeatures(curr->type) <= getModule()->features, + shouldBeTrue(curr->type.getFeatures() <= getModule()->features, curr, "all used features should be allowed"); } @@ -1763,15 +1763,15 @@ void FunctionValidator::visitFunction(Function* curr) { "Multivalue functions not allowed yet"); FeatureSet features; for (auto type : curr->sig.params.expand()) { - features |= getFeatures(type); + features |= type.getFeatures(); shouldBeTrue(type.isConcrete(), curr, "params must be concretely typed"); } for (auto type : curr->sig.results.expand()) { - features |= getFeatures(type); + features |= type.getFeatures(); shouldBeTrue(type.isConcrete(), curr, "results must be concretely typed"); } for (auto type : curr->vars) { - features |= getFeatures(type); + features |= type.getFeatures(); shouldBeTrue(type.isConcrete(), curr, "vars must be concretely typed"); } shouldBeTrue(features <= getModule()->features, @@ -2005,7 +2005,7 @@ static void validateExports(Module& module, ValidationInfo& info) { static void validateGlobals(Module& module, ValidationInfo& info) { ModuleUtils::iterDefinedGlobals(module, [&](Global* curr) { - info.shouldBeTrue(getFeatures(curr->type) <= module.features, + info.shouldBeTrue(curr->type.getFeatures() <= module.features, curr->name, "all used types should be allowed"); info.shouldBeTrue(