Skip to content

Commit 8f24193

Browse files
stephanie-wangStephanie
authored andcommitted
[Core] Refactor Worker and ModelRunner to consolidate control plane communication (vllm-project#5408)
Signed-off-by: Stephanie Wang <[email protected]> Signed-off-by: Stephanie <[email protected]> Co-authored-by: Stephanie <[email protected]> Signed-off-by: Alvant <[email protected]>
1 parent dbaf5b7 commit 8f24193

29 files changed

+1108
-575
lines changed

tests/worker/test_model_input.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import dataclasses
2+
from typing import List, Tuple, Type
3+
4+
import torch
5+
6+
from vllm.attention import AttentionMetadata
7+
from vllm.attention.backends.abstract import AttentionBackend
8+
from vllm.model_executor import SamplingMetadata
9+
from vllm.model_executor.pooling_metadata import PoolingMetadata
10+
from vllm.worker.embedding_model_runner import (
11+
ModelInputForGPUWithPoolingMetadata)
12+
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
13+
14+
15+
class MockAttentionBackend(AttentionBackend):
16+
17+
@staticmethod
18+
def get_name() -> str:
19+
raise NotImplementedError
20+
21+
@staticmethod
22+
def get_impl_cls():
23+
raise NotImplementedError
24+
25+
@staticmethod
26+
def get_metadata_cls() -> Type["AttentionMetadata"]:
27+
return AttentionMetadata
28+
29+
@staticmethod
30+
def get_kv_cache_shape(
31+
num_blocks: int,
32+
block_size: int,
33+
num_kv_heads: int,
34+
head_size: int,
35+
) -> Tuple[int, ...]:
36+
raise NotImplementedError
37+
38+
@staticmethod
39+
def swap_blocks(
40+
src_kv_cache: torch.Tensor,
41+
dst_kv_cache: torch.Tensor,
42+
src_to_dst: torch.Tensor,
43+
) -> None:
44+
pass
45+
46+
@staticmethod
47+
def copy_blocks(
48+
kv_caches: List[torch.Tensor],
49+
src_to_dists: torch.Tensor,
50+
) -> None:
51+
pass
52+
53+
54+
def test_model_runner_input():
55+
sampling_metadata = SamplingMetadata(
56+
["seq_group"],
57+
"selected_token_indices",
58+
"categorized_sample_indices",
59+
"num_prompts",
60+
)
61+
attn_metadata = AttentionMetadata(
62+
num_prefills=1,
63+
num_prefill_tokens=2,
64+
num_decode_tokens=3,
65+
slot_mapping=torch.zeros(1),
66+
)
67+
model_input = ModelInputForGPUWithSamplingMetadata(
68+
input_tokens=torch.ones(10),
69+
input_positions=torch.ones(10),
70+
sampling_metadata=sampling_metadata,
71+
attn_metadata=attn_metadata)
72+
73+
assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
74+
75+
# Test round trip serialization.
76+
tensor_dict = model_input.as_broadcastable_tensor_dict()
77+
attn_backend = MockAttentionBackend()
78+
received_model_input = (
79+
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
80+
tensor_dict, attn_backend=attn_backend))
81+
# Check that received copy has correct values.
82+
assert isinstance(received_model_input,
83+
ModelInputForGPUWithSamplingMetadata)
84+
assert received_model_input.input_tokens is not None
85+
assert (
86+
received_model_input.input_tokens == model_input.input_tokens).all()
87+
assert received_model_input.input_positions is not None
88+
assert (received_model_input.input_positions == model_input.input_positions
89+
).all()
90+
assert received_model_input.multi_modal_kwargs is None
91+
assert (received_model_input.multi_modal_kwargs ==
92+
model_input.multi_modal_kwargs)
93+
assert received_model_input.lora_requests is None
94+
assert received_model_input.lora_requests == model_input.lora_requests
95+
assert received_model_input.lora_mapping is None
96+
assert received_model_input.lora_mapping == model_input.lora_mapping
97+
for field in dataclasses.fields(AttentionMetadata):
98+
assert getattr(received_model_input.attn_metadata, field.name,
99+
None) == getattr(attn_metadata, field.name, None)
100+
# For sampling metadata, only selected_token_indices is copied.
101+
assert (received_model_input.sampling_metadata.selected_token_indices ==
102+
sampling_metadata.selected_token_indices)
103+
assert received_model_input.sampling_metadata.seq_groups is None
104+
105+
106+
def test_embedding_model_runner_input():
107+
pooling_metadata = PoolingMetadata(
108+
seq_groups=[[0]],
109+
seq_data={},
110+
prompt_lens=[1],
111+
)
112+
attn_metadata = AttentionMetadata(
113+
num_prefills=1,
114+
num_prefill_tokens=2,
115+
num_decode_tokens=3,
116+
slot_mapping=torch.zeros(1),
117+
)
118+
model_input = ModelInputForGPUWithPoolingMetadata(
119+
input_tokens=torch.ones(10),
120+
input_positions=torch.ones(10),
121+
pooling_metadata=pooling_metadata,
122+
attn_metadata=attn_metadata)
123+
124+
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
125+
126+
# Test round trip serialization.
127+
tensor_dict = model_input.as_broadcastable_tensor_dict()
128+
attn_backend = MockAttentionBackend()
129+
received_model_input = (
130+
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
131+
tensor_dict, attn_backend=attn_backend))
132+
# Check that received copy has correct values.
133+
assert isinstance(received_model_input,
134+
ModelInputForGPUWithPoolingMetadata)
135+
assert received_model_input.input_tokens is not None
136+
assert (
137+
received_model_input.input_tokens == model_input.input_tokens).all()
138+
assert received_model_input.input_positions is not None
139+
assert (received_model_input.input_positions == model_input.input_positions
140+
).all()
141+
assert received_model_input.multi_modal_kwargs is None
142+
assert (received_model_input.multi_modal_kwargs ==
143+
model_input.multi_modal_kwargs)
144+
assert received_model_input.lora_requests is None
145+
assert received_model_input.lora_requests == model_input.lora_requests
146+
assert received_model_input.lora_mapping is None
147+
assert received_model_input.lora_mapping == model_input.lora_mapping
148+
for field in dataclasses.fields(AttentionMetadata):
149+
assert getattr(received_model_input.attn_metadata, field.name,
150+
None) == getattr(attn_metadata, field.name, None)
151+
# Pooling metadata is not broadcast.
152+
assert received_model_input.pooling_metadata is None

tests/worker/test_model_runner.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ def test_prepare_prompt(batch_size):
6161
expected_selected_token_indices.append(selected_token_start_idx +
6262
seq_len - 1)
6363
selected_token_start_idx += seq_len
64-
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
64+
model_input = model_runner._prepare_model_input_tensors(
65+
seq_group_metadata_list)
6566
input_tokens = model_input.input_tokens
6667
input_positions = model_input.input_positions
6768
attn_metadata = model_input.attn_metadata
6869
return_seq_lens = model_input.seq_lens
69-
slot_mapping = model_input.slot_mapping
70+
slot_mapping = attn_metadata.slot_mapping
7071
assert return_seq_lens == seq_lens
7172
assert len(slot_mapping) == len(input_tokens)
7273

@@ -174,10 +175,11 @@ def test_prepare_decode_cuda_graph(batch_size):
174175
assert seq_group_metadata.token_chunk_size == 1
175176
seq_group_metadata_list.append(seq_group_metadata)
176177

177-
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
178+
model_input = model_runner._prepare_model_input_tensors(
179+
seq_group_metadata_list)
178180
input_tokens, input_positions, attn_metadata, slot_mapping = (
179181
model_input.input_tokens, model_input.input_positions,
180-
model_input.attn_metadata, model_input.slot_mapping)
182+
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
181183
assert len(slot_mapping) == len(input_tokens)
182184

183185
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
@@ -259,32 +261,29 @@ def test_empty_seq_group():
259261
enforce_eager=False,
260262
)
261263
seq_group_metadata_list: List[SequenceGroupMetadata] = []
262-
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
263-
input_tokens, input_positions, attn_metadata, slot_mapping = (
264+
model_input = model_runner._prepare_model_input_tensors(
265+
seq_group_metadata_list)
266+
input_tokens, input_positions, attn_metadata = (
264267
model_input.input_tokens,
265268
model_input.input_positions,
266269
model_input.attn_metadata,
267-
model_input.slot_mapping,
268270
)
269-
assert len(input_tokens) == 0
270-
assert len(input_positions) == 0
271+
assert input_tokens is None
272+
assert input_positions is None
271273
assert attn_metadata is None
272-
assert len(slot_mapping) == 0
273-
274-
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
275-
(input_tokens, input_positions, attn_metadata, slot_mapping,
276-
return_seq_lens) = (
277-
model_input.input_tokens,
278-
model_input.input_positions,
279-
model_input.attn_metadata,
280-
model_input.slot_mapping,
281-
model_input.seq_lens,
282-
)
283-
assert len(input_tokens) == 0
284-
assert len(input_positions) == 0
274+
275+
model_input = model_runner._prepare_model_input_tensors(
276+
seq_group_metadata_list)
277+
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
278+
model_input.input_tokens,
279+
model_input.input_positions,
280+
model_input.attn_metadata,
281+
model_input.seq_lens,
282+
)
283+
assert input_tokens is None
284+
assert input_positions is None
285285
assert attn_metadata is None
286-
assert len(slot_mapping) == 0
287-
assert len(return_seq_lens) == 0
286+
assert return_seq_lens is None
288287

289288

290289
@pytest.fixture
@@ -353,8 +352,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
353352
seq_group_metadata_list.append(seq_group_metadata)
354353
decode_metadata_list.append(seq_group_metadata)
355354

356-
(input_tokens, input_positions, attn_metadata, _, _, _,
357-
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
355+
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
356+
(input_tokens, input_positions, attn_metadata) = (
357+
model_input.input_tokens,
358+
model_input.input_positions,
359+
model_input.attn_metadata,
360+
)
358361

359362
prefill_meta_actual = attn_metadata.prefill_metadata
360363
decode_meta_actual = attn_metadata.decode_metadata
@@ -367,7 +370,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
367370

368371
# Verify attn metadata is consistent. We don't need to test individual
369372
# values here because they are tested above.
370-
attn_metadata = model_runner._prepare_model_input(
373+
attn_metadata = model_runner._prepare_model_input_tensors(
371374
seq_group_metadata_list).attn_metadata
372375

373376
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),

vllm/attention/backends/abstract.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@ def get_impl_cls() -> Type["AttentionImpl"]:
2121

2222
@staticmethod
2323
@abstractmethod
24-
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
24+
def get_metadata_cls() -> Type["AttentionMetadata"]:
2525
raise NotImplementedError
2626

27+
@classmethod
28+
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
29+
return cls.get_metadata_cls()(*args, **kwargs)
30+
2731
@staticmethod
2832
@abstractmethod
2933
def get_kv_cache_shape(

vllm/attention/backends/blocksparse_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
9090
return BlocksparseFlashAttentionImpl
9191

9292
@staticmethod
93-
def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
94-
return BlocksparseFlashAttentionMetadata(*args, **kwargs)
93+
def get_metadata_cls() -> Type["AttentionMetadata"]:
94+
return BlocksparseFlashAttentionMetadata
9595

9696
@staticmethod
9797
def get_kv_cache_shape(

vllm/attention/backends/flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]:
2525
return FlashAttentionImpl
2626

2727
@staticmethod
28-
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
29-
return FlashAttentionMetadata(*args, **kwargs)
28+
def get_metadata_cls() -> Type["AttentionMetadata"]:
29+
return FlashAttentionMetadata
3030

3131
@staticmethod
3232
def get_kv_cache_shape(

vllm/attention/backends/flashinfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def get_impl_cls() -> Type["FlashInferImpl"]:
2222
return FlashInferImpl
2323

2424
@staticmethod
25-
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
26-
return FlashInferMetadata(*args, **kwargs)
25+
def get_metadata_cls() -> Type["AttentionMetadata"]:
26+
return FlashInferMetadata
2727

2828
@staticmethod
2929
def get_kv_cache_shape(

vllm/attention/backends/ipex_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
2525
return IpexAttnBackendImpl
2626

2727
@staticmethod
28-
def make_metadata(*args, **kwargs) -> "IpexAttnMetadata":
29-
return IpexAttnMetadata(*args, **kwargs)
28+
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
29+
return IpexAttnMetadata
3030

3131
@staticmethod
3232
def get_kv_cache_shape(

vllm/attention/backends/pallas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
1616
return PallasAttentionBackendImpl
1717

1818
@staticmethod
19-
def make_metadata(*args, **kwargs) -> "PallasMetadata":
20-
return PallasMetadata(*args, **kwargs)
19+
def get_metadata_cls() -> Type["PallasMetadata"]:
20+
return PallasMetadata
2121

2222
@staticmethod
2323
def get_kv_cache_shape(

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
2525
return ROCmFlashAttentionImpl
2626

2727
@staticmethod
28-
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
29-
return ROCmFlashAttentionMetadata(*args, **kwargs)
28+
def get_metadata_cls() -> Type["AttentionMetadata"]:
29+
return ROCmFlashAttentionMetadata
3030

3131
@staticmethod
3232
def get_kv_cache_shape(

vllm/attention/backends/torch_sdpa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
3131
return TorchSDPABackendImpl
3232

3333
@staticmethod
34-
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
35-
return TorchSDPAMetadata(*args, **kwargs)
34+
def get_metadata_cls() -> Type["AttentionMetadata"]:
35+
return TorchSDPAMetadata
3636

3737
@staticmethod
3838
def get_kv_cache_shape(

0 commit comments

Comments
 (0)