Skip to content

Implement relaxed SIMD dot product instructions #4586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion scripts/gen-s-parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,11 @@
("f32x4.relaxed_max", "makeBinary(s, BinaryOp::RelaxedMaxVecF32x4)"),
("f64x2.relaxed_min", "makeBinary(s, BinaryOp::RelaxedMinVecF64x2)"),
("f64x2.relaxed_max", "makeBinary(s, BinaryOp::RelaxedMaxVecF64x2)"),
("i16x8.relaxed_q15mulr_s", "makeBinary(s, BinaryOp::RelaxedQ15MulrSVecI16x8)"),
("i16x8.relaxed_q15mulr_s", "makeBinary(s, BinaryOp::RelaxedQ15MulrSVecI16x8)"),
("i16x8.dot_i8x16_i7x16_s", "makeBinary(s, BinaryOp::DotI8x16I7x16SToVecI16x8)"),
("i16x8.dot_i8x16_i7x16_u", "makeBinary(s, BinaryOp::DotI8x16I7x16UToVecI16x8)"),
("i32x4.dot_i8x16_i7x16_add_s", "makeSIMDTernary(s, SIMDTernaryOp::DotI8x16I7x16AddSToVecI32x4)"),
("i32x4.dot_i8x16_i7x16_add_u", "makeSIMDTernary(s, SIMDTernaryOp::DotI8x16I7x16AddUToVecI32x4)"),

# reference types instructions
("ref.null", "makeRefNull(s)"),
Expand Down
33 changes: 30 additions & 3 deletions src/gen-s-parser.inc
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,17 @@ switch (op[0]) {
case 'b':
if (strcmp(op, "i16x8.bitmask") == 0) { return makeUnary(s, UnaryOp::BitmaskVecI16x8); }
goto parse_error;
case 'd': {
switch (op[22]) {
case 's':
if (strcmp(op, "i16x8.dot_i8x16_i7x16_s") == 0) { return makeBinary(s, BinaryOp::DotI8x16I7x16SToVecI16x8); }
goto parse_error;
case 'u':
if (strcmp(op, "i16x8.dot_i8x16_i7x16_u") == 0) { return makeBinary(s, BinaryOp::DotI8x16I7x16UToVecI16x8); }
goto parse_error;
default: goto parse_error;
}
}
case 'e': {
switch (op[7]) {
case 'q':
Expand Down Expand Up @@ -1692,9 +1703,25 @@ switch (op[0]) {
case 'b':
if (strcmp(op, "i32x4.bitmask") == 0) { return makeUnary(s, UnaryOp::BitmaskVecI32x4); }
goto parse_error;
case 'd':
if (strcmp(op, "i32x4.dot_i16x8_s") == 0) { return makeBinary(s, BinaryOp::DotSVecI16x8ToVecI32x4); }
goto parse_error;
case 'd': {
switch (op[11]) {
case '1':
if (strcmp(op, "i32x4.dot_i16x8_s") == 0) { return makeBinary(s, BinaryOp::DotSVecI16x8ToVecI32x4); }
goto parse_error;
case '8': {
switch (op[26]) {
case 's':
if (strcmp(op, "i32x4.dot_i8x16_i7x16_add_s") == 0) { return makeSIMDTernary(s, SIMDTernaryOp::DotI8x16I7x16AddSToVecI32x4); }
goto parse_error;
case 'u':
if (strcmp(op, "i32x4.dot_i8x16_i7x16_add_u") == 0) { return makeSIMDTernary(s, SIMDTernaryOp::DotI8x16I7x16AddUToVecI32x4); }
goto parse_error;
default: goto parse_error;
}
}
default: goto parse_error;
}
}
case 'e': {
switch (op[7]) {
case 'q':
Expand Down
4 changes: 4 additions & 0 deletions src/ir/cost.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ struct CostAnalyzer : public OverriddenVisitor<CostAnalyzer, CostType> {
case SwizzleVecI8x16:
case RelaxedSwizzleVecI8x16:
case RelaxedQ15MulrSVecI16x8:
case DotI8x16I7x16SToVecI16x8:
case DotI8x16I7x16UToVecI16x8:
ret = 1;
break;
case InvalidBinary:
Expand Down Expand Up @@ -541,6 +543,8 @@ struct CostAnalyzer : public OverriddenVisitor<CostAnalyzer, CostType> {
case RelaxedFmsVecF32x4:
case RelaxedFmaVecF64x2:
case RelaxedFmsVecF64x2:
case DotI8x16I7x16AddSToVecI32x4:
case DotI8x16I7x16AddUToVecI32x4:
ret = 1;
break;
}
Expand Down
2 changes: 2 additions & 0 deletions src/literal.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ class Literal {
Literal minUI32x4(const Literal& other) const;
Literal maxSI32x4(const Literal& other) const;
Literal maxUI32x4(const Literal& other) const;
Literal dotSI8x16toI16x8(const Literal& other) const;
Literal dotUI8x16toI16x8(const Literal& other) const;
Literal dotSI16x8toI32x4(const Literal& other) const;
Literal extMulLowSI32x4(const Literal& other) const;
Literal extMulHighSI32x4(const Literal& other) const;
Expand Down
12 changes: 12 additions & 0 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,12 @@ struct PrintExpressionContents
case RelaxedFmsVecF64x2:
o << "f64x2.relaxed_fms";
break;
case DotI8x16I7x16AddSToVecI32x4:
o << "i32x4.dot_i8x16_i7x16_add_s";
break;
case DotI8x16I7x16AddUToVecI32x4:
o << "i32x4.dot_i8x16_i7x16_add_u";
break;
}
restoreNormalColor(o);
}
Expand Down Expand Up @@ -1854,6 +1860,12 @@ struct PrintExpressionContents
case RelaxedQ15MulrSVecI16x8:
o << "i16x8.relaxed_q15mulr_s";
break;
case DotI8x16I7x16SToVecI16x8:
o << "i16x8.dot_i8x16_i7x16_s";
break;
case DotI8x16I7x16UToVecI16x8:
o << "i16x8.dot_i8x16_i7x16_u";
break;

case InvalidBinary:
WASM_UNREACHABLE("unvalid binary operator");
Expand Down
4 changes: 4 additions & 0 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,10 @@ enum ASTNodes {
F64x2RelaxedMin = 0xd4,
F64x2RelaxedMax = 0xee,
I16x8RelaxedQ15MulrS = 0x111,
I16x8DotI8x16I7x16S = 0x112,
I16x8DotI8x16I7x16U = 0x113,
I32x4DotI8x16I7x16AddS = 0x114,
I32x4DotI8x16I7x16AddU = 0x115,

// bulk memory opcodes

Expand Down
7 changes: 6 additions & 1 deletion src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,11 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
case RelaxedSwizzleVecI8x16:
return left.swizzleI8x16(right);

case DotI8x16I7x16SToVecI16x8:
return left.dotSI8x16toI16x8(right);
case DotI8x16I7x16UToVecI16x8:
return left.dotUI8x16toI16x8(right);

case InvalidBinary:
WASM_UNREACHABLE("invalid binary op");
}
Expand Down Expand Up @@ -1124,7 +1129,7 @@ class ExpressionRunner : public OverriddenVisitor<SubType, Flow> {
case RelaxedFmsVecF64x2:
return a.relaxedFmsF64x2(b, c);
default:
// TODO: implement signselect
// TODO: implement signselect and dot_add
WASM_UNREACHABLE("not implemented");
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,8 @@ enum BinaryOp {
RelaxedMinVecF64x2,
RelaxedMaxVecF64x2,
RelaxedQ15MulrSVecI16x8,
DotI8x16I7x16SToVecI16x8,
DotI8x16I7x16UToVecI16x8,

InvalidBinary
};
Expand Down Expand Up @@ -552,6 +554,8 @@ enum SIMDTernaryOp {
LaneselectI16x8,
LaneselectI32x4,
LaneselectI64x2,
DotI8x16I7x16AddSToVecI32x4,
DotI8x16I7x16AddUToVecI32x4,
};

enum RefIsOp {
Expand Down
30 changes: 23 additions & 7 deletions src/wasm/literal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2372,17 +2372,33 @@ Literal Literal::pmaxF64x2(const Literal& other) const {
return binary<2, &Literal::getLanesF64x2, &Literal::pmax>(*this, other);
}

Literal Literal::dotSI16x8toI32x4(const Literal& other) const {
LaneArray<8> lhs = getLanesSI16x8();
LaneArray<8> rhs = other.getLanesSI16x8();
LaneArray<4> result;
for (size_t i = 0; i < 4; ++i) {
result[i] = Literal(lhs[i * 2].geti32() * rhs[i * 2].geti32() +
lhs[i * 2 + 1].geti32() * rhs[i * 2 + 1].geti32());
template<size_t Lanes,
size_t Factor,
LaneArray<Lanes * Factor> (Literal::*IntoLanes)() const>
static Literal dot(const Literal& left, const Literal& right) {
Comment on lines +2375 to +2378
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a brief comment about what this does can be helpful, like, "lane-wise multiply signed n-bit integers in the two input vectors and add adjacent pairs of the full (n*2)-bit results". I wasn't aware that adjacent pairs are added in the result before checking the instruction description myself.

LaneArray<Lanes* Factor> lhs = (left.*IntoLanes)();
LaneArray<Lanes* Factor> rhs = (right.*IntoLanes)();
LaneArray<Lanes> result;
for (size_t i = 0; i < Lanes; ++i) {
result[i] = Literal(int32_t(0));
for (size_t j = 0; j < Factor; ++j) {
result[i] = Literal(result[i].geti32() + lhs[i * Factor + j].geti32() *
rhs[i * Factor + j].geti32());
}
}
return Literal(result);
}

Literal Literal::dotSI8x16toI16x8(const Literal& other) const {
return dot<8, 2, &Literal::getLanesSI8x16>(*this, other);
}
Literal Literal::dotUI8x16toI16x8(const Literal& other) const {
return dot<8, 2, &Literal::getLanesUI8x16>(*this, other);
}
Literal Literal::dotSI16x8toI32x4(const Literal& other) const {
return dot<4, 2, &Literal::getLanesSI16x8>(*this, other);
}

Literal Literal::bitselectV128(const Literal& left,
const Literal& right) const {
return andV128(left).orV128(notV128().andV128(right));
Expand Down
16 changes: 16 additions & 0 deletions src/wasm/wasm-binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5601,6 +5601,14 @@ bool WasmBinaryBuilder::maybeVisitSIMDBinary(Expression*& out, uint32_t code) {
curr = allocator.alloc<Binary>();
curr->op = RelaxedQ15MulrSVecI16x8;
break;
case BinaryConsts::I16x8DotI8x16I7x16S:
curr = allocator.alloc<Binary>();
curr->op = DotI8x16I7x16SToVecI16x8;
break;
case BinaryConsts::I16x8DotI8x16I7x16U:
curr = allocator.alloc<Binary>();
curr->op = DotI8x16I7x16UToVecI16x8;
break;
default:
return false;
}
Expand Down Expand Up @@ -6075,6 +6083,14 @@ bool WasmBinaryBuilder::maybeVisitSIMDTernary(Expression*& out, uint32_t code) {
curr = allocator.alloc<SIMDTernary>();
curr->op = RelaxedFmsVecF64x2;
break;
case BinaryConsts::I32x4DotI8x16I7x16AddS:
curr = allocator.alloc<SIMDTernary>();
curr->op = DotI8x16I7x16AddSToVecI32x4;
break;
case BinaryConsts::I32x4DotI8x16I7x16AddU:
curr = allocator.alloc<SIMDTernary>();
curr->op = DotI8x16I7x16AddUToVecI32x4;
break;
default:
return false;
}
Expand Down
14 changes: 14 additions & 0 deletions src/wasm/wasm-stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,12 @@ void BinaryInstWriter::visitSIMDTernary(SIMDTernary* curr) {
case RelaxedFmsVecF64x2:
o << U32LEB(BinaryConsts::F64x2RelaxedFms);
break;
case DotI8x16I7x16AddSToVecI32x4:
o << U32LEB(BinaryConsts::I32x4DotI8x16I7x16AddS);
break;
case DotI8x16I7x16AddUToVecI32x4:
o << U32LEB(BinaryConsts::I32x4DotI8x16I7x16AddU);
break;
}
}

Expand Down Expand Up @@ -1846,6 +1852,14 @@ void BinaryInstWriter::visitBinary(Binary* curr) {
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::I16x8RelaxedQ15MulrS);
break;
case DotI8x16I7x16SToVecI16x8:
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::I16x8DotI8x16I7x16S);
break;
case DotI8x16I7x16UToVecI16x8:
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::I16x8DotI8x16I7x16U);
break;

case InvalidBinary:
WASM_UNREACHABLE("invalid binary op");
Expand Down
4 changes: 3 additions & 1 deletion src/wasm/wasm-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,9 @@ void FunctionValidator::visitBinary(Binary* curr) {
case NarrowUVecI32x4ToVecI16x8:
case SwizzleVecI8x16:
case RelaxedSwizzleVecI8x16:
case RelaxedQ15MulrSVecI16x8: {
case RelaxedQ15MulrSVecI16x8:
case DotI8x16I7x16SToVecI16x8:
case DotI8x16I7x16UToVecI16x8: {
shouldBeEqualOrFirstIsUnreachable(
curr->left->type, Type(Type::v128), curr, "v128 op");
shouldBeEqualOrFirstIsUnreachable(
Expand Down
112 changes: 112 additions & 0 deletions test/lit/relaxed-simd.wast
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,88 @@
)
)

;; CHECK-BINARY: (func $i16x8.dot_i8x16_i7x16_s (param $0 v128) (param $1 v128) (result v128)
;; CHECK-BINARY-NEXT: (i16x8.dot_i8x16_i7x16_s
;; CHECK-BINARY-NEXT: (local.get $0)
;; CHECK-BINARY-NEXT: (local.get $1)
;; CHECK-BINARY-NEXT: )
;; CHECK-BINARY-NEXT: )
;; CHECK-TEXT: (func $i16x8.dot_i8x16_i7x16_s (param $0 v128) (param $1 v128) (result v128)
;; CHECK-TEXT-NEXT: (i16x8.dot_i8x16_i7x16_s
;; CHECK-TEXT-NEXT: (local.get $0)
;; CHECK-TEXT-NEXT: (local.get $1)
;; CHECK-TEXT-NEXT: )
;; CHECK-TEXT-NEXT: )
(func $i16x8.dot_i8x16_i7x16_s (param $0 v128) (param $1 v128) (result v128)
(i16x8.dot_i8x16_i7x16_s
(local.get $0)
(local.get $1)
)
)

;; CHECK-BINARY: (func $i16x8.dot_i8x16_i7x16_u (param $0 v128) (param $1 v128) (result v128)
;; CHECK-BINARY-NEXT: (i16x8.dot_i8x16_i7x16_u
;; CHECK-BINARY-NEXT: (local.get $0)
;; CHECK-BINARY-NEXT: (local.get $1)
;; CHECK-BINARY-NEXT: )
;; CHECK-BINARY-NEXT: )
;; CHECK-TEXT: (func $i16x8.dot_i8x16_i7x16_u (param $0 v128) (param $1 v128) (result v128)
;; CHECK-TEXT-NEXT: (i16x8.dot_i8x16_i7x16_u
;; CHECK-TEXT-NEXT: (local.get $0)
;; CHECK-TEXT-NEXT: (local.get $1)
;; CHECK-TEXT-NEXT: )
;; CHECK-TEXT-NEXT: )
(func $i16x8.dot_i8x16_i7x16_u (param $0 v128) (param $1 v128) (result v128)
(i16x8.dot_i8x16_i7x16_u
(local.get $0)
(local.get $1)
)
)

;; CHECK-BINARY: (func $i32x4.dot_i8x16_i7x16_add_s (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-BINARY-NEXT: (i32x4.dot_i8x16_i7x16_add_s
;; CHECK-BINARY-NEXT: (local.get $0)
;; CHECK-BINARY-NEXT: (local.get $1)
;; CHECK-BINARY-NEXT: (local.get $2)
;; CHECK-BINARY-NEXT: )
;; CHECK-BINARY-NEXT: )
;; CHECK-TEXT: (func $i32x4.dot_i8x16_i7x16_add_s (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-TEXT-NEXT: (i32x4.dot_i8x16_i7x16_add_s
;; CHECK-TEXT-NEXT: (local.get $0)
;; CHECK-TEXT-NEXT: (local.get $1)
;; CHECK-TEXT-NEXT: (local.get $2)
;; CHECK-TEXT-NEXT: )
;; CHECK-TEXT-NEXT: )
(func $i32x4.dot_i8x16_i7x16_add_s (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
(i32x4.dot_i8x16_i7x16_add_s
(local.get $0)
(local.get $1)
(local.get $2)
)
)

;; CHECK-BINARY: (func $i32x4.dot_i8x16_i7x16_add_u (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-BINARY-NEXT: (i32x4.dot_i8x16_i7x16_add_u
;; CHECK-BINARY-NEXT: (local.get $0)
;; CHECK-BINARY-NEXT: (local.get $1)
;; CHECK-BINARY-NEXT: (local.get $2)
;; CHECK-BINARY-NEXT: )
;; CHECK-BINARY-NEXT: )
;; CHECK-TEXT: (func $i32x4.dot_i8x16_i7x16_add_u (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-TEXT-NEXT: (i32x4.dot_i8x16_i7x16_add_u
;; CHECK-TEXT-NEXT: (local.get $0)
;; CHECK-TEXT-NEXT: (local.get $1)
;; CHECK-TEXT-NEXT: (local.get $2)
;; CHECK-TEXT-NEXT: )
;; CHECK-TEXT-NEXT: )
(func $i32x4.dot_i8x16_i7x16_add_u (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
(i32x4.dot_i8x16_i7x16_add_u
(local.get $0)
(local.get $1)
(local.get $2)
)
)

)
;; CHECK-NODEBUG: (type $v128_v128_v128_=>_v128 (func (param v128 v128 v128) (result v128)))

Expand Down Expand Up @@ -507,3 +589,33 @@
;; CHECK-NODEBUG-NEXT: (local.get $1)
;; CHECK-NODEBUG-NEXT: )
;; CHECK-NODEBUG-NEXT: )

;; CHECK-NODEBUG: (func $18 (param $0 v128) (param $1 v128) (result v128)
;; CHECK-NODEBUG-NEXT: (i16x8.dot_i8x16_i7x16_s
;; CHECK-NODEBUG-NEXT: (local.get $0)
;; CHECK-NODEBUG-NEXT: (local.get $1)
;; CHECK-NODEBUG-NEXT: )
;; CHECK-NODEBUG-NEXT: )

;; CHECK-NODEBUG: (func $19 (param $0 v128) (param $1 v128) (result v128)
;; CHECK-NODEBUG-NEXT: (i16x8.dot_i8x16_i7x16_u
;; CHECK-NODEBUG-NEXT: (local.get $0)
;; CHECK-NODEBUG-NEXT: (local.get $1)
;; CHECK-NODEBUG-NEXT: )
;; CHECK-NODEBUG-NEXT: )

;; CHECK-NODEBUG: (func $20 (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-NODEBUG-NEXT: (i32x4.dot_i8x16_i7x16_add_s
;; CHECK-NODEBUG-NEXT: (local.get $0)
;; CHECK-NODEBUG-NEXT: (local.get $1)
;; CHECK-NODEBUG-NEXT: (local.get $2)
;; CHECK-NODEBUG-NEXT: )
;; CHECK-NODEBUG-NEXT: )

;; CHECK-NODEBUG: (func $21 (param $0 v128) (param $1 v128) (param $2 v128) (result v128)
;; CHECK-NODEBUG-NEXT: (i32x4.dot_i8x16_i7x16_add_u
;; CHECK-NODEBUG-NEXT: (local.get $0)
;; CHECK-NODEBUG-NEXT: (local.get $1)
;; CHECK-NODEBUG-NEXT: (local.get $2)
;; CHECK-NODEBUG-NEXT: )
;; CHECK-NODEBUG-NEXT: )