diff --git a/ml-agents/tests/yamato/training_int_tests.py b/ml-agents/tests/yamato/training_int_tests.py index 2bec3c7435..5ef043b0a7 100644 --- a/ml-agents/tests/yamato/training_int_tests.py +++ b/ml-agents/tests/yamato/training_int_tests.py @@ -26,11 +26,10 @@ def run_training(python_version: str, csharp_version: str) -> bool: f"Running training with python={python_version or latest} and c#={csharp_version or latest}" ) output_dir = "models" if python_version else "results" - nn_file_expected = f"./{output_dir}/{run_id}/3DBall.nn" onnx_file_expected = f"./{output_dir}/{run_id}/3DBall.onnx" frozen_graph_file_expected = f"./{output_dir}/{run_id}/3DBall/frozen_graph_def.pb" - if os.path.exists(nn_file_expected): + if os.path.exists(onnx_file_expected): # Should never happen - make sure nothing leftover from an old test. print("Artifacts from previous build found!") return False @@ -96,21 +95,16 @@ def run_training(python_version: str, csharp_version: str) -> bool: if csharp_version is None and python_version is None: model_artifacts_dir = os.path.join(get_base_output_path(), "models") os.makedirs(model_artifacts_dir, exist_ok=True) - shutil.copy(nn_file_expected, model_artifacts_dir) shutil.copy(onnx_file_expected, model_artifacts_dir) shutil.copy(frozen_graph_file_expected, model_artifacts_dir) - if ( - res.returncode != 0 - or not os.path.exists(nn_file_expected) - or not os.path.exists(onnx_file_expected) - ): + if res.returncode != 0 or not os.path.exists(onnx_file_expected): print("mlagents-learn run FAILED!") return False if csharp_version is None and python_version is None: # Use abs path so that loading doesn't get confused - model_path = os.path.abspath(os.path.dirname(nn_file_expected)) + model_path = os.path.abspath(os.path.dirname(onnx_file_expected)) for extension in ["nn", "onnx"]: inference_ok = run_inference(env_path, model_path, extension) if not inference_ok: