Skip to content

No more patches #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aten/src/ATen/native/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,12 @@ void SpatialSoftMax_getLaunchSizes(
uint32_t block_threads = block.x * block.y;
smem_size = block.x == 1 ? 0 : block_threads * sizeof(accscalar_t);
int max_active_blocks;
#ifdef __HIP_PLATFORM_HCC__
max_active_blocks = 16;
#else
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks,
k, block_threads, smem_size);
#endif
max_active_blocks *= at::globalContext().getCurrentDeviceProperties()->multiProcessorCount;
grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, dim_size, inner_size);
}
Expand Down
15 changes: 15 additions & 0 deletions aten/src/THC/THCDeviceUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ __device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int
#endif
}

#ifdef __HIP_PLATFORM_HCC__
//To handle ambiguity, add a type double version.
__device__ __forceinline__ double WARP_SHFL_XOR(double value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) {
//(HIP doesn't support double)
return (double) __shfl_xor((float) value, laneMask, width);
}
#endif
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
Expand Down Expand Up @@ -83,6 +90,14 @@ __device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width
#endif
}

#ifdef __HIP_PLATFORM_HCC__
//To handle ambiguity, add a type double version.
__device__ __forceinline__ double WARP_SHFL_DOWN(double value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
{
//(HIP doesn't support double)
return (double) __shfl_down((float) value, delta, width);
}
#endif
template <typename T>
__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
{
Expand Down
14 changes: 7 additions & 7 deletions aten/src/THC/generic/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,11 @@ THC_API void THCTensor_(clampedRandom)(THCState* state, THCTensor *self_, int64_
#if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT)
if (range > 1ULL << 32) {
generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states, size, data, min_val, range);
gen->state.gen_states, static_cast<int32_t>(size), data, min_val, range);
} else {
#endif
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states, size, data, min_val, range);
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int32_t>(min_val), static_cast<uint32_t>(range));
#if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT)
}
#endif
Expand All @@ -520,19 +520,19 @@ THC_API void THCTensor_(random)(THCState* state, THCTensor *self_)

#if defined(THC_REAL_IS_HALF)
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states, size, data, 0UL, (1UL << HLF_MANT_DIG) + 1);
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int32_t>(0UL), static_cast<uint32_t>((1UL << HLF_MANT_DIG) + 1));
#elif defined(THC_REAL_IS_FLOAT)
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states, size, data, 0UL, (1UL << FLT_MANT_DIG) + 1);
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int32_t>(0UL), static_cast<uint32_t>((1UL << FLT_MANT_DIG) + 1));
#elif defined(THC_REAL_IS_DOUBLE)
generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states, size, data, 0ULL, (1ULL << DBL_MANT_DIG) + 1);
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int64_t>(0ULL), static_cast<uint64_t>((1ULL << DBL_MANT_DIG) + 1));
#elif defined(THC_REAL_IS_LONG)
generate_random_64<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states, size, data, 0ULL, static_cast<uint64_t>(std::numeric_limits<real>::max()) + 1);
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int64_t>(0ULL), static_cast<uint64_t>(std::numeric_limits<real>::max()) + 1);
#else
generate_random<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
gen->state.gen_states, size, data, 0UL, static_cast<uint32_t>(std::numeric_limits<real>::max()) + 1);
gen->state.gen_states, static_cast<int32_t>(size), data, static_cast<int32_t>(0UL), static_cast<uint32_t>(std::numeric_limits<real>::max()) + 1);
#endif

THCTensor_(freeCopyTo)(state, self, self_);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THCUNN/FeatureLPPooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ featureLPPoolingUpdateOutput(const THCDeviceTensor<T, 4> input,

if (Stride < Width) {
// Shift registers for calculating the next point
RegisterUtils<T, Width>::shiftLeft<Stride>(in);
RegisterUtils<T, Width>::template shiftLeft<Stride>(in);
}
}
}
Expand Down Expand Up @@ -377,7 +377,7 @@ featureLPPoolingUpdateGradInput(const THCDeviceTensor<T, 4> gradOutput,

if (Stride < Width) {
// Shift registers for calculating the next point
RegisterUtils<T, Width>::shiftLeft<Stride>(in);
RegisterUtils<T, Width>::template shiftLeft<Stride>(in);
}
}
}
Expand Down
27 changes: 0 additions & 27 deletions tools/amd_build/disabled_features.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
{
"disable_unsupported_hip_calls":
[
{
"path": "aten/src/THC/generic/THCTensorSort.cu",
"functions": {
"thrust::copy": ";",
"thrust::stable_sort_by_key": ";"
}
},
{
"path": "aten/src/THC/THCBlas.cu",
"functions": {
Expand All @@ -23,18 +16,6 @@
"HIPBLAS_DATA_HALF": "0"
}
},
{
"path": "aten/src/THCUNN/SoftMaxCommon.cuh",
"functions": {
"cudaOccupancyMaxActiveBlocksPerMultiprocessor": "16"
}
},
{
"path": "aten/src/ATen/native/cuda/SoftMax.cu",
"functions": {
"cudaOccupancyMaxActiveBlocksPerMultiprocessor": "16"
}
},
{
"path": "aten/src/THC/THCStream.cpp",
"functions": {
Expand Down Expand Up @@ -142,7 +123,6 @@
"aten/src/ATen/native/cuda/CuFFTUtils.h",
"aten/src/ATen/native/cuda/CuFFTPlanCache.h",
"aten/src/ATen/native/cuda/SpectralOps.cu",
"aten/src/THCUNN/RReLU.cu",
"aten/src/ATen/native/cuda/Distributions.cu"
],
"disabled_functions": [
Expand Down Expand Up @@ -186,13 +166,6 @@
"THNN_(SparseLinear_accGradParameters)"
]
},
{
"path": "aten/src/THCUNN/generic/RReLU.cu",
"functions": [
"THNN_(RReLU_updateOutput)",
"THNN_(RReLU_updateGradInput)"
]
},
{
"path": "aten/src/THCUNN/generic/LookupTable.cu",
"functions": [
Expand Down

This file was deleted.

29 changes: 0 additions & 29 deletions tools/amd_build/patches/a_aten_src_THC_THCDeviceUtils.cuh.patch

This file was deleted.

This file was deleted.

13 changes: 9 additions & 4 deletions tools/amd_build/pyHIPIFY/hipify-python.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,9 @@ def add_static_casts(directory, extensions, KernelTemplateParams):
kernel_name_with_template = KernelTemplateParams[kernel_name]["kernel_with_template"]
argument_types = KernelTemplateParams[kernel_name]["arg_types"]

old_kernel_launch = input_source[arguments[0]["start"]:arguments[-1]["end"]]
old_kernel_launch = input_source[arguments[5]["start"]:arguments[-1]["end"]]
full_old_kernel_launch = input_source[arguments[0]["start"]:arguments[-1]["end"]]
full_new_kernel_launch = full_old_kernel_launch
new_kernel_launch = old_kernel_launch

kernel_params = argument_strings[5:]
Expand All @@ -788,14 +790,17 @@ def replace_arg(match):
# Update to static_cast, account for cases where argument is at start/end of string
new_kernel_launch = re.sub(r'(^|\W)({0})(\W|$)'.format(re.escape(the_arg)), replace_arg, new_kernel_launch)

# replace kernel arguments in full kernel launch arguments w/ static_cast ones
full_new_kernel_launch = full_new_kernel_launch.replace(old_kernel_launch, new_kernel_launch)

# Add template type
if "THCUNN" in filepath.split("/") and "generic" not in filepath.split("/"):
kernel_name_with_template = kernel_name_with_template.replace("<real>", "<Dtype>")
new_kernel_launch = re.sub(r'\b{0}\b'.format(original_kernel_name_with_template),
lambda x: kernel_name_with_template, new_kernel_launch)
full_new_kernel_launch = re.sub(r'\b{0}\b'.format(original_kernel_name_with_template),
lambda x: kernel_name_with_template, full_new_kernel_launch)

# Replace Launch
new_output_source = new_output_source.replace(old_kernel_launch, new_kernel_launch)
new_output_source = new_output_source.replace(full_old_kernel_launch, full_new_kernel_launch)

# Overwrite file contents
fileobj.seek(0)
Expand Down