Skip to content

Commit a972e73

Browse files
authored
Add abstract base class for attention mechanisms with unified interface
Differential Revision: D68956201 Pull Request resolved: #8039
1 parent 6897210 commit a972e73

17 files changed

+448
-383
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@
3737
skip_annotation,
3838
update_spill_fill_size,
3939
)
40+
from executorch.examples.models.llama.llama_transformer import MOEFeedForward
4041

41-
from executorch.examples.models.llama.llama_transformer import ModelArgs, MOEFeedForward
42+
from executorch.examples.models.llama.model_args import ModelArgs
4243

4344
from executorch.examples.qualcomm.utils import setup_common_args_and_variables
4445

examples/models/llama/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ runtime.python_library(
1414
srcs = [
1515
"llama_transformer.py",
1616
"rope.py",
17+
"attention.py",
18+
"model_args.py",
1719
],
1820
_is_external_target = True,
1921
base_module = "executorch.examples.models.llama",

examples/models/llama/attention.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict, Optional, Tuple, Type
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
from executorch.examples.models.llama.model_args import ModelArgs
8+
from executorch.examples.models.llama.rope import Rope
9+
10+
11+
class Attention(nn.Module, ABC):
12+
"""Abstract base class for attention mechanisms with unified interface."""
13+
14+
@abstractmethod
15+
def forward(
16+
self,
17+
x: torch.Tensor,
18+
freqs_cos: torch.Tensor,
19+
freqs_sin: torch.Tensor,
20+
mask: Optional[torch.Tensor] = None,
21+
input_pos: Optional[torch.Tensor] = None,
22+
in_cache_state: Optional[Any] = None,
23+
out_cache_state: Optional[Any] = None,
24+
) -> Tuple[torch.Tensor, Optional[Any]]:
25+
"""Forward pass for attention mechanism.
26+
27+
Args:
28+
x: Input tensor of shape (batch_size, seq_len, dim)
29+
freqs_cos, freqs_sin: Rotary position embedding frequencies
30+
mask: Optional attention mask
31+
input_pos: Positions for KV cache updates
32+
in_cache_state/out_cache_state: Cache states
33+
34+
Returns:
35+
Tuple of (output tensor, updated cache state)
36+
"""
37+
pass
38+
39+
40+
ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {}
41+
42+
43+
def register_attention(name: str):
44+
"""Decorator to register attention classes"""
45+
46+
def decorator(cls: Type[Attention]):
47+
ATTENTION_REGISTRY[name.lower()] = cls
48+
return cls
49+
50+
return decorator
51+
52+
53+
class KVCache(nn.Module):
54+
def __init__(
55+
self,
56+
max_batch_size: int,
57+
max_context_length: int,
58+
n_heads: int,
59+
head_dim: int,
60+
enable_dynamic_shape: bool,
61+
dtype=torch.float32,
62+
):
63+
super().__init__()
64+
self.max_context_length = max_context_length
65+
cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)
66+
67+
self.max_batch_size = max_batch_size
68+
self.n_heads = n_heads
69+
self.head_dim = head_dim
70+
self.enable_dynamic_shape = enable_dynamic_shape
71+
self.register_buffer(
72+
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
73+
)
74+
self.register_buffer(
75+
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
76+
)
77+
78+
def update(
79+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
80+
) -> Tuple[torch.Tensor, torch.Tensor]:
81+
# input_pos: [S], k_val: [B, H, S, D]
82+
if self.enable_dynamic_shape:
83+
start_pos = input_pos[0].item()
84+
torch._check_is_size(start_pos)
85+
torch._check(start_pos < self.max_context_length)
86+
dim_to_slice = 2
87+
seq_length = k_val.size(dim_to_slice)
88+
# Replace the entry in the cache for this token
89+
# The following lines are equivalent to:
90+
# cache_k[:bsz, start_pos : start_pos + seqlen] = xk
91+
# cache_v[:bsz, start_pos : start_pos + seqlen] = xv
92+
# when dim_to_slice is 1
93+
# We use .narrow() here to make the compiler happy
94+
# pyre-ignore: Incompatible parameter type [6]
95+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
96+
# pyre-ignore: Incompatible parameter type [6]
97+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
98+
99+
narrowed_k.copy_(k_val)
100+
narrowed_v.copy_(v_val)
101+
return self.k_cache, self.v_cache
102+
else:
103+
k_out = self.k_cache
104+
v_out = self.v_cache
105+
k_out[:, :, input_pos] = k_val
106+
v_out[:, :, input_pos] = v_val
107+
108+
return k_out, v_out
109+
110+
111+
class SDPA(nn.Module):
112+
def __init__(
113+
self,
114+
dim: int,
115+
head_dim: int,
116+
n_rep: int,
117+
max_context_len: int,
118+
enable_dynamic_shape: bool,
119+
):
120+
super().__init__()
121+
self.dim = dim
122+
self.head_dim = head_dim
123+
self.n_rep = n_rep
124+
self.max_context_len = max_context_len
125+
self.enable_dynamic_shape = enable_dynamic_shape
126+
127+
def forward(
128+
self,
129+
input_pos: torch.Tensor,
130+
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
131+
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
132+
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
133+
bsz,
134+
seqlen,
135+
mask: torch.Tensor,
136+
) -> torch.Tensor:
137+
if self.enable_dynamic_shape:
138+
start_pos = input_pos[-1].item()
139+
torch._check_is_size(start_pos)
140+
torch._check(start_pos < self.max_context_len)
141+
seq_length = q.size(2)
142+
# pyre-ignore: Incompatible parameter type [6]
143+
attn_mask = mask.narrow(0, start_pos, seq_length)
144+
else:
145+
attn_mask = mask[None, None, input_pos]
146+
147+
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
148+
# can natively support GQA now. But needs enable_gqa=True
149+
k = k.repeat_interleave(self.n_rep, dim=1)
150+
v = v.repeat_interleave(self.n_rep, dim=1)
151+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
152+
153+
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
154+
155+
156+
@register_attention("mha")
157+
class AttentionMHA(Attention):
158+
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
159+
super().__init__()
160+
self.use_kv_cache = args.use_kv_cache
161+
self.n_heads = args.n_heads
162+
self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
163+
assert self.n_heads % self.n_kv_heads == 0
164+
model_parallel_size = 1
165+
self.n_local_heads = self.n_heads // model_parallel_size
166+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
167+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
168+
self.head_dim = args.head_dim
169+
self.max_batch_size = args.max_batch_size
170+
self.max_context_len = args.max_context_len
171+
self.dim = args.dim
172+
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
173+
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
174+
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
175+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
176+
177+
self.layer_id = layer_id
178+
179+
self.rope = rope
180+
181+
causal_mask = torch.tril(
182+
torch.ones(
183+
self.max_context_len,
184+
self.max_context_len,
185+
dtype=torch.bool,
186+
device="cpu",
187+
)
188+
)
189+
self.register_buffer("mask", causal_mask, persistent=False)
190+
191+
if self.use_kv_cache:
192+
self.kv_cache = KVCache(
193+
args.max_batch_size,
194+
args.max_context_len,
195+
self.n_kv_heads,
196+
self.head_dim,
197+
args.enable_dynamic_shape,
198+
)
199+
self.SDPA = SDPA(
200+
dim=self.n_local_heads * self.head_dim,
201+
head_dim=self.head_dim,
202+
n_rep=self.n_rep,
203+
max_context_len=self.max_context_len,
204+
enable_dynamic_shape=args.enable_dynamic_shape,
205+
)
206+
207+
def forward(
208+
self,
209+
x: torch.Tensor,
210+
freqs_cos: torch.Tensor,
211+
freqs_sin: torch.Tensor,
212+
mask: Optional[torch.Tensor] = None,
213+
input_pos: Optional[torch.Tensor] = None,
214+
in_cache_state: Optional[Any] = None,
215+
out_cache_state: Optional[Any] = None,
216+
) -> Tuple[torch.Tensor, Optional[Any]]:
217+
bsz, seqlen, _ = x.shape
218+
219+
# QKV
220+
q, k, v = self.wq(x), self.wk(x), self.wv(x)
221+
# We need view_copy elimination
222+
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
223+
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
224+
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
225+
226+
# RoPE relative positional embeddings
227+
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
228+
229+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
230+
k = k.transpose(1, 2)
231+
v = v.transpose(1, 2)
232+
233+
if self.use_kv_cache:
234+
assert input_pos is not None
235+
k, v = self.kv_cache.update(input_pos, k, v)
236+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
237+
return self.wo(output)
238+
239+
# grouped multiquery attention: expand out keys and values
240+
k = k.repeat_interleave(self.n_rep, dim=1)
241+
v = v.repeat_interleave(self.n_rep, dim=1)
242+
243+
assert hasattr(self, "mask")
244+
245+
mask = self.mask[:seqlen, :seqlen]
246+
247+
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
248+
249+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
250+
251+
output = self.wo(output)
252+
253+
return output

0 commit comments

Comments
 (0)