diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index d500a47320f333..0ee5d18d1e2bde 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -118,8 +118,12 @@ void SpatialSoftMax_getLaunchSizes( uint32_t block_threads = block.x * block.y; smem_size = block.x == 1 ? 0 : block_threads * sizeof(accscalar_t); int max_active_blocks; +#ifdef __HIP_PLATFORM_HCC__ + max_active_blocks = 16; +#else cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, k, block_threads, smem_size); +#endif max_active_blocks *= at::globalContext().getCurrentDeviceProperties()->multiProcessorCount; grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, dim_size, inner_size); } diff --git a/aten/src/THC/THCDeviceUtils.cuh b/aten/src/THC/THCDeviceUtils.cuh index 8c0de78406da90..7f16455ff21801 100644 --- a/aten/src/THC/THCDeviceUtils.cuh +++ b/aten/src/THC/THCDeviceUtils.cuh @@ -53,6 +53,13 @@ __device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int #endif } +#ifdef __HIP_PLATFORM_HCC__ +//To handle ambiguity, add a type double version. +__device__ __forceinline__ double WARP_SHFL_XOR(double value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { + //(HIP doesn't support double) + return (double) __shfl_xor((float) value, laneMask, width); +} +#endif template __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { @@ -83,6 +90,14 @@ __device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width #endif } +#ifdef __HIP_PLATFORM_HCC__ +//To handle ambiguity, add a type double version. +__device__ __forceinline__ double WARP_SHFL_DOWN(double value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) +{ + //(HIP doesn't support double) + return (double) __shfl_down((float) value, delta, width); +} +#endif template __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) { diff --git a/aten/src/THC/generic/THCTensorRandom.cu b/aten/src/THC/generic/THCTensorRandom.cu index ef6ae9191585c8..3eb3c211c03d63 100644 --- a/aten/src/THC/generic/THCTensorRandom.cu +++ b/aten/src/THC/generic/THCTensorRandom.cu @@ -490,11 +490,11 @@ THC_API void THCTensor_(clampedRandom)(THCState* state, THCTensor *self_, int64_ #if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT) if (range > 1ULL << 32) { generate_random_64<<>>( - gen->state.gen_states, size, data, min_val, range); + gen->state.gen_states, static_cast(size), data, min_val, range); } else { #endif generate_random<<>>( - gen->state.gen_states, size, data, min_val, range); + gen->state.gen_states, static_cast(size), data, static_cast(min_val), static_cast(range)); #if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT) } #endif @@ -520,19 +520,19 @@ THC_API void THCTensor_(random)(THCState* state, THCTensor *self_) #if defined(THC_REAL_IS_HALF) generate_random<<>>( - gen->state.gen_states, size, data, 0UL, (1UL << HLF_MANT_DIG) + 1); + gen->state.gen_states, static_cast(size), data, static_cast(0UL), static_cast((1UL << HLF_MANT_DIG) + 1)); #elif defined(THC_REAL_IS_FLOAT) generate_random<<>>( - gen->state.gen_states, size, data, 0UL, (1UL << FLT_MANT_DIG) + 1); + gen->state.gen_states, static_cast(size), data, static_cast(0UL), static_cast((1UL << FLT_MANT_DIG) + 1)); #elif defined(THC_REAL_IS_DOUBLE) generate_random_64<<>>( - gen->state.gen_states, size, data, 0ULL, (1ULL << DBL_MANT_DIG) + 1); + gen->state.gen_states, static_cast(size), data, static_cast(0ULL), static_cast((1ULL << DBL_MANT_DIG) + 1)); #elif defined(THC_REAL_IS_LONG) generate_random_64<<>>( - gen->state.gen_states, size, data, 0ULL, static_cast(std::numeric_limits::max()) + 1); + gen->state.gen_states, static_cast(size), data, static_cast(0ULL), static_cast(std::numeric_limits::max()) + 1); #else generate_random<<>>( - gen->state.gen_states, size, data, 0UL, static_cast(std::numeric_limits::max()) + 1); + gen->state.gen_states, static_cast(size), data, static_cast(0UL), static_cast(std::numeric_limits::max()) + 1); #endif THCTensor_(freeCopyTo)(state, self, self_); diff --git a/aten/src/THCUNN/FeatureLPPooling.cu b/aten/src/THCUNN/FeatureLPPooling.cu index 4ad190fbe6b651..7026f0dd9b3805 100644 --- a/aten/src/THCUNN/FeatureLPPooling.cu +++ b/aten/src/THCUNN/FeatureLPPooling.cu @@ -193,7 +193,7 @@ featureLPPoolingUpdateOutput(const THCDeviceTensor input, if (Stride < Width) { // Shift registers for calculating the next point - RegisterUtils::shiftLeft(in); + RegisterUtils::template shiftLeft(in); } } } @@ -377,7 +377,7 @@ featureLPPoolingUpdateGradInput(const THCDeviceTensor gradOutput, if (Stride < Width) { // Shift registers for calculating the next point - RegisterUtils::shiftLeft(in); + RegisterUtils::template shiftLeft(in); } } } diff --git a/tools/amd_build/disabled_features.yaml b/tools/amd_build/disabled_features.yaml index 46c470fceacc9a..ea76b9fe4c9004 100644 --- a/tools/amd_build/disabled_features.yaml +++ b/tools/amd_build/disabled_features.yaml @@ -1,13 +1,6 @@ { "disable_unsupported_hip_calls": [ - { - "path": "aten/src/THC/generic/THCTensorSort.cu", - "functions": { - "thrust::copy": ";", - "thrust::stable_sort_by_key": ";" - } - }, { "path": "aten/src/THC/THCBlas.cu", "functions": { @@ -23,18 +16,6 @@ "HIPBLAS_DATA_HALF": "0" } }, - { - "path": "aten/src/THCUNN/SoftMaxCommon.cuh", - "functions": { - "cudaOccupancyMaxActiveBlocksPerMultiprocessor": "16" - } - }, - { - "path": "aten/src/ATen/native/cuda/SoftMax.cu", - "functions": { - "cudaOccupancyMaxActiveBlocksPerMultiprocessor": "16" - } - }, { "path": "aten/src/THC/THCStream.cpp", "functions": { @@ -142,7 +123,6 @@ "aten/src/ATen/native/cuda/CuFFTUtils.h", "aten/src/ATen/native/cuda/CuFFTPlanCache.h", "aten/src/ATen/native/cuda/SpectralOps.cu", - "aten/src/THCUNN/RReLU.cu", "aten/src/ATen/native/cuda/Distributions.cu" ], "disabled_functions": [ @@ -186,13 +166,6 @@ "THNN_(SparseLinear_accGradParameters)" ] }, - { - "path": "aten/src/THCUNN/generic/RReLU.cu", - "functions": [ - "THNN_(RReLU_updateOutput)", - "THNN_(RReLU_updateGradInput)" - ] - }, { "path": "aten/src/THCUNN/generic/LookupTable.cu", "functions": [ diff --git a/tools/amd_build/patches/a_aten_src_THCUNN_FeatureLPPooling.cu.patch b/tools/amd_build/patches/a_aten_src_THCUNN_FeatureLPPooling.cu.patch deleted file mode 100644 index 858d43e26a4231..00000000000000 --- a/tools/amd_build/patches/a_aten_src_THCUNN_FeatureLPPooling.cu.patch +++ /dev/null @@ -1,22 +0,0 @@ -diff --git a/aten/src/THCUNN/FeatureLPPooling.cu b/aten/src/THCUNN/FeatureLPPooling.cu -index 4ad190fbe..615ab4ec6 100644 ---- a/aten/src/THCUNN/FeatureLPPooling.cu -+++ b/aten/src/THCUNN/FeatureLPPooling.cu -@@ -193,7 +193,7 @@ featureLPPoolingUpdateOutput(const THCDeviceTensor input, - - if (Stride < Width) { - // Shift registers for calculating the next point -- RegisterUtils::shiftLeft(in); -+ RegisterUtils::template shiftLeft(in); - } - } - } -@@ -377,7 +377,7 @@ featureLPPoolingUpdateGradInput(const THCDeviceTensor gradOutput, - - if (Stride < Width) { - // Shift registers for calculating the next point -- RegisterUtils::shiftLeft(in); -+ RegisterUtils::template shiftLeft(in); - } - } - } diff --git a/tools/amd_build/patches/a_aten_src_THC_THCDeviceUtils.cuh.patch b/tools/amd_build/patches/a_aten_src_THC_THCDeviceUtils.cuh.patch deleted file mode 100644 index 15d399853877c9..00000000000000 --- a/tools/amd_build/patches/a_aten_src_THC_THCDeviceUtils.cuh.patch +++ /dev/null @@ -1,29 +0,0 @@ -diff --git a/aten/src/THC/THCDeviceUtils.cuh b/aten/src/THC/THCDeviceUtils.cuh -index 4ae2bee07..2845b9c68 100644 ---- a/aten/src/THC/THCDeviceUtils.cuh -+++ b/aten/src/THC/THCDeviceUtils.cuh -@@ -52,6 +52,11 @@ __device__ __forceinline__ int WARP_BALLOT(int predicate, unsigned int mask = 0x - #endif - } - -+//To handle ambiguity, add a type double version. -+__device__ __forceinline__ double WARP_SHFL_XOR(double value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { -+ //(HIP doesn't support double) -+ return (double) __shfl_xor((float) value, laneMask, width); -+} - template - __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) - { -@@ -82,6 +87,12 @@ __device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width - #endif - } - -+//To handle ambiguity, add a type double version. -+__device__ __forceinline__ double WARP_SHFL_DOWN(double value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) -+{ -+ //(HIP doesn't support double) -+ return (double) __shfl_down((float) value, delta, width); -+} - template - __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) - { diff --git a/tools/amd_build/patches/a_aten_src_THC_generic_THCTensorRandom.cu.patch b/tools/amd_build/patches/a_aten_src_THC_generic_THCTensorRandom.cu.patch deleted file mode 100644 index 355c1e3a7c3faf..00000000000000 --- a/tools/amd_build/patches/a_aten_src_THC_generic_THCTensorRandom.cu.patch +++ /dev/null @@ -1,43 +0,0 @@ -diff --git a/aten/src/THC/generic/THCTensorRandom.cu b/aten/src/THC/generic/THCTensorRandom.cu -index 906780b4f..b03e051cb 100644 ---- a/aten/src/THC/generic/THCTensorRandom.cu -+++ b/aten/src/THC/generic/THCTensorRandom.cu -@@ -504,11 +504,11 @@ THC_API void THCTensor_(clampedRandom)(THCState* state, THCTensor *self_, int64_ - #if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT) - if (range > 1ULL << 32) { - generate_random_64<<>>( -- gen->state.gen_states, size, data, min_val, range); -+ gen->state.gen_states, static_cast(size), data, min_val, range); - } else { - #endif - generate_random<<>>( -- gen->state.gen_states, size, data, min_val, range); -+ gen->state.gen_states, static_cast(size), data, static_cast(min_val), static_cast(range)); - #if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT) - } - #endif -@@ -534,19 +534,19 @@ THC_API void THCTensor_(random)(THCState* state, THCTensor *self_) - - #if defined(THC_REAL_IS_HALF) - generate_random<<>>( -- gen->state.gen_states, size, data, 0UL, (1UL << HLF_MANT_DIG) + 1); -+ gen->state.gen_states, static_cast(size), data, static_cast(0UL), static_cast((1UL << HLF_MANT_DIG) + 1)); - #elif defined(THC_REAL_IS_FLOAT) - generate_random<<>>( -- gen->state.gen_states, size, data, 0UL, (1UL << FLT_MANT_DIG) + 1); -+ gen->state.gen_states, static_cast(size), data, static_cast(0UL), static_cast((1UL << FLT_MANT_DIG) + 1)); - #elif defined(THC_REAL_IS_DOUBLE) - generate_random_64<<>>( -- gen->state.gen_states, size, data, 0ULL, (1ULL << DBL_MANT_DIG) + 1); -+ gen->state.gen_states, static_cast(size), data, static_cast(0ULL), static_cast((1ULL << DBL_MANT_DIG) + 1)); - #elif defined(THC_REAL_IS_LONG) - generate_random_64<<>>( -- gen->state.gen_states, size, data, 0ULL, static_cast(std::numeric_limits::max()) + 1); -+ gen->state.gen_states, static_cast(size), data, static_cast(0ULL), static_cast(std::numeric_limits::max()) + 1); - #else - generate_random<<>>( -- gen->state.gen_states, size, data, 0UL, static_cast(std::numeric_limits::max()) + 1); -+ gen->state.gen_states, static_cast(size), data, static_cast(0UL), static_cast(std::numeric_limits::max()) + 1); - #endif - - THCTensor_(freeCopyTo)(state, self, self_); diff --git a/tools/amd_build/pyHIPIFY/hipify-python.py b/tools/amd_build/pyHIPIFY/hipify-python.py index 7aae2481099fd4..c735a01fae97d6 100755 --- a/tools/amd_build/pyHIPIFY/hipify-python.py +++ b/tools/amd_build/pyHIPIFY/hipify-python.py @@ -772,7 +772,9 @@ def add_static_casts(directory, extensions, KernelTemplateParams): kernel_name_with_template = KernelTemplateParams[kernel_name]["kernel_with_template"] argument_types = KernelTemplateParams[kernel_name]["arg_types"] - old_kernel_launch = input_source[arguments[0]["start"]:arguments[-1]["end"]] + old_kernel_launch = input_source[arguments[5]["start"]:arguments[-1]["end"]] + full_old_kernel_launch = input_source[arguments[0]["start"]:arguments[-1]["end"]] + full_new_kernel_launch = full_old_kernel_launch new_kernel_launch = old_kernel_launch kernel_params = argument_strings[5:] @@ -788,14 +790,17 @@ def replace_arg(match): # Update to static_cast, account for cases where argument is at start/end of string new_kernel_launch = re.sub(r'(^|\W)({0})(\W|$)'.format(re.escape(the_arg)), replace_arg, new_kernel_launch) + # replace kernel arguments in full kernel launch arguments w/ static_cast ones + full_new_kernel_launch = full_new_kernel_launch.replace(old_kernel_launch, new_kernel_launch) + # Add template type if "THCUNN" in filepath.split("/") and "generic" not in filepath.split("/"): kernel_name_with_template = kernel_name_with_template.replace("", "") - new_kernel_launch = re.sub(r'\b{0}\b'.format(original_kernel_name_with_template), - lambda x: kernel_name_with_template, new_kernel_launch) + full_new_kernel_launch = re.sub(r'\b{0}\b'.format(original_kernel_name_with_template), + lambda x: kernel_name_with_template, full_new_kernel_launch) # Replace Launch - new_output_source = new_output_source.replace(old_kernel_launch, new_kernel_launch) + new_output_source = new_output_source.replace(full_old_kernel_launch, full_new_kernel_launch) # Overwrite file contents fileobj.seek(0)