diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index c6f01dbadce..ec55f2f1ee0 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Dict, Optional, Tuple, Type, TypedDict import torch import torch.nn as nn @@ -8,6 +8,15 @@ from executorch.examples.models.llama.rope import Rope +class ForwardOptions(TypedDict, total=False): + """Optional parameters for `Attention.forward` (compative with Python 3.10 and plus).""" + + mask: Optional[torch.Tensor] + input_pos: Optional[torch.Tensor] + in_cache_state: Optional[Any] + out_cache_state: Optional[Any] + + class Attention(nn.Module, ABC): """Abstract base class for attention mechanisms with unified interface.""" @@ -17,19 +26,14 @@ def forward( x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, - mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - in_cache_state: Optional[Any] = None, - out_cache_state: Optional[Any] = None, + **kwargs: ForwardOptions, ) -> Tuple[torch.Tensor, Optional[Any]]: """Forward pass for attention mechanism. Args: x: Input tensor of shape (batch_size, seq_len, dim) freqs_cos, freqs_sin: Rotary position embedding frequencies - mask: Optional attention mask - input_pos: Positions for KV cache updates - in_cache_state/out_cache_state: Cache states + ForwardOptions: grouped optional args Returns: Tuple of (output tensor, updated cache state) @@ -209,11 +213,9 @@ def forward( x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, - mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - in_cache_state: Optional[Any] = None, - out_cache_state: Optional[Any] = None, + **kwargs: ForwardOptions, ) -> Tuple[torch.Tensor, Optional[Any]]: + input_pos = kwargs.get("input_pos") bsz, seqlen, _ = x.shape # QKV