diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index e280db56c42257..241e3ba6735dd5 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -30,7 +30,6 @@ cmake --version pip install -r requirements.txt || true if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then - # This is necessary in order to cross compile (or else we'll have missing GPU device). export MAX_JOBS=4 # This is necessary in order to cross compile (or else we'll have missing GPU device). export HCC_AMDGPU_TARGET=gfx900 diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 25a2e6d8b501f0..de876639923227 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -257,9 +257,9 @@ ENDIF() IF(USE_ROCM) ### Link in the ROCm libraries BLAS / RNG. FIND_LIBRARY(HIPBLAS_LIBRARY hipblas HINTS ${HIPBLAS_PATH}/lib) - FIND_LIBRARY(HIPRNG_LIBRARY hcrng HINTS ${HIPRNG_PATH}/lib) + FIND_LIBRARY(HIPRAND_LIBRARY hiprand HINTS ${HIPRAND_PATH}/lib) - list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${HIPBLAS_LIBRARY} ${HIPRNG_LIBRARY}) + list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${HIPBLAS_LIBRARY} ${HIPRAND_LIBRARY}) ENDIF() # Include CPU paths for CUDA as well diff --git a/aten/src/ATen/cuda/CUDAApplyUtils.cuh b/aten/src/ATen/cuda/CUDAApplyUtils.cuh index a30bed575d3d3c..0673a88b051998 100644 --- a/aten/src/ATen/cuda/CUDAApplyUtils.cuh +++ b/aten/src/ATen/cuda/CUDAApplyUtils.cuh @@ -1,6 +1,6 @@ #pragma once -#include "detail/IndexUtils.cuh" +#include "ATen/cuda/detail/IndexUtils.cuh" #include "ATen/TensorUtils.h" #include "THC/THCAtomics.cuh" #include "ATen/cuda/CUDAContext.h" diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu new file mode 100644 index 00000000000000..2d133a70dc23b6 --- /dev/null +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -0,0 +1,162 @@ +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAApplyUtils.cuh" +#include "ATen/cuda/detail/IndexUtils.cuh" +#include "ATen/cuda/detail/TensorInfo.cuh" +#include "curand_kernel.h" + +#include +#include +#include + + +THCGenerator* THCRandom_getGenerator(THCState* state); + +namespace at{ +namespace native{ + +namespace { + +// philox generates 128 bits of randomness at a time. Kernel uses this explicitly by putting suitably transformed result into float4 +// for all members of float4 to be consumed UNROLL has to be 4. Don't change! +const int UNROLL = 4; + +std::pair next_philox_seed(at::Generator* gen, uint64_t increment) { + auto gen_ = THCRandom_getGenerator(at::globalContext().getTHCState()); + uint64_t offset = gen_->state.philox_seed_offset.fetch_add(increment); + return std::make_pair(gen_->state.initial_seed, offset); +} + + +template < + typename scalar_t, + typename accscalar_t, + typename IndexType, + int ADims> +#if __CUDA_ARCH__ >= 350 +__launch_bounds__(256,8) +#endif +__global__ void +fused_dropout_kernel(cuda::detail::TensorInfo a, + cuda::detail::TensorInfo b, + cuda::detail::TensorInfo c, + IndexType totalElements, accscalar_t p, std::pair seeds + ) { + + accscalar_t pinv = accscalar_t(1)/p; + IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init( + seeds.first, + idx, + seeds.second, + &state); + IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * + blockDim.x * gridDim.x * UNROLL; + for (IndexType linearIndex = idx; + linearIndex < rounded_size; + linearIndex += gridDim.x * blockDim.x*UNROLL) { +//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything + float4 rand = curand_uniform4(&state); + scalar_t src[UNROLL]; + rand.x = rand.x < p; + rand.y = rand.y < p; + rand.z = rand.z < p; + rand.w = rand.w < p; + for (int ii = 0; ii < UNROLL; ii++) { + IndexType li = linearIndex + blockDim.x * gridDim.x * ii; + if (li < totalElements) { + // Convert `linearIndex` into an offset of `a` + const IndexType aOffset = + cuda::detail::IndexToOffset::get(li, a); + src[ii] = a.data[aOffset]; + } + } + for (int ii = 0; ii < UNROLL; ii++) { + IndexType li = linearIndex + blockDim.x * gridDim.x * ii; + if (li < totalElements) { + // Convert `linearIndex` into an offset of `b` + const IndexType bOffset = + cuda::detail::IndexToOffset::get(li, b); + b.data[bOffset] = src[ii]*(&rand.x)[ii]*pinv; + c.data[bOffset] = (uint8_t)(&rand.x)[ii]; + } + } + __syncthreads(); + } +} + +template +void masked_scale_kernel(at::Tensor& ret, const at::Tensor src, const at::Tensor mask, accscalar_t scale){ + at::cuda::CUDA_tensor_apply3(ret, src, mask, [scale]__device__(scalar_t& ret_val, const scalar_t& src_val, const uint8_t mask_val){ + ret_val = (float)mask_val * src_val * scale; + }); +} +} //anonymous namespace + +std::tuple +fused_dropout_cuda(const Tensor& self, double p, Generator * gen){ + Tensor ret = at::empty_like(self); + Tensor mask = self.type().toScalarType(kByte).tensor(self.sizes()); + const int64_t nelem = self.numel(); + const int64_t block_size = 256; + unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; + dim3 dim_block(block_size); + dim3 grid((nelem + block_size -1)/block_size); + grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); +//number of times random will be generated per thread, to offset philox counter in thc random state + int64_t counter_offset = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; + if (cuda::detail::canUse32BitIndexMath(self)){ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] { + using accscalar_t = acc_type; + accscalar_t pa = (accscalar_t)(p); + auto self_info = cuda::detail::getTensorInfo(self); + auto ret_info = cuda::detail::getTensorInfo(ret); + auto mask_info = cuda::detail::getTensorInfo(mask); + self_info.collapseDims(); + ret_info.collapseDims(); + mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor + switch (self_info.dims) { + case 1: + fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,counter_offset)); + break; + default: + fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,counter_offset)); + } + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "fused_dropout", [&] { + using accscalar_t = acc_type; + accscalar_t pa = (accscalar_t)(p); + auto self_info = cuda::detail::getTensorInfo(self); + auto ret_info = cuda::detail::getTensorInfo(ret); + auto mask_info = cuda::detail::getTensorInfo(mask); + self_info.collapseDims(); + ret_info.collapseDims(); + mask_info.collapseDims(); //ret and mask are collapsed to 1d contiguous tensor + switch (self_info.dims) { + case 1: + fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,counter_offset)); + break; + default: + fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, next_philox_seed(gen,counter_offset)); + } + }); + } + THCudaCheck(cudaGetLastError()); + return std::tuple(ret, mask); +} + +Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){ + Tensor ret = at::empty_like(self); + AT_CHECK(mask.type().scalarType() == at::ScalarType::Byte, "mask should be torch.uint8 dtype"); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.type(), "masked_scale", [&] { + using accscalar_t = acc_type; + accscalar_t pa = (accscalar_t)(scale); + masked_scale_kernel(ret, self, mask, pa); + }); + return ret; +} + +} +} diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a86ab6cb41ba52..1361848c832fd5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -54,6 +54,14 @@ dispatch: CUDA: _cudnn_init_dropout_state +- func: _fused_dropout(Tensor self, double p, Generator* generator=nullptr) -> (Tensor, Tensor) + dispatch: + CUDA: fused_dropout_cuda + +- func: _masked_scale(Tensor self, Tensor mask, double scale) -> Tensor + dispatch: + CUDA: masked_scale_cuda + - func: abs(Tensor self) -> Tensor - func: abs_(Tensor self) -> Tensor @@ -1003,6 +1011,7 @@ - func: logsumexp_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor variants: function + - func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, double margin=0.0, int64_t reduction=Reduction::ElementwiseMean) -> Tensor variants: function diff --git a/aten/src/THC/THCTensorRandom.h b/aten/src/THC/THCTensorRandom.h index 5203df28c78e68..1ee539e52c9cd6 100644 --- a/aten/src/THC/THCTensorRandom.h +++ b/aten/src/THC/THCTensorRandom.h @@ -5,6 +5,9 @@ #include "generic/THCTensorRandom.h" #include "THCGenerateAllTypes.h" +#ifdef __HIP_PLATFORM_HCC__ +#include +#endif typedef struct THCGenerator THCGenerator; diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 123244d220665f..7e5dd07b62c388 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -348,7 +348,7 @@ endif() if(USE_ROCM) # Call again since Caffe2_HIP_INCLUDES is extended with ATen include dirs. if(BUILD_ATEN) - # Get Compile Definitions from the directory (FindHIP.CMake bug) + # Get Compile Definitions from the directory (FindHIP.cmake bug) get_directory_property(MY_DEFINITIONS COMPILE_DEFINITIONS) if(MY_DEFINITIONS) foreach(_item ${MY_DEFINITIONS}) @@ -364,7 +364,7 @@ if(USE_ROCM) ENDIF() # FindHIP.CMake checks if the SHARED flag is set and adds extra logic accordingly. - hip_add_library(caffe2_hip SHARED ${Caffe2_HIP_SRCS}) + hip_add_library(caffe2_hip ${Caffe2_HIP_SRCS}) # Since PyTorch files contain HIP headers, these flags are required for the necessary definitions to be added. set_target_properties(caffe2_hip PROPERTIES COMPILE_FLAGS ${HIP_HIPCC_FLAGS}) diff --git a/caffe2/contrib/warpctc/ctc_op.cpp b/caffe2/contrib/warpctc/ctc_op.cpp index 0d0176623a5934..9d06e4041b69fa 100644 --- a/caffe2/contrib/warpctc/ctc_op.cpp +++ b/caffe2/contrib/warpctc/ctc_op.cpp @@ -2,6 +2,11 @@ #include "caffe2/core/context_gpu.h" #include "caffe2/core/operator.h" +#ifdef CAFFE2_USE_IDEEP +#include +#include +#endif + namespace caffe2 { namespace detail { @@ -17,9 +22,13 @@ ctcComputeInfo workspaceInfo(const CPUContext& /*context*/) { } REGISTER_CPU_OPERATOR(CTC, CTCOp); -OPERATOR_SCHEMA(CTC).NumInputs(4).NumOutputs(2, 3); +OPERATOR_SCHEMA(CTC).NumInputs(3, 4).NumOutputs(2, 3); // .EnforceInputOutputGradient({{0, 0}}); +#ifdef CAFFE2_USE_IDEEP +REGISTER_IDEEP_OPERATOR(CTC, IDEEPFallbackOp>); +#endif + namespace { class GetCTCGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; diff --git a/caffe2/contrib/warpctc/ctc_op.h b/caffe2/contrib/warpctc/ctc_op.h index 6c27c907726b82..f94fa177ae01e9 100644 --- a/caffe2/contrib/warpctc/ctc_op.h +++ b/caffe2/contrib/warpctc/ctc_op.h @@ -45,13 +45,24 @@ class CTCOp final : public Operator { bool RunOnDevice() override { // inputs const auto& inputs = Input(INPUTS); + const auto maxTimeSteps = inputs.dim(0); const auto minibatchSize = inputs.dim(1); const auto alphabetSize = inputs.dim(2); const auto& labels = OperatorBase::template Input(LABELS, CPU); const auto& labelLengths = OperatorBase::template Input(LABEL_LENGTHS, CPU); - const auto& inputLengths = - OperatorBase::template Input(INPUT_LENGTHS, CPU); + + const int* inputLengthsData = nullptr; + if (InputSize() == 4) { + const auto& inputLengths = + OperatorBase::template Input(INPUT_LENGTHS, CPU); + inputLengthsData = inputLengths.template data(); + } else { + // Input lengths not passed in. Default to max timesteps for + // each item in minibatch. + default_input_lengths_.resize(minibatchSize, maxTimeSteps); + inputLengthsData = default_input_lengths_.data(); + } // outputs Tensor* gradients = nullptr; @@ -74,28 +85,40 @@ class CTCOp final : public Operator { size_t workspaceSizeBytes; CTC_CHECK(get_workspace_size( labelLengths.template data(), - inputLengths.template data(), + inputLengthsData, alphabetSize, minibatchSize, detail::workspaceInfo(context_), &workspaceSizeBytes)); workspace->Resize(workspaceSizeBytes); + auto* workspaceData = workspace->template mutable_data(); + + if (is_test_ && labels.dim(0) == 0) { + // compute_ctc_loss doesn't handle empty labels well + T* costsData = costs->template mutable_data(); + for (int i = 0; i < costs->size(); ++i) { + costsData[i] = 0; + } + return true; + } + CTC_CHECK(compute_ctc_loss( inputs.template data(), gradients ? gradients->template mutable_data() : nullptr, labels.template data(), labelLengths.template data(), - inputLengths.template data(), + inputLengthsData, alphabetSize, minibatchSize, costs->template mutable_data(), - workspace->template mutable_data(), + workspaceData, detail::workspaceInfo(context_))); return true; } private: bool is_test_; + std::vector default_input_lengths_; INPUT_TAGS(INPUTS, LABELS, LABEL_LENGTHS, INPUT_LENGTHS); }; diff --git a/caffe2/contrib/warpctc/ctc_ops_test.py b/caffe2/contrib/warpctc/ctc_ops_test.py index 602ab285f3bcb4..25bb0a39e3a965 100644 --- a/caffe2/contrib/warpctc/ctc_ops_test.py +++ b/caffe2/contrib/warpctc/ctc_ops_test.py @@ -19,7 +19,7 @@ def softmax(w): class CTCOpsTest(test_util.TestCase): - def verify_cost(self, device_option, is_test): + def verify_cost(self, device_option, is_test, skip_input_lengths=False): alphabet_size = 5 N = 1 T = 2 @@ -36,9 +36,12 @@ def verify_cost(self, device_option, is_test): input_lengths = np.asarray([T]).astype(np.int32) net = core.Net("test-net") + input_blobs = ["inputs", "labels", "label_lengths"] + if not skip_input_lengths: + input_blobs.append("input_lengths") output_blobs = ["costs", "workspace"] if is_test \ else ["inputs_grad_to_be_copied", "costs", "workspace"] - net.CTC(["inputs", "labels", "label_lengths", "input_lengths"], + net.CTC(input_blobs, output_blobs, is_test=is_test, device_option=device_option) @@ -47,7 +50,8 @@ def verify_cost(self, device_option, is_test): self.ws.create_blob("inputs").feed(inputs, device_option=device_option) self.ws.create_blob("labels").feed(labels) self.ws.create_blob("label_lengths").feed(label_lengths) - self.ws.create_blob("input_lengths").feed(input_lengths) + if not skip_input_lengths: + self.ws.create_blob("input_lengths").feed(input_lengths) self.ws.run(net) probs = softmax(inputs) expected = probs[0, 0, 1] * probs[1, 0, 2] @@ -68,20 +72,37 @@ def test_ctc_cost_cpu(self): self.verify_cost( caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU), is_test=False) + self.verify_cost( + caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU), + is_test=False, skip_input_lengths=True) def test_ctc_cost_gpu(self): self.verify_cost( caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA, cuda_gpu_id=0), is_test=False) + self.verify_cost( + caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA, + cuda_gpu_id=0), + is_test=False, + skip_input_lengths=True) def test_ctc_forward_only_cpu(self): self.verify_cost( caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU), is_test=True) + self.verify_cost( + caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU), + is_test=True, + skip_input_lengths=True) def test_ctc_forward_only_gpu(self): self.verify_cost( caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA, cuda_gpu_id=0), is_test=True) + self.verify_cost( + caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA, + cuda_gpu_id=0), + is_test=True, + skip_input_lengths=True) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 1cbcf6486e9b43..9991082800a250 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -540,7 +540,7 @@ if(BUILD_CAFFE2 OR BUILD_ATEN) ${rocrand_LIBRARIES} ${hiprand_LIBRARIES} ${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipblas_LIBRARIES}) # Additional libraries required by PyTorch AMD that aren't used by Caffe2 (not in Caffe2's docker image) if(BUILD_ATEN) - set(Caffe2_HIP_DEPENDENCY_LIBS ${Caffe2_HIP_DEPENDENCY_LIBS} ${hipsparse_LIBRARIES} ${hiprng_LIBRARIES}) + set(Caffe2_HIP_DEPENDENCY_LIBS ${Caffe2_HIP_DEPENDENCY_LIBS} ${hipsparse_LIBRARIES}) endif() # TODO: There is a bug in rocblas's cmake files that exports the wrong targets name in ${rocblas_LIBRARIES} list(APPEND Caffe2_HIP_DEPENDENCY_LIBS @@ -555,7 +555,8 @@ if(USE_ROCM AND NOT BUILD_CAFFE2) include_directories(SYSTEM ${HIP_PATH}/include) include_directories(SYSTEM ${HIPBLAS_PATH}/include) include_directories(SYSTEM ${HIPSPARSE_PATH}/include) - include_directories(SYSTEM ${HIPRNG_PATH}/include) + include_directories(SYSTEM ${HIPRAND_PATH}/include) + include_directories(SYSTEM ${ROCRAND_PATH}/include) include_directories(SYSTEM ${THRUST_PATH}) # load HIP cmake module and load platform id diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 481f812852afe7..0ac1a14ed332f7 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -45,13 +45,6 @@ ELSE() SET(ROCBLAS_PATH $ENV{ROCBLAS_PATH}) ENDIF() -# HIPRNG_PATH -IF(NOT DEFINED ENV{HIPRNG_PATH}) - SET(HIPRNG_PATH ${ROCM_PATH}/hcrng) -ELSE() - SET(HIPRNG_PATH $ENV{HIPRNG_PATH}) -ENDIF() - # HIPSPARSE_PATH IF(NOT DEFINED ENV{HIPSPARSE_PATH}) SET(HIPSPARSE_PATH ${ROCM_PATH}/hcsparse) @@ -138,7 +131,6 @@ IF(HIP_FOUND) # however currently it's just the lib name FIND_LIBRARY(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib) FIND_LIBRARY(hiprand_LIBRARIES hiprand HINTS ${HIPRAND_PATH}/lib) - FIND_LIBRARY(hiprng_LIBRARIES hcrng HINTS ${HIPRNG_PATH}/lib) FIND_LIBRARY(hipblas_LIBRARIES hipblas HINTS ${HIPBLAS_PATH}/lib) FIND_LIBRARY(hipsparse_LIBRARIES hipsparse HINTS ${HIPSPARSE_PATH}/lib) diff --git a/docker/caffe2/jenkins/common/install_rocm.sh b/docker/caffe2/jenkins/common/install_rocm.sh index f76cf90f92657f..a0bb222cc8ec2d 100644 --- a/docker/caffe2/jenkins/common/install_rocm.sh +++ b/docker/caffe2/jenkins/common/install_rocm.sh @@ -24,12 +24,6 @@ install_ubuntu() { rocm-profiler \ cxlactivitylogger - pushd $HOME - # install hcrng - curl https://s3.amazonaws.com/ossci-linux/hcrng-master-a8c6a0b-Linux.deb -o hcrng.deb - dpkg -i hcrng.deb - rm hcrng.deb - # hotfix a bug in hip's cmake files, this has been fixed in # https://github.com/ROCm-Developer-Tools/HIP/pull/516 but for # some reason it has not included in the latest rocm release @@ -56,13 +50,6 @@ install_hip_thrust() { git clone --recursive https://github.com/ROCmSoftwarePlatform/cub-hip.git /data/Thrust/thrust/system/cuda/detail/cub-hip } -# This will be removed after merging an upcoming PR. -install_hcrng() { - mkdir -p /opt/rocm/debians - curl https://s3.amazonaws.com/ossci-linux/hcrng-master-a8c6a0b-Linux.deb -o /opt/rocm/debians/hcrng.deb - dpkg -i /opt/rocm/debians/hcrng.deb -} - # This will be removed after merging an upcoming PR. install_hcsparse() { mkdir -p /opt/rocm/debians @@ -88,6 +75,5 @@ else fi install_hip_thrust -install_hcrng install_rocrand install_hcsparse diff --git a/setup.py b/setup.py index 47815d19ae5c7c..6a14f09e654c2c 100644 --- a/setup.py +++ b/setup.py @@ -884,12 +884,16 @@ def run(self): hcc_include_path = '/opt/rocm/hcc/include' hipblas_include_path = '/opt/rocm/hipblas/include' hipsparse_include_path = '/opt/rocm/hcsparse/include' + hiprand_include_path = '/opt/rocm/hiprand/include' + rocrand_include_path = '/opt/rocm/rocrand/include' hip_lib_path = '/opt/rocm/hip/lib' hcc_lib_path = '/opt/rocm/hcc/lib' include_dirs.append(rocm_include_path) include_dirs.append(hcc_include_path) include_dirs.append(hipblas_include_path) include_dirs.append(hipsparse_include_path) + include_dirs.append(hiprand_include_path) + include_dirs.append(rocrand_include_path) include_dirs.append(tmp_install_path + "/include/THCUNN") extra_link_args.append('-L' + hip_lib_path) extra_link_args.append('-Wl,-rpath,' + hip_lib_path) diff --git a/test/common.py b/test/common.py index 5c7999b6f986f4..f2df87969efb67 100644 --- a/test/common.py +++ b/test/common.py @@ -97,6 +97,7 @@ def _check_module_exists(name): if TEST_NUMPY: import numpy + def skipIfRocm(fn): @wraps(fn) def wrapper(*args, **kwargs): @@ -115,7 +116,6 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper - def skipIfNoLapack(fn): @wraps(fn) def wrapper(*args, **kwargs): diff --git a/test/test_nn.py b/test/test_nn.py index c8d54c58b8e5e5..5bd646661760a5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -732,9 +732,10 @@ def test_invalid_conv3d(self): input = torch.empty(1, 1, 4, 4, 4) self.assertRaises(RuntimeError, lambda: module(input)) - def _test_dropout(self, cls, input): + def _test_dropout(self, cls, cuda, input): p = 0.2 - input.fill_(1 - p) + device = torch.device("cuda") if cuda else torch.device("cpu") + input = input.to(device).fill_(1 - p) module = cls(p) input_var = torch.tensor(input, requires_grad=True) @@ -2098,7 +2099,7 @@ def func(x): def test_Dropout(self): input = torch.Tensor(1000) - self._test_dropout(nn.Dropout, input) + self._test_dropout(nn.Dropout, False, input) def test_Dropout2d(self): b = random.randint(1, 5) @@ -2106,7 +2107,7 @@ def test_Dropout2d(self): h = random.randint(1, 5) num_features = 1000 input = torch.Tensor(num_features, b, w, h) - self._test_dropout(nn.Dropout2d, input) + self._test_dropout(nn.Dropout2d, False, input) def test_Dropout3d(self): b = random.randint(1, 5) @@ -2115,7 +2116,31 @@ def test_Dropout3d(self): d = random.randint(1, 2) num_features = 1000 input = torch.Tensor(num_features, b, d, w, h) - self._test_dropout(nn.Dropout3d, input) + self._test_dropout(nn.Dropout3d, False, input) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_Dropout_cuda(self): + input = torch.Tensor(1000) + self._test_dropout(nn.Dropout, True, input) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_Dropout2d_cuda(self): + b = random.randint(1, 5) + w = random.randint(1, 5) + h = random.randint(1, 5) + num_features = 1000 + input = torch.Tensor(num_features, b, w, h) + self._test_dropout(nn.Dropout2d, True, input) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_Dropout3d_cuda(self): + b = random.randint(1, 5) + w = random.randint(1, 5) + h = random.randint(1, 5) + d = random.randint(1, 2) + num_features = 1000 + input = torch.Tensor(num_features, b, d, w, h) + self._test_dropout(nn.Dropout3d, True, input) def test_AlphaDropout(self): # generate random tensor with zero mean and unit std diff --git a/tools/amd_build/disabled_features.yaml b/tools/amd_build/disabled_features.yaml index 1554c839ca7e2b..703ece84b2e780 100644 --- a/tools/amd_build/disabled_features.yaml +++ b/tools/amd_build/disabled_features.yaml @@ -78,6 +78,31 @@ "struct curandStateMtgp32*": "curandStateMtgp32*" } }, + { + "path": "aten/src/THC/THCTensorRandom.cu", + "s_constants": { + "struct curandStateMtgp32*": "curandStateMtgp32*" + } + }, + { + "path": "aten/src/THC/THCTensorRandom.h", + "s_constants": { + "struct curandStateMtgp32*": "curandStateMtgp32*" + } + }, + { + "path": "aten/src/THCUNN/generic/RReLU.cu", + "s_constants": { + "struct curandStateMtgp32*": "curandStateMtgp32*" + } + }, + { + "path": "aten/src/THC/THCGenerator.hpp", + "s_constants": { + "struct curandStateMtgp32*": "curandStateMtgp32*", + "struct mtgp32_kernel_params": "mtgp32_kernel_params" + } + }, { "path": "aten/src/ATen/native/cuda/CuFFTUtils.h", "s_constants": { @@ -126,7 +151,8 @@ "aten/src/ATen/native/cuda/CuFFTUtils.h", "aten/src/ATen/native/cuda/CuFFTPlanCache.h", "aten/src/ATen/native/cuda/SpectralOps.cu", - "aten/src/ATen/native/cuda/Distributions.cu" + "aten/src/ATen/native/cuda/Distributions.cu", + "aten/src/ATen/native/cuda/Dropout.cu" ], "disabled_functions": [ { @@ -187,6 +213,13 @@ "functions": [ "THCTensor_(getTextureObject)" ] + }, + { + "path": "aten/src/THC/THCTensorRandom.cu", + "functions": [ + "THCRandom_setRNGState", + "set_rngstate_kernel" + ] } ] } diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py index 8863337c016e9b..0348a95a8cf44e 100644 --- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py +++ b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py @@ -200,45 +200,45 @@ "cublasPointerMode_t": ("hipblasPointerMode_t", CONV_TYPE, API_BLAS), "cublasAtomicsMode_t": ("hipblasAtomicsMode_t", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED), "cublasDataType_t": ("hipblasDataType_t", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED), - "curandStatus": ("hiprngStatus_t", CONV_TYPE, API_RAND), - "curandStatus_t": ("hiprngStatus_t", CONV_TYPE, API_RAND), - "curandRngType": ("hiprngRngType_t", CONV_TYPE, API_RAND), - "curandRngType_t": ("hiprngRngType_t", CONV_TYPE, API_RAND), - "curandGenerator_st": ("hiprngGenerator_st", CONV_TYPE, API_RAND), - "curandGenerator_t": ("hiprngGenerator_t", CONV_TYPE, API_RAND), - "curandDirectionVectorSet": ("hiprngDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandDirectionVectorSet_t": ("hiprngDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandOrdering": ("hiprngOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandOrdering_t": ("hiprngOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandDistribution_st": ("hiprngDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandHistogramM2V_st": ("hiprngDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandDistribution_t": ("hiprngDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandHistogramM2V_t": ("hiprngDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandDistributionShift_st": ("hiprngDistributionShift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandDistributionShift_t": ("hiprngDistributionShift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandDistributionM2Shift_st": ("hiprngDistributionM2Shift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandDistributionM2Shift_t": ("hiprngDistributionM2Shift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandHistogramM2_st": ("hiprngHistogramM2_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandHistogramM2_t": ("hiprngHistogramM2_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandHistogramM2K_st": ("hiprngHistogramM2K_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandHistogramM2K_t": ("hiprngHistogramM2K_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandDiscreteDistribution_st": ("hiprngDiscreteDistribution_st", CONV_TYPE, API_RAND), - "curandDiscreteDistribution_t": ("hiprngDiscreteDistribution_t", CONV_TYPE, API_RAND), - "curandMethod": ("hiprngMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandMethod_t": ("hiprngMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandDirectionVectors32_t": ("hiprngDirectionVectors32_t", CONV_TYPE, API_RAND), - "curandDirectionVectors64_t": ("hiprngDirectionVectors64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandStateMtgp32_t": ("hiprngStateMtgp32_t", CONV_TYPE, API_RAND), - "curandStateMtgp32": ("hcrngStateMtgp32", CONV_TYPE, API_RAND), - "curandStateScrambledSobol64_t": ("hiprngStateScrambledSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandStateSobol64_t": ("hiprngStateSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandStateScrambledSobol32_t": ("hiprngStateScrambledSobol32_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), - "curandStateSobol32_t": ("hiprngStateSobol32_t", CONV_TYPE, API_RAND), - "curandStateMRG32k3a_t": ("hiprngStateMRG32k3a_t", CONV_TYPE, API_RAND), - "curandStatePhilox4_32_10_t": ("hiprngStatePhilox4_32_10_t", CONV_TYPE, API_RAND), - "curandStateXORWOW_t": ("hiprngStateXORWOW_t", CONV_TYPE, API_RAND), - "curandState_t": ("hiprngState_t", CONV_TYPE, API_RAND), - "curandState": ("hiprngState_t", CONV_TYPE, API_RAND) + "curandStatus": ("hiprandStatus_t", CONV_TYPE, API_RAND), + "curandStatus_t": ("hiprandStatus_t", CONV_TYPE, API_RAND), + "curandRngType": ("hiprandRngType_t", CONV_TYPE, API_RAND), + "curandRngType_t": ("hiprandRngType_t", CONV_TYPE, API_RAND), + "curandGenerator_st": ("hiprandGenerator_st", CONV_TYPE, API_RAND), + "curandGenerator_t": ("hiprandGenerator_t", CONV_TYPE, API_RAND), + "curandDirectionVectorSet": ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandDirectionVectorSet_t": ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandOrdering": ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandOrdering_t": ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandDistribution_st": ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandHistogramM2V_st": ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandDistribution_t": ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandHistogramM2V_t": ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandDistributionShift_st": ("hiprandDistributionShift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandDistributionShift_t": ("hiprandDistributionShift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandDistributionM2Shift_st": ("hiprandDistributionM2Shift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandDistributionM2Shift_t": ("hiprandDistributionM2Shift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandHistogramM2_st": ("hiprandHistogramM2_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandHistogramM2_t": ("hiprandHistogramM2_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandHistogramM2K_st": ("hiprandHistogramM2K_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandHistogramM2K_t": ("hiprandHistogramM2K_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandDiscreteDistribution_st": ("hiprandDiscreteDistribution_st", CONV_TYPE, API_RAND), + "curandDiscreteDistribution_t": ("hiprandDiscreteDistribution_t", CONV_TYPE, API_RAND), + "curandMethod": ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandMethod_t": ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandDirectionVectors32_t": ("hiprandDirectionVectors32_t", CONV_TYPE, API_RAND), + "curandDirectionVectors64_t": ("hiprandDirectionVectors64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandStateMtgp32_t": ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND), + "curandStateMtgp32": ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND), + "curandStateScrambledSobol64_t": ("hiprandStateScrambledSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandStateSobol64_t": ("hiprandStateSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandStateScrambledSobol32_t": ("hiprandStateScrambledSobol32_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + "curandStateSobol32_t": ("hiprandStateSobol32_t", CONV_TYPE, API_RAND), + "curandStateMRG32k3a_t": ("hiprandStateMRG32k3a_t", CONV_TYPE, API_RAND), + "curandStatePhilox4_32_10_t": ("hiprandStatePhilox4_32_10_t", CONV_TYPE, API_RAND), + "curandStateXORWOW_t": ("hiprandStateXORWOW_t", CONV_TYPE, API_RAND), + "curandState_t": ("hiprandState_t", CONV_TYPE, API_RAND), + "curandState": ("hiprandState_t", CONV_TYPE, API_RAND) } CUDA_INCLUDE_MAP = { @@ -254,23 +254,23 @@ "vector_types.h": ("hip/hip_vector_types.h", CONV_INCLUDE, API_RUNTIME), "cublas.h": ("hipblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS), "cublas_v2.h": ("hipblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS), - "curand.h": ("hiprng.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND), - "curand_kernel.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_discrete.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_discrete2.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_globals.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_lognormal.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_mrg32k3a.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_mtgp32.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_mtgp32_host.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_mtgp32_kernel.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_mtgp32dc_p_11213.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_normal.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_normal_static.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_philox4x32_x.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_poisson.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_precalc.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), - "curand_uniform.h": ("hiprng_kernel.h", CONV_INCLUDE, API_RAND), + "curand.h": ("hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND), + "curand_kernel.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_discrete.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_discrete2.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_globals.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_lognormal.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_mrg32k3a.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_mtgp32.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_mtgp32_host.h": ("hiprand_mtgp32_host.h", CONV_INCLUDE, API_RAND), + "curand_mtgp32_kernel.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_mtgp32dc_p_11213.h": ("rocrand_mtgp32_11213.h", CONV_INCLUDE, API_RAND), + "curand_normal.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_normal_static.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_philox4x32_x.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_poisson.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_precalc.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), + "curand_uniform.h": ("hiprand_kernel.h", CONV_INCLUDE, API_RAND), "cusparse.h": ("hipsparse.h", CONV_INCLUDE, API_RAND), "#include ": ("", CONV_INCLUDE, API_RAND, HIP_UNSUPPORTED), "#include ": ("", CONV_INCLUDE, API_RAND, HIP_UNSUPPORTED), @@ -1992,109 +1992,110 @@ "cublasDrotm_v2": ("hipblasDrotm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), "cublasSrotmg_v2": ("hipblasSrotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), "cublasDrotmg_v2": ("hipblasDrotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), - "CURAND_STATUS_SUCCESS": ("HIPRNG_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_VERSION_MISMATCH": ("hiprng_STATUS_VERSION_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_NOT_INITIALIZED": ("hiprng_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_ALLOCATION_FAILED": ("hiprng_STATUS_ALLOCATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_TYPE_ERROR": ("hiprng_STATUS_TYPE_ERROR", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_OUT_OF_RANGE": ("hiprng_STATUS_OUT_OF_RANGE", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_LENGTH_NOT_MULTIPLE": ("hiprng_STATUS_LENGTH_NOT_MULTIPLE", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_DOUBLE_PRECISION_REQUIRED": ("hiprng_STATUS_DOUBLE_PRECISION_REQUIRED", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_LAUNCH_FAILURE": ("hiprng_STATUS_LAUNCH_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_PREEXISTING_FAILURE": ("hiprng_STATUS_PREEXISTING_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_INITIALIZATION_FAILED": ("hiprng_STATUS_INITIALIZATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_ARCH_MISMATCH": ("hiprng_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), - "curand_STATUS_INTERNAL_ERROR": ("hiprng_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_TEST": ("hiprng_RNG_TEST", CONV_NUMERIC_LITERAL, API_RAND), - "mtgp32dc_params_fast_11213": ("mtgp32_params_fast_11213", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_PSEUDO_DEFAULT": ("hiprng_RNG_PSEUDO_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_PSEUDO_XORWOW": ("hiprng_RNG_PSEUDO_XORWOW", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_PSEUDO_MRG32K3A": ("hiprng_RNG_PSEUDO_MRG32K3A", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_PSEUDO_MTGP32": ("hiprng_RNG_PSEUDO_MTGP32", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_PSEUDO_MT19937": ("hiprng_RNG_PSEUDO_MT19937", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_PSEUDO_PHILOX4_32_10": ("hiprng_RNG_PSEUDO_PHILOX4_32_10", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_QUASI_DEFAULT": ("hiprng_RNG_QUASI_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_QUASI_SOBOL32": ("hiprng_RNG_QUASI_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_QUASI_SCRAMBLED_SOBOL32": ("hiprng_RNG_QUASI_SCRAMBLED_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_QUASI_SOBOL64": ("hiprng_RNG_QUASI_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), - "curand_RNG_QUASI_SCRAMBLED_SOBOL64": ("hiprng_RNG_QUASI_SCRAMBLED_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), - "curand_ORDERING_PSEUDO_BEST": ("hiprng_ORDERING_PSEUDO_BEST", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_ORDERING_PSEUDO_DEFAULT": ("hiprng_ORDERING_PSEUDO_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_ORDERING_PSEUDO_SEEDED": ("hiprng_ORDERING_PSEUDO_SEEDED", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_ORDERING_QUASI_DEFAULT": ("hiprng_ORDERING_QUASI_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_DIRECTION_VECTORS_32_JOEKUO6": ("hiprng_DIRECTION_VECTORS_32_JOEKUO6", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6": ("hiprng_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_DIRECTION_VECTORS_64_JOEKUO6": ("hiprng_DIRECTION_VECTORS_64_JOEKUO6", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6": ("hiprng_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_CHOOSE_BEST": ("hiprng_CHOOSE_BEST", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_ITR": ("hiprng_ITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_KNUTH": ("hiprng_KNUTH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_HITR": ("hiprng_HITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_M1": ("hiprng_M1", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_M2": ("hiprng_M2", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_BINARY_SEARCH": ("hiprng_BINARY_SEARCH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_DISCRETE_GAUSS": ("hiprng_DISCRETE_GAUSS", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_REJECTION": ("hiprng_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_DEVICE_API": ("hiprng_DEVICE_API", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_FAST_REJECTION": ("hiprng_FAST_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_3RD": ("hiprng_3RD", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_DEFINITION": ("hiprng_DEFINITION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curand_POISSON": ("hiprng_POISSON", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), - "curandCreateGenerator": ("hiprngCreateGenerator", CONV_MATH_FUNC, API_RAND), - "curandCreateGeneratorHost": ("hiprngCreateGeneratorHost", CONV_MATH_FUNC, API_RAND), - "curandCreatePoissonDistribution": ("hiprngCreatePoissonDistribution", CONV_MATH_FUNC, API_RAND), - "curandDestroyDistribution": ("hiprngDestroyDistribution", CONV_MATH_FUNC, API_RAND), - "curandDestroyGenerator": ("hiprngDestroyGenerator", CONV_MATH_FUNC, API_RAND), - "curandGenerate": ("hiprngGenerate", CONV_MATH_FUNC, API_RAND), - "curandGenerateLogNormal": ("hiprngGenerateLogNormal", CONV_MATH_FUNC, API_RAND), - "curandGenerateLogNormalDouble": ("hiprngGenerateLogNormalDouble", CONV_MATH_FUNC, API_RAND), - "curandGenerateLongLong": ("hiprngGenerateLongLong", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), - "curandGenerateNormal": ("hiprngGenerateNormal", CONV_MATH_FUNC, API_RAND), - "curandGenerateNormalDouble": ("hiprngGenerateNormalDouble", CONV_MATH_FUNC, API_RAND), - "curandGeneratePoisson": ("hiprngGeneratePoisson", CONV_MATH_FUNC, API_RAND), - "curandGenerateSeeds": ("hiprngGenerateSeeds", CONV_MATH_FUNC, API_RAND), - "curandGenerateUniform": ("hiprngGenerateUniform", CONV_MATH_FUNC, API_RAND), - "curandGenerateUniformDouble": ("hiprngGenerateUniformDouble", CONV_MATH_FUNC, API_RAND), - "curandGetDirectionVectors32": ("hiprngGetDirectionVectors32", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), - "curandGetDirectionVectors64": ("hiprngGetDirectionVectors64", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), - "curandGetProperty": ("hiprngGetProperty", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), - "curandGetScrambleConstants32": ("hiprngGetScrambleConstants32", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), - "curandGetScrambleConstants64": ("hiprngGetScrambleConstants64", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), - "curandGetVersion": ("hiprngGetVersion", CONV_MATH_FUNC, API_RAND), - "curandSetGeneratorOffset": ("hiprngSetGeneratorOffset", CONV_MATH_FUNC, API_RAND), - "curandSetGeneratorOrdering": ("hiprngSetGeneratorOrdering", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), - "curandSetPseudoRandomGeneratorSeed": ("hiprngSetPseudoRandomGeneratorSeed", CONV_MATH_FUNC, API_RAND), - "curandSetQuasiRandomGeneratorDimensions": ("hiprngSetQuasiRandomGeneratorDimensions", CONV_MATH_FUNC, API_RAND), - "curandSetStream": ("hiprngSetStream", CONV_MATH_FUNC, API_RAND), - "curand": ("hiprng", CONV_DEVICE_FUNC, API_RAND), - "curand_init": ("hiprng_init", CONV_DEVICE_FUNC, API_RAND), - "curand_log_normal": ("hiprng_log_normal", CONV_DEVICE_FUNC, API_RAND), - "curand_log_normal_double": ("hiprng_log_normal_double", CONV_DEVICE_FUNC, API_RAND), - "curand_log_normal2": ("hiprng_log_normal2", CONV_DEVICE_FUNC, API_RAND), - "curand_log_normal2_double": ("hiprng_log_normal2_double", CONV_DEVICE_FUNC, API_RAND), - "curand_log_normal4": ("hiprng_log_normal4", CONV_DEVICE_FUNC, API_RAND), - "curand_log_normal4_double": ("hiprng_log_normal4_double", CONV_DEVICE_FUNC, API_RAND), - "curand_mtgp32_single": ("hiprng_mtgp32_single", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), - "curand_mtgp32_single_specific": ("hiprng_mtgp32_single_specific", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), - "curand_mtgp32_specific": ("hiprng_mtgp32_specific", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), - "curand_normal": ("hiprng_normal", CONV_DEVICE_FUNC, API_RAND), - "curandMakeMTGP32Constants": ("hiprngMakeMTGP32Constants", CONV_DEVICE_FUNC, API_RAND), - "curandMakeMTGP32KernelState": ("hiprngMakeMTGP32KernelState", CONV_DEVICE_FUNC, API_RAND), - "curand_normal_double": ("hiprng_normal_double", CONV_DEVICE_FUNC, API_RAND), - "curand_normal2": ("hiprng_normal2", CONV_DEVICE_FUNC, API_RAND), - "curand_normal2_double": ("hiprng_normal2_double", CONV_DEVICE_FUNC, API_RAND), - "curand_normal4": ("hiprng_normal4", CONV_DEVICE_FUNC, API_RAND), - "curand_normal4_double": ("hiprng_normal4_double", CONV_DEVICE_FUNC, API_RAND), - "curand_uniform": ("hiprng_uniform", CONV_DEVICE_FUNC, API_RAND), - "curand_uniform_double": ("hiprng_uniform_double", CONV_DEVICE_FUNC, API_RAND), - "curand_uniform2_double": ("hiprng_uniform2_double", CONV_DEVICE_FUNC, API_RAND), - "curand_uniform4": ("hiprng_uniform4", CONV_DEVICE_FUNC, API_RAND), - "curand_uniform4_double": ("hiprng_uniform4_double", CONV_DEVICE_FUNC, API_RAND), - "curand_discrete": ("hiprng_discrete", CONV_DEVICE_FUNC, API_RAND), - "curand_discrete4": ("hiprng_discrete4", CONV_DEVICE_FUNC, API_RAND), - "curand_poisson": ("hiprng_poisson", CONV_DEVICE_FUNC, API_RAND), - "curand_poisson4": ("hiprng_poisson4", CONV_DEVICE_FUNC, API_RAND), - "curand_Philox4x32_10": ("hiprng_Philox4x32_10", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED) + "CURAND_STATUS_SUCCESS": ("HIPRAND_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_VERSION_MISMATCH": ("HIPRAND_STATUS_VERSION_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_NOT_INITIALIZED": ("HIPRAND_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_ALLOCATION_FAILED": ("HIPRAND_STATUS_ALLOCATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_TYPE_ERROR": ("HIPRAND_STATUS_TYPE_ERROR", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_OUT_OF_RANGE": ("HIPRAND_STATUS_OUT_OF_RANGE", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_LENGTH_NOT_MULTIPLE": ("HIPRAND_STATUS_LENGTH_NOT_MULTIPLE", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_DOUBLE_PRECISION_REQUIRED": ("HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_LAUNCH_FAILURE": ("HIPRAND_STATUS_LAUNCH_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_PREEXISTING_FAILURE": ("HIPRAND_STATUS_PREEXISTING_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_INITIALIZATION_FAILED": ("HIPRAND_STATUS_INITIALIZATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_ARCH_MISMATCH": ("HIPRAND_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), + "curand_STATUS_INTERNAL_ERROR": ("HIPRAND_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_TEST": ("HIPRAND_RNG_TEST", CONV_NUMERIC_LITERAL, API_RAND), + "mtgp32dc_params_fast_11213": ("mtgp32dc_params_fast_11213", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_PSEUDO_DEFAULT": ("HIPRAND_RNG_PSEUDO_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_PSEUDO_XORWOW": ("HIPRAND_RNG_PSEUDO_XORWOW", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_PSEUDO_MRG32K3A": ("HIPRAND_RNG_PSEUDO_MRG32K3A", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_PSEUDO_MTGP32": ("HIPRAND_RNG_PSEUDO_MTGP32", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_PSEUDO_MT19937": ("HIPRAND_RNG_PSEUDO_MT19937", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_PSEUDO_PHILOX4_32_10": ("HIPRAND_RNG_PSEUDO_PHILOX4_32_10", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_QUASI_DEFAULT": ("HIPRAND_RNG_QUASI_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_QUASI_SOBOL32": ("HIPRAND_RNG_QUASI_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_QUASI_SCRAMBLED_SOBOL32": ("HIPRAND_RNG_QUASI_SCRAMBLED_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_QUASI_SOBOL64": ("HIPRAND_RNG_QUASI_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), + "curand_RNG_QUASI_SCRAMBLED_SOBOL64": ("HIPRAND_RNG_QUASI_SCRAMBLED_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), + "curand_ORDERING_PSEUDO_BEST": ("HIPRAND_ORDERING_PSEUDO_BEST", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_ORDERING_PSEUDO_DEFAULT": ("HIPRAND_ORDERING_PSEUDO_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_ORDERING_PSEUDO_SEEDED": ("HIPRAND_ORDERING_PSEUDO_SEEDED", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_ORDERING_QUASI_DEFAULT": ("HIPRAND_ORDERING_QUASI_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_DIRECTION_VECTORS_32_JOEKUO6": ("HIPRAND_DIRECTION_VECTORS_32_JOEKUO6", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6": ("HIPRAND_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_DIRECTION_VECTORS_64_JOEKUO6": ("HIPRAND_DIRECTION_VECTORS_64_JOEKUO6", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6": ("HIPRAND_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_CHOOSE_BEST": ("HIPRAND_CHOOSE_BEST", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_ITR": ("HIPRAND_ITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_KNUTH": ("HIPRAND_KNUTH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_HITR": ("HIPRAND_HITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_M1": ("HIPRAND_M1", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_M2": ("HIPRAND_M2", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_BINARY_SEARCH": ("HIPRAND_BINARY_SEARCH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_DISCRETE_GAUSS": ("HIPRAND_DISCRETE_GAUSS", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_REJECTION": ("HIPRAND_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_DEVICE_API": ("HIPRAND_DEVICE_API", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_FAST_REJECTION": ("HIPRAND_FAST_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_3RD": ("HIPRAND_3RD", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_DEFINITION": ("HIPRAND_DEFINITION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curand_POISSON": ("HIPRAND_POISSON", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + "curandCreateGenerator": ("hiprandCreateGenerator", CONV_MATH_FUNC, API_RAND), + "curandCreateGeneratorHost": ("hiprandCreateGeneratorHost", CONV_MATH_FUNC, API_RAND), + "curandCreatePoissonDistribution": ("hiprandCreatePoissonDistribution", CONV_MATH_FUNC, API_RAND), + "curandDestroyDistribution": ("hiprandDestroyDistribution", CONV_MATH_FUNC, API_RAND), + "curandDestroyGenerator": ("hiprandDestroyGenerator", CONV_MATH_FUNC, API_RAND), + "curandGenerate": ("hiprandGenerate", CONV_MATH_FUNC, API_RAND), + "curandGenerateLogNormal": ("hiprandGenerateLogNormal", CONV_MATH_FUNC, API_RAND), + "curandGenerateLogNormalDouble": ("hiprandGenerateLogNormalDouble", CONV_MATH_FUNC, API_RAND), + "curandGenerateLongLong": ("hiprandGenerateLongLong", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + "curandGenerateNormal": ("hiprandGenerateNormal", CONV_MATH_FUNC, API_RAND), + "curandGenerateNormalDouble": ("hiprandGenerateNormalDouble", CONV_MATH_FUNC, API_RAND), + "curandGeneratePoisson": ("hiprandGeneratePoisson", CONV_MATH_FUNC, API_RAND), + "curandGenerateSeeds": ("hiprandGenerateSeeds", CONV_MATH_FUNC, API_RAND), + "curandGenerateUniform": ("hiprandGenerateUniform", CONV_MATH_FUNC, API_RAND), + "curandGenerateUniformDouble": ("hiprandGenerateUniformDouble", CONV_MATH_FUNC, API_RAND), + "curandGetDirectionVectors32": ("hiprandGetDirectionVectors32", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + "curandGetDirectionVectors64": ("hiprandGetDirectionVectors64", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + "curandGetProperty": ("hiprandGetProperty", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + "curandGetScrambleConstants32": ("hiprandGetScrambleConstants32", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + "curandGetScrambleConstants64": ("hiprandGetScrambleConstants64", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + "curandGetVersion": ("hiprandGetVersion", CONV_MATH_FUNC, API_RAND), + "curandSetGeneratorOffset": ("hiprandSetGeneratorOffset", CONV_MATH_FUNC, API_RAND), + "curandSetGeneratorOrdering": ("hiprandSetGeneratorOrdering", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + "curandSetPseudoRandomGeneratorSeed": ("hiprandSetPseudoRandomGeneratorSeed", CONV_MATH_FUNC, API_RAND), + "curandSetQuasiRandomGeneratorDimensions": ("hiprandSetQuasiRandomGeneratorDimensions", CONV_MATH_FUNC, API_RAND), + "curandSetStream": ("hiprandSetStream", CONV_MATH_FUNC, API_RAND), + "curand": ("hiprand", CONV_DEVICE_FUNC, API_RAND), + "curand_init": ("hiprand_init", CONV_DEVICE_FUNC, API_RAND), + "curand_log_normal": ("hiprand_log_normal", CONV_DEVICE_FUNC, API_RAND), + "curand_log_normal_double": ("hiprand_log_normal_double", CONV_DEVICE_FUNC, API_RAND), + "curand_log_normal2": ("hiprand_log_normal2", CONV_DEVICE_FUNC, API_RAND), + "curand_log_normal2_double": ("hiprand_log_normal2_double", CONV_DEVICE_FUNC, API_RAND), + "curand_log_normal4": ("hiprand_log_normal4", CONV_DEVICE_FUNC, API_RAND), + "curand_log_normal4_double": ("hiprand_log_normal4_double", CONV_DEVICE_FUNC, API_RAND), + "curand_mtgp32_single": ("hiprand_mtgp32_single", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + "curand_mtgp32_single_specific": ("hiprand_mtgp32_single_specific", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + "curand_mtgp32_specific": ("hiprand_mtgp32_specific", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + "curand_normal": ("hiprand_normal", CONV_DEVICE_FUNC, API_RAND), + "curandMakeMTGP32Constants": ("hiprandMakeMTGP32Constants", CONV_DEVICE_FUNC, API_RAND), + "curandMakeMTGP32KernelState": ("hiprandMakeMTGP32KernelState", CONV_DEVICE_FUNC, API_RAND), + "curand_normal_double": ("hiprand_normal_double", CONV_DEVICE_FUNC, API_RAND), + "curand_normal2": ("hiprand_normal2", CONV_DEVICE_FUNC, API_RAND), + "curand_normal2_double": ("hiprand_normal2_double", CONV_DEVICE_FUNC, API_RAND), + "curand_normal4": ("hiprand_normal4", CONV_DEVICE_FUNC, API_RAND), + "curand_normal4_double": ("hiprand_normal4_double", CONV_DEVICE_FUNC, API_RAND), + "curand_uniform": ("hiprand_uniform", CONV_DEVICE_FUNC, API_RAND), + "curand_uniform_double": ("hiprand_uniform_double", CONV_DEVICE_FUNC, API_RAND), + "curand_uniform2_double": ("hiprand_uniform2_double", CONV_DEVICE_FUNC, API_RAND), + "curand_uniform4": ("hiprand_uniform4", CONV_DEVICE_FUNC, API_RAND), + "curand_uniform4_double": ("hiprand_uniform4_double", CONV_DEVICE_FUNC, API_RAND), + "curand_discrete": ("hiprand_discrete", CONV_DEVICE_FUNC, API_RAND), + "curand_discrete4": ("hiprand_discrete4", CONV_DEVICE_FUNC, API_RAND), + "curand_poisson": ("hiprand_poisson", CONV_DEVICE_FUNC, API_RAND), + "curand_poisson4": ("hiprand_poisson4", CONV_DEVICE_FUNC, API_RAND), + "curand_Philox4x32_10": ("hiprand_Philox4x32_10", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + "mtgp32_kernel_params": ("mtgp32_kernel_params_t", CONV_MATH_FUNC, API_RAND) } CUDA_SPARSE_MAP = { diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 2bee61b024317e..af35dc9a8a7d67 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -126,15 +126,21 @@ profiler::RecordFunction profiler("${name}");""") PRE_RECORD_TRACE = CodeTemplate("""\ -jit::tracer::PreTraceInfo trace_info; +torch::jit::Node* node = nullptr; if (jit::tracer::isTracing()) { - trace_info = jit::tracer::preRecordTrace(jit::aten::${trace_name}, ${trace_inputs}); + auto& graph = jit::tracer::getTracingState()->graph; + node = graph->create(jit::aten::${trace_name}, /*outputs=*/0); + jit::tracer::recordSourceLocation(node); + ${add_trace_inputs} + graph->appendNode(node); } """) +ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${input}", ${input});""") + POST_RECORD_TRACE = CodeTemplate("""\ if (jit::tracer::isTracing()) { - jit::tracer::postRecordTrace( trace_info, ${trace_outputs} ); + jit::tracer::postRecordTrace(node, ArrayRef(${trace_outputs}) ); } """) @@ -372,7 +378,11 @@ def emit_record_trace(env): return ('', '') local = {} - local['trace_inputs'] = sum([['"{}"'.format(arg['name']), arg['name']] for arg in declaration['arguments']], []) + + add_trace_inputs = [] + for argument in declaration['arguments']: + add_trace_inputs.append(ADD_TRACE_INPUT.substitute(input=argument['name'])) + local['add_trace_inputs'] = '\n'.join(add_trace_inputs) # Record inplace operations as out-of-place operations (e.g., # not add_ but add) diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index fd8a960b6f2fa1..da5c5456780495 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -13,6 +13,7 @@ #include "torch/csrc/jit/tracer.h" #include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/symbolic_variable.h" +#include "torch/csrc/jit/ir.h" #include "torch/csrc/utils/variadic.h" #include "torch/csrc/autograd/functions/utils.h" diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index a1dca1e2eed9da..b1a835edf1fb98 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -558,12 +558,12 @@ static void _assert_not_tracing(const char* name, const variable_list& input_var } } -static jit::tracer::PreTraceInfo _trace_pre_record( +static Node* _trace_pre_record( PyObject* op_obj, PyObject *input_objects, const variable_list& input_vars) { if (!jit::tracer::isTracing()) { - return jit::tracer::PreTraceInfo(); + return nullptr; } // Save scalar args and the calling convention @@ -593,7 +593,7 @@ static jit::tracer::PreTraceInfo _trace_pre_record( } static void _trace_post_record( - const jit::tracer::PreTraceInfo& trace_info, + Node* node, PyObject* op_obj, const variable_list& input_vars, PyObject *output_objects, @@ -610,15 +610,15 @@ static void _trace_post_record( output_vars[i] = var->cdata; } - jit::tracer::postRecordTrace(trace_info, output_vars); + jit::tracer::postRecordTrace(node, output_vars); - trace_info.n->i_(attr::inplace, is_inplace); + node->i_(attr::inplace, is_inplace); } PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const UnpackedInput& unpacked, PyObject *inputs, THPObjectPtr&& raw_output, bool is_executable, - const jit::tracer::PreTraceInfo& trace_info) { + Node* node) { bool unpack_output = ensure_tuple(raw_output); auto num_outputs = PyTuple_GET_SIZE(raw_output.get()); @@ -639,7 +639,7 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked bool is_inplace = static_cast(grad_fn->dirty_tensors); _wrap_outputs(grad_fn, inputs, raw_output, outputs, is_executable); - _trace_post_record(trace_info, op_obj, unpacked.input_vars, outputs, is_inplace); + _trace_post_record(node, op_obj, unpacked.input_vars, outputs, is_inplace); if (is_executable) { _save_variables(grad_fn); } else { @@ -687,7 +687,7 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs) } return process_outputs(nullptr, self, unpacked_input, _inputs, std::move(raw_output), - is_executable, jit::tracer::PreTraceInfo()); + is_executable, nullptr); END_HANDLE_TH_ERRORS } @@ -708,7 +708,7 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) InputFlags& input_info = info_pair.second; // Record input nodes if tracing - auto trace_info = _trace_pre_record(cls, inputs, unpacked_input.input_vars); + auto* node = _trace_pre_record(cls, inputs, unpacked_input.input_vars); // Initialize backward function (and ctx) bool is_executable = input_info.is_executable; @@ -737,7 +737,7 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) } return process_outputs(cls, ctx, unpacked_input, inputs, std::move(tensor_outputs), - is_executable, trace_info); + is_executable, node); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/jit/custom_operator.h b/torch/csrc/jit/custom_operator.h index 63e9901964aacd..f0cc29db346963 100644 --- a/torch/csrc/jit/custom_operator.h +++ b/torch/csrc/jit/custom_operator.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -56,6 +57,25 @@ FunctionSchema createFunctionSchemaFromTraits(const std::string& name) { return {name, arguments, returns}; } +template +Node* getTracedNode( + const FunctionSchema& schema, + const std::tuple& tuple) { + auto symbol = Symbol::fromQualString(schema.name); + const auto& graph = tracer::getTracingState()->graph; + Node* node = graph->create(std::move(symbol), /*outputs=*/0); + tracer::recordSourceLocation(node); + + // Hack to call addInputs for the parameter pack in a sequenced fashion. + // https://stackoverflow.com/questions/12030538/calling-a-function-for-each-variadic-template-argument-and-an-array + int _[] = {(tracer::addInputs(node, schema.arguments[Is].name.c_str(), std::get(tuple)), 0)...}; + (void)_; + + graph->appendNode(node); + + return node; +} + /// Does two things for an operator implementation and a tuple of arguments: /// 1. Pops all necessary arguments off the stack into the tuple's elements, /// 2. Unpacks the tuple and calls the operator implementation. @@ -66,12 +86,25 @@ template < typename... Types, size_t... Is> ReturnType callOperatorWithTuple( + const FunctionSchema& schema, Implementation&& implementation, Stack& stack, std::tuple& tuple, Indices) { + Node* node = nullptr; + if (jit::tracer::isTracing()) { + node = getTracedNode(schema, tuple); + } + pop(stack, std::get(tuple)...); - return std::forward(implementation)(std::get(tuple)...); + auto result = + std::forward(implementation)(std::get(tuple)...); + + if (jit::tracer::isTracing()) { + jit::tracer::postRecordTrace(node, result); + } + + return result; } void checkArgumentVector( @@ -175,9 +208,10 @@ Operator createOperator( auto schema = torch::jit::detail::inferAndCheckSchema(schemaOrName); - return Operator(schema, [implementation](Stack& stack) { + return Operator(schema, [implementation, schema](Stack& stack) { ArgumentTuple tuple; auto result = torch::jit::detail::callOperatorWithTuple( + schema, std::move(implementation), stack, tuple, diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 94f1cbd0fac446..447d5eb3219049 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -1131,7 +1131,7 @@ friend struct Block; return oss.str(); } - friend std::ostream& operator<<(std::ostream & out, const Graph & g); + friend TORCH_API std::ostream& operator<<(std::ostream & out, const Graph & g); TORCH_API std::shared_ptr copy(); private: diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp index 494ea4da7a36a2..8b68da694f596f 100644 --- a/torch/csrc/jit/python_tracer.cpp +++ b/torch/csrc/jit/python_tracer.cpp @@ -69,7 +69,7 @@ std::shared_ptr createGraphByTracing( } } -PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj, +Node* preRecordPythonTrace(THPObjectPtr pyobj, std::string arg_types, at::ArrayRef inputs, pyobj_list scalar_args) { @@ -78,14 +78,10 @@ PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj, throw python_error(); } - PreTraceInfo info; - auto & state = getTracingState(); - auto & graph = state->graph; + auto & graph = getTracingState()->graph; - Node *n = info.n = graph->createPythonOp( - std::move(apply), - arg_types, - std::move(scalar_args)); + Node* n = graph->createPythonOp( + std::move(apply), arg_types, std::move(scalar_args)); recordSourceLocation(n); for (const Variable & input : inputs) { @@ -95,7 +91,7 @@ PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj, // NB: Order matters. This must append after inputs but before outputs. graph->appendNode(n); - return info; + return n; } void pythonRecordSourceLocation(Node* n) { diff --git a/torch/csrc/jit/python_tracer.h b/torch/csrc/jit/python_tracer.h index 960f48516dd4e1..ae3a63d5b4e4e2 100644 --- a/torch/csrc/jit/python_tracer.h +++ b/torch/csrc/jit/python_tracer.h @@ -10,7 +10,7 @@ void initPythonTracerBindings(PyObject *module); std::string getPythonInterpreterStackTrace(); -tracer::PreTraceInfo preRecordPythonTrace( +Node* preRecordPythonTrace( THPObjectPtr pyobj, std::string arg_types, at::ArrayRef inputs, pyobj_list scalar_args); diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp index dd523c8d741892..e4e9781bc78217 100644 --- a/torch/csrc/jit/test_jit.cpp +++ b/torch/csrc/jit/test_jit.cpp @@ -985,12 +985,12 @@ void testCustomOperators() { REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType); Stack stack; - push(stack, 2.0f, at::ones(5)); + push(stack, 2.0f, autograd::make_variable(at::ones(5))); op->getOperation()(stack); at::Tensor output; pop(stack, output); - REQUIRE(output.allclose(at::full(5, 3.0f))); + REQUIRE(output.allclose(autograd::make_variable(at::full(5, 3.0f)))); } { RegisterOperators reg({createOperator( @@ -1014,12 +1014,12 @@ void testCustomOperators() { REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType); Stack stack; - push(stack, 2.0f, at::ones(5)); + push(stack, 2.0f, autograd::make_variable(at::ones(5))); op->getOperation()(stack); at::Tensor output; pop(stack, output); - REQUIRE(output.allclose(at::full(5, 3.0f))); + REQUIRE(output.allclose(autograd::make_variable(at::full(5, 3.0f)))); } { // Check that lists work well. @@ -1050,7 +1050,7 @@ void testCustomOperators() { Stack stack; push(stack, std::vector{1, 2}); push(stack, std::vector{1.0, 2.0}); - push(stack, std::vector{at::ones(5)}); + push(stack, std::vector{autograd::make_variable(at::ones(5))}); op->getOperation()(stack); std::vector output; pop(stack, output); @@ -1089,6 +1089,52 @@ void testCustomOperators() { "for the return value in that position")); #endif // USE_CATCH } + { + auto op = createOperator( + "traced::op(float a, Tensor b) -> Tensor", + [](double a, at::Tensor b) { return a + b; }); + + std::shared_ptr state; + variable_list trace_vars_in; + std::tie(state, trace_vars_in) = tracer::enter({}); + + Stack stack; + push(stack, 2.0f, autograd::make_variable(at::ones(5))); + op.getOperation()(stack); + at::Tensor output = autograd::make_variable(at::empty({})); + pop(stack, output); + + tracer::exit({output}); + + std::string op_name("traced::op"); + bool contains_traced_op = false; + for (const auto& node : state->graph->nodes()) { + if (std::string(node->kind().toQualString()) == op_name) { + contains_traced_op = true; + break; + } + } + REQUIRE(contains_traced_op); + } + { +#ifdef USE_CATCH + // vector is not supported yet. + auto op = createOperator( + "traced::op(float[] f) -> int", + [](const std::vector& f) -> int64_t { return f.size(); }); + + std::shared_ptr state; + variable_list trace_vars_in; + std::tie(state, trace_vars_in) = tracer::enter({}); + + Stack stack; + push(stack, std::vector{1.0}); + + REQUIRE_THROWS_WITH( + op.getOperation()(stack), + StartsWith("Tracing float lists currently not supported!")); +#endif + } } TORCH_API std::string runJITCPPTests() { diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index d6219685c1acff..b261636a6e5c30 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -25,17 +25,20 @@ void genericAddInput(Node *n, T value) { } void badArgType() { - throw std::runtime_error("Found an unsupported argument type in the JIT tracer. File a bug report."); + AT_ERROR("Found an unsupported argument type in the JIT tracer. File a bug report."); } +thread_local std::shared_ptr tracing_state; + +} // namespace detail -void addInputs(Node *n, const char * name, int64_t value) { genericAddInput(n, value); } -void addInputs(Node *n, const char * name, bool value) { genericAddInput(n, value); } -void addInputs(Node *n, const char * name, double value) { genericAddInput(n, value); } -void addInputs(Node *n, const char * name, const at::Scalar& value) { genericAddInput(n, value); } +void addInputs(Node *n, const char * name, int64_t value) { detail::genericAddInput(n, value); } +void addInputs(Node *n, const char * name, bool value) { detail::genericAddInput(n, value); } +void addInputs(Node *n, const char * name, double value) { detail::genericAddInput(n, value); } +void addInputs(Node *n, const char * name, const at::Scalar& value) { detail::genericAddInput(n, value); } void addInputs(Node *n, const char * name, const at::Tensor& value) { n->addInput(getValueTrace(value)); } -void addInputs(Node *n, const char * name, const std::string& value) { badArgType(); } -void addInputs(Node *n, const char * name, const at::SparseTensorRef& value) { badArgType(); } +void addInputs(Node *n, const char * name, const std::string& value) { detail::badArgType(); } +void addInputs(Node *n, const char * name, const at::SparseTensorRef& value) { detail::badArgType(); } void addInputs(Node *n, const char * name, at::TensorList value) { Graph *g = n->owningGraph(); @@ -64,9 +67,9 @@ void addInputs(Node *n, const char * name, at::IntList value) { n->addInput(g->insertNode(g->createList(jit::IntType::get(), info))->output()); } -thread_local std::shared_ptr tracing_state; - -} // namespace detail +void addInputs(Node *n, const char * name, const ArrayRef& value) { + AT_ERROR("Tracing float lists currently not supported!"); +} const std::shared_ptr& getTracingState() { return detail::tracing_state; @@ -81,11 +84,11 @@ TracingState::TracingState() TracingState::~TracingState() = default; -void postRecordTrace(const PreTraceInfo& info, +void postRecordTrace(Node* node, at::ArrayRef outputs) { for (size_t i = 0; i < outputs.size(); i++) { auto & output = outputs[i]; - Value * value = info.n->addOutput(); + Value * value = node->addOutput(); if (output.defined()) { value->inferTypeFrom(output.data()); setValueTrace(output, value); diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index fe34b12df30fd6..337e9d892fb826 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -1,13 +1,17 @@ #pragma once +#include "torch/csrc/autograd/function_hook.h" +#include "torch/csrc/autograd/variable.h" #include "torch/csrc/jit/assertions.h" -#include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/constants.h" -#include "torch/csrc/WindowsTorchApiMacro.h" +#include "torch/csrc/jit/ir.h" +#include "torch/csrc/utils/functional.h" #include "torch/csrc/utils/functional.h" #include "torch/csrc/utils/variadic.h" -#include "torch/csrc/autograd/function_hook.h" -#include "torch/csrc/autograd/variable.h" +#include "torch/csrc/utils/variadic.h" +#include "torch/csrc/WindowsTorchApiMacro.h" + +#include #include #include @@ -187,18 +191,10 @@ inline void abandon() { setTracingState(nullptr); } -// Pre-recorded information about the trace before we actually carry -// out the trace -struct PreTraceInfo { - Node *n; -}; - TORCH_API void recordSourceLocation(Node* n); TORCH_API void setRecordSourceLocation(void (*v)(Node*)); -namespace detail { - // NB: those serve both as an intermediate steps in addInputs below, // as well as the overloads that terminate template recursion void addInputs(Node *n, const char * name, int64_t value); @@ -208,6 +204,7 @@ void addInputs(Node *n, const char * name, const at::Scalar& value); void addInputs(Node *n, const char * name, const at::Tensor& value); void addInputs(Node *n, const char * name, at::IntList value); void addInputs(Node *n, const char * name, at::TensorList value); +void addInputs(Node *n, const char * name, const ArrayRef& value); void addInputs(Node *n, const char * name, const std::string& value); void addInputs(Node *n, const char * name, const at::SparseTensorRef& value); @@ -216,34 +213,23 @@ void addInputs(Node *n, const char * name, std::array value) { throw std::runtime_error("Found an unsupported argument type in the JIT tracer. File a bug report."); } -template -void addInputs(Node *n, const char * arg_name, T arg, const char * next_arg_name, Args... args) { - addInputs(n, arg_name, arg); - addInputs(n, next_arg_name, args...); -} - -} // namespace detail +TORCH_API void postRecordTrace(Node* node, at::ArrayRef outputs); -// NB: if you change this function, you might want to take a look at -// preRecordPythonTrace from python_tracer.cpp -template -PreTraceInfo preRecordTrace(Symbol op, Args... inputs) { - PreTraceInfo info; - auto & state = getTracingState(); - auto & graph = state->graph; - - Node * n = info.n = graph->create(op, /*outputs=*/0); - recordSourceLocation(n); - - detail::addInputs(n, inputs...); - - // NB: Order matters. This must append after inputs but before outputs. - graph->appendNode(n); - - return info; +inline void postRecordTrace(Node* node, at::ArrayRef tensors) { + postRecordTrace(node, fmap(tensors)); } -TORCH_API void postRecordTrace(const PreTraceInfo& info, at::ArrayRef outputs); +template < + typename T, + typename = torch::enable_if_t< + (!std::is_convertible, ArrayRef>::value && + !std::is_convertible, ArrayRef>::value && + !std::is_convertible, Variable>::value)>> +void postRecordTrace(Node* node, T&&) { + AT_ERROR( + "Found an unsupported argument type ", at::demangle_type(), + " in the JIT tracer. File a bug report."); +} TORCH_API autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim); diff --git a/torch/nn/_functions/dropout.py b/torch/nn/_functions/dropout.py index 19b589a321f714..24647c0f83900d 100644 --- a/torch/nn/_functions/dropout.py +++ b/torch/nn/_functions/dropout.py @@ -15,6 +15,10 @@ def symbolic(g, input, p=0.5, train=False, inplace=False): r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) return r + @staticmethod + def _fused_kernel_acceptable(input, p, cls_name, inplace): + return input.is_cuda and p > 0 and p < 1 and not inplace and cls_name == 'Dropout' + @classmethod def forward(cls, ctx, input, p=0.5, train=False, inplace=False): if p < 0 or p > 1: @@ -23,10 +27,15 @@ def forward(cls, ctx, input, p=0.5, train=False, inplace=False): ctx.p = p ctx.train = train ctx.inplace = inplace + ctx.use_fused_kernel = Dropout._fused_kernel_acceptable(input, ctx.p, cls.__name__, ctx.inplace) if ctx.p == 0 or not ctx.train: return input + if ctx.use_fused_kernel: + output, ctx.noise = input._fused_dropout(1 - ctx.p) + return output + if ctx.inplace: ctx.mark_dirty(input) output = input @@ -45,7 +54,13 @@ def forward(cls, ctx, input, p=0.5, train=False, inplace=False): @staticmethod def backward(ctx, grad_output): - if ctx.p > 0 and ctx.train: + if ctx.use_fused_kernel: + if not grad_output.requires_grad: + return grad_output._masked_scale(ctx.noise, 1. / (1 - ctx.p)), None, None, None + else: + # use autograd-friendly backward if double backward is required + return grad_output * (ctx.noise.type_as(grad_output) * (1. / (1 - ctx.p))), None, None, None + elif ctx.p > 0 and ctx.train: return grad_output * ctx.noise, None, None, None else: return grad_output, None, None, None @@ -84,6 +99,7 @@ def forward(cls, ctx, input, p=0.5, train=False, inplace=False): if p < 0 or p > 1: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + ctx.use_fused_kernel = False ctx.p = p ctx.train = train ctx.inplace = inplace