|
14 | 14 | register_node_visitor,
|
15 | 15 | )
|
16 | 16 | from executorch.backends.arm.tosa_mapping import TosaArg
|
17 |
| -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args |
| 17 | +from executorch.backends.arm.tosa_quant_utils import ( |
| 18 | + build_rescale, |
| 19 | + search_quant_arg_downstream, |
| 20 | + search_quant_arg_upstream, |
| 21 | +) |
18 | 22 |
|
19 | 23 | from executorch.backends.arm.tosa_utils import build_reshape
|
20 |
| -from executorch.exir.dialects._ops import ops as exir_ops |
21 | 24 | from serializer.tosa_serializer import TosaOp
|
22 | 25 |
|
23 | 26 |
|
@@ -67,12 +70,7 @@ def define_node(
|
67 | 70 | input_zp = 0
|
68 | 71 | if is_quant_node:
|
69 | 72 | input_node = node.all_input_nodes[1]
|
70 |
| - # rank > 2 linear layer |
71 |
| - if input_node.target == exir_ops.edge.aten.view_copy.default: |
72 |
| - quant_node = input_node.all_input_nodes[0] |
73 |
| - else: |
74 |
| - quant_node = input_node |
75 |
| - input_zp = get_quant_node_args(quant_node).zp |
| 73 | + input_zp = search_quant_arg_upstream(input_node).zp |
76 | 74 | attr.ConvAttribute(
|
77 | 75 | pad=pad_attr,
|
78 | 76 | stride=stride_attr,
|
@@ -107,24 +105,16 @@ def define_node(
|
107 | 105 | # Read inputs' parent nodes
|
108 | 106 | _, input_node, weight_node = node.all_input_nodes
|
109 | 107 |
|
110 |
| - # rank > 2 linear layer |
111 |
| - if input_node.target == exir_ops.edge.aten.view_copy.default: |
112 |
| - quant_node = input_node.all_input_nodes[0] |
113 |
| - input_scale = get_quant_node_args(quant_node).scale |
114 |
| - consumer_node = list(node.users)[0] |
115 |
| - consumer_consumer_node = list(consumer_node.users)[0] |
116 |
| - quant_args = get_quant_node_args(consumer_consumer_node) |
117 |
| - consumer_node_scale = quant_args.scale |
118 |
| - consumer_node_node_zp = quant_args.zp |
119 |
| - else: |
120 |
| - input_scale = get_quant_node_args(input_node).scale |
121 |
| - consumer_node = list(node.users)[0] |
122 |
| - quant_args = get_quant_node_args(consumer_node) |
123 |
| - consumer_node_scale = quant_args.scale |
124 |
| - consumer_node_node_zp = quant_args.zp |
| 108 | + qargs = search_quant_arg_upstream(input_node) |
| 109 | + input_scale = qargs.scale |
| 110 | + consumer_node = list(node.users)[0] |
| 111 | + quant_args = search_quant_arg_downstream(consumer_node) |
| 112 | + |
| 113 | + consumer_node_scale = quant_args.scale |
| 114 | + consumer_node_node_zp = quant_args.zp |
125 | 115 |
|
126 | 116 | weight_node_q_node = weight_node.all_input_nodes[0]
|
127 |
| - weight_scale = get_quant_node_args(weight_node_q_node).scale |
| 117 | + weight_scale = search_quant_arg_upstream(weight_node_q_node).scale |
128 | 118 |
|
129 | 119 | output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale
|
130 | 120 |
|
|
0 commit comments