Skip to content

Export TorchTune llama3_2_vision in ET #5911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0b5a9a7
Add kwarg example inputs to eager model base
jackzhxng Sep 30, 2024
a9647d2
Create create new method for example kwarg inputs instead
jackzhxng Oct 7, 2024
fa3b1d2
Add kwarg example inputs to eager model base
jackzhxng Sep 30, 2024
e8715ba
Lint
jackzhxng Oct 8, 2024
a6f96a2
Accept model type parameter in export_llama
jackzhxng Oct 5, 2024
328c72c
Remove future implementation
jackzhxng Oct 5, 2024
ec80bba
Lint
jackzhxng Oct 15, 2024
c9bbe12
Create create new method for example kwarg inputs instead
jackzhxng Oct 7, 2024
99d5bfb
Accept model type parameter in export_llama
jackzhxng Oct 5, 2024
1fb2236
Torchtune llama3_2_vision model in ET, no quantization
jackzhxng Oct 5, 2024
e0c4b8a
Fix vision model example input
jackzhxng Oct 8, 2024
e145bd1
Lint
jackzhxng Oct 22, 2024
ed906cb
Kv cache
jackzhxng Oct 25, 2024
6dd47e7
Merge branch 'main' into jz/tt-llama
jackzhxng Oct 25, 2024
1825972
Update READMEs
jackzhxng Oct 25, 2024
196499a
Change model default arg
jackzhxng Oct 25, 2024
96ba40b
Update eager runner and eval llama
jackzhxng Oct 25, 2024
18a82e1
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng Oct 25, 2024
0f3035d
Fix tests
jackzhxng Oct 25, 2024
e677e14
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng Oct 25, 2024
b1f6678
Fix tests again
jackzhxng Oct 28, 2024
13d004b
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng Oct 28, 2024
de45c48
Strict = True
jackzhxng Oct 31, 2024
2fe7bd8
Merge branch 'main' into jz/tt-llama-2
jackzhxng Nov 13, 2024
64dcbda
Lint
jackzhxng Nov 13, 2024
a89d6b2
Fix merge
jackzhxng Nov 13, 2024
e5428de
Move to subdir
jackzhxng Nov 14, 2024
bf33485
Merge remote-tracking branch 'origin/main' into jz/tt-llama-2
jackzhxng Nov 14, 2024
7a0101f
Add automatically generated export tests
jackzhxng Nov 14, 2024
9777e23
Fix internal pyre warning
jackzhxng Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/scripts/gather_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"resnet50": "linux.12xlarge",
"llava": "linux.12xlarge",
"llama3_2_vision_encoder": "linux.12xlarge",
"llama3_2_text_decoder": "linux.12xlarge",
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
"dl3": "linux.12xlarge",
"emformer_join": "linux.12xlarge",
Expand Down
1 change: 1 addition & 0 deletions examples/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"llama2": ("llama", "Llama2Model"),
"llama": ("llama", "Llama2Model"),
"llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"),
"llama3_2_text_decoder": ("llama3_2_vision", "Llama3_2Decoder"),
"lstm": ("lstm", "LSTMModel"),
"mobilebert": ("mobilebert", "MobileBertModelExample"),
"mv2": ("mobilenet_v2", "MV2Model"),
Expand Down
63 changes: 35 additions & 28 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

from executorch.devtools.etrecord import generate_etrecord

from executorch.examples.models.llama.llama_transformer import ModelArgs

from executorch.extension.llm.export.builder import DType, LLMEdgeManager

from executorch.extension.llm.export.partitioner_lib import (
Expand Down Expand Up @@ -82,7 +80,7 @@


EXECUTORCH_DEFINED_MODELS = ["stories110m", "llama2", "llama3", "llama3_1", "llama3_2"]
TORCHTUNE_DEFINED_MODELS = []
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]


class WeightType(Enum):
Expand Down Expand Up @@ -138,7 +136,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"--model",
default="llama3",
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
help="The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
help="The Lllama model to export. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
)
parser.add_argument(
"-E",
Expand Down Expand Up @@ -819,16 +817,18 @@ def _load_llama_model_metadata(
use_kv_cache: bool,
use_sdpa_with_kv_cache: bool,
enable_dynamic_shape: bool,
model_args: ModelArgs,
max_seq_len: int,
n_layers: int,
vocab_size: int,
metadata_str: Optional[str] = None,
):
is_fairseq2 = weight_type == WeightType.FAIRSEQ2
metadata = {
"get_bos_id": 3 if is_fairseq2 else 1,
"get_eos_ids": [3] if is_fairseq2 else [2],
"get_max_seq_len": model_args.max_seq_len,
"get_n_layers": model_args.n_layers,
"get_vocab_size": model_args.vocab_size,
"get_max_seq_len": max_seq_len,
"get_n_layers": n_layers,
"get_vocab_size": vocab_size,
"use_kv_cache": use_kv_cache,
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
"enable_dynamic_shape": enable_dynamic_shape,
Expand Down Expand Up @@ -885,27 +885,31 @@ def _load_llama_model(
module_name = "llama"
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
elif modelname in TORCHTUNE_DEFINED_MODELS:
raise NotImplementedError(
"Torchtune Llama models are not yet supported in ExecuTorch export."
)
if modelname == "llama3_2_vision":
module_name = "llama3_2_vision"
model_class_name = "Llama3_2Decoder"
else:
raise ValueError(f"{modelname} is not a valid Llama model.")
else:
raise ValueError(f"{modelname} is not a valid Llama model.")

model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
module_name,
model_class_name,
checkpoint=checkpoint,
checkpoint_dir=checkpoint_dir,
params=params_path,
use_kv_cache=use_kv_cache,
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
generate_full_logits=generate_full_logits,
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
enable_dynamic_shape=enable_dynamic_shape,
input_prune_map_path=input_prune_map_path,
output_prune_map_path=output_prune_map_path,
args=args,
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
EagerModelFactory.create_model(
module_name,
model_class_name,
checkpoint=checkpoint,
checkpoint_dir=checkpoint_dir,
params=params_path,
use_kv_cache=use_kv_cache,
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
generate_full_logits=generate_full_logits,
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
enable_dynamic_shape=enable_dynamic_shape,
input_prune_map_path=input_prune_map_path,
output_prune_map_path=output_prune_map_path,
args=args,
)
)
if dtype_override:
assert isinstance(
Expand Down Expand Up @@ -937,12 +941,13 @@ def _load_llama_model(
return LLMEdgeManager(
model=model,
modelname=modelname,
max_seq_len=model.params.max_seq_len,
max_seq_len=model.max_seq_len,
dtype=dtype,
use_kv_cache=use_kv_cache,
generate_full_logits=generate_full_logits,
example_inputs=example_inputs,
example_kwarg_inputs=example_kwarg_inputs,
dynamic_shapes=dynamic_shapes,
enable_dynamic_shape=enable_dynamic_shape,
calibration_tasks=calibration_tasks,
calibration_limit=calibration_limit,
Expand All @@ -955,7 +960,9 @@ def _load_llama_model(
use_kv_cache,
use_sdpa_with_kv_cache,
enable_dynamic_shape,
model.params,
model.max_seq_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously all the models were getting it from model.params. Is it guaranteed that all models will have these available on them as attributes directly? Hopefully CI catches if they don't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, hopefully CI catches but should be okay since there's only two model.pys using this export_llama_lib atm and both have it defined

model.n_layers,
model.vocab_size,
metadata_str,
),
args=args,
Expand Down
2 changes: 2 additions & 0 deletions examples/models/llama3_2_vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .text_decoder.model import Llama3_2Decoder
from .vision_encoder import FlamingoVisionEncoderModel, VisionEncoderConfig

__all__ = [
"FlamingoVisionEncoderModel",
"Llama3_2Decoder",
"VisionEncoderConfig",
]
174 changes: 174 additions & 0 deletions examples/models/llama3_2_vision/text_decoder/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import json
from typing import Any, Dict

import torch
from executorch.examples.models.checkpoint import (
get_checkpoint_dtype,
get_default_model_resource_dir,
)

from executorch.examples.models.model_base import EagerModelBase
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune


def to_decoder_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
"""
Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains
weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale".
To load the text decoder on its own, the "decoder" prefix needs to be removed.
"""
return {
".".join(weight.split(".")[1:]): value
for weight, value in checkpoint.items()
if weight.startswith("decoder")
}


class Llama3_2Decoder(EagerModelBase):
"""
Just the text decoder portions of the Llama3.2 multimodal model.
"""

def __init__(self, **kwargs):
# Set member vars from kwargs.
self.max_seq_len = kwargs.get(
"max_seq_len", 8192
) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
self.encoder_max_seq_len = kwargs.get(
"encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1)
) # Same as above.
self.generate_full_logits = kwargs.get("generate_full_logits", False)
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.verbose = kwargs.get("verbose", False)
self.args = kwargs.get("args", None)

ckpt_dir = get_default_model_resource_dir(__file__)
# Single checkpoint file.
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
# Sharded checkpoint.
checkpoint_dir = kwargs.get("checkpoint_dir", None)
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")

self.causal_mask = torch.tril(
torch.ones(
size=(self.max_seq_len, self.max_seq_len),
dtype=torch.bool,
)
)
self.input_pos = torch.arange(self.max_seq_len)

# Load checkpoint and params.
device = "cpu"
if checkpoint_dir is not None:
raise NotImplementedError(
"Sharded checkpoint not yet supported for Llama3_2Decoder."
)
else:
checkpoint = torch.load(
checkpoint_path, map_location=device, weights_only=False, mmap=True
)
checkpoint = llama3_vision_meta_to_tune(checkpoint)
checkpoint = to_decoder_checkpoint(checkpoint)
with open(params_path, "r") as f:
params = json.loads(f.read())

# Find dtype from checkpoint. (skip for now)
self.dtype = get_checkpoint_dtype(checkpoint)

# Load model.
# Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
# i.e. the model isn't fully initialized or something.
self.model_ = llama3_2_vision_decoder(
vocab_size=params["vocab_size"],
num_layers=params["n_layers"],
fusion_interval=params["fusion_interval"],
num_special_tokens=params["n_special_tokens"],
num_heads=params["n_heads"],
num_kv_heads=params["n_kv_heads"],
embed_dim=params["dim"],
max_seq_len=self.max_seq_len,
encoder_max_seq_len=self.encoder_max_seq_len,
rope_base=params["rope_theta"],
intermediate_dim=params["intermediate_dim"],
)
# Save params for future use.
for param_name, param_val in params.items():
setattr(self.model_, param_name, param_val)

# Quantize. (skip for now)

# Load checkpoint.
missing, unexpected = self.model_.load_state_dict(
checkpoint,
strict=False,
assign=True,
)
if kwargs.get("verbose", False):
print("============= missing keys ================")
print(missing)
print("============= /missing ================")
print("============= unexpected keys ================")
print(unexpected)
print("============= /unexpected ================")

# Prune the output layer if output_prune_map is provided.
output_prune_map = None
if self.output_prune_map_path is not None:
from executorch.examples.models.llama2.source_transformation.prune_output import (
prune_output_vocab,
)

with open(self.output_prune_map_path, "r") as f:
output_prune_map = json.load(f)
# Change keys from string to int (json only supports string keys)
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}

self.model_ = prune_output_vocab(self.model_, output_prune_map)

# if self.use_kv_cache:
# print("Setting up KV cache on the model...")
# self.model_.setup_caches(
# batch_size=1,
# dtype=self.dtype,
# )

def get_eager_model(self) -> torch.nn.Module:
if self.dtype:
return self.model_.to(self.dtype)
else:
return self.model_.to(torch.float16)

def get_example_inputs(self):
return (torch.ones(1, 64, dtype=torch.long),)

def get_example_kwarg_inputs(self):
# TODO: add input_pos and mask when after making cache work.
return {
# "mask": self.causal_mask[None, 64, None, :],
# "encoder_input": None,
# "encoder_mask": None,
# "input_pos": self.input_pos[None, 64]
}

def get_dynamic_shapes(self):
batch_size = 1
dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
dynamic_shapes = {
"tokens": {0: batch_size, 1: dim_seq_len},
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
# "mask": {0: batch_size, 1: dim_seq_len, 2: self.max_seq_len},
# "input_pos" : {0: batch_size, 1: dim_seq_len},
}
return dynamic_shapes
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"dim": 4096,
"ffn_dim_multiplier": 1.3,
"fusion_interval": 4,
"intermediate_dim": 14336,
"multiple_of": 1024,
"n_heads": 32,
"n_kv_heads": 8,
"n_layers": 32,
"n_special_tokens": 8,
"norm_eps": 1e-05,
"rope_theta": 500000.0,
"use_scaled_rope": true,
"vision_chunk_size": 560,
"vision_max_num_chunks": 4,
"vocab_size": 128256,
"vision_num_cross_attention_layers": 8
}
Loading