Skip to content

Commit fb88a10

Browse files
author
Yu Cong
committed
Hotfix: Circumvent tf-2.12 breaking change on tflite subgraph API to unbreak UT
TF-2.12.0 introduced API change that breaks tf2onnx UT tests on the tflite paths, due to the addition of compulsory subgraph arg to several function's input signature: tensorflow/tensorflow@55d84d7 This commit is a temporary hotfix to unbreak related UT failure. Existing tf2onnx's use cases get tflite Interpreter's tensors from model's first subgraph only. The hotfix hard-codes subgraph index to `0` to retain the same behavior while resolves API diff. Signed-off-by: Yu Cong <[email protected]>
1 parent 25c977c commit fb88a10

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

tests/backend_test_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,12 @@ def tflite_has_supported_types(self, tflite_path):
252252
with open(tflite_path, 'rb') as f:
253253
buf = f.read()
254254
buf = bytearray(buf)
255+
subgraph_idx = 0 # assume only one subgraph
255256
model = Model.GetRootAsModel(buf, 0)
256-
tensor_cnt = model.Subgraphs(0).TensorsLength()
257+
tensor_cnt = model.Subgraphs(subgraph_idx).TensorsLength()
257258
interpreter = tf.lite.Interpreter(tflite_path)
258259
for i in range(tensor_cnt):
259-
dtype = interpreter._get_tensor_details(i)['dtype'] # pylint: disable=protected-access
260+
dtype = interpreter._get_tensor_details(i, subgraph_idx)['dtype'] # pylint: disable=protected-access
260261
if np.dtype(dtype).kind == 'O':
261262
return False
262263
return True

tf2onnx/tflite_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,11 @@ def read_tflite_model(tflite_path):
196196
try:
197197
interpreter = tf.lite.Interpreter(tflite_path)
198198
interpreter.allocate_tensors()
199-
tensor_cnt = model.Subgraphs(0).TensorsLength()
199+
subgraph_idx = 0 # assume only one subgraph
200+
tensor_cnt = model.Subgraphs(subgraph_idx).TensorsLength()
200201
for i in range(tensor_cnt):
201-
name = model.Subgraphs(0).Tensors(i).Name().decode()
202-
details = interpreter._get_tensor_details(i) # pylint: disable=protected-access
202+
name = model.Subgraphs(subgraph_idx).Tensors(i).Name().decode()
203+
details = interpreter._get_tensor_details(i, subgraph_idx) # pylint: disable=protected-access
203204
if "shape_signature" in details:
204205
tensor_shapes[name] = details["shape_signature"].tolist()
205206
elif "shape" in details:

0 commit comments

Comments
 (0)