Skip to content

Commit cbe146d

Browse files
sxufacebook-github-bot
authored andcommitted
Pass ForwardOptions from top level module and also return any relevant state as output
Summary: Pass a `ForwardOptions` argument (introduced by #8128) from the top level transformer, consolidate some existing inputs into it, and return any optional updates from the attention implementation. Differential Revision: D69080123
1 parent d0b8fe3 commit cbe146d

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

examples/models/llama/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,4 @@ def forward(
252252

253253
output = self.wo(output)
254254

255-
return output
255+
return output, None

examples/models/llama/llama_transformer.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
import torch
1313
import torch.nn.functional as F
1414

15-
from executorch.examples.models.llama.attention import ATTENTION_REGISTRY
15+
from executorch.examples.models.llama.attention import (
16+
ATTENTION_REGISTRY,
17+
ForwardOptions,
18+
)
1619

1720
from executorch.examples.models.llama.model_args import ModelArgs
1821

@@ -148,17 +151,17 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
148151
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
149152
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
150153

151-
def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN
152-
h = self.attention.forward(
153-
self.attention_norm(x), freqs_cos, freqs_sin, input_pos=input_pos
154+
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
155+
h, attn_options_update = self.attention.forward(
156+
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
154157
)
155158

156159
h = x + h
157160
if hasattr(self, "block_sparse_moe"):
158161
out = h + self.block_sparse_moe(self.ffn_norm(h))
159162
else:
160163
out = h + self.feed_forward(self.ffn_norm(h))
161-
return out
164+
return out, attn_options_update
162165

163166

164167
class Transformer(nn.Module):
@@ -185,27 +188,28 @@ def __init__(self, params: ModelArgs):
185188
def forward(
186189
self,
187190
tokens: Optional[torch.LongTensor] = None, # tokens
188-
input_pos: Optional[
189-
torch.LongTensor
190-
] = None, # Scalar tensor indicating size of window of the caches
191191
h: Optional[torch.FloatTensor] = None, # embeddings
192+
attn_options: Optional[ForwardOptions] = None,
192193
) -> torch.Tensor:
193194
if (tokens is None) ^ (h is not None):
194195
raise ValueError(
195196
"You cannot specify both tokens and h at the same time, and must specify either one"
196197
)
197198
if tokens is not None and h is None:
198199
h = self.tok_embeddings(tokens)
199-
seqlen = h.shape[1]
200-
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen)
200+
if (
201+
attn_options.get("freqs_cos") is None
202+
and attn_options.get("freqs_sin") is None
203+
and (input_pos := attn_options.get("input_pos")) is not None
204+
):
205+
seqlen = h.shape[1]
206+
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen)
207+
attn_options.update({"freqs_cos": freqs_cos, "freqs_sin": freqs_sin})
201208

202209
for layer in self.layers:
203-
h = layer(
204-
h,
205-
freqs_cos,
206-
freqs_sin,
207-
input_pos,
208-
)
210+
h, attn_options_update = layer(h, **attn_options)
211+
if attn_options_update is not None:
212+
attn_options.update(**attn_options_update)
209213

210214
if not self.generate_full_logits:
211215
# Only the last logit is used for the new generated token
@@ -237,4 +241,4 @@ def forward(
237241
expanded_logits[:, list(self.output_prune_map.values())] = logits
238242
logits = expanded_logits
239243

240-
return logits
244+
return logits, attn_options_update

0 commit comments

Comments
 (0)