diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 5935c7d3caf..498ee4ea68e 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -37,8 +37,9 @@ skip_annotation, update_spill_fill_size, ) +from executorch.examples.models.llama.llama_transformer import MOEFeedForward -from executorch.examples.models.llama.llama_transformer import ModelArgs, MOEFeedForward +from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.qualcomm.utils import setup_common_args_and_variables diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 4fe7f6cc2b1..f6b78e876c8 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -14,6 +14,8 @@ runtime.python_library( srcs = [ "llama_transformer.py", "rope.py", + "attention.py", + "model_args.py", ], _is_external_target = True, base_module = "executorch.examples.models.llama", diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py new file mode 100644 index 00000000000..c6f01dbadce --- /dev/null +++ b/examples/models/llama/attention.py @@ -0,0 +1,253 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope + + +class Attention(nn.Module, ABC): + """Abstract base class for attention mechanisms with unified interface.""" + + @abstractmethod + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + in_cache_state: Optional[Any] = None, + out_cache_state: Optional[Any] = None, + ) -> Tuple[torch.Tensor, Optional[Any]]: + """Forward pass for attention mechanism. + + Args: + x: Input tensor of shape (batch_size, seq_len, dim) + freqs_cos, freqs_sin: Rotary position embedding frequencies + mask: Optional attention mask + input_pos: Positions for KV cache updates + in_cache_state/out_cache_state: Cache states + + Returns: + Tuple of (output tensor, updated cache state) + """ + pass + + +ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {} + + +def register_attention(name: str): + """Decorator to register attention classes""" + + def decorator(cls: Type[Attention]): + ATTENTION_REGISTRY[name.lower()] = cls + return cls + + return decorator + + +class KVCache(nn.Module): + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool, + dtype=torch.float32, + ): + super().__init__() + self.max_context_length = max_context_length + cache_shape = (max_batch_size, n_heads, max_context_length, head_dim) + + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.enable_dynamic_shape = enable_dynamic_shape + self.register_buffer( + "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + self.register_buffer( + "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [S], k_val: [B, H, S, D] + if self.enable_dynamic_shape: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_context_length) + dim_to_slice = 2 + seq_length = k_val.size(dim_to_slice) + # Replace the entry in the cache for this token + # The following lines are equivalent to: + # cache_k[:bsz, start_pos : start_pos + seqlen] = xk + # cache_v[:bsz, start_pos : start_pos + seqlen] = xv + # when dim_to_slice is 1 + # We use .narrow() here to make the compiler happy + # pyre-ignore: Incompatible parameter type [6] + narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) + # pyre-ignore: Incompatible parameter type [6] + narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) + + narrowed_k.copy_(k_val) + narrowed_v.copy_(v_val) + return self.k_cache, self.v_cache + else: + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class SDPA(nn.Module): + def __init__( + self, + dim: int, + head_dim: int, + n_rep: int, + max_context_len: int, + enable_dynamic_shape: bool, + ): + super().__init__() + self.dim = dim + self.head_dim = head_dim + self.n_rep = n_rep + self.max_context_len = max_context_len + self.enable_dynamic_shape = enable_dynamic_shape + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim) + k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim) + v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim) + bsz, + seqlen, + mask: torch.Tensor, + ) -> torch.Tensor: + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_context_len) + seq_length = q.size(2) + # pyre-ignore: Incompatible parameter type [6] + attn_mask = mask.narrow(0, start_pos, seq_length) + else: + attn_mask = mask[None, None, input_pos] + + # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention + # can natively support GQA now. But needs enable_gqa=True + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + +@register_attention("mha") +class AttentionMHA(Attention): + def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): + super().__init__() + self.use_kv_cache = args.use_kv_cache + self.n_heads = args.n_heads + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert self.n_heads % self.n_kv_heads == 0 + model_parallel_size = 1 + self.n_local_heads = self.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.head_dim + self.max_batch_size = args.max_batch_size + self.max_context_len = args.max_context_len + self.dim = args.dim + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + + self.layer_id = layer_id + + self.rope = rope + + causal_mask = torch.tril( + torch.ones( + self.max_context_len, + self.max_context_len, + dtype=torch.bool, + device="cpu", + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + if self.use_kv_cache: + self.kv_cache = KVCache( + args.max_batch_size, + args.max_context_len, + self.n_kv_heads, + self.head_dim, + args.enable_dynamic_shape, + ) + self.SDPA = SDPA( + dim=self.n_local_heads * self.head_dim, + head_dim=self.head_dim, + n_rep=self.n_rep, + max_context_len=self.max_context_len, + enable_dynamic_shape=args.enable_dynamic_shape, + ) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + in_cache_state: Optional[Any] = None, + out_cache_state: Optional[Any] = None, + ) -> Tuple[torch.Tensor, Optional[Any]]: + bsz, seqlen, _ = x.shape + + # QKV + q, k, v = self.wq(x), self.wk(x), self.wv(x) + # We need view_copy elimination + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # RoPE relative positional embeddings + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if self.use_kv_cache: + assert input_pos is not None + k, v = self.kv_cache.update(input_pos, k, v) + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) + return self.wo(output) + + # grouped multiquery attention: expand out keys and values + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + assert hasattr(self, "mask") + + mask = self.mask[:seqlen, :seqlen] + + output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + + output = self.wo(output) + + return output diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index cc6b81edc10..08526dde195 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -7,19 +7,16 @@ # Please refer to README.md in the same folder for more information. -from dataclasses import dataclass -from functools import partial -from typing import Dict, Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F -from executorch.examples.models.llama.rope import ( - hf_apply_rotary_emb, - hf_precompute_freqs_cis, - precompute_freqs_cis, - RotaryEmbedding, -) +from executorch.examples.models.llama.attention import ATTENTION_REGISTRY + +from executorch.examples.models.llama.model_args import ModelArgs + +from executorch.examples.models.llama.rope import Rope from torch import nn @@ -71,360 +68,6 @@ def forward(self, x): return output * self.weight -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - - -@dataclass -class ModelArgs: - dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 - n_kv_heads: Optional[int] = None - vocab_size: int = -1 # defined later by tokenizer - hidden_dim: Optional[int] = None - head_dim: Optional[int] = None # Optional customized head_dim - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None - norm_eps: float = 1e-5 - max_batch_size: int = 32 - max_seq_len: int = 2048 - max_context_len: int = 2048 - moe: bool = False # True to enable the MoE (Mixture of Experts) - num_experts: int = 8 # Number of experts - num_activated_experts: int = 2 # Number of experts to activate - use_kv_cache: bool = False # Use key/value cache - use_sdpa_with_kv_cache_op: bool = ( - False # Use custom sdpa op that updates kv cache in-place - ) - # Generate logits for all inputs. When it's True, it would take big memory usage - # at runtime. Enable it only necessary (e.g., use perplexity tools that requires - # logits for all input tokens.) - generate_full_logits: bool = False - enable_dynamic_shape: bool = False # export model with dynamic shape support - # A dictionary mapping from pruned token-id to original token-id - input_prune_map: Optional[Dict[int, int]] = None - # A dictionary mapping from pruned token-id to original token-id - output_prune_map: Optional[Dict[int, int]] = None - use_hf_rope: bool = False # Use HuggingFace's RoPE implementation - rope_theta: Optional[float] = ( - None # The official name to override self.rope_freq_base. - ) - rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. - use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. - rope_scale_factor: int = 8 - # Additional Model Metadata needed at runtime - bos_idx: int = 1 - eos_idx: int = 3 - bos_count: int = -1 # i.e., a single EOS is used as BOS - eos_count: int = 2 - - quantization_args: Optional[dict] = None - lora_args: Optional[dict] = None - - def __post_init__(self): - if self.n_kv_heads is None: - self.n_kv_heads = self.n_heads - - # rope_theta overrides rope_freq_base since it's the official name. - if self.rope_theta is not None: - self.rope_freq_base = self.rope_theta - - if self.use_sdpa_with_kv_cache_op: - assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache" - - if self.hidden_dim is None: - # If hidden_dim is not explicitly set in the ModelArgs, - # then calculate implicitly based on dim and also multiple of `args.multiple_of` - multiple_of = self.multiple_of - hidden_dim = 4 * self.dim - hidden_dim = int(2 * hidden_dim / 3) - if self.ffn_dim_multiplier is not None: - hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) - self.hidden_dim = find_multiple(hidden_dim, multiple_of) - - if self.head_dim is None: - self.head_dim = self.dim // self.n_heads - - -class Rope(torch.nn.Module): - def __init__(self, params: ModelArgs): - super().__init__() - self.params = params - if self.params.use_hf_rope: - self.precompute_freqs_cis = hf_precompute_freqs_cis - else: - self.precompute_freqs_cis = partial( - precompute_freqs_cis, - use_scaled=self.params.use_scaled_rope, - scale_factor=self.params.rope_scale_factor, - ) - freqs_cos, freqs_sin = self.precompute_freqs_cis( - self.params.head_dim, - ( - self.params.max_context_len # Normal llama2. - if self.params.ffn_dim_multiplier is None - else self.params.max_context_len * 2 # Sharded checkpoint. - ), - self.params.rope_freq_base, - ) - self.register_buffer("freqs_cos", freqs_cos, persistent=False) - self.register_buffer("freqs_sin", freqs_sin, persistent=False) - if self.params.use_hf_rope: - self.apply_rotary_emb = hf_apply_rotary_emb - else: - self.apply_rotary_emb = RotaryEmbedding() - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, - ): - return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) - - def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): - """ - Get the precomputed frequencies for the given input position and sequence length. - - Args: - input_pos (torch.Tensor): The input position tensor. - seq_len (int): The sequence length. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length. - """ - if self.params.use_kv_cache: - assert ( - input_pos is not None - ), "input_pos must be provided when use_kv_cache is True" - - if self.params.enable_dynamic_shape: - # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. - input_pos_item = input_pos[-1].item() - torch._check_is_size(input_pos_item) - torch._check(input_pos_item < self.params.max_context_len) - # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor - freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) - # pyre-ignore: Incompatible parameter type [6] - freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len) - else: - # When not using dynamic shape, use of the .item results in - # symints, due to querying the data from tensor. - # this path avoids that for mps backend, although probably mps backend - # can support dynamic shape? - freqs_cos = self.freqs_cos[input_pos] - freqs_sin = self.freqs_sin[input_pos] - - else: - assert input_pos is None, "input_pos is unused when use_kv_cache is False" - freqs_cos = self.freqs_cos[:seq_len] - freqs_sin = self.freqs_sin[:seq_len] - return freqs_cos, freqs_sin - - -class KVCache(nn.Module): - def __init__( - self, - max_batch_size: int, - max_context_length: int, - n_heads: int, - head_dim: int, - enable_dynamic_shape: bool, - dtype=torch.float32, - ): - super().__init__() - self.max_context_length = max_context_length - cache_shape = (max_batch_size, n_heads, max_context_length, head_dim) - - self.max_batch_size = max_batch_size - self.n_heads = n_heads - self.head_dim = head_dim - self.enable_dynamic_shape = enable_dynamic_shape - self.register_buffer( - "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") - ) - self.register_buffer( - "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") - ) - - def update( - self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, H, S, D] - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_context_length) - dim_to_slice = 2 - seq_length = k_val.size(dim_to_slice) - # Replace the entry in the cache for this token - # The following lines are equivalent to: - # cache_k[:bsz, start_pos : start_pos + seqlen] = xk - # cache_v[:bsz, start_pos : start_pos + seqlen] = xv - # when dim_to_slice is 1 - # We use .narrow() here to make the compiler happy - # pyre-ignore: Incompatible parameter type [6] - narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) - # pyre-ignore: Incompatible parameter type [6] - narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) - - narrowed_k.copy_(k_val) - narrowed_v.copy_(v_val) - return self.k_cache, self.v_cache - else: - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - - return k_out, v_out - - -class SDPA(nn.Module): - def __init__( - self, - dim: int, - head_dim: int, - n_rep: int, - max_context_len: int, - enable_dynamic_shape: bool, - ): - super().__init__() - self.dim = dim - self.head_dim = head_dim - self.n_rep = n_rep - self.max_context_len = max_context_len - self.enable_dynamic_shape = enable_dynamic_shape - - def forward( - self, - input_pos: torch.Tensor, - q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim) - k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim) - v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim) - bsz, - seqlen, - mask: torch.Tensor, - ) -> torch.Tensor: - if self.enable_dynamic_shape: - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_context_len) - seq_length = q.size(2) - # pyre-ignore: Incompatible parameter type [6] - attn_mask = mask.narrow(0, start_pos, seq_length) - else: - attn_mask = mask[None, None, input_pos] - - # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention - # can natively support GQA now. But needs enable_gqa=True - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) - - return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): - super().__init__() - self.use_kv_cache = args.use_kv_cache - self.n_heads = args.n_heads - self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads - assert self.n_heads % self.n_kv_heads == 0 - model_parallel_size = 1 - self.n_local_heads = self.n_heads // model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // model_parallel_size - self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.head_dim - self.max_batch_size = args.max_batch_size - self.max_context_len = args.max_context_len - self.dim = args.dim - self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) - - self.layer_id = layer_id - - self.rope = rope - - causal_mask = torch.tril( - torch.ones( - self.max_context_len, - self.max_context_len, - dtype=torch.bool, - device="cpu", - ) - ) - self.register_buffer("mask", causal_mask, persistent=False) - - if self.use_kv_cache: - self.kv_cache = KVCache( - args.max_batch_size, - args.max_context_len, - self.n_kv_heads, - self.head_dim, - args.enable_dynamic_shape, - ) - self.SDPA = SDPA( - dim=self.n_local_heads * self.head_dim, - head_dim=self.head_dim, - n_rep=self.n_rep, - max_context_len=self.max_context_len, - enable_dynamic_shape=args.enable_dynamic_shape, - ) - - def forward( - self, - x: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, - ): - bsz, seqlen, _ = x.shape - - # QKV - q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination - q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - - # RoPE relative positional embeddings - q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) - - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - if self.use_kv_cache: - assert input_pos is not None - k, v = self.kv_cache.update(input_pos, k, v) - output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) - return self.wo(output) - - # grouped multiquery attention: expand out keys and values - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - - assert hasattr(self, "mask") - - mask = self.mask[:seqlen, :seqlen] - - output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) - - output = self.wo(output) - - return output - - class FeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -491,7 +134,13 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.head_dim - self.attention = Attention(args, layer_id, rope) + if args.attention_type not in ATTENTION_REGISTRY: + raise ValueError( + f"Unknown attention type: {args.attention_type}. " + f"Available: {list(ATTENTION_REGISTRY.keys())}" + ) + cls = ATTENTION_REGISTRY[args.attention_type] + self.attention = cls(args, layer_id, rope) if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: @@ -501,7 +150,7 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN h = self.attention.forward( - self.attention_norm(x), freqs_cos, freqs_sin, input_pos + self.attention_norm(x), freqs_cos, freqs_sin, input_pos=input_pos ) h = x + h diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 00f59df286d..19c7ed0b311 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -15,8 +15,9 @@ get_checkpoint_dtype, get_default_model_resource_dir, ) +from executorch.examples.models.llama.llama_transformer import Transformer -from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer +from executorch.examples.models.llama.model_args import ModelArgs try: from .fairseq2 import convert_to_llama_checkpoint diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py new file mode 100644 index 00000000000..e1c4edb8e93 --- /dev/null +++ b/examples/models/llama/model_args.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass +from typing import Dict, Optional + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + hidden_dim: Optional[int] = None + head_dim: Optional[int] = None # Optional customized head_dim + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + max_batch_size: int = 32 + max_seq_len: int = 2048 + max_context_len: int = 2048 + moe: bool = False # True to enable the MoE (Mixture of Experts) + num_experts: int = 8 # Number of experts + num_activated_experts: int = 2 # Number of experts to activate + attention_type: str = "mha" # Attention type, registered in attention.py + use_kv_cache: bool = False # Use key/value cache + use_sdpa_with_kv_cache_op: bool = ( + False # Use custom sdpa op that updates kv cache in-place + ) + # Generate logits for all inputs. When it's True, it would take big memory usage + # at runtime. Enable it only necessary (e.g., use perplexity tools that requires + # logits for all input tokens.) + generate_full_logits: bool = False + enable_dynamic_shape: bool = False # export model with dynamic shape support + # A dictionary mapping from pruned token-id to original token-id + input_prune_map: Optional[Dict[int, int]] = None + # A dictionary mapping from pruned token-id to original token-id + output_prune_map: Optional[Dict[int, int]] = None + use_hf_rope: bool = False # Use HuggingFace's RoPE implementation + rope_theta: Optional[float] = ( + None # The official name to override self.rope_freq_base. + ) + rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. + use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. + rope_scale_factor: int = 8 + # Additional Model Metadata needed at runtime + bos_idx: int = 1 + eos_idx: int = 3 + bos_count: int = -1 # i.e., a single EOS is used as BOS + eos_count: int = 2 + + quantization_args: Optional[dict] = None + lora_args: Optional[dict] = None + + def __post_init__(self): + if self.n_kv_heads is None: + self.n_kv_heads = self.n_heads + + # rope_theta overrides rope_freq_base since it's the official name. + if self.rope_theta is not None: + self.rope_freq_base = self.rope_theta + + if self.use_sdpa_with_kv_cache_op: + assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache" + + if self.hidden_dim is None: + # If hidden_dim is not explicitly set in the ModelArgs, + # then calculate implicitly based on dim and also multiple of `args.multiple_of` + multiple_of = self.multiple_of + hidden_dim = 4 * self.dim + hidden_dim = int(2 * hidden_dim / 3) + if self.ffn_dim_multiplier is not None: + hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) + + def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + self.hidden_dim = find_multiple(hidden_dim, multiple_of) + + if self.head_dim is None: + self.head_dim = self.dim // self.n_heads diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index cd3ddb0d3b8..01352f404df 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -8,9 +8,11 @@ # Different RoPE implementations import math +from functools import partial from typing import Optional, Tuple import torch +from executorch.examples.models.llama.model_args import ModelArgs # ======================== Stock Implementation ======================== @@ -205,3 +207,80 @@ def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1): sin = sin.unsqueeze(unsqueeze_dim) k_embed = (k * cos) + (rotate_half(k) * sin) return k_embed + + +class Rope(torch.nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + if self.params.use_hf_rope: + self.precompute_freqs_cis = hf_precompute_freqs_cis + else: + self.precompute_freqs_cis = partial( + precompute_freqs_cis, + use_scaled=self.params.use_scaled_rope, + scale_factor=self.params.rope_scale_factor, + ) + freqs_cos, freqs_sin = self.precompute_freqs_cis( + self.params.head_dim, + ( + self.params.max_context_len # Normal llama2. + if self.params.ffn_dim_multiplier is None + else self.params.max_context_len * 2 # Sharded checkpoint. + ), + self.params.rope_freq_base, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + if self.params.use_hf_rope: + self.apply_rotary_emb = hf_apply_rotary_emb + else: + self.apply_rotary_emb = RotaryEmbedding() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + + def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get the precomputed frequencies for the given input position and sequence length. + + Args: + input_pos (torch.Tensor): The input position tensor. + seq_len (int): The sequence length. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length. + """ + if self.params.use_kv_cache: + assert ( + input_pos is not None + ), "input_pos must be provided when use_kv_cache is True" + + if self.params.enable_dynamic_shape: + # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. + input_pos_item = input_pos[-1].item() + torch._check_is_size(input_pos_item) + torch._check(input_pos_item < self.params.max_context_len) + # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) + # pyre-ignore: Incompatible parameter type [6] + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len) + else: + # When not using dynamic shape, use of the .item results in + # symints, due to querying the data from tensor. + # this path avoids that for mps backend, although probably mps backend + # can support dynamic shape? + freqs_cos = self.freqs_cos[input_pos] + freqs_sin = self.freqs_sin[input_pos] + + else: + assert input_pos is None, "input_pos is unused when use_kv_cache is False" + freqs_cos = self.freqs_cos[:seq_len] + freqs_sin = self.freqs_sin[:seq_len] + return freqs_cos, freqs_sin diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py index f1d40b70423..d5f065550d2 100644 --- a/examples/models/llama/source_transformation/attention.py +++ b/examples/models/llama/source_transformation/attention.py @@ -12,7 +12,7 @@ from typing import List, Optional, Tuple import torch -from executorch.examples.models.llama.llama_transformer import Attention +from executorch.examples.models.llama.attention import Attention from torch import nn diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index d710773d007..22bd8a3e228 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -12,15 +12,12 @@ import torch -from executorch.examples.models.llama.llama_transformer import ( - Attention, - KVCache, - ModelArgs, - Rope, -) +from executorch.examples.models.llama.attention import AttentionMHA, KVCache +from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import ( apply_rotary_emb_to_k, hf_apply_rotary_emb_to_k, + Rope, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter @@ -266,7 +263,7 @@ def _replace_attention( eviction_batch_size=eviction_batch_size, ) - if isinstance(child_module, Attention): + if isinstance(child_module, AttentionMHA): kv_cache = child_module.kv_cache kv_cache_with_attention_sink = KVCacheWithAttentionSink( n_heads=kv_cache.n_heads, diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index 650546b6dbb..023fc6800ff 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn -from executorch.examples.models.llama.llama_transformer import KVCache +from executorch.examples.models.llama.attention import KVCache from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index f3c297dd409..1bb7d277545 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -13,7 +13,7 @@ import torch -from executorch.examples.models.llama.llama_transformer import KVCache, SDPA +from executorch.examples.models.llama.attention import KVCache, SDPA class SDPACustom(torch.nn.Module): diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py index 5ecf3d162e3..fc882ebf4ab 100644 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -7,7 +7,7 @@ import unittest import torch -from executorch.examples.models.llama.llama_transformer import ModelArgs +from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.source_transformation.attention_sink import ( KVCacheWithAttentionSink, diff --git a/examples/models/llama/source_transformation/test_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_quantized_kv_cache.py index fac62e73664..4252518a4ee 100644 --- a/examples/models/llama/source_transformation/test_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_quantized_kv_cache.py @@ -8,7 +8,7 @@ import torch -from executorch.examples.models.llama.llama_transformer import KVCache +from executorch.examples.models.llama.attention import KVCache from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( QuantizedCacheType, diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index 6a1cdac32e0..35c88e10b6b 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -8,7 +8,7 @@ import torch -from executorch.examples.models.llama.llama_transformer import KVCache +from executorch.examples.models.llama.attention import KVCache from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( CustomKVCache, diff --git a/examples/models/llama/tests/test_pre_quantization_transforms.py b/examples/models/llama/tests/test_pre_quantization_transforms.py index dc7c640dba9..345f3fad9ba 100644 --- a/examples/models/llama/tests/test_pre_quantization_transforms.py +++ b/examples/models/llama/tests/test_pre_quantization_transforms.py @@ -7,7 +7,8 @@ import unittest import torch -from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer +from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.source_transformation.pre_quantization import ( sanitize_checkpoint_from_pre_quantization, transform_embedding_for_pre_quantization, diff --git a/examples/models/llama/tests/test_simple_sdpa.py b/examples/models/llama/tests/test_simple_sdpa.py index 3ad9f634ccf..d60bc30b7d3 100644 --- a/examples/models/llama/tests/test_simple_sdpa.py +++ b/examples/models/llama/tests/test_simple_sdpa.py @@ -7,7 +7,7 @@ import unittest import torch -from executorch.examples.models.llama.llama_transformer import KVCache, SDPA +from executorch.examples.models.llama.attention import KVCache, SDPA from executorch.examples.models.llama.source_transformation.sdpa import SDPASimple diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 68a9e59e0ce..304b49759f2 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -12,7 +12,8 @@ import requests import torch -from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer +from executorch.examples.models.llama.llama_transformer import Transformer +from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( replace_kv_cache_with_custom_kv_cache,