|
15 | 15 | from tensorflow.python.framework import tensor_util
|
16 | 16 | import tensorflow as tf
|
17 | 17 | import numpy as np
|
| 18 | +from packaging.version import Version |
18 | 19 | from tf2onnx.tflite.TensorType import TensorType as TFLiteTensorType
|
19 | 20 | from tf2onnx.tflite.Model import Model
|
20 | 21 | from tf2onnx.flexbuffers import read_flexbuffer
|
21 |
| -from tf2onnx.tf_utils import read_tf_node_def_attrs |
| 22 | +from tf2onnx.tf_utils import read_tf_node_def_attrs, get_tf_version |
22 | 23 | from tf2onnx.graph import Graph
|
23 | 24 | from tf2onnx import utils
|
24 | 25 |
|
@@ -196,14 +197,14 @@ def read_tflite_model(tflite_path):
|
196 | 197 | try:
|
197 | 198 | interpreter = tf.lite.Interpreter(tflite_path)
|
198 | 199 | interpreter.allocate_tensors()
|
199 |
| - tensor_cnt = model.Subgraphs(0).TensorsLength() |
200 |
| - for i in range(tensor_cnt): |
201 |
| - name = model.Subgraphs(0).Tensors(i).Name().decode() |
202 |
| - details = interpreter._get_tensor_details(i) # pylint: disable=protected-access |
203 |
| - if "shape_signature" in details: |
204 |
| - tensor_shapes[name] = details["shape_signature"].tolist() |
205 |
| - elif "shape" in details: |
206 |
| - tensor_shapes[name] = details["shape"].tolist() |
| 200 | + tensor_details = interpreter.get_tensor_details() |
| 201 | + |
| 202 | + for tensor_detail in tensor_details: |
| 203 | + name = tensor_detail.get('name') |
| 204 | + if "shape_signature" in tensor_detail: |
| 205 | + tensor_shapes[name] = tensor_detail["shape_signature"].tolist() |
| 206 | + elif "shape" in tensor_detail: |
| 207 | + tensor_shapes[name] = tensor_detail["shape"].tolist() |
207 | 208 | except Exception as e: # pylint: disable=broad-except
|
208 | 209 | logger.warning("Error loading model into tflite interpreter: %s", e)
|
209 | 210 | tflite_graphs = get_model_subgraphs(model)
|
|
0 commit comments