Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c618001
Rename input data types:
DarkLight1337 Sep 21, 2024
d4a5c21
Rename `from_token_counts` to `from_prompt_token_counts`
DarkLight1337 Sep 21, 2024
74b0f24
Merge branch 'main' into rename-inputs
DarkLight1337 Sep 22, 2024
3fc063b
Merge branch 'main' into rename-inputs
DarkLight1337 Sep 23, 2024
03923ea
Cleanup
DarkLight1337 Sep 23, 2024
eb00f71
Adopt `token_inputs` helper function
DarkLight1337 Sep 23, 2024
b4f4cc1
Cleanup
DarkLight1337 Sep 23, 2024
fc1278f
Merge branch 'main' into rename-inputs
DarkLight1337 Sep 25, 2024
5c0f091
Add backward compatibility for `LLMInputs`
DarkLight1337 Sep 25, 2024
a8cb339
Add backward compability for `EncoderDecoderLLMInputs`
DarkLight1337 Sep 25, 2024
ab5a937
rename PromptInputs and inputs with backward compatibility (#8760)
DarkLight1337 Sep 25, 2024
4f5a5d5
Fix doc
DarkLight1337 Sep 26, 2024
d2bebaa
Merge branch 'rename-prompt-redo' into rename-inputs
DarkLight1337 Sep 27, 2024
343e4c9
Update mllama
DarkLight1337 Sep 27, 2024
3f099a1
Fix type annotation
DarkLight1337 Sep 27, 2024
d6d958c
Merge branch 'main' into rename-inputs
DarkLight1337 Sep 27, 2024
c2db5e1
Remove faulty assertion
DarkLight1337 Sep 27, 2024
54bb0cf
Fix processor call
DarkLight1337 Sep 27, 2024
6c2f55f
format
DarkLight1337 Sep 27, 2024
d80ed3b
Merge branch 'main' into rename-inputs
DarkLight1337 Sep 29, 2024
f89b108
Merge branch 'main' into rename-inputs
DarkLight1337 Sep 29, 2024
43b4116
Merge branch 'main' into rename-inputs
DarkLight1337 Sep 30, 2024
c6b0d86
Merge branch 'main' into rename-inputs
DarkLight1337 Oct 4, 2024
1b443fe
Merge branch 'main' into rename-inputs
DarkLight1337 Oct 5, 2024
1c3ce7b
Merge branch 'main' into rename-inputs
DarkLight1337 Oct 7, 2024
be96d9c
Merge branch 'main' into rename-inputs
DarkLight1337 Oct 9, 2024
ea8242f
Merge branch 'main' into rename-inputs
DarkLight1337 Oct 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/dev/input_processing/model_inputs_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Module Contents
LLM Engine Inputs
-----------------

.. autoclass:: vllm.inputs.LLMInputs
.. autoclass:: vllm.inputs.DecoderOnlyInputs
:members:
:show-inheritance:

Expand Down
28 changes: 13 additions & 15 deletions tests/models/decoder_only/vision_language/test_phi3v.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
import re
from typing import Callable, List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type

import pytest
import torch
from transformers import AutoImageProcessor, AutoTokenizer

from vllm.inputs import InputContext, LLMInputs
from vllm.inputs import InputContext, token_inputs
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_input_mapper_override(model: str, image_assets: _ImageAssets,
(4, 781),
(16, 2653),
])
def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
num_crops: int, expected_max_tokens: int):
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
# NOTE: mm_processor_kwargs on the context in this test is unused, since
Expand Down Expand Up @@ -343,8 +343,8 @@ def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
(16, 2653, 1),
(16, 2653, 2),
])
def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
num_crops: int, toks_per_img: int, num_imgs: int):
def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
toks_per_img: int, num_imgs: int):
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
Expand Down Expand Up @@ -374,7 +374,7 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
(16, 1921, 1),
(16, 1921, 2),
])
def test_input_processor_override(input_processor_for_phi3v: Callable,
def test_input_processor_override(input_processor_for_phi3v,
image_assets: _ImageAssets, model: str,
num_crops: int, expected_toks_per_img: int,
num_imgs: int):
Expand All @@ -393,16 +393,14 @@ def test_input_processor_override(input_processor_for_phi3v: Callable,
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs

llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})

proc_llm_inputs = input_processor_for_phi3v(
ctx=ctx,
llm_inputs=llm_inputs,
num_crops=num_crops,
)
processed_inputs = input_processor_for_phi3v(ctx,
inputs,
num_crops=num_crops)

# Ensure we have the right number of placeholders per num_crops size
img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs
12 changes: 6 additions & 6 deletions tests/models/decoder_only/vision_language/test_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from PIL.Image import Image

from vllm.inputs import InputContext, LLMInputs
from vllm.inputs import InputContext, token_inputs
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size

Expand Down Expand Up @@ -71,12 +71,12 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen,
"""Happy cases for image inputs to Qwen's multimodal input processor."""
prompt = "".join(
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
inputs = LLMInputs(
inputs = token_inputs(
prompt=prompt,
# When processing multimodal data for a multimodal model, the qwen
# input processor will overwrite the provided prompt_token_ids with
# the image prompts
prompt_token_ids=None,
prompt_token_ids=[],
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
)
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
Expand Down Expand Up @@ -134,9 +134,9 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen,
trust_remote_code=True)
prompt = "Picture 1: <img></img>\n"
prompt_token_ids = tokenizer.encode(prompt)
inputs = LLMInputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_data)
inputs = token_inputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_data)
# Should fail since we have too many or too few dimensions for embeddings
with pytest.raises(ValueError):
input_processor_for_qwen(qwen_vl_context, inputs)
Expand Down
18 changes: 9 additions & 9 deletions tests/multimodal/test_processor_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import torch

from vllm.inputs import InputContext, LLMInputs
from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs
from vllm.inputs.registry import InputRegistry
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
Expand All @@ -31,7 +31,7 @@ def use_processor_mock():
"""Patches the internal model input processor with an override callable."""

def custom_processor(ctx: InputContext,
llm_inputs: LLMInputs,
inputs: DecoderOnlyInputs,
*,
num_crops=DEFAULT_NUM_CROPS):
# For testing purposes, we don't worry about the llm inputs / return
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_default_processor_is_a_noop():
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID)
processor = dummy_registry.create_input_processor(ctx.model_config)
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
proc_inputs = token_inputs(prompt_token_ids=[], prompt="")
proc_outputs = processor(inputs=proc_inputs)
assert proc_inputs is proc_outputs

Expand Down Expand Up @@ -125,9 +125,9 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops,
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
assert num_crops_val == expected_seq_count


Expand All @@ -154,9 +154,9 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
processor = dummy_registry.create_input_processor(ctx.model_config)
# Should filter out the inference time kwargs
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=mm_processor_kwargs))
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=mm_processor_kwargs))
assert num_crops_val == DEFAULT_NUM_CROPS


Expand Down
10 changes: 5 additions & 5 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptType)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -635,7 +635,7 @@ def _verify_args(self) -> None:
def _add_processed_request(
self,
request_id: str,
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
Expand Down Expand Up @@ -1855,8 +1855,8 @@ def is_encoder_decoder_model(self):
def is_embedding_model(self):
return self.model_config.is_embedding_model

def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
Expand Down
37 changes: 29 additions & 8 deletions vllm/inputs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry

INPUT_REGISTRY = InputRegistry()
Expand All @@ -19,8 +20,11 @@
"PromptType",
"SingletonPrompt",
"ExplicitEncoderDecoderPrompt",
"LLMInputs",
"EncoderDecoderLLMInputs",
"TokenInputs",
"token_inputs",
"SingletonInputs",
"DecoderOnlyInputs",
"EncoderDecoderInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
Expand All @@ -31,14 +35,31 @@


def __getattr__(name: str):
if name == "PromptInput":
import warnings
import warnings

if name == "PromptInput":
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return PromptType

if name == "LLMInputs":
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return DecoderOnlyInputs

if name == "EncoderDecoderLLMInputs":
msg = (
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return EncoderDecoderInputs

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
78 changes: 62 additions & 16 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Optional, Tuple, Union)
Optional, Tuple, Union, cast)

from typing_extensions import NotRequired, TypedDict, TypeVar

Expand Down Expand Up @@ -51,7 +51,7 @@ class TokensPrompt(TypedDict):

SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
"""
Set of possible schemas for a single LLM input:
Set of possible schemas for a single prompt:

- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
Expand Down Expand Up @@ -120,13 +120,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
"""


class LLMInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.

This specifies the data required for decoder-only models.
"""
class TokenInputs(TypedDict):
"""Represents token-based inputs."""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""

Expand All @@ -150,7 +145,40 @@ class LLMInputs(TypedDict):
"""


class EncoderDecoderLLMInputs(LLMInputs):
def token_inputs(
prompt_token_ids: List[int],
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values."""
inputs = TokenInputs(prompt_token_ids=prompt_token_ids)

if prompt is not None:
inputs["prompt"] = prompt
if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data
if mm_processor_kwargs is not None:
inputs["mm_processor_kwargs"] = mm_processor_kwargs

return inputs


SingletonInputs = TokenInputs
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""

DecoderOnlyInputs = TokenInputs
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""


class EncoderDecoderInputs(TokenInputs):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
Expand Down Expand Up @@ -204,11 +232,12 @@ def zip_enc_dec_prompts(
be zipped with the encoder/decoder prompts.
"""
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
if isinstance(mm_processor_kwargs, Dict):
mm_processor_kwargs = cast(Dict[str, Any], {})
if isinstance(mm_processor_kwargs, dict):
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
mm_processor_kwargs)
build_explicit_enc_dec_prompt(
encoder_prompt, decoder_prompt,
cast(Dict[str, Any], mm_processor_kwargs))
for (encoder_prompt,
decoder_prompt) in zip(enc_prompts, dec_prompts)
]
Expand All @@ -229,14 +258,31 @@ def to_enc_dec_tuple_list(


def __getattr__(name: str):
if name == "PromptInput":
import warnings
import warnings

if name == "PromptInput":
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return PromptType

if name == "LLMInputs":
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return DecoderOnlyInputs

if name == "EncoderDecoderLLMInputs":
msg = (
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return EncoderDecoderInputs

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Loading