From 4a159e88e62d295d890b72555934e0e9288cb6b8 Mon Sep 17 00:00:00 2001 From: "Edward Wang (EcoF)" Date: Wed, 22 Jun 2022 00:58:16 -0700 Subject: [PATCH] fix submodule imports by importing functions directly Summary: fixes two sporadic issues from missing attributes: - breaking circular imports - submodule not being imported explicitly Reviewed By: ehhuang Differential Revision: D37071652 fbshipit-source-id: 0680f098384b0fd21076339750e9d1a96186ede3 --- torchvision/extension.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchvision/extension.py b/torchvision/extension.py index ea837c234d3..ae1da9c0d04 100644 --- a/torchvision/extension.py +++ b/torchvision/extension.py @@ -47,10 +47,10 @@ def _check_cuda_version(): """ if not _HAS_OPS: return -1 - import torch + from torch.version import cuda as torch_version_cuda _version = torch.ops.torchvision._cuda_version() - if _version != -1 and torch.version.cuda is not None: + if _version != -1 and torch_version_cuda is not None: tv_version = str(_version) if int(tv_version) < 10000: tv_major = int(tv_version[0]) @@ -58,8 +58,7 @@ def _check_cuda_version(): else: tv_major = int(tv_version[0:2]) tv_minor = int(tv_version[3]) - t_version = torch.version.cuda - t_version = t_version.split(".") + t_version = torch_version_cuda.split(".") t_major = int(t_version[0]) t_minor = int(t_version[1]) if t_major != tv_major or t_minor != tv_minor: