Skip to content

Commit 8fe6997

Browse files
hubertlu-twjithunnair-amd
authored andcommitted
Cherry-pick the commit to make TORCH_(CUDABLAS|CUSOLVER)_CHECK usable in custom extensions (ROCm#909)
1 parent b52dfa3 commit 8fe6997

File tree

7 files changed

+73
-4
lines changed

7 files changed

+73
-4
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include <ATen/cuda/CUDABlas.h>
66
#include <ATen/cuda/Exceptions.h>
7+
#include <c10/util/irange.h>
8+
#include <c10/macros/Export.h>
79

810
#ifdef __HIP_PLATFORM_HCC__
911
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
@@ -100,7 +102,7 @@ namespace at {
100102
namespace cuda {
101103
namespace blas {
102104

103-
const char* _cublasGetErrorEnum(cublasStatus_t error) {
105+
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error) {
104106
if (error == CUBLAS_STATUS_SUCCESS) {
105107
return "CUBLAS_STATUS_SUCCESS";
106108
}

aten/src/ATen/cuda/CUDASolver.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
#include <ATen/NativeFunctions.h>
33
#include <ATen/cuda/CUDASolver.h>
44
#include <c10/cuda/CUDACachingAllocator.h>
5+
#include <c10/macros/Export.h>
56

67
#ifdef CUDART_VERSION
78

89
namespace at {
910
namespace cuda {
1011
namespace solver {
1112

12-
const char* cusolverGetErrorMessage(cusolverStatus_t status) {
13+
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status) {
1314
switch (status) {
1415
case CUSOLVER_STATUS_SUCCESS: return "CUSOLVER_STATUS_SUCCES";
1516
case CUSOLVER_STATUS_NOT_INITIALIZED: return "CUSOLVER_STATUS_NOT_INITIALIZED";

aten/src/ATen/cuda/Exceptions.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <cublas_v2.h>
44
#include <cusparse.h>
5+
#include <c10/macros/Export.h>
56

67
#ifdef CUDART_VERSION
78
#include <cusolver_common.h>
@@ -39,7 +40,7 @@ class CuDNNError : public c10::Error {
3940
} while (0)
4041

4142
namespace at { namespace cuda { namespace blas {
42-
const char* _cublasGetErrorEnum(cublasStatus_t error);
43+
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
4344
}}} // namespace at::cuda::blas
4445

4546
#define TORCH_CUDABLAS_CHECK(EXPR) \
@@ -66,7 +67,7 @@ const char *cusparseGetErrorString(cusparseStatus_t status);
6667
#ifdef CUDART_VERSION
6768

6869
namespace at { namespace cuda { namespace solver {
69-
const char* cusolverGetErrorMessage(cusolverStatus_t status);
70+
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
7071
}}} // namespace at::cuda::solver
7172

7273
#define TORCH_CUSOLVER_CHECK(EXPR) \
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include <iostream>
2+
3+
#include <torch/extension.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
6+
#include <cublas_v2.h>
7+
8+
torch::Tensor noop_cublas_function(torch::Tensor x) {
9+
cublasHandle_t handle;
10+
TORCH_CUDABLAS_CHECK(cublasCreate(&handle));
11+
TORCH_CUDABLAS_CHECK(cublasDestroy(handle));
12+
return x;
13+
}
14+
15+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16+
m.def("noop_cublas_function", &noop_cublas_function, "a cublas function");
17+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include <torch/extension.h>
2+
#include <ATen/cuda/CUDAContext.h>
3+
4+
#include <cusolverDn.h>
5+
6+
7+
torch::Tensor noop_cusolver_function(torch::Tensor x) {
8+
cusolverDnHandle_t handle;
9+
TORCH_CUSOLVER_CHECK(cusolverDnCreate(&handle));
10+
TORCH_CUSOLVER_CHECK(cusolverDnDestroy(handle));
11+
return x;
12+
}
13+
14+
15+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16+
m.def("noop_cusolver_function", &noop_cusolver_function, "a cusolver function");
17+
}

test/cpp_extensions/setup.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@
4848
'nvcc': ['-O2']})
4949
ext_modules.append(extension)
5050

51+
if torch.cuda.is_available() and CUDA_HOME is not None:
52+
cublas_extension = CUDAExtension(
53+
name='torch_test_cpp_extension.cublas_extension',
54+
sources=['cublas_extension.cpp']
55+
)
56+
ext_modules.append(cublas_extension)
57+
58+
cusolver_extension = CUDAExtension(
59+
name='torch_test_cpp_extension.cusolver_extension',
60+
sources=['cusolver_extension.cpp']
61+
)
62+
ext_modules.append(cusolver_extension)
63+
5164
setup(
5265
name='torch_test_cpp_extension',
5366
packages=['torch_test_cpp_extension'],

test/test_cpp_extensions_aot.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,24 @@ def test_cuda_extension(self):
8080
# 2 * sigmoid(0) = 2 * 0.5 = 1
8181
self.assertEqual(z, torch.ones_like(z))
8282

83+
@common.skipIfRocm
84+
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
85+
def test_cublas_extension(self):
86+
from torch_test_cpp_extension import cublas_extension
87+
88+
x = torch.zeros(100, device="cuda", dtype=torch.float32)
89+
z = cublas_extension.noop_cublas_function(x)
90+
self.assertEqual(z, x)
91+
92+
@common.skipIfRocm
93+
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
94+
def test_cusolver_extension(self):
95+
from torch_test_cpp_extension import cusolver_extension
96+
97+
x = torch.zeros(100, device="cuda", dtype=torch.float32)
98+
z = cusolver_extension.noop_cusolver_function(x)
99+
self.assertEqual(z, x)
100+
83101
@unittest.skipIf(IS_WINDOWS, "Not available on Windows")
84102
def test_no_python_abi_suffix_sets_the_correct_library_name(self):
85103
# For this test, run_test.py will call `python setup.py install` in the

0 commit comments

Comments
 (0)