Skip to content

Commit 9af908d

Browse files
Qualcomm AI Engine Direct - GA Static Smollm3 3B (#14149)
Summary: - e2e script for GA Static SmolLm3-3B - perf: 16a4w block quant token rate in kv mode: ~= 30 tokens/sec(SM8750) - acc: PPL ~= (fp: 8.345 -> htp:8.976) in wikitext dataset - add model params file & model weight converter ### Test plan ``` bash python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ```
1 parent a4b7de0 commit 9af908d

File tree

16 files changed

+373
-49
lines changed

16 files changed

+373
-49
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from enum import Enum, unique
67
from typing import Sequence
78

89
import torch
@@ -50,6 +51,17 @@ def annotate_down_proj(
5051
)
5152

5253

54+
@unique
55+
class StaticLLMQuantConfig(Enum):
56+
"""
57+
Layer namespace configuration for Qualcomm's static LLaMA quantization.
58+
"""
59+
60+
wq_sha = "wq_sha" # Query weight (single head)
61+
wk_sha = "wk_sha" # Key weight (single head)
62+
wv_sha = "wv_sha" # Value weight (single head)
63+
64+
5365
def annotate_eurobert(gm: torch.fx.GraphModule):
5466
"""
5567
QNN does not support int32 -> signed 16bit quant
@@ -185,11 +197,35 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
185197
)
186198

187199

188-
def annotate_wv_sha(gm: torch.fx.GraphModule, quantization_config: QuantizationConfig):
200+
def annotate_qkv_proj_sha(
201+
gm: torch.fx.GraphModule,
202+
quantization_config: QuantizationConfig,
203+
qkv_tags: set[StaticLLMQuantConfig],
204+
):
205+
"""
206+
Annotates QKV projection layers in a GraphModule for quantization,
207+
specifically layers defined in StaticLLMQuantConfig.
208+
209+
Args:
210+
qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers
211+
(e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in
212+
StaticLLMQuantConfig are allowed.
213+
214+
Raises:
215+
ValueError: If any tag in `qkv_tags` is not among the allowed enum members.
216+
"""
217+
218+
# Get all valid tags from the StaticLLMQuantConfig enum
219+
allowed_tags = set(StaticLLMQuantConfig)
220+
invalid_tags = qkv_tags - allowed_tags
221+
if invalid_tags:
222+
raise ValueError(
223+
f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}"
224+
)
225+
189226
for node in gm.graph.nodes:
190-
if (
191-
node.target == torch.ops.aten.conv2d.default
192-
and "wv_sha" in node.meta["stack_trace"]
227+
if node.target == torch.ops.aten.conv2d.default and any(
228+
tag.value in node.meta["stack_trace"] for tag in qkv_tags
193229
):
194230
input_qspec_map = {}
195231
input_qspec_map[node.args[0]] = quantization_config.input_activation

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5138,6 +5138,60 @@ def test_static_qwen3(self):
51385138
msg["inference_speed"], inference_speed_ref[self.model]
51395139
)
51405140

5141+
def test_qwen2_5(self):
5142+
if not self.required_envs([]):
5143+
self.skipTest("missing required envs")
5144+
prompt = "My favourite condiment is "
5145+
cmds = [
5146+
"python",
5147+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py",
5148+
"--prompt",
5149+
prompt,
5150+
"--decoder_model",
5151+
"qwen2.5_0.5B",
5152+
"--ptq",
5153+
"16a8w",
5154+
"--enable_spinquant_r3",
5155+
"--max_seq_len",
5156+
"128",
5157+
"--artifact",
5158+
self.artifact_dir,
5159+
"--build_folder",
5160+
self.build_folder,
5161+
"--model",
5162+
self.model,
5163+
"--ip",
5164+
self.ip,
5165+
"--port",
5166+
str(self.port),
5167+
]
5168+
if self.compile_only:
5169+
cmds.extend(["--compile_only"])
5170+
elif self.device:
5171+
cmds.extend(["--device", self.device])
5172+
if self.host:
5173+
cmds.extend(["--host", self.host])
5174+
elif self.enable_x86_64:
5175+
cmds.extend(["--enable_x86_64"])
5176+
if self.pre_gen_pte:
5177+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
5178+
5179+
golden_start_with = "My favourite condiment is iced tea."
5180+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
5181+
with Listener((self.ip, self.port)) as listener:
5182+
conn = listener.accept()
5183+
p.communicate()
5184+
msg = json.loads(conn.recv())
5185+
if "Error" in msg:
5186+
self.fail(msg["Error"])
5187+
else:
5188+
if not self.compile_only:
5189+
model_out = msg["result"][0]
5190+
self.assertTrue(
5191+
model_out.startswith(golden_start_with),
5192+
f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'",
5193+
)
5194+
51415195
def test_static_smollm2(self):
51425196
if not self.required_envs():
51435197
self.skipTest("missing required envs")
@@ -5171,6 +5225,8 @@ def test_static_smollm2(self):
51715225
"--eval_perplexity",
51725226
"--task",
51735227
"wikitext",
5228+
"--limit",
5229+
"1",
51745230
]
51755231
if self.compile_only:
51765232
cmds.extend(["--compile_only"])
@@ -5194,22 +5250,14 @@ def test_static_smollm2(self):
51945250
self.assertLessEqual(msg["wiki_ppl"], 25)
51955251
self.assertGreaterEqual(msg["inference_speed"], 200)
51965252

5197-
def test_qwen2_5(self):
5198-
if not self.required_envs([]):
5253+
def test_static_smollm3(self):
5254+
if not self.required_envs():
51995255
self.skipTest("missing required envs")
5256+
52005257
prompt = "My favourite condiment is "
52015258
cmds = [
52025259
"python",
5203-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py",
5204-
"--prompt",
5205-
prompt,
5206-
"--decoder_model",
5207-
"qwen2.5_0.5B",
5208-
"--ptq",
5209-
"16a8w",
5210-
"--enable_spinquant_r3",
5211-
"--max_seq_len",
5212-
"128",
5260+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
52135261
"--artifact",
52145262
self.artifact_dir,
52155263
"--build_folder",
@@ -5220,6 +5268,21 @@ def test_qwen2_5(self):
52205268
self.ip,
52215269
"--port",
52225270
str(self.port),
5271+
"--prompt",
5272+
f"{prompt}",
5273+
"--decoder_model",
5274+
"smollm3-3b",
5275+
"--model_mode",
5276+
"kv",
5277+
"--temperature",
5278+
"0",
5279+
"--max_seq_len",
5280+
"1024",
5281+
"--eval_perplexity",
5282+
"--task",
5283+
"wikitext",
5284+
"--limit",
5285+
"1",
52235286
]
52245287
if self.compile_only:
52255288
cmds.extend(["--compile_only"])
@@ -5232,7 +5295,6 @@ def test_qwen2_5(self):
52325295
if self.pre_gen_pte:
52335296
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
52345297

5235-
golden_start_with = "My favourite condiment is iced tea."
52365298
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
52375299
with Listener((self.ip, self.port)) as listener:
52385300
conn = listener.accept()
@@ -5241,11 +5303,12 @@ def test_qwen2_5(self):
52415303
if "Error" in msg:
52425304
self.fail(msg["Error"])
52435305
else:
5244-
if not self.compile_only:
5245-
model_out = msg["result"][0]
5246-
self.assertTrue(
5247-
model_out.startswith(golden_start_with),
5248-
f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'",
5306+
inference_speed_ref = {"SM8650": 23, "SM8750": 28}
5307+
self.assertLessEqual(msg["wiki_ppl"], 10)
5308+
self.assertLessEqual(msg["pte_size"], 2_600_000_000) # 2.6GB
5309+
if self.model in inference_speed_ref:
5310+
self.assertGreaterEqual(
5311+
msg["inference_speed"], inference_speed_ref[self.model]
52495312
)
52505313

52515314

examples/models/llama/model_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ class ModelArgs:
7878
use_qk_norm: bool = False # apply normalization to q and k in the attention
7979
qk_norm_before_rope: bool = False # when to apply qk norm
8080
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
81+
no_rope_layer_interval: Optional[int] = (
82+
None # Interval at which to skip RoPE. From Rope to Nope and Back Again: A New Hybrid Attention Strategy (https://huggingface.co/papers/2501.18795).
83+
)
8184
partial_rotary_factor: float = 1.0
8285
rope_theta: Optional[float] = (
8386
None # The official name to override self.rope_freq_base.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"dim": 2048,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 11008,
5+
"n_heads": 16,
6+
"n_kv_heads": 4,
7+
"n_layers": 36,
8+
"norm_eps": 1e-06,
9+
"rope_theta": 5000000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 128256,
12+
"use_hf_rope": false,
13+
"no_rope_layer_interval": 4,
14+
"attention_qkv_bias": false
15+
}

examples/models/smollm3/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.examples.models.llama.model import Llama2Model
5+
from executorch.examples.models.smollm3.convert_weights import convert_weights
6+
7+
8+
class SmolLM3Model(Llama2Model):
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
13+
__all__ = [
14+
"SmolLM3Model",
15+
"convert_weights",
16+
]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import argparse
2+
import json
3+
import os
4+
from typing import Dict
5+
6+
import torch
7+
8+
from safetensors.torch import load_file
9+
10+
from torchtune.models.convert_weights import get_mapped_key
11+
12+
13+
_SMOLLM_TO_META = {
14+
"model.embed_tokens.weight": "tok_embeddings.weight",
15+
"model.norm.weight": "norm.weight",
16+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
17+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
18+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
19+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
20+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
21+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
22+
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
23+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
24+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
25+
}
26+
27+
28+
def smollm_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
29+
"""
30+
Convert a state dict from torchtune's format to Meta's format. This function
31+
doesn't handle any sharding or splitting of state dicts. It follows the
32+
state_dict IN -> state_dict OUT pattern.
33+
34+
Args:
35+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
36+
37+
Returns:
38+
Dict[str, torch.Tensor]: State dict in Meta's format.
39+
"""
40+
converted_state_dict = {}
41+
for key, value in state_dict.items():
42+
new_key = get_mapped_key(key, _SMOLLM_TO_META)
43+
converted_state_dict[new_key] = value
44+
converted_state_dict["output.weight"] = converted_state_dict[
45+
"tok_embeddings.weight"
46+
]
47+
48+
return converted_state_dict
49+
50+
51+
def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
52+
index_path = os.path.join(input_dir, "model.safetensors.index.json")
53+
if os.path.exists(index_path):
54+
# Sharded checkpoint.
55+
with open(index_path, "r") as f:
56+
index = json.load(f)
57+
weight_map = index["weight_map"]
58+
checkpoint_shards = sorted(set(weight_map.values()))
59+
60+
# Load all the shards into memory
61+
shard_to_weights = {}
62+
for shard in checkpoint_shards:
63+
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))
64+
65+
# Merge tensors into consolidated state dict.
66+
merged_state_dict = {}
67+
for weight_name, shard in weight_map.items():
68+
tensor = shard_to_weights[shard][weight_name]
69+
merged_state_dict[weight_name] = tensor
70+
return merged_state_dict
71+
else:
72+
# Single checkpoint.
73+
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
74+
return state_dict
75+
76+
77+
def load_checkpoint(input_dir: str) -> Dict:
78+
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
79+
if os.path.exists(pytorch_path):
80+
print("Loading checkpoint from PyTorch .bin file")
81+
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
82+
print("Loading checkpoint from safetensors directory")
83+
return load_checkpoint_from_safetensors(input_dir)
84+
85+
86+
def convert_weights(input_dir: str, output_file: str) -> None:
87+
print("Loading checkpoint...")
88+
sd = load_checkpoint(input_dir)
89+
print("Converting checkpoint...")
90+
sd = smollm_to_meta(sd)
91+
print("Saving checkpoint...")
92+
torch.save(sd, output_file)
93+
print("Done.")
94+
95+
96+
def main():
97+
parser = argparse.ArgumentParser(
98+
description="Convert SmolLM weights to Meta format."
99+
)
100+
parser.add_argument(
101+
"input_dir",
102+
type=str,
103+
help="Path to directory containing checkpoint files",
104+
)
105+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
106+
107+
args = parser.parse_args()
108+
convert_weights(args.input_dir, args.output)
109+
110+
111+
if __name__ == "__main__":
112+
main()

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ This file provides you the instructions to run LLM Decoder model with different
99
5. Phi4-mini-instruct
1010
6. QWEN2.5 0.5B / 1.5B
1111
7. QWEN3 0.6B / 1.7B
12-
8. SMOLLM2 135M
12+
8. SmolLM2 135M
13+
9. SmolLM3 3B
1314

1415

1516
We offer the following modes to execute the model:
@@ -113,10 +114,16 @@ Default example using hybrid mode
113114
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
114115
```
115116

116-
#### SMOLLM2
117+
#### SmolLM2
117118
Default example using hybrid mode.
118119
```bash
119-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
120+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
121+
```
122+
123+
#### SmolLM3
124+
Default example using kv mode.
125+
```bash
126+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
120127
```
121128

122129

0 commit comments

Comments
 (0)