Skip to content

migrate PReLU to ATen #11758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed

Conversation

weiyangfb
Copy link
Contributor

@weiyangfb weiyangfb commented Sep 17, 2018

CPU with weight.numel() = 1

# === previous ===
>>> m = nn.PReLU()
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
1000 loops, best of 100: 1.5 ms per loop

>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
10 loops, best of 100: 14.9 ms per loop

# === current ====
>>> m = nn.PReLU()
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
1000 loops, best of 100: 179 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
100 loops, best of 100: 378 µs per loop

CPU with weight.numel() = channels

# === previous ===
>>> m = nn.PReLU(100)
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
1000 loops, best of 100: 351 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
10 loops, best of 100: 7.14 ms per loop

# === current ====
>>> m = nn.PReLU(100)
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
1000 loops, best of 100: 260 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
100 loops, best of 100: 854 µs per loop

CUDA with weight.numel() = 1

# === previous ===
>>> m = nn.PReLU().cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
10000 loops, best of 100: 97.2 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
1000 loops, best of 100: 1.3 ms per loop

# === current ====
>>> m = nn.PReLU().cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
10000 loops, best of 100: 78.9 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
1000 loops, best of 100: 1.27 ms per loop

CUDA with weight.numel() = channel

# === previous ===
>>> m = nn.PReLU(100).cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
10000 loops, best of 100: 101 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
1000 loops, best of 100: 1.67 ms per loop

# === current ====
>>> m = nn.PReLU(100).cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
10000 loops, best of 100: 76.9 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
1000 loops, best of 100: 1.12 ms per loop

Updated runtimes after removing DEBUG flag when building PyTorch..

@ezyang @ssnl @zou3519 @soumith

@weiyangfb weiyangfb force-pushed the prelu_segfault branch 6 times, most recently from d5c5bd3 to 9496032 Compare September 18, 2018 03:21
@weiyangfb weiyangfb changed the title [wip] migrate PReLU to ATen migrate PReLU to ATen Sep 18, 2018
@weiyangfb weiyangfb force-pushed the prelu_segfault branch 5 times, most recently from 456e60f to ede17ee Compare September 18, 2018 04:18
@ezyang
Copy link
Contributor

ezyang commented Sep 18, 2018

Not sure about the perf, but shouldn't you delete the old THNN implementation now?

@ezyang
Copy link
Contributor

ezyang commented Sep 18, 2018

Can we get a billing of changes, i.e., what you changed when you did the port?

@weiyangfb weiyangfb force-pushed the prelu_segfault branch 2 times, most recently from a85d2f9 to b9909af Compare September 18, 2018 05:53
@fmassa
Copy link
Member

fmassa commented Sep 18, 2018

What's the plan with TensorIterator? Is it going to be replacing CPU_tensor_apply?

@soumith
Copy link
Member

soumith commented Sep 18, 2018

@fmassa it will replace both CPU and GPU tensor apply

@fmassa
Copy link
Member

fmassa commented Sep 18, 2018

Should we instead try using TensorIterator here instead of the CPU_tensor_apply, or is it not yet there?

@weiyangfb
Copy link
Contributor Author

@ezyang sure, I took a closer look at the forward implementation at CPU when weight.numel() = 1, and it looks the same to me. Here's the previous implementation:

void THNN_(PReLU_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
THTensor *weight)
{
THTensor_(resizeAs)(output, input);
int64_t nOutputPlane = THTensor_(numel)(weight);
if (nOutputPlane == 1)
{
// handle shared parameter case
scalar_t w = *weight->data<scalar_t>();
TH_TENSOR_APPLY2(scalar_t, output, scalar_t, input,
const scalar_t r = (*input_data > 0) ? 1 : w;
*output_data = *input_data * r;
);
return;
}

auto strides = input.strides();

// case1: shared weight for all channels
if (weight_num == 1) {

This comment was marked as off-topic.

This comment was marked as off-topic.

Tensor weight_grad_collector = at::empty_like(input);

// case1: shared parameter for all channels
if (weight_num == 1) {

This comment was marked as off-topic.

@weiyangfb
Copy link
Contributor Author

@fmassa I wish I could use TensorIterator instead, but I haven't found a direct replacement to CPU_tensor_apply *, right now I just parallelized the CPU kernels to mitigate the performance issues.

for (j = 0; j < channel_size; j++) {
for (k = 0; k < input_stride1; k++) {
int64_t pos = i * input_stride0 + j * input_stride1 + k;
result_data[pos] = (input_data[pos] > 0) ? input_data[pos] : weight_data[j] * input_data[pos];

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@weiyangfb
Copy link
Contributor Author

@ezyang I will keep the THNN & THCUNN code for now since @cpuhrsch is planning to nuke /legacy/nn

input_stride0,
input_stride1);
});
// update weight_grad

This comment was marked as off-topic.

input_stride1,
input_numel);
});
// update weight_grad

This comment was marked as off-topic.

@cpuhrsch
Copy link
Contributor

By "CPU_tensor_apply" you all are referring to TH? We have an equivalent within ATen as well. Presumably that will be replaced by TensorIterator as well? @colesbury

@cpuhrsch
Copy link
Contributor

In addition to this we also have Parallel.h that allows us to abstract the OMP calls.

#pragma omp parallel for private(i) if (input_numel > 1000)
for (i = 0; i < input_numel; i++) {
scalar_t input_data_val = input_data[i];
result_data[i] = (input_data_val > 0) ? input_data_val : weight_val * input_data_val;

This comment was marked as off-topic.

This comment was marked as off-topic.

// multiply values at each channel with weight[channel]
#pragma omp parallel for private(i) if (input_numel > 1000)
for (i = 0; i < input_numel; i++) {
int64_t channel = (i % input_stride0) / input_stride1;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

for (i = 0; i < input_numel; i++) {
scalar_t input_data_val = input_data[i];
scalar_t grad_out_data_val = grad_out_data[i];
input_grad_data[i] = (input_data_val > 0) ? grad_out_data_val : weight_val * grad_out_data_val;

This comment was marked as off-topic.

@cpuhrsch
Copy link
Contributor

Another thing to (potentially) consider here is writing this in terms of vec256 within native/cpu for the CPU kernel - see #11565 as an example. This is much more involved however and mostly benefits CPU-bound kernels.

@resistor
Copy link
Contributor

@cpuhrsch I extracted the kernel and confirmed that I was able to get the compiler to generate fully vectorized code up to the AVX2 using the suggestions I posted.

@cpuhrsch
Copy link
Contributor

@resistor - the only issue is that this won't be compiled with -mavx/-mavx2 etc. if it doesn't live within native/cpu for the OSS version, where we rely on a dispatch to not assume avx/avx2 capabilities for any CPU. But we can easily move it into there and use the autovectorization. Having said that, just because the compiler vectorizes it doesn't mean it's faster, because it might try to maintain invariants we don't care about. For example, we might want to tradeoff adding one number at a time with adding 8 floats to 8 adders at a time. This won't be the same results, but it's much faster and will still be reproducible.

@resistor
Copy link
Contributor

@cpuhrsch - Makes sense re: dispatch.

re: autovectorization, for this particular kernel there's not a lot of room for variation, since's it's a purely element-wise filter. Getting the compiler to do the right thing for an accumulator loop is a lot trickier and more fragile.

@cpuhrsch
Copy link
Contributor

@resistor - on element-wise operations we also have vml.h that defines a set of functions that can act as an interface to Intel's VML, or a loop + SLEEF if that's not available. Intel VML is much faster and also parallelizes (but of course doesn't let you merge operations). It's also worth checking whether some parts here can be gutted and replaced by Ops within MKL/MKL-DNN/iDeep etc.

Having said this, in my mind, none of this is required for this and there's still value in dumping and simplifying code from TH/THNN and putting it into native.

@weiyangfb
Copy link
Contributor Author

weiyangfb commented Sep 19, 2018

After making some changes on the CPU kernel (removing mod & div, split lines for compiler optimization, etc..), now the runtime looks much nicer:

>>> m = nn.PReLU()
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit y = m(x)
1000 loops, best of 3: 764 µs per loop

>>> y = m(x).sum()
>>> %timeit y.backward(retain_graph=True)
100 loops, best of 3: 3.06 ms per loop

>>> m = nn.PReLU(100)
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit y = m(x)
1000 loops, best of 3: 877 µs per loop

>>> y = m(x).sum()
>>> %timeit y.backward(retain_graph=True)
100 loops, best of 3: 4.53 ms per loop

Sorry I was looking at the wrong results. Just updated...

@weiyangfb
Copy link
Contributor Author

@cpuhrsch yeah, parallel_for is very nice to use. But I also want to do reduction in the loop together, and so I hacked it up this way. It would be very nice if I can learn and use vec256 too! I will look into the example use case.

@weiyangfb
Copy link
Contributor Author

@cpuhrsch by CPU_tensor_apply I meant the one at ATen. At the beginning I had some performance with it, I guess it was mainly because all the optimization tricks weren't carried through. It would be nice to have TensorIterator here as well, but I haven't found an easy way to replace CPU_tensor_apply directly.

auto strides = input.strides();

// case1: shared weight for all channels
if (weight_num == 1) {

This comment was marked as off-topic.

at::cuda::CUDA_tensor_apply2<scalar_t, scalar_t>(
input,
result,
[=] __device__ (

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@weiyangfb
Copy link
Contributor Author

Other than the CPU performance improvement with Vec256 (I probably can do it in a separate PR), do I need to do further changes on this PR? @ezyang

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, thank you for debugging the perf issue, it is much appreciated.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Sep 22, 2018
Summary:
- fixes pytorch/pytorch#10723
- migrate PReLU to ATen and deprecate legacy PReLU
- performance:

CPU with weight.numel() = 1
```
>>> m = nn.PReLU()
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
100 loops, best of 100: 9.43 ms per loop

>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
10 loops, best of 100: 24.4 ms per loop

>>> m = nn.PReLU()
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
1000 loops, best of 100: 695 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
100 loops, best of 100: 2.47 ms per loop
```

CPU with weight.numel() = channels
```
>>> m = nn.PReLU(100)
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
1000 loops, best of 100: 603 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
100 loops, best of 100: 13.3 ms per loop

>>> m = nn.PReLU(100)
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
1000 loops, best of 100: 655 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
100 loops, best of 100: 2.45 ms per loop
```

CUDA with weight.numel() = 1
```
>>> m = nn.PReLU().cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
10000 loops, best of 100: 187 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
100 loops, best of 100: 2.01 ms per loop

>>> m = nn.PReLU().cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
1000 loops, best of 100: 195 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
100 loops, best of 100: 2.28 ms per loop
```

CUDA with weight.numel() = channel
```
>>> m = nn.PReLU(100).cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
1000 loops, best of 100: 174 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
100 loops, best of 100: 2.27 ms per loop

>>> m = nn.PReLU(100).cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
10000 loops, best of 100: 181 µs per loop

>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
100 loops, best of 100: 2.26 ms per loop
```

The huge performance regression in CPU when weight.numel() = 1 is addressed by replacing at::CPU_tensor_apply* with parallelized kernels.

ezyang SsnL zou3519  soumith
Pull Request resolved: pytorch/pytorch#11758

Differential Revision: D9995799

Pulled By: weiyangfb

fbshipit-source-id: d289937c78075f46a54dafbde92fab0cc4b5b86e
@ezyang ezyang added the merged label Jun 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Segfault during backward when using PReLU
8 participants