Skip to content

Commit 543faed

Browse files
apaszkeRob Kunkle
authored and
Rob Kunkle
committed
Move dropout and alpha dropout to ATen (pytorch#10384)
Summary: zdevito ezyang Pull Request resolved: pytorch#10384 Reviewed By: ezyang Differential Revision: D9272583 Pulled By: apaszke fbshipit-source-id: ed5d37b28ce9ff25800bbaa0daf066cfbf1f9921
1 parent 9c691cc commit 543faed

File tree

10 files changed

+233
-175
lines changed

10 files changed

+233
-175
lines changed

aten/src/ATen/native/Dropout.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#include "ATen/ATen.h"
2+
#include "ATen/Dispatch.h"
3+
4+
namespace at { namespace native {
5+
6+
namespace {
7+
8+
Tensor make_feature_noise(const Tensor& input) {
9+
auto input_sizes = input.sizes();
10+
AT_CHECK(input.dim() >= 2, "Feature dropout requires at least 2 dimensions in the input");
11+
std::vector<int64_t> sizes;
12+
sizes.reserve(input.dim());
13+
sizes.push_back(input_sizes[0]);
14+
sizes.push_back(input_sizes[1]);
15+
for (int64_t i = 2; i < input.dim(); ++i)
16+
sizes.push_back(1);
17+
return at::empty(sizes, input.options());
18+
}
19+
20+
bool is_fused_kernel_acceptable(const Tensor& input, double p) {
21+
return input.is_cuda() && p > 0 && p < 1;
22+
}
23+
24+
// NB: sure, we could have used different overloads here, but I would feel insecure
25+
// knowing that this dispatch depends only on the constness of the references
26+
template<bool inplace>
27+
Tensor& multiply(Tensor& input, const Tensor& noise) {
28+
static_assert(inplace, "Wrong multiply overload triggered in Dropout.cpp");
29+
return input.mul_(noise);
30+
}
31+
32+
template<bool inplace>
33+
Tensor multiply(const Tensor& input, const Tensor& noise) {
34+
static_assert(!inplace, "Wrong multiply overload triggered in Dropout.cpp");
35+
return input.mul(noise);
36+
}
37+
38+
template<bool feature_dropout, bool alpha_dropout, bool inplace, typename T>
39+
typename std::conditional<inplace, Tensor&, Tensor>::type
40+
_dropout_impl(T& input, double p, bool train) {
41+
AT_CHECK(p >= 0 && p <= 1, "dropout probability has to be between 0 and 1, but got ", p);
42+
if (p == 0 || !train) {
43+
return input;
44+
}
45+
46+
if (p == 1) {
47+
return multiply<inplace>(input, at::zeros({}, input.options()));
48+
}
49+
50+
at::Tensor b; // used for alpha_dropout only
51+
auto noise = feature_dropout ? make_feature_noise(input) : at::empty_like(input);
52+
noise.bernoulli_(1 - p);
53+
if (alpha_dropout) {
54+
constexpr double alpha = 1.7580993408473766;
55+
double a = 1. / std::sqrt((alpha * alpha * p + 1) * (1 - p));
56+
b = noise.add(-1).mul_(alpha * a).add_(alpha * a * p);
57+
noise.mul_(a);
58+
} else {
59+
noise.div_(1 - p);
60+
}
61+
62+
if (!alpha_dropout) {
63+
return multiply<inplace>(input, noise);
64+
} else {
65+
return multiply<inplace>(input, noise).add_(b);
66+
}
67+
}
68+
69+
#define ALIAS_SPECIALIZATION(ALIAS_NAME, IS_FEATURE, IS_ALPHA) \
70+
template <bool inplace, typename... Args> \
71+
typename std::conditional<inplace, Tensor&, Tensor>::type \
72+
ALIAS_NAME(Args&&... args) { \
73+
return _dropout_impl<IS_FEATURE, IS_ALPHA, inplace>(std::forward<Args>(args)...); \
74+
}
75+
76+
ALIAS_SPECIALIZATION(_dropout, false, false)
77+
ALIAS_SPECIALIZATION(_feature_dropout, true, false)
78+
ALIAS_SPECIALIZATION(_alpha_dropout, false, true )
79+
ALIAS_SPECIALIZATION(_feature_alpha_dropout, true, true )
80+
81+
} // anomymous namepsace
82+
83+
Tensor dropout(const Tensor& input, double p, bool train) {
84+
if (is_fused_kernel_acceptable(input, p)) {
85+
return std::get<0>(input._fused_dropout(1 - p));
86+
}
87+
return _dropout<false>(input, p, train);
88+
}
89+
90+
Tensor& dropout_(Tensor& input, double p, bool train) {
91+
return _dropout<true>(input, p, train);
92+
}
93+
94+
Tensor feature_dropout(const Tensor& input, double p, bool train) {
95+
return _feature_dropout<false>(input, p, train);
96+
}
97+
98+
Tensor& feature_dropout_(Tensor& input, double p, bool train) {
99+
return _feature_dropout<true>(input, p, train);
100+
}
101+
102+
Tensor alpha_dropout(const Tensor& input, double p, bool train) {
103+
return _alpha_dropout<false>(input, p, train);
104+
}
105+
106+
Tensor& alpha_dropout_(Tensor& input, double p, bool train) {
107+
return _alpha_dropout<true>(input, p, train);
108+
}
109+
110+
Tensor feature_alpha_dropout(const Tensor& input, double p, bool train) {
111+
return _feature_alpha_dropout<false>(input, p, train);
112+
}
113+
114+
Tensor& feature_alpha_dropout_(Tensor& input, double p, bool train) {
115+
return _feature_alpha_dropout<true>(input, p, train);
116+
}
117+
118+
}} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,30 @@
6262
dispatch:
6363
CUDA: masked_scale_cuda
6464

65+
- func: dropout(Tensor input, double p, bool train) -> Tensor
66+
variants: function
67+
68+
- func: dropout_(Tensor self, double p, bool train) -> Tensor
69+
variants: function
70+
71+
- func: feature_dropout(Tensor input, double p, bool train) -> Tensor
72+
variants: function
73+
74+
- func: feature_dropout_(Tensor self, double p, bool train) -> Tensor
75+
variants: function
76+
77+
- func: alpha_dropout(Tensor input, double p, bool train) -> Tensor
78+
variants: function
79+
80+
- func: alpha_dropout_(Tensor self, double p, bool train) -> Tensor
81+
variants: function
82+
83+
- func: feature_alpha_dropout(Tensor input, double p, bool train) -> Tensor
84+
variants: function
85+
86+
- func: feature_alpha_dropout_(Tensor self, double p, bool train) -> Tensor
87+
variants: function
88+
6589
- func: abs(Tensor self) -> Tensor
6690

6791
- func: abs_(Tensor self) -> Tensor

test/expect/TestJit.test_alexnet.expect

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,23 @@ graph(%0 : Double(1, 3, 224, 224)
4444
%46 : int = prim::Constant[value=9216](), scope: AlexNet
4545
%47 : int[] = prim::ListConstruct(%45, %46), scope: AlexNet
4646
%48 : Double(1, 9216) = aten::view(%41, %47), scope: AlexNet
47-
%49 : Double(1, 9216) = ^Dropout(0.5, True, False)(%48), scope: AlexNet/Sequential[classifier]/Dropout[0]
48-
%50 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1]
49-
%51 : int = prim::Constant[value=4096](), scope: AlexNet/Sequential[classifier]/Linear[1]
50-
%52 : int[] = prim::ListConstruct(%21, %51), scope: AlexNet/Sequential[classifier]/Linear[1]
51-
%53 : Double(1, 4096) = aten::expand(%12, %52, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
52-
%54 : Double(1, 4096) = aten::addmm(%53, %49, %50, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
53-
%55 : Double(1, 4096) = aten::threshold(%54, %23, %23), scope: AlexNet/Sequential[classifier]/ReLU[2]
54-
%56 : Double(1, 4096) = ^Dropout(0.5, True, False)(%55), scope: AlexNet/Sequential[classifier]/Dropout[3]
55-
%57 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4]
56-
%58 : Double(1, 4096) = aten::expand(%14, %52, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
57-
%59 : Double(1, 4096) = aten::addmm(%58, %56, %57, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
58-
%60 : Double(1, 4096) = aten::threshold(%59, %23, %23), scope: AlexNet/Sequential[classifier]/ReLU[5]
59-
%61 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6]
60-
%62 : int = prim::Constant[value=1000](), scope: AlexNet/Sequential[classifier]/Linear[6]
61-
%63 : int[] = prim::ListConstruct(%21, %62), scope: AlexNet/Sequential[classifier]/Linear[6]
62-
%64 : Double(1, 1000) = aten::expand(%16, %63, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
63-
%65 : Double(1, 1000) = aten::addmm(%64, %60, %61, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
64-
return (%65);
47+
%49 : float = prim::Constant[value=0.5](), scope: AlexNet/Sequential[classifier]/Dropout[0]
48+
%50 : Double(1!, 9216) = aten::dropout(%48, %49, %21), scope: AlexNet/Sequential[classifier]/Dropout[0]
49+
%51 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1]
50+
%52 : int = prim::Constant[value=4096](), scope: AlexNet/Sequential[classifier]/Linear[1]
51+
%53 : int[] = prim::ListConstruct(%21, %52), scope: AlexNet/Sequential[classifier]/Linear[1]
52+
%54 : Double(1, 4096) = aten::expand(%12, %53, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
53+
%55 : Double(1, 4096) = aten::addmm(%54, %50, %51, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
54+
%56 : Double(1, 4096) = aten::threshold(%55, %23, %23), scope: AlexNet/Sequential[classifier]/ReLU[2]
55+
%57 : Double(1!, 4096) = aten::dropout(%56, %49, %21), scope: AlexNet/Sequential[classifier]/Dropout[3]
56+
%58 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4]
57+
%59 : Double(1, 4096) = aten::expand(%14, %53, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
58+
%60 : Double(1, 4096) = aten::addmm(%59, %57, %58, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
59+
%61 : Double(1, 4096) = aten::threshold(%60, %23, %23), scope: AlexNet/Sequential[classifier]/ReLU[5]
60+
%62 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6]
61+
%63 : int = prim::Constant[value=1000](), scope: AlexNet/Sequential[classifier]/Linear[6]
62+
%64 : int[] = prim::ListConstruct(%21, %63), scope: AlexNet/Sequential[classifier]/Linear[6]
63+
%65 : Double(1, 1000) = aten::expand(%16, %64, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
64+
%66 : Double(1, 1000) = aten::addmm(%65, %61, %62, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
65+
return (%66);
6566
}
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
graph(%0 : Double(2, 2)) {
2-
%1 : Double(2, 2) = ^Dropout(0.6, True, False)(%0), scope: Dropout
3-
return (%1);
2+
%1 : float = prim::Constant[value=0.6](), scope: Dropout
3+
%2 : int = prim::Constant[value=1](), scope: Dropout
4+
%3 : Double(2, 2) = aten::dropout(%0, %1, %2), scope: Dropout
5+
return (%3);
46
}

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@
228228
self: grad * tensor
229229
tensor: grad * self
230230

231+
- name: _fused_dropout(Tensor self, double p, Generator generator)
232+
self: _fused_dropout_backward(grad, result1, p)
233+
231234
- name: eig(Tensor self, bool eigenvectors)
232235
self: not_implemented("eig")
233236

tools/autograd/templates/Functions.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,16 @@ Tensor repeat_backward(Tensor grad, int64_t input_dims, IntList repeats) {
525525
return grad;
526526
}
527527

528+
// p1m == 1 - p
529+
Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) {
530+
if (grad.requires_grad()) {
531+
// Use autograd-friendly backward if double backward is required
532+
return grad * (mask.type_as(grad) * (1. / p1m));
533+
} else {
534+
return grad._masked_scale(mask, 1. / p1m);
535+
}
536+
}
537+
528538
Tensor select_equals_backward(Tensor grad, const Tensor & input, const Tensor & value) {
529539
auto grad_input = zeros_like(input);
530540
grad_input.masked_fill_(input == value, grad);

torch/nn/_functions/dropout.py

Lines changed: 0 additions & 146 deletions
This file was deleted.

torch/nn/backends/thnn.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,12 @@ def _initialize_backend():
2222
from .._functions.thnn import _all_functions as _thnn_functions
2323
from .._functions.rnn import RNN, \
2424
RNNTanhCell, RNNReLUCell, GRUCell, LSTMCell
25-
from .._functions.dropout import Dropout, FeatureDropout
2625

2726
backend.register_function('RNN', RNN)
2827
backend.register_function('RNNTanhCell', RNNTanhCell)
2928
backend.register_function('RNNReLUCell', RNNReLUCell)
3029
backend.register_function('LSTMCell', LSTMCell)
3130
backend.register_function('GRUCell', GRUCell)
32-
backend.register_function('Dropout', Dropout)
33-
backend.register_function('Dropout2d', FeatureDropout)
34-
backend.register_function('Dropout3d', FeatureDropout)
3531
for cls in _thnn_functions:
3632
name = cls.__name__
3733
backend.register_function(name, cls)

0 commit comments

Comments
 (0)