Skip to content

Commit ad832b8

Browse files
committed
Address comments
1 parent 9563600 commit ad832b8

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

test/smoke_test/smoke_test.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def check_nightly_binaries_date(package: str) -> None:
5555
f"Expected {module['name']} to be less then {NIGHTLY_ALLOWED_DELTA} days. But its {date_m_delta}"
5656
)
5757

58-
def cuda_runtime_error():
58+
def test_cuda_runtime_errors_captured():
5959
cuda_exception_missed=True
6060
try:
6161
torch._assert_async(torch.tensor(0, device="cuda"))
@@ -65,22 +65,13 @@ def cuda_runtime_error():
6565
print(f"Caught CUDA exception with success: {e}")
6666
cuda_exception_missed = False
6767
else:
68-
raise(e)
68+
raise e
6969
if(cuda_exception_missed):
7070
raise RuntimeError( f"Expected CUDA RuntimeError but have not received!")
7171

7272
def smoke_test_cuda(package: str) -> None:
7373
if not torch.cuda.is_available() and is_cuda_system:
7474
raise RuntimeError(f"Expected CUDA {gpu_arch_ver}. However CUDA is not loaded.")
75-
if torch.cuda.is_available():
76-
if torch.version.cuda != gpu_arch_ver:
77-
raise RuntimeError(
78-
f"Wrong CUDA version. Loaded: {torch.version.cuda} Expected: {gpu_arch_ver}"
79-
)
80-
print(f"torch cuda: {torch.version.cuda}")
81-
# todo add cudnn version validation
82-
print(f"torch cudnn: {torch.backends.cudnn.version()}")
83-
print(f"cuDNN enabled? {torch.backends.cudnn.enabled}")
8475

8576
if(package == 'all' and is_cuda_system):
8677
for module in MODULES:
@@ -94,6 +85,19 @@ def smoke_test_cuda(package: str) -> None:
9485
version = imported_module._extension._check_cuda_version()
9586
print(f"{module['name']} CUDA: {version}")
9687

88+
if torch.cuda.is_available():
89+
if torch.version.cuda != gpu_arch_ver:
90+
raise RuntimeError(
91+
f"Wrong CUDA version. Loaded: {torch.version.cuda} Expected: {gpu_arch_ver}"
92+
)
93+
print(f"torch cuda: {torch.version.cuda}")
94+
# todo add cudnn version validation
95+
print(f"torch cudnn: {torch.backends.cudnn.version()}")
96+
print(f"cuDNN enabled? {torch.backends.cudnn.enabled}")
97+
98+
# This check has to be run last, since its messing up CUDA runtime
99+
test_cuda_runtime_errors_captured()
100+
97101

98102
def smoke_test_conv2d() -> None:
99103
import torch.nn as nn
@@ -114,6 +118,7 @@ def smoke_test_conv2d() -> None:
114118
with torch.cuda.amp.autocast():
115119
out = conv(x)
116120

121+
117122
def smoke_test_modules():
118123
for module in MODULES:
119124
if module["repo"]:
@@ -141,7 +146,6 @@ def main() -> None:
141146
)
142147
options = parser.parse_args()
143148
print(f"torch: {torch.__version__}")
144-
smoke_test_cuda(options.package)
145149
smoke_test_conv2d()
146150

147151
if options.package == "all":
@@ -151,9 +155,7 @@ def main() -> None:
151155
if installation_str.find("nightly") != -1:
152156
check_nightly_binaries_date(options.package)
153157

154-
# This check has to be run last, since its messing up CUDA runtime
155-
if torch.cuda.is_available():
156-
cuda_runtime_error()
158+
smoke_test_cuda(options.package)
157159

158160

159161
if __name__ == "__main__":

0 commit comments

Comments
 (0)