-
Notifications
You must be signed in to change notification settings - Fork 595
Export TorchTune llama3_2_vision in ET #5911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
0b5a9a7
Add kwarg example inputs to eager model base
jackzhxng a9647d2
Create create new method for example kwarg inputs instead
jackzhxng fa3b1d2
Add kwarg example inputs to eager model base
jackzhxng e8715ba
Lint
jackzhxng a6f96a2
Accept model type parameter in export_llama
jackzhxng 328c72c
Remove future implementation
jackzhxng ec80bba
Lint
jackzhxng c9bbe12
Create create new method for example kwarg inputs instead
jackzhxng 99d5bfb
Accept model type parameter in export_llama
jackzhxng 1fb2236
Torchtune llama3_2_vision model in ET, no quantization
jackzhxng e0c4b8a
Fix vision model example input
jackzhxng e145bd1
Lint
jackzhxng ed906cb
Kv cache
jackzhxng 6dd47e7
Merge branch 'main' into jz/tt-llama
jackzhxng 1825972
Update READMEs
jackzhxng 196499a
Change model default arg
jackzhxng 96ba40b
Update eager runner and eval llama
jackzhxng 18a82e1
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng 0f3035d
Fix tests
jackzhxng e677e14
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng b1f6678
Fix tests again
jackzhxng 13d004b
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng de45c48
Strict = True
jackzhxng 2fe7bd8
Merge branch 'main' into jz/tt-llama-2
jackzhxng 64dcbda
Lint
jackzhxng a89d6b2
Fix merge
jackzhxng e5428de
Move to subdir
jackzhxng bf33485
Merge remote-tracking branch 'origin/main' into jz/tt-llama-2
jackzhxng 7a0101f
Add automatically generated export tests
jackzhxng 9777e23
Fix internal pyre warning
jackzhxng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
|
||
import json | ||
from typing import Any, Dict | ||
|
||
import torch | ||
from executorch.examples.models.checkpoint import ( | ||
get_checkpoint_dtype, | ||
get_default_model_resource_dir, | ||
) | ||
|
||
from executorch.examples.models.model_base import EagerModelBase | ||
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder | ||
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune | ||
|
||
|
||
def to_decoder_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: | ||
""" | ||
Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains | ||
weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale". | ||
To load the text decoder on its own, the "decoder" prefix needs to be removed. | ||
""" | ||
return { | ||
".".join(weight.split(".")[1:]): value | ||
for weight, value in checkpoint.items() | ||
if weight.startswith("decoder") | ||
} | ||
|
||
|
||
class Llama3_2Decoder(EagerModelBase): | ||
""" | ||
Just the text decoder portions of the Llama3.2 multimodal model. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
# Set member vars from kwargs. | ||
self.max_seq_len = kwargs.get( | ||
"max_seq_len", 8192 | ||
) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment. | ||
self.encoder_max_seq_len = kwargs.get( | ||
"encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1) | ||
) # Same as above. | ||
self.generate_full_logits = kwargs.get("generate_full_logits", False) | ||
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) | ||
self.output_prune_map_path = kwargs.get("output_prune_map_path", None) | ||
self.use_kv_cache = kwargs.get("use_kv_cache", False) | ||
self.verbose = kwargs.get("verbose", False) | ||
self.args = kwargs.get("args", None) | ||
|
||
ckpt_dir = get_default_model_resource_dir(__file__) | ||
# Single checkpoint file. | ||
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth") | ||
# Sharded checkpoint. | ||
checkpoint_dir = kwargs.get("checkpoint_dir", None) | ||
params_path = kwargs.get("params", ckpt_dir / "demo_config.json") | ||
|
||
self.causal_mask = torch.tril( | ||
torch.ones( | ||
size=(self.max_seq_len, self.max_seq_len), | ||
dtype=torch.bool, | ||
) | ||
) | ||
self.input_pos = torch.arange(self.max_seq_len) | ||
|
||
# Load checkpoint and params. | ||
device = "cpu" | ||
if checkpoint_dir is not None: | ||
raise NotImplementedError( | ||
"Sharded checkpoint not yet supported for Llama3_2Decoder." | ||
) | ||
else: | ||
checkpoint = torch.load( | ||
checkpoint_path, map_location=device, weights_only=False, mmap=True | ||
) | ||
checkpoint = llama3_vision_meta_to_tune(checkpoint) | ||
checkpoint = to_decoder_checkpoint(checkpoint) | ||
with open(params_path, "r") as f: | ||
params = json.loads(f.read()) | ||
|
||
# Find dtype from checkpoint. (skip for now) | ||
self.dtype = get_checkpoint_dtype(checkpoint) | ||
|
||
# Load model. | ||
# Cannot use "with torch.device("meta"):" because it causes some exceptions during export, | ||
# i.e. the model isn't fully initialized or something. | ||
self.model_ = llama3_2_vision_decoder( | ||
vocab_size=params["vocab_size"], | ||
num_layers=params["n_layers"], | ||
fusion_interval=params["fusion_interval"], | ||
num_special_tokens=params["n_special_tokens"], | ||
num_heads=params["n_heads"], | ||
num_kv_heads=params["n_kv_heads"], | ||
embed_dim=params["dim"], | ||
max_seq_len=self.max_seq_len, | ||
encoder_max_seq_len=self.encoder_max_seq_len, | ||
rope_base=params["rope_theta"], | ||
intermediate_dim=params["intermediate_dim"], | ||
) | ||
# Save params for future use. | ||
for param_name, param_val in params.items(): | ||
setattr(self.model_, param_name, param_val) | ||
|
||
# Quantize. (skip for now) | ||
|
||
# Load checkpoint. | ||
missing, unexpected = self.model_.load_state_dict( | ||
checkpoint, | ||
strict=False, | ||
assign=True, | ||
) | ||
if kwargs.get("verbose", False): | ||
print("============= missing keys ================") | ||
print(missing) | ||
print("============= /missing ================") | ||
print("============= unexpected keys ================") | ||
print(unexpected) | ||
print("============= /unexpected ================") | ||
|
||
# Prune the output layer if output_prune_map is provided. | ||
output_prune_map = None | ||
if self.output_prune_map_path is not None: | ||
from executorch.examples.models.llama2.source_transformation.prune_output import ( | ||
prune_output_vocab, | ||
) | ||
|
||
with open(self.output_prune_map_path, "r") as f: | ||
output_prune_map = json.load(f) | ||
# Change keys from string to int (json only supports string keys) | ||
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} | ||
|
||
self.model_ = prune_output_vocab(self.model_, output_prune_map) | ||
|
||
# if self.use_kv_cache: | ||
# print("Setting up KV cache on the model...") | ||
# self.model_.setup_caches( | ||
# batch_size=1, | ||
# dtype=self.dtype, | ||
# ) | ||
|
||
def get_eager_model(self) -> torch.nn.Module: | ||
if self.dtype: | ||
return self.model_.to(self.dtype) | ||
else: | ||
return self.model_.to(torch.float16) | ||
|
||
def get_example_inputs(self): | ||
return (torch.ones(1, 64, dtype=torch.long),) | ||
|
||
def get_example_kwarg_inputs(self): | ||
# TODO: add input_pos and mask when after making cache work. | ||
return { | ||
# "mask": self.causal_mask[None, 64, None, :], | ||
# "encoder_input": None, | ||
# "encoder_mask": None, | ||
# "input_pos": self.input_pos[None, 64] | ||
} | ||
|
||
def get_dynamic_shapes(self): | ||
batch_size = 1 | ||
dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len) | ||
dynamic_shapes = { | ||
"tokens": {0: batch_size, 1: dim_seq_len}, | ||
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096}, | ||
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc}, | ||
# "mask": {0: batch_size, 1: dim_seq_len, 2: self.max_seq_len}, | ||
# "input_pos" : {0: batch_size, 1: dim_seq_len}, | ||
} | ||
return dynamic_shapes |
18 changes: 18 additions & 0 deletions
18
examples/models/llama3_2_vision/text_decoder/params/demo_config.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
{ | ||
"dim": 4096, | ||
"ffn_dim_multiplier": 1.3, | ||
"fusion_interval": 4, | ||
"intermediate_dim": 14336, | ||
"multiple_of": 1024, | ||
"n_heads": 32, | ||
"n_kv_heads": 8, | ||
"n_layers": 32, | ||
"n_special_tokens": 8, | ||
"norm_eps": 1e-05, | ||
"rope_theta": 500000.0, | ||
"use_scaled_rope": true, | ||
"vision_chunk_size": 560, | ||
"vision_max_num_chunks": 4, | ||
"vocab_size": 128256, | ||
"vision_num_cross_attention_layers": 8 | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously all the models were getting it from
model.params
. Is it guaranteed that all models will have these available on them as attributes directly? Hopefully CI catches if they don't.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, hopefully CI catches but should be okay since there's only two
model.py
s using thisexport_llama_lib
atm and both have it defined