Skip to content

Cherry-pick the commit to make TORCH_(CUDABLAS|CUSOLVER)_CHECK usable in custom extensions #909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/util/irange.h>
#include <c10/macros/Export.h>

#ifdef __HIP_PLATFORM_HCC__
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
Expand Down Expand Up @@ -100,7 +102,7 @@ namespace at {
namespace cuda {
namespace blas {

const char* _cublasGetErrorEnum(cublasStatus_t error) {
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error) {
if (error == CUBLAS_STATUS_SUCCESS) {
return "CUBLAS_STATUS_SUCCESS";
}
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/cuda/CUDASolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
#include <ATen/NativeFunctions.h>
#include <ATen/cuda/CUDASolver.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/macros/Export.h>

#ifdef CUDART_VERSION

namespace at {
namespace cuda {
namespace solver {

const char* cusolverGetErrorMessage(cusolverStatus_t status) {
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status) {
switch (status) {
case CUSOLVER_STATUS_SUCCESS: return "CUSOLVER_STATUS_SUCCES";
case CUSOLVER_STATUS_NOT_INITIALIZED: return "CUSOLVER_STATUS_NOT_INITIALIZED";
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/cuda/Exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cublas_v2.h>
#include <cusparse.h>
#include <c10/macros/Export.h>

#ifdef CUDART_VERSION
#include <cusolver_common.h>
Expand Down Expand Up @@ -39,7 +40,7 @@ class CuDNNError : public c10::Error {
} while (0)

namespace at { namespace cuda { namespace blas {
const char* _cublasGetErrorEnum(cublasStatus_t error);
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
}}} // namespace at::cuda::blas

#define TORCH_CUDABLAS_CHECK(EXPR) \
Expand All @@ -66,7 +67,7 @@ const char *cusparseGetErrorString(cusparseStatus_t status);
#ifdef CUDART_VERSION

namespace at { namespace cuda { namespace solver {
const char* cusolverGetErrorMessage(cusolverStatus_t status);
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
}}} // namespace at::cuda::solver

#define TORCH_CUSOLVER_CHECK(EXPR) \
Expand Down
17 changes: 17 additions & 0 deletions test/cpp_extensions/cublas_extension.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <iostream>

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include <cublas_v2.h>

torch::Tensor noop_cublas_function(torch::Tensor x) {
cublasHandle_t handle;
TORCH_CUDABLAS_CHECK(cublasCreate(&handle));
TORCH_CUDABLAS_CHECK(cublasDestroy(handle));
return x;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("noop_cublas_function", &noop_cublas_function, "a cublas function");
}
17 changes: 17 additions & 0 deletions test/cpp_extensions/cusolver_extension.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include <cusolverDn.h>


torch::Tensor noop_cusolver_function(torch::Tensor x) {
cusolverDnHandle_t handle;
TORCH_CUSOLVER_CHECK(cusolverDnCreate(&handle));
TORCH_CUSOLVER_CHECK(cusolverDnDestroy(handle));
return x;
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("noop_cusolver_function", &noop_cusolver_function, "a cusolver function");
}
13 changes: 13 additions & 0 deletions test/cpp_extensions/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@
'nvcc': ['-O2']})
ext_modules.append(extension)

if torch.cuda.is_available() and CUDA_HOME is not None:
cublas_extension = CUDAExtension(
name='torch_test_cpp_extension.cublas_extension',
sources=['cublas_extension.cpp']
)
ext_modules.append(cublas_extension)

cusolver_extension = CUDAExtension(
name='torch_test_cpp_extension.cusolver_extension',
sources=['cusolver_extension.cpp']
)
ext_modules.append(cusolver_extension)

setup(
name='torch_test_cpp_extension',
packages=['torch_test_cpp_extension'],
Expand Down
18 changes: 18 additions & 0 deletions test/test_cpp_extensions_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,24 @@ def test_cuda_extension(self):
# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))

@common.skipIfRocm
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_cublas_extension(self):
from torch_test_cpp_extension import cublas_extension

x = torch.zeros(100, device="cuda", dtype=torch.float32)
z = cublas_extension.noop_cublas_function(x)
self.assertEqual(z, x)

@common.skipIfRocm
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_cusolver_extension(self):
from torch_test_cpp_extension import cusolver_extension

x = torch.zeros(100, device="cuda", dtype=torch.float32)
z = cusolver_extension.noop_cusolver_function(x)
self.assertEqual(z, x)

@unittest.skipIf(IS_WINDOWS, "Not available on Windows")
def test_no_python_abi_suffix_sets_the_correct_library_name(self):
# For this test, run_test.py will call `python setup.py install` in the
Expand Down