diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 45aea1ccdb6d4..fa865718bc552 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -84,6 +84,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>; def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>; def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>; +def int_dx_wave_active_op : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i8_ty, llvm_i8_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>; def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>; def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index e8f56b18730d7..df43cae5edaed 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -793,6 +793,16 @@ def CreateHandleFromBinding : DXILOp<218, createHandleFromBinding> { let stages = [Stages]; } +def WaveActiveOp : DXILOp<119, waveActiveOp> { + let Doc = "returns the result of the operation across waves"; + let LLVMIntrinsic = int_dx_wave_active_op; + let arguments = [OverloadTy, Int8Ty, Int8Ty]; + let result = OverloadTy; + let overloads = [Overloads]; + let stages = [Stages]; + let attributes = [Attributes]; +} + def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> { let Doc = "returns 1 for the first lane in the wave"; let LLVMIntrinsic = int_dx_wave_is_first_lane; diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp index be714b5c87895..b0f54a0679de2 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp @@ -18,6 +18,9 @@ using namespace llvm; bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx) { switch (ID) { + case Intrinsic::dx_wave_active_op: { + return ScalarOpdIdx == 1 || ScalarOpdIdx == 2; + } default: return false; } @@ -26,6 +29,7 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable( Intrinsic::ID ID) const { switch (ID) { + case Intrinsic::dx_wave_active_op: case Intrinsic::dx_frac: case Intrinsic::dx_rsqrt: return true; diff --git a/llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll b/llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll new file mode 100644 index 0000000000000..d5d1e615e99af --- /dev/null +++ b/llvm/test/CodeGen/DirectX/WaveActiveOp-vec.ll @@ -0,0 +1,34 @@ +; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s + +; Test that for scalar values, WaveReadLaneAt maps down to the DirectX op + +define noundef <2 x half> @wave_active_op_v2half(<2 x half> noundef %expr) { +entry: +; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i0, i8 0, i8 0) +; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr.i1, i8 0, i8 0) + %ret = call <2 x half> @llvm.dx.wave.active.op.f16(<2 x half> %expr, i8 0, i8 0) + ret <2 x half> %ret +} + +define noundef <3 x i32> @wave_active_op_v3i32(<3 x i32> noundef %expr) { +entry: +; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i0, i8 1, i8 1) +; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i1, i8 1, i8 1) +; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr.i2, i8 1, i8 1) + %ret = call <3 x i32> @llvm.dx.wave.active.op(<3 x i32> %expr, i8 1, i8 1) + ret <3 x i32> %ret +} + +define noundef <4 x double> @wave_active_op_v4f64(<4 x double> noundef %expr) { +entry: +; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i0, i8 2, i8 0) +; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i1, i8 2, i8 0) +; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i2, i8 2, i8 0) +; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr.i3, i8 2, i8 0) + %ret = call <4 x double> @llvm.dx.wave.active.op(<4 x double> %expr, i8 2, i8 0) + ret <4 x double> %ret +} + +declare <2 x half> @llvm.dx.wave.active.op.v2f16(<2 x half>, i8, i8) +declare <3 x i32> @llvm.dx.wave.active.op.v3i32(<3 x i32>, i8, i8) +declare <4 x double> @llvm.dx.wave.active.op.v4f64(<4 x double>, i8, i8) diff --git a/llvm/test/CodeGen/DirectX/WaveActiveOp.ll b/llvm/test/CodeGen/DirectX/WaveActiveOp.ll new file mode 100644 index 0000000000000..e6cafd696d25c --- /dev/null +++ b/llvm/test/CodeGen/DirectX/WaveActiveOp.ll @@ -0,0 +1,53 @@ +; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s + +; Test that for scalar values, WaveReadLaneAt maps down to the DirectX op + +define noundef half @wave_active_op_half(half noundef %expr) { +entry: +; CHECK: call half @dx.op.waveActiveOp.f16(i32 119, half %expr, i8 0, i8 0) + %ret = call half @llvm.dx.wave.active.op.f16(half %expr, i8 0, i8 0) + ret half %ret +} + +define noundef float @wave_active_op_float(float noundef %expr) { +entry: +; CHECK: call float @dx.op.waveActiveOp.f32(i32 119, float %expr, i8 1, i8 0) + %ret = call float @llvm.dx.wave.active.op(float %expr, i8 1, i8 0) + ret float %ret +} + +define noundef double @wave_active_op_double(double noundef %expr) { +entry: +; CHECK: call double @dx.op.waveActiveOp.f64(i32 119, double %expr, i8 2, i8 0) + %ret = call double @llvm.dx.wave.active.op(double %expr, i8 2, i8 0) + ret double %ret +} + +define noundef i16 @wave_active_op_i16(i16 noundef %expr) { +entry: +; CHECK: call i16 @dx.op.waveActiveOp.i16(i32 119, i16 %expr, i8 1, i8 0) + %ret = call i16 @llvm.dx.wave.active.op.i16(i16 %expr, i8 1, i8 0) + ret i16 %ret +} + +define noundef i32 @wave_active_op_i32(i32 noundef %expr) { +entry: +; CHECK: call i32 @dx.op.waveActiveOp.i32(i32 119, i32 %expr, i8 2, i8 1) + %ret = call i32 @llvm.dx.wave.active.op.i32(i32 %expr, i8 2, i8 1) + ret i32 %ret +} + +define noundef i64 @wave_active_op_i64(i64 noundef %expr) { +entry: +; CHECK: call i64 @dx.op.waveActiveOp.i64(i32 119, i64 %expr, i8 3, i8 0) + %ret = call i64 @llvm.dx.wave.active.op.i64(i64 %expr, i8 3, i8 0) + ret i64 %ret +} + +declare half @llvm.dx.wave.active.op.f16(half, i8, i8) +declare float @llvm.dx.wave.active.op.f32(float, i8, i8) +declare double @llvm.dx.wave.active.op.f64(double, i8, i8) + +declare i16 @llvm.dx.wave.active.op.i16(i16, i8, i8) +declare i32 @llvm.dx.wave.active.op.i32(i32, i8, i8) +declare i64 @llvm.dx.wave.active.op.i64(i64, i8, i8)