Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/neuron/2_core/test_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0

from vllm import LLM, SamplingParams


def test_mistral():
llm = LLM(model="mistralai/Mistral-7B-v0.1",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
use_v2_block_manager=True,
override_neuron_config={
"sequence_parallel_enabled": False,
"skip_warmup": True
},
device="neuron")

prompts = [
"The president of the United States is",
"The capital of France is",
]
outputs = llm.generate(prompts, SamplingParams(top_k=1))

expected_outputs = [
" the most powerful person in the world. He is the head of state "
"and head",
" a city of many faces. It is a city of history, culture, art"
]

for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
assert (expected_output == generated_text)
3 changes: 3 additions & 0 deletions vllm/model_executor/model_loader/neuronx_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
# Models supported by Neuronx distributed for inference.
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str]] = {
"LlamaForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"MistralForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"DbrxForCausalLM":
Expand Down
3 changes: 1 addition & 2 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
assert (vllm_config.lora_config
is None), "LoRA is not supported for Neuron backend."

cache_config = vllm_config.cache_config
if cache_config:
if vllm_config.cache_config and vllm_config.model_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len # type: ignore
Expand Down