diff --git a/.ci/scripts/gather_test_models.py b/.ci/scripts/gather_test_models.py index e22e1965678..078561c9d85 100755 --- a/.ci/scripts/gather_test_models.py +++ b/.ci/scripts/gather_test_models.py @@ -25,6 +25,7 @@ "resnet50": "linux.12xlarge", "llava": "linux.12xlarge", "llama3_2_vision_encoder": "linux.12xlarge", + "llama3_2_text_decoder": "linux.12xlarge", # This one causes timeout on smaller runner, the root cause is unclear (T161064121) "dl3": "linux.12xlarge", "emformer_join": "linux.12xlarge", diff --git a/examples/models/__init__.py b/examples/models/__init__.py index d3f2a74f4d9..842b87241cc 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -19,6 +19,7 @@ "llama2": ("llama", "Llama2Model"), "llama": ("llama", "Llama2Model"), "llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"), + "llama3_2_text_decoder": ("llama3_2_vision", "Llama3_2Decoder"), "lstm": ("lstm", "LSTMModel"), "mobilebert": ("mobilebert", "MobileBertModelExample"), "mv2": ("mobilenet_v2", "MV2Model"), diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index c4334443f23..7ebdf95418d 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -24,8 +24,6 @@ from executorch.devtools.etrecord import generate_etrecord -from executorch.examples.models.llama.llama_transformer import ModelArgs - from executorch.extension.llm.export.builder import DType, LLMEdgeManager from executorch.extension.llm.export.partitioner_lib import ( @@ -82,7 +80,7 @@ EXECUTORCH_DEFINED_MODELS = ["stories110m", "llama2", "llama3", "llama3_1", "llama3_2"] -TORCHTUNE_DEFINED_MODELS = [] +TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] class WeightType(Enum): @@ -138,7 +136,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--model", default="llama3", choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS, - help="The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.", + help="The Lllama model to export. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.", ) parser.add_argument( "-E", @@ -819,16 +817,18 @@ def _load_llama_model_metadata( use_kv_cache: bool, use_sdpa_with_kv_cache: bool, enable_dynamic_shape: bool, - model_args: ModelArgs, + max_seq_len: int, + n_layers: int, + vocab_size: int, metadata_str: Optional[str] = None, ): is_fairseq2 = weight_type == WeightType.FAIRSEQ2 metadata = { "get_bos_id": 3 if is_fairseq2 else 1, "get_eos_ids": [3] if is_fairseq2 else [2], - "get_max_seq_len": model_args.max_seq_len, - "get_n_layers": model_args.n_layers, - "get_vocab_size": model_args.vocab_size, + "get_max_seq_len": max_seq_len, + "get_n_layers": n_layers, + "get_vocab_size": vocab_size, "use_kv_cache": use_kv_cache, "use_sdpa_with_kv_cache": use_sdpa_with_kv_cache, "enable_dynamic_shape": enable_dynamic_shape, @@ -885,27 +885,31 @@ def _load_llama_model( module_name = "llama" model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py. elif modelname in TORCHTUNE_DEFINED_MODELS: - raise NotImplementedError( - "Torchtune Llama models are not yet supported in ExecuTorch export." - ) + if modelname == "llama3_2_vision": + module_name = "llama3_2_vision" + model_class_name = "Llama3_2Decoder" + else: + raise ValueError(f"{modelname} is not a valid Llama model.") else: raise ValueError(f"{modelname} is not a valid Llama model.") - model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model( - module_name, - model_class_name, - checkpoint=checkpoint, - checkpoint_dir=checkpoint_dir, - params=params_path, - use_kv_cache=use_kv_cache, - use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, - generate_full_logits=generate_full_logits, - fairseq2=weight_type == WeightType.FAIRSEQ2, - max_seq_len=max_seq_len, - enable_dynamic_shape=enable_dynamic_shape, - input_prune_map_path=input_prune_map_path, - output_prune_map_path=output_prune_map_path, - args=args, + model, example_inputs, example_kwarg_inputs, dynamic_shapes = ( + EagerModelFactory.create_model( + module_name, + model_class_name, + checkpoint=checkpoint, + checkpoint_dir=checkpoint_dir, + params=params_path, + use_kv_cache=use_kv_cache, + use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, + generate_full_logits=generate_full_logits, + fairseq2=weight_type == WeightType.FAIRSEQ2, + max_seq_len=max_seq_len, + enable_dynamic_shape=enable_dynamic_shape, + input_prune_map_path=input_prune_map_path, + output_prune_map_path=output_prune_map_path, + args=args, + ) ) if dtype_override: assert isinstance( @@ -937,12 +941,13 @@ def _load_llama_model( return LLMEdgeManager( model=model, modelname=modelname, - max_seq_len=model.params.max_seq_len, + max_seq_len=model.max_seq_len, dtype=dtype, use_kv_cache=use_kv_cache, generate_full_logits=generate_full_logits, example_inputs=example_inputs, example_kwarg_inputs=example_kwarg_inputs, + dynamic_shapes=dynamic_shapes, enable_dynamic_shape=enable_dynamic_shape, calibration_tasks=calibration_tasks, calibration_limit=calibration_limit, @@ -955,7 +960,9 @@ def _load_llama_model( use_kv_cache, use_sdpa_with_kv_cache, enable_dynamic_shape, - model.params, + model.max_seq_len, + model.n_layers, + model.vocab_size, metadata_str, ), args=args, diff --git a/examples/models/llama3_2_vision/__init__.py b/examples/models/llama3_2_vision/__init__.py index 2e073404dbb..b22aa07869c 100644 --- a/examples/models/llama3_2_vision/__init__.py +++ b/examples/models/llama3_2_vision/__init__.py @@ -4,9 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from .text_decoder.model import Llama3_2Decoder from .vision_encoder import FlamingoVisionEncoderModel, VisionEncoderConfig __all__ = [ "FlamingoVisionEncoderModel", + "Llama3_2Decoder", "VisionEncoderConfig", ] diff --git a/examples/models/llama3_2_vision/text_decoder/model.py b/examples/models/llama3_2_vision/text_decoder/model.py new file mode 100644 index 00000000000..73d79bb08e0 --- /dev/null +++ b/examples/models/llama3_2_vision/text_decoder/model.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import json +from typing import Any, Dict + +import torch +from executorch.examples.models.checkpoint import ( + get_checkpoint_dtype, + get_default_model_resource_dir, +) + +from executorch.examples.models.model_base import EagerModelBase +from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder +from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune + + +def to_decoder_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: + """ + Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains + weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale". + To load the text decoder on its own, the "decoder" prefix needs to be removed. + """ + return { + ".".join(weight.split(".")[1:]): value + for weight, value in checkpoint.items() + if weight.startswith("decoder") + } + + +class Llama3_2Decoder(EagerModelBase): + """ + Just the text decoder portions of the Llama3.2 multimodal model. + """ + + def __init__(self, **kwargs): + # Set member vars from kwargs. + self.max_seq_len = kwargs.get( + "max_seq_len", 8192 + ) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment. + self.encoder_max_seq_len = kwargs.get( + "encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1) + ) # Same as above. + self.generate_full_logits = kwargs.get("generate_full_logits", False) + self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) + self.output_prune_map_path = kwargs.get("output_prune_map_path", None) + self.use_kv_cache = kwargs.get("use_kv_cache", False) + self.verbose = kwargs.get("verbose", False) + self.args = kwargs.get("args", None) + + ckpt_dir = get_default_model_resource_dir(__file__) + # Single checkpoint file. + checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth") + # Sharded checkpoint. + checkpoint_dir = kwargs.get("checkpoint_dir", None) + params_path = kwargs.get("params", ckpt_dir / "demo_config.json") + + self.causal_mask = torch.tril( + torch.ones( + size=(self.max_seq_len, self.max_seq_len), + dtype=torch.bool, + ) + ) + self.input_pos = torch.arange(self.max_seq_len) + + # Load checkpoint and params. + device = "cpu" + if checkpoint_dir is not None: + raise NotImplementedError( + "Sharded checkpoint not yet supported for Llama3_2Decoder." + ) + else: + checkpoint = torch.load( + checkpoint_path, map_location=device, weights_only=False, mmap=True + ) + checkpoint = llama3_vision_meta_to_tune(checkpoint) + checkpoint = to_decoder_checkpoint(checkpoint) + with open(params_path, "r") as f: + params = json.loads(f.read()) + + # Find dtype from checkpoint. (skip for now) + self.dtype = get_checkpoint_dtype(checkpoint) + + # Load model. + # Cannot use "with torch.device("meta"):" because it causes some exceptions during export, + # i.e. the model isn't fully initialized or something. + self.model_ = llama3_2_vision_decoder( + vocab_size=params["vocab_size"], + num_layers=params["n_layers"], + fusion_interval=params["fusion_interval"], + num_special_tokens=params["n_special_tokens"], + num_heads=params["n_heads"], + num_kv_heads=params["n_kv_heads"], + embed_dim=params["dim"], + max_seq_len=self.max_seq_len, + encoder_max_seq_len=self.encoder_max_seq_len, + rope_base=params["rope_theta"], + intermediate_dim=params["intermediate_dim"], + ) + # Save params for future use. + for param_name, param_val in params.items(): + setattr(self.model_, param_name, param_val) + + # Quantize. (skip for now) + + # Load checkpoint. + missing, unexpected = self.model_.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + if kwargs.get("verbose", False): + print("============= missing keys ================") + print(missing) + print("============= /missing ================") + print("============= unexpected keys ================") + print(unexpected) + print("============= /unexpected ================") + + # Prune the output layer if output_prune_map is provided. + output_prune_map = None + if self.output_prune_map_path is not None: + from executorch.examples.models.llama2.source_transformation.prune_output import ( + prune_output_vocab, + ) + + with open(self.output_prune_map_path, "r") as f: + output_prune_map = json.load(f) + # Change keys from string to int (json only supports string keys) + output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} + + self.model_ = prune_output_vocab(self.model_, output_prune_map) + + # if self.use_kv_cache: + # print("Setting up KV cache on the model...") + # self.model_.setup_caches( + # batch_size=1, + # dtype=self.dtype, + # ) + + def get_eager_model(self) -> torch.nn.Module: + if self.dtype: + return self.model_.to(self.dtype) + else: + return self.model_.to(torch.float16) + + def get_example_inputs(self): + return (torch.ones(1, 64, dtype=torch.long),) + + def get_example_kwarg_inputs(self): + # TODO: add input_pos and mask when after making cache work. + return { + # "mask": self.causal_mask[None, 64, None, :], + # "encoder_input": None, + # "encoder_mask": None, + # "input_pos": self.input_pos[None, 64] + } + + def get_dynamic_shapes(self): + batch_size = 1 + dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len) + dynamic_shapes = { + "tokens": {0: batch_size, 1: dim_seq_len}, + # "encoder_input": {0: 1, 1: dim_enc, 2: 4096}, + # "encoder_mask": {0: 1, 1: dim, 2: dim_enc}, + # "mask": {0: batch_size, 1: dim_seq_len, 2: self.max_seq_len}, + # "input_pos" : {0: batch_size, 1: dim_seq_len}, + } + return dynamic_shapes diff --git a/examples/models/llama3_2_vision/text_decoder/params/demo_config.json b/examples/models/llama3_2_vision/text_decoder/params/demo_config.json new file mode 100644 index 00000000000..694df17d945 --- /dev/null +++ b/examples/models/llama3_2_vision/text_decoder/params/demo_config.json @@ -0,0 +1,18 @@ +{ + "dim": 4096, + "ffn_dim_multiplier": 1.3, + "fusion_interval": 4, + "intermediate_dim": 14336, + "multiple_of": 1024, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 32, + "n_special_tokens": 8, + "norm_eps": 1e-05, + "rope_theta": 500000.0, + "use_scaled_rope": true, + "vision_chunk_size": 560, + "vision_max_num_chunks": 4, + "vocab_size": 128256, + "vision_num_cross_attention_layers": 8 +}