|
| 1 | +#include "ATen/ATen.h" |
| 2 | +#include "ATen/AccumulateType.h" |
| 3 | +#include "ATen/cuda/CUDAApplyUtils.cuh" |
| 4 | +#include "detail/IndexUtils.cuh" |
| 5 | +#include "detail/TensorInfo.cuh" |
| 6 | +#include "curand_kernel.h" |
| 7 | + |
| 8 | +#include <THC/THCGeneral.h> |
| 9 | +#include <THC/THCTensorRandom.h> |
| 10 | +#include <THC/THCGenerator.hpp> |
| 11 | + |
| 12 | + |
| 13 | +THCGenerator* THCRandom_getGenerator(THCState* state); |
| 14 | + |
| 15 | +namespace at{ |
| 16 | +namespace native{ |
| 17 | + |
| 18 | +namespace { |
| 19 | + |
| 20 | +//due to limitations of philox generator UNROLL has to be 4 |
| 21 | +const int UNROLL = 4; |
| 22 | + |
| 23 | +std::pair<uint64_t, uint64_t> next_philox_seed(at::Generator* gen, uint64_t increment) { |
| 24 | + auto gen_ = THCRandom_getGenerator(at::globalContext().getTHCState()); |
| 25 | + uint64_t offset = gen_->state.philox_seed_offset.fetch_add(increment); |
| 26 | + return std::make_pair(gen_->state.initial_seed, offset); |
| 27 | +} |
| 28 | + |
| 29 | + |
| 30 | +template < |
| 31 | + typename scalar_t, |
| 32 | + typename accscalar_t, |
| 33 | + typename IndexType, |
| 34 | + int ADims> |
| 35 | +#if __CUDA_ARCH__ >= 350 |
| 36 | +__launch_bounds__(256,8) |
| 37 | +#endif |
| 38 | +__global__ void |
| 39 | +fused_dropout_kernel(cuda::detail::TensorInfo<scalar_t, IndexType> a, |
| 40 | + cuda::detail::TensorInfo<scalar_t, IndexType> b, |
| 41 | + cuda::detail::TensorInfo<uint8_t, IndexType> c, |
| 42 | + IndexType totalElements, accscalar_t p, std::pair<uint64_t, uint64_t> seeds |
| 43 | + ) { |
| 44 | + |
| 45 | + accscalar_t pinv = accscalar_t(1)/p; |
| 46 | + IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 47 | + curandStatePhilox4_32_10_t state; |
| 48 | + curand_init( |
| 49 | + seeds.first, |
| 50 | + idx, |
| 51 | + seeds.second, |
| 52 | + &state); |
| 53 | + IndexType rounded_size = ((totalElements - 1)/(blockDim.x*gridDim.x*UNROLL)+1)*blockDim.x*gridDim.x*UNROLL; |
| 54 | + for (IndexType linearIndex = idx; |
| 55 | + linearIndex < rounded_size; |
| 56 | + linearIndex += gridDim.x * blockDim.x*UNROLL) { |
| 57 | +//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything |
| 58 | + float4 rand = curand_uniform4(&state); |
| 59 | + scalar_t src[UNROLL]; |
| 60 | + rand.x = rand.x < p; |
| 61 | + rand.y = rand.y < p; |
| 62 | + rand.z = rand.z < p; |
| 63 | + rand.w = rand.w < p; |
| 64 | + for (int ii = 0; ii < UNROLL; ii++) { |
| 65 | + IndexType li = linearIndex + blockDim.x * gridDim.x * ii; |
| 66 | + if (li < totalElements) { |
| 67 | + // Convert `linearIndex` into an offset of `a` |
| 68 | + const IndexType aOffset = |
| 69 | + cuda::detail::IndexToOffset<scalar_t, IndexType, ADims>::get(li, a); |
| 70 | + src[ii] = a.data[aOffset]; |
| 71 | + } |
| 72 | + } |
| 73 | + for (int ii = 0; ii < UNROLL; ii++) { |
| 74 | + IndexType li = linearIndex + blockDim.x * gridDim.x * ii; |
| 75 | + if (li < totalElements) { |
| 76 | + // Convert `linearIndex` into an offset of `b` |
| 77 | + const IndexType bOffset = |
| 78 | + cuda::detail::IndexToOffset<scalar_t, IndexType, 1>::get(li, b); |
| 79 | + b.data[bOffset] = src[ii]*(&rand.x)[ii]*pinv; |
| 80 | + c.data[bOffset] = (uint8_t)(&rand.x)[ii]; |
| 81 | + } |
| 82 | + } |
| 83 | + __syncthreads(); |
| 84 | + } |
| 85 | +} |
| 86 | + |
| 87 | +template<typename scalar_t, typename accscalar_t> |
| 88 | +void masked_scale_kernel(at::Tensor& ret, const at::Tensor src, const at::Tensor mask, accscalar_t scale){ |
| 89 | + at::cuda::CUDA_tensor_apply3<scalar_t, scalar_t, uint8_t>(ret, src, mask, [scale]__device__(scalar_t& ret_val, const scalar_t& src_val, const uint8_t mask_val){ |
| 90 | + ret_val = mask_val * src_val * scale; |
| 91 | + }); |
| 92 | +} |
| 93 | +} //anonymous namespace |
| 94 | + |
| 95 | +std::tuple<Tensor,Tensor> |
| 96 | +fused_dropout_cuda(const Tensor& self, double p, Generator * gen){ |
| 97 | + Tensor ret = at::empty_like(self); |
| 98 | + Tensor mask = self.type().toScalarType(kByte).tensor(self.sizes()); |
| 99 | + const int64_t nelem = self.numel(); |
| 100 | + int64_t block_size = 256; |
| 101 | + unsigned int blocks_per_sm = at::globalContext().getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; |
| 102 | + dim3 dim_block(block_size); |
| 103 | + dim3 grid((nelem + block_size -1)/block_size); |
| 104 | + grid.x = std::min((unsigned int)at::globalContext().getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); |
| 105 | + int64_t nrep = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; |
| 106 | + if (cuda::detail::canUse32BitIndexMath(self)){ |
| 107 | + AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] { |
| 108 | + using accscalar_t = acc_type<scalar_t, true>; |
| 109 | + accscalar_t pa = (accscalar_t)(p); |
| 110 | + auto self_info = cuda::detail::getTensorInfo<scalar_t, unsigned int>(self); |
| 111 | + auto ret_info = cuda::detail::getTensorInfo<scalar_t, unsigned int>(ret); |
| 112 | + auto mask_info = cuda::detail::getTensorInfo<uint8_t, unsigned int>(mask); |
| 113 | + self_info.collapseDims(); |
| 114 | + ret_info.collapseDims(); |
| 115 | + mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor |
| 116 | + switch (self_info.dims) { |
| 117 | + case 1: |
| 118 | + fused_dropout_kernel<scalar_t, accscalar_t, unsigned int, 1><<<grid, dim_block, 0, globalContext().getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,nrep)); |
| 119 | + break; |
| 120 | + default: |
| 121 | + fused_dropout_kernel<scalar_t, accscalar_t, unsigned int, -1><<<dim_block, grid, 0, globalContext().getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,nrep)); |
| 122 | + } |
| 123 | + }); |
| 124 | + } else { |
| 125 | + AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] { |
| 126 | + using accscalar_t = acc_type<scalar_t, true>; |
| 127 | + accscalar_t pa = (accscalar_t)(p); |
| 128 | + auto self_info = cuda::detail::getTensorInfo<scalar_t, uint64_t>(self); |
| 129 | + auto ret_info = cuda::detail::getTensorInfo<scalar_t, uint64_t>(ret); |
| 130 | + auto mask_info = cuda::detail::getTensorInfo<uint8_t, uint64_t>(mask); |
| 131 | + self_info.collapseDims(); |
| 132 | + ret_info.collapseDims(); |
| 133 | + mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor |
| 134 | + switch (self_info.dims) { |
| 135 | + case 1: |
| 136 | + fused_dropout_kernel<scalar_t, accscalar_t, uint64_t, 1><<<dim_block, grid, 0, globalContext().getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,nrep)); |
| 137 | + break; |
| 138 | + default: |
| 139 | + fused_dropout_kernel<scalar_t, accscalar_t, uint64_t, -1><<<dim_block, grid, 0, globalContext().getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,nrep)); |
| 140 | + } |
| 141 | + }); |
| 142 | + } |
| 143 | + THCudaCheck(cudaGetLastError()); |
| 144 | + return std::tuple<Tensor,Tensor>(ret, mask); |
| 145 | +} |
| 146 | + |
| 147 | +Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){ |
| 148 | + Tensor ret = at::empty_like(self); |
| 149 | + AT_CHECK(mask.type().scalarType() == at::ScalarType::Byte, "mask should be torch.uint8 dtype"); |
| 150 | + AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "masked_scale", [&] { |
| 151 | + using accscalar_t = acc_type<scalar_t, true>; |
| 152 | + accscalar_t pa = (accscalar_t)(scale); |
| 153 | + masked_scale_kernel<scalar_t>(ret, self, mask, pa); |
| 154 | + }); |
| 155 | + return ret; |
| 156 | +} |
| 157 | + |
| 158 | +} |
| 159 | +} |
0 commit comments