From 2d9ca06dafe36b93d5d2c2d27106397d2d2951a3 Mon Sep 17 00:00:00 2001 From: Crucifixion-Fxl Date: Thu, 22 May 2025 12:35:38 +0800 Subject: [PATCH 1/5] [Bugfix] Migrate to REGEX Library to prevent catastrophic backtracking Signed-off-by: Crucifixion-Fxl --- .github/scripts/cleanup_pr_body.sh | 2 +- benchmarks/benchmark_serving_structured_output.py | 6 +++--- benchmarks/kernels/graph_machete_bench.py | 2 +- docs/source/conf.py | 2 +- docs/source/generate_examples.py | 3 ++- examples/offline_inference/prithvi_geospatial_mae.py | 2 +- requirements/common.txt | 1 + requirements/docs.txt | 1 + requirements/nightly_torch_test.txt | 2 +- setup.py | 3 +-- tests/entrypoints/llm/test_guided_generate.py | 2 +- tests/entrypoints/openai/test_chat.py | 2 +- tests/entrypoints/openai/test_completion.py | 3 +-- tests/entrypoints/openai/test_prompt_validation.py | 12 ++++++------ tests/models/multimodal/generation/test_phi4mm.py | 2 +- .../multimodal/generation/vlm_utils/model_utils.py | 2 +- tests/tool_use/test_tool_choice_required.py | 4 ++-- .../entrypoints/llm/test_struct_output_generate.py | 2 +- tests/v1/entrypoints/openai/test_completion.py | 2 +- tests/v1/sample/utils.py | 3 ++- vllm/collect_env.py | 2 +- vllm/config.py | 2 +- vllm/engine/arg_utils.py | 2 +- vllm/entrypoints/openai/api_server.py | 2 +- vllm/entrypoints/openai/protocol.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 2 +- .../openai/tool_parsers/deepseekv3_tool_parser.py | 3 ++- .../tool_parsers/granite_20b_fc_tool_parser.py | 2 +- .../openai/tool_parsers/hermes_tool_parser.py | 2 +- .../openai/tool_parsers/jamba_tool_parser.py | 2 +- .../openai/tool_parsers/llama_tool_parser.py | 2 +- .../openai/tool_parsers/mistral_tool_parser.py | 2 +- .../openai/tool_parsers/phi4mini_tool_parser.py | 2 +- .../openai/tool_parsers/pythonic_tool_parser.py | 2 +- vllm/lora/models.py | 2 +- vllm/lora/utils.py | 2 +- vllm/model_executor/guided_decoding/utils.py | 2 +- .../guided_decoding/xgrammar_decoding.py | 2 +- .../layers/quantization/compressed_tensors/utils.py | 2 +- vllm/model_executor/layers/quantization/modelopt.py | 2 +- .../layers/quantization/quark/utils.py | 3 ++- .../layers/quantization/utils/gptq_utils.py | 2 +- vllm/model_executor/model_loader/tensorizer.py | 2 +- vllm/model_executor/models/mimo_mtp.py | 2 +- vllm/model_executor/models/minimax_text_01.py | 2 +- vllm/model_executor/models/phi3v.py | 2 +- vllm/model_executor/models/qwen_vl.py | 2 +- vllm/model_executor/models/transformers.py | 2 +- vllm/multimodal/processing.py | 2 +- vllm/reasoning/granite_reasoning_parser.py | 2 +- vllm/transformers_utils/tokenizers/mistral.py | 2 +- vllm/utils.py | 2 +- vllm/v1/structured_output/utils.py | 2 +- 53 files changed, 65 insertions(+), 61 deletions(-) mode change 100755 => 100644 .github/scripts/cleanup_pr_body.sh mode change 100755 => 100644 setup.py diff --git a/.github/scripts/cleanup_pr_body.sh b/.github/scripts/cleanup_pr_body.sh old mode 100755 new mode 100644 index 3246c6f9bc4b..8d65936fba1d --- a/.github/scripts/cleanup_pr_body.sh +++ b/.github/scripts/cleanup_pr_body.sh @@ -26,7 +26,7 @@ sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}" # Remove HTML
section that includes text of "PR Checklist (Click to Expand)" python3 - < None: # vllm_flash_attn python code: # Regex from # `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)` - import re compiled_regex = re.compile( r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") file_members += list( diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index fdbdccd4654c..dd5d17885eb9 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re import weakref from enum import Enum import jsonschema import pytest +import regex as re from pydantic import BaseModel from vllm.distributed import cleanup_dist_env_and_memory diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index a10b42ea3a4b..2509ef0d280a 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -2,13 +2,13 @@ # imports for guided decoding tests import json -import re from typing import Optional import jsonschema import openai # use the official client for correctness check import pytest import pytest_asyncio +import regex as re import requests import torch from openai import BadRequestError, OpenAI diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 1d9aa4972b70..9d12f27a2b87 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 - # imports for guided decoding tests import json -import re import shutil from tempfile import TemporaryDirectory from typing import Optional @@ -11,6 +9,7 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio +import regex as re # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index f889189a9968..e384915899d3 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # imports for guided decoding tests -import re - import openai import pytest +import regex as re from ...utils import RemoteOpenAIServer @@ -32,7 +31,7 @@ async def test_out_of_vocab_token_ids(): client = remote_server.get_async_client() with pytest.raises(openai.BadRequestError, - match=re.compile('.*out of vocabulary.*')): + match=re.compile('.*out of vocabulary.*').pattern): await client.completions.create(model=model_name, prompt=[999999], max_tokens=5, @@ -46,9 +45,10 @@ async def test_reject_multistep_with_guided_decoding(): with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - with pytest.raises(openai.BadRequestError, - match=re.compile( - '.*Guided decoding .* multi-step decoding.*')): + with pytest.raises( + openai.BadRequestError, + match=re.compile( + '.*Guided decoding .* multi-step decoding.*').pattern): await client.completions.create( model=model_name, prompt="Hello", diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index 11460a1a8d2b..04df062b776f 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import os -import re from collections.abc import Sequence from typing import Optional import librosa import pytest +import regex as re from huggingface_hub import snapshot_download from transformers import AutoTokenizer diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index b71400fc8312..743c7f947697 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -3,11 +3,11 @@ for manipulating the input / output of HF & vLLM test runners, which are typically specific to a small subset of models. """ -import re import types from pathlib import PosixPath from typing import Optional, Union +import regex as re import torch from PIL.Image import Image from transformers import (AutoConfig, AutoTokenizer, BatchFeature, diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index 2ab87a0ef41f..291769848145 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from copy import deepcopy from unittest.mock import MagicMock import pytest +import regex as re from pydantic import TypeAdapter from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -333,4 +333,4 @@ def test_streaming_output_valid(output, empty_params, delta_len): combined_messages += message.tool_calls[0].function.arguments combined_messages += "}]" assert json.loads(combined_messages) == output - assert json.dumps(json.loads(combined_messages)) == output_json + assert json.dumps(json.loads(combined_messages)) == output_json \ No newline at end of file diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 25bbcd901d6a..5f1fff200de3 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -4,12 +4,12 @@ from __future__ import annotations import json -import re from enum import Enum from typing import TYPE_CHECKING, Any import jsonschema import pytest +import regex as re from pydantic import BaseModel from tests.reasoning.utils import run_reasoning_extraction diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 3ffc54f520b4..333ad23795f3 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -import re from typing import Optional import openai # use the official client for correctness check import pytest import pytest_asyncio +import regex as re from openai import BadRequestError from tests.utils import RemoteOpenAIServer diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index f540895bbf14..932b652aea32 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -import re from enum import Enum from typing import Optional +import regex as re + from vllm import CompletionOutput diff --git a/vllm/collect_env.py b/vllm/collect_env.py index 85746b7ef606..86eb465b8f65 100644 --- a/vllm/collect_env.py +++ b/vllm/collect_env.py @@ -815,4 +815,4 @@ def main(): if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index a185a75c6bf3..24ef675a9c99 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -6,7 +6,6 @@ import hashlib import inspect import json -import re import textwrap import uuid import warnings @@ -20,6 +19,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, Protocol, TypeVar, Union, cast, get_args, get_origin) +import regex as re import torch from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0c6b15b79da..eeba9b30bd0a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,7 +4,6 @@ import argparse import dataclasses import json -import re import sys import threading import warnings @@ -13,6 +12,7 @@ from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union, cast, get_args, get_origin) +import regex as re import torch from typing_extensions import TypeIs, deprecated diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0ab6fcdca1a4..2da89b4f5944 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -7,7 +7,6 @@ import inspect import multiprocessing import os -import re import signal import socket import tempfile @@ -21,6 +20,7 @@ from typing import Annotated, Optional, Union import prometheus_client +import regex as re import uvloop from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 5ab2356a0898..be4f76ab5347 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3,11 +3,11 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import json -import re import time from http import HTTPStatus from typing import Annotated, Any, ClassVar, Literal, Optional, Union +import regex as re import torch from fastapi import HTTPException, UploadFile from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ee18e0b0a454..bc11686d7be8 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -2,7 +2,6 @@ import asyncio import json -import re import time from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import Sequence as GenericSequence @@ -10,6 +9,7 @@ import jinja2 import partial_json_parser +import regex as re from fastapi import Request from pydantic import TypeAdapter diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py index bd8e87e4cee8..14e743e13a72 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -import re from collections.abc import Sequence from typing import Union +import regex as re + from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index b93de6b41817..600ccbcf35d0 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from json import JSONDecoder from typing import Union import partial_json_parser +import regex as re from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import random_tool_call_id diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index e56a8ef7193c..2b9f9852bcb3 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from typing import Union import partial_json_parser +import regex as re from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import random_tool_call_id diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 6cac6f8163bf..e882ca2605e2 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from typing import Union import partial_json_parser +import regex as re from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import random_tool_call_id diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 9307034f40d6..561402a72bd4 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from json import JSONDecoder from typing import Union import partial_json_parser +import regex as re from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 9dbfe85ecc68..fecad7e653ab 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from random import choices from string import ascii_letters, digits from typing import Union import partial_json_parser +import regex as re from partial_json_parser.core.options import Allow from pydantic import Field diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index abf70a5e85c4..798f346fc97d 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re from collections.abc import Sequence from typing import Any, Optional +import regex as re from transformers import PreTrainedTokenizerBase from vllm.entrypoints.chat_utils import random_tool_call_id diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index bb91a35af3be..22018c0d4f4f 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -2,10 +2,10 @@ import ast import json -import re from collections.abc import Sequence from typing import Any, Union +import regex as re from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 959fe4a672a6..7e8321691f57 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -3,11 +3,11 @@ import copy import math import os -import re from collections.abc import Sequence from dataclasses import dataclass, field from typing import Any, Callable, Optional, Union +import regex as re import safetensors.torch import torch from torch import nn diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index b66850d4304f..619dd3bdc40a 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import os -import re from typing import Optional, Union import huggingface_hub +import regex as re from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, HFValidationError, RepositoryNotFoundError) from torch import nn diff --git a/vllm/model_executor/guided_decoding/utils.py b/vllm/model_executor/guided_decoding/utils.py index 1ad1ef8fbf16..3f77cf394d9a 100644 --- a/vllm/model_executor/guided_decoding/utils.py +++ b/vllm/model_executor/guided_decoding/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -import re +import regex as re def has_xgrammar_unsupported_json_features(schema: dict) -> bool: diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 8e40da4b3aa9..cd6029c14239 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -4,10 +4,10 @@ from __future__ import annotations import json -import re from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any +import regex as re import torch import vllm.envs diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index ccd54281ceb7..75e81c4dd49d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -import re from collections.abc import Iterable, Mapping from types import MappingProxyType from typing import Optional +import regex as re from compressed_tensors import CompressionFormat from torch.nn import Module diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 13957a96deca..97167cb5833d 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -228,7 +228,7 @@ def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": exclude_modules, group_size) def is_layer_excluded(self, prefix: str, exclude_modules: list): - import re + import regex as re for pattern in exclude_modules: regex_str = pattern.replace('.', r'\.').replace('*', r'.*') if re.fullmatch(regex_str, prefix): diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index d1d293b01791..5e56bcb7564c 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -import re from collections.abc import Iterable, Mapping from types import MappingProxyType from typing import Any, Optional +import regex as re + def deep_compare(dict1: Any, dict2: Any) -> bool: if type(dict1) is not type(dict2): diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py index ff7a8169e6fb..36161d13b24f 100644 --- a/vllm/model_executor/layers/quantization/utils/gptq_utils.py +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -import re from copy import deepcopy from typing import Optional, Union +import regex as re import torch from vllm.config import QuantizationConfig diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 0ff35b3a6dca..e9fff705f1d4 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -4,13 +4,13 @@ import dataclasses import io import os -import re import time from collections.abc import Generator from dataclasses import dataclass from functools import partial from typing import BinaryIO, Optional, Union +import regex as re import torch from torch import nn from transformers import PretrainedConfig diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index adcfcaa6b1e6..cbca6a4c8f9d 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -250,7 +250,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params def map_model_name_to_mtp_param_name(self, name: str) -> str: - import re + import regex as re name_without_prefix = [ "token_layernorm", "hidden_layernorm", "input_proj", "final_layernorm" diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 0285402dadf7..9dffe96fc545 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -2,10 +2,10 @@ """Inference-only MiniMaxText01 model.""" import copy import math -import re from collections.abc import Iterable from typing import Optional, Union +import regex as re import torch import torch.distributed import torch.nn.functional as F diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index bb4d46be3f99..b757e661d771 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -14,10 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re from collections.abc import Iterable, Mapping, Sequence from typing import Any, Literal, Optional, TypedDict, Union +import regex as re import torch import torch.nn as nn from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 3701153bace5..57a66b793711 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -7,12 +7,12 @@ import copy import math -import re import unicodedata from collections.abc import Collection, Mapping, Sequence, Set from functools import lru_cache, partial from typing import Callable, Literal, Optional, TypedDict, Union +import regex as re import torch from torch import nn from torchvision import transforms diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index a8f30b2f27bf..4247ce640703 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -14,10 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `transformers` models""" -import re from collections.abc import Iterable from typing import Literal, Optional, Union +import regex as re import torch from torch import nn from transformers import AutoModel, PretrainedConfig, PreTrainedModel diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 320a26f37555..f9e8db8d38a2 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 import json -import re import sys from abc import ABC, abstractmethod from collections import defaultdict @@ -12,6 +11,7 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, TypeVar, Union, cast) +import regex as re import torch from typing_extensions import assert_never diff --git a/vllm/reasoning/granite_reasoning_parser.py b/vllm/reasoning/granite_reasoning_parser.py index 0dae02d33fec..07a63e294df4 100644 --- a/vllm/reasoning/granite_reasoning_parser.py +++ b/vllm/reasoning/granite_reasoning_parser.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -import re from collections.abc import Sequence from typing import Optional, Union +import regex as re from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 551c2d55b4fc..05de6a603655 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import os -import re from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union, cast import huggingface_hub +import regex as re from huggingface_hub import HfApi, hf_hub_download from vllm.logger import init_logger diff --git a/vllm/utils.py b/vllm/utils.py index 0cd90c130d3e..d8f099995003 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -19,7 +19,6 @@ import multiprocessing import os import pickle -import re import signal import socket import subprocess @@ -54,6 +53,7 @@ import numpy as np import numpy.typing as npt import psutil +import regex as re import torch import torch.types import yaml diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index f33f4972e103..111e92dc0990 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -2,7 +2,7 @@ from __future__ import annotations -import re +import regex as re def grammar_is_likely_lark(grammar_str: str) -> bool: From 8de207a3ae4747b7833a62539aee75d4f569cb25 Mon Sep 17 00:00:00 2001 From: Crucifixion-Fxl Date: Thu, 22 May 2025 21:32:24 +0800 Subject: [PATCH 2/5] [Bugfix] Migrate to REGEX Library to prevent catastrophic backtracking V2 Signed-off-by: Crucifixion-Fxl --- benchmarks/benchmark_serving_structured_output.py | 2 +- pyproject.toml | 1 + requirements/build.txt | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 7619f76b7623..6a50f47d3951 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -689,7 +689,7 @@ def _eval_correctness_choice(expected, actual): def _eval_correctness_regex(expected, actual): import regex as re - return re.match(args.re, actual) is not None + return re.match(args.regex, actual) is not None def _eval_correctness(expected, actual): if args.structure_type == "guided_json": diff --git a/pyproject.toml b/pyproject.toml index 0b803a26b658..6a2d9c44d414 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ requires = [ "setuptools-scm>=8.0", "torch == 2.7.0", "wheel", + "regex", "jinja2", ] build-backend = "setuptools.build_meta" diff --git a/requirements/build.txt b/requirements/build.txt index 5edc593b9270..320e5b892584 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -7,3 +7,4 @@ setuptools-scm>=8 torch==2.7.0 wheel jinja2>=3.1.6 +regex From a2e228987d29a680f4130f6c9a2b237d1cf68caa Mon Sep 17 00:00:00 2001 From: Crucifixion-Fxl Date: Fri, 23 May 2025 08:16:26 +0800 Subject: [PATCH 3/5] [Bugfix] Migrate to REGEX Library to prevent catastrophic backtracking Signed-off-by: Crucifixion-Fxl --- .buildkite/release-pipeline.yaml | 2 +- .../scripts/hardware_ci/run-neuron-test.sh | 4 +- .buildkite/test-pipeline.yaml | 5 +- .github/CODEOWNERS | 2 + CMakeLists.txt | 6 +- .../kernels/benchmark_paged_attention.py | 6 +- csrc/cutlass_extensions/common.hpp | 9 - csrc/moe/moe_ops.h | 4 +- csrc/moe/moe_permute_unpermute_op.cu | 43 +- .../moe_permute_unpermute_kernel.cu | 10 +- csrc/moe/torch_bindings.cpp | 6 +- .../cutlass_w8a8/scaled_mm_entry.cu | 2 +- csrc/rocm/attention.cu | 2051 +++++++++++++++-- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 8 +- docker/Dockerfile | 18 +- docker/Dockerfile.neuron | 7 +- docker/Dockerfile.rocm_base | 2 +- docker/Dockerfile.s390x | 32 +- docs/source/design/v1/torch_compile.md | 10 +- .../features/automatic_prefix_caching.md | 76 +- .../multimodal_inputs.md | 0 docs/source/features/prompt_embeds.md | 44 + docs/source/features/tool_calling.md | 11 +- .../installation/gpu/xpu.inc.md | 1 - docs/source/getting_started/quickstart.md | 5 + docs/source/index.md | 4 +- docs/source/models/supported_models.md | 7 +- docs/source/serving/prompt_embeds.md | 142 -- .../disagg_vllm_launcher.sh | 2 +- .../automatic_prefix_caching.py | 98 + .../disaggregated-prefill-v1/README.md | 1 + .../decode_example.py | 76 +- .../prefill_example.py | 90 +- .../prompt_embed_inference.py | 103 + .../openai_chat_completion_client.py | 24 +- ...enai_chat_completion_structured_outputs.py | 7 +- ...etion_structured_outputs_structural_tag.py | 7 +- .../openai_completion_client.py | 20 +- ...ompt_embed_inference_with_openai_client.py | 86 + .../tool_chat_template_llama4_pythonic.jinja | 100 +- pyproject.toml | 4 +- requirements/cpu.txt | 3 +- requirements/test.in | 1 + requirements/test.txt | 22 +- requirements/tpu.txt | 10 +- .../test_basic_correctness.py | 75 +- tests/conftest.py | 9 + tests/distributed/test_events.py | 9 +- tests/distributed/test_shm_broadcast.py | 10 +- .../openai/correctness/test_mteb.py | 42 + .../entrypoints/openai/test_openai_schema.py | 57 +- .../test_llama4_pythonic_tool_parser.py | 193 ++ tests/kernels/attention/test_attention.py | 8 +- tests/kernels/moe/test_moe.py | 18 + .../kernels/moe/test_moe_permute_unpermute.py | 4 +- tests/lora/test_lora_functions.py | 2 +- .../models/language/generation/test_hybrid.py | 2 +- tests/models/language/pooling/test_gte.py | 2 - tests/models/language/pooling/test_nomic.py | 1 - .../pooling/test_snowflake_arctic_embed.py | 1 - tests/models/quantization/test_nvfp4.py | 6 +- tests/models/registry.py | 5 +- tests/models/test_utils.py | 70 + tests/neuron/2_core/test_mistral.py | 62 + tests/quantization/test_auto_round.py | 30 + tests/quantization/test_bitsandbytes.py | 60 +- tests/test_outputs.py | 14 + tests/tool_use/utils.py | 2 +- tests/v1/core/test_kv_cache_utils.py | 71 +- tests/v1/core/test_prefix_caching.py | 36 +- tests/v1/sample/test_topk_topp_sampler.py | 136 +- tests/v1/test_oracle.py | 2 +- tests/v1/worker/test_gpu_input_batch.py | 39 +- tests/v1/worker/test_gpu_model_runner.py | 57 +- tools/install_nixl.sh | 109 + vllm/attention/backends/rocm_flash_attn.py | 3 +- .../ops/chunked_prefill_paged_decode.py | 3 +- .../attention/ops/triton_unified_attention.py | 4 +- vllm/compilation/backends.py | 206 +- vllm/compilation/base_piecewise_backend.py | 71 + vllm/compilation/cuda_piecewise_backend.py | 213 ++ vllm/config.py | 43 +- .../device_communicators/cpu_communicator.py | 8 +- .../device_communicators/shm_broadcast.py | 18 +- .../kv_transfer/kv_connector/utils.py | 5 +- .../kv_connector/v1/multi_connector.py | 2 +- .../kv_connector/v1/nixl_connector.py | 25 +- .../v1/shared_storage_connector.py | 6 +- vllm/distributed/utils.py | 154 +- vllm/engine/arg_utils.py | 8 +- vllm/entrypoints/cli/openai.py | 29 + .../openai/tool_parsers/__init__.py | 4 +- .../granite_20b_fc_tool_parser.py | 12 +- .../tool_parsers/granite_tool_parser.py | 12 +- .../tool_parsers/internlm2_tool_parser.py | 12 +- .../openai/tool_parsers/jamba_tool_parser.py | 16 +- .../llama4_pythonic_tool_parser.py | 303 +++ .../openai/tool_parsers/llama_tool_parser.py | 12 +- .../tool_parsers/phi4mini_tool_parser.py | 9 +- .../tool_parsers/pythonic_tool_parser.py | 9 +- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 146 ++ vllm/model_executor/layers/fused_moe/layer.py | 3 +- .../layers/fused_moe/moe_pallas.py | 20 +- .../layers/fused_moe/moe_permute_unpermute.py | 4 + vllm/model_executor/layers/linear.py | 9 + .../layers/mamba/mamba_mixer2.py | 165 +- .../layers/quantization/__init__.py | 7 +- .../layers/quantization/auto_round.py | 306 +++ .../layers/quantization/ipex_quant.py | 2 - .../layers/quantization/modelopt.py | 2 +- .../model_loader/bitsandbytes_loader.py | 16 +- .../model_loader/neuronx_distributed.py | 62 +- .../model_executor/model_loader/tensorizer.py | 5 + vllm/model_executor/models/bloom.py | 79 +- vllm/model_executor/models/exaone.py | 5 +- vllm/model_executor/models/falcon_h1.py | 684 ++++++ vllm/model_executor/models/granite.py | 21 +- vllm/model_executor/models/grok1.py | 10 +- vllm/model_executor/models/llava_next.py | 6 +- vllm/model_executor/models/llava_onevision.py | 6 +- vllm/model_executor/models/minimax_text_01.py | 10 +- vllm/model_executor/models/mixtral.py | 7 +- vllm/model_executor/models/mixtral_quant.py | 10 +- vllm/model_executor/models/nemotron.py | 16 +- vllm/model_executor/models/olmo.py | 16 +- vllm/model_executor/models/olmo2.py | 22 +- vllm/model_executor/models/olmoe.py | 5 +- vllm/model_executor/models/orion.py | 11 +- vllm/model_executor/models/phi4mm.py | 4 +- vllm/model_executor/models/phimoe.py | 5 +- vllm/model_executor/models/qwen2_moe.py | 5 +- vllm/model_executor/models/qwen3_moe.py | 5 +- vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/solar.py | 16 +- vllm/model_executor/models/stablelm.py | 10 +- vllm/model_executor/models/starcoder2.py | 5 +- vllm/model_executor/models/utils.py | 18 +- vllm/multimodal/hasher.py | 2 +- vllm/outputs.py | 9 + vllm/platforms/cpu.py | 11 + vllm/platforms/cuda.py | 4 + vllm/platforms/hpu.py | 11 + vllm/platforms/interface.py | 7 + vllm/platforms/neuron.py | 14 +- vllm/platforms/rocm.py | 52 +- vllm/platforms/tpu.py | 11 + vllm/platforms/xpu.py | 11 + vllm/sequence.py | 14 +- vllm/transformers_utils/configs/eagle.py | 6 +- vllm/utils.py | 6 + .../attention/backends/mla/rocm_aiter_mla.py | 40 +- vllm/v1/core/kv_cache_manager.py | 34 +- vllm/v1/core/kv_cache_utils.py | 13 +- vllm/v1/core/sched/output.py | 12 +- vllm/v1/core/sched/scheduler.py | 45 +- vllm/v1/engine/core.py | 8 +- vllm/v1/executor/multiproc_executor.py | 2 +- vllm/v1/kv_cache_interface.py | 42 - vllm/v1/spec_decode/eagle.py | 6 +- vllm/v1/worker/block_table.py | 47 - vllm/v1/worker/gpu_input_batch.py | 13 +- vllm/v1/worker/gpu_model_runner.py | 274 +-- vllm/v1/worker/tpu_model_runner.py | 35 +- vllm/worker/model_runner.py | 111 +- 164 files changed, 6239 insertions(+), 1815 deletions(-) rename docs/source/{serving => features}/multimodal_inputs.md (100%) create mode 100644 docs/source/features/prompt_embeds.md delete mode 100644 docs/source/serving/prompt_embeds.md create mode 100644 examples/offline_inference/automatic_prefix_caching.py create mode 100644 examples/offline_inference/prompt_embed_inference.py create mode 100644 examples/online_serving/prompt_embed_inference_with_openai_client.py create mode 100644 tests/entrypoints/openai/correctness/test_mteb.py create mode 100644 tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py create mode 100644 tests/neuron/2_core/test_mistral.py create mode 100644 tests/quantization/test_auto_round.py create mode 100644 tests/test_outputs.py create mode 100644 tools/install_nixl.sh create mode 100644 vllm/compilation/base_piecewise_backend.py create mode 100644 vllm/compilation/cuda_piecewise_backend.py create mode 100644 vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/auto_round.py create mode 100644 vllm/model_executor/models/falcon_h1.py diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 2118cf4595eb..b3c27e2c99c2 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -64,7 +64,7 @@ steps: - "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT" plugins: - docker-login#v3.0.0: - username: vllm + username: vllmbot password-env: DOCKERHUB_TOKEN env: DOCKER_BUILDKIT: "1" diff --git a/.buildkite/scripts/hardware_ci/run-neuron-test.sh b/.buildkite/scripts/hardware_ci/run-neuron-test.sh index ec6a080eb499..c0b9dd8dadba 100644 --- a/.buildkite/scripts/hardware_ci/run-neuron-test.sh +++ b/.buildkite/scripts/hardware_ci/run-neuron-test.sh @@ -11,13 +11,14 @@ container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" HF_CACHE="$(realpath ~)/huggingface" mkdir -p "${HF_CACHE}" HF_MOUNT="/root/.cache/huggingface" +HF_TOKEN=$(aws secretsmanager get-secret-value --secret-id "ci/vllm-neuron/hf-token" --region us-west-2 --query 'SecretString' --output text | jq -r .VLLM_NEURON_CI_HF_TOKEN) NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache" mkdir -p "${NEURON_COMPILE_CACHE_URL}" NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache" # Try building the docker image -aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws # prune old image and containers to save disk space, and only once a day # by using a timestamp file in tmp. @@ -47,6 +48,7 @@ trap remove_docker_container EXIT docker run --rm -it --device=/dev/neuron0 --network bridge \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ + -e "HF_TOKEN=${HF_TOKEN}" \ -v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \ -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ --name "${container_name}" \ diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 461fb6d30c45..0e4a0e2a531b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -59,6 +59,7 @@ steps: - pytest -v -s async_engine # AsyncLLMEngine - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py + - pytest -v -s test_outputs.py - pytest -v -s multimodal - pytest -v -s test_utils.py # Utils - pytest -v -s worker # Worker @@ -125,7 +126,7 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ - pytest -v -s entrypoints/test_chat_utils.py - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -138,6 +139,7 @@ steps: - vllm/core/ - tests/distributed/test_utils - tests/distributed/test_pynccl + - tests/distributed/test_events - tests/spec_decode/e2e/test_integration_dist_tp4 - tests/compile/test_basic_correctness - examples/offline_inference/rlhf.py @@ -156,6 +158,7 @@ steps: - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py + - pytest -v -s distributed/test_events.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 76aa5f7a35d5..a37bdb0f4d9e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -13,6 +13,7 @@ /vllm/model_executor/guided_decoding @mgoin @russellb /vllm/multimodal @DarkLight1337 @ywang96 /vllm/vllm_flash_attn @LucasWilkinson +/vllm/lora @jeejeelee CMakeLists.txt @tlrmchlsmth # vLLM V1 @@ -40,3 +41,4 @@ CMakeLists.txt @tlrmchlsmth /tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb /tests/v1/structured_output @mgoin @russellb /tests/weight_loading @mgoin @youkaichao +/tests/lora @jeejeelee diff --git a/CMakeLists.txt b/CMakeLists.txt index a6c54be9530b..ffb801d62619 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,7 +30,11 @@ set(ignoreMe "${VLLM_PYTHON_PATH}") set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") # Supported NVIDIA architectures. -set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL) + set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") +else() + set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") +endif() # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201") diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 17432159c94e..54f05e723226 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -84,7 +84,10 @@ def main( if version == "v2": if current_platform.is_rocm(): global PARTITION_SIZE - PARTITION_SIZE = 1024 if not args.custom_paged_attn else PARTITION_SIZE_ROCM + if not args.custom_paged_attn and not current_platform.is_navi(): + PARTITION_SIZE = 1024 + else: + PARTITION_SIZE = PARTITION_SIZE_ROCM num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), @@ -159,6 +162,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: scale, block_tables, seq_lens, + None, block_size, max_seq_len, alibi_slopes, diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index 0877da52435e..195872e8edd3 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -15,15 +15,6 @@ cutlassGetStatusString(error)); \ } -/** - * Panic wrapper for unwinding CUDA runtime errors - */ -#define CUDA_CHECK(status) \ - { \ - cudaError_t error = status; \ - TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ - } - inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { int max_shared_mem_per_block_opt_in = 0; cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 0bae119a7c46..8fda434d452f 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -28,4 +28,6 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); -#endif \ No newline at end of file +#endif + +bool moe_permute_unpermute_supported(); \ No newline at end of file diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 76d5f0eab021..9a7465261abf 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -5,6 +5,9 @@ #include "permute_unpermute_kernels/dispatch.h" #include "core/registration.h" +// moe_permute kernels require at least CUDA 12.0 +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + void moe_permute( const torch::Tensor& input, // [n_token, hidden] const torch::Tensor& topk_weights, //[n_token, topk] @@ -127,7 +130,45 @@ void moe_unpermute( }); } +#else + +void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, + torch::Tensor& topk_ids, + const torch::Tensor& token_expert_indicies, + const std::optional& expert_map, + int64_t n_expert, int64_t n_local_expert, int64_t topk, + const std::optional& align_block_size, + torch::Tensor& permuted_input, + torch::Tensor& expert_first_token_offset, + torch::Tensor& src_row_id2dst_row_id_map, + torch::Tensor& m_indices) { + TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); +} + +void moe_unpermute(const torch::Tensor& input, + const torch::Tensor& topk_weights, torch::Tensor& topk_ids, + const torch::Tensor& token_expert_indicies, + const std::optional& expert_map, + int64_t n_expert, int64_t n_local_expert, int64_t topk, + const std::optional& align_block_size, + torch::Tensor& permuted_input, + torch::Tensor& expert_first_token_offset, + torch::Tensor& src_row_id2dst_row_id_map, + torch::Tensor& m_indices) { + TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); +} + +#endif + +bool moe_permute_unpermute_supported() { +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + return true; +#else + return false; +#endif +} + TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_permute", &moe_permute); m.impl("moe_unpermute", &moe_unpermute); -} \ No newline at end of file +} diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu index aa353d0f0437..de2c153882d9 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -1,6 +1,9 @@ #include "moe_permute_unpermute_kernel.h" +// moe_permute kernels require at least CUDA 12.0 +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + // CubKeyValueSorter definition begin CubKeyValueSorter::CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {} @@ -131,9 +134,6 @@ __global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size, int num_experts) { auto tidx = threadIdx.x; auto bidx = blockIdx.x; - auto lidx = tidx & 31; - auto widx = tidx >> 5; - auto warp_count = (blockDim.x + 31) >> 5; auto offset = bidx * blockDim.x; auto bound = min(offset + blockDim.x, size); extern __shared__ int smem_expert_map[]; @@ -226,4 +226,6 @@ void getMIndices(int64_t* expert_first_token_offset, expert_first_token_offset, align_expert_first_token_offset, m_indices, num_local_expert, align_block_size); } -} \ No newline at end of file +} + +#endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 810026d034c0..7d35ec79ead4 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Calculate the result of moe by summing up the partial results // from all selected experts. - m.def("moe_sum(Tensor! input, Tensor output) -> ()"); + m.def("moe_sum(Tensor input, Tensor! output) -> ()"); m.impl("moe_sum", torch::kCUDA, &moe_sum); // Aligning the number of tokens to be processed by each expert such @@ -77,7 +77,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " "expert_first_token_offset, int n_expert, int n_local_expert,int " "topk, Tensor! hidden_states)->()"); - // conditionally compiled so impl registration is in source file + + m.def("moe_permute_unpermute_supported() -> bool"); + m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); #endif } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 3c258ddce61e..e9b408fbf2ee 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -123,7 +123,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { } bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { - // CUTLASS groped FP8 kernels need at least CUDA 12.3 + // CUTLASS grouped FP8 kernels need at least CUDA 12.3 // and SM90 (Hopper) #if defined CUDA_VERSION diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 8cc5a0f4f218..f1e7da164199 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -30,6 +30,14 @@ #define __HIP__GFX9__ #endif +#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__)) + #define __HIP__GFX11__ +#endif + +#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__)) + #define __HIP__GFX12__ +#endif + #if defined(NDEBUG) #undef NDEBUG #include @@ -43,7 +51,7 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#if defined(__HIP__GFX9__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 @@ -1482,191 +1490,1690 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } } -#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +#elif defined(__HIP__GFX11__) -// clang-format off -template -__global__ -__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int* __restrict__ query_start_loc_ptr, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, const float* k_scale, const float* v_scale) { - UNREACHABLE_CODE +using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; + +using bit16x8 = __attribute__((__vector_size__(8 * sizeof(uint16_t)))) uint16_t; +union b16x8_u { + bit16x8 u16x8; + _B16x4 xy[2]; +}; +typedef b16x8_u _B16x8; + +using bit16x16 = + __attribute__((__vector_size__(16 * sizeof(uint16_t)))) uint16_t; +union b16x16_u { + bit16x16 u16x16; + _B16x8 xy[2]; +}; +typedef b16x16_u _B16x16; + +using _B8x8 = uint2; +using bit8_t = uint8_t; + +typedef struct _B8x16 { + _B8x8 xy[2]; +} _B8x16; + +template +__device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x16& inpA, + const bit16x16& inpB, + const floatx8& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(inpA, inpB, inpC); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(inpA, inpB, inpC); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { + if constexpr (std::is_same::value) { + union h2cvt { + __half2 h2[4]; + _B16x8 b16x8; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3])); + u.h2[2] = __float22half2_rn(make_float2(inp[4], inp[5])); + u.h2[3] = __float22half2_rn(make_float2(inp[6], inp[7])); + return u.b16x8; + } else if constexpr (std::is_same::value) { + union b2cvt { + __hip_bfloat162 b2[4]; + _B16x8 b16x8; + } u; + + u.b2[0] = __float22bfloat162_rn(make_float2(inp[0], inp[1])); + u.b2[1] = __float22bfloat162_rn(make_float2(inp[2], inp[3])); + u.b2[2] = __float22bfloat162_rn(make_float2(inp[4], inp[5])); + u.b2[3] = __float22bfloat162_rn(make_float2(inp[6], inp[7])); + + return u.b16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } } +// clang-format off template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO> __global__ -__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int* __restrict__ query_start_loc_ptr, // [num_seqs] +__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] int max_ctx_blocks, const float* k_scale, const float* v_scale) { - UNREACHABLE_CODE -} + // clang-format on + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11 + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane2id = laneid % 2; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; -// Grid: (num_heads, num_seqs). -template -__global__ -__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int* __restrict__ query_start_loc_ptr, // [num_seqs] - const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { - UNREACHABLE_CODE -} -// clang-format on + const int seq_idx = blockIdx.x; + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) { + return; + } -#endif // defined(__HIP__GFX9__) TODO: Add NAVI support + const int partition_idx = blockIdx.y; -#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ - paged_attention_ll4mi_QKV_mfma16_kernel \ - <<>>( \ - query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ - max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ - kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ - max_ctx_blocks, k_scale_ptr, v_scale_ptr); + constexpr int T_PAR_SIZE = 256; // token partition size set to 256 -#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ - paged_attention_ll4mi_QKV_mfma4_kernel \ - <<>>( \ - query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ - max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ - kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ - max_ctx_blocks, k_scale_ptr, v_scale_ptr); + const int max_num_partitions = gridDim.y; -#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ - paged_attention_ll4mi_reduce_kernel \ - <<>>( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ - context_lens_ptr, query_start_loc_ptr, max_num_partitions, \ - fp8_out_scale_ptr); + const int context_len = context_lens[seq_idx]; // length of a seq -template -void paged_attention_custom_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, const int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& context_lens, - const std::optional& query_start_loc, int max_context_len, - const std::optional& alibi_slopes, torch::Tensor& k_scale, - torch::Tensor& v_scale, const std::optional& fp8_out_scale) { - int num_seqs = block_tables.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); + const int partition_start_token_idx = partition_idx * T_PAR_SIZE; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } - // NOTE: query start location is optional for V0 decode should not be used. - // If batch contains mix of prefills and decode, prefills should be skipped. - const int* query_start_loc_ptr = - query_start_loc - ? reinterpret_cast(query_start_loc.value().data_ptr()) - : nullptr; + constexpr int GQA_RATIO2 = DIVIDE_ROUND_UP(GQA_RATIO, 2); - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + // shared_logits is used for multiple purposes + __shared__ _B16x16 shared_logits[NWARPS][2][16][2]; - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); - const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); - // NOTE: fp8_out_scale is optional. - const auto fp8_out_scale_ptr = - fp8_out_scale - ? static_cast(fp8_out_scale.value().data_ptr()) - : nullptr; - OUTT* out_ptr = reinterpret_cast(out.data_ptr()); + // for QK wmma16x16, layout is QHead/Tokenx16 across every 16 lanes, + // 32 Bytes HeadElements in each lane, 2x16B HeadElements across a row of warp + constexpr int ROWS_PER_WARP = + WARP_SIZE / 16 / 2; // rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = + 16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = + CONTIGUOUS_KV_ELEMS_16B_LOAD * + ROWS_PER_WARP; // each fetch across a warp fetches these many elements + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across + // warp - const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + _B16x16 Qlocal[QKHELOOP / 2]; // note that 16 contiguous elements of Q should + // be fetched per lane for 16 bit cache types - // partition size is fixed at 256 since both mfma4 and mfma16 kernels support - // it mfma4 kernel also supports partition size 512 - constexpr int PARTITION_SIZE = 256; - const int max_num_partitions = - DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); - const int gqa_ratio = num_heads / num_kv_heads; - assert(num_heads % num_kv_heads == 0); - assert(head_size == HEAD_SIZE); + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); - constexpr int NTHR = 256; - dim3 grid(num_seqs, max_num_partitions, num_kv_heads); - dim3 block(NTHR); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + constexpr int TOKENS_PER_WARP = + T_PAR_SIZE / + NWARPS; // sub partition of tokens per warp for qk calculation + constexpr int TLOOP = + TOKENS_PER_WARP / + 16; // each wmma16x16x16 instruction processes 16 tokens - // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 - switch (gqa_ratio) { - case 1: - LAUNCH_CUSTOM_ATTENTION_MFMA4(1); - break; - case 2: - LAUNCH_CUSTOM_ATTENTION_MFMA4(2); - break; - case 3: - LAUNCH_CUSTOM_ATTENTION_MFMA4(3); - break; - case 4: - LAUNCH_CUSTOM_ATTENTION_MFMA4(4); - break; + _B16x16 Klocal[TLOOP] + [QKHELOOP / 2]; // can be interpreted as B8x16 for 8 bit types + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + // for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps + // each wmma takes QH16xT16x16HE across warp + // repeat wmma across QKHELOOP dimension + // output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens + // across 2 rows x 8 tokens per lane + + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + + if (GQA_RATIO == 1) { + const int local_qhead_idx = lane16id % GQA_RATIO; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const scalar_t* q_ptr = + q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE; + if (lane16id < GQA_RATIO) { + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { + const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH * 2; + const _B16x16* q_fetch_ptr_32B = + reinterpret_cast(q_fetch_ptr); + Qlocal[qkhe_depth] = *q_fetch_ptr_32B; + } + } + } else { + // fetch Q in shared across warps and then write to registers + const int local_qhead_idx = 2 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const scalar_t* q_ptr = + q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE; + + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + + const int offset1 = + lane16id / + 2; // 16 contiguous chunks of head elems are spread across 8x2lanes + shared_logits[offset1][lane2id][local_qhead_idx][0].xy[0] = tmp; + } + + __syncthreads(); + + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { + Qlocal[qkhe_depth].xy[0] = + shared_logits[qkhe_depth][0][lane16id % GQA_RATIO][0].xy[0]; + Qlocal[qkhe_depth].xy[1] = + shared_logits[qkhe_depth][1][lane16id % GQA_RATIO][0].xy[0]; + } + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + // fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + + constexpr int KX = 16 / sizeof(cache_t); + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = 0; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = + static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = + reinterpret_cast(k_fetch_ptr); + Klocal[token_depth][qkhe_depth / 2].xy[qkhe_depth % 2] = *k_fetch_ptr_16B; + } + } + + constexpr int VTOKENS_PER_LANE = + TOKENS_PER_WARP / ROWS_PER_WARP; // 32/1 = 32 vtokens per lane + constexpr int VBLOCKS_PER_LANE = 2; // assumes block size >=16 + constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps + constexpr int VTLANELOOP = DIVIDE_ROUND_UP( + VTOKENS_PER_LANE, + CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes + // minimum block size is 16 + constexpr int VHELOOP = DIVIDE_ROUND_UP( + (HEAD_SIZE / 16), NWARPS); // head_size distributed across warps; each + // wmma instr works on 16 head elements + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + // fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; + vblock_depth++) { + const int vlocal_token_idx = + vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = + partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x16 Vlocal[VTLOOP][VHELOOP] + [VTLANELOOP / 2]; // this can be interpreted as B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + // v fetches are 16head elems across lanes x (16x2) tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int64_t vblock_number = static_cast( + vphysical_block_number[vtoken_depth] + [vfetch_depth / VBLOCKS_PER_LANE]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = + v_ptr3 + + (vfetch_depth % VBLOCKS_PER_LANE) * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = + reinterpret_cast(v_fetch_ptr); + Vlocal[vtoken_depth][vhe_depth][vfetch_depth / 2].xy[vfetch_depth % 2] = + *v_fetch_ptr_16B; + } + } + } + + floatx8 dout[TLOOP]; + // qk wmma + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { + dout[token_depth] = gcn_wmma16x16x16_instr( + Klocal[token_depth][qkhe_depth].u16x16, Qlocal[qkhe_depth].u16x16, + dout[token_depth]); + } + dout[token_depth] *= scale; + } + + // calculate qk_max and exp_sum per warp and write to shared memory + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + const int qkout_token_idx = + partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid; + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = (local_token_idx + 2 * i < context_len) + ? dout[token_depth][i] + : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, 16)); + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = (local_token_idx + 2 * i < context_len) + ? __expf(dout[token_depth][i] - qk_max) + : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + exp_sum += __shfl_xor(exp_sum, 16); + + __syncthreads(); + + if (laneid < 16) { + shared_qk_max[warpid][lane16id] = qk_max; + shared_exp_sum[warpid][lane16id] = exp_sum; + } + + __syncthreads(); + + // calculate partition qk_max and exp_sum + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = shared_qk_max[w][lane16id]; + partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]); + } + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max); + partition_exp_sum += shared_exp_sum[w][lane16id] * warp_qk_max_exp[w]; + } + + const float inv_sum_scale = + __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid]; + + __syncthreads(); + + // write logits to shared mem + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] *= inv_sum_scale; + shared_logits[warpid][token_depth][lane16id][0].xy[rowid] = + from_floatx8(dout[token_depth]); + } + __syncthreads(); + + _B16x8 swp_buf[TLOOP][2]; + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + swp_buf[token_depth][0] = + shared_logits[warpid][token_depth][lane16id][0].xy[0]; + swp_buf[token_depth][1] = + shared_logits[warpid][token_depth][lane16id][0].xy[1]; + } + + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + shared_logits[warpid][token_depth][lane16id][0].xy[rowid].u16x8[i] = + swp_buf[token_depth][i % 2].u16x8[4 * rowid + (i / 2)]; + } + } + + // write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + + (wg_start_head_idx + qhead_idx) * max_num_partitions + + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + _B16x8 outelems[VHELOOP]; + // Softmax V wmma + // v layout: 16he across lanes x (16x2) tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx8 tmp_out = {0}; + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP / 2; + vfetch_depth++) { + const int offset = vfetch_depth; + // if output format is 16 qheads across 16 lanes, 16 head elems spread + // across rows + tmp_out = gcn_wmma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x16, + shared_logits[vtoken_depth][offset][lane16id][0].u16x16, tmp_out); + } + } + outelems[vhe_depth] = from_floatx8(tmp_out); + } + + __syncthreads(); + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + shared_logits[warpid][vhe_depth][lane16id][0].xy[rowid] = + outelems[vhe_depth]; // lane16 id head dimension; rowid head element + // dimension + } + + __syncthreads(); + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + swp_buf[vhe_depth][0] = shared_logits[warpid][vhe_depth][lane16id][0].xy[0]; + swp_buf[vhe_depth][1] = shared_logits[warpid][vhe_depth][lane16id][0].xy[1]; + } + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + shared_logits[warpid][vhe_depth][lane16id][0].xy[rowid].u16x8[i] = + swp_buf[vhe_depth][i % 2].u16x8[4 * rowid + (i / 2)]; + } + } + + __syncthreads(); + + // write to tmp_out with coalesced writes after reading from shared mem + if (warpid == 0) { + _B16x8 vout[GQA_RATIO2]; + // each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + const int offset1 = (head_elem_idx / 16) % NWARPS; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 8) % 2; // num_he % num_row + vout[h] = + shared_logits[offset1][offset2][local_head_idx][0].xy[offset3]; + } + + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult + + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + } +} + +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { + const auto num_heads = gridDim.x; + const auto head_idx = blockIdx.x; + const auto seq_idx = blockIdx.y; + + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { + return; + } + + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + [[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + // max num partitions supported is warp_size * NPAR_LOOPS + __shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + int valid_partition[NPAR_LOOPS]; + float reg_max_logit[NPAR_LOOPS]; + const int last_valid_partition = num_partitions - 1; + + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + valid_partition[i] = + (partition_no < num_partitions) ? partition_no : last_valid_partition; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; + } + float max_logit = reg_max_logit[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + max_logit = fmaxf(max_logit, reg_max_logit[i]); + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float rescaled_exp_sum[NPAR_LOOPS]; + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + rescaled_exp_sum[i] *= (partition_no < num_partitions) + ? expf(reg_max_logit[i] - max_logit) + : 0.0f; + } + float global_exp_sum = rescaled_exp_sum[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + global_exp_sum += rescaled_exp_sum[i]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + shared_exp_sums[partition_no] = rescaled_exp_sum[i]; + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x == 0) { + shared_global_exp_sum = global_exp_sum; + } + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 32; + scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = from_float(dzero); + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + } + } // num_partitions > JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + } + } + + for (int p = 1; p < NPAR_LOOPS; p++) { + if (num_partitions > p * MAX_NPAR) { + idx = 0; + #pragma unroll + for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; + } + } + } + + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + + static_cast(head_idx) * HEAD_SIZE; + out_ptr[threadIdx.x] = from_float(acc); +} + +#elif defined(__HIP__GFX12__) + +using floatx8 = __attribute__((__vector_size__(8 * sizeof(float)))) float; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; + +using bit16x8 = __attribute__((__vector_size__(8 * sizeof(uint16_t)))) uint16_t; +union b16x8_u { + bit16x8 u16x8; + _B16x4 xy[2]; +}; +typedef b16x8_u _B16x8; + +using _B8x8 = uint2; +using bit8_t = uint8_t; + +typedef struct _B8x16 { + _B8x8 xy[2]; +} _B8x16; + +template +__device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x8& inpA, + const bit16x8& inpB, + const floatx8& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(inpA, inpB, inpC); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(inpA, inpB, inpC); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float_b16(const bit16_t& inp) { + union tmpcvt { + bit16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + t16.u = inp; + if constexpr (std::is_same::value) { + return (float)t16.f; + } else if constexpr (std::is_same::value) { + return __bfloat162float(t16.b); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { + if constexpr (std::is_same::value) { + union h2cvt { + __half2 h2[4]; + _B16x8 b16x8; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3])); + u.h2[2] = __float22half2_rn(make_float2(inp[4], inp[5])); + u.h2[3] = __float22half2_rn(make_float2(inp[6], inp[7])); + return u.b16x8; + } else if constexpr (std::is_same::value) { + union b2cvt { + __hip_bfloat162 b2[4]; + _B16x8 b16x8; + } u; + + u.b2[0] = __float22bfloat162_rn(make_float2(inp[0], inp[1])); + u.b2[1] = __float22bfloat162_rn(make_float2(inp[2], inp[3])); + u.b2[2] = __float22bfloat162_rn(make_float2(inp[4], inp[5])); + u.b2[3] = __float22bfloat162_rn(make_float2(inp[6], inp[7])); + + return u.b16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +// clang-format off +template +__global__ +__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + // clang-format on + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11 + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane2id = laneid % 2; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; + + const int seq_idx = blockIdx.x; + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { + return; + } + const int partition_idx = blockIdx.y; + + constexpr int T_PAR_SIZE = 256; // token partition size set to 256 + + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; // length of a seq + + const int partition_start_token_idx = partition_idx * T_PAR_SIZE; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + + constexpr int GQA_RATIO2 = DIVIDE_ROUND_UP(GQA_RATIO, 2); + + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + // shared_logits is used for multiple purposes + __shared__ _B16x8 shared_logits[NWARPS][2][16][2]; + + // for QK wmma16x16_gfx12, layout is QHead/Tokenx16 across every 16 lanes, + // 16 Bytes HeadElements in each lane, 2x16B HeadElements across 2 rows of + // warp + constexpr int ROWS_PER_WARP = + WARP_SIZE / 16; // rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = + 16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = + CONTIGUOUS_KV_ELEMS_16B_LOAD * + ROWS_PER_WARP; // each fetch across a warp fetches these many elements + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across + // warp + + _B16x8 Qlocal[QKHELOOP]; // note that 16 contiguous elements of Q should + // be fetched per lane for 16 bit cache types + + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); + + constexpr int TOKENS_PER_WARP = + T_PAR_SIZE / + NWARPS; // sub partition of tokens per warp for qk calculation + constexpr int TLOOP = + TOKENS_PER_WARP / + 16; // each wmma16x16x16 instruction processes 16 tokens + + _B16x8 Klocal[TLOOP] + [QKHELOOP]; // can be interpreted as B8x16 for 8 bit types + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + // for QK wmma, tokens in multiples of TOKENS_PER_WARP are spread across warps + // each wmma takes QH16xT16x16HE across warp + // repeat wmma across QKHELOOP dimension + // output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens + // across 2 rows x 8 tokens per lane + + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + + if (GQA_RATIO == 1) { + const int local_qhead_idx = lane16id % GQA_RATIO; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const scalar_t* q_ptr = q + query_start_off * q_stride + + global_qhead_idx * HEAD_SIZE + + rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + if (lane16id < GQA_RATIO) { + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + Qlocal[qkhe_depth] = *q_fetch_ptr_16B; + } + } + } else { + // fetch Q in shared across warps and then write to registers + const int local_qhead_idx = 2 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const scalar_t* q_ptr = + q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE; + + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + + const int offset1 = + lane16id / + 2; // 16 contiguous chunks of head elems are spread across 8x2lanes + shared_logits[offset1][lane2id][local_qhead_idx][0] = tmp; + } + + __syncthreads(); + + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + Qlocal[qkhe_depth] = + shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][0]; + } + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + // fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + + constexpr int KX = 16 / sizeof(cache_t); + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = + static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = + reinterpret_cast(k_fetch_ptr); + Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + } + } + + constexpr int VTOKENS_PER_LANE = + TOKENS_PER_WARP / ROWS_PER_WARP; // 32/2 = 16 vtokens per lane + constexpr int VBLOCKS_PER_LANE = 1; // assumes block size >=16 + constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps + constexpr int VTLANELOOP = DIVIDE_ROUND_UP( + VTOKENS_PER_LANE, + CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes + // minimum block size is 16 + constexpr int VHELOOP = DIVIDE_ROUND_UP( + (HEAD_SIZE / 16), NWARPS); // head_size distributed across warps; each + // wmma instr works on 16 head elements + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + // fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; + vblock_depth++) { + const int vlocal_token_idx = + vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + + rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = + partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x8 Vlocal[VTLOOP][VHELOOP] + [VTLANELOOP]; // this can be interpreted as B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + + ((rowid * VTOKENS_PER_LANE) % BLOCK_SIZE); + + // v fetches are 16head elems across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int vblock_depth = 0; + const int64_t vblock_number = static_cast( + vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = + v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = + reinterpret_cast(v_fetch_ptr); + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; + } + } + } + + floatx8 dout[TLOOP]; + // qk wmma + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + dout[token_depth] = gcn_wmma16x16x16_instr( + Klocal[token_depth][qkhe_depth].u16x8, Qlocal[qkhe_depth].u16x8, + dout[token_depth]); + } + dout[token_depth] *= scale; + } + + // calculate qk_max and exp_sum per warp and write to shared memory + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + const int qkout_token_idx = + partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 8; + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = + (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, 16)); + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 8; i++) { + const float tmp = (local_token_idx + i < context_len) + ? __expf(dout[token_depth][i] - qk_max) + : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + exp_sum += __shfl_xor(exp_sum, 16); + + __syncthreads(); + + if (laneid < 16) { + shared_qk_max[warpid][lane16id] = qk_max; + shared_exp_sum[warpid][lane16id] = exp_sum; + } + + __syncthreads(); + + // calculate partition qk_max and exp_sum + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = shared_qk_max[w][lane16id]; + partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]); + } + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max); + partition_exp_sum += shared_exp_sum[w][lane16id] * warp_qk_max_exp[w]; + } + + const float inv_sum_scale = + __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid]; + + __syncthreads(); + + // write logits to shared mem + #pragma unroll + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] *= inv_sum_scale; + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx8(dout[token_depth]); + } + + // write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + + (wg_start_head_idx + qhead_idx) * max_num_partitions + + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + _B16x8 outelems[VHELOOP]; + // Softmax V wmma + // v layout: 16he across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx8 tmp_out = {0}; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int offset = rowid * VTLANELOOP + vfetch_depth; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // if output format is 16 qheads across 16 lanes, 16 head elems spread + // across rows + tmp_out = gcn_wmma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x8, + shared_logits[vtoken_depth][offset2][lane16id][offset1].u16x8, + tmp_out); + } + } + outelems[vhe_depth] = from_floatx8(tmp_out); + } + + __syncthreads(); + + #pragma unroll + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + shared_logits[warpid][vhe_depth][lane16id][rowid] = + outelems[vhe_depth]; // lane16 id head dimension; rowid head element + // dimension + } + + __syncthreads(); + + // write to tmp_out with coalesced writes after reading from shared mem + if (warpid == 0) { + _B16x8 vout[GQA_RATIO2]; + // each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + const int offset1 = (head_elem_idx / 16) % NWARPS; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 8) % 2; // num_he % num_row + vout[h] = shared_logits[offset1][offset2][local_head_idx][offset3]; + } + + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult + + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO2; h++) { + const int local_head_idx = 2 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + } +} + +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { + const auto num_heads = gridDim.x; + const auto head_idx = blockIdx.x; + const auto seq_idx = blockIdx.y; + + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { + return; + } + + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + [[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + // max num partitions supported is warp_size * NPAR_LOOPS + __shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + int valid_partition[NPAR_LOOPS]; + float reg_max_logit[NPAR_LOOPS]; + const int last_valid_partition = num_partitions - 1; + + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + valid_partition[i] = + (partition_no < num_partitions) ? partition_no : last_valid_partition; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; + } + float max_logit = reg_max_logit[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + max_logit = fmaxf(max_logit, reg_max_logit[i]); + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float rescaled_exp_sum[NPAR_LOOPS]; + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + rescaled_exp_sum[i] *= (partition_no < num_partitions) + ? expf(reg_max_logit[i] - max_logit) + : 0.0f; + } + float global_exp_sum = rescaled_exp_sum[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + global_exp_sum += rescaled_exp_sum[i]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + shared_exp_sums[partition_no] = rescaled_exp_sum[i]; + } + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x == 0) { + shared_global_exp_sum = global_exp_sum; + } + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 32; + scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = from_float(dzero); + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + } + } // num_partitions > JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + } + } + + for (int p = 1; p < NPAR_LOOPS; p++) { + if (num_partitions > p * MAX_NPAR) { + idx = 0; + #pragma unroll + for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; + } + } + } + + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + + static_cast(head_idx) * HEAD_SIZE; + out_ptr[threadIdx.x] = from_float(acc); +} + +#else + +// clang-format off +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { + UNREACHABLE_CODE +} +// clang-format on + +#endif + +#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma16_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ + kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ + max_ctx_blocks, k_scale_ptr, v_scale_ptr); + +#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma4_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ + kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ + max_ctx_blocks, k_scale_ptr, v_scale_ptr); + +#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ + paged_attention_ll4mi_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ + context_lens_ptr, query_start_loc_ptr, max_num_partitions, \ + fp8_out_scale_ptr); + +template +void paged_attention_custom_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + const std::optional& query_start_loc, int max_context_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const std::optional& fp8_out_scale) { + int num_seqs = block_tables.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: query start location is optional for V0 decode should not be used. + // If batch contains mix of prefills and decode, prefills should be skipped. + const int* query_start_loc_ptr = + query_start_loc + ? reinterpret_cast(query_start_loc.value().data_ptr()) + : nullptr; + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + // NOTE: fp8_out_scale is optional. + const auto fp8_out_scale_ptr = + fp8_out_scale + ? static_cast(fp8_out_scale.value().data_ptr()) + : nullptr; + OUTT* out_ptr = reinterpret_cast(out.data_ptr()); + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + // partition size is fixed at 256 since both mfma4 and mfma16 kernels support + // it mfma4 kernel also supports partition size 512 + constexpr int PARTITION_SIZE = 256; + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + + constexpr int NTHR = 256; + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 + switch (gqa_ratio) { + case 1: + LAUNCH_CUSTOM_ATTENTION_MFMA4(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION_MFMA4(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION_MFMA4(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION_MFMA4(4); + break; case 5: LAUNCH_CUSTOM_ATTENTION_MFMA16(5); break; @@ -1744,13 +3251,195 @@ void paged_attention_custom_launcher( } } -#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ - PSIZE, ALIBI_ENABLED) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ - max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); +template +void paged_attention_custom_launcher_navi( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + const std::optional& query_start_loc, int max_context_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale) { + int num_seqs = block_tables.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: query start location is optional for V0 decode should not be used. + // If batch contains mix of prefills and decode, prefills should be skipped. + const int* query_start_loc_ptr = + query_start_loc + ? reinterpret_cast(query_start_loc.value().data_ptr()) + : nullptr; + + // NOTE: Navi does not support alibi_slopes. + const float* alibi_slopes_ptr = nullptr; + + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + // NOTE: Navi does not support fp8. + const auto fp8_out_scale_ptr = nullptr; + OUTT* out_ptr = reinterpret_cast(out.data_ptr()); + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + constexpr int PARTITION_SIZE = 256; + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + + constexpr int NTHR = 256; + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (gqa_ratio) { + case 1: + LAUNCH_CUSTOM_ATTENTION_MFMA16(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION_MFMA16(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION_MFMA16(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION_MFMA16(4); + break; + case 5: + LAUNCH_CUSTOM_ATTENTION_MFMA16(5); + break; + case 6: + LAUNCH_CUSTOM_ATTENTION_MFMA16(6); + break; + case 7: + LAUNCH_CUSTOM_ATTENTION_MFMA16(7); + break; + case 8: + LAUNCH_CUSTOM_ATTENTION_MFMA16(8); + break; + case 9: + LAUNCH_CUSTOM_ATTENTION_MFMA16(9); + break; + case 10: + LAUNCH_CUSTOM_ATTENTION_MFMA16(10); + break; + case 11: + LAUNCH_CUSTOM_ATTENTION_MFMA16(11); + break; + case 12: + LAUNCH_CUSTOM_ATTENTION_MFMA16(12); + break; + case 13: + LAUNCH_CUSTOM_ATTENTION_MFMA16(13); + break; + case 14: + LAUNCH_CUSTOM_ATTENTION_MFMA16(14); + break; + case 15: + LAUNCH_CUSTOM_ATTENTION_MFMA16(15); + break; + case 16: + LAUNCH_CUSTOM_ATTENTION_MFMA16(16); + break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; + } + + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + const int warp_size = 32; + const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, warp_size); + // reduction kernel supports upto 16 NPAR_loops * 32 (warp_size) * 256 + // (partition size) = 128K context length + switch (npar_loops) { + case 1: + LAUNCH_CUSTOM_REDUCTION(1); + break; + case 2: + LAUNCH_CUSTOM_REDUCTION(2); + break; + case 3: + LAUNCH_CUSTOM_REDUCTION(3); + break; + case 4: + LAUNCH_CUSTOM_REDUCTION(4); + break; + case 5: + LAUNCH_CUSTOM_REDUCTION(5); + break; + case 6: + LAUNCH_CUSTOM_REDUCTION(6); + break; + case 7: + LAUNCH_CUSTOM_REDUCTION(7); + break; + case 8: + LAUNCH_CUSTOM_REDUCTION(8); + break; + case 9: + LAUNCH_CUSTOM_REDUCTION(9); + break; + case 10: + LAUNCH_CUSTOM_REDUCTION(10); + break; + case 11: + LAUNCH_CUSTOM_REDUCTION(11); + break; + case 12: + LAUNCH_CUSTOM_REDUCTION(12); + break; + case 13: + LAUNCH_CUSTOM_REDUCTION(13); + break; + case 14: + LAUNCH_CUSTOM_REDUCTION(14); + break; + case 15: + LAUNCH_CUSTOM_REDUCTION(15); + break; + case 16: + LAUNCH_CUSTOM_REDUCTION(16); + break; + default: + TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); + break; + } +} + +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ + PSIZE, ALIBI_ENABLED) \ + if (!is_navi) { \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ + max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \ + } else { \ + paged_attention_custom_launcher_navi< \ + T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ + max_context_len, alibi_slopes, k_scale, v_scale); \ + } #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ OUTT, PSIZE) \ @@ -1807,6 +3496,24 @@ void paged_attention_custom_launcher( break; \ } +bool is_navi_gpu() { + static bool is_cached = false; + static bool result; + + if (!is_cached) { + int device_id; + hipDeviceProp_t deviceProp; + hipGetDevice(&device_id); + hipGetDeviceProperties(&deviceProp, device_id); + + std::string arch = deviceProp.gcnArchName; + result = arch.find("gfx11") == 0 || arch.find("gfx12") == 0; + is_cached = true; + } + + return result; +} + // clang-format off void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] @@ -1827,6 +3534,8 @@ void paged_attention( torch::Tensor& v_scale, const std::optional& fp8_out_scale) { // clang-format on + bool is_navi = is_navi_gpu(); + const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index 9c8a50332ad0..c22523da4e43 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -8,6 +8,8 @@ #include +#include "cuda_utils.h" + #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" @@ -95,9 +97,9 @@ struct cutlass_sparse_3x_gemm { // clang-format off using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, - ElementAB, cutlass::layout::RowMajor, AlignmentAB, - ElementAB, cutlass::layout::ColumnMajor, AlignmentAB, + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, + ElementAB, cutlass::layout::RowMajor, AlignmentAB, + ElementAB, cutlass::layout::ColumnMajor, AlignmentAB, ElementAcc, TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp; diff --git a/docker/Dockerfile b/docker/Dockerfile index 97a7879da876..cc3499d1f0a9 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -189,6 +189,8 @@ WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive ARG TARGETPLATFORM +SHELL ["/bin/bash", "-c"] + RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment @@ -261,8 +263,11 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'; \ else \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX'; \ - fi && \ - export FLASHINFER_ENABLE_AOT=1; \ + fi; \ + CUDA_MAJOR="${CUDA_VERSION%%.*}"; \ + if [ "$CUDA_MAJOR" -lt 12 ]; then \ + export FLASHINFER_ENABLE_SM90=0; \ + fi; \ uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@21ea1d2545f74782b91eb8c08fd503ac4c0743fc" ; \ fi COPY examples examples @@ -273,7 +278,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ . /etc/environment && \ uv pip list -# Although we build Flashinfer with AOT mode, there's still +# Even when we build Flashinfer with AOT mode, there's still # some issues w.r.t. JIT compilation. Therefore we need to # install build dependencies for JIT compilation. # TODO: Remove this once FlashInfer AOT wheel is fixed @@ -301,8 +306,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4" # install development dependencies (for testing) -RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/dev.txt +RUN --mount=type=cache,target=/root/.cache/uv \ + CUDA_MAJOR="${CUDA_VERSION%%.*}"; \ + if [ "$CUDA_MAJOR" -ge 12 ]; then \ + uv pip install --system -r requirements/dev.txt; \ + fi # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/docker/Dockerfile.neuron b/docker/Dockerfile.neuron index 2b63fe301bac..259dc5a23f78 100644 --- a/docker/Dockerfile.neuron +++ b/docker/Dockerfile.neuron @@ -1,6 +1,6 @@ # default base image # https://gallery.ecr.aws/neuron/pytorch-inference-neuronx -ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.22.0-ubuntu22.04" +ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04" FROM $BASE_IMAGE @@ -22,8 +22,7 @@ WORKDIR ${APP_MOUNT}/vllm RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity -RUN python3 -m pip install sentencepiece transformers==4.48.0 -U -RUN python3 -m pip install neuronx-cc==2.17.194.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U RUN python3 -m pip install pytest # uninstall transformers-neuronx package explicitly to avoid version conflict @@ -49,6 +48,8 @@ RUN python3 -m pip install -e tests/vllm_test_utils # FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps +RUN python3 -m pip install sentencepiece transformers==4.48.0 -U + # overwrite entrypoint to run bash script RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 222b9c158e5e..45efcbde698b 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="5a77249" +ARG AITER_BRANCH="c1debd8" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/docker/Dockerfile.s390x b/docker/Dockerfile.s390x index 9c10cd56b594..4e89bb3057c5 100644 --- a/docker/Dockerfile.s390x +++ b/docker/Dockerfile.s390x @@ -84,16 +84,40 @@ RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \ rustup default stable && \ rustup show +FROM python-install AS torch +ARG TORCH_VERSION=2.7.0 +ENV export _GLIBCXX_USE_CXX11_ABI=1 +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" + +WORKDIR /tmp + +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + git clone https://github.com/pytorch/pytorch.git && \ + cd pytorch && \ + git checkout v2.7.0 && \ + git submodule sync && \ + git submodule update --init --recursive && \ + uv pip install cmake ninja && \ + uv pip install -r requirements.txt && \ + python setup.py bdist_wheel + + FROM python-install AS torch-vision # Install torchvision -ARG TORCH_VERSION=2.7.0.dev20250304 +ARG TORCH_VERSION=2.7.0 ARG TORCH_VISION_VERSION=v0.20.1 WORKDIR /tmp RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ git clone https://github.com/pytorch/vision.git && \ cd vision && \ git checkout $TORCH_VISION_VERSION && \ - uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \ + TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \ + uv pip install -v $TORCH_WHL_FILE && \ python setup.py bdist_wheel FROM python-install AS hf-xet-builder @@ -138,15 +162,17 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \ --mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \ --mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \ + --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ sed -i '/^torch/d' requirements/build.txt && \ ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \ HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \ + TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \ uv pip install -v \ $ARROW_WHL_FILE \ $VISION_WHL_FILE \ $HF_XET_WHL_FILE \ - --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + $TORCH_WHL_FILE \ --index-strategy unsafe-best-match \ -r requirements/build.txt \ -r requirements/cpu.txt diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index 4d8ce0fd9227..64b6f0cc0a9b 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -99,7 +99,9 @@ This time, Inductor compilation is completely bypassed, and we will load from di The above example just uses Inductor to compile for a general shape (i.e. symbolic shape). We can also use Inductor to compile for some of the specific shapes, for example: -`vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'compile_sizes': [1, 2, 4, 8]}"` +``` +vllm serve meta-llama/Llama-3.2-1B --compilation_config '{"compile_sizes": [1, 2, 4, 8]}' +``` Then it will also compile a specific kernel just for batch size `1, 2, 4, 8`. At this time, all of the shapes in the computation graph are static and known, and we will turn on auto-tuning to tune for max performance. This can be slow when you run it for the first time, but the next time you run it, we can directly bypass the tuning and run the tuned kernel. @@ -134,12 +136,14 @@ The cudagraphs are captured and managed by the compiler backend, and replayed wh By default, vLLM will try to determine a set of sizes to capture cudagraph. You can also override it using the config `cudagraph_capture_sizes`: -`vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"` +``` +vllm serve meta-llama/Llama-3.2-1B --compilation-config '{"cudagraph_capture_sizes": [1, 2, 4, 8]}' +``` Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture. ### Full Cudagraph capture -It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config "{'full_cuda_graph': True}"` +It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config '{"full_cuda_graph": true}'`. Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled. diff --git a/docs/source/features/automatic_prefix_caching.md b/docs/source/features/automatic_prefix_caching.md index 59016d7fcf6b..5c5b37c2a071 100644 --- a/docs/source/features/automatic_prefix_caching.md +++ b/docs/source/features/automatic_prefix_caching.md @@ -14,81 +14,7 @@ Technical details on how vLLM implements APC can be found [here](#design-automat Set `enable_prefix_caching=True` in vLLM engine to enable APC. Here is an example: -```python -import time -from vllm import LLM, SamplingParams - - -# A prompt containing a large markdown table. The table is randomly generated by GPT-4. -LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """ -| ID | Name | Age | Occupation | Country | Email | Phone Number | Address | -|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------| -| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL | -| 2 | Jane Smith | 34 | Doctor | Canada | jane.smith@example.com | 555-5678 | 456 Oak St, Toronto, ON | -| 3 | Alice Johnson | 27 | Teacher | UK | alice.j@example.com | 555-8765 | 789 Pine St, London, UK | -| 4 | Bob Brown | 45 | Artist | Australia | bob.b@example.com | 555-4321 | 321 Maple St, Sydney, NSW | -| 5 | Carol White | 31 | Scientist | New Zealand | carol.w@example.com | 555-6789 | 654 Birch St, Wellington, NZ | -| 6 | Dave Green | 28 | Lawyer | Ireland | dave.g@example.com | 555-3456 | 987 Cedar St, Dublin, IE | -| 7 | Emma Black | 40 | Musician | USA | emma.b@example.com | 555-1111 | 246 Ash St, New York, NY | -| 8 | Frank Blue | 37 | Chef | Canada | frank.b@example.com | 555-2222 | 135 Spruce St, Vancouver, BC | -| 9 | Grace Yellow | 50 | Engineer | UK | grace.y@example.com | 555-3333 | 864 Fir St, Manchester, UK | -| 10 | Henry Violet | 32 | Artist | Australia | henry.v@example.com | 555-4444 | 753 Willow St, Melbourne, VIC| -| 11 | Irene Orange | 26 | Scientist | New Zealand | irene.o@example.com | 555-5555 | 912 Poplar St, Auckland, NZ | -| 12 | Jack Indigo | 38 | Teacher | Ireland | jack.i@example.com | 555-6666 | 159 Elm St, Cork, IE | -| 13 | Karen Red | 41 | Lawyer | USA | karen.r@example.com | 555-7777 | 357 Cedar St, Boston, MA | -| 14 | Leo Brown | 30 | Chef | Canada | leo.b@example.com | 555-8888 | 246 Oak St, Calgary, AB | -| 15 | Mia Green | 33 | Musician | UK | mia.g@example.com | 555-9999 | 975 Pine St, Edinburgh, UK | -| 16 | Noah Yellow | 29 | Doctor | Australia | noah.y@example.com | 555-0000 | 864 Birch St, Brisbane, QLD | -| 17 | Olivia Blue | 35 | Engineer | New Zealand | olivia.b@example.com | 555-1212 | 753 Maple St, Hamilton, NZ | -| 18 | Peter Black | 42 | Artist | Ireland | peter.b@example.com | 555-3434 | 912 Fir St, Limerick, IE | -| 19 | Quinn White | 28 | Scientist | USA | quinn.w@example.com | 555-5656 | 159 Willow St, Seattle, WA | -| 20 | Rachel Red | 31 | Teacher | Canada | rachel.r@example.com | 555-7878 | 357 Poplar St, Ottawa, ON | -| 21 | Steve Green | 44 | Lawyer | UK | steve.g@example.com | 555-9090 | 753 Elm St, Birmingham, UK | -| 22 | Tina Blue | 36 | Musician | Australia | tina.b@example.com | 555-1213 | 864 Cedar St, Perth, WA | -| 23 | Umar Black | 39 | Chef | New Zealand | umar.b@example.com | 555-3435 | 975 Spruce St, Christchurch, NZ| -| 24 | Victor Yellow | 43 | Engineer | Ireland | victor.y@example.com | 555-5657 | 246 Willow St, Galway, IE | -| 25 | Wendy Orange | 27 | Artist | USA | wendy.o@example.com | 555-7879 | 135 Elm St, Denver, CO | -| 26 | Xavier Green | 34 | Scientist | Canada | xavier.g@example.com | 555-9091 | 357 Oak St, Montreal, QC | -| 27 | Yara Red | 41 | Teacher | UK | yara.r@example.com | 555-1214 | 975 Pine St, Leeds, UK | -| 28 | Zack Blue | 30 | Lawyer | Australia | zack.b@example.com | 555-3436 | 135 Birch St, Adelaide, SA | -| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ | -| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE | -""" - - -def get_generation_time(llm, sampling_params, prompts): - # time the generation - start_time = time.time() - output = llm.generate(prompts, sampling_params=sampling_params) - end_time = time.time() - # print the output and generation time - print(f"Output: {output[0].outputs[0].text}") - print(f"Generation time: {end_time - start_time} seconds.") - - -# set enable_prefix_caching=True to enable APC -llm = LLM( - model='lmsys/longchat-13b-16k', - enable_prefix_caching=True -) - -sampling_params = SamplingParams(temperature=0, max_tokens=100) - -# Querying the age of John Doe -get_generation_time( - llm, - sampling_params, - LONG_PROMPT + "Question: what is the age of John Doe? Your answer: The age of John Doe is ", -) - -# Querying the age of Zack Blue -# This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again. -get_generation_time( - llm, - sampling_params, - LONG_PROMPT + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ", -) -``` + ## Example workloads diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/features/multimodal_inputs.md similarity index 100% rename from docs/source/serving/multimodal_inputs.md rename to docs/source/features/multimodal_inputs.md diff --git a/docs/source/features/prompt_embeds.md b/docs/source/features/prompt_embeds.md new file mode 100644 index 000000000000..9d7b242bbe51 --- /dev/null +++ b/docs/source/features/prompt_embeds.md @@ -0,0 +1,44 @@ +# Prompt Embedding Inputs + +This page teaches you how to pass prompt embedding inputs to vLLM. + +## What are prompt embeddings? + +The traditional flow of text data for a Large Language Model goes from text to token ids (via a tokenizer) then from token ids to prompt embeddings. For a traditional decoder-only model (such as meta-llama/Llama-3.1-8B-Instruct), this step of converting token ids to prompt embeddings happens via a look-up from a learned embedding matrix, but the model is not limited to processing only the embeddings corresponding to its token vocabulary. + +:::{note} +Prompt embeddings are currently only supported in the v0 engine. +::: + +## Offline Inference + +To input multi-modal data, follow this schema in {class}`vllm.inputs.EmbedsPrompt`: + +- `prompt_embeds`: A torch tensor representing a sequence of prompt/token embeddings. This has the shape (sequence_length, hidden_size), where sequence length is the number of tokens embeddings and hidden_size is the hidden size (embedding size) of the model. + +### Hugging Face Transformers Inputs + +You can pass prompt embeddings from Hugging Face Transformers models to the `'prompt_embeds'` field of the prompt embedding dictionary, as shown in the following examples: + + + +## Online Serving + +Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package. + +When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first. + +Prompt embeddings are passed in as base64 encoded torch tensors. + +### Transformers Inputs via OpenAI Client + +First, launch the OpenAI-compatible server: + +```bash +vllm serve meta-llama/Llama-3.2-1B-Instruct --task generate \ + --max-model-len 4096 --enable-prompt-embeds +``` + +Then, you can use the OpenAI client as follows: + + diff --git a/docs/source/features/tool_calling.md b/docs/source/features/tool_calling.md index 2795b769345e..f76128406bfd 100644 --- a/docs/source/features/tool_calling.md +++ b/docs/source/features/tool_calling.md @@ -158,13 +158,13 @@ All Llama 3.1, 3.2 and 4 models should be supported. * `meta-llama/Llama-3.2-*` * `meta-llama/Llama-4-*` -The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. +The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. As for llama 4 models, it is recommended to use the `llama4_pythonic` tool parser. Other tool calling formats like the built in python tool calling or custom tool calling are not supported. Known issues: -1. Parallel tool calls are not supported. +1. Parallel tool calls are not supported for llama 3, but it is supported in llama 4 models. 2. The model can generate parameters with a wrong format, such as generating an array serialized as string instead of an array. @@ -177,11 +177,10 @@ images. Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}` -VLLM also provides a JSON based chat template for Llama 4: -* - this is based on the "official" chat template for the Llama 4 -models, but tweaked so that it works better with vLLM. +VLLM also provides a pythonic and JSON based chat template for Llama 4, but pythonic tool calling is recommended: +* - this is based on the [official chat template](https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/) for the Llama 4 models. -For Llama 4 use `--tool-call-parser llama4_json examples/tool_chat_template_llama4_json.jinja`. +For Llama 4 model, use `--tool-call-parser llama4_pythonic --chat-template examples/tool_chat_template_llama4_pythonic.jinja`. #### IBM Granite diff --git a/docs/source/getting_started/installation/gpu/xpu.inc.md b/docs/source/getting_started/installation/gpu/xpu.inc.md index 4ab41a21c2a1..74937a184227 100644 --- a/docs/source/getting_started/installation/gpu/xpu.inc.md +++ b/docs/source/getting_started/installation/gpu/xpu.inc.md @@ -66,7 +66,6 @@ XPU platform supports **tensor parallel** inference/serving and also supports ** python -m vllm.entrypoints.openai.api_server \ --model=facebook/opt-13b \ --dtype=bfloat16 \ - --device=xpu \ --max_model_len=1024 \ --distributed-executor-backend=ray \ --pipeline-parallel-size=2 \ diff --git a/docs/source/getting_started/quickstart.md b/docs/source/getting_started/quickstart.md index 298ba59f7d8b..42468ff73c2c 100644 --- a/docs/source/getting_started/quickstart.md +++ b/docs/source/getting_started/quickstart.md @@ -82,6 +82,11 @@ llm = LLM(model="facebook/opt-125m") :::{note} By default, vLLM downloads models from [Hugging Face](https://huggingface.co/). If you would like to use models from [ModelScope](https://www.modelscope.cn), set the environment variable `VLLM_USE_MODELSCOPE` before initializing the engine. + +```shell +export VLLM_USE_MODELSCOPE=True +``` + ::: Now, the fun part! The outputs are generated using `llm.generate`. It adds the input prompts to the vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of `RequestOutput` objects, which include all of the output tokens. diff --git a/docs/source/index.md b/docs/source/index.md index 7e5b73c96896..db2192e87dcf 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -90,6 +90,8 @@ models/extensions/index :maxdepth: 1 features/quantization/index +features/multimodal_inputs +features/prompt_embeds features/lora features/tool_calling features/reasoning_outputs @@ -118,8 +120,6 @@ training/rlhf.md serving/offline_inference serving/openai_compatible_server serving/serve_args -serving/multimodal_inputs -serving/prompt_embeds serving/distributed_serving serving/metrics serving/engine_args diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 80eccfd034af..6022dfb9c2c6 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -54,7 +54,7 @@ For a model to be compatible with the Transformers backend for vLLM it must: If the compatible model is: -- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for or `--trust-remode-code` for the . +- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for or `--trust-remote-code` for the . - in a local directory, simply pass directory path to `model=` for or `vllm serve ` for the . This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM! @@ -392,6 +392,11 @@ Specified using `--task generate`. * `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. * ✅︎ * ✅︎ +- * `FalconH1ForCausalLM` + * Falcon-H1 + * `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. + * ✅︎ + * ✅︎ - * `GemmaForCausalLM` * Gemma * `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. diff --git a/docs/source/serving/prompt_embeds.md b/docs/source/serving/prompt_embeds.md deleted file mode 100644 index 483ca16648a4..000000000000 --- a/docs/source/serving/prompt_embeds.md +++ /dev/null @@ -1,142 +0,0 @@ -# Prompt Embedding Inputs - -This page teaches you how to pass prompt embedding inputs to vLLM. - -## What are prompt embeddings? - -The traditional flow of text data for a Large Language Model goes from text to token ids (via a tokenizer) then from token ids to prompt embeddings. For a traditional decoder-only model (such as meta-llama/Llama-3.1-8B-Instruct), this step of converting token ids to prompt embeddings happens via a look-up from a learned embedding matrix, but the model is not limited to processing only the embeddings corresponding to its token vocabulary. - -:::{note} -Prompt embeddings are currently only supported in the v0 engine. -::: - -## Offline Inference - -To input multi-modal data, follow this schema in {class}`vllm.inputs.EmbedsPrompt`: - -- `prompt_embeds`: A torch tensor representing a sequence of prompt/token embeddings. This has the shape (sequence_length, hidden_size), where sequence length is the number of tokens embeddings and hidden_size is the hidden size (embedding size) of the model. - -### Hugging Face Transformers Inputs - -You can pass prompt embeddings from Hugging Face Transformers models to the `'prompt_embeds'` field of the prompt embedding dictionary, as shown in the following examples: - -```python -from vllm import LLM -import transformers - -model_name = "meta-llama/Llama-3.2-1B-Instruct" - -# Transformers -tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) -transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name) - -llm = LLM(model=model_name, enable_prompt_embeds=True) - -# Refer to the HuggingFace repo for the correct format to use -chat = [{"role": "user", "content": "Please tell me about the capital of France."}] -token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') - -prompt_embeds = embedding_layer(token_ids).squeeze(0) - -# Single prompt inference -outputs = llm.generate({ - "prompt_embeds": prompt_embeds, -}) - -for o in outputs: - generated_text = o.outputs[0].text - print(generated_text) - -# Batch inference - -chats = [ - [{"role": "user", "content": "Please tell me about the capital of France."}], - [{"role": "user", "content": "When is the day longest during the year?"}], - [{"role": "user", "content": "Where is bigger, the moon or the sun?"}] -] - -token_ids_list = [ - tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') for chat in chats -] -prompt_embeds_list = [embedding_layer(token_ids).squeeze(0) for token_ids in token_ids_list] - -outputs = llm.generate( - [ - { - "prompt_embeds": prompt_embeds, - } for prompt_embeds in prompt_embeds_list - ] -) - -for o in outputs: - generated_text = o.outputs[0].text - print(generated_text) -``` - -## Online Serving - -Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package. - -When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first. - -Prompt embeddings are passed in as base64 encoded torch tensors. - -### Transformers Inputs via OpenAI Client - -First, launch the OpenAI-compatible server: - -```bash -vllm serve meta-llama/Llama-3.2-1B-Instruct --task generate \ - --max-model-len 4096 --enable-prompt-embeds -``` - -Then, you can use the OpenAI client as follows: - -```python -from openai import OpenAI -import transformers -import torch - -openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" - -client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, -) - -model_name = "meta-llama/Llama-3.2-1B-Instruct" - -# Transformers -tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) -transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name) - - -# Refer to the HuggingFace repo for the correct format to use -chat = [{"role": "user", "content": "Please tell me about the capital of France."}] -token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') - -prompt_embeds = embedding_layer(token_ids).squeeze(0) - -# Prompt embeddings -buffer = io.BytesIO() -torch.save(prompt_embeds, buffer) -buffer.seek(0) -binary_data = buffer.read() -encoded_embeds = base64.b64encode(binary_data).decode('utf-8') - - -completion = client_with_prompt_embeds.completions.create( - model=model_name, - # NOTE: The OpenAI client does not allow `None` as an input to - # `prompt`. Use an empty string if you have no text prompts. - prompt="", - max_tokens=5, - temperature=0.0, - # NOTE: The OpenAI client allows passing in extra JSON body via the - # `extra_body` argument. - extra_body={"prompt_embeds": encoded_embeds} -) - -print(completion.choices[0].text) -``` diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh index 831ef0bb574b..5719fa821292 100644 --- a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh @@ -54,6 +54,6 @@ elif [[ $1 == "decoder" ]]; then else echo "Invalid role: $1" - echo "Should be either prefill, decode" + echo "Should be either prefiller, decoder" exit 1 fi diff --git a/examples/offline_inference/automatic_prefix_caching.py b/examples/offline_inference/automatic_prefix_caching.py new file mode 100644 index 000000000000..6d05d0b99d80 --- /dev/null +++ b/examples/offline_inference/automatic_prefix_caching.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Demonstration script for Automatic Prefix Caching (APC) in vLLM. + +Automatic Prefix Caching (APC) allows the vLLM engine to reuse cached +KV (key-value) pairs from previous prompts if a new query shares the same +prefix. This reduces redundant computation and improves inference speed. + +To enable APC, set `enable_prefix_caching=True` when initializing the +vLLM engine. + +This script uses a long Markdown table as the shared prompt prefix and +compares the generation time for two queries that share the same prefix +but ask different questions. + +Run: +python examples/offline_inference/automatic_prefix_caching.py +""" +import time + +from vllm import LLM, SamplingParams + +# ruff: noqa: E501 +# A prompt containing a large markdown table. The table is randomly generated by GPT-4. +LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """ +| ID | Name | Age | Occupation | Country | Email | Phone Number | Address | +|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------| +| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL | +| 2 | Jane Smith | 34 | Doctor | Canada | jane.smith@example.com | 555-5678 | 456 Oak St, Toronto, ON | +| 3 | Alice Johnson | 27 | Teacher | UK | alice.j@example.com | 555-8765 | 789 Pine St, London, UK | +| 4 | Bob Brown | 45 | Artist | Australia | bob.b@example.com | 555-4321 | 321 Maple St, Sydney, NSW | +| 5 | Carol White | 31 | Scientist | New Zealand | carol.w@example.com | 555-6789 | 654 Birch St, Wellington, NZ | +| 6 | Dave Green | 28 | Lawyer | Ireland | dave.g@example.com | 555-3456 | 987 Cedar St, Dublin, IE | +| 7 | Emma Black | 40 | Musician | USA | emma.b@example.com | 555-1111 | 246 Ash St, New York, NY | +| 8 | Frank Blue | 37 | Chef | Canada | frank.b@example.com | 555-2222 | 135 Spruce St, Vancouver, BC | +| 9 | Grace Yellow | 50 | Engineer | UK | grace.y@example.com | 555-3333 | 864 Fir St, Manchester, UK | +| 10 | Henry Violet | 32 | Artist | Australia | henry.v@example.com | 555-4444 | 753 Willow St, Melbourne, VIC| +| 11 | Irene Orange | 26 | Scientist | New Zealand | irene.o@example.com | 555-5555 | 912 Poplar St, Auckland, NZ | +| 12 | Jack Indigo | 38 | Teacher | Ireland | jack.i@example.com | 555-6666 | 159 Elm St, Cork, IE | +| 13 | Karen Red | 41 | Lawyer | USA | karen.r@example.com | 555-7777 | 357 Cedar St, Boston, MA | +| 14 | Leo Brown | 30 | Chef | Canada | leo.b@example.com | 555-8888 | 246 Oak St, Calgary, AB | +| 15 | Mia Green | 33 | Musician | UK | mia.g@example.com | 555-9999 | 975 Pine St, Edinburgh, UK | +| 16 | Noah Yellow | 29 | Doctor | Australia | noah.y@example.com | 555-0000 | 864 Birch St, Brisbane, QLD | +| 17 | Olivia Blue | 35 | Engineer | New Zealand | olivia.b@example.com | 555-1212 | 753 Maple St, Hamilton, NZ | +| 18 | Peter Black | 42 | Artist | Ireland | peter.b@example.com | 555-3434 | 912 Fir St, Limerick, IE | +| 19 | Quinn White | 28 | Scientist | USA | quinn.w@example.com | 555-5656 | 159 Willow St, Seattle, WA | +| 20 | Rachel Red | 31 | Teacher | Canada | rachel.r@example.com | 555-7878 | 357 Poplar St, Ottawa, ON | +| 21 | Steve Green | 44 | Lawyer | UK | steve.g@example.com | 555-9090 | 753 Elm St, Birmingham, UK | +| 22 | Tina Blue | 36 | Musician | Australia | tina.b@example.com | 555-1213 | 864 Cedar St, Perth, WA | +| 23 | Umar Black | 39 | Chef | New Zealand | umar.b@example.com | 555-3435 | 975 Spruce St, Christchurch, NZ| +| 24 | Victor Yellow | 43 | Engineer | Ireland | victor.y@example.com | 555-5657 | 246 Willow St, Galway, IE | +| 25 | Wendy Orange | 27 | Artist | USA | wendy.o@example.com | 555-7879 | 135 Elm St, Denver, CO | +| 26 | Xavier Green | 34 | Scientist | Canada | xavier.g@example.com | 555-9091 | 357 Oak St, Montreal, QC | +| 27 | Yara Red | 41 | Teacher | UK | yara.r@example.com | 555-1214 | 975 Pine St, Leeds, UK | +| 28 | Zack Blue | 30 | Lawyer | Australia | zack.b@example.com | 555-3436 | 135 Birch St, Adelaide, SA | +| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ | +| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE | +""" + + +def get_generation_time(llm, sampling_params, prompts): + # time the generation + start_time = time.time() + output = llm.generate(prompts, sampling_params=sampling_params) + end_time = time.time() + # print the output and generation time + print("-" * 30) + print(f"Output: {output[0].outputs[0].text}") + print(f"Generation time: {end_time - start_time} seconds.") + print("-" * 30) + + +def main(): + # set enable_prefix_caching=True to enable APC + llm = LLM(model='lmsys/longchat-13b-16k', enable_prefix_caching=True) + + sampling_params = SamplingParams(temperature=0, max_tokens=100) + + # Querying the age of John Doe + get_generation_time( + llm, + sampling_params, + LONG_PROMPT + + "Question: what is the age of John Doe? Your answer: The age of John Doe is ", + ) + + # Querying the age of Zack Blue + # This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again. + get_generation_time( + llm, + sampling_params, + LONG_PROMPT + + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ", + ) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/disaggregated-prefill-v1/README.md b/examples/offline_inference/disaggregated-prefill-v1/README.md index f708eb253838..9cbdb19820f5 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/README.md +++ b/examples/offline_inference/disaggregated-prefill-v1/README.md @@ -5,5 +5,6 @@ This example contains scripts that demonstrate disaggregated prefill in the offl ## Files - `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially. + - Make sure you are in the `examples/offline_inference/disaggregated-prefill-v1` directory before running `run.sh`. - `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`. - `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`. diff --git a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py index 11918f72feec..531c96f176a3 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py @@ -3,35 +3,47 @@ from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig -# Read prompts from output.txt -prompts = [] -try: - with open("output.txt") as f: - for line in f: - prompts.append(line.strip()) - print(f"Loaded {len(prompts)} prompts from output.txt") -except FileNotFoundError: - print("Error: output.txt file not found") - exit(-1) - -sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - -llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - gpu_memory_utilization=0.8, - max_num_batched_tokens=64, - max_num_seqs=16, - kv_transfer_config=KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "shared_storage_path": "local_storage" - })) #, max_model_len=2048, max_num_batched_tokens=2048) - -# 1ST generation (prefill instance) -outputs = llm.generate(prompts, sampling_params) - -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +def read_prompts(): + """Read prompts from output.txt""" + prompts = [] + try: + with open("output.txt") as f: + for line in f: + prompts.append(line.strip()) + print(f"Loaded {len(prompts)} prompts from output.txt") + return prompts + except FileNotFoundError: + print("Error: output.txt file not found") + exit(-1) + + +def main(): + prompts = read_prompts() + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=64, + max_num_seqs=16, + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage" + })) #, max_model_len=2048, max_num_batched_tokens=2048) + + # 1ST generation (prefill instance) + outputs = llm.generate(prompts, sampling_params) + + print("-" * 30) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 30) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py index 798128301e0f..24b7b1d8fdbe 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py @@ -3,42 +3,54 @@ from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig -context = "Hi " * 1000 -context2 = "Hey " * 500 -prompts = [ - context + "Hello, my name is", - context + "The capital of France is", - context2 + "Your name is", - context2 + "The capital of China is", -] - -sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - -llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - gpu_memory_utilization=0.8, - kv_transfer_config=KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "shared_storage_path": "local_storage" - })) #, max_model_len=2048, max_num_batched_tokens=2048) - -# 1ST generation (prefill instance) -outputs = llm.generate( - prompts, - sampling_params, -) - -new_prompts = [] -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - new_prompts.append(prompt + generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -# Write new_prompts to output.txt -with open("output.txt", "w") as f: - for prompt in new_prompts: - f.write(prompt + "\n") -print(f"Saved {len(new_prompts)} prompts to output.txt") + +def read_prompts(): + context = "Hi " * 1000 + context2 = "Hey " * 500 + return [ + context + "Hello, my name is", + context + "The capital of France is", + context2 + "Your name is", + context2 + "The capital of China is", + ] + + +def main(): + prompts = read_prompts() + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage" + })) #, max_model_len=2048, max_num_batched_tokens=2048) + + # 1ST generation (prefill instance) + outputs = llm.generate( + prompts, + sampling_params, + ) + + new_prompts = [] + print("-" * 30) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 30) + + # Write new_prompts to output.txt + with open("output.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") + print(f"Saved {len(new_prompts)} prompts to output.txt") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/prompt_embed_inference.py b/examples/offline_inference/prompt_embed_inference.py new file mode 100644 index 000000000000..99c5a682fb27 --- /dev/null +++ b/examples/offline_inference/prompt_embed_inference.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Demonstrates how to generate prompt embeddings using +Hugging Face Transformers and use them as input to vLLM +for both single and batch inference. + +Model: meta-llama/Llama-3.2-1B-Instruct +Note: This model is gated on Hugging Face Hub. + You must request access to use it: + https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct + +Requirements: +- vLLM +- transformers + +Run: + python examples/offline_inference/prompt_embed_inference.py +""" + +import torch +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizer) + +from vllm import LLM + + +def init_tokenizer_and_llm(model_name: str): + tokenizer = AutoTokenizer.from_pretrained(model_name) + transformers_model = AutoModelForCausalLM.from_pretrained(model_name) + embedding_layer = transformers_model.get_input_embeddings() + llm = LLM(model=model_name, enable_prompt_embeds=True) + return tokenizer, embedding_layer, llm + + +def get_prompt_embeds(chat: list[dict[str, + str]], tokenizer: PreTrainedTokenizer, + embedding_layer: torch.nn.Module): + token_ids = tokenizer.apply_chat_template(chat, + add_generation_prompt=True, + return_tensors='pt') + prompt_embeds = embedding_layer(token_ids).squeeze(0) + return prompt_embeds + + +def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, + embedding_layer: torch.nn.Module): + chat = [{ + "role": "user", + "content": "Please tell me about the capital of France." + }] + prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer) + + outputs = llm.generate({ + "prompt_embeds": prompt_embeds, + }) + + print("\n[Single Inference Output]") + print("-" * 30) + for o in outputs: + print(o.outputs[0].text) + print("-" * 30) + + +def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, + embedding_layer: torch.nn.Module): + chats = [[{ + "role": "user", + "content": "Please tell me about the capital of France." + }], + [{ + "role": "user", + "content": "When is the day longest during the year?" + }], + [{ + "role": "user", + "content": "Where is bigger, the moon or the sun?" + }]] + + prompt_embeds_list = [ + get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats + ] + + outputs = llm.generate([{ + "prompt_embeds": embeds + } for embeds in prompt_embeds_list]) + + print("\n[Batch Inference Outputs]") + print("-" * 30) + for i, o in enumerate(outputs): + print(f"Q{i+1}: {chats[i][0]['content']}") + print(f"A{i+1}: {o.outputs[0].text}\n") + print("-" * 30) + + +def main(): + model_name = "meta-llama/Llama-3.2-1B-Instruct" + tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name) + single_prompt_inference(llm, tokenizer, embedding_layer) + batch_prompt_inference(llm, tokenizer, embedding_layer) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_client.py b/examples/online_serving/openai_chat_completion_client.py index 74e0c045d621..bf99777d5697 100644 --- a/examples/online_serving/openai_chat_completion_client.py +++ b/examples/online_serving/openai_chat_completion_client.py @@ -3,6 +3,9 @@ NOTE: start a supported chat completion model server with `vllm serve`, e.g. vllm serve meta-llama/Llama-2-7b-chat-hf """ + +import argparse + from openai import OpenAI # Modify OpenAI's API key and API base to use vLLM's API server. @@ -24,7 +27,15 @@ }] -def main(): +def parse_args(): + parser = argparse.ArgumentParser(description="Client for vLLM API server") + parser.add_argument("--stream", + action="store_true", + help="Enable streaming response") + return parser.parse_args() + + +def main(args): client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") api_key=openai_api_key, @@ -34,16 +45,23 @@ def main(): models = client.models.list() model = models.data[0].id + # Chat Completion API chat_completion = client.chat.completions.create( messages=messages, model=model, + stream=args.stream, ) print("-" * 50) print("Chat completion results:") - print(chat_completion) + if args.stream: + for c in chat_completion: + print(c) + else: + print(chat_completion) print("-" * 50) if __name__ == "__main__": - main() + args = parse_args() + main(args) diff --git a/examples/online_serving/openai_chat_completion_structured_outputs.py b/examples/online_serving/openai_chat_completion_structured_outputs.py index 660369e55d40..722d747a69bf 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs.py @@ -12,6 +12,9 @@ from openai import BadRequestError, OpenAI from pydantic import BaseModel +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + # Guided decoding by Choice (list of possible options) def guided_choice_completion(client: OpenAI, model: str): @@ -134,8 +137,8 @@ def extra_backend_options_completion(client: OpenAI, model: str): def main(): client: OpenAI = OpenAI( - base_url="http://localhost:8000/v1", - api_key="-", + base_url=openai_api_base, + api_key=openai_api_key, ) model = client.models.list().data[0].id diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py index 42aa12c451c0..08f939942508 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py @@ -7,11 +7,14 @@ # to enforce the format of a tool call response, but it could be used for # any structured output within a subset of the response. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + def main(): client = OpenAI( - base_url="http://localhost:8000/v1", - api_key="-", + base_url=openai_api_base, + api_key=openai_api_key, ) messages = [{ diff --git a/examples/online_serving/openai_completion_client.py b/examples/online_serving/openai_completion_client.py index 6ab7619bff19..77f721921da2 100644 --- a/examples/online_serving/openai_completion_client.py +++ b/examples/online_serving/openai_completion_client.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import argparse + from openai import OpenAI # Modify OpenAI's API key and API base to use vLLM's API server. @@ -7,7 +9,15 @@ openai_api_base = "http://localhost:8000/v1" -def main(): +def parse_args(): + parser = argparse.ArgumentParser(description="Client for vLLM API server") + parser.add_argument("--stream", + action="store_true", + help="Enable streaming response") + return parser.parse_args() + + +def main(args): client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") api_key=openai_api_key, @@ -18,18 +28,17 @@ def main(): model = models.data[0].id # Completion API - stream = False completion = client.completions.create( model=model, prompt="A robot may not injure a human being", echo=False, n=2, - stream=stream, + stream=args.stream, logprobs=3) print("-" * 50) print("Completion results:") - if stream: + if args.stream: for c in completion: print(c) else: @@ -38,4 +47,5 @@ def main(): if __name__ == "__main__": - main() + args = parse_args() + main(args) diff --git a/examples/online_serving/prompt_embed_inference_with_openai_client.py b/examples/online_serving/prompt_embed_inference_with_openai_client.py new file mode 100644 index 000000000000..ea580f1b432b --- /dev/null +++ b/examples/online_serving/prompt_embed_inference_with_openai_client.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +vLLM OpenAI-Compatible Client with Prompt Embeddings + +This script demonstrates how to: +1. Generate prompt embeddings using Hugging Face Transformers +2. Encode them in base64 format +3. Send them to a vLLM server via the OpenAI-compatible Completions API + +Run the vLLM server first: +vllm serve meta-llama/Llama-3.2-1B-Instruct \ + --task generate \ + --max-model-len 4096 \ + --enable-prompt-embeds + +Run the client: +python examples/online_serving/prompt_embed_inference_with_openai_client.py + +Model: meta-llama/Llama-3.2-1B-Instruct +Note: This model is gated on Hugging Face Hub. + You must request access to use it: + https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct + +Dependencies: +- transformers +- torch +- openai +""" +import base64 +import io + +import torch +import transformers +from openai import OpenAI + + +def main(): + client = OpenAI( + api_key="EMPTY", + base_url="http://localhost:8000/v1", + ) + + model_name = "meta-llama/Llama-3.2-1B-Instruct" + + # Transformers + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + transformers_model = transformers.AutoModelForCausalLM.from_pretrained( + model_name) + + # Refer to the HuggingFace repo for the correct format to use + chat = [{ + "role": "user", + "content": "Please tell me about the capital of France." + }] + token_ids = tokenizer.apply_chat_template(chat, + add_generation_prompt=True, + return_tensors='pt') + + embedding_layer = transformers_model.get_input_embeddings() + prompt_embeds = embedding_layer(token_ids).squeeze(0) + + # Prompt embeddings + buffer = io.BytesIO() + torch.save(prompt_embeds, buffer) + buffer.seek(0) + binary_data = buffer.read() + encoded_embeds = base64.b64encode(binary_data).decode('utf-8') + + completion = client.completions.create( + model=model_name, + # NOTE: The OpenAI client does not allow `None` as an input to + # `prompt`. Use an empty string if you have no text prompts. + prompt="", + max_tokens=5, + temperature=0.0, + # NOTE: The OpenAI client allows passing in extra JSON body via the + # `extra_body` argument. + extra_body={"prompt_embeds": encoded_embeds}) + + print("-" * 30) + print(completion.choices[0].text) + print("-" * 30) + + +if __name__ == "__main__": + main() diff --git a/examples/tool_chat_template_llama4_pythonic.jinja b/examples/tool_chat_template_llama4_pythonic.jinja index bd18a35bdda9..bbed3d8205e0 100644 --- a/examples/tool_chat_template_llama4_pythonic.jinja +++ b/examples/tool_chat_template_llama4_pythonic.jinja @@ -1,16 +1,17 @@ {{- bos_token }} -{%- if custom_tools is defined %} +{%- if custom_tools is defined and custom_tools%} {%- set tools = custom_tools %} {%- endif %} -{%- if not tools_in_user_message is defined %} - {%- set tools_in_user_message = false %} -{%- endif %} -{%- if not tools is defined %} +{%- if tools is defined and tools %} + {%- set tool_definition = tool_definition ~ (tools | tojson(indent=4)) %} +{%- else %} {%- set tools = none %} {%- endif %} + {#- This block extracts the system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %} + {%- set user_provided_system_message = true %} {%- if messages[0]['content'] is string %} {%- set system_message = messages[0]['content']|trim %} {%- else %} @@ -18,68 +19,33 @@ {%- endif %} {%- set messages = messages[1:] %} {%- else %} - {%- if tools is not none %} - {#- Add default tool system message when tools are provided #} - {%- set system_message = "You are a helpful assistant with tool calling " - "capabilities. Only reply with a tool call if the function exists in the " - "library provided by the user. If it doesn't exist, just reply directly in " - "natural language. When you receive a tool call response, use the output to " - "format an answer to the original user question." %} + {%- if tools is not none %} + {#- Since not system_message was provided by user, if tool is provided, system_message is now default tool system message #} + {#- This system message is from llama website:https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/ #} + {%- set system_message = "You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:\n\n1. FUNCTION CALLS:\n- ONLY use functions that are EXPLICITLY listed in the function list below\n- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If a function is not in the list, respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)\n- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\nExamples:\nCORRECT: [get_weather(location=\"Vancouver\"), calculate_route(start=\"Boston\", end=\"New York\")] <- Only if get_weather and calculate_route are in function list\nINCORRECT: get_weather(location=\"New York\")\nINCORRECT: Let me check the weather: [get_weather(location=\"New York\")]\nINCORRECT: [get_events(location=\"Singapore\")] <- If function not in list\n\n2. RESPONSE RULES:\n- For pure function requests matching a listed function: ONLY output the function call(s)\n- For knowledge questions: ONLY output text\n- For missing parameters: ONLY request the specific missing parameters\n- For unavailable services (not in function list): output ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\". Do NOT execute a function call.\n- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations\n- NEVER combine text and function calls in the same response\n- NEVER suggest alternative functions when the requested service is unavailable\n- NEVER create or invent new functions not listed below\n\n3. STRICT BOUNDARIES:\n- ONLY use functions from the list below - no exceptions\n- NEVER use a function as an alternative to unavailable information\n- NEVER call functions not present in the function list\n- NEVER add explanatory text to function calls\n- NEVER respond with empty brackets\n- Use proper Python/JSON syntax for function calls\n- Check the function list carefully before responding\n\n4. TOOL RESPONSE HANDLING:\n- When receiving tool responses: provide concise, natural language responses\n- Don't repeat tool response verbatim\n- Don't add supplementary information\n\nHere is a list of functions in JSON format that you can invoke:\n" %} {%- else %} {%- set system_message = "" %} {%- endif %} {%- endif %} - -{#- System message if the user supplied one, or if tools are used (default tool system message) #} +{#- Now writing the system message: use the user provided system message if user_provided_system_message, else default tool system message if tools presented #} {%- if system_message %} {#- always use user provided system message to override default tool system message #} {{- "<|header_start|>system<|header_end|>\n\n" }} {{- system_message }} - {%- if tools is not none and not tools_in_user_message %} - {{- "Tools: You have access to the following tools. You might need to use one " - "or more function/tool calls to fulfill the task. \n" - "If none are needed, then proceed to the response.\n\n" - "Tool Call Syntax: You can call tools using the following syntax:\n" - "[func_name1(params_name1=params_value1, params_name2=params_value2, ...), ...]\n" - "Do not include anything else when calling the tools with the syntax above.\n\n" - "Here is a list of functions in JSON format that you can invoke.\n " }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} + {%- if user_provided_system_message and tools %} + {{- "\nHere is a list of functions in JSON format that you can invoke. Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\n" }} + {{- tool_definition -}} + {%- elif tool_definition %} + {{- tool_definition -}} {%- endif %} {{- "<|eot|>" }} {%- endif %} -{#- Custom tools are passed in a user message with some extra guidance #} -{%- if tools_in_user_message and tools is not none %} - {#- Extract the first user message so we can plug it in here #} - {%- if messages | length != 0 %} - {%- if messages[0]['content'] is string %} - {%- set first_user_message = messages[0]['content']|trim %} - {%- else %} - {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} - {%- endif %} - {%- set messages = messages[1:] %} - {%- else %} - {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} - {%- endif %} - {{- '<|header_start|>user<|header_end|>\n\n' -}} - {{- first_user_message}} - {{- "\nHere is a list of functions in JSON format that you can invoke:"}} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} - {{- "Should you decide to return the function call(s), put them in the format " - "of [func_name1(params_name1=params_value1, params_name2=params_value2, " - "...), ...]\nDo not include anything else when calling the tools with the " - "syntax above." }} -{%- endif %} - +{#- Now deal with all other messages #} {%- for message in messages %} - {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} + {#- Base case: messages that are not from tool role and has empty tool_call list #} + {%- if not (message.role == 'ipython' or message.role == 'tool' or ('tool_calls' in message and message.tool_calls|length != 0 )) %} + {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} {%- if message['content'] is string %} {{- message['content'] }} {%- else %} @@ -91,10 +57,12 @@ {%- endif %} {%- endfor %} {%- endif %} - {{- "<|eot|>" }} - {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %} - {%- set tool_call = message.tool_calls[0].function %} - {{- '<|header_start|>assistant<|header_end|>\n\n' -}} + {{- "<|eot|>" }} + {#- Tool case: messages has non-empty tool_call list, must from assistant #} + {%- elif 'tool_calls' in message %} + {#- assume tool_calls are always coming from assistant #} + {%- if message.role == 'assistant' %} + {{- '<|header_start|>assistant<|header_end|>\n\n' -}} {%- if message['content'] is string %} {{- message['content'] }} {%- else %} @@ -106,32 +74,36 @@ {%- endif %} {%- endfor %} {%- endif %} + {{- "[" }} {%- for tool_call in message.tool_calls %} {%- if tool_call.function is defined %} {%- set tool_call = tool_call.function %} {%- endif %} - {{- tool_call.name + '(' -}} + {{- tool_call.name + '(' -}} {%- for param in tool_call.arguments %} - {{- param + '=' -}} + {{- param + '="' -}} {{- "%s" | format(tool_call.arguments[param]) -}} + {{- '"' -}} {% if not loop.last %}, {% endif %} {%- endfor %} {{- ')' -}} {% if not loop.last %}, {% endif %} {%- endfor %} - {{- "<|eom|>" }} + {{- "]<|eot|>" }} +{%- endif %} +{#- Tool_response case: messages are from tool_response #} {%- elif message.role == "tool" or message.role == "ipython" %} {{- "<|header_start|>ipython<|header_end|>\n\n" }} {%- if message.content is string %} - {{- message.content | tojson }} + {{- message.content | tojson }} {%- else %} {%- for content in message['content'] %} {%- if content['type'] == 'text' %} - {{- content['text'] | tojson }} + {{- content['text'] | tojson }} {%- endif %} {%- endfor %} {%- endif %} - {{- "<|eom|>" }} + {{- "<|eot|>" }} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} diff --git a/pyproject.toml b/pyproject.toml index 6a2d9c44d414..a12f545dc91f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,8 +36,8 @@ dynamic = [ "version", "dependencies", "optional-dependencies"] [project.urls] Homepage="https://github.com/vllm-project/vllm" -Documentation="https://vllm.readthedocs.io/en/latest/" -Slack="http://slack.vllm.ai/" +Documentation="https://docs.vllm.ai/en/latest/" +Slack="https://slack.vllm.ai/" [project.scripts] vllm = "vllm.entrypoints.cli.main:main" diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 752931158a05..d4191888382c 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -2,11 +2,12 @@ -r common.txt # Dependencies for CPUs +packaging>=24.2 +setuptools>=77.0.3,<80.0.0 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.7.0+cpu; platform_machine == "x86_64" torch==2.7.0; platform_system == "Darwin" torch==2.7.0; platform_machine == "ppc64le" or platform_machine == "aarch64" -torch==2.7.0.dev20250304; platform_machine == "s390x" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" diff --git a/requirements/test.in b/requirements/test.in index cdc7c563f087..87af61769038 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -33,6 +33,7 @@ num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test +mteb>=1.38.11, <2 # required for mteb test transformers==4.51.3 tokenizers==0.21.1 huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. diff --git a/requirements/test.txt b/requirements/test.txt index 9a15d9a0d824..89d477017342 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -99,6 +99,7 @@ datasets==3.0.2 # via # evaluate # lm-eval + # mteb decorator==5.1.1 # via librosa dill==0.3.8 @@ -124,6 +125,8 @@ email-validator==2.2.0 # via pydantic encodec==0.1.1 # via vocos +eval-type-backport==0.2.2 + # via mteb evaluate==0.4.3 # via lm-eval fastparquet==2024.11.0 @@ -291,6 +294,8 @@ msgpack==1.1.0 # via # librosa # ray +mteb==1.38.11 + # via -r requirements/test.in multidict==6.1.0 # via # aiohttp @@ -331,6 +336,7 @@ numpy==1.26.4 # librosa # matplotlib # mistral-common + # mteb # numba # numexpr # opencv-python-headless @@ -443,6 +449,8 @@ plotly==5.24.1 # via genai-perf pluggy==1.5.0 # via pytest +polars==1.29.0 + # via mteb pooch==1.8.2 # via librosa portalocker==2.10.1 @@ -476,6 +484,7 @@ pydantic==2.9.2 # via # datamodel-code-generator # mistral-common + # mteb pydantic-core==2.23.4 # via pydantic pygments==2.18.0 @@ -522,6 +531,8 @@ python-dateutil==2.9.0.post0 # typepy python-rapidjson==1.20 # via tritonclient +pytrec-eval-terrier==0.5.7 + # via mteb pytz==2024.2 # via # pandas @@ -564,6 +575,7 @@ requests==2.32.3 # huggingface-hub # lm-eval # mistral-common + # mteb # pooch # ray # responses @@ -580,6 +592,7 @@ rfc3987==1.3.8 rich==13.9.4 # via # genai-perf + # mteb # typer rouge-score==0.1.2 # via lm-eval @@ -607,16 +620,20 @@ scikit-learn==1.5.2 # via # librosa # lm-eval + # mteb # sentence-transformers scipy==1.13.1 # via # librosa + # mteb # scikit-learn # sentence-transformers # statsmodels # vocos sentence-transformers==3.2.1 - # via -r requirements/test.in + # via + # -r requirements/test.in + # mteb sentencepiece==0.2.0 # via mistral-common setuptools==77.0.3 @@ -696,6 +713,7 @@ torch==2.7.0+cu128 # fastsafetensors # lm-eval # mamba-ssm + # mteb # peft # runai-model-streamer # sentence-transformers @@ -720,6 +738,7 @@ tqdm==4.66.6 # evaluate # huggingface-hub # lm-eval + # mteb # nltk # peft # pqdm @@ -759,6 +778,7 @@ typing-extensions==4.12.2 # huggingface-hub # librosa # mistral-common + # mteb # pqdm # pydantic # pydantic-core diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 11501bc5d92f..3b204a8f9905 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -18,9 +18,9 @@ setuptools==78.1.0 --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.8.0.dev20250430 -torchvision==0.22.0.dev20250430 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.8.0.dev20250518 +torchvision==0.22.0.dev20250518 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 9f3b0e8ae079..86b5e1e0ab7c 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -8,12 +8,13 @@ from unittest.mock import Mock import pytest +import torch -from vllm import LLM +from vllm import LLM, envs from vllm.platforms import current_platform from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 -from ..conftest import VllmRunner +from ..conftest import HfRunner, VllmRunner from ..models.utils import check_outputs_equal from ..utils import multi_gpu_test @@ -43,11 +44,26 @@ def test_vllm_gc_ed(): assert weak_llm() is None +def _fix_prompt_embed_outputs( + vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner, + example_prompts: list[str]) -> list[tuple[list[int], str]]: + fixed_vllm_outputs = [] + for vllm_output, hf_input, prompt in zip( + vllm_outputs, hf_model.get_inputs(example_prompts), + example_prompts): + hf_input_ids = hf_input["input_ids"].tolist()[0] + fixed_vllm_outputs.append( + (hf_input_ids + vllm_output[0][len(hf_input_ids):], + prompt + vllm_output[1])) + return fixed_vllm_outputs + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models( monkeypatch: pytest.MonkeyPatch, hf_runner, @@ -56,8 +72,13 @@ def test_models( dtype: str, max_tokens: int, enforce_eager: bool, + enable_prompt_embeds: bool, ) -> None: + if enable_prompt_embeds and envs.is_set( + "VLLM_USE_V1") and envs.VLLM_USE_V1: + pytest.skip("enable_prompt_embeds is not supported in v1.") + if backend == "FLASHINFER" and current_platform.is_rocm(): pytest.skip("Flashinfer does not support ROCm/HIP.") @@ -78,14 +99,25 @@ def test_models( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + if enable_prompt_embeds: + with torch.no_grad(): + prompt_embeds = hf_model.get_prompt_embeddings( + example_prompts) with VllmRunner(model, max_model_len=8192, dtype=dtype, enforce_eager=enforce_eager, + enable_prompt_embeds=enable_prompt_embeds, gpu_memory_utilization=0.7) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) + if enable_prompt_embeds: + vllm_outputs = vllm_model.generate_greedy( + prompt_embeds, max_tokens) + vllm_outputs = _fix_prompt_embed_outputs( + vllm_outputs, hf_model, example_prompts) + else: + vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -108,6 +140,7 @@ def test_models( ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"), ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), ]) +@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models_distributed( monkeypatch: pytest.MonkeyPatch, hf_runner, @@ -117,14 +150,22 @@ def test_models_distributed( distributed_executor_backend: str, attention_backend: str, test_suite: str, + enable_prompt_embeds: bool, ) -> None: + if enable_prompt_embeds and envs.is_set( + "VLLM_USE_V1") and envs.VLLM_USE_V1: + pytest.skip("enable_prompt_embeds is not supported in v1.") + if test_suite != TARGET_TEST_SUITE: pytest.skip(f"Skip test for {test_suite}") with monkeypatch.context() as monkeypatch_context: if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa - # test Ray Compiled Graph + if enable_prompt_embeds: + pytest.skip( + "enable_prompt_embeds does not work with ray compiled dag." + ) monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") @@ -147,12 +188,26 @@ def test_models_distributed( dtype=dtype, tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + if enable_prompt_embeds: + with hf_runner(model, dtype=dtype) as hf_model: + with torch.no_grad(): + prompt_embeds = hf_model.get_prompt_embeddings( + example_prompts) + vllm_outputs = vllm_model.generate_greedy( + prompt_embeds, max_tokens) + vllm_outputs = _fix_prompt_embed_outputs( + vllm_outputs, hf_model, example_prompts) + hf_outputs = hf_model.generate_greedy( + example_prompts, max_tokens) + else: + vllm_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy( + example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, diff --git a/tests/conftest.py b/tests/conftest.py index c5700179c228..19c2c6247129 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -430,6 +430,15 @@ def get_inputs( return all_inputs + def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]: + all_inputs = self.get_inputs(prompts) + embeddings = [] + for inputs in all_inputs: + input_ids = self.wrap_device(inputs)["input_ids"] + embedding = self.model.get_input_embeddings()(input_ids).squeeze(0) + embeddings.append(embedding) + return embeddings + def classify(self, prompts: list[str]) -> list[str]: # output is final logits all_inputs = self.get_inputs(prompts) diff --git a/tests/distributed/test_events.py b/tests/distributed/test_events.py index 15bcfdb8555f..8de1aa20eabd 100644 --- a/tests/distributed/test_events.py +++ b/tests/distributed/test_events.py @@ -119,13 +119,12 @@ def test_topic_filtering(publisher_config): """ publisher_config.replay_endpoint = None - cfg = publisher_config.model_copy() - cfg.topic = "foo" - pub = EventPublisherFactory.create(cfg) + publisher_config.topic = "foo" + pub = EventPublisherFactory.create(publisher_config) from .conftest import MockSubscriber - sub_foo = MockSubscriber(cfg.endpoint, None, "foo") - sub_bar = MockSubscriber(cfg.endpoint, None, "bar") + sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo") + sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar") try: time.sleep(0.1) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index 711c2441f34b..f9eacc11d75f 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -9,7 +9,7 @@ from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import get_ip, get_open_port, update_environment_variables +from vllm.utils import get_open_port, update_environment_variables def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]: @@ -60,12 +60,12 @@ def worker_fn(): rank = dist.get_rank() if rank == 0: port = get_open_port() - ip = get_ip() + ip = '127.0.0.1' dist.broadcast_object_list([ip, port], src=0) else: recv = [None, None] dist.broadcast_object_list(recv, src=0) - ip, port = recv + ip, port = recv # type: ignore stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) @@ -107,10 +107,10 @@ def worker_fn(): if pg == dist.group.WORLD: dist.barrier() - print("torch distributed passed the test!") + print(f"torch distributed passed the test! Rank {rank}") else: pg.barrier() - print("StatelessProcessGroup passed the test!") + print(f"StatelessProcessGroup passed the test! Rank {rank}") def test_shm_broadcast(): diff --git a/tests/entrypoints/openai/correctness/test_mteb.py b/tests/entrypoints/openai/correctness/test_mteb.py new file mode 100644 index 000000000000..b702e0acd38b --- /dev/null +++ b/tests/entrypoints/openai/correctness/test_mteb.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +import math +import os + +import pytest + +from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS, + OpenAIClientMtebEncoder, + run_mteb_embed_task, + run_mteb_embed_task_st) +from tests.utils import RemoteOpenAIServer + +os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" + +MODEL_NAME = "BAAI/bge-m3" +DTYPE = "float16" +MAIN_SCORE = 0.7873427091972599 + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", "embed", "--dtype", DTYPE, "--enforce-eager", + "--max-model-len", "512" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +def test_mteb(server): + client = server.get_client() + encoder = OpenAIClientMtebEncoder(MODEL_NAME, client) + vllm_main_score = run_mteb_embed_task(encoder, MTEB_EMBED_TASKS) + st_main_score = MAIN_SCORE or run_mteb_embed_task_st( + MODEL_NAME, MTEB_EMBED_TASKS) + + print("VLLM main score: ", vllm_main_score) + print("SentenceTransformer main score: ", st_main_score) + print("Difference: ", st_main_score - vllm_main_score) + + assert math.isclose(st_main_score, vllm_main_score, rel_tol=1e-4) diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 5c585d54c429..cae2a3b59553 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Final + import pytest import schemathesis +from hypothesis import settings from schemathesis import GenerationConfig from ...utils import RemoteOpenAIServer @@ -9,6 +12,8 @@ MODEL_NAME = "HuggingFaceTB/SmolVLM-256M-Instruct" MAXIMUM_IMAGES = 2 +DEFAULT_TIMEOUT_SECONDS: Final[int] = 10 +LONG_TIMEOUT_SECONDS: Final[int] = 60 @pytest.fixture(scope="module") @@ -42,8 +47,58 @@ def get_schema(server): schema = schemathesis.from_pytest_fixture("get_schema") +@schemathesis.hook +def before_generate_case(context: schemathesis.hooks.HookContext, strategy): + op = context.operation + assert op is not None + + def no_file_type(case: schemathesis.models.Case): + """ + This filter skips test cases for the `POST /tokenize` endpoint where the + HTTP request body uses `"type": "file"` in any message's content. + We expect these cases to fail because that type isn't implemented here + https://github.com/vllm-project/vllm/blob/0b34593017953051b3225b1483ce0f4670e3eb0e/vllm/entrypoints/chat_utils.py#L1038-L1095 + + Example test cases that are skipped: + curl -X POST -H 'Content-Type: application/json' \ + -d '{"messages": [{"role": "assistant"}, {"content": [{"file": {}, "type": "file"}], "role": "user"}]}' \ + http://localhost:8000/tokenize + + curl -X POST -H 'Content-Type: application/json' \ + -d '{"messages": [{"content": [{"file": {}, "type": "file"}], "role": "user"}]}' \ + http://localhost:8000/tokenize + """ # noqa: E501 + if (op.method.lower() == "post" and op.path == "/tokenize" + and hasattr(case, "body") and isinstance(case.body, dict) + and "messages" in case.body + and isinstance(case.body["messages"], list) + and len(case.body["messages"]) > 0): + for message in case.body["messages"]: + if not isinstance(message, dict): + continue + content = message.get("content", []) + if not isinstance(content, list) or len(content) == 0: + continue + if any(item.get("type") == "file" for item in content): + return False + return True + + return strategy.filter(no_file_type) + + @schema.parametrize() @schema.override(headers={"Content-Type": "application/json"}) +@settings(deadline=LONG_TIMEOUT_SECONDS * 1000) def test_openapi_stateless(case: schemathesis.Case): + key = ( + case.operation.method.upper(), + case.operation.path, + ) + timeout = { + # requires a longer timeout + ("POST", "/v1/chat/completions"): + LONG_TIMEOUT_SECONDS, + }.get(key, DEFAULT_TIMEOUT_SECONDS) + #No need to verify SSL certificate for localhost - case.call_and_validate(verify=False) + case.call_and_validate(verify=False, timeout=timeout) diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py new file mode 100644 index 000000000000..92ba1376e200 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock + +import pytest + +from tests.entrypoints.openai.tool_parsers.utils import ( + run_tool_extraction, run_tool_extraction_streaming) +from vllm.entrypoints.openai.protocol import FunctionCall +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager + +# Test cases similar to pythonic parser but with Llama4 specific format +SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]" +SIMPLE_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{"city": "LA", "metric": "C"}', +) +MORE_TYPES_FUNCTION_OUTPUT = ("[register_user(name='Doe', " + "age=9, " + "address={'city': 'LA', 'state': 'CA'}, " + "role=None, " + "passed_test=True, " + "aliases=['John', 'Johnny'])]") +MORE_TYPES_FUNCTION_CALL = FunctionCall( + name="register_user", + arguments='{"name": "Doe", ' + '"age": 9, ' + '"address": {"city": "LA", "state": "CA"}, ' + '"role": null, ' + '"passed_test": true, ' + '"aliases": ["John", "Johnny"]}', +) +PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]" +PARAMETERLESS_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{}', +) +EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]" +EMPTY_DICT_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"additional_data": {}}', +) +EMPTY_LIST_FUNCTION_OUTPUT = "[do_something_cool(steps=[])]" +EMPTY_LIST_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"steps": []}', +) +ESCAPED_STRING_FUNCTION_OUTPUT = ( + r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]") +ESCAPED_STRING_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', +) +PYTHON_TAG_FUNCTION_OUTPUT = ( + "<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>") + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_tool_call(streaming: bool): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + model_output = "How can I help you today?" + + content, tool_calls = run_tool_extraction(tool_parser, + model_output, + streaming=streaming) + + assert content == model_output + assert len(tool_calls) == 0 + + +test_str = "<|python_start|>" +test_str += "[get_weather(city='LA', metric='C')," +test_str += "register_user(name='Doe', age=9)]" +TEST_CASES = [ + pytest.param(True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="simple_streaming"), + pytest.param(False, + SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], + id="simple_nonstreaming"), + pytest.param(True, + MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming"), + pytest.param(False, + MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming"), + pytest.param(True, + PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_streaming"), + pytest.param(False, + PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_nonstreaming"), + pytest.param(True, + EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_streaming"), + pytest.param(False, + EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_nonstreaming"), + pytest.param(True, + EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_streaming"), + pytest.param(False, + EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_nonstreaming"), + pytest.param(True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_streaming"), + pytest.param(False, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_nonstreaming"), + pytest.param( + True, + "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_streaming"), + pytest.param( + False, + "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_nonstreaming"), + pytest.param(True, + PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], + id="python_tag_streaming"), + pytest.param(False, + PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], + id="python_tag_nonstreaming"), + pytest.param(True, + test_str, [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_streaming"), + pytest.param(False, + "<|python_start|>[get_weather(city='LA', metric='C'), " + + "register_user(name='Doe', age=9)]", [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_nonstreaming"), +] + + +@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", + TEST_CASES) +def test_tool_call(streaming: bool, model_output: str, + expected_tool_calls: list[FunctionCall]): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + + content, tool_calls = run_tool_extraction(tool_parser, + model_output, + streaming=streaming) + + assert len(tool_calls) == len(expected_tool_calls) + for actual, expected in zip(tool_calls, expected_tool_calls): + assert actual.type == "function" + assert actual.function == expected + + +def test_streaming_tool_call_with_large_steps(): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + model_output_deltas = [ + "<|python_start|>[get_weather(city='LA', metric='C'), " + "get_weather(), " + "do_something_cool(steps=[])]<|python_end|>", + ] + + reconstructor = run_tool_extraction_streaming( + tool_parser, model_output_deltas, assert_one_tool_per_delta=False) + + assert reconstructor.other_content == "" + assert len(reconstructor.tool_calls) == 3 + assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL + assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL + assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index e5650136f258..d9f956fbc7c0 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -148,6 +148,11 @@ def test_paged_attention( or (version == "rocm" and head_size not in (64, 128))): pytest.skip() + if (version == "rocm" and current_platform.is_navi() + and (kv_cache_dtype == "fp8" or head_size != 128 + or block_size != 16 or use_alibi)): + pytest.skip() + global PARTITION_SIZE current_platform.seed_everything(seed) @@ -275,6 +280,7 @@ def test_paged_attention( scale, block_tables, seq_lens, + None, block_size, max_seq_len, alibi_slopes, @@ -286,7 +292,7 @@ def test_paged_attention( opcheck(torch.ops._rocm_C.paged_attention, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, + seq_lens, None, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 43ddc79fcb81..299279390fe0 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -575,3 +575,21 @@ def test_moe_align_block_size_opcheck(): opcheck(torch.ops._moe_C.moe_align_block_size, (topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad)) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): + input = torch.randn((m, topk, k), device="cuda", dtype=dtype) + actual = torch.empty((m, k), device="cuda", dtype=dtype) + + expected = input.sum(dim=1) + torch.ops._moe_C.moe_sum(input, actual) + + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) + + opcheck(torch.ops._moe_C.moe_sum, (input, actual)) diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index dfcd61f77587..10e6ac64df87 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.layer import determine_expert_map from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - moe_permute, moe_unpermute) + moe_permute, moe_permute_unpermute_supported, moe_unpermute) from vllm.platforms import current_platform NUM_EXPERTS = [16, 64] @@ -167,6 +167,8 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, n_expert: int, ep_size: int, dtype: torch.dtype, align_block_size: Optional[int]): + if not moe_permute_unpermute_supported(): + pytest.skip("moe_permute_unpermute is not supported on this platform.") fill_invalid_expert = 0 ep_rank = np.random.randint(0, ep_size) expert_map = None diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py index 204624a0540a..7ae33a848a0a 100644 --- a/tests/lora/test_lora_functions.py +++ b/tests/lora/test_lora_functions.py @@ -69,7 +69,7 @@ def run_check(fn, args, expected: list): run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11]) run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11]) - # Remove all LoRAs + # Remove all LoRAs. run_check(llm.remove_lora, 13, [12, 10, 11]) run_check(llm.remove_lora, 12, [10, 11]) run_check(llm.remove_lora, 11, [10]) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 9b7a42acece5..604cb854b32f 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -31,7 +31,7 @@ # not compatible with pip-compile. "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", - "hmellor/bamba-tiny-random", + "hmellor/tiny-random-BambaForCausalLM", ] # Avoid OOM diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 3ccf2999664c..b60d27aaa72b 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -58,8 +58,6 @@ @pytest.mark.parametrize("model_info", MODELS) def test_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: - pytest.skip("Skipping mteb test.") - from .mteb_utils import mteb_test_embed_models vllm_extra_kwargs: dict[str, Any] = {} diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling/test_nomic.py index 6e9de30f977d..28df32e0c230 100644 --- a/tests/models/language/pooling/test_nomic.py +++ b/tests/models/language/pooling/test_nomic.py @@ -23,7 +23,6 @@ @pytest.mark.parametrize("model_info", MODELS) def test_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: - pytest.skip("Skipping mteb test.") from .mteb_utils import mteb_test_embed_models mteb_test_embed_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index 7d9c3c73d852..5679e0e1ce00 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -46,7 +46,6 @@ def test_models_mteb( vllm_runner, model_info: EmbedModelInfo, ) -> None: - pytest.skip("Skipping mteb test.") from .mteb_utils import mteb_test_embed_models mteb_test_embed_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index f94f3457c377..510858c2d7ef 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -41,8 +41,8 @@ reason= "Prevent unstable test based on golden strings from breaking the build " " and test input model being too large and hanging the system.") -@pytest.mark.skipif(not is_quant_method_supported("nvfp4"), - reason="nvfp4 is not supported on this GPU type.") +@pytest.mark.skipif(not is_quant_method_supported("modelopt_fp4"), + reason="modelopt_fp4 is not supported on this GPU type.") @pytest.mark.parametrize("model_name", MODELS) def test_models(example_prompts, model_name) -> None: model = LLM( @@ -50,7 +50,7 @@ def test_models(example_prompts, model_name) -> None: max_model_len=MAX_MODEL_LEN, trust_remote_code=True, enforce_eager=True, - quantization="nvfp4", + quantization="modelopt_fp4", ) tokenizer = AutoTokenizer.from_pretrained(model_name) diff --git a/tests/models/registry.py b/tests/models/registry.py index 84abd42e9231..911a58e99d4c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -124,7 +124,7 @@ def check_available_online( "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B", - extras={"tiny": "hmellor/bamba-tiny-random"}), # noqa: E501 + extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", {"1b": "bigscience/bloomz-1b1"}), "ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b", @@ -147,6 +147,9 @@ def check_available_online( "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), + "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct", + is_available_online=False, + min_transformers_version="4.52.2"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index d61c7d2d5000..a16384efe195 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -77,3 +77,73 @@ def weight_generator(): assert torch.all( new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 + + +def test_module_skip_prefix(): + """Ensure the auto weight loader can skip prefix.""" + mod = ModuleWithNestedBatchNorm() + # Run some data through the module with batchnorm + mod(torch.Tensor([[1, 2], [3, 4]])) + + # Try to load the weights to a new instance + def weight_generator(): + # weights needed to be filtered out + redundant_weights = { + "prefix.bn.weight": torch.Tensor([1, 2]), + "prefix.bn.bias": torch.Tensor([3, 4]), + } + yield from (mod.state_dict() | redundant_weights).items() + + new_mod = ModuleWithNestedBatchNorm() + + assert not torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert not torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 + + loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."]) + loader.load_weights(weight_generator()) + + # Ensure the stats are updated + assert torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 + + +def test_module_skip_substr(): + """Ensure the auto weight loader can skip prefix.""" + mod = ModuleWithNestedBatchNorm() + # Run some data through the module with batchnorm + mod(torch.Tensor([[1, 2], [3, 4]])) + + # Try to load the weights to a new instance + def weight_generator(): + # weights needed to be filtered out + redundant_weights = { + "nested_mod.0.substr.weight": torch.Tensor([1, 2]), + "nested_mod.0.substr.bias": torch.Tensor([3, 4]), + "nested_mod.substr.weight": torch.Tensor([1, 2]), + "nested_mod.substr.bias": torch.Tensor([3, 4]), + } + yield from (mod.state_dict() | redundant_weights).items() + + new_mod = ModuleWithNestedBatchNorm() + + assert not torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert not torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 + + loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."]) + loader.load_weights(weight_generator()) + + # Ensure the stats are updated + assert torch.all( + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + assert torch.all( + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 diff --git a/tests/neuron/2_core/test_mistral.py b/tests/neuron/2_core/test_mistral.py new file mode 100644 index 000000000000..cc3b53a9d7c9 --- /dev/null +++ b/tests/neuron/2_core/test_mistral.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams + + +def test_mistral(): + llm = LLM(model="mistralai/Mistral-7B-v0.1", + tensor_parallel_size=2, + max_num_seqs=4, + max_model_len=128, + use_v2_block_manager=True, + override_neuron_config={ + "sequence_parallel_enabled": False, + "skip_warmup": True + }, + device="neuron") + + # Send more prompts than the compiled batch size (4) and request + # varying generation lengths to test accuracy related to Neuron + # specific sequence id sorting. + prompts = [ + "The president of the United States is", + "The capital of France is", + "What is Annapurna labs?", + "I believe the meaning of life is", + "Tell me a story about a brave knight", + "Hello, my name is Llama", + ] + + sampling_params = [ + SamplingParams(top_k=1, max_tokens=10), + SamplingParams(top_k=1, max_tokens=20), + SamplingParams(top_k=1, max_tokens=30), + SamplingParams(top_k=1, max_tokens=40), + SamplingParams(top_k=1, max_tokens=50), + SamplingParams(top_k=1, max_tokens=60) + ] + + outputs = llm.generate(prompts, sampling_params) + + expected_outputs = [ + " the most powerful person in the world. He is", + " a city of many faces. It is a city of history, culture, art, " + "fashion, and", + "\n\nAnnapurna Labs is a semiconductor company that was founded " + "in 2013 by Amazon. The company is", + " to be happy.\n\nI believe that happiness is a choice.\n\nI " + "believe that happiness is a state of mind.\n\nI believe that " + "happiness is a journey.\n\nI believe", + " who rescued a princess from a dragon.\n\nTell me a story about" + " a princess who rescued herself from a dragon.\n\nTell me a " + "story about a princess who rescued herself from a dragon and " + "then rescued a knight from", + " and I am a 10 year old male. I am a very friendly and " + "affectionate boy who loves to be around people. I am a very " + "active boy who loves to play and run around. I am a very smart " + "boy who loves to learn new things. I am a very loyal boy" + ] + + for expected_output, output in zip(expected_outputs, outputs): + generated_text = output.outputs[0].text + assert (expected_output == generated_text) diff --git a/tests/quantization/test_auto_round.py b/tests/quantization/test_auto_round.py new file mode 100644 index 000000000000..81ceecdb45d6 --- /dev/null +++ b/tests/quantization/test_auto_round.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test model set-up and inference for quantized HF models supported + on the AutoRound. + + Validating the configuration and printing results for manual checking. + + Run `pytest tests/quantization/test_auto_round.py`. +""" + +import pytest + +from vllm.platforms import current_platform + +MODELS = [ + "OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", ##auto_round:auto_gptq + "Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound" ##auto_round:auto_awq +] + + +@pytest.mark.skipif(not current_platform.is_cpu() + and not current_platform.is_xpu() + and not current_platform.is_cuda(), + reason="only supports CPU/XPU/CUDA backend.") +@pytest.mark.parametrize("model", MODELS) +def test_auto_round(vllm_runner, model): + with vllm_runner(model) as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=8) + assert output + print(f"{output[0][1]}") diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 8d9ae282153c..e8ddfd7fc779 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -8,9 +8,11 @@ import pytest import torch +from transformers import BitsAndBytesConfig from tests.quantization.utils import is_quant_method_supported +from ..models.utils import check_embeddings_close from ..utils import compare_two_settings, create_new_process_for_each_test models_4bit_to_test = [ @@ -19,6 +21,10 @@ "quantize inflight model with both HF and Mistral format weights") ] +models_4bit_to_embedding_test = [ + ("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"), +] + models_pre_qaunt_4bit_to_test = [ ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', 'read pre-quantized 4-bit FP4 model'), @@ -39,7 +45,8 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: - hf_model_kwargs = {"load_in_4bit": True} + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True)) validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], model_name, False, hf_model_kwargs) @@ -77,7 +84,8 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: - hf_model_kwargs = {"load_in_4bit": True} + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True)) validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], @@ -113,6 +121,54 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None: compare_two_settings(model_name, common_args, pp_args) +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", + models_4bit_to_embedding_test) +@pytest.mark.parametrize("dtype", ["half"]) +@create_new_process_for_each_test() +def test_4bit_bnb_embedding_model( + model_name, + description, + hf_runner, + vllm_runner, + example_prompts, + dtype: str, +) -> None: + + # The example_prompts has ending "\n", for example: + # "Write a short story about a robot that dreams for the first time.\n" + # sentence_transformers will strip the input texts, see: + # https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159 + # This makes the input_ids different between hf_model and vllm_model. + # So we need to strip the input texts to avoid test failing. + example_prompts = [str(s).strip() for s in example_prompts] + + # Inflight 4bit quantization + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True)) + with hf_runner( + model_name, + dtype=dtype, + model_kwargs=hf_model_kwargs, + is_sentence_transformer=True, + ) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model_name, + task="embed", + dtype=dtype, + quantization="bitsandbytes") as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=5e-2, + ) + + def log_generated_texts(prompts, outputs, runner_name): logged_texts = [] for i, (_, generated_text) in enumerate(outputs): diff --git a/tests/test_outputs.py b/tests/test_outputs.py new file mode 100644 index 000000000000..c41bd6723ba1 --- /dev/null +++ b/tests/test_outputs.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.outputs import RequestOutput + + +def test_request_output_forward_compatible(): + output = RequestOutput(request_id="test_request_id", + prompt="test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[], + finished=False, + example_arg_added_in_new_version="some_value") + assert output is not None diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index c14eaf71e978..efa6455c41df 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -88,7 +88,7 @@ def ensure_system_prompt(messages: list[dict[str, Any]], "meta-llama/Llama-4-Scout-17B-16E-Instruct", "arguments": [ "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "pythonic", "--chat-template", + "--tool-call-parser", "llama4_pythonic", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_llama4_pythonic.jinja"), "-tp", "4" diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 43a27da2dbe4..1e2767e2d198 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -19,8 +19,7 @@ hash_request_tokens, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor, - SlidingWindowSpec) + KVCacheGroupSpec, KVCacheTensor) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -55,14 +54,12 @@ def new_kv_cache_spec(block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, - use_mla=False, - sliding_window=None): + use_mla=False): return FullAttentionSpec(block_size=block_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, - use_mla=use_mla, - sliding_window=sliding_window) + use_mla=use_mla) def test_none_hash(monkeypatch): @@ -495,68 +492,6 @@ def test_unify_kv_cache_configs(): unify_kv_cache_configs(diff_kv_cache_config) -def test_merge_kv_cache_spec(): - same_layer_specs = [ - new_kv_cache_spec(num_kv_heads=32), - new_kv_cache_spec(num_kv_heads=32), - ] - merged_layer_spec = same_layer_specs[0].merge(same_layer_specs) - assert merged_layer_spec.block_size == 16 - assert merged_layer_spec.num_kv_heads == 32 - assert merged_layer_spec.head_size == 64 - assert merged_layer_spec.dtype == torch.float32 - assert merged_layer_spec.sliding_window is None - - different_layer_specs = [ - new_kv_cache_spec(num_kv_heads=32), - new_kv_cache_spec(num_kv_heads=16), - ] - with pytest.raises(AssertionError): - different_layer_specs[0].merge(different_layer_specs) - - full_spec = new_kv_cache_spec(num_kv_heads=32) - different_type_layer_specs = [ - full_spec, - SlidingWindowSpec( - block_size=full_spec.block_size, - num_kv_heads=full_spec.num_kv_heads, - head_size=full_spec.head_size, - dtype=full_spec.dtype, - use_mla=full_spec.use_mla, - sliding_window=1, - ), - ] - with pytest.raises(AssertionError): - different_type_layer_specs[0].merge(different_type_layer_specs) - with pytest.raises(AssertionError): - different_type_layer_specs[1].merge(different_type_layer_specs) - - different_sliding_window_layer_specs = [ - new_kv_cache_spec(num_kv_heads=32), - new_kv_cache_spec(num_kv_heads=32, sliding_window=1), - new_kv_cache_spec(num_kv_heads=32, sliding_window=2), - ] - with pytest.raises(ValueError): - different_sliding_window_layer_specs[0].merge( - different_sliding_window_layer_specs) - - same_sliding_window_layer_specs = [ - new_kv_cache_spec(num_kv_heads=32, sliding_window=1), - new_kv_cache_spec(num_kv_heads=32, sliding_window=1), - ] - merged_layer_spec = same_sliding_window_layer_specs[0].merge( - same_sliding_window_layer_specs) - assert merged_layer_spec.sliding_window == 1 - - same_sliding_window_layer_spec_with_none = [ - new_kv_cache_spec(num_kv_heads=32, sliding_window=1), - new_kv_cache_spec(num_kv_heads=32, sliding_window=None), - ] - merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge( - same_sliding_window_layer_spec_with_none) - assert merged_layer_spec.sliding_window == 1 - - @pytest.mark.parametrize( ("model_id", "max_model_len", "want_estimated_max_len"), [ ("Qwen/Qwen1.5-7B", 16385, 16384), diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3da27786b1f2..2d7411381e16 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -84,7 +84,7 @@ def test_prefill(hash_algo): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] # Check full block metadata parent_block_hash = None @@ -107,13 +107,13 @@ def test_prefill(hash_algo): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert computed_blocks.get_block_ids() == [[1, 2, 3]] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[5]] + assert blocks.get_block_ids() == [5] for block in computed_blocks.blocks: assert block.ref_cnt == 2 @@ -141,13 +141,13 @@ def test_prefill(hash_algo): req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 - assert computed_blocks.get_block_ids() == [[1, 2, 3]] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[6]] + assert blocks.get_block_ids() == [6] # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -171,7 +171,7 @@ def test_prefill(hash_algo): len(computed_blocks.blocks) * 16, computed_blocks) # This block ID order also checks the eviction order. - assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]] + assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None @@ -208,7 +208,7 @@ def test_prefill_plp(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] req0_block_hashes = [b.block_hash for b in blocks.blocks] # Check full block metadata @@ -233,13 +233,13 @@ def test_prefill_plp(): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert computed_blocks.get_block_ids() == [[1, 2, 3]] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[5]] + assert blocks.get_block_ids() == [5] for block in computed_blocks.blocks: assert block.ref_cnt == 2 @@ -277,11 +277,11 @@ def test_prefill_plp(): block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks] == req0_block_hashes - assert block_ids != [[1, 2, 3, 4]] + assert block_ids != [1, 2, 3, 4] # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. - for block_id in block_ids[0]: + for block_id in block_ids: assert manager.block_pool.blocks[block_id].ref_cnt == 1 manager.free(req2) @@ -307,7 +307,7 @@ def test_decode(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] # Append slots without allocating a new block. req0.num_computed_tokens = 55 @@ -379,12 +379,12 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert computed_blocks.get_block_ids() == [[1, 2]] + assert computed_blocks.get_block_ids() == [1, 2] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[10]] + assert blocks.get_block_ids() == [10] assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -625,7 +625,7 @@ def test_mm_prefix_caching(): blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -686,7 +686,7 @@ def test_cache_key_salting(): blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -797,7 +797,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) blocks = manager.allocate_slots(req0, 55) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids @@ -808,7 +808,7 @@ def test_reset_prefix_cache(): blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[5]] + assert blocks.get_block_ids() == [5] # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index a8a713d446b7..220f05c7ff1c 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -16,31 +16,40 @@ FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available +@pytest.fixture(autouse=True) +def reset_default_device(): + """ + Explicitly set the default device, which can affect subsequent tests. + Adding this fixture helps avoid this problem. + """ + original_device = torch.get_default_device() + yield + torch.set_default_device(original_device) + + def test_topk_impl_equivalance(): - with torch.device(DEVICE): - generator = Generator(device=DEVICE).manual_seed(33) + torch.set_default_device(DEVICE) + generator = Generator(device=DEVICE).manual_seed(33) - logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) - # Random top-k values between 1 and 9. - k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator) + # Random top-k values between 1 and 9. + k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator) - # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). - k.masked_fill_( - torch.randint(0, - 2, (BATCH_SIZE, ), - generator=generator, - dtype=bool), VOCAB_SIZE) + # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). + k.masked_fill_( + torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool), + VOCAB_SIZE) - # Top-k only implementation - result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) + # Top-k only implementation + result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) - # Top-p + top-k - no_op_top_p = torch.tensor([1.0]) - result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) + # Top-p + top-k + no_op_top_p = torch.tensor([1.0]) + result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) - assert torch.allclose(result1, result2) + assert torch.allclose(result1, result2) def test_flashinfer_sampler(): @@ -58,50 +67,49 @@ def test_flashinfer_sampler(): pytest.skip( "FlashInfer not installed or not available on this platform.") - with torch.device(DEVICE): - generator = Generator(device=DEVICE).manual_seed(42) - - # Generate random logits - logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) - - # Generate various top-k and top-p values - k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator) - p_values = torch.rand( - (BATCH_SIZE, ), - generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0] - - # Sometimes disable top-k (k=vocab_size) - k_values.masked_fill_( - torch.randint(0, - 2, (BATCH_SIZE, ), - generator=generator, - dtype=torch.bool), VOCAB_SIZE) - - # Sometimes disable top-p (p=1.0) - p_values.masked_fill_( - torch.randint(0, - 2, (BATCH_SIZE, ), - generator=generator, - dtype=torch.bool), 1.0) - - python_logits = apply_top_k_top_p( - logits=logits.clone(), - k=k_values, - p=p_values, - ) - python_probs = torch.softmax(python_logits, dim=-1) - - # FlashInfer only exposed renorm interfaces for probs so convert first - flashinfer_probs = torch.softmax(logits.clone(), dim=-1) - flashinfer_probs = top_k_renorm_probs( - probs=flashinfer_probs, - top_k=k_values, - ) - flashinfer_probs = top_p_renorm_probs( - probs=flashinfer_probs, - top_p=p_values, - ) - - # Compare the results - assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \ - "FlashInfer and Python sampling implementations do not match!" + torch.set_default_device(DEVICE) + generator = Generator(device=DEVICE).manual_seed(42) + + # Generate random logits + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) + + # Generate various top-k and top-p values + k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator) + p_values = torch.rand( + (BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0] + + # Sometimes disable top-k (k=vocab_size) + k_values.masked_fill_( + torch.randint(0, + 2, (BATCH_SIZE, ), + generator=generator, + dtype=torch.bool), VOCAB_SIZE) + + # Sometimes disable top-p (p=1.0) + p_values.masked_fill_( + torch.randint(0, + 2, (BATCH_SIZE, ), + generator=generator, + dtype=torch.bool), 1.0) + + python_logits = apply_top_k_top_p( + logits=logits.clone(), + k=k_values, + p=p_values, + ) + python_probs = torch.softmax(python_logits, dim=-1) + + # FlashInfer only exposed renorm interfaces for probs so convert first + flashinfer_probs = torch.softmax(logits.clone(), dim=-1) + flashinfer_probs = top_k_renorm_probs( + probs=flashinfer_probs, + top_k=k_values, + ) + flashinfer_probs = top_p_renorm_probs( + probs=flashinfer_probs, + top_p=p_values, + ) + + # Compare the results + assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \ + "FlashInfer and Python sampling implementations do not match!" diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index c34c673e985e..1b77417a1bd3 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -12,7 +12,7 @@ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder "mistralai/Mamba-Codestral-7B-v0.1", # mamba - "hmellor/bamba-tiny-random", # hybrid + "hmellor/tiny-random-BambaForCausalLM", # hybrid "BAAI/bge-m3", # embedding ] diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 638f5bedcfca..7b1359c8576f 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -9,11 +9,9 @@ from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState, + InputBatch) VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 @@ -24,27 +22,6 @@ MAX_NUM_PROMPT_TOKENS = 64 -def get_kv_cache_config() -> KVCacheConfig: - return KVCacheConfig( - num_blocks=10, - tensors={ - "layer.0": KVCacheTensor(size=1024), - }, - kv_cache_groups=[ - KVCacheGroupSpec( - layer_names=["layer.0"], - kv_cache_spec=FullAttentionSpec( - block_size=1, - num_kv_heads=1, - head_size=16, - dtype=torch.float16, - use_mla=False, - ), - ), - ], - ) - - def _compare_objs(obj1, obj2): attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) attr_names = set([ @@ -64,10 +41,6 @@ def _compare_objs(obj1, obj2): elif isinstance(a, np.ndarray): if np.allclose(a, b): is_same = True - elif isinstance(a, MultiGroupBlockTable): - for a_i, b_i in zip(a.block_tables, b.block_tables): - _compare_objs(a_i, b_i) - is_same = True elif isinstance(a, (BlockTable, SamplingMetadata)): _compare_objs(a, b) is_same = True # if we make it here must be same @@ -225,7 +198,7 @@ def _construct_cached_request_state(req_id_suffix: int): sampling_params=_create_sampling_params(), mm_inputs=[], mm_positions=[], - block_ids=[[]], + block_ids=[], generator=None, num_computed_tokens=len(output_token_ids), output_token_ids=output_token_ids, @@ -247,11 +220,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, + max_num_blocks_per_req=10, max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, - kv_cache_config=get_kv_cache_config(), ) reqs: list[CachedRequestState] = [] req_id_reqs = {} @@ -337,20 +310,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, + max_num_blocks_per_req=10, max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, - kv_cache_config=get_kv_cache_config(), ) ref_input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, + max_num_blocks_per_req=10, max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, - kv_cache_config=get_kv_cache_config(), ) reqs: list[CachedRequestState] = [] diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index e44660525763..725747294fd8 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,16 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 +import weakref import pytest +import torch -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VllmConfig) +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor) +from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -18,34 +17,13 @@ def initialize_kv_cache(runner: GPUModelRunner): """ Only perform necessary steps in GPUModelRunner.initialize_kv_cache() """ - kv_cache_config = KVCacheConfig( - num_blocks=10, - tensors={ - "layer.0": KVCacheTensor(size=1024), - }, - kv_cache_groups=[ - KVCacheGroupSpec( - layer_names=["layer.0"], - kv_cache_spec=FullAttentionSpec( - block_size=16, - num_kv_heads=runner.model_config.get_num_kv_heads( - runner.parallel_config), - head_size=runner.model_config.get_head_size(), - dtype=runner.kv_cache_dtype, - use_mla=False, - )) - ]) - runner.kv_cache_config = kv_cache_config - runner.input_batch = InputBatch( - max_num_reqs=runner.max_num_reqs, - max_model_len=runner.max_model_len, - max_num_batched_tokens=runner.max_num_tokens, - device=runner.device, - pin_memory=runner.pin_memory, - vocab_size=runner.model_config.get_vocab_size(), - kv_cache_config=kv_cache_config, - ) - runner.initialize_attn_backend(kv_cache_config) + kv_cache_spec = FullAttentionSpec(block_size=16, + num_kv_heads=1, + head_size=64, + dtype=torch.float16, + use_mla=False) + runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()( + weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table) @pytest.fixture @@ -70,12 +48,10 @@ def model_runner(): swap_space=0, cache_dtype="auto", ) - parallel_config = ParallelConfig() vllm_config = VllmConfig( model_config=model_config, cache_config=cache_config, scheduler_config=scheduler_config, - parallel_config=parallel_config, ) device = "cuda" @@ -97,7 +73,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), - block_ids=[[0]], + block_ids=[0], num_computed_tokens=0, lora_request=None, )) @@ -135,14 +111,13 @@ def _is_sampling_metadata_changed(model_runner, def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: req_index = model_runner.input_batch.req_id_to_index[req_id] - block_table = model_runner.input_batch.block_table[0] + block_table = model_runner.input_batch.block_table req_state = model_runner.requests[req_id] - if block_table.num_blocks_per_row[req_index] != len( - req_state.block_ids[0]): + if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids): return False num_blocks = block_table.num_blocks_per_row[req_index] return (block_table.block_table_np[req_index, :num_blocks] == - req_state.block_ids[0]).all() + req_state.block_ids).all() def test_update_states_new_request(model_runner): @@ -225,7 +200,7 @@ def test_update_states_request_resumed(model_runner): req_id=req_id, resumed_from_preemption=False, new_token_ids=[], - new_block_ids=[[]], + new_block_ids=[], num_computed_tokens=0, ) diff --git a/tools/install_nixl.sh b/tools/install_nixl.sh new file mode 100644 index 000000000000..56717cfb77f7 --- /dev/null +++ b/tools/install_nixl.sh @@ -0,0 +1,109 @@ +#!/bin/bash +# Usage: ./install_nixl.sh [--force] + +FORCE=false +if [ "$1" == "--force" ]; then + FORCE=true +fi + +SUDO=false +if command -v sudo >/dev/null 2>&1 && sudo -n true 2>/dev/null; then + SUDO=true +fi + +ARCH=$(uname -m) + +ROOT_DIR="/usr/local" +mkdir -p "$ROOT_DIR" +GDR_HOME="$ROOT_DIR/gdrcopy" +UCX_HOME="$ROOT_DIR/ucx" +NIXL_HOME="$ROOT_DIR/nixl" +CUDA_HOME=/usr/local/cuda + +export PATH="$GDR_HOME/bin:$UCX_HOME/bin:$NIXL_HOME/bin:$PATH" +export LD_LIBRARY_PATH="$GDR_HOME/lib:$UCX_HOME/lib:$NIXL_HOME/lib/$ARCH-linux-gnu:$LD_LIBRARY_PATH" + +TEMP_DIR="nixl_installer" +mkdir -p "$TEMP_DIR" +cd "$TEMP_DIR" + +pip install meson ninja pybind11 + +if [ ! -e "/dev/gdrdrv" ] || [ "$FORCE" = true ]; then + echo "Installing gdrcopy\n" + wget https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v2.5.tar.gz + tar xzf v2.5.tar.gz; rm v2.5.tar.gz + cd gdrcopy-2.5 + make prefix=$GDR_HOME CUDA=$CUDA_HOME all install + + if $SUDO; then + echo "Running insmod.sh with sudo" + sudo ./insmod.sh + else + echo "Skipping insmod.sh - sudo not available" + echo "Please run 'sudo ./gdrcopy-2.5/insmod.sh' manually if needed" + fi + + cd .. +else + echo "Found /dev/gdrdrv. Skipping gdrcopy installation" +fi + +if ! command -v ucx_info &> /dev/null || [ "$FORCE" = true ]; then + echo "Installing UCX" + wget https://github.com/openucx/ucx/releases/download/v1.18.0/ucx-1.18.0.tar.gz + tar xzf ucx-1.18.0.tar.gz; rm ucx-1.18.0.tar.gz + cd ucx-1.18.0 + + # Checking Mellanox NICs + MLX_OPTS="" + if lspci | grep -i mellanox > /dev/null || command -v ibstat > /dev/null; then + echo "Mellanox NIC detected, adding Mellanox-specific options" + MLX_OPTS="--with-rdmacm \ + --with-mlx5-dv \ + --with-ib-hw-tm" + fi + + ./configure --prefix=$UCX_HOME \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=$CUDA_HOME \ + --with-dm \ + --with-gdrcopy=$GDR_HOME \ + --with-verbs \ + --enable-mt \ + $MLX_OPTS + make -j + make -j install-strip + + if $SUDO; then + echo "Running ldconfig with sudo" + sudo ldconfig + else + echo "Skipping ldconfig - sudo not available" + echo "Please run 'sudo ldconfig' manually if needed" + fi + + cd .. +else + echo "Found existing UCX. Skipping UCX installation" +fi + +if ! command -v nixl_test &> /dev/null || [ "$FORCE" = true ]; then + echo "Installing NIXL" + wget https://github.com/ai-dynamo/nixl/archive/refs/tags/0.2.0.tar.gz + tar xzf 0.2.0.tar.gz; rm 0.2.0.tar.gz + cd nixl-0.2.0 + meson setup build --prefix=$NIXL_HOME -Ducx_path=$UCX_HOME + cd build + ninja + ninja install + + cd ../.. +else + echo "Found existing NIXL. Skipping NIXL installation" +fi diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8076c4791d3c..abcb68911a8b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -861,7 +861,8 @@ def forward( gqa_ratio = num_heads // self.num_kv_heads use_custom = use_rocm_custom_paged_attention( decode_query.dtype, head_size, block_size, gqa_ratio, - decode_meta.max_decode_seq_len, self.sliding_window) + decode_meta.max_decode_seq_len, self.sliding_window, + self.kv_cache_dtype, self.alibi_slopes) if use_custom: max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type != AttentionType.ENCODER_DECODER else diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 217db3bf965d..785799b6bf68 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -283,7 +283,8 @@ def chunked_prefill_paged_decode( use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, block_size, num_queries_per_kv, - max_seq_len, sliding_window) + max_seq_len, sliding_window, + kv_cache_dtype, alibi_slopes) if use_custom: _PARTITION_SIZE_ROCM = 256 max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 241e84ca669d..4bced779785a 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -31,8 +31,8 @@ def apply_softcap(S, x): def kernel_unified_attention_2d( output_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] - value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0c1381a565c1..8114cddcd9fa 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -6,9 +6,7 @@ import pprint import time from collections.abc import Sequence -from contextlib import ExitStack from typing import Any, Callable, Optional -from unittest.mock import patch import torch import torch.fx as fx @@ -16,13 +14,13 @@ import vllm.envs as envs from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger -from vllm.utils import weak_ref_tensors +from vllm.platforms import current_platform +from vllm.utils import resolve_obj_by_qualname from .compiler_interface import (CompilerInterface, EagerAdaptor, InductorAdaptor, InductorStandaloneAdaptor) from .counter import compilation_counter from .inductor_pass import InductorPass -from .monitor import end_monitoring_torch_compile from .pass_manager import PostGradPassManager logger = init_logger(__name__) @@ -297,7 +295,9 @@ def call_module(self, target: torch.fx.node.Target, num_graphs=len(self.compile_submod_names), runtime_shape=None) - self.module.__dict__[target] = PiecewiseBackend( + piecewise_backend = resolve_obj_by_qualname( + current_platform.get_piecewise_backend_cls()) + self.module.__dict__[target] = piecewise_backend( submod, self.vllm_config, self.graph_pool, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_general_shape, self.vllm_backend) @@ -341,7 +341,7 @@ def __init__( ): global global_graph_pool if global_graph_pool is None: - global_graph_pool = torch.cuda.graph_pool_handle() + global_graph_pool = current_platform.graph_pool_handle() # TODO: in the future, if we want to use multiple # streams, it might not be safe to share a global pool. @@ -558,197 +558,3 @@ def copy_and_call(*args): return self.split_gm(*list_args) return copy_and_call - - -@dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int - need_to_compile: bool # the size is in compile_sizes - use_cudagraph: bool # the size is in cudagraph_capture_sizes - - compiled: bool = False - runnable: Callable = None # type: ignore - num_finished_warmup: int = 0 - cudagraph: Optional[torch.cuda.CUDAGraph] = None - output: Optional[Any] = None - - # for cudagraph debugging, track the input addresses - # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None - - -class PiecewiseBackend: - - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend): - """ - The backend for piecewise compilation. - It mainly handles the compilation and cudagraph capturing. - - We will compile `self.graph` once for the general shape, - and then compile for different shapes specified in - `compilation_config.compile_sizes`. - - Independently, we will capture cudagraph for different shapes. - - If a shape needs both compilation and cudagraph, we will - compile it first, and then capture cudagraph. - """ - self.graph = graph - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool - self.piecewise_compile_index = piecewise_compile_index - self.total_piecewise_compiles = total_piecewise_compiles - self.vllm_backend = vllm_backend - - self.is_first_graph = piecewise_compile_index == 0 - self.is_last_graph = ( - piecewise_compile_index == total_piecewise_compiles - 1) - - self.compile_sizes: set[int] = set( - self.compilation_config.compile_sizes) - self.cudagraph_capture_sizes: set[int] = set( - self.compilation_config.cudagraph_capture_sizes - ) if self.compilation_config.use_cudagraph else set() - - self.first_run_finished = False - - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa - - self.sym_shape_indices = sym_shape_indices - - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - - # the entries for different shapes that we need to either - # compile or capture cudagraph - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} - - # to_be_compiled_sizes tracks the remaining sizes to compile, - # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() - for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, - need_to_compile=shape in self.compile_sizes, - use_cudagraph=shape in self.cudagraph_capture_sizes, - ) - - def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: - # no specific sizes to compile - # save the hash of the inductor graph for the next run - self.vllm_backend.compiler_manager.save_to_file() - end_monitoring_torch_compile(self.vllm_config) - - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) - - runtime_shape = args[self.sym_shape_indices[0]] - if runtime_shape not in self.concrete_size_entries: - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) - - entry = self.concrete_size_entries[runtime_shape] - - if entry.runnable is None: - entry.runnable = self.compiled_graph_for_general_shape - - if entry.need_to_compile and not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) - # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( - self.graph, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=self.piecewise_compile_index, - num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape) - - # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: - self.check_for_ending_compilation() - - if not entry.use_cudagraph: - return entry.runnable(*args) - - if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa - entry.num_finished_warmup += 1 - if self.is_first_graph: - logger.debug( - "Warming up %s/%s for shape %s", - entry.num_finished_warmup, - self.compilation_config.cudagraph_num_of_warmups, - runtime_shape) - return entry.runnable(*args) - - if self.is_first_graph: - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every shape. - # We only log it in the debug mode. - logger.debug("Capturing a cudagraph for shape %s", - runtime_shape) - - input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - entry.input_addresses = input_addresses - cudagraph = torch.cuda.CUDAGraph() - - with ExitStack() as stack: - if not self.is_first_graph: - # during every model forward, we will capture - # many pieces of cudagraphs (roughly one per layer). - # running gc again and again across layers will - # make the cudagraph capture very slow. - # therefore, we only run gc for the first graph, - # and disable gc for the rest of the graphs. - stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context( - patch("torch.cuda.empty_cache", lambda: None)) - - # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): - # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args) - if self.is_last_graph: - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. It is only safe to do this for - # the last graph, because the output of the last graph - # will not be used by any other cuda graph. - output = weak_ref_tensors(output) - - # here we always use weak ref for the output - # to save memory - entry.output = weak_ref_tensors(output) - entry.cudagraph = cudagraph - - compilation_counter.num_cudagraph_caputured += 1 - - # important: we need to return the output, rather than - # the weak ref of the output, so that pytorch can correctly - # manage the memory during cuda graph capture - return output - - if self.is_debugging_mode: - # check if the input addresses are the same - new_input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - assert new_input_addresses == entry.input_addresses, ( - "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}" - ) - - entry.cudagraph.replay() - return entry.output diff --git a/vllm/compilation/base_piecewise_backend.py b/vllm/compilation/base_piecewise_backend.py new file mode 100644 index 000000000000..84d1e1f77739 --- /dev/null +++ b/vllm/compilation/base_piecewise_backend.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Protocol + +import torch.fx as fx + +from vllm.compilation.backends import VllmBackend +from vllm.config import VllmConfig + + +class AbstractPiecewiseBackend(Protocol): + """ + PiecewiseBackend interface that allows platforms to extend + piecewise static graph. + """ + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend, **kwargs): + """ + Initializes the PiecewiseBackend class with compilation and + execution-related configurations. + + This class handles piecewise compilation, graph capturing, + and dispatching for specific input shapes. + + Args: + graph (fx.GraphModule): The graph represented in fx. + vllm_config (VllmConfig): Global configuration for vLLM. + graph_pool (Any): + Graph memory pool handle, e.g., + `torch.cuda.graph_pool_handle()`. + piecewise_compile_index (int): + Index of the current piecewise subgraph. + total_piecewise_compiles (int): + Total number of piecewise-compiled graphs. + sym_shape_indices (list[int]): + Indices of symbolic shape. + compiled_graph_for_general_shape (Callable): + Callable that executes the graph compiled for general shapes. + vllm_backend (VllmBackend): + Backend compiler that manages compilation and graph runtime + for vLLM. + + Keyword Args: + kwargs: Additional keyword arguments reserved for future + extensions or custom platforms. + """ + raise NotImplementedError + + def __call__(self, *args) -> Any: + """Executes the compiled graph for given input args. + + If this is the first invocation, executes the general compiled graph + and initiates the compilation process tracking. For subsequent calls, + dynamically dispatches execution to either a compiled graph or a static + graph based on the input shape. + + Args: + *args: Variable length input arguments to be passed into the + graph. The symbolic shape is expected to be in position + `sym_shape_indices[0]`. + + Returns: + Any: Output of the executed graph. This can be from the general + compiled graph, a specialized compiled version for the given shape, + or a replayed static graph. + """ + raise NotImplementedError diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py new file mode 100644 index 000000000000..0ad480e28cd7 --- /dev/null +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch +import torch.fx as fx + +import vllm.envs as envs +from vllm.compilation.backends import VllmBackend +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import end_monitoring_torch_compile +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import weak_ref_tensors + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class ConcreteSizeEntry: + runtime_shape: int + need_to_compile: bool # the size is in compile_sizes + use_cudagraph: bool # the size is in cudagraph_capture_sizes + + compiled: bool = False + runnable: Callable = None # type: ignore + num_finished_warmup: int = 0 + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + +class CUDAPiecewiseBackend: + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend): + """ + The backend for piecewise compilation. + It mainly handles the compilation and cudagraph capturing. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_config.compile_sizes`. + + Independently, we will capture cudagraph for different shapes. + + If a shape needs both compilation and cudagraph, we will + compile it first, and then capture cudagraph. + """ + self.graph = graph + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + self.vllm_backend = vllm_backend + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = ( + piecewise_compile_index == total_piecewise_compiles - 1) + + self.compile_sizes: set[int] = set( + self.compilation_config.compile_sizes) + self.cudagraph_capture_sizes: set[int] = set( + self.compilation_config.cudagraph_capture_sizes + ) if self.compilation_config.use_cudagraph else set() + + self.first_run_finished = False + + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + + self.sym_shape_indices = sym_shape_indices + + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # the entries for different shapes that we need to either + # compile or capture cudagraph + self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=shape in self.compile_sizes, + use_cudagraph=shape in self.cudagraph_capture_sizes, + ) + + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.vllm_backend.compiler_manager.save_to_file() + end_monitoring_torch_compile(self.vllm_config) + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + return self.compiled_graph_for_general_shape(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.need_to_compile and not entry.compiled: + entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) + # args are real arguments + entry.runnable = self.vllm_backend.compiler_manager.compile( + self.graph, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + runtime_shape=runtime_shape) + + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + self.check_for_ending_compilation() + + if not entry.use_cudagraph: + return entry.runnable(*args) + + if entry.cudagraph is None: + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa + entry.num_finished_warmup += 1 + if self.is_first_graph: + logger.debug( + "Warming up %s/%s for shape %s", + entry.num_finished_warmup, + self.compilation_config.cudagraph_num_of_warmups, + runtime_shape) + return entry.runnable(*args) + + if self.is_first_graph: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every shape. + # We only log it in the debug mode. + logger.debug("Capturing a cudagraph for shape %s", + runtime_shape) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_caputured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + + entry.cudagraph.replay() + return entry.output diff --git a/vllm/config.py b/vllm/config.py index 24ef675a9c99..ed2c49d05756 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -42,7 +42,10 @@ try_get_generation_config, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect -from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, +from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes, + LayerBlockType, cuda_device_count_stateless, get_cpu_memory, get_open_port, is_torch_equal_or_newer, random_uuid, resolve_obj_by_qualname) @@ -64,12 +67,6 @@ ConfigT = TypeVar("ConfigT", bound=ConfigType) -# This value is chosen to have a balance between ITL and TTFT. Note it is -# not optimized for throughput. -_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 -_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 -_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 - TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", "score", "reward", "transcription"] @@ -824,7 +821,7 @@ def _verify_quantization(self) -> None: optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", - "quark", "nvfp4", "bitblas", "gptq_bitblas" + "quark", "modelopt_fp4", "bitblas", "gptq_bitblas" ] if self.quantization is not None: self.quantization = cast(QuantizationMethods, @@ -2074,28 +2071,28 @@ def __post_init__(self) -> None: # so we don't reject sequences on account of a short # max_num_batched_tokens. self.max_num_batched_tokens = max( - self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS) + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) else: self.max_num_batched_tokens = ( - _DEFAULT_MAX_NUM_BATCHED_TOKENS) + DEFAULT_MAX_NUM_BATCHED_TOKENS) else: # If max_model_len is too short, use - # _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value + # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value # for higher throughput. self.max_num_batched_tokens = max( - self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS) + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) if self.runner_type == "pooling": # Choose specific value for higher throughput self.max_num_batched_tokens = max( self.max_num_batched_tokens, - _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, ) if self.is_multimodal_model: # The value needs to be at least the number of multimodal tokens self.max_num_batched_tokens = max( self.max_num_batched_tokens, - _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, ) # When using default settings, @@ -2201,7 +2198,11 @@ class DeviceConfig: """Configuration for the device to use for vLLM execution.""" device: Union[Device, torch.device] = "auto" - """Device type for vLLM execution.""" + """Device type for vLLM execution. + This parameter is deprecated and will be + removed in a future release. + It will now be set automatically based + on the current platform.""" device_type: str = field(init=False) """Device type from the current platform. This is set in `__post_init__`.""" @@ -4312,18 +4313,6 @@ def __post_init__(self): "full_cuda_graph is not supported with " "cascade attention. Disabling cascade attention.") self.model_config.disable_cascade_attn = True - - if self.model_config and self.model_config.use_mla and \ - not (current_platform.is_cuda() or current_platform.is_rocm()): - logger.info( - "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") - self.scheduler_config.enable_chunked_prefill = False - self.scheduler_config.chunked_prefill_enabled = False - self.scheduler_config.max_num_batched_tokens = max( - self.scheduler_config.max_model_len, - _DEFAULT_MAX_NUM_BATCHED_TOKENS) - if self.cache_config is not None: self.cache_config.enable_prefix_caching = False diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index d4b34900b951..c04218cb9f39 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -22,8 +22,10 @@ def __init__(self, super().__init__(cpu_group, device, device_group, unique_name) self.dist_module = torch.distributed - if (current_platform.get_cpu_architecture() == CpuArchEnum.X86) \ - and hasattr(torch.ops._C, "init_shm_manager"): + if (current_platform.get_cpu_architecture() + == CpuArchEnum.X86) and hasattr( + torch.ops._C, + "init_shm_manager") and unique_name.startswith("tp"): self.dist_module = _CPUSHMDistributed(self) def all_reduce(self, input_): @@ -96,6 +98,8 @@ class _CPUSHMDistributed: def __init__(self, communicator: CpuCommunicator): instance_identifier = os.environ["VLLM_DIST_IDENT"] + unique_name = communicator.unique_name + instance_identifier = f"{instance_identifier}-{unique_name}" self.communicator = communicator group_ranks = [str(rank) for rank in self.communicator.ranks] diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index fa944407a703..40e57e6624d1 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -import os import pickle -import sys import time from contextlib import contextmanager from dataclasses import dataclass, field @@ -19,7 +17,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore import vllm.envs as envs -from vllm.distributed.utils import StatelessProcessGroup +from vllm.distributed.utils import StatelessProcessGroup, sched_yield from vllm.logger import init_logger from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path, is_valid_ipv6_address) @@ -28,20 +26,6 @@ logger = init_logger(__name__) -# We prefer to use os.sched_yield as it results in tighter polling loops, -# measured to be around 3e-7 seconds. However on earlier versions of Python -# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) -USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) - or (sys.version_info[:2] == (3, 10) - and sys.version_info[2] >= 8)) - - -def sched_yield(): - if USE_SCHED_YIELD: - os.sched_yield() - else: - time.sleep(0) - class ShmRingBuffer: diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 0b0ce9828a74..b1c9c9af6e23 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -44,8 +44,9 @@ def get_model_args(self, model_executable: torch.nn.Module): head_size = model_config.qk_nope_head_dim + \ model_config.qk_rope_head_dim else: - head_size = getattr(model_config, "head_dim", - int(hidden_size // num_attention_heads)) + head_size = getattr(model_config, "head_dim", None) + if head_size is None: + head_size = int(hidden_size // num_attention_heads) return num_heads, head_size diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index cea454a0b597..0aabb260fd3d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -40,7 +40,7 @@ class MultiConnector(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) - self._connectors = [] + self._connectors: list[KVConnectorBase_V1] = [] ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "connectors") assert ktcs is not None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e6c83a0fc5bd..b00f097110b0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -209,7 +209,17 @@ def get_num_new_matched_tokens( rounded_num_prompt_tokens = round_down( len(request.prompt_token_ids), self.block_size) count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) - return count, count > 0 + if count > 0: + return count, True + + # NOTE: if count is 0 here, we have less than block_size + # tokens to pull after subtracting the local prefix cache hit. + # The remote only sends fully computed blocks, so there is + # nothing to transfer but we still need to notify the + # prefill worker so that the remote blocks are freed. + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + self._reqs_need_recv[request.request_id] = (request, []) # No remote prefill for this request. return 0, False @@ -225,10 +235,6 @@ def update_state_after_alloc(self, request: "Request", num_external_tokens, params) if params is not None and params.get("do_remote_prefill"): - # NOTE(rob): if prompt < block_size, no remote blocks - # since the remote only sends fully computed blocks, so - # skip recving for this request. num_external_tokens - # should be 0 if there are no remote blocks. if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", "remote_port")): @@ -253,6 +259,15 @@ def build_connector_meta( # Loop through scheduled reqs and convert to ReqMeta. for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None + # For the case where there are no remote blocks to pull + # (block_ids is empty), we don't need to schedule + # an async read on the worker side. + if not block_ids: + logger.debug( + "Skipping adding request %s to NixlConnectorMetadata, " + "as there are no remote blocks to pull", req_id) + continue + meta.add_new_req( request_id=req_id, local_block_ids=block_ids, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 0421a65a2c81..0fedb6fd5ed9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -288,7 +288,7 @@ def build_connector_meta( for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in self._requests_need_load: meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], + block_ids=new_req.block_ids, block_size=self._block_size, is_store=False) total_need_load += 1 @@ -299,7 +299,7 @@ def build_connector_meta( # the original prompt tokens. if not self._found_match_for_request(new_req): meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], + block_ids=new_req.block_ids, block_size=self._block_size, is_store=True) @@ -319,7 +319,7 @@ def build_connector_meta( # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. - block_ids = cached_req.new_block_ids[0] + block_ids = cached_req.new_block_ids meta.add_request(token_ids=token_ids, block_ids=block_ids, diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 6bb323d79d64..93a069d36c4b 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -6,9 +6,12 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import dataclasses import datetime +import os import pickle import socket +import sys import time +import uuid from collections import deque from collections.abc import Sequence from typing import Any, Optional @@ -27,6 +30,20 @@ logger = init_logger(__name__) +# We prefer to use os.sched_yield as it results in tighter polling loops, +# measured to be around 3e-7 seconds. However on earlier versions of Python +# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) +USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) + or (sys.version_info[:2] == (3, 10) + and sys.version_info[2] >= 8)) + + +def sched_yield(): + if USE_SCHED_YIELD: + os.sched_yield() + else: + time.sleep(0) + def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" @@ -212,10 +229,141 @@ def all_gather_obj(self, obj: Any) -> list[Any]: gathered_objs.append(recv_obj) return gathered_objs - def barrier(self): - """A barrier to synchronize all ranks.""" + def barrier(self, timeout: float = 30.0): + """A robust barrier to synchronize all ranks. + + + Uses a multi-phase approach to ensure all processes reach the barrier + before proceeding: + + 1. Each process signals it has reached the barrier + + 2. Each process signals that it has confirmed the arrival of all other + ranks. + + 3. Rank 0 waits for all other ranks to signal their departure to ensure + that all ranks have departed the barrier first. + + Args: + timeout: Maximum time in seconds to wait for each phase (in seconds) + + + Raises: + RuntimeError: If coordination fails or times out + """ + # Generate a barrier ID that is globally unique + try: + if self.rank == 0: + barrier_id = f"barrier_{uuid.uuid4()}" + self.broadcast_obj(barrier_id, src=0) + else: + barrier_id = self.broadcast_obj(None, src=0) + except Exception as e: + raise RuntimeError("Failed to broadcast barrier_id") from e + + # Phase 1: Signal arrival at barrier + # Wait for all processes to arrive + # We need all ranks to confirm the arrival of all other ranks. + # This is the key synchronization point. + arrival_key = f"arrival_{barrier_id}_{self.rank}" + try: + self.store.set(arrival_key, b"1") + except Exception as e: + raise RuntimeError("Failed to signal barrier arrival") from e + + start_time = time.time() + processes_arrived: set[int] = set() + + while len(processes_arrived) < self.world_size: + # Check for timeout + cur_time = time.time() + if cur_time - start_time > timeout: + raise RuntimeError("Barrier timed out after %f seconds", + timeout) + + # Check for each process + for i in range(self.world_size): + if i in processes_arrived: + continue + + key = f"arrival_{barrier_id}_{i}" + try: + # Try to get the key - if it exists, we'll get a value + # If it doesn't exist, it will throw an exception + self.store.get(key) + processes_arrived.add(i) + except KeyError: + # Key doesn't exist yet + pass + except Exception as check_e: + logger.debug("Error checking key existence: %s", check_e) + sched_yield() + + # Short sleep to avoid tight polling + if len(processes_arrived) < self.world_size: + sched_yield() + + # Phase 2: Signal departure from barrier + # We only care to block at this stage in rank 0, which runs the + # server side of the TCPStore. We want to make sure that all + # clients have departed the barrier before rank 0 in case the + # next thing after the barrier is a shutdown, including tearing + # down the TCPStore. Other ranks can exit the barrier immediately + # after signaling their departure. + departure_key = f"departure_{barrier_id}_{self.rank}" + try: + self.store.set(departure_key, b"1") + except Exception as e: + raise RuntimeError("Failed to signal barrier departure") from e + + if self.rank != 0: + return + + # Make rank 0 wait for all processes to signal departure + start_time = time.time() + processes_departed: set[int] = set() + + while len(processes_departed) < self.world_size: + # Check for timeout + if time.time() - start_time > timeout: + raise RuntimeError("Barrier departure timed out after %f s", + timeout) + + # Check for each process + for i in range(self.world_size): + if i in processes_departed: + continue + + key = f"departure_{barrier_id}_{i}" + try: + # Try to get the key - if it exists, we'll get a value + # If it doesn't exist, it will throw an exception + self.store.get(key) + processes_departed.add(i) + except KeyError: + # Key doesn't exist yet + pass + except Exception as check_e: + logger.debug("Error checking key existence: %s", check_e) + sched_yield() + + # Short sleep to avoid tight polling + if len(processes_departed) < self.world_size: + sched_yield() + + # Clean up keys to avoid leaking memory in the store for i in range(self.world_size): - self.broadcast_obj(None, src=i) + try: + self.store.delete_key(f"arrival_{barrier_id}_{i}") + except Exception: + logger.debug("Error deleting key: %s", + f'arrival_{barrier_id}_{i}') + + try: + self.store.delete_key(f"departure_{barrier_id}_{i}") + except Exception: + logger.debug("Error deleting key: %s", + f'departure_{barrier_id}_{i}') @staticmethod def create( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index eeba9b30bd0a..a1c58fc3c148 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -577,7 +577,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action=argparse.BooleanOptionalAction, deprecated=True, help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as " - "of v0.8.6. Use `--reasoning-parser` to specify the reasoning " + "of v0.9.0. Use `--reasoning-parser` to specify the reasoning " "parser backend instead. This flag (`--enable-reasoning`) will be " "removed in v0.10.0. When `--reasoning-parser` is specified, " "reasoning mode is automatically enabled.") @@ -737,7 +737,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: title="DeviceConfig", description=DeviceConfig.__doc__, ) - device_group.add_argument("--device", **device_kwargs["device"]) + device_group.add_argument("--device", + **device_kwargs["device"], + deprecated=True) # Speculative arguments speculative_group = parser.add_argument_group( @@ -977,7 +979,7 @@ def create_engine_config( from vllm.platforms import current_platform current_platform.pre_register_and_update() - device_config = DeviceConfig(device=self.device) + device_config = DeviceConfig(device=current_platform.device_type) model_config = self.create_model_config() # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py index 1d1bba1d49ce..215fcf3c3e44 100644 --- a/vllm/entrypoints/cli/openai.py +++ b/vllm/entrypoints/cli/openai.py @@ -101,9 +101,18 @@ def cmd(args: argparse.Namespace) -> None: model_name, client = _interactive_cli(args) system_prompt = args.system_prompt conversation: list[ChatCompletionMessageParam] = [] + if system_prompt is not None: conversation.append({"role": "system", "content": system_prompt}) + if args.quick: + conversation.append({"role": "user", "content": args.quick}) + + chat_completion = client.chat.completions.create( + model=model_name, messages=conversation) + print(chat_completion.choices[0].message.content) + return + print("Please enter a message for the chat model:") while True: try: @@ -136,6 +145,12 @@ def subparser_init( default=None, help=("The system prompt to be added to the chat template, " "used for models that support system prompts.")) + chat_parser.add_argument("-q", + "--quick", + type=str, + metavar="MESSAGE", + help=("Send a single prompt as MESSAGE " + "and print the response, then exit.")) return chat_parser @@ -149,6 +164,13 @@ def __init__(self): @staticmethod def cmd(args: argparse.Namespace) -> None: model_name, client = _interactive_cli(args) + + if args.quick: + completion = client.completions.create(model=model_name, + prompt=args.quick) + print(completion.choices[0].text) + return + print("Please enter prompt to complete:") while True: input_prompt = input("> ") @@ -168,6 +190,13 @@ def subparser_init( "via the running API server."), usage="vllm complete [options]") _add_query_options(complete_parser) + complete_parser.add_argument( + "-q", + "--quick", + type=str, + metavar="PROMPT", + help= + "Send a single prompt and print the completion output, then exit.") return complete_parser diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index f7c7112b124f..054c0b006b2f 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -7,6 +7,7 @@ from .hermes_tool_parser import Hermes2ProToolParser from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser +from .llama4_pythonic_tool_parser import Llama4PythonicToolParser from .llama_tool_parser import Llama3JsonToolParser from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser @@ -16,5 +17,6 @@ "ToolParser", "ToolParserManager", "Granite20bFCToolParser", "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", - "PythonicToolParser", "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser" + "Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser", + "DeepSeekV3ToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 600ccbcf35d0..383e0d44de99 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -80,7 +80,8 @@ def extract_tool_calls( function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"]), + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False), ), ) for function_call in raw_function_calls ] @@ -166,7 +167,8 @@ def extract_tool_calls_streaming( if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) sent = len( self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] @@ -218,7 +220,8 @@ def extract_tool_calls_streaming( if cur_arguments: sent = len( self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) prev_arguments = self.prev_tool_call_arr[ self.current_tool_id].get("arguments") @@ -226,7 +229,8 @@ def extract_tool_calls_streaming( if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) if cur_args_json != prev_args_json: prefix = find_common_prefix( diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index 6710e7938c43..b8bf142530ee 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -67,7 +67,8 @@ def extract_tool_calls( function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"]), + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False), ), ) for function_call in raw_function_calls ] @@ -151,7 +152,8 @@ def extract_tool_calls_streaming( if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) sent = len( self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] @@ -197,7 +199,8 @@ def extract_tool_calls_streaming( if cur_arguments: sent = len( self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) prev_arguments = self.prev_tool_call_arr[ self.current_tool_id].get("arguments") @@ -205,7 +208,8 @@ def extract_tool_calls_streaming( if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) if cur_args_json != prev_args_json: prefix = find_common_prefix( prev_args_json, cur_args_json) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 5abd553d884d..3f2799f8010a 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -133,7 +133,8 @@ def extract_tool_calls_streaming( delta = None # first time to get parameters elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) arguments_delta = cur_arguments_json[:cur_arguments_json. index(delta_text) + @@ -148,8 +149,10 @@ def extract_tool_calls_streaming( self.current_tool_id] += arguments_delta # both prev and cur parameters, send the increase parameters elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) argument_diff = extract_intermediate_diff( cur_args_json, prev_args_json) @@ -190,7 +193,8 @@ def extract_tool_calls( action_dict = json.loads(action) name, parameters = action_dict['name'], json.dumps( action_dict.get('parameters', action_dict.get('arguments', - {}))) + {})), + ensure_ascii=False) if not tools or name not in [t.function.name for t in tools]: ExtractedToolCallInformation(tools_called=False, diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index e882ca2605e2..2714a545f997 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -96,8 +96,9 @@ def extract_tool_calls( function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"]))) - for function_call in raw_function_calls + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False), + )) for function_call in raw_function_calls ] content = model_output[:model_output. @@ -187,7 +188,7 @@ def extract_tool_calls_streaming( diff: Union[str, None] = current_tool_call.get("arguments") if diff: - diff = json.dumps(diff).replace( + diff = json.dumps(diff, ensure_ascii=False).replace( self.streamed_args_for_tool[self.current_tool_id], "") delta = DeltaMessage(tool_calls=[ @@ -248,7 +249,8 @@ def extract_tool_calls_streaming( "mid-arguments") delta = None elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) logger.debug("finding %s in %s", new_text, cur_arguments_json) @@ -267,8 +269,10 @@ def extract_tool_calls_streaming( self.current_tool_id] += arguments_delta elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) logger.debug("Searching for diff between \n%s\n%s", cur_args_json, prev_args_json) diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py new file mode 100644 index 000000000000..f483ac4eeee6 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 + +import ast +import json +import re +from collections.abc import Sequence +from typing import Any, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class _UnexpectedAstError(Exception): + pass + + +@ToolParserManager.register_module("llama4_pythonic") +class Llama4PythonicToolParser(ToolParser): + """ + Toolcall parser for Llama4 that produce tool calls in a pythonic style + Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic + """ + # TODO(mdepinet): Possible future improvements: + # 1. Support text + tools separated by either <|python_tag|> or \n\n + # 2. Support tools outside of a list (or separated by a semicolon). + # This depends on item 1 for consistent streaming. + # Neither of these are necessary for e.g. ToolACE, but both would help make + # Llama3.2 models more reliable. + + TOOL_CALL_REGEX = re.compile( + r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", + re.DOTALL) + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # Rename for readability. This is NOT a tool id. + @property + def current_tool_index(self) -> int: + return self.current_tool_id + + @current_tool_index.setter + def current_tool_index(self, value: int) -> None: + self.current_tool_id = value + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + + # remove <|python_start|> and <|python_end|> + # as Llama 4 model sometime will output those tokens + if model_output.startswith("<|python_start|>"): + model_output = model_output[len("<|python_start|>"):] + model_output = model_output.replace("<|python_end|>", "") + if not (self.TOOL_CALL_REGEX.match(model_output)): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + module = ast.parse(model_output) + parsed = getattr(module.body[0], "value", None) + if isinstance(parsed, ast.List) and all( + isinstance(e, ast.Call) for e in parsed.elts): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ], + content=None) + else: + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + except Exception: + logger.exception("Error in extracting tool call from response.") + # Treat as regular text + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not current_text.startswith("[") and not current_text.startswith( + "<|python_start|>"): + return DeltaMessage(content=delta_text) + + try: + # remove <|python_start|> and <|python_end|> + if current_text.startswith("<|python_start|>"): + current_text = current_text[len("<|python_start|>"):] + if current_text.endswith("<|python_end|>"): + current_text = current_text[:current_text. + rfind("<|python_end|>")] + valid_and_added_text = _make_valid_python(current_text) + if valid_and_added_text is None: + return None + valid_text, added_text = valid_and_added_text + + module = ast.parse(valid_text) + parsed = getattr(module.body[0], "value", None) + if not isinstance(parsed, ast.List) or not all( + isinstance(e, ast.Call) for e in parsed.elts): + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + tool_calls = [ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ] + + tool_deltas = [] + for index, new_call in enumerate(tool_calls): + if index < self.current_tool_index: + continue + + self.current_tool_index = index + if len(self.streamed_args_for_tool) == index: + self.streamed_args_for_tool.append("") + + new_call_complete = index < len( + tool_calls) - 1 or ")]" not in added_text + if new_call_complete: + self.current_tool_index += 1 + + withheld_suffix = (added_text[:-2] + if not new_call_complete else "") + if not new_call_complete and added_text[-2] == ")": + # Function call is incomplete. Withhold the closing bracket. + withheld_suffix = withheld_suffix + "}" + # Strings get single quotes in the model-produced string. + # JSON requires double quotes. + withheld_suffix = withheld_suffix.replace("'", '"') + delta = _compute_tool_delta(self.streamed_args_for_tool[index], + new_call, index, withheld_suffix) + + if delta is not None: + tool_deltas.append(delta) + if (delta.function is not None + and delta.function.arguments is not None): + self.streamed_args_for_tool[ + index] += delta.function.arguments + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining it's final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if tool_deltas and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if tool_deltas: + return DeltaMessage(tool_calls=tool_deltas) + elif not added_text and self.current_tool_id > 0: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage(content='') + else: + return None + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + +def _get_parameter_value(val: ast.expr) -> Any: + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + raise _UnexpectedAstError( + "Dict tool call arguments must have literal keys") + return { + k.value: _get_parameter_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [_get_parameter_value(v) for v in val.elts] + else: + raise _UnexpectedAstError("Tool call arguments must be literals") + + +def _handle_single_tool(call: ast.Call) -> ToolCall: + if not isinstance(call.func, ast.Name): + raise _UnexpectedAstError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = _get_parameter_value(keyword.value) + return ToolCall(type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(arguments))) + + +def _make_valid_python(text: str) -> Union[tuple[str, str], None]: + bracket_stack = [] + for index, char in enumerate(text): + if char in {"[", "(", "{"}: + bracket_stack.append(char) + elif char == "]": + if not bracket_stack or bracket_stack.pop() != "[": + raise _UnexpectedAstError("Mismatched square brackets") + elif char == ")": + if not bracket_stack or bracket_stack.pop() != "(": + raise _UnexpectedAstError("Mismatched parentheses") + elif char == "}": + if not bracket_stack or bracket_stack.pop() != "{": + raise _UnexpectedAstError("Mismatched curly braces") + elif char in {"'", '"'}: + if bracket_stack and bracket_stack[-1] == char: + if index > 0 and text[index - 1] == "\\": + # Treat an escaped quote as a regular character + pass + else: + bracket_stack.pop() + elif bracket_stack and bracket_stack[-1] in {"'", '"'}: + # Double quote within a single quote string or vice versa. + pass + else: + bracket_stack.append(char) + + text = text.rstrip() + if text.endswith("=") or text.endswith(":"): + # Since we have no type information for this property/parameter value, + # we can't fill in a valid value. + return None + if bracket_stack and bracket_stack[-1] == "{": + trailing_dict_text = text[:text.rfind("{")] + num_keys = trailing_dict_text.count(":") + num_values = trailing_dict_text.count(",") + if num_keys <= num_values: + return None # Incomplete property name within parameter value + if bracket_stack and bracket_stack[-1] == "(": + trailing_params_text = text[:text.rfind("(")] + num_full_param_names = trailing_params_text.count("=") + num_full_param_values = trailing_params_text.count(",") + if num_full_param_names <= num_full_param_values: + return None # Incomplete parameter name + if text.endswith(","): + text = text[:-1] + if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( + "[") and not text.endswith(")"): + return None # Incomplete function name + + added_text = "" + for char in reversed(bracket_stack): + if char == "[": + added_text += "]" + elif char == "(": + added_text += ")" + elif char == "{": + added_text += "}" + elif char == "'": + added_text += "'" + elif char == '"': + added_text += '"' + + return text + added_text, added_text + + +def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, + index: int, + withheld_suffix: str) -> Union[DeltaToolCall, None]: + new_call_args = new_call.function.arguments + if withheld_suffix: + assert new_call_args.endswith(withheld_suffix) + new_call_args = new_call_args[:-len(withheld_suffix)] + if not previously_sent_args: + return DeltaToolCall(id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + )) + + arg_diff = new_call_args[len(previously_sent_args):] + return DeltaToolCall( + id=None, index=index, function=DeltaFunctionCall( + arguments=arg_diff)) if arg_diff else None diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 561402a72bd4..4eda7044cbba 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -88,7 +88,8 @@ def extract_tool_calls( # function call args are JSON but as a string arguments=json.dumps(raw_function_call["arguments"] \ if "arguments" in raw_function_call \ - else raw_function_call["parameters"]))) + else raw_function_call["parameters"], + ensure_ascii=False))) for raw_function_call in function_call_arr ] @@ -174,7 +175,8 @@ def extract_tool_calls_streaming( if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) sent = len( self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] @@ -226,7 +228,8 @@ def extract_tool_calls_streaming( if cur_arguments: sent = len( self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) prev_arguments = self.prev_tool_call_arr[ self.current_tool_id].get("arguments") @@ -234,7 +237,8 @@ def extract_tool_calls_streaming( if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) if cur_args_json != prev_args_json: prefix = find_common_prefix( diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 798f346fc97d..b403a146716d 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -79,10 +79,11 @@ def extract_tool_calls( name=raw_function_call["name"], # function call args are JSON but as a string arguments=json.dumps( - raw_function_call["arguments"] if "arguments" in - raw_function_call else - raw_function_call["parameters"]))) - for raw_function_call in function_call_arr + raw_function_call["arguments"] + if "arguments" in raw_function_call else + raw_function_call["parameters"], + ensure_ascii=False), + )) for raw_function_call in function_call_arr ] # get any content before the tool call diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 22018c0d4f4f..548ff39d1ca4 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -200,9 +200,12 @@ def _handle_single_tool(call: ast.Call) -> ToolCall: arguments = {} for keyword in call.keywords: arguments[keyword.arg] = _get_parameter_value(keyword.value) - return ToolCall(type="function", - function=FunctionCall(name=function_name, - arguments=json.dumps(arguments))) + return ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(arguments, + ensure_ascii=False)), + ) def _make_valid_python(text: str) -> Union[tuple[str, str], None]: diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..3e0ad0d5a989 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f1cb77f64eae..31efe16d1c27 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -50,8 +50,7 @@ else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): - # the iterative moe implementation is used until the moe_pallas is fixed - from .moe_torch_iterative import fused_moe as fused_moe_pallas + from .moe_pallas import fused_moe as fused_moe_pallas else: fused_moe_pallas = None # type: ignore logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index 8f28b64ed487..babeb97308a9 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -2,7 +2,23 @@ import torch import torch.nn.functional as F -from torch_xla.experimental.custom_kernel import _histogram + + +def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: + """ + Compute the histogram of a int32 tensor. The bin edges are defined by the + min and max values, with step = 1. + """ + assert input.dtype == torch.int32, "input must be of torch.int32 dtype." + assert min <= max, "min must be less than or equal to max." + + def searchsorted(sorted_sequence: torch.Tensor, + values_to_search: torch.Tensor) -> torch.Tensor: + return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1) + + bin_edges = torch.linspace(min, max, max - min + 1, + dtype=input.dtype).to(input.device) + return searchsorted(bin_edges, input).to(torch.int32) def fused_moe( @@ -61,7 +77,7 @@ def fused_moe( x = torch.ops.xla.gmm(x, w2, group_sizes) x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) - x = x * topk_weights.unsqueeze_(dim=-1) + x = x * topk_weights.unsqueeze(dim=-1) x = x.sum(dim=-2) x = x.reshape(orig_shape) return x diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 270e7cf1298a..cb396f26c96e 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -182,3 +182,7 @@ def moe_unpermute( expert_first_token_offset, n_expert, n_local_expert, topk, hidden_states) return hidden_states + + +def moe_permute_unpermute_supported(): + return torch.ops._moe_C.moe_permute_unpermute_supported() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 54dd1251e59f..dd2e477f3954 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -261,6 +261,7 @@ class ReplicatedLinear(LinearBase): quant_config: Quantization configure. prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. """ def __init__( @@ -523,6 +524,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. """ def __init__( @@ -805,6 +807,7 @@ class QKVParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. """ def __init__( @@ -1155,7 +1158,13 @@ class RowParallelLinear(LinearBase): bias can be fused with other element-wise operations. We skip adding bias but instead return it. params_dtype: Data type for the parameters. + reduce_results: If true, call all-reduce on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y = X_iA_i quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.down_proj) + return_bias: If true, return bias together with outputs in forward pass. """ def __init__( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index bc6e6fcdd0a2..f94ab75f9a4f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -34,7 +34,11 @@ @CustomOp.register("mixer2_gated_rms_norm") class Mixer2RMSNormGated(CustomOp): - def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): + def __init__(self, + full_hidden_size: int, + full_n_groups: int, + use_rms_norm: bool = True, + eps: float = 1e-6): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() @@ -44,11 +48,17 @@ def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): self.n_groups = full_hidden_size // self.group_size self.variance_epsilon = eps - self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) - set_weight_attrs(self.weight, - {"weight_loader": sharded_weight_loader(0)}) - assert self.full_hidden_size % self.tp_size== 0,\ - "Tensor parallel world size must divide hidden size." + self.use_rms_norm = use_rms_norm + if self.use_rms_norm: + # Register norm weight only if we're actually applying RMSNorm + self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) + set_weight_attrs(self.weight, + {"weight_loader": sharded_weight_loader(0)}) + else: + # Avoid checkpoint mismatch by skipping unused parameter + self.register_parameter("weight", None) + assert (self.full_hidden_size % self.tp_size == 0 + ), "Tensor parallel world size must divide hidden size." def forward_native( self, @@ -66,6 +76,8 @@ def forward_native( # the input and then redundantly compute the RMSNorm. input_dtype = x.dtype x = x * nn.functional.silu(gate.to(torch.float32)) + if not self.use_rms_norm: + return x.to(input_dtype) if self.n_groups == 1: if self.tp_size > 1: @@ -74,7 +86,7 @@ def forward_native( global_sums = tensor_model_parallel_all_reduce(local_sums) # Calculate the variance count = self.tp_size * x.shape[-1] - variance = (global_sums / count) + variance = global_sums / count else: variance = x.pow(2).mean(-1, keepdim=True) @@ -105,6 +117,11 @@ def forward_cuda( x: torch.Tensor, gate: torch.Tensor, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + input_dtype = x.dtype + if not self.use_rms_norm: + # Keep gate in float32 for numerical stability during silu + return x * nn.functional.silu(gate.to( + torch.float32)).to(input_dtype) if self.tp_size > 1 or self.n_groups != 1: return self.forward_native(x, gate) @@ -124,7 +141,7 @@ def forward_cuda( def extra_groups_for_head_shards(ngroups: int, tp_size: int): - """Compute the increase in group numbers to account for + """Compute the increase in group numbers to account for replication in order to accompany the head shards.""" # in the case ngoups % tp_size == 0, this will be zero @@ -182,13 +199,15 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # seem to handle slices well. # https://github.com/python/mypy/issues/2410 param.data[ - boundary:(boundary + take), # type: ignore[misc] - ...] = loaded_weight[loaded_start_idx:( # type: ignore[misc] - loaded_start_idx + take)] # type: ignore[misc] + boundary:(boundary + take), + ... # type: ignore[misc] + ] = loaded_weight[loaded_start_idx:(loaded_start_idx + + take) # type: ignore[misc] + ] # type: ignore[misc] # move indexing boundaries boundary += shard_size - loaded_boundary += (full_dim - extra) + loaded_boundary += full_dim - extra return loader @@ -206,19 +225,22 @@ class MambaMixer2(CustomOp): **selective** state spaces) """ - def __init__(self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - use_conv_bias: bool, - use_bias: bool, - n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, - rms_norm_eps: float = 1e-5, - activation="silu", - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() # For TP, the sharding plan is as follows: @@ -238,17 +260,16 @@ def __init__(self, self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() - assert num_heads % self.tp_size == 0, \ - "Tensor parallel world size must divide num heads." + assert (num_heads % self.tp_size == 0 + ), "Tensor parallel world size must divide num heads." - assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ - ( - "If tensor parallel world size does not divide num_heads, " - "then num_groups must equal 1." - ) + assert (n_groups % self.tp_size) == 0 or n_groups == 1, ( + "If tensor parallel world size does not divide num_heads, " + "then num_groups must equal 1.") - assert self.tp_size == 1 or quant_config is None, \ - "Tensor parallel currently not supported for quantized models." + assert ( + self.tp_size == 1 or quant_config is None + ), "Tensor parallel currently not supported for quantized models." self.ssm_state_size = ssm_state_size self.activation = activation @@ -265,8 +286,7 @@ def __init__(self, self.n_groups = n_groups + extra_groups_for_head_shards( n_groups, self.tp_size) - self.conv_dim = (intermediate_size + - 2 * self.n_groups * ssm_state_size) + self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size self.conv1d = ColumnParallelLinear( input_size=conv_kernel_size, output_size=self.conv_dim, @@ -279,11 +299,12 @@ def __init__(self, # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = ColumnParallelLinear(input_size=hidden_size, - output_size=intermediate_size + - self.conv_dim + self.num_heads, - bias=use_bias, - quant_config=quant_config) + self.in_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config, + ) # - because in_proj is a concatenation of 3 weights, we # need to interleave them before sharding @@ -305,7 +326,8 @@ def __init__(self, # - ditto for the otther two weights below delattr(self.conv1d.bias, "weight_loader") set_weight_attrs( - self.conv1d.bias, { + self.conv1d.bias, + { "weight_loader": mamba_v2_sharded_weight_loader( [ @@ -316,18 +338,25 @@ def __init__(self, self.tp_size, tp_rank, ) - }) + }, + ) delattr(self.conv1d.weight, "weight_loader") set_weight_attrs( - self.conv1d.weight, { + self.conv1d.weight, + { "weight_loader": - mamba_v2_sharded_weight_loader([ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], self.tp_size, tp_rank) - }) + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }, + ) if quant_config is None: # - quant layers do not have a weight loader @@ -345,8 +374,10 @@ def __init__(self, head_setings, # for dt ], self.tp_size, - tp_rank) - }) + tp_rank, + ) + }, + ) # - these are TPed by heads to reduce the size of the # temporal shape @@ -357,6 +388,7 @@ def __init__(self, )) self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.use_rms_norm = use_rms_norm set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) a_weight_loader = composed_weight_loader( @@ -365,18 +397,25 @@ def __init__(self, set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - self.out_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=use_bias, - input_is_parallel=True, - quant_config=quant_config) + self.out_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config, + ) self.norm = Mixer2RMSNormGated(intermediate_size, n_groups, + self.use_rms_norm, eps=rms_norm_eps) - def forward_native(self, hidden_states: torch.Tensor, - conv_state: torch.Tensor, ssm_state: torch.Tensor): + def forward_native( + self, + hidden_states: torch.Tensor, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): pass def forward_cuda( @@ -384,6 +423,7 @@ def forward_cuda( hidden_states: torch.Tensor, mamba_cache_params: MambaCacheParams, mamba2_metadata: Mamba2Metadata, + mup_vector: Optional[torch.Tensor] = None, ): # mamba2_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill @@ -401,6 +441,10 @@ def forward_cuda( # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) + + if mup_vector is not None: + projected_states = projected_states * mup_vector + gate, hidden_states_B_C, dt = torch.split( projected_states, [ @@ -561,6 +605,9 @@ def forward_cuda( hidden_states = torch.vstack(ssd_output_list) # 4. gated MLP + # GatedRMSNorm internally applying SiLU to the gate + # SiLU is applied internally before normalization, unlike standard + # norm usage hidden_states = self.norm(hidden_states, gate) # 5. Final linear projection diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index a713b1e93c2d..407b9c72f41d 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -14,7 +14,7 @@ "ptpc_fp8", "fbgemm_fp8", "modelopt", - "nvfp4", + "modelopt_fp4", "marlin", "bitblas", "gguf", @@ -33,6 +33,7 @@ "quark", "moe_wna16", "torchao", + "auto-round", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -84,6 +85,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig from .aqlm import AQLMConfig + from .auto_round import AutoRoundConfig from .awq import AWQConfig from .awq_marlin import AWQMarlinConfig from .bitblas import BitBLASConfig @@ -118,7 +120,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, "modelopt": ModelOptFp8Config, - "nvfp4": ModelOptNvFp4Config, + "modelopt_fp4": ModelOptNvFp4Config, "marlin": MarlinConfig, "bitblas": BitBLASConfig, "gguf": GGUFConfig, @@ -138,6 +140,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "quark": QuarkConfig, "moe_wna16": MoeWNA16Config, "torchao": TorchAOConfig, + "auto-round": AutoRoundConfig, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py new file mode 100644 index 000000000000..a5e63843cf62 --- /dev/null +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fractions import Fraction +from typing import Any, Optional, Union + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class AutoRoundConfig(QuantizationConfig): + """Config class for AutoRound. + Reference: https://arxiv.org/pdf/2309.05516 + """ + + SUPPORTED_BITS = {2, 3, 4, 8} + SUPPORTED_DTYPES = {"int"} + SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"} + SUPPORTED_BACKENDS = { + "auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin", "ipex" + } + + def __init__( + self, + weight_bits: int, + group_size: int, + sym: bool = True, + packing_format: str = "auto_round:auto_gptq", + block_name_to_quantize: Optional[Union[str, list[str]]] = None, + extra_config: Optional[dict[str, Any]] = None, + data_type: str = "int", + backend: str = "auto", + ) -> None: + super().__init__() + if weight_bits not in self.SUPPORTED_BITS: + raise ValueError(f"Unsupported weight_bits: {weight_bits}, " + f"currently only support {self.SUPPORTED_BITS}") + if data_type not in self.SUPPORTED_DTYPES: + raise ValueError( + f"Unsupported data_type: {data_type}," + f" currently only support {self.SUPPORTED_DTYPES}") + if packing_format not in self.SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported packing_format: {packing_format}, " + f"currently only support {self.SUPPORTED_FORMATS}") + if backend not in self.SUPPORTED_BACKENDS: + raise ValueError( + f"Unsupported backend: {backend}, " + f"currently only support {self.SUPPORTED_BACKENDS}") + + self.weight_bits = weight_bits + self.group_size = group_size + self.sym = sym + self.packing_format = packing_format + self.block_name_to_quantize = (block_name_to_quantize.split(",") if + isinstance(block_name_to_quantize, str) + else block_name_to_quantize) + self.extra_config = extra_config + self.data_type = data_type + self.backend = backend + self.pack_factor = Fraction(32, weight_bits) + + def __repr__(self) -> str: + return (f"AutoRoundConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, sym={self.sym})") + + @classmethod + def get_name(cls): ## use str will trigger preci issue + return "auto-round" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantization_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig": + return cls( + weight_bits=cls.get_from_keys(config, ["bits"]), + group_size=cls.get_from_keys(config, ["group_size"]), + sym=cls.get_from_keys(config, ["sym"]), + packing_format=cls.get_from_keys_or(config, ["packing_format"], + "auto_round:auto_gptq"), + block_name_to_quantize=cls.get_from_keys_or( + config, ["block_name_to_quantize", "to_quant_block_names"], + None), + extra_config=cls.get_from_keys_or(config, ["extra_config"], None), + data_type=cls.get_from_keys_or(config, ["data_type"], "int"), + backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], + "auto"), + ) + + def get_layer_config(self, layer, layer_name: str): + # Priority: extra_config > block_name_to_quantize > type fallback + if self.extra_config and layer_name in self.extra_config: + cfg = self.extra_config[layer_name] + return cfg.get("bits", self.weight_bits), cfg.get( + "group_size", self.group_size), cfg.get("sym", self.sym) + + quantized = True + if self.block_name_to_quantize: + quantized = any(name in layer_name + for name in self.block_name_to_quantize) + elif isinstance(layer, ParallelLMHead): + quantized = False + + return (self.weight_bits, self.group_size, + self.sym) if quantized else (16, -1, True) + + def check_quantized(self, weight_bits: int) -> bool: + return weight_bits < 16 + + def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, check_moe_marlin_supports_layer) + + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + + logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", + prefix, layer.__class__.__name__, weight_bits, group_size, + sym) + if backend == "auto" or "marlin" in backend: + if isinstance(layer, FusedMoE): + use_marlin = check_moe_marlin_supports_layer(layer, group_size) + else: + + AWQ_TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + use_marlin = ((weight_bits, sym) in AWQ_TYPE_MAP + and check_marlin_supported( + AWQ_TYPE_MAP[(weight_bits)], group_size, + not sym)) + else: + use_marlin = False + if use_marlin: + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod) + quant_args_marlin = AWQMarlinConfig(weight_bits=weight_bits, + group_size=group_size, + zero_point=not sym, + lm_head_quantized=False, + full_config={}, + modules_to_not_convert=[]) + else: + from vllm.model_executor.layers.quantization.awq import ( + AWQConfig, AWQLinearMethod) + quant_args = AWQConfig( + weight_bits=weight_bits, + group_size=group_size, + zero_point=not sym, + ) + + if isinstance(layer, FusedMoE): + if use_marlin: + return AWQMoEMethod(quant_args_marlin) + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + config = { + "linear_quant_method": "awq", + "weight_bits": weight_bits, + "group_size": group_size, + "zero_point": not sym, + } + return MoeWNA16Config.from_config(config).get_quant_method( + layer, prefix) + + if isinstance(layer, (LinearBase, ParallelLMHead)): + if use_marlin: + return AWQMarlinLinearMethod(quant_args_marlin) + else: + return AWQLinearMethod(quant_args) + return None + + def apply_gptq_quant_layer(self, + layer, + prefix: str, + backend: str = "auto"): + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, check_moe_marlin_supports_layer) + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + + logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", + prefix, layer.__class__.__name__, weight_bits, group_size, + sym) + if backend == "auto" or "marlin" in backend: + if isinstance(layer, FusedMoE): + use_marlin = check_moe_marlin_supports_layer(layer, group_size) + else: + GPTQ_TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP + and check_marlin_supported( + GPTQ_TYPE_MAP[(weight_bits, sym)], + group_size, + has_zp=not sym)) + else: + use_marlin = False + if use_marlin: + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod) + quant_args_marlin = GPTQMarlinConfig(weight_bits=weight_bits, + group_size=group_size, + is_sym=sym, + lm_head_quantized=False, + desc_act=False, + dynamic={}, + full_config={}) + else: + from vllm.model_executor.layers.quantization.gptq import ( + GPTQConfig, GPTQLinearMethod) + quant_args = GPTQConfig(weight_bits=weight_bits, + group_size=group_size, + lm_head_quantized=False, + desc_act=False, + dynamic={}) + + if isinstance(layer, FusedMoE): + if use_marlin: + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + config = { + "linear_quant_method": "gptq", + "weight_bits": weight_bits, + "group_size": group_size, + "sym": sym, + "lm_head_quantized": False, + } + return MoeWNA16Config.from_config(config).get_quant_method( + layer, prefix) + return GPTQMarlinMoEMethod(quant_args_marlin) + + if isinstance(layer, (LinearBase, ParallelLMHead)): + if use_marlin: + return GPTQMarlinLinearMethod(quant_args_marlin) + else: + return GPTQLinearMethod(quant_args) + + return None + + def apply_ipex_quant_layer(self, layer, prefix: str): + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + from vllm.model_executor.layers.quantization.ipex_quant import ( + IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod) + if isinstance(layer, (LinearBase, ParallelLMHead)): + if "awq" in self.packing_format: + config = IPEXConfig(method="awq", + weight_bits=weight_bits, + group_size=group_size) + return IPEXAWQLinearMethod(config) + elif "gptq" in self.packing_format: + config = IPEXConfig(method="gptq", + weight_bits=weight_bits, + group_size=group_size) + return IPEXGPTQLinearMethod(config) + else: + raise ValueError( + f"ipex backend only supports awq " + f"and gtpq format,but got {self.packing_format}") + else: + return None + + def get_quant_method(self, layer: torch.nn.Module, prefix: str): + if (current_platform.is_cpu() or current_platform.is_xpu() + or self.backend == "ipex"): + return self.apply_ipex_quant_layer(layer, prefix) + if "gptq" in self.packing_format or "gptq" in self.backend: + return self.apply_gptq_quant_layer(layer, prefix) + if "awq" in self.packing_format or "awq" in self.backend: + return self.apply_awq_quant_layer(layer, prefix) diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 8bce6bba460a..b7baa3d3363b 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -181,8 +181,6 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) - if bias is not None: - out.add_(bias) return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 97167cb5833d..1c5680f952ab 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -192,7 +192,7 @@ def __init__( @classmethod def get_name(cls) -> QuantizationMethods: - return "nvfp4" + return "modelopt_fp4" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 47a7a06bb744..6771c128c5a1 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -35,6 +35,7 @@ download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, pt_weights_iterator, safetensors_weights_iterator) +from vllm.model_executor.models import is_pooling_model from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -133,6 +134,16 @@ def _prepare_weights(self, model_name_or_path: str, return hf_weights_files, use_safetensors def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + def _maybe_pool_model(module_name:str): + # For pool model, we need to add the prefix `model.` + # for the weight name if possible. + if self.is_pool_model and self.target_modules[0]. \ + startswith("model.") and not module_name.startswith( + "model."): + return "model."+module_name + + return module_name + if use_safetensors: iterator = safetensors_weights_iterator( hf_weights_files, @@ -148,6 +159,9 @@ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): # mapping weight names from transformers to vllm while preserving # original names. mapped_name = self.weight_mapper(org_name) + mapped_name=_maybe_pool_model(mapped_name) + + yield org_name, mapped_name, param def _get_quantized_weights_iterator( @@ -405,7 +419,7 @@ def _load_weights(self, model_config: ModelConfig, raise AttributeError( f"Model {type(model).__name__} does not support BitsAndBytes " "quantization yet. No 'packed_modules_mapping' found.") - + self.is_pool_model=is_pooling_model(model) self.modules_mapping = ParamMapping( copy.deepcopy(model.packed_modules_mapping)) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index 1c4f66061d1d..557feea46a90 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -48,6 +48,9 @@ # Models supported by Neuronx distributed for inference. _NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = { "LlamaForCausalLM": + ("neuronx_distributed_inference.models.llama.modeling_llama", + "NeuronLlamaForCausalLM"), + "MistralForCausalLM": ("neuronx_distributed_inference.models.llama.modeling_llama", "NeuronLlamaForCausalLM"), "DbrxForCausalLM": @@ -84,16 +87,29 @@ def forward( input_block_ids: torch.Tensor, sampling_params: torch.Tensor, ) -> torch.Tensor: + # sort block ids sequentially for perf/neuron support reasons + sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) + input_ids = torch.index_select(input_ids, 0, sorted_indices) + positions = torch.index_select(positions, 0, sorted_indices) + sampling_params = torch.index_select(sampling_params, 0, + sorted_indices) + output = self.model(input_ids, attention_mask=None, position_ids=positions, - seq_ids=input_block_ids, + seq_ids=sorted_input_block_ids, sampling_params=sampling_params) # on-device sampling if self.config.neuron_config.on_device_sampling_config: - return output.hidden_states + output = output.hidden_states else: - return output.logits[:, -1, :] + output = output.logits[:, -1, :] + + restored_indices = torch.argsort(sorted_indices) + if input_block_ids.shape[0] != 1: + output = torch.index_select(output, 0, restored_indices) + + return output def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: @@ -143,8 +159,8 @@ def load_weights(self, model_name_or_path: str, **kwargs): config = neuronx_model_cls.get_config_cls()( neuron_config, load_config=load_pretrained_config(model_name_or_path)) - hashed_config = hashlib.md5( - config.to_json_string().encode('utf-8')).hexdigest() + hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), + usedforsecurity=False).hexdigest() if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") elif os.path.exists(model_name_or_path): @@ -263,8 +279,8 @@ def load_weights(self, model_name_or_path: str, **kwargs): config = neuronx_model_cls.get_config_cls()( neuron_config, load_config=load_pretrained_config(model_name_or_path)) - hashed_config = hashlib.md5( - config.to_json_string().encode('utf-8')).hexdigest() + hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), + usedforsecurity=False).hexdigest() if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") elif os.path.exists(model_name_or_path): @@ -337,14 +353,26 @@ def forward( input_block_ids: torch.Tensor, sampling_params: torch.Tensor, ) -> torch.Tensor: + # sort block ids sequentially for perf/neuron support reasons + sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) + input_ids = torch.index_select(input_ids, 0, sorted_indices) + positions = torch.index_select(positions, 0, sorted_indices) + sampling_params = torch.index_select(sampling_params, 0, + sorted_indices) + output = self.model(input_ids, attention_mask=None, position_ids=positions, - seq_ids=input_block_ids, + seq_ids=sorted_input_block_ids, sampling_params=sampling_params) + restored_indices = torch.argsort(sorted_indices) + # CTX encoding if (positions[:, 0]).sum().item() == 0: - return output.fused_outputs[0][:, 0:1] + output = output.fused_outputs[0][:, 0:1] + if input_block_ids.shape[0] != 1: + output = torch.index_select(output, 0, restored_indices) + return output # Fused Spec (Generation) accepted_tokens_with_padding = output.fused_outputs[0] @@ -359,6 +387,10 @@ def forward( -1) >= generated_token_counts accepted_tokens_with_padding[mask] = -1 + if input_block_ids.shape[0] != 1: + accepted_tokens_with_padding = torch.index_select( + accepted_tokens_with_padding, 0, restored_indices) + return accepted_tokens_with_padding def sample( @@ -413,6 +445,10 @@ def load_weights(self, model_name_or_path: str, draft_neuron_config.speculation_length = 0 draft_neuron_config.trace_tokengen_model = True draft_neuron_config.enable_fused_speculation = False + if getattr(config.neuron_config, "draft_model_modules_to_not_convert", + None): + draft_neuron_config.modules_to_not_convert = ( + draft_neuron_config.draft_model_modules_to_not_convert) if config.neuron_config.enable_eagle_speculation: draft_neuron_config.is_eagle_draft = True draft_neuron_config.sequence_parallel_enabled = False @@ -426,8 +462,8 @@ def load_weights(self, model_name_or_path: str, config.fused_spec_config = fused_spec_config self.config.neuron_config = neuron_config - hashed_config = hashlib.md5( - config.to_json_string().encode('utf-8')).hexdigest() + hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), + usedforsecurity=False).hexdigest() if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") elif os.path.exists(model_name_or_path): @@ -499,7 +535,7 @@ def _get_default_neuron_config(model_config: ModelConfig, max_context_length=scheduler_config.max_model_len, seq_len=scheduler_config.max_model_len, enable_bucketing=True, - is_continuous_batching=(batch_size > 1), + is_continuous_batching=True, quantized=False, torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], padding_side="right", @@ -517,6 +553,7 @@ def _get_default_speculation_config(model_config: ModelConfig, args.""" neuron_config = dict( tp_degree=parallel_config.tensor_parallel_size, + ctx_batch_size=1, batch_size=scheduler_config.max_num_seqs, max_context_length=scheduler_config.max_model_len, seq_len=scheduler_config.max_model_len, @@ -524,6 +561,7 @@ def _get_default_speculation_config(model_config: ModelConfig, trace_tokengen_model=False, enable_fused_speculation=True, enable_bucketing=True, + is_continuous_batching=True, quantized=False, torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], on_device_sampling_config=dict( diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index e9fff705f1d4..609a180fd849 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -214,6 +214,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: group.add_argument( "--tensorizer-uri", + type=str, help="Path to serialized model tensors. Can be a local file path," " or an HTTP(S) or S3 URI.", ) @@ -226,6 +227,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) group.add_argument( "--encryption-keyfile", + type=str, default=None, help="The file path to a binary file containing a binary key to " "use for decryption. Can be a file path or S3 network URI.") @@ -239,18 +241,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "and model size. This greatly increases performance.") group.add_argument( "--s3-access-key-id", + type=str, default=None, help="The access key for the S3 bucket. Can also be set via the " "S3_ACCESS_KEY_ID environment variable.", ) group.add_argument( "--s3-secret-access-key", + type=str, default=None, help="The secret access key for the S3 bucket. Can also be set via " "the S3_SECRET_ACCESS_KEY environment variable.", ) group.add_argument( "--s3-endpoint", + type=str, default=None, help="The endpoint for the S3 bucket. Can also be set via the " "S3_ENDPOINT_URL environment variable.", diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index eb1085d6b40d..10424e218fbc 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -43,7 +43,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -229,6 +229,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + self.config = config self.embed_dim = config.hidden_size @@ -278,6 +279,38 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # NOTE: BLOOM's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): @@ -325,35 +358,15 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if name == "lm_head.weight": - continue - if not name.startswith("transformer."): - name = "transformer." + name - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - - if "query_key_value" in name: - # NOTE: BLOOM's fused QKV's output_dim has the shape of - # (num_heads * 3 * head_size), while the - # required shape is (3 * num_heads * head_size). - # Thus, we need weight conversion. - output_dim = getattr(param, "output_dim", None) - num_heads = self.config.num_attention_heads - if output_dim is not None: - loaded_weight_shape = loaded_weight.shape - loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) - loaded_weight = loaded_weight.reshape(loaded_weight_shape) - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"]) + weights = _add_transformer_prefix(weights) + return loader.load_weights(weights) + + +def _add_transformer_prefix( + weights: Iterable[tuple[str, torch.Tensor]] +) -> Iterable[tuple[str, torch.Tensor]]: + for name, tensor in weights: + if not name.startswith('transformer.'): + name = 'transformer.' + name + yield name, tensor diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 4ffd06319684..838560692bcf 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -127,8 +127,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py new file mode 100644 index 000000000000..1c0e3911fcce --- /dev/null +++ b/vllm/model_executor/models/falcon_h1.py @@ -0,0 +1,684 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only FalconH1 model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers import FalconH1Config + +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsV0Only) +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class FalconH1MLP(nn.Module): + + def __init__( + self, + config: FalconH1Config, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + ) + self.tp_size = get_tensor_model_parallel_world_size() + self.intermediate_size = config.intermediate_size + self.gate_multiplier, self.down_multiplier = config.mlp_multipliers + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x[:, :self.intermediate_size // self.tp_size] *= self.gate_multiplier + x = self.act_fn(x) + x, _ = self.down_proj(x) + x = x * self.down_multiplier + return x + + +class FalconH1SSMDecoderLayer(nn.Module): + + def __init__( + self, + config: FalconH1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + + self.d_ssm = (int(config.mamba_expand * config.hidden_size) + if config.mamba_d_ssm is None else config.mamba_d_ssm) + + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=self.d_ssm, + use_conv_bias=config.mamba_conv_bias, + use_bias=config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + quant_config=quant_config, + use_rms_norm=config.mamba_rms_norm, + ) + # n_groups is overridden later by `MambaMixer2` + self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state + self.zxbcdt_multipliers = config.ssm_multipliers + self._init_mup_vector() + + def _init_mup_vector(self): + """ + Non learnable per-block scaling vector composed of element-wise + multipliersapplied to each separate contiguous block of the output + of the linear projection (in_proj) before further processing + (gating, convolution, SSM): + + - Z block: [0 : d_ssm] → zxbcdt_multipliers[0] + - X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1] + - B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2] + - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S] + → zxbcdt_multipliers[3] + - dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4] + + where: + - d_ssm: Dimension of state-space model latent + - G: Number of groups (n_groups) + - S: SSM state size per group + - All indices are divided by tp_size to support tensor parallelism + """ + vector_shape = (2 * self.d_ssm + 2 * self.groups_time_state_size + + self.config.mamba_n_heads) // self.tp_size + mup_vector = torch.ones(1, vector_shape) + # Z vector 0 -> d_ssm + mup_vector[:, :self.d_ssm // + self.tp_size] *= self.zxbcdt_multipliers[0] + # X vector d_ssm -> 2 * d_ssm + mup_vector[:, + (self.d_ssm // + self.tp_size):(2 * self.d_ssm // + self.tp_size)] *= self.zxbcdt_multipliers[1] + # B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state) + mup_vector[ + :, + (2 * self.d_ssm) // + self.tp_size:(2 * self.d_ssm + self.groups_time_state_size) // + self.tp_size, + ] *= self.zxbcdt_multipliers[2] + # C vector 2 * d_ssm + (n_group * d_state) + # -> 2 * d_ssm + 2 * (n_group * d_state) + mup_vector[ + :, + (2 * self.d_ssm + self.groups_time_state_size) // + self.tp_size:(2 * self.d_ssm + 2 * self.groups_time_state_size) // + self.tp_size, + ] *= self.zxbcdt_multipliers[3] + # dt vector 2 * d_ssm + 2 * (n_group * d_state) + # -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads + mup_vector[ + :, + (2 * self.d_ssm + 2 * self.groups_time_state_size) // + self.tp_size:, + ] *= self.zxbcdt_multipliers[4] + + self.register_buffer("mup_vector", mup_vector, persistent=False) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + hidden_states = self.mamba( + hidden_states, + mamba_cache_params, + mamba2_metadata=mamba2_metadata, + mup_vector=self.mup_vector, + ) + return hidden_states, residual + + +class FalconH1AttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: FalconH1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + rope_theta = getattr(config, "rope_theta", 1e11) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = (config.hidden_size // self.total_num_heads if getattr( + config, "head_dim", None) is None else config.head_dim) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if hasattr(config, "partial_rotary_factor"): + rotary_dim = self.head_dim * config.partial_rotary_factor + elif hasattr(config, "attn_rotary_emb"): + rotary_dim = config.attn_rotary_emb # for backward compatibility + else: + rotary_dim = self.head_dim # default + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + rope_scaling=rope_scaling, + base=rope_theta, + is_neox_style=True, + dtype=None, # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.key_multiplier = config.key_multiplier + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + k = k * self.key_multiplier + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + ) + return hidden_states, residual + + +class FalconH1ParallelHybrid(nn.Module): + """ + A hybrid decoder layer for FalconH1 where the input is processed + in parallel through both the self-attention branch and the SSM (Mamba) + branch. Their outputs are then summed to produce the final hidden state. + + This layer uses: + - FalconH1AttentionDecoderLayer for the multi-head self-attention branch. + - FalconH1SSMDecoderLayer for the state-space (Mamba) branch. + """ + + def __init__( + self, + config: FalconH1Config, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Instantiate the attention branch + self.self_attn = FalconH1AttentionDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + # Instantiate the SSM branch + self.mamba = FalconH1SSMDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + ) + self.ssm_out_multiplier = config.ssm_out_multiplier + self.ssm_in_multiplier = config.ssm_in_multiplier + + self.attention_in_multiplier = config.attention_in_multiplier + self.attn_out_multiplier = config.attention_out_multiplier + + self.feed_forward = FalconH1MLP(config) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Process input through the attention branch. + # FalconH1AttentionDecoderLayer expects positions, hidden_states, + # kv_cache, attn_metadata, and residual. + attn_hidden, _ = self.self_attn( + positions=positions, + hidden_states=hidden_states * self.attention_in_multiplier, + residual=residual, + **kwargs, + ) + + # Process input through the SSM branch. + # FalconH1SSMDecoderLayer expects hidden_states, attn_metadata, + # residual, mamba_cache_params, and sequence_idx. + ssm_hidden, _ = self.mamba( + hidden_states=hidden_states * self.ssm_in_multiplier, + residual=residual, + mamba_cache_params=mamba_cache_params, + mamba2_metadata=mamba2_metadata, + **kwargs, + ) + # Sum the outputs from both branches. + # We assume both branches produce outputs of the same + # dimensionality (config.hidden_size). + hidden_states = (attn_hidden * self.attn_out_multiplier) + ( + ssm_hidden * self.ssm_out_multiplier) + hidden_states = hidden_states + residual + + # feed-forward + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class FalconH1Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: FalconH1Config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank: + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.embedding_multiplier = config.embedding_multiplier + else: + self.embed_tokens = PPMissingLayer() + self.embedding_multiplier = 1.0 + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = FalconH1ParallelHybrid + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + if get_pp_group().is_last_rank: + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + else: + self.final_layernorm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + attn_metadata = get_forward_context().attn_metadata + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds * self.embedding_multiplier + else: + hidden_states = (self.get_input_embeddings(input_ids) * + self.embedding_multiplier) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i) + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + mamba_cache_params=layer_mamba_cache_params, + mamba2_metadata=mamba2_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsV0Only): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert (not cache_config.enable_prefix_caching + ), "FalconH1 currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = FalconH1Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.tie_word_embeddings = config.tie_word_embeddings + self.unpadded_vocab_size = config.vocab_size + self.mamba_cache: Optional[MambaCacheManager] = None + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + ) + self.lm_head_multiplier = config.lm_head_multiplier + if self.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + # Used to track and store by the Mamba cache between steps. + + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, + config.vocab_size, + scale=config.lm_head_multiplier, + ) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + if self.mamba_cache is None: + self.mamba_cache = MambaCacheManager( + self.vllm_config, + self.lm_head.weight.dtype + if hasattr(self.lm_head, 'weight') else torch.bfloat16, + self.config.num_hidden_layers, + *self._get_mamba_cache_shape(), + ) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model( + input_ids, + positions, + mamba_cache_params, + intermediate_tensors, + inputs_embeds, + ) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> tuple[tuple[int, int], tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = (int(self.config.mamba_expand * + hidden_size) if self.config.mamba_d_ssm + is None else self.config.mamba_d_ssm) + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size) + + # - heads and n_groups are TP-ed + conv_dim = intermediate_size + 2 * n_groups * self.config.mamba_d_state + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_n_heads, world_size), + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if "mamba" in name: + name = name.replace("mamba", "mamba.mamba") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if self.tie_word_embeddings and "lm_head" in name: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if self.tie_word_embeddings: + loaded_params.add("lm_head.weight") + return loaded_params diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index eed0820a5779..3524d036db22 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -122,8 +122,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = config.attention_multiplier @@ -478,18 +479,14 @@ def make_empty_intermediate_tensors( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = [ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached", - ] # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings: - skip_prefixes.append("lm_head.weight") + skip_prefixes = (["lm_head."] + if self.config.tie_word_embeddings else None) - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 6d2d16d098d4..578d31a851a9 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -550,10 +550,12 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = ["rotary_emb.inv_freq"] # Skip lm_head when tie_word_embeddings is True - if self.config.tie_word_embeddings: - skip_prefixes.append("lm_head") + skip_prefixes = (["lm_head"] + if self.config.tie_word_embeddings else None) - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index e731f1bfdb9a..581a32325d4c 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -135,11 +135,13 @@ def _get_num_unpadded_features( current_aspect_ratio = current_width / current_height if aspect_ratio > current_aspect_ratio: - new_height = (original_height * current_width) // original_width + new_height = int( + round(original_height * (current_width / original_width), 7)) padding = (current_height - new_height) // 2 current_height = current_height - (2 * padding) else: - new_width = (original_width * current_height) // original_height + new_width = int( + round(original_width * (current_height / original_height), 7)) padding = (current_width - new_width) // 2 current_width = current_width - (2 * padding) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 49f1ecb4be89..7ea759fd59b8 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -116,11 +116,13 @@ def _get_num_unpadded_features( current_aspect_ratio = current_width / current_height if aspect_ratio > current_aspect_ratio: - new_height = (original_height * current_width) // original_width + new_height = int( + round(original_height * (current_width / original_width), 7)) padding = (current_height - new_height) // 2 current_height = current_height - (2 * padding) else: - new_width = (original_width * current_height) // original_height + new_width = int( + round(original_width * (current_height / original_height), 7)) padding = (current_width - new_width) // 2 current_width = current_width - (2 * padding) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 9dffe96fc545..36bab9ee13b1 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -604,8 +604,9 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) - head_dim = getattr(config, "head_dim", - config.hidden_size // config.num_attention_heads) + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = config.hidden_size // config.num_attention_heads if hasattr(config, "max_model_len") and isinstance( config.max_model_len, int): max_position_embeddings = min(config.max_position_embeddings, @@ -861,8 +862,9 @@ def layer_fn(prefix): cache_shape=self.cache_shape) rope_theta = getattr(config, "rope_theta", 10000) - head_dim = getattr(config, "head_dim", - config.hidden_size // config.num_attention_heads) + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = config.hidden_size // config.num_attention_heads if hasattr(config, "max_model_len") and isinstance( config.max_model_len, int): max_position_embeddings = min(config.max_position_embeddings, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 1968bf9e68af..9bc7a16153e1 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -138,8 +138,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MixtralConfig has an optional head_dim argument - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -482,5 +483,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"]) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index b6a0c9ec6fc1..8220200d270c 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -193,8 +193,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MixtralConfig has an optional head_dim argument - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -447,8 +448,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 0b5a102ea1f2..d0999e30e1ba 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -158,8 +158,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -502,14 +503,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 26ca770d8493..fcb7c619a102 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -382,19 +382,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached", - "lm_head.weight" - ] if self.config.tie_word_embeddings else [ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index e4dc0e0cc411..33adacdae5f5 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -314,7 +314,8 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -325,6 +326,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() for name, loaded_weight in weights: if is_pp_missing_parameter(name, self): continue @@ -347,6 +349,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class Olmo2ForCausalLM(nn.Module, SupportsPP): @@ -403,19 +407,7 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached", - "lm_head.weight" - ] if self.config.tie_word_embeddings else [ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 9a07f57fd999..6364b89fb837 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -442,8 +442,5 @@ def compute_logits(self, hidden_states: torch.Tensor, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=["rotary_emb.inv_freq"], - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 1ccd1fe1f741..da2a194e6bdf 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -344,14 +344,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index b7bb3c45c633..418ff900ffd5 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1228,9 +1228,7 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: - weights = ((name, data) for name, data in weights - if "lora" not in name) - loader = AutoWeightsLoader(self) + loader = AutoWeightsLoader(self, skip_substrs=["lora"]) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 7f2e9fdf7c4e..d9917c26d1b1 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -660,8 +660,5 @@ def compute_logits(self, hidden_states: torch.Tensor, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 7cf98dc7a4ea..143b9f98b029 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -535,8 +535,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index aae5401721df..8a4c2850dda3 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -530,8 +530,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["rotary_emb.inv_freq"]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c55f7ccd344f..61115afa76d4 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -79,6 +79,7 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), + "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"), "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 8c78c846302a..fcd17cc1c2ba 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -126,8 +126,9 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -500,14 +501,5 @@ def compute_logits(self, hidden_states: torch.Tensor, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=([ - "rotary_emb.inv_freq", - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ]), - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 8c2ad6f19251..86ce813ddf3d 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -338,13 +338,5 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - skip_prefixes=[ - "rotary_emb.inv_freq", "rotary_emb.cos_cached", - "rotary_emb.sin_cached" - ], - ) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 5927afa91f49..f4ba5a8030e5 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -349,8 +349,7 @@ def load_weights(self, weights: Iterable[tuple[str, self, # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. - skip_prefixes=([ - "rotary_emb.inv_freq", "lm_head.weight" - ] if self.config.tie_word_embeddings else ["rotary_emb.inv_freq"]), + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 5cc501622891..027cd748e9de 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -80,18 +80,30 @@ class AutoWeightsLoader: environment variable ``VLLM_LOGGING_LEVEL=DEBUG``. """ + # Models trained using early version ColossalAI + # may include these tensors in checkpoint. Skip them. + ROTARY_EMBEDS_UNUSED_WEIGHTS = [ + "rotary_emb.inv_freq", + "rotary_emb.cos_cached", + "rotary_emb.sin_cached", + ] + def __init__( self, module: nn.Module, *, skip_prefixes: Optional[list[str]] = None, + skip_substrs: Optional[list[str]] = None, ignore_unexpected_prefixes: Optional[list[str]] = None, ) -> None: super().__init__() self.module = module self.skip_prefixes = skip_prefixes or [] + self.skip_substrs = skip_substrs or [] self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or [] + # update default skip_substrs + self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS def _groupby_prefix( self, @@ -119,7 +131,8 @@ def _get_qualname(self, prefix: str, rest: str) -> str: return ".".join((prefix, rest)) def _can_skip(self, qualname: str) -> bool: - return any(qualname.startswith(p) for p in self.skip_prefixes) + return (any(qualname.startswith(p) for p in self.skip_prefixes) + or any(substr in qualname for substr in self.skip_substrs)) def _can_ignore_unexpected(self, qualname: str) -> bool: return any( @@ -257,6 +270,9 @@ def load_weights( ) -> set[str]: if mapper is not None: weights = mapper.apply(weights) + # filter out weights with first-prefix/substr to skip in name + weights = ((name, weight) for name, weight in weights + if not self._can_skip(name)) autoloaded_weights = set(self._load_module("", self.module, weights)) return autoloaded_weights diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index 53e289370a9f..f6ab72f4e9b8 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -43,7 +43,7 @@ def serialize_item(cls, obj: object) -> bytes: "ndarray", { "dtype": obj.dtype.str, "shape": obj.shape, - "data": obj.data.tobytes(), + "data": obj.tobytes(), }) logger.warning( diff --git a/vllm/outputs.py b/vllm/outputs.py index 6cd60575b00d..05026b569691 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -9,12 +9,15 @@ import torch from typing_extensions import TypeVar, deprecated +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceGroupBase, SequenceStatus) +logger = init_logger(__name__) + @dataclass class CompletionOutput: @@ -122,7 +125,13 @@ def __init__( *, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, kv_transfer_params: Optional[dict[str, Any]] = None, + # Forward compatibility, code that uses args added in new release can + # still run with older versions of vLLM without breaking. + **kwargs: Any, ) -> None: + if kwargs: + logger.warning_once("RequestOutput: Ignoring extra arguments: %s", + str(kwargs)) self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 2d48af397636..5c0c90972b58 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -9,6 +9,7 @@ import torch from vllm.logger import init_logger +from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend @@ -177,6 +178,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: " set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.") os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if vllm_config.model_config and vllm_config.model_config.use_mla: + logger.info( + "MLA is enabled on a non-GPU platform; forcing chunked " + "prefill and prefix caching to be disabled.") + vllm_config.scheduler_config.enable_chunked_prefill = False + vllm_config.scheduler_config.chunked_prefill_enabled = False + vllm_config.scheduler_config.max_num_batched_tokens = max( + vllm_config.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + @classmethod def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on CPU.") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bdee8b2f821d..0bdf15959302 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -311,6 +311,10 @@ def supports_v1(cls, model_config: "ModelConfig") -> bool: def use_custom_allreduce(cls) -> bool: return True + @classmethod + def get_piecewise_backend_cls(cls) -> str: + return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 456b054b2b43..6f7c5a6d3cae 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -7,6 +7,7 @@ from vllm import envs from vllm.logger import init_logger +from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import Platform, PlatformEnum, _Backend @@ -80,6 +81,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.") os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + if vllm_config.model_config and vllm_config.model_config.use_mla: + logger.info( + "MLA is enabled on a non-GPU platform; forcing chunked " + "prefill and prefix caching to be disabled.") + vllm_config.scheduler_config.enable_chunked_prefill = False + vllm_config.scheduler_config.chunked_prefill_enabled = False + vllm_config.scheduler_config.max_num_batched_tokens = max( + vllm_config.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + @classmethod def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on HPU.") diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index b09e31e9ed46..20284b4e1801 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -478,6 +478,13 @@ def get_cu_count(cls, device_id: int = 0) -> int: """ raise NotImplementedError + @classmethod + def get_piecewise_backend_cls(cls) -> str: + """ + Get piecewise backend class for piecewise graph. + """ + return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 71f7c718cdf9..9cd49fd34804 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -6,6 +6,7 @@ from vllm import envs from vllm.logger import init_logger +from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import Platform, PlatformEnum @@ -51,12 +52,21 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: assert (vllm_config.lora_config is None), "LoRA is not supported for Neuron backend." - cache_config = vllm_config.cache_config - if cache_config: + if vllm_config.cache_config and vllm_config.model_config: # neuron needs block_size = max_model_len vllm_config.cache_config.block_size = \ vllm_config.model_config.max_model_len # type: ignore + if vllm_config.model_config and vllm_config.model_config.use_mla: + logger.info( + "MLA is enabled on a non-GPU platform; forcing chunked " + "prefill and prefix caching to be disabled.") + vllm_config.scheduler_config.enable_chunked_prefill = False + vllm_config.scheduler_config.chunked_prefill_enabled = False + vllm_config.scheduler_config.max_num_batched_tokens = max( + vllm_config.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + @classmethod def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on Neuron.") diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index c8b86087578d..1685c65ad0b9 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -102,26 +102,42 @@ def on_mi250_mi300() -> bool: @cache -def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, - block_size: int, gqa_ratio: int, - max_seq_len: int, - sliding_window: int) -> bool: +def use_rocm_custom_paged_attention( + qtype: torch.dtype, + head_size: int, + block_size: int, + gqa_ratio: int, + max_seq_len: int, + sliding_window: int, + kv_cache_dtype: str, + alibi_slopes: Optional[torch.Tensor] = None) -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) + ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) - # rocm custom page attention not support on gfx1* # custom paged attn always supported on V0. On V1, requires sliding window # disabled due to observed numerical discrepancy. - return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0 - or sliding_window == (-1, -1)) - and (qtype == torch.half or qtype == torch.bfloat16) - and (head_size == 64 or head_size == 128) - and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 - and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - and envs.VLLM_ROCM_USE_AITER)) + if ON_GFX9: + return ((not envs.VLLM_USE_V1 or sliding_window == 0 + or sliding_window == (-1, -1)) + and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + and envs.VLLM_ROCM_USE_AITER)) + + else: + return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 + or sliding_window == (-1, -1)) + and (qtype == torch.half or qtype == torch.bfloat16) + and head_size == 128 and block_size == 16 + and (gqa_ratio >= 3 and gqa_ratio <= 16) + and max_seq_len <= 32768 and alibi_slopes is None + and kv_cache_dtype == "auto" + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) class RocmPlatform(Platform): @@ -362,3 +378,11 @@ def use_custom_allreduce(cls) -> bool: def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( device_id).multi_processor_count + + @classmethod + def is_navi(cls) -> bool: + return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName + + @classmethod + def get_piecewise_backend_cls(cls) -> str: + return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 6c573c1b3635..0173b15697cf 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -9,6 +9,7 @@ from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import Platform, PlatformEnum, _Backend @@ -161,6 +162,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Forcing --disable_chunked_mm_input.") scheduler_config.disable_chunked_mm_input = True + if vllm_config.model_config and vllm_config.model_config.use_mla: + logger.info( + "MLA is enabled on a non-GPU platform; forcing chunked " + "prefill and prefix caching to be disabled.") + vllm_config.scheduler_config.enable_chunked_prefill = False + vllm_config.scheduler_config.chunked_prefill_enabled = False + vllm_config.scheduler_config.max_num_batched_tokens = max( + vllm_config.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + @classmethod def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on TPU.") diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 225e756cd7ce..785fb6ce1b79 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -5,6 +5,7 @@ import torch from vllm.logger import init_logger +from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import DeviceCapability, Platform, PlatformEnum, _Backend @@ -113,6 +114,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend = "ray" + if vllm_config.model_config and vllm_config.model_config.use_mla: + logger.info( + "MLA is enabled on a non-GPU platform; forcing chunked " + "prefill and prefix caching to be disabled.") + vllm_config.scheduler_config.enable_chunked_prefill = False + vllm_config.scheduler_config.chunked_prefill_enabled = False + vllm_config.scheduler_config.max_num_batched_tokens = max( + vllm_config.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + @classmethod def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on XPU.") diff --git a/vllm/sequence.py b/vllm/sequence.py index 5aa9ae62f542..f5f9c56a7db2 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -112,12 +112,12 @@ class RequestMetrics: will include model forward, block/sync across workers, cpu-gpu sync time and sampling time. spec_token_acceptance_counts: number of accepted speculative tokens at - each position; the first token is from + each position; the first token is from the target model and is always accepted; - e.g., when it's [10, 8, 4, 2] for a req, + e.g., when it's [10, 8, 4, 2] for a req, it means there were 10 forward passes in - total, and there were 8, 4, 2 accepted - tokens at 1st, 2nd, 3rd speculation step. + total, and there were 8, 4, 2 accepted + tokens at 1st, 2nd, 3rd speculation step. """ arrival_time: float last_token_time: float @@ -714,9 +714,9 @@ class SequenceGroup: trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request. priority: User-defined priority of the request. - draft_size: The number of speculative tokens plus one from the target + draft_size: The number of speculative tokens plus one from the target model; equal to max number of tokens a step can generate - for single-draft speculative decoding but larger than + for single-draft speculative decoding but larger than that for multi-draft SD (currently not supported). """ @@ -1123,7 +1123,7 @@ def __repr__(self) -> str: self.output_embed.shape if self.output_embed is not None else None return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " f"output_token={self.output_token}, " - f"output_embed.shape={output_embed_shape}" + f"output_embed.shape={output_embed_shape}, " f"logprobs={self.logprobs})") def __eq__(self, other: object) -> bool: diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 586d5c7f5e54..377523efefc3 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -52,13 +52,15 @@ def __init__(self, assert self.model is not None, \ "model should not be None when method is eagle" kwargs["architectures"] = [ - f"Eagle{arch}" for arch in self.model.architectures + f"Eagle{arch}" if not arch.startswith("Eagle") \ + else arch for arch in self.model.architectures ] elif method == "eagle3": assert self.model is not None, \ "model should not be None when method is eagle3" kwargs["architectures"] = [ - f"Eagle3{arch}" for arch in self.model.architectures + f"Eagle3{arch}" if not arch.startswith("Eagle3") \ + else arch for arch in self.model.architectures ] else: raise ValueError(f"Invalid method {method}. \ diff --git a/vllm/utils.py b/vllm/utils.py index d8f099995003..333fd7634a8d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -77,6 +77,12 @@ logger = init_logger(__name__) +# This value is chosen to have a balance between ITL and TTFT. Note it is +# not optimized for throughput. +DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 +POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 + # Exception strings for non-implemented encoder/decoder scenarios # Reminder: Please update docs/source/features/compatibility_matrix.md diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 7ce39110ac01..56ac834b4d7e 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -53,6 +53,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # The number of entries in the last page of each request in # the paged kv cache, shape: [batch_size] paged_kv_last_page_len: Optional[torch.Tensor] = None + # The query indptr, shape : [num_decode + 1] + qo_indptr: Optional[torch.Tensor] = None class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): @@ -67,35 +69,41 @@ def __init__(self, runner, kv_cache_spec: AttentionSpec, max_model_len = self.runner.model_config.max_model_len assert max_model_len == 32768,\ "AITER MLA requires max_model_len=32768" - assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ + assert self.runner.block_size == 1, "AITER MLA" \ "only supports block size 1." def _get_paged_kv_tensors( self, block_table: torch.Tensor, seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: - page_size = self.kv_cache_spec.block_size + page_size = self.runner.block_size block_table_bounds = (seq_lens + page_size - 1) // page_size + device = self.runner.device mask = (torch.arange(block_table.size(1), dtype=block_table.dtype, - device=block_table.device).unsqueeze(0) + device=device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table[mask] paged_kv_indptr = torch.cat([ - torch.zeros(1, - dtype=block_table_bounds.dtype, - device=block_table_bounds.device), + torch.zeros(1, dtype=block_table_bounds.dtype, device=device), block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) + qo_indptr = torch.arange(0, + self._num_decodes + 1, + step=1, + dtype=torch.int32, + device=device) + return ( paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, + qo_indptr, ) def _build_decode(self, block_table_tensor: torch.Tensor, @@ -105,6 +113,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, paged_kv_indices, paged_kv_indptr, paged_last_page_len, + qo_indptr, ) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) attn_metadata = AiterMLADecodeMetadata( @@ -112,7 +121,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens=seq_lens, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_last_page_len) + paged_kv_last_page_len=paged_last_page_len, + qo_indptr=qo_indptr) return attn_metadata @@ -137,7 +147,10 @@ def __init__( alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, **mla_args) - + assert (num_heads == 16 or num_heads == 128), ( + f"Aiter MLA only supports 16 or 128 number of heads.\n" + f"Provided {num_heads} number of heads.\n" + "Try adjusting tensor_parallel_size value.") unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap ] @@ -189,7 +202,18 @@ def _forward_decode( kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + if self.num_heads == 16: + # AITER MLA decode kernel only supports + # max_seqlen_q=1 when using 16 heads. + max_seqlen_qo = 1 + else: + # AITER MLA decode Kernel handles arbitrary + # max_seqlen_q values when using 128 heads. + assert attn_metadata.prefill is not None + max_seqlen_qo = attn_metadata.prefill.max_query_len + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.decode.qo_indptr, max_seqlen_qo, attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index da18ece7555a..598fc871110e 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -32,16 +32,9 @@ def create_empty(cls) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" return cls([]) - def get_block_ids(self) -> list[list[int]]: - """ - Converts the KVCacheBlocks instance to block_ids. - - Returns: - list[list[int]]: A two-level list where - * the outer list corresponds to KV cache groups (only 1 group now) - * each inner list contains the block_ids of the blocks in that group - """ - return [[block.block_id for block in self.blocks]] + def get_block_ids(self) -> list[int]: + """Converts the KVCacheBlocks instance to a list of block IDs.""" + return [block.block_id for block in self.blocks] def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" @@ -307,9 +300,9 @@ def get_num_common_prefix_blocks( self, request: Request, num_running_requests: int, - ) -> list[int]: + ) -> int: """Calculate the number of common prefix blocks shared by all requests - in the RUNNING state for each kv cache group. + in the RUNNING state. The function determines this by selecting any request and iterating through its blocks. A block is considered a common prefix block if its @@ -339,14 +332,11 @@ def get_num_common_prefix_blocks( requests in the current step. Returns: - list[int]: The number of common prefix blocks for each kv cache - group. + int: The number of common prefix blocks. """ assert request.status == RequestStatus.RUNNING - return [ - self.single_type_manager.get_num_common_prefix_blocks( - request.request_id, num_running_requests) - ] + return self.single_type_manager.get_num_common_prefix_blocks( + request.request_id, num_running_requests) def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. @@ -364,8 +354,10 @@ def take_events(self) -> list[KVCacheEvent]: """ return self.block_pool.take_events() - def get_block_ids(self, request_id: str) -> list[list[int]]: + def get_block_ids(self, request_id: str) -> list[int]: """Get the block ids of a request.""" assert request_id in self.single_type_manager.req_to_blocks - return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id] - ).get_block_ids() + return [ + block.block_id + for block in self.single_type_manager.req_to_blocks[request_id] + ] diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 403b5401be75..27c515835087 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -577,12 +577,14 @@ def create_kv_cache_group_specs( """ kv_cache_groups = [] for layer_names_one_group in grouped_layer_names: - layer_specs = [ - kv_cache_spec[layer_name] for layer_name in layer_names_one_group - ] - merged_layer_spec = layer_specs[0].merge(layer_specs) + layer_spec = kv_cache_spec[layer_names_one_group[0]] + assert all( + kv_cache_spec[layer_name] == layer_spec + for layer_name in layer_names_one_group[1:]), ( + "All layers in the same KV cache group must share the same " + "KVCacheSpec.") kv_cache_groups.append( - KVCacheGroupSpec(layer_names_one_group, merged_layer_spec)) + KVCacheGroupSpec(layer_names_one_group, layer_spec)) return kv_cache_groups @@ -681,7 +683,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): head_size=spec.head_size, dtype=spec.dtype, use_mla=spec.use_mla, - sliding_window=spec.sliding_window, ) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 257234430983..24032498e50b 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -26,7 +26,7 @@ class NewRequestData: mm_hashes: list[str] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams - block_ids: list[list[int]] + block_ids: list[int] num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -34,7 +34,7 @@ class NewRequestData: def from_request( cls, request: Request, - block_ids: list[list[int]], + block_ids: list[int], ) -> NewRequestData: return cls( req_id=request.request_id, @@ -85,7 +85,7 @@ class CachedRequestData: # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool new_token_ids: list[int] - new_block_ids: list[list[int]] + new_block_ids: list[int] num_computed_tokens: int @classmethod @@ -94,7 +94,7 @@ def from_request( request: Request, resumed_from_preemption: bool, new_token_ids: list[int], - new_block_ids: list[list[int]], + new_block_ids: list[int], ) -> CachedRequestData: return cls( req_id=request.request_id, @@ -131,9 +131,9 @@ class SchedulerOutput: # E.g., if a request has [0, 1], it could mean the vision encoder needs # to process that the request's 0-th and 1-th images in the current step. scheduled_encoder_inputs: dict[str, list[int]] - # Number of common prefix blocks for all requests in each KV cache group. + # Number of common prefix blocks for all requests. # This can be used for cascade attention. - num_common_prefix_blocks: list[int] + num_common_prefix_blocks: int # Request IDs that are finished in between the previous and the current # steps. This is used to notify the workers about the finished requests diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5ad05485e8f3..2152409019b9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -173,7 +173,7 @@ def schedule(self) -> SchedulerOutput: # uses structured decoding. structured_output_request_ids: dict[str, int] = {} - req_to_new_block_ids: dict[str, list[list[int]]] = {} + req_to_new_block_ids: dict[str, list[int]] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -345,32 +345,38 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.appendleft(request) continue + num_external_computed_tokens = 0 + load_kv_async = False + # Get already-cached tokens. if num_prealloc_computed_tokens == 0: new_computed_blocks, num_native_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_native_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens = (num_native_computed_tokens + + num_external_computed_tokens) else: # P/D: skip checking prefix cache if loaded from remote kvs. new_computed_blocks = KVCacheBlocks.create_empty() num_native_computed_tokens = 0 - # Get externally-cached tokens if using a KVConnector. - num_external_computed_tokens, load_kv_async = ( - (0, False) if self.connector is None else - self.connector.get_num_new_matched_tokens( - request, num_native_computed_tokens)) - - # Total computed tokens (local + external). - num_computed_tokens = (num_native_computed_tokens + - num_external_computed_tokens + - num_prealloc_computed_tokens) + # Total computed tokens (allocated in prior step). + num_computed_tokens = num_prealloc_computed_tokens encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget # P/D: loading remote KV, do not allocate for new work. if load_kv_async: + assert num_external_computed_tokens > 0 num_new_tokens = 0 # Number of tokens to be scheduled. else: @@ -411,7 +417,8 @@ def schedule(self) -> SchedulerOutput: # KVConnector: update internal state after allocation. # This information is used to determine if a load is # needed for this request. - if self.connector is not None: + if num_external_computed_tokens: + assert self.connector is not None self.connector.update_state_after_alloc( request, new_computed_blocks + new_blocks, @@ -477,8 +484,7 @@ def schedule(self) -> SchedulerOutput: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) + num_common_prefix_blocks = 0 if self.running: any_request = self.running[0] num_common_prefix_blocks = ( @@ -565,7 +571,7 @@ def _make_cached_request_data( request: Request, num_scheduled_tokens: int, num_scheduled_spec_tokens: int, - new_block_ids: list[list[int]], + new_block_ids: list[int], resumed_from_preemption: bool, ) -> CachedRequestData: # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating @@ -940,9 +946,7 @@ def _connector_finished( """ if self.connector is None: return False, None - assert len(self.kv_cache_config.kv_cache_groups - ) == 1, "KV connector only supports one KV cache group now" - block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) return self.connector.request_finished(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: @@ -959,10 +963,9 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ if request.request_id not in self.finished_recving_kv_req_ids: return False - assert len(self.kv_cache_config.kv_cache_groups - ) == 1, "KV connector only supports one KV cache group now" + # Now that the blocks are ready, actually cache them. - block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) num_computed_tokens = len(block_ids) * self.block_size if num_computed_tokens == request.num_tokens: num_computed_tokens -= 1 diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0cf2383af1c9..64e472457ee3 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -57,6 +57,10 @@ def __init__(self, executor_fail_callback: Optional[Callable] = None): assert vllm_config.model_config.runner_type != "pooling" + # plugins need to be loaded at the engine/scheduler level too + from vllm.plugins import load_general_plugins + load_general_plugins() + self.vllm_config = vllm_config logger.info("Initializing a V1 LLM engine (v%s) with config: %s", VLLM_VERSION, vllm_config) @@ -697,7 +701,7 @@ def _init_data_parallel(self, vllm_config: VllmConfig): for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)) - self.local_dp_rank = local_dp_rank + self.dp_rank = dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.current_wave = 0 @@ -770,7 +774,7 @@ def run_busy_loop(self): local_unfinished_reqs) if not self.engines_running: - if self.local_dp_rank == 0: + if self.dp_rank == 0: # Notify client that we are pausing the loop. logger.debug("Wave %d finished, pausing engine loop.", self.current_wave) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 74b226b45424..2061806e6b36 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -50,6 +50,7 @@ def _init_executor(self) -> None: self.is_failed = False self.shutdown_event = threading.Event() self.failure_callback: Optional[FailureCallback] = None + self.io_thread_pool: Optional[ThreadPoolExecutor] = None self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size @@ -107,7 +108,6 @@ def _init_executor(self) -> None: # For pipeline parallel, we use a thread pool for asynchronous # execute_model. - self.io_thread_pool: Optional[ThreadPoolExecutor] = None if self.max_concurrent_batches > 1: # Note: must use only 1 IO thread to keep dequeue sequence # from the response queue diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 2747fc7fabd1..4fc0844cd1f4 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,11 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -import copy from dataclasses import dataclass -from typing import Optional import torch -from typing_extensions import Self from vllm.config import VllmConfig from vllm.logger import init_logger @@ -56,16 +53,6 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: """ raise NotImplementedError - @classmethod - def merge(cls, specs: list[Self]) -> Self: - """ - Merge a list of KVCacheSpec objects into a single KVCacheSpec object. - """ - assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), ( - "All layers in the same KV cache group must share the same " - "type_id.") - return copy.deepcopy(specs[0]) - @dataclass class AttentionSpec(KVCacheSpec): @@ -84,16 +71,6 @@ def page_size_bytes(self) -> int: @dataclass class FullAttentionSpec(AttentionSpec): - sliding_window: Optional[int] = None - """ - When hybrid allocator is disabled and the model contains both full - attention layers and sliding window attention layers, sliding - window attention are regarded as full attention in KV cache manager - (blocks are allocated for all tokens), while computed as sliding window - attention in model runner. - In this case, we use FullAttentionSpec and record the sliding window size. - Default to None for not using sliding window attention. - """ @property def type_id(self) -> str: @@ -103,25 +80,6 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes - @classmethod - def merge(cls, specs: list[Self]) -> Self: - """ - Merge a list of FullAttentionSpec objects into a single - FullAttentionSpec object. - """ - merged_spec = super().merge(specs) - sliding_window = set(spec.sliding_window for spec in specs - if spec.sliding_window is not None) - if len(sliding_window) == 0: - merged_spec.sliding_window = None - elif len(sliding_window) == 1: - merged_spec.sliding_window = sliding_window.pop() - else: - raise ValueError( - "All sliding window layers in the same KV cache group " - "must have the same window size.") - return merged_spec - @dataclass class SlidingWindowSpec(AttentionSpec): diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5b84bc1f5ec3..19fb2a2af7dd 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,7 +9,8 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.triton_utils import tl, triton @@ -308,6 +309,9 @@ def load_model(self, target_model: nn.Module) -> None: loaded_weights = self.model.load_weights( loader.get_all_weights(draft_model_config, self.model)) + process_weights_after_loading(self.model, draft_model_config, + target_device) + # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: assert "model.embed_tokens.weight" not in loaded_weights, \ diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 0c3341691509..581d3d9bd11b 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -4,8 +4,6 @@ import torch from vllm.logger import init_logger -from vllm.utils import cdiv -from vllm.v1.kv_cache_interface import KVCacheConfig logger = init_logger(__name__) @@ -98,48 +96,3 @@ def get_cpu_tensor(self) -> torch.Tensor: def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" return self.block_table_np - - -class MultiGroupBlockTable: - """The BlockTables for each KV cache group.""" - - def __init__(self, max_num_reqs: int, max_model_len: int, - max_num_batched_tokens: int, pin_memory: bool, - device: torch.device, kv_cache_config: KVCacheConfig) -> None: - max_num_blocks_per_req = [ - cdiv(max_model_len, g.kv_cache_spec.block_size) - for g in kv_cache_config.kv_cache_groups - ] - self.block_tables = [ - BlockTable(max_num_reqs, max_num_blocks_per_req[i], - max_num_batched_tokens, pin_memory, device) - for i in range(len(kv_cache_config.kv_cache_groups)) - ] - - def append_row(self, block_ids: list[list[int]], row_idx: int) -> None: - for i, block_table in enumerate(self.block_tables): - block_table.append_row(block_ids[i], row_idx) - - def add_row(self, block_ids: list[list[int]], row_idx: int) -> None: - for i, block_table in enumerate(self.block_tables): - block_table.add_row(block_ids[i], row_idx) - - def move_row(self, src: int, tgt: int) -> None: - for block_table in self.block_tables: - block_table.move_row(src, tgt) - - def swap_row(self, src: int, tgt: int) -> None: - for block_table in self.block_tables: - block_table.swap_row(src, tgt) - - def commit(self, num_reqs: int) -> None: - for block_table in self.block_tables: - block_table.commit(num_reqs) - - def clear(self) -> None: - for block_table in self.block_tables: - block_table.clear() - - def __getitem__(self, idx: int) -> "BlockTable": - """Returns the BlockTable for the i-th KV cache group.""" - return self.block_tables[idx] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 570de9bddd29..871654fca366 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -11,11 +11,10 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values -from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import LogprobsTensors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice -from vllm.v1.worker.block_table import MultiGroupBlockTable +from vllm.v1.worker.block_table import BlockTable _SAMPLING_EPS = 1e-5 @@ -30,7 +29,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: list[list[int]] + block_ids: list[int] num_computed_tokens: int output_token_ids: list[int] @@ -59,14 +58,15 @@ def __init__( self, max_num_reqs: int, max_model_len: int, + max_num_blocks_per_req: int, max_num_batched_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, - kv_cache_config: KVCacheConfig, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.device = device self.pin_memory = pin_memory @@ -99,13 +99,12 @@ def __init__( self.num_computed_tokens_cpu_tensor.numpy() # Block table. - self.block_table = MultiGroupBlockTable( + self.block_table = BlockTable( max_num_reqs=max_num_reqs, - max_model_len=max_model_len, + max_num_blocks_per_req=max_num_blocks_per_req, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, - kv_cache_config=kv_cache_config, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 201796c96ee5..759d69293a32 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,8 +12,6 @@ import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadataBuilder) from vllm.attention.layer import Attention from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import (CompilationLevel, VllmConfig, @@ -34,8 +32,8 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, cdiv, check_use_alibi, - is_pin_memory_available) + GiB_bytes, LayerBlockType, LazyLoader, cdiv, + check_use_alibi, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget @@ -53,7 +51,6 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -105,17 +102,59 @@ def __init__( self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] + # NOTE(woosuk): sliding_window is None for models with interleaved + # attention. Use interleaved_sliding_window instead. + self.sliding_window = model_config.get_sliding_window() + self.interleaved_sliding_window = getattr( + model_config.hf_text_config, "interleaved_sliding_window", None) + self.window_size = (self.sliding_window + or self.interleaved_sliding_window) + self.is_multimodal_model = model_config.is_multimodal_model + self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. + self.num_attn_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) self.num_query_heads = model_config.get_num_attention_heads( parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size + self.attn_backend = get_attn_backend( + self.head_size, + self.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) + if self.attn_backend is None: + error_msg = ( + f"Error with get_att_backend: {self.head_size=}, " + f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{self.model_config.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 GPUModelRunner.") + + if self.vllm_config.compilation_config.full_cuda_graph: + attn_backend_name = self.attn_backend.__name__ + flash_attn_version = get_flash_attn_version() + if attn_backend_name != "FlashAttentionBackend" or \ + flash_attn_version != 3: + raise ValueError( + f"full_cuda_graph is only supported with " + f"FA3. Current attention backend is {attn_backend_name}, " + f"FlashAttention version is {flash_attn_version}.") + self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support @@ -137,10 +176,8 @@ def __init__( # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] - self.attn_metadata_builders: list[AttentionMetadataBuilder] = [] - self.attn_backends: list[type[AttentionBackend]] = [] # self.kv_cache_config: KVCacheConfig - # self.input_batch: InputBatch # Persistent batch. + # self.attn_metadata_builder: type[AttentionMetadataBuilder] # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -169,6 +206,16 @@ def __init__( # Request states. self.requests: dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=model_config.get_vocab_size(), + ) self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE @@ -263,31 +310,6 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: - """ - Update the order of requests in the batch based on the attention - backend's needs. For example, some attention backends (namely MLA) may - want to separate requests based on if the attention computation will be - compute-bound or memory-bound. - - Args: - scheduler_output: The scheduler output. - - Returns: - True if the batch was reordered, False otherwise. - """ - batch_reordered = self.attn_metadata_builders[0].reorder_batch( - self.input_batch, scheduler_output) - - # For models with multiple KV cache groups, the groups should agree on - # the same order of requests. We ensure this by only allowing the first - # group to reorder the batch and asserting that all other groups do not - # reorder the batch. - for i in range(1, len(self.kv_cache_config.kv_cache_groups)): - assert not self.attn_metadata_builders[i].reorder_batch( - self.input_batch, scheduler_output) - return batch_reordered - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -424,8 +446,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. - for i in range(len(self.kv_cache_config.kv_cache_groups)): - req_state.block_ids[i].extend(req_data.new_block_ids[i]) + req_state.block_ids.extend(req_data.new_block_ids) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -483,7 +504,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if removed_req_indices: self.input_batch.condense(removed_req_indices) - batch_reordered = self._may_reorder_batch(scheduler_output) + # Some attention backends (namely MLA) may want to separate requests + # based on if the attention computation will be compute-bound or + # memory-bound. This gives them a hook to do that. + batch_reordered = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() @@ -551,29 +576,21 @@ def _prepare_inputs( torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping for each KV cache group. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - block_size = kv_cache_group_spec.kv_cache_spec.block_size - block_table: BlockTable = self.input_batch.block_table[ - kv_cache_group_id] - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` - # here because M (max_model_len) is not necessarily divisible by - # block_size. - block_table_indices = ( - req_indices * block_table.max_num_blocks_per_req + - positions_np // block_size) - block_table_cpu = block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten( - )[block_table_indices].numpy() - block_offsets = positions_np % block_size - np.add( - block_numbers * block_size, - block_offsets, - out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.input_batch.block_table. + slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -615,6 +632,10 @@ def _prepare_inputs( attn_metadata: dict[str, FlashAttentionMetadata] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. + # NOTE(Chen): there is exactly one KV cache group that contains all + # attetnion layers in the model for now, so the current logic for + # getting attn_metadata is not related to kv_cache_group information. + # Will extend this part to support multiple KV cache groups later. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): @@ -623,19 +644,15 @@ def _prepare_inputs( if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id], - kv_cache_group_spec.kv_cache_spec, - self.attn_metadata_builders[kv_cache_group_id], + scheduler_output.num_common_prefix_blocks, ) - attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id].build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata)) + attn_metadata_i = self.attn_metadata_builder.build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -673,8 +690,6 @@ def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, num_common_prefix_blocks: int, - kv_cache_spec: KVCacheSpec, - attn_metadata_builder: AttentionMetadataBuilder, ) -> int: """Compute the length of the common prefix for cascade attention. @@ -693,7 +708,7 @@ def _compute_cascade_attn_prefix_len( Returns: int: Length of common prefix in tokens. """ - common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size + common_prefix_len = num_common_prefix_blocks * self.block_size if common_prefix_len == 0: # Common case. return 0 @@ -742,19 +757,15 @@ def _compute_cascade_attn_prefix_len( common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * - kv_cache_spec.block_size) - use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or - (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None)) - assert isinstance(kv_cache_spec, AttentionSpec) - use_cascade = attn_metadata_builder.use_cascade_attention( + common_prefix_len = (common_prefix_len // self.block_size * + self.block_size) + use_cascade = self.attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, - num_kv_heads=kv_cache_spec.num_kv_heads, + num_kv_heads=self.num_kv_heads, use_alibi=self.use_alibi, - use_sliding_window=use_sliding_window, + use_sliding_window=self.window_size is not None, num_sms=self.num_sms, ) return common_prefix_len if use_cascade else 0 @@ -1640,7 +1651,7 @@ def _dummy_run( dtype=np.int32) if skip_attn: - attn_metadata: Optional[dict[str, FlashAttentionMetadata]] = None + attn_metadata = None else: query_start_loc = self.query_start_loc[:num_reqs + 1] seq_lens = self.seq_lens[:num_reqs] @@ -1648,19 +1659,13 @@ def _dummy_run( common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens) - attn_metadata = {} - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id].build( - num_reqs=num_tokens, - num_actual_tokens=num_tokens, - max_query_len=num_tokens, - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - )) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + attn_metadata = self.attn_metadata_builder.build( + num_reqs=num_tokens, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -1716,6 +1721,10 @@ def _dummy_sampler_run( self, hidden_states: torch.Tensor, ) -> torch.Tensor: + # The dummy hidden states may contain special values, + # like `inf` or `nan`. + # To avoid breaking the sampler, we use a random tensor here instead. + hidden_states = torch.rand_like(hidden_states) logits = self.model.compute_logits(hidden_states, None) num_reqs = logits.size(0) @@ -1890,56 +1899,6 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: - """ - Initialize the attention backends and attention metadata builders. - """ - assert len(self.attn_backends) == 0 and len( - self.attn_metadata_builders - ) == 0, "Attention backends are already initialized" - for i, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group_spec.kv_cache_spec - if not isinstance(kv_cache_spec, AttentionSpec): - raise NotImplementedError( - "Only AttentionSpec is supported for now.") - attn_backend_i = get_attn_backend( - kv_cache_spec.head_size, - self.dtype, - kv_cache_spec.dtype, - kv_cache_spec.block_size, - self.model_config.is_attention_free, - use_mla=kv_cache_spec.use_mla, - ) - if attn_backend_i is None: - error_msg = ( - f"Error with get_attn_backend: {kv_cache_spec.head_size=}, " - f"{self.dtype=}, {kv_cache_spec.dtype=}, " - f"{kv_cache_spec.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{kv_cache_spec.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 " - "GPUModelRunner.") - - if self.vllm_config.compilation_config.full_cuda_graph: - attn_backend_name = attn_backend_i.__name__ - flash_attn_version = get_flash_attn_version() - if attn_backend_name != "FlashAttentionBackend" or \ - flash_attn_version != 3: - raise ValueError( - f"full_cuda_graph is only supported with " - f"FA3. Current attention backend is " - f"{attn_backend_name}, FlashAttention version is " - f"{flash_attn_version}.") - - block_table_i = self.input_batch.block_table[i] - attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - weakref.proxy(self), kv_cache_spec, block_table_i) - self.attn_backends.append(attn_backend_i) - self.attn_metadata_builders.append(attn_metadata_builder_i) - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1947,21 +1906,15 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + if len(kv_cache_config.kv_cache_groups) > 1: + raise NotImplementedError( + "Hybrid models with more than one KV cache type are not " + "supported yet.") self.kv_cache_config = kv_cache_config - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - kv_cache_config=kv_cache_config, - ) - self.initialize_attn_backend(kv_cache_config) kv_caches: dict[str, torch.Tensor] = {} - for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: tensor_config = kv_cache_config.tensors[layer_name] @@ -1976,7 +1929,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( + kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype @@ -1996,6 +1949,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self), + kv_cache_config.kv_cache_groups[0].kv_cache_spec, + self.input_batch.block_table) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2da99696445e..b4daf5a34678 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -171,10 +171,19 @@ def __init__( self.kv_caches: list[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} - # self.input_batch: InputBatch # Persistent batch. # Request states. self.requests: dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.vocab_size, + ) # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. @@ -190,7 +199,7 @@ def __init__( self.block_table_cpu = torch.zeros( (self.max_num_reqs, self.max_num_blocks_per_req), - dtype=torch.int32, + dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu") self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, @@ -515,12 +524,12 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, - out=self.input_batch.block_table[0]. + out=self.input_batch.block_table. slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. @@ -545,15 +554,15 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.position_ids = self.positions_cpu[: padded_total_num_scheduled_tokens].to( self.device) - self.input_batch.block_table[0].slot_mapping_cpu[ + self.input_batch.block_table.slot_mapping_cpu[ total_num_scheduled_tokens:] = _PAD_SLOT_ID slot_mapping = ( - self.input_batch.block_table[0]. + self.input_batch.block_table. slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( self.device)) block_tables = self.block_table_cpu[:self.max_num_reqs] block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( - self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]) + self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) block_tables = block_tables.to(self.device) query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( self.device) @@ -1254,18 +1263,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: "Hybrid models with more than one KV cache type are not " "supported yet.") - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - kv_cache_config=kv_cache_config, - ) - assert self.block_table_cpu.dtype == self.input_batch.block_table[ - 0].get_cpu_tensor().dtype - kv_caches: dict[str, torch.Tensor] = {} for kv_cache_group in kv_cache_config.kv_cache_groups: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8a294de45c81..53e79adf9aae 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,7 +23,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_pp_group +from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, graph_capture) @@ -729,7 +729,10 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( seq_group_metadata, range(positions[0], positions[0] + len(positions))) - if not mm_kwargs: + + # M-RoPE requires mrope_positions even for plain text; return early + # when mm_kwargs is empty only if inter_data.is_prompt is False. + if not mm_kwargs and not inter_data.is_prompt: return inter_data.multi_modal_kwargs = mm_kwargs @@ -741,12 +744,6 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, video_grid_thw = mm_kwargs.get("video_grid_thw", None) audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", None) - assert ( - image_grid_thw is not None or video_grid_thw is not None - or audio_feature_lengths is not None), ( - "mrope embedding type requires multi-modal input mapper " - "returns 'image_grid_thw' or 'video_grid_thw' or " - "'audio_feature_lengths'.") second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) @@ -872,7 +869,7 @@ def build(self) -> ModelInputForGPU: """ # Combine and flatten intermediate data. input_tokens = list[int]() - inputs_embeds_lst = list[torch.Tensor]() + inputs_embeds_list = list[torch.Tensor]() token_types = list[int]() for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: @@ -880,15 +877,15 @@ def build(self) -> ModelInputForGPU: for cur_token_types in inter_data.token_types: token_types.extend(cur_token_types) if inter_data.inputs_embeds is not None: - inputs_embeds_lst.append( + inputs_embeds_list.append( inter_data.inputs_embeds.to( dtype=self.runner.model_config.dtype, device=self.runner.device)) inputs_embeds: Optional[torch.Tensor] - if len(inputs_embeds_lst) == 0: + if len(inputs_embeds_list) == 0: inputs_embeds = None else: - inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to( + inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to( dtype=self.runner.model_config.dtype, device=self.runner.device) assert len(inputs_embeds) == len(input_tokens) @@ -1893,50 +1890,60 @@ def execute_model( logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) - if not self.is_driver_worker: - return [] + if self.is_driver_worker: + if model_input.async_callback is not None: + model_input.async_callback() - if model_input.async_callback is not None: - model_input.async_callback() - - # Sample the next token. - assert isinstance(self.sampler, Sampler) - orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor - if model_input.inputs_embeds is not None: - self.sampler.include_gpu_probs_tensor = True + # Sample the next token. + assert isinstance(self.sampler, Sampler) + orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor + if model_input.inputs_embeds is not None: + self.sampler.include_gpu_probs_tensor = True - output: SamplerOutput = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - # If there are multiple workers, we are still tracking the latency - # from the start time of the driver worker to the end time of the - # driver worker. The model forward time will then end up covering - # the communication time as well. - output.model_forward_time = (orig_model_forward_time + - model_forward_time) + output: SamplerOutput = self.sampler( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + # If there are multiple workers, we are still tracking the + # latency from the start time of the driver worker to the end + # time of the driver worker. The model forward time will then + # end up covering the communication time as well. + output.model_forward_time = (orig_model_forward_time + + model_forward_time) if model_input.inputs_embeds is not None: - self.sampler.include_gpu_probs_tensor = \ - orig_include_gpu_probs_tensor - if output.sampled_token_ids is not None: - output.sampled_token_embeds = self.model.get_input_embeddings( - output.sampled_token_ids.squeeze(1)) - - for token_embed, sequence_group_output in zip( - output.sampled_token_embeds, output.outputs): - assert len(sequence_group_output.samples) == 1 - sequence_group_output.samples[0].output_embed = token_embed + if self.is_driver_worker: + sampled = broadcast_tensor_dict( + {"token_ids": output.sampled_token_ids}) + else: + sampled = broadcast_tensor_dict() + if sampled["token_ids"] is not None: + sampled_token_embeds = self.model.get_input_embeddings( + sampled["token_ids"].squeeze(1)) + if self.is_driver_worker: + self.sampler.include_gpu_probs_tensor = \ + orig_include_gpu_probs + + output.sampled_token_embeds = sampled_token_embeds + + for token_embed, sequence_group_output in zip( + output.sampled_token_embeds, output.outputs): + assert len(sequence_group_output.samples) == 1 + sequence_group_output.samples[ + 0].output_embed = token_embed + + if not self.is_driver_worker: + return [] if self.return_hidden_states: # we only need to pass hidden states of most recent token From 3f77bccf31d09b533cb4a8c2808207a2e91288f0 Mon Sep 17 00:00:00 2001 From: Crucifixion-Fxl Date: Fri, 23 May 2025 12:28:10 +0800 Subject: [PATCH 4/5] [Bugfix] Migrate to REGEX Library to prevent catastrophic backtracking Signed-off-by: Crucifixion-Fxl --- vllm/model_executor/model_loader/tensorizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 16e16a0cbba4..56391b4f177d 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -8,6 +8,7 @@ import json import os import time +import threading from collections.abc import Generator from dataclasses import dataclass from functools import partial From 9523077a4fb4e6df6a6153786cd8a89d179df10a Mon Sep 17 00:00:00 2001 From: Crucifixion-Fxl Date: Fri, 23 May 2025 12:33:21 +0800 Subject: [PATCH 5/5] [Bugfix] Migrate to REGEX Library to prevent catastrophic backtracking Signed-off-by: Crucifixion-Fxl --- vllm/model_executor/model_loader/tensorizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 56391b4f177d..6f9408d892c3 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -7,8 +7,8 @@ import io import json import os -import time import threading +import time from collections.abc import Generator from dataclasses import dataclass from functools import partial