Skip to content

Commit c4796fb

Browse files
committed
Add CUDA version check
1 parent 7ba7cf4 commit c4796fb

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

torchaudio/_extension.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,19 @@ def _init_extension():
100100
pass
101101

102102

103+
def _check_cuda_version():
104+
version = torch.ops.torchaudio.cuda_version()
105+
if version != -1 and torch.version.cuda is not None:
106+
version_str = str(version)
107+
ta_version = f"{version_str[:-3]}.{version_str[-2]}"
108+
t_version = torch.version.cuda
109+
if ta_version != t_version:
110+
raise RuntimeError(
111+
"Detected that PyTorch and TorchAudio were compiled with different CUDA versions. "
112+
f"PyTorch has CUDA version {t_version} whereas TorchAudio has CUDA version {ta_version}. "
113+
"Please install the TorchAudio version that matches your PyTorch version."
114+
)
115+
116+
103117
_init_extension()
118+
_check_cuda_version()

torchaudio/csrc/utils.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include <torch/script.h>
22

3+
#ifdef USE_CUDA
4+
#include <cuda.h>
5+
#endif
6+
37
namespace torchaudio {
48

59
namespace {
@@ -30,12 +34,21 @@ bool is_ffmpeg_available() {
3034
#endif
3135
}
3236

37+
int64_t cuda_version() {
38+
#ifdef USE_CUDA
39+
return CUDA_VERSION;
40+
#else
41+
return -1;
42+
#endif
43+
}
44+
3345
} // namespace
3446

3547
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
3648
m.def("torchaudio::is_sox_available", &is_sox_available);
3749
m.def("torchaudio::is_kaldi_available", &is_kaldi_available);
3850
m.def("torchaudio::is_ffmpeg_available", &is_ffmpeg_available);
51+
m.def("torchaudio::cuda_version", &cuda_version);
3952
}
4053

4154
} // namespace torchaudio

0 commit comments

Comments
 (0)