Skip to content

Commit d78c794

Browse files
authored
Refactor getMaxBits() out of OptimizeInstructions and add beginnings of unit testing for it (#3019)
getMaxBits just moves around, no logic is changed. Aside from adding getMaxBits, the change in bits.h is 99% whitespace. helps #2879
1 parent b44c1af commit d78c794

File tree

3 files changed

+265
-223
lines changed

3 files changed

+265
-223
lines changed

src/ir/bits.h

Lines changed: 232 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -20,92 +20,255 @@
2020
#include "ir/literal-utils.h"
2121
#include "support/bits.h"
2222
#include "wasm-builder.h"
23+
#include <ir/load-utils.h>
2324

2425
namespace wasm {
2526

26-
struct Bits {
27-
// get a mask to keep only the low # of bits
28-
static int32_t lowBitMask(int32_t bits) {
29-
uint32_t ret = -1;
30-
if (bits >= 32) {
31-
return ret;
32-
}
33-
return ret >> (32 - bits);
27+
namespace Bits {
28+
29+
// get a mask to keep only the low # of bits
30+
inline int32_t lowBitMask(int32_t bits) {
31+
uint32_t ret = -1;
32+
if (bits >= 32) {
33+
return ret;
3434
}
35+
return ret >> (32 - bits);
36+
}
3537

36-
// checks if the input is a mask of lower bits, i.e., all 1s up to some high
37-
// bit, and all zeros from there. returns the number of masked bits, or 0 if
38-
// this is not such a mask
39-
static uint32_t getMaskedBits(uint32_t mask) {
40-
if (mask == uint32_t(-1)) {
41-
return 32; // all the bits
42-
}
43-
if (mask == 0) {
44-
return 0; // trivially not a mask
45-
}
46-
// otherwise, see if x & (x + 1) turns this into non-zero value
47-
// 00011111 & (00011111 + 1) => 0
48-
if (mask & (mask + 1)) {
49-
return 0;
50-
}
51-
// this is indeed a mask
52-
return 32 - CountLeadingZeroes(mask);
38+
// checks if the input is a mask of lower bits, i.e., all 1s up to some high
39+
// bit, and all zeros from there. returns the number of masked bits, or 0 if
40+
// this is not such a mask
41+
inline uint32_t getMaskedBits(uint32_t mask) {
42+
if (mask == uint32_t(-1)) {
43+
return 32; // all the bits
44+
}
45+
if (mask == 0) {
46+
return 0; // trivially not a mask
5347
}
48+
// otherwise, see if x & (x + 1) turns this into non-zero value
49+
// 00011111 & (00011111 + 1) => 0
50+
if (mask & (mask + 1)) {
51+
return 0;
52+
}
53+
// this is indeed a mask
54+
return 32 - CountLeadingZeroes(mask);
55+
}
56+
57+
// gets the number of effective shifts a shift operation does. In
58+
// wasm, only 5 bits matter for 32-bit shifts, and 6 for 64.
59+
inline Index getEffectiveShifts(Index amount, Type type) {
60+
if (type == Type::i32) {
61+
return amount & 31;
62+
} else if (type == Type::i64) {
63+
return amount & 63;
64+
}
65+
WASM_UNREACHABLE("unexpected type");
66+
}
5467

55-
// gets the number of effective shifts a shift operation does. In
56-
// wasm, only 5 bits matter for 32-bit shifts, and 6 for 64.
57-
static Index getEffectiveShifts(Index amount, Type type) {
58-
if (type == Type::i32) {
59-
return amount & 31;
60-
} else if (type == Type::i64) {
61-
return amount & 63;
68+
inline Index getEffectiveShifts(Expression* expr) {
69+
auto* amount = expr->cast<Const>();
70+
if (amount->type == Type::i32) {
71+
return getEffectiveShifts(amount->value.geti32(), Type::i32);
72+
} else if (amount->type == Type::i64) {
73+
return getEffectiveShifts(amount->value.geti64(), Type::i64);
74+
}
75+
WASM_UNREACHABLE("unexpected type");
76+
}
77+
78+
inline Expression* makeSignExt(Expression* value, Index bytes, Module& wasm) {
79+
if (value->type == Type::i32) {
80+
if (bytes == 1 || bytes == 2) {
81+
auto shifts = bytes == 1 ? 24 : 16;
82+
Builder builder(wasm);
83+
return builder.makeBinary(
84+
ShrSInt32,
85+
builder.makeBinary(
86+
ShlInt32,
87+
value,
88+
LiteralUtils::makeFromInt32(shifts, Type::i32, wasm)),
89+
LiteralUtils::makeFromInt32(shifts, Type::i32, wasm));
90+
}
91+
assert(bytes == 4);
92+
return value; // nothing to do
93+
} else {
94+
assert(value->type == Type::i64);
95+
if (bytes == 1 || bytes == 2 || bytes == 4) {
96+
auto shifts = bytes == 1 ? 56 : (bytes == 2 ? 48 : 32);
97+
Builder builder(wasm);
98+
return builder.makeBinary(
99+
ShrSInt64,
100+
builder.makeBinary(
101+
ShlInt64,
102+
value,
103+
LiteralUtils::makeFromInt32(shifts, Type::i64, wasm)),
104+
LiteralUtils::makeFromInt32(shifts, Type::i64, wasm));
62105
}
63-
WASM_UNREACHABLE("unexpected type");
106+
assert(bytes == 8);
107+
return value; // nothing to do
64108
}
109+
}
65110

66-
static Index getEffectiveShifts(Expression* expr) {
67-
auto* amount = expr->cast<Const>();
68-
if (amount->type == Type::i32) {
69-
return getEffectiveShifts(amount->value.geti32(), Type::i32);
70-
} else if (amount->type == Type::i64) {
71-
return getEffectiveShifts(amount->value.geti64(), Type::i64);
111+
// getMaxBits() helper that has pessimistic results for the bits used in locals.
112+
struct DummyLocalInfoProvider {
113+
Index getMaxBitsForLocal(LocalGet* get) {
114+
if (get->type == Type::i32) {
115+
return 32;
72116
}
73-
WASM_UNREACHABLE("unexpected type");
117+
if (get->type == Type::i32) {
118+
return 64;
119+
}
120+
WASM_UNREACHABLE("type has no integer bit size");
74121
}
122+
};
75123

76-
static Expression* makeSignExt(Expression* value, Index bytes, Module& wasm) {
77-
if (value->type == Type::i32) {
78-
if (bytes == 1 || bytes == 2) {
79-
auto shifts = bytes == 1 ? 24 : 16;
80-
Builder builder(wasm);
81-
return builder.makeBinary(
82-
ShrSInt32,
83-
builder.makeBinary(
84-
ShlInt32,
85-
value,
86-
LiteralUtils::makeFromInt32(shifts, Type::i32, wasm)),
87-
LiteralUtils::makeFromInt32(shifts, Type::i32, wasm));
124+
// Returns the maximum amount of bits used in an integer expression
125+
// not extremely precise (doesn't look into add operands, etc.)
126+
// LocalInfoProvider is an optional class that can provide answers about
127+
// local.get.
128+
template<typename LocalInfoProvider = DummyLocalInfoProvider>
129+
Index getMaxBits(Expression* curr,
130+
LocalInfoProvider* localInfoProvider = nullptr) {
131+
if (auto* const_ = curr->dynCast<Const>()) {
132+
switch (curr->type.getSingle()) {
133+
case Type::i32:
134+
return 32 - const_->value.countLeadingZeroes().geti32();
135+
case Type::i64:
136+
return 64 - const_->value.countLeadingZeroes().geti64();
137+
default:
138+
WASM_UNREACHABLE("invalid type");
139+
}
140+
} else if (auto* binary = curr->dynCast<Binary>()) {
141+
switch (binary->op) {
142+
// 32-bit
143+
case AddInt32:
144+
case SubInt32:
145+
case MulInt32:
146+
case DivSInt32:
147+
case DivUInt32:
148+
case RemSInt32:
149+
case RemUInt32:
150+
case RotLInt32:
151+
case RotRInt32:
152+
return 32;
153+
case AndInt32:
154+
return std::min(getMaxBits(binary->left, localInfoProvider),
155+
getMaxBits(binary->right, localInfoProvider));
156+
case OrInt32:
157+
case XorInt32:
158+
return std::max(getMaxBits(binary->left, localInfoProvider),
159+
getMaxBits(binary->right, localInfoProvider));
160+
case ShlInt32: {
161+
if (auto* shifts = binary->right->dynCast<Const>()) {
162+
return std::min(Index(32),
163+
getMaxBits(binary->left, localInfoProvider) +
164+
Bits::getEffectiveShifts(shifts));
165+
}
166+
return 32;
88167
}
89-
assert(bytes == 4);
90-
return value; // nothing to do
91-
} else {
92-
assert(value->type == Type::i64);
93-
if (bytes == 1 || bytes == 2 || bytes == 4) {
94-
auto shifts = bytes == 1 ? 56 : (bytes == 2 ? 48 : 32);
95-
Builder builder(wasm);
96-
return builder.makeBinary(
97-
ShrSInt64,
98-
builder.makeBinary(
99-
ShlInt64,
100-
value,
101-
LiteralUtils::makeFromInt32(shifts, Type::i64, wasm)),
102-
LiteralUtils::makeFromInt32(shifts, Type::i64, wasm));
168+
case ShrUInt32: {
169+
if (auto* shift = binary->right->dynCast<Const>()) {
170+
auto maxBits = getMaxBits(binary->left, localInfoProvider);
171+
auto shifts =
172+
std::min(Index(Bits::getEffectiveShifts(shift)),
173+
maxBits); // can ignore more shifts than zero us out
174+
return std::max(Index(0), maxBits - shifts);
175+
}
176+
return 32;
177+
}
178+
case ShrSInt32: {
179+
if (auto* shift = binary->right->dynCast<Const>()) {
180+
auto maxBits = getMaxBits(binary->left, localInfoProvider);
181+
if (maxBits == 32) {
182+
return 32;
183+
}
184+
auto shifts =
185+
std::min(Index(Bits::getEffectiveShifts(shift)),
186+
maxBits); // can ignore more shifts than zero us out
187+
return std::max(Index(0), maxBits - shifts);
188+
}
189+
return 32;
190+
}
191+
// 64-bit TODO
192+
// comparisons
193+
case EqInt32:
194+
case NeInt32:
195+
case LtSInt32:
196+
case LtUInt32:
197+
case LeSInt32:
198+
case LeUInt32:
199+
case GtSInt32:
200+
case GtUInt32:
201+
case GeSInt32:
202+
case GeUInt32:
203+
case EqInt64:
204+
case NeInt64:
205+
case LtSInt64:
206+
case LtUInt64:
207+
case LeSInt64:
208+
case LeUInt64:
209+
case GtSInt64:
210+
case GtUInt64:
211+
case GeSInt64:
212+
case GeUInt64:
213+
case EqFloat32:
214+
case NeFloat32:
215+
case LtFloat32:
216+
case LeFloat32:
217+
case GtFloat32:
218+
case GeFloat32:
219+
case EqFloat64:
220+
case NeFloat64:
221+
case LtFloat64:
222+
case LeFloat64:
223+
case GtFloat64:
224+
case GeFloat64:
225+
return 1;
226+
default: {
103227
}
104-
assert(bytes == 8);
105-
return value; // nothing to do
228+
}
229+
} else if (auto* unary = curr->dynCast<Unary>()) {
230+
switch (unary->op) {
231+
case ClzInt32:
232+
case CtzInt32:
233+
case PopcntInt32:
234+
return 6;
235+
case ClzInt64:
236+
case CtzInt64:
237+
case PopcntInt64:
238+
return 7;
239+
case EqZInt32:
240+
case EqZInt64:
241+
return 1;
242+
case WrapInt64:
243+
return std::min(Index(32), getMaxBits(unary->value, localInfoProvider));
244+
default: {
245+
}
246+
}
247+
} else if (auto* set = curr->dynCast<LocalSet>()) {
248+
// a tee passes through the value
249+
return getMaxBits(set->value, localInfoProvider);
250+
} else if (auto* get = curr->dynCast<LocalGet>()) {
251+
return localInfoProvider->getMaxBitsForLocal(get);
252+
} else if (auto* load = curr->dynCast<Load>()) {
253+
// if signed, then the sign-extension might fill all the bits
254+
// if unsigned, then we have a limit
255+
if (LoadUtils::isSignRelevant(load) && !load->signed_) {
256+
return 8 * load->bytes;
106257
}
107258
}
108-
};
259+
switch (curr->type.getSingle()) {
260+
case Type::i32:
261+
return 32;
262+
case Type::i64:
263+
return 64;
264+
case Type::unreachable:
265+
return 64; // not interesting, but don't crash
266+
default:
267+
WASM_UNREACHABLE("invalid type");
268+
}
269+
}
270+
271+
} // namespace Bits
109272

110273
} // namespace wasm
111274

0 commit comments

Comments
 (0)