Skip to content

Conversation

afeldman-nm
Copy link
Contributor

@afeldman-nm afeldman-nm commented Aug 19, 2024

NOTE: because SamplerOutput was refactored from sequence.py into sampler.py, this PR makes a large number of trivial import changes to all files which use SamplerOutput (i.e. many of the model files.) The only files you need to review are:

  • tests/multi-step/test_correctness.py
  • vllm/model_executor/layers/sampler.py
  • vllm/worker/multi_step_model_runner.py

The purpose of this PR is to achieve parity between the logprobs UX for multi-step scheduling, and the logprobs UX for single-step scheduling.

Multi-step scheduling defers output pythonization until after N gpu-side steps have completed.

For multi-step scheduling, currently the following pythonization steps are skipped entirely rather than being deferred:

  • The output of _sample_with_torch() in sample.py is never pythonized
  • Sampler.forward() in sample.py never pythonizes logprobs (this depends on the pythonized output of _sample_with_torch())
  • _pythonize_sampler_output() in multi_step_model_runner.py never utilizes the logprobs computed by the sampler, passing a dummy value back for logprobs instead.

Note that during profile_run(), the statements above do not apply because pythonization is performed immediately and not deferred. It is only during actual prefill and decode steps where the above difficulties with deferred pythonization apply.

This PR adds logprobs to the multi-step UX by:

  • (During profile_run()): having _pythonize_sampler_output() extract the already-computed logprobs and inject them into the pythonized output
  • (During prefill/decode): having _pythonize_sampler_output() invoke the deferred pythonization steps & inject the resultant logprobs into the pythonized output

Related to #7528 (only the logprobs support issue)

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@SolitaryThinker
Copy link
Contributor

fyi #7000 has been merged to main

@afeldman-nm afeldman-nm force-pushed the afeldman-nm/logprobs branch from 43cc1e3 to 642d31b Compare August 21, 2024 13:42
@afeldman-nm
Copy link
Contributor Author

FYI, the PR as it stands moves SamplerOutput from sequences.py to sampler.py in order to get around issues with circular imports (this also is a sensible reorganization on principle, in my opinion.) Since every model imports sampler, there are a large number of files changed - however most of these are trivial changes to the sampler import.

@afeldman-nm
Copy link
Contributor Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 21, 2024
@afeldman-nm
Copy link
Contributor Author

@SolitaryThinker @comaniac

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Comments are mostly for coding style and refactoring.
Also can you run some benchmarks to make sure there's no performance regression with logprobs is not requested?

@afeldman-nm
Copy link
Contributor Author

afeldman-nm commented Aug 27, 2024

Hi @comaniac FYI I addressed the perf regression by skipping logprobs pythonization entirely in the scenario where no logprobs are required. I updated the benchmark results in-place so you can see them. In my opinion the new results are at parity with the baseline results. A few of these metrics have significant variance & I do not think that any of the metrics are worse by a significant amount.

@afeldman-nm afeldman-nm force-pushed the afeldman-nm/logprobs branch from 81fedc1 to fbb75b7 Compare August 28, 2024 03:07
@comaniac comaniac enabled auto-merge (squash) August 28, 2024 03:58
auto-merge was automatically disabled August 29, 2024 19:48

Head branch was pushed to by a user without write access

@simon-mo simon-mo merged commit 428dd14 into vllm-project:main Aug 30, 2024
16 checks passed
@afeldman-nm
Copy link
Contributor Author

afeldman-nm commented Sep 3, 2024

(Note: the method used to configure logprobs via CLI was not implemented correctly for this test so the test must be repeated with a correct method. The impact of async output proc is probably reflected accurately in the results below, but not the impact of logprobs)

Perf test (8x multi-step + {async output proc., logprobs}, main (ec26653)):

Server (TP=1, PP=1):

python  -m vllm.entrypoints.openai.api_server     --model meta-llama/Meta-Llama-3-8B --swap-space 16     --disable-log-requests  --use-v2-block-manager   --tensor-parallel-size 1  --pipeline-parallel-size 1  --gpu-memory-utilization 0.90  --num-scheduler-steps 8 {--disable-async-output-proc} --port 9000  --worker-use-ray 2>&1 | tee output.log

Client:

python benchmarks/benchmark_serving.py     --backend vllm  {--logprobs 5}    --model meta-llama/Meta-Llama-3-8B     --dataset-name sharegpt     --dataset-path ../sharegpt.json --port 9000

Testbench:

1x NVIDIA H100 80GB

Benchmarking result (+ async output proc, - logprobs)

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  25.06     
Total input tokens:                      215196    
Total generated tokens:                  128433    
Request throughput (req/s):              39.91     
Input token throughput (tok/s):          8587.63   
Output token throughput (tok/s):         5125.26   
---------------Time to First Token----------------
Mean TTFT (ms):                          5242.34   
Median TTFT (ms):                        4397.99   
P99 TTFT (ms):                           17020.34  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          44.01     
Median TPOT (ms):                        34.54     
P99 TPOT (ms):                           307.28    
---------------Inter-token Latency----------------
Mean ITL (ms):                           239.08    
Median ITL (ms):                         245.91    
P99 ITL (ms):                            658.52    
==================================================

Benchmarking result (- async output proc, - logprobs)

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  28.48     
Total input tokens:                      215196    
Total generated tokens:                  128453    
Request throughput (req/s):              35.12     
Input token throughput (tok/s):          7556.83   
Output token throughput (tok/s):         4510.76   
---------------Time to First Token----------------
Mean TTFT (ms):                          5380.46   
Median TTFT (ms):                        4032.20   
P99 TTFT (ms):                           18331.10  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          60.71     
Median TPOT (ms):                        41.93     
P99 TPOT (ms):                           441.90    
---------------Inter-token Latency----------------
Mean ITL (ms):                           284.85    
Median ITL (ms):                         274.92    
P99 ITL (ms):                            693.19    
==================================================

Benchmarking result (+ async output proc, + logprobs)

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  53.42     
Total input tokens:                      215196    
Total generated tokens:                  128159    
Request throughput (req/s):              18.72     
Input token throughput (tok/s):          4028.60   
Output token throughput (tok/s):         2399.22   
---------------Time to First Token----------------
Mean TTFT (ms):                          7341.89   
Median TTFT (ms):                        4569.37   
P99 TTFT (ms):                           27881.26  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          103.21    
Median TPOT (ms):                        83.13     
P99 TPOT (ms):                           444.63    
---------------Inter-token Latency----------------
Mean ITL (ms):                           587.77    
Median ITL (ms):                         601.43    
P99 ITL (ms):                            1291.48   
==================================================

Benchmarking result (- async output proc, + logprobs)

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  53.95     
Total input tokens:                      215196    
Total generated tokens:                  128083    
Request throughput (req/s):              18.53     
Input token throughput (tok/s):          3988.47   
Output token throughput (tok/s):         2373.91   
---------------Time to First Token----------------
Mean TTFT (ms):                          7401.77   
Median TTFT (ms):                        4576.37   
P99 TTFT (ms):                           28999.51  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          111.32    
Median TPOT (ms):                        81.79     
P99 TPOT (ms):                           607.83    
---------------Inter-token Latency----------------
Mean ITL (ms):                           582.65    
Median ITL (ms):                         625.54    
P99 ITL (ms):                            1101.44   
==================================================

@afeldman-nm afeldman-nm deleted the afeldman-nm/logprobs branch September 5, 2024 15:27
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants