Skip to content

Commit 9d220e0

Browse files
authored
Merge pull request #204 from jithunnair-amd/enable_sparse_for_rocm
Enable sparse functions for ROCm
2 parents 2cdc205 + a47ea97 commit 9d220e0

File tree

6 files changed

+22
-36
lines changed

6 files changed

+22
-36
lines changed

aten/src/ATen/cuda/CUDAContext.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,9 @@ Allocator* getCUDADeviceAllocator() {
5454
}
5555

5656
/* Handles */
57-
#ifndef __HIP_PLATFORM_HCC__
58-
cusparseHandle_t getCurrentCUDASparseHandle() {
59-
return THCState_getCurrentSparseHandle(at::globalContext().getTHCState());
60-
}
61-
#endif
57+
cusparseHandle_t getCurrentCUDASparseHandle() {
58+
return THCState_getCurrentSparseHandle(at::globalContext().getTHCState());
59+
}
6260

6361
} // namespace cuda
6462

aten/src/ATen/cuda/CUDAContext.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,7 @@ AT_API void uncheckedSetCurrentCUDAStream(CUDAStream stream);
5959
AT_API Allocator* getCUDADeviceAllocator();
6060

6161
/* Handles */
62-
#ifndef __HIP_PLATFORM_HCC__
63-
AT_API cusparseHandle_t getCurrentCUDASparseHandle();
64-
#endif
62+
AT_API cusparseHandle_t getCurrentCUDASparseHandle();
6563

6664

6765
} // namespace cuda

aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
namespace at { namespace native { namespace sparse { namespace cuda {
1111

12-
#ifndef __HIP_PLATFORM_HCC__
1312

1413
std::string cusparseGetErrorString(cusparseStatus_t status) {
1514
switch(status)
@@ -224,6 +223,5 @@ void XcoosortByRow(int64_t m, int64_t n, int64_t nnz, int *cooRows, int *cooCols
224223
CUSPARSE_CHECK(cusparseXcoosortByRow(handle, i_m, i_n, i_nnz, cooRows, cooCols, P, pBuffer));
225224
}
226225

227-
#endif
228226

229227
}}}} // namespace at::native::sparse::cuda

aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
namespace at { namespace native {
2626

2727
SparseTensor coalesce_sparse_cuda(const SparseTensor& self) {
28-
#ifndef __HIP_PLATFORM_HCC__
2928
int64_t nnz = self._nnz();
3029
if (nnz < 2) {
3130
_get_sparse_impl(self)->set_coalesced(true);
@@ -147,9 +146,6 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) {
147146

148147
THCudaCheck(cudaGetLastError());
149148
return dst;
150-
#else
151-
AT_ERROR("coalesce_sparse_cuda: HIP not supported");
152-
#endif
153149
}
154150

155151
}} // namespace at::native

aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ namespace at { namespace native {
2222
// Utility functions
2323
// --------------------------------------------------------------------
2424

25-
#ifndef __HIP_PLATFORM_HCC__
2625
namespace {
2726
IntTensor _to_csr_int(const LongTensor& rowIndices, int64_t dim, int64_t nnz) {
2827
IntTensor csr = at::empty({dim+1}, CUDA(kInt));
@@ -32,7 +31,6 @@ namespace {
3231
return csr;
3332
}
3433
}
35-
#endif
3634

3735
// NB: Deleted spaddcmul (aka addcmul_, but not actually wired up), spaddcdiv (not
3836
// wired at all)
@@ -42,7 +40,6 @@ namespace {
4240
// --------------------------------------------------------------------
4341

4442
Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseTensor& sparse_, const Tensor& dense, Scalar beta, Scalar alpha) {
45-
#ifndef __HIP_PLATFORM_HCC__
4643
AT_ASSERT(t.is_cuda()); // dispatch argument
4744
AT_CHECK(r_.is_cuda(), "addmm: expected 'out' to be CUDA, but got CPU");
4845
AT_CHECK(sparse_.is_cuda(), "addmm: expected 'mat1' to be CUDA, but got CPU");
@@ -141,9 +138,6 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT
141138

142139
r_.copy_(r__);
143140
return r_;
144-
#else
145-
AT_ERROR("s_addmm_out_sparse_dense_cuda: HIP not supported");
146-
#endif
147141
}
148142

149143
Tensor s_addmm_sparse_dense_cuda(
@@ -175,7 +169,6 @@ Tensor& s_addmm_sparse_dense_cuda_(
175169
// --------------------------------------------------------------------
176170

177171
SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse_, const Tensor& dense/* , Scalar alpha */) {
178-
#ifndef __HIP_PLATFORM_HCC__
179172
AT_ASSERT(sparse_.is_cuda()); // dispatch argument
180173
AT_CHECK(r_.is_cuda(), "hspmm: expected 'out' to be CUDA, but got CPU");
181174
AT_CHECK(dense.is_cuda(), "hspmm: expected 'mat2' to be CUDA, but got CPU");
@@ -231,9 +224,6 @@ SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse
231224
_get_sparse_impl(r_)->set_indices_and_values_unsafe(indices, values);
232225

233226
return r_;
234-
#else
235-
AT_ERROR("hspmm_out_sparse_cuda: HIP not supported");
236-
#endif
237227
}
238228

239229
SparseTensor hspmm_sparse_cuda(const SparseTensor& sparse, const Tensor& dense) {
@@ -248,7 +238,6 @@ SparseTensor hspmm_sparse_cuda(const SparseTensor& sparse, const Tensor& dense)
248238
// --------------------------------------------------------------------
249239

250240
Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorRef sparse_, at::Scalar value) {
251-
#ifndef __HIP_PLATFORM_HCC__
252241
const SparseTensor& sparse = sparse_.tref;
253242

254243
AT_ASSERT(dense.is_cuda()); // dispatch argument
@@ -339,17 +328,13 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR
339328
THCudaCheck(cudaGetLastError());
340329

341330
return r_;
342-
#else
343-
AT_ERROR("add_out_dense_sparse_cuda: HIP not supported");
344-
#endif
345331
}
346332

347333
// --------------------------------------------------------------------
348334
// add(SparseTensor, SparseTensor, Scalar) [broadcasts]
349335
// --------------------------------------------------------------------
350336

351337
SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const SparseTensor& src, Scalar value) {
352-
#ifndef __HIP_PLATFORM_HCC__
353338
AT_ASSERT(t.is_cuda()); // dispatch argument
354339
AT_CHECK(src.is_cuda(), "add: expected 'other' to be CUDA, but got CPU");
355340
AT_CHECK(r_.is_cuda(), "add: expected 'out' to be CUDA, but got CPU");
@@ -396,17 +381,13 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const
396381
// }
397382

398383
return r_;
399-
#else
400-
AT_ERROR("s_add_out_sparse_cuda: HIP not supported");
401-
#endif
402384
}
403385

404386
// --------------------------------------------------------------------
405387
// mul(SparseTensor, SparseTensor) [broadcasts]
406388
// --------------------------------------------------------------------
407389

408390
SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, const SparseTensor& src_) {
409-
#ifndef __HIP_PLATFORM_HCC__
410391
if (src_.dim() == 0) {
411392
return mul_out_sparse_zerodim(r_, t_, src_);
412393
} else if (t_.dim() == 0) {
@@ -474,9 +455,6 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons
474455
_get_sparse_impl(r_)->set_coalesced(true);
475456

476457
return r_;
477-
#else
478-
AT_ERROR("mul_out_sparse_cuda: HIP not supported");
479-
#endif
480458
}
481459

482460
}} // namespace at::native

tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2174,6 +2174,18 @@
21742174
"cusparseOperation_t": ("hipsparseOperation_t", CONV_TYPE, API_SPARSE),
21752175
"cusparseCreate": ("hipsparseCreate", CONV_MATH_FUNC, API_SPARSE),
21762176
"cusparseDestroy": ("hipsparseDestroy", CONV_MATH_FUNC, API_SPARSE),
2177+
"cusparseXcoo2csr": ("hipsparseXcoo2csr", CONV_MATH_FUNC, API_SPARSE),
2178+
"cusparseMatDescr_t": ("hipsparseMatDescr_t", CONV_MATH_FUNC, API_SPARSE),
2179+
"cusparseCreateMatDescr": ("hipsparseCreateMatDescr", CONV_MATH_FUNC, API_SPARSE),
2180+
"cusparseScsrmm2": ("hipsparseScsrmm2", CONV_MATH_FUNC, API_SPARSE),
2181+
"cusparseDcsrmm2": ("hipsparseDcsrmm2", CONV_MATH_FUNC, API_SPARSE),
2182+
"cusparseXcsrsort_bufferSizeExt": ("hipsparseXcsrsort_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE),
2183+
"cusparseXcsrsort": ("hipsparseXcsrsort", CONV_MATH_FUNC, API_SPARSE),
2184+
"cusparseXcoosort_bufferSizeExt": ("hipsparseXcoosort_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE),
2185+
"cusparseXcoosortByRow": ("hipsparseXcoosortByRow", CONV_MATH_FUNC, API_SPARSE),
2186+
"cusparseSetStream": ("hipsparseSetStream", CONV_MATH_FUNC, API_SPARSE),
2187+
"cusparseCreateIdentityPermutation": ("hipsparseCreateIdentityPermutation", CONV_MATH_FUNC, API_SPARSE),
2188+
"cusparseSetMatIndexBase": ("hipsparseSetMatIndexBase", CONV_MATH_FUNC, API_SPARSE),
21772189
"CUSPARSE_STATUS_SUCCESS": ("HIPSPARSE_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_SPARSE),
21782190
"CUSPARSE_STATUS_NOT_INITIALIZED": ("HIPSPARSE_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_SPARSE),
21792191
"CUSPARSE_STATUS_ALLOC_FAILED": ("HIPSPARSE_STATUS_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_SPARSE),
@@ -2183,6 +2195,12 @@
21832195
"CUSPARSE_STATUS_INTERNAL_ERROR": ("HIPSPARSE_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_SPARSE),
21842196
"CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED": ("HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_SPARSE),
21852197
"CUSPARSE_STATUS_ARCH_MISMATCH": ("HIPSPARSE_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_SPARSE),
2198+
"CUSPARSE_STATUS_ZERO_PIVOT": ("HIPSPARSE_STATUS_ZERO_PIVOT", CONV_NUMERIC_LITERAL, API_SPARSE),
2199+
"CUSPARSE_OPERATION_TRANSPOSE": ("HIPSPARSE_OPERATION_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPARSE),
2200+
"CUSPARSE_OPERATION_NON_TRANSPOSE": ("HIPSPARSE_OPERATION_NON_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPARSE),
2201+
"CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE": ("HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPARSE),
2202+
"CUSPARSE_INDEX_BASE_ZERO": ("HIPSPARSE_INDEX_BASE_ZERO", CONV_NUMERIC_LITERAL, API_SPARSE),
2203+
"CUSPARSE_INDEX_BASE_ONE": ("HIPSPARSE_INDEX_BASE_ONE", CONV_NUMERIC_LITERAL, API_SPARSE),
21862204
}
21872205

21882206
PYTORCH_SPECIFIC_MAPPINGS = {

0 commit comments

Comments
 (0)