Skip to content

Commit 11f4d99

Browse files
BLOrange-AMDjithunnair-amd
authored andcommitted
Removed args inside function (#1595)
Fixes SWDEV-475071 (cherry picked from commit 041aa1b47978154de63edc6b7ffcdea218a847a3)
1 parent abff421 commit 11f4d99

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

.automation_scripts/run_pytorch_unit_tests.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -327,16 +327,14 @@ def run_selected_tests(workflow_name, test_run_test_path, overall_logs_path_curr
327327

328328
return selected_results_dict
329329

330-
def run_test_and_summarize_results() -> Dict[str, Any]:
331-
# parse args
332-
args = parse_args()
333-
pytorch_root_dir = str(args.pytorch_root)
334-
priority_tests = bool(args.priority_tests)
335-
test_config = list[str](args.test_config)
336-
default_list = list[str](args.default_list)
337-
distributed_list = list[str](args.distributed_list)
338-
inductor_list = list[str](args.inductor_list)
339-
skip_rerun = bool(args.skip_rerun)
330+
def run_test_and_summarize_results(
331+
pytorch_root_dir: str,
332+
priority_tests: bool,
333+
test_config: List[str],
334+
default_list: List[str],
335+
distributed_list: List[str],
336+
inductor_list: List[str],
337+
skip_rerun: bool) -> Dict[str, Any]:
340338

341339
# copy current environment variables
342340
_environ = dict(os.environ)
@@ -388,8 +386,8 @@ def run_test_and_summarize_results() -> Dict[str, Any]:
388386
CONSOLIDATED_LOG_FILE_PATH = overall_logs_path_current_run + CONSOLIDATED_LOG_FILE_NAME
389387

390388
# Check multi gpu availability if distributed tests are enabled
391-
if ("distributed" in args.test_config) or len(args.distributed_list) != 0:
392-
check_num_gpus_for_distributed();
389+
if ("distributed" in test_config) or len(distributed_list) != 0:
390+
check_num_gpus_for_distributed()
393391

394392
# Install test requirements
395393
command = "pip3 install -r requirements.txt && pip3 install -r .ci/docker/requirements-ci.txt"
@@ -511,7 +509,8 @@ def check_num_gpus_for_distributed():
511509
assert num_gpus_visible > 1, "Number of visible GPUs should be >1 to run distributed unit tests"
512510

513511
def main():
514-
all_tests_results = run_test_and_summarize_results()
512+
args = parse_args()
513+
all_tests_results = run_test_and_summarize_results(args.pytorch_root, args.priority_tests, args.test_config, args.default_list, args.distributed_list, args.inductor_list, args.skip_rerun)
515514
pprint(dict(all_tests_results))
516515

517516
if __name__ == "__main__":

0 commit comments

Comments
 (0)