diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index ec55f2f1ee0..2c51f12efec 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -236,7 +236,7 @@ def forward( assert input_pos is not None k, v = self.kv_cache.update(input_pos, k, v) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) - return self.wo(output) + return self.wo(output), None # grouped multiquery attention: expand out keys and values k = k.repeat_interleave(self.n_rep, dim=1) @@ -252,4 +252,4 @@ def forward( output = self.wo(output) - return output + return output, None diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 387d45bd3f5..6872222861d 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -65,7 +65,9 @@ def _model_call(self, inps): result_logits = [] for pos in range(inps.shape[-1]): pos_tensor = torch.tensor([pos], dtype=torch.int64) - logits = self._model(inps[:, pos : pos + 1], pos_tensor) + logits = self._model( + inps[:, pos : pos + 1], {"input_pos": pos_tensor} + ) result_logits.append(logits) if self._generate_full_logits: return torch.cat(result_logits, dim=1) @@ -74,7 +76,9 @@ def _model_call(self, inps): else: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) # Batch process the whole sequence. - logits = self._model(inps[:, : self._max_seq_length], pos_tensor) + logits = self._model( + inps[:, : self._max_seq_length], {"input_pos": pos_tensor} + ) return logits else: diff --git a/examples/models/llama/evaluate/eager_eval.py b/examples/models/llama/evaluate/eager_eval.py index b3f04ef3bb5..e50b8c193c7 100644 --- a/examples/models/llama/evaluate/eager_eval.py +++ b/examples/models/llama/evaluate/eager_eval.py @@ -77,7 +77,9 @@ def _model_call(self, inps): if self._use_kv_cache: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) # Batch process the whole sequence. - logits = self._model(inps[:, : self._max_seq_length], pos_tensor) + logits = self._model( + inps[:, : self._max_seq_length], {"input_pos": pos_tensor} + ) return logits else: return self._model(inps) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 08526dde195..43159dfa80b 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -7,12 +7,15 @@ # Please refer to README.md in the same folder for more information. -from typing import Optional +from typing import Any, Optional, Tuple, Union import torch import torch.nn.functional as F -from executorch.examples.models.llama.attention import ATTENTION_REGISTRY +from executorch.examples.models.llama.attention import ( + ATTENTION_REGISTRY, + ForwardOptions, +) from executorch.examples.models.llama.model_args import ModelArgs @@ -148,9 +151,9 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN - h = self.attention.forward( - self.attention_norm(x), freqs_cos, freqs_sin, input_pos=input_pos + def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN + h, attn_options_update = self.attention.forward( + self.attention_norm(x), freqs_cos, freqs_sin, **attn_options ) h = x + h @@ -158,7 +161,7 @@ def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN out = h + self.block_sparse_moe(self.ffn_norm(h)) else: out = h + self.feed_forward(self.ffn_norm(h)) - return out + return out, attn_options_update class Transformer(nn.Module): @@ -185,27 +188,28 @@ def __init__(self, params: ModelArgs): def forward( self, tokens: Optional[torch.LongTensor] = None, # tokens - input_pos: Optional[ - torch.LongTensor - ] = None, # Scalar tensor indicating size of window of the caches + attn_options: Optional[ForwardOptions] = None, h: Optional[torch.FloatTensor] = None, # embeddings - ) -> torch.Tensor: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[Any]]]: if (tokens is None) ^ (h is not None): raise ValueError( "You cannot specify both tokens and h at the same time, and must specify either one" ) if tokens is not None and h is None: h = self.tok_embeddings(tokens) + + if attn_options is None: + attn_options = {} seqlen = h.shape[1] - freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen) + freqs_cos, freqs_sin = self.rope.get_freqs( + attn_options.get("input_pos"), seqlen + ) + attn_options_update = None for layer in self.layers: - h = layer( - h, - freqs_cos, - freqs_sin, - input_pos, - ) + h, attn_options_update = layer(h, freqs_cos, freqs_sin, attn_options) + if attn_options_update is not None: + attn_options.update(**attn_options_update) if not self.generate_full_logits: # Only the last logit is used for the new generated token @@ -237,4 +241,7 @@ def forward( expanded_logits[:, list(self.output_prune_map.values())] = logits logits = expanded_logits + if attn_options_update is not None: + return logits, attn_options_update + return logits diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 19c7ed0b311..ccc79c3e3cb 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -289,16 +289,18 @@ def get_example_inputs_kvcache_sdpa(self): if self.enable_dynamic_shape: return ( torch.tensor([[2, 3, 4]], dtype=torch.long), - torch.tensor([0], dtype=torch.long), + {"input_pos": torch.tensor([0], dtype=torch.long)}, ) else: return ( torch.tensor( [[1]], dtype=torch.long ), # tokens, with kv cache our input token length is always just 1 token. - torch.tensor( - [0], dtype=torch.long - ), # start_pos, what token of output are we on. + { + "input_pos": torch.tensor( + [0], dtype=torch.long + ) # start_pos, what token of output are we on. + }, ) def _transform_for_pre_quantization(self, checkpoint, model_args): diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 559b4e04892..54cfc283ae9 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -42,7 +42,7 @@ def forward( tokens: torch.Tensor, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.model.forward(tokens=tokens, input_pos=input_pos) + return self.model.forward(tokens, {"input_pos": input_pos}) def build_args_parser() -> argparse.ArgumentParser: diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index b76d2dcd043..a5057e5e850 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -80,7 +80,7 @@ def __init__(self, llava): self.text_model = llava.text_model def forward(self, input_pos, embeddings): - return self.text_model(None, input_pos, embeddings) + return self.text_model(None, {"input_pos": input_pos}, embeddings) llava_text_model = LlavaTextModel(llava) diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 304b49759f2..c0f8ecfc354 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -208,7 +208,7 @@ def step( ) -> torch.Tensor: """Input is one token. Return logits for next token.""" token_embeds = self.embed_tokens(token).unsqueeze(0) - return self.text_model.forward(None, input_pos, token_embeds) + return self.text_model.forward(None, {"input_pos": input_pos}, token_embeds) def image_embedding(self, images: torch.Tensor) -> torch.Tensor: preprocessed_img = self.image_preprocess(images) @@ -236,7 +236,9 @@ def prefill( """Avoiding the torch.where() call to find placeholder and insert image embedding. Taking 3 inputs instead.""" embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image) # returns the prefilled token length too, because the text model generates one logits in each forward call. - return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds) + return embeds.shape[1], self.text_model.forward( + None, {"input_pos": torch.tensor([0])}, embeds + ) # reference prefill using the text model in HF def prefill_ref( diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index be6977b639c..504be3c6343 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -170,7 +170,7 @@ def _get_dynamic_shape(self) -> Any: self.dynamic_shapes = ({1: dim},) elif self.enable_dynamic_shape: # Two input arguments: tokens and input_pos but input_pos is static shape - self.dynamic_shapes = ({1: dim}, {0: 1}) + self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}}) else: # Two input arguments: tokens and input_pos but both are of static shape self.dynamic_shapes = None @@ -270,7 +270,7 @@ def calibrate_template( while token_list[-1] != tokenizer.eos_id and pos < max_len: logits = module( torch.full((1, 1), token_list[pos]), - torch.tensor((pos,)), + {"input_pos": torch.tensor((pos,))}, ) pos += 1 if pos >= len(token_list):