Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,32 +250,35 @@ def expand_to_neighbor_buckets(bs_idx, bs_range, ctx_idx, ctx_range, max_num_bat
# filter rules for buckets
# prompt
def not_over_max_model_len(bs, query, ctx):
if not bs * (query + ctx * block_size) <= max_model_len:
smaller_than_limit = bs * (query + ctx * block_size) <= max_model_len
if not smaller_than_limit:
omitted_buckets.add(
("condition: bs * (query + ctx * block_size) <= max_model_len", "-> bs, query, ctx: ", bs, query, ctx))
return bs * (query + ctx * block_size) <= max_model_len
return smaller_than_limit

def not_over_max_num_batched_tokens(bs, query, ctx):
if not bs * query <= max_num_batched_tokens:
smaller_than_limit = bs * query <= max_num_batched_tokens
if not smaller_than_limit:
omitted_buckets.add(
("condition: bs * query <= max_num_batched_tokens", "-> bs, query, ctx: ", bs, query, ctx))
return bs * query <= max_num_batched_tokens
return smaller_than_limit

def ctx_not_over_max_ctx_for_merged_prefill(bs, query, ctx):
if not ctx <= max_num_prefill_seqs * math.ceil(
(max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size):
smaller_than_limit = ctx <= max_num_prefill_seqs * math.ceil(
(max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size)
if not smaller_than_limit:
omitted_buckets.add((
"ctx <= max_num_prefill_seqs * math.ceil((max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size)",
"-> bs, query, ctx: ", bs, query, ctx))
return ctx <= max_num_prefill_seqs * math.ceil(
(max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size)
return smaller_than_limit

# decode
def block_not_greater_than_max_model_len(bs, query, ctx):
if not ctx <= bs * math.ceil(max_model_len / block_size):
smaller_than_limit = ctx <= bs * math.ceil(max_model_len / block_size)
if not smaller_than_limit:
omitted_buckets.add(
("condition: ctx <= bs * math.ceil(max_model_len / block_size)", "-> bs, query, ctx: ", bs, query, ctx))
return ctx <= bs * math.ceil(max_model_len / block_size)
return smaller_than_limit

def batch_size_smaller_than_blocks(bs, query, ctx):
if not bs <= ctx:
Expand Down