Skip to content

Commit d2528ec

Browse files
authored
Merge pull request #30 from iotamudelta/patchesbegone
No more patches
2 parents ac5448f + 48d194d commit d2528ec

9 files changed

+37
-134
lines changed

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,12 @@ void SpatialSoftMax_getLaunchSizes(
118118
uint32_t block_threads = block.x * block.y;
119119
smem_size = block.x == 1 ? 0 : block_threads * sizeof(accscalar_t);
120120
int max_active_blocks;
121+
#ifdef __HIP_PLATFORM_HCC__
122+
max_active_blocks = 16;
123+
#else
121124
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks,
122125
k, block_threads, smem_size);
126+
#endif
123127
max_active_blocks *= at::globalContext().getCurrentDeviceProperties()->multiProcessorCount;
124128
grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, dim_size, inner_size);
125129
}

aten/src/THC/THCDeviceUtils.cuh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ __device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int
5353
#endif
5454
}
5555

56+
#ifdef __HIP_PLATFORM_HCC__
57+
//To handle ambiguity, add a type double version.
58+
__device__ __forceinline__ double WARP_SHFL_XOR(double value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) {
59+
//(HIP doesn't support double)
60+
return (double) __shfl_xor((float) value, laneMask, width);
61+
}
62+
#endif
5663
template <typename T>
5764
__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
5865
{
@@ -83,6 +90,14 @@ __device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width
8390
#endif
8491
}
8592

93+
#ifdef __HIP_PLATFORM_HCC__
94+
//To handle ambiguity, add a type double version.
95+
__device__ __forceinline__ double WARP_SHFL_DOWN(double value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
96+
{
97+
//(HIP doesn't support double)
98+
return (double) __shfl_down((float) value, delta, width);
99+
}
100+
#endif
86101
template <typename T>
87102
__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
88103
{

aten/src/THC/generic/THCTensorRandom.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -490,11 +490,11 @@ THC_API void THCTensor_(clampedRandom)(THCState* state, THCTensor *self_, int64_
490490
#if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT)
491491
if (range > 1ULL << 32) {
492492
generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
493-
gen->state.gen_states, size, data, min_val, range);
493+
gen->state.gen_states, static_cast<int32_t>(size), data, min_val, range);
494494
} else {
495495
#endif
496496
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
497-
gen->state.gen_states, size, data, min_val, range);
497+
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int32_t>(min_val), static_cast<uint32_t>(range));
498498
#if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT)
499499
}
500500
#endif
@@ -520,19 +520,19 @@ THC_API void THCTensor_(random)(THCState* state, THCTensor *self_)
520520
521521
#if defined(THC_REAL_IS_HALF)
522522
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
523-
gen->state.gen_states, size, data, 0UL, (1UL << HLF_MANT_DIG) + 1);
523+
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int32_t>(0UL), static_cast<uint32_t>((1UL << HLF_MANT_DIG) + 1));
524524
#elif defined(THC_REAL_IS_FLOAT)
525525
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
526-
gen->state.gen_states, size, data, 0UL, (1UL << FLT_MANT_DIG) + 1);
526+
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int32_t>(0UL), static_cast<uint32_t>((1UL << FLT_MANT_DIG) + 1));
527527
#elif defined(THC_REAL_IS_DOUBLE)
528528
generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
529-
gen->state.gen_states, size, data, 0ULL, (1ULL << DBL_MANT_DIG) + 1);
529+
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int64_t>(0ULL), static_cast<uint64_t>((1ULL << DBL_MANT_DIG) + 1));
530530
#elif defined(THC_REAL_IS_LONG)
531531
generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
532-
gen->state.gen_states, size, data, 0ULL, static_cast<uint64_t>(std::numeric_limits<real>::max()) + 1);
532+
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int64_t>(0ULL), static_cast<uint64_t>(std::numeric_limits<real>::max()) + 1);
533533
#else
534534
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
535-
gen->state.gen_states, size, data, 0UL, static_cast<uint32_t>(std::numeric_limits<real>::max()) + 1);
535+
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int32_t>(0UL), static_cast<uint32_t>(std::numeric_limits<real>::max()) + 1);
536536
#endif
537537
538538
THCTensor_(freeCopyTo)(state, self, self_);

aten/src/THCUNN/FeatureLPPooling.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ featureLPPoolingUpdateOutput(const THCDeviceTensor<T, 4> input,
193193

194194
if (Stride < Width) {
195195
// Shift registers for calculating the next point
196-
RegisterUtils<T, Width>::shiftLeft<Stride>(in);
196+
RegisterUtils<T, Width>::template shiftLeft<Stride>(in);
197197
}
198198
}
199199
}
@@ -377,7 +377,7 @@ featureLPPoolingUpdateGradInput(const THCDeviceTensor<T, 4> gradOutput,
377377

378378
if (Stride < Width) {
379379
// Shift registers for calculating the next point
380-
RegisterUtils<T, Width>::shiftLeft<Stride>(in);
380+
RegisterUtils<T, Width>::template shiftLeft<Stride>(in);
381381
}
382382
}
383383
}

tools/amd_build/disabled_features.yaml

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
{
22
"disable_unsupported_hip_calls":
33
[
4-
{
5-
"path": "aten/src/THC/generic/THCTensorSort.cu",
6-
"functions": {
7-
"thrust::copy": ";",
8-
"thrust::stable_sort_by_key": ";"
9-
}
10-
},
114
{
125
"path": "aten/src/THC/THCBlas.cu",
136
"functions": {
@@ -23,18 +16,6 @@
2316
"HIPBLAS_DATA_HALF": "0"
2417
}
2518
},
26-
{
27-
"path": "aten/src/THCUNN/SoftMaxCommon.cuh",
28-
"functions": {
29-
"cudaOccupancyMaxActiveBlocksPerMultiprocessor": "16"
30-
}
31-
},
32-
{
33-
"path": "aten/src/ATen/native/cuda/SoftMax.cu",
34-
"functions": {
35-
"cudaOccupancyMaxActiveBlocksPerMultiprocessor": "16"
36-
}
37-
},
3819
{
3920
"path": "aten/src/THC/THCStream.cpp",
4021
"functions": {
@@ -142,7 +123,6 @@
142123
"aten/src/ATen/native/cuda/CuFFTUtils.h",
143124
"aten/src/ATen/native/cuda/CuFFTPlanCache.h",
144125
"aten/src/ATen/native/cuda/SpectralOps.cu",
145-
"aten/src/THCUNN/RReLU.cu",
146126
"aten/src/ATen/native/cuda/Distributions.cu"
147127
],
148128
"disabled_functions": [
@@ -186,13 +166,6 @@
186166
"THNN_(SparseLinear_accGradParameters)"
187167
]
188168
},
189-
{
190-
"path": "aten/src/THCUNN/generic/RReLU.cu",
191-
"functions": [
192-
"THNN_(RReLU_updateOutput)",
193-
"THNN_(RReLU_updateGradInput)"
194-
]
195-
},
196169
{
197170
"path": "aten/src/THCUNN/generic/LookupTable.cu",
198171
"functions": [

tools/amd_build/patches/a_aten_src_THCUNN_FeatureLPPooling.cu.patch

Lines changed: 0 additions & 22 deletions
This file was deleted.

tools/amd_build/patches/a_aten_src_THC_THCDeviceUtils.cuh.patch

Lines changed: 0 additions & 29 deletions
This file was deleted.

tools/amd_build/patches/a_aten_src_THC_generic_THCTensorRandom.cu.patch

Lines changed: 0 additions & 43 deletions
This file was deleted.

tools/amd_build/pyHIPIFY/hipify-python.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,9 @@ def add_static_casts(directory, extensions, KernelTemplateParams):
772772
kernel_name_with_template = KernelTemplateParams[kernel_name]["kernel_with_template"]
773773
argument_types = KernelTemplateParams[kernel_name]["arg_types"]
774774

775-
old_kernel_launch = input_source[arguments[0]["start"]:arguments[-1]["end"]]
775+
old_kernel_launch = input_source[arguments[5]["start"]:arguments[-1]["end"]]
776+
full_old_kernel_launch = input_source[arguments[0]["start"]:arguments[-1]["end"]]
777+
full_new_kernel_launch = full_old_kernel_launch
776778
new_kernel_launch = old_kernel_launch
777779

778780
kernel_params = argument_strings[5:]
@@ -788,14 +790,17 @@ def replace_arg(match):
788790
# Update to static_cast, account for cases where argument is at start/end of string
789791
new_kernel_launch = re.sub(r'(^|\W)({0})(\W|$)'.format(re.escape(the_arg)), replace_arg, new_kernel_launch)
790792

793+
# replace kernel arguments in full kernel launch arguments w/ static_cast ones
794+
full_new_kernel_launch = full_new_kernel_launch.replace(old_kernel_launch, new_kernel_launch)
795+
791796
# Add template type
792797
if "THCUNN" in filepath.split("/") and "generic" not in filepath.split("/"):
793798
kernel_name_with_template = kernel_name_with_template.replace("<real>", "<Dtype>")
794-
new_kernel_launch = re.sub(r'\b{0}\b'.format(original_kernel_name_with_template),
795-
lambda x: kernel_name_with_template, new_kernel_launch)
799+
full_new_kernel_launch = re.sub(r'\b{0}\b'.format(original_kernel_name_with_template),
800+
lambda x: kernel_name_with_template, full_new_kernel_launch)
796801

797802
# Replace Launch
798-
new_output_source = new_output_source.replace(old_kernel_launch, new_kernel_launch)
803+
new_output_source = new_output_source.replace(full_old_kernel_launch, full_new_kernel_launch)
799804

800805
# Overwrite file contents
801806
fileobj.seek(0)

0 commit comments

Comments
 (0)