@@ -55,7 +55,7 @@ def check_nightly_binaries_date(package: str) -> None:
55
55
f"Expected { module ['name' ]} to be less then { NIGHTLY_ALLOWED_DELTA } days. But its { date_m_delta } "
56
56
)
57
57
58
- def cuda_runtime_error ():
58
+ def test_cuda_runtime_errors_captured ():
59
59
cuda_exception_missed = True
60
60
try :
61
61
torch ._assert_async (torch .tensor (0 , device = "cuda" ))
@@ -65,22 +65,13 @@ def cuda_runtime_error():
65
65
print (f"Caught CUDA exception with success: { e } " )
66
66
cuda_exception_missed = False
67
67
else :
68
- raise ( e )
68
+ raise e
69
69
if (cuda_exception_missed ):
70
70
raise RuntimeError ( f"Expected CUDA RuntimeError but have not received!" )
71
71
72
72
def smoke_test_cuda (package : str ) -> None :
73
73
if not torch .cuda .is_available () and is_cuda_system :
74
74
raise RuntimeError (f"Expected CUDA { gpu_arch_ver } . However CUDA is not loaded." )
75
- if torch .cuda .is_available ():
76
- if torch .version .cuda != gpu_arch_ver :
77
- raise RuntimeError (
78
- f"Wrong CUDA version. Loaded: { torch .version .cuda } Expected: { gpu_arch_ver } "
79
- )
80
- print (f"torch cuda: { torch .version .cuda } " )
81
- # todo add cudnn version validation
82
- print (f"torch cudnn: { torch .backends .cudnn .version ()} " )
83
- print (f"cuDNN enabled? { torch .backends .cudnn .enabled } " )
84
75
85
76
if (package == 'all' and is_cuda_system ):
86
77
for module in MODULES :
@@ -94,6 +85,19 @@ def smoke_test_cuda(package: str) -> None:
94
85
version = imported_module ._extension ._check_cuda_version ()
95
86
print (f"{ module ['name' ]} CUDA: { version } " )
96
87
88
+ if torch .cuda .is_available ():
89
+ if torch .version .cuda != gpu_arch_ver :
90
+ raise RuntimeError (
91
+ f"Wrong CUDA version. Loaded: { torch .version .cuda } Expected: { gpu_arch_ver } "
92
+ )
93
+ print (f"torch cuda: { torch .version .cuda } " )
94
+ # todo add cudnn version validation
95
+ print (f"torch cudnn: { torch .backends .cudnn .version ()} " )
96
+ print (f"cuDNN enabled? { torch .backends .cudnn .enabled } " )
97
+
98
+ # This check has to be run last, since its messing up CUDA runtime
99
+ test_cuda_runtime_errors_captured ()
100
+
97
101
98
102
def smoke_test_conv2d () -> None :
99
103
import torch .nn as nn
@@ -114,6 +118,7 @@ def smoke_test_conv2d() -> None:
114
118
with torch .cuda .amp .autocast ():
115
119
out = conv (x )
116
120
121
+
117
122
def smoke_test_modules ():
118
123
for module in MODULES :
119
124
if module ["repo" ]:
@@ -141,7 +146,6 @@ def main() -> None:
141
146
)
142
147
options = parser .parse_args ()
143
148
print (f"torch: { torch .__version__ } " )
144
- smoke_test_cuda (options .package )
145
149
smoke_test_conv2d ()
146
150
147
151
if options .package == "all" :
@@ -151,9 +155,7 @@ def main() -> None:
151
155
if installation_str .find ("nightly" ) != - 1 :
152
156
check_nightly_binaries_date (options .package )
153
157
154
- # This check has to be run last, since its messing up CUDA runtime
155
- if torch .cuda .is_available ():
156
- cuda_runtime_error ()
158
+ smoke_test_cuda (options .package )
157
159
158
160
159
161
if __name__ == "__main__" :
0 commit comments