Skip to content

CUDA extension with TORCH_CUDABLAS_CHECK throws undefined symbol error at import #67073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
xwang233 opened this issue Oct 22, 2021 · 6 comments
Closed
Labels
module: cpp-extensions Related to torch.utils.cpp_extension module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@xwang233
Copy link
Collaborator

xwang233 commented Oct 22, 2021

🐛 Bug

CUDA extension with TORCH_CUDABLAS_CHECK throws undefined symbol error

To Reproduce

Prepare two files to build a CUDA extension

  • cuda_ext.cpp
#include <iostream>

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include <cublas_v2.h>

void a_cublas_function() {
    printf("hello world\n");

    cublasHandle_t handle;
    TORCH_CUDABLAS_CHECK(cublasCreate(&handle));

    TORCH_CUDABLAS_CHECK(cublasDestroy(handle));
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("a_cublas_function", &a_cublas_function, "a cublas function");
}
  • setup.py
from setuptools import setup, Extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension

cublas_module = CUDAExtension(
            name='cublas_ext',
            sources=['cublas_ext.cpp']
        )

setup(
    name='cublas_ext_root',
    version='0.1',
    ext_modules=[cublas_module],
    cmdclass={
        'build_ext': BuildExtension.with_options(use_ninja=False)
    }
)

build cuda extension with

pip install -v --no-cache-dir .   

run test with

python -c 'import torch; import cublas_ext; cublas_ext.a_cublas_function();'

got error message

Traceback (most recent call last):
  File "<string>", line 1, in <module>
ImportError: /home/xwang/.local/lib/python3.9/site-packages/cublas_ext.cpython-39-x86_64-linux-gnu.so: undefined symbol: _ZN2at4cuda4blas19_cublasGetErrorEnumE14cublasStatus_t

Expected behavior

No error for TORCH_CUDABLAS_CHECK in cuda extensions.

Environment

pytorch is source build using gcc 10.3 from latest master commit

Collecting environment information...
PyTorch version: 1.11.0a0+gitf56a1a5
Is debug build: False
CUDA used to build PyTorch: 11.4
ROCM used to build PyTorch: N/A

OS: Manjaro Linux (x86_64)
GCC version: (GCC) 11.1.0
Clang version: Could not collect
CMake version: version 3.21.1
Libc version: glibc-2.33

Python version: 3.9.6 (default, Jun 30 2021, 10:22:16)  [GCC 11.1.0] (64-bit runtime)
Python platform: Linux-5.10.60-1-MANJARO-x86_64-with-glibc2.33
Is CUDA available: True
CUDA runtime version: 11.4.100
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 2070 SUPER
GPU 1: NVIDIA GeForce GTX 1070 Ti

Nvidia driver version: 470.63.01
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.2.2
/usr/lib/libcudnn_adv_infer.so.8.2.2
/usr/lib/libcudnn_adv_train.so.8.2.2
/usr/lib/libcudnn_cnn_infer.so.8.2.2
/usr/lib/libcudnn_cnn_train.so.8.2.2
/usr/lib/libcudnn_ops_infer.so.8.2.2
/usr/lib/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy==0.812
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.19.5
[pip3] torch==1.11.0a0+gitf56a1a5
[pip3] torch-tb-profiler==0.2.0
[conda] Could not collect

Additional context

N/A

cc @malfet @zou3519 @crcrpar @ptrblck @ngimel

@xwang233 xwang233 changed the title CUDA extension with TORCH_CUDABLAS_CHECK throws undefined symbol error CUDA extension with TORCH_CUDABLAS_CHECK throws undefined symbol error at import Oct 22, 2021
@xwang233 xwang233 added the module: cpp-extensions Related to torch.utils.cpp_extension label Oct 22, 2021
@ptrblck
Copy link
Collaborator

ptrblck commented Oct 24, 2021

Could this be failing because TORCH_CUDABLAS_CHECK might be missing the TORCH_CUDA_CPP_API export?

@mrshenli mrshenli added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 24, 2021
@crcrpar
Copy link
Collaborator

crcrpar commented Oct 25, 2021

TORCH_CUSOLVER_CHECK and TORCH_CUDASPARSE_CHECK have the same issue.
I could fix the former, cusolver check by exporting one internal function, but not the latter; I encountered ImportError: dynamic module does not define module export function on my workstation

EDIT: I saw different behaviors of TORCH_CUDASPARSE_CHECK in different environments. Let me revisit this one later.

facebook-github-bot pushed a commit that referenced this issue Oct 30, 2021
…67161)

Summary:
Make `TORCH_CUDABLAS_CHECK` and `TORCH_CUSOLVER_CHECK` available in custom extensions by exporting the internal functions called by the both macros.

Rel: #67073

cc xwang233 ptrblck

Pull Request resolved: #67161

Reviewed By: jbschlosser

Differential Revision: D31984694

Pulled By: ngimel

fbshipit-source-id: 0035ecd1398078cf7d3abc23aaefda57aaa31106
@byronyi
Copy link

byronyi commented Nov 2, 2021

Does this affect version 1.10?

crcrpar added a commit to crcrpar/pytorch that referenced this issue Nov 3, 2021
…ytorch#67161)

Summary:
Make `TORCH_CUDABLAS_CHECK` and `TORCH_CUSOLVER_CHECK` available in custom extensions by exporting the internal functions called by the both macros.

Rel: pytorch#67073

cc xwang233 ptrblck

Pull Request resolved: pytorch#67161

Reviewed By: jbschlosser

Differential Revision: D31984694

Pulled By: ngimel

fbshipit-source-id: 0035ecd1398078cf7d3abc23aaefda57aaa31106
@crcrpar
Copy link
Collaborator

crcrpar commented Nov 3, 2021

@byronyi it seems like the macro has not been available since the beginning (of course, including 1.10) in custom extensions.

facebook-github-bot pushed a commit that referenced this issue Nov 4, 2021
Summary:
Skip building extensions if windows following #67161 (comment)

Related issue: #67073

cc ngimel xwang233 ptrblck

Pull Request resolved: #67735

Reviewed By: bdhirsh

Differential Revision: D32141250

Pulled By: ngimel

fbshipit-source-id: 9bfdb7cf694c99f6fc8cbe9033a12429b6e4b6fe
@crcrpar
Copy link
Collaborator

crcrpar commented Nov 10, 2021

Retried TORCH_CUDASPARSE_CHECK in a custom extension and it worked without any changes.

The code used are as follows:

#include <iostream>

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include <cusparse.h>

void a_custom_function() {
    printf("hello world\n");
    cusparseHandle_t sparse_handle;
    TORCH_CUDASPARSE_CHECK(cusparseCreate(&sparse_handle));
    TORCH_CUDASPARSE_CHECK(cusparseDestroy(sparse_handle));
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("a_custom_function", &a_custom_function, "a custom function");
}
# setup.py
from setuptools import setup, Extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension

cublas_module = CUDAExtension(
            name='cuda_ext',
            sources=['cuda_ext.cpp']
        )

setup(
    name='cuda_ext_root',
    version='0.1',
    ext_modules=[cublas_module],
    cmdclass={
        'build_ext': BuildExtension.with_options(use_ninja=False)
    }
)

my environment is:

PyTorch version: 1.11.0a0+gitcb2a41e
Is debug build: False
CUDA used to build PyTorch: 11.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.19.6
Libc version: glibc-2.31

Python version: 3.8.10 (default, Jun  4 2021, 15:09:15)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.8.0-55-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.4.100
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080
Nvidia driver version: 470.63.01
cuDNN version: Probably one of the following:
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.2
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.20.2
[pip3] pytorch-lightning==1.5.0.dev0
[pip3] torch==1.11.0a0+gitcb2a41e
[pip3] torchmetrics==0.5.1
[conda] blas                      1.0                         mkl  
[conda] magma-cuda111             2.5.2                         1    pytorch
[conda] mkl                       2021.2.0           h06a4308_296  
[conda] mkl-include               2021.2.0           h06a4308_296  
[conda] mkl-service               2.3.0            py38h27cfd23_1  
[conda] mkl_fft                   1.3.0            py38h42c9631_2  
[conda] mkl_random                1.2.1            py38ha9443f7_2  
[conda] mypy_extensions           0.4.3                    py38_0  
[conda] numpy                     1.20.2           py38h2d18471_0  
[conda] numpy-base                1.20.2           py38hfae3a4d_0  
[conda] pytorch-lightning         1.5.0.dev0               pypi_0    pypi
[conda] torch                     1.11.0a0+gitcb2a41e           dev_0    <develop>
[conda] torchmetrics              0.5.1                    pypi_0    pypi

@xwang233
Copy link
Collaborator Author

Closed as fixed. Thanks @crcrpar ! 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpp-extensions Related to torch.utils.cpp_extension module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants