Skip to content

Commit f8aa318

Browse files
authored
Enable sparse functionality on ROCm (#241)
* Enable sparse functions for ROCm * Reenable test_sparse unit tests that are now passing in ROCm (#208) * Reenable test_sparse unit tests that are now passing * It's a flaky test for us - skip.
1 parent bcc2a05 commit f8aa318

File tree

5 files changed

+18
-40
lines changed

5 files changed

+18
-40
lines changed

aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
namespace at { namespace native { namespace sparse { namespace cuda {
1111

12-
#ifndef __HIP_PLATFORM_HCC__
1312

1413
std::string cusparseGetErrorString(cusparseStatus_t status) {
1514
switch(status)
@@ -224,6 +223,5 @@ void XcoosortByRow(int64_t m, int64_t n, int64_t nnz, int *cooRows, int *cooCols
224223
CUSPARSE_CHECK(cusparseXcoosortByRow(handle, i_m, i_n, i_nnz, cooRows, cooCols, P, pBuffer));
225224
}
226225

227-
#endif
228226

229227
}}}} // namespace at::native::sparse::cuda

aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
namespace at { namespace native {
2626

2727
SparseTensor coalesce_sparse_cuda(const SparseTensor& self) {
28-
#ifndef __HIP_PLATFORM_HCC__
2928
int64_t nnz = self._nnz();
3029
if (self.is_coalesced()) {
3130
return self;
@@ -151,9 +150,6 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) {
151150

152151
THCudaCheck(cudaGetLastError());
153152
return dst;
154-
#else
155-
AT_ERROR("coalesce_sparse_cuda: HIP not supported");
156-
#endif
157153
}
158154

159155
}} // namespace at::native

aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ namespace at { namespace native {
2222
// Utility functions
2323
// --------------------------------------------------------------------
2424

25-
#ifndef __HIP_PLATFORM_HCC__
2625
namespace {
2726
IntTensor _to_csr_int(const LongTensor& rowIndices, int64_t dim, int64_t nnz) {
2827
IntTensor csr = at::empty({dim+1}, CUDA(kInt));
@@ -32,7 +31,6 @@ namespace {
3231
return csr;
3332
}
3433
}
35-
#endif
3634

3735
// NB: Deleted spaddcmul (aka addcmul_, but not actually wired up), spaddcdiv (not
3836
// wired at all)
@@ -42,7 +40,6 @@ namespace {
4240
// --------------------------------------------------------------------
4341

4442
Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseTensor& sparse_, const Tensor& dense, Scalar beta, Scalar alpha) {
45-
#ifndef __HIP_PLATFORM_HCC__
4643
AT_ASSERT(t.is_cuda()); // dispatch argument
4744
AT_CHECK(r_.is_cuda(), "addmm: expected 'out' to be CUDA, but got CPU");
4845
AT_CHECK(sparse_.is_cuda(), "addmm: expected 'mat1' to be CUDA, but got CPU");
@@ -142,9 +139,6 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT
142139

143140
r_.copy_(r__);
144141
return r_;
145-
#else
146-
AT_ERROR("s_addmm_out_sparse_dense_cuda: HIP not supported");
147-
#endif
148142
}
149143

150144
Tensor s_addmm_sparse_dense_cuda(
@@ -176,7 +170,6 @@ Tensor& s_addmm_sparse_dense_cuda_(
176170
// --------------------------------------------------------------------
177171

178172
SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse_, const Tensor& dense/* , Scalar alpha */) {
179-
#ifndef __HIP_PLATFORM_HCC__
180173
AT_ASSERT(sparse_.is_cuda()); // dispatch argument
181174
AT_CHECK(r_.is_cuda(), "hspmm: expected 'out' to be CUDA, but got CPU");
182175
AT_CHECK(dense.is_cuda(), "hspmm: expected 'mat2' to be CUDA, but got CPU");
@@ -232,9 +225,6 @@ SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse
232225
_get_sparse_impl(r_)->set_indices_and_values_unsafe(indices, values);
233226

234227
return r_;
235-
#else
236-
AT_ERROR("hspmm_out_sparse_cuda: HIP not supported");
237-
#endif
238228
}
239229

240230
SparseTensor hspmm_sparse_cuda(const SparseTensor& sparse, const Tensor& dense) {
@@ -249,7 +239,6 @@ SparseTensor hspmm_sparse_cuda(const SparseTensor& sparse, const Tensor& dense)
249239
// --------------------------------------------------------------------
250240

251241
Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorRef sparse_, at::Scalar value) {
252-
#ifndef __HIP_PLATFORM_HCC__
253242
const SparseTensor& sparse = sparse_.tref;
254243

255244
AT_ASSERT(dense.is_cuda()); // dispatch argument
@@ -344,17 +333,13 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR
344333
THCudaCheck(cudaGetLastError());
345334

346335
return r_;
347-
#else
348-
AT_ERROR("add_out_dense_sparse_cuda: HIP not supported");
349-
#endif
350336
}
351337

352338
// --------------------------------------------------------------------
353339
// add(SparseTensor, SparseTensor, Scalar) [broadcasts]
354340
// --------------------------------------------------------------------
355341

356342
SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const SparseTensor& src, Scalar value) {
357-
#ifndef __HIP_PLATFORM_HCC__
358343
AT_ASSERT(t.is_cuda()); // dispatch argument
359344
AT_CHECK(src.is_cuda(), "add: expected 'other' to be CUDA, but got CPU");
360345
AT_CHECK(r_.is_cuda(), "add: expected 'out' to be CUDA, but got CPU");
@@ -401,17 +386,13 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const
401386
// }
402387

403388
return r_;
404-
#else
405-
AT_ERROR("s_add_out_sparse_cuda: HIP not supported");
406-
#endif
407389
}
408390

409391
// --------------------------------------------------------------------
410392
// mul(SparseTensor, SparseTensor) [broadcasts]
411393
// --------------------------------------------------------------------
412394

413395
SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, const SparseTensor& src_) {
414-
#ifndef __HIP_PLATFORM_HCC__
415396
if (src_.dim() == 0) {
416397
return mul_out_sparse_zerodim(r_, t_, src_);
417398
} else if (t_.dim() == 0) {
@@ -480,9 +461,6 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons
480461
_get_sparse_impl(r_)->set_coalesced(true);
481462

482463
return r_;
483-
#else
484-
AT_ERROR("mul_out_sparse_cuda: HIP not supported");
485-
#endif
486464
}
487465

488466
}} // namespace at::native

test/test_sparse.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,6 @@ def test_tensor(x, exp_i, exp_v):
479479
exp_v = self.ValueTensor(2, 3, 0)
480480
test_tensor(x, exp_i, exp_v)
481481

482-
@skipIfRocm
483482
def test_clone(self):
484483
def test_shape(sparse_dims, nnz, with_size):
485484
x = self._gen_sparse(sparse_dims, nnz, with_size)[0]
@@ -824,7 +823,6 @@ def test_spadd_hybrid(self):
824823
self._test_spadd_shape(0, [50, 30, 0], [2, 0])
825824
self._test_spadd_shape(10, [50, 30, 20], [2, 0])
826825

827-
@skipIfRocm
828826
def test_norm(self):
829827
def test_shape(sparse_dims, nnz, with_size):
830828
x, _, _ = self._gen_sparse(sparse_dims, nnz, with_size)
@@ -924,7 +922,6 @@ def test_basic_ops_hybrid(self):
924922
self._test_basic_ops_shape(0, 0, [10, 10, 10], [2, 0])
925923
self._test_basic_ops_shape(0, 0, [10, 10, 0], [2, 0])
926924

927-
@skipIfRocm
928925
def test_add_dense_sparse_mismatch(self):
929926
def test_shape(dense_size, sparse_dims_shape, dense_dims_shape, sparse_size):
930927
x = torch.zeros(dense_size, dtype=self.value_dtype, device=self.device)
@@ -1198,7 +1195,6 @@ def test_storage_not_null(self):
11981195

11991196
@cuda_only
12001197
@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
1201-
@skipIfRocm
12021198
def test_same_gpu(self):
12031199
def check_device(x, device_id):
12041200
self.assertEqual(x.get_device(), device_id)
@@ -1308,7 +1304,6 @@ def test_factory(self):
13081304
self.assertEqual(device, sparse_tensor._values().device)
13091305
self.assertEqual(True, sparse_tensor.requires_grad)
13101306

1311-
@skipIfRocm
13121307
def test_factory_size_check(self):
13131308
indices = self.IndexTensor([[1, 2],
13141309
[0, 2]])
@@ -1374,7 +1369,6 @@ def test_factory_empty_indices(self):
13741369
expected_indices = torch.empty((4, 0), dtype=torch.long, device=device)
13751370
self.assertEqual(tensor._indices(), expected_indices)
13761371

1377-
@skipIfRocm
13781372
def test_factory_nnz(self):
13791373
indices = self.IndexTensor([[0]]) # (sparseDims, nnz): (1, 1)
13801374
values = self.ValueTensor([[1, 1], [1, 1]]) # (nnz, ...): (2, 2)
@@ -1408,7 +1402,6 @@ def test_shape(i_shape, v_shape, size, expected_size):
14081402
test_shape([3, 0], [0, 2, 4, 0], [0, 0, 0, 2, 4, 0], [0, 0, 0, 2, 4, 0])
14091403
test_shape([3, 0], [0, 2, 4, 0], [1, 2, 3, 2, 4, 0], [1, 2, 3, 2, 4, 0])
14101404

1411-
@skipIfRocm
14121405
def test_factory_dense_dims(self):
14131406
indices = self.IndexTensor([[0]])
14141407
values = self.ValueTensor([[[1, 1, 1], [1, 1, 1]]])
@@ -1439,7 +1432,6 @@ def test_factory_type_inference(self):
14391432
self.assertEqual(torch.int64, t.dtype)
14401433

14411434
@cuda_only
1442-
@skipIfRocm
14431435
def test_factory_device_type_inference(self):
14441436
# both indices/values are CUDA
14451437
shape = (1, 3)
@@ -1552,7 +1544,6 @@ def test_empty_full(self):
15521544
TestTorch._test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, None)
15531545
TestTorch._test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cuda:0'))
15541546

1555-
@skipIfRocm
15561547
def test_is_sparse(self):
15571548
x = torch.randn(3, 3)
15581549
self.assertFalse(x.is_sparse)
@@ -1602,7 +1593,6 @@ def _test_resize_shape(self, x_i, x_v, x_size, y_i, y_v, y_size):
16021593
self.assertEqual(x.to_dense().view(-1)[0:x_v_numel].view(x_v),
16031594
x_dense.view(-1)[0:x_v_numel].view(x_v))
16041595

1605-
@skipIfRocm
16061596
def test_resize(self):
16071597
# 1. Expand the size of some dense dimensions [Supported]
16081598
self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3],
@@ -1693,7 +1683,6 @@ def setUp(self):
16931683

16941684
class TestSparseOneOff(TestCase):
16951685
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
1696-
@skipIfRocm
16971686
def test_cuda_from_cpu(self):
16981687
with self.assertRaisesRegex(
16991688
RuntimeError,
@@ -1717,7 +1706,6 @@ def test_cuda_from_cpu(self):
17171706
[0, 4, 4, 0])
17181707

17191708
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
1720-
@skipIfRocm
17211709
def test_cuda_sparse_cpu_dense_add(self):
17221710
x = torch.zeros(3, 4, 4)
17231711
sparse_y = torch.cuda.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(),

tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2174,6 +2174,18 @@
21742174
("cusparseOperation_t", ("hipsparseOperation_t", CONV_TYPE, API_SPARSE)),
21752175
("cusparseCreate", ("hipsparseCreate", CONV_MATH_FUNC, API_SPARSE)),
21762176
("cusparseDestroy", ("hipsparseDestroy", CONV_MATH_FUNC, API_SPARSE)),
2177+
("cusparseXcoo2csr", ("hipsparseXcoo2csr", CONV_MATH_FUNC, API_SPARSE)),
2178+
("cusparseMatDescr_t", ("hipsparseMatDescr_t", CONV_MATH_FUNC, API_SPARSE)),
2179+
("cusparseCreateMatDescr", ("hipsparseCreateMatDescr", CONV_MATH_FUNC, API_SPARSE)),
2180+
("cusparseScsrmm2", ("hipsparseScsrmm2", CONV_MATH_FUNC, API_SPARSE)),
2181+
("cusparseDcsrmm2", ("hipsparseDcsrmm2", CONV_MATH_FUNC, API_SPARSE)),
2182+
("cusparseXcsrsort_bufferSizeExt", ("hipsparseXcsrsort_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE)),
2183+
("cusparseXcsrsort", ("hipsparseXcsrsort", CONV_MATH_FUNC, API_SPARSE)),
2184+
("cusparseXcoosort_bufferSizeExt", ("hipsparseXcoosort_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE)),
2185+
("cusparseXcoosortByRow", ("hipsparseXcoosortByRow", CONV_MATH_FUNC, API_SPARSE)),
2186+
("cusparseSetStream", ("hipsparseSetStream", CONV_MATH_FUNC, API_SPARSE)),
2187+
("cusparseCreateIdentityPermutation", ("hipsparseCreateIdentityPermutation", CONV_MATH_FUNC, API_SPARSE)),
2188+
("cusparseSetMatIndexBase", ("hipsparseSetMatIndexBase", CONV_MATH_FUNC, API_SPARSE)),
21772189
("CUSPARSE_STATUS_SUCCESS", ("HIPSPARSE_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_SPARSE)),
21782190
("CUSPARSE_STATUS_NOT_INITIALIZED", ("HIPSPARSE_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_SPARSE)),
21792191
("CUSPARSE_STATUS_ALLOC_FAILED", ("HIPSPARSE_STATUS_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_SPARSE)),
@@ -2183,6 +2195,12 @@
21832195
("CUSPARSE_STATUS_INTERNAL_ERROR", ("HIPSPARSE_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_SPARSE)),
21842196
("CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED", ("HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_SPARSE)),
21852197
("CUSPARSE_STATUS_ARCH_MISMATCH", ("HIPSPARSE_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_SPARSE)),
2198+
("CUSPARSE_STATUS_ZERO_PIVOT", ("HIPSPARSE_STATUS_ZERO_PIVOT", CONV_NUMERIC_LITERAL, API_SPARSE)),
2199+
("CUSPARSE_OPERATION_TRANSPOSE", ("HIPSPARSE_OPERATION_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPARSE)),
2200+
("CUSPARSE_OPERATION_NON_TRANSPOSE", ("HIPSPARSE_OPERATION_NON_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPARSE)),
2201+
("CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE", ("HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPARSE)),
2202+
("CUSPARSE_INDEX_BASE_ZERO", ("HIPSPARSE_INDEX_BASE_ZERO", CONV_NUMERIC_LITERAL, API_SPARSE)),
2203+
("CUSPARSE_INDEX_BASE_ONE", ("HIPSPARSE_INDEX_BASE_ONE", CONV_NUMERIC_LITERAL, API_SPARSE)),
21862204
])
21872205

21882206
PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict([

0 commit comments

Comments
 (0)