diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index 941d20c95a1..3fcfe6edd35 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -23,7 +23,17 @@ class MatchArgRanksPass(ExportPass): """ For ops in 'targeted_ops', make sure that the inputs share the same rank. - New dimensions are inserted at from the beginning of the + New dimensions are inserted from the beginning of the inputs that have a + lower rank to match the input with the highest rank. + + Example: + input0 = shape(4, 3, 2) + input1 = shape(2) + input2 = shape(3, 1) + Becomes: + input0 = shape(4, 3, 2) + input1 = shape(1, 1, 2) + input2 = shape(1, 3, 1) """ def __init__(self, exported_program): @@ -54,34 +64,6 @@ def _match_op_rank(self, graph_module, node, arg, max_rank): ) node.replace_input_with(arg, view) - def _match_buffer_rank(self, arg, max_rank): - """ - Change arg's fake tensor meta to match max_rank if: - - arg is found in inputs_to_buffers or inputs_to_parameters. - """ - fake_tensor = get_first_fake_tensor(arg) - shape = fake_tensor.shape - rank = len(shape) - new_shape = list([1] * (max_rank - rank) + list(shape)) - - buffer_name = None - if arg.name in self.exported_program.graph_signature.inputs_to_buffers: - buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ - arg.name - ] - elif arg.name in self.exported_program.graph_signature.inputs_to_parameters: - buffer_name = self.exported_program.graph_signature.inputs_to_parameters[ - arg.name - ] - if buffer_name: - new_tensor = self.exported_program.state_dict[buffer_name].reshape( - new_shape - ) - self.exported_program.state_dict[buffer_name] = new_tensor - arg.meta["val"] = fake_tensor.fake_mode.from_tensor( - new_tensor, static_shapes=True - ) - def call(self, graph_module: GraphModule) -> PassResult: for node in graph_module.graph.nodes: node = cast(Node, node) @@ -105,12 +87,7 @@ def call(self, graph_module: GraphModule) -> PassResult: if rank == max_rank: continue - # If the argument is call_function, match shape by inserting view node. - if arg.op == "call_function": - self._match_op_rank(graph_module, node, arg, max_rank) - else: - # If the argument is a buffer or parameter, adjust shape by changing the fake tensor meta. - self._match_buffer_rank(arg, max_rank) + self._match_op_rank(graph_module, node, arg, max_rank) graph_module.recompile() graph_module = super().call(graph_module).graph_module