diff --git a/tests/full_tests/ci_gsm8k_tests.sh b/tests/full_tests/ci_gsm8k_tests.sh index 8c0136c2..18ef9b9c 100644 --- a/tests/full_tests/ci_gsm8k_tests.sh +++ b/tests/full_tests/ci_gsm8k_tests.sh @@ -158,6 +158,14 @@ run_gsm8k_granite_test() { echo "✅ Test with granite-8b passed." } +# GSM8K on granite-8b (unified attn) +run_gsm8k_granite_test_unified_attn() { + echo "➡️ Testing GSM8K on granite-8b with unified attention..." + VLLM_UNIFIED_ATTN=True VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 \ + pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/granite-8b.yaml" + echo "✅ Test with granite-8b unified attention passed." +} + # GSM8K on granite-8b with async scheduling run_gsm8k_granite_async_test() { echo "➡️ Testing GSM8K on granite-8b with async scheduling..." @@ -230,6 +238,7 @@ launch_all_tests() { run_compressed_w4a16_channelwise_test run_compressed_w4a16_moe_gidx_test run_gsm8k_granite_test + run_gsm8k_granite_test_unified_attn run_gsm8k_granite_async_test run_gsm8k_deepseek_test run_gsm8k_qwen3_30b_test diff --git a/vllm_gaudi/extension/unified.py b/vllm_gaudi/extension/unified.py index 6ed2a3ad..b9942e1f 100644 --- a/vllm_gaudi/extension/unified.py +++ b/vllm_gaudi/extension/unified.py @@ -357,8 +357,12 @@ def create(total_tokens: torch.tensor, block_table: torch.tensor, block_size: in group_ids, group_offsets = indices_and_offsets(num_ctx_blocks) block_ids = fetch_2d(block_table, group_ids, group_offsets) - block_usages = torch.clamp( - total_tokens.index_select(0, group_ids) - group_offsets * block_size + 1, 1, block_size) + #NOTE(kzawora): Originally, we were clamping + # total_tokens.index_select(0, group_ids) - group_offsets * block_size + 1 + # I'm not sure why +1 was there originally, but in non-block-aligned prefix-prefill scenarios + # it made causal mask not cover the first unused token. + # (e.g. with context 28, the 28th slot was unmasked, causing the effective context length to be 29) + block_usages = torch.clamp(total_tokens.index_select(0, group_ids) - group_offsets * block_size, 1, block_size) ctx = Context(group_ids, group_offsets, block_ids, block_usages) all_shapes = [v.shape for v in ctx._values() if torch.is_tensor(v)] diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 4049b09a..3e9adab8 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2825,6 +2825,18 @@ def unified_execute_model( self.input_batch.token_ids_cpu_tensor.index_put_((batch.logits_groups_cpu, batch.new_token_positions_cpu), sampled_token_ids_cpu) + ######### UPDATE REQUEST STATE WITH GENERATED TOKENS ######### + num_reqs = len(selected_req_ids) + for req_id in self.input_batch.req_ids[:num_reqs]: + req_state = self.requests[req_id] + i = self.input_batch.req_id_to_index[req_id] + seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) + token_ids = sampled_token_ids[i] + num_tokens = len(token_ids) + self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids + self.input_batch.num_tokens[i] += len(token_ids) + req_state.output_token_ids.extend(token_ids) + model_runner_output = ModelRunnerOutput( req_ids=batch.req_ids_cpu, req_id_to_index=self.input_batch.req_id_to_index,