Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/full_tests/ci_gsm8k_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ run_gsm8k_granite_test() {
echo "✅ Test with granite-8b passed."
}

# GSM8K on granite-8b (unified attn)
run_gsm8k_granite_test_unified_attn() {
echo "➡️ Testing GSM8K on granite-8b with unified attention..."
VLLM_UNIFIED_ATTN=True VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/granite-8b.yaml"
echo "✅ Test with granite-8b unified attention passed."
}

# GSM8K on granite-8b with async scheduling
run_gsm8k_granite_async_test() {
echo "➡️ Testing GSM8K on granite-8b with async scheduling..."
Expand Down Expand Up @@ -230,6 +238,7 @@ launch_all_tests() {
run_compressed_w4a16_channelwise_test
run_compressed_w4a16_moe_gidx_test
run_gsm8k_granite_test
run_gsm8k_granite_test_unified_attn
run_gsm8k_granite_async_test
run_gsm8k_deepseek_test
run_gsm8k_qwen3_30b_test
Expand Down
8 changes: 6 additions & 2 deletions vllm_gaudi/extension/unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,12 @@ def create(total_tokens: torch.tensor, block_table: torch.tensor, block_size: in

group_ids, group_offsets = indices_and_offsets(num_ctx_blocks)
block_ids = fetch_2d(block_table, group_ids, group_offsets)
block_usages = torch.clamp(
total_tokens.index_select(0, group_ids) - group_offsets * block_size + 1, 1, block_size)
#NOTE(kzawora): Originally, we were clamping
# total_tokens.index_select(0, group_ids) - group_offsets * block_size + 1
# I'm not sure why +1 was there originally, but in non-block-aligned prefix-prefill scenarios
# it made causal mask not cover the first unused token.
# (e.g. with context 28, the 28th slot was unmasked, causing the effective context length to be 29)
block_usages = torch.clamp(total_tokens.index_select(0, group_ids) - group_offsets * block_size, 1, block_size)

ctx = Context(group_ids, group_offsets, block_ids, block_usages)
all_shapes = [v.shape for v in ctx._values() if torch.is_tensor(v)]
Expand Down
12 changes: 12 additions & 0 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2825,6 +2825,18 @@ def unified_execute_model(
self.input_batch.token_ids_cpu_tensor.index_put_((batch.logits_groups_cpu, batch.new_token_positions_cpu),
sampled_token_ids_cpu)

######### UPDATE REQUEST STATE WITH GENERATED TOKENS #########
num_reqs = len(selected_req_ids)
for req_id in self.input_batch.req_ids[:num_reqs]:
req_state = self.requests[req_id]
i = self.input_batch.req_id_to_index[req_id]
seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id])
token_ids = sampled_token_ids[i]
num_tokens = len(token_ids)
self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids
self.input_batch.num_tokens[i] += len(token_ids)
req_state.output_token_ids.extend(token_ids)

model_runner_output = ModelRunnerOutput(
req_ids=batch.req_ids_cpu,
req_id_to_index=self.input_batch.req_id_to_index,
Expand Down