-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Core] Optimize topp/topk calculation in sampler #12156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Optimize topp/topk calculation in sampler #12156
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: Artur Fierka <[email protected]>
Signed-off-by: Artur Fierka <[email protected]>
28f6461
to
7987523
Compare
1 similar comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @afierka-intel thanks a lot for the contrib.
I like the idea of optimizing a scalar case and agree with avoiding the sort, but in that regard @njhill has recently landed a similar torch.topk-based optimization #15478.
Overall I am not sure if the overhead needed for providing an exact implementation (handling "boundary" values) is actually worth vs an ""approximate"" fast version.
set VLLM_HANDLE_TOPK_DUPLICATES to 1 or true. | ||
""" | ||
_padded_k = 0 | ||
_handle_duplicates = os.getenv('VLLM_HANDLE_TOPK_DUPLICATES', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we use vllm.envs
for env vars
self._top_k_scalar = top_k_scalar | ||
self._top_p_scalar = top_p_scalar | ||
|
||
self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the increment arbitrary?
return self.should_modify_greedy_probs_inplace | ||
|
||
|
||
class ApplyToppTopkScalar: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be a nn.Module
too?
self._top_k_scalar = top_k_scalar | ||
self._top_p_scalar = top_p_scalar | ||
|
||
self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not yet used right
@afierka-intel I think that this may be redundant now that #15478 is merged. That one uses torch.topk but also retains the existing behavior of including all tokens tied for kth place (I'm not sure that matters in any case). It's also more general and works for mixed top-k values. |
I think that PR is only for V1, we still want to pursue on this PR for v0. |
@xuechendi perhaps that impl could be reused though? I think the duplicate handling shouldn't be necessary. |
Hello @njhill. Thank you for all you constructive feedback :) I think you general idea is OK and trying to follow that path. However I found that both top/topk sampling methods give sometimes different results. Debugging the problem. I'll back to you soon. |
Short update. Found a culprit causing different results - slightly different API between implementations and had to adjust topp and topk parameters. Further perf analysis shown a significant perf diff between our and v1 implementations. Debugging the problem. |
Hello again @njhill ! I've finished my performance analysis and the PR impl is working better on HPU comparing with if we backport v1 vllm/vllm/v1/sample/ops/topk_topp_sampler.py at main · vllm-project/vllm, performance show 1.13x gain with existing v0 topk_topp_sampler and 1.10x gain with v1 topk_topp_sampler. Update: I forgot to check some scenarios. Benchmarking in progress. |
Hi @njhill. After deep perf analysis of all 3 top-p/top-k samplers, i.e. v0 default method, v1 default method and our implementation, we decided to remove our implementation or use default one. There are some scenarios where our implementations shines, but we prefer beauty and simplicity of the upstream version. Thank you for your review and your time! |
As conclusion of this review vllm-project#12156 and further, comprehensive performance analysis we decided to bring back an original top-p/top-k implementation. Signed-off-by: Artur Fierka <[email protected]>
One line description
Use topk instead of sort for topp/topk calculation under certain conditions (scalar value of p and k).
Details
Instead of using
k
for topk, we use_padded_k
, which is strictly larger than k and monotonically non decreasing.We need/use
_padded_k > k
for cases where the smallest value of the topk=k values has some values beyond k, (for example for [9,8,8,8,7,7,7], with k=3, we have [9,8,8,8], which is 4 instead of 3 values),To prevent excessive recompilations, anytime we require an expansion of
_padded_k
we increment with a fixed constant_increment
(usually >1), to have a bucketed approach to prevent multiple shapesBasic outline
_padded_k
num_duplicates_of_smallest_of_topk
max_num_duplicates_of_smallest_of_topk
_padded_k
is big enough to containmax_num_duplicates_of_smallest_of_topk
. if not, then expand_padded_k
, and redo the topk again with expanded_padded_k
_padded_k
Perf benefit
The feature gives a 49% increase in thruput in the case with warmup, and 30% increase in thruput in the case without warmup
Extra Notes
_init_sampling_tensors
to determine if its scalar case. This has a minor perf hit. ideally if someone could tell us that its a scalar from the top itself...fliplr
in the code, which could be removed, if we can compute reverse cumsum. however the formula for reverse cumsum as expressed here ,x + torch.sum(x, dim=1, keepdims=True) - torch.cumsum(x, dim=1)
is numerically unstable, because of the addition/subtraction. It works well enough on ints and large numbers, but not on small probability values.k
affects the gains we might get from this. For example in the expt shown above, with k=20, thruput increases around 30%. But if k = 2000, instead of 20, throughput increases the gain is 14%. Thus the gain % might decrease with increasing k, as asymptotically topk would probably converges to sort's performance for large k. However practically k is pretty small.or