Skip to content

Commit e8db57d

Browse files
committed
[ExecuTorch][Weight Sharing][XNNPACK] Serialize constant tensors into named data map
We serialize tensors into the named data map, and return the output in preprocess result. Allowing for XNNPACK to share tensors with the same name (instead of duplicating). A key change here is with fused tensors. For BN and Convolution Fusion, we fuse the conv weights and bias with the BN parameters creating new tensors. We then create get_attr nodes for these new parameters. Due to the graph.fx interpreter in export pass base, the new names we create for these new tensors are lost each time. As a result, at the end we introduce a new pass to preserve the names we created. This seems a little hacky for now, but is the only way to preserve the new fused names. Differential Revision: [D70315207](https://our.internmc.facebook.com/intern/diff/D70315207/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D70315207/)! [ghstack-poisoned]
1 parent c5cfe62 commit e8db57d

File tree

9 files changed

+100
-26
lines changed

9 files changed

+100
-26
lines changed

backends/xnnpack/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ python_library(
1919
"//executorch/exir/passes:const_prop_pass",
2020
"//executorch/exir/passes:memory_format_ops_pass",
2121
"//executorch/exir/program:program",
22+
"//executorch/backends/transforms:utils",
2223
],
2324
)

backends/xnnpack/_passes/fuse_batch_norm_with_conv.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,13 @@
99
import torch
1010

1111
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
12+
from executorch.backends.transforms.utils import (
13+
create_constant_placeholder,
14+
delete_constant_placeholder,
15+
)
16+
from torch.export.graph_signature import InputKind
1217

13-
from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node
18+
from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node, get_tensor_name
1419
from executorch.exir import ExportedProgram
1520
from executorch.exir.dialects._ops import ops as exir_ops
1621
from executorch.exir.pass_base import PassResult
@@ -28,7 +33,7 @@ class FuseBatchNormWithConvPass(XNNPACKPass):
2833

2934
def call(self, graph_module: torch.fx.GraphModule):
3035
graph = graph_module.graph
31-
counter = 0
36+
constant_placeholders_to_delete = set()
3237
for conv in graph.nodes:
3338
# We want to discover a chain of conv -> batch_norm.
3439
# Only proceed if the current node is a conv node, and has a single
@@ -55,9 +60,11 @@ def call(self, graph_module: torch.fx.GraphModule):
5560
assert len(conv.args) == 9
5661

5762
conv_weight = get_param_tensor(self.exported_program, conv.args[1])
63+
conv_weight_name = get_tensor_name(self.exported_program, conv.args[1])
5864
assert conv_weight is not None
5965

6066
conv_bias = get_param_tensor(self.exported_program, conv.args[2])
67+
conv_bias_name = get_tensor_name(self.exported_program, conv.args[2])
6168

6269
# Get the parameters from the batchnorm op
6370
assert (
@@ -95,32 +102,57 @@ def call(self, graph_module: torch.fx.GraphModule):
95102
bn_bias,
96103
is_transpose,
97104
)
105+
fused_weight_name = (conv_weight_name + "_fused_bn").replace(".", "_")
106+
fused_bias_name = (conv_bias_name + "_fused_bn").replace(".", "_")
98107

99108
# Modify the graph by updating the weight and bias of conv op
100109
# with the fused weight and bias params, and replacing all the users
101110
# of getitem(batchnorm) with the conv op.
102-
with graph.inserting_before(conv):
103-
fused_weight_name = f"_fused_with_bn_weight_{counter}"
104-
graph_module.register_parameter(fused_weight_name, fused_weight)
105-
fused_weight_node = graph.get_attr(fused_weight_name)
106-
fused_bias_name = f"_fused_with_bn_bias_{counter}"
107-
graph_module.register_parameter(fused_bias_name, fused_bias)
108-
fused_bias_node = graph.get_attr(fused_bias_name)
109-
110-
# Update the weight and bias of conv op
111-
conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else [])
112-
conv_args[1] = fused_weight_node
113-
conv_args[2] = fused_bias_node
114-
conv.args = tuple(conv_args)
111+
with graph.inserting_before(conv.args[1]):
112+
fused_conv_weight_node = create_constant_placeholder(
113+
exp_program=self.exported_program,
114+
graph=graph_module.graph,
115+
kind=InputKind.PARAMETER,
116+
name=fused_weight_name,
117+
data=fused_weight
118+
)
119+
if fused_bias is not None:
120+
fused_conv_bias_node = create_constant_placeholder(
121+
exp_program=self.exported_program,
122+
graph=graph_module.graph,
123+
kind=InputKind.PARAMETER,
124+
name=fused_bias_name,
125+
data=fused_bias
126+
)
127+
else:
128+
fused_conv_bias_node = None
129+
130+
conv.args = (
131+
conv.args[0],
132+
fused_conv_weight_node,
133+
fused_conv_bias_node,
134+
*conv.args[3:]
135+
)
136+
137+
115138
# Remove any use of batchnorm from the graph
116139
for user in bn.users.copy():
117140
assert user.target == operator.getitem
118141
user.replace_all_uses_with(conv)
119142
graph.erase_node(user)
120143

121144
graph.erase_node(bn)
145+
constant_placeholders_to_delete.update(
146+
conv.args[1:3] + bn.args[1:5]
147+
)
122148

123-
counter += 1
149+
if len(constant_placeholders_to_delete) > 0:
150+
graph_module.graph.eliminate_dead_code()
151+
for node in constant_placeholders_to_delete:
152+
if (node is not None) and (
153+
len(node.users) == 0
154+
):
155+
delete_constant_placeholder(self.exported_program, node)
124156

125157
graph_module.recompile()
126158
# To Regenerate meta data and shape information, retrace module

backends/xnnpack/operators/node_visitor.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717

1818
from executorch.backends.xnnpack.operators.quant_params import QuantParams
19+
from executorch.exir._serialize._named_data_store import NamedDataStore
1920

2021
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
2122
ConstantDataOffset,
@@ -30,11 +31,15 @@
3031
XNNTensorValue,
3132
XValue,
3233
)
34+
from executorch.backends.xnnpack.utils.xnnpack_constants import (
35+
UINT64_MAX
36+
)
3337
from executorch.backends.xnnpack.utils.utils import (
3438
check_or_raise,
3539
get_input_node,
3640
get_param_tensor,
3741
is_param_node,
42+
get_tensor_name,
3843
PERM_NCHW_TO_NHWC,
3944
)
4045

@@ -86,11 +91,11 @@ def __init__(
8691
self,
8792
exported_program: ExportedProgram,
8893
external_ids: Dict,
89-
constant_data_bytes: bytearray,
94+
named_data_store: NamedDataStore,
9095
) -> None:
9196
self._external_ids = external_ids or {}
9297
self._exported_program = exported_program or None
93-
self._constant_data_bytes = constant_data_bytes
98+
self._named_data_store = named_data_store
9499

95100
@property
96101
def external_ids(self) -> Dict:
@@ -579,12 +584,13 @@ def get_serialized_buffer_index(
579584
ctypes.POINTER(array_type),
580585
).contents
581586

582-
offset = len(self._constant_data_bytes)
587+
named_key = get_tensor_name(self.exported_program, get_attr_node)
588+
if named_key == "":
589+
raise ValueError(f"Tensor from node: {get_attr_node} has no name")
590+
583591
size = const_val.untyped_storage().nbytes()
584-
xnn_graph.constant_data.append(ConstantDataOffset(offset=offset, size=size))
585-
self._constant_data_bytes.extend(
586-
_pad_to(bytes(array), _aligned_size(size, CONSTANT_TENSOR_ALIGNMENT))
587-
)
592+
xnn_graph.constant_data.append(ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key))
593+
self._named_data_store.add_named_data(named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT)
588594

589595
return buffer_idx
590596

backends/xnnpack/serialization/schema.fbs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,20 @@ table XNNLeakyReLU {
316316
table ConstantDataOffset {
317317
// Constant data offsets are relative to the constant data base offset provided
318318
// in the XNNPACKHeader.
319+
// named_key and offset are mutually exclusive, meaning only one of these values
320+
// are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX.
321+
// If the offset is not UINT64_MAX, then the named key must be an empty string
319322
offset: uint64;
320323

321324
// The size in bytes of valid data starting at the offset. The constant data
322325
// may be followed by padding before the next piece of constant data
323326
size: uint64;
327+
328+
// unique string id used to query the offset from the named data store.
329+
// named_key and offset are mutually exclusive, meaning only one of these values
330+
// are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX.
331+
// If the offset is not UINT64_MAX, then the named key must be an empty string
332+
named_key: string;
324333
}
325334

326335
table XNNGraph {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ class XValue:
470470
class ConstantDataOffset:
471471
offset: int
472472
size: int
473+
named_key: str = ""
473474

474475

475476
@dataclass

backends/xnnpack/utils/gen_xnnpack_constants.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@
2626
} > xnnpack_constants.py
2727

2828
echo UINT32_MAX = 4294967295 >> xnnpack_constants.py
29+
echo UINT64_MAX = 18446744073709551615 >> xnnpack_constants.py
2930
awk '/^#define\s+XNN_/ { print $2,"=",$3} ' "$1"/include/xnnpack.h >> xnnpack_constants.py
3031
if ! grep -qc "^XNN_" xnnpack_constants.py; then false; fi

backends/xnnpack/utils/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,24 @@ def get_param_tensor(
131131
raise RuntimeError(f"unsupported param type, {node.op}.")
132132

133133

134+
def get_tensor_name(
135+
exp_prog: ExportedProgram, node: torch.fx.Node
136+
) -> str:
137+
if node is None:
138+
return ""
139+
if is_param(exp_prog, node):
140+
return exp_prog.graph_signature.inputs_to_parameters[node.name]
141+
elif is_buffer(exp_prog, node):
142+
return exp_prog.graph_signature.inputs_to_buffers[node.name]
143+
elif is_lifted_tensor_constant(exp_prog, node):
144+
return exp_prog.graph_signature.inputs_to_lifted_tensor_constants[node.name]
145+
else:
146+
assert(isinstance(node.target, str))
147+
return node.target
148+
149+
return ""
150+
151+
134152
def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]:
135153
"""
136154
Returns the source fn of the given node, return None if something goes wrong

backends/xnnpack/utils/xnnpack_constants.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
# Auto-generated by gen_xnnpack_constants.sh script. Do not modify
88
UINT32_MAX = 4294967295
9+
UINT64_MAX = 18446744073709551615
10+
XNN_EXTRA_BYTES = 128
911
XNN_EXTRA_BYTES = 16
1012
XNN_MAX_TENSOR_DIMS = 6
13+
XNN_INVALID_VALUE_ID = UINT32_MAX
1114
XNN_FLAG_HINT_SPARSE_INFERENCE = 0x00000001
1215
XNN_FLAG_HINT_FP16_INFERENCE = 0x00000002
1316
XNN_FLAG_FORCE_FP16_INFERENCE = 0x00000004
@@ -26,7 +29,8 @@
2629
XNN_FLAG_YIELD_WORKERS = 0x00000010
2730
XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER = 0x00000020
2831
XNN_FLAG_KEEP_DIMS = 0x00000040
29-
XNN_EXTRA_QUANTIZATION_PARAMS = 8
32+
XNN_EXTRA_QUANTIZATION_PARAMS = 10
33+
XNN_MIN_BLOCKSIZE = 32
3034
XNN_VALUE_FLAG_EXTERNAL_INPUT = 0x00000001
3135
XNN_VALUE_FLAG_EXTERNAL_OUTPUT = 0x00000002
3236
XNN_VALUE_FLAG_PERSISTENT = 0x00000004

backends/xnnpack/xnnpack_preprocess.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
PreprocessResult,
3939
)
4040
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
41+
from executorch.exir._serialize._named_data_store import NamedDataStore
4142
from torch.export.exported_program import ExportedProgram
4243

4344
DEFAULT_DEBUG_HANDLE = 65535
@@ -103,7 +104,7 @@ def preprocess(
103104
edge_program: ExportedProgram,
104105
compile_specs: List[CompileSpec],
105106
) -> PreprocessResult:
106-
107+
named_data_store = NamedDataStore()
107108
xnnpack_edge_compile_config = get_xnnpack_edge_compile_config()
108109

109110
# Need to wrap EP here because xnnpack does addmm to linear
@@ -162,7 +163,7 @@ def preprocess(
162163
)
163164

164165
constant_data_bytes = bytearray()
165-
node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes)
166+
node_visitors = get_node_visitors(ep, node_to_external_map, named_data_store)
166167

167168
for node in graph_module.graph.nodes:
168169
if node.op == "call_function":
@@ -191,4 +192,5 @@ def preprocess(
191192
xnnpack_graph, constant_data_bytes
192193
),
193194
debug_handle_map={},
195+
data_store_output=named_data_store.get_named_data_store_output(),
194196
)

0 commit comments

Comments
 (0)