Skip to content

[OpenCL] Implement BatchedReduceAdd for arbitrary axes #2958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

SplitInfinity
Copy link
Contributor

@SplitInfinity SplitInfinity commented May 22, 2019

Description
This commit extends the implementation of BatchedReduceAdd in the OpenCL backend
so that it can handle any reduction axis, not just 1. This can be useful
when the first dimension is the batch dimension, and a reduction needs
to be performed within each example.

The existing implementation for axis = 0 computes each slice element in
parallel and linearizes the slice for simplicity (i.e. creates a 1D
global workspace as large as the number of elements in the output). This
implementation generalizes this concept by creating a global workspace
with rank equal to the number of dimensions of the output and computing
each one in parallel. The slice sizes of the input and output shapes are
are precomputed on the host and passed in as kernel arguments so that
the kernel can compute the correct offsets into the input and output
buffer by multiplying its set of global IDs with those slice sizes.

Test Plan
This commit enables the existing non-zero axis BatchedReduceAdd unit
test for OpenCL and modifies it test axis=2. All unit tests pass.

Copy link

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SplitInfinity has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@SplitInfinity SplitInfinity requested a review from opti-mix May 22, 2019 23:25
@stale
Copy link

stale bot commented Jun 6, 2019

This issue has been automatically marked as stale because it has not had recent activity. It will be closed in 15 days if no further activity occurs. Thank you for your contributions.

Copy link
Contributor

@nickgg nickgg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How confident are you the backedReduceAddWithAxis_Float test is enough coverage?

std::vector<size_t> batchSliceSizes(
numBatchDims > 1 ? numBatchDims - 1 : 1, 1);
size_t currentSliceSize = 1, axisSliceSize = 1;
for (size_t i = 1, j = batchSliceSizes.size() - 1, e = batchDims.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable names are a bit terse here, would prefer end and index.

Description:
This commit extends the implementation of BatchedReduceAdd in the OpenCL backend
so that it can handle any reduction axis, not just 1. This can be useful
when the first dimension is the batch dimension, and a reduction needs
to be performed within each example.

The existing implementation for axis = 0 computes each slice element in
parallel and linearizes the slice for simplicity (i.e. creates a 1D
global workspace as large as the number of elements in the output). This
implementation generalizes this concept by creating a global workspace
with rank equal to the number of dimensions of the output and computing
each one in parallel. The slice sizes of the input and output shapes are
are precomputed on the host and passed in as kernel arguments so that
the kernel can compute the correct offsets into the input and output
buffer by multiplying its set of global IDs with those slice sizes.

Testing:
This commit enables the existing non-zero axis BatchedReduceAdd unit
test for OpenCL and augments it to test axis=2. All unit tests pass.
@SplitInfinity SplitInfinity force-pushed the any-dim-ocl-batched-reduce-add branch from b4c8146 to 5ff5f50 Compare June 13, 2019 01:25
@SplitInfinity
Copy link
Contributor Author

SplitInfinity commented Jun 13, 2019

backedReduceAddWithAxis_Float tests only axis=1 so I modified it to test axis=2 (which in this case is the last axis, another thing that would be good to test). We have other tests for axis=0 and to test the case that the result is zero-dimensional, so I'm reasonably confident now.

Copy link

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SplitInfinity has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@nickgg nickgg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one

@facebook-github-bot
Copy link

@SplitInfinity merged this pull request in 086fd84.

SplitInfinity pushed a commit to SplitInfinity/glow that referenced this pull request Jun 14, 2019
Description:
This commit fixes two bugs in the OpenCL implementation of
BatchedReduceAddInst and adds a few comments for clarity.

The first is a segmentation fault caused by
incorporating feedback on pytorch#2958. A suggestion was made to make the loop
variable i in the loop that computes batchSliceSizes count down instead of
count up, but this was done without changing the type (which was size_t,
an unsigned type), so the loop never terminates and eventually leads to a
segmentation fault.

The second bug is an incorrect computation of destSliceSizes. Instead of
multiplying the slice size at a dimension with the number of elements in
that same dimension, the code was multiplying the former with the number
of elements in the *adjacent* dimension. This was surfaced by the unit
test added in pytorch#2958 for axis = 2.

Testing:
1) ninja check with OpenCL enabled, DEBUG mode
2) ninja check with OpenCL enabled, RELEASE mode
3) ninja check with OpenCL enabled, ASAN+UBSAN mode
facebook-github-bot pushed a commit that referenced this pull request Jun 15, 2019
Summary:
**Description**
This commit fixes two bugs in the OpenCL implementation of
`BatchedReduceAddInst` and adds a few comments for clarity.

The first is a segmentation fault caused by
incorporating feedback on #2958. A suggestion was made to make the loop
variable `i` in the loop that computes `batchSliceSizes` count down instead of
count up, but this suggestion was taken without changing the type (which was `size_t`,
an unsigned type), so the loop never terminates and eventually leads to a
segmentation fault.

The second bug is an incorrect computation of `destSliceSizes`. Instead of
multiplying the slice size at a dimension with the number of elements in
that same dimension, the code was multiplying the former with the number
of elements in the *adjacent* dimension. This was surfaced by the unit
test added in #2958 for `axis = 2`.

**Test Plan**
1) `ninja check` with OpenCL enabled, DEBUG mode

```
      Start  1: BackendCorrectnessTest
 1/34 Test  #1: BackendCorrectnessTest ..............   Passed   21.28 sec
      Start  2: BackendTest
 2/34 Test  #2: BackendTest .........................   Passed    1.97 sec
      Start  3: BasicIRTest
 3/34 Test  #3: BasicIRTest .........................   Passed    0.05 sec
      Start  4: Caffe2ImporterTest
 4/34 Test  #4: Caffe2ImporterTest ..................   Passed    3.00 sec
      Start  5: DeviceManagerTest
 5/34 Test  #5: DeviceManagerTest ...................   Passed    0.76 sec
      Start  6: ThreadPoolExecutorTest
 6/34 Test  #6: ThreadPoolExecutorTest ..............   Passed    1.48 sec
      Start  7: Float16Test
 7/34 Test  #7: Float16Test .........................   Passed    0.01 sec
      Start  8: GemmTest
 8/34 Test  #8: GemmTest ............................   Passed    0.05 sec
      Start  9: GlowOnnxifiManagerTest
 9/34 Test  #9: GlowOnnxifiManagerTest ..............   Passed    0.06 sec
      Start 10: GradCheckTest
10/34 Test #10: GradCheckTest .......................   Passed    4.72 sec
      Start 11: GraphGradTest
11/34 Test #11: GraphGradTest .......................   Passed    0.06 sec
      Start 12: GraphOptzTest
12/34 Test #12: GraphOptzTest .......................   Passed    0.03 sec
      Start 13: GraphSchedulerTest
13/34 Test #13: GraphSchedulerTest ..................   Passed    0.01 sec
      Start 14: GraphTest
14/34 Test #14: GraphTest ...........................   Passed    1.03 sec
      Start 15: HostManagerTest
15/34 Test #15: HostManagerTest .....................   Passed    7.49 sec
      Start 16: HyphenTest
16/34 Test #16: HyphenTest ..........................   Passed    1.17 sec
      Start 17: IROptTest
17/34 Test #17: IROptTest ...........................   Passed    0.01 sec
      Start 18: ImageTest
18/34 Test #18: ImageTest ...........................   Passed    0.31 sec
      Start 19: LLVMIRGenTest
19/34 Test #19: LLVMIRGenTest .......................   Passed    0.01 sec
      Start 20: MLTest
20/34 Test #20: MLTest ..............................   Passed   46.30 sec
      Start 21: MemoryAllocatorTest
21/34 Test #21: MemoryAllocatorTest .................   Passed    0.03 sec
      Start 22: OCLTest
22/34 Test #22: OCLTest .............................   Passed    0.24 sec
      Start 23: OnnxImporterTest
23/34 Test #23: OnnxImporterTest ....................   Passed    0.12 sec
      Start 24: OperatorGradTest
24/34 Test #24: OperatorGradTest ....................   Passed    0.05 sec
      Start 25: OperatorTest
25/34 Test #25: OperatorTest ........................   Passed   14.47 sec
      Start 26: PartitionerTest
26/34 Test #26: PartitionerTest .....................   Passed    0.05 sec
      Start 28: ProvisionerTest
27/34 Test #28: ProvisionerTest .....................   Passed    1.00 sec
      Start 29: QuantizationTest
28/34 Test #29: QuantizationTest ....................   Passed    7.46 sec
      Start 30: TensorsTest
29/34 Test #30: TensorsTest .........................   Passed    0.36 sec
      Start 31: TensorPoolTest
30/34 Test #31: TensorPoolTest ......................   Passed    0.01 sec
      Start 32: ThreadPoolTest
31/34 Test #32: ThreadPoolTest ......................   Passed    0.01 sec
      Start 33: TraceEventsTest
32/34 Test #33: TraceEventsTest .....................   Passed   10.62 sec
      Start 34: TypeAToTypeBFunctionConverterTest
33/34 Test #34: TypeAToTypeBFunctionConverterTest ...   Passed    0.06 sec
      Start 35: UtilsTest
34/34 Test #35: UtilsTest ...........................   Passed    0.02 sec

100% tests passed, 0 tests failed out of 34

Total Test time (real) = 124.33 sec
```

2) `ninja check` with OpenCL enabled, RELEASE mode
```
      Start  1: BackendCorrectnessTest
 1/34 Test  #1: BackendCorrectnessTest ..............   Passed   11.51 sec
      Start  2: BackendTest
 2/34 Test  #2: BackendTest .........................   Passed    1.53 sec
      Start  3: BasicIRTest
 3/34 Test  #3: BasicIRTest .........................   Passed    0.02 sec
      Start  4: Caffe2ImporterTest
 4/34 Test  #4: Caffe2ImporterTest ..................   Passed    0.62 sec
      Start  5: DeviceManagerTest
 5/34 Test  #5: DeviceManagerTest ...................   Passed    0.83 sec
      Start  6: ThreadPoolExecutorTest
 6/34 Test  #6: ThreadPoolExecutorTest ..............   Passed    0.71 sec
      Start  7: Float16Test
 7/34 Test  #7: Float16Test .........................   Passed    0.01 sec
      Start  8: GemmTest
 8/34 Test  #8: GemmTest ............................   Passed    0.31 sec
      Start  9: GlowOnnxifiManagerTest
 9/34 Test  #9: GlowOnnxifiManagerTest ..............   Passed    0.33 sec
      Start 10: GradCheckTest
10/34 Test #10: GradCheckTest .......................   Passed    1.90 sec
      Start 11: GraphGradTest
11/34 Test #11: GraphGradTest .......................   Passed    0.32 sec
      Start 12: GraphOptzTest
12/34 Test #12: GraphOptzTest .......................   Passed    0.03 sec
      Start 13: GraphSchedulerTest
13/34 Test #13: GraphSchedulerTest ..................   Passed    0.02 sec
      Start 14: GraphTest
14/34 Test #14: GraphTest ...........................   Passed    0.59 sec
      Start 15: HostManagerTest
15/34 Test #15: HostManagerTest .....................   Passed   10.61 sec
      Start 16: HyphenTest
16/34 Test #16: HyphenTest ..........................   Passed    4.18 sec
      Start 17: IROptTest
17/34 Test #17: IROptTest ...........................   Passed    0.04 sec
      Start 18: ImageTest
18/34 Test #18: ImageTest ...........................   Passed    0.10 sec
      Start 19: LLVMIRGenTest
19/34 Test #19: LLVMIRGenTest .......................   Passed    0.71 sec
      Start 20: MLTest
20/34 Test #20: MLTest ..............................   Passed   52.44 sec
      Start 21: MemoryAllocatorTest
21/34 Test #21: MemoryAllocatorTest .................   Passed    0.03 sec
      Start 22: OCLTest
22/34 Test #22: OCLTest .............................   Passed    0.96 sec
      Start 23: OnnxImporterTest
23/34 Test #23: OnnxImporterTest ....................   Passed    0.89 sec
      Start 24: OperatorGradTest
24/34 Test #24: OperatorGradTest ....................   Passed    0.76 sec
      Start 25: OperatorTest
25/34 Test #25: OperatorTest ........................   Passed   33.00 sec
      Start 26: PartitionerTest
26/34 Test #26: PartitionerTest .....................   Passed    0.79 sec
      Start 28: ProvisionerTest
27/34 Test #28: ProvisionerTest .....................   Passed    3.00 sec
      Start 29: QuantizationTest
28/34 Test #29: QuantizationTest ....................   Passed   19.64 sec
      Start 30: TensorsTest
29/34 Test #30: TensorsTest .........................   Passed    0.09 sec
      Start 31: TensorPoolTest
30/34 Test #31: TensorPoolTest ......................   Passed    0.04 sec
      Start 32: ThreadPoolTest
31/34 Test #32: ThreadPoolTest ......................   Passed    0.04 sec
      Start 33: TraceEventsTest
32/34 Test #33: TraceEventsTest .....................   Passed   13.18 sec
      Start 34: TypeAToTypeBFunctionConverterTest
33/34 Test #34: TypeAToTypeBFunctionConverterTest ...   Passed    0.87 sec
      Start 35: UtilsTest
34/34 Test #35: UtilsTest ...........................   Passed    0.04 sec

100% tests passed, 0 tests failed out of 34

Total Test time (real) = 160.15 sec
```
3) `ninja check` with OpenCL enabled, ASAN+UBSAN mode
```
      Start  1: BackendCorrectnessTest
 1/34 Test  #1: BackendCorrectnessTest ..............   Passed   65.05 sec
      Start  2: BackendTest
 2/34 Test  #2: BackendTest .........................   Passed    5.42 sec
      Start  3: BasicIRTest
 3/34 Test  #3: BasicIRTest .........................   Passed    0.09 sec
      Start  4: Caffe2ImporterTest
 4/34 Test  #4: Caffe2ImporterTest ..................   Passed   11.51 sec
      Start  5: DeviceManagerTest
 5/34 Test  #5: DeviceManagerTest ...................   Passed    1.93 sec
      Start  6: ThreadPoolExecutorTest
 6/34 Test  #6: ThreadPoolExecutorTest ..............   Passed    5.08 sec
      Start  7: Float16Test
 7/34 Test  #7: Float16Test .........................   Passed    0.03 sec
      Start  8: GemmTest
 8/34 Test  #8: GemmTest ............................   Passed    0.22 sec
      Start  9: GlowOnnxifiManagerTest
 9/34 Test  #9: GlowOnnxifiManagerTest ..............   Passed    0.18 sec
      Start 10: GradCheckTest
10/34 Test #10: GradCheckTest .......................   Passed   15.40 sec
      Start 11: GraphGradTest
11/34 Test #11: GraphGradTest .......................   Passed    0.22 sec
      Start 12: GraphOptzTest
12/34 Test #12: GraphOptzTest .......................   Passed    0.12 sec
      Start 13: GraphSchedulerTest
13/34 Test #13: GraphSchedulerTest ..................   Passed    0.03 sec
      Start 14: GraphTest
14/34 Test #14: GraphTest ...........................   Passed    3.00 sec
      Start 15: HostManagerTest
15/34 Test #15: HostManagerTest .....................   Passed   13.79 sec
      Start 16: HyphenTest
16/34 Test #16: HyphenTest ..........................   Passed    3.47 sec
      Start 17: IROptTest
17/34 Test #17: IROptTest ...........................   Passed    0.05 sec
      Start 18: ImageTest
18/34 Test #18: ImageTest ...........................   Passed    1.08 sec
      Start 19: LLVMIRGenTest
19/34 Test #19: LLVMIRGenTest .......................   Passed    0.05 sec
      Start 20: MLTest
20/34 Test #20: MLTest ..............................   Passed  141.01 sec
      Start 21: MemoryAllocatorTest
21/34 Test #21: MemoryAllocatorTest .................   Passed    0.08 sec
      Start 22: OCLTest
22/34 Test #22: OCLTest .............................   Passed    0.64 sec
      Start 23: OnnxImporterTest
23/34 Test #23: OnnxImporterTest ....................   Passed    0.51 sec
      Start 24: OperatorGradTest
24/34 Test #24: OperatorGradTest ....................   Passed    0.14 sec
      Start 25: OperatorTest
25/34 Test #25: OperatorTest ........................   Passed   35.78 sec
      Start 26: PartitionerTest
26/34 Test #26: PartitionerTest .....................   Passed    0.20 sec
      Start 28: ProvisionerTest
27/34 Test #28: ProvisionerTest .....................   Passed    2.25 sec
      Start 29: QuantizationTest
28/34 Test #29: QuantizationTest ....................   Passed   17.17 sec
      Start 30: TensorsTest
29/34 Test #30: TensorsTest .........................   Passed    1.28 sec
      Start 31: TensorPoolTest
30/34 Test #31: TensorPoolTest ......................   Passed    0.03 sec
      Start 32: ThreadPoolTest
31/34 Test #32: ThreadPoolTest ......................   Passed    0.05 sec
      Start 33: TraceEventsTest
32/34 Test #33: TraceEventsTest .....................   Passed   32.11 sec
      Start 34: TypeAToTypeBFunctionConverterTest
33/34 Test #34: TypeAToTypeBFunctionConverterTest ...   Passed    0.15 sec
      Start 35: UtilsTest
34/34 Test #35: UtilsTest ...........................   Passed    0.07 sec

100% tests passed, 0 tests failed out of 34

Total Test time (real) = 358.24 sec
```
Pull Request resolved: #3118

Differential Revision: D15836207

Pulled By: SplitInfinity

fbshipit-source-id: 7bfa3c6ed5583d6a8f42b1f712f359e8e1d10b47
@SplitInfinity SplitInfinity deleted the any-dim-ocl-batched-reduce-add branch July 11, 2019 18:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants