diff --git a/lib/Backends/OpenCL/OpenCL.cpp b/lib/Backends/OpenCL/OpenCL.cpp index 9890c836f3..2d07666aa4 100644 --- a/lib/Backends/OpenCL/OpenCL.cpp +++ b/lib/Backends/OpenCL/OpenCL.cpp @@ -916,22 +916,28 @@ llvm::Error OpenCLFunction::execute(ExecutionContext *context) { auto axis = BRA->getAxis(); // Determine and store the slice sizes of each input dimension excluding - // the reduce axis into batchSliceSizes. These are used by the kernel to - // index correctly into the input buffer. If the input has one dimension - // (that is also the reduce axis), store one slice of size 1 into - // batchSliceSizes. + // the reduce axis into batchSliceSizes. Determine also the slice size on + // the reduce axis and store that separately. These are used by the kernel + // to index correctly into the input buffer. If the input has one + // dimension (that is also the reduce axis), store one slice of size 1 + // into batchSliceSizes. auto batchDims = BRA->getBatch()->getType()->dims(); auto numBatchDims = batchDims.size(); std::vector batchSliceSizes( numBatchDims > 1 ? numBatchDims - 1 : 1, 1); size_t currentSliceSize = 1, axisSliceSize = 1; - for (size_t i = batchSliceSizes.size() - 1, j = i; i >= 0; ++i) { + for (ssize_t i = batchDims.size() - 1, j = batchSliceSizes.size() - 1; + i >= 0; --i) { + // If i is the reduce axis, currentSliceSize is the slice size at the + // reduce axis. Store it in axisSliceSize and not in batchSliceSizes. If + // not, do the opposite. if (i == axis) { axisSliceSize = currentSliceSize; } else { batchSliceSizes[j--] = currentSliceSize; } + // Compute the slice size for the next iteration. currentSliceSize *= batchDims[i]; } @@ -946,8 +952,14 @@ llvm::Error OpenCLFunction::execute(ExecutionContext *context) { } auto numDestDims = destDimsVec.size(); std::vector destSliceSizes(numDestDims > 0 ? numDestDims : 1, 1); - for (size_t i = 2, e = destDimsVec.size(); i <= e; ++i) { - destSliceSizes[e - i] = destSliceSizes[e - i + 1] * destDimsVec[e - i]; + + // Start i at destDimsVec.size() - 2 because the last slice size is always + // known to be 1. + for (ssize_t i = destDimsVec.size() - 2; i >= 0; --i) { + // The slice size of the current dimension is the slice size of the + // previous dimension multiplied by the number of elements in that + // dimension. + destSliceSizes[i] = destSliceSizes[i + 1] * destDimsVec[i + 1]; } // Allocate device buffers for batchSliceSizes and destSliceSizes.