Skip to content

Commit 0152029

Browse files
authored
Hotfix: Circumvent tf-2.12 breaking change on tflite subgraph API to unbreak UT (#2204)
TF-2.12.0 introduced API change that breaks tf2onnx UT tests on the tflite paths, due to the addition of compulsory subgraph arg to several function's input signature: tensorflow/tensorflow@55d84d7 This commit is a temporary hotfix to unbreak related UT failure. Existing tf2onnx's use cases get tflite Interpreter's tensors from model's first subgraph only. The hotfix hard-codes subgraph index to `0` to retain the same behavior while resolves API diff. Signed-off-by: Yu Cong <[email protected]>
1 parent 5259b4a commit 0152029

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

tests/backend_test_base.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from tf2onnx.tf_loader import tf_optimize, is_tf2, get_hash_table_info
2929
from tf2onnx.tf_utils import compress_graph_def
3030
from tf2onnx.graph import ExternalTensorStorage
31-
from tf2onnx.tflite.Model import Model
3231

3332

3433
if is_tf2():
@@ -249,14 +248,10 @@ def convert_to_tflite(self, graph_def, feed_dict, outputs):
249248

250249
def tflite_has_supported_types(self, tflite_path):
251250
try:
252-
with open(tflite_path, 'rb') as f:
253-
buf = f.read()
254-
buf = bytearray(buf)
255-
model = Model.GetRootAsModel(buf, 0)
256-
tensor_cnt = model.Subgraphs(0).TensorsLength()
257251
interpreter = tf.lite.Interpreter(tflite_path)
258-
for i in range(tensor_cnt):
259-
dtype = interpreter._get_tensor_details(i)['dtype'] # pylint: disable=protected-access
252+
tensor_details = interpreter.get_tensor_details()
253+
for tensor_detail in tensor_details:
254+
dtype = tensor_detail.get('dtype')
260255
if np.dtype(dtype).kind == 'O':
261256
return False
262257
return True

tf2onnx/tflite_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,14 @@ def read_tflite_model(tflite_path):
196196
try:
197197
interpreter = tf.lite.Interpreter(tflite_path)
198198
interpreter.allocate_tensors()
199-
tensor_cnt = model.Subgraphs(0).TensorsLength()
200-
for i in range(tensor_cnt):
201-
name = model.Subgraphs(0).Tensors(i).Name().decode()
202-
details = interpreter._get_tensor_details(i) # pylint: disable=protected-access
203-
if "shape_signature" in details:
204-
tensor_shapes[name] = details["shape_signature"].tolist()
205-
elif "shape" in details:
206-
tensor_shapes[name] = details["shape"].tolist()
199+
tensor_details = interpreter.get_tensor_details()
200+
201+
for tensor_detail in tensor_details:
202+
name = tensor_detail.get('name')
203+
if "shape_signature" in tensor_detail:
204+
tensor_shapes[name] = tensor_detail["shape_signature"].tolist()
205+
elif "shape" in tensor_detail:
206+
tensor_shapes[name] = tensor_detail["shape"].tolist()
207207
except Exception as e: # pylint: disable=broad-except
208208
logger.warning("Error loading model into tflite interpreter: %s", e)
209209
tflite_graphs = get_model_subgraphs(model)

0 commit comments

Comments
 (0)