|
| 1 | +from distutils.util import strtobool |
| 2 | +import os |
1 | 3 | import logging
|
2 | 4 | from typing import Any, List, Set, NamedTuple
|
| 5 | +from distutils.version import LooseVersion |
3 | 6 |
|
4 | 7 | try:
|
5 | 8 | import onnx
|
|
18 | 21 | from tensorflow.python.framework import graph_util
|
19 | 22 | from mlagents.trainers import tensorflow_to_barracuda as tf2bc
|
20 | 23 |
|
| 24 | +if LooseVersion(tf.__version__) < LooseVersion("1.12.0"): |
| 25 | + # ONNX is only tested on 1.12.0 and later |
| 26 | + ONNX_EXPORT_ENABLED = False |
| 27 | + |
| 28 | + |
21 | 29 | logger = logging.getLogger("mlagents.trainers")
|
22 | 30 |
|
23 | 31 | POSSIBLE_INPUT_NODES = frozenset(
|
@@ -67,18 +75,28 @@ def export_policy_model(
|
67 | 75 | logger.info(f"Exported {settings.model_path}.nn file")
|
68 | 76 |
|
69 | 77 | # Save to onnx too (if we were able to import it)
|
70 |
| - if ONNX_EXPORT_ENABLED and settings.convert_to_onnx: |
71 |
| - try: |
72 |
| - onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def) |
73 |
| - onnx_output_path = settings.model_path + ".onnx" |
74 |
| - with open(onnx_output_path, "wb") as f: |
75 |
| - f.write(onnx_graph.SerializeToString()) |
76 |
| - logger.info(f"Converting to {onnx_output_path}") |
77 |
| - except Exception: |
78 |
| - logger.exception( |
79 |
| - "Exception trying to save ONNX graph. Please report this error on " |
80 |
| - "https://github.com/Unity-Technologies/ml-agents/issues and " |
81 |
| - "attach a copy of frozen_graph_def.pb" |
| 78 | + if ONNX_EXPORT_ENABLED: |
| 79 | + if settings.convert_to_onnx: |
| 80 | + try: |
| 81 | + onnx_graph = convert_frozen_to_onnx(settings, frozen_graph_def) |
| 82 | + onnx_output_path = settings.model_path + ".onnx" |
| 83 | + with open(onnx_output_path, "wb") as f: |
| 84 | + f.write(onnx_graph.SerializeToString()) |
| 85 | + logger.info(f"Converting to {onnx_output_path}") |
| 86 | + except Exception: |
| 87 | + # Make conversion errors fatal depending on environment variables (only done during CI) |
| 88 | + if _enforce_onnx_conversion(): |
| 89 | + raise |
| 90 | + logger.exception( |
| 91 | + "Exception trying to save ONNX graph. Please report this error on " |
| 92 | + "https://github.com/Unity-Technologies/ml-agents/issues and " |
| 93 | + "attach a copy of frozen_graph_def.pb" |
| 94 | + ) |
| 95 | + |
| 96 | + else: |
| 97 | + if _enforce_onnx_conversion(): |
| 98 | + raise RuntimeError( |
| 99 | + "ONNX conversion enforced, but couldn't import dependencies." |
82 | 100 | )
|
83 | 101 |
|
84 | 102 |
|
@@ -203,3 +221,16 @@ def _process_graph(settings: SerializationSettings, graph: tf.Graph) -> List[str
|
203 | 221 | for n in nodes:
|
204 | 222 | logger.info("\t" + n)
|
205 | 223 | return nodes
|
| 224 | + |
| 225 | + |
| 226 | +def _enforce_onnx_conversion() -> bool: |
| 227 | + env_var_name = "TEST_ENFORCE_ONNX_CONVERSION" |
| 228 | + if env_var_name not in os.environ: |
| 229 | + return False |
| 230 | + |
| 231 | + val = os.environ[env_var_name] |
| 232 | + try: |
| 233 | + # This handles e.g. "false" converting reasonably to False |
| 234 | + return strtobool(val) |
| 235 | + except Exception: |
| 236 | + return False |
0 commit comments