2
2
import re
3
3
import sys
4
4
from pathlib import Path
5
-
5
+ import argparse
6
6
import torch
7
- import torchaudio
8
7
9
- # the following import would invoke
10
- # _check_cuda_version()
11
- # via torchvision.extension._check_cuda_version()
12
- import torchvision
8
+
13
9
14
10
gpu_arch_ver = os .getenv ("GPU_ARCH_VER" )
15
11
gpu_arch_type = os .getenv ("GPU_ARCH_TYPE" )
@@ -38,20 +34,21 @@ def get_anaconda_output_for_package(pkg_name_str):
38
34
return output .strip ().split ('\n ' )[- 1 ]
39
35
40
36
41
- def check_nightly_binaries_date () -> None :
37
+ def check_nightly_binaries_date (package : str ) -> None :
42
38
torch_str = torch .__version__
43
- ta_str = torchaudio .__version__
44
- tv_str = torchvision .__version__
45
-
46
39
date_t_str = re .findall ("dev\d+" , torch .__version__ )
47
- date_ta_str = re .findall ("dev\d+" , torchaudio .__version__ )
48
- date_tv_str = re .findall ("dev\d+" , torchvision .__version__ )
49
40
50
- # check that the above three lists are equal and none of them is empty
51
- if not date_t_str or not date_t_str == date_ta_str == date_tv_str :
52
- raise RuntimeError (
53
- f"Expected torch, torchaudio, torchvision to be the same date. But they are from { date_t_str } , { date_ta_str } , { date_tv_str } respectively"
54
- )
41
+ if (package == "all" ):
42
+ ta_str = torchaudio .__version__
43
+ tv_str = torchvision .__version__
44
+ date_ta_str = re .findall ("dev\d+" , torchaudio .__version__ )
45
+ date_tv_str = re .findall ("dev\d+" , torchvision .__version__ )
46
+
47
+ # check that the above three lists are equal and none of them is empty
48
+ if not date_t_str or not date_t_str == date_ta_str == date_tv_str :
49
+ raise RuntimeError (
50
+ f"Expected torch, torchaudio, torchvision to be the same date. But they are from { date_t_str } , { date_ta_str } , { date_tv_str } respectively"
51
+ )
55
52
56
53
# check that the date is recent, at this point, date_torch_str is not empty
57
54
binary_date_str = date_t_str [0 ][3 :]
@@ -65,8 +62,7 @@ def check_nightly_binaries_date() -> None:
65
62
f"the binaries are from { binary_date_obj } and are more than 2 days old!"
66
63
)
67
64
68
-
69
- def smoke_test_cuda () -> None :
65
+ def smoke_test_cuda (package : str ) -> None :
70
66
if not torch .cuda .is_available () and is_cuda_system :
71
67
raise RuntimeError (f"Expected CUDA { gpu_arch_ver } . However CUDA is not loaded." )
72
68
if torch .cuda .is_available ():
@@ -79,23 +75,23 @@ def smoke_test_cuda() -> None:
79
75
print (f"torch cudnn: { torch .backends .cudnn .version ()} " )
80
76
print (f"cuDNN enabled? { torch .backends .cudnn .enabled } " )
81
77
82
- if installation_str . find ( "nightly" ) != - 1 :
83
- # just print out cuda version, as version check were already performed during import
84
- print ( f"torchvision cuda: { torch . ops . torchvision . _cuda_version () } " )
85
- print (f"torchaudio cuda: { torch .ops .torchaudio . cuda_version ()} " )
86
- else :
87
- # torchaudio runtime added the cuda verison check on 09/23/2022 via
88
- # https://github.com/pytorch/audio/pull/2707
89
- # so relying on anaconda output for pytorch-test and pytorch channel
90
- torchaudio_allstr = get_anaconda_output_for_package ( torchaudio . __name__ )
91
- if (
92
- is_cuda_system
93
- and "cu" + str ( gpu_arch_ver ). replace ( "." , "" ) not in torchaudio_allstr
94
- ):
95
- raise RuntimeError (
96
- f"CUDA version issue. Loaded: { torchaudio_allstr } Expected: { gpu_arch_ver } "
97
- )
98
-
78
+ if ( package == 'all' ) :
79
+ if installation_str . find ( "nightly" ) != - 1 :
80
+ # just print out cuda version, as version check were already performed during import
81
+ print (f"torchvision cuda: { torch .ops .torchvision . _cuda_version ()} " )
82
+ print ( f"torchaudio cuda: { torch . ops . torchaudio . cuda_version () } " )
83
+ else :
84
+ # torchaudio runtime added the cuda verison check on 09/23/2022 via
85
+ # https://github.com/ pytorch/audio/pull/2707
86
+ # so relying on anaconda output for pytorch-test and pytorch channel
87
+ torchaudio_allstr = get_anaconda_output_for_package ( torchaudio . __name__ )
88
+ if (
89
+ is_cuda_system
90
+ and "cu" + str ( gpu_arch_ver ). replace ( "." , "" ) not in torchaudio_allstr
91
+ ):
92
+ raise RuntimeError (
93
+ f"CUDA version issue. Loaded: { torchaudio_allstr } Expected: { gpu_arch_ver } "
94
+ )
99
95
100
96
def smoke_test_conv2d () -> None :
101
97
import torch .nn as nn
@@ -180,24 +176,37 @@ def smoke_test_torchaudio() -> None:
180
176
181
177
182
178
def main () -> None :
183
- # todo add torch, torchvision and torchaudio tests
179
+ parser = argparse .ArgumentParser ()
180
+ parser .add_argument (
181
+ "--package" ,
182
+ help = "Package to include in smoke testing" ,
183
+ type = str ,
184
+ choices = ["all" , "torchonly" ],
185
+ default = "all" ,
186
+ )
187
+
184
188
print (f"torch: { torch .__version__ } " )
185
- print (f"torchvision: { torchvision .__version__ } " )
186
- print (f"torchaudio: { torchaudio .__version__ } " )
187
- smoke_test_cuda ()
189
+ smoke_test_cuda (options .package )
190
+ smoke_test_conv2d ()
188
191
189
192
# only makes sense to check nightly package where dates are known
190
193
if installation_str .find ("nightly" ) != - 1 :
191
194
check_nightly_binaries_date ()
192
195
193
- smoke_test_conv2d ()
194
- smoke_test_torchaudio ()
195
- smoke_test_torchvision ()
196
- smoke_test_torchvision_read_decode ()
197
- smoke_test_torchvision_resnet50_classify ()
198
- if torch .cuda .is_available ():
199
- smoke_test_torchvision_resnet50_classify ("cuda" )
200
-
196
+ if options .package == "all" :
197
+ import torchaudio
198
+ # the following import would invoke
199
+ # _check_cuda_version()
200
+ # via torchvision.extension._check_cuda_version()
201
+ import torchvision
202
+ print (f"torchvision: { torchvision .__version__ } " )
203
+ print (f"torchaudio: { torchaudio .__version__ } " )
204
+ smoke_test_torchaudio ()
205
+ smoke_test_torchvision ()
206
+ smoke_test_torchvision_read_decode ()
207
+ smoke_test_torchvision_resnet50_classify ()
208
+ if torch .cuda .is_available ():
209
+ smoke_test_torchvision_resnet50_classify ("cuda" )
201
210
202
211
if __name__ == "__main__" :
203
212
main ()
0 commit comments