Skip to content

Successful ROCm pytorch build without hcrng and hcsparse. #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions aten/src/ATen/cudnn/Descriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#if CUDNN_VERSION < 7000

#include <curand_kernel.h>
//#include <curand_kernel.h>

/*
Note [cuDNN dropout descriptor initialization]
Expand Down Expand Up @@ -233,7 +233,8 @@ inline cudnnStatus_t cudnnRestoreDropoutDescriptor(
if (ret != CUDNN_STATUS_SUCCESS) return ret;
if (expectedStateSizeInBytes != stateSizeInBytes) return CUDNN_STATUS_INVALID_VALUE;
dropoutDesc->dropout = dropout;
dropoutDesc->nstates = (int)stateSizeInBytes/sizeof(curandState_t);
// dropoutDesc->nstates = (int)stateSizeInBytes/sizeof(curandState_t);
dropoutDesc->nstates = (int)stateSizeInBytes;
dropoutDesc->states = states;
return CUDNN_STATUS_SUCCESS;
}
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/Distributions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
#include "ATen/cuda/CUDAApplyUtils.cuh"
#include "ATen/AccumulateType.h"

#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
//#include <curand.h>
//#include <curand_kernel.h>
//#include <curand_philox4x32_x.h>
#include <utility>
#include <functional>
#include <nvfunctional>
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/Embedding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,15 @@ __global__ void renorm_kernel(
} else if (norm_type == 2) {
v += x * x;
} else {
v += std::pow(x, norm_type);
//v += std::pow(x, norm_type);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jithunnair-amd I don't think it's a good way to change the behavior of an operator. I'd be in favor of disabling the whole operator instead of changing its behavior just to get things compiled. It would trigger future issue later on. Moreover not comment has been provided on these changes which makes it very hard to fix them later on.

}
}

using Op = ReduceAdd<accscalar_t>;
v = reduceBlock<accscalar_t>(sdata, blockDim.x, v, Op(), 0);

if (tid == 0) {
sdata[0] = std::pow(v, static_cast<accscalar_t>(1.0 / norm_type));
//sdata[0] = std::pow(v, static_cast<accscalar_t>(1.0 / norm_type));
}
__syncthreads();

Expand Down
28 changes: 14 additions & 14 deletions aten/src/ATen/native/cuda/Gesv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ void magmaGesvBatched<double>(
dB_array, lddb, dinfo_array, batch_count, queue);
}

static magma_queue_t createMagmaQueue(const Tensor& tensor) {
auto& context = tensor.type().get_context();
magma_queue_t magma_queue;
magma_queue_create_from_cuda(
tensor.get_device(),
context.getCurrentCUDAStream(),
THCState_getCurrentBlasHandle(context.getTHCState()),
THCState_getCurrentSparseHandle(context.getTHCState()),
&magma_queue);
return magma_queue;
}
//static magma_queue_t createMagmaQueue(const Tensor& tensor) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jithunnair-amd same as the rest of the PR. No comment has been made regarding to the rationale disabling the logic crucial to the operator here. Instead of disabling the implementation of a particular operator I'd like to have the operators themselves be disabled.

// auto& context = tensor.type().get_context();
// magma_queue_t magma_queue;
// magma_queue_create_from_cuda(
// tensor.get_device(),
// context.getCurrentCUDAStream(),
// THCState_getCurrentBlasHandle(context.getTHCState()),
// THCState_getCurrentSparseHandle(context.getTHCState()),
// &magma_queue);
// return magma_queue;
//}

static inline magma_int_t magma_int_cast(int64_t value, const char* varname) {
auto result = static_cast<magma_int_t>(value);
Expand Down Expand Up @@ -116,9 +116,9 @@ AT_ERROR("gesv: MAGMA library not found in "
ipiv_array[i] = &ipiv_data[i * n];
}

magmaGesvBatched<scalar_t>(
n, nrhs, A_array, n, ipiv_array, b_array, n,
info_array, batch_size, createMagmaQueue(b));
// magmaGesvBatched<scalar_t>(
// n, nrhs, A_array, n, ipiv_array, b_array, n,
// info_array, batch_size, createMagmaQueue(b));

for (int64_t i = 0; i < batch_size; i++) {
infos[i] = info_array[i];
Expand Down
22 changes: 11 additions & 11 deletions aten/src/ATen/native/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace {
template<typename T, typename AccumT>
struct LogSoftMaxForwardEpilogue {
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
: logsum(max_input + std::log(sum)) {}
: logsum(max_input /*+ std::log(sum)*/ ) {}

__device__ __forceinline__ T operator()(T input) const {
return static_cast<T>(input - logsum);
Expand All @@ -33,7 +33,7 @@ struct LogSoftMaxBackwardEpilogue {
: sum(sum) {}

__device__ __forceinline__ T operator()(T gradOutput, T output) const {
return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);
return static_cast<T>(gradOutput /*- std::exp(static_cast<AccumT>(output)) * sum */ );
}

const AccumT sum;
Expand All @@ -46,7 +46,7 @@ struct SoftMaxForwardEpilogue {
, sum(sum) {}

__device__ __forceinline__ T operator()(T input) const {
return static_cast<T>(std::exp(input - max_input) / sum);
return static_cast<T>(0); // std::exp(input - max_input) / sum);
}

const AccumT max_input;
Expand Down Expand Up @@ -203,9 +203,9 @@ __global__ void cunn_SpatialSoftMaxForward(
max_input = spatialBlockReduceX<accscalar_t, Max>(sdata,max_input);

accscalar_t sum = 0;
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
- max_input);
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {}
//sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
// - max_input);
sum = spatialBlockReduceX<accscalar_t, Add>(sdata, sum);

Epilogue<scalar_t, accscalar_t> epilogue(max_input, sum);
Expand All @@ -218,9 +218,9 @@ __global__ void cunn_SpatialSoftMaxForward(
max_input = Max<accscalar_t>()(max_input, value);
}
accscalar_t sum = 0;
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
- max_input);
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {}
//sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
// - max_input);
Epilogue<scalar_t, accscalar_t> epilogue(max_input, sum);
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
Expand Down Expand Up @@ -284,7 +284,7 @@ template <typename T, typename AccumT>
struct MaxFloat
{
__device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
return ::max(max, (AccumT)v);
return /*::max(max,*/ (AccumT)v /*)*/ ;
}
};

Expand All @@ -303,7 +303,7 @@ struct SumExpFloat
: max_k(v) {}

__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + std::exp(v - max_k);
return sum; // + std::exp(v - max_k);
}

const AccumT max_k;
Expand Down
170 changes: 85 additions & 85 deletions aten/src/THC/THCGeneral.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,12 @@ void THCudaShutdown(THCState* state)
THCublasCheck(cublasDestroy(res->blasHandles[i]));
}
/* Free user defined sparse handles */
for (int i = 0; i < res->numSparseHandles; ++i) {
THCusparseCheck(cusparseDestroy(res->sparseHandles[i]));
}
// for (int i = 0; i < res->numSparseHandles; ++i) {
// THCusparseCheck(cusparseDestroy(res->sparseHandles[i]));
// }

free(res->blasHandles);
free(res->sparseHandles);
// free(res->sparseHandles);
THCStream_free((THCStream*)THCThreadLocal_get(state->currentStreams[dev]));
THCThreadLocal_free(state->currentStreams[dev]);
}
Expand Down Expand Up @@ -354,14 +354,14 @@ void THCState_reserveDeviceSparseHandles(THCState* state, int device, int numSpa
THCudaCheck(cudaGetDevice(&prevDev));
THCudaCheck(cudaSetDevice(device));

size_t size = numSparseHandles * sizeof(cusparseHandle_t);
cusparseHandle_t* handles = (cusparseHandle_t*) realloc(res->sparseHandles, size);
for (int i = res->numSparseHandles; i < numSparseHandles; ++i) {
handles[i] = NULL;
THCusparseCheck(cusparseCreate(&handles[i]));
}
res->sparseHandles = handles;
res->numSparseHandles = numSparseHandles;
// size_t size = numSparseHandles * sizeof(cusparseHandle_t);
// cusparseHandle_t* handles = (cusparseHandle_t*) realloc(res->sparseHandles, size);
// for (int i = res->numSparseHandles; i < numSparseHandles; ++i) {
// handles[i] = NULL;
// THCusparseCheck(cusparseCreate(&handles[i]));
// }
// res->sparseHandles = handles;
// res->numSparseHandles = numSparseHandles;

THCudaCheck(cudaSetDevice(prevDev));
}
Expand Down Expand Up @@ -419,16 +419,16 @@ cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int han
return res->blasHandles[handle - 1];
}

cusparseHandle_t THCState_getDeviceSparseHandle(THCState *state, int device, int handle)
{
if (handle <= 0 || handle > state->numUserSparseHandles) {
THError("%d is not a valid handle, valid range is: (1, %d)",
handle, state->numUserSparseHandles);
}
THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device);
THCState_reserveDeviceSparseHandles(state, device, handle);
return res->sparseHandles[handle - 1];
}
//cusparseHandle_t THCState_getDeviceSparseHandle(THCState *state, int device, int handle)
//{
// if (handle <= 0 || handle > state->numUserSparseHandles) {
// THError("%d is not a valid handle, valid range is: (1, %d)",
// handle, state->numUserSparseHandles);
// }
// THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device);
// THCState_reserveDeviceSparseHandles(state, device, handle);
// return res->sparseHandles[handle - 1];
//}

static THCStream* THCState_getStreamOnDevice(THCState* state, int device)
{
Expand Down Expand Up @@ -493,21 +493,21 @@ cublasHandle_t THCState_getCurrentBlasHandle(THCState *state)
return NULL;
}

cusparseHandle_t THCState_getCurrentSparseHandle(THCState *state)
{
/* This is called at the point of kernel execution.
For some debugging code or improperly instrumented kernels,
`state` is null */
if (state) {
int device;
THCudaCheck(cudaGetDevice(&device));

int handle = THCState_getCurrentSparseHandleIndex(state);
return THCState_getDeviceSparseHandle(state, device, handle);
}
THError("THCState and sparseHandles must be set as there is no default sparseHandle");
return NULL;
}
//cusparseHandle_t THCState_getCurrentSparseHandle(THCState *state)
//{
// /* This is called at the point of kernel execution.
// For some debugging code or improperly instrumented kernels,
// `state` is null */
// if (state) {
// int device;
// THCudaCheck(cudaGetDevice(&device));
//
// int handle = THCState_getCurrentSparseHandleIndex(state);
// return THCState_getDeviceSparseHandle(state, device, handle);
// }
// THError("THCState and sparseHandles must be set as there is no default sparseHandle");
// return NULL;
//}

int THCState_getCurrentBlasHandleIndex(THCState *state)
{
Expand Down Expand Up @@ -643,54 +643,54 @@ void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
}
}

void __THCusparseCheck(cusparseStatus_t status, const char *file, const int line)
{
if(status != CUSPARSE_STATUS_SUCCESS)
{
const char* errmsg = NULL;

switch(status)
{
case CUSPARSE_STATUS_NOT_INITIALIZED:
errmsg = "library not initialized";
break;

case CUSPARSE_STATUS_ALLOC_FAILED:
errmsg = "resource allocation failed";
break;

case CUSPARSE_STATUS_INVALID_VALUE:
errmsg = "an invalid numeric value was used as an argument";
break;

case CUSPARSE_STATUS_ARCH_MISMATCH:
errmsg = "an absent device architectural feature is required";
break;

case CUSPARSE_STATUS_MAPPING_ERROR:
errmsg = "an access to GPU memory space failed";
break;

case CUSPARSE_STATUS_EXECUTION_FAILED:
errmsg = "the GPU program failed to execute";
break;

case CUSPARSE_STATUS_INTERNAL_ERROR:
errmsg = "an internal operation failed";
break;

case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
errmsg = "the matrix type is not supported by this function";
break;

default:
errmsg = "unknown error";
break;
}

_THError(file, line, "cusparse runtime error : %s", errmsg);
}
}
//void __THCusparseCheck(cusparseStatus_t status, const char *file, const int line)
//{
// if(status != CUSPARSE_STATUS_SUCCESS)
// {
// const char* errmsg = NULL;
//
// switch(status)
// {
// case CUSPARSE_STATUS_NOT_INITIALIZED:
// errmsg = "library not initialized";
// break;
//
// case CUSPARSE_STATUS_ALLOC_FAILED:
// errmsg = "resource allocation failed";
// break;
//
// case CUSPARSE_STATUS_INVALID_VALUE:
// errmsg = "an invalid numeric value was used as an argument";
// break;
//
// case CUSPARSE_STATUS_ARCH_MISMATCH:
// errmsg = "an absent device architectural feature is required";
// break;
//
// case CUSPARSE_STATUS_MAPPING_ERROR:
// errmsg = "an access to GPU memory space failed";
// break;
//
// case CUSPARSE_STATUS_EXECUTION_FAILED:
// errmsg = "the GPU program failed to execute";
// break;
//
// case CUSPARSE_STATUS_INTERNAL_ERROR:
// errmsg = "an internal operation failed";
// break;
//
// case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
// errmsg = "the matrix type is not supported by this function";
// break;
//
// default:
// errmsg = "unknown error";
// break;
// }
//
// _THError(file, line, "cusparse runtime error : %s", errmsg);
// }
//}

void THCSetGCHandler(THCState *state, void (*cutorchGCFunction_)(void *data), void *data )
{
Expand Down
Loading