From 1bd94dff46fe093a4c7b8a3bdd4c23ed87c23d89 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Wed, 22 Jan 2025 16:45:29 +0800 Subject: [PATCH 1/2] fix: correct Aria model output Signed-off-by: xffxff <1247714429@qq.com> --- examples/offline_inference/vision_language.py | 3 +- vllm/model_executor/models/aria.py | 51 ++++++++++++++++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index f9048c7735e..415439e88ed 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -28,9 +28,10 @@ def run_aria(question: str, modality: str): llm = LLM(model=model_name, max_model_len=4096, max_num_seqs=2, + dtype="bfloat16", disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - prompt = (f"<|im_start|>user\n<|img|>\n{question}" + prompt = (f"<|im_start|>user\n<|img|>{question}" "<|im_end|>\n<|im_start|>assistant\n") stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 34ac39b812d..d5179df3de1 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -30,6 +30,7 @@ from vllm.sequence import IntermediateTensors # yapf: disable +from .idefics2_vision_model import Idefics2VisionConfig from .idefics2_vision_model import ( Idefics2VisionTransformer as Idefics3VisionTransformer) # yapf: enable @@ -50,6 +51,50 @@ class AriaImagePixelInputs(TypedDict): """ +class AriaVisionTransformer(Idefics3VisionTransformer): + + def __init__( + self, + config: Idefics2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, quant_config, prefix) + self.post_layernorm = nn.Identity() + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + + # NOTE: post_layernorm is not used in Aria + if "post_layernorm" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + class AriaProjectorMLP(nn.Module): def __init__( @@ -228,8 +273,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_output = torch.nn.functional.linear(hidden_states, self.router_weight) + hidden_states_copy = hidden_states.clone() + # NOTE: hidden_states will be modified inplace by `FusedMoE` sparse_expert_output = self.experts(hidden_states, router_output) - shared_expert_output = self.shared_experts(hidden_states) + shared_expert_output = self.shared_experts(hidden_states_copy) return sparse_expert_output + shared_expert_output @@ -445,7 +492,7 @@ def __init__( quant_config = vllm_config.quant_config self.config = config - self.vision_tower = Idefics3VisionTransformer( + self.vision_tower = AriaVisionTransformer( config.vision_config, quant_config, prefix=f"{prefix}.vision_tower", From 10024a6282183d9ae4847829af920746241536cd Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Wed, 22 Jan 2025 17:42:51 +0800 Subject: [PATCH 2/2] add some comments Signed-off-by: xffxff <1247714429@qq.com> --- vllm/model_executor/models/aria.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index d5179df3de1..8c6873de136 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -60,6 +60,9 @@ def __init__( prefix: str = "", ) -> None: super().__init__(config, quant_config, prefix) + # Unlike Idefics3VisionTransformer which uses LayerNorm after the + # final layer, Aria omits this normalization, so we replace it with an + # Identity layer self.post_layernorm = nn.Identity() def load_weights(self, weights: Iterable[Tuple[str,