Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d69c2ec

Browse files
committedMay 12, 2025·
Arm backend: Add TOSA support for GroupNorm
- Decompose groupnorm into a sequence of supported operators - Have some numerical issues with BI profile - Fix docstring in decompose_layernorm_pass - Add "native_group_norm.default" to CUSTOM_EDGE_OPS Change-Id: I3f70388c12b8d9afd52876840b6c008a1b0bec4e Signed-off-by: Yufeng Shi <[email protected]>
1 parent b11807c commit d69c2ec

File tree

7 files changed

+361
-4
lines changed

7 files changed

+361
-4
lines changed
 

‎backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2323
from .decompose_div_pass import DecomposeDivPass # noqa
2424
from .decompose_gelu_pass import DecomposeGeluPass # noqa
25+
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
2526
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2627
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
2728
from .decompose_linear_pass import DecomposeLinearPass # noqa

‎backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
DecomposeCosineSimilarityPass,
2828
DecomposeDivPass,
2929
DecomposeGeluPass,
30+
DecomposeGroupNormPass,
3031
DecomposeLayerNormPass,
3132
DecomposeLeakyReLUPass,
3233
DecomposeLinearPass,
@@ -136,6 +137,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
136137
self.add_pass(DecomposeLinearPass())
137138
self.add_pass(DecomposeLeakyReLUPass())
138139
self.add_pass(DecomposeBatchNormPass())
140+
self.add_pass(DecomposeGroupNormPass())
139141
self.add_pass(DecomposeLayerNormPass())
140142
self.add_pass(DecomposeVarPass())
141143
self.add_pass(DecomposeMeanDimPass())
@@ -202,6 +204,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
202204
self.add_pass(DecomposeScaledDotProductAttention())
203205
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
204206
self.add_pass(ScalarsToAttributePass())
207+
self.add_pass(DecomposeGroupNormPass())
205208
self.add_pass(DecomposeLayerNormPass())
206209
self.add_pass(DecomposeVarPass())
207210
self.add_pass(DecomposeMeanDimPass())
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import operator
9+
10+
import torch
11+
from executorch.backends.arm._passes import ArmPass
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import PassResult
15+
16+
17+
def get_group_norm_decomposition(op) -> tuple:
18+
if op == exir_ops.edge.aten.native_group_norm.default:
19+
return (
20+
exir_ops.edge.aten.mean.dim,
21+
exir_ops.edge.aten.sub.Tensor,
22+
exir_ops.edge.aten.var.correction,
23+
exir_ops.edge.aten.full.default,
24+
exir_ops.edge.aten.add.Tensor,
25+
exir_ops.edge.aten.rsqrt.default,
26+
exir_ops.edge.aten.mul.Tensor,
27+
exir_ops.edge.aten.view_copy.default,
28+
)
29+
if op == torch.ops.aten.group_norm.default:
30+
return (
31+
torch.ops.aten.mean.dim,
32+
torch.ops.aten.sub.Tensor,
33+
torch.ops.aten.var.correction,
34+
torch.ops.aten.full.default,
35+
torch.ops.aten.add.Tensor,
36+
torch.ops.aten.rsqrt.default,
37+
torch.ops.aten.mul.Tensor,
38+
torch.ops.aten.view_copy.default,
39+
)
40+
raise RuntimeError(f"Can't get group_norm composition for op {op}")
41+
42+
43+
class DecomposeGroupNormPass(ArmPass):
44+
"""
45+
groupnorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
46+
Decompose groupnorm(x, weight, bias, N, C, HxW, group, eps) to a sequence of:
47+
mean = op_mean(x, dims) # E[x]
48+
var = op_var(x, dims) # Var[x]
49+
numerator = op_sub(x, mean) # (x - E[x])
50+
add = op_add(var, eps) # Var[x] + eps
51+
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
52+
mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps))
53+
weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54+
bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
55+
where x can viewed with shape [N, group, C//group, HxW] dims=[C//group, HxW]
56+
57+
Source: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html
58+
"""
59+
60+
def call(self, graph_module: torch.fx.GraphModule):
61+
modified = False
62+
for node in graph_module.graph.nodes:
63+
if node.op != "call_function" or node.target not in (
64+
exir_ops.edge.aten.native_group_norm.default,
65+
torch.ops.aten.group_norm.default,
66+
):
67+
continue
68+
69+
# epsilon default value
70+
eps = torch.finfo().eps
71+
weights = None
72+
bias = None
73+
args = node.args
74+
meta = node.meta
75+
if isinstance(meta["val"], tuple):
76+
shape = meta["val"][0].size()
77+
dtype = meta["val"][0].dtype
78+
else:
79+
shape = meta["val"].size()
80+
dtype = meta["val"].dtype
81+
match len(args):
82+
# MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps
83+
case 8:
84+
x, weights, bias, N, C, HxW, group, eps = args
85+
# BI profile: affine=[True|False], eps!=1e-5
86+
case 5:
87+
x, group, weights, bias, eps = args
88+
# BI profile: affine=True, eps=1e-5
89+
case 4:
90+
x, group, weights, bias = args
91+
# BI profile: affine=False, eps=1e=5
92+
case 2:
93+
x, group = args
94+
# Unsupported args
95+
case _:
96+
raise ValueError(
97+
f"Unsupported group_norm argument pattern with {len(args)} args"
98+
)
99+
N = shape[0]
100+
C = shape[1]
101+
HxW = 1
102+
for dim in shape[2:]:
103+
HxW *= dim
104+
channels_per_group = C // group
105+
grouped_shape = torch.Size([N, group, channels_per_group, HxW])
106+
dims = [2, 3]
107+
epsilon_reshaped_shape = torch.Size([1] * len(grouped_shape))
108+
weights_reshaped_shape = torch.Size([1, group, channels_per_group, 1])
109+
(
110+
mean_op,
111+
sub_op,
112+
var_op,
113+
full_op,
114+
add_op,
115+
rsqrt_op,
116+
mul_op,
117+
view_op,
118+
) = get_group_norm_decomposition(node.target)
119+
with graph_module.graph.inserting_before(node):
120+
keepdim = True
121+
x_reshaped = create_node(
122+
graph_module.graph,
123+
view_op,
124+
args=(x, grouped_shape),
125+
from_node=node,
126+
)
127+
mean = create_node(
128+
graph_module.graph, mean_op, args=(x_reshaped, dims, keepdim)
129+
)
130+
sub = create_node(graph_module.graph, sub_op, args=(x_reshaped, mean))
131+
var = create_node(
132+
graph_module.graph,
133+
var_op,
134+
args=(x_reshaped, dims),
135+
kwargs={"correction": 0, "keepdim": keepdim},
136+
from_node=node,
137+
)
138+
full = create_node(
139+
graph_module.graph,
140+
full_op,
141+
args=(epsilon_reshaped_shape, eps),
142+
kwargs={"dtype": dtype},
143+
from_node=node,
144+
)
145+
add0 = create_node(
146+
graph_module.graph, add_op, args=(var, full), from_node=node
147+
)
148+
rsqrt = create_node(
149+
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
150+
)
151+
mul0 = create_node(
152+
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
153+
)
154+
if weights is not None:
155+
weights_reshaped = create_node(
156+
graph_module.graph,
157+
view_op,
158+
args=(weights, weights_reshaped_shape),
159+
from_node=node,
160+
)
161+
mul1 = create_node(
162+
graph_module.graph,
163+
mul_op,
164+
args=(
165+
mul0,
166+
weights_reshaped,
167+
),
168+
from_node=node,
169+
)
170+
else:
171+
mul1 = mul0
172+
if bias is not None:
173+
bias_reshaped_shape = weights_reshaped_shape
174+
bias_reshaped = create_node(
175+
graph_module.graph,
176+
view_op,
177+
args=(bias, bias_reshaped_shape),
178+
from_node=node,
179+
)
180+
output = create_node(
181+
graph_module.graph,
182+
add_op,
183+
args=(mul1, bias_reshaped),
184+
from_node=node,
185+
)
186+
else:
187+
output = mul1
188+
189+
output_reshaped = create_node(
190+
graph_module.graph,
191+
view_op,
192+
args=(output, shape),
193+
from_node=node,
194+
)
195+
196+
users = [user for user in node.users if node != user]
197+
node.replace_all_uses_with(output_reshaped)
198+
for user in users:
199+
if user.target == operator.getitem:
200+
user.replace_all_uses_with(output_reshaped)
201+
graph_module.graph.erase_node(node)
202+
graph_module.graph.eliminate_dead_code()
203+
modified = True
204+
if modified:
205+
graph_module.recompile()
206+
graph_module = super().call(graph_module).graph_module
207+
208+
return PassResult(graph_module, modified)

‎backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -47,11 +46,12 @@ class DecomposeLayerNormPass(ArmPass):
4746
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
4847
mean = op_mean(x, dims) # E[x]
4948
var = op_var(x, dims) # Var[x]
50-
denominator = op_sub(x, mean) # (x - E[x])
49+
numerator = op_sub(x, mean) # (x - E[x])
5150
add = op_add(var, eps) # Var[x] + eps
5251
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
53-
mul = op_mul(denominator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54-
bias = op_add(mul, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
52+
mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps))
53+
weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
54+
bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
5555
5656
Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
5757
"""

‎backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def is_node_supported(
201201
exir_ops.edge.aten.div.Scalar,
202202
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
203203
exir_ops.edge.aten.native_layer_norm.default,
204+
exir_ops.edge.aten.native_group_norm.default,
204205
exir_ops.edge.aten.sigmoid.default,
205206
exir_ops.edge.aten.mean.dim,
206207
exir_ops.edge.aten.mm.default,
@@ -270,6 +271,7 @@ def is_node_supported(
270271
exir_ops.edge.aten.div.Tensor,
271272
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
272273
exir_ops.edge.aten.native_layer_norm.default,
274+
exir_ops.edge.aten.native_group_norm.default,
273275
exir_ops.edge.aten.mean.dim,
274276
exir_ops.edge.aten._softmax.default,
275277
exir_ops.edge.aten._log_softmax.default,

‎backends/arm/scripts/parse_test_names.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"linear.default",
1414
"maximum.default",
1515
"adaptive_avg_pool2d.default",
16+
"native_group_norm.default",
1617
]
1718
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
1819

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm.test import common
8+
from executorch.backends.arm.test.tester.test_pipeline import (
9+
EthosU55PipelineBI,
10+
EthosU85PipelineBI,
11+
TosaPipelineBI,
12+
TosaPipelineMI,
13+
)
14+
15+
16+
class GroupNorm(torch.nn.Module):
17+
18+
def __init__(
19+
self,
20+
num_groups: int,
21+
num_channels: int,
22+
eps: float = 1e-5,
23+
affine: bool = True,
24+
):
25+
super().__init__()
26+
self.group_norm = torch.nn.GroupNorm(
27+
num_groups,
28+
num_channels,
29+
eps=eps,
30+
affine=affine,
31+
)
32+
33+
def forward(
34+
self,
35+
x: torch.Tensor,
36+
):
37+
return self.group_norm(x)
38+
39+
40+
input_t = tuple[torch.Tensor]
41+
test_data_suite = {
42+
"rand_4_6_groups_1": ((torch.rand(4, 6),), GroupNorm(1, 6)),
43+
"rand_4_6_groups_2": ((torch.rand(4, 6),), GroupNorm(2, 6)),
44+
"rand_4_6_groups_6": ((torch.rand(4, 6),), GroupNorm(6, 6)),
45+
"rand_4_6_8_groups_2_eps_no_affine": (
46+
(torch.rand(4, 6, 8),),
47+
GroupNorm(2, 6, eps=1e-3, affine=False),
48+
),
49+
"randn_1_12_8_6_groups_6_eps": (
50+
(torch.randn(1, 12, 8, 6),),
51+
GroupNorm(6, 12, eps=1e-2),
52+
),
53+
"randn_1_12_8_6_groups_12": ((torch.randn(1, 12, 8, 6),), GroupNorm(12, 12)),
54+
"rand_6_8_10_12_groups_1": ((torch.rand(6, 8, 10, 12),), GroupNorm(1, 8)),
55+
"rand_6_8_10_12_groups_4_no_affine": (
56+
(torch.rand(6, 8, 10, 12),),
57+
GroupNorm(4, 8, affine=False),
58+
),
59+
"rand_6_8_10_12_groups_8": ((torch.rand(6, 8, 10, 12),), GroupNorm(8, 8)),
60+
}
61+
62+
63+
@common.parametrize("test_data", test_data_suite)
64+
def test_native_group_norm_tosa_MI(test_data):
65+
aten_op = "torch.ops.aten.group_norm.default"
66+
exir_op = "executorch_exir_dialects_edge__ops_aten_native_group_norm_default"
67+
pipeline = TosaPipelineMI[input_t](
68+
test_data[1],
69+
test_data[0],
70+
aten_op=aten_op,
71+
exir_op=exir_op,
72+
)
73+
pipeline.run()
74+
75+
76+
@common.parametrize(
77+
"test_data",
78+
test_data_suite,
79+
xfails={
80+
"randn_1_12_8_6_groups_12": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
81+
"rand_6_8_10_12_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
82+
"rand_6_8_10_12_groups_4_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
83+
"rand_6_8_10_12_groups_8": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
84+
},
85+
strict=False,
86+
)
87+
def test_native_group_norm_tosa_BI(test_data):
88+
aten_op = "torch.ops.aten.sub.Tensor" # 'sub' op arbitrarily chosen to confirm groupnorm was decomposed
89+
exir_op = "executorch_exir_dialects_edge__ops_aten_native_group_norm_default"
90+
pipeline = TosaPipelineBI[input_t](
91+
test_data[1],
92+
test_data[0],
93+
aten_op=aten_op,
94+
exir_op=exir_op,
95+
)
96+
pipeline.run()
97+
98+
99+
@common.parametrize(
100+
"test_data",
101+
test_data_suite,
102+
xfails={
103+
"randn_1_12_8_6_groups_12": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
104+
"rand_6_8_10_12_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
105+
"rand_6_8_10_12_groups_4_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
106+
"rand_6_8_10_12_groups_8": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
107+
},
108+
strict=False,
109+
)
110+
@common.XfailIfNoCorstone300
111+
def test_native_group_norm_u55_BI(test_data):
112+
pipeline = EthosU55PipelineBI[input_t](
113+
test_data[1],
114+
test_data[0],
115+
"torch.ops.aten.sub.Tensor", # 'sub' op arbitrarily chosen to confirm groupnorm was decomposed
116+
run_on_fvp=True,
117+
)
118+
pipeline.change_args("run_method_and_compare_outputs", atol=1, qtol=1)
119+
pipeline.run()
120+
121+
122+
@common.parametrize(
123+
"test_data",
124+
test_data_suite,
125+
xfails={
126+
"randn_1_12_8_6_groups_12": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
127+
"rand_6_8_10_12_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
128+
"rand_6_8_10_12_groups_4_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
129+
"rand_6_8_10_12_groups_8": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
130+
},
131+
strict=False,
132+
)
133+
@common.XfailIfNoCorstone320
134+
def test_native_group_norm_u85_BI(test_data):
135+
pipeline = EthosU85PipelineBI[input_t](
136+
test_data[1],
137+
test_data[0],
138+
"torch.ops.aten.sub.Tensor", # 'sub' op arbitrarily chosen to confirm groupnorm was decomposed
139+
run_on_fvp=True,
140+
)
141+
pipeline.change_args("run_method_and_compare_outputs", atol=1, qtol=1)
142+
pipeline.run()

0 commit comments

Comments
 (0)
Please sign in to comment.