|
| 1 | +from abc import ABC, abstractmethod |
| 2 | +from typing import Any, Dict, Optional, Tuple, Type |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import torch.nn.functional as F |
| 7 | +from executorch.examples.models.llama.model_args import ModelArgs |
| 8 | +from executorch.examples.models.llama.rope import Rope |
| 9 | + |
| 10 | + |
| 11 | +class Attention(nn.Module, ABC): |
| 12 | + """Abstract base class for attention mechanisms with unified interface.""" |
| 13 | + |
| 14 | + @abstractmethod |
| 15 | + def forward( |
| 16 | + self, |
| 17 | + x: torch.Tensor, |
| 18 | + freqs_cos: torch.Tensor, |
| 19 | + freqs_sin: torch.Tensor, |
| 20 | + mask: Optional[torch.Tensor] = None, |
| 21 | + input_pos: Optional[torch.Tensor] = None, |
| 22 | + in_cache_state: Optional[Any] = None, |
| 23 | + out_cache_state: Optional[Any] = None, |
| 24 | + ) -> Tuple[torch.Tensor, Optional[Any]]: |
| 25 | + """Forward pass for attention mechanism. |
| 26 | +
|
| 27 | + Args: |
| 28 | + x: Input tensor of shape (batch_size, seq_len, dim) |
| 29 | + freqs_cos, freqs_sin: Rotary position embedding frequencies |
| 30 | + mask: Optional attention mask |
| 31 | + input_pos: Positions for KV cache updates |
| 32 | + in_cache_state/out_cache_state: Cache states |
| 33 | +
|
| 34 | + Returns: |
| 35 | + Tuple of (output tensor, updated cache state) |
| 36 | + """ |
| 37 | + pass |
| 38 | + |
| 39 | + |
| 40 | +ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {} |
| 41 | + |
| 42 | + |
| 43 | +def register_attention(name: str): |
| 44 | + """Decorator to register attention classes""" |
| 45 | + |
| 46 | + def decorator(cls: Type[Attention]): |
| 47 | + ATTENTION_REGISTRY[name.lower()] = cls |
| 48 | + return cls |
| 49 | + |
| 50 | + return decorator |
| 51 | + |
| 52 | + |
| 53 | +class KVCache(nn.Module): |
| 54 | + def __init__( |
| 55 | + self, |
| 56 | + max_batch_size: int, |
| 57 | + max_context_length: int, |
| 58 | + n_heads: int, |
| 59 | + head_dim: int, |
| 60 | + enable_dynamic_shape: bool, |
| 61 | + dtype=torch.float32, |
| 62 | + ): |
| 63 | + super().__init__() |
| 64 | + self.max_context_length = max_context_length |
| 65 | + cache_shape = (max_batch_size, n_heads, max_context_length, head_dim) |
| 66 | + |
| 67 | + self.max_batch_size = max_batch_size |
| 68 | + self.n_heads = n_heads |
| 69 | + self.head_dim = head_dim |
| 70 | + self.enable_dynamic_shape = enable_dynamic_shape |
| 71 | + self.register_buffer( |
| 72 | + "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") |
| 73 | + ) |
| 74 | + self.register_buffer( |
| 75 | + "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") |
| 76 | + ) |
| 77 | + |
| 78 | + def update( |
| 79 | + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor |
| 80 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 81 | + # input_pos: [S], k_val: [B, H, S, D] |
| 82 | + if self.enable_dynamic_shape: |
| 83 | + start_pos = input_pos[0].item() |
| 84 | + torch._check_is_size(start_pos) |
| 85 | + torch._check(start_pos < self.max_context_length) |
| 86 | + dim_to_slice = 2 |
| 87 | + seq_length = k_val.size(dim_to_slice) |
| 88 | + # Replace the entry in the cache for this token |
| 89 | + # The following lines are equivalent to: |
| 90 | + # cache_k[:bsz, start_pos : start_pos + seqlen] = xk |
| 91 | + # cache_v[:bsz, start_pos : start_pos + seqlen] = xv |
| 92 | + # when dim_to_slice is 1 |
| 93 | + # We use .narrow() here to make the compiler happy |
| 94 | + # pyre-ignore: Incompatible parameter type [6] |
| 95 | + narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) |
| 96 | + # pyre-ignore: Incompatible parameter type [6] |
| 97 | + narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) |
| 98 | + |
| 99 | + narrowed_k.copy_(k_val) |
| 100 | + narrowed_v.copy_(v_val) |
| 101 | + return self.k_cache, self.v_cache |
| 102 | + else: |
| 103 | + k_out = self.k_cache |
| 104 | + v_out = self.v_cache |
| 105 | + k_out[:, :, input_pos] = k_val |
| 106 | + v_out[:, :, input_pos] = v_val |
| 107 | + |
| 108 | + return k_out, v_out |
| 109 | + |
| 110 | + |
| 111 | +class SDPA(nn.Module): |
| 112 | + def __init__( |
| 113 | + self, |
| 114 | + dim: int, |
| 115 | + head_dim: int, |
| 116 | + n_rep: int, |
| 117 | + max_context_len: int, |
| 118 | + enable_dynamic_shape: bool, |
| 119 | + ): |
| 120 | + super().__init__() |
| 121 | + self.dim = dim |
| 122 | + self.head_dim = head_dim |
| 123 | + self.n_rep = n_rep |
| 124 | + self.max_context_len = max_context_len |
| 125 | + self.enable_dynamic_shape = enable_dynamic_shape |
| 126 | + |
| 127 | + def forward( |
| 128 | + self, |
| 129 | + input_pos: torch.Tensor, |
| 130 | + q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim) |
| 131 | + k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim) |
| 132 | + v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim) |
| 133 | + bsz, |
| 134 | + seqlen, |
| 135 | + mask: torch.Tensor, |
| 136 | + ) -> torch.Tensor: |
| 137 | + if self.enable_dynamic_shape: |
| 138 | + start_pos = input_pos[-1].item() |
| 139 | + torch._check_is_size(start_pos) |
| 140 | + torch._check(start_pos < self.max_context_len) |
| 141 | + seq_length = q.size(2) |
| 142 | + # pyre-ignore: Incompatible parameter type [6] |
| 143 | + attn_mask = mask.narrow(0, start_pos, seq_length) |
| 144 | + else: |
| 145 | + attn_mask = mask[None, None, input_pos] |
| 146 | + |
| 147 | + # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention |
| 148 | + # can natively support GQA now. But needs enable_gqa=True |
| 149 | + k = k.repeat_interleave(self.n_rep, dim=1) |
| 150 | + v = v.repeat_interleave(self.n_rep, dim=1) |
| 151 | + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) |
| 152 | + |
| 153 | + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) |
| 154 | + |
| 155 | + |
| 156 | +@register_attention("mha") |
| 157 | +class AttentionMHA(Attention): |
| 158 | + def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): |
| 159 | + super().__init__() |
| 160 | + self.use_kv_cache = args.use_kv_cache |
| 161 | + self.n_heads = args.n_heads |
| 162 | + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads |
| 163 | + assert self.n_heads % self.n_kv_heads == 0 |
| 164 | + model_parallel_size = 1 |
| 165 | + self.n_local_heads = self.n_heads // model_parallel_size |
| 166 | + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size |
| 167 | + self.n_rep = self.n_local_heads // self.n_local_kv_heads |
| 168 | + self.head_dim = args.head_dim |
| 169 | + self.max_batch_size = args.max_batch_size |
| 170 | + self.max_context_len = args.max_context_len |
| 171 | + self.dim = args.dim |
| 172 | + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) |
| 173 | + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) |
| 174 | + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) |
| 175 | + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) |
| 176 | + |
| 177 | + self.layer_id = layer_id |
| 178 | + |
| 179 | + self.rope = rope |
| 180 | + |
| 181 | + causal_mask = torch.tril( |
| 182 | + torch.ones( |
| 183 | + self.max_context_len, |
| 184 | + self.max_context_len, |
| 185 | + dtype=torch.bool, |
| 186 | + device="cpu", |
| 187 | + ) |
| 188 | + ) |
| 189 | + self.register_buffer("mask", causal_mask, persistent=False) |
| 190 | + |
| 191 | + if self.use_kv_cache: |
| 192 | + self.kv_cache = KVCache( |
| 193 | + args.max_batch_size, |
| 194 | + args.max_context_len, |
| 195 | + self.n_kv_heads, |
| 196 | + self.head_dim, |
| 197 | + args.enable_dynamic_shape, |
| 198 | + ) |
| 199 | + self.SDPA = SDPA( |
| 200 | + dim=self.n_local_heads * self.head_dim, |
| 201 | + head_dim=self.head_dim, |
| 202 | + n_rep=self.n_rep, |
| 203 | + max_context_len=self.max_context_len, |
| 204 | + enable_dynamic_shape=args.enable_dynamic_shape, |
| 205 | + ) |
| 206 | + |
| 207 | + def forward( |
| 208 | + self, |
| 209 | + x: torch.Tensor, |
| 210 | + freqs_cos: torch.Tensor, |
| 211 | + freqs_sin: torch.Tensor, |
| 212 | + mask: Optional[torch.Tensor] = None, |
| 213 | + input_pos: Optional[torch.Tensor] = None, |
| 214 | + in_cache_state: Optional[Any] = None, |
| 215 | + out_cache_state: Optional[Any] = None, |
| 216 | + ) -> Tuple[torch.Tensor, Optional[Any]]: |
| 217 | + bsz, seqlen, _ = x.shape |
| 218 | + |
| 219 | + # QKV |
| 220 | + q, k, v = self.wq(x), self.wk(x), self.wv(x) |
| 221 | + # We need view_copy elimination |
| 222 | + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| 223 | + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
| 224 | + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
| 225 | + |
| 226 | + # RoPE relative positional embeddings |
| 227 | + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) |
| 228 | + |
| 229 | + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) |
| 230 | + k = k.transpose(1, 2) |
| 231 | + v = v.transpose(1, 2) |
| 232 | + |
| 233 | + if self.use_kv_cache: |
| 234 | + assert input_pos is not None |
| 235 | + k, v = self.kv_cache.update(input_pos, k, v) |
| 236 | + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) |
| 237 | + return self.wo(output) |
| 238 | + |
| 239 | + # grouped multiquery attention: expand out keys and values |
| 240 | + k = k.repeat_interleave(self.n_rep, dim=1) |
| 241 | + v = v.repeat_interleave(self.n_rep, dim=1) |
| 242 | + |
| 243 | + assert hasattr(self, "mask") |
| 244 | + |
| 245 | + mask = self.mask[:seqlen, :seqlen] |
| 246 | + |
| 247 | + output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) |
| 248 | + |
| 249 | + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) |
| 250 | + |
| 251 | + output = self.wo(output) |
| 252 | + |
| 253 | + return output |
0 commit comments