From 24bf59b0affe1586e030dc724f516352cdccd534 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Wed, 16 Apr 2025 13:42:59 +0000 Subject: [PATCH 01/24] Added test to functionally verify match between HF and vLLM Signed-off-by: Thomas Ortner --- .../language/test_granitemoehybrid.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/models/decoder_only/language/test_granitemoehybrid.py diff --git a/tests/models/decoder_only/language/test_granitemoehybrid.py b/tests/models/decoder_only/language/test_granitemoehybrid.py new file mode 100644 index 000000000000..fff27c6ca47d --- /dev/null +++ b/tests/models/decoder_only/language/test_granitemoehybrid.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from ...utils import check_logprobs_close + +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_model_equivalence_to_hf_greedy( + hf_runner, + vllm_runner, + example_prompts, + dtype: str, + max_tokens: int, + num_logprobs: int, +): + # Path of the checkpoints + DIR = '/block/granite/granite-hybridmoe-7b-a1b-base-pipecleaner-hf' + + with hf_runner(DIR, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(DIR, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + +if __name__ == "__main__": + pytest.main(["tests/models/decoder_only/language/test_granitemoehybrid.py"]) From b7e89a0279d4fe791eadc92b1f982b406c18d132 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 16 Apr 2025 14:03:51 +0000 Subject: [PATCH 02/24] GraniteMoeHybrid model Signed-off-by: Stanislaw Wozniak --- .../model_executor/models/granitemoehybrid.py | 593 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 2 files changed, 594 insertions(+) create mode 100644 vllm/model_executor/models/granitemoehybrid.py diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py new file mode 100644 index 000000000000..fd6f3360d176 --- /dev/null +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -0,0 +1,593 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only GraniteMoeHybrid model.""" +# Added by the IBM Team, 2025 +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import GraniteMoeHybridConfig + +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, + QKVParallelLinear, ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsQuant, SupportsV0Only) +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +from .granitemoe import GraniteMoeMoE +from .granitemoeshared import GraniteMoeSharedMLP + + +class GraniteMoeHybridMambaDecoderLayer(nn.Module): + + def __init__(self, + config: GraniteMoeHybridConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + + self.self_attn = MambaMixer2(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + quant_config=quant_config) + + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + + self.shared_mlp = None if \ + getattr(config, 'shared_intermediate_size', 0) == 0 \ + else GraniteMoeSharedMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.shared_mlp" + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, mamba_cache_params, + mamba2_metadata) + hidden_states = residual + hidden_states * self.residual_multiplier + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if self.shared_mlp is None: + hidden_states = self.block_sparse_moe(hidden_states) + else: + # create a copy since block_sparse_moe modifies in-place + moe_hidden_states = hidden_states.clone() + moe_hidden_states = self.block_sparse_moe(moe_hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + del moe_hidden_states + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states, residual + + +class GraniteMoeHybridAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteMoeHybridConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + + self.self_attn = GraniteMoeHybridMultiheadLatentAttention( + config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + + self.shared_mlp = None if \ + getattr(config, 'shared_intermediate_size', 0) == 0 \ + else GraniteMoeSharedMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.shared_mlp" + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if self.shared_mlp is None: + hidden_states = self.block_sparse_moe(hidden_states) + else: + # create a copy since block_sparse_moe modifies in-place + moe_hidden_states = hidden_states.clone() + moe_hidden_states = self.block_sparse_moe(moe_hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + del moe_hidden_states + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states, residual + + +class GraniteMoeHybridMultiheadLatentAttention(nn.Module): + + def __init__( + self, + config: GraniteMoeHybridConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.causal = True + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.attention_bias = config.attention_bias + self.query_compression_size = config.mla_query_comp_size + self.key_value_compression_size = config.mla_key_value_comp_size + self.attention_multiplier = config.attention_multiplier + self.softmax_dropout_p = config.mla_softmax_dropout + + self.softmax_dropout = nn.Identity() if config.mla_softmax_dropout == 0 else nn.Dropout(config.mla_softmax_dropout) + self.dropout = nn.Identity() if config.mla_dropout == 0 else nn.Dropout(config.mla_dropout) + self.head_dim = self.hidden_size // self.num_heads + + self.c_attn_down_projection = ReplicatedLinear(self.hidden_size, + self.query_compression_size + 2 * self.key_value_compression_size, + bias=self.attention_bias, + quant_config=quant_config) + + self.query_up_projection = ReplicatedLinear(self.query_compression_size, + self.hidden_size, + bias=self.attention_bias, + quant_config=quant_config) + + self.key_up_projection = ReplicatedLinear(self.key_value_compression_size, + self.hidden_size, + bias=self.attention_bias, + quant_config=quant_config) + + self.value_up_projection = ReplicatedLinear(self.key_value_compression_size, + self.hidden_size, + bias=self.attention_bias, + quant_config=quant_config) + + self.c_proj = ReplicatedLinear(self.hidden_size, + self.hidden_size, + bias=self.attention_bias, + quant_config=quant_config) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.attention_multiplier, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + + hidden_states = self.c_attn_down_projection(hidden_states)[0] + query, key, value = hidden_states.split( + (self.query_compression_size, self.key_value_compression_size, self.key_value_compression_size), dim=-1 + ) + query = self.query_up_projection(query)[0] + key = self.key_up_projection(key)[0] + value = self.value_up_projection(value)[0] + + hidden_states = self.attn(query, key, value) + del query, key, value + + hidden_states = self.c_proj(hidden_states)[0] + hidden_states = self.dropout(hidden_states) + return hidden_states + + + +ALL_DECODER_LAYER_TYPES = { + "multihead_latent_attention": GraniteMoeHybridAttentionDecoderLayer, + "mamba2": GraniteMoeHybridMambaDecoderLayer, +} + +class GraniteMoeHybridModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.embedding_multiplier = config.embedding_multiplier + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layer_types[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + attn_metadata = get_forward_context().attn_metadata + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + input_ids=input_ids, + attn_metadata=attn_metadata, + ) + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states = hidden_states * self.embedding_multiplier + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + residual = None + num_attn = 0 + for i in range(len(self.layers)): + layer = self.layers[i] + if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer): + num_attn += 1 + + layer_mamba_cache_params = None + if isinstance(layer, GraniteMoeHybridMambaDecoderLayer): + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + i - num_attn) + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=layer_mamba_cache_params, + mamba2_metadata=mamba2_metadata + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, + weights: Iterable[Tuple[str, torch.Tensor]] + ) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + def _load(n, p): + param = params_dict[n] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, p) + loaded_params.add(n) + + def _load_expert(n, p, name, shard_id, expert_id): + param = params_dict[n] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, p, name, shard_id=shard_id, expert_id=expert_id) + loaded_params.add(n) + + for n, p in weights: + if "A_log" in n: + n = n.replace("A_log", "A") + + # Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215 + # Mapping different experts' layout: from HF (input_linear, output_linear, router) to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate) + if n.endswith('.block_sparse_moe.input_linear.weight'): + for e in range(p.size(0)): + w1_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w1.weight") + w3_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w3.weight") + w1_param, w3_param = p[e].chunk(2, dim=0) + _load_expert(n.replace('.input_linear.','.experts.w13_'), w1_param, w1_name, shard_id='w1', expert_id=e) + _load_expert(n.replace('.input_linear.','.experts.w13_'), w3_param, w3_name, shard_id='w3', expert_id=e) + elif n.endswith('.block_sparse_moe.output_linear.weight'): + for e in range(p.size(0)): + w2_name = n.replace( + '.block_sparse_moe.output_linear.weight', + f".block_sparse_moe.experts.{e}.w2.weight") + w2_param = p[e] + _load_expert(n.replace('.output_linear.', '.experts.w2_'), w2_param, w2_name, shard_id='w2', expert_id=e) + elif n.endswith('.block_sparse_moe.router.layer.weight'): + gate_name = n.replace('.block_sparse_moe.router.layer.weight', + ".block_sparse_moe.gate.weight") + _load(gate_name, p) + else: + _load(n,p) + + return loaded_params + + +class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsV0Only, SupportsQuant): + #LoRA + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + # layer_types in hf_config are "multihead_latent_attention" or "mamba2" + # vLLM cache initialization expects that property layers_block_type returns exactly "attention" or "mamba", so we remap the strings here: + def _layers_block_type(self): + result = [] + for l in self.layer_types: + if 'attention' in l: + result.append('attention') + if 'mamba' in l: + result.append('mamba') + return result + #inject custom property getter code: + vllm_config.model_config.hf_config.__class__.layers_block_type = property(lambda self: _layers_block_type(self)) + + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "GraniteMoeHybrid currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = GraniteMoeHybridModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head")) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=1 / + self.config.logits_scaling) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.model_config.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = self.config.mamba_expand * hidden_size + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_n_heads, world_size), + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 156a201de35a..1bfb45e13c0f 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -65,6 +65,7 @@ "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), + "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501 "GritLM": ("gritlm", "GritLM"), "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), From e0136b31f8d4a8cf68e3d4bb4c44584938993c20 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Mon, 21 Apr 2025 18:54:23 +0000 Subject: [PATCH 03/24] Removed MLA and added RoPE Signed-off-by: Thomas Ortner --- .../language/test_granitemoehybrid.py | 16 ++-- .../model_executor/models/granitemoehybrid.py | 74 +++++++++---------- 2 files changed, 48 insertions(+), 42 deletions(-) diff --git a/tests/models/decoder_only/language/test_granitemoehybrid.py b/tests/models/decoder_only/language/test_granitemoehybrid.py index fff27c6ca47d..0d8d1894806c 100644 --- a/tests/models/decoder_only/language/test_granitemoehybrid.py +++ b/tests/models/decoder_only/language/test_granitemoehybrid.py @@ -3,6 +3,14 @@ import pytest from ...utils import check_logprobs_close +# Path of the checkpoints +MODELS = [ + "/block/granite/granite-4.0-tiny-base-pipecleaner-hf", + # "/code/granite/granite-4_0-small-base-pipecleaner-hf", + # "/code/granitegranite-4_0-medium-base-pipecleaner-hf", +] + +@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @@ -10,18 +18,16 @@ def test_model_equivalence_to_hf_greedy( hf_runner, vllm_runner, example_prompts, + model: str, dtype: str, max_tokens: int, num_logprobs: int, ): - # Path of the checkpoints - DIR = '/block/granite/granite-hybridmoe-7b-a1b-base-pipecleaner-hf' - - with hf_runner(DIR, dtype=dtype) as hf_model: + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - with vllm_runner(DIR, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index fd6f3360d176..514f2df35f06 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -13,15 +13,14 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, - QKVParallelLinear, ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -34,7 +33,7 @@ from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant, SupportsV0Only) -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -55,7 +54,7 @@ def __init__(self, self.hidden_size = config.hidden_size self.residual_multiplier = config.residual_multiplier - self.self_attn = MambaMixer2(hidden_size= config.hidden_size, + self.mamba = MambaMixer2(hidden_size= config.hidden_size, ssm_state_size = config.mamba_d_state, conv_kernel_size = config.mamba_d_conv, intermediate_size = config.mamba_expand *\ @@ -100,7 +99,7 @@ def forward( ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn(hidden_states, mamba_cache_params, + hidden_states = self.mamba(hidden_states, mamba_cache_params, mamba2_metadata) hidden_states = residual + hidden_states * self.residual_multiplier @@ -205,44 +204,48 @@ def __init__( self.causal = True self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.attention_bias = config.attention_bias - self.query_compression_size = config.mla_query_comp_size - self.key_value_compression_size = config.mla_key_value_comp_size + self.attention_multiplier = config.attention_multiplier - self.softmax_dropout_p = config.mla_softmax_dropout - self.softmax_dropout = nn.Identity() if config.mla_softmax_dropout == 0 else nn.Dropout(config.mla_softmax_dropout) - self.dropout = nn.Identity() if config.mla_dropout == 0 else nn.Dropout(config.mla_dropout) - self.head_dim = self.hidden_size // self.num_heads - - self.c_attn_down_projection = ReplicatedLinear(self.hidden_size, - self.query_compression_size + 2 * self.key_value_compression_size, - bias=self.attention_bias, - quant_config=quant_config) - - self.query_up_projection = ReplicatedLinear(self.query_compression_size, - self.hidden_size, + self.q_proj = ReplicatedLinear( self.hidden_size, + self.num_heads * self.head_dim, bias=self.attention_bias, quant_config=quant_config) - self.key_up_projection = ReplicatedLinear(self.key_value_compression_size, - self.hidden_size, + self.k_proj = ReplicatedLinear( self.hidden_size, + self.num_key_value_heads * self.head_dim, bias=self.attention_bias, quant_config=quant_config) - self.value_up_projection = ReplicatedLinear(self.key_value_compression_size, - self.hidden_size, + self.v_proj = ReplicatedLinear( self.hidden_size, + self.num_key_value_heads * self.head_dim, bias=self.attention_bias, quant_config=quant_config) - self.c_proj = ReplicatedLinear(self.hidden_size, + self.o_proj = ReplicatedLinear(self.hidden_size, self.hidden_size, bias=self.attention_bias, quant_config=quant_config) + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type == "rope": + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=int(config.rope_theta), + rope_scaling=config.rope_scaling if hasattr(config, "rope_scaling") and config.rope_scaling is not None else None, + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, self.head_dim, self.attention_multiplier, + num_kv_heads=self.num_key_value_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn") @@ -251,29 +254,26 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: - hidden_states = self.c_attn_down_projection(hidden_states)[0] - query, key, value = hidden_states.split( - (self.query_compression_size, self.key_value_compression_size, self.key_value_compression_size), dim=-1 - ) - query = self.query_up_projection(query)[0] - key = self.key_up_projection(key)[0] - value = self.value_up_projection(value)[0] + query = self.q_proj(hidden_states)[0] + key = self.k_proj(hidden_states)[0] + value = self.v_proj(hidden_states)[0] + + if self.position_embedding_type == "rope": + query, key = self.rotary_emb(positions, query, key) hidden_states = self.attn(query, key, value) del query, key, value - hidden_states = self.c_proj(hidden_states)[0] - hidden_states = self.dropout(hidden_states) + hidden_states = self.o_proj(hidden_states)[0] return hidden_states ALL_DECODER_LAYER_TYPES = { - "multihead_latent_attention": GraniteMoeHybridAttentionDecoderLayer, - "mamba2": GraniteMoeHybridMambaDecoderLayer, + "attention": GraniteMoeHybridAttentionDecoderLayer, + "mamba": GraniteMoeHybridMambaDecoderLayer, } class GraniteMoeHybridModel(nn.Module): From 4cfef42ab8b11770b460a91289544db831b9acbb Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Mon, 21 Apr 2025 18:56:27 +0000 Subject: [PATCH 04/24] Updated basic examples Signed-off-by: Thomas Ortner --- examples/offline_inference/basic/basic.py | 8 +++-- examples/offline_inference/basic/basic_HF.py | 36 ++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 examples/offline_inference/basic/basic_HF.py diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index ae5ae7cb4834..2aab89dc9271 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import torch from vllm import LLM, SamplingParams # Sample prompts. @@ -11,15 +12,18 @@ ] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +# sampling_params = SamplingParams(temperature=0.) +DIR = '/block/granite/granite-4.0-tiny-base-pipecleaner-hf' +# DIR = '/block/granite/granite-hybridmoe-7b-a1b-base-pipecleaner-hf' def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m") + llm = LLM(model=DIR, dtype=torch.float16, gpu_memory_utilization=0.5)#, enforce_eager=True) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) + outputs = llm.generate(prompts[1], sampling_params) # Print the outputs. print("\nGenerated Outputs:\n" + "-" * 60) for output in outputs: diff --git a/examples/offline_inference/basic/basic_HF.py b/examples/offline_inference/basic/basic_HF.py new file mode 100644 index 000000000000..ff53a6d5338d --- /dev/null +++ b/examples/offline_inference/basic/basic_HF.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +DIR = '/block/granite/granite-4.0-tiny-base-pipecleaner-hf' +# DIR = '/code/granite/granite-4_0-small-base-pipecleaner-hf' +# DIR = '/code/granite/granite-4_0-medium-base-pipecleaner-hf' + +def main(): + tokenizer = AutoTokenizer.from_pretrained(DIR) + inputs = tokenizer(prompts[1], return_tensors="pt").to("cuda") + + model = AutoModelForCausalLM.from_pretrained(DIR, torch_dtype=torch.float16).to("cuda") + + outputs_ids = model.generate(**inputs, max_new_tokens=20) + + # Print the outputs. + outputs_str = tokenizer.batch_decode(outputs_ids, skip_special_tokens=True) + print("\nGenerated Outputs:\n" + "-" * 60) + prompt = prompts[1] + print(f"Prompt: {prompt!r}") + print(f"Output: {outputs_str!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() From 859e4736c78b7dcfb982a55c7cb177cef74f32f0 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 24 Apr 2025 12:42:13 +0000 Subject: [PATCH 05/24] TensorParallel and cleanup Signed-off-by: Stanislaw Wozniak --- .../language/test_granitemoehybrid.py | 14 ++--- .../model_executor/models/granitemoehybrid.py | 58 ++++++------------- 2 files changed, 26 insertions(+), 46 deletions(-) diff --git a/tests/models/decoder_only/language/test_granitemoehybrid.py b/tests/models/decoder_only/language/test_granitemoehybrid.py index 0d8d1894806c..110b278b1fbc 100644 --- a/tests/models/decoder_only/language/test_granitemoehybrid.py +++ b/tests/models/decoder_only/language/test_granitemoehybrid.py @@ -5,9 +5,9 @@ # Path of the checkpoints MODELS = [ - "/block/granite/granite-4.0-tiny-base-pipecleaner-hf", - # "/code/granite/granite-4_0-small-base-pipecleaner-hf", - # "/code/granitegranite-4_0-medium-base-pipecleaner-hf", + "/code/granite/granite-4_0-tiny-base-pipecleaner-hf", + #"/code/granite/granite-4_0-small-base-pipecleaner-hf", + # "/code/granite/granite-4_0-medium-base-pipecleaner-hf", ] @pytest.mark.parametrize("model", MODELS) @@ -23,13 +23,13 @@ def test_model_equivalence_to_hf_greedy( max_tokens: int, num_logprobs: int, ): - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 514f2df35f06..2d7050baa26e 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) @@ -132,7 +132,7 @@ def __init__( self.hidden_size = config.hidden_size self.residual_multiplier = config.residual_multiplier - self.self_attn = GraniteMoeHybridMultiheadLatentAttention( + self.self_attn = GraniteMoeHybridAttention( config, cache_config=cache_config, quant_config=quant_config, @@ -191,7 +191,7 @@ def forward( return hidden_states, residual -class GraniteMoeHybridMultiheadLatentAttention(nn.Module): +class GraniteMoeHybridAttention(nn.Module): def __init__( self, @@ -203,33 +203,35 @@ def __init__( super().__init__() self.causal = True self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.attention_bias = config.attention_bias - self.attention_multiplier = config.attention_multiplier + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads - self.q_proj = ReplicatedLinear( self.hidden_size, + self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.head_dim, bias=self.attention_bias, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.q_proj") - self.k_proj = ReplicatedLinear( self.hidden_size, + self.k_proj = ColumnParallelLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.k_proj") - self.v_proj = ReplicatedLinear( self.hidden_size, + self.v_proj = ColumnParallelLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.v_proj") - self.o_proj = ReplicatedLinear(self.hidden_size, + self.o_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=self.attention_bias, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.o_proj") self.position_embedding_type = config.position_embedding_type if self.position_embedding_type == "rope": @@ -434,15 +436,7 @@ def _load_expert(n, p, name, shard_id, expert_id): class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsV0Only, SupportsQuant): - #LoRA - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": ["up_proj", "down_proj"] - } + packed_modules_mapping = {} embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", @@ -450,20 +444,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, Suppor embedding_padding_modules = ["lm_head"] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - - # layer_types in hf_config are "multihead_latent_attention" or "mamba2" - # vLLM cache initialization expects that property layers_block_type returns exactly "attention" or "mamba", so we remap the strings here: - def _layers_block_type(self): - result = [] - for l in self.layer_types: - if 'attention' in l: - result.append('attention') - if 'mamba' in l: - result.append('mamba') - return result - #inject custom property getter code: - vllm_config.model_config.hf_config.__class__.layers_block_type = property(lambda self: _layers_block_type(self)) - config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config From bcb5e77a31b9ec2bc3e326b7edd8c43de257c16e Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Thu, 24 Apr 2025 12:58:00 +0000 Subject: [PATCH 06/24] Fixing previous commit Signed-off-by: Stanislaw Wozniak --- vllm/model_executor/models/granitemoehybrid.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 2d7050baa26e..0ccca7cd8860 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.linear import ReplicatedLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) @@ -209,25 +209,25 @@ def __init__( self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads - self.q_proj = ColumnParallelLinear(self.hidden_size, + self.q_proj = ReplicatedLinear( self.hidden_size, self.num_heads * self.head_dim, bias=self.attention_bias, quant_config=quant_config, prefix=f"{prefix}.q_proj") - self.k_proj = ColumnParallelLinear(self.hidden_size, + self.k_proj = ReplicatedLinear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias, quant_config=quant_config, prefix=f"{prefix}.k_proj") - self.v_proj = ColumnParallelLinear(self.hidden_size, + self.v_proj = ReplicatedLinear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias, quant_config=quant_config, prefix=f"{prefix}.v_proj") - self.o_proj = RowParallelLinear( self.hidden_size, + self.o_proj = ReplicatedLinear( self.hidden_size, self.hidden_size, bias=self.attention_bias, quant_config=quant_config, From 18bb63da94ed81aaf29fb9851ce5f89c1321b8c1 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 30 Apr 2025 08:18:00 +0000 Subject: [PATCH 07/24] Cleanup Co-authored-by: Thomas Ortner Signed-off-by: Stanislaw Wozniak --- examples/offline_inference/basic/basic_HF.py | 36 ------------- .../language/test_granitemoehybrid.py | 9 ++-- tests/models/registry.py | 1 + .../model_executor/models/granitemoehybrid.py | 50 ++++++++++--------- 4 files changed, 31 insertions(+), 65 deletions(-) delete mode 100644 examples/offline_inference/basic/basic_HF.py diff --git a/examples/offline_inference/basic/basic_HF.py b/examples/offline_inference/basic/basic_HF.py deleted file mode 100644 index ff53a6d5338d..000000000000 --- a/examples/offline_inference/basic/basic_HF.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] - -DIR = '/block/granite/granite-4.0-tiny-base-pipecleaner-hf' -# DIR = '/code/granite/granite-4_0-small-base-pipecleaner-hf' -# DIR = '/code/granite/granite-4_0-medium-base-pipecleaner-hf' - -def main(): - tokenizer = AutoTokenizer.from_pretrained(DIR) - inputs = tokenizer(prompts[1], return_tensors="pt").to("cuda") - - model = AutoModelForCausalLM.from_pretrained(DIR, torch_dtype=torch.float16).to("cuda") - - outputs_ids = model.generate(**inputs, max_new_tokens=20) - - # Print the outputs. - outputs_str = tokenizer.batch_decode(outputs_ids, skip_special_tokens=True) - print("\nGenerated Outputs:\n" + "-" * 60) - prompt = prompts[1] - print(f"Prompt: {prompt!r}") - print(f"Output: {outputs_str!r}") - print("-" * 60) - - -if __name__ == "__main__": - main() diff --git a/tests/models/decoder_only/language/test_granitemoehybrid.py b/tests/models/decoder_only/language/test_granitemoehybrid.py index 110b278b1fbc..40508edebf74 100644 --- a/tests/models/decoder_only/language/test_granitemoehybrid.py +++ b/tests/models/decoder_only/language/test_granitemoehybrid.py @@ -2,7 +2,7 @@ import pytest from ...utils import check_logprobs_close - + # Path of the checkpoints MODELS = [ "/code/granite/granite-4_0-tiny-base-pipecleaner-hf", @@ -26,17 +26,14 @@ def test_model_equivalence_to_hf_greedy( with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - + check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, name_0="hf", name_1="vllm", ) - -if __name__ == "__main__": - pytest.main(["tests/models/decoder_only/language/test_granitemoehybrid.py"]) diff --git a/tests/models/registry.py b/tests/models/registry.py index a19c43b698f1..d1074693db43 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -163,6 +163,7 @@ def check_available_online( {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), + "GraniteMoeHybridForCausalLM": _HfExamplesInfo("/code/granite/granite-4_0-tiny-base-pipecleaner-hf"), "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 0ccca7cd8860..92efd75657a1 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ReplicatedLinear, RowParallelLinear +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import ( MambaMixer2, extra_groups_for_head_shards) @@ -52,7 +52,7 @@ def __init__(self, super().__init__() self.config = config self.hidden_size = config.hidden_size - self.residual_multiplier = config.residual_multiplier + self.residual_multiplier = config.residual_multiplier self.mamba = MambaMixer2(hidden_size= config.hidden_size, ssm_state_size = config.mamba_d_state, @@ -87,7 +87,7 @@ def __init__(self, self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + eps=config.rms_norm_eps) def forward( self, @@ -102,7 +102,7 @@ def forward( hidden_states = self.mamba(hidden_states, mamba_cache_params, mamba2_metadata) hidden_states = residual + hidden_states * self.residual_multiplier - + residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) if self.shared_mlp is None: @@ -207,32 +207,32 @@ def __init__( self.attention_multiplier = config.attention_multiplier self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - + self.num_key_value_heads = config.num_key_value_heads + self.q_proj = ReplicatedLinear( self.hidden_size, self.num_heads * self.head_dim, bias=self.attention_bias, quant_config=quant_config, prefix=f"{prefix}.q_proj") - + self.k_proj = ReplicatedLinear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias, quant_config=quant_config, prefix=f"{prefix}.k_proj") - + self.v_proj = ReplicatedLinear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias, quant_config=quant_config, prefix=f"{prefix}.v_proj") - + self.o_proj = ReplicatedLinear( self.hidden_size, self.hidden_size, bias=self.attention_bias, quant_config=quant_config, prefix=f"{prefix}.o_proj") - + self.position_embedding_type = config.position_embedding_type if self.position_embedding_type == "rope": self.rotary_emb = get_rope( @@ -240,10 +240,11 @@ def __init__( rotary_dim=self.head_dim, max_position=config.max_position_embeddings, base=int(config.rope_theta), - rope_scaling=config.rope_scaling if hasattr(config, "rope_scaling") and config.rope_scaling is not None else None, + rope_scaling=config.rope_scaling if hasattr(config, "rope_scaling") \ + and config.rope_scaling is not None else None, is_neox_style=True, ) - + self.attn = Attention(self.num_heads, self.head_dim, self.attention_multiplier, @@ -261,10 +262,10 @@ def forward( query = self.q_proj(hidden_states)[0] key = self.k_proj(hidden_states)[0] value = self.v_proj(hidden_states)[0] - + if self.position_embedding_type == "rope": query, key = self.rotary_emb(positions, query, key) - + hidden_states = self.attn(query, key, value) del query, key, value @@ -319,8 +320,7 @@ def get_layer(prefix: str): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - self.norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -382,7 +382,7 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]] ) -> Set[str]: params_dict = dict(self.named_parameters()) @@ -405,7 +405,8 @@ def _load_expert(n, p, name, shard_id, expert_id): n = n.replace("A_log", "A") # Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215 - # Mapping different experts' layout: from HF (input_linear, output_linear, router) to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate) + # Mapping different experts' layout: from HF (input_linear, output_linear, router) + # to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate) if n.endswith('.block_sparse_moe.input_linear.weight'): for e in range(p.size(0)): w1_name = n.replace( @@ -415,15 +416,18 @@ def _load_expert(n, p, name, shard_id, expert_id): '.block_sparse_moe.input_linear.weight', f".block_sparse_moe.experts.{e}.w3.weight") w1_param, w3_param = p[e].chunk(2, dim=0) - _load_expert(n.replace('.input_linear.','.experts.w13_'), w1_param, w1_name, shard_id='w1', expert_id=e) - _load_expert(n.replace('.input_linear.','.experts.w13_'), w3_param, w3_name, shard_id='w3', expert_id=e) + _load_expert(n.replace('.input_linear.','.experts.w13_'), + w1_param, w1_name, shard_id='w1', expert_id=e) + _load_expert(n.replace('.input_linear.','.experts.w13_'), + w3_param, w3_name, shard_id='w3', expert_id=e) elif n.endswith('.block_sparse_moe.output_linear.weight'): for e in range(p.size(0)): w2_name = n.replace( '.block_sparse_moe.output_linear.weight', f".block_sparse_moe.experts.{e}.w2.weight") - w2_param = p[e] - _load_expert(n.replace('.output_linear.', '.experts.w2_'), w2_param, w2_name, shard_id='w2', expert_id=e) + w2_param = p[e] + _load_expert(n.replace('.output_linear.', '.experts.w2_'), + w2_param, w2_name, shard_id='w2', expert_id=e) elif n.endswith('.block_sparse_moe.router.layer.weight'): gate_name = n.replace('.block_sparse_moe.router.layer.weight', ".block_sparse_moe.gate.weight") @@ -570,4 +574,4 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) From 7cb6a81c9e1d5c1741a3cf5d9b9ec41bdccf3dad Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 30 Apr 2025 08:26:59 +0000 Subject: [PATCH 08/24] Cleanup Co-authored-by: Thomas Ortner Signed-off-by: Stanislaw Wozniak --- examples/offline_inference/basic/basic.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 2aab89dc9271..ae5ae7cb4834 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import torch from vllm import LLM, SamplingParams # Sample prompts. @@ -12,18 +11,15 @@ ] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -# sampling_params = SamplingParams(temperature=0.) -DIR = '/block/granite/granite-4.0-tiny-base-pipecleaner-hf' -# DIR = '/block/granite/granite-hybridmoe-7b-a1b-base-pipecleaner-hf' def main(): # Create an LLM. - llm = LLM(model=DIR, dtype=torch.float16, gpu_memory_utilization=0.5)#, enforce_eager=True) + llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts[1], sampling_params) + outputs = llm.generate(prompts, sampling_params) # Print the outputs. print("\nGenerated Outputs:\n" + "-" * 60) for output in outputs: From b69ca16b6df7e215618106f9be328b08fd2bba8f Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 30 Apr 2025 19:55:47 +0000 Subject: [PATCH 09/24] Removing sampler. Fixing URLs Signed-off-by: Stanislaw Wozniak --- .../decoder_only/language/test_granitemoehybrid.py | 4 +--- tests/models/registry.py | 2 +- vllm/model_executor/models/granitemoehybrid.py | 11 ----------- 3 files changed, 2 insertions(+), 15 deletions(-) diff --git a/tests/models/decoder_only/language/test_granitemoehybrid.py b/tests/models/decoder_only/language/test_granitemoehybrid.py index 40508edebf74..fc81e8362b85 100644 --- a/tests/models/decoder_only/language/test_granitemoehybrid.py +++ b/tests/models/decoder_only/language/test_granitemoehybrid.py @@ -5,9 +5,7 @@ # Path of the checkpoints MODELS = [ - "/code/granite/granite-4_0-tiny-base-pipecleaner-hf", - #"/code/granite/granite-4_0-small-base-pipecleaner-hf", - # "/code/granite/granite-4_0-medium-base-pipecleaner-hf", + "ibm-research/granite-4.0-tiny-test", ] @pytest.mark.parametrize("model", MODELS) diff --git a/tests/models/registry.py b/tests/models/registry.py index d1074693db43..d5dc44905a83 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -163,7 +163,7 @@ def check_available_online( {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeHybridForCausalLM": _HfExamplesInfo("/code/granite/granite-4_0-tiny-base-pipecleaner-hf"), + "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-research/granite-4.0-tiny-test"), "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 92efd75657a1..b0322b89ecd4 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -21,7 +21,6 @@ Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -488,8 +487,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Used to track and store by the Mamba cache between steps. self.mamba_cache: Optional[MambaCacheManager] = None - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -563,14 +560,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) From 6f8b1f4fd457ab7af43d8d381de7e0d042d75e98 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 30 Apr 2025 21:08:20 +0000 Subject: [PATCH 10/24] Fixing pre-hook errors. Signed-off-by: Stanislaw Wozniak --- tests/models/registry.py | 2 +- vllm/model_executor/models/granitemoehybrid.py | 11 ++++++----- vllm/model_executor/models/registry.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index d5dc44905a83..6e01c766022c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -163,7 +163,7 @@ def check_available_online( {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-research/granite-4.0-tiny-test"), + "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-research/granite-4.0-tiny-test"), # noqa: E501 "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index b0322b89ecd4..a32e5261989f 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -239,7 +239,8 @@ def __init__( rotary_dim=self.head_dim, max_position=config.max_position_embeddings, base=int(config.rope_theta), - rope_scaling=config.rope_scaling if hasattr(config, "rope_scaling") \ + rope_scaling=config.rope_scaling \ + if hasattr(config, "rope_scaling") \ and config.rope_scaling is not None else None, is_neox_style=True, ) @@ -404,7 +405,7 @@ def _load_expert(n, p, name, shard_id, expert_id): n = n.replace("A_log", "A") # Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215 - # Mapping different experts' layout: from HF (input_linear, output_linear, router) + # Mapping different experts' layout: from HF (input_linear, output_linear, router) # to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate) if n.endswith('.block_sparse_moe.input_linear.weight'): for e in range(p.size(0)): @@ -425,7 +426,7 @@ def _load_expert(n, p, name, shard_id, expert_id): '.block_sparse_moe.output_linear.weight', f".block_sparse_moe.experts.{e}.w2.weight") w2_param = p[e] - _load_expert(n.replace('.output_linear.', '.experts.w2_'), + _load_expert(n.replace('.output_linear.', '.experts.w2_'), w2_param, w2_name, shard_id='w2', expert_id=e) elif n.endswith('.block_sparse_moe.router.layer.weight'): gate_name = n.replace('.block_sparse_moe.router.layer.weight', @@ -437,8 +438,8 @@ def _load_expert(n, p, name, shard_id, expert_id): return loaded_params -class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsV0Only, SupportsQuant): +class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, + SupportsPP, IsHybrid, SupportsV0Only, SupportsQuant): packed_modules_mapping = {} embedding_modules = { "embed_tokens": "input_embeddings", diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 1bfb45e13c0f..5c27cdc603e8 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -65,7 +65,7 @@ "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), - "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), + "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), # noqa: E501 "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501 "GritLM": ("gritlm", "GritLM"), "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), From c3b64604b3e341646b54f1bb4a4a923b44f5b488 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 30 Apr 2025 21:42:35 +0000 Subject: [PATCH 11/24] ruff reformatting pre-commit fix Signed-off-by: Stanislaw Wozniak --- .../language/test_granitemoehybrid.py | 2 + .../model_executor/models/granitemoehybrid.py | 126 ++++++++++-------- 2 files changed, 73 insertions(+), 55 deletions(-) diff --git a/tests/models/decoder_only/language/test_granitemoehybrid.py b/tests/models/decoder_only/language/test_granitemoehybrid.py index fc81e8362b85..133931c8d439 100644 --- a/tests/models/decoder_only/language/test_granitemoehybrid.py +++ b/tests/models/decoder_only/language/test_granitemoehybrid.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest + from ...utils import check_logprobs_close # Path of the checkpoints @@ -8,6 +9,7 @@ "ibm-research/granite-4.0-tiny-test", ] + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index a32e5261989f..ad16df835b33 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -15,10 +15,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -30,14 +30,12 @@ from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant, SupportsV0Only) -from .utils import (AutoWeightsLoader, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - from .granitemoe import GraniteMoeMoE from .granitemoeshared import GraniteMoeSharedMLP +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsQuant, SupportsV0Only) +from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory, + make_layers, maybe_prefix) class GraniteMoeHybridMambaDecoderLayer(nn.Module): @@ -208,29 +206,31 @@ def __init__( self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads - self.q_proj = ReplicatedLinear( self.hidden_size, - self.num_heads * self.head_dim, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - - self.k_proj = ReplicatedLinear( self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.k_proj") - - self.v_proj = ReplicatedLinear( self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.v_proj") - - self.o_proj = ReplicatedLinear( self.hidden_size, - self.hidden_size, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.q_proj = ReplicatedLinear(self.hidden_size, + self.num_heads * self.head_dim, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.k_proj = ReplicatedLinear(self.hidden_size, + self.num_key_value_heads * + self.head_dim, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.k_proj") + + self.v_proj = ReplicatedLinear(self.hidden_size, + self.num_key_value_heads * + self.head_dim, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.v_proj") + + self.o_proj = ReplicatedLinear(self.hidden_size, + self.hidden_size, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") self.position_embedding_type = config.position_embedding_type if self.position_embedding_type == "rope": @@ -273,12 +273,12 @@ def forward( return hidden_states - ALL_DECODER_LAYER_TYPES = { "attention": GraniteMoeHybridAttentionDecoderLayer, "mamba": GraniteMoeHybridMambaDecoderLayer, } + class GraniteMoeHybridModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -370,8 +370,7 @@ def forward( hidden_states=hidden_states, residual=residual, mamba_cache_params=layer_mamba_cache_params, - mamba2_metadata=mamba2_metadata - ) + mamba2_metadata=mamba2_metadata) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -382,22 +381,27 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, - weights: Iterable[Tuple[str, torch.Tensor]] - ) -> Set[str]: + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() - + def _load(n, p): param = params_dict[n] - weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, p) loaded_params.add(n) - + def _load_expert(n, p, name, shard_id, expert_id): param = params_dict[n] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, p, name, shard_id=shard_id, expert_id=expert_id) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + p, + name, + shard_id=shard_id, + expert_id=expert_id) loaded_params.add(n) for n, p in weights: @@ -405,8 +409,9 @@ def _load_expert(n, p, name, shard_id, expert_id): n = n.replace("A_log", "A") # Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215 - # Mapping different experts' layout: from HF (input_linear, output_linear, router) - # to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate) + # Mapping different experts' layout: + # from HF (input_linear, output_linear, router) + # to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate) if n.endswith('.block_sparse_moe.input_linear.weight'): for e in range(p.size(0)): w1_name = n.replace( @@ -416,10 +421,16 @@ def _load_expert(n, p, name, shard_id, expert_id): '.block_sparse_moe.input_linear.weight', f".block_sparse_moe.experts.{e}.w3.weight") w1_param, w3_param = p[e].chunk(2, dim=0) - _load_expert(n.replace('.input_linear.','.experts.w13_'), - w1_param, w1_name, shard_id='w1', expert_id=e) - _load_expert(n.replace('.input_linear.','.experts.w13_'), - w3_param, w3_name, shard_id='w3', expert_id=e) + _load_expert(n.replace('.input_linear.', '.experts.w13_'), + w1_param, + w1_name, + shard_id='w1', + expert_id=e) + _load_expert(n.replace('.input_linear.', '.experts.w13_'), + w3_param, + w3_name, + shard_id='w3', + expert_id=e) elif n.endswith('.block_sparse_moe.output_linear.weight'): for e in range(p.size(0)): w2_name = n.replace( @@ -427,19 +438,23 @@ def _load_expert(n, p, name, shard_id, expert_id): f".block_sparse_moe.experts.{e}.w2.weight") w2_param = p[e] _load_expert(n.replace('.output_linear.', '.experts.w2_'), - w2_param, w2_name, shard_id='w2', expert_id=e) + w2_param, + w2_name, + shard_id='w2', + expert_id=e) elif n.endswith('.block_sparse_moe.router.layer.weight'): gate_name = n.replace('.block_sparse_moe.router.layer.weight', ".block_sparse_moe.gate.weight") _load(gate_name, p) else: - _load(n,p) + _load(n, p) return loaded_params -class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, - SupportsPP, IsHybrid, SupportsV0Only, SupportsQuant): +class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, + SupportsPP, IsHybrid, SupportsV0Only, + SupportsQuant): packed_modules_mapping = {} embedding_modules = { "embed_tokens": "input_embeddings", @@ -463,7 +478,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.scheduler_config = scheduler_config self.model = GraniteMoeHybridModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + prefix=maybe_prefix( + prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -502,11 +518,11 @@ def forward(self, **kwargs): if self.mamba_cache is None: num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) + self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( self.vllm_config, self.model_config.dtype, num_mamba_layers, *self._get_mamba_cache_shape()) - + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) From 171532cf496bb095fcb1b331a17820fff37ec1ee Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 30 Apr 2025 22:34:45 +0000 Subject: [PATCH 12/24] Skip tests until HF models become available Signed-off-by: Stanislaw Wozniak --- tests/models/decoder_only/language/test_granitemoehybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/language/test_granitemoehybrid.py b/tests/models/decoder_only/language/test_granitemoehybrid.py index 133931c8d439..8b393776524f 100644 --- a/tests/models/decoder_only/language/test_granitemoehybrid.py +++ b/tests/models/decoder_only/language/test_granitemoehybrid.py @@ -9,7 +9,7 @@ "ibm-research/granite-4.0-tiny-test", ] - +@pytest.mark.skip(reason="HF model URLs not available yet") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) From 586247a77838da6a19ce770e39555ddb63cccb25 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Wed, 30 Apr 2025 22:46:17 +0000 Subject: [PATCH 13/24] Pre-commit hook fix Signed-off-by: Stanislaw Wozniak --- tests/models/decoder_only/language/test_granitemoehybrid.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/decoder_only/language/test_granitemoehybrid.py b/tests/models/decoder_only/language/test_granitemoehybrid.py index 8b393776524f..39ae5b560bb4 100644 --- a/tests/models/decoder_only/language/test_granitemoehybrid.py +++ b/tests/models/decoder_only/language/test_granitemoehybrid.py @@ -9,6 +9,7 @@ "ibm-research/granite-4.0-tiny-test", ] + @pytest.mark.skip(reason="HF model URLs not available yet") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) From b298dc1b62202c7b48692c51110dff5d289307d8 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Thu, 1 May 2025 16:58:58 +0000 Subject: [PATCH 14/24] Integrated code review; Moved tests; Added model to test_hybrid.py Signed-off-by: Thomas Ortner --- .../generation/test_granitemoehybrid.py | 40 +++++++++++++++++++ .../models/language/generation/test_hybrid.py | 6 +++ 2 files changed, 46 insertions(+) create mode 100644 tests/models/language/generation/test_granitemoehybrid.py diff --git a/tests/models/language/generation/test_granitemoehybrid.py b/tests/models/language/generation/test_granitemoehybrid.py new file mode 100644 index 000000000000..f667e8ce90df --- /dev/null +++ b/tests/models/language/generation/test_granitemoehybrid.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from ...utils import check_logprobs_close + +# Path of the checkpoints +MODELS = [ + "ibm-research/granite-4.0-tiny-test", +] + + +@pytest.mark.skip(reason="HF model URLs not available yet") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_model_equivalence_to_hf_greedy( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +): + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 880967b4aed1..7ae928cc88f2 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -23,6 +23,9 @@ HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", + # NOTE: ibm-research/granite-4.0-tiny-test are skipped currently as + # the HF model URLs not available yet + "ibm-research/granite-4.0-tiny-test", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's # not compatible with pip-compile. @@ -46,6 +49,9 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: + if model == "ibm-research/granite-4.0-tiny-test": + pytest.skip(reason="HF model URLs not available yet") + with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) From 0937f7989cef883022f4fbb0f6a729efb90ed337 Mon Sep 17 00:00:00 2001 From: Thomas Ortner Date: Fri, 2 May 2025 05:21:59 +0000 Subject: [PATCH 15/24] Added missing files Signed-off-by: Thomas Ortner --- .../language/test_granitemoehybrid.py | 40 ------------------- .../generation/test_granitemoehybrid.py | 2 +- .../models/language/generation/test_hybrid.py | 5 +-- .../model_executor/models/granitemoehybrid.py | 20 +++++----- 4 files changed, 13 insertions(+), 54 deletions(-) delete mode 100644 tests/models/decoder_only/language/test_granitemoehybrid.py diff --git a/tests/models/decoder_only/language/test_granitemoehybrid.py b/tests/models/decoder_only/language/test_granitemoehybrid.py deleted file mode 100644 index 39ae5b560bb4..000000000000 --- a/tests/models/decoder_only/language/test_granitemoehybrid.py +++ /dev/null @@ -1,40 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from ...utils import check_logprobs_close - -# Path of the checkpoints -MODELS = [ - "ibm-research/granite-4.0-tiny-test", -] - - -@pytest.mark.skip(reason="HF model URLs not available yet") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_model_equivalence_to_hf_greedy( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -): - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/models/language/generation/test_granitemoehybrid.py b/tests/models/language/generation/test_granitemoehybrid.py index f667e8ce90df..39ae5b560bb4 100644 --- a/tests/models/language/generation/test_granitemoehybrid.py +++ b/tests/models/language/generation/test_granitemoehybrid.py @@ -23,7 +23,7 @@ def test_model_equivalence_to_hf_greedy( dtype: str, max_tokens: int, num_logprobs: int, -): +): with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 7ae928cc88f2..536aad9988f6 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -25,7 +25,7 @@ "ai21labs/Jamba-tiny-dev", # NOTE: ibm-research/granite-4.0-tiny-test are skipped currently as # the HF model URLs not available yet - "ibm-research/granite-4.0-tiny-test", + # "ibm-research/granite-4.0-tiny-test", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's # not compatible with pip-compile. @@ -49,9 +49,6 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - if model == "ibm-research/granite-4.0-tiny-test": - pytest.skip(reason="HF model URLs not available yet") - with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index ad16df835b33..dea9a0da3127 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -232,8 +232,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.o_proj") - self.position_embedding_type = config.position_embedding_type - if self.position_embedding_type == "rope": + if config.position_embedding_type == "rope": self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, @@ -244,6 +243,8 @@ def __init__( and config.rope_scaling is not None else None, is_neox_style=True, ) + else: + self.rotary_emb = None self.attn = Attention(self.num_heads, self.head_dim, @@ -263,7 +264,7 @@ def forward( key = self.k_proj(hidden_states)[0] value = self.v_proj(hidden_states)[0] - if self.position_embedding_type == "rope": + if self.rotary_emb is not None: query, key = self.rotary_emb(positions, query, key) hidden_states = self.attn(query, key, value) @@ -349,11 +350,11 @@ def forward( hidden_states = hidden_states * self.embedding_multiplier residual = None else: - assert intermediate_tensors is not None + if intermediate_tensors is None: + raise RuntimeError('Intermediate tensors may not be None!') hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - residual = None num_attn = 0 for i in range(len(self.layers)): layer = self.layers[i] @@ -463,18 +464,19 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, embedding_padding_modules = ["lm_head"] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "GraniteMoeHybrid currently does not support prefix caching" + if cache_config.enable_prefix_caching: + raise RuntimeError( + "GraniteMoeHybrid currently does not support prefix caching") self.quant_config = vllm_config.quant_config - - super().__init__() self.config = config self.scheduler_config = scheduler_config self.model = GraniteMoeHybridModel(vllm_config=vllm_config, From 8ff06918f1e270ef33b0949ef0c7e9bde6662b4d Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 5 May 2025 07:49:20 +0000 Subject: [PATCH 16/24] Added to supported_models.md and fixed URLs Signed-off-by: Stanislaw Wozniak --- docs/source/models/supported_models.md | 5 +++++ tests/models/language/generation/test_granitemoehybrid.py | 4 ++-- tests/models/language/generation/test_hybrid.py | 6 +++--- tests/models/registry.py | 2 +- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 831f9a86d1d4..8e9e05a41bb4 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -385,6 +385,11 @@ See [this page](#generative-models) for more information on how to use generativ * `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. * ✅︎ * ✅︎ + - * `GraniteMoeHybridForCausalLM` + * Granite 4.0 MoE Hybrid + * `ibm-granite/granite-4.0-tiny-preview`, etc. + * ✅︎ + * ✅︎ - * `GraniteMoeSharedForCausalLM` * Granite MoE Shared * `ibm-research/moe-7b-1b-active-shared-experts` (test model) diff --git a/tests/models/language/generation/test_granitemoehybrid.py b/tests/models/language/generation/test_granitemoehybrid.py index 39ae5b560bb4..c4697543391c 100644 --- a/tests/models/language/generation/test_granitemoehybrid.py +++ b/tests/models/language/generation/test_granitemoehybrid.py @@ -6,11 +6,11 @@ # Path of the checkpoints MODELS = [ - "ibm-research/granite-4.0-tiny-test", + "ibm-granite/granite-4.0-tiny-preview", ] -@pytest.mark.skip(reason="HF model URLs not available yet") +@pytest.mark.skip(reason="HF model is in the HF main yet") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 536aad9988f6..c9a72a2dce16 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -23,9 +23,9 @@ HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", - # NOTE: ibm-research/granite-4.0-tiny-test are skipped currently as - # the HF model URLs not available yet - # "ibm-research/granite-4.0-tiny-test", + # NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as + # the HF model is in the HF main yet + # "ibm-granite/granite-4.0-tiny-preview", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's # not compatible with pip-compile. diff --git a/tests/models/registry.py b/tests/models/registry.py index 6e01c766022c..86fd2778cb8c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -163,7 +163,7 @@ def check_available_online( {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-research/granite-4.0-tiny-test"), # noqa: E501 + "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), From 495fe339e12fcf5c269510a0738500a46dc54073 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 5 May 2025 09:53:51 +0200 Subject: [PATCH 17/24] Update docs/source/models/supported_models.md Co-authored-by: Cyrus Leung Signed-off-by: Stanislaw Wozniak --- docs/source/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 8e9e05a41bb4..e2bbb231d40a 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -385,7 +385,7 @@ See [this page](#generative-models) for more information on how to use generativ * `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. * ✅︎ * ✅︎ - - * `GraniteMoeHybridForCausalLM` +- * `GraniteMoeHybridForCausalLM` * Granite 4.0 MoE Hybrid * `ibm-granite/granite-4.0-tiny-preview`, etc. * ✅︎ From e114f1b7cc51e284a394856e673a3648b98111ed Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 5 May 2025 09:32:14 +0000 Subject: [PATCH 18/24] Commenting out failing HF registry test as code is not yet in HF used in CI Signed-off-by: Stanislaw Wozniak --- tests/models/registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 86fd2778cb8c..c08a0e0f7be1 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -163,7 +163,8 @@ def check_available_online( {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 + # NOTE: GraniteMoeHybridForCausalLM not yet in HF main, so test_registry_imports would fail + #"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), From ca1d9990ea99d2b379ac61eab29c67ec8bdb4167 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 5 May 2025 09:53:09 +0000 Subject: [PATCH 19/24] Temporarily marking model as is_available_online=False until it appears in HF Transformers Signed-off-by: Stanislaw Wozniak --- tests/models/registry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index c08a0e0f7be1..e0c69fb99c54 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -163,8 +163,7 @@ def check_available_online( {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - # NOTE: GraniteMoeHybridForCausalLM not yet in HF main, so test_registry_imports would fail - #"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 + "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview", is_available_online=False), # noqa: E501 "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), From dc044977ec2f1449568cfb7b088ab004f6a56c05 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 5 May 2025 12:20:37 +0000 Subject: [PATCH 20/24] Temporarily commenting out registration test until the model appears in HF Transformers Signed-off-by: Stanislaw Wozniak --- tests/models/registry.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index e0c69fb99c54..fa6df071b191 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -163,7 +163,9 @@ def check_available_online( {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview", is_available_online=False), # noqa: E501 + # NOTE: GraniteMoeHybridForCausalLM not yet in HF main, + # so test_registry_imports would fail + #"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), From a636074a26be4e8512fb409f612afbb3282f8c0a Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 5 May 2025 12:31:24 +0000 Subject: [PATCH 21/24] Fixing pre-commit error Signed-off-by: Stanislaw Wozniak --- tests/models/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index fa6df071b191..9490d9d27ab0 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -163,7 +163,7 @@ def check_available_online( {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - # NOTE: GraniteMoeHybridForCausalLM not yet in HF main, + # GraniteMoeHybridForCausalLM not yet in HF main, # so test_registry_imports would fail #"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 From 8c68720d15488ceac107f9dedcd6cb110ff011df Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 5 May 2025 14:44:36 +0000 Subject: [PATCH 22/24] Marking model with next minor min_transformers_version Signed-off-by: Stanislaw Wozniak --- tests/models/registry.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 9490d9d27ab0..fc84deeadca2 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -163,9 +163,8 @@ def check_available_online( {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - # GraniteMoeHybridForCausalLM not yet in HF main, - # so test_registry_imports would fail - #"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 + "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview", # noqa: E501 + min_transformers_version="4.52.0"), # noqa: E501 "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), From 6c8e664e56a6a781151555b25a8079db7e07c68d Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 5 May 2025 18:57:00 +0200 Subject: [PATCH 23/24] Update tests/models/language/generation/test_hybrid.py Co-authored-by: Tyler Michael Smith Signed-off-by: Stanislaw Wozniak --- tests/models/language/generation/test_hybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index c9a72a2dce16..9b7a42acece5 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -24,7 +24,7 @@ HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", # NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as - # the HF model is in the HF main yet + # it is not yet available in huggingface transformers # "ibm-granite/granite-4.0-tiny-preview", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's From 95ede004682ec07a9898356eade7af160dae8af6 Mon Sep 17 00:00:00 2001 From: Stanislaw Wozniak Date: Mon, 5 May 2025 17:00:46 +0000 Subject: [PATCH 24/24] Clarifying comments Signed-off-by: Stanislaw Wozniak --- tests/models/language/generation/test_granitemoehybrid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/language/generation/test_granitemoehybrid.py b/tests/models/language/generation/test_granitemoehybrid.py index c4697543391c..da3f5e1100bf 100644 --- a/tests/models/language/generation/test_granitemoehybrid.py +++ b/tests/models/language/generation/test_granitemoehybrid.py @@ -10,7 +10,8 @@ ] -@pytest.mark.skip(reason="HF model is in the HF main yet") +@pytest.mark.skip( + reason="Granite 4.0 is not yet available in huggingface transformers") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64])