diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index a9bc89a2..17d3f639 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -250,32 +250,35 @@ def expand_to_neighbor_buckets(bs_idx, bs_range, ctx_idx, ctx_range, max_num_bat # filter rules for buckets # prompt def not_over_max_model_len(bs, query, ctx): - if not bs * (query + ctx * block_size) <= max_model_len: + smaller_than_limit = bs * (query + ctx * block_size) <= max_model_len + if not smaller_than_limit: omitted_buckets.add( ("condition: bs * (query + ctx * block_size) <= max_model_len", "-> bs, query, ctx: ", bs, query, ctx)) - return bs * (query + ctx * block_size) <= max_model_len + return smaller_than_limit def not_over_max_num_batched_tokens(bs, query, ctx): - if not bs * query <= max_num_batched_tokens: + smaller_than_limit = bs * query <= max_num_batched_tokens + if not smaller_than_limit: omitted_buckets.add( ("condition: bs * query <= max_num_batched_tokens", "-> bs, query, ctx: ", bs, query, ctx)) - return bs * query <= max_num_batched_tokens + return smaller_than_limit def ctx_not_over_max_ctx_for_merged_prefill(bs, query, ctx): - if not ctx <= max_num_prefill_seqs * math.ceil( - (max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size): + smaller_than_limit = ctx <= max_num_prefill_seqs * math.ceil( + (max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size) + if not smaller_than_limit: omitted_buckets.add(( "ctx <= max_num_prefill_seqs * math.ceil((max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size)", "-> bs, query, ctx: ", bs, query, ctx)) - return ctx <= max_num_prefill_seqs * math.ceil( - (max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size) + return smaller_than_limit # decode def block_not_greater_than_max_model_len(bs, query, ctx): - if not ctx <= bs * math.ceil(max_model_len / block_size): + smaller_than_limit = ctx <= bs * math.ceil(max_model_len / block_size) + if not smaller_than_limit: omitted_buckets.add( ("condition: ctx <= bs * math.ceil(max_model_len / block_size)", "-> bs, query, ctx: ", bs, query, ctx)) - return ctx <= bs * math.ceil(max_model_len / block_size) + return smaller_than_limit def batch_size_smaller_than_blocks(bs, query, ctx): if not bs <= ctx: