-
Notifications
You must be signed in to change notification settings - Fork 699
[OpenCL] Implement BatchedReduceAdd for arbitrary axes #2958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OpenCL] Implement BatchedReduceAdd for arbitrary axes #2958
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SplitInfinity has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed in 15 days if no further activity occurs. Thank you for your contributions. |
326dd93
to
b4c8146
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How confident are you the backedReduceAddWithAxis_Float test is enough coverage?
lib/Backends/OpenCL/OpenCL.cpp
Outdated
std::vector<size_t> batchSliceSizes( | ||
numBatchDims > 1 ? numBatchDims - 1 : 1, 1); | ||
size_t currentSliceSize = 1, axisSliceSize = 1; | ||
for (size_t i = 1, j = batchSliceSizes.size() - 1, e = batchDims.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable names are a bit terse here, would prefer end
and index
.
Description: This commit extends the implementation of BatchedReduceAdd in the OpenCL backend so that it can handle any reduction axis, not just 1. This can be useful when the first dimension is the batch dimension, and a reduction needs to be performed within each example. The existing implementation for axis = 0 computes each slice element in parallel and linearizes the slice for simplicity (i.e. creates a 1D global workspace as large as the number of elements in the output). This implementation generalizes this concept by creating a global workspace with rank equal to the number of dimensions of the output and computing each one in parallel. The slice sizes of the input and output shapes are are precomputed on the host and passed in as kernel arguments so that the kernel can compute the correct offsets into the input and output buffer by multiplying its set of global IDs with those slice sizes. Testing: This commit enables the existing non-zero axis BatchedReduceAdd unit test for OpenCL and augments it to test axis=2. All unit tests pass.
b4c8146
to
5ff5f50
Compare
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SplitInfinity has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice one
@SplitInfinity merged this pull request in 086fd84. |
Description: This commit fixes two bugs in the OpenCL implementation of BatchedReduceAddInst and adds a few comments for clarity. The first is a segmentation fault caused by incorporating feedback on pytorch#2958. A suggestion was made to make the loop variable i in the loop that computes batchSliceSizes count down instead of count up, but this was done without changing the type (which was size_t, an unsigned type), so the loop never terminates and eventually leads to a segmentation fault. The second bug is an incorrect computation of destSliceSizes. Instead of multiplying the slice size at a dimension with the number of elements in that same dimension, the code was multiplying the former with the number of elements in the *adjacent* dimension. This was surfaced by the unit test added in pytorch#2958 for axis = 2. Testing: 1) ninja check with OpenCL enabled, DEBUG mode 2) ninja check with OpenCL enabled, RELEASE mode 3) ninja check with OpenCL enabled, ASAN+UBSAN mode
Summary: **Description** This commit fixes two bugs in the OpenCL implementation of `BatchedReduceAddInst` and adds a few comments for clarity. The first is a segmentation fault caused by incorporating feedback on #2958. A suggestion was made to make the loop variable `i` in the loop that computes `batchSliceSizes` count down instead of count up, but this suggestion was taken without changing the type (which was `size_t`, an unsigned type), so the loop never terminates and eventually leads to a segmentation fault. The second bug is an incorrect computation of `destSliceSizes`. Instead of multiplying the slice size at a dimension with the number of elements in that same dimension, the code was multiplying the former with the number of elements in the *adjacent* dimension. This was surfaced by the unit test added in #2958 for `axis = 2`. **Test Plan** 1) `ninja check` with OpenCL enabled, DEBUG mode ``` Start 1: BackendCorrectnessTest 1/34 Test #1: BackendCorrectnessTest .............. Passed 21.28 sec Start 2: BackendTest 2/34 Test #2: BackendTest ......................... Passed 1.97 sec Start 3: BasicIRTest 3/34 Test #3: BasicIRTest ......................... Passed 0.05 sec Start 4: Caffe2ImporterTest 4/34 Test #4: Caffe2ImporterTest .................. Passed 3.00 sec Start 5: DeviceManagerTest 5/34 Test #5: DeviceManagerTest ................... Passed 0.76 sec Start 6: ThreadPoolExecutorTest 6/34 Test #6: ThreadPoolExecutorTest .............. Passed 1.48 sec Start 7: Float16Test 7/34 Test #7: Float16Test ......................... Passed 0.01 sec Start 8: GemmTest 8/34 Test #8: GemmTest ............................ Passed 0.05 sec Start 9: GlowOnnxifiManagerTest 9/34 Test #9: GlowOnnxifiManagerTest .............. Passed 0.06 sec Start 10: GradCheckTest 10/34 Test #10: GradCheckTest ....................... Passed 4.72 sec Start 11: GraphGradTest 11/34 Test #11: GraphGradTest ....................... Passed 0.06 sec Start 12: GraphOptzTest 12/34 Test #12: GraphOptzTest ....................... Passed 0.03 sec Start 13: GraphSchedulerTest 13/34 Test #13: GraphSchedulerTest .................. Passed 0.01 sec Start 14: GraphTest 14/34 Test #14: GraphTest ........................... Passed 1.03 sec Start 15: HostManagerTest 15/34 Test #15: HostManagerTest ..................... Passed 7.49 sec Start 16: HyphenTest 16/34 Test #16: HyphenTest .......................... Passed 1.17 sec Start 17: IROptTest 17/34 Test #17: IROptTest ........................... Passed 0.01 sec Start 18: ImageTest 18/34 Test #18: ImageTest ........................... Passed 0.31 sec Start 19: LLVMIRGenTest 19/34 Test #19: LLVMIRGenTest ....................... Passed 0.01 sec Start 20: MLTest 20/34 Test #20: MLTest .............................. Passed 46.30 sec Start 21: MemoryAllocatorTest 21/34 Test #21: MemoryAllocatorTest ................. Passed 0.03 sec Start 22: OCLTest 22/34 Test #22: OCLTest ............................. Passed 0.24 sec Start 23: OnnxImporterTest 23/34 Test #23: OnnxImporterTest .................... Passed 0.12 sec Start 24: OperatorGradTest 24/34 Test #24: OperatorGradTest .................... Passed 0.05 sec Start 25: OperatorTest 25/34 Test #25: OperatorTest ........................ Passed 14.47 sec Start 26: PartitionerTest 26/34 Test #26: PartitionerTest ..................... Passed 0.05 sec Start 28: ProvisionerTest 27/34 Test #28: ProvisionerTest ..................... Passed 1.00 sec Start 29: QuantizationTest 28/34 Test #29: QuantizationTest .................... Passed 7.46 sec Start 30: TensorsTest 29/34 Test #30: TensorsTest ......................... Passed 0.36 sec Start 31: TensorPoolTest 30/34 Test #31: TensorPoolTest ...................... Passed 0.01 sec Start 32: ThreadPoolTest 31/34 Test #32: ThreadPoolTest ...................... Passed 0.01 sec Start 33: TraceEventsTest 32/34 Test #33: TraceEventsTest ..................... Passed 10.62 sec Start 34: TypeAToTypeBFunctionConverterTest 33/34 Test #34: TypeAToTypeBFunctionConverterTest ... Passed 0.06 sec Start 35: UtilsTest 34/34 Test #35: UtilsTest ........................... Passed 0.02 sec 100% tests passed, 0 tests failed out of 34 Total Test time (real) = 124.33 sec ``` 2) `ninja check` with OpenCL enabled, RELEASE mode ``` Start 1: BackendCorrectnessTest 1/34 Test #1: BackendCorrectnessTest .............. Passed 11.51 sec Start 2: BackendTest 2/34 Test #2: BackendTest ......................... Passed 1.53 sec Start 3: BasicIRTest 3/34 Test #3: BasicIRTest ......................... Passed 0.02 sec Start 4: Caffe2ImporterTest 4/34 Test #4: Caffe2ImporterTest .................. Passed 0.62 sec Start 5: DeviceManagerTest 5/34 Test #5: DeviceManagerTest ................... Passed 0.83 sec Start 6: ThreadPoolExecutorTest 6/34 Test #6: ThreadPoolExecutorTest .............. Passed 0.71 sec Start 7: Float16Test 7/34 Test #7: Float16Test ......................... Passed 0.01 sec Start 8: GemmTest 8/34 Test #8: GemmTest ............................ Passed 0.31 sec Start 9: GlowOnnxifiManagerTest 9/34 Test #9: GlowOnnxifiManagerTest .............. Passed 0.33 sec Start 10: GradCheckTest 10/34 Test #10: GradCheckTest ....................... Passed 1.90 sec Start 11: GraphGradTest 11/34 Test #11: GraphGradTest ....................... Passed 0.32 sec Start 12: GraphOptzTest 12/34 Test #12: GraphOptzTest ....................... Passed 0.03 sec Start 13: GraphSchedulerTest 13/34 Test #13: GraphSchedulerTest .................. Passed 0.02 sec Start 14: GraphTest 14/34 Test #14: GraphTest ........................... Passed 0.59 sec Start 15: HostManagerTest 15/34 Test #15: HostManagerTest ..................... Passed 10.61 sec Start 16: HyphenTest 16/34 Test #16: HyphenTest .......................... Passed 4.18 sec Start 17: IROptTest 17/34 Test #17: IROptTest ........................... Passed 0.04 sec Start 18: ImageTest 18/34 Test #18: ImageTest ........................... Passed 0.10 sec Start 19: LLVMIRGenTest 19/34 Test #19: LLVMIRGenTest ....................... Passed 0.71 sec Start 20: MLTest 20/34 Test #20: MLTest .............................. Passed 52.44 sec Start 21: MemoryAllocatorTest 21/34 Test #21: MemoryAllocatorTest ................. Passed 0.03 sec Start 22: OCLTest 22/34 Test #22: OCLTest ............................. Passed 0.96 sec Start 23: OnnxImporterTest 23/34 Test #23: OnnxImporterTest .................... Passed 0.89 sec Start 24: OperatorGradTest 24/34 Test #24: OperatorGradTest .................... Passed 0.76 sec Start 25: OperatorTest 25/34 Test #25: OperatorTest ........................ Passed 33.00 sec Start 26: PartitionerTest 26/34 Test #26: PartitionerTest ..................... Passed 0.79 sec Start 28: ProvisionerTest 27/34 Test #28: ProvisionerTest ..................... Passed 3.00 sec Start 29: QuantizationTest 28/34 Test #29: QuantizationTest .................... Passed 19.64 sec Start 30: TensorsTest 29/34 Test #30: TensorsTest ......................... Passed 0.09 sec Start 31: TensorPoolTest 30/34 Test #31: TensorPoolTest ...................... Passed 0.04 sec Start 32: ThreadPoolTest 31/34 Test #32: ThreadPoolTest ...................... Passed 0.04 sec Start 33: TraceEventsTest 32/34 Test #33: TraceEventsTest ..................... Passed 13.18 sec Start 34: TypeAToTypeBFunctionConverterTest 33/34 Test #34: TypeAToTypeBFunctionConverterTest ... Passed 0.87 sec Start 35: UtilsTest 34/34 Test #35: UtilsTest ........................... Passed 0.04 sec 100% tests passed, 0 tests failed out of 34 Total Test time (real) = 160.15 sec ``` 3) `ninja check` with OpenCL enabled, ASAN+UBSAN mode ``` Start 1: BackendCorrectnessTest 1/34 Test #1: BackendCorrectnessTest .............. Passed 65.05 sec Start 2: BackendTest 2/34 Test #2: BackendTest ......................... Passed 5.42 sec Start 3: BasicIRTest 3/34 Test #3: BasicIRTest ......................... Passed 0.09 sec Start 4: Caffe2ImporterTest 4/34 Test #4: Caffe2ImporterTest .................. Passed 11.51 sec Start 5: DeviceManagerTest 5/34 Test #5: DeviceManagerTest ................... Passed 1.93 sec Start 6: ThreadPoolExecutorTest 6/34 Test #6: ThreadPoolExecutorTest .............. Passed 5.08 sec Start 7: Float16Test 7/34 Test #7: Float16Test ......................... Passed 0.03 sec Start 8: GemmTest 8/34 Test #8: GemmTest ............................ Passed 0.22 sec Start 9: GlowOnnxifiManagerTest 9/34 Test #9: GlowOnnxifiManagerTest .............. Passed 0.18 sec Start 10: GradCheckTest 10/34 Test #10: GradCheckTest ....................... Passed 15.40 sec Start 11: GraphGradTest 11/34 Test #11: GraphGradTest ....................... Passed 0.22 sec Start 12: GraphOptzTest 12/34 Test #12: GraphOptzTest ....................... Passed 0.12 sec Start 13: GraphSchedulerTest 13/34 Test #13: GraphSchedulerTest .................. Passed 0.03 sec Start 14: GraphTest 14/34 Test #14: GraphTest ........................... Passed 3.00 sec Start 15: HostManagerTest 15/34 Test #15: HostManagerTest ..................... Passed 13.79 sec Start 16: HyphenTest 16/34 Test #16: HyphenTest .......................... Passed 3.47 sec Start 17: IROptTest 17/34 Test #17: IROptTest ........................... Passed 0.05 sec Start 18: ImageTest 18/34 Test #18: ImageTest ........................... Passed 1.08 sec Start 19: LLVMIRGenTest 19/34 Test #19: LLVMIRGenTest ....................... Passed 0.05 sec Start 20: MLTest 20/34 Test #20: MLTest .............................. Passed 141.01 sec Start 21: MemoryAllocatorTest 21/34 Test #21: MemoryAllocatorTest ................. Passed 0.08 sec Start 22: OCLTest 22/34 Test #22: OCLTest ............................. Passed 0.64 sec Start 23: OnnxImporterTest 23/34 Test #23: OnnxImporterTest .................... Passed 0.51 sec Start 24: OperatorGradTest 24/34 Test #24: OperatorGradTest .................... Passed 0.14 sec Start 25: OperatorTest 25/34 Test #25: OperatorTest ........................ Passed 35.78 sec Start 26: PartitionerTest 26/34 Test #26: PartitionerTest ..................... Passed 0.20 sec Start 28: ProvisionerTest 27/34 Test #28: ProvisionerTest ..................... Passed 2.25 sec Start 29: QuantizationTest 28/34 Test #29: QuantizationTest .................... Passed 17.17 sec Start 30: TensorsTest 29/34 Test #30: TensorsTest ......................... Passed 1.28 sec Start 31: TensorPoolTest 30/34 Test #31: TensorPoolTest ...................... Passed 0.03 sec Start 32: ThreadPoolTest 31/34 Test #32: ThreadPoolTest ...................... Passed 0.05 sec Start 33: TraceEventsTest 32/34 Test #33: TraceEventsTest ..................... Passed 32.11 sec Start 34: TypeAToTypeBFunctionConverterTest 33/34 Test #34: TypeAToTypeBFunctionConverterTest ... Passed 0.15 sec Start 35: UtilsTest 34/34 Test #35: UtilsTest ........................... Passed 0.07 sec 100% tests passed, 0 tests failed out of 34 Total Test time (real) = 358.24 sec ``` Pull Request resolved: #3118 Differential Revision: D15836207 Pulled By: SplitInfinity fbshipit-source-id: 7bfa3c6ed5583d6a8f42b1f712f359e8e1d10b47
Description
This commit extends the implementation of
BatchedReduceAdd
in the OpenCL backendso that it can handle any reduction axis, not just 1. This can be useful
when the first dimension is the batch dimension, and a reduction needs
to be performed within each example.
The existing implementation for axis = 0 computes each slice element in
parallel and linearizes the slice for simplicity (i.e. creates a 1D
global workspace as large as the number of elements in the output). This
implementation generalizes this concept by creating a global workspace
with rank equal to the number of dimensions of the output and computing
each one in parallel. The slice sizes of the input and output shapes are
are precomputed on the host and passed in as kernel arguments so that
the kernel can compute the correct offsets into the input and output
buffer by multiplying its set of global IDs with those slice sizes.
Test Plan
This commit enables the existing non-zero axis
BatchedReduceAdd
unittest for OpenCL and modifies it test
axis=2
. All unit tests pass.