diff --git a/.circleci/config.yml b/.circleci/config.yml index f8f645edcb..555d72da70 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -18,12 +18,17 @@ jobs: pip_constraints: type: string description: Constraints file that is passed to "pip install". We constraint older versions of libraries for older python runtime, in order to help ensure compatibility. + enforce_onnx_conversion: + type: integer + default: 0 + description: Whether to raise an exception if ONNX models couldn't be saved. executor: << parameters.executor >> working_directory: ~/repo # Run additional numpy checks on unit tests environment: TEST_ENFORCE_NUMPY_FLOAT32: 1 + TEST_ENFORCE_ONNX_CONVERSION: << parameters.enforce_onnx_conversion >> steps: - checkout @@ -217,6 +222,8 @@ workflows: pyversion: 3.7.3 # Test python 3.7 with the newest supported versions pip_constraints: test_constraints_max_tf1_version.txt + # Make sure ONNX conversion passes here (recent version of tensorflow 1.x) + enforce_onnx_conversion: 1 - build_python: name: python_3.7.3+tf2 executor: python373 diff --git a/docs/Unity-Inference-Engine.md b/docs/Unity-Inference-Engine.md index 4dded6b883..a7d7626e0b 100644 --- a/docs/Unity-Inference-Engine.md +++ b/docs/Unity-Inference-Engine.md @@ -33,7 +33,7 @@ There are currently two supported model formats: * ONNX (`.onnx`) files use an [industry-standard open format](https://onnx.ai/about.html) produced by the [tf2onnx package](https://github.com/onnx/tensorflow-onnx). Export to ONNX is currently considered beta. To enable it, make sure `tf2onnx>=1.5.5` is installed in pip. -tf2onnx does not currently support tensorflow 2.0.0 or later. +tf2onnx does not currently support tensorflow 2.0.0 or later, or earlier than 1.12.0. ## Using the Unity Inference Engine diff --git a/ml-agents/mlagents/model_serialization.py b/ml-agents/mlagents/model_serialization.py index 449a4742f0..04cdc27647 100644 --- a/ml-agents/mlagents/model_serialization.py +++ b/ml-agents/mlagents/model_serialization.py @@ -1,5 +1,8 @@ +from distutils.util import strtobool +import os import logging from typing import Any, List, Set, NamedTuple +from distutils.version import LooseVersion try: import onnx @@ -18,6 +21,11 @@ from tensorflow.python.framework import graph_util from mlagents.trainers import tensorflow_to_barracuda as tf2bc +if LooseVersion(tf.__version__) < LooseVersion("1.12.0"): + # ONNX is only tested on 1.12.0 and later + ONNX_EXPORT_ENABLED = False + + logger = logging.getLogger("mlagents.trainers") POSSIBLE_INPUT_NODES = frozenset( @@ -67,18 +75,28 @@ def export_policy_model( logger.info(f"Exported {settings.model_path}.nn file") # Save to onnx too (if we were able to import it) - if ONNX_EXPORT_ENABLED and settings.convert_to_onnx: - try: - onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def) - onnx_output_path = settings.model_path + ".onnx" - with open(onnx_output_path, "wb") as f: - f.write(onnx_graph.SerializeToString()) - logger.info(f"Converting to {onnx_output_path}") - except Exception: - logger.exception( - "Exception trying to save ONNX graph. Please report this error on " - "https://github.com/Unity-Technologies/ml-agents/issues and " - "attach a copy of frozen_graph_def.pb" + if ONNX_EXPORT_ENABLED: + if settings.convert_to_onnx: + try: + onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def) + onnx_output_path = settings.model_path + ".onnx" + with open(onnx_output_path, "wb") as f: + f.write(onnx_graph.SerializeToString()) + logger.info(f"Converting to {onnx_output_path}") + except Exception: + # Make conversion errors fatal depending on environment variables (only done during CI) + if _enforce_onnx_conversion(): + raise + logger.exception( + "Exception trying to save ONNX graph. Please report this error on " + "https://github.com/Unity-Technologies/ml-agents/issues and " + "attach a copy of frozen_graph_def.pb" + ) + + else: + if _enforce_onnx_conversion(): + raise RuntimeError( + "ONNX conversion enforced, but couldn't import dependencies." ) @@ -203,3 +221,16 @@ def _process_graph(settings: SerializationSettings, graph: tf.Graph) -> List[str for n in nodes: logger.info("\t" + n) return nodes + + +def _enforce_onnx_conversion() -> bool: + env_var_name = "TEST_ENFORCE_ONNX_CONVERSION" + if env_var_name not in os.environ: + return False + + val = os.environ[env_var_name] + try: + # This handles e.g. "false" converting reasonably to False + return strtobool(val) + except Exception: + return False diff --git a/test_constraints_max_tf1_version.txt b/test_constraints_max_tf1_version.txt index 9a93467de8..d14c5fe4cc 100644 --- a/test_constraints_max_tf1_version.txt +++ b/test_constraints_max_tf1_version.txt @@ -3,6 +3,5 @@ # For projects with upper bounds, we should periodically update this list to the latest release version grpcio>=1.23.0 numpy>=1.17.2 -# Temporary workaround for https://github.com/tensorflow/tensorflow/issues/36179 and https://github.com/tensorflow/tensorflow/issues/36188 -tensorflow>=1.14.0,<1.15.1 +tensorflow>=1.15.2,<2.0.0 h5py>=2.10.0 diff --git a/test_constraints_min_version.txt b/test_constraints_min_version.txt index a5282c9785..a83b513f31 100644 --- a/test_constraints_min_version.txt +++ b/test_constraints_min_version.txt @@ -3,5 +3,5 @@ grpcio==1.11.0 numpy==1.14.1 Pillow==4.2.1 protobuf==3.6 -tensorflow==1.7 +tensorflow==1.7.0 h5py==2.9.0 diff --git a/test_requirements.txt b/test_requirements.txt index 8129de94df..c9680e1f66 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -3,6 +3,4 @@ pytest>4.0.0,<6.0.0 pytest-cov==2.6.1 pytest-xdist -# Tests install onnx and tf2onnx, but this doesn't support tensorflow>=2.0.0 -# Since we test tensorflow2.0 with python3.7, exclude it based on the python version -tf2onnx>=1.5.5; python_version < '3.7' +tf2onnx>=1.5.5