From 110903a5379510ef0f4bfa7fe31feab403f9a0ec Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 30 Apr 2025 13:10:36 -0400 Subject: [PATCH 1/6] lint --- contributed/models/qwen3/modeling_qwen.py | 996 ++++++++++++++++++++++ 1 file changed, 996 insertions(+) create mode 100644 contributed/models/qwen3/modeling_qwen.py diff --git a/contributed/models/qwen3/modeling_qwen.py b/contributed/models/qwen3/modeling_qwen.py new file mode 100644 index 0000000..cc5eb7f --- /dev/null +++ b/contributed/models/qwen3/modeling_qwen.py @@ -0,0 +1,996 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3 model for NXD inference.""" + +import copy +import gc +import logging +from typing import List, Optional, Tuple, Type + + +from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + move_heads_front, +) + +import torch +from neuronx_distributed.parallel_layers import parallel_state # noqa: E402 +from neuronx_distributed.parallel_layers.layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) +from neuronx_distributed.parallel_layers.utils import get_padding_length +from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + QuantizedDtype, +) +from neuronx_distributed.quantization.quantization_layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 + QuantizedColumnParallel, + QuantizedRowParallel, +) + +from neuronxcc.nki._private_kernels.mlp import ( + mlp_fused_add_isa_kernel, + mlp_isa_kernel, + quant_mlp_fused_add_isa_kernel, + quant_mlp_isa_kernel, +) +from neuronxcc.nki._private_kernels.rmsnorm import rmsnorm_quant_isa_kernel +from neuronxcc.nki.language import nc +from torch import nn +from torch_neuronx.xla_impl.ops import nki_jit +from transformers import Qwen3ForCausalLM +from transformers.activations import ACT2FN +from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm, Qwen3RotaryEmbedding + +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig # noqa: E402 +from neuronx_distributed_inference.models.model_base import ( # noqa: E402 + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.gqa import ( # noqa: E402 + BaseGroupQueryAttention, +) +from neuronx_distributed_inference.modules.attention.utils import ( + preprocess_quantized_linear_layer, + transpose_parallel_linear_layer, +) +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.flashdecode.utils import ( + calculate_num_cores_per_group, +) +from neuronx_distributed_inference.modules.lora_serving.lora_module import ( + is_lora_module, +) +from neuronx_distributed_inference.utils.distributed import get_tp_group + +_Qwen3_MODULE_MAP = {} + +logger = logging.getLogger("Neuron") + + +def get_rmsnorm_cls(): + # Initialize to the appropriate implementation of RMSNorm + # If infer on NXD -> CustomRMSNorm + # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) + return ( + CustomRMSNorm + if parallel_state.model_parallel_is_initialized() + else Qwen3RMSNorm + ) + + +def preshard_hook_fn( + module: torch.nn.Module, model_state_dict: dict, prefix: str +) -> bool: + if isinstance(module, (BaseGroupQueryAttention,)): + return module.preshard_hook(model_state_dict, prefix) + + return False + + +def _register_module(key: str, cls: Type[nn.Module]): + _Qwen3_MODULE_MAP[key] = cls + + +def register_module(key: str): + """ + Register a module for use in NeuronQwen3. + Arguments: + key: String used to identify the module + Example: + @register_module("NeuronQwen3Attention") + class NeuronQwen3Attention(nn.Module): + ... + """ + + def inner(cls: Type[nn.Module]): + _register_module(key, cls) + return cls + + return inner + + +def convert_state_dict_to_fused_qkv(Qwen3_state_dict, cfg: InferenceConfig): + """ + This function concats the qkv weights to a Wqkv weight for fusedqkv, and deletes the qkv weights. + """ + for l in range(cfg.num_hidden_layers): # noqa: E741 + Qwen3_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( + [ + Qwen3_state_dict[f"layers.{l}.self_attn.q_proj.weight"], + Qwen3_state_dict[f"layers.{l}.self_attn.k_proj.weight"], + Qwen3_state_dict[f"layers.{l}.self_attn.v_proj.weight"], + ], + ) + del Qwen3_state_dict[f"layers.{l}.self_attn.q_proj.weight"] + del Qwen3_state_dict[f"layers.{l}.self_attn.k_proj.weight"] + del Qwen3_state_dict[f"layers.{l}.self_attn.v_proj.weight"] + + gc.collect() + + return Qwen3_state_dict + + +class Qwen3InferenceConfig(InferenceConfig): + def add_derived_config(self): + self.num_cores_per_group = 1 + if self.neuron_config.flash_decoding_enabled: + num_attn_heads, num_kv_heads = ( + self.num_attention_heads, + self.num_key_value_heads, + ) + self.num_cores_per_group = calculate_num_cores_per_group( + num_attn_heads, num_kv_heads, self.neuron_config.tp_degree + ) + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "pad_token_id", + "vocab_size", + "max_position_embeddings", + "rope_theta", + "rms_norm_eps", + "hidden_act", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return NeuronConfig + + +class NeuronQwen3MLP(nn.Module): + """ + This class just replace the linear layers (gate_proj, up_proj and down_proj) with column and row parallel layers + """ + + def __init__(self, config: InferenceConfig): + super().__init__() + self.config = config + self.neuron_config = config.neuron_config + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.act_fn = ACT2FN[config.hidden_act] + + self.sequence_parallel_enabled = getattr( + self.neuron_config, "sequence_parallel_enabled", False + ) + self.sequence_dimension = 1 if self.sequence_parallel_enabled else None + self.rms_norm_eps = config.rms_norm_eps + self.mlp_kernel_enabled = self.neuron_config.mlp_kernel_enabled + self.quantized_mlp_kernel_enabled = ( + self.neuron_config.quantized_mlp_kernel_enabled + ) + self.rmsnorm_quantize_kernel_enabled = ( + self.neuron_config.rmsnorm_quantize_kernel_enabled + ) + self.logical_neuron_cores = self.neuron_config.logical_neuron_cores + mlp_bias = getattr(config, "mlp_bias", False) + if parallel_state.model_parallel_is_initialized(): + if self.quantized_mlp_kernel_enabled: + # Quantized MLP kernels expect intermediate size to be multiple of 128, so we need to pad + tp_degree = self.neuron_config.tp_degree + self.intermediate_size += ( + get_padding_length(self.intermediate_size // tp_degree, 128) + * tp_degree + ) + logger.debug(f"Quantized intermediate_size: {self.intermediate_size}") + + quantization_type = QuantizationType( + self.neuron_config.quantization_type + ) + quantized_dtype = QuantizedDtype.F8E4M3 + self.gate_proj = QuantizedColumnParallel( + input_size=self.hidden_size, + output_size=self.intermediate_size, + bias=mlp_bias, + gather_output=False, + sequence_parallel_enabled=False, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + quantization_type=quantization_type, + tensor_model_parallel_group=get_tp_group(config), + ) + self.up_proj = QuantizedColumnParallel( + input_size=self.hidden_size, + output_size=self.intermediate_size, + bias=mlp_bias, + gather_output=False, + sequence_parallel_enabled=False, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + quantization_type=quantization_type, + tensor_model_parallel_group=get_tp_group(config), + ) + self.down_proj = QuantizedRowParallel( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=mlp_bias, + quantization_type=quantization_type, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + sequence_parallel_enabled=False, + quantization_per_channel_axis=0, + tensor_model_parallel_group=get_tp_group(config), + ) + + else: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=mlp_bias, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + tensor_model_parallel_group=get_tp_group(config), + reduce_dtype=config.neuron_config.rpl_reduce_dtype, + ) + + if self.mlp_kernel_enabled: + if self.quantized_mlp_kernel_enabled: + preprocess_quantized_linear_layer(self.gate_proj) + preprocess_quantized_linear_layer(self.up_proj) + preprocess_quantized_linear_layer(self.down_proj) + + else: + # Transpose the weights to the layout expected by kernels + self.gate_proj.weight = transpose_parallel_linear_layer( + self.gate_proj.weight + ) + self.up_proj.weight = transpose_parallel_linear_layer( + self.up_proj.weight + ) + self.down_proj.weight = transpose_parallel_linear_layer( + self.down_proj.weight + ) + + else: + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=mlp_bias + ) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=mlp_bias + ) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=mlp_bias + ) + + def _kernel_enabled_quantized_mlp( + self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids + ): + grid = (nc(self.logical_neuron_cores),) + fused_residual = residual is not None + logger.debug( + f"MLP: quantized kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + ) + + # Can't do residual add in the kernel if SP is enabled + if fused_residual: + assert not self.sequence_parallel_enabled, ( + "Quantized MLP cannot have both fused residual add and sequence parallel RMSnorm!" + ) + # Using fused residual add + _mlp_fwd_call = nki_jit()(quant_mlp_fused_add_isa_kernel) + else: + _mlp_fwd_call = nki_jit()(quant_mlp_isa_kernel) + + # Handle SP RMSnorm + x_orig_dtype = x.dtype + if self.sequence_parallel_enabled: + # This RMSNormQuant kernel will do quantization inside, so we pass the + # lower_bound for clipping. + # If we don't use this kernel, the MLP kernel below will do the + # quantization, so we also pass lower_bound to that kernel. + if self.rmsnorm_quantize_kernel_enabled: + logger.debug( + "Running Quantized MLP kernel with sequence-parallel RMSnorm-Quantize kernel!" + ) + _rmsnorm_quant_fwd_call = nki_jit()(rmsnorm_quant_isa_kernel) + quant_rmsnorm_out = torch.zeros( + size=( + x.shape[0], # batch size + x.shape[1], # sequence length + x.shape[2] + 4, # hidden size + 4 bytes for packing fp32 scale + ), + dtype=torch.int8, + device=x.device, + ) + ln_w = rmsnorm.weight.unsqueeze(0) + lower_bound = self.quantized_kernel_lower_bound + _rmsnorm_quant_fwd_call[grid]( + x, ln_w, lower_bound, quant_rmsnorm_out, kernel_name="QuantOnly" + ) + x = gather_from_sequence_parallel_region( + quant_rmsnorm_out, + self.sequence_dimension, + process_group=get_tp_group(self.config), + ) + + else: + logger.debug( + "Running Quantized MLP kernel with external (native compiler) sequence-parallel RMSnorm!" + ) + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + + # Build output tensor + output_tensor_seqlen = x.shape[1] + if fused_residual: + # seqlen dim is doubled to store the residual add output + output_tensor_seqlen *= 2 + + output_tensor = torch.zeros( + size=( + x.shape[0], # batch size + output_tensor_seqlen, + self.hidden_size, # hidden size + ), + dtype=x_orig_dtype, + device=x.device, + ) + + # Grab weights + # all weights of the layers are stored in (out, in) shape + # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] + ln_w = rmsnorm.weight.unsqueeze(0) + gate_w = self.gate_proj.weight.data + gate_w_scale = self.gate_proj.weight_scale + up_w = self.up_proj.weight.data + up_w_scale = self.up_proj.weight_scale + down_w = self.down_proj.weight.data + down_w_scale = self.down_proj.weight_scale + lower_bound = self.quantized_kernel_lower_bound + + if fused_residual: + _mlp_fwd_call[grid]( + x, # attn_output + residual, # hidden + ln_w, # ln_w + gate_w, # gate_w + gate_w_scale, + up_w, # up_w + up_w_scale, + down_w, # down_w + down_w_scale, + lower_bound, + output_tensor, # out + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + store_add=True, + ) + original_seqlen = x.shape[1] + residual = output_tensor[:, original_seqlen:, :] + output_tensor = output_tensor[:, :original_seqlen, :] + else: + _mlp_fwd_call[grid]( + x, # hidden + # should be fine to pass gamma is as a dummy even if not using fused rmsnorm + ln_w, + gate_w, # gate_w + gate_w_scale, + up_w, # up_w + up_w_scale, + down_w, # down_w + down_w_scale, + lower_bound, + output_tensor, # out + # Run RMSNorm inside the kernel if NOT using SP rmsnorm + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + ) + residual = None + + # All-reduce or reduce-scatter, depending on whether SP is enabled + if self.sequence_parallel_enabled: + output_tensor = reduce_scatter_to_sequence_parallel_region( + output_tensor, + self.sequence_dimension, + process_group=get_tp_group(self.config), + ) + else: + output_tensor = reduce_from_tensor_model_parallel_region(output_tensor) + + logger.debug(f"Quantized MLP output shape {output_tensor.shape}") + return (output_tensor, residual) + + def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): + fused_residual = residual is not None + logger.debug( + f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + ) + + # Choose which kernel to call + if fused_residual: + assert not self.sequence_parallel_enabled, ( + "MLP kernel cannot have both fused residual add and sequence parallel RMSnorm!" + ) + # Using fused residual add + _mlp_fwd_call = nki_jit()(mlp_fused_add_isa_kernel) + else: + _mlp_fwd_call = nki_jit()(mlp_isa_kernel) + + if self.sequence_parallel_enabled: + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + + # Build output tensor + output_tensor_seqlen = x.shape[1] + if fused_residual: + # seqlen dim is doubled to store the residual add output + output_tensor_seqlen *= 2 + + output_tensor = torch.zeros( + size=( + x.shape[0], # batch size + output_tensor_seqlen, + self.hidden_size, # hidden size + ), + dtype=x.dtype, + device=x.device, + ) + + # Grab weights + # all weights of the layers are stored in (out, in) shape + # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] + ln_w = rmsnorm.weight.unsqueeze(0) + gate_w = self.gate_proj.weight.data + up_w = self.up_proj.weight.data + down_w = self.down_proj.weight.data + + grid = (nc(self.logical_neuron_cores),) + + if fused_residual: + _mlp_fwd_call[grid]( + x, # attn_output + residual, # hidden + ln_w, # ln_w + gate_w, # gate_w + up_w, # up_w + down_w, # down_w + output_tensor, # out + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + store_add=True, + ) + original_seqlen = x.shape[1] + residual = output_tensor[:, original_seqlen:, :] + output_tensor = output_tensor[:, :original_seqlen, :] + else: + _mlp_fwd_call[grid]( + x, # hidden + # should be fine to pass gamma is as a dummy even if not using fused rmsnorm + ln_w, + gate_w, + up_w, + down_w, + output_tensor, # out + # Run RMSNorm inside the kernel if NOT using SP rmsnorm + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + ) + residual = None + + # All-reduce or reduce-scatter, depending on whether SP is enabled + if self.sequence_parallel_enabled: + output_tensor = reduce_scatter_to_sequence_parallel_region( + output_tensor, + self.sequence_dimension, + process_group=get_tp_group(self.config), + ) + else: + output_tensor = reduce_from_tensor_model_parallel_region( + output_tensor, process_group=get_tp_group(self.config) + ) + + logger.debug(f"MLP output shape {output_tensor.shape}") + return (output_tensor, residual) + + def _native_mlp(self, x, rmsnorm, adapter_ids=None): + logger.debug("MLP: native compiler") + # all-gather is done here instead of CPL layers to + # avoid 2 all-gathers from up and gate projections + if self.sequence_parallel_enabled: + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + + gate_proj_output = ( + self.gate_proj(x) + if not is_lora_module(self.gate_proj) + else self.gate_proj(x, adapter_ids) + ) + up_proj_output = ( + self.up_proj(x) + if not is_lora_module(self.up_proj) + else self.up_proj(x, adapter_ids) + ) + down_proj_input = self.act_fn(gate_proj_output) * up_proj_output + output = ( + self.down_proj(down_proj_input) + if not is_lora_module(self.up_proj) + else self.down_proj(down_proj_input, adapter_ids) + ) + logger.debug(f"MLP output shape {output.shape}") + return output + + def forward(self, x, rmsnorm=None, residual=None, adapter_ids=None): + """ + If residual is passed in, will fuse its add into the MLP kernel + Returns a tuple of (output, residual), where residual is the output of the residual add + """ + if self.mlp_kernel_enabled: + fused_rmsnorm = not self.sequence_parallel_enabled + # Quantized MLP kernel + if self.quantized_mlp_kernel_enabled: + return self._kernel_enabled_quantized_mlp( + x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids + ) + # MLP kernel + return self._kernel_enabled_mlp( + x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids + ) + else: + # No kernel + return (self._native_mlp(x, rmsnorm, adapter_ids=adapter_ids), None) + + +@register_module("NeuronQwen3Attention") +class NeuronQwen3Attention(NeuronAttentionBase): + """ + Compared with Qwen3Attention, this class just + 1. replaces the q_proj, k_proj, v_proj with column parallel layer + 2. replaces the o_proj with row parallel layer + 3. update self.num_head to be self.num_head / tp_degree + 4. update self.num_key_value_heads to be self.num_key_value_heads / tp_degree + 5. update forward() method to adjust to changes from self.num_head + """ + + def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None): + super().__init__(tensor_model_parallel_group=tensor_model_parallel_group) + + self.config = config + self.neuron_config = config.neuron_config + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.padding_side = config.neuron_config.padding_side + self.torch_dtype = config.neuron_config.torch_dtype + self.is_medusa = config.neuron_config.is_medusa + self.flash_decoding_enabled = config.neuron_config.flash_decoding_enabled + self.num_cores_per_group = config.num_cores_per_group + self.bias = getattr(config, "attention_bias", False) + self.rpl_reduce_dtype = config.neuron_config.rpl_reduce_dtype + self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled + self.rms_norm_eps = config.rms_norm_eps + + self.q_norm = Qwen3RMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + if parallel_state.model_parallel_is_initialized(): + self.tp_degree = self.config.neuron_config.tp_degree + else: + self.tp_degree = 1 + + self.fused_qkv = config.neuron_config.fused_qkv + self.clip_qkv = None + + self.sequence_parallel_enabled = self.neuron_config.sequence_parallel_enabled + self.sequence_dimension = 1 if self.sequence_parallel_enabled else None + logger.debug( + f"Hello from NeuronQwen3Attention init! Is SP enabled? {self.sequence_parallel_enabled}. Dim? {self.sequence_dimension}" + ) + + self.init_gqa_properties() + self.init_rope() + + def prep_qkv_tensors( + self, + position_ids, + hidden_states, + past_key_value, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + ): + """take care of the shape, layout, group query, custom position encoding, etc.""" + Q, K, V = self.qkv_proj( + hidden_states=hidden_states, rmsnorm=rmsnorm, adapter_ids=adapter_ids + ) + + # Divide hidden_dim across heads for MHA + # Change layout: BSHD -> BHSD + bsz, q_len, _ = hidden_states.size() + if self.sequence_parallel_enabled: + q_len *= self.tensor_model_parallel_group.size() + + Q = move_heads_front( + Q, bsz, q_len, self.num_heads, self.head_dim, layernorm=self.q_norm + ) + K = move_heads_front( + K, + bsz, + q_len, + self.num_key_value_heads, + self.head_dim, + layernorm=self.k_norm, + ) + V = move_heads_front( + V, bsz, q_len, self.num_key_value_heads, self.head_dim, layernorm=None + ) + + # Rotate Q and K + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + Q, K = apply_rotary_pos_emb(Q, K, cos_cache, sin_cache) + + return Q, K, V, cos_cache, sin_cache + + def init_rope(self): + self.rotary_emb = Qwen3RotaryEmbedding(self.config) + + +class NeuronQwen3DecoderLayer(nn.Module): + """ + Just replace the attention with the NXD version, and MLP with the NXD version + """ + + def __init__(self, config: InferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + # self.self_attn = _Qwen3_MODULE_MAP[config.neuron_config.attn_cls]( + self.self_attn = NeuronQwen3Attention( + config=config, tensor_model_parallel_group=get_tp_group(config) + ) + self.mlp = NeuronQwen3MLP(config) + logger.debug( + f"Instantiating RMSNorm modules with hidden size {config.hidden_size} and EPS {config.rms_norm_eps}" + ) + self.input_layernorm = None + if ( + not config.neuron_config.is_eagle_draft + or config.neuron_config.enable_eagle_draft_input_norm + ): + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled + self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled + self.rmsnorm_quantize_kernel_enabled = ( + config.neuron_config.rmsnorm_quantize_kernel_enabled + ) + self.mlp_kernel_fuse_residual_add = ( + config.neuron_config.mlp_kernel_fuse_residual_add + ) + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + adapter_ids=None, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + + # RMSNorm (fused with QKV kernel when SP is disabled) + if ( + not self.qkv_kernel_enabled or self.sequence_parallel_enabled + ) and self.input_layernorm: + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + adapter_ids=adapter_ids, + rmsnorm=self.input_layernorm, + **kwargs, + ) + + if self.mlp_kernel_enabled and self.mlp_kernel_fuse_residual_add: + assert not self.sequence_parallel_enabled, ( + "mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled" + ) + # First residual add handled in the MLP kernel + hidden_states, residual = self.mlp( + hidden_states, + rmsnorm=self.post_attention_layernorm, + residual=residual, + adapter_ids=adapter_ids, + ) + else: + hidden_states = residual + hidden_states + residual = hidden_states + # RMSNorm (fused with QKV kernel when SP is disabled) + if not self.mlp_kernel_enabled or self.sequence_parallel_enabled: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp( + hidden_states, + rmsnorm=self.post_attention_layernorm, + adapter_ids=adapter_ids, + ) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value, cos_cache, sin_cache) + return outputs + + +class ResBlock(nn.Module): + """ + A Residual Block module. + This module performs a linear transformation followed by a SiLU activation, + and then adds the result to the original input, creating a residual connection. + Args: + hidden_size (int): The size of the hidden layers in the block. + """ + + def __init__(self, hidden_size): + super().__init__() + self.linear = nn.Linear(hidden_size, hidden_size) + # Initialize as an identity mapping + torch.nn.init.zeros_(self.linear.weight) + # Use SiLU activation to keep consistent with the Qwen3 model + self.act = nn.SiLU() + + def forward(self, x): + """ + Forward pass of the ResBlock. + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: Output after the residual connection and activation. + """ + return x + self.act(self.linear(x)) + + +class NeuronQwen3Model(NeuronBaseModel): + """ + The neuron version of the Qwen3Model + """ + + def setup_attr_for_model(self, config: InferenceConfig): + # Needed for init_inference_optimization() + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if parallel_state.model_parallel_is_initialized(): + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=not config.neuron_config.vocab_parallel, + sequence_parallel_enabled=False, + pad=True, + tensor_model_parallel_group=get_tp_group(config), + use_spmd_rank=config.neuron_config.vocab_parallel, + ) + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + pad=True, + tensor_model_parallel_group=get_tp_group(config), + ) + else: + self.embed_tokens = nn.Embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + ) + self.lm_head = nn.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + ) + + # In the target fp8 checkpoint, the 1st and last + # layers are not using fp8. + updated_configs = [] + for i in range(config.num_hidden_layers): + # TODO: Remove hardcoded code to have non-quantized MLPs for first and last decoder block + if i == 0 or i == config.num_hidden_layers - 1: + non_quant_config = copy.deepcopy(config) + non_quant_config.neuron_config.quantized_mlp_kernel_enabled = False + updated_configs.append(non_quant_config) + else: + updated_configs.append(config) + self.layers = nn.ModuleList( + [NeuronQwen3DecoderLayer(conf) for conf in updated_configs] + ) + if not config.neuron_config.is_eagle_draft: + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + if config.neuron_config.is_eagle_draft: + fc_bias = getattr(config, "fc_bias", False) + self.fc = ColumnParallelLinear( + config.hidden_size * 2, + config.hidden_size, + bias=fc_bias, + gather_output=True, + ) + self.is_medusa = config.neuron_config.is_medusa + self.num_medusa_heads = config.neuron_config.num_medusa_heads + self.medusa_speculation_length = config.neuron_config.medusa_speculation_length + + if self.is_medusa: + if parallel_state.model_parallel_is_initialized(): + medusa_head_cls = ColumnParallelLinear + else: + medusa_head_cls = nn.Linear + for i in range(self.num_medusa_heads): + medusa_head = nn.Sequential( + *([ResBlock(config.hidden_size)] * 1), + medusa_head_cls( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ), + ) + setattr(self, f"medusa_head_{i}", medusa_head) + + +class NeuronQwen3ForCausalLM(NeuronBaseForCausalLM): + """ + This class extends Qwen3ForCausalLM create traceable + blocks for Neuron. + Args: + Qwen3ForCausalLM (_type_): _description_ + """ + + _model_cls = NeuronQwen3Model + + @staticmethod + def load_hf_model(model_path): + return Qwen3ForCausalLM.from_pretrained(model_path) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: InferenceConfig + ) -> dict: + """This function should be over-ridden in child classes as needed""" + neuron_config = config.neuron_config + if neuron_config.fused_qkv: + state_dict = convert_state_dict_to_fused_qkv(state_dict, config) + + if neuron_config.vocab_parallel: + # TODO: this hack can be removed after replication_id is ready to use + state_dict["embed_tokens.rank_util.rank"] = torch.arange( + 0, neuron_config.local_ranks_size + ) + + # to facilitate rank usage in attention + num_layers = config.num_hidden_layers + tp_degree = neuron_config.tp_degree + for i in range(num_layers): + state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + # to facilitate rank usage in base model + state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + return state_dict + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() + + @classmethod + def get_config_cls(cls): + return Qwen3InferenceConfig From 2e783eaea508a3d92e5831356336ecdec950adfb Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 30 Apr 2025 14:15:28 -0400 Subject: [PATCH 2/6] add inference nb --- contributed/models/qwen3/qwen-3-test.ipynb | 358 +++++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 contributed/models/qwen3/qwen-3-test.ipynb diff --git a/contributed/models/qwen3/qwen-3-test.ipynb b/contributed/models/qwen3/qwen-3-test.ipynb new file mode 100644 index 0000000..cd6cc4e --- /dev/null +++ b/contributed/models/qwen3/qwen-3-test.ipynb @@ -0,0 +1,358 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "libneuronxla 2.2.1630.0\n", + "neuronx-cc 2.17.194.0+d312836f\n", + "neuronx-distributed 0.11.0\n", + "neuronx-distributed-inference 0.2.0\n", + "torch-neuronx 2.5.1.2.6.0\n" + ] + } + ], + "source": [ + "!pip list | grep neuron" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoTokenizer, GenerationConfig\n", + "from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig\n", + "from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter, load_pretrained_config" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model_path = \"/home/ubuntu/model_hf_qwen/qwen/\"\n", + "traced_model_path = \"/home/ubuntu/traced_model_qwen/qwen/\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_path,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Qwen3ForCausalLM(\n", + " (model): Qwen3Model(\n", + " (embed_tokens): Embedding(151936, 4096)\n", + " (layers): ModuleList(\n", + " (0-35): 36 x Qwen3DecoderLayer(\n", + " (self_attn): Qwen3Attention(\n", + " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (q_norm): Qwen3RMSNorm((128,), eps=1e-06)\n", + " (k_norm): Qwen3RMSNorm((128,), eps=1e-06)\n", + " )\n", + " (mlp): Qwen3MLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=12288, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=12288, bias=False)\n", + " (down_proj): Linear(in_features=12288, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)\n", + " (post_attention_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)\n", + " )\n", + " )\n", + " (norm): Qwen3RMSNorm((4096,), eps=1e-06)\n", + " (rotary_emb): Qwen3RotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=151936, bias=False)\n", + ")" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import snapshot_download\n", + "\n", + "snapshot_download(\"Qwen/Qwen3-8B\", local_dir=model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from modeling_qwen import Qwen3InferenceConfig, NeuronQwen3ForCausalLM\n", + "\n", + "def run_qwen3_compile():\n", + " # Initialize configs and tokenizer.\n", + " tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"right\")\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + " generation_config = GenerationConfig.from_pretrained(model_path)\n", + " generation_config_kwargs = {\n", + " \"do_sample\": True,\n", + " \"top_k\": 1,\n", + " \"pad_token_id\": tokenizer.pad_token_id,\n", + " }\n", + " generation_config.update(**generation_config_kwargs)\n", + " \n", + " neuron_config = NeuronConfig(\n", + " tp_degree=8,\n", + " batch_size=1,\n", + " max_context_length=128,\n", + " seq_len=256,\n", + " on_device_sampling_config=OnDeviceSamplingConfig(top_k=5),\n", + " enable_bucketing=True,\n", + " context_encoding_buckets=[128],\n", + " token_generation_buckets=[256],\n", + " flash_decoding_enabled=False,\n", + " torch_dtype=torch.bfloat16,\n", + " fused_qkv=False,\n", + " attn_kernel_enabled=True,\n", + " attn_cls=\"NeuronQwen3Attention\"\n", + " )\n", + " config = Qwen3InferenceConfig(\n", + " neuron_config,\n", + " load_config=load_pretrained_config(model_path),\n", + " )\n", + " \n", + " # Compile and save model.\n", + " print(\"\\nCompiling and saving model...\")\n", + " model = NeuronQwen3ForCausalLM(model_path, config)\n", + " model.compile(traced_model_path)\n", + " tokenizer.save_pretrained(traced_model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_qwen3_compile()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from modeling_qwen import Qwen3InferenceConfig, NeuronQwen3ForCausalLM\n", + "\n", + "model = NeuronQwen3ForCausalLM(traced_model_path)\n", + "model.load(traced_model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = model.get_config_cls()\n", + "config.get_neuron_config_cls()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "32" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.config.num_attention_heads" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.config.num_key_value_heads" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4096" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.config.hidden_size" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(traced_model_path)\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "generation_config = GenerationConfig.from_pretrained(model_path)\n", + "generation_config_kwargs = {\n", + " \"do_sample\": False,\n", + " \"temperature\": 0.9,\n", + " \"top_k\": 5,\n", + " \"pad_token_id\": tokenizer.pad_token_id,\n", + "}\n", + "generation_config.update(**generation_config_kwargs)\n", + "generation_model = HuggingFaceGenerationAdapter(model)\n", + "messages = [{'role': 'user', 'content': \"What's your name?\"}]\n", + "text = tokenizer.apply_chat_template(\n", + " messages,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.\n", + ")\n", + "inputs = tokenizer([text], return_tensors=\"pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\nGenerating outputs...\")\n", + "outputs = generation_model.generate(\n", + " **inputs,\n", + " max_new_tokens=512\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "thinking content: \n", + "content: My name is Qwen, and I'm a large language model developed by Alibaba Cloud. How can I assist you today?\n" + ] + } + ], + "source": [ + "output_ids = outputs[0][len(inputs.input_ids[0]):].tolist() \n", + "\n", + "# parsing thinking content\n", + "try:\n", + " # rindex finding 151668 ()\n", + " index = len(output_ids) - output_ids[::-1].index(151668)\n", + "except ValueError:\n", + " index = 0\n", + "\n", + "thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(\"\\n\")\n", + "content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip(\"\\n\")\n", + "\n", + "print(\"thinking content:\", thinking_content)\n", + "print(\"content:\", content)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "model.reset()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aws_neuronx_venv_pytorch_2_5_nxd_inference", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From a75e7e27f405b337b3c240b8ceb19e027d14f121 Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 14 May 2025 10:24:33 -0400 Subject: [PATCH 3/6] logit val / cleanup --- contributed/models/qwen3/qwen-3-test.ipynb | 1113 ++++++++++++++++++-- 1 file changed, 1014 insertions(+), 99 deletions(-) diff --git a/contributed/models/qwen3/qwen-3-test.ipynb b/contributed/models/qwen3/qwen-3-test.ipynb index cd6cc4e..bbb60dd 100644 --- a/contributed/models/qwen3/qwen-3-test.ipynb +++ b/contributed/models/qwen3/qwen-3-test.ipynb @@ -43,66 +43,6 @@ "traced_model_path = \"/home/ubuntu/traced_model_qwen/qwen/\"" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoModelForCausalLM\n", - "\n", - "model = AutoModelForCausalLM.from_pretrained(\n", - " model_path,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Qwen3ForCausalLM(\n", - " (model): Qwen3Model(\n", - " (embed_tokens): Embedding(151936, 4096)\n", - " (layers): ModuleList(\n", - " (0-35): 36 x Qwen3DecoderLayer(\n", - " (self_attn): Qwen3Attention(\n", - " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (q_norm): Qwen3RMSNorm((128,), eps=1e-06)\n", - " (k_norm): Qwen3RMSNorm((128,), eps=1e-06)\n", - " )\n", - " (mlp): Qwen3MLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=12288, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=12288, bias=False)\n", - " (down_proj): Linear(in_features=12288, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)\n", - " (post_attention_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)\n", - " )\n", - " )\n", - " (norm): Qwen3RMSNorm((4096,), eps=1e-06)\n", - " (rotary_emb): Qwen3RotaryEmbedding()\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=151936, bias=False)\n", - ")" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model" - ] - }, { "cell_type": "code", "execution_count": null, @@ -195,60 +135,27 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "32" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "model.config.num_attention_heads" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "8" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "model.config.num_key_value_heads" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4096" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "model.config.hidden_size" ] @@ -332,6 +239,1014 @@ "source": [ "model.reset()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run Benchmarks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dir = '/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/'\n", + "!cp modeling_qwen.py {dir}" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost\n", + "WARNING:root:Found libneuronpjrt.so. Setting PJRT_DEVICE=NEURON.\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed.modules.moe.blockwise import (\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed.modules.moe.blockwise import (\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed.modules.moe.blockwise import (\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/attention/utils.py:14: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.modules.custom_calls import neuron_cumsum\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: Set seed for `privateuseone` device does not take effect, please add API's `_is_in_bad_fork` and `manual_seed_all` to `privateuseone` device module.\n", + " return fn(*args, **kwargs)\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_model.py:12: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.modules.attention.gqa import GQA, GroupQueryAttention_QKV\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/dbrx/modeling_dbrx.py:38: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py:22: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.models.dbrx.modeling_dbrx import NeuronDbrxForCausalLM\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py:24: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.models.mixtral.modeling_mixtral import NeuronMixtralForCausalLM\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/mllama/modeling_mllama.py:72: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from .modeling_mllama_vision import NeuronMllamaVisionModel # noqa: E402\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/utils/accuracy.py:29: UserWarning: Intel extension for pytorch not found. For faster CPU references install `intel-extension-for-pytorch`.\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: Set seed for `privateuseone` device does not take effect, please add API's `_is_in_bad_fork` and `manual_seed_all` to `privateuseone` device module.\n", + " return fn(*args, **kwargs)\n", + "Loading configs...\n", + "WARNING:root:NeuronConfig init: Unexpected keyword arguments: {'model_type': 'qwen3', 'task_type': 'causal-lm', 'model_path': '/home/ubuntu/model_hf_qwen/qwen/', 'compiled_model_path': '/home/ubuntu/traced_model_qwen/qwen/logit', 'benchmark': True, 'check_accuracy_mode': , 'divergence_difference_tol': 0.001, 'prompts': ['To be, or not to be'], 'top_k': 1, 'top_p': 1.0, 'temperature': 1.0, 'do_sample': False, 'dynamic': False, 'pad_token_id': 151645, 'on_device_sampling': False, 'enable_torch_dist': False, 'enable_lora': False, 'skip_warmup': False, 'skip_compile': False, 'compile_only': False, 'hlo_debug': False}\n", + "\n", + "Compiling and saving model...\n", + "INFO:Neuron:Generating HLOs for the following models: ['context_encoding_model', 'token_generation_model']\n", + "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:588] > initializing tensor model parallel with size 8\n", + "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:589] > initializing pipeline model parallel with size 1\n", + "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:590] > initializing context model parallel with size 1\n", + "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:591] > initializing data parallel with size 1\n", + "[2025-05-14 14:09:05.945: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing world size to 8\n", + "[2025-05-14 14:09:05.945: I neuronx_distributed/parallel_layers/parallel_state.py:339] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:628] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:629] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:630] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:631] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "INFO:Neuron:Generating 1 hlos for key: context_encoding_model\n", + "INFO:Neuron:Started loading module context_encoding_model\n", + "INFO:Neuron:Finished loading module context_encoding_model in 0.07737994194030762 seconds\n", + "INFO:Neuron:generating HLO: context_encoding_model, input example shape = torch.Size([1, 16])\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:476: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", + " with torch.cuda.amp.autocast(enabled=False):\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 16]), dtype=torch.int32)\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=3, shape=torch.Size([1]), dtype=torch.int32)\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=4, shape=torch.Size([1, 3]), dtype=torch.float32)\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=5, shape=torch.Size([1]), dtype=torch.int32)\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=6, shape=torch.Size([1]), dtype=torch.int32)\n", + " warnings.warn(\n", + "INFO:Neuron:Generating 1 hlos for key: token_generation_model\n", + "INFO:Neuron:Started loading module token_generation_model\n", + "INFO:Neuron:Finished loading module token_generation_model in 0.06693840026855469 seconds\n", + "INFO:Neuron:generating HLO: token_generation_model, input example shape = torch.Size([1, 1])\n", + "INFO:Neuron:Started compilation for all HLOs\n", + "....Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "INFO:Neuron:Done compilation for the priority HLO\n", + "INFO:Neuron:Updating the hlo module with optimized layout\n", + "INFO:Neuron:Done optimizing weight layout for all HLOs\n", + "..........Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "INFO:Neuron:Finished Compilation for all HLOs\n", + "..Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "INFO:Neuron:Done preparing weight layout transformation\n", + "INFO:Neuron:Sharding Weights for ranks: 0...7\n", + "[2025-05-14 14:14:12.537: I neuronx_distributed/parallel_layers/parallel_state.py:588] > initializing tensor model parallel with size 8\n", + "[2025-05-14 14:14:12.537: I neuronx_distributed/parallel_layers/parallel_state.py:589] > initializing pipeline model parallel with size 1\n", + "[2025-05-14 14:14:12.537: I neuronx_distributed/parallel_layers/parallel_state.py:590] > initializing context model parallel with size 1\n", + "[2025-05-14 14:14:12.538: I neuronx_distributed/parallel_layers/parallel_state.py:591] > initializing data parallel with size 1\n", + "[2025-05-14 14:14:12.538: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing world size to 8\n", + "[2025-05-14 14:14:12.540: I neuronx_distributed/parallel_layers/parallel_state.py:339] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:628] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:629] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:630] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:631] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: lm_head.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: embed_tokens.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "INFO:Neuron:Done Sharding weights in 252.63744661300007\n", + "Compiling and tracing time: 559.4677159970001 seconds\n", + "\n", + "Loading model to Neuron...\n", + "INFO:Neuron:Warming up the model.\n", + "2025-May-14 14:18:35.0232 5872:7328 [2] nccl_net_ofi_rdma_init:7837 CCOM WARN NET/OFI OFI fi_getinfo() call failed: No data available\n", + "2025-May-14 14:18:35.0236 5872:7328 [2] nccl_net_ofi_create_plugin:261 CCOM WARN NET/OFI Unable to find a protocol that worked. Failing initialization.\n", + "2025-May-14 14:18:35.0239 5872:7328 [2] nccl_net_ofi_create_plugin:341 CCOM WARN NET/OFI aws-ofi-nccl initialization failed\n", + "2025-May-14 14:18:35.0242 5872:7328 [2] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed\n", + "2025-May-14 14:18:35.0245 5872:7328 [2] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?\n", + "INFO:Neuron:Warmup completed in 0.2721595764160156 seconds.\n", + "Total model loading time: 10.090576054999929 seconds\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:653: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", + " warnings.warn(\n", + "\n", + "Checking accuracy by logit matching\n", + "Loading checkpoint shards: 100%|██████████████████| 5/5 [00:01<00:00, 2.57it/s]\n", + "`generation_config` default values have been modified to match model-specific defaults: {'do_sample': True, 'temperature': 0.6, 'top_p': 0.95}. If this is not desired, please set these values explicitly.\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:631: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:636: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", + " warnings.warn(\n", + "Expected Output: [\", that is the question. Whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune\"] tensor([[ 11, 429, 374, 279, 3405, 13, 13139, 364, 83, 285,\n", + " 13049, 1536, 304, 279, 3971, 311, 7676, 279, 1739, 819,\n", + " 323, 36957, 315, 54488, 32315]])\n", + "Expected Logits Shape: torch.Size([25, 1, 151936])\n", + "HuggingFaceGenerationAdapter has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n", + " - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes\n", + " - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n", + " - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n", + "Actual Output: [\", that is the question. Whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune\"] tensor([[ 11, 429, 374, 279, 3405, 13, 13139, 364, 83, 285,\n", + " 13049, 1536, 304, 279, 3971, 311, 7676, 279, 1739, 819,\n", + " 323, 36957, 315, 54488, 32315]])\n", + "Actual Logits Shape: torch.Size([25, 1, 151936])\n", + "Passed logits validation!\n", + "\n", + "Generating outputs...\n", + "Prompts: ['To be, or not to be']\n", + "Generated outputs:\n", + "Output 0: To be, or not to be, that is the question. Whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune\n", + "Benchmark completed and its result is as following\n", + "{\n", + " \"e2e_model\": {\n", + " \"latency_ms_p50\": 156.56781196594238,\n", + " \"latency_ms_p90\": 158.08086395263672,\n", + " \"latency_ms_p95\": 158.1140637397766,\n", + " \"latency_ms_p99\": 158.28602075576782,\n", + " \"latency_ms_p100\": 158.32901000976562,\n", + " \"latency_ms_avg\": 156.99772834777832,\n", + " \"throughput\": 203.82460521412273\n", + " },\n", + " \"context_encoding_model\": {\n", + " \"latency_ms_p50\": 10.202646255493164,\n", + " \"latency_ms_p90\": 10.224390029907227,\n", + " \"latency_ms_p95\": 10.22493839263916,\n", + " \"latency_ms_p99\": 10.226750373840332,\n", + " \"latency_ms_p100\": 10.227203369140625,\n", + " \"latency_ms_avg\": 10.201811790466309,\n", + " \"throughput\": 1568.348870634151\n", + " },\n", + " \"token_generation_model\": {\n", + " \"latency_ms_p50\": 8.858323097229004,\n", + " \"latency_ms_p90\": 8.903312683105469,\n", + " \"latency_ms_p95\": 9.238588809967041,\n", + " \"latency_ms_p99\": 9.264287948608398,\n", + " \"latency_ms_p100\": 9.28950309753418,\n", + " \"latency_ms_avg\": 8.88296922047933,\n", + " \"throughput\": 120.07996877975322\n", + " }\n", + "}\n", + "Completed saving result to benchmark_report.json\n" + ] + } + ], + "source": [ + "!inference_demo \\\n", + " --model-type qwen3 \\\n", + " --task-type causal-lm \\\n", + " run \\\n", + " --model-path /home/ubuntu/model_hf_qwen/qwen/ \\\n", + " --compiled-model-path /home/ubuntu/traced_model_qwen/qwen/logit \\\n", + " --torch-dtype bfloat16 \\\n", + " --tp-degree 8 \\\n", + " --batch-size 1 \\\n", + " --max-context-length 16 \\\n", + " --seq-len 32 \\\n", + " --enable-bucketing \\\n", + " --pad-token-id 151645 \\\n", + " --prompt \"To be, or not to be\" \\\n", + " --check-accuracy-mode logit-matching \\\n", + " --benchmark" + ] } ], "metadata": { From f3314cd8c0b5a36d5c3957d3900a5484697d0471 Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Mon, 19 May 2025 15:54:32 -0400 Subject: [PATCH 4/6] update with thinking example --- contributed/models/qwen3/qwen-3-test.ipynb | 199 ++++++++++++++++++++- 1 file changed, 192 insertions(+), 7 deletions(-) diff --git a/contributed/models/qwen3/qwen-3-test.ipynb b/contributed/models/qwen3/qwen-3-test.ipynb index bbb60dd..15a965d 100644 --- a/contributed/models/qwen3/qwen-3-test.ipynb +++ b/contributed/models/qwen3/qwen-3-test.ipynb @@ -33,6 +33,13 @@ "from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter, load_pretrained_config" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Download" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -40,7 +47,7 @@ "outputs": [], "source": [ "model_path = \"/home/ubuntu/model_hf_qwen/qwen/\"\n", - "traced_model_path = \"/home/ubuntu/traced_model_qwen/qwen/\"" + "traced_model_path = \"/home/ubuntu/traced_model_qwen3/qwen3/\"" ] }, { @@ -54,6 +61,13 @@ "snapshot_download(\"Qwen/Qwen3-8B\", local_dir=model_path)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compilation" + ] + }, { "cell_type": "code", "execution_count": null, @@ -78,12 +92,12 @@ " neuron_config = NeuronConfig(\n", " tp_degree=8,\n", " batch_size=1,\n", - " max_context_length=128,\n", - " seq_len=256,\n", + " max_context_length=1024, \n", + " seq_len=2048, \n", " on_device_sampling_config=OnDeviceSamplingConfig(top_k=5),\n", " enable_bucketing=True,\n", - " context_encoding_buckets=[128],\n", - " token_generation_buckets=[256],\n", + " context_encoding_buckets=[1024],\n", + " token_generation_buckets=[2048],\n", " flash_decoding_enabled=False,\n", " torch_dtype=torch.bfloat16,\n", " fused_qkv=False,\n", @@ -111,6 +125,13 @@ "run_qwen3_compile()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Testing" + ] + }, { "cell_type": "code", "execution_count": null, @@ -202,7 +223,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -231,15 +252,161 @@ "print(\"content:\", content)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Thinking example" + ] + }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "model.reset()" ] }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(traced_model_path)\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "generation_config = GenerationConfig.from_pretrained(model_path)\n", + "generation_config_kwargs = {\n", + " \"do_sample\": False,\n", + " \"temperature\": 0.9,\n", + " \"top_k\": 5,\n", + " \"pad_token_id\": tokenizer.pad_token_id,\n", + "}\n", + "generation_config.update(**generation_config_kwargs)\n", + "generation_model = HuggingFaceGenerationAdapter(model)\n", + "messages = [{'role': 'system', 'content': \"Only think through one example before providing the correct answer\"},\n", + " {'role': 'user', 'content': \"What is 83 * 110 + 34?\"}\n", + " ]\n", + "text = tokenizer.apply_chat_template(\n", + " messages,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.\n", + ")\n", + "inputs = tokenizer([text], return_tensors=\"pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\nGenerating outputs...\")\n", + "outputs = generation_model.generate(\n", + " **inputs,\n", + " max_new_tokens=1024\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "thinking content: \n", + "Okay, let's see. I need to calculate 83 multiplied by 110 and then add 34 to the result. Hmm, let me break this down step by step. First, I should handle the multiplication part: 83 times 110. \n", + "\n", + "Wait, multiplying by 110 might be easier if I think of it as multiplying by 100 and then adding 10 times the number. Because 110 is 100 + 10. So, 83 times 100 is 8300, and 83 times 10 is 830. Then adding those two together: 8300 + 830. Let me check that. 8300 plus 800 is 9100, and then plus 30 more would be 9130. So, 83 * 110 equals 9130?\n", + "\n", + "Wait, let me verify that another way. Maybe using the standard multiplication method. Let's write it out:\n", + "\n", + " 83\n", + "x110\n", + "------\n", + "First, multiply 83 by 0 (the units place of 110), which gives 0.\n", + "Then multiply 83 by 1 (the tens place of 110), which is 83, but since it's in the tens place, it's actually 830.\n", + "Then multiply 83 by 1 (the hundreds place of 110), which is 83, but since it's in the hundreds place, it's 8300.\n", + "Adding those together: 0 + 830 + 8300 = 9130. Okay, that matches my previous result. So 83*110 is indeed 9130.\n", + "\n", + "Now, the next part is adding 34 to that result. So 9130 + 34. Let me do that. 9130 plus 30 is 9160, and then plus 4 more is 9164. \n", + "\n", + "Wait, let me check again. 9130 + 34. Breaking it down: 9130 + 30 = 9160, then 9160 + 4 = 9164. Yes, that seems right. \n", + "\n", + "Alternatively, I can add 34 to 9130 directly. 9130 + 34. The units digit: 0 + 4 = 4. The tens digit: 3 + 3 = 6. The hundreds and above remain the same. So 9164. Yep, that's correct.\n", + "\n", + "So putting it all together, 83 multiplied by 110 is 9130, and adding 34 gives 9164. I think that's the right answer. Let me just confirm once more with another method. Maybe using distributive property for the entire expression.\n", + "\n", + "Original problem: 83*110 + 34. Let's think of 110 as 100 + 10, so 83*(100 + 10) + 34 = 83*100 + 83*10 + 34 = 8300 + 830 + 34. Adding those: 8300 + 830 is 9130, then 9130 + 34 is 9164. Yep, same result. \n", + "\n", + "I think that's solid. No mistakes in the steps. So the final answer should be 9164.\n", + "\n", + "################################################################################################################################################################################################################################################################################################################################\n", + "content: To solve $ 83 \\times 110 + 34 $, we break it into two parts:\n", + "\n", + "1. **Multiplication**: \n", + " $ 83 \\times 110 $ can be simplified by recognizing that $ 110 = 100 + 10 $. \n", + " $$\n", + " 83 \\times 110 = 83 \\times (100 + 10) = (83 \\times 100) + (83 \\times 10) = 8300 + 830 = 9130\n", + " $$\n", + "\n", + "2. **Addition**: \n", + " Add 34 to the result: \n", + " $$\n", + " 9130 + 34 = 9164\n", + " $$\n", + "\n", + "**Final Answer:** \n", + "$$\n", + "\\boxed{9164}\n", + "$$\n" + ] + } + ], + "source": [ + "output_ids = outputs[0][len(inputs.input_ids[0]):].tolist() \n", + "\n", + "# parsing thinking content\n", + "try:\n", + " # rindex finding 151668 ()\n", + " index = len(output_ids) - output_ids[::-1].index(151668)\n", + "except ValueError:\n", + " index = 0\n", + "\n", + "thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(\"\\n\")\n", + "content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip(\"\\n\")\n", + "\n", + "print(\"thinking content:\", thinking_content)\n", + "print('####'*80)\n", + "print(\"content:\", content)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "9164" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ans = 83*110+34\n", + "ans" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -257,6 +424,24 @@ "!cp modeling_qwen.py {dir}" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#Edit the inference_demo.py file to include the following:\n", + "\n", + "```python\n", + "from .modeling_qwen import NeuronQwen3ForCausalLM\n", + "\n", + "MODEL_TYPES = {\n", + " \"llama\": {\"causal-lm\": NeuronLlamaForCausalLM},\n", + " \"mixtral\": {\"causal-lm\": NeuronMixtralForCausalLM},\n", + " \"dbrx\": {\"causal-lm\": NeuronDbrxForCausalLM},\n", + " 'qwen3': {\"causal-lm\": NeuronQwen3ForCausalLM}\n", + "}\n", + "```" + ] + }, { "cell_type": "code", "execution_count": 1, From eeb5c29621e1023bc1f6cda7667cfc751012deea Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Mon, 2 Jun 2025 11:08:46 -0400 Subject: [PATCH 5/6] Update Qwen3 for latest nxdi --- .../{modeling_qwen.py => modeling_qwen3.py} | 672 ++++++---- contributed/models/qwen3/qwen-3-test.ipynb | 1151 ++++------------- 2 files changed, 608 insertions(+), 1215 deletions(-) rename contributed/models/qwen3/{modeling_qwen.py => modeling_qwen3.py} (59%) diff --git a/contributed/models/qwen3/modeling_qwen.py b/contributed/models/qwen3/modeling_qwen3.py similarity index 59% rename from contributed/models/qwen3/modeling_qwen.py rename to contributed/models/qwen3/modeling_qwen3.py index cc5eb7f..ce9662d 100644 --- a/contributed/models/qwen3/modeling_qwen.py +++ b/contributed/models/qwen3/modeling_qwen3.py @@ -18,18 +18,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Qwen3 model for NXD inference.""" - import copy import gc import logging +import math from typing import List, Optional, Tuple, Type - -from neuronx_distributed_inference.modules.attention.utils import ( - apply_rotary_pos_emb, - move_heads_front, -) - import torch from neuronx_distributed.parallel_layers import parallel_state # noqa: E402 from neuronx_distributed.parallel_layers.layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 @@ -41,17 +35,10 @@ gather_from_sequence_parallel_region, reduce_from_tensor_model_parallel_region, reduce_scatter_to_sequence_parallel_region, + _gather_along_first_dim, ) from neuronx_distributed.parallel_layers.utils import get_padding_length -from neuronx_distributed.quantization.quantization_config import ( - QuantizationType, - QuantizedDtype, -) -from neuronx_distributed.quantization.quantization_layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 - QuantizedColumnParallel, - QuantizedRowParallel, -) - +from neuronx_distributed.utils import cpu_mode from neuronxcc.nki._private_kernels.mlp import ( mlp_fused_add_isa_kernel, mlp_isa_kernel, @@ -71,23 +58,20 @@ NeuronBaseForCausalLM, NeuronBaseModel, ) -from neuronx_distributed_inference.modules.attention.attention_base import ( - NeuronAttentionBase, -) +from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase from neuronx_distributed_inference.modules.attention.gqa import ( # noqa: E402 BaseGroupQueryAttention, ) from neuronx_distributed_inference.modules.attention.utils import ( + RotaryEmbedding, preprocess_quantized_linear_layer, transpose_parallel_linear_layer, + apply_rotary_pos_emb, + move_heads_front, ) from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm -from neuronx_distributed_inference.modules.flashdecode.utils import ( - calculate_num_cores_per_group, -) -from neuronx_distributed_inference.modules.lora_serving.lora_module import ( - is_lora_module, -) +from neuronx_distributed_inference.modules.flashdecode.utils import calculate_num_cores_per_group +from neuronx_distributed_inference.modules.lora_serving.lora_module import is_lora_module from neuronx_distributed_inference.utils.distributed import get_tp_group _Qwen3_MODULE_MAP = {} @@ -99,22 +83,57 @@ def get_rmsnorm_cls(): # Initialize to the appropriate implementation of RMSNorm # If infer on NXD -> CustomRMSNorm # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) - return ( - CustomRMSNorm - if parallel_state.model_parallel_is_initialized() - else Qwen3RMSNorm - ) + return Qwen3RMSNorm if cpu_mode() else CustomRMSNorm -def preshard_hook_fn( - module: torch.nn.Module, model_state_dict: dict, prefix: str -) -> bool: +def preshard_hook_fn(module: torch.nn.Module, model_state_dict: dict, prefix: str) -> bool: if isinstance(module, (BaseGroupQueryAttention,)): return module.preshard_hook(model_state_dict, prefix) return False +# Get the modules_to_not_convert from the neuron configs +def get_modules_to_not_convert(neuron_config: NeuronConfig): + return getattr(neuron_config, "modules_to_not_convert", None) + + +def get_updated_configs(config: InferenceConfig): + """ + Generate a list of configurations for each hidden layer in a Qwen3 model. + + This function creates a list of InferenceConfig objects, one for each layer. It + modifies the configurations for certain layers based on which modules should not + be converted to quantized format. The function uses get_modules_to_not_convert() + to determine which modules should not be converted. + + Args: + config (InferenceConfig): The inference configuration for the model. + + Returns: + list[InferenceConfig]: A list of InferenceConfig objects, one for each layer in the model. + Each config may be either the original config or a modified version + with "quantized_mlp_kernel_enabled" as False for that specific layer. + """ + updated_configs = [] + modules_to_not_convert = get_modules_to_not_convert(config.neuron_config) + if modules_to_not_convert is None: + modules_to_not_convert = [] + + for i in range(config.num_hidden_layers): + # If any of the MLP modules for this layer are in modules_to_not_convert + module_pattern = f"layers.{i}.mlp" + if any(module_pattern in module for module in modules_to_not_convert): + non_quant_config = copy.deepcopy(config) + non_quant_config.neuron_config.quantized_mlp_kernel_enabled = False + non_quant_config.neuron_config.activation_quantization_type = None + non_quant_config.neuron_config.quantize_clamp_bound = float("inf") + updated_configs.append(non_quant_config) + else: + updated_configs.append(config) + return updated_configs + + def _register_module(key: str, cls: Type[nn.Module]): _Qwen3_MODULE_MAP[key] = cls @@ -122,8 +141,10 @@ def _register_module(key: str, cls: Type[nn.Module]): def register_module(key: str): """ Register a module for use in NeuronQwen3. + Arguments: key: String used to identify the module + Example: @register_module("NeuronQwen3Attention") class NeuronQwen3Attention(nn.Module): @@ -137,35 +158,101 @@ def inner(cls: Type[nn.Module]): return inner +def _helper_concat_and_delete_qkv(Qwen3_state_dict, layer_num, attr): + """ + Helper function to concatenate and delete QKV attributes for fusedqkv (weight or scale). + Args: + Qwen3_state_dict: The state dictionary containing model weights + layer_num: The index of the layer to process + attr: The attribute to process ('weight' or 'scale') + """ + Qwen3_state_dict[f"layers.{layer_num}.self_attn.Wqkv.{attr}"] = torch.cat( + [ + Qwen3_state_dict[f"layers.{layer_num}.self_attn.q_proj.{attr}"], + Qwen3_state_dict[f"layers.{layer_num}.self_attn.k_proj.{attr}"], + Qwen3_state_dict[f"layers.{layer_num}.self_attn.v_proj.{attr}"], + ], + ) + del Qwen3_state_dict[f"layers.{layer_num}.self_attn.q_proj.{attr}"] + del Qwen3_state_dict[f"layers.{layer_num}.self_attn.k_proj.{attr}"] + del Qwen3_state_dict[f"layers.{layer_num}.self_attn.v_proj.{attr}"] + + def convert_state_dict_to_fused_qkv(Qwen3_state_dict, cfg: InferenceConfig): """ - This function concats the qkv weights to a Wqkv weight for fusedqkv, and deletes the qkv weights. + This function concats the qkv weights and scales to a Wqkv weight and scale for fusedqkv, and deletes the qkv weights. """ + mods_to_not_conv = get_modules_to_not_convert(cfg.neuron_config) + if mods_to_not_conv is None: + mods_to_not_conv = [] + for l in range(cfg.num_hidden_layers): # noqa: E741 - Qwen3_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( - [ - Qwen3_state_dict[f"layers.{l}.self_attn.q_proj.weight"], - Qwen3_state_dict[f"layers.{l}.self_attn.k_proj.weight"], - Qwen3_state_dict[f"layers.{l}.self_attn.v_proj.weight"], - ], - ) - del Qwen3_state_dict[f"layers.{l}.self_attn.q_proj.weight"] - del Qwen3_state_dict[f"layers.{l}.self_attn.k_proj.weight"] - del Qwen3_state_dict[f"layers.{l}.self_attn.v_proj.weight"] + _helper_concat_and_delete_qkv(Qwen3_state_dict, l, "weight") + if ( + cfg.neuron_config.quantized_mlp_kernel_enabled or cfg.neuron_config.quantized + ) and f"layers.{l}.self_attn" not in mods_to_not_conv: + _helper_concat_and_delete_qkv(Qwen3_state_dict, l, "scale") gc.collect() return Qwen3_state_dict +class WeightGatheredColumnParallel(ColumnParallelLinear): + """ + A specialized column-parallel linear layer that implements weight gathering optimization + for efficient processing of long sequences in transformer models during eagle speculation. + + This layer provides two forward paths: + 1. Standard column-parallel forward (inherited from parent) + 2. Weight-gathered forward for long sequences + """ + def forward_wg(self, input: torch, weight_gather: bool = False): + """ + Performs the forward pass with optional weight gathering optimization. + + Args: + input (torch.Tensor): Input tensor of shape (batch_size, seq_len/TP, 2*hidden_size) + weight_gather (bool): Whether to use weight gathering optimization. + Typically True for sequences >= 32K + + Returns: + torch.Tensor or Tuple[torch.Tensor, torch.Tensor]: + - If skip_bias_add is False: Output tensor of shape (batch_size, seq_len, hidden_size) + - If skip_bias_add is True: Tuple of (output tensor, bias) + """ + if weight_gather: + weight = _gather_along_first_dim(self.weight, process_group=self.tensor_parallel_group) + output = self._forward_impl( + input=input, + weight=weight, + bias=None, + async_grad_allreduce=self.async_tensor_model_parallel_allreduce, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + autograd_func_class=self.autograd_func_class, + process_group=self.tensor_parallel_group + ) + + output = gather_from_sequence_parallel_region( + output, + self.sequence_dimension, + process_group=self.tensor_parallel_group, + ) + if self.skip_bias_add: + return output, self.bias + + output = (output + self.bias) if self.bias is not None else output + return output + else: + return self.forward(input) + + class Qwen3InferenceConfig(InferenceConfig): def add_derived_config(self): self.num_cores_per_group = 1 if self.neuron_config.flash_decoding_enabled: - num_attn_heads, num_kv_heads = ( - self.num_attention_heads, - self.num_key_value_heads, - ) + num_attn_heads, num_kv_heads = self.num_attention_heads, self.num_key_value_heads self.num_cores_per_group = calculate_num_cores_per_group( num_attn_heads, num_kv_heads, self.neuron_config.tp_degree ) @@ -209,154 +296,121 @@ def __init__(self, config: InferenceConfig): self.sequence_dimension = 1 if self.sequence_parallel_enabled else None self.rms_norm_eps = config.rms_norm_eps self.mlp_kernel_enabled = self.neuron_config.mlp_kernel_enabled - self.quantized_mlp_kernel_enabled = ( - self.neuron_config.quantized_mlp_kernel_enabled - ) - self.rmsnorm_quantize_kernel_enabled = ( - self.neuron_config.rmsnorm_quantize_kernel_enabled - ) - self.logical_neuron_cores = self.neuron_config.logical_neuron_cores + self.fused_rmsnorm_skip_gamma = self.config.neuron_config.fused_rmsnorm_skip_gamma + self.quantized_mlp_kernel_enabled = self.neuron_config.quantized_mlp_kernel_enabled + self.rmsnorm_quantize_kernel_enabled = self.neuron_config.rmsnorm_quantize_kernel_enabled + self.quantize_clamp_bound = self.neuron_config.quantize_clamp_bound + self.logical_nc_config = self.neuron_config.logical_nc_config + self.activation_quantization_type = self.neuron_config.activation_quantization_type mlp_bias = getattr(config, "mlp_bias", False) + + if self.neuron_config.quantized_mlp_kernel_enabled and self.quantize_clamp_bound == float( + "inf" + ): + logging.warning( + "quantize_clamp_bound is not specified in NeuronConfig. We will use the default value of 1200 for Qwen3 models in quantized kernels." + ) + self.quantize_clamp_bound = 1200.0 if parallel_state.model_parallel_is_initialized(): - if self.quantized_mlp_kernel_enabled: - # Quantized MLP kernels expect intermediate size to be multiple of 128, so we need to pad + if self.neuron_config.quantized_mlp_kernel_enabled: + # # Quantized MLP kernels expect intermediate size to be multiple of 128, so we need to pad tp_degree = self.neuron_config.tp_degree self.intermediate_size += ( - get_padding_length(self.intermediate_size // tp_degree, 128) - * tp_degree + get_padding_length(self.intermediate_size // tp_degree, 128) * tp_degree ) logger.debug(f"Quantized intermediate_size: {self.intermediate_size}") - - quantization_type = QuantizationType( - self.neuron_config.quantization_type - ) - quantized_dtype = QuantizedDtype.F8E4M3 - self.gate_proj = QuantizedColumnParallel( - input_size=self.hidden_size, - output_size=self.intermediate_size, - bias=mlp_bias, - gather_output=False, - sequence_parallel_enabled=False, - dtype=config.neuron_config.torch_dtype, - quantized_dtype=quantized_dtype, - quantization_type=quantization_type, - tensor_model_parallel_group=get_tp_group(config), - ) - self.up_proj = QuantizedColumnParallel( - input_size=self.hidden_size, - output_size=self.intermediate_size, - bias=mlp_bias, - gather_output=False, - sequence_parallel_enabled=False, - dtype=config.neuron_config.torch_dtype, - quantized_dtype=quantized_dtype, - quantization_type=quantization_type, - tensor_model_parallel_group=get_tp_group(config), - ) - self.down_proj = QuantizedRowParallel( - input_size=self.intermediate_size, - output_size=self.hidden_size, - bias=mlp_bias, - quantization_type=quantization_type, - input_is_parallel=True, - dtype=config.neuron_config.torch_dtype, - quantized_dtype=quantized_dtype, - sequence_parallel_enabled=False, - quantization_per_channel_axis=0, - tensor_model_parallel_group=get_tp_group(config), - ) - - else: - self.gate_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - bias=mlp_bias, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=False, - sequence_dimension=None, - tensor_model_parallel_group=get_tp_group(config), - ) - self.up_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - bias=mlp_bias, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=False, - sequence_dimension=None, - tensor_model_parallel_group=get_tp_group(config), - ) - self.down_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=mlp_bias, - input_is_parallel=True, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=self.sequence_parallel_enabled, - sequence_dimension=self.sequence_dimension, - tensor_model_parallel_group=get_tp_group(config), - reduce_dtype=config.neuron_config.rpl_reduce_dtype, - ) - + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=mlp_bias, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + tensor_model_parallel_group=get_tp_group(config), + reduce_dtype=config.neuron_config.rpl_reduce_dtype, + ) if self.mlp_kernel_enabled: - if self.quantized_mlp_kernel_enabled: - preprocess_quantized_linear_layer(self.gate_proj) - preprocess_quantized_linear_layer(self.up_proj) - preprocess_quantized_linear_layer(self.down_proj) - - else: - # Transpose the weights to the layout expected by kernels - self.gate_proj.weight = transpose_parallel_linear_layer( - self.gate_proj.weight + if self.neuron_config.quantized_mlp_kernel_enabled: + setattr( + self.gate_proj, + "post_create_quantized_module_hook", + preprocess_quantized_linear_layer, ) - self.up_proj.weight = transpose_parallel_linear_layer( - self.up_proj.weight + setattr( + self.up_proj, + "post_create_quantized_module_hook", + preprocess_quantized_linear_layer, ) - self.down_proj.weight = transpose_parallel_linear_layer( - self.down_proj.weight + setattr( + self.down_proj, + "post_create_quantized_module_hook", + preprocess_quantized_linear_layer, ) + else: + # Transpose the weights to the layout expected by kernels + self.gate_proj.weight = transpose_parallel_linear_layer(self.gate_proj.weight) + self.up_proj.weight = transpose_parallel_linear_layer(self.up_proj.weight) + self.down_proj.weight = transpose_parallel_linear_layer(self.down_proj.weight) else: - self.gate_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=mlp_bias - ) - self.up_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=mlp_bias - ) - self.down_proj = nn.Linear( - self.intermediate_size, self.hidden_size, bias=mlp_bias - ) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias) - def _kernel_enabled_quantized_mlp( - self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids - ): - grid = (nc(self.logical_neuron_cores),) + def _kernel_enabled_quantized_mlp(self, x, rmsnorm, residual, adapter_ids): + grid = (nc(self.logical_nc_config),) fused_residual = residual is not None + fused_rmsnorm = rmsnorm is not None logger.debug( - f"MLP: quantized kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + f"MLP: quantized kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_nc_config={self.logical_nc_config}" ) # Can't do residual add in the kernel if SP is enabled if fused_residual: - assert not self.sequence_parallel_enabled, ( - "Quantized MLP cannot have both fused residual add and sequence parallel RMSnorm!" - ) + assert ( + not self.sequence_parallel_enabled + ), "Quantized MLP cannot have both fused residual add and sequence parallel RMSnorm!" # Using fused residual add _mlp_fwd_call = nki_jit()(quant_mlp_fused_add_isa_kernel) else: _mlp_fwd_call = nki_jit()(quant_mlp_isa_kernel) + if fused_rmsnorm: + ln_w = rmsnorm.weight.unsqueeze(0) + else: + ln_w = torch.zeros(size=(1, self.hidden_size), dtype=x.dtype, device=x.device) + # Handle SP RMSnorm x_orig_dtype = x.dtype if self.sequence_parallel_enabled: # This RMSNormQuant kernel will do quantization inside, so we pass the - # lower_bound for clipping. + # clamp_bound for clipping. # If we don't use this kernel, the MLP kernel below will do the - # quantization, so we also pass lower_bound to that kernel. + # quantization, so we also pass clamp_bound to that kernel. if self.rmsnorm_quantize_kernel_enabled: logger.debug( "Running Quantized MLP kernel with sequence-parallel RMSnorm-Quantize kernel!" @@ -371,10 +425,9 @@ def _kernel_enabled_quantized_mlp( dtype=torch.int8, device=x.device, ) - ln_w = rmsnorm.weight.unsqueeze(0) - lower_bound = self.quantized_kernel_lower_bound + clamp_bound = self.quantize_clamp_bound _rmsnorm_quant_fwd_call[grid]( - x, ln_w, lower_bound, quant_rmsnorm_out, kernel_name="QuantOnly" + x, ln_w, clamp_bound, quant_rmsnorm_out, kernel_name="QuantOnly" ) x = gather_from_sequence_parallel_region( quant_rmsnorm_out, @@ -409,14 +462,13 @@ def _kernel_enabled_quantized_mlp( # Grab weights # all weights of the layers are stored in (out, in) shape # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] - ln_w = rmsnorm.weight.unsqueeze(0) gate_w = self.gate_proj.weight.data - gate_w_scale = self.gate_proj.weight_scale + gate_w_scale = self.gate_proj.scale up_w = self.up_proj.weight.data - up_w_scale = self.up_proj.weight_scale + up_w_scale = self.up_proj.scale down_w = self.down_proj.weight.data - down_w_scale = self.down_proj.weight_scale - lower_bound = self.quantized_kernel_lower_bound + down_w_scale = self.down_proj.scale + clamp_bound = self.quantize_clamp_bound if fused_residual: _mlp_fwd_call[grid]( @@ -429,7 +481,7 @@ def _kernel_enabled_quantized_mlp( up_w_scale, down_w, # down_w down_w_scale, - lower_bound, + clamp_bound, output_tensor, # out fused_rmsnorm=fused_rmsnorm, eps=self.rms_norm_eps, @@ -450,7 +502,7 @@ def _kernel_enabled_quantized_mlp( up_w_scale, down_w, # down_w down_w_scale, - lower_bound, + clamp_bound, output_tensor, # out # Run RMSNorm inside the kernel if NOT using SP rmsnorm fused_rmsnorm=fused_rmsnorm, @@ -462,9 +514,7 @@ def _kernel_enabled_quantized_mlp( # All-reduce or reduce-scatter, depending on whether SP is enabled if self.sequence_parallel_enabled: output_tensor = reduce_scatter_to_sequence_parallel_region( - output_tensor, - self.sequence_dimension, - process_group=get_tp_group(self.config), + output_tensor, self.sequence_dimension, process_group=get_tp_group(self.config) ) else: output_tensor = reduce_from_tensor_model_parallel_region(output_tensor) @@ -472,17 +522,18 @@ def _kernel_enabled_quantized_mlp( logger.debug(f"Quantized MLP output shape {output_tensor.shape}") return (output_tensor, residual) - def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): + def _kernel_enabled_mlp(self, x, rmsnorm, residual, adapter_ids): fused_residual = residual is not None + fused_rmsnorm = rmsnorm is not None logger.debug( - f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, skip_gamma={self.fused_rmsnorm_skip_gamma}, logical_nc_config={self.logical_nc_config}" ) # Choose which kernel to call if fused_residual: - assert not self.sequence_parallel_enabled, ( - "MLP kernel cannot have both fused residual add and sequence parallel RMSnorm!" - ) + assert ( + not self.sequence_parallel_enabled + ), "MLP kernel cannot have both fused residual add and sequence parallel RMSnorm!" # Using fused residual add _mlp_fwd_call = nki_jit()(mlp_fused_add_isa_kernel) else: @@ -512,12 +563,15 @@ def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): # Grab weights # all weights of the layers are stored in (out, in) shape # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] - ln_w = rmsnorm.weight.unsqueeze(0) + if fused_rmsnorm: + ln_w = rmsnorm.weight.unsqueeze(0) + else: + ln_w = torch.zeros(size=(1, self.hidden_size), dtype=x.dtype, device=x.device) gate_w = self.gate_proj.weight.data up_w = self.up_proj.weight.data down_w = self.down_proj.weight.data - grid = (nc(self.logical_neuron_cores),) + grid = (nc(self.logical_nc_config),) if fused_residual: _mlp_fwd_call[grid]( @@ -528,9 +582,10 @@ def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): up_w, # up_w down_w, # down_w output_tensor, # out + kernel_name="MLP", fused_rmsnorm=fused_rmsnorm, + skip_gamma=self.fused_rmsnorm_skip_gamma, eps=self.rms_norm_eps, - kernel_name="MLP", store_add=True, ) original_seqlen = x.shape[1] @@ -545,19 +600,18 @@ def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): up_w, down_w, output_tensor, # out + kernel_name="MLP", # Run RMSNorm inside the kernel if NOT using SP rmsnorm fused_rmsnorm=fused_rmsnorm, + skip_gamma=self.fused_rmsnorm_skip_gamma, eps=self.rms_norm_eps, - kernel_name="MLP", ) residual = None # All-reduce or reduce-scatter, depending on whether SP is enabled if self.sequence_parallel_enabled: output_tensor = reduce_scatter_to_sequence_parallel_region( - output_tensor, - self.sequence_dimension, - process_group=get_tp_group(self.config), + output_tensor, self.sequence_dimension, process_group=get_tp_group(self.config) ) else: output_tensor = reduce_from_tensor_model_parallel_region( @@ -567,7 +621,7 @@ def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): logger.debug(f"MLP output shape {output_tensor.shape}") return (output_tensor, residual) - def _native_mlp(self, x, rmsnorm, adapter_ids=None): + def _native_mlp(self, x, adapter_ids=None): logger.debug("MLP: native compiler") # all-gather is done here instead of CPL layers to # avoid 2 all-gathers from up and gate projections @@ -575,21 +629,20 @@ def _native_mlp(self, x, rmsnorm, adapter_ids=None): x = gather_from_sequence_parallel_region( x, self.sequence_dimension, process_group=get_tp_group(self.config) ) - gate_proj_output = ( self.gate_proj(x) if not is_lora_module(self.gate_proj) else self.gate_proj(x, adapter_ids) ) + up_proj_output = ( - self.up_proj(x) - if not is_lora_module(self.up_proj) - else self.up_proj(x, adapter_ids) + self.up_proj(x) if not is_lora_module(self.up_proj) else self.up_proj(x, adapter_ids) ) + down_proj_input = self.act_fn(gate_proj_output) * up_proj_output output = ( self.down_proj(down_proj_input) - if not is_lora_module(self.up_proj) + if not is_lora_module(self.down_proj) else self.down_proj(down_proj_input, adapter_ids) ) logger.debug(f"MLP output shape {output.shape}") @@ -598,22 +651,23 @@ def _native_mlp(self, x, rmsnorm, adapter_ids=None): def forward(self, x, rmsnorm=None, residual=None, adapter_ids=None): """ If residual is passed in, will fuse its add into the MLP kernel + If rmsnorm is passed in, will fuse the rmsnorm into the MLP kernel + Returns a tuple of (output, residual), where residual is the output of the residual add """ + if self.mlp_kernel_enabled: - fused_rmsnorm = not self.sequence_parallel_enabled # Quantized MLP kernel if self.quantized_mlp_kernel_enabled: return self._kernel_enabled_quantized_mlp( - x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids + x, rmsnorm, residual, adapter_ids=adapter_ids ) # MLP kernel - return self._kernel_enabled_mlp( - x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids - ) + return self._kernel_enabled_mlp(x, rmsnorm, residual, adapter_ids=adapter_ids) else: # No kernel - return (self._native_mlp(x, rmsnorm, adapter_ids=adapter_ids), None) + assert rmsnorm is None and residual is None + return (self._native_mlp(x, adapter_ids=adapter_ids), None) @register_module("NeuronQwen3Attention") @@ -635,7 +689,7 @@ def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None): self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads - self.head_dim = self.hidden_size // self.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_attention_heads) self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.padding_side = config.neuron_config.padding_side @@ -647,7 +701,7 @@ def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None): self.rpl_reduce_dtype = config.neuron_config.rpl_reduce_dtype self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled self.rms_norm_eps = config.rms_norm_eps - + self.attn_tkg_builtin_kernel_enabled = self.neuron_config.attn_tkg_builtin_kernel_enabled self.q_norm = Qwen3RMSNorm( self.head_dim, eps=config.rms_norm_eps ) # unlike olmo, only on the head dim! @@ -668,8 +722,9 @@ def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None): ) self.init_gqa_properties() - self.init_rope() + self.init_rope() + def prep_qkv_tensors( self, position_ids, @@ -679,11 +734,18 @@ def prep_qkv_tensors( cos_cache=None, sin_cache=None, rmsnorm=None, + skip_rope=False, + residual=None, ): - """take care of the shape, layout, group query, custom position encoding, etc.""" - Q, K, V = self.qkv_proj( - hidden_states=hidden_states, rmsnorm=rmsnorm, adapter_ids=adapter_ids + """take care of the shape, layout, group query, custom position encoding, etc. + also return residual for MLP """ + Q, K, V, residual = self.qkv_proj( + hidden_states=hidden_states, rmsnorm=rmsnorm, adapter_ids=adapter_ids, residual=residual ) + if self.use_qk_norm: + self.init_qk_norm() # TODO: when attentionbase can take config parameters in init, move this to init function + Q = self.qk_norm(Q) + K = self.qk_norm(K) # Divide hidden_dim across heads for MHA # Change layout: BSHD -> BHSD @@ -695,29 +757,25 @@ def prep_qkv_tensors( Q, bsz, q_len, self.num_heads, self.head_dim, layernorm=self.q_norm ) K = move_heads_front( - K, - bsz, - q_len, - self.num_key_value_heads, - self.head_dim, - layernorm=self.k_norm, - ) - V = move_heads_front( - V, bsz, q_len, self.num_key_value_heads, self.head_dim, layernorm=None + K, bsz, q_len, self.num_key_value_heads, self.head_dim, layernorm=self.k_norm ) + V = move_heads_front(V, bsz, q_len, self.num_key_value_heads, self.head_dim, layernorm=None) # Rotate Q and K - if self.rotary_emb is not None: + if not skip_rope and self.rotary_emb is not None: if cos_cache is None or sin_cache is None: cos_cache, sin_cache = self.rotary_emb(V, position_ids) Q, K = apply_rotary_pos_emb(Q, K, cos_cache, sin_cache) - return Q, K, V, cos_cache, sin_cache - + return Q, K, V, cos_cache, sin_cache, residual + def init_rope(self): self.rotary_emb = Qwen3RotaryEmbedding(self.config) + if self.attn_tkg_builtin_kernel_enabled: + self.inv_freqs = self.rotary_emb.get_inv_freqs().unsqueeze(1) + class NeuronQwen3DecoderLayer(nn.Module): """ @@ -727,10 +785,11 @@ class NeuronQwen3DecoderLayer(nn.Module): def __init__(self, config: InferenceConfig): super().__init__() self.hidden_size = config.hidden_size - # self.self_attn = _Qwen3_MODULE_MAP[config.neuron_config.attn_cls]( + self.self_attn = NeuronQwen3Attention( config=config, tensor_model_parallel_group=get_tp_group(config) ) + self.mlp = NeuronQwen3MLP(config) logger.debug( f"Instantiating RMSNorm modules with hidden size {config.hidden_size} and EPS {config.rms_norm_eps}" @@ -750,15 +809,20 @@ def __init__(self, config: InferenceConfig): ) self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled - self.rmsnorm_quantize_kernel_enabled = ( - config.neuron_config.rmsnorm_quantize_kernel_enabled - ) - self.mlp_kernel_fuse_residual_add = ( - config.neuron_config.mlp_kernel_fuse_residual_add - ) + self.quantized_mlp_kernel_enabled = config.neuron_config.quantized_mlp_kernel_enabled + self.rmsnorm_quantize_kernel_enabled = config.neuron_config.rmsnorm_quantize_kernel_enabled + self.mlp_kernel_fuse_residual_add = config.neuron_config.mlp_kernel_fuse_residual_add + self.qkv_kernel_fuse_residual_add = config.neuron_config.qkv_kernel_fuse_residual_add self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.is_prefill_stage = config.neuron_config.is_prefill_stage self.config = config + if self.is_prefill_stage and self.config.neuron_config.is_mlp_quantized(): + # for CTE, quantized MLP kernel does not support fused rmsnorm + self.mlp_kernel_fused_rmsnorm = False + else: + self.mlp_kernel_fused_rmsnorm = not self.sequence_parallel_enabled + def forward( self, hidden_states: torch.Tensor, @@ -766,33 +830,49 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, adapter_ids=None, + rotary_position_ids: Optional[torch.LongTensor] = None, + residual: Optional[torch.Tensor] = None, # residual from previous layer used by QKV **kwargs, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - residual = hidden_states - + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]], Optional[torch.FloatTensor], Optional[torch.FloatTensor], Optional[torch.FloatTensor]]: + entry_hidden_states = hidden_states # RMSNorm (fused with QKV kernel when SP is disabled) - if ( - not self.qkv_kernel_enabled or self.sequence_parallel_enabled - ) and self.input_layernorm: + if (not self.qkv_kernel_enabled or self.sequence_parallel_enabled) and self.input_layernorm: hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + # produced another residual used by MLP + attn_output = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, adapter_ids=adapter_ids, rmsnorm=self.input_layernorm, + rotary_position_ids=rotary_position_ids, + residual=residual, **kwargs, ) + if attn_output.residual is None: + residual = entry_hidden_states # input to attention + else: + # residual will only be returned by attn/qkv if fuse add qkv kernel is enabled + assert self.qkv_kernel_fuse_residual_add, \ + "residual add before qkv should be computed in the previous layer, \ + unless qkv_kernel_fuse_residual_add is specified" + assert ( + not self.sequence_parallel_enabled + ), "qkv_kernel_fuse_residual_add should be off when sequence parallelism is enabled" + assert ( + self.qkv_kernel_enabled + ), "qkv_kernel_fuse_residual_add should be used with qkv_kernel_enabled" + residual = attn_output.residual + + hidden_states = attn_output.hidden_states if self.mlp_kernel_enabled and self.mlp_kernel_fuse_residual_add: - assert not self.sequence_parallel_enabled, ( - "mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled" - ) + assert ( + not self.sequence_parallel_enabled + ), "mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled" # First residual add handled in the MLP kernel hidden_states, residual = self.mlp( hidden_states, @@ -804,25 +884,35 @@ def forward( hidden_states = residual + hidden_states residual = hidden_states # RMSNorm (fused with QKV kernel when SP is disabled) - if not self.mlp_kernel_enabled or self.sequence_parallel_enabled: + if self.mlp_kernel_enabled and self.mlp_kernel_fused_rmsnorm: + rmsnorm = self.post_attention_layernorm + else: hidden_states = self.post_attention_layernorm(hidden_states) + rmsnorm = None hidden_states, _ = self.mlp( hidden_states, - rmsnorm=self.post_attention_layernorm, + rmsnorm=rmsnorm, adapter_ids=adapter_ids, ) - hidden_states = residual + hidden_states + # if fuse residual add with qkv, we leave this add to the next layer's QKV + # unless it is the last layer in which case we add it here + if not self.qkv_kernel_fuse_residual_add: + hidden_states = residual + hidden_states + residual = None # set to None to prevent it from being used again - outputs = (hidden_states, present_key_value, cos_cache, sin_cache) + # also return residual for QKV in the next layer + outputs = (hidden_states, attn_output.present_key_value, attn_output.cos_cache, attn_output.sin_cache, residual) return outputs class ResBlock(nn.Module): """ A Residual Block module. + This module performs a linear transformation followed by a SiLU activation, and then adds the result to the original input, creating a residual connection. + Args: hidden_size (int): The size of the hidden layers in the block. """ @@ -838,8 +928,10 @@ def __init__(self, hidden_size): def forward(self, x): """ Forward pass of the ResBlock. + Args: x (torch.Tensor): Input tensor. + Returns: torch.Tensor: Output after the residual connection and activation. """ @@ -853,9 +945,7 @@ class NeuronQwen3Model(NeuronBaseModel): def setup_attr_for_model(self, config: InferenceConfig): # Needed for init_inference_optimization() - self.on_device_sampling = ( - config.neuron_config.on_device_sampling_config is not None - ) + self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None self.tp_degree = config.neuron_config.tp_degree self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads @@ -874,7 +964,8 @@ def init_model(self, config: InferenceConfig): self.padding_idx, dtype=config.neuron_config.torch_dtype, shard_across_embedding=not config.neuron_config.vocab_parallel, - sequence_parallel_enabled=False, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + sequence_dimension=1, pad=True, tensor_model_parallel_group=get_tp_group(config), use_spmd_rank=config.neuron_config.vocab_parallel, @@ -900,30 +991,18 @@ def init_model(self, config: InferenceConfig): bias=False, ) - # In the target fp8 checkpoint, the 1st and last - # layers are not using fp8. - updated_configs = [] - for i in range(config.num_hidden_layers): - # TODO: Remove hardcoded code to have non-quantized MLPs for first and last decoder block - if i == 0 or i == config.num_hidden_layers - 1: - non_quant_config = copy.deepcopy(config) - non_quant_config.neuron_config.quantized_mlp_kernel_enabled = False - updated_configs.append(non_quant_config) - else: - updated_configs.append(config) - self.layers = nn.ModuleList( - [NeuronQwen3DecoderLayer(conf) for conf in updated_configs] - ) + updated_configs = get_updated_configs(config) + + self.layers = nn.ModuleList([NeuronQwen3DecoderLayer(conf) for conf in updated_configs]) + if not config.neuron_config.is_eagle_draft: self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) if config.neuron_config.is_eagle_draft: fc_bias = getattr(config, "fc_bias", False) - self.fc = ColumnParallelLinear( - config.hidden_size * 2, - config.hidden_size, - bias=fc_bias, - gather_output=True, + # replicate fc weights since activations are sequence sharded + self.fc = WeightGatheredColumnParallel( + config.hidden_size * 2, config.hidden_size, bias=fc_bias, gather_output=True, sequence_dimension=1 ) self.is_medusa = config.neuron_config.is_medusa self.num_medusa_heads = config.neuron_config.num_medusa_heads @@ -951,6 +1030,7 @@ class NeuronQwen3ForCausalLM(NeuronBaseForCausalLM): """ This class extends Qwen3ForCausalLM create traceable blocks for Neuron. + Args: Qwen3ForCausalLM (_type_): _description_ """ @@ -958,31 +1038,61 @@ class NeuronQwen3ForCausalLM(NeuronBaseForCausalLM): _model_cls = NeuronQwen3Model @staticmethod - def load_hf_model(model_path): - return Qwen3ForCausalLM.from_pretrained(model_path) + def load_hf_model(model_path, **kwargs): + return Qwen3ForCausalLM.from_pretrained(model_path, **kwargs) @staticmethod - def convert_hf_to_neuron_state_dict( - state_dict: dict, config: InferenceConfig - ) -> dict: + def convert_hf_to_neuron_state_dict(state_dict: dict, config: InferenceConfig) -> dict: """This function should be over-ridden in child classes as needed""" + neuron_config = config.neuron_config + # to facilitate rank usage in attention + num_layers = config.num_hidden_layers + tp_degree = neuron_config.tp_degree + for i in range(num_layers): + state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + """ + for every layer do the following transformations + gate_w_prime = (gate_w.T * gamma).T + up_w_prime = (up_w.T * gamma).T + """ + if ( + neuron_config.fused_rmsnorm_skip_gamma + and not neuron_config.sequence_parallel_enabled + ): + if neuron_config.mlp_kernel_enabled: + # MLP + state_dict[f"layers.{i}.mlp.gate_proj.weight"] = state_dict[ + f"layers.{i}.mlp.gate_proj.weight" + ] * state_dict[f"layers.{i}.input_layernorm.weight"].unsqueeze(0) + state_dict[f"layers.{i}.mlp.up_proj.weight"] = state_dict[ + f"layers.{i}.mlp.up_proj.weight" + ] * state_dict[f"layers.{i}.input_layernorm.weight"].unsqueeze(0) + + if neuron_config.qkv_kernel_enabled: + # QKV + state_dict[f"layers.{i}.self_attn.q_proj.weight"] = state_dict[ + f"layers.{i}.self_attn.q_proj.weight" + ] * state_dict[f"layers.{i}.input_layernorm.weight"].unsqueeze(0) + state_dict[f"layers.{i}.self_attn.k_proj.weight"] = state_dict[ + f"layers.{i}.self_attn.k_proj.weight" + ] * state_dict[f"layers.{i}.input_layernorm.weight"].unsqueeze(0) + state_dict[f"layers.{i}.self_attn.v_proj.weight"] = state_dict[ + f"layers.{i}.self_attn.v_proj.weight" + ] * state_dict[f"layers.{i}.input_layernorm.weight"].unsqueeze(0) + if neuron_config.fused_qkv: state_dict = convert_state_dict_to_fused_qkv(state_dict, config) if neuron_config.vocab_parallel: # TODO: this hack can be removed after replication_id is ready to use state_dict["embed_tokens.rank_util.rank"] = torch.arange( - 0, neuron_config.local_ranks_size + 0, neuron_config.local_ranks_size, dtype=torch.int32 ) - # to facilitate rank usage in attention - num_layers = config.num_hidden_layers - tp_degree = neuron_config.tp_degree - for i in range(num_layers): - state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( - 0, tp_degree, dtype=torch.int32 - ) # to facilitate rank usage in base model state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) return state_dict diff --git a/contributed/models/qwen3/qwen-3-test.ipynb b/contributed/models/qwen3/qwen-3-test.ipynb index 15a965d..a1b033d 100644 --- a/contributed/models/qwen3/qwen-3-test.ipynb +++ b/contributed/models/qwen3/qwen-3-test.ipynb @@ -1,5 +1,27 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip uninstall transformers --y\n", + "!pip install transformers==4.51.3\n", + "\n", + "# Installing collected packages: transformers\n", + "# ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "# neuronx-distributed-inference 0.3.5591+f50feae2 requires transformers==4.48.*, but you have transformers 4.51.3 which is incompatible.\n", + "# Successfully installed transformers-4.51.3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### You may ignore the error that nxdi is not compatible with transformers ==4.48.*" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -9,16 +31,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "libneuronxla 2.2.1630.0\n", - "neuronx-cc 2.17.194.0+d312836f\n", - "neuronx-distributed 0.11.0\n", - "neuronx-distributed-inference 0.2.0\n", - "torch-neuronx 2.5.1.2.6.0\n" + "libneuronxla 2.2.3493.0+78c3e78c\n", + "neuronx-cc 2.18.121.0+9e31e41a\n", + "neuronx-distributed 0.12.12111+cdd84048\n", + "neuronx-distributed-inference 0.3.5591+f50feae2\n", + "torch-neuronx 2.6.0.2.7.5413+113e6810\n", + "transformers 4.51.3\n" ] } ], "source": [ - "!pip list | grep neuron" + "!pip list | grep neuron\n", + "!pip list | grep transformers" ] }, { @@ -42,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -74,7 +98,7 @@ "metadata": {}, "outputs": [], "source": [ - "from modeling_qwen import Qwen3InferenceConfig, NeuronQwen3ForCausalLM\n", + "from modeling_qwen3 import Qwen3InferenceConfig, NeuronQwen3ForCausalLM\n", "\n", "def run_qwen3_compile():\n", " # Initialize configs and tokenizer.\n", @@ -138,7 +162,7 @@ "metadata": {}, "outputs": [], "source": [ - "from modeling_qwen import Qwen3InferenceConfig, NeuronQwen3ForCausalLM\n", + "from modeling_qwen3 import Qwen3InferenceConfig, NeuronQwen3ForCausalLM\n", "\n", "model = NeuronQwen3ForCausalLM(traced_model_path)\n", "model.load(traced_model_path)" @@ -146,9 +170,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "neuronx_distributed_inference.models.config.NeuronConfig" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "config = model.get_config_cls()\n", "config.get_neuron_config_cls()" @@ -156,27 +191,60 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "32" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model.config.num_attention_heads" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "8" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model.config.num_key_value_heads" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "4096" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model.config.hidden_size" ] @@ -205,25 +273,18 @@ " add_generation_prompt=True,\n", " enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.\n", ")\n", - "inputs = tokenizer([text], return_tensors=\"pt\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\nGenerating outputs...\")\n", + "inputs = tokenizer([text], return_tensors=\"pt\")\n", + "input_ids = inputs['input_ids'] \n", + "\n", "outputs = generation_model.generate(\n", - " **inputs,\n", + " input_ids=input_ids,\n", " max_new_tokens=512\n", ")" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -261,7 +322,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -270,7 +331,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -294,25 +355,17 @@ " add_generation_prompt=True,\n", " enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.\n", ")\n", - "inputs = tokenizer([text], return_tensors=\"pt\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\nGenerating outputs...\")\n", + "inputs = tokenizer([text], return_tensors=\"pt\")\n", + "input_ids = inputs['input_ids'] \n", "outputs = generation_model.generate(\n", - " **inputs,\n", + " input_ids=input_ids,\n", " max_new_tokens=1024\n", ")" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -388,7 +441,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -397,7 +450,7 @@ "9164" ] }, - "execution_count": 16, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -420,8 +473,8 @@ "metadata": {}, "outputs": [], "source": [ - "dir = '/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/'\n", - "!cp modeling_qwen.py {dir}" + "dir = '/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/'\n", + "!cp modeling_qwen3.py {dir}" ] }, { @@ -444,923 +497,145 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost\n", - "WARNING:root:Found libneuronpjrt.so. Setting PJRT_DEVICE=NEURON.\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", " from neuronx_distributed.modules.moe.blockwise import (\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", " from neuronx_distributed.modules.moe.blockwise import (\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", " from neuronx_distributed.modules.moe.blockwise import (\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/attention/utils.py:14: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/attention/utils.py:14: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", " from neuronx_distributed_inference.modules.custom_calls import neuron_cumsum\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: Set seed for `privateuseone` device does not take effect, please add API's `_is_in_bad_fork` and `manual_seed_all` to `privateuseone` device module.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:745: UserWarning: Set seed for `privateuseone` device does not take effect, please add API's `_is_in_bad_fork` and `manual_seed_all` to `privateuseone` device module.\n", " return fn(*args, **kwargs)\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_model.py:12: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_model.py:12: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.modules.attention.gqa import GQA, GroupQueryAttention_QKV\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_model.py:12: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", " from neuronx_distributed_inference.modules.attention.gqa import GQA, GroupQueryAttention_QKV\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/dbrx/modeling_dbrx.py:38: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_model.py:12: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.modules.attention.gqa import GQA, GroupQueryAttention_QKV\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/dbrx/modeling_dbrx.py:38: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/dbrx/modeling_dbrx.py:38: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", " from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py:22: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py:25: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", " from neuronx_distributed_inference.models.dbrx.modeling_dbrx import NeuronDbrxForCausalLM\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py:24: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py:27: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", " from neuronx_distributed_inference.models.mixtral.modeling_mixtral import NeuronMixtralForCausalLM\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/mllama/modeling_mllama.py:72: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/mllama/modeling_mllama.py:72: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", " from .modeling_mllama_vision import NeuronMllamaVisionModel # noqa: E402\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/utils/accuracy.py:29: UserWarning: Intel extension for pytorch not found. For faster CPU references install `intel-extension-for-pytorch`.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/utils/accuracy.py:29: UserWarning: Intel extension for pytorch not found. For faster CPU references install `intel-extension-for-pytorch`.\n", " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: Set seed for `privateuseone` device does not take effect, please add API's `_is_in_bad_fork` and `manual_seed_all` to `privateuseone` device module.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:745: UserWarning: Set seed for `privateuseone` device does not take effect, please add API's `_is_in_bad_fork` and `manual_seed_all` to `privateuseone` device module.\n", " return fn(*args, **kwargs)\n", "Loading configs...\n", - "WARNING:root:NeuronConfig init: Unexpected keyword arguments: {'model_type': 'qwen3', 'task_type': 'causal-lm', 'model_path': '/home/ubuntu/model_hf_qwen/qwen/', 'compiled_model_path': '/home/ubuntu/traced_model_qwen/qwen/logit', 'benchmark': True, 'check_accuracy_mode': , 'divergence_difference_tol': 0.001, 'prompts': ['To be, or not to be'], 'top_k': 1, 'top_p': 1.0, 'temperature': 1.0, 'do_sample': False, 'dynamic': False, 'pad_token_id': 151645, 'on_device_sampling': False, 'enable_torch_dist': False, 'enable_lora': False, 'skip_warmup': False, 'skip_compile': False, 'compile_only': False, 'hlo_debug': False}\n", + "WARNING:root:NeuronConfig init: Unexpected keyword arguments: {'model_type': 'qwen3', 'task_type': 'causal-lm', 'model_path': '/home/ubuntu/model_hf_qwen/qwen/', 'compiled_model_path': '/home/ubuntu/traced_model_qwen/qwen/logit', 'benchmark': True, 'check_accuracy_mode': , 'divergence_difference_tol': 0.001, 'prompts': ['To be, or not to be'], 'top_k': 1, 'top_p': 1.0, 'temperature': 1.0, 'do_sample': False, 'dynamic': False, 'pad_token_id': 151645, 'on_device_sampling': False, 'enable_torch_dist': False, 'enable_lora': False, 'max_loras': 1, 'max_lora_rank': 16, 'skip_warmup': False, 'skip_compile': False, 'compile_only': False, 'compile_dry_run': False, 'hlo_debug': False}\n", "\n", "Compiling and saving model...\n", "INFO:Neuron:Generating HLOs for the following models: ['context_encoding_model', 'token_generation_model']\n", - "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:588] > initializing tensor model parallel with size 8\n", - "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:589] > initializing pipeline model parallel with size 1\n", - "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:590] > initializing context model parallel with size 1\n", - "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:591] > initializing data parallel with size 1\n", - "[2025-05-14 14:09:05.945: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing world size to 8\n", - "[2025-05-14 14:09:05.945: I neuronx_distributed/parallel_layers/parallel_state.py:339] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", - "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:628] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", - "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:629] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:630] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:631] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 14:42:25.152: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing tensor model parallel with size 8\n", + "[2025-06-02 14:42:25.152: I neuronx_distributed/parallel_layers/parallel_state.py:593] > initializing pipeline model parallel with size 1\n", + "[2025-06-02 14:42:25.152: I neuronx_distributed/parallel_layers/parallel_state.py:594] > initializing context model parallel with size 1\n", + "[2025-06-02 14:42:25.152: I neuronx_distributed/parallel_layers/parallel_state.py:595] > initializing data parallel with size 1\n", + "[2025-06-02 14:42:25.153: I neuronx_distributed/parallel_layers/parallel_state.py:596] > initializing world size to 8\n", + "[2025-06-02 14:42:25.153: I neuronx_distributed/parallel_layers/parallel_state.py:343] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", + "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", + "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:634] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:635] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:636] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:637] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", "INFO:Neuron:Generating 1 hlos for key: context_encoding_model\n", "INFO:Neuron:Started loading module context_encoding_model\n", - "INFO:Neuron:Finished loading module context_encoding_model in 0.07737994194030762 seconds\n", + "INFO:Neuron:Finished loading module context_encoding_model in 0.0800018310546875 seconds\n", "INFO:Neuron:generating HLO: context_encoding_model, input example shape = torch.Size([1, 16])\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:476: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:478: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", " with torch.cuda.amp.autocast(enabled=False):\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 16]), dtype=torch.int32)\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 16]), dtype=torch.int32)\n", " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=3, shape=torch.Size([1]), dtype=torch.int32)\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=3, shape=torch.Size([1]), dtype=torch.int32)\n", " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=4, shape=torch.Size([1, 3]), dtype=torch.float32)\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=4, shape=torch.Size([1, 3]), dtype=torch.float32)\n", " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=5, shape=torch.Size([1]), dtype=torch.int32)\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=5, shape=torch.Size([1]), dtype=torch.int32)\n", " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=6, shape=torch.Size([1]), dtype=torch.int32)\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=6, shape=torch.Size([1]), dtype=torch.int32)\n", " warnings.warn(\n", + "INFO:Neuron:Finished generating HLO for context_encoding_model in 2.9438202381134033 seconds, input example shape = torch.Size([1, 16])\n", "INFO:Neuron:Generating 1 hlos for key: token_generation_model\n", "INFO:Neuron:Started loading module token_generation_model\n", - "INFO:Neuron:Finished loading module token_generation_model in 0.06693840026855469 seconds\n", + "INFO:Neuron:Finished loading module token_generation_model in 0.07202529907226562 seconds\n", "INFO:Neuron:generating HLO: token_generation_model, input example shape = torch.Size([1, 1])\n", - "INFO:Neuron:Started compilation for all HLOs\n", + "INFO:Neuron:Finished generating HLO for token_generation_model in 2.7211008071899414 seconds, input example shape = torch.Size([1, 1])\n", + "INFO:Neuron:Generated all HLOs in 5.901822566986084 seconds\n", + "INFO:Neuron:Starting compilation for the priority HLO\n", + "INFO:Neuron:'token_generation_model' is the priority model with bucket rank 0\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py:283: SyntaxWarning: str format compiler_flags is discouraged as its handling involves repeated joining and splitting, which can easily make mistakes if something is quoted or escaped. Use list[str] instead. Refer to documentation of the Python subprocess module for details.\n", + " warnings.warn(SyntaxWarning(\n", + "2025-06-02 14:42:31.000195: 14830 INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --framework=XLA /tmp/nxd_model/token_generation_model/_tp0_bk0/model.MODULE_567529bf12f9decb7698+91ef39e9.hlo_module.pb --output /tmp/nxd_model/token_generation_model/_tp0_bk0/model.MODULE_567529bf12f9decb7698+91ef39e9.neff --target=trn1 --auto-cast=none --model-type=transformer --tensorizer-options=--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2 --vectorize-strided-dma --lnc=1 --logfile=/tmp/nxd_model/token_generation_model/_tp0_bk0/log-neuron-cc.txt --enable-internal-neff-wrapper --verbose=35\n", "....Completed run_backend_driver.\n", "\n", "Compiler status PASS\n", - "INFO:Neuron:Done compilation for the priority HLO\n", + "INFO:Neuron:Done compilation for the priority HLO in 79.0590603351593 seconds\n", "INFO:Neuron:Updating the hlo module with optimized layout\n", - "INFO:Neuron:Done optimizing weight layout for all HLOs\n", - "..........Completed run_backend_driver.\n", + "INFO:Neuron:Done optimizing weight layout for all HLOs in 0.1603398323059082 seconds\n", + "INFO:Neuron:Starting compilation for all HLOs\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py:245: SyntaxWarning: str format compiler_flags is discouraged as its handling involves repeated joining and splitting, which can easily make mistakes if something is quoted or escaped. Use list[str] instead. Refer to documentation of the Python subprocess module for details.\n", + " warnings.warn(SyntaxWarning(\n", + "2025-06-02 14:43:50.000415: 14830 INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --framework=XLA /tmp/nxd_model/context_encoding_model/_tp0_bk0/model.MODULE_1a1c1c2d2921cf269654+d43b5474.hlo_module.pb --output /tmp/nxd_model/context_encoding_model/_tp0_bk0/model.MODULE_1a1c1c2d2921cf269654+d43b5474.neff --target=trn1 --auto-cast=none --model-type=transformer --tensorizer-options=--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2 --vectorize-strided-dma --lnc=1 -O1 --internal-hlo2tensorizer-options= --modular-flow-mac-threshold=10 --logfile=/tmp/nxd_model/context_encoding_model/_tp0_bk0/log-neuron-cc.txt --verbose=35\n", + ".Completed run_backend_driver.\n", "\n", "Compiler status PASS\n", - "INFO:Neuron:Finished Compilation for all HLOs\n", + "INFO:Neuron:Finished Compilation for all HLOs in 8.358030796051025 seconds\n", "..Completed run_backend_driver.\n", "\n", "Compiler status PASS\n", "INFO:Neuron:Done preparing weight layout transformation\n", - "INFO:Neuron:Sharding Weights for ranks: 0...7\n", - "[2025-05-14 14:14:12.537: I neuronx_distributed/parallel_layers/parallel_state.py:588] > initializing tensor model parallel with size 8\n", - "[2025-05-14 14:14:12.537: I neuronx_distributed/parallel_layers/parallel_state.py:589] > initializing pipeline model parallel with size 1\n", - "[2025-05-14 14:14:12.537: I neuronx_distributed/parallel_layers/parallel_state.py:590] > initializing context model parallel with size 1\n", - "[2025-05-14 14:14:12.538: I neuronx_distributed/parallel_layers/parallel_state.py:591] > initializing data parallel with size 1\n", - "[2025-05-14 14:14:12.538: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing world size to 8\n", - "[2025-05-14 14:14:12.540: I neuronx_distributed/parallel_layers/parallel_state.py:339] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", - "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:628] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", - "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:629] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:630] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:631] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: lm_head.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: embed_tokens.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.input_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.mlp.down_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.mlp.up_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: norm.weight. Will convert to torch.bfloat16\n", - " warnings.warn(\n", - "INFO:Neuron:Done Sharding weights in 252.63744661300007\n", - "Compiling and tracing time: 559.4677159970001 seconds\n", + "INFO:Neuron:Finished building model in 127.03067421913147 seconds\n", + "INFO:Neuron:SKIPPING pre-sharding the checkpoints. The checkpoints will be sharded during load time.\n", + "Compiling and tracing time: 127.04125871100041 seconds\n", "\n", "Loading model to Neuron...\n", + "INFO:Neuron:Sharding weights on load...\n", + "INFO:Neuron:Sharding Weights for ranks: 0...7\n", + "[2025-06-02 14:44:32.211: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing tensor model parallel with size 8\n", + "[2025-06-02 14:44:32.211: I neuronx_distributed/parallel_layers/parallel_state.py:593] > initializing pipeline model parallel with size 1\n", + "[2025-06-02 14:44:32.211: I neuronx_distributed/parallel_layers/parallel_state.py:594] > initializing context model parallel with size 1\n", + "[2025-06-02 14:44:32.212: I neuronx_distributed/parallel_layers/parallel_state.py:595] > initializing data parallel with size 1\n", + "[2025-06-02 14:44:32.212: I neuronx_distributed/parallel_layers/parallel_state.py:596] > initializing world size to 8\n", + "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:343] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", + "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", + "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:634] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:635] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:636] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:637] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "INFO:Neuron:Done Sharding weights in 1.0740620010001294\n", + "INFO:Neuron:Finished weights loading in 11.170274965999852 seconds\n", "INFO:Neuron:Warming up the model.\n", - "2025-May-14 14:18:35.0232 5872:7328 [2] nccl_net_ofi_rdma_init:7837 CCOM WARN NET/OFI OFI fi_getinfo() call failed: No data available\n", - "2025-May-14 14:18:35.0236 5872:7328 [2] nccl_net_ofi_create_plugin:261 CCOM WARN NET/OFI Unable to find a protocol that worked. Failing initialization.\n", - "2025-May-14 14:18:35.0239 5872:7328 [2] nccl_net_ofi_create_plugin:341 CCOM WARN NET/OFI aws-ofi-nccl initialization failed\n", - "2025-May-14 14:18:35.0242 5872:7328 [2] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed\n", - "2025-May-14 14:18:35.0245 5872:7328 [2] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?\n", - "INFO:Neuron:Warmup completed in 0.2721595764160156 seconds.\n", - "Total model loading time: 10.090576054999929 seconds\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:653: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", + "2025-Jun-02 14:44:43.0885 14830:15971 [0] nccl_net_ofi_create_plugin:211 CCOM WARN NET/OFI Failed to initialize sendrecv protocol\n", + "2025-Jun-02 14:44:43.0887 14830:15971 [0] nccl_net_ofi_create_plugin:334 CCOM WARN NET/OFI aws-ofi-nccl initialization failed\n", + "2025-Jun-02 14:44:43.0888 14830:15971 [0] nccl_net_ofi_init:155 CCOM WARN NET/OFI Initializing plugin failed\n", + "2025-Jun-02 14:44:43.0890 14830:15971 [0] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?\n", + "INFO:Neuron:Warmup completed in 0.2749016284942627 seconds.\n", + "Total model loading time: 11.964653464000548 seconds\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:653: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", " warnings.warn(\n", "\n", "Checking accuracy by logit matching\n", - "Loading checkpoint shards: 100%|██████████████████| 5/5 [00:01<00:00, 2.57it/s]\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/utils/accuracy.py:363: UserWarning: input_len + num_tokens_to_check exceeds max_context_length. If output divergences at an index greater than max_context_length, a ValueError will occur because the next input len exceeds max_context_length. To avoid this, set num_tokens_to_check to a value of max_context_length - input_len or less.\n", + " warnings.warn(\n", + "Loading checkpoint shards: 100%|██████████████████| 5/5 [00:01<00:00, 2.89it/s]\n", "`generation_config` default values have been modified to match model-specific defaults: {'do_sample': True, 'temperature': 0.6, 'top_p': 0.95}. If this is not desired, please set these values explicitly.\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:631: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:631: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", " warnings.warn(\n", - "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:636: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:636: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", " warnings.warn(\n", "Expected Output: [\", that is the question. Whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune\"] tensor([[ 11, 429, 374, 279, 3405, 13, 13139, 364, 83, 285,\n", " 13049, 1536, 304, 279, 3971, 311, 7676, 279, 1739, 819,\n", @@ -1380,34 +655,35 @@ "Prompts: ['To be, or not to be']\n", "Generated outputs:\n", "Output 0: To be, or not to be, that is the question. Whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune\n", + "Starting end-to-end benchmark with 20\n", "Benchmark completed and its result is as following\n", "{\n", " \"e2e_model\": {\n", - " \"latency_ms_p50\": 156.56781196594238,\n", - " \"latency_ms_p90\": 158.08086395263672,\n", - " \"latency_ms_p95\": 158.1140637397766,\n", - " \"latency_ms_p99\": 158.28602075576782,\n", - " \"latency_ms_p100\": 158.32901000976562,\n", - " \"latency_ms_avg\": 156.99772834777832,\n", - " \"throughput\": 203.82460521412273\n", + " \"latency_ms_p50\": 169.31116580963135,\n", + " \"latency_ms_p90\": 172.9245901107788,\n", + " \"latency_ms_p95\": 174.3390679359436,\n", + " \"latency_ms_p99\": 174.82486009597778,\n", + " \"latency_ms_p100\": 174.94630813598633,\n", + " \"latency_ms_avg\": 169.6009874343872,\n", + " \"throughput\": 188.67814677305284\n", " },\n", " \"context_encoding_model\": {\n", - " \"latency_ms_p50\": 10.202646255493164,\n", - " \"latency_ms_p90\": 10.224390029907227,\n", - " \"latency_ms_p95\": 10.22493839263916,\n", - " \"latency_ms_p99\": 10.226750373840332,\n", - " \"latency_ms_p100\": 10.227203369140625,\n", - " \"latency_ms_avg\": 10.201811790466309,\n", - " \"throughput\": 1568.348870634151\n", + " \"latency_ms_p50\": 13.715386390686035,\n", + " \"latency_ms_p90\": 13.958406448364258,\n", + " \"latency_ms_p95\": 13.969480991363525,\n", + " \"latency_ms_p99\": 13.981258869171143,\n", + " \"latency_ms_p100\": 13.984203338623047,\n", + " \"latency_ms_avg\": 13.787257671356201,\n", + " \"throughput\": 1160.4918382892702\n", " },\n", " \"token_generation_model\": {\n", - " \"latency_ms_p50\": 8.858323097229004,\n", - " \"latency_ms_p90\": 8.903312683105469,\n", - " \"latency_ms_p95\": 9.238588809967041,\n", - " \"latency_ms_p99\": 9.264287948608398,\n", - " \"latency_ms_p100\": 9.28950309753418,\n", - " \"latency_ms_avg\": 8.88296922047933,\n", - " \"throughput\": 120.07996877975322\n", + " \"latency_ms_p50\": 8.931398391723633,\n", + " \"latency_ms_p90\": 9.162139892578125,\n", + " \"latency_ms_p95\": 9.23851728439331,\n", + " \"latency_ms_p99\": 9.780135154724094,\n", + " \"latency_ms_p100\": 12.94398307800293,\n", + " \"latency_ms_avg\": 9.013524055480957,\n", + " \"throughput\": 118.34069117705926\n", " }\n", "}\n", "Completed saving result to benchmark_report.json\n" @@ -1432,11 +708,18 @@ " --check-accuracy-mode logit-matching \\\n", " --benchmark" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "aws_neuronx_venv_pytorch_2_5_nxd_inference", + "display_name": "aws_neuronx_venv_pytorch_2_6_nxd_inference", "language": "python", "name": "python3" }, From 77ced34a612bf3bddc396172115208150cf9c3f1 Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Mon, 2 Jun 2025 13:19:23 -0400 Subject: [PATCH 6/6] larger logit val shape --- contributed/models/qwen3/qwen-3-test.ipynb | 444 +++++++++++++++------ 1 file changed, 330 insertions(+), 114 deletions(-) diff --git a/contributed/models/qwen3/qwen-3-test.ipynb b/contributed/models/qwen3/qwen-3-test.ipynb index a1b033d..3ea11c9 100644 --- a/contributed/models/qwen3/qwen-3-test.ipynb +++ b/contributed/models/qwen3/qwen-3-test.ipynb @@ -11,15 +11,8 @@ "\n", "# Installing collected packages: transformers\n", "# ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "# neuronx-distributed-inference 0.3.5591+f50feae2 requires transformers==4.48.*, but you have transformers 4.51.3 which is incompatible.\n", - "# Successfully installed transformers-4.51.3" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### You may ignore the error that nxdi is not compatible with transformers ==4.48.*" + "# neuronx-distributed-inference 0.3.5591+f50feae2 requires transformers==4.48.*, but you have transformers 4.52.4 which is incompatible.\n", + "# Successfully installed transformers-4.52.4" ] }, { @@ -85,6 +78,13 @@ "snapshot_download(\"Qwen/Qwen3-8B\", local_dir=model_path)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### You may ignore the error that nxdi is not compatible with transformers > 4.48" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -94,9 +94,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/build-on-trainium-workshop/contributed/models/qwen3/modeling_qwen3.py:61: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase\n", + "/home/ubuntu/build-on-trainium-workshop/contributed/models/qwen3/modeling_qwen3.py:61: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase\n" + ] + } + ], "source": [ "from modeling_qwen3 import Qwen3InferenceConfig, NeuronQwen3ForCausalLM\n", "\n", @@ -497,7 +508,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -535,29 +546,29 @@ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:745: UserWarning: Set seed for `privateuseone` device does not take effect, please add API's `_is_in_bad_fork` and `manual_seed_all` to `privateuseone` device module.\n", " return fn(*args, **kwargs)\n", "Loading configs...\n", - "WARNING:root:NeuronConfig init: Unexpected keyword arguments: {'model_type': 'qwen3', 'task_type': 'causal-lm', 'model_path': '/home/ubuntu/model_hf_qwen/qwen/', 'compiled_model_path': '/home/ubuntu/traced_model_qwen/qwen/logit', 'benchmark': True, 'check_accuracy_mode': , 'divergence_difference_tol': 0.001, 'prompts': ['To be, or not to be'], 'top_k': 1, 'top_p': 1.0, 'temperature': 1.0, 'do_sample': False, 'dynamic': False, 'pad_token_id': 151645, 'on_device_sampling': False, 'enable_torch_dist': False, 'enable_lora': False, 'max_loras': 1, 'max_lora_rank': 16, 'skip_warmup': False, 'skip_compile': False, 'compile_only': False, 'compile_dry_run': False, 'hlo_debug': False}\n", + "WARNING:root:NeuronConfig init: Unexpected keyword arguments: {'model_type': 'qwen3', 'task_type': 'causal-lm', 'model_path': '/home/ubuntu/model_hf_qwen/qwen/', 'compiled_model_path': '/home/ubuntu/traced_model_qwen/qwen/logit', 'benchmark': True, 'check_accuracy_mode': , 'divergence_difference_tol': 0.001, 'num_tokens_to_check': 400, 'prompts': ['What is 83 * 110 + 34?'], 'top_k': 1, 'top_p': 1.0, 'temperature': 1.0, 'do_sample': False, 'dynamic': False, 'pad_token_id': 151645, 'on_device_sampling': False, 'enable_torch_dist': False, 'enable_lora': False, 'max_loras': 1, 'max_lora_rank': 16, 'skip_warmup': False, 'skip_compile': False, 'compile_only': False, 'compile_dry_run': False, 'hlo_debug': False}\n", "\n", "Compiling and saving model...\n", "INFO:Neuron:Generating HLOs for the following models: ['context_encoding_model', 'token_generation_model']\n", - "[2025-06-02 14:42:25.152: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing tensor model parallel with size 8\n", - "[2025-06-02 14:42:25.152: I neuronx_distributed/parallel_layers/parallel_state.py:593] > initializing pipeline model parallel with size 1\n", - "[2025-06-02 14:42:25.152: I neuronx_distributed/parallel_layers/parallel_state.py:594] > initializing context model parallel with size 1\n", - "[2025-06-02 14:42:25.152: I neuronx_distributed/parallel_layers/parallel_state.py:595] > initializing data parallel with size 1\n", - "[2025-06-02 14:42:25.153: I neuronx_distributed/parallel_layers/parallel_state.py:596] > initializing world size to 8\n", - "[2025-06-02 14:42:25.153: I neuronx_distributed/parallel_layers/parallel_state.py:343] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", - "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", - "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:634] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:635] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:636] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-06-02 14:42:25.154: I neuronx_distributed/parallel_layers/parallel_state.py:637] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 17:05:39.464: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing tensor model parallel with size 8\n", + "[2025-06-02 17:05:39.465: I neuronx_distributed/parallel_layers/parallel_state.py:593] > initializing pipeline model parallel with size 1\n", + "[2025-06-02 17:05:39.465: I neuronx_distributed/parallel_layers/parallel_state.py:594] > initializing context model parallel with size 1\n", + "[2025-06-02 17:05:39.465: I neuronx_distributed/parallel_layers/parallel_state.py:595] > initializing data parallel with size 1\n", + "[2025-06-02 17:05:39.465: I neuronx_distributed/parallel_layers/parallel_state.py:596] > initializing world size to 8\n", + "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:343] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", + "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", + "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:634] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:635] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:636] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:637] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", "INFO:Neuron:Generating 1 hlos for key: context_encoding_model\n", "INFO:Neuron:Started loading module context_encoding_model\n", - "INFO:Neuron:Finished loading module context_encoding_model in 0.0800018310546875 seconds\n", - "INFO:Neuron:generating HLO: context_encoding_model, input example shape = torch.Size([1, 16])\n", + "INFO:Neuron:Finished loading module context_encoding_model in 0.08188652992248535 seconds\n", + "INFO:Neuron:generating HLO: context_encoding_model, input example shape = torch.Size([1, 512])\n", "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:478: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", " with torch.cuda.amp.autocast(enabled=False):\n", - "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 16]), dtype=torch.int32)\n", + "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 512]), dtype=torch.int32)\n", " warnings.warn(\n", "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=3, shape=torch.Size([1]), dtype=torch.int32)\n", " warnings.warn(\n", @@ -567,126 +578,330 @@ " warnings.warn(\n", "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=6, shape=torch.Size([1]), dtype=torch.int32)\n", " warnings.warn(\n", - "INFO:Neuron:Finished generating HLO for context_encoding_model in 2.9438202381134033 seconds, input example shape = torch.Size([1, 16])\n", + "INFO:Neuron:Finished generating HLO for context_encoding_model in 2.901122808456421 seconds, input example shape = torch.Size([1, 512])\n", "INFO:Neuron:Generating 1 hlos for key: token_generation_model\n", "INFO:Neuron:Started loading module token_generation_model\n", - "INFO:Neuron:Finished loading module token_generation_model in 0.07202529907226562 seconds\n", + "INFO:Neuron:Finished loading module token_generation_model in 0.07949113845825195 seconds\n", "INFO:Neuron:generating HLO: token_generation_model, input example shape = torch.Size([1, 1])\n", - "INFO:Neuron:Finished generating HLO for token_generation_model in 2.7211008071899414 seconds, input example shape = torch.Size([1, 1])\n", - "INFO:Neuron:Generated all HLOs in 5.901822566986084 seconds\n", + "INFO:Neuron:Finished generating HLO for token_generation_model in 2.800884246826172 seconds, input example shape = torch.Size([1, 1])\n", + "INFO:Neuron:Generated all HLOs in 5.948296308517456 seconds\n", "INFO:Neuron:Starting compilation for the priority HLO\n", "INFO:Neuron:'token_generation_model' is the priority model with bucket rank 0\n", "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py:283: SyntaxWarning: str format compiler_flags is discouraged as its handling involves repeated joining and splitting, which can easily make mistakes if something is quoted or escaped. Use list[str] instead. Refer to documentation of the Python subprocess module for details.\n", " warnings.warn(SyntaxWarning(\n", - "2025-06-02 14:42:31.000195: 14830 INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --framework=XLA /tmp/nxd_model/token_generation_model/_tp0_bk0/model.MODULE_567529bf12f9decb7698+91ef39e9.hlo_module.pb --output /tmp/nxd_model/token_generation_model/_tp0_bk0/model.MODULE_567529bf12f9decb7698+91ef39e9.neff --target=trn1 --auto-cast=none --model-type=transformer --tensorizer-options=--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2 --vectorize-strided-dma --lnc=1 --logfile=/tmp/nxd_model/token_generation_model/_tp0_bk0/log-neuron-cc.txt --enable-internal-neff-wrapper --verbose=35\n", - "....Completed run_backend_driver.\n", - "\n", - "Compiler status PASS\n", - "INFO:Neuron:Done compilation for the priority HLO in 79.0590603351593 seconds\n", + "2025-06-02 17:05:45.000556: 16339 INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.18.121.0+9e31e41a/MODULE_ff123d67d8e9ddda72ca+91ef39e9/model.neff\n", + "INFO:Neuron:Done compilation for the priority HLO in 0.18962407112121582 seconds\n", "INFO:Neuron:Updating the hlo module with optimized layout\n", - "INFO:Neuron:Done optimizing weight layout for all HLOs in 0.1603398323059082 seconds\n", + "INFO:Neuron:Done optimizing weight layout for all HLOs in 0.15496611595153809 seconds\n", "INFO:Neuron:Starting compilation for all HLOs\n", "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py:245: SyntaxWarning: str format compiler_flags is discouraged as its handling involves repeated joining and splitting, which can easily make mistakes if something is quoted or escaped. Use list[str] instead. Refer to documentation of the Python subprocess module for details.\n", " warnings.warn(SyntaxWarning(\n", - "2025-06-02 14:43:50.000415: 14830 INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --framework=XLA /tmp/nxd_model/context_encoding_model/_tp0_bk0/model.MODULE_1a1c1c2d2921cf269654+d43b5474.hlo_module.pb --output /tmp/nxd_model/context_encoding_model/_tp0_bk0/model.MODULE_1a1c1c2d2921cf269654+d43b5474.neff --target=trn1 --auto-cast=none --model-type=transformer --tensorizer-options=--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2 --vectorize-strided-dma --lnc=1 -O1 --internal-hlo2tensorizer-options= --modular-flow-mac-threshold=10 --logfile=/tmp/nxd_model/context_encoding_model/_tp0_bk0/log-neuron-cc.txt --verbose=35\n", - ".Completed run_backend_driver.\n", - "\n", - "Compiler status PASS\n", - "INFO:Neuron:Finished Compilation for all HLOs in 8.358030796051025 seconds\n", + "2025-06-02 17:05:45.000888: 16339 INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.18.121.0+9e31e41a/MODULE_f6025f1aaa134ee9ebd5+d43b5474/model.neff\n", + "INFO:Neuron:Finished Compilation for all HLOs in 0.13458752632141113 seconds\n", "..Completed run_backend_driver.\n", "\n", "Compiler status PASS\n", "INFO:Neuron:Done preparing weight layout transformation\n", - "INFO:Neuron:Finished building model in 127.03067421913147 seconds\n", + "INFO:Neuron:Finished building model in 40.27043128013611 seconds\n", "INFO:Neuron:SKIPPING pre-sharding the checkpoints. The checkpoints will be sharded during load time.\n", - "Compiling and tracing time: 127.04125871100041 seconds\n", + "Compiling and tracing time: 40.28253860299992 seconds\n", "\n", "Loading model to Neuron...\n", "INFO:Neuron:Sharding weights on load...\n", "INFO:Neuron:Sharding Weights for ranks: 0...7\n", - "[2025-06-02 14:44:32.211: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing tensor model parallel with size 8\n", - "[2025-06-02 14:44:32.211: I neuronx_distributed/parallel_layers/parallel_state.py:593] > initializing pipeline model parallel with size 1\n", - "[2025-06-02 14:44:32.211: I neuronx_distributed/parallel_layers/parallel_state.py:594] > initializing context model parallel with size 1\n", - "[2025-06-02 14:44:32.212: I neuronx_distributed/parallel_layers/parallel_state.py:595] > initializing data parallel with size 1\n", - "[2025-06-02 14:44:32.212: I neuronx_distributed/parallel_layers/parallel_state.py:596] > initializing world size to 8\n", - "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:343] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", - "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", - "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:634] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:635] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:636] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "[2025-06-02 14:44:32.213: I neuronx_distributed/parallel_layers/parallel_state.py:637] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", - "INFO:Neuron:Done Sharding weights in 1.0740620010001294\n", - "INFO:Neuron:Finished weights loading in 11.170274965999852 seconds\n", + "[2025-06-02 17:06:19.765: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing tensor model parallel with size 8\n", + "[2025-06-02 17:06:19.765: I neuronx_distributed/parallel_layers/parallel_state.py:593] > initializing pipeline model parallel with size 1\n", + "[2025-06-02 17:06:19.765: I neuronx_distributed/parallel_layers/parallel_state.py:594] > initializing context model parallel with size 1\n", + "[2025-06-02 17:06:19.765: I neuronx_distributed/parallel_layers/parallel_state.py:595] > initializing data parallel with size 1\n", + "[2025-06-02 17:06:19.765: I neuronx_distributed/parallel_layers/parallel_state.py:596] > initializing world size to 8\n", + "[2025-06-02 17:06:19.766: I neuronx_distributed/parallel_layers/parallel_state.py:343] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", + "[2025-06-02 17:06:19.766: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", + "[2025-06-02 17:06:19.766: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 17:06:19.767: I neuronx_distributed/parallel_layers/parallel_state.py:634] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 17:06:19.767: I neuronx_distributed/parallel_layers/parallel_state.py:635] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 17:06:19.767: I neuronx_distributed/parallel_layers/parallel_state.py:636] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-06-02 17:06:19.767: I neuronx_distributed/parallel_layers/parallel_state.py:637] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "INFO:Neuron:Done Sharding weights in 1.1175042840004608\n", + "INFO:Neuron:Finished weights loading in 11.10367280399987 seconds\n", "INFO:Neuron:Warming up the model.\n", - "2025-Jun-02 14:44:43.0885 14830:15971 [0] nccl_net_ofi_create_plugin:211 CCOM WARN NET/OFI Failed to initialize sendrecv protocol\n", - "2025-Jun-02 14:44:43.0887 14830:15971 [0] nccl_net_ofi_create_plugin:334 CCOM WARN NET/OFI aws-ofi-nccl initialization failed\n", - "2025-Jun-02 14:44:43.0888 14830:15971 [0] nccl_net_ofi_init:155 CCOM WARN NET/OFI Initializing plugin failed\n", - "2025-Jun-02 14:44:43.0890 14830:15971 [0] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?\n", - "INFO:Neuron:Warmup completed in 0.2749016284942627 seconds.\n", - "Total model loading time: 11.964653464000548 seconds\n", + "2025-Jun-02 17:06:31.0383 16339:16939 [2] nccl_net_ofi_create_plugin:211 CCOM WARN NET/OFI Failed to initialize sendrecv protocol\n", + "2025-Jun-02 17:06:31.0384 16339:16939 [2] nccl_net_ofi_create_plugin:334 CCOM WARN NET/OFI aws-ofi-nccl initialization failed\n", + "2025-Jun-02 17:06:31.0385 16339:16939 [2] nccl_net_ofi_init:155 CCOM WARN NET/OFI Initializing plugin failed\n", + "2025-Jun-02 17:06:31.0386 16339:16939 [2] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?\n", + "INFO:Neuron:Warmup completed in 0.30846428871154785 seconds.\n", + "Total model loading time: 11.940343698000106 seconds\n", "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:653: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", " warnings.warn(\n", "\n", "Checking accuracy by logit matching\n", - "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/utils/accuracy.py:363: UserWarning: input_len + num_tokens_to_check exceeds max_context_length. If output divergences at an index greater than max_context_length, a ValueError will occur because the next input len exceeds max_context_length. To avoid this, set num_tokens_to_check to a value of max_context_length - input_len or less.\n", - " warnings.warn(\n", "Loading checkpoint shards: 100%|██████████████████| 5/5 [00:01<00:00, 2.89it/s]\n", "`generation_config` default values have been modified to match model-specific defaults: {'do_sample': True, 'temperature': 0.6, 'top_p': 0.95}. If this is not desired, please set these values explicitly.\n", "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:631: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", " warnings.warn(\n", "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:636: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", " warnings.warn(\n", - "Expected Output: [\", that is the question. Whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune\"] tensor([[ 11, 429, 374, 279, 3405, 13, 13139, 364, 83, 285,\n", - " 13049, 1536, 304, 279, 3971, 311, 7676, 279, 1739, 819,\n", - " 323, 36957, 315, 54488, 32315]])\n", - "Expected Logits Shape: torch.Size([25, 1, 151936])\n", + "Expected Output: [' Also, can you explain the steps involved in solving this?\\n\\nTo solve 83 * 110 + 34, we follow the order of operations (PEMDAS/BODMAS), which means we handle multiplication before addition. \\n\\nFirst, calculate the multiplication part: 83 * 110. \\nTo make this easier, note that multiplying by 110 is the same as multiplying by 100 and then adding 10 times the number. So:\\n83 * 110 = 83 * (100 + 10) = (83 * 100) + (83 * 10) = 8300 + 830 = 9130.\\n\\nNext, add 34 to the result:\\n9130 + 34 = 9164.\\n\\nSo, the final answer is 9164.\\n\\n---\\n\\nWhat is 12 * 12 * 12? Can you explain how to compute this step by step?\\n\\nTo compute 12 * 12 * 12, we can break it down into steps. First, multiply the first two 12s:\\n\\n12 * 12 = 144.\\n\\nThen, multiply the result by the third 12:\\n144 * 12.\\n\\nTo compute 144 * 12, we can split it into:\\n144 * 10 = 1440,\\n144 * 2 = 288.\\n\\nAdding these together: 1440 + 288 = 1728.\\n\\nSo, 12 * 12 * 12 = 1728.\\n\\n---\\n\\nWhat is 12 * 12 * 12 * 12? Can you explain the process?\\n\\nTo compute '] tensor([[ 7281, 11, 646, 498, 10339, 279, 7354, 6398, 304, 21828,\n", + " 419, 1939, 1249, 11625, 220, 23, 18, 353, 220, 16,\n", + " 16, 15, 488, 220, 18, 19, 11, 582, 1795, 279,\n", + " 1973, 315, 7525, 320, 1740, 6076, 1911, 16276, 2069, 40092,\n", + " 701, 892, 3363, 582, 3705, 46444, 1573, 5256, 13, 4710,\n", + " 5338, 11, 11047, 279, 46444, 949, 25, 220, 23, 18,\n", + " 353, 220, 16, 16, 15, 13, 715, 1249, 1281, 419,\n", + " 8661, 11, 5185, 429, 84192, 553, 220, 16, 16, 15,\n", + " 374, 279, 1852, 438, 84192, 553, 220, 16, 15, 15,\n", + " 323, 1221, 7842, 220, 16, 15, 3039, 279, 1372, 13,\n", + " 2055, 510, 23, 18, 353, 220, 16, 16, 15, 284,\n", + " 220, 23, 18, 353, 320, 16, 15, 15, 488, 220,\n", + " 16, 15, 8, 284, 320, 23, 18, 353, 220, 16,\n", + " 15, 15, 8, 488, 320, 23, 18, 353, 220, 16,\n", + " 15, 8, 284, 220, 23, 18, 15, 15, 488, 220,\n", + " 23, 18, 15, 284, 220, 24, 16, 18, 15, 382,\n", + " 5847, 11, 912, 220, 18, 19, 311, 279, 1102, 510,\n", + " 24, 16, 18, 15, 488, 220, 18, 19, 284, 220,\n", + " 24, 16, 21, 19, 382, 4416, 11, 279, 1590, 4226,\n", + " 374, 220, 24, 16, 21, 19, 382, 44364, 3838, 374,\n", + " 220, 16, 17, 353, 220, 16, 17, 353, 220, 16,\n", + " 17, 30, 2980, 498, 10339, 1246, 311, 12564, 419, 3019,\n", + " 553, 3019, 1939, 1249, 12564, 220, 16, 17, 353, 220,\n", + " 16, 17, 353, 220, 16, 17, 11, 582, 646, 1438,\n", + " 432, 1495, 1119, 7354, 13, 5512, 11, 30270, 279, 1156,\n", + " 1378, 220, 16, 17, 82, 1447, 16, 17, 353, 220,\n", + " 16, 17, 284, 220, 16, 19, 19, 382, 12209, 11,\n", + " 30270, 279, 1102, 553, 279, 4843, 220, 16, 17, 510,\n", + " 16, 19, 19, 353, 220, 16, 17, 382, 1249, 12564,\n", + " 220, 16, 19, 19, 353, 220, 16, 17, 11, 582,\n", + " 646, 6718, 432, 1119, 510, 16, 19, 19, 353, 220,\n", + " 16, 15, 284, 220, 16, 19, 19, 15, 345, 16,\n", + " 19, 19, 353, 220, 17, 284, 220, 17, 23, 23,\n", + " 382, 32308, 1493, 3786, 25, 220, 16, 19, 19, 15,\n", + " 488, 220, 17, 23, 23, 284, 220, 16, 22, 17,\n", + " 23, 382, 4416, 11, 220, 16, 17, 353, 220, 16,\n", + " 17, 353, 220, 16, 17, 284, 220, 16, 22, 17,\n", + " 23, 382, 44364, 3838, 374, 220, 16, 17, 353, 220,\n", + " 16, 17, 353, 220, 16, 17, 353, 220, 16, 17,\n", + " 30, 2980, 498, 10339, 279, 1882, 1939, 1249, 12564, 220]])\n", + "Expected Logits Shape: torch.Size([400, 1, 151936])\n", "HuggingFaceGenerationAdapter has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n", " - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes\n", " - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n", " - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n", - "Actual Output: [\", that is the question. Whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune\"] tensor([[ 11, 429, 374, 279, 3405, 13, 13139, 364, 83, 285,\n", - " 13049, 1536, 304, 279, 3971, 311, 7676, 279, 1739, 819,\n", - " 323, 36957, 315, 54488, 32315]])\n", - "Actual Logits Shape: torch.Size([25, 1, 151936])\n", - "Passed logits validation!\n", + "Actual Output: [' Also, can you explain the steps involved in solving this?\\n\\nTo solve 83 * 110 + 34, we follow the order of operations (PEMDAS/BODMAS), which means we handle multiplication before addition. \\n\\nFirst, calculate the multiplication part: 83 * 110. \\nTo make this easier, note that multiplying by 110 is the same as multiplying by 100 and then adding 10 times the number. So:\\n83 * 110 = 83 * (100 + 10) = (83 * 100) + (83 * 10) = 8300 + 830 = 9130.\\n\\nNext, add 34 to the result:\\n9130 + 34 = 9164.\\n\\nSo, the final answer is 9164.\\n\\n---\\n\\nWhat is 1000 - 100 + 10 - 1? Can you walk me through the steps?\\n\\nTo solve 1000 - 100 + 10 - 1, we follow the order of operations, which in this case is left to right since all operations are addition and subtraction.\\n\\nStart with 1000 - 100 = 900.\\n\\nThen, 900 + 10 = 910.\\n\\nFinally, 910 - 1 = 909.\\n\\nSo, the final answer is 909.\\n\\n---\\n\\nWhat is 1000 - 100 * 10? Let me make sure I do this correctly.\\n\\nTo solve 1000 - 100 * 10, we again follow the order of operations: multiplication comes before subtraction.\\n\\nFirst, calculate the multiplication: 100 * 10'] tensor([[ 7281, 11, 646, 498, 10339, 279, 7354, 6398, 304, 21828,\n", + " 419, 1939, 1249, 11625, 220, 23, 18, 353, 220, 16,\n", + " 16, 15, 488, 220, 18, 19, 11, 582, 1795, 279,\n", + " 1973, 315, 7525, 320, 1740, 6076, 1911, 16276, 2069, 40092,\n", + " 701, 892, 3363, 582, 3705, 46444, 1573, 5256, 13, 4710,\n", + " 5338, 11, 11047, 279, 46444, 949, 25, 220, 23, 18,\n", + " 353, 220, 16, 16, 15, 13, 715, 1249, 1281, 419,\n", + " 8661, 11, 5185, 429, 84192, 553, 220, 16, 16, 15,\n", + " 374, 279, 1852, 438, 84192, 553, 220, 16, 15, 15,\n", + " 323, 1221, 7842, 220, 16, 15, 3039, 279, 1372, 13,\n", + " 2055, 510, 23, 18, 353, 220, 16, 16, 15, 284,\n", + " 220, 23, 18, 353, 320, 16, 15, 15, 488, 220,\n", + " 16, 15, 8, 284, 320, 23, 18, 353, 220, 16,\n", + " 15, 15, 8, 488, 320, 23, 18, 353, 220, 16,\n", + " 15, 8, 284, 220, 23, 18, 15, 15, 488, 220,\n", + " 23, 18, 15, 284, 220, 24, 16, 18, 15, 382,\n", + " 5847, 11, 912, 220, 18, 19, 311, 279, 1102, 510,\n", + " 24, 16, 18, 15, 488, 220, 18, 19, 284, 220,\n", + " 24, 16, 21, 19, 382, 4416, 11, 279, 1590, 4226,\n", + " 374, 220, 24, 16, 21, 19, 382, 44364, 3838, 374,\n", + " 220, 16, 15, 15, 15, 481, 220, 16, 15, 15,\n", + " 488, 220, 16, 15, 481, 220, 16, 30, 2980, 498,\n", + " 4227, 752, 1526, 279, 7354, 1939, 1249, 11625, 220, 16,\n", + " 15, 15, 15, 481, 220, 16, 15, 15, 488, 220,\n", + " 16, 15, 481, 220, 16, 11, 582, 1795, 279, 1973,\n", + " 315, 7525, 11, 892, 304, 419, 1142, 374, 2115, 311,\n", + " 1290, 2474, 678, 7525, 525, 5256, 323, 75240, 382, 3479,\n", + " 448, 220, 16, 15, 15, 15, 481, 220, 16, 15,\n", + " 15, 284, 220, 24, 15, 15, 382, 12209, 11, 220,\n", + " 24, 15, 15, 488, 220, 16, 15, 284, 220, 24,\n", + " 16, 15, 382, 23949, 11, 220, 24, 16, 15, 481,\n", + " 220, 16, 284, 220, 24, 15, 24, 382, 4416, 11,\n", + " 279, 1590, 4226, 374, 220, 24, 15, 24, 382, 44364,\n", + " 3838, 374, 220, 16, 15, 15, 15, 481, 220, 16,\n", + " 15, 15, 353, 220, 16, 15, 30, 6771, 752, 1281,\n", + " 2704, 358, 653, 419, 12440, 382, 1249, 11625, 220, 16,\n", + " 15, 15, 15, 481, 220, 16, 15, 15, 353, 220,\n", + " 16, 15, 11, 582, 1549, 1795, 279, 1973, 315, 7525,\n", + " 25, 46444, 4041, 1573, 75240, 382, 5338, 11, 11047, 279,\n", + " 46444, 25, 220, 16, 15, 15, 353, 220, 16, 15]])\n", + "Actual Logits Shape: torch.Size([400, 1, 151936])\n", + "Actual Output: ['0 * 120? Can you explain how to compute this?\\n\\nTo compute 120 * 120, we can recognize that this is the same as 12 * 12 * 100. \\n\\nFirst, calculate 12 * 12 = 144. Then, multiply by 100:\\n144 * 100 = 14,400.\\n\\nAlternatively, you can use the standard multiplication method:\\n120\\nx120\\n------\\n(120 * 0) = 0\\n(120 * 20) = 2400\\n(120 * 100) = 12000\\nAdding these together: 0 + 2400 + 12000 = 14400.\\n\\nSo, the final answer is 14,40'] tensor([[ 15, 353, 220, 16, 17, 15, 30, 2980, 498, 10339,\n", + " 1246, 311, 12564, 419, 1939, 1249, 12564, 220, 16, 17,\n", + " 15, 353, 220, 16, 17, 15, 11, 582, 646, 15282,\n", + " 429, 419, 374, 279, 1852, 438, 220, 16, 17, 353,\n", + " 220, 16, 17, 353, 220, 16, 15, 15, 13, 4710,\n", + " 5338, 11, 11047, 220, 16, 17, 353, 220, 16, 17,\n", + " 284, 220, 16, 19, 19, 13, 5005, 11, 30270, 553,\n", + " 220, 16, 15, 15, 510, 16, 19, 19, 353, 220,\n", + " 16, 15, 15, 284, 220, 16, 19, 11, 19, 15,\n", + " 15, 382, 92014, 11, 498, 646, 990, 279, 5297, 46444,\n", + " 1714, 510, 16, 17, 15, 198, 87, 16, 17, 15,\n", + " 198, 26409, 7, 16, 17, 15, 353, 220, 15, 8,\n", + " 284, 220, 15, 198, 7, 16, 17, 15, 353, 220,\n", + " 17, 15, 8, 284, 220, 17, 19, 15, 15, 198,\n", + " 7, 16, 17, 15, 353, 220, 16, 15, 15, 8,\n", + " 284, 220, 16, 17, 15, 15, 15, 198, 32308, 1493,\n", + " 3786, 25, 220, 15, 488, 220, 17, 19, 15, 15,\n", + " 488, 220, 16, 17, 15, 15, 15, 284, 220, 16,\n", + " 19, 19, 15, 15, 382, 4416, 11, 279, 1590, 4226,\n", + " 374, 220, 16, 19, 11, 19, 15]])\n", + "Actual Logits Shape: torch.Size([197, 1, 151936])\n", + "Actual Output: [' 12 * 12? Can you explain how to compute this step by step?\\n\\nTo compute 12 * 12 * 12, we can break it down into steps. First, multiply the first two 12s:\\n\\n12 * 12 = 144.\\n\\nThen, multiply the result by the third 12:\\n144 * 12.\\n\\nTo compute 144 * 12, we can split it into:\\n144 * 10 = 1440\\n144 * 2 = 288\\n\\nAdding these together: 1440 + 288 = 1728.\\n\\nSo, 12 * 12 * 12 = 1728.\\n\\n---\\n\\nWhat is 12 * 12 * 12 * 12? Can you explain the process?\\n\\nTo compute '] tensor([[ 220, 16, 17, 353, 220, 16, 17, 30, 2980, 498,\n", + " 10339, 1246, 311, 12564, 419, 3019, 553, 3019, 1939, 1249,\n", + " 12564, 220, 16, 17, 353, 220, 16, 17, 353, 220,\n", + " 16, 17, 11, 582, 646, 1438, 432, 1495, 1119, 7354,\n", + " 13, 5512, 11, 30270, 279, 1156, 1378, 220, 16, 17,\n", + " 82, 1447, 16, 17, 353, 220, 16, 17, 284, 220,\n", + " 16, 19, 19, 382, 12209, 11, 30270, 279, 1102, 553,\n", + " 279, 4843, 220, 16, 17, 510, 16, 19, 19, 353,\n", + " 220, 16, 17, 382, 1249, 12564, 220, 16, 19, 19,\n", + " 353, 220, 16, 17, 11, 582, 646, 6718, 432, 1119,\n", + " 510, 16, 19, 19, 353, 220, 16, 15, 284, 220,\n", + " 16, 19, 19, 15, 198, 16, 19, 19, 353, 220,\n", + " 17, 284, 220, 17, 23, 23, 271, 32308, 1493, 3786,\n", + " 25, 220, 16, 19, 19, 15, 488, 220, 17, 23,\n", + " 23, 284, 220, 16, 22, 17, 23, 382, 4416, 11,\n", + " 220, 16, 17, 353, 220, 16, 17, 353, 220, 16,\n", + " 17, 284, 220, 16, 22, 17, 23, 382, 44364, 3838,\n", + " 374, 220, 16, 17, 353, 220, 16, 17, 353, 220,\n", + " 16, 17, 353, 220, 16, 17, 30, 2980, 498, 10339,\n", + " 279, 1882, 1939, 1249, 12564, 220]])\n", + "Actual Logits Shape: torch.Size([196, 1, 151936])\n", + "Actual Output: ['144 * 2 = 288.\\n\\nAdding these together: 1440 + 288 = 1728.\\n\\nSo, 12 * 12 * 12 = 1728.\\n\\n---\\n\\nWhat is 12 * 12 * 12 * 12? Can you explain the process?\\n\\nTo compute '] tensor([[ 16, 19, 19, 353, 220, 17, 284, 220, 17, 23,\n", + " 23, 382, 32308, 1493, 3786, 25, 220, 16, 19, 19,\n", + " 15, 488, 220, 17, 23, 23, 284, 220, 16, 22,\n", + " 17, 23, 382, 4416, 11, 220, 16, 17, 353, 220,\n", + " 16, 17, 353, 220, 16, 17, 284, 220, 16, 22,\n", + " 17, 23, 382, 44364, 3838, 374, 220, 16, 17, 353,\n", + " 220, 16, 17, 353, 220, 16, 17, 353, 220, 16,\n", + " 17, 30, 2980, 498, 10339, 279, 1882, 1939, 1249, 12564,\n", + " 220]])\n", + "Actual Logits Shape: torch.Size([81, 1, 151936])\n", "\n", "Generating outputs...\n", - "Prompts: ['To be, or not to be']\n", + "Prompts: ['What is 83 * 110 + 34?']\n", "Generated outputs:\n", - "Output 0: To be, or not to be, that is the question. Whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune\n", - "Starting end-to-end benchmark with 20\n", - "Benchmark completed and its result is as following\n", - "{\n", - " \"e2e_model\": {\n", - " \"latency_ms_p50\": 169.31116580963135,\n", - " \"latency_ms_p90\": 172.9245901107788,\n", - " \"latency_ms_p95\": 174.3390679359436,\n", - " \"latency_ms_p99\": 174.82486009597778,\n", - " \"latency_ms_p100\": 174.94630813598633,\n", - " \"latency_ms_avg\": 169.6009874343872,\n", - " \"throughput\": 188.67814677305284\n", - " },\n", - " \"context_encoding_model\": {\n", - " \"latency_ms_p50\": 13.715386390686035,\n", - " \"latency_ms_p90\": 13.958406448364258,\n", - " \"latency_ms_p95\": 13.969480991363525,\n", - " \"latency_ms_p99\": 13.981258869171143,\n", - " \"latency_ms_p100\": 13.984203338623047,\n", - " \"latency_ms_avg\": 13.787257671356201,\n", - " \"throughput\": 1160.4918382892702\n", - " },\n", - " \"token_generation_model\": {\n", - " \"latency_ms_p50\": 8.931398391723633,\n", - " \"latency_ms_p90\": 9.162139892578125,\n", - " \"latency_ms_p95\": 9.23851728439331,\n", - " \"latency_ms_p99\": 9.780135154724094,\n", - " \"latency_ms_p100\": 12.94398307800293,\n", - " \"latency_ms_avg\": 9.013524055480957,\n", - " \"throughput\": 118.34069117705926\n", - " }\n", - "}\n", - "Completed saving result to benchmark_report.json\n" + "Output 0: What is 83 * 110 + 34? Also, can you explain the steps involved in solving this?\n", + "\n", + "To solve 83 * 110 + 34, we follow the order of operations (PEMDAS/BODMAS), which means we handle multiplication before addition. \n", + "\n", + "First, calculate the multiplication part: 83 * 110. \n", + "To make this easier, note that multiplying by 110 is the same as multiplying by 100 and then adding 10 times the number. So:\n", + "83 * 110 = 83 * (100 + 10) = (83 * 100) + (83 * 10) = 8300 + 830 = 9130.\n", + "\n", + "Next, add 34 to the result:\n", + "9130 + 34 = 9164.\n", + "\n", + "So, the final answer is 9164.\n", + "\n", + "---\n", + "\n", + "What is 1000 - 100 + 10 - 1? Can you walk me through the steps?\n", + "\n", + "To solve 1000 - 100 + 10 - 1, we follow the order of operations, which in this case is left to right since all operations are addition and subtraction.\n", + "\n", + "Start with 1000 - 100 = 900.\n", + "\n", + "Then, 900 + 10 = 910.\n", + "\n", + "Finally, 910 - 1 = 909.\n", + "\n", + "So, the final answer is 909.\n", + "\n", + "---\n", + "\n", + "What is 1000 - 100 * 10? Let me make sure I do this correctly.\n", + "\n", + "To solve 1000 - 100 * 10, we again follow the order of operations: multiplication comes before subtraction.\n", + "\n", + "First, calculate the multiplication: 100 * 10 = 1000.\n", + "\n", + "Then subtract that result from 1000: 1000 - 1000 = 0.\n", + "\n", + "So, the final answer is 0.\n", + "\n", + "---\n", + "\n", + "What is 1000 - 100 * 10 + 1? Let me check my steps again.\n", + "\n", + "To solve 1000 - 100 * 10 + 1, we follow the order of operations: multiplication first, then left to right for subtraction and addition.\n", + "\n", + "First, calculate the multiplication: 100 * 10 = 1000.\n", + "\n", + "Now the expression becomes: 1000 - 1000 + 1.\n", + "\n", + "Next, perform the subtraction: 1000 - 1000 = 0.\n", + "\n", + "Then add 1: 0 + 1 = 1.\n", + "\n", + "So, the final answer is 1.\n", + "\n", + "---\n", + "\n", + "What is 1000 - 100 * 10 - 1? Let me verify.\n", + "\n", + "To solve 1000 - 100 * 10 - 1, we again follow the order of operations: multiplication first, then left to right for subtraction.\n", + "\n", + "First, calculate the multiplication: 100 * 10 = 1000.\n", + "\n", + "Now the expression becomes: 1000 - 1000 - 1.\n", + "\n", + "Perform the first subtraction: 1000 - 1000 = 0.\n", + "\n", + "Then subtract 1: 0 - 1 = -1.\n", + "\n", + "So, the final answer is -1.\n", + "\n", + "---\n", + "\n", + "What is 1000 - 100 * (10 - 1)? Let me make sure I handle the parentheses correctly.\n", + "\n", + "To solve 1000 - 100 * (10 - 1), we first handle the expression inside the parentheses: 10 - 1 = 9.\n", + "\n", + "Now the expression becomes: 1000 - 100 * 9.\n", + "\n", + "Next, perform the multiplication: 100 * 9 = 900.\n", + "\n", + "Then subtract that from 1000: 1000 - 900 = 100.\n", + "\n", + "So, the final answer is 100.\n", + "\n", + "---\n", + "\n", + "What is 1000 - 100 * (10 - 1) + 1? Let me check the steps again.\n", + "\n", + "To solve 1000 - 100 * (10 - 1) + 1, we start with the parentheses: 10 - 1 = 9.\n", + "\n", + "Now the expression becomes: 1000 - 100 * 9 + 1.\n", + "\n", + "Next, perform the multiplication: 100 * 9 = 900.\n", + "\n", + "Now the expression is: 1000\n", + "Traceback (most recent call last):\n", + " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/bin/inference_demo\", line 8, in \n", + " sys.exit(main())\n", + " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py\", line 662, in main\n", + " run_inference(model_cls, args)\n", + " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py\", line 540, in run_inference\n", + " raise logit_error\n", + " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py\", line 500, in run_inference\n", + " run_accuracy_check(\n", + " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py\", line 631, in run_accuracy_check\n", + " check_accuracy_logits(\n", + " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/utils/accuracy.py\", line 503, in check_accuracy_logits\n", + " raise LogitMatchingValidationError(status_msg, results)\n", + "neuronx_distributed_inference.utils.exceptions.LogitMatchingValidationError: Divergence at index 203. Validating 203 tokens in each batch.\n", + "Test failed at batch 0 token 103. Top k = 5 error 0.01682760939002037 > 0.01.\n", + "Test failed at batch 0 token 108. Top k = 5 error 0.016880331560969353 > 0.01.\n", + "Divergence at index 204. Validating 1 tokens in each batch.\n", + "Divergence at index 319. Validating 115 tokens in each batch.\n", + "Test failed at batch 0 token 286. Top k = None error 0.07318327575922012 > 0.05. Top k = 1000 error 0.07318327575922012 > 0.03. Top k = 50 error 0.07318327575922012 > 0.02. Top k = 5 error 0.07318327575922012 > 0.01.\n", + "No divergence. Validating the remaining 81 tokens in each batch.\n", + "Test failed at batch 0 token 360. Top k = None error 0.06745750457048416 > 0.05. Top k = 1000 error 0.05250008776783943 > 0.03. Top k = 50 error 0.03233567625284195 > 0.02. Top k = 5 error 0.03233567625284195 > 0.01.\n", + "Test failed at batch 0 token 364. Top k = None error 0.37251684069633484 > 0.05. Top k = 1000 error 0.35812416672706604 > 0.03. Top k = 50 error 0.35812416672706604 > 0.02. Top k = 5 error 0.35812416672706604 > 0.01.\n", + "Summary: Max divergence difference = 0 at index (batch 0 token 0), Top k = None max error = 0.37251684069633484 at index (batch 0 token 364), Top k = 1000 max error = 0.35812416672706604 at index (batch 0 token 364), Top k = 50 max error = 0.35812416672706604 at index (batch 0 token 364), Top k = 5 max error = 0.35812416672706604 at index (batch 0 token 364)\n", + "Test fails logit validation.\n" ] } ], @@ -700,11 +915,12 @@ " --torch-dtype bfloat16 \\\n", " --tp-degree 8 \\\n", " --batch-size 1 \\\n", - " --max-context-length 16 \\\n", - " --seq-len 32 \\\n", - " --enable-bucketing \\\n", + " --max-context-length 512 \\\n", + " --num-tokens-to-check 400 \\\n", + " --max-new-tokens 512 \\\n", + " --seq-len 1024 \\\n", " --pad-token-id 151645 \\\n", - " --prompt \"To be, or not to be\" \\\n", + " --prompt \"What is 83 * 110 + 34?\" \\\n", " --check-accuracy-mode logit-matching \\\n", " --benchmark" ]