diff --git a/torchaudio/_extension.py b/torchaudio/_extension.py index 97763abfae..a82816867e 100644 --- a/torchaudio/_extension.py +++ b/torchaudio/_extension.py @@ -100,4 +100,19 @@ def _init_extension(): pass +def _check_cuda_version(): + version = torch.ops.torchaudio.cuda_version() + if version is not None and torch.version.cuda is not None: + version_str = str(version) + ta_version = f"{version_str[:-3]}.{version_str[-2]}" + t_version = torch.version.cuda + if ta_version != t_version: + raise RuntimeError( + "Detected that PyTorch and TorchAudio were compiled with different CUDA versions. " + f"PyTorch has CUDA version {t_version} whereas TorchAudio has CUDA version {ta_version}. " + "Please install the TorchAudio version that matches your PyTorch version." + ) + + _init_extension() +_check_cuda_version() diff --git a/torchaudio/csrc/utils.cpp b/torchaudio/csrc/utils.cpp index 9fa807fd7d..6078fc13c0 100644 --- a/torchaudio/csrc/utils.cpp +++ b/torchaudio/csrc/utils.cpp @@ -1,5 +1,9 @@ #include +#ifdef USE_CUDA +#include +#endif + namespace torchaudio { namespace { @@ -30,12 +34,21 @@ bool is_ffmpeg_available() { #endif } +c10::optional cuda_version() { +#ifdef USE_CUDA + return CUDA_VERSION; +#else + return {}; +#endif +} + } // namespace TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def("torchaudio::is_sox_available", &is_sox_available); m.def("torchaudio::is_kaldi_available", &is_kaldi_available); m.def("torchaudio::is_ffmpeg_available", &is_ffmpeg_available); + m.def("torchaudio::cuda_version", &cuda_version); } } // namespace torchaudio