@@ -327,16 +327,14 @@ def run_selected_tests(workflow_name, test_run_test_path, overall_logs_path_curr
327
327
328
328
return selected_results_dict
329
329
330
- def run_test_and_summarize_results () -> Dict [str , Any ]:
331
- # parse args
332
- args = parse_args ()
333
- pytorch_root_dir = str (args .pytorch_root )
334
- priority_tests = bool (args .priority_tests )
335
- test_config = list [str ](args .test_config )
336
- default_list = list [str ](args .default_list )
337
- distributed_list = list [str ](args .distributed_list )
338
- inductor_list = list [str ](args .inductor_list )
339
- skip_rerun = bool (args .skip_rerun )
330
+ def run_test_and_summarize_results (
331
+ pytorch_root_dir : str ,
332
+ priority_tests : bool ,
333
+ test_config : List [str ],
334
+ default_list : List [str ],
335
+ distributed_list : List [str ],
336
+ inductor_list : List [str ],
337
+ skip_rerun : bool ) -> Dict [str , Any ]:
340
338
341
339
# copy current environment variables
342
340
_environ = dict (os .environ )
@@ -388,8 +386,8 @@ def run_test_and_summarize_results() -> Dict[str, Any]:
388
386
CONSOLIDATED_LOG_FILE_PATH = overall_logs_path_current_run + CONSOLIDATED_LOG_FILE_NAME
389
387
390
388
# Check multi gpu availability if distributed tests are enabled
391
- if ("distributed" in args . test_config ) or len (args . distributed_list ) != 0 :
392
- check_num_gpus_for_distributed ();
389
+ if ("distributed" in test_config ) or len (distributed_list ) != 0 :
390
+ check_num_gpus_for_distributed ()
393
391
394
392
# Install test requirements
395
393
command = "pip3 install -r requirements.txt && pip3 install -r .ci/docker/requirements-ci.txt"
@@ -511,7 +509,8 @@ def check_num_gpus_for_distributed():
511
509
assert num_gpus_visible > 1 , "Number of visible GPUs should be >1 to run distributed unit tests"
512
510
513
511
def main ():
514
- all_tests_results = run_test_and_summarize_results ()
512
+ args = parse_args ()
513
+ all_tests_results = run_test_and_summarize_results (args .pytorch_root , args .priority_tests , args .test_config , args .default_list , args .distributed_list , args .inductor_list , args .skip_rerun )
515
514
pprint (dict (all_tests_results ))
516
515
517
516
if __name__ == "__main__" :
0 commit comments