Skip to content

[WebAssembly] Implement the wide-arithmetic proposal #111598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
@@ -5098,6 +5098,8 @@ def msimd128 : Flag<["-"], "msimd128">, Group<m_wasm_Features_Group>;
def mno_simd128 : Flag<["-"], "mno-simd128">, Group<m_wasm_Features_Group>;
def mtail_call : Flag<["-"], "mtail-call">, Group<m_wasm_Features_Group>;
def mno_tail_call : Flag<["-"], "mno-tail-call">, Group<m_wasm_Features_Group>;
def mwide_arithmetic : Flag<["-"], "mwide-arithmetic">, Group<m_wasm_Features_Group>;
def mno_wide_arithmetic : Flag<["-"], "mno-wide-arithmetic">, Group<m_wasm_Features_Group>;
def mexec_model_EQ : Joined<["-"], "mexec-model=">, Group<m_wasm_Features_Driver_Group>,
Values<"command,reactor">,
HelpText<"Execution model (WebAssembly only)">,
12 changes: 12 additions & 0 deletions clang/lib/Basic/Targets/WebAssembly.cpp
Original file line number Diff line number Diff line change
@@ -59,6 +59,7 @@ bool WebAssemblyTargetInfo::hasFeature(StringRef Feature) const {
.Case("sign-ext", HasSignExt)
.Case("simd128", SIMDLevel >= SIMD128)
.Case("tail-call", HasTailCall)
.Case("wide-arithmetic", HasWideArithmetic)
.Default(false);
}

@@ -102,6 +103,8 @@ void WebAssemblyTargetInfo::getTargetDefines(const LangOptions &Opts,
Builder.defineMacro("__wasm_simd128__");
if (HasTailCall)
Builder.defineMacro("__wasm_tail_call__");
if (HasWideArithmetic)
Builder.defineMacro("__wasm_wide_arithmetic__");

Builder.defineMacro("__GCC_HAVE_SYNC_COMPARE_AND_SWAP_1");
Builder.defineMacro("__GCC_HAVE_SYNC_COMPARE_AND_SWAP_2");
@@ -166,6 +169,7 @@ bool WebAssemblyTargetInfo::initFeatureMap(
Features["multimemory"] = true;
Features["nontrapping-fptoint"] = true;
Features["tail-call"] = true;
Features["wide-arithmetic"] = true;
setSIMDLevel(Features, RelaxedSIMD, true);
};
if (CPU == "generic") {
@@ -293,6 +297,14 @@ bool WebAssemblyTargetInfo::handleTargetFeatures(
HasTailCall = false;
continue;
}
if (Feature == "+wide-arithmetic") {
HasWideArithmetic = true;
continue;
}
if (Feature == "-wide-arithmetic") {
HasWideArithmetic = false;
continue;
}

Diags.Report(diag::err_opt_not_valid_with_opt)
<< Feature << "-target-feature";
1 change: 1 addition & 0 deletions clang/lib/Basic/Targets/WebAssembly.h
Original file line number Diff line number Diff line change
@@ -65,6 +65,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyTargetInfo : public TargetInfo {
bool HasReferenceTypes = false;
bool HasSignExt = false;
bool HasTailCall = false;
bool HasWideArithmetic = false;

std::string ABI;

6 changes: 6 additions & 0 deletions clang/test/Driver/wasm-features.c
Original file line number Diff line number Diff line change
@@ -94,3 +94,9 @@

// TAIL-CALL: "-target-feature" "+tail-call"
// NO-TAIL-CALL: "-target-feature" "-tail-call"

// RUN: %clang --target=wasm32-unknown-unknown -### %s -mwide-arithmetic 2>&1 | FileCheck %s -check-prefix=WIDE-ARITH
// RUN: %clang --target=wasm32-unknown-unknown -### %s -mno-wide-arithmetic 2>&1 | FileCheck %s -check-prefix=NO-WIDE-ARITH

// WIDE-ARITH: "-target-feature" "+wide-arithmetic"
// NO-WIDE-ARITH: "-target-feature" "-wide-arithmetic"
12 changes: 12 additions & 0 deletions clang/test/Preprocessor/wasm-target-features.c
Original file line number Diff line number Diff line change
@@ -154,6 +154,7 @@
// MVP-NOT: #define __wasm_sign_ext__ 1{{$}}
// MVP-NOT: #define __wasm_simd128__ 1{{$}}
// MVP-NOT: #define __wasm_tail_call__ 1{{$}}
// MVP-NOT: #define __wasm_wide_arithmetic__ 1{{$}}

// RUN: %clang -E -dM %s -o - 2>&1 \
// RUN: -target wasm32-unknown-unknown -mcpu=generic \
@@ -184,6 +185,7 @@
// GENERIC-NOT: #define __wasm_relaxed_simd__ 1{{$}}
// GENERIC-NOT: #define __wasm_simd128__ 1{{$}}
// GENERIC-NOT: #define __wasm_tail_call__ 1{{$}}
// GENERIC-NOT: #define __wasm_wide_arithmetic__ 1{{$}}

// RUN: %clang -E -dM %s -o - 2>&1 \
// RUN: -target wasm32-unknown-unknown -mcpu=bleeding-edge \
@@ -206,6 +208,7 @@
// BLEEDING-EDGE-INCLUDE-DAG: #define __wasm_sign_ext__ 1{{$}}
// BLEEDING-EDGE-INCLUDE-DAG: #define __wasm_simd128__ 1{{$}}
// BLEEDING-EDGE-INCLUDE-DAG: #define __wasm_tail_call__ 1{{$}}
// BLEEDING-EDGE-INCLUDE-DAG: #define __wasm_wide_arithmetic__ 1{{$}}

// RUN: %clang -E -dM %s -o - 2>&1 \
// RUN: -target wasm32-unknown-unknown -mcpu=bleeding-edge -mno-simd128 \
@@ -215,3 +218,12 @@
// RUN: | FileCheck %s -check-prefix=BLEEDING-EDGE-NO-SIMD128
//
// BLEEDING-EDGE-NO-SIMD128-NOT: #define __wasm_simd128__ 1{{$}}

// RUN: %clang -E -dM %s -o - 2>&1 \
// RUN: -target wasm32-unknown-unknown -mwide-arithmetic \
// RUN: | FileCheck %s -check-prefix=WIDE-ARITHMETIC
// RUN: %clang -E -dM %s -o - 2>&1 \
// RUN: -target wasm64-unknown-unknown -mwide-arithmetic \
// RUN: | FileCheck %s -check-prefix=WIDE-ARITHMETIC
//
// WIDE-ARITHMETIC: #define __wasm_wide_arithmetic__ 1{{$}}
4 changes: 4 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssembly.td
Original file line number Diff line number Diff line change
@@ -78,6 +78,10 @@ def FeatureTailCall :
SubtargetFeature<"tail-call", "HasTailCall", "true",
"Enable tail call instructions">;

def FeatureWideArithmetic :
SubtargetFeature<"wide-arithmetic", "HasWideArithmetic", "true",
"Enable wide-arithmetic instructions">;

//===----------------------------------------------------------------------===//
// Architectures.
//===----------------------------------------------------------------------===//
4 changes: 4 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISD.def
Original file line number Diff line number Diff line change
@@ -44,6 +44,10 @@ HANDLE_NODETYPE(TRUNC_SAT_ZERO_U)
HANDLE_NODETYPE(DEMOTE_ZERO)
HANDLE_NODETYPE(MEMORY_COPY)
HANDLE_NODETYPE(MEMORY_FILL)
HANDLE_NODETYPE(I64_ADD128)
HANDLE_NODETYPE(I64_SUB128)
HANDLE_NODETYPE(I64_MUL_WIDE_S)
HANDLE_NODETYPE(I64_MUL_WIDE_U)

// Memory intrinsics
HANDLE_MEM_NODETYPE(GLOBAL_GET)
71 changes: 71 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
@@ -167,6 +167,13 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setOperationAction(Op, T, Expand);
}

if (Subtarget->hasWideArithmetic()) {
setOperationAction(ISD::ADD, MVT::i128, Custom);
setOperationAction(ISD::SUB, MVT::i128, Custom);
setOperationAction(ISD::SMUL_LOHI, MVT::i64, Custom);
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Custom);
}

if (Subtarget->hasNontrappingFPToInt())
for (auto Op : {ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT})
for (auto T : {MVT::i32, MVT::i64})
@@ -1443,6 +1450,10 @@ void WebAssemblyTargetLowering::ReplaceNodeResults(
// Do not add any results, signifying that N should not be custom lowered.
// EXTEND_VECTOR_INREG is implemented for some vectors, but not all.
break;
case ISD::ADD:
case ISD::SUB:
Results.push_back(Replace128Op(N, DAG));
break;
default:
llvm_unreachable(
"ReplaceNodeResults not implemented for this op for WebAssembly!");
@@ -1519,6 +1530,9 @@ SDValue WebAssemblyTargetLowering::LowerOperation(SDValue Op,
return DAG.UnrollVectorOp(Op.getNode());
case ISD::CLEAR_CACHE:
report_fatal_error("llvm.clear_cache is not supported on wasm");
case ISD::SMUL_LOHI:
case ISD::UMUL_LOHI:
return LowerMUL_LOHI(Op, DAG);
}
}

@@ -1617,6 +1631,63 @@ SDValue WebAssemblyTargetLowering::LowerLoad(SDValue Op,
return Op;
}

SDValue WebAssemblyTargetLowering::LowerMUL_LOHI(SDValue Op,
SelectionDAG &DAG) const {
assert(Subtarget->hasWideArithmetic());
assert(Op.getValueType() == MVT::i64);
SDLoc DL(Op);
unsigned Opcode;
switch (Op.getOpcode()) {
case ISD::UMUL_LOHI:
Opcode = WebAssemblyISD::I64_MUL_WIDE_U;
break;
case ISD::SMUL_LOHI:
Opcode = WebAssemblyISD::I64_MUL_WIDE_S;
break;
default:
llvm_unreachable("unexpected opcode");
}
SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);
SDValue Hi =
DAG.getNode(Opcode, DL, DAG.getVTList(MVT::i64, MVT::i64), LHS, RHS);
SDValue Lo(Hi.getNode(), 1);
SDValue Ops[] = {Hi, Lo};
return DAG.getMergeValues(Ops, DL);
}

SDValue WebAssemblyTargetLowering::Replace128Op(SDNode *N,
SelectionDAG &DAG) const {
assert(Subtarget->hasWideArithmetic());
auto ValTy = N->getValueType(0);
assert(ValTy == MVT::i128);
SDLoc DL(N);
unsigned Opcode;
switch (N->getOpcode()) {
case ISD::ADD:
Opcode = WebAssemblyISD::I64_ADD128;
break;
case ISD::SUB:
Opcode = WebAssemblyISD::I64_SUB128;
break;
default:
llvm_unreachable("unexpected opcode");
}
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);

SDValue C0 = DAG.getConstant(0, DL, MVT::i64);
SDValue C1 = DAG.getConstant(1, DL, MVT::i64);
SDValue LHS_0 = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i64, LHS, C0);
SDValue LHS_1 = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i64, LHS, C1);
SDValue RHS_0 = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i64, RHS, C0);
SDValue RHS_1 = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i64, RHS, C1);
SDValue Result_LO = DAG.getNode(Opcode, DL, DAG.getVTList(MVT::i64, MVT::i64),
LHS_0, LHS_1, RHS_0, RHS_1);
SDValue Result_HI(Result_LO.getNode(), 1);
return DAG.getNode(ISD::BUILD_PAIR, DL, N->getVTList(), Result_LO, Result_HI);
}

SDValue WebAssemblyTargetLowering::LowerCopyToReg(SDValue Op,
SelectionDAG &DAG) const {
SDValue Src = Op.getOperand(2);
2 changes: 2 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
Original file line number Diff line number Diff line change
@@ -138,6 +138,8 @@ class WebAssemblyTargetLowering final : public TargetLowering {
SDValue LowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerLoad(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerStore(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMUL_LOHI(SDValue Op, SelectionDAG &DAG) const;
SDValue Replace128Op(SDNode *N, SelectionDAG &DAG) const;

// Custom DAG combine hooks
SDValue
4 changes: 4 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td
Original file line number Diff line number Diff line change
@@ -84,6 +84,10 @@ def HasTailCall :
Predicate<"Subtarget->hasTailCall()">,
AssemblerPredicate<(all_of FeatureTailCall), "tail-call">;

def HasWideArithmetic :
Predicate<"Subtarget->hasWideArithmetic()">,
AssemblerPredicate<(all_of FeatureWideArithmetic), "wide-arithmetic">;

//===----------------------------------------------------------------------===//
// WebAssembly-specific DAG Node Types.
//===----------------------------------------------------------------------===//
45 changes: 45 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrInteger.td
Original file line number Diff line number Diff line change
@@ -129,3 +129,48 @@ def : Pat<(select (i32 (seteq I32:$cond, 0)), I32:$lhs, I32:$rhs),
(SELECT_I32 I32:$rhs, I32:$lhs, I32:$cond)>;
def : Pat<(select (i32 (seteq I32:$cond, 0)), I64:$lhs, I64:$rhs),
(SELECT_I64 I64:$rhs, I64:$lhs, I32:$cond)>;

let Predicates = [HasWideArithmetic] in {
defm I64_ADD128 : I<(outs I64:$lo, I64:$hi), (ins I64:$lhs_lo, I64:$lhs_hi, I64:$rhs_lo, I64:$rhs_hi),
(outs), (ins),
[],
"i64.add128\t$lo, $hi, $lhs_lo, $lhs_hi, $rhs_lo, $rhs_hi",
"i64.add128",
0xfc13>;
defm I64_SUB128 : I<(outs I64:$lo, I64:$hi), (ins I64:$lhs_lo, I64:$lhs_hi, I64:$rhs_lo, I64:$rhs_hi),
(outs), (ins),
[],
"i64.sub128\t$lo, $hi, $lhs_lo, $lhs_hi, $rhs_lo, $rhs_hi",
"i64.sub128",
0xfc14>;
defm I64_MUL_WIDE_S : I<(outs I64:$lo, I64:$hi), (ins I64:$lhs, I64:$rhs),
(outs), (ins),
[],
"i64.mul_wide_s\t$lo, $hi, $lhs, $rhs",
"i64.mul_wide_s",
0xfc15>;
defm I64_MUL_WIDE_U : I<(outs I64:$lo, I64:$hi), (ins I64:$lhs, I64:$rhs),
(outs), (ins),
[],
"i64.mul_wide_u\t$lo, $hi, $lhs, $rhs",
"i64.mul_wide_u",
0xfc16>;
} // Predicates = [HasWideArithmetic]

def wasm_binop128_t : SDTypeProfile<2, 4, []>;
def wasm_add128 : SDNode<"WebAssemblyISD::I64_ADD128", wasm_binop128_t>;
def wasm_sub128 : SDNode<"WebAssemblyISD::I64_SUB128", wasm_binop128_t>;

def : Pat<(wasm_add128 I64:$a, I64:$b, I64:$c, I64:$d),
(I64_ADD128 $a, $b, $c, $d)>;
def : Pat<(wasm_sub128 I64:$a, I64:$b, I64:$c, I64:$d),
(I64_SUB128 $a, $b, $c, $d)>;

def wasm_mul_wide_t : SDTypeProfile<2, 2, []>;
def wasm_mul_wide_s : SDNode<"WebAssemblyISD::I64_MUL_WIDE_S", wasm_mul_wide_t>;
def wasm_mul_wide_u : SDNode<"WebAssemblyISD::I64_MUL_WIDE_U", wasm_mul_wide_t>;

def : Pat<(wasm_mul_wide_s I64:$x, I64:$y),
(I64_MUL_WIDE_S $x, $y)>;
def : Pat<(wasm_mul_wide_u I64:$x, I64:$y),
(I64_MUL_WIDE_U $x, $y)>;
2 changes: 2 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h
Original file line number Diff line number Diff line change
@@ -51,6 +51,7 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo {
bool HasReferenceTypes = false;
bool HasSignExt = false;
bool HasTailCall = false;
bool HasWideArithmetic = false;

/// What processor and OS we're targeting.
Triple TargetTriple;
@@ -106,6 +107,7 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo {
bool hasSignExt() const { return HasSignExt; }
bool hasSIMD128() const { return SIMDLevel >= SIMD128; }
bool hasTailCall() const { return HasTailCall; }
bool hasWideArithmetic() const { return HasWideArithmetic; }

/// Parses features string setting specified subtarget options. Definition of
/// function is auto generated by tblgen.
132 changes: 132 additions & 0 deletions llvm/test/CodeGen/WebAssembly/wide-arithmetic.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mattr=+wide-arithmetic < %s | FileCheck %s

target triple = "wasm32-unknown-unknown"

define i128 @add_i128(i128 %a, i128 %b) {
; CHECK-LABEL: add_i128:
; CHECK: .functype add_i128 (i32, i64, i64, i64, i64) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 1
; CHECK-NEXT: local.get 2
; CHECK-NEXT: local.get 3
; CHECK-NEXT: local.get 4
; CHECK-NEXT: i64.add128
; CHECK-NEXT: local.set 3
; CHECK-NEXT: local.set 4
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 3
; CHECK-NEXT: i64.store 8
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 4
; CHECK-NEXT: i64.store 0
; CHECK-NEXT: # fallthrough-return
%c = add i128 %a, %b
ret i128 %c
}

define i128 @sub_i128(i128 %a, i128 %b) {
; CHECK-LABEL: sub_i128:
; CHECK: .functype sub_i128 (i32, i64, i64, i64, i64) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 1
; CHECK-NEXT: local.get 2
; CHECK-NEXT: local.get 3
; CHECK-NEXT: local.get 4
; CHECK-NEXT: i64.sub128
; CHECK-NEXT: local.set 3
; CHECK-NEXT: local.set 4
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 3
; CHECK-NEXT: i64.store 8
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 4
; CHECK-NEXT: i64.store 0
; CHECK-NEXT: # fallthrough-return
%c = sub i128 %a, %b
ret i128 %c
}

define i128 @mul_i128(i128 %a, i128 %b) {
; CHECK-LABEL: mul_i128:
; CHECK: .functype mul_i128 (i32, i64, i64, i64, i64) -> ()
; CHECK-NEXT: .local i64
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: local.get 3
; CHECK-NEXT: i64.mul_wide_u
; CHECK-NEXT: local.set 5
; CHECK-NEXT: i64.store 0
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 5
; CHECK-NEXT: local.get 1
; CHECK-NEXT: local.get 4
; CHECK-NEXT: i64.mul
; CHECK-NEXT: i64.add
; CHECK-NEXT: local.get 2
; CHECK-NEXT: local.get 3
; CHECK-NEXT: i64.mul
; CHECK-NEXT: i64.add
; CHECK-NEXT: i64.store 8
; CHECK-NEXT: # fallthrough-return
%c = mul i128 %a, %b
ret i128 %c
}

define i128 @i64_mul_wide_s(i64 %a, i64 %b) {
; CHECK-LABEL: i64_mul_wide_s:
; CHECK: .functype i64_mul_wide_s (i32, i64, i64) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 1
; CHECK-NEXT: local.get 2
; CHECK-NEXT: i64.mul_wide_s
; CHECK-NEXT: local.set 1
; CHECK-NEXT: local.set 2
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i64.store 8
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 2
; CHECK-NEXT: i64.store 0
; CHECK-NEXT: # fallthrough-return
%a128 = sext i64 %a to i128
%b128 = sext i64 %b to i128
%c = mul i128 %a128, %b128
ret i128 %c
}

define i128 @i64_mul_wide_u(i64 %a, i64 %b) {
; CHECK-LABEL: i64_mul_wide_u:
; CHECK: .functype i64_mul_wide_u (i32, i64, i64) -> ()
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 1
; CHECK-NEXT: local.get 2
; CHECK-NEXT: i64.mul_wide_u
; CHECK-NEXT: local.set 1
; CHECK-NEXT: local.set 2
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i64.store 8
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 2
; CHECK-NEXT: i64.store 0
; CHECK-NEXT: # fallthrough-return
%a128 = zext i64 %a to i128
%b128 = zext i64 %b to i128
%c = mul i128 %a128, %b128
ret i128 %c
}

define i64 @mul_i128_only_lo(i128 %a, i128 %b) {
; CHECK-LABEL: mul_i128_only_lo:
; CHECK: .functype mul_i128_only_lo (i64, i64, i64, i64) -> (i64)
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 2
; CHECK-NEXT: i64.mul
; CHECK-NEXT: # fallthrough-return
%c = mul i128 %a, %b
%d = trunc i128 %c to i64
ret i64 %d
}