Skip to content
1 change: 0 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,7 +1596,6 @@ def schedule(
multi_modal_placeholders=(
seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None),
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
else:
Expand Down
3 changes: 1 addition & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,11 @@ async def add_request_async(
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)

preprocessed_inputs = await self.input_preprocessor.preprocess_async(
processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)

if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
Expand Down
13 changes: 3 additions & 10 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputs)
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
Expand Down Expand Up @@ -213,7 +212,6 @@ def __init__(
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
Expand Down Expand Up @@ -274,11 +272,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
self.tokenizer,
mm_registry)

self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
self.model_config)

self.model_executor = executor_class(vllm_config=vllm_config, )
self.model_executor = executor_class(vllm_config=vllm_config)

if self.model_config.runner_type != "pooling":
self._initialize_kv_caches()
Expand Down Expand Up @@ -762,12 +756,11 @@ def add_request(
prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request))

preprocessed_inputs = self.input_preprocessor.preprocess(
processed_inputs = self.input_preprocessor.preprocess(
prompt,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)

self._add_processed_request(
request_id=request_id,
Expand Down
8 changes: 3 additions & 5 deletions vllm/inputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonInputsAdapter, SingletonPrompt,
TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts)
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import (DummyData, InputContext, InputProcessingContext,
InputRegistry)

Expand All @@ -27,7 +26,6 @@
"EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"SingletonInputsAdapter",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
Expand Down
161 changes: 2 additions & 159 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Iterable
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast

import torch
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
from typing_extensions import NotRequired, TypedDict, TypeVar

if TYPE_CHECKING:
from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
MultiModalPlaceholderDict)
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs


class TextPrompt(TypedDict):
Expand Down Expand Up @@ -147,46 +141,11 @@ class TokenInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""

multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""

Comment on lines -150 to -155
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed in this PR, but now that we have a clear set of different branches for input processing, we should probably add some documentation under each input/prompt type to indicate when they will be used.

multi_modal_inputs: NotRequired["MultiModalKwargs"]
"""
Optional multi-modal inputs to pass to the model,
if the model supports it.
"""

multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
"""
Placeholder ranges for the multi-modal data.
"""

multi_modal_hashes: NotRequired[list[str]]
"""
The hashes of the multi-modal data.
"""

mm_processor_kwargs: NotRequired[dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""


def token_inputs(
prompt_token_ids: list[int],
token_type_ids: Optional[list[int]] = None,
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
multi_modal_hashes: Optional[list[str]] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values."""
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
Expand All @@ -195,16 +154,6 @@ def token_inputs(
inputs["prompt"] = prompt
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data
if multi_modal_inputs is not None:
inputs["multi_modal_inputs"] = multi_modal_inputs
if multi_modal_hashes is not None:
inputs["multi_modal_hashes"] = multi_modal_hashes
if multi_modal_placeholders is not None:
inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None:
inputs["mm_processor_kwargs"] = mm_processor_kwargs

return inputs

Expand Down Expand Up @@ -237,112 +186,6 @@ class EncoderDecoderInputs(TypedDict):
:class:`vllm.sequence.Sequence`.
"""


@dataclass
class SingletonInputsAdapter:
"""
Unified interface to access the components of :class:`SingletonInputs`.
"""
inputs: SingletonInputs

@cached_property
def prompt(self) -> Optional[str]:
inputs = self.inputs

if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("prompt")

assert_never(inputs) # type: ignore[arg-type]

@cached_property
def prompt_token_ids(self) -> list[int]:
inputs = self.inputs

if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("prompt_token_ids", [])

assert_never(inputs) # type: ignore[arg-type]

@cached_property
def token_type_ids(self) -> list[int]:
inputs = self.inputs

if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("token_type_ids", [])

assert_never(inputs) # type: ignore[arg-type]

@cached_property
def prompt_embeds(self) -> Optional[torch.Tensor]:
inputs = self.inputs

if inputs["type"] == "token" or inputs["type"] == "multimodal":
return None

assert_never(inputs) # type: ignore[arg-type]

@cached_property
def multi_modal_data(self) -> "MultiModalDataDict":
inputs = self.inputs

if inputs["type"] == "token":
return inputs.get("multi_modal_data", {})

if inputs["type"] == "multimodal":
return inputs.get("mm_kwargs", {})

assert_never(inputs) # type: ignore[arg-type]

@cached_property
def multi_modal_inputs(self) -> Union[dict, "MultiModalKwargs"]:
inputs = self.inputs

if inputs["type"] == "token":
return inputs.get("multi_modal_inputs", {})

if inputs["type"] == "multimodal":
return inputs.get("mm_kwargs", {})

assert_never(inputs) # type: ignore[arg-type]

@cached_property
def multi_modal_hashes(self) -> list[str]:
inputs = self.inputs

if inputs["type"] == "token":
return inputs.get("multi_modal_hashes", [])

if inputs["type"] == "multimodal":
# only the case when we use MultiModalInputs
return inputs.get("mm_hashes", []) # type: ignore[return-value]

assert_never(inputs) # type: ignore[arg-type]

@cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
inputs = self.inputs

if inputs["type"] == "token":
return inputs.get("multi_modal_placeholders", {})

if inputs["type"] == "multimodal":
return inputs.get("mm_placeholders", {})

assert_never(inputs) # type: ignore[arg-type]

@cached_property
def mm_processor_kwargs(self) -> dict[str, Any]:
inputs = self.inputs

if inputs["type"] == "token":
return inputs.get("mm_processor_kwargs", {})

if inputs["type"] == "multimodal":
return {}

assert_never(inputs) # type: ignore[arg-type]


ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
The inputs to :data:`vllm.inputs.InputProcessor`.
Expand Down
Loading