-
Notifications
You must be signed in to change notification settings - Fork 24.3k
migrate PReLU to ATen #11758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
migrate PReLU to ATen #11758
Conversation
d5c5bd3
to
9496032
Compare
456e60f
to
ede17ee
Compare
Not sure about the perf, but shouldn't you delete the old THNN implementation now? |
Can we get a billing of changes, i.e., what you changed when you did the port? |
a85d2f9
to
b9909af
Compare
What's the plan with |
@fmassa it will replace both CPU and GPU tensor apply |
Should we instead try using |
@ezyang sure, I took a closer look at the forward implementation at CPU when weight.numel() = 1, and it looks the same to me. Here's the previous implementation: pytorch/aten/src/THNN/generic/PReLU.c Lines 5 to 23 in 3da8d71
|
auto strides = input.strides(); | ||
|
||
// case1: shared weight for all channels | ||
if (weight_num == 1) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Tensor weight_grad_collector = at::empty_like(input); | ||
|
||
// case1: shared parameter for all channels | ||
if (weight_num == 1) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@fmassa I wish I could use TensorIterator instead, but I haven't found a direct replacement to |
aten/src/ATen/native/Activation.cpp
Outdated
for (j = 0; j < channel_size; j++) { | ||
for (k = 0; k < input_stride1; k++) { | ||
int64_t pos = i * input_stride0 + j * input_stride1 + k; | ||
result_data[pos] = (input_data[pos] > 0) ? input_data[pos] : weight_data[j] * input_data[pos]; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Activation.cpp
Outdated
input_stride0, | ||
input_stride1); | ||
}); | ||
// update weight_grad |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
input_stride1, | ||
input_numel); | ||
}); | ||
// update weight_grad |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
By "CPU_tensor_apply" you all are referring to TH? We have an equivalent within ATen as well. Presumably that will be replaced by TensorIterator as well? @colesbury |
In addition to this we also have Parallel.h that allows us to abstract the OMP calls. |
aten/src/ATen/native/Activation.cpp
Outdated
#pragma omp parallel for private(i) if (input_numel > 1000) | ||
for (i = 0; i < input_numel; i++) { | ||
scalar_t input_data_val = input_data[i]; | ||
result_data[i] = (input_data_val > 0) ? input_data_val : weight_val * input_data_val; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Activation.cpp
Outdated
// multiply values at each channel with weight[channel] | ||
#pragma omp parallel for private(i) if (input_numel > 1000) | ||
for (i = 0; i < input_numel; i++) { | ||
int64_t channel = (i % input_stride0) / input_stride1; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
aten/src/ATen/native/Activation.cpp
Outdated
for (i = 0; i < input_numel; i++) { | ||
scalar_t input_data_val = input_data[i]; | ||
scalar_t grad_out_data_val = grad_out_data[i]; | ||
input_grad_data[i] = (input_data_val > 0) ? grad_out_data_val : weight_val * grad_out_data_val; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Another thing to (potentially) consider here is writing this in terms of vec256 within native/cpu for the CPU kernel - see #11565 as an example. This is much more involved however and mostly benefits CPU-bound kernels. |
@cpuhrsch I extracted the kernel and confirmed that I was able to get the compiler to generate fully vectorized code up to the AVX2 using the suggestions I posted. |
@resistor - the only issue is that this won't be compiled with -mavx/-mavx2 etc. if it doesn't live within native/cpu for the OSS version, where we rely on a dispatch to not assume avx/avx2 capabilities for any CPU. But we can easily move it into there and use the autovectorization. Having said that, just because the compiler vectorizes it doesn't mean it's faster, because it might try to maintain invariants we don't care about. For example, we might want to tradeoff adding one number at a time with adding 8 floats to 8 adders at a time. This won't be the same results, but it's much faster and will still be reproducible. |
@cpuhrsch - Makes sense re: dispatch. re: autovectorization, for this particular kernel there's not a lot of room for variation, since's it's a purely element-wise filter. Getting the compiler to do the right thing for an accumulator loop is a lot trickier and more fragile. |
@resistor - on element-wise operations we also have vml.h that defines a set of functions that can act as an interface to Intel's VML, or a loop + SLEEF if that's not available. Intel VML is much faster and also parallelizes (but of course doesn't let you merge operations). It's also worth checking whether some parts here can be gutted and replaced by Ops within MKL/MKL-DNN/iDeep etc. Having said this, in my mind, none of this is required for this and there's still value in dumping and simplifying code from TH/THNN and putting it into native. |
After making some changes on the CPU kernel (removing mod & div, split lines for compiler optimization, etc..), now the runtime looks much nicer:
Sorry I was looking at the wrong results. Just updated... |
@cpuhrsch yeah, |
@cpuhrsch by |
auto strides = input.strides(); | ||
|
||
// case1: shared weight for all channels | ||
if (weight_num == 1) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
at::cuda::CUDA_tensor_apply2<scalar_t, scalar_t>( | ||
input, | ||
result, | ||
[=] __device__ ( |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
2. deprecate legacy PReLU and tests
e836b58
to
9533134
Compare
…prove performance
313fc38
to
b4b5ae0
Compare
Other than the CPU performance improvement with Vec256 (I probably can do it in a separate PR), do I need to do further changes on this PR? @ezyang |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, thank you for debugging the perf issue, it is much appreciated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - fixes pytorch/pytorch#10723 - migrate PReLU to ATen and deprecate legacy PReLU - performance: CPU with weight.numel() = 1 ``` >>> m = nn.PReLU() >>> x = torch.randn(100, 100, 100, requires_grad=True) >>> %timeit -r 100 y = m(x) 100 loops, best of 100: 9.43 ms per loop >>> y = m(x).sum() >>> %timeit -r 100 y.backward(retain_graph=True) 10 loops, best of 100: 24.4 ms per loop >>> m = nn.PReLU() >>> x = torch.randn(100, 100, 100, requires_grad=True) >>> %timeit -r 100 y = m(x) 1000 loops, best of 100: 695 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 y.backward(retain_graph=True) 100 loops, best of 100: 2.47 ms per loop ``` CPU with weight.numel() = channels ``` >>> m = nn.PReLU(100) >>> x = torch.randn(100, 100, 100, requires_grad=True) >>> %timeit -r 100 y = m(x) 1000 loops, best of 100: 603 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 y.backward(retain_graph=True) 100 loops, best of 100: 13.3 ms per loop >>> m = nn.PReLU(100) >>> x = torch.randn(100, 100, 100, requires_grad=True) >>> %timeit -r 100 y = m(x) 1000 loops, best of 100: 655 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 y.backward(retain_graph=True) 100 loops, best of 100: 2.45 ms per loop ``` CUDA with weight.numel() = 1 ``` >>> m = nn.PReLU().cuda() >>> x = torch.randn(100, 100, 100, requires_grad=True).cuda() >>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize(); 10000 loops, best of 100: 187 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize(); 100 loops, best of 100: 2.01 ms per loop >>> m = nn.PReLU().cuda() >>> x = torch.randn(100, 100, 100, requires_grad=True).cuda() >>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize(); 1000 loops, best of 100: 195 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize(); 100 loops, best of 100: 2.28 ms per loop ``` CUDA with weight.numel() = channel ``` >>> m = nn.PReLU(100).cuda() >>> x = torch.randn(100, 100, 100, requires_grad=True).cuda() >>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize(); 1000 loops, best of 100: 174 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize(); 100 loops, best of 100: 2.27 ms per loop >>> m = nn.PReLU(100).cuda() >>> x = torch.randn(100, 100, 100, requires_grad=True).cuda() >>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize(); 10000 loops, best of 100: 181 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize(); 100 loops, best of 100: 2.26 ms per loop ``` The huge performance regression in CPU when weight.numel() = 1 is addressed by replacing at::CPU_tensor_apply* with parallelized kernels. ezyang SsnL zou3519 soumith Pull Request resolved: pytorch/pytorch#11758 Differential Revision: D9995799 Pulled By: weiyangfb fbshipit-source-id: d289937c78075f46a54dafbde92fab0cc4b5b86e
CPU with weight.numel() = 1
CPU with weight.numel() = channels
CUDA with weight.numel() = 1
CUDA with weight.numel() = channel
Updated runtimes after removing DEBUG flag when building PyTorch..
@ezyang @ssnl @zou3519 @soumith