Skip to content

Commit 30c0373

Browse files
author
Natalia Gimelshein
committed
add fused dropout kernels
1 parent b5c8d59 commit 30c0373

File tree

4 files changed

+215
-6
lines changed

4 files changed

+215
-6
lines changed

aten/src/ATen/native/cuda/Dropout.cu

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#include "ATen/ATen.h"
2+
#include "ATen/AccumulateType.h"
3+
#include "ATen/cuda/CUDAApplyUtils.cuh"
4+
#include "detail/IndexUtils.cuh"
5+
#include "detail/TensorInfo.cuh"
6+
#include "curand_kernel.h"
7+
8+
#include <THC/THCGeneral.h>
9+
#include <THC/THCTensorRandom.h>
10+
#include <THC/THCGenerator.hpp>
11+
12+
13+
THCGenerator* THCRandom_getGenerator(THCState* state);
14+
15+
namespace at{
16+
namespace native{
17+
18+
namespace {
19+
20+
//due to limitations of philox generator UNROLL has to be 4
21+
const int UNROLL = 4;
22+
23+
std::pair<uint64_t, uint64_t> next_philox_seed(at::Generator* gen, uint64_t increment) {
24+
auto gen_ = THCRandom_getGenerator(at::globalContext().getTHCState());
25+
uint64_t offset = gen_->state.philox_seed_offset.fetch_add(increment);
26+
return std::make_pair(gen_->state.initial_seed, offset);
27+
}
28+
29+
30+
template <
31+
typename scalar_t,
32+
typename accscalar_t,
33+
typename IndexType,
34+
int ADims>
35+
#if __CUDA_ARCH__ >= 350
36+
__launch_bounds__(256,8)
37+
#endif
38+
__global__ void
39+
fused_dropout_kernel(cuda::detail::TensorInfo<scalar_t, IndexType> a,
40+
cuda::detail::TensorInfo<scalar_t, IndexType> b,
41+
cuda::detail::TensorInfo<uint8_t, IndexType> c,
42+
IndexType totalElements, accscalar_t p, std::pair<uint64_t, uint64_t> seeds
43+
) {
44+
45+
accscalar_t pinv = accscalar_t(1)/p;
46+
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
47+
curandStatePhilox4_32_10_t state;
48+
curand_init(
49+
seeds.first,
50+
idx,
51+
seeds.second,
52+
&state);
53+
IndexType rounded_size = ((totalElements - 1)/(blockDim.x*gridDim.x*UNROLL)+1)*blockDim.x*gridDim.x*UNROLL;
54+
for (IndexType linearIndex = idx;
55+
linearIndex < rounded_size;
56+
linearIndex += gridDim.x * blockDim.x*UNROLL) {
57+
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
58+
float4 rand = curand_uniform4(&state);
59+
scalar_t src[UNROLL];
60+
rand.x = rand.x < p;
61+
rand.y = rand.y < p;
62+
rand.z = rand.z < p;
63+
rand.w = rand.w < p;
64+
for (int ii = 0; ii < UNROLL; ii++) {
65+
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
66+
if (li < totalElements) {
67+
// Convert `linearIndex` into an offset of `a`
68+
const IndexType aOffset =
69+
cuda::detail::IndexToOffset<scalar_t, IndexType, ADims>::get(li, a);
70+
src[ii] = a.data[aOffset];
71+
}
72+
}
73+
for (int ii = 0; ii < UNROLL; ii++) {
74+
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
75+
if (li < totalElements) {
76+
// Convert `linearIndex` into an offset of `b`
77+
const IndexType bOffset =
78+
cuda::detail::IndexToOffset<scalar_t, IndexType, 1>::get(li, b);
79+
b.data[bOffset] = src[ii]*(&rand.x)[ii]*pinv;
80+
c.data[bOffset] = (uint8_t)(&rand.x)[ii];
81+
}
82+
}
83+
__syncthreads();
84+
}
85+
}
86+
87+
template<typename scalar_t, typename accscalar_t>
88+
void masked_scale_kernel(at::Tensor& ret, const at::Tensor src, const at::Tensor mask, accscalar_t scale){
89+
at::cuda::CUDA_tensor_apply3<scalar_t, scalar_t, uint8_t>(ret, src, mask, [scale]__device__(scalar_t& ret_val, const scalar_t& src_val, const uint8_t mask_val){
90+
ret_val = mask_val * src_val * scale;
91+
});
92+
}
93+
} //anonymous namespace
94+
95+
std::tuple<Tensor,Tensor>
96+
fused_dropout_cuda(const Tensor& self, double p, Generator * gen){
97+
Tensor ret = at::empty_like(self);
98+
Tensor mask = self.type().toScalarType(kByte).tensor(self.sizes());
99+
const int64_t nelem = self.numel();
100+
int64_t block_size = 256;
101+
unsigned int blocks_per_sm = at::globalContext().getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
102+
dim3 dim_block(block_size);
103+
dim3 grid((nelem + block_size -1)/block_size);
104+
grid.x = std::min((unsigned int)at::globalContext().getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
105+
int64_t nrep = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
106+
if (cuda::detail::canUse32BitIndexMath(self)){
107+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] {
108+
using accscalar_t = acc_type<scalar_t, true>;
109+
accscalar_t pa = (accscalar_t)(p);
110+
auto self_info = cuda::detail::getTensorInfo<scalar_t, unsigned int>(self);
111+
auto ret_info = cuda::detail::getTensorInfo<scalar_t, unsigned int>(ret);
112+
auto mask_info = cuda::detail::getTensorInfo<uint8_t, unsigned int>(mask);
113+
self_info.collapseDims();
114+
ret_info.collapseDims();
115+
mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor
116+
switch (self_info.dims) {
117+
case 1:
118+
fused_dropout_kernel<scalar_t, accscalar_t, unsigned int, 1><<<grid, dim_block, 0, globalContext().getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,nrep));
119+
break;
120+
default:
121+
fused_dropout_kernel<scalar_t, accscalar_t, unsigned int, -1><<<dim_block, grid, 0, globalContext().getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,nrep));
122+
}
123+
});
124+
} else {
125+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] {
126+
using accscalar_t = acc_type<scalar_t, true>;
127+
accscalar_t pa = (accscalar_t)(p);
128+
auto self_info = cuda::detail::getTensorInfo<scalar_t, uint64_t>(self);
129+
auto ret_info = cuda::detail::getTensorInfo<scalar_t, uint64_t>(ret);
130+
auto mask_info = cuda::detail::getTensorInfo<uint8_t, uint64_t>(mask);
131+
self_info.collapseDims();
132+
ret_info.collapseDims();
133+
mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor
134+
switch (self_info.dims) {
135+
case 1:
136+
fused_dropout_kernel<scalar_t, accscalar_t, uint64_t, 1><<<dim_block, grid, 0, globalContext().getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,nrep));
137+
break;
138+
default:
139+
fused_dropout_kernel<scalar_t, accscalar_t, uint64_t, -1><<<dim_block, grid, 0, globalContext().getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,nrep));
140+
}
141+
});
142+
}
143+
THCudaCheck(cudaGetLastError());
144+
return std::tuple<Tensor,Tensor>(ret, mask);
145+
}
146+
147+
Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){
148+
Tensor ret = at::empty_like(self);
149+
AT_CHECK(mask.type().scalarType() == at::ScalarType::Byte, "mask should be torch.uint8 dtype");
150+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "masked_scale", [&] {
151+
using accscalar_t = acc_type<scalar_t, true>;
152+
accscalar_t pa = (accscalar_t)(scale);
153+
masked_scale_kernel<scalar_t>(ret, self, mask, pa);
154+
});
155+
return ret;
156+
}
157+
158+
}
159+
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,10 @@
495495
- func: dot_out(Tensor result, Tensor self, Tensor tensor) -> Tensor
496496
variants: function
497497

498+
- func: fused_dropout(Tensor self, double p, Generator* generator=nullptr) -> (Tensor, Tensor)
499+
dispatch:
500+
CUDA: fused_dropout_cuda
501+
498502
- func: einsum(std::string equation, TensorList tensors) -> Tensor
499503
variants: function
500504

@@ -905,6 +909,11 @@
905909
- func: logsumexp_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor
906910
variants: function
907911

912+
- func: masked_scale(Tensor self, Tensor mask, double scale) -> Tensor
913+
dispatch:
914+
CUDA: masked_scale_cuda
915+
916+
908917
- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, double margin=0.0, int64_t reduction=Reduction::ElementwiseMean) -> Tensor
909918
variants: function
910919

test/test_nn.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -712,9 +712,10 @@ def test_no_grad(self):
712712
self.assertFalse(output2.requires_grad)
713713
self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))
714714

715-
def _test_dropout(self, cls, input):
715+
def _test_dropout(self, cls, cuda, input):
716716
p = 0.2
717-
input.fill_(1 - p)
717+
device = torch.device("cuda") if cuda else torch.device("cpu")
718+
input = input.to(device).fill_(1 - p)
718719

719720
module = cls(p)
720721
input_var = torch.tensor(input, requires_grad=True)
@@ -2077,15 +2078,15 @@ def func(x):
20772078

20782079
def test_Dropout(self):
20792080
input = torch.Tensor(1000)
2080-
self._test_dropout(nn.Dropout, input)
2081+
self._test_dropout(nn.Dropout, False, input)
20812082

20822083
def test_Dropout2d(self):
20832084
b = random.randint(1, 5)
20842085
w = random.randint(1, 5)
20852086
h = random.randint(1, 5)
20862087
num_features = 1000
20872088
input = torch.Tensor(num_features, b, w, h)
2088-
self._test_dropout(nn.Dropout2d, input)
2089+
self._test_dropout(nn.Dropout2d, False, input)
20892090

20902091
def test_Dropout3d(self):
20912092
b = random.randint(1, 5)
@@ -2094,7 +2095,31 @@ def test_Dropout3d(self):
20942095
d = random.randint(1, 2)
20952096
num_features = 1000
20962097
input = torch.Tensor(num_features, b, d, w, h)
2097-
self._test_dropout(nn.Dropout3d, input)
2098+
self._test_dropout(nn.Dropout3d, False, input)
2099+
2100+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
2101+
def test_Dropout_cuda(self):
2102+
input = torch.Tensor(1000)
2103+
self._test_dropout(nn.Dropout, True, input)
2104+
2105+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
2106+
def test_Dropout2d_cuda(self):
2107+
b = random.randint(1, 5)
2108+
w = random.randint(1, 5)
2109+
h = random.randint(1, 5)
2110+
num_features = 1000
2111+
input = torch.Tensor(num_features, b, w, h)
2112+
self._test_dropout(nn.Dropout2d, True, input)
2113+
2114+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
2115+
def test_Dropout3d_cuda(self):
2116+
b = random.randint(1, 5)
2117+
w = random.randint(1, 5)
2118+
h = random.randint(1, 5)
2119+
d = random.randint(1, 2)
2120+
num_features = 1000
2121+
input = torch.Tensor(num_features, b, d, w, h)
2122+
self._test_dropout(nn.Dropout3d, True, input)
20982123

20992124
def test_AlphaDropout(self):
21002125
# generate random tensor with zero mean and unit std

torch/nn/_functions/dropout.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from torch.autograd.function import InplaceFunction
33
from itertools import repeat
4+
from torch.autograd.function import once_differentiable
45

56

67
class Dropout(InplaceFunction):
@@ -15,6 +16,10 @@ def symbolic(g, input, p=0.5, train=False, inplace=False):
1516
r, _ = g.op("Dropout", input, ratio_f=p, is_test_i=not train, outputs=2)
1617
return r
1718

19+
@staticmethod
20+
def _fused_kernel_acceptable(input, p, cls_name, inplace):
21+
return input.is_cuda and p > 0 and p < 1 and not inplace and cls_name == 'Dropout'
22+
1823
@classmethod
1924
def forward(cls, ctx, input, p=0.5, train=False, inplace=False):
2025
if p < 0 or p > 1:
@@ -23,10 +28,15 @@ def forward(cls, ctx, input, p=0.5, train=False, inplace=False):
2328
ctx.p = p
2429
ctx.train = train
2530
ctx.inplace = inplace
31+
ctx.use_fused_kernel = Dropout._fused_kernel_acceptable(input, ctx.p, cls.__name__, ctx.inplace)
2632

2733
if ctx.p == 0 or not ctx.train:
2834
return input
2935

36+
if ctx.use_fused_kernel:
37+
output, ctx.noise = input.fused_dropout(1 - ctx.p)
38+
return output
39+
3040
if ctx.inplace:
3141
ctx.mark_dirty(input)
3242
output = input
@@ -45,7 +55,13 @@ def forward(cls, ctx, input, p=0.5, train=False, inplace=False):
4555

4656
@staticmethod
4757
def backward(ctx, grad_output):
48-
if ctx.p > 0 and ctx.train:
58+
if ctx.use_fused_kernel:
59+
if not grad_output.requires_grad:
60+
return grad_output.masked_scale(ctx.noise, 1. / (1 - ctx.p)), None, None, None
61+
else:
62+
# use autograd-friendly backward if double backward is required
63+
return grad_output * (ctx.noise.type_as(grad_output) * (1. / (1 - ctx.p))), None, None, None
64+
elif ctx.p > 0 and ctx.train:
4965
return grad_output * ctx.noise, None, None, None
5066
else:
5167
return grad_output, None, None, None

0 commit comments

Comments
 (0)