Skip to content

[OpenCL] Fix bugs in BatchedReduceAddInst implementation #3118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions lib/Backends/OpenCL/OpenCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,22 +916,28 @@ llvm::Error OpenCLFunction::execute(ExecutionContext *context) {
auto axis = BRA->getAxis();

// Determine and store the slice sizes of each input dimension excluding
// the reduce axis into batchSliceSizes. These are used by the kernel to
// index correctly into the input buffer. If the input has one dimension
// (that is also the reduce axis), store one slice of size 1 into
// batchSliceSizes.
// the reduce axis into batchSliceSizes. Determine also the slice size on
// the reduce axis and store that separately. These are used by the kernel
// to index correctly into the input buffer. If the input has one
// dimension (that is also the reduce axis), store one slice of size 1
// into batchSliceSizes.
auto batchDims = BRA->getBatch()->getType()->dims();
auto numBatchDims = batchDims.size();
std::vector<size_t> batchSliceSizes(
numBatchDims > 1 ? numBatchDims - 1 : 1, 1);
size_t currentSliceSize = 1, axisSliceSize = 1;
for (size_t i = batchSliceSizes.size() - 1, j = i; i >= 0; ++i) {
for (ssize_t i = batchDims.size() - 1, j = batchSliceSizes.size() - 1;
i >= 0; --i) {
// If i is the reduce axis, currentSliceSize is the slice size at the
// reduce axis. Store it in axisSliceSize and not in batchSliceSizes. If
// not, do the opposite.
if (i == axis) {
axisSliceSize = currentSliceSize;
} else {
batchSliceSizes[j--] = currentSliceSize;
}

// Compute the slice size for the next iteration.
currentSliceSize *= batchDims[i];
}

Expand All @@ -946,8 +952,14 @@ llvm::Error OpenCLFunction::execute(ExecutionContext *context) {
}
auto numDestDims = destDimsVec.size();
std::vector<size_t> destSliceSizes(numDestDims > 0 ? numDestDims : 1, 1);
for (size_t i = 2, e = destDimsVec.size(); i <= e; ++i) {
destSliceSizes[e - i] = destSliceSizes[e - i + 1] * destDimsVec[e - i];

// Start i at destDimsVec.size() - 2 because the last slice size is always
// known to be 1.
for (ssize_t i = destDimsVec.size() - 2; i >= 0; --i) {
// The slice size of the current dimension is the slice size of the
// previous dimension multiplied by the number of elements in that
// dimension.
destSliceSizes[i] = destSliceSizes[i + 1] * destDimsVec[i + 1];
}

// Allocate device buffers for batchSliceSizes and destSliceSizes.
Expand Down