diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 91168a388d3..66eeb10989f 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -175,9 +175,16 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.max_batch_size = args.max_batch_size self.max_context_len = args.max_context_len self.dim = args.dim - self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.attention_qkv_bias = args.attention_qkv_bias + self.wq = nn.Linear( + self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wk = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wv = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.layer_id = layer_id diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index e1c4edb8e93..28804839815 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -21,6 +21,7 @@ class ModelArgs: num_experts: int = 8 # Number of experts num_activated_experts: int = 2 # Number of experts to activate attention_type: str = "mha" # Attention type, registered in attention.py + attention_qkv_bias: bool = False use_kv_cache: bool = False # Use key/value cache use_sdpa_with_kv_cache_op: bool = ( False # Use custom sdpa op that updates kv cache in-place diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 01352f404df..e081c442032 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -114,6 +114,7 @@ def apply_rotary_emb_to_k( return xk_out.type_as(xk) +# Wrap apply_rotary_emb in a module to enable it to be module swapped out. class RotaryEmbedding(torch.nn.Module): def __init__(self): super().__init__() @@ -213,14 +214,20 @@ class Rope(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() self.params = params + + # Choose the appropriate RoPE implementation if self.params.use_hf_rope: self.precompute_freqs_cis = hf_precompute_freqs_cis + self.apply_rotary_emb = hf_apply_rotary_emb else: self.precompute_freqs_cis = partial( precompute_freqs_cis, use_scaled=self.params.use_scaled_rope, scale_factor=self.params.rope_scale_factor, ) + self.apply_rotary_emb = RotaryEmbedding() + + # Precompute frequencies freqs_cos, freqs_sin = self.precompute_freqs_cis( self.params.head_dim, ( @@ -232,10 +239,6 @@ def __init__(self, params: ModelArgs): ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) - if self.params.use_hf_rope: - self.apply_rotary_emb = hf_apply_rotary_emb - else: - self.apply_rotary_emb = RotaryEmbedding() def forward( self, diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 3a5f88ad3f3..43db873fb65 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -207,22 +207,23 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): self.dim = config.dim self.head_dim = config.head_dim self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5) + self.attention_qkv_bias = config.attention_qkv_bias self.wqs = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=False) + nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) for _ in range(self.n_heads) ] ) self.wks = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=False) + nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) for _ in range(self.n_kv_heads) ] ) self.wvs = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=False) + nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) for _ in range(self.n_kv_heads) ] ) diff --git a/examples/models/qwen2_5/1_5b_config.json b/examples/models/qwen2_5/1_5b_config.json new file mode 100644 index 00000000000..64daca5a7cd --- /dev/null +++ b/examples/models/qwen2_5/1_5b_config.json @@ -0,0 +1,14 @@ +{ + "dim": 1536, + "ffn_dim_multiplier": 1, + "hidden_dim": 8960, + "n_heads": 12, + "n_kv_heads": 2, + "n_layers": 28, + "norm_eps": 1e-06, + "rope_theta": 1000000.0, + "use_scaled_rope": false, + "vocab_size": 151936, + "use_hf_rope": true, + "attention_qkv_bias": true +} diff --git a/examples/models/qwen2_5/README.md b/examples/models/qwen2_5/README.md new file mode 100644 index 00000000000..9bf791a35ed --- /dev/null +++ b/examples/models/qwen2_5/README.md @@ -0,0 +1,63 @@ +## Summary +Qwen 2.5 is the latest iteration of the Qwen series of large language models (LLMs) developed by Alibaba. At the moment, 1.5b is currently supporting, with plans in the future for adding the 0.5b and 3b versions. + +## Instructions + +Qwen 2.5 uses the same example code as Llama, while the checkpoint, model params, and tokenizer are different. Please see the [Llama README page](../llama/README.md) for details. + +All commands for exporting and running Llama on various backends should also be applicable to Qwen 2.5, by swapping the following args: +``` +--model qwen2_5 +--params examples/models/qwen2_5/1_5b_config.json +--checkpoint +``` + +### Generate the Checkpoint +The original checkpoint can be obtained from HuggingFace: +``` +huggingface-cli download Qwen/Qwen2.5-1.5B +``` + +We then convert it to Meta's checkpoint format: +``` +python examples/models/qwen2_5/convert_weights.py +``` + +### Example export and run +Here is an basic example for exporting and running Qwen 2.5, although please refer to [Llama README page](../llama/README.md) for more advanced usage. + +Export to XNNPack, no quantization: +``` +# No quantization +# Set these paths to point to the downloaded files +QWEN_CHECKPOINT=path/to/checkpoint.pth + +python -m examples.models.llama.export_llama \ + --model "qwen2_5" \ + --checkpoint "${QWEN_CHECKPOINT:?}" \ + --params examples/models/qwen2_5/1_5b_config.json \ + -kv \ + --use_sdpa_with_kv_cache \ + -d fp32 \ + -X \ + --metadata '{"get_bos_id":151643, "get_eos_ids":[151643]}' \ + --output_name="qwen2_5-1_5b.pte" + --verbose +``` + +Run using the executor runner: +``` +# Currently a work in progress, just need to enable HuggingFace json tokenizer in C++. +# In the meantime, can run with an example Python runner with pybindings: + +python -m examples.models.llama.runner.native + --model qwen2_5 + --pte + -kv + --tokenizer /tokenizer.json + --tokenizer_config /tokenizer_config.json + --prompt "Who is the founder of Meta?" + --params examples/models/qwen2_5/1_5b_config.json + --max_len 64 + --temperature 0 +``` diff --git a/examples/models/qwen2_5/convert_weights.py b/examples/models/qwen2_5/convert_weights.py new file mode 100644 index 00000000000..6b6c0bbdfe2 --- /dev/null +++ b/examples/models/qwen2_5/convert_weights.py @@ -0,0 +1,90 @@ +import argparse +from typing import Dict + +import torch + +from torchtune.models.convert_weights import get_mapped_key + +from torchtune.training import FullModelHFCheckpointer + +# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. +_QWEN_2_FROM_META = { + "tok_embeddings.weight": "tok_embeddings.weight", + "norm.weight": "norm.scale", + "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", + "layers.{}.attention.wk.bias": "layers.{}.attn.k_proj.bias", + "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", + "layers.{}.attention.wq.bias": "layers.{}.attn.q_proj.bias", + "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", + "layers.{}.attention.wv.bias": "layers.{}.attn.v_proj.bias", + "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", + "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", + "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", + "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", + "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", + "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", +} + + +def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from torchtune's format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _QWEN_2_FROM_META.items()} + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[new_key] = value + + # 0.5b and 1.5b models share the same weights for tok_embeddings and output embeddings, see https://github.com/QwenLM/Qwen2.5/issues/733. + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + + return converted_state_dict + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Qwen2 weights to Meta format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing checkpoint files", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + + # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. + checkpointer = FullModelHFCheckpointer( + # checkpoint_dir="/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/", + checkpoint_dir=args.input_dir, + checkpoint_files=["model.safetensors"], + output_dir=".", + model_type="QWEN2", + ) + + print("Loading checkpoint...") + sd = checkpointer.load_checkpoint() + + print("Converting checkpoint...") + sd = qwen_2_tune_to_meta(sd["model"]) + # torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth") + + torch.save(sd, args.output) + print(f"Checkpoint saved to {args.output}") + + +if __name__ == "__main__": + main()