Skip to content

Commit ce1b50c

Browse files
author
Martin Yuan
committed
Add abstract base class for attention mechanisms with unified interface
1 parent 15c772c commit ce1b50c

File tree

12 files changed

+426
-372
lines changed

12 files changed

+426
-372
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
update_spill_fill_size,
3939
)
4040

41-
from executorch.examples.models.llama.llama_transformer import ModelArgs, MOEFeedForward
41+
from executorch.examples.models.llama.model_args import ModelArgs
42+
from executorch.examples.models.llama.llama_transformer import MOEFeedForward
4243

4344
from executorch.examples.qualcomm.utils import setup_common_args_and_variables
4445

examples/models/llama/attention.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional, Tuple, Any, Dict, Type
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from executorch.examples.models.llama.model_args import ModelArgs
7+
from executorch.examples.models.llama.rope import Rope
8+
9+
class Attention(nn.Module, ABC):
10+
"""Abstract base class for attention mechanisms with unified interface."""
11+
@abstractmethod
12+
def forward(
13+
self,
14+
x: torch.Tensor,
15+
freqs_cos: torch.Tensor,
16+
freqs_sin: torch.Tensor,
17+
mask: Optional[torch.Tensor] = None,
18+
input_pos: Optional[torch.Tensor] = None,
19+
in_cache_state: Optional[Any] = None,
20+
out_cache_state: Optional[Any] = None,
21+
) -> Tuple[torch.Tensor, Optional[Any]]:
22+
"""Forward pass for attention mechanism.
23+
24+
Args:
25+
x: Input tensor of shape (batch_size, seq_len, dim)
26+
freqs_cos, freqs_sin: Rotary position embedding frequencies
27+
mask: Optional attention mask
28+
input_pos: Positions for KV cache updates
29+
in_cache_state/out_cache_state: Cache states
30+
31+
Returns:
32+
Tuple of (output tensor, updated cache state)
33+
"""
34+
pass
35+
36+
ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {}
37+
38+
def register_attention(name: str):
39+
"""Decorator to register attention classes"""
40+
def decorator(cls: Type[Attention]):
41+
ATTENTION_REGISTRY[name.lower()] = cls
42+
return cls
43+
return decorator
44+
45+
class KVCache(nn.Module):
46+
def __init__(
47+
self,
48+
max_batch_size: int,
49+
max_seq_length: int,
50+
n_heads: int,
51+
head_dim: int,
52+
enable_dynamic_shape: bool,
53+
dtype=torch.float32,
54+
):
55+
super().__init__()
56+
self.max_seq_length = max_seq_length
57+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
58+
59+
self.max_batch_size = max_batch_size
60+
self.n_heads = n_heads
61+
self.head_dim = head_dim
62+
self.enable_dynamic_shape = enable_dynamic_shape
63+
self.register_buffer(
64+
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
65+
)
66+
self.register_buffer(
67+
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
68+
)
69+
70+
def update(
71+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
72+
) -> Tuple[torch.Tensor, torch.Tensor]:
73+
# input_pos: [S], k_val: [B, H, S, D]
74+
if self.enable_dynamic_shape:
75+
start_pos = input_pos[0].item()
76+
torch._check_is_size(start_pos)
77+
torch._check(start_pos < self.max_seq_length)
78+
dim_to_slice = 2
79+
seq_length = k_val.size(dim_to_slice)
80+
# Replace the entry in the cache for this token
81+
# The following lines are equivalent to:
82+
# cache_k[:bsz, start_pos : start_pos + seqlen] = xk
83+
# cache_v[:bsz, start_pos : start_pos + seqlen] = xv
84+
# when dim_to_slice is 1
85+
# We use .narrow() here to make the compiler happy
86+
# pyre-ignore: Incompatible parameter type [6]
87+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
88+
# pyre-ignore: Incompatible parameter type [6]
89+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
90+
91+
narrowed_k.copy_(k_val)
92+
narrowed_v.copy_(v_val)
93+
return self.k_cache, self.v_cache
94+
else:
95+
k_out = self.k_cache
96+
v_out = self.v_cache
97+
k_out[:, :, input_pos] = k_val
98+
v_out[:, :, input_pos] = v_val
99+
100+
return k_out, v_out
101+
102+
class SDPA(nn.Module):
103+
def __init__(
104+
self,
105+
dim: int,
106+
head_dim: int,
107+
n_rep: int,
108+
max_seq_len: int,
109+
enable_dynamic_shape: bool,
110+
):
111+
super().__init__()
112+
self.dim = dim
113+
self.head_dim = head_dim
114+
self.n_rep = n_rep
115+
self.max_seq_len = max_seq_len
116+
self.enable_dynamic_shape = enable_dynamic_shape
117+
118+
def forward(
119+
self,
120+
input_pos: torch.Tensor,
121+
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
122+
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
123+
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
124+
bsz,
125+
seqlen,
126+
mask: torch.Tensor,
127+
) -> torch.Tensor:
128+
if self.enable_dynamic_shape:
129+
start_pos = input_pos[-1].item()
130+
torch._check_is_size(start_pos)
131+
torch._check(start_pos < self.max_seq_len)
132+
seq_length = q.size(2)
133+
# pyre-ignore: Incompatible parameter type [6]
134+
attn_mask = mask.narrow(0, start_pos, seq_length)
135+
else:
136+
attn_mask = mask[None, None, input_pos]
137+
138+
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
139+
# can natively support GQA now. But needs enable_gqa=True
140+
k = k.repeat_interleave(self.n_rep, dim=1)
141+
v = v.repeat_interleave(self.n_rep, dim=1)
142+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
143+
144+
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
145+
146+
@register_attention("mha")
147+
class AttentionMHA(Attention):
148+
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
149+
super().__init__()
150+
self.use_kv_cache = args.use_kv_cache
151+
self.n_heads = args.n_heads
152+
self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
153+
assert self.n_heads % self.n_kv_heads == 0
154+
model_parallel_size = 1
155+
self.n_local_heads = self.n_heads // model_parallel_size
156+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
157+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
158+
self.head_dim = args.head_dim
159+
self.max_batch_size = args.max_batch_size
160+
self.max_seq_len = args.max_seq_len
161+
self.dim = args.dim
162+
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
163+
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
164+
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
165+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
166+
167+
self.layer_id = layer_id
168+
169+
self.rope = rope
170+
171+
causal_mask = torch.tril(
172+
torch.ones(
173+
self.max_seq_len,
174+
self.max_seq_len,
175+
dtype=torch.bool,
176+
device="cpu",
177+
)
178+
)
179+
self.register_buffer("mask", causal_mask, persistent=False)
180+
181+
if self.use_kv_cache:
182+
self.kv_cache = KVCache(
183+
args.max_batch_size,
184+
args.max_seq_len,
185+
self.n_kv_heads,
186+
self.head_dim,
187+
args.enable_dynamic_shape,
188+
)
189+
self.SDPA = SDPA(
190+
dim=self.n_local_heads * self.head_dim,
191+
head_dim=self.head_dim,
192+
n_rep=self.n_rep,
193+
max_seq_len=self.max_seq_len,
194+
enable_dynamic_shape=args.enable_dynamic_shape,
195+
)
196+
197+
def forward(
198+
self,
199+
x: torch.Tensor,
200+
freqs_cos: torch.Tensor,
201+
freqs_sin: torch.Tensor,
202+
mask: Optional[torch.Tensor] = None,
203+
input_pos: Optional[torch.Tensor] = None,
204+
in_cache_state: Optional[Any] = None,
205+
out_cache_state: Optional[Any] = None,
206+
) -> Tuple[torch.Tensor, Optional[Any]]:
207+
bsz, seqlen, _ = x.shape
208+
209+
# QKV
210+
q, k, v = self.wq(x), self.wk(x), self.wv(x)
211+
# We need view_copy elimination
212+
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
213+
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
214+
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
215+
216+
# RoPE relative positional embeddings
217+
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
218+
219+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
220+
k = k.transpose(1, 2)
221+
v = v.transpose(1, 2)
222+
223+
if self.use_kv_cache:
224+
assert input_pos is not None
225+
k, v = self.kv_cache.update(input_pos, k, v)
226+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
227+
return self.wo(output)
228+
229+
# grouped multiquery attention: expand out keys and values
230+
k = k.repeat_interleave(self.n_rep, dim=1)
231+
v = v.repeat_interleave(self.n_rep, dim=1)
232+
233+
assert hasattr(self, "mask")
234+
235+
mask = self.mask[:seqlen, :seqlen]
236+
237+
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
238+
239+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
240+
241+
output = self.wo(output)
242+
243+
return output

0 commit comments

Comments
 (0)