You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Implement BatchedReduceAdd for arbitrary axes (#2958)
Summary:
**Description**
This commit extends the implementation of `BatchedReduceAdd` in the OpenCL backend
so that it can handle any reduction axis, not just 1. This can be useful
when the first dimension is the batch dimension, and a reduction needs
to be performed within each example.
The existing implementation for axis = 0 computes each slice element in
parallel and linearizes the slice for simplicity (i.e. creates a 1D
global workspace as large as the number of elements in the output). This
implementation generalizes this concept by creating a global workspace
with rank equal to the number of dimensions of the output and computing
each one in parallel. The slice sizes of the input and output shapes are
are precomputed on the host and passed in as kernel arguments so that
the kernel can compute the correct offsets into the input and output
buffer by multiplying its set of global IDs with those slice sizes.
**Test Plan**
This commit enables the existing non-zero axis `BatchedReduceAdd` unit
test for OpenCL and modifies it test `axis=2`. All unit tests pass.
Pull Request resolved: #2958
Differential Revision: D15462357
Pulled By: SplitInfinity
fbshipit-source-id: c1cb526ad12fbb000c01215d531cd5dd6a0c0929
0 commit comments