From 98c425bd1fdf380d79f7e61bbfa3fc136f2b6fe1 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Fri, 4 Oct 2024 20:43:30 -0700 Subject: [PATCH 1/7] Add et version of TorchTune MHA for swapping with custom op --- examples/models/llama2/export_llama_lib.py | 2 + .../torchtune/attention.py | 101 +++++++ .../torchtune/modules/mha.py | 284 ++++++++++++++++++ .../torchtune/modules/sdpa.py | 90 ++++++ 4 files changed, 477 insertions(+) create mode 100644 examples/models/llama2/source_transformation/torchtune/attention.py create mode 100644 examples/models/llama2/source_transformation/torchtune/modules/mha.py create mode 100644 examples/models/llama2/source_transformation/torchtune/modules/sdpa.py diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 29518b641f4..d9ab495278a 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -67,6 +67,7 @@ replace_sdpa_with_flex_sdpa, replace_sdpa_with_simple_sdpa, ) +from .source_transformation.torchtune.attention import replace_mha_with_inference_mha IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -936,6 +937,7 @@ def _get_source_transforms( # noqa if args.use_sdpa_with_kv_cache: transforms.append(replace_sdpa_with_custom_op) + transforms.append(replace_mha_with_inference_mha) if args.quantize_kv_cache: assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" diff --git a/examples/models/llama2/source_transformation/torchtune/attention.py b/examples/models/llama2/source_transformation/torchtune/attention.py new file mode 100644 index 00000000000..9f84bd3f000 --- /dev/null +++ b/examples/models/llama2/source_transformation/torchtune/attention.py @@ -0,0 +1,101 @@ +import torch +import torchtune.modules.attention as TorchTuneAttention +from executorch.examples.models.llama2.source_transformation.torchtune.modules.mha import MultiHeadAttention +from executorch.examples.models.llama2.source_transformation.torchtune.modules.sdpa import SDPA + +def _replace_mha_with_inference_mha(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, TorchTuneAttention.MultiHeadAttention): + setattr( + module, + name, + MultiHeadAttention( + embed_dim=child.embed_dim, + num_heads=child.num_heads, + num_kv_heads=child.num_kv_heads, + head_dim=child.head_dim, + q_proj=child.q_proj, + k_proj=child.k_proj, + v_proj=child.v_proj, + output_proj=child.output_proj, + pos_embeddings=child.pos_embedding, + q_norm=child.q_norm, + k_norm=child.k_norm, + kv_cache=child.kv_cache, + max_seq_len=child.max_seq_len, + is_causal=child.is_causal, + attn_dropout=child.attn_dropout, + ), + ) + else: + replace_mha_with_inference_mha(child) + +def replace_mha_with_inference_mha(module: torch.nn.Module): + """ + Replace TorchTune's MHA with an inference friendly version of MHA that + separates out the inference-related parts for further optimization. + """ + _replace_mha_with_inference_mha(module) + return module + +# class SDPACustom(torch.nn.Module): +# def __init__( +# self, +# kv_cache: KVCache, +# dim: int, +# ): +# super().__init__() +# # Custom op only supports float32 currently. Converting to/from float32 is +# # faster than not having the op. +# self.kv_cache = kv_cache.to(torch.float) +# self.dim = dim + +# def forward( +# self, +# input_pos: torch.Tensor, +# q: torch.Tensor, +# k: torch.Tensor, +# v: torch.Tensor, +# bsz, +# seqlen, +# mask, +# ): +# # Custom op only supports float32 currently. Converting to/from float32 is +# # faster than not having the op. +# input_dtype = q.dtype +# q = q.to(dtype=torch.float) +# k = k.to(dtype=torch.float) +# v = v.to(dtype=torch.float) +# output = torch.ops.llama.sdpa_with_kv_cache( +# q, +# k, +# v, +# self.kv_cache.k_cache, +# self.kv_cache.v_cache, +# input_pos[-1].item(), +# seqlen, +# None, # Attention mask +# 0, # dropout probability. Ignored by the code +# True, # is_causal +# ) +# return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) + + +# def _replace_sdpa_with_custom_op(module: torch.nn.Module): +# for name, child in module.named_children(): +# if isinstance(child, SDPA): +# setattr( +# module, +# name, +# SDPACustom(child.kv_cache, child.dim), +# ) +# else: +# _replace_sdpa_with_custom_op(child) + + +# def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: +# from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa + +# _replace_sdpa_with_custom_op(module) +# return module + diff --git a/examples/models/llama2/source_transformation/torchtune/modules/mha.py b/examples/models/llama2/source_transformation/torchtune/modules/mha.py new file mode 100644 index 00000000000..855fe420ce4 --- /dev/null +++ b/examples/models/llama2/source_transformation/torchtune/modules/mha.py @@ -0,0 +1,284 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +from torch import nn +from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention +from torchtune.modules.kv_cache import KVCache +from executorch.examples.models.llama2.source_transformation.torchtune.modules.sdpa import SDPA + +logger = logging.getLogger(__name__) + + +class MultiHeadAttention(nn.Module): + """Multi-headed attention layer with support for grouped query + attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1. + + GQA is a version of multiheaded attention (MHA) which uses fewer + key/value heads than query heads by grouping n query heads for each + key and value head. Multi-Query Attention is an extreme + version where we have a single key and value head shared by all + query heads. + + Following is an example of MHA, GQA and MQA with num_heads = 4 + + (credit for the documentation: + `litgpt.Config `_). + + + :: + + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ │ │ │ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ + └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ + ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ + │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ + └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ + ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ + MHA GQA MQA + n_kv_heads =4 n_kv_heads=2 n_kv_heads=1 + + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, + for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. + head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. + q_proj (nn.Module): projection layer for query. + k_proj (nn.Module): projection layer for key. + v_proj (nn.Module): projection layer for value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied + before updating from kv_cache. This means it will only support token wide normalization and not + batch or sequence wide normalization. + k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + is_causal (bool): sets the default mask to causal when no mask is provided + attn_dropout (float): dropout value passed onto the + scaled_dot_product_attention function. This argument is ignored if the + self.training is False. Default value is 0.0. + + Raises: + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim % num_heads != 0`` + ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` + ValueError: if q_norm is defined without k_norm or vice versa + """ + + def __init__( + self, + *, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: Optional[nn.Module] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + if bool(q_norm) ^ bool(k_norm): + raise ValueError("q and k norm must be set together") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + # Number of queries per k, v + self.q_per_kv = self.num_heads // self.num_kv_heads + + # Set layers + self.kv_cache = kv_cache + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_embeddings = pos_embeddings + + # Use flex attention if supported and we are sample packing + self._attention_call = _sdpa_or_flex_attention() + self._sdpa = SDPA( + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + head_dim=self.head_dim, + q_per_kv=self.q_per_kv, + attn_dropout=self.attn_dropout, + is_causal=self.is_causal, + attention_fn=self._attention_call, + kv_cache=self.kv_cache, + ) + + def setup_cache( + self, batch_size: int, dtype: torch.dtype, max_seq_len: int + ) -> None: + """Setup key value caches for attention calculation. If called + after kv_cache is already setup, this will be skipped. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + max_seq_len (int): maximum sequence length model will be run with. + """ + # Don't overwrite user defined kv_cache from init + if self.kv_cache is not None: + logger.warning( + "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." + ) + else: + self.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=max_seq_len, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + + def reset_cache(self): + """Reset the key value caches.""" + if self.kv_cache is None: + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + self.kv_cache.reset() + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape [b x s_x x d] for the query + y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input + for k and v. For self attention, x=y. Optional only with kv_cache enabled. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either: + + A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, + or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. + A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means + token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask + is used by default. + + A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence + created via `create_block_mask `_. We use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. + Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Raises: + ValueError: If no ``y`` input and ``kv_cache`` is not enabled. + + Returns: + torch.Tensor: output tensor with attention applied + + Notation used for tensor shapes: + - b: batch size + - s_x: sequence length for x + - s_y: sequence length for y + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + """ + # x has shape [b, s_x, d] + # y has shape [b, s_y, d] + b, s_x, _ = x.shape + s_y = y.shape[1] if y is not None else 0 + + # q has shape [b, s_x, num_heads * head_dim] + q = self.q_proj(x) + q = q.view(b, s_x, self.num_kv_heads * self.q_per_kv, self.head_dim) + + # Apply positional embeddings + if self.pos_embeddings is not None: + q = self.pos_embeddings(q, input_pos=input_pos) + + # Normalize q + if self.q_norm is not None: + q = self.q_norm(q) + + if y is None: + if self.kv_cache is None: + raise ValueError( + "Must provide y input or use kv_cache to enable streaming decoding" + ) + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: + # Update k and v shape, positional embeddings, and normalization + + # k has shape [b, s_y, num_kv_heads * head_dim] + # v has shape [b, s_y, num_kv_heads * head_dim] + k = self.k_proj(y) + v = self.v_proj(y) + + # Apply positional embeddings + # k: [b, s_y, n_kv, h_d] + k = k.view(b, s_y, -1, self.head_dim) + v = v.view(b, s_y, -1, self.head_dim) + if self.pos_embeddings is not None: + k = self.pos_embeddings(k, input_pos=input_pos) + + # Normalize k + if self.k_norm is not None: + k = self.k_norm(k) + + # Update key-value cache + if self.kv_cache is not None: + self._sdpa.kv_cache_update(input_pos, k, v) + + output = self._sdpa(q, k, v, b, s_x) + return self.output_proj(output) diff --git a/examples/models/llama2/source_transformation/torchtune/modules/sdpa.py b/examples/models/llama2/source_transformation/torchtune/modules/sdpa.py new file mode 100644 index 00000000000..3f0bb324a63 --- /dev/null +++ b/examples/models/llama2/source_transformation/torchtune/modules/sdpa.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +from torch import nn, Tensor + + +class SDPA(nn.Module): + """ + The core of SDPA which can be optimized and can be swapped + out for a more efficient implementations. Split into + kv cache update and core sdpa (foward) components because + they are easier to optimize separately. + """ + + def __init__( + self, + num_kv_heads: int, + num_heads: int, + head_dim: int, + q_per_kv: int, + attn_dropout: float, + is_causal: bool, + attention_fn, + kv_cache, + ) -> None: + super().__init__() + self.num_kv_heads = num_kv_heads + self.num_heads = num_heads + self.head_dim = head_dim + self.q_per_kv = q_per_kv + self.attn_dropout = attn_dropout + self.is_causal = is_causal + self._attention_fn = attention_fn + self._kv_cache = kv_cache + + def kv_cache_update( + self, + input_pos: Tensor, + k: Tensor, + v: Tensor, + ) -> Tuple[Tensor, Tensor]: + k, v = self._kv_cache.update(input_pos, k, v) + return k, v + + def forward( + self, + q: Tensor, # [b, s, n_h, h_d] + k: Tensor, # [b, s, n_kv, h_d] + v: Tensor, # [b, s, n_kv, h_d] + bsz: int, + seq_len: int, + mask: Tensor = None, + ) -> Tensor: + # View + expand + reshape bring num_kv_heads to num_heads for k and v + # to match q. + + # k: [bsz, seq_len, n_kv, 1, h_d] + # v: [bsz, seq_len, n_kv, 1, h_d] + k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + + # Expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + + # [bsz, s, n_h, h_d] + k = k.reshape(bsz, seq_len, -1, self.head_dim) + v = v.reshape(bsz, seq_len, -1, self.head_dim) + + # [bsz, n_h, s, h_d] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + output = self._attention_fn( + q, + k, + v, + mask=mask, + dropout_p=self.attn_dropout, + is_causal=self._kv_cache is None and mask is None and self.is_causal, + ) + # Reshape the output to be the same shape as the input + return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) From 1fe035658451d57b002b309bc6acbf292c6d1d50 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Wed, 16 Oct 2024 17:57:18 -0700 Subject: [PATCH 2/7] Recent TT updates --- .../llama2/source_transformation/torchtune/modules/mha.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/models/llama2/source_transformation/torchtune/modules/mha.py b/examples/models/llama2/source_transformation/torchtune/modules/mha.py index 855fe420ce4..a3b54b1cbd6 100644 --- a/examples/models/llama2/source_transformation/torchtune/modules/mha.py +++ b/examples/models/llama2/source_transformation/torchtune/modules/mha.py @@ -70,9 +70,8 @@ class MultiHeadAttention(nn.Module): max_seq_len (int): maximum sequence length supported by the model. This is needed to compute the RoPE Cache. Default: 4096. is_causal (bool): sets the default mask to causal when no mask is provided - attn_dropout (float): dropout value passed onto the - scaled_dot_product_attention function. This argument is ignored if the - self.training is False. Default value is 0.0. + attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. + This argument is ignored if self.training is False. Default value is 0.0. Raises: ValueError: If ``num_heads % num_kv_heads != 0`` @@ -147,7 +146,7 @@ def __init__( num_heads=self.num_heads, head_dim=self.head_dim, q_per_kv=self.q_per_kv, - attn_dropout=self.attn_dropout, + attn_dropout=self.attn_dropout if self.training else 0.0, is_causal=self.is_causal, attention_fn=self._attention_call, kv_cache=self.kv_cache, From 8afb8e1ba84940fed41fe97f1c4536a4da07a801 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Wed, 16 Oct 2024 18:12:18 -0700 Subject: [PATCH 3/7] Match up mha with TT --- .../source_transformation/torchtune/modules/mha.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/models/llama2/source_transformation/torchtune/modules/mha.py b/examples/models/llama2/source_transformation/torchtune/modules/mha.py index a3b54b1cbd6..bad4cde0f8c 100644 --- a/examples/models/llama2/source_transformation/torchtune/modules/mha.py +++ b/examples/models/llama2/source_transformation/torchtune/modules/mha.py @@ -126,8 +126,6 @@ def __init__( self.head_dim = head_dim self.max_seq_len = max_seq_len self.is_causal = is_causal - # Number of queries per k, v - self.q_per_kv = self.num_heads // self.num_kv_heads # Set layers self.kv_cache = kv_cache @@ -145,7 +143,7 @@ def __init__( num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, head_dim=self.head_dim, - q_per_kv=self.q_per_kv, + q_per_kv=self.num_heads // self.num_kv_heads, attn_dropout=self.attn_dropout if self.training else 0.0, is_causal=self.is_causal, attention_fn=self._attention_call, @@ -239,7 +237,10 @@ def forward( # q has shape [b, s_x, num_heads * head_dim] q = self.q_proj(x) - q = q.view(b, s_x, self.num_kv_heads * self.q_per_kv, self.head_dim) + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) # Apply positional embeddings if self.pos_embeddings is not None: From a91666d82bd403b8c89a703e50f452642ac685f9 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Tue, 22 Oct 2024 02:40:26 -0700 Subject: [PATCH 4/7] Split sdpa into custom op and quantized kv cache --- examples/models/llama2/export_llama_lib.py | 38 +++++++- .../quantized_kv_cache.py | 64 ++++++++++++- .../llama2/source_transformation/sdpa.py | 61 ++++++++++++- .../torchtune/attention.py | 74 ++------------- .../torchtune/modules/mha.py | 90 +++++++++++++++++-- .../torchtune/modules/sdpa.py | 90 ------------------- 6 files changed, 249 insertions(+), 168 deletions(-) delete mode 100644 examples/models/llama2/source_transformation/torchtune/modules/sdpa.py diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index d9ab495278a..d2f1b248a08 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -54,6 +54,7 @@ ) from .source_transformation.quantized_kv_cache import ( replace_kv_cache_with_quantized_kv_cache, + replace_torchtune_kv_cache_with_quantized_kv_cache, ) from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm @@ -65,6 +66,7 @@ replace_sdpa_with_coreml_sdpa, replace_sdpa_with_custom_op, replace_sdpa_with_flex_sdpa, + replace_sdpa_with_sdpa_only_custom_op, replace_sdpa_with_simple_sdpa, ) from .source_transformation.torchtune.attention import replace_mha_with_inference_mha @@ -237,7 +239,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--use_sdpa_with_kv_cache", default=False, action="store_true", - help="Whether to use sdpa_with_kv_cache update op when using kv cache", + help="Whether to use a custom sdpa + kv_cache update when kv cache is enabled.", ) parser.add_argument( "--disable_dynamic_shape", @@ -582,6 +584,18 @@ def _validate_args(args): if args.num_sharding > 0 and not args.qnn: raise ValueError("Model shard is only supported with qnn backend now.") + if args.model in TORCHTUNE_DEFINED_MODELS: + if args.use_sdpa_with_kv_cache: + if not args.use_kv_cache and not args.quantize_kv_cache: + raise ValueError( + f"TorchTune-defined {args.model} only works with custom SDPA op + quantized KV cache at the moment. Please enable use_kv_cache and quantize_kv_cache when use_sdpa_with_kv_cache is enabled." + ) + if args.use_kv_cache: + if not args.quantize_kv_cache: + raise ValueError( + f"TorchTune-defined {args.model} only works with quantized KV cache at the moment. Please enable quantize_kv_cache when use_kv_cache is enabled." + ) + def _export_llama(args) -> LLMEdgeManager: # noqa: C901 _validate_args(args) @@ -884,6 +898,7 @@ def _load_llama_model( def _get_source_transforms( # noqa modelname: str, dtype_override: Optional[DType], args ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: + is_torchtune_model = modelname in TORCHTUNE_DEFINED_MODELS transforms = [] if args.use_spin_quant: @@ -936,12 +951,27 @@ def _get_source_transforms( # noqa transforms.append(materialze_broadcast_of_rope_freq_cis) if args.use_sdpa_with_kv_cache: - transforms.append(replace_sdpa_with_custom_op) - transforms.append(replace_mha_with_inference_mha) + if is_torchtune_model: + assert ( + args.use_kv_cache and args.quantize_kv_cache + ), "use_sdpa_with_kv_cache requires use_kv_cache=True and quantize_kv_cache=True for TorchTune at the moment." + transforms.append(replace_mha_with_inference_mha) + transforms.append(replace_sdpa_with_sdpa_only_custom_op) + else: + transforms.append(replace_sdpa_with_custom_op) if args.quantize_kv_cache: assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" - transforms.append(replace_kv_cache_with_quantized_kv_cache) + if is_torchtune_model: + transforms.append( + lambda module: replace_torchtune_kv_cache_with_quantized_kv_cache( + module, + is_transposed=not args.use_sdpa_with_kv_cache, + enable_dynamic_shape=args.enable_dynamic_shape, + ) + ) + else: + transforms.append(replace_kv_cache_with_quantized_kv_cache) if args.use_kv_cache: if args.qnn: diff --git a/examples/models/llama2/source_transformation/quantized_kv_cache.py b/examples/models/llama2/source_transformation/quantized_kv_cache.py index 8eec7846d3c..57025255e5a 100644 --- a/examples/models/llama2/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama2/source_transformation/quantized_kv_cache.py @@ -11,6 +11,7 @@ import torch.nn as nn from executorch.examples.models.llama2.llama_transformer import KVCache from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torchtune.modules.kv_cache import KVCache as TorchTuneKVCache """ @@ -207,8 +208,31 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType): kv_cache.enable_dynamic_shape, ) + @classmethod + def from_torchtune_float( + cls, + kv_cache, + cache_type: QuantizedCacheType, + is_transposed: bool, + enable_dynamic_shape: bool, + ): + cache_shape = kv_cache.k_cache.shape + if kv_cache.is_tranposed: + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape + else: + max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + return cls( + max_batch_size, + max_seq_length, + n_heads, + head_dim, + cache_type, + is_transposed, + enable_dynamic_shape, + ) -def replace_kv_cache_with_quantized_kv_cache(module): + +def replace_kv_cache_with_quantized_kv_cache(module: nn.Module) -> nn.Module: logging.warning( "Replacing KVCache with QuantizedKVCache. This modifies the model in place." ) @@ -222,3 +246,41 @@ def replace_kv_cache_with_quantized_kv_cache(module): else: replace_kv_cache_with_quantized_kv_cache(child) return module + + +def replace_torchtune_kv_cache_with_quantized_kv_cache( + module: nn.Module, is_transposed: bool, enable_dynamic_shape: bool +) -> nn.Module: + """ + Replace TorchTune KVCache with Executorch's quantized KVCache. + + Args: + is_transposed: whether q, k, and v are transposed. Should set to false when sdpa custom op source transform is enabled. + enable_dynamic_shape: whether dynamic shapes are enabled. + + Returns: + The passed in model. + """ + logging.warning( + "Replacing KVCache with QuantizedKVCache. This modifies the model in place." + ) + for name, child in module.named_children(): + if isinstance(child, TorchTuneKVCache): + cache_shape = child.k_cache.shape + if is_transposed: + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape + else: + max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + setattr( + module, + name, + QuantizedKVCache.from_torchtune_float( + child, + QuantizedCacheType.AffineAsymmetric, + is_transposed, + enable_dynamic_shape, + ), + ) + else: + replace_kv_cache_with_quantized_kv_cache(child) + return module diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index bda6966fa16..40b237b0fa6 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -80,7 +80,7 @@ def forward( input_pos[0].item(), seqlen, None, # Attention mask - 0, # dropout probability. Ignored by the code + 0, # Dropout probability, ignored by the code True, # is_causal ) return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) @@ -105,6 +105,65 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: return module +class SDPAOnlyCustom(torch.nn.Module): + """ + Just the custom SDPA op, no KV cache update included. Can only be used + in conjunction with a quantized KV cache. + """ + + def __init__( + self, + ): + super().__init__() + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + seqlen: int, + mask: torch.Tensor = None, + ): + # Custom op only supports float32 currently. Converting to/from float32 is + # faster than not having the op. + input_dtype = q.dtype + q = q.to(dtype=torch.float) + k = k.to(dtype=torch.float) + v = v.to(dtype=torch.float) + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + input_pos[0].item(), + None, # Attention mask + 0, # Dropout probability, ignored by the code. + True, # is_causal + ) + return output.view(bsz, seqlen, -1).to(dtype=input_dtype) + + +def _replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, SDPA): + assert ( + child.kv_cache.cache_fp_type == torch.float32 + ), "Only float32 is supported for custom SDPA" + setattr( + module, + name, + SDPAOnlyCustom(), + ) + else: + _replace_sdpa_with_sdpa_only_custom_op(child) + + +def replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module) -> torch.nn.Module: + _replace_sdpa_with_sdpa_only_custom_op(module) + return module + + class SDPASimple(torch.nn.Module): def __init__( diff --git a/examples/models/llama2/source_transformation/torchtune/attention.py b/examples/models/llama2/source_transformation/torchtune/attention.py index 9f84bd3f000..0dd62eafd6d 100644 --- a/examples/models/llama2/source_transformation/torchtune/attention.py +++ b/examples/models/llama2/source_transformation/torchtune/attention.py @@ -1,9 +1,11 @@ import torch import torchtune.modules.attention as TorchTuneAttention -from executorch.examples.models.llama2.source_transformation.torchtune.modules.mha import MultiHeadAttention -from executorch.examples.models.llama2.source_transformation.torchtune.modules.sdpa import SDPA +from executorch.examples.models.llama2.source_transformation.torchtune.modules.mha import ( + MultiHeadAttention, +) -def _replace_mha_with_inference_mha(module: torch.nn.Module): + +def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None: for name, child in module.named_children(): if isinstance(child, TorchTuneAttention.MultiHeadAttention): setattr( @@ -18,7 +20,7 @@ def _replace_mha_with_inference_mha(module: torch.nn.Module): k_proj=child.k_proj, v_proj=child.v_proj, output_proj=child.output_proj, - pos_embeddings=child.pos_embedding, + pos_embeddings=child.pos_embeddings, q_norm=child.q_norm, k_norm=child.k_norm, kv_cache=child.kv_cache, @@ -30,72 +32,10 @@ def _replace_mha_with_inference_mha(module: torch.nn.Module): else: replace_mha_with_inference_mha(child) -def replace_mha_with_inference_mha(module: torch.nn.Module): +def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module: """ Replace TorchTune's MHA with an inference friendly version of MHA that separates out the inference-related parts for further optimization. """ _replace_mha_with_inference_mha(module) return module - -# class SDPACustom(torch.nn.Module): -# def __init__( -# self, -# kv_cache: KVCache, -# dim: int, -# ): -# super().__init__() -# # Custom op only supports float32 currently. Converting to/from float32 is -# # faster than not having the op. -# self.kv_cache = kv_cache.to(torch.float) -# self.dim = dim - -# def forward( -# self, -# input_pos: torch.Tensor, -# q: torch.Tensor, -# k: torch.Tensor, -# v: torch.Tensor, -# bsz, -# seqlen, -# mask, -# ): -# # Custom op only supports float32 currently. Converting to/from float32 is -# # faster than not having the op. -# input_dtype = q.dtype -# q = q.to(dtype=torch.float) -# k = k.to(dtype=torch.float) -# v = v.to(dtype=torch.float) -# output = torch.ops.llama.sdpa_with_kv_cache( -# q, -# k, -# v, -# self.kv_cache.k_cache, -# self.kv_cache.v_cache, -# input_pos[-1].item(), -# seqlen, -# None, # Attention mask -# 0, # dropout probability. Ignored by the code -# True, # is_causal -# ) -# return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) - - -# def _replace_sdpa_with_custom_op(module: torch.nn.Module): -# for name, child in module.named_children(): -# if isinstance(child, SDPA): -# setattr( -# module, -# name, -# SDPACustom(child.kv_cache, child.dim), -# ) -# else: -# _replace_sdpa_with_custom_op(child) - - -# def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: -# from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa - -# _replace_sdpa_with_custom_op(module) -# return module - diff --git a/examples/models/llama2/source_transformation/torchtune/modules/mha.py b/examples/models/llama2/source_transformation/torchtune/modules/mha.py index bad4cde0f8c..7512388fd2e 100644 --- a/examples/models/llama2/source_transformation/torchtune/modules/mha.py +++ b/examples/models/llama2/source_transformation/torchtune/modules/mha.py @@ -11,13 +11,17 @@ from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache -from executorch.examples.models.llama2.source_transformation.torchtune.modules.sdpa import SDPA logger = logging.getLogger(__name__) class MultiHeadAttention(nn.Module): - """Multi-headed attention layer with support for grouped query + """ + NOTE: copied from Torchtune's mha.py. Should be mostly 1:1 except + that SDPA is factored out so that it can be swapped for more + efficient ExecuTorch-defined SDPA ops. + + Multi-headed attention layer with support for grouped query attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1. GQA is a version of multiheaded attention (MHA) which uses fewer @@ -71,7 +75,7 @@ class MultiHeadAttention(nn.Module): This is needed to compute the RoPE Cache. Default: 4096. is_causal (bool): sets the default mask to causal when no mask is provided attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. - This argument is ignored if self.training is False. Default value is 0.0. + Default value is 0.0. Raises: ValueError: If ``num_heads % num_kv_heads != 0`` @@ -150,6 +154,11 @@ def __init__( kv_cache=self.kv_cache, ) + # this flag indicates whether to update the kv-cache during forward + # passes. when disabled, we can have the cache setup but still + # perform normal forward passes + self.cache_enabled = False + def setup_cache( self, batch_size: int, dtype: torch.dtype, max_seq_len: int ) -> None: @@ -174,6 +183,7 @@ def setup_cache( head_dim=self.head_dim, dtype=dtype, ) + self.cache_enabled = True def reset_cache(self): """Reset the key value caches.""" @@ -277,8 +287,78 @@ def forward( k = self.k_norm(k) # Update key-value cache - if self.kv_cache is not None: - self._sdpa.kv_cache_update(input_pos, k, v) + if self.kv_cache is not None and self.cache_enabled: + k, v = self._sdpa.kv_cache.update(input_pos, k, v) output = self._sdpa(q, k, v, b, s_x) return self.output_proj(output) + + +class SDPA(nn.Module): + """ + TorchTune's SDPA which can be optimized and can be swapped + out for a more efficient implementations. + """ + + def __init__( + self, + num_kv_heads: int, + num_heads: int, + head_dim: int, + q_per_kv: int, + attn_dropout: float, + is_causal: bool, + attention_fn, + kv_cache, + ) -> None: + super().__init__() + self.num_kv_heads = num_kv_heads + self.num_heads = num_heads + self.head_dim = head_dim + self.q_per_kv = q_per_kv + self.attn_dropout = attn_dropout + self.is_causal = is_causal + self._attention_fn = attention_fn + self.kv_cache = kv_cache + + def forward( + self, + q: torch.Tensor, # [b, s, n_h, h_d] + k: torch.Tensor, # [b, s, n_kv, h_d] + v: torch.Tensor, # [b, s, n_kv, h_d] + bsz: int, + seq_len: int, + mask: torch.Tensor = None, + ) -> torch.Tensor: + # View + expand + reshape bring num_kv_heads to num_heads for k and v + # to match q. + + # k: [bsz, seq_len, n_kv, 1, h_d] + # v: [bsz, seq_len, n_kv, 1, h_d] + k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + + # Expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + + # [bsz, s, n_h, h_d] + k = k.reshape(bsz, seq_len, -1, self.head_dim) + v = v.reshape(bsz, seq_len, -1, self.head_dim) + + # [bsz, n_h, s, h_d] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + output = self._attention_fn( + q, + k, + v, + mask=mask, + dropout_p=self.attn_dropout, + is_causal=self.kv_cache is None and mask is None and self.is_causal, + ) + # Reshape the output to be the same shape as the input + return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) diff --git a/examples/models/llama2/source_transformation/torchtune/modules/sdpa.py b/examples/models/llama2/source_transformation/torchtune/modules/sdpa.py deleted file mode 100644 index 3f0bb324a63..00000000000 --- a/examples/models/llama2/source_transformation/torchtune/modules/sdpa.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Tuple - -from torch import nn, Tensor - - -class SDPA(nn.Module): - """ - The core of SDPA which can be optimized and can be swapped - out for a more efficient implementations. Split into - kv cache update and core sdpa (foward) components because - they are easier to optimize separately. - """ - - def __init__( - self, - num_kv_heads: int, - num_heads: int, - head_dim: int, - q_per_kv: int, - attn_dropout: float, - is_causal: bool, - attention_fn, - kv_cache, - ) -> None: - super().__init__() - self.num_kv_heads = num_kv_heads - self.num_heads = num_heads - self.head_dim = head_dim - self.q_per_kv = q_per_kv - self.attn_dropout = attn_dropout - self.is_causal = is_causal - self._attention_fn = attention_fn - self._kv_cache = kv_cache - - def kv_cache_update( - self, - input_pos: Tensor, - k: Tensor, - v: Tensor, - ) -> Tuple[Tensor, Tensor]: - k, v = self._kv_cache.update(input_pos, k, v) - return k, v - - def forward( - self, - q: Tensor, # [b, s, n_h, h_d] - k: Tensor, # [b, s, n_kv, h_d] - v: Tensor, # [b, s, n_kv, h_d] - bsz: int, - seq_len: int, - mask: Tensor = None, - ) -> Tensor: - # View + expand + reshape bring num_kv_heads to num_heads for k and v - # to match q. - - # k: [bsz, seq_len, n_kv, 1, h_d] - # v: [bsz, seq_len, n_kv, 1, h_d] - k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) - v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) - - # Expand the key and value tensors to have the same shape - # as the query tensor by copying values across the relevant dim - if self.num_heads != self.num_kv_heads: - k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) - v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) - - # [bsz, s, n_h, h_d] - k = k.reshape(bsz, seq_len, -1, self.head_dim) - v = v.reshape(bsz, seq_len, -1, self.head_dim) - - # [bsz, n_h, s, h_d] - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - output = self._attention_fn( - q, - k, - v, - mask=mask, - dropout_p=self.attn_dropout, - is_causal=self._kv_cache is None and mask is None and self.is_causal, - ) - # Reshape the output to be the same shape as the input - return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) From 310b3a39fb9ece6a4e358d2e81d5fc1e1923d71f Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Fri, 1 Nov 2024 10:43:53 -0700 Subject: [PATCH 5/7] Move llama2 -> llama --- examples/models/llama/export_llama_lib.py | 1 + .../source_transformation/torchtune/attention.py | 0 .../source_transformation/torchtune/modules/mha.py | 0 3 files changed, 1 insertion(+) rename examples/models/{llama2 => llama}/source_transformation/torchtune/attention.py (100%) rename examples/models/{llama2 => llama}/source_transformation/torchtune/modules/mha.py (100%) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index e5458ef7c3d..4a94d3d148d 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -962,6 +962,7 @@ def _get_source_transforms( # noqa if args.expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) + transforms.append(replace_mha_with_inference_mha) if args.use_sdpa_with_kv_cache: if is_torchtune_model: assert ( diff --git a/examples/models/llama2/source_transformation/torchtune/attention.py b/examples/models/llama/source_transformation/torchtune/attention.py similarity index 100% rename from examples/models/llama2/source_transformation/torchtune/attention.py rename to examples/models/llama/source_transformation/torchtune/attention.py diff --git a/examples/models/llama2/source_transformation/torchtune/modules/mha.py b/examples/models/llama/source_transformation/torchtune/modules/mha.py similarity index 100% rename from examples/models/llama2/source_transformation/torchtune/modules/mha.py rename to examples/models/llama/source_transformation/torchtune/modules/mha.py From 458785282ac5931fde80123305913b3a692661e1 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Fri, 1 Nov 2024 11:52:33 -0700 Subject: [PATCH 6/7] Lint and print --- examples/models/llama/export_llama_lib.py | 8 ++++++-- .../llama/source_transformation/torchtune/attention.py | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 4a94d3d148d..8472e66b9c1 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -70,10 +70,10 @@ replace_sdpa_with_simple_sdpa, ) -from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb - from .source_transformation.torchtune.attention import replace_mha_with_inference_mha +from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb + IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -1019,4 +1019,8 @@ def _get_source_transforms( # noqa if args.vulkan: transforms.append(replace_with_vulkan_rotary_emb) + print( + f"Performing the following source transformations: {[transform.__name__ for transform in transforms]}" + ) + return transforms diff --git a/examples/models/llama/source_transformation/torchtune/attention.py b/examples/models/llama/source_transformation/torchtune/attention.py index 0dd62eafd6d..f6062275f93 100644 --- a/examples/models/llama/source_transformation/torchtune/attention.py +++ b/examples/models/llama/source_transformation/torchtune/attention.py @@ -32,6 +32,7 @@ def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None: else: replace_mha_with_inference_mha(child) + def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module: """ Replace TorchTune's MHA with an inference friendly version of MHA that From 3145bde57f8c8ec84153b29c2f29deb570a7a65f Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Fri, 1 Nov 2024 12:05:32 -0700 Subject: [PATCH 7/7] Revert portion to move to next PR --- examples/models/llama/export_llama_lib.py | 46 +------------ .../quantized_kv_cache.py | 64 +------------------ .../llama/source_transformation/sdpa.py | 61 +----------------- 3 files changed, 5 insertions(+), 166 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 8472e66b9c1..0b1946f0cb6 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -54,7 +54,6 @@ ) from .source_transformation.quantized_kv_cache import ( replace_kv_cache_with_quantized_kv_cache, - replace_torchtune_kv_cache_with_quantized_kv_cache, ) from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm @@ -66,15 +65,10 @@ replace_sdpa_with_coreml_sdpa, replace_sdpa_with_custom_op, replace_sdpa_with_flex_sdpa, - replace_sdpa_with_sdpa_only_custom_op, replace_sdpa_with_simple_sdpa, ) - -from .source_transformation.torchtune.attention import replace_mha_with_inference_mha - from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb - IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -243,7 +237,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--use_sdpa_with_kv_cache", default=False, action="store_true", - help="Whether to use a custom sdpa + kv_cache update when kv cache is enabled.", + help="Whether to use sdpa_with_kv_cache update op when using kv cache", ) parser.add_argument( "--disable_dynamic_shape", @@ -595,18 +589,6 @@ def _validate_args(args): if args.num_sharding > 0 and not args.qnn: raise ValueError("Model shard is only supported with qnn backend now.") - if args.model in TORCHTUNE_DEFINED_MODELS: - if args.use_sdpa_with_kv_cache: - if not args.use_kv_cache and not args.quantize_kv_cache: - raise ValueError( - f"TorchTune-defined {args.model} only works with custom SDPA op + quantized KV cache at the moment. Please enable use_kv_cache and quantize_kv_cache when use_sdpa_with_kv_cache is enabled." - ) - if args.use_kv_cache: - if not args.quantize_kv_cache: - raise ValueError( - f"TorchTune-defined {args.model} only works with quantized KV cache at the moment. Please enable quantize_kv_cache when use_kv_cache is enabled." - ) - def _export_llama(args) -> LLMEdgeManager: # noqa: C901 _validate_args(args) @@ -910,7 +892,6 @@ def _load_llama_model( def _get_source_transforms( # noqa modelname: str, dtype_override: Optional[DType], args ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: - is_torchtune_model = modelname in TORCHTUNE_DEFINED_MODELS transforms = [] if args.use_spin_quant: @@ -962,29 +943,12 @@ def _get_source_transforms( # noqa if args.expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) - transforms.append(replace_mha_with_inference_mha) if args.use_sdpa_with_kv_cache: - if is_torchtune_model: - assert ( - args.use_kv_cache and args.quantize_kv_cache - ), "use_sdpa_with_kv_cache requires use_kv_cache=True and quantize_kv_cache=True for TorchTune at the moment." - transforms.append(replace_mha_with_inference_mha) - transforms.append(replace_sdpa_with_sdpa_only_custom_op) - else: - transforms.append(replace_sdpa_with_custom_op) + transforms.append(replace_sdpa_with_custom_op) if args.quantize_kv_cache: assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" - if is_torchtune_model: - transforms.append( - lambda module: replace_torchtune_kv_cache_with_quantized_kv_cache( - module, - is_transposed=not args.use_sdpa_with_kv_cache, - enable_dynamic_shape=args.enable_dynamic_shape, - ) - ) - else: - transforms.append(replace_kv_cache_with_quantized_kv_cache) + transforms.append(replace_kv_cache_with_quantized_kv_cache) if args.use_kv_cache: if args.qnn: @@ -1019,8 +983,4 @@ def _get_source_transforms( # noqa if args.vulkan: transforms.append(replace_with_vulkan_rotary_emb) - print( - f"Performing the following source transformations: {[transform.__name__ for transform in transforms]}" - ) - return transforms diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index e9399195c74..6d92a45e800 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -11,7 +11,6 @@ import torch.nn as nn from executorch.examples.models.llama.llama_transformer import KVCache from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 -from torchtune.modules.kv_cache import KVCache as TorchTuneKVCache """ @@ -208,31 +207,8 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType): kv_cache.enable_dynamic_shape, ) - @classmethod - def from_torchtune_float( - cls, - kv_cache, - cache_type: QuantizedCacheType, - is_transposed: bool, - enable_dynamic_shape: bool, - ): - cache_shape = kv_cache.k_cache.shape - if kv_cache.is_tranposed: - max_batch_size, n_heads, max_seq_length, head_dim = cache_shape - else: - max_batch_size, max_seq_length, n_heads, head_dim = cache_shape - return cls( - max_batch_size, - max_seq_length, - n_heads, - head_dim, - cache_type, - is_transposed, - enable_dynamic_shape, - ) - -def replace_kv_cache_with_quantized_kv_cache(module: nn.Module) -> nn.Module: +def replace_kv_cache_with_quantized_kv_cache(module): logging.warning( "Replacing KVCache with QuantizedKVCache. This modifies the model in place." ) @@ -246,41 +222,3 @@ def replace_kv_cache_with_quantized_kv_cache(module: nn.Module) -> nn.Module: else: replace_kv_cache_with_quantized_kv_cache(child) return module - - -def replace_torchtune_kv_cache_with_quantized_kv_cache( - module: nn.Module, is_transposed: bool, enable_dynamic_shape: bool -) -> nn.Module: - """ - Replace TorchTune KVCache with Executorch's quantized KVCache. - - Args: - is_transposed: whether q, k, and v are transposed. Should set to false when sdpa custom op source transform is enabled. - enable_dynamic_shape: whether dynamic shapes are enabled. - - Returns: - The passed in model. - """ - logging.warning( - "Replacing KVCache with QuantizedKVCache. This modifies the model in place." - ) - for name, child in module.named_children(): - if isinstance(child, TorchTuneKVCache): - cache_shape = child.k_cache.shape - if is_transposed: - max_batch_size, n_heads, max_seq_length, head_dim = cache_shape - else: - max_batch_size, max_seq_length, n_heads, head_dim = cache_shape - setattr( - module, - name, - QuantizedKVCache.from_torchtune_float( - child, - QuantizedCacheType.AffineAsymmetric, - is_transposed, - enable_dynamic_shape, - ), - ) - else: - replace_kv_cache_with_quantized_kv_cache(child) - return module diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 2ac33d82616..f8362648f32 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -80,7 +80,7 @@ def forward( input_pos[0].item(), seqlen, None, # Attention mask - 0, # Dropout probability, ignored by the code + 0, # dropout probability. Ignored by the code True, # is_causal ) return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) @@ -105,65 +105,6 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: return module -class SDPAOnlyCustom(torch.nn.Module): - """ - Just the custom SDPA op, no KV cache update included. Can only be used - in conjunction with a quantized KV cache. - """ - - def __init__( - self, - ): - super().__init__() - - def forward( - self, - input_pos: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - bsz: int, - seqlen: int, - mask: torch.Tensor = None, - ): - # Custom op only supports float32 currently. Converting to/from float32 is - # faster than not having the op. - input_dtype = q.dtype - q = q.to(dtype=torch.float) - k = k.to(dtype=torch.float) - v = v.to(dtype=torch.float) - output = torch.ops.llama.custom_sdpa( - q, - k, - v, - input_pos[0].item(), - None, # Attention mask - 0, # Dropout probability, ignored by the code. - True, # is_causal - ) - return output.view(bsz, seqlen, -1).to(dtype=input_dtype) - - -def _replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module): - for name, child in module.named_children(): - if isinstance(child, SDPA): - assert ( - child.kv_cache.cache_fp_type == torch.float32 - ), "Only float32 is supported for custom SDPA" - setattr( - module, - name, - SDPAOnlyCustom(), - ) - else: - _replace_sdpa_with_sdpa_only_custom_op(child) - - -def replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module) -> torch.nn.Module: - _replace_sdpa_with_sdpa_only_custom_op(module) - return module - - class SDPASimple(torch.nn.Module): def __init__(