Skip to content

Replace hcRNG with rocRAND. #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Aug 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c1652be
Replace hcRNG with rocRAND.
iotamudelta Jul 17, 2018
61a15cb
Merge branch 'master' into rocRAND_PR
iotamudelta Jul 18, 2018
b9518ec
Merge branch 'master' into rocRAND_PR
iotamudelta Jul 19, 2018
87a3cba
Remove hcRNG dependency and get rocRAND v.1.0
iotamudelta Jul 24, 2018
0f6110f
Merge remote-tracking branch 'rocm_upstream/master' into rocRAND_PR
iotamudelta Jul 24, 2018
ec192fd
Merge pull request #73 from iotamudelta/master
iotamudelta Jul 26, 2018
fbd1d65
Merge branch 'master' into rocRAND_PR
iotamudelta Jul 26, 2018
e177794
Merge remote-tracking branch 'rocm_upstream/master' into rocRAND_PR
iotamudelta Jul 26, 2018
d1fd62c
Do not install the rocRAND release version, switch to master in
iotamudelta Jul 26, 2018
46c121c
Merge branch 'master' into enableunittests
iotamudelta Aug 6, 2018
789cf1a
Minor changes to pass flake8 tests
jithunnair-amd Aug 6, 2018
dd0d208
Merge pull request #100 from jithunnair-amd/enable_unit_tests_for_rocm
iotamudelta Aug 6, 2018
b56ed34
Merge branch 'master' into rocRAND_PR
iotamudelta Aug 6, 2018
7376e8e
Merge remote-tracking branch 'upstream/master' into rocRAND_PR
iotamudelta Aug 6, 2018
b51b51e
Merge remote-tracking branch 'rocm_upstream/master' into rocRAND_PR
iotamudelta Aug 7, 2018
e9c047e
Merge remote-tracking branch 'rocm_upstream/enableunittests' into roc…
iotamudelta Aug 7, 2018
f1c36e6
Do not set SHARED flag unconditionally here.
iotamudelta Aug 7, 2018
d306547
Merge branch 'master' into rocRAND_PR
iotamudelta Aug 7, 2018
4692427
Merge branch 'master' into rocRAND_PR
iotamudelta Aug 7, 2018
7497949
Optional input lengths in CTC op (#10228)
viswanathgs Aug 7, 2018
5bb2149
add fused dropout kernels (#9666)
Aug 7, 2018
5390476
Add tracing to custom op and simplify tracer overall (#10212)
goldsborough Aug 7, 2018
eb5975a
Merge remote-tracking branch 'upstream/master'
iotamudelta Aug 7, 2018
ca94a4a
Merge branch 'master' into rocRAND_PR
iotamudelta Aug 7, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .jenkins/pytorch/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ cmake --version
pip install -r requirements.txt || true

if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
# This is necessary in order to cross compile (or else we'll have missing GPU device).
export MAX_JOBS=4
# This is necessary in order to cross compile (or else we'll have missing GPU device).
export HCC_AMDGPU_TARGET=gfx900
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ ENDIF()
IF(USE_ROCM)
### Link in the ROCm libraries BLAS / RNG.
FIND_LIBRARY(HIPBLAS_LIBRARY hipblas HINTS ${HIPBLAS_PATH}/lib)
FIND_LIBRARY(HIPRNG_LIBRARY hcrng HINTS ${HIPRNG_PATH}/lib)
FIND_LIBRARY(HIPRAND_LIBRARY hiprand HINTS ${HIPRAND_PATH}/lib)

list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${HIPBLAS_LIBRARY} ${HIPRNG_LIBRARY})
list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${HIPBLAS_LIBRARY} ${HIPRAND_LIBRARY})
ENDIF()

# Include CPU paths for CUDA as well
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDAApplyUtils.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "detail/IndexUtils.cuh"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include "ATen/TensorUtils.h"
#include "THC/THCAtomics.cuh"
#include "ATen/cuda/CUDAContext.h"
Expand Down
162 changes: 162 additions & 0 deletions aten/src/ATen/native/cuda/Dropout.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAApplyUtils.cuh"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include "ATen/cuda/detail/TensorInfo.cuh"
#include "curand_kernel.h"

#include <THC/THCGeneral.h>
#include <THC/THCTensorRandom.h>
#include <THC/THCGenerator.hpp>


THCGenerator* THCRandom_getGenerator(THCState* state);

namespace at{
namespace native{

namespace {

// philox generates 128 bits of randomness at a time. Kernel uses this explicitly by putting suitably transformed result into float4
// for all members of float4 to be consumed UNROLL has to be 4. Don't change!
const int UNROLL = 4;

std::pair<uint64_t, uint64_t> next_philox_seed(at::Generator* gen, uint64_t increment) {
auto gen_ = THCRandom_getGenerator(at::globalContext().getTHCState());
uint64_t offset = gen_->state.philox_seed_offset.fetch_add(increment);
return std::make_pair(gen_->state.initial_seed, offset);
}


template <
typename scalar_t,
typename accscalar_t,
typename IndexType,
int ADims>
#if __CUDA_ARCH__ >= 350
__launch_bounds__(256,8)
#endif
__global__ void
fused_dropout_kernel(cuda::detail::TensorInfo<scalar_t, IndexType> a,
cuda::detail::TensorInfo<scalar_t, IndexType> b,
cuda::detail::TensorInfo<uint8_t, IndexType> c,
IndexType totalElements, accscalar_t p, std::pair<uint64_t, uint64_t> seeds
) {

accscalar_t pinv = accscalar_t(1)/p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
idx,
seeds.second,
&state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL) {
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL];
rand.x = rand.x < p;
rand.y = rand.y < p;
rand.z = rand.z < p;
rand.w = rand.w < p;
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
// Convert `linearIndex` into an offset of `a`
const IndexType aOffset =
cuda::detail::IndexToOffset<scalar_t, IndexType, ADims>::get(li, a);
src[ii] = a.data[aOffset];
}
}
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
// Convert `linearIndex` into an offset of `b`
const IndexType bOffset =
cuda::detail::IndexToOffset<scalar_t, IndexType, 1>::get(li, b);
b.data[bOffset] = src[ii]*(&rand.x)[ii]*pinv;
c.data[bOffset] = (uint8_t)(&rand.x)[ii];
}
}
__syncthreads();
}
}

template<typename scalar_t, typename accscalar_t>
void masked_scale_kernel(at::Tensor& ret, const at::Tensor src, const at::Tensor mask, accscalar_t scale){
at::cuda::CUDA_tensor_apply3<scalar_t, scalar_t, uint8_t>(ret, src, mask, [scale]__device__(scalar_t& ret_val, const scalar_t& src_val, const uint8_t mask_val){
ret_val = (float)mask_val * src_val * scale;
});
}
} //anonymous namespace

std::tuple<Tensor,Tensor>
fused_dropout_cuda(const Tensor& self, double p, Generator * gen){
Tensor ret = at::empty_like(self);
Tensor mask = self.type().toScalarType(kByte).tensor(self.sizes());
const int64_t nelem = self.numel();
const int64_t block_size = 256;
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
dim3 dim_block(block_size);
dim3 grid((nelem + block_size -1)/block_size);
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state
int64_t counter_offset = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
if (cuda::detail::canUse32BitIndexMath(self)){
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] {
using accscalar_t = acc_type<scalar_t, true>;
accscalar_t pa = (accscalar_t)(p);
auto self_info = cuda::detail::getTensorInfo<scalar_t, unsigned int>(self);
auto ret_info = cuda::detail::getTensorInfo<scalar_t, unsigned int>(ret);
auto mask_info = cuda::detail::getTensorInfo<uint8_t, unsigned int>(mask);
self_info.collapseDims();
ret_info.collapseDims();
mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor
switch (self_info.dims) {
case 1:
fused_dropout_kernel<scalar_t, accscalar_t, unsigned int, 1><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,counter_offset));
break;
default:
fused_dropout_kernel<scalar_t, accscalar_t, unsigned int, -1><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,counter_offset));
}
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] {
using accscalar_t = acc_type<scalar_t, true>;
accscalar_t pa = (accscalar_t)(p);
auto self_info = cuda::detail::getTensorInfo<scalar_t, uint64_t>(self);
auto ret_info = cuda::detail::getTensorInfo<scalar_t, uint64_t>(ret);
auto mask_info = cuda::detail::getTensorInfo<uint8_t, uint64_t>(mask);
self_info.collapseDims();
ret_info.collapseDims();
mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor
switch (self_info.dims) {
case 1:
fused_dropout_kernel<scalar_t, accscalar_t, uint64_t, 1><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,counter_offset));
break;
default:
fused_dropout_kernel<scalar_t, accscalar_t, uint64_t, -1><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,counter_offset));
}
});
}
THCudaCheck(cudaGetLastError());
return std::tuple<Tensor,Tensor>(ret, mask);
}

Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){
Tensor ret = at::empty_like(self);
AT_CHECK(mask.type().scalarType() == at::ScalarType::Byte, "mask should be torch.uint8 dtype");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "masked_scale", [&] {
using accscalar_t = acc_type<scalar_t, true>;
accscalar_t pa = (accscalar_t)(scale);
masked_scale_kernel<scalar_t>(ret, self, mask, pa);
});
return ret;
}

}
}
9 changes: 9 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@
dispatch:
CUDA: _cudnn_init_dropout_state

- func: _fused_dropout(Tensor self, double p, Generator* generator=nullptr) -> (Tensor, Tensor)
dispatch:
CUDA: fused_dropout_cuda

- func: _masked_scale(Tensor self, Tensor mask, double scale) -> Tensor
dispatch:
CUDA: masked_scale_cuda

- func: abs(Tensor self) -> Tensor

- func: abs_(Tensor self) -> Tensor
Expand Down Expand Up @@ -1003,6 +1011,7 @@
- func: logsumexp_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor
variants: function


- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, double margin=0.0, int64_t reduction=Reduction::ElementwiseMean) -> Tensor
variants: function

Expand Down
3 changes: 3 additions & 0 deletions aten/src/THC/THCTensorRandom.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

#include "generic/THCTensorRandom.h"
#include "THCGenerateAllTypes.h"
#ifdef __HIP_PLATFORM_HCC__
#include <hiprand_kernel.h>
#endif

typedef struct THCGenerator THCGenerator;

Expand Down
4 changes: 2 additions & 2 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ endif()
if(USE_ROCM)
# Call again since Caffe2_HIP_INCLUDES is extended with ATen include dirs.
if(BUILD_ATEN)
# Get Compile Definitions from the directory (FindHIP.CMake bug)
# Get Compile Definitions from the directory (FindHIP.cmake bug)
get_directory_property(MY_DEFINITIONS COMPILE_DEFINITIONS)
if(MY_DEFINITIONS)
foreach(_item ${MY_DEFINITIONS})
Expand All @@ -364,7 +364,7 @@ if(USE_ROCM)
ENDIF()

# FindHIP.CMake checks if the SHARED flag is set and adds extra logic accordingly.
hip_add_library(caffe2_hip SHARED ${Caffe2_HIP_SRCS})
hip_add_library(caffe2_hip ${Caffe2_HIP_SRCS})

# Since PyTorch files contain HIP headers, these flags are required for the necessary definitions to be added.
set_target_properties(caffe2_hip PROPERTIES COMPILE_FLAGS ${HIP_HIPCC_FLAGS})
Expand Down
11 changes: 10 additions & 1 deletion caffe2/contrib/warpctc/ctc_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/operator.h"

#ifdef CAFFE2_USE_IDEEP
#include <caffe2/ideep/operators/operator_fallback_ideep.h>
#include <caffe2/ideep/utils/ideep_operator.h>
#endif

namespace caffe2 {

namespace detail {
Expand All @@ -17,9 +22,13 @@ ctcComputeInfo workspaceInfo<CPUContext>(const CPUContext& /*context*/) {
}

REGISTER_CPU_OPERATOR(CTC, CTCOp<float, CPUContext>);
OPERATOR_SCHEMA(CTC).NumInputs(4).NumOutputs(2, 3);
OPERATOR_SCHEMA(CTC).NumInputs(3, 4).NumOutputs(2, 3);
// .EnforceInputOutputGradient({{0, 0}});

#ifdef CAFFE2_USE_IDEEP
REGISTER_IDEEP_OPERATOR(CTC, IDEEPFallbackOp<CTCOp<float, CPUContext>>);
#endif

namespace {
class GetCTCGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
Expand Down
33 changes: 28 additions & 5 deletions caffe2/contrib/warpctc/ctc_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,24 @@ class CTCOp final : public Operator<Context> {
bool RunOnDevice() override {
// inputs
const auto& inputs = Input(INPUTS);
const auto maxTimeSteps = inputs.dim(0);
const auto minibatchSize = inputs.dim(1);
const auto alphabetSize = inputs.dim(2);
const auto& labels = OperatorBase::template Input<Tensor>(LABELS, CPU);
const auto& labelLengths =
OperatorBase::template Input<Tensor>(LABEL_LENGTHS, CPU);
const auto& inputLengths =
OperatorBase::template Input<Tensor>(INPUT_LENGTHS, CPU);

const int* inputLengthsData = nullptr;
if (InputSize() == 4) {
const auto& inputLengths =
OperatorBase::template Input<Tensor>(INPUT_LENGTHS, CPU);
inputLengthsData = inputLengths.template data<int>();
} else {
// Input lengths not passed in. Default to max timesteps for
// each item in minibatch.
default_input_lengths_.resize(minibatchSize, maxTimeSteps);
inputLengthsData = default_input_lengths_.data();
}

// outputs
Tensor* gradients = nullptr;
Expand All @@ -74,28 +85,40 @@ class CTCOp final : public Operator<Context> {
size_t workspaceSizeBytes;
CTC_CHECK(get_workspace_size(
labelLengths.template data<int>(),
inputLengths.template data<int>(),
inputLengthsData,
alphabetSize,
minibatchSize,
detail::workspaceInfo(context_),
&workspaceSizeBytes));
workspace->Resize(workspaceSizeBytes);
auto* workspaceData = workspace->template mutable_data<uint8_t>();

if (is_test_ && labels.dim(0) == 0) {
// compute_ctc_loss doesn't handle empty labels well
T* costsData = costs->template mutable_data<T>();
for (int i = 0; i < costs->size(); ++i) {
costsData[i] = 0;
}
return true;
}

CTC_CHECK(compute_ctc_loss(
inputs.template data<T>(),
gradients ? gradients->template mutable_data<T>() : nullptr,
labels.template data<int>(),
labelLengths.template data<int>(),
inputLengths.template data<int>(),
inputLengthsData,
alphabetSize,
minibatchSize,
costs->template mutable_data<T>(),
workspace->template mutable_data<uint8_t>(),
workspaceData,
detail::workspaceInfo(context_)));
return true;
}

private:
bool is_test_;
std::vector<int> default_input_lengths_;

INPUT_TAGS(INPUTS, LABELS, LABEL_LENGTHS, INPUT_LENGTHS);
};
Expand Down
Loading