diff --git a/tests/neuron/2_core/test_mistral.py b/tests/neuron/2_core/test_mistral.py new file mode 100644 index 000000000000..8acd082f2ded --- /dev/null +++ b/tests/neuron/2_core/test_mistral.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams + + +def test_mistral(): + llm = LLM(model="mistralai/Mistral-7B-v0.1", + tensor_parallel_size=2, + max_num_seqs=4, + max_model_len=512, + use_v2_block_manager=True, + override_neuron_config={ + "sequence_parallel_enabled": False, + "skip_warmup": True + }, + device="neuron") + + prompts = [ + "The president of the United States is", + "The capital of France is", + ] + outputs = llm.generate(prompts, SamplingParams(top_k=1)) + + expected_outputs = [ + " the most powerful person in the world. He is the head of state " + "and head", + " a city of many faces. It is a city of history, culture, art" + ] + + for expected_output, output in zip(expected_outputs, outputs): + generated_text = output.outputs[0].text + assert (expected_output == generated_text) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index f879c99ac2ef..034c45824c2b 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -48,6 +48,9 @@ # Models supported by Neuronx distributed for inference. _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str]] = { "LlamaForCausalLM": + ("neuronx_distributed_inference.models.llama.modeling_llama", + "NeuronLlamaForCausalLM"), + "MistralForCausalLM": ("neuronx_distributed_inference.models.llama.modeling_llama", "NeuronLlamaForCausalLM"), "DbrxForCausalLM": diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 71f7c718cdf9..e08337b8391d 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -51,8 +51,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: assert (vllm_config.lora_config is None), "LoRA is not supported for Neuron backend." - cache_config = vllm_config.cache_config - if cache_config: + if vllm_config.cache_config and vllm_config.model_config: # neuron needs block_size = max_model_len vllm_config.cache_config.block_size = \ vllm_config.model_config.max_model_len # type: ignore