Skip to content

Commit fd95e02

Browse files
afeldman-nmabf149njhill
authored
[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)
Co-authored-by: Andrew Feldman <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 660470e commit fd95e02

33 files changed

+3976
-352
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,9 @@ steps:
148148
- python3 cpu_offload.py
149149
- python3 offline_inference_with_prefix.py
150150
- python3 llm_engine_example.py
151-
- python3 llava_example.py
151+
- python3 offline_inference_vision_language.py
152152
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
153+
- python3 offline_inference_encoder_decoder.py
153154

154155
- label: Models Test # 1hr10min
155156
source_file_dependencies:
@@ -289,6 +290,7 @@ steps:
289290
commands:
290291
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
291292
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
293+
- pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py
292294
- pytest -v -s distributed/test_chunked_prefill_distributed.py
293295
- pytest -v -s distributed/test_multimodal_broadcast.py
294296
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
'''
2+
Demonstrate prompting of text-to-text
3+
encoder/decoder models, specifically BART
4+
'''
5+
6+
from vllm import LLM, SamplingParams
7+
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
8+
from vllm.utils import zip_enc_dec_prompt_lists
9+
10+
dtype = "float"
11+
12+
# Create a BART encoder/decoder model instance
13+
llm = LLM(
14+
model="facebook/bart-large-cnn",
15+
dtype=dtype,
16+
)
17+
18+
# Get BART tokenizer
19+
tokenizer = llm.llm_engine.get_tokenizer_group()
20+
21+
# Test prompts
22+
#
23+
# This section shows all of the valid ways to prompt an
24+
# encoder/decoder model.
25+
#
26+
# - Helpers for building prompts
27+
text_prompt_raw = "Hello, my name is"
28+
text_prompt = TextPrompt(prompt="The president of the United States is")
29+
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
30+
prompt="The capital of France is"))
31+
# - Pass a single prompt to encoder/decoder model
32+
# (implicitly encoder input prompt);
33+
# decoder input prompt is assumed to be None
34+
35+
single_text_prompt_raw = text_prompt_raw # Pass a string directly
36+
single_text_prompt = text_prompt # Pass a TextPrompt
37+
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
38+
39+
# - Pass explicit encoder and decoder input prompts within one data structure.
40+
# Encoder and decoder prompts can both independently be text or tokens, with
41+
# no requirement that they be the same prompt type. Some example prompt-type
42+
# combinations are shown below, note that these are not exhaustive.
43+
44+
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
45+
# Pass encoder prompt string directly, &
46+
# pass decoder prompt tokens
47+
encoder_prompt=single_text_prompt_raw,
48+
decoder_prompt=single_tokens_prompt,
49+
)
50+
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
51+
# Pass TextPrompt to encoder, and
52+
# pass decoder prompt string directly
53+
encoder_prompt=single_text_prompt,
54+
decoder_prompt=single_text_prompt_raw,
55+
)
56+
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
57+
# Pass encoder prompt tokens directly, and
58+
# pass TextPrompt to decoder
59+
encoder_prompt=single_tokens_prompt,
60+
decoder_prompt=single_text_prompt,
61+
)
62+
63+
# - Finally, here's a useful helper function for zipping encoder and
64+
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
65+
# instances
66+
zipped_prompt_list = zip_enc_dec_prompt_lists(
67+
['An encoder prompt', 'Another encoder prompt'],
68+
['A decoder prompt', 'Another decoder prompt'])
69+
70+
# - Let's put all of the above example prompts together into one list
71+
# which we will pass to the encoder/decoder LLM.
72+
prompts = [
73+
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
74+
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
75+
] + zipped_prompt_list
76+
77+
print(prompts)
78+
79+
# Create a sampling params object.
80+
sampling_params = SamplingParams(
81+
temperature=0,
82+
top_p=1.0,
83+
min_tokens=0,
84+
max_tokens=20,
85+
)
86+
87+
# Generate output tokens from the prompts. The output is a list of
88+
# RequestOutput objects that contain the prompt, generated
89+
# text, and other information.
90+
outputs = llm.generate(prompts, sampling_params)
91+
92+
# Print the outputs.
93+
for output in outputs:
94+
prompt = output.prompt
95+
encoder_prompt = output.encoder_prompt
96+
generated_text = output.outputs[0].text
97+
print(f"Encoder prompt: {encoder_prompt!r}, "
98+
f"Decoder prompt: {prompt!r}, "
99+
f"Generated text: {generated_text!r}")

0 commit comments

Comments
 (0)