Skip to content

Commit df33fd2

Browse files
author
Yu Cong
committed
Hotfix: Circumvent tf-2.12 breaking change on tflite subgraph API to unbreak UT
TF-2.12.0 introduced API change that breaks tf2onnx UT tests on the tflite paths, due to the addition of compulsory subgraph arg to several function's input signature: tensorflow/tensorflow@55d84d7 This commit is a temporary hotfix to unbreak related UT failure. Existing tf2onnx's use cases get tflite Interpreter's tensors from model's first subgraph only. The hotfix hard-codes subgraph index to `0` to retain the same behavior while resolves API diff. Signed-off-by: Yu Cong <[email protected]>
1 parent 25c977c commit df33fd2

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

tests/backend_test_base.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tensorflow.python.ops import variables as variables_lib
1919
from tensorflow.python.ops import lookup_ops
2020
import onnx
21+
from packaging.version import Version
2122
from common import get_test_config
2223
from tfjs_runner import run_tfjs
2324
from tf2onnx import constants
@@ -26,7 +27,7 @@
2627
from tf2onnx import optimizer
2728
from tf2onnx.tf_loader import tf_reset_default_graph, tf_session, tf_placeholder, from_function, freeze_session
2829
from tf2onnx.tf_loader import tf_optimize, is_tf2, get_hash_table_info
29-
from tf2onnx.tf_utils import compress_graph_def
30+
from tf2onnx.tf_utils import compress_graph_def, get_tf_version
3031
from tf2onnx.graph import ExternalTensorStorage
3132
from tf2onnx.tflite.Model import Model
3233

@@ -249,14 +250,10 @@ def convert_to_tflite(self, graph_def, feed_dict, outputs):
249250

250251
def tflite_has_supported_types(self, tflite_path):
251252
try:
252-
with open(tflite_path, 'rb') as f:
253-
buf = f.read()
254-
buf = bytearray(buf)
255-
model = Model.GetRootAsModel(buf, 0)
256-
tensor_cnt = model.Subgraphs(0).TensorsLength()
257253
interpreter = tf.lite.Interpreter(tflite_path)
258-
for i in range(tensor_cnt):
259-
dtype = interpreter._get_tensor_details(i)['dtype'] # pylint: disable=protected-access
254+
tensor_details = interpreter.get_tensor_details()
255+
for tensor_detail in tensor_details:
256+
dtype = tensor_detail.get('dtype')
260257
if np.dtype(dtype).kind == 'O':
261258
return False
262259
return True

tf2onnx/tflite_utils.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
from tensorflow.python.framework import tensor_util
1616
import tensorflow as tf
1717
import numpy as np
18+
from packaging.version import Version
1819
from tf2onnx.tflite.TensorType import TensorType as TFLiteTensorType
1920
from tf2onnx.tflite.Model import Model
2021
from tf2onnx.flexbuffers import read_flexbuffer
21-
from tf2onnx.tf_utils import read_tf_node_def_attrs
22+
from tf2onnx.tf_utils import read_tf_node_def_attrs, get_tf_version
2223
from tf2onnx.graph import Graph
2324
from tf2onnx import utils
2425

@@ -196,14 +197,14 @@ def read_tflite_model(tflite_path):
196197
try:
197198
interpreter = tf.lite.Interpreter(tflite_path)
198199
interpreter.allocate_tensors()
199-
tensor_cnt = model.Subgraphs(0).TensorsLength()
200-
for i in range(tensor_cnt):
201-
name = model.Subgraphs(0).Tensors(i).Name().decode()
202-
details = interpreter._get_tensor_details(i) # pylint: disable=protected-access
203-
if "shape_signature" in details:
204-
tensor_shapes[name] = details["shape_signature"].tolist()
205-
elif "shape" in details:
206-
tensor_shapes[name] = details["shape"].tolist()
200+
tensor_details = interpreter.get_tensor_details()
201+
202+
for tensor_detail in tensor_details:
203+
name = tensor_detail.get('name')
204+
if "shape_signature" in tensor_detail:
205+
tensor_shapes[name] = tensor_detail["shape_signature"].tolist()
206+
elif "shape" in tensor_detail:
207+
tensor_shapes[name] = tensor_detail["shape"].tolist()
207208
except Exception as e: # pylint: disable=broad-except
208209
logger.warning("Error loading model into tflite interpreter: %s", e)
209210
tflite_graphs = get_model_subgraphs(model)

0 commit comments

Comments
 (0)