Skip to content

[OptimizeForJS] Optimize 64-bit div by constant #4055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
bf6eb7e
init
MaxGraey Aug 4, 2021
414a0eb
fix
MaxGraey Aug 4, 2021
c985cdc
comments
MaxGraey Aug 4, 2021
95f3472
fix (wip)
MaxGraey Aug 4, 2021
86a42b3
wip
MaxGraey Aug 4, 2021
1cb12f6
fix
MaxGraey Aug 4, 2021
d282466
lint
MaxGraey Aug 4, 2021
37a416f
lint
MaxGraey Aug 4, 2021
860860c
more tests and fixes
MaxGraey Aug 4, 2021
14c8039
lint
MaxGraey Aug 4, 2021
5b1790d
cleanups
MaxGraey Aug 4, 2021
397919f
refactor
MaxGraey Aug 4, 2021
a3fca5f
special case for signed Power of Two divisors
MaxGraey Aug 5, 2021
5de26a0
fix
MaxGraey Aug 5, 2021
c544a2d
lint
MaxGraey Aug 5, 2021
2167b6e
more tests
MaxGraey Aug 5, 2021
2e789a7
comment reminder rules for now
MaxGraey Aug 5, 2021
59092b3
minor refactoring
MaxGraey Aug 5, 2021
0b924ba
fix
MaxGraey Aug 5, 2021
3de6a63
no skip
MaxGraey Aug 5, 2021
89eac9d
refactor
MaxGraey Aug 5, 2021
984b530
skip negative divisors for unsigned divs
MaxGraey Aug 5, 2021
49ec6dd
skip const-by-const divs
MaxGraey Aug 5, 2021
69457d7
fix
MaxGraey Aug 5, 2021
88f1f5f
Merge branch 'main' into opt-for-js-div-by-const
MaxGraey Aug 5, 2021
aa3e3cc
fixes
MaxGraey Aug 5, 2021
cda2348
lint
MaxGraey Aug 5, 2021
c5971d0
fix
MaxGraey Aug 5, 2021
b7dfb82
fix
MaxGraey Aug 5, 2021
19e63e5
fix
MaxGraey Aug 5, 2021
a8f78af
lint
MaxGraey Aug 5, 2021
448cfb8
Merge branch 'main' into opt-for-js-div-by-const
MaxGraey Aug 5, 2021
25ac50d
Merge branch 'main' into opt-for-js-div-by-const
MaxGraey Aug 5, 2021
bda17ad
clarify comment
MaxGraey Aug 5, 2021
2f5ea6a
add optimize instructions after optimize-for-js
MaxGraey Aug 6, 2021
be3c454
update wasm2js fixtures
MaxGraey Aug 6, 2021
6f3b04d
add test for i64(x) / -4
MaxGraey Aug 6, 2021
d97849e
add div by smin test
MaxGraey Aug 6, 2021
bafc225
embed full license content for header and cpp
MaxGraey Aug 9, 2021
014993c
suggestions
MaxGraey Aug 9, 2021
1ce29d5
add comment for wasm-intrinsics.wat
MaxGraey Aug 11, 2021
6ae253f
remove empry gap
MaxGraey Aug 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/asmjs/shared-constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ cashew::IString WASM_FETCH_HIGH_BITS("__wasm_fetch_high_bits");
cashew::IString INT64_TO_32_HIGH_BITS("i64toi32_i32$HIGH_BITS");
cashew::IString WASM_NEAREST_F32("__wasm_nearest_f32");
cashew::IString WASM_NEAREST_F64("__wasm_nearest_f64");
cashew::IString WASM_I64_MUL_HIGH("__wasm_i64_mulh");
cashew::IString WASM_I64_MUL("__wasm_i64_mul");
cashew::IString WASM_I64_SDIV("__wasm_i64_sdiv");
cashew::IString WASM_I64_UDIV("__wasm_i64_udiv");
Expand Down
1 change: 1 addition & 0 deletions src/asmjs/shared-constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ extern cashew::IString WASM_FETCH_HIGH_BITS;
extern cashew::IString INT64_TO_32_HIGH_BITS;
extern cashew::IString WASM_NEAREST_F32;
extern cashew::IString WASM_NEAREST_F64;
extern cashew::IString WASM_I64_MUL_HIGH;
extern cashew::IString WASM_I64_MUL;
extern cashew::IString WASM_I64_SDIV;
extern cashew::IString WASM_I64_UDIV;
Expand Down
289 changes: 287 additions & 2 deletions src/passes/OptimizeForJS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,94 @@
#include <pass.h>
#include <wasm.h>

#include "abi/js.h"
#include "asmjs/shared-constants.h"
#include "passes/intrinsics-module.h"
#include "wasm-builder.h"
#include "wasm-s-parser.h"
#include <ir/abstract.h>
#include <ir/literal-utils.h>
#include <ir/localize.h>
#include <ir/match.h>
#include <ir/module-utils.h>
#include <limits>
#include <support/bits.h>
#include <support/div-by-const.h>

namespace wasm {

struct OptimizeForJSPass : public WalkerPass<PostWalker<OptimizeForJSPass>> {
bool isFunctionParallel() override { return true; }
bool requireMulhIntrinsic;

bool isFunctionParallel() override { return false; }

Pass* create() override { return new OptimizeForJSPass; }

void doWalkModule(Module* module) {
super::doWalkModule(module);

if (requireMulhIntrinsic) {
Module intrinsics;
std::string input(IntrinsicsModuleWast);
SExpressionParser parser(const_cast<char*>(input.c_str()));
Element& root = *parser.root;
SExpressionWasmBuilder builder(intrinsics, *root[0], IRProfile::Normal);
auto* func = intrinsics.getFunction(WASM_I64_MUL_HIGH);
doWalkFunction(ModuleUtils::copyFunction(func, *module));
}
}

void visitBinary(Binary* curr) {
using namespace Abstract;
using namespace Match;
{
// Rewrite popcnt(x) == 1 ==> !!x & !(x & (x - 1))
// popcnt(x) == 1 ==> !!x & !(x & (x - 1))
Expression* x;
if (matches(curr, binary(Eq, unary(Popcnt, any(&x)), ival(1)))) {
rewritePopcntEqualOne(x);
}
}
{
// i64(x) / C ==> mulh(x, M') >> S'
// where M' and S' are magic constants
Const* c;
Expression* x;
if (matches(curr, binary(DivU, any(&x), i64(&c)))) {
requireMulhIntrinsic = true;
rewriteDivByConstU64(x, (uint64_t)c->value.geti64());
}
}
{
// i64(x) / C ==> mulh(x, M') >> S'
// where M' and S' are magic constants
Const* c;
Expression* x;
if (matches(curr, binary(DivS, any(&x), i64(&c)))) {
requireMulhIntrinsic = true;
rewriteDivByConstS64(x, c->value.geti64());
}
}
// TODO: implement same approach for reminders
// {
// // i64(x) % C ==> mulh(x, M') >> S'
// // where M' and S' are magic constants
// Const* c;
// Expression* x;
// if (matches(curr, binary(RemU, any(&x), i64(&c)))) {
// // requireMulhIntrinsic = true;
// rewriteRemByConstU64(x, (uint64_t)c->value.geti64());
// }
// }
// {
// // i64(x) % C ==> mulh(x, M') >> S'
// // where M' and S' are magic constants
// Const* c;
// Expression* x;
// if (matches(curr, binary(RemS, any(&x), i64(&c)))) {
// // requireMulhIntrinsic = true;
// rewriteRemByConstS64(x, c->value.geti64());
// }
// }
}

void rewritePopcntEqualOne(Expression* expr) {
Expand All @@ -65,6 +130,226 @@ struct OptimizeForJSPass : public WalkerPass<PostWalker<OptimizeForJSPass>> {
builder.makeLocalGet(temp.index, type),
builder.makeConst(Literal::makeOne(type.getBasic())))))));
}

void rewriteDivByConstU64(Expression* dividend, uint64_t divisor) {
// skip if divisor is power of two or negative value
// also skip if dividend is const. All this handled by optimize
// instructions pass.
if (Bits::isPowerOf2(divisor) || int64_t(divisor) < 0LL ||
dividend->is<Const>()) {
return;
}

Builder builder(*getModule());

if (divisor == 0) {
// dividend / 0 -> 0
//
// This valid case in JavaScript
replaceCurrent(builder.makeConst(uint64_t(0)));
return;
}

const unsigned shift = Bits::countTrailingZeroes(divisor);

Type type = dividend->type;
Localizer temp(dividend, getFunction(), getModule());
Index tempIndex = temp.index;

uint64_t shiftedDivisor = divisor;
Expression* shiftedDividend;

if (shift) {
shiftedDivisor >>= shift;
shiftedDividend =
builder.makeBinary(ShrUInt64,
builder.makeLocalGet(tempIndex, type),
builder.makeConst(uint64_t(shift)));
} else {
shiftedDividend = builder.makeLocalGet(tempIndex, type);
}

const auto payload = unsignedDivisionByConstant(shiftedDivisor, shift);

// quotient = mulh(dividend, M')
Expression* quotient =
builder.makeCall(WASM_I64_MUL_HIGH,
{shiftedDividend, builder.makeConst(payload.multiplier)},
type);

if (payload.add) {
// t1 = dividend - quotient
// t2 = (t1 >> 1) + quotient
// res = t2 >> (S' - 1)
assert(payload.shift > 0);
quotient = builder.makeBinary(
ShrUInt64,
builder.makeBinary(
AddInt64,
builder.makeBinary(
ShrUInt64,
builder.makeBinary(
SubInt64, builder.makeLocalGet(tempIndex, type), quotient),
builder.makeConst(uint64_t(1))),
builder.makeLocalGet(tempIndex, type)),
builder.makeConst(uint64_t(payload.shift - 1)));
} else {
// res = quot >> shift
quotient = builder.makeBinary(
ShrUInt64, quotient, builder.makeConst(uint64_t(payload.shift)));
}

// use following control flow logic:
//
// if (!(dividend >> 32)) { // if high word is empty
// return i64(i32(dividend) / i32(C)) // or 0 if C > 2 ** 32
// } else {
// return mulh(dividend, M') >> S'
// }
Expression* quotient32;
Expression* cond = builder.makeUnary(
EqZInt64,
builder.makeBinary(
ShrUInt64, temp.expr, builder.makeConst(uint64_t(32))));

if (divisor <= (uint64_t)std::numeric_limits<uint32_t>::max()) {
// i64(i32(dividend) / i32(C))
quotient32 = builder.makeUnary(
ExtendUInt32,
builder.makeBinary(
DivUInt32,
builder.makeUnary(WrapInt64, builder.makeLocalGet(tempIndex, type)),
builder.makeConst(uint32_t(divisor))));
} else {
// i32(dividend) / C, where C > 2 ** 32 -> 0
quotient32 = builder.makeConst(uint64_t(0));
}

replaceCurrent(builder.makeIf(cond, quotient32, quotient));
}

void rewriteDivByConstS64(Expression* dividend, int64_t divisor) {
if (divisor == std::numeric_limits<int64_t>::min() ||
dividend->is<Const>()) {
return;
}

if (divisor == 1LL) {
// dividend / 1 -> dividend
replaceCurrent(dividend);
return;
}

Builder builder(*getModule());

if (divisor == 0LL) {
// dividend / 0 -> 0
//
// This valid case in JavaScript
replaceCurrent(builder.makeConst(uint64_t(0)));
return;
}

if (divisor == -1LL) {
// dividend / -1 -> 0 - dividend
//
// Note: i64.min / -1 special case leads to overflow (trap) in WebAssembly
// but valid in JavaScript (0 - i64.min -> i64.min).
replaceCurrent(
builder.makeBinary(SubInt64, builder.makeConst(uint64_t(0)), dividend));
return;
}

Localizer temp(dividend, getFunction(), getModule());
Type type = dividend->type;
Index tempIndex = temp.index;

int64_t absoluteDivisor = std::abs(divisor);

if (Bits::isPowerOf2(absoluteDivisor)) {
// dividend / +C_pot ->
// +((x < 0 ? (x + (C_pot - 1)) : x) >> ctz(abs(C_pot)))
//
// dividend / -C_pot ->
// -((x < 0 ? (x + (C_pot - 1)) : x) >> ctz(abs(C_pot)))

// x < 0
Expression* cond =
builder.makeBinary(LtSInt64, temp.expr, builder.makeConst(int64_t(0)));
Expression* ifTrue =
builder.makeBinary(AddInt64,
builder.makeLocalGet(tempIndex, type),
builder.makeConst(int64_t(absoluteDivisor - 1LL)));
Expression* ifFalse = builder.makeLocalGet(tempIndex, type);

Expression* quotient = builder.makeBinary(
ShrSInt64,
builder.makeSelect(cond, ifTrue, ifFalse),
builder.makeConst(int64_t(Bits::countTrailingZeroes(absoluteDivisor))));

if (divisor < 0) {
quotient =
builder.makeBinary(SubInt64, builder.makeConst(int64_t(0)), quotient);
}

replaceCurrent(quotient);
return;
}

const auto payload = signedDivisionByConstant(uint64_t(divisor));

// quotient = mulh(dividend, M')
Expression* quotient =
builder.makeCall(WASM_I64_MUL_HIGH,
{builder.makeLocalGet(tempIndex, type),
builder.makeConst(payload.multiplier)},
type);

if (divisor > 0 && int64_t(payload.multiplier) < 0) {
quotient = builder.makeBinary(
AddInt64, quotient, builder.makeLocalGet(tempIndex, type));
} else if (divisor < 0 && int64_t(payload.multiplier) > 0) {
quotient = builder.makeBinary(
SubInt64, quotient, builder.makeLocalGet(tempIndex, type));
}

quotient = builder.makeBinary(
AddInt64,
builder.makeBinary(
ShrSInt64, quotient, builder.makeConst(int64_t(payload.shift))),
builder.makeBinary(ShrUInt64,
builder.makeLocalGet(tempIndex, type),
builder.makeConst(int64_t(63))));

// use following control flow logic:
//
// if (!(dividend >> 32)) { // if high word is empry
// return i64(i32(dividend) / i32(C)) // or 0 if C > 2 ** 32
// } else {
// return mulh(dividend, M') >> S'
// }
Expression* quotient32;
Expression* cond = builder.makeUnary(
EqZInt64,
builder.makeBinary(
ShrUInt64, temp.expr, builder.makeConst(uint64_t(32))));

if ((uint64_t)std::abs(divisor) <=
(uint64_t)std::numeric_limits<uint32_t>::max()) {
// i64(i32(dividend) / i32(C))
quotient32 = builder.makeUnary(
ExtendSInt32,
builder.makeBinary(
DivSInt32,
builder.makeUnary(WrapInt64, builder.makeLocalGet(tempIndex, type)),
builder.makeConst(int32_t(divisor))));
} else {
// i32(dividend) / C, where C > 2 ** 32 -> 0
quotient32 = builder.makeConst(int64_t(0));
}

replaceCurrent(builder.makeIf(cond, quotient32, quotient));
}
};

Pass* createOptimizeForJSPass() { return new OptimizeForJSPass(); }
Expand Down
Loading