Skip to content

Commit 086fd84

Browse files
Meghan Lelefacebook-github-bot
Meghan Lele
authored andcommitted
Implement BatchedReduceAdd for arbitrary axes (#2958)
Summary: **Description** This commit extends the implementation of `BatchedReduceAdd` in the OpenCL backend so that it can handle any reduction axis, not just 1. This can be useful when the first dimension is the batch dimension, and a reduction needs to be performed within each example. The existing implementation for axis = 0 computes each slice element in parallel and linearizes the slice for simplicity (i.e. creates a 1D global workspace as large as the number of elements in the output). This implementation generalizes this concept by creating a global workspace with rank equal to the number of dimensions of the output and computing each one in parallel. The slice sizes of the input and output shapes are are precomputed on the host and passed in as kernel arguments so that the kernel can compute the correct offsets into the input and output buffer by multiplying its set of global IDs with those slice sizes. **Test Plan** This commit enables the existing non-zero axis `BatchedReduceAdd` unit test for OpenCL and modifies it test `axis=2`. All unit tests pass. Pull Request resolved: #2958 Differential Revision: D15462357 Pulled By: SplitInfinity fbshipit-source-id: c1cb526ad12fbb000c01215d531cd5dd6a0c0929
1 parent 1abe4e5 commit 086fd84

File tree

3 files changed

+141
-22
lines changed

3 files changed

+141
-22
lines changed

lib/Backends/OpenCL/OpenCL.cpp

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -913,18 +913,92 @@ llvm::Error OpenCLFunction::execute(ExecutionContext *context) {
913913
}
914914

915915
if (auto *BRA = dyn_cast<BatchedReduceAddInst>(&I)) {
916-
assert(BRA->getAxis() == 0 && "No current support for non-zero axis.");
916+
auto axis = BRA->getAxis();
917+
918+
// Determine and store the slice sizes of each input dimension excluding
919+
// the reduce axis into batchSliceSizes. These are used by the kernel to
920+
// index correctly into the input buffer. If the input has one dimension
921+
// (that is also the reduce axis), store one slice of size 1 into
922+
// batchSliceSizes.
923+
auto batchDims = BRA->getBatch()->getType()->dims();
924+
auto numBatchDims = batchDims.size();
925+
std::vector<size_t> batchSliceSizes(
926+
numBatchDims > 1 ? numBatchDims - 1 : 1, 1);
927+
size_t currentSliceSize = 1, axisSliceSize = 1;
928+
for (size_t i = batchSliceSizes.size() - 1, j = i; i >= 0; ++i) {
929+
if (i == axis) {
930+
axisSliceSize = currentSliceSize;
931+
} else {
932+
batchSliceSizes[j--] = currentSliceSize;
933+
}
934+
935+
currentSliceSize *= batchDims[i];
936+
}
937+
938+
// Determine and store the slice sizes of each output dimension excluding
939+
// the reduce axis into destSliceSizes. These are used by the kernel to
940+
// index correctly into the output buffer. If the output has zero
941+
// dimensions store one slice of size 1 into destSliceSizes.
942+
auto destDims = BRA->getDest()->getType()->dims();
943+
std::vector<size_t> destDimsVec(destDims.begin(), destDims.end());
944+
if (destDims.empty()) {
945+
destDimsVec.emplace_back(1);
946+
}
947+
auto numDestDims = destDimsVec.size();
948+
std::vector<size_t> destSliceSizes(numDestDims > 0 ? numDestDims : 1, 1);
949+
for (size_t i = 2, e = destDimsVec.size(); i <= e; ++i) {
950+
destSliceSizes[e - i] = destSliceSizes[e - i + 1] * destDimsVec[e - i];
951+
}
952+
953+
// Allocate device buffers for batchSliceSizes and destSliceSizes.
954+
size_t batchSlicesBufSize = batchSliceSizes.size() * sizeof(size_t);
955+
size_t destSlicesBufSize = destSliceSizes.size() * sizeof(size_t);
956+
cl_mem batchSlicesBuf = allocDeviceBuffer(batchSlicesBufSize);
957+
cl_mem destSlicesBuf = allocDeviceBuffer(destSlicesBufSize);
958+
959+
// Copy batchSliceSizes and destSliceSizes from host to device.
960+
cl_event writeBatchSlicesEvent{nullptr}, writeDestSlicesEvent{nullptr};
961+
cl_int err = clEnqueueWriteBuffer(
962+
commands_, batchSlicesBuf, /*blocking_write=*/CL_FALSE, /*offset=*/0,
963+
batchSlicesBufSize, batchSliceSizes.data(),
964+
/* num_events_in_wait_list */ 0,
965+
/* event_list */ nullptr,
966+
/* event */ kernelProfiling_ ? &writeBatchSlicesEvent : nullptr);
967+
GLOW_ASSERT(err == CL_SUCCESS && "Unable to copy BRA data to the device");
968+
if (kernelProfiling_) {
969+
kernelLaunches_.emplace_back(KernelLaunch("batchedReduceAddSliceData",
970+
"batchedReduceAddSliceData",
971+
writeBatchSlicesEvent));
972+
}
917973

974+
err = clEnqueueWriteBuffer(
975+
commands_, destSlicesBuf, /*blocking_write=*/CL_FALSE, /*offset=*/0,
976+
destSlicesBufSize, destSliceSizes.data(),
977+
/* num_events_in_wait_list */ 0,
978+
/* event_list */ nullptr,
979+
/* event */ kernelProfiling_ ? &writeDestSlicesEvent : nullptr);
980+
GLOW_ASSERT(err == CL_SUCCESS && "Unable to copy BRA data to the device");
981+
if (kernelProfiling_) {
982+
kernelLaunches_.emplace_back(KernelLaunch("batchedReduceAddSliceData",
983+
"batchedReduceAddSliceData",
984+
writeDestSlicesEvent));
985+
}
986+
987+
// Wait for the writes to finish.
988+
clFinish(commands_);
989+
990+
// Create kernel and set arguments.
918991
cl_kernel kernel = createKernel(kernelName);
919992
setKernelArg(kernel, 0, deviceBuffer_);
920993
auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);
921994

922-
auto bdim = flattenCdr(BRA->getBatch()->dims());
923-
setKernelArg<cl_uint>(kernel, numArgs + 1, bdim.first);
924-
setKernelArg<cl_uint>(kernel, numArgs + 2, bdim.second);
995+
setKernelArg(kernel, numArgs + 1, batchSlicesBuf);
996+
setKernelArg(kernel, numArgs + 2, destSlicesBuf);
997+
setKernelArg<cl_uint>(kernel, numArgs + 3, batchDims[axis]);
998+
setKernelArg<cl_uint>(kernel, numArgs + 4, axisSliceSize);
925999

9261000
// Parallelize on each element in the slice.
927-
enqueueKernel(I.getName(), commands_, kernel, deviceId_, {bdim.second},
1001+
enqueueKernel(I.getName(), commands_, kernel, deviceId_, destDimsVec,
9281002
kernelLaunches_);
9291003
continue;
9301004
}

lib/Backends/OpenCL/kernels.cl

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -619,18 +619,52 @@ __kernel void elementcmplteW(__global void *mem, cl_uint32_t dest,
619619
}
620620

621621
__kernel void batchedreduceaddK(__global float *dest, __global float *batch,
622-
cl_uint32_t numSlice, cl_uint32_t sliceSize) {
623-
size_t s = get_global_id(0);
624-
dest[s] = 0;
625-
for (size_t n = 0; n < numSlice; n++) {
626-
dest[s] += batch[n * sliceSize + s];
622+
__global cl_host_size_t *batchSliceSizes,
623+
__global cl_host_size_t *destSliceSize,
624+
cl_uint32_t numSlices,
625+
cl_uint32_t axisSliceSize) {
626+
size_t workDim = get_work_dim();
627+
628+
// This is the component of the offset into batch that depends only on the
629+
// kernel's global IDs.
630+
size_t batchOffset = 0;
631+
632+
// This is the offset into dest. It depends only on the kernel's global IDs.
633+
size_t destOffset = 0;
634+
635+
// Compute batchOffset and destOffset by multiplying the kernel's global IDs
636+
// with the corresponding batch and dest slice sizes.
637+
//
638+
// For example, suppose the input shape is {3, 4, 5} and the reduce axis is 1.
639+
// Then, the output shape is {3, 5}. In this case, batchSliceSizes is {4 * 5 =
640+
// 20, 1} (axis 1 is missing) and destSliceSizes is {5, 1}. The global
641+
// workspace this kernel was launched with has dimensions {3, 5} (one for each
642+
// output element). A kernel with IDs {i, j} will add together elements
643+
// {i, 0..4, j} and store the result in element {i, j}, so (i * 20 + j * 1)
644+
// will be a component of every offset it uses to access batch, and (i * 5 + j
645+
// * 1) will be the offset it uses to access dest. This is precisely what
646+
// batchOffset and destOffset are. The loop below precomputes these offsets
647+
// before the actual reduce.
648+
for (size_t i = 0; i < workDim; ++i) {
649+
size_t id = get_global_id(i);
650+
batchOffset += id * batchSliceSizes[i];
651+
destOffset += id * destSliceSize[i];
652+
}
653+
654+
// Perform the actual reduce. Add the slice number * the slice size at the
655+
// axis index to batchOffset to get the elements to add together.
656+
dest[destOffset] = 0;
657+
for (size_t n = 0; n < numSlices; n++) {
658+
dest[destOffset] += batch[n * axisSliceSize + batchOffset];
627659
}
628660
}
629661

630-
__kernel void batchedreduceaddW(__global void *mem, cl_uint32_t dest,
631-
cl_uint32_t batch, cl_uint32_t numSlice,
632-
cl_uint32_t sliceSize) {
633-
batchedreduceaddK(&mem[dest], &mem[batch], numSlice, sliceSize);
662+
__kernel void
663+
batchedreduceaddW(__global void *mem, cl_uint32_t dest, cl_uint32_t batch,
664+
__global void *batchSliceSizes, __global void *destSliceSizes,
665+
cl_uint32_t numSlices, cl_uint32_t axisSliceSize) {
666+
batchedreduceaddK(&mem[dest], &mem[batch], batchSliceSizes, destSliceSizes,
667+
numSlices, axisSliceSize);
634668
}
635669

636670
__kernel void batchedaddK(__global float *dest, __global float *batch,

tests/unittests/OperatorTest.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -794,22 +794,33 @@ static void testBatchedReduceAddWithAxis(glow::PlaceholderBindings &bindings,
794794
bindings.allocate(batch)->getHandle<DataType>() = {0, 1, 2, 3, 4, 5,
795795
6, 7, 8, 9, 10, 11};
796796

797-
auto OT = uniqueTypeConditionallyQuantized(mod, DTy, {2, 2});
798-
auto *R = F->createBatchedReduceAdd("reduce.add", OT, batch, /* axis */ 1);
799-
auto *save = F->createSave("save", R);
800-
auto *result = bindings.allocate(save->getPlaceholder());
797+
auto OT1 = uniqueTypeConditionallyQuantized(mod, DTy, {2, 2});
798+
auto *R1 =
799+
F->createBatchedReduceAdd("reduce.add.axis.1", OT1, batch, /* axis */ 1);
800+
auto OT2 = uniqueTypeConditionallyQuantized(mod, DTy, {2, 3});
801+
auto *R2 =
802+
F->createBatchedReduceAdd("reduce.add.axis.2", OT2, batch, /* axis */ 2);
803+
auto *save1 = F->createSave("save1", R1);
804+
auto *save2 = F->createSave("save2", R2);
805+
806+
auto *result1 = bindings.allocate(save1->getPlaceholder());
807+
auto *result2 = bindings.allocate(save2->getPlaceholder());
801808

802809
EE.compile(CompilationMode::Infer, F);
803810
EE.run(bindings);
804811

805-
auto expected = createTensorConditionallyQuantized(DTy, {2, 2});
806-
expected.getHandle<DataType>() = {6, 9, 24, 27};
807-
EXPECT_TRUE(result->isEqual(expected));
812+
auto expected1 = createTensorConditionallyQuantized(DTy, {2, 2});
813+
expected1.getHandle<DataType>() = {6, 9, 24, 27};
814+
EXPECT_TRUE(result1->isEqual(expected1));
815+
816+
auto expected2 = createTensorConditionallyQuantized(DTy, {2, 3});
817+
expected2.getHandle<DataType>() = {1, 5, 9, 13, 17, 21};
818+
EXPECT_TRUE(result2->isEqual(expected2));
808819
}
809820

810821
/// Test that batchedReduceAddWithAxis is correctly supported in FloatTy.
811822
TEST_P(OperatorTest, batchedReduceAddWithAxis_Float) {
812-
ENABLED_BACKENDS(Interpreter, CPU);
823+
ENABLED_BACKENDS(Interpreter, CPU, OpenCL);
813824
testBatchedReduceAddWithAxis<float>(bindings_, mod_, F_, EE_,
814825
ElemKind::FloatTy);
815826
}

0 commit comments

Comments
 (0)