Skip to content

Conversation

afierka-intel
Copy link

@afierka-intel afierka-intel commented Jan 17, 2025

One line description

Use topk instead of sort for topp/topk calculation under certain conditions (scalar value of p and k).

Details

Instead of using k for topk, we use _padded_k, which is strictly larger than k and monotonically non decreasing.

We need/use _padded_k > k for cases where the smallest value of the topk=k values has some values beyond k, (for example for [9,8,8,8,7,7,7], with k=3, we have [9,8,8,8], which is 4 instead of 3 values),

To prevent excessive recompilations, anytime we require an expansion of _padded_k we increment with a fixed constant _increment (usually >1), to have a bucketed approach to prevent multiple shapes

Basic outline

  1. perform topk with _padded_k
  2. find the "kth" value in each row (smallest number that will be in topk), this is variable num_duplicates_of_smallest_of_topk
  3. find maximum of number of duplicates, this variable is max_num_duplicates_of_smallest_of_topk
  4. check if _padded_k is big enough to contain max_num_duplicates_of_smallest_of_topk. if not, then expand _padded_k, and redo the topk again with expanded _padded_k
  5. maskout the values that are extra in _padded_k
  6. move to doing topp

Perf benefit

The feature gives a 49% increase in thruput in the case with warmup, and 30% increase in thruput in the case without warmup

Extra Notes

  1. Works only for "scalar" case, though it might be possible to extend the basic idea (topk instead of sort) for vector case as well. (Outline of this is: find max k in topk vector, then perform topk using that, etc. needs some bucketing possibly to prevent dyn shapes etc)
  2. Need an additional check in _init_sampling_tensors to determine if its scalar case. This has a minor perf hit. ideally if someone could tell us that its a scalar from the top itself...
  3. Some tradeoffs can be made, where we use a sufficiently large padded_k (which is still smaller than vocab size) from the beginning, and hope that every case lands within that bucket. Cases that wont land are expected to be very, very rare. For example if padded_k = max(2 * k, 100) is used, and k = say 50, then we need the smallest of the topk value to repeat 50 times with same probability, which is exceedingly unlikely. If we trade off this mathematical improbability, then we can do with only 1 topk op, which might be faster
  4. There is a fliplr in the code, which could be removed, if we can compute reverse cumsum. however the formula for reverse cumsum as expressed here , x + torch.sum(x, dim=1, keepdims=True) - torch.cumsum(x, dim=1) is numerically unstable, because of the addition/subtraction. It works well enough on ints and large numbers, but not on small probability values.
  5. The value of k affects the gains we might get from this. For example in the expt shown above, with k=20, thruput increases around 30%. But if k = 2000, instead of 20, throughput increases the gain is 14%. Thus the gain % might decrease with increasing k, as asymptotically topk would probably converges to sort's performance for large k. However practically k is pretty small.
  6. For larger models, the gains may be less, as they are more device bound probably
  7. Cumsum may be taking long. Maybe try below. Initial try
import torch
y = torch.tensor([[1,2,3], [4,5,6]])
mask1 = torch.tensor([[[1,0,0],[1,1,0],[1,1,1]], [[1,0,0],[1,1,0],[1,1,1]]])
torch.sum(y.unsqueeze(1)*mask1,2)

or

F.conv1d(torch.tensor([[[0,0,0,0,1,2,3,4,5]], [[0,0,0,0,6,7,8,9,10.0]]]), torch.ones([1,1,5], dtype=torch.float32))

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@afierka-intel afierka-intel force-pushed the optimize-topp-topk-sampler branch from 28f6461 to 7987523 Compare January 20, 2025 18:01
@afierka-intel afierka-intel marked this pull request as ready for review January 20, 2025 18:13
@afierka-intel
Copy link
Author

Hello @andrew , @tmm1 , @markmc and @zhouyuan . Can anyone take a look for this little PR please? Thanks in advance!

@afierka-intel
Copy link
Author

Hello again @andrew , @tmm1 , @markmc and @zhouyuan . Can anyone take a look for this little PR please? Thanks in advance!

1 similar comment
@afierka-intel
Copy link
Author

Hello again @andrew , @tmm1 , @markmc and @zhouyuan . Can anyone take a look for this little PR please? Thanks in advance!

@Swipe4057 Swipe4057 mentioned this pull request Feb 5, 2025
6 tasks
@afierka-intel
Copy link
Author

Hello again guys! Can anyone from codeowners: @andrew , @tmm1 , @markmc and @zhouyuan take a look for this little PR please? Thanks in advance!

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @afierka-intel thanks a lot for the contrib.
I like the idea of optimizing a scalar case and agree with avoiding the sort, but in that regard @njhill has recently landed a similar torch.topk-based optimization #15478.

Overall I am not sure if the overhead needed for providing an exact implementation (handling "boundary" values) is actually worth vs an ""approximate"" fast version.

set VLLM_HANDLE_TOPK_DUPLICATES to 1 or true.
"""
_padded_k = 0
_handle_duplicates = os.getenv('VLLM_HANDLE_TOPK_DUPLICATES',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we use vllm.envs for env vars

self._top_k_scalar = top_k_scalar
self._top_p_scalar = top_p_scalar

self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the increment arbitrary?

return self.should_modify_greedy_probs_inplace


class ApplyToppTopkScalar:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be a nn.Module too?

self._top_k_scalar = top_k_scalar
self._top_p_scalar = top_p_scalar

self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not yet used right

@njhill
Copy link
Member

njhill commented Apr 1, 2025

@afierka-intel I think that this may be redundant now that #15478 is merged. That one uses torch.topk but also retains the existing behavior of including all tokens tied for kth place (I'm not sure that matters in any case). It's also more general and works for mixed top-k values.

@xuechendi
Copy link
Contributor

@afierka-intel I think that this may be redundant now that #15478 is merged. That one uses torch.topk but also retains the existing behavior of including all tokens tied for kth place (I'm not sure that matters in any case). It's also more general and works for mixed top-k values.

I think that PR is only for V1, we still want to pursue on this PR for v0.

@njhill
Copy link
Member

njhill commented Apr 7, 2025

@xuechendi perhaps that impl could be reused though? I think the duplicate handling shouldn't be necessary.

@afierka-intel
Copy link
Author

afierka-intel commented Apr 29, 2025

Hello @njhill. Thank you for all you constructive feedback :) I think you general idea is OK and trying to follow that path. However I found that both top/topk sampling methods give sometimes different results. Debugging the problem. I'll back to you soon.

@afierka-intel
Copy link
Author

Short update. Found a culprit causing different results - slightly different API between implementations and had to adjust topp and topk parameters. Further perf analysis shown a significant perf diff between our and v1 implementations. Debugging the problem.

@afierka-intel
Copy link
Author

afierka-intel commented May 5, 2025

Hello again @njhill !

I've finished my performance analysis and the PR impl is working better on HPU comparing with if we backport v1 vllm/vllm/v1/sample/ops/topk_topp_sampler.py at main · vllm-project/vllm, performance show 1.13x gain with existing v0 topk_topp_sampler and 1.10x gain with v1 topk_topp_sampler.

Update: I forgot to check some scenarios. Benchmarking in progress.

@afierka-intel
Copy link
Author

Hi @njhill. After deep perf analysis of all 3 top-p/top-k samplers, i.e. v0 default method, v1 default method and our implementation, we decided to remove our implementation or use default one. There are some scenarios where our implementations shines, but we prefer beauty and simplicity of the upstream version. Thank you for your review and your time!

madamczyk-intel pushed a commit to HabanaAI/vllm-fork that referenced this pull request May 15, 2025
As conclusion of this review
vllm-project#12156 and further,
comprehensive performance analysis we decided to bring back an original
top-p/top-k implementation.

Signed-off-by: Artur Fierka <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants