Skip to content

Commit 8e200b7

Browse files
committed
enable 4-bit embedding on llama.py
1 parent 79be866 commit 8e200b7

File tree

8 files changed

+1097
-11
lines changed

8 files changed

+1097
-11
lines changed

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,13 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
111111
return supported
112112

113113
def __del__(self):
114-
self.qnn_manager.Destroy()
114+
# HTP op package contains some static data structures
115+
# which will trigger preparation failure in qnn_preprocess
116+
# if libQnnHtp.so is not fully unloaded
117+
# ---
118+
# currently we'll just keep manager alive for simplicity
119+
#self.qnn_manager.Destroy()
120+
pass
115121

116122

117123
class QnnPartitioner(Partitioner):
@@ -179,7 +185,12 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu
179185
# pop certain keys in meta for not affecting the passes in compilation
180186
# TODO: need to put property name in common definitions
181187
node.meta.pop(QCOM_AXIS_ORDER, "")
182-
del self.op_support_checker
188+
# HTP op package contains some static data structures
189+
# which will trigger preparation failure in qnn_preprocess
190+
# if libQnnHtp.so is not fully unloaded
191+
# ---
192+
# currently we'll just keep manager alive for simplicity
193+
#del self.op_support_checker
183194
return PartitionResult(
184195
tagged_exported_program=edge_program, partition_tags=self.partition_tags
185196
)

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
1010
from executorch.backends.qualcomm.quantizer.quantizer import (
1111
get_16a8w_qnn_ptq_config,
12+
get_16a4w_qnn_ptq_config,
1213
get_8a8w_qnn_ptq_config,
1314
get_ptq_per_channel_quant_config,
1415
QuantizationConfig,
@@ -53,6 +54,34 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
5354
)
5455

5556

57+
def annotate_linear_16a4w_in_affine_layer(gm: torch.fx.GraphModule) -> None:
58+
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
59+
input_qspec_map = {}
60+
input_act = node.args[0]
61+
input_spec = quantization_config.input_activation
62+
input_qspec_map[input_act] = input_spec
63+
64+
weight = node.args[1]
65+
input_qspec_map[weight] = quantization_config.weight
66+
67+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
68+
input_qspec_map=input_qspec_map,
69+
output_qspec=quantization_config.output_activation,
70+
_annotated=True,
71+
)
72+
73+
quantization_config_16a4w = get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver)
74+
for node in gm.graph.nodes:
75+
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
76+
if "nn_module_stack" in node.meta:
77+
module_values_list = list(node.meta["nn_module_stack"].values())
78+
full_qualified_name = module_values_list[-1][0]
79+
if full_qualified_name == "output.conv":
80+
annotate_conv2d(
81+
node, quantization_config=quantization_config_16a4w
82+
)
83+
84+
5685
def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
5786
for node in gm.graph.nodes:
5887
if node.op == "output":

examples/qualcomm/oss_scripts/llama/custom_ops/embedding/Makefile

Lines changed: 364 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<!--
3+
Copyright (c) Qualcomm Innovation Center, Inc.
4+
All rights reserved
5+
6+
This source code is licensed under the BSD-style license found in the
7+
LICENSE file in the root directory of this source tree.
8+
-->
9+
<OpDefCollection
10+
PackageName="EmbeddingOpPackage"
11+
Domain="executorch"
12+
Version="1.0"
13+
>
14+
<OpDefList>
15+
<OpDef>
16+
<Name>Embedding</Name>
17+
<Description>
18+
<Content>implmentation of torch.nn.Embedding</Content>
19+
</Description>
20+
21+
<Reference Source="PyTorch Documentation"
22+
Url="torch.nn.Embedding &lt;https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html&gt;"/>
23+
24+
<Input>
25+
<Name>input</Name>
26+
<Description>
27+
<Content>data table</Content>
28+
</Description>
29+
<Mandatory>true</Mandatory>
30+
<Datatype>BACKEND_SPECIFIC</Datatype>
31+
<Shape>
32+
<Rank>2D</Rank>
33+
<Text>a tensor of 2 dimension</Text>
34+
</Shape>
35+
</Input>
36+
37+
<Input>
38+
<Name>indices</Name>
39+
<Description>
40+
<Content>indices to extract data</Content>
41+
</Description>
42+
<Mandatory>true</Mandatory>
43+
<Datatype>QNN_DATATYPE_INT_32</Datatype>
44+
<Shape>
45+
<Rank>ND</Rank>
46+
<Text>a tensor of N dimension</Text>
47+
</Shape>
48+
</Input>
49+
50+
<Output>
51+
<Name>output</Name>
52+
<Description>
53+
<Content>output activation</Content>
54+
</Description>
55+
<Mandatory>true</Mandatory>
56+
<Datatype>BACKEND_SPECIFIC</Datatype>
57+
<Shape>
58+
<Rank>ND</Rank>
59+
<Text>a tensor of N dimension</Text>
60+
</Shape>
61+
</Output>
62+
63+
<!--This Op is implemented on these Backends-->
64+
<SupportedBackend>HTP</SupportedBackend>
65+
</OpDef>
66+
67+
</OpDefList>
68+
69+
<SupplementalOpDefList Backend="HTP">
70+
<SupportedOps>
71+
<OpName>Embedding</OpName>
72+
</SupportedOps>
73+
74+
<!--Embedding-->
75+
<SupplementalOpDef>
76+
<Name>Embedding</Name>
77+
78+
<Input>
79+
<Name>input</Name>
80+
<Datatype>QNN_DATATYPE_SFIXED_POINT_8</Datatype>
81+
</Input>
82+
83+
<Output>
84+
<Name>output</Name>
85+
<Datatype>QNN_DATATYPE_SFIXED_POINT_8</Datatype>
86+
</Output>
87+
</SupplementalOpDef>
88+
</SupplementalOpDefList>
89+
90+
</OpDefCollection>
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch.library import impl, Library
9+
10+
op_lib = Library("qaisw", "DEF")
11+
op_lib.define("embedding(Tensor table, Tensor indices) -> Tensor")
12+
13+
@impl(op_lib, "embedding", dispatch_key="CompositeExplicitAutograd")
14+
def embedding_impl(table: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
15+
return table[indices]
16+
17+
18+
class CustomEmbedding(torch.nn.Module):
19+
def __init__(self, weight):
20+
super(CustomEmbedding, self).__init__()
21+
self.weight = weight
22+
23+
def forward(self, indices):
24+
return torch.ops.qaisw.embedding.default(self.weight, indices)
25+
26+
27+
def custom_embedding_annotation(gm: torch.fx.GraphModule) -> None:
28+
import itertools
29+
from executorch.backends.qualcomm.quantizer.annotators import (
30+
_is_annotated,
31+
QUANT_ANNOTATION_KEY,
32+
)
33+
from executorch.backends.qualcomm.quantizer.qconfig import (
34+
get_16a4w_qnn_ptq_config,
35+
)
36+
from torch.ao.quantization.quantize_pt2e import QuantizationAnnotation, SharedQuantizationSpec
37+
from torch.fx import Node
38+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
39+
40+
custom_partitions = get_source_partitions(gm.graph, [torch.ops.qaisw.embedding.default])
41+
custom_partitions = list(itertools.chain(*custom_partitions.values()))
42+
quantization_config = get_16a4w_qnn_ptq_config()
43+
for custom_partition in custom_partitions:
44+
if len(custom_partition.output_nodes) > 1:
45+
raise ValueError("custom partition has more than one output node")
46+
custom_node = custom_partition.output_nodes[0]
47+
if (
48+
custom_node.op != "call_function"
49+
or custom_node.target != torch.ops.qaisw.embedding.default
50+
):
51+
raise ValueError(f"{custom_node} is not a custom operator")
52+
# skip annotation if it is already annotated
53+
if _is_annotated([custom_node]):
54+
continue
55+
56+
input_qspec_map = {}
57+
input_act = custom_node.args[0]
58+
assert isinstance(input_act, Node)
59+
input_spec = quantization_config.weight
60+
input_qspec_map[input_act] = input_spec
61+
62+
custom_node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
63+
input_qspec_map=input_qspec_map,
64+
output_qspec=SharedQuantizationSpec((input_act, custom_node)),
65+
_annotated=True,
66+
)

0 commit comments

Comments
 (0)