You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[OpenCL] Implement BatchedReduceAdd for arbitrary axes
Description:
This commit extends the implementation of BatchedReduceAdd in the OpenCL backend
so that it can handle any reduction axis, not just 1. This can be useful
when the first dimension is the batch dimension, and a reduction needs
to be performed within each example.
The existing implementation for axis = 0 computes each slice element in
parallel and linearizes the slice for simplicity (i.e. creates a 1D
global workspace as large as the number of elements in the output). This
implementation generalizes this concept by creating a global workspace
with rank equal to the number of dimensions of the output and computing
each one in parallel. The slice sizes of the input and output shapes are
are precomputed on the host and passed in as kernel arguments so that
the kernel can compute the correct offsets into the input and output
buffer by multiplying its set of global IDs with those slice sizes.
Testing:
This commit enables the existing non-zero axis BatchedReduceAdd unit
test for OpenCL and augments it to test axis=2. All unit tests pass.
0 commit comments