-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[MPS] Improve runtime complexity of roi_align
#9100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9100
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
c4b01c0
to
34d749d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a small script which will measure the time difference directly between the old roi pool and the new one? The one in the main thread is a bit confusing to me since the first section has no roi_pool and the 2nd one does.
Also about:
"One concern I have with the approach I'm proposing here is numeric overflow of the index with large input sizes."
Have you tested it out on larger input sizes and tested against CPU that this implementation produces equivalent results?
@@ -225,105 +225,96 @@ kernel void nms<DTYPE ## 4, DTYPE>( \ | |||
uint2 tgid [[threadgroup_position_in_grid]], \ | |||
uint2 tid2 [[thread_position_in_threadgroup]]); | |||
template<typename T, typename integer_t> | |||
template <typename T, typename integer_t> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need templating here for integer_t? From what I see it just registers two of such op:
REGISTER_ROI_ALIGN_OP(float, int64_t);
REGISTER_ROI_ALIGN_OP(half, int64_t);
Both of which are int64_t so maybe we can remove it? I know it wasn't added in this PR but would be a nice thing to add to it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM. Fixed.
Thanks a lot for the review!
I agree that the perf outputs from the first comment is a bit confusing. The culprit looks like it's I added a regression test
I've tested it with values generating Above 2^31 I get a crash on CPU with the error
Indexing into the tensor I get valid output eg. for
but I don't trust the results to be numerically correct - especially considering
These errors can be triggered by setting
Should we add a check on output_size against INT_MAX for MPS? We should probably add a check in CPU as well to prevent a crash, but I consider it out of scope for this PR. cc @Isalia20 |
roi_align
on MPS has significantly inflated runtime complexity due to a bug in the looping behavior of the kernel. I've not found any other correctness issues with the current implementation, which closely follows the CUDA implementation. This PR fixes the runtime complexity, otherwise the kernel is semantically identical to before.Note that this PR switches the dispatching to
dispatchThreads
, which has a tighter build target set thandispatchThreadgroups
. RefNonuniform threadgroup size
in Metal feature set tables.Some other MPS kernels in vision is also likely affected.
Running the example code from pytorch/pytorch#124850 (comment) before:
and after
One concern I have with the approach I'm proposing here is numeric overflow of the index with large input sizes.
Fixes pytorch/pytorch#124850
cc @malfet @kulinseth @qqaatw