Skip to content

Commit c9e1ee4

Browse files
authored
[SOW MS3] test_sparse (#1034)
* [ROCM] Enable some sparse tests on ROCM Enabling: test_sampled_addmm_errors_cuda_complex128 test_sampled_addmm_errors_cuda_complex64 test_sampled_addmm_errors_cuda_float32 test_sampled_addmm_errors_cuda_float64 test_sparse_add_cuda_complex128 test_sparse_add_cuda_complex64 * [ROCm] unskip * Hipified Sparse code to enable UTs * [ROCm] Enable sparse bmm tests * [ROCm] Enable test_csr_matvec * [ROCm} Cleanup some comments * [ROCm] re-adding cuda-hip mappings. * [ROCm] Added support for baddbmm and enabled test These changes require ROCm 5.2 * [ROCm] More hipify and descriptors.
1 parent 49780db commit c9e1ee4

File tree

10 files changed

+78
-23
lines changed

10 files changed

+78
-23
lines changed

aten/src/ATen/cuda/CUDASparse.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
// cuSparse Generic API added in CUDA 10.1
66
// Windows support added in CUDA 11.0
7-
// ROCm is not enabled
8-
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32)))
7+
#if defined(USE_ROCM) || (defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32))))
98
#define AT_USE_CUSPARSE_GENERIC_API() 1
109
#else
1110
#define AT_USE_CUSPARSE_GENERIC_API() 0

aten/src/ATen/cuda/CUDASparseDescriptors.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ CuSparseSpMatCsrDescriptor::CuSparseSpMatCsrDescriptor(const Tensor& input, int6
175175
value_type // data type of values
176176
));
177177

178-
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
178+
#if (defined(USE_ROCM) && ROCM_VERSION >= 50200) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11000)
179179
if (ndim == 3 && batch_offset == -1) {
180180
int batch_count =
181181
at::native::cuda_int_cast(at::native::batchCount(input), "batch_count");

aten/src/ATen/cuda/CUDASparseDescriptors.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ class CuSparseDescriptor {
4040
#if defined(USE_ROCM)
4141
// hipSPARSE doesn't define this
4242
using cusparseMatDescr = std::remove_pointer<cusparseMatDescr_t>::type;
43+
using cusparseDnMatDescr = std::remove_pointer<cusparseDnMatDescr_t>::type;
44+
using cusparseDnVecDescr = std::remove_pointer<cusparseDnVecDescr_t>::type;
45+
using cusparseSpMatDescr = std::remove_pointer<cusparseSpMatDescr_t>::type;
46+
using cusparseSpGEMMDescr = std::remove_pointer<cusparseSpGEMMDescr_t>::type;
4347
#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
4448
using bsrsv2Info = std::remove_pointer<bsrsv2Info_t>::type;
4549
using bsrsm2Info = std::remove_pointer<bsrsm2Info_t>::type;
@@ -96,11 +100,13 @@ class TORCH_CUDA_CPP_API CuSparseBsrsm2Info
96100

97101
cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type);
98102

103+
#if defined(CUDA_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 50200)
99104
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
100105
: public CuSparseDescriptor<cusparseDnMatDescr, &cusparseDestroyDnMat> {
101106
public:
102107
explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
103108
};
109+
#endif
104110

105111
class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
106112
: public CuSparseDescriptor<cusparseDnVecDescr, &cusparseDestroyDnVec> {
@@ -116,7 +122,7 @@ class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor
116122
public:
117123
explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1);
118124

119-
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
125+
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11000)
120126
std::tuple<int64_t, int64_t, int64_t> get_size() {
121127
int64_t rows, cols, nnz;
122128
TORCH_CUDASPARSE_CHECK(cusparseSpMatGetSize(
@@ -190,7 +196,7 @@ class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor
190196
};
191197
#endif
192198

193-
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
199+
#if (defined(USE_ROCM) && ROCM_VERSION >= 50200) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11000)
194200
class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor
195201
: public CuSparseDescriptor<cusparseSpGEMMDescr, &cusparseSpGEMM_destroyDescr> {
196202
public:

aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ void spmm(
562562
const Scalar& beta,
563563
const Scalar& alpha,
564564
const Tensor& result) {
565-
#if !AT_USE_CUSPARSE_GENERIC_API()
565+
#if !AT_USE_CUSPARSE_GENERIC_API() || (defined(USE_ROCM) && ROCM_VERSION < 50200)
566566
addmm_out_legacy(mat1, mat2, beta, alpha, result);
567567
#else
568568
c10::MaybeOwned<Tensor> result_ = prepare_dense_matrix_for_cusparse(result);
@@ -672,12 +672,19 @@ void spgemm(
672672
const Scalar& beta,
673673
const Scalar& alpha,
674674
const at::sparse_csr::SparseCsrTensor& C) {
675-
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
675+
#if !(defined(USE_ROCM)) && (defined(CUDA_VERSION) && CUDA_VERSION < 11000)
676676
TORCH_CHECK(
677677
false,
678678
"Calling addmm with sparse GPU tensors requires compiling ",
679679
"PyTorch with CUDA 11+. ",
680680
"Please use PyTorch built with newer CUDA version.");
681+
#elif defined(USE_ROCM) && ROCM_VERSION < 50200
682+
//TODO: Test if this is reachable
683+
TORCH_CHECK(
684+
false,
685+
"Calling addmm with sparse GPU tensors requires compiling ",
686+
"PyTorch with ROCm 5.2+. ",
687+
"Please use PyTorch built with newer ROCm version.");
681688
#else
682689
// older versions of cusparse on Windows segfault for complex128 dtype
683690
#if defined(_WIN32) && defined(CUSPARSE_VERSION) && CUSPARSE_VERSION < 11400

aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// Using these APIs in any other systems will result in compile-time or run-time failures.
1414
// Their support will be extended in the next releases.
1515

16-
#if defined(CUDART_VERSION) && (CUSPARSE_VERSION >= 11000 || (!defined(_MSC_VER) && CUSPARSE_VERSION >= 10301))
16+
#if (defined(CUDART_VERSION) && (CUSPARSE_VERSION >= 11000 || (!defined(_MSC_VER) && CUSPARSE_VERSION >= 10301))) || (defined(USE_ROCM) && ROCM_VERSION >= 50200)
1717
#define IS_SPMM_AVAILABLE() 1
1818
#else
1919
#define IS_SPMM_AVAILABLE() 0

aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ Tensor bmm_sparse_cuda(const SparseTensor& self, const Tensor& mat2) {
703703
return bmm_out_sparse_cuda(self, mat2, result);
704704
}
705705

706-
#if !(defined(USE_ROCM) || (defined(_MSC_VER) && CUSPARSE_VERSION < 11000))
706+
#if defined(USE_ROCM) || !(defined(_MSC_VER) && CUSPARSE_VERSION < 11000)
707707
__global__ void search_end_matrix_indices_cuda_kernel(
708708
int64_t* mat_el_end_indices,
709709
int64_t num_matrices,
@@ -784,11 +784,9 @@ cudaDataType getTensorCudaDataType(Tensor self) {
784784
#endif
785785

786786
Tensor& bmm_out_sparse_cuda(const SparseTensor& self, const Tensor& mat2, Tensor& result) {
787-
#if defined(USE_ROCM)
788-
TORCH_CHECK(false, "bmm sparse-dense is not supported on HIP");
789-
#elif defined(_MSC_VER) && (CUSPARSE_VERSION < 11000)
787+
#if defined(_MSC_VER) && (CUSPARSE_VERSION < 11000)
790788
TORCH_CHECK(false, "bmm sparse-dense CUDA is not supported on Windows with cuda before 11.0");
791-
#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 10010) // linux cuda >= 10.1 or windows cuda >= 11.0
789+
#elif defined(USE_ROCM) || (defined(CUDART_VERSION) && (CUDART_VERSION >= 10010)) // linux cuda >= 10.1 or windows cuda >= 11.0
792790

793791
TORCH_CHECK(!mat2.is_sparse(), "bmm_sparse: Tensor 'mat2' must be dense");
794792
TORCH_CHECK(self.dense_dim() == 0, "bmm_sparse: Tensor 'self' must have 0 dense dims, but has ", self.dense_dim());

test/test_sparse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.testing import make_tensor
1010
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
1111
do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS, gradcheck, coalescedonoff, \
12-
DeterministicGuard, first_sample
12+
DeterministicGuard, first_sample, TEST_WITH_ROCM
1313
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
1414
from numbers import Number
1515
from typing import Dict, Any
@@ -1135,7 +1135,7 @@ def test_shape(di, dj, dk, nnz):
11351135
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
11361136
)
11371137
@unittest.skipIf(
1138-
TEST_CUDA and _get_torch_cuda_version() < (10, 1),
1138+
TEST_CUDA and _get_torch_cuda_version() < (10, 1) and not TEST_WITH_ROCM,
11391139
"bmm sparse-dense requires CUDA 10.1 or greater"
11401140
)
11411141
@coalescedonoff
@@ -1197,7 +1197,7 @@ def test_shape(num_mats, dim_i, dim_j, dim_k, nnz):
11971197
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
11981198
)
11991199
@unittest.skipIf(
1200-
_get_torch_cuda_version() < (10, 1),
1200+
_get_torch_cuda_version() < (10, 1) and not TEST_WITH_ROCM,
12011201
"bmm sparse-dense requires CUDA 10.1 or greater"
12021202
)
12031203
def test_bmm_deterministic(self, device, dtype, coalesced):

test/test_sparse_csr.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,10 @@ def test_matmul_device_mismatch(self, device, dtype):
844844
*[torch.half] if SM53OrLater else [],
845845
*[torch.bfloat16] if SM80OrLater else []))
846846
def test_csr_matvec(self, device, dtype):
847+
848+
if TEST_WITH_ROCM and (dtype == torch.half or dtype == torch.bfloat16):
849+
self.skipTest("ROCm doesn't work with half dtypes correctly.")
850+
847851
side = 100
848852
for index_dtype in [torch.int32, torch.int64]:
849853
csr = self.genSparseCSRTensor((side, side), 1000, device=device, dtype=dtype, index_dtype=index_dtype)
@@ -860,7 +864,7 @@ def test_csr_matvec(self, device, dtype):
860864
csr.matmul(bad_vec)
861865

862866
@onlyCUDA
863-
@unittest.skipIf(not CUDA11OrLater, "Only CUDA 11+ is supported")
867+
@unittest.skipIf(not (CUDA11OrLater or TEST_WITH_ROCM), "Only CUDA 11+ is supported")
864868
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
865869
def test_baddbmm(self, device, dtype):
866870
def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None):
@@ -1462,9 +1466,6 @@ def _test_spadd_shape(fn, nnz, shape):
14621466
def test_sparse_add(self, device, dtype):
14631467
def run_test(m, n, index_dtype):
14641468

1465-
if TEST_WITH_ROCM and dtype.is_complex:
1466-
self.skipTest("ROCm doesn't work with complex dtype correctly.")
1467-
14681469
alpha = random.random()
14691470
nnz1 = random.randint(0, m * n)
14701471
nnz2 = random.randint(0, m * n)
@@ -1668,10 +1669,9 @@ def run_test(c, a, b):
16681669
b = make_tensor((k, n), dtype=dtype, device=device)
16691670
run_test(c, a, b)
16701671

1671-
@skipCUDAIfRocm
16721672
@onlyCUDA
16731673
@skipCUDAIf(
1674-
not _check_cusparse_sddmm_available(),
1674+
not (TEST_WITH_ROCM or _check_cusparse_sddmm_available()),
16751675
"cuSparse Generic API SDDMM is not available"
16761676
)
16771677
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)

torch/testing/_internal/common_device_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1310,7 +1310,7 @@ def wrap_fn(self, *args, **kwargs):
13101310

13111311
# Skips a test on CUDA if cuSparse generic API is not available
13121312
def skipCUDAIfNoCusparseGeneric(fn):
1313-
return skipCUDAIf(not TEST_CUSPARSE_GENERIC, "cuSparse Generic API not available")(fn)
1313+
return skipCUDAIf(not TEST_WITH_ROCM and not TEST_CUSPARSE_GENERIC, "cuSparse Generic API not available")(fn)
13141314

13151315
def skipCUDAIfNoCudnn(fn):
13161316
return skipCUDAIfCudnnVersionLessThan(0)(fn)

torch/utils/hipify/cuda_to_hip_mappings.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7884,6 +7884,51 @@
78847884
("hipsparseSetMatIndexBase", CONV_MATH_FUNC, API_SPARSE),
78857885
),
78867886
("cusparseSetMatType", ("hipsparseSetMatType", CONV_MATH_FUNC, API_SPARSE)),
7887+
("cusparseSpMV", ("hipsparseSpMV", CONV_MATH_FUNC, API_SPARSE)),
7888+
("cusparseSpMV_bufferSize", ("hipsparseSpMV_bufferSize", CONV_MATH_FUNC, API_SPARSE)),
7889+
("cusparseSpMM", ("hipsparseSpMM", CONV_MATH_FUNC, API_SPARSE)),
7890+
("cusparseSpMM_bufferSize", ("hipsparseSpMM_bufferSize", CONV_MATH_FUNC, API_SPARSE)),
7891+
("cusparseCreateDnMat", ("hipsparseCreateDnMat", CONV_MATH_FUNC, API_SPARSE)),
7892+
("cusparseDnMatSetStridedBatch", ("hipsparseDnMatSetStridedBatch", CONV_MATH_FUNC, API_SPARSE)),
7893+
("cusparseCsrSetStridedBatch", ("hipsparseCsrSetStridedBatch", CONV_MATH_FUNC, API_SPARSE)),
7894+
("cusparseCreateDnVec", ("hipsparseCreateDnVec", CONV_MATH_FUNC, API_SPARSE)),
7895+
("cusparseCreateCsr", ("hipsparseCreateCsr", CONV_MATH_FUNC, API_SPARSE)),
7896+
("cusparseDestroyDnMat", ("hipsparseDestroyDnMat", CONV_MATH_FUNC, API_SPARSE)),
7897+
("cusparseDestroyDnVec", ("hipsparseDestroyDnVec", CONV_MATH_FUNC, API_SPARSE)),
7898+
("cusparseDestroySpMat", ("hipsparseDestroySpMat", CONV_MATH_FUNC, API_SPARSE)),
7899+
("cusparseSpGEMM_destroyDescr", ("hipsparseSpGEMM_destroyDescr", CONV_MATH_FUNC, API_SPARSE)),
7900+
("cusparseCreateCoo", ("hipsparseCreateCoo", CONV_MATH_FUNC, API_SPARSE)),
7901+
("cusparseCreateCsr", ("hipsparseCreateCsr", CONV_MATH_FUNC, API_SPARSE)),
7902+
("cusparseSpGEMM_createDescr", ("hipsparseSpGEMM_createDescr", CONV_MATH_FUNC, API_SPARSE)),
7903+
("cusparseDnMatSetStridedBatch", ("hipsparseDnMatSetStridedBatch", CONV_MATH_FUNC, API_SPARSE)),
7904+
("cusparseSpGEMM_copy", ("hipsparseSpGEMM_copy", CONV_MATH_FUNC, API_SPARSE)),
7905+
("cusparseSpGEMM_compute", ("hipsparseSpGEMM_compute", CONV_MATH_FUNC, API_SPARSE)),
7906+
("cusparseSpGEMM_workEstimation", ("hipsparseSpGEMM_workEstimation", CONV_MATH_FUNC, API_SPARSE)),
7907+
("cusparseSpMatGetSize", ("hipsparseSpMatGetSize", CONV_MATH_FUNC, API_SPARSE)),
7908+
("cusparseCsrSetPointers", ("hipsparseCsrSetPointers", CONV_MATH_FUNC, API_SPARSE)),
7909+
("cusparseSpMVAlg_t", ("hipsparseSpMVAlg_t", CONV_TYPE, API_SPARSE)),
7910+
("cusparseSpMMAlg_t", ("hipsparseSpMMAlg_t", CONV_TYPE, API_SPARSE)),
7911+
("cusparseIndexType_t", ("hipsparseIndexType_t", CONV_TYPE, API_SPARSE)),
7912+
# Unsupported ("cusparseMatDescr", ("hipsparseMatDescr", CONV_TYPE, API_SPARSE)),
7913+
# Unsupported ("cusparseDnMatDescr", ("hipsparseDnMatDescr", CONV_TYPE, API_SPARSE)),
7914+
# Unsupported ("cusparseDnVecDescr", ("hipsparseDnVecDescr", CONV_TYPE, API_SPARSE)),
7915+
# Unsupported ("cusparseSpMatDescr", ("hipsparseSpMatDescr", CONV_TYPE, API_SPARSE)),
7916+
# Unsupported ("cusparseSpGEMMDescr", ("hipsparseSpGEMMDescr", CONV_TYPE, API_SPARSE)),
7917+
("cusparseDnMatDescr_t", ("hipsparseDnMatDescr_t", CONV_TYPE, API_SPARSE)),
7918+
("cusparseDnVecDescr_t", ("hipsparseDnVecDescr_t", CONV_TYPE, API_SPARSE)),
7919+
("cusparseSpMatDescr_t", ("hipsparseSpMatDescr_t", CONV_TYPE, API_SPARSE)),
7920+
("cusparseSpGEMMDescr_t", ("hipsparseSpGEMMDescr_t", CONV_TYPE, API_SPARSE)),
7921+
("CUSPARSE_INDEX_32I", ("HIPSPARSE_INDEX_32I", CONV_NUMERIC_LITERAL, API_SPARSE)),
7922+
("CUSPARSE_INDEX_64I", ("HIPSPARSE_INDEX_64I", CONV_NUMERIC_LITERAL, API_SPARSE)),
7923+
("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COLUMN", CONV_NUMERIC_LITERAL, API_SPARSE)),
7924+
("CUSPARSE_MV_ALG_DEFAULT", ("HIPSPARSE_MV_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPARSE)),
7925+
("CUSPARSE_MM_ALG_DEFAULT", ("HIPSPARSE_MM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPARSE)),
7926+
("CUSPARSE_COOMM_ALG1", ("HIPSPARSE_COOMM_ALG1", CONV_NUMERIC_LITERAL, API_SPARSE)),
7927+
("CUSPARSE_COOMM_ALG2", ("HIPSPARSE_COOMM_ALG2", CONV_NUMERIC_LITERAL, API_SPARSE)),
7928+
("CUSPARSE_COOMM_ALG3", ("HIPSPARSE_COOMM_ALG3", CONV_NUMERIC_LITERAL, API_SPARSE)),
7929+
("CUSPARSE_COOMV_ALG", ("HIPSPARSE_COOMV_ALG", CONV_NUMERIC_LITERAL, API_SPARSE)),
7930+
("CUSPARSE_CSRMM_ALG1", ("HIPSPARSE_CSRMM_ALG1", CONV_NUMERIC_LITERAL, API_SPARSE)),
7931+
("CUSPARSE_SPGEMM_DEFAULT", ("HIPSPARSE_SPGEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_SPARSE)),
78877932
(
78887933
"CUSPARSE_STATUS_SUCCESS",
78897934
("HIPSPARSE_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_SPARSE),

0 commit comments

Comments
 (0)