File tree 7 files changed +73
-4
lines changed 7 files changed +73
-4
lines changed Original file line number Diff line number Diff line change 4
4
5
5
#include < ATen/cuda/CUDABlas.h>
6
6
#include < ATen/cuda/Exceptions.h>
7
+ #include < c10/util/irange.h>
8
+ #include < c10/macros/Export.h>
7
9
8
10
#ifdef __HIP_PLATFORM_HCC__
9
11
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
@@ -100,7 +102,7 @@ namespace at {
100
102
namespace cuda {
101
103
namespace blas {
102
104
103
- const char * _cublasGetErrorEnum (cublasStatus_t error) {
105
+ C10_EXPORT const char * _cublasGetErrorEnum (cublasStatus_t error) {
104
106
if (error == CUBLAS_STATUS_SUCCESS) {
105
107
return " CUBLAS_STATUS_SUCCESS" ;
106
108
}
Original file line number Diff line number Diff line change 2
2
#include < ATen/NativeFunctions.h>
3
3
#include < ATen/cuda/CUDASolver.h>
4
4
#include < c10/cuda/CUDACachingAllocator.h>
5
+ #include < c10/macros/Export.h>
5
6
6
7
#ifdef CUDART_VERSION
7
8
8
9
namespace at {
9
10
namespace cuda {
10
11
namespace solver {
11
12
12
- const char * cusolverGetErrorMessage (cusolverStatus_t status) {
13
+ C10_EXPORT const char * cusolverGetErrorMessage (cusolverStatus_t status) {
13
14
switch (status) {
14
15
case CUSOLVER_STATUS_SUCCESS: return " CUSOLVER_STATUS_SUCCES" ;
15
16
case CUSOLVER_STATUS_NOT_INITIALIZED: return " CUSOLVER_STATUS_NOT_INITIALIZED" ;
Original file line number Diff line number Diff line change 2
2
3
3
#include < cublas_v2.h>
4
4
#include < cusparse.h>
5
+ #include < c10/macros/Export.h>
5
6
6
7
#ifdef CUDART_VERSION
7
8
#include < cusolver_common.h>
@@ -39,7 +40,7 @@ class CuDNNError : public c10::Error {
39
40
} while (0 )
40
41
41
42
namespace at { namespace cuda { namespace blas {
42
- const char * _cublasGetErrorEnum (cublasStatus_t error);
43
+ C10_EXPORT const char * _cublasGetErrorEnum (cublasStatus_t error);
43
44
}}} // namespace at::cuda::blas
44
45
45
46
#define TORCH_CUDABLAS_CHECK (EXPR ) \
@@ -66,7 +67,7 @@ const char *cusparseGetErrorString(cusparseStatus_t status);
66
67
#ifdef CUDART_VERSION
67
68
68
69
namespace at { namespace cuda { namespace solver {
69
- const char * cusolverGetErrorMessage (cusolverStatus_t status);
70
+ C10_EXPORT const char * cusolverGetErrorMessage (cusolverStatus_t status);
70
71
}}} // namespace at::cuda::solver
71
72
72
73
#define TORCH_CUSOLVER_CHECK (EXPR ) \
Original file line number Diff line number Diff line change
1
+ #include < iostream>
2
+
3
+ #include < torch/extension.h>
4
+ #include < ATen/cuda/CUDAContext.h>
5
+
6
+ #include < cublas_v2.h>
7
+
8
+ torch::Tensor noop_cublas_function (torch::Tensor x) {
9
+ cublasHandle_t handle;
10
+ TORCH_CUDABLAS_CHECK (cublasCreate (&handle));
11
+ TORCH_CUDABLAS_CHECK (cublasDestroy (handle));
12
+ return x;
13
+ }
14
+
15
+ PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
16
+ m.def (" noop_cublas_function" , &noop_cublas_function, " a cublas function" );
17
+ }
Original file line number Diff line number Diff line change
1
+ #include < torch/extension.h>
2
+ #include < ATen/cuda/CUDAContext.h>
3
+
4
+ #include < cusolverDn.h>
5
+
6
+
7
+ torch::Tensor noop_cusolver_function (torch::Tensor x) {
8
+ cusolverDnHandle_t handle;
9
+ TORCH_CUSOLVER_CHECK (cusolverDnCreate (&handle));
10
+ TORCH_CUSOLVER_CHECK (cusolverDnDestroy (handle));
11
+ return x;
12
+ }
13
+
14
+
15
+ PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
16
+ m.def (" noop_cusolver_function" , &noop_cusolver_function, " a cusolver function" );
17
+ }
Original file line number Diff line number Diff line change 48
48
'nvcc' : ['-O2' ]})
49
49
ext_modules .append (extension )
50
50
51
+ if torch .cuda .is_available () and CUDA_HOME is not None :
52
+ cublas_extension = CUDAExtension (
53
+ name = 'torch_test_cpp_extension.cublas_extension' ,
54
+ sources = ['cublas_extension.cpp' ]
55
+ )
56
+ ext_modules .append (cublas_extension )
57
+
58
+ cusolver_extension = CUDAExtension (
59
+ name = 'torch_test_cpp_extension.cusolver_extension' ,
60
+ sources = ['cusolver_extension.cpp' ]
61
+ )
62
+ ext_modules .append (cusolver_extension )
63
+
51
64
setup (
52
65
name = 'torch_test_cpp_extension' ,
53
66
packages = ['torch_test_cpp_extension' ],
Original file line number Diff line number Diff line change @@ -80,6 +80,24 @@ def test_cuda_extension(self):
80
80
# 2 * sigmoid(0) = 2 * 0.5 = 1
81
81
self .assertEqual (z , torch .ones_like (z ))
82
82
83
+ @common .skipIfRocm
84
+ @unittest .skipIf (not TEST_CUDA , "CUDA not found" )
85
+ def test_cublas_extension (self ):
86
+ from torch_test_cpp_extension import cublas_extension
87
+
88
+ x = torch .zeros (100 , device = "cuda" , dtype = torch .float32 )
89
+ z = cublas_extension .noop_cublas_function (x )
90
+ self .assertEqual (z , x )
91
+
92
+ @common .skipIfRocm
93
+ @unittest .skipIf (not TEST_CUDA , "CUDA not found" )
94
+ def test_cusolver_extension (self ):
95
+ from torch_test_cpp_extension import cusolver_extension
96
+
97
+ x = torch .zeros (100 , device = "cuda" , dtype = torch .float32 )
98
+ z = cusolver_extension .noop_cusolver_function (x )
99
+ self .assertEqual (z , x )
100
+
83
101
@unittest .skipIf (IS_WINDOWS , "Not available on Windows" )
84
102
def test_no_python_abi_suffix_sets_the_correct_library_name (self ):
85
103
# For this test, run_test.py will call `python setup.py install` in the
You can’t perform that action at this time.
0 commit comments