Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,42 @@ using ExpandSMulExtendedPattern =
using ExpandUMulExtendedPattern =
ExpandMulExtendedPattern<UMulExtendedOp, false>;

struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
using OpRewritePattern<IAddCarryOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IAddCarryOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();

// Currently, WGSL only supports 32-bit integer types. Any other integer
// types should already have been promoted/demoted to i32.
Type argTy = lhs.getType();
auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
if (elemTy.getIntOrFloatBitWidth() != 32)
return rewriter.notifyMatchFailure(
loc,
llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));

Value one =
rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
Value zero =
rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));

// Calculate the carry by checking if the addition resulted in an overflow.
Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);

Value add = rewriter.create<CompositeConstructOp>(
loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));

rewriter.replaceOp(op, add);
return success();
}
};

//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
Expand All @@ -191,8 +227,12 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns) {
// WGSL currently does not support extended multiplication ops, see:
// https://github.com/gpuweb/gpuweb/issues/1565.
patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern>(
patterns.getContext());
patterns.add<
// clang-format off
ExpandSMulExtendedPattern,
ExpandUMulExtendedPattern,
ExpandAddCarryPattern
>(patterns.getContext());
}
} // namespace spirv
} // namespace mlir
37 changes: 37 additions & 0 deletions mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,41 @@ spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
}

// CHECK-LABEL: func @iaddcarry_i32
// CHECK-SAME: ([[A:%.+]]: i32, [[B:%.+]]: i32)
// CHECK-NEXT: [[ONE:%.+]] = spirv.Constant 1 : i32
// CHECK-NEXT: [[ZERO:%.+]] = spirv.Constant 0 : i32
// CHECK-NEXT: [[OUT:%.+]] = spirv.IAdd [[A]], [[B]]
// CHECK-NEXT: [[CMP:%.+]] = spirv.ULessThan [[OUT]], [[A]]
// CHECK-NEXT: [[CARRY:%.+]] = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (i32, i32) -> !spirv.struct<(i32, i32)>
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
spirv.func @iaddcarry_i32(%a : i32, %b : i32) -> !spirv.struct<(i32, i32)> "None" {
%0 = spirv.IAddCarry %a, %b : !spirv.struct<(i32, i32)>
spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
}

// CHECK-LABEL: func @iaddcarry_vector_i32
// CHECK-SAME: ([[A:%.+]]: vector<3xi32>, [[B:%.+]]: vector<3xi32>)
// CHECK-NEXT: [[ONE:%.+]] = spirv.Constant dense<1> : vector<3xi32>
// CHECK-NEXT: [[ZERO:%.+]] = spirv.Constant dense<0> : vector<3xi32>
// CHECK-NEXT: [[OUT:%.+]] = spirv.IAdd [[A]], [[B]]
// CHECK-NEXT: [[CMP:%.+]] = spirv.ULessThan [[OUT]], [[A]]
// CHECK-NEXT: [[CARRY:%.+]] = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (vector<3xi32>, vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)>
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
spirv.func @iaddcarry_vector_i32(%a : vector<3xi32>, %b : vector<3xi32>)
-> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {
%0 = spirv.IAddCarry %a, %b : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
spirv.ReturnValue %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
}

// CHECK-LABEL: func @iaddcarry_i16
// CHECK-NEXT: spirv.IAddCarry
// CHECK-NEXT: spirv.ReturnValue
spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None" {
%0 = spirv.IAddCarry %a, %b : !spirv.struct<(i16, i16)>
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
}

} // end module
68 changes: 68 additions & 0 deletions mlir/test/mlir-vulkan-runner/iaddcarry_extended.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Make sure that addition with carry produces expected results
// with and without expansion to primitive add/cmp ops for WebGPU.

// RUN: mlir-vulkan-runner %s \
// RUN: --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
// RUN: --entry-point-result=void | FileCheck %s

// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \
// RUN: --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
// RUN: --entry-point-result=void | FileCheck %s

// CHECK: [0, 42, 0, 42]
// CHECK: [1, 0, 1, 1]
module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
} {
gpu.module @kernels {
gpu.func @kernel_add(%arg0 : memref<4xi32>, %arg1 : memref<4xi32>, %arg2 : memref<4xi32>, %arg3 : memref<4xi32>)
kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
%0 = gpu.block_id x
%lhs = memref.load %arg0[%0] : memref<4xi32>
%rhs = memref.load %arg1[%0] : memref<4xi32>
%sum, %carry = arith.addui_extended %lhs, %rhs : i32, i1

%carry_i32 = arith.extui %carry : i1 to i32

memref.store %sum, %arg2[%0] : memref<4xi32> memref.store %carry_i32, %arg3[%0] : memref<4xi32>
gpu.return
}
}

func.func @main() {
%buf0 = memref.alloc() : memref<4xi32>
%buf1 = memref.alloc() : memref<4xi32>
%buf2 = memref.alloc() : memref<4xi32>
%buf3 = memref.alloc() : memref<4xi32>
%i32_0 = arith.constant 0 : i32

// Initialize output buffers.
%buf4 = memref.cast %buf2 : memref<4xi32> to memref<?xi32>
%buf5 = memref.cast %buf3 : memref<4xi32> to memref<?xi32>
call @fillResource1DInt(%buf4, %i32_0) : (memref<?xi32>, i32) -> ()
call @fillResource1DInt(%buf5, %i32_0) : (memref<?xi32>, i32) -> ()

%idx_0 = arith.constant 0 : index
%idx_1 = arith.constant 1 : index
%idx_4 = arith.constant 4 : index

// Initialize input buffers.
%lhs_vals = arith.constant dense<[-1, 24, 4294967295, 43]> : vector<4xi32>
%rhs_vals = arith.constant dense<[1, 18, 1, 4294967295]> : vector<4xi32>
vector.store %lhs_vals, %buf0[%idx_0] : memref<4xi32>, vector<4xi32>
vector.store %rhs_vals, %buf1[%idx_0] : memref<4xi32>, vector<4xi32>

gpu.launch_func @kernels::@kernel_add
blocks in (%idx_4, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1)
args(%buf0 : memref<4xi32>, %buf1 : memref<4xi32>, %buf2 : memref<4xi32>, %buf3 : memref<4xi32>)
%buf_sum = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
%buf_carry = memref.cast %buf5 : memref<?xi32> to memref<*xi32>
call @printMemrefI32(%buf_sum) : (memref<*xi32>) -> ()
call @printMemrefI32(%buf_carry) : (memref<*xi32>) -> ()
return
}
func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
func.func private @printMemrefI32(%ptr : memref<*xi32>)
}