diff --git a/.buildkite/scripts/hardware_ci/run-neuron-test.sh b/.buildkite/scripts/hardware_ci/run-neuron-test.sh index c0b9dd8dadba..3d294ea5f8a7 100644 --- a/.buildkite/scripts/hardware_ci/run-neuron-test.sh +++ b/.buildkite/scripts/hardware_ci/run-neuron-test.sh @@ -53,4 +53,11 @@ docker run --rm -it --device=/dev/neuron0 --network bridge \ -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ --name "${container_name}" \ ${image_name} \ - /bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys && python3 -m pytest /workspace/vllm/tests/neuron/2_core/ -v --capture=tee-sys" + /bin/bash -c " + python3 /workspace/vllm/examples/offline_inference/neuron.py; + python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys; + for f in /workspace/vllm/tests/neuron/2_core/*.py; do + echo 'Running test file: '$f; + python3 -m pytest \$f -v --capture=tee-sys; + done + " \ No newline at end of file diff --git a/tests/neuron/2_core/test_eagle.py b/tests/neuron/2_core/test_eagle.py new file mode 100644 index 000000000000..d71c88689a99 --- /dev/null +++ b/tests/neuron/2_core/test_eagle.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +import shutil +import tempfile + +import torch +from huggingface_hub import snapshot_download +from safetensors import safe_open + +from vllm import LLM, SamplingParams + + +def patch_eagle_draft_with_lm_head(target_model_id: str, + draft_model_id: str) -> str: + # In NxDI, draft model checkpoint must include lm_head weights from target + # model. For more details see https://awsdocs-neuron.readthedocs-hosted.com + # /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html + # #eagle-checkpoint-compatibility + final_draft_dir = "/tmp/patched_eagle_draft" + + with tempfile.TemporaryDirectory() as tmp_dir: + target_dir = snapshot_download(repo_id=target_model_id, + local_dir=os.path.join( + tmp_dir, "target")) + draft_dir = snapshot_download(repo_id=draft_model_id, + local_dir=os.path.join(tmp_dir, "draft")) + + lm_head_key = "lm_head.weight" + index_path = os.path.join(target_dir, "model.safetensors.index.json") + with open(index_path) as f: + index = json.load(f) + shard_name = index["weight_map"][lm_head_key] + target_safetensor_path = os.path.join(target_dir, shard_name) + + with safe_open(target_safetensor_path, framework="pt") as f: + target_lm_head = f.get_tensor(lm_head_key) + + draft_path = os.path.join(draft_dir, "pytorch_model.bin") + draft_state_dict = torch.load(draft_path, map_location="cpu") + draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16) + torch.save(draft_state_dict, draft_path) + + shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True) + + return final_draft_dir + + +def test_eagle(): + patched_draft_path = patch_eagle_draft_with_lm_head( + target_model_id="meta-llama/Llama-2-7b-hf", + draft_model_id="yuhuili/EAGLE-llama2-chat-7B") + llm = LLM( + model="meta-llama/Llama-2-7b-hf", + speculative_config={ + "model": patched_draft_path, + "num_speculative_tokens": 5, + "max_model_len": 128 + }, + max_num_seqs=1, + max_model_len=128, + tensor_parallel_size=2, + override_neuron_config={ + "enable_eagle_speculation": True, + "enable_fused_speculation": True, + "fused_qkv": True + }, + ) + prompts = [ + "The president of the United States is", + ] + outputs = llm.generate(prompts, SamplingParams(top_k=1)) + expected_output = " the head of state and head of government of " \ + "the United States. The president direct" + + for output in outputs: + generated_text = output.outputs[0].text + print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") + assert (expected_output == generated_text) + + print("Neuron Eagle speculation test passed.") diff --git a/tests/neuron/2_core/test_mistral.py b/tests/neuron/2_core/test_mistral.py index cc3b53a9d7c9..3e651502d1e2 100644 --- a/tests/neuron/2_core/test_mistral.py +++ b/tests/neuron/2_core/test_mistral.py @@ -12,8 +12,7 @@ def test_mistral(): override_neuron_config={ "sequence_parallel_enabled": False, "skip_warmup": True - }, - device="neuron") + }) # Send more prompts than the compiled batch size (4) and request # varying generation lengths to test accuracy related to Neuron @@ -59,4 +58,7 @@ def test_mistral(): for expected_output, output in zip(expected_outputs, outputs): generated_text = output.outputs[0].text + print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") assert (expected_output == generated_text) + + print("Neuron Mistral test passed.") diff --git a/vllm/config.py b/vllm/config.py index 3fa1db0e8390..bdff4c70fb18 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2529,11 +2529,10 @@ def __post_init__(self): "Chunked prefill and EAGLE are not compatible " "when using V0.") - from vllm.platforms import current_platform from vllm.transformers_utils.configs.eagle import ( EAGLEConfig) if isinstance(self.draft_model_config.hf_config, - EAGLEConfig) or current_platform.is_neuron(): + EAGLEConfig): pass else: eagle_config = EAGLEConfig(