diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 9736bd0e166ba9..6904b6d72b2368 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -4,6 +4,8 @@ #include #include +#include +#include #ifdef __HIP_PLATFORM_HCC__ #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) @@ -100,7 +102,7 @@ namespace at { namespace cuda { namespace blas { -const char* _cublasGetErrorEnum(cublasStatus_t error) { +C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error) { if (error == CUBLAS_STATUS_SUCCESS) { return "CUBLAS_STATUS_SUCCESS"; } diff --git a/aten/src/ATen/cuda/CUDASolver.cpp b/aten/src/ATen/cuda/CUDASolver.cpp index 355238f6b2dcf4..8b906d4931feba 100644 --- a/aten/src/ATen/cuda/CUDASolver.cpp +++ b/aten/src/ATen/cuda/CUDASolver.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #ifdef CUDART_VERSION @@ -9,7 +10,7 @@ namespace at { namespace cuda { namespace solver { -const char* cusolverGetErrorMessage(cusolverStatus_t status) { +C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status) { switch (status) { case CUSOLVER_STATUS_SUCCESS: return "CUSOLVER_STATUS_SUCCES"; case CUSOLVER_STATUS_NOT_INITIALIZED: return "CUSOLVER_STATUS_NOT_INITIALIZED"; diff --git a/aten/src/ATen/cuda/Exceptions.h b/aten/src/ATen/cuda/Exceptions.h index 1414e319656bd3..d3fdfec69c9a62 100644 --- a/aten/src/ATen/cuda/Exceptions.h +++ b/aten/src/ATen/cuda/Exceptions.h @@ -2,6 +2,7 @@ #include #include +#include #ifdef CUDART_VERSION #include @@ -39,7 +40,7 @@ class CuDNNError : public c10::Error { } while (0) namespace at { namespace cuda { namespace blas { -const char* _cublasGetErrorEnum(cublasStatus_t error); +C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error); }}} // namespace at::cuda::blas #define TORCH_CUDABLAS_CHECK(EXPR) \ @@ -66,7 +67,7 @@ const char *cusparseGetErrorString(cusparseStatus_t status); #ifdef CUDART_VERSION namespace at { namespace cuda { namespace solver { -const char* cusolverGetErrorMessage(cusolverStatus_t status); +C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status); }}} // namespace at::cuda::solver #define TORCH_CUSOLVER_CHECK(EXPR) \ diff --git a/test/cpp_extensions/cublas_extension.cpp b/test/cpp_extensions/cublas_extension.cpp new file mode 100644 index 00000000000000..61945b1aa223a3 --- /dev/null +++ b/test/cpp_extensions/cublas_extension.cpp @@ -0,0 +1,17 @@ +#include + +#include +#include + +#include + +torch::Tensor noop_cublas_function(torch::Tensor x) { + cublasHandle_t handle; + TORCH_CUDABLAS_CHECK(cublasCreate(&handle)); + TORCH_CUDABLAS_CHECK(cublasDestroy(handle)); + return x; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("noop_cublas_function", &noop_cublas_function, "a cublas function"); +} diff --git a/test/cpp_extensions/cusolver_extension.cpp b/test/cpp_extensions/cusolver_extension.cpp new file mode 100644 index 00000000000000..515d09958a8d27 --- /dev/null +++ b/test/cpp_extensions/cusolver_extension.cpp @@ -0,0 +1,17 @@ +#include +#include + +#include + + +torch::Tensor noop_cusolver_function(torch::Tensor x) { + cusolverDnHandle_t handle; + TORCH_CUSOLVER_CHECK(cusolverDnCreate(&handle)); + TORCH_CUSOLVER_CHECK(cusolverDnDestroy(handle)); + return x; +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("noop_cusolver_function", &noop_cusolver_function, "a cusolver function"); +} diff --git a/test/cpp_extensions/setup.py b/test/cpp_extensions/setup.py index 7888d0e3a88bbd..96301510f05adf 100644 --- a/test/cpp_extensions/setup.py +++ b/test/cpp_extensions/setup.py @@ -48,6 +48,19 @@ 'nvcc': ['-O2']}) ext_modules.append(extension) +if torch.cuda.is_available() and CUDA_HOME is not None: + cublas_extension = CUDAExtension( + name='torch_test_cpp_extension.cublas_extension', + sources=['cublas_extension.cpp'] + ) + ext_modules.append(cublas_extension) + + cusolver_extension = CUDAExtension( + name='torch_test_cpp_extension.cusolver_extension', + sources=['cusolver_extension.cpp'] + ) + ext_modules.append(cusolver_extension) + setup( name='torch_test_cpp_extension', packages=['torch_test_cpp_extension'], diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index cf35e6b13265d9..979f23d0698709 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -80,6 +80,24 @@ def test_cuda_extension(self): # 2 * sigmoid(0) = 2 * 0.5 = 1 self.assertEqual(z, torch.ones_like(z)) + @common.skipIfRocm + @unittest.skipIf(not TEST_CUDA, "CUDA not found") + def test_cublas_extension(self): + from torch_test_cpp_extension import cublas_extension + + x = torch.zeros(100, device="cuda", dtype=torch.float32) + z = cublas_extension.noop_cublas_function(x) + self.assertEqual(z, x) + + @common.skipIfRocm + @unittest.skipIf(not TEST_CUDA, "CUDA not found") + def test_cusolver_extension(self): + from torch_test_cpp_extension import cusolver_extension + + x = torch.zeros(100, device="cuda", dtype=torch.float32) + z = cusolver_extension.noop_cusolver_function(x) + self.assertEqual(z, x) + @unittest.skipIf(IS_WINDOWS, "Not available on Windows") def test_no_python_abi_suffix_sets_the_correct_library_name(self): # For this test, run_test.py will call `python setup.py install` in the