diff --git a/tests/v1/kv_connector/cpu_kv_integration/__init__.py b/tests/v1/kv_connector/cpu_kv_integration/__init__.py new file mode 100644 index 000000000000..f5bc53998307 --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# Empty init file to mark directory as Python package diff --git a/tests/v1/kv_connector/cpu_kv_integration/online_test.sh b/tests/v1/kv_connector/cpu_kv_integration/online_test.sh new file mode 100644 index 000000000000..c1eee5cfc69f --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/online_test.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [model]" + exit 1 +fi + +if [[ $# -eq 1 ]]; then + echo "Using default model: meta-llama/Llama-3.1-8B-Instruct" + MODEL="meta-llama/Llama-3.1-8B-Instruct" +else + echo "Using model: $2" + MODEL=$2 +fi + + +if [[ $1 == "prefiller" ]]; then + # Prefiller listens on port 8100 + #UCX_TLS=cuda_ipc,cuda_copy,tcp \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + CUDA_VISIBLE_DEVICES=0 \ + vllm serve $MODEL \ + --port 8100 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"CPUConnector","kv_role":"kv_producer","kv_connector_extra_config": {"host": "localhost", "port": "54321", "size": 40}}' + + +elif [[ $1 == "decoder" ]]; then + # Decoder listens on port 8200 + #UCX_TLS=cuda_ipc,cuda_copy,tcp \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + CUDA_VISIBLE_DEVICES=1 \ + vllm serve $MODEL \ + --port 8200 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"CPUConnector","kv_role":"kv_consumer","kv_connector_extra_config": {"host": "localhost", "port": "54321", "size": 40}}' + + +else + echo "Invalid role: $1" + echo "Should be either prefiller, decoder" + exit 1 +fi diff --git a/tests/v1/kv_connector/cpu_kv_integration/output.txt b/tests/v1/kv_connector/cpu_kv_integration/output.txt new file mode 100644 index 000000000000..09cf415402dc --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/output.txt @@ -0,0 +1,4 @@ +Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hi Hello, my name is [ +Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey Hey The capital of France is Paris +Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Hello Your name is not +How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How How The capital of China is Beijing diff --git a/tests/v1/kv_connector/cpu_kv_integration/output_decode.txt b/tests/v1/kv_connector/cpu_kv_integration/output_decode.txt new file mode 100644 index 000000000000..e688555677d2 --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/output_decode.txt @@ -0,0 +1,4 @@ + Hi Hi Hi Hi Hello, my name is [Your Name] and I am a [Your +Hi Hi The capital of France is Paris. The capital of France is Paris. The +Hello Hello Hello Your name is not in the list. Please check your email for +ow How The capital of China is Beijing. Beijing is a city in northern China. diff --git a/tests/v1/kv_connector/cpu_kv_integration/run_nsys.sh b/tests/v1/kv_connector/cpu_kv_integration/run_nsys.sh new file mode 100644 index 000000000000..dae01d303952 --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/run_nsys.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +if [[ $1 == "decoder" ]]; then +echo "Running decoder" +CUDA_VISIBLE_DEVICES=7 nsys profile \ + --trace=cuda,nvtx,osrt \ + --gpu-metrics-devices=cuda-visible \ + --python-sampling=true \ + --trace-fork-before-exec=true \ + --output=decoder \ + --force-overwrite=true \ + python3 toy_decode.py + +else +echo "Running prefiller" +CUDA_VISIBLE_DEVICES=6 nsys profile \ + --trace=cuda,nvtx,osrt \ + --gpu-metrics-devices=cuda-visible \ + --python-sampling=true \ + --trace-fork-before-exec=true \ + --output=prefiller \ + --force-overwrite=true \ + python3 toy_example.py +fi diff --git a/tests/v1/kv_connector/cpu_kv_integration/temptest.py b/tests/v1/kv_connector/cpu_kv_integration/temptest.py new file mode 100644 index 000000000000..8a133ae7d902 --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/temptest.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_cpu_utils import ( + NixlKVSender) + +sender = NixlKVSender(1024 * 1024 * 1024) + +sender.close() diff --git a/tests/v1/kv_connector/cpu_kv_integration/test_cpu_connector_kernels.py b/tests/v1/kv_connector/cpu_kv_integration/test_cpu_connector_kernels.py new file mode 100644 index 000000000000..06525cbba7eb --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/test_cpu_connector_kernels.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from vllm.distributed.kv_transfer.kv_connector.v1.cpu_connector import ( + d2h_page_copy, h2d_copy_leading_tokens, h2d_copy_trailing_tokens, + h2d_page_copy) + + +@pytest.fixture +def device_tensors(): + """Create sample device tensors for testing.""" + # Create tensors with shape (2, num_blocks, page_size, head_size, + # hidden_size) + num_blocks = 4 + page_size = 16 + head_size = 8 + hidden_size = 128 + + # Initialize with unique values for each position + k_tensor = torch.arange(num_blocks * page_size * head_size * hidden_size, + dtype=torch.float32, + device='cuda') + k_tensor = k_tensor.reshape(num_blocks, page_size, head_size, hidden_size) + + v_tensor = k_tensor + 1000 # Different values for v + + # Stack k and v tensors + kv_tensor = torch.stack([k_tensor, v_tensor], dim=0) + return kv_tensor + + +@pytest.fixture +def host_buffer(): + """Create host buffer for testing.""" + # Create buffer with same dimensions as device tensor but fewer blocks + num_blocks = 2 # Smaller than device tensor + page_size = 16 + head_size = 8 + hidden_size = 128 + + k_buffer = torch.zeros(num_blocks * page_size * head_size * hidden_size, + dtype=torch.float32) + k_buffer = k_buffer.reshape(num_blocks, page_size, head_size, hidden_size) + + v_buffer = torch.zeros_like(k_buffer) + + # Stack k and v buffers + kv_buffer = torch.stack([k_buffer, v_buffer], dim=0) + return kv_buffer + + +def test_d2h_page_copy(device_tensors, host_buffer): + """Test device to host copy operation.""" + # Copy blocks 1 and 3 from device to host + block_ids = [1, 3] + + d2h_page_copy(device_tensors, host_buffer, block_ids) + + # Verify copied data + for i, block_id in enumerate(block_ids): + # Check key tensor + assert torch.allclose(host_buffer[0, i].cpu(), + device_tensors[0, block_id].cpu()) + # Check value tensor + assert torch.allclose(host_buffer[1, i].cpu(), + device_tensors[1, block_id].cpu()) + + +def test_h2d_copy_leading_tokens(): + """Test copying leading tokens from host to device.""" + # Create sample tensors + page_size = 16 + head_size = 8 + hidden_size = 128 + + src_buffer = torch.ones((2, 1, page_size, head_size, hidden_size), + dtype=torch.float32) + # Initialize destination with a known pattern + dst_layer = torch.full((2, 1, page_size, head_size, hidden_size), + fill_value=2.0, + dtype=torch.float32, + device='cuda') + + # Copy first 8 tokens (half of page_size) + end_position = 8 + h2d_copy_leading_tokens(src_buffer, + dst_layer, + src_block_id=0, + dst_block_id=0, + end_position_in_block=end_position) + + # Verify first 8 tokens were copied + assert torch.allclose(dst_layer[0, 0, :end_position].cpu(), + src_buffer[0, 0, :end_position]) + assert torch.allclose(dst_layer[1, 0, :end_position].cpu(), + src_buffer[1, 0, :end_position]) + + # Verify remaining tokens are unchanged (should still be 2.0) + expected_unchanged = torch.full( + (page_size - end_position, head_size, hidden_size), + fill_value=2.0, + dtype=torch.float32) + assert torch.allclose(dst_layer[0, 0, end_position:].cpu(), + expected_unchanged) + assert torch.allclose(dst_layer[1, 0, end_position:].cpu(), + expected_unchanged) + + +def test_h2d_copy_trailing_tokens(): + """Test copying trailing tokens from host to device.""" + # Create sample tensors + page_size = 16 + head_size = 8 + hidden_size = 128 + + src_buffer = torch.ones((2, 1, page_size, head_size, hidden_size), + dtype=torch.float32) + # Initialize destination with a known pattern + dst_layer = torch.full((2, 1, page_size, head_size, hidden_size), + fill_value=2.0, + dtype=torch.float32, + device='cuda') + + # Copy last 8 tokens (half of page_size) + start_position = 8 + h2d_copy_trailing_tokens(src_buffer, + dst_layer, + src_block_id=0, + dst_block_id=0, + start_position_in_block=start_position) + + # Verify last 8 tokens were copied + assert torch.allclose(dst_layer[0, 0, start_position:].cpu(), + src_buffer[0, 0, start_position:]) + assert torch.allclose(dst_layer[1, 0, start_position:].cpu(), + src_buffer[1, 0, start_position:]) + + # Verify leading tokens are unchanged (should still be 2.0) + expected_unchanged = torch.full((start_position, head_size, hidden_size), + fill_value=2.0, + dtype=torch.float32) + assert torch.allclose(dst_layer[0, 0, :start_position].cpu(), + expected_unchanged) + assert torch.allclose(dst_layer[1, 0, :start_position].cpu(), + expected_unchanged) + + +def test_h2d_page_copy(): + """Test host to device page copy operation.""" + # Create sample tensors + num_blocks = 4 + page_size = 16 + head_size = 8 + hidden_size = 128 + block_size = page_size + + src_buffer = torch.ones((2, num_blocks, page_size, head_size, hidden_size), + dtype=torch.float32) + # Initialize destination with a known pattern + dst_layer = torch.full((2, num_blocks, page_size, head_size, hidden_size), + fill_value=2.0, + dtype=torch.float32, + device='cuda') + + # Test copying a range of tokens that spans multiple blocks + block_ids = [0, 1, 2, 3] + start_token_idx = 8 + stop_token_idx = 56 + + h2d_page_copy(src_buffer, dst_layer, block_ids, start_token_idx, + stop_token_idx, block_size) + + # Calculate which blocks should be fully/partially copied + start_block = start_token_idx // block_size + end_block = (stop_token_idx + block_size - 1) // block_size + start_pos = start_token_idx % block_size + end_pos = stop_token_idx % block_size + + # Expected unchanged value + expected_unchanged = torch.full((page_size, head_size, hidden_size), + fill_value=2.0, + dtype=torch.float32) + + # Verify copied and unchanged data for each block + for i in range(num_blocks): + if i < start_block or i >= end_block: + # Blocks outside the copy range should be unchanged + assert torch.allclose(dst_layer[:, block_ids[i]].cpu(), + expected_unchanged) + elif i == start_block: + # First block - verify both copied and unchanged parts + # Leading part should be unchanged + assert torch.allclose(dst_layer[:, block_ids[i], :start_pos].cpu(), + expected_unchanged[:start_pos]) + # Trailing part should be copied + assert torch.allclose(dst_layer[:, block_ids[i], start_pos:].cpu(), + src_buffer[:, i, start_pos:]) + elif i == end_block - 1: + # Last block - verify both copied and unchanged parts + # Leading part should be copied + assert torch.allclose(dst_layer[:, block_ids[i], :end_pos].cpu(), + src_buffer[:, i, :end_pos]) + # Trailing part should be unchanged + assert torch.allclose(dst_layer[:, block_ids[i], end_pos:].cpu(), + expected_unchanged[end_pos:]) + else: + # Middle blocks - verify full copy + assert torch.allclose(dst_layer[:, block_ids[i]].cpu(), + src_buffer[:, i]) + + +def test_h2d_page_copy_edge_cases(): + """Test edge cases for host to device page copy.""" + # Create sample tensors + num_blocks = 2 + page_size = 16 + head_size = 8 + hidden_size = 128 + block_size = page_size + + src_buffer = torch.ones((2, num_blocks, page_size, head_size, hidden_size), + dtype=torch.float32) + dst_layer = torch.zeros((2, num_blocks, page_size, head_size, hidden_size), + dtype=torch.float32, + device='cuda') + + # Test case 1: Copy exactly one block + block_ids = [0, 1] + start_token_idx = 0 + stop_token_idx = block_size + + h2d_page_copy(src_buffer, dst_layer, block_ids, start_token_idx, + stop_token_idx, block_size) + + assert torch.allclose(dst_layer[:, block_ids[0]].cpu(), src_buffer[:, 0]) + + # Test case 2: Copy partial block + dst_layer.zero_() + block_ids = [0, 1] + start_token_idx = block_size + 2 + stop_token_idx = block_size + 6 + + h2d_page_copy(src_buffer, dst_layer, block_ids, start_token_idx, + stop_token_idx, block_size) + + start_pos = start_token_idx % block_size + end_pos = stop_token_idx % block_size + + assert torch.allclose(dst_layer[:, block_ids[1], start_pos:end_pos].cpu(), + src_buffer[:, 1, start_pos:end_pos]) + + +def test_h2d_page_copy_aligned(): + """Test host to device page copy operation with block-aligned boundaries.""" + # Create sample tensors + num_blocks = 4 + page_size = 16 + head_size = 8 + hidden_size = 128 + block_size = page_size + + src_buffer = torch.ones((2, num_blocks, page_size, head_size, hidden_size), + dtype=torch.float32) + # Initialize destination with a known pattern + dst_layer = torch.full((2, num_blocks, page_size, head_size, hidden_size), + fill_value=2.0, + dtype=torch.float32, + device='cuda') + + # Test copying exactly 2 blocks (from block 1 to block 3) + block_ids = [0, 1, 2, 3] + start_token_idx = block_size # Start at beginning of block 1 + stop_token_idx = block_size * 3 # End at end of block 2 + + h2d_page_copy(src_buffer, dst_layer, block_ids, start_token_idx, + stop_token_idx, block_size) + + # Expected unchanged value + expected_unchanged = torch.full((page_size, head_size, hidden_size), + fill_value=2.0, + dtype=torch.float32) + + # Verify copied and unchanged data for each block + for i in range(num_blocks): + if i == 0 or i == 3: + # First and last blocks should be unchanged + assert torch.allclose( + dst_layer[:, block_ids[i]].cpu(), + expected_unchanged), f"Block {i} should be unchanged" + else: + # Middle blocks (1 and 2) should be fully copied + assert torch.allclose( + dst_layer[:, block_ids[i]].cpu(), + src_buffer[:, i]), f"Block {i} should be fully copied" + + # Test copying a single block-aligned region + dst_layer.fill_(2.0) # Reset destination + start_token_idx = block_size * 2 # Start at beginning of block 2 + stop_token_idx = block_size * 3 # End at end of block 2 + + h2d_page_copy(src_buffer, dst_layer, block_ids, start_token_idx, + stop_token_idx, block_size) + + # Verify only block 2 was copied, others unchanged + for i in range(num_blocks): + if i == 2: + # Block 2 should be fully copied + assert torch.allclose( + dst_layer[:, block_ids[i]].cpu(), + src_buffer[:, i]), "Block 2 should be fully copied" + else: + # All other blocks should be unchanged + assert torch.allclose( + dst_layer[:, block_ids[i]].cpu(), + expected_unchanged), f"Block {i} should be unchanged" diff --git a/tests/v1/kv_connector/cpu_kv_integration/test_nixl_cpu_utils.py b/tests/v1/kv_connector/cpu_kv_integration/test_nixl_cpu_utils.py new file mode 100644 index 000000000000..d4837e0d1c56 --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/test_nixl_cpu_utils.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 +import time + +import pytest +import torch +import torch.multiprocessing as mp + +import vllm.distributed.kv_transfer.kv_connector.v1.nixl_cpu_utils as utils +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_cpu_utils import ( + DestinationSpec, NixlCPUReceiver, NixlCPUSender, RingBufferAllocator, + SourceSpec) + +try: + #from nixl._api import nixl_agent as NixlWrapper + import importlib + spec = importlib.util.find_spec("nixl._api") + if spec is None: + raise ImportError("NIXL is not available") + NIXL_AVAILABLE = True +except ImportError: + NIXL_AVAILABLE = False + + +def run_receiver(buffer_config, host, base_port, rank, ready_event, + stop_event): + """Process function for running the receiver.""" + try: + # Mock tensor_model_parallel_rank for this process + utils.get_tensor_model_parallel_rank = lambda: rank + + # Create ring buffer allocator + allocator = utils.RingBufferAllocator( + size=buffer_config['buffer_size'], + align_to=buffer_config['nixl_page_size']) + + # Create and start receiver + receiver = NixlCPUReceiver( + allocator=allocator, + nixl_page_size=buffer_config['nixl_page_size']) + receiver.start_handshake_listener(host, base_port) + + # Signal receiver is ready + ready_event.set() + + # Wait for stop signal + stop_event.wait() + + # Cleanup + receiver.stop_handshake_listener() + + except Exception as e: + print(f"Receiver process error: {e}") + raise + + +def run_sender(buffer_config, host, base_port, rank, receiver_ready_event): + """Process function for running the sender.""" + try: + # Mock tensor_model_parallel_rank for this process + utils.get_tensor_model_parallel_rank = lambda: rank + + # Create ring buffer allocator + allocator = utils.RingBufferAllocator( + size=buffer_config['buffer_size'], + align_to=buffer_config['nixl_page_size']) + + # Wait for receiver to be ready + receiver_ready_event.wait() + + # Create sender and perform handshake + sender = NixlCPUSender(buffer_size=buffer_config['buffer_size'], + buffer_ptr=allocator.get_buffer_ptr(), + nixl_page_size=buffer_config['nixl_page_size']) + + dest_spec = DestinationSpec(rank=rank, host=host, base_port=base_port) + sender._nixl_handshake(dest_spec) + + # Verify handshake results + assert dest_spec.get_id() in sender._remote_agents + assert sender._remote_agents[dest_spec.get_id()] is not None + peer_name = sender._remote_agents[dest_spec.get_id()] + assert sender._remote_xfer_handlers[peer_name] is not None + + return True + except Exception as e: + print(f"Sender process error: {e}") + raise + + +def run_receiver_with_progress(buffer_config, + host, + base_port, + rank, + ready_event, + stop_event, + progress_interval=0.001): + """Process function for running the receiver with progress loop.""" + try: + # Mock tensor_model_parallel_rank for this process + utils.get_tensor_model_parallel_rank = lambda: rank + + # Create ring buffer allocator + allocator = utils.RingBufferAllocator( + size=buffer_config['buffer_size'], + align_to=buffer_config['nixl_page_size']) + allocator._buffer.fill_(0) + + # Create and start receiver + receiver = NixlCPUReceiver( + allocator=allocator, + nixl_page_size=buffer_config['nixl_page_size']) + receiver.start_handshake_listener(host, base_port) + + # Signal receiver is ready + ready_event.set() + + # Run progress loop until stop signal + while not receiver.get_finished(): + receiver.progress() + time.sleep(progress_interval) + + finished = receiver.get_finished(clear=True) + assert len(finished) == 1 + source_spec, vaddr = finished[0] + paddr = allocator.virtual_to_physical(vaddr) + + # Check if the numbers are all correct (should be uint8 all 1) + num_elements = source_spec.get_size() + should_1 = allocator._buffer[paddr:paddr + num_elements] + should_0_a = allocator._buffer[:paddr] + should_0_b = allocator._buffer[paddr + num_elements:] + assert (should_1 == 1).all(), "Buffer data mismatch" + if len(should_0_a) > 0: + assert (should_0_a == 0).all(), "Buffer data mismatch" + if len(should_0_b) > 0: + assert (should_0_b == 0).all(), "Buffer data mismatch" + + while not stop_event.is_set(): + receiver.progress() + time.sleep(progress_interval) + + # Cleanup + receiver.stop_handshake_listener() + + except Exception as e: + print(f"Receiver process error: {e}") + raise + + +def run_sender_with_protocol(buffer_config, host, base_port, rank, + receiver_ready_event, success_event): + """Process function for running the sender with protocol communication.""" + try: + # Mock tensor_model_parallel_rank for this process + utils.get_tensor_model_parallel_rank = lambda: rank + + # Create ring buffer allocator + allocator = utils.RingBufferAllocator( + size=buffer_config['buffer_size'], + align_to=buffer_config['nixl_page_size']) + + # Wait for receiver to be ready + receiver_ready_event.wait() + + # Create sender + sender = NixlCPUSender(buffer_size=buffer_config['buffer_size'], + buffer_ptr=allocator.get_buffer_ptr(), + nixl_page_size=buffer_config['nixl_page_size']) + + # Create destination spec and perform handshake + dest_spec = DestinationSpec(rank=rank, host=host, base_port=base_port) + sender._nixl_handshake(dest_spec) + + # Create source spec and prepare send + source_spec = SourceSpec( + request_id="test_request", + layer_id=0, + start=0, + stop=16, # Assuming we want to send 16 tokens + shape=(2, 1, 16, 8, 128), # Example shape + dtype_str="bfloat16", # Example dtype + num_all_tokens=16, + ) + + # Prepare send and wait for completion + uid = sender.prepare_send(source_spec, dest_spec) + + max_retries = 100 + retry_count = 0 + remote_agent = None + + while retry_count < max_retries: + remote_agent, receiver_paddr = \ + sender.check_and_remove_prepared_send(uid) + if remote_agent is not None: + break + time.sleep(0.1) + retry_count += 1 + + assert remote_agent is not None, "Failed to get remote agent" + assert receiver_paddr != -1, "Failed to get receiver virtual address" + + # Test the real send + vaddr, buffer = allocator.allocate(source_spec.get_size()) + paddr = allocator.virtual_to_physical(vaddr) + + buffer.fill_(1) # Fill with dummy data + + handle = sender.send(paddr, receiver_paddr, source_spec.get_size(), + uid, dest_spec) + + while not sender.is_send_finished(handle): + time.sleep(0.1) + print("Send completed successfully") + + if remote_agent is not None: + success_event.set() + + except Exception as e: + print(f"Sender process error: {e}") + raise + + +@pytest.mark.skipif(not NIXL_AVAILABLE, reason="NIXL is not available") +class TestNixlCPUUtils: + """Test cases for NixlCPUSender and NixlCPUReceiver.""" + + @classmethod + def setup_class(cls): + """Set up the test class.""" + pass + + @pytest.fixture + def buffer_config(self): + """Common buffer configuration for tests.""" + buffer_size = 1 << 20 # 1MB + torch_buffer = torch.zeros(buffer_size, + dtype=torch.uint8, + device='cpu') + + return { + 'buffer_size': buffer_size, + 'buffer_ptr': torch_buffer.data_ptr(), + 'nixl_page_size': 4096 # Standard page size + } + + def test_sender_creation(self, buffer_config): + """Test creation of NixlCPUSender.""" + sender = NixlCPUSender(buffer_size=buffer_config['buffer_size'], + buffer_ptr=buffer_config['buffer_ptr'], + nixl_page_size=buffer_config['nixl_page_size']) + + # Verify internal state + assert sender._buffer_size == buffer_config['buffer_size'] + assert sender._buffer_ptr == buffer_config['buffer_ptr'] + assert sender._nixl_page_size == buffer_config['nixl_page_size'] + assert isinstance(sender._remote_agents, dict) + + # Verify NIXL initialization + assert sender._nixl_wrapper is not None + assert sender._reg_dlist is not None + assert sender._local_xfer_dlist is not None + + def test_receiver_creation(self, buffer_config): + """Test creation of NixlCPUReceiver.""" + # Create ring buffer allocator + allocator = RingBufferAllocator( + size=buffer_config['buffer_size'], + align_to=buffer_config['nixl_page_size']) + + receiver = NixlCPUReceiver( + allocator=allocator, + nixl_page_size=buffer_config['nixl_page_size']) + + # Verify internal state + assert receiver._buffer_size == buffer_config['buffer_size'] + assert receiver._buffer_ptr == allocator.get_buffer_ptr() + assert receiver._nixl_page_size == buffer_config['nixl_page_size'] + assert isinstance(receiver._inflight_requests, dict) + assert isinstance(receiver._inflight_request_vaddr, dict) + assert receiver._allocator is allocator + + # Verify NIXL initialization + assert receiver._nixl_wrapper is not None + assert receiver._reg_dlist is not None + assert receiver._local_xfer_dlist is not None + + def test_nixl_handshake_multiprocess(self, buffer_config): + """Test NIXL handshake between sender and receiver in separate + processes. + """ + # Setup test parameters + test_host = "127.0.0.1" + test_base_port = 50051 + test_rank = 0 + + old_start_method = mp.get_start_method(allow_none=True) + mp.set_start_method("spawn", force=True) + + # Create events for process synchronization + receiver_ready = mp.Event() + stop_receiver = mp.Event() + + # Start receiver process + receiver_process = mp.Process(target=run_receiver, + args=(buffer_config, test_host, + test_base_port, test_rank, + receiver_ready, stop_receiver)) + receiver_process.start() + + # Start sender process + sender_process = mp.Process(target=run_sender, + args=(buffer_config, test_host, + test_base_port, test_rank, + receiver_ready)) + sender_process.start() + + try: + # Wait for processes to complete + sender_process.join(timeout=20) + assert sender_process.exitcode == 0, "Sender process failed" + + finally: + # Cleanup + stop_receiver.set() + receiver_process.join(timeout=5) + + # Force terminate if processes haven't exited + if receiver_process.is_alive(): + receiver_process.terminate() + if sender_process.is_alive(): + sender_process.terminate() + + mp.set_start_method(old_start_method, force=True) + + def test_nixl_protocol_communication(self, buffer_config): + """Test the full protocol communication between sender and receiver.""" + # Setup test parameters + test_host = "127.0.0.1" + test_base_port = 50052 + test_rank = 0 + + # Set multiprocessing start method + old_start_method = mp.get_start_method(allow_none=True) + mp.set_start_method("spawn", force=True) + + # Create events for process synchronization + receiver_ready = mp.Event() + stop_receiver = mp.Event() + protocol_success = mp.Event() + + # Start receiver process with progress loop + receiver_process = mp.Process(target=run_receiver_with_progress, + args=(buffer_config, test_host, + test_base_port, test_rank, + receiver_ready, stop_receiver)) + receiver_process.start() + + # Start sender process with protocol communication + sender_process = mp.Process(target=run_sender_with_protocol, + args=(buffer_config, test_host, + test_base_port, test_rank, + receiver_ready, protocol_success)) + sender_process.start() + + try: + # Wait for protocol communication to complete + protocol_complete = protocol_success.wait(timeout=20) + assert protocol_complete, \ + "Protocol communication failed or timed out" + + # Wait for sender process to complete + sender_process.join(timeout=5) + assert sender_process.exitcode == 0, "Sender process failed" + + finally: + # Cleanup + stop_receiver.set() + receiver_process.join(timeout=5) + + # Force terminate if processes haven't exited + if receiver_process.is_alive(): + receiver_process.terminate() + if sender_process.is_alive(): + sender_process.terminate() + + mp.set_start_method(old_start_method, force=True) diff --git a/tests/v1/kv_connector/cpu_kv_integration/test_ring_buffer_allocator.py b/tests/v1/kv_connector/cpu_kv_integration/test_ring_buffer_allocator.py new file mode 100644 index 000000000000..26051fdb3e00 --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/test_ring_buffer_allocator.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_cpu_utils import ( + RingBufferAllocator) + + +def test_basic_allocation(): + """Test basic allocation and deallocation behavior.""" + # Create a buffer with 1024 bytes, aligned to 256 bytes + allocator = RingBufferAllocator(size=1024, align_to=256) + + # Allocate 100 bytes - should be aligned to 256 + addr1, buffer1 = allocator.allocate(100) + assert addr1 >= 0 # Valid address + assert buffer1 is not None + assert len(buffer1) == 100 + assert allocator.high_watermark == 256 # Aligned to 256 + assert allocator.low_watermark == 0 + + # Allocate another 100 bytes + addr2, buffer2 = allocator.allocate(100) + assert addr2 >= 0 # Valid address + assert buffer2 is not None + assert len(buffer2) == 100 + assert allocator.high_watermark == 512 # Next aligned position + + # Verify buffers don't overlap + assert buffer1.data_ptr() + len(buffer1) <= buffer2.data_ptr() + + +def test_alignment(): + """Test that allocations are properly aligned.""" + allocator = RingBufferAllocator(size=1024, align_to=256) + + # Allocate various sizes and verify alignment + sizes = [10, 100, 200, 50] + addresses = [] + buffers = [] + + for size in sizes: + addr, buf = allocator.allocate(size) + assert addr >= 0 # Valid address + assert buf is not None + addresses.append(addr) + buffers.append(buf) + # High watermark should always be aligned to 256 + assert allocator.high_watermark % 256 == 0 + + +def test_wraparound(): + """Test buffer wraparound behavior.""" + allocator = RingBufferAllocator(size=1024, align_to=256) + + # Fill most of the buffer + addr1, buffer1 = allocator.allocate(300) # Takes 512 bytes aligned + addr2, buffer2 = allocator.allocate(300) # Takes 512 bytes aligned + assert addr1 >= 0 and addr2 >= 0 # Valid addresses + assert buffer1 is not None and buffer2 is not None + + # This allocation should fail as we don't have enough contiguous space + addr3, buffer3 = allocator.allocate(300) + assert addr3 == -1 # Invalid address + assert buffer3 is None + + # Free the first buffer + allocator.free(addr1) # Free first 512 bytes + + # Now we should be able to allocate again by wrapping around + addr4, buffer4 = allocator.allocate(200) + assert addr4 >= 0 # Valid address + assert buffer4 is not None + assert allocator.high_watermark >= allocator._size # Wrapped around + assert allocator.high_watermark % allocator._size < 512 # Using freed space + + +def test_fragmentation(): + """Test handling of fragmentation.""" + allocator = RingBufferAllocator(size=1024, align_to=256) + + # Allocate several buffers + addr1, buffer1 = allocator.allocate(100) # 256 bytes aligned + addr2, buffer2 = allocator.allocate(100) # 256 bytes aligned + addr3, buffer3 = allocator.allocate(100) # 256 bytes aligned + assert all(addr >= 0 for addr in [addr1, addr2, addr3]) # Valid addresses + assert all(buf is not None for buf in [buffer1, buffer2, buffer3]) + + # Free buffer2, creating a gap + allocator.free(addr2) # Free middle buffer + + # Try to allocate a buffer larger than the gap + addr4, buffer4 = allocator.allocate(300) + assert addr4 == -1 # Invalid address + assert buffer4 is None # Should fail due to fragmentation + + # Allocate a buffer that fits in the gap + # This should also fail as we don't track gaps in current implementation + addr5, buffer5 = allocator.allocate(100) + assert addr5 == -1 # Invalid address + assert buffer5 is None # Should fail due to fragmentation + + # Free buffer1 + allocator.free(addr1) # Free first buffer + + # Now we should be able to allocate again + addr6, buffer6 = allocator.allocate(100) + assert addr6 >= 0 # Valid address + assert buffer6 is not None + assert allocator.high_watermark >= allocator._size # Wrapped around + assert allocator.high_watermark % allocator._size < 512 # Using freed space + + +def test_full_buffer(): + """Test behavior when buffer is completely full.""" + allocator = RingBufferAllocator(size=1024, align_to=256) + + # Fill the entire buffer + addresses = [] + buffers = [] + while True: + addr, buf = allocator.allocate(200) + if addr == -1: # Invalid address indicates allocation failure + break + addresses.append(addr) + buffers.append(buf) + + # Verify we can't allocate more + addr, buf = allocator.allocate(10) + assert addr == -1 + assert buf is None + + # Free everything + for addr in addresses: + allocator.free(addr) + + # Should be able to allocate again + addr, buffer = allocator.allocate(200) + assert addr >= 0 # Valid address + assert buffer is not None + + +def test_invalid_free(): + """Test that freeing invalid addresses raises an error.""" + allocator = RingBufferAllocator(size=1024, align_to=256) + + # Allocate a buffer + addr, buffer = allocator.allocate(100) + assert addr >= 0 # Valid address + assert buffer is not None + + # Try to free an invalid address + with pytest.raises(AssertionError): + allocator.free(100) # Invalid address diff --git a/tests/v1/kv_connector/cpu_kv_integration/toy_decode.py b/tests/v1/kv_connector/cpu_kv_integration/toy_decode.py new file mode 100644 index 000000000000..6965f33f5eb4 --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/toy_decode.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os + +# VLLM_ENABLE_V1_MULTIPROCESSING=0 +# VLLM_WORKER_MULTIPROC_METHOD=spawn +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def get_kv_transfer_params(req_id: int): + return {"prefill_request_id": str(req_id)} + + +if __name__ == "__main__": + + context = "Hi " * 1000 + context2 = "Hi " * 1000 + context3 = "Hello " * 1000 + context4 = "How " * 1000 + prompts = [ + context + "Hello, my name is", + context2 + "The capital of France is", + context3 + "Your name is", + context4 + "The capital of China is", + ] + + sampling_param_base = SamplingParams(temperature=0, + top_p=0.95, + max_tokens=10) + sampling_params = [] + for i in range(len(prompts)): + sampling_param = sampling_param_base.clone() + sampling_param.extra_args = { + "kv_transfer_params": get_kv_transfer_params(i), + } + sampling_params.append(sampling_param) + + llm = LLM( + model="meta-llama/Llama-3.1-8B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig( + kv_connector="CPUConnector", + kv_role="kv_consumer", + kv_connector_extra_config={ + "host": "localhost", + "port": 54321, + "size": 4, + }, + ), + #load_format="dummy", + max_model_len=2048, + max_num_batched_tokens=2048, + block_size=128, + tensor_parallel_size=1, + ) + + # 1ST generation (prefill instance) + outputs = llm.generate( + prompts, + sampling_params, + ) + + new_prompts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt[-30:] + generated_text) + #print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # Write new_prompts to output.txt + with open("output_decode.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") + print(f"Saved {len(new_prompts)} prompts to output.txt") + + from vllm.distributed.kv_transfer import get_kv_transfer_group + try: + cpu_connector = get_kv_transfer_group() + cpu_connector.close() + except Exception: + pass diff --git a/tests/v1/kv_connector/cpu_kv_integration/toy_decoder_manager.py b/tests/v1/kv_connector/cpu_kv_integration/toy_decoder_manager.py new file mode 100644 index 000000000000..68b8d4e3f8d6 --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/toy_decoder_manager.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time + +import torch.multiprocessing as mp + +import vllm.distributed.kv_transfer.kv_connector.v1.nixl_cpu_utils as utils +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_cpu_utils import ( + NixlDecodeManager) + + +def main(): + """Main function to run the receiver.""" + # Setup test parameters + test_host = "127.0.0.1" + test_base_port = 54321 + test_rank = 0 + expected_layers = 32 + + # Buffer configuration + buffer_size = 1 << 30 # 1GB + + try: + # Mock tensor_model_parallel_rank for this process + utils.get_tensor_model_parallel_rank = lambda: test_rank + utils.get_tensor_model_parallel_world_size = lambda: 1 + utils.get_tp_group = lambda: None + + decoder_manager = NixlDecodeManager(buffer_size, test_host, + test_base_port) + + print(f"Receiver started on {test_host}:{test_base_port}") + + # Run progress loop until interrupted + try: + while True: + decoder_manager.progress() + finished = decoder_manager.get_finished(expected_layers) + print(f"Got {len(finished)} finished requests") + + for req_id in finished: + print(f"Processing finished request {req_id}") + for i in range(expected_layers): + decode_specs = decoder_manager.get_kv_specs(req_id, i) + for spec in decode_specs: + print( + f"Received layer {i} tokens " + f"{spec.start} - {spec.stop} request {req_id}. " + f"The shape is {spec.buffer.shape}. " + f"The digest is {spec.buffer.mean()}.") + + decoder_manager.free_request(req_id) + + allocator = decoder_manager._allocator + print("Allocator high/low watermark:", + allocator.high_watermark, allocator.low_watermark) + time.sleep(1) # Small sleep to prevent busy waiting + + except KeyboardInterrupt: + decoder_manager.close() + print("\nShutting down receiver...") + + print("Receiver stopped") + + except Exception as e: + print(f"Receiver error: {e}") + raise + + +if __name__ == "__main__": + # Set multiprocessing start method + mp.set_start_method("spawn", force=True) + main() diff --git a/tests/v1/kv_connector/cpu_kv_integration/toy_example.py b/tests/v1/kv_connector/cpu_kv_integration/toy_example.py new file mode 100644 index 000000000000..a0ab6b74c43e --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/toy_example.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os + +# VLLM_ENABLE_V1_MULTIPROCESSING=0 +# VLLM_WORKER_MULTIPROC_METHOD=spawn +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +if __name__ == "__main__": + + context = "Hi " * 1000 + context2 = "Hey " * 1000 + context3 = "Hello " * 1000 + context4 = "How " * 1000 + prompts = [ + context + "Hello, my name is", + context2 + "The capital of France is", + context3 + "Your name is", + context4 + "The capital of China is", + ] + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + llm = LLM( + model="meta-llama/Llama-3.1-8B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig( + kv_connector="CPUConnector", + kv_role="kv_producer", + kv_connector_extra_config={ + "host": "localhost", + "port": 54321, + "size": 4, + }, + ), + #load_format="dummy", + max_model_len=2048, + max_num_batched_tokens=2048, + block_size=128, + tensor_parallel_size=1, + ) + + # 1ST generation (prefill instance) + outputs = llm.generate( + prompts, + sampling_params, + ) + + new_prompts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + #print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # Write new_prompts to output.txt + with open("output.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") + print(f"Saved {len(new_prompts)} prompts to output.txt") + + # HACK: for offline single-process inference only + # Wait for all send finishes + from vllm.distributed.kv_transfer import get_kv_transfer_group + try: + cpu_connector = get_kv_transfer_group() + cpu_connector.close() + except Exception: + pass diff --git a/tests/v1/kv_connector/cpu_kv_integration/toy_example_outdated.py b/tests/v1/kv_connector/cpu_kv_integration/toy_example_outdated.py new file mode 100644 index 000000000000..824110e087af --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/toy_example_outdated.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +@pytest.fixture +def env_setup(): + """Set up required environment variables""" + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +@pytest.fixture +def input_prompts(): + """Create test prompts""" + context = "Hi " * 10 # Reduced size for testing + context2 = "Hey " * 10 + context3 = "Hello " * 10 + context4 = "How " * 10 + return [ + context + "Hello, my name is", + context2 + "The capital of France is", + context3 + "Your name is", + context4 + "The capital of China is", + ] + + +@pytest.fixture +def llm_instance(): + """Create LLM instance with test configuration""" + return LLM( + model="meta-llama/Llama-3.1-8B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig( + kv_connector="CPUConnector", + kv_role="kv_producer", + kv_connector_extra_config={}, + ), + load_format="dummy", + max_model_len=2048, + max_num_batched_tokens=2048, + block_size=64, + ) + + +def test_llm_generation(env_setup, input_prompts, llm_instance, tmp_path): + """Test LLM generation and output saving""" + # Configure sampling parameters + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + # Generate outputs + outputs = llm_instance.generate(input_prompts, sampling_params) + + # Verify outputs + assert len(outputs) == len( + input_prompts), "Number of outputs should match number of prompts" + + # Process outputs + new_prompts = [] + for output in outputs: + assert hasattr(output, 'prompt'), "Output should have prompt attribute" + assert hasattr(output, + 'outputs'), "Output should have outputs attribute" + assert len(output.outputs) > 0, "Output should have generated text" + + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + + # Test file writing + output_file = tmp_path / "output.txt" + with open(output_file, "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") + + # Verify file contents + assert output_file.exists(), "Output file should be created" + with open(output_file) as f: + lines = f.readlines() + assert len(lines) == len( + input_prompts), "File should contain all prompts" + for line in lines: + assert line.strip(), "Lines should not be empty" diff --git a/tests/v1/kv_connector/cpu_kv_integration/toy_proxy_server.py b/tests/v1/kv_connector/cpu_kv_integration/toy_proxy_server.py new file mode 100644 index 000000000000..636ed81dd6c8 --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/toy_proxy_server.py @@ -0,0 +1,215 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import time +from contextlib import asynccontextmanager + +import httpx +import numpy as np +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize clients + prefiller_base_url = ( + f"http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1") + decoder_base_url = ( + f"http://{global_args.decoder_host}:{global_args.decoder_port}/v1") + + app.state.prefill_client = httpx.AsyncClient(timeout=None, + base_url=prefiller_base_url) + app.state.decode_client = httpx.AsyncClient(timeout=None, + base_url=decoder_base_url) + + yield + + # Shutdown: Close clients + await app.state.prefill_client.aclose() + await app.state.decode_client.aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +class StatsCalculator: + + def __init__(self): + self._stats = [] + self._last_log_time = time.time() + + def add(self, value): + self._stats.append(value) + if time.time() - self._last_log_time > 5: + self._log_stats() + self._last_log_time = time.time() + + def _log_stats(self): + # Print average, median, and 99th percentile + np_arr = np.array(self._stats) + output_str = ( + f"\nNum requests: {len(self._stats)}" + + "\nPrefill node TTFT stats:" + + f"\n - Average (ms): {np.mean(np_arr)}" + + f"\n - Median (ms): {np.median(np_arr)}" + + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n") + print( + "===============================", + output_str, + "===============================", + ) + + +stats_calculator = StatsCalculator() +counter = 0 + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-host", type=str, default="localhost") + parser.add_argument("--prefiller-port", type=int, default=8100) + parser.add_argument("--decoder-host", type=str, default="localhost") + parser.add_argument("--decoder-port", type=int, default=8200) + args = parser.parse_args() + return args + + +# Initialize variables to hold the persistent clients +app.state.prefill_client = None +app.state.decode_client = None + + +async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data["max_tokens"] = 1 + req_data["stream"] = False + if "stream_options" in req_data: + del req_data["stream_options"] + if "max_completion_tokens" in req_data: + req_data["max_completion_tokens"] = 1 + + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + response = await client.post(endpoint, json=req_data, headers=headers) + print("Got the response:", response.json()) + response.raise_for_status() + return response + + +async def stream_service_response(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Asynchronously stream the response from a service using a persistent client. + """ + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + async with client.stream("POST", endpoint, json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + print("Received a new request!") + + # Send request to prefill service, ignore the response + response = await send_request_to_service(app.state.prefill_client, + "/completions", req_data) + + # Extract the needed fields + response_json = response.json() + kv_transfer_params = response_json.get('kv_transfer_params', {}) + print("Got kv_transfer_params:", kv_transfer_params) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + print("Streaming response from decode service") + async for chunk in stream_service_response(app.state.decode_client, + "/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server", + "- completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, + "/chat/completions", req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/chat/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server ", + "- chat completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +if __name__ == "__main__": + global global_args + global_args = parse_args() + + import uvicorn + + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/tests/v1/kv_connector/cpu_kv_integration/toy_receiver.py b/tests/v1/kv_connector/cpu_kv_integration/toy_receiver.py new file mode 100644 index 000000000000..9059151d56f5 --- /dev/null +++ b/tests/v1/kv_connector/cpu_kv_integration/toy_receiver.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time + +import torch.multiprocessing as mp + +import vllm.distributed.kv_transfer.kv_connector.v1.nixl_cpu_utils as utils +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_cpu_utils import ( + NixlCPUReceiver, RingBufferAllocator) + + +def main(): + """Main function to run the receiver.""" + # Setup test parameters + test_host = "127.0.0.1" + test_base_port = 54321 + test_rank = 0 + + # Buffer configuration + buffer_size = 1 << 30 # 1GB + nixl_page_size = 4096 # Standard page size + + try: + # Mock tensor_model_parallel_rank for this process + utils.get_tensor_model_parallel_rank = lambda: test_rank + + # Create ring buffer allocator + allocator = RingBufferAllocator(size=buffer_size, + align_to=nixl_page_size) + allocator._buffer.fill_(0) + + # Create and start receiver + receiver = NixlCPUReceiver(allocator=allocator, + nixl_page_size=nixl_page_size) + receiver.start_handshake_listener(test_host, test_base_port) + + print(f"Receiver started on {test_host}:{test_base_port}") + + # Run progress loop until interrupted + try: + while True: + receiver.progress() + + # Check for finished requests + finished = receiver.get_finished(clear=True) + if finished: + for source_spec, vaddr in finished: + print( + f"Got data from request {source_spec.request_id}") + paddr = allocator.virtual_to_physical(vaddr) + + # Verify received data + num_elements = source_spec.get_size() + received_data = allocator._buffer\ + [paddr : paddr + num_elements]\ + .view(source_spec.dtype)\ + .reshape(source_spec.tensor_shape) + print(f"Received layer {source_spec.layer_id} tokens " + f"{source_spec.start} - {source_spec.stop} of " + f"request {source_spec.request_id}") + print(f"The shape is {received_data.shape}") + print(f"The digest is {received_data.mean()}") + allocator.free(vaddr) + + print("Allocator high/low watermark:", + allocator.high_watermark, allocator.low_watermark) + time.sleep(1) # Small sleep to prevent busy waiting + + except KeyboardInterrupt: + print("\nShutting down receiver...") + + # Cleanup + receiver.stop_handshake_listener() + print("Receiver stopped") + + except Exception as e: + print(f"Receiver error: {e}") + raise + + +if __name__ == "__main__": + # Set multiprocessing start method + mp.set_start_method("spawn", force=True) + main() diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 58dfa251c735..c3f763ac3801 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -126,3 +126,8 @@ def create_connector_v1( "MultiConnector", "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", "MultiConnector") + +KVConnectorFactory.register_connector( + "CPUConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.cpu_connector", + "CPUConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index f80b5eba235d..8b2e4b001cdd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -281,3 +281,10 @@ def request_finished( returned by the engine. """ return False, None + + def close(self) -> None: + """ + Close the connector. This is called when the connector is no longer + needed. + """ + return diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/cpu_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/cpu_connector.py new file mode 100644 index 000000000000..c6a1f0738056 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/cpu_connector.py @@ -0,0 +1,928 @@ +# SPDX-License-Identifier: Apache-2.0 +import threading +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.cpu_connector_utils import ( + DestinationSpec, SourceSpec) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_cpu_utils import ( + NixlDecodeManager, NixlPrefillManager, NixlSendTask) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.utils import cdiv, round_down +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import KVTransferConfig + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.core.sched.output import CachedRequestData, NewRequestData + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +def d2h_page_copy(src_layer: torch.Tensor, dst_buffer: torch.Tensor, + block_ids: list[int]) -> None: + """Copy data from device to host. + + Args: + src_layer (torch.Tensor): The source layer on device, shape is + (2, num_vllm_blocks, page_size, ...remaining dims...) + dst_buffer (torch.Tensor): The destination buffer on host, shape is + (2, len(block_ids), page_size, ...remaining dims...) + block_ids (list[int]): The list of vllm block ids to copy from. + """ + block_mapping = torch.stack([ + torch.tensor(block_ids, dtype=torch.long), + torch.arange(len(block_ids), dtype=torch.long) + ], + dim=1) + ops.swap_blocks(src_layer[0], dst_buffer[0], block_mapping) + ops.swap_blocks(src_layer[1], dst_buffer[1], block_mapping) + + +def h2d_copy_part_block(src_buffer: torch.Tensor, dst_layer: torch.Tensor, + src_block_id: int, dst_block_id: int, + start_position_in_block: int, + end_position_in_block: Optional[int]) -> None: + """Copy the part of a block from host buffer to device layer. + + Args: + src_buffer (torch.Tensor): The source buffer on host, shape is + (2, len(block_ids), page_size, ...remaining dims...) + dst_layer (torch.Tensor): The destination layer on device, shape is + (2, num_vllm_blocks, page_size, ...remaining dims...) + src_block_id (int): The source block id to copy. + dst_block_id (int): The destination block id to copy. + start_position_in_block (int): The start position in the block to copy. + end_position_in_block (int): The end position in the block to copy. + """ + if end_position_in_block is None: + # If end_position_in_block is None, copy until the end of the block + end_position_in_block = src_buffer[0][0].shape[0] + + dst_k = dst_layer[0][dst_block_id][ + start_position_in_block:end_position_in_block] + src_k = src_buffer[0][src_block_id][ + start_position_in_block:end_position_in_block] + dst_v = dst_layer[1][dst_block_id][ + start_position_in_block:end_position_in_block] + src_v = src_buffer[1][src_block_id][ + start_position_in_block:end_position_in_block] + dst_k.copy_(src_k, non_blocking=True) + dst_v.copy_(src_v, non_blocking=True) + + +def h2d_copy_leading_tokens(src_buffer: torch.Tensor, dst_layer: torch.Tensor, + src_block_id: int, dst_block_id: int, + end_position_in_block: int) -> None: + """Copy the leading tokens in 1 block from host buffer to device layer. + + Args: + src_buffer (torch.Tensor): The source buffer on host, shape is + (2, len(block_ids), page_size, ...remaining dims...) + dst_layer (torch.Tensor): The destination layer on device, shape is + (2, num_vllm_blocks, page_size, ...remaining dims...) + src_block_id (int): The source block id to copy. + dst_block_id (int): The destination block id to copy. + end_position_in_block (int): The end position in the block to copy. + """ + h2d_copy_part_block(src_buffer, dst_layer, src_block_id, dst_block_id, 0, + end_position_in_block) + + +def h2d_copy_trailing_tokens(src_buffer: torch.Tensor, dst_layer: torch.Tensor, + src_block_id: int, dst_block_id: int, + start_position_in_block: int) -> None: + """Copy the trailing tokens in 1 block from host buffer to device layer. + + Args: + src_buffer (torch.Tensor): The source buffer on host, shape is + (2, len(block_ids), page_size, ...remaining dims...) + dst_layer (torch.Tensor): The destination layer on device, shape is + (2, num_vllm_blocks, page_size, ...remaining dims...) + src_block_id (int): The source block id to copy. + dst_block_id (int): The destination block id to copy. + start_position_in_block (int): The start position in the block to copy. + """ + h2d_copy_part_block(src_buffer, dst_layer, src_block_id, dst_block_id, + start_position_in_block, None) + + +def h2d_page_copy(src_buffer: torch.Tensor, dst_layer: torch.Tensor, + block_ids: list[int], start_token_idx: int, + stop_token_idx: int, block_size: int) -> None: + """Copy data from host to device. + + Args: + src_buffer (torch.Tensor): The source buffer on host, shape is + (2, len(block_ids), page_size, ...remaining dims...) + dst_layer (torch.Tensor): The destination layer on device, shape is + (2, num_vllm_pages, page_size, ...remaining dims...) + block_ids (list[int]): The list of vllm block ids to copy to (for all + the tokens) + start_token_idx (int): The start token index in the request + stop_token_idx (int): The stop token index in the request + block_size (int): The block size in vLLM + """ + # Step 1: build the block mapping (src_block_id, dst_block_id) + separate_first_block = start_token_idx % block_size != 0 + separate_last_block = stop_token_idx % block_size != 0 + + start_block_id = start_token_idx // block_size # inclusive + end_block_id = stop_token_idx // block_size # exclusive + src_block_ids = torch.arange(start_block_id, + end_block_id, + dtype=torch.long) + num_blocks = len(src_block_ids) + + if separate_first_block: + src_block_ids = src_block_ids[1:] + # NOTE: we don't need to add the last block id here, because the + # end_block_id is exclusive + # E.g., start = 10, stop = 50, block_size = 16, then we have + # start_block_id = 0 , separate_first_block = True + # end_block_id = 3, separate_last_block = True + # src_block_ids = [1, 2] + # We will copy token 10-15 and 48-49 from the first and last block + # separately. + + vllm_block_ids = torch.tensor(block_ids, dtype=torch.long) + dst_block_ids = vllm_block_ids[src_block_ids] + real_src_block_ids = src_block_ids - start_block_id + + # Step 2: copy the first and last block separately if needed + if start_block_id == end_block_id: + # Only one block to copy + start_position_in_block = start_token_idx % block_size + end_position_in_block = stop_token_idx % block_size + #h2d_copy_part_block(src_buffer, dst_layer, start_block_id, + # vllm_block_ids[start_block_id], + # start_position_in_block, end_position_in_block) + h2d_copy_part_block(src_buffer, dst_layer, 0, + vllm_block_ids[start_block_id], + start_position_in_block, end_position_in_block) + return + + if separate_first_block: + first_block_id_src = start_block_id + first_block_id_dst = vllm_block_ids[first_block_id_src] + start_token_idx_in_block = start_token_idx % block_size + h2d_copy_trailing_tokens(src_buffer, dst_layer, 0, first_block_id_dst, + start_token_idx_in_block) + + if separate_last_block: + last_block_id_src = end_block_id + last_block_id_dst = vllm_block_ids[last_block_id_src] + stop_token_idx_in_block = stop_token_idx % block_size + h2d_copy_leading_tokens(src_buffer, dst_layer, num_blocks - 1, + last_block_id_dst, stop_token_idx_in_block) + + # Step 3: copy the middle blocks + block_mapping = torch.stack([real_src_block_ids, dst_block_ids], dim=1) + ops.swap_blocks(src_buffer[0], dst_layer[0], block_mapping) + ops.swap_blocks(src_buffer[1], dst_layer[1], block_mapping) + + +##################################################################### +# Connector related code +##################################################################### + + +@dataclass +class PrefillRequestTracker: + """RequestTracker is used to track the state of a request. + + Attributes: + req_id (str): The id of the request. + num_saved_tokens (int): The number of tokens saved. + num_loaded_tokens (int): The number of tokens loaded. + num_computed_tokens (int): The number of tokens computed. + allocated_block_ids (list[int]): The list of allocated block ids. + """ + # Request id + req_id: str + + # Block ids that are already allocated for this request + allocated_block_ids: list[int] + + # Total number of tokens in the "full request" + num_all_tokens: int = 0 + + # Total number of tokens that are already seen until this step + num_total_tokens: int = 0 + + # Number of tokens that are already saved + num_saved_tokens: int = 0 + + @staticmethod + def from_new_request( + new_request: "NewRequestData", + num_tokens_to_compute: int, + ) -> "PrefillRequestTracker": + """Create the request tracker from a new request. + + Args: + new_request (NewRequestData): the new request data. + num_tokens_to_compute (int): the number of tokens that will + be 'computed', including the `num_computed_tokens` (vLLM's + local cache hit) and new tokens that will be scheduled. + """ + unfolded_block_ids = [] + for block_ids in new_request.block_ids: + unfolded_block_ids.extend(block_ids) + + return PrefillRequestTracker( + req_id=new_request.req_id, + allocated_block_ids=unfolded_block_ids, + num_all_tokens=len(new_request.prompt_token_ids), + num_total_tokens=num_tokens_to_compute, + num_saved_tokens=0, + ) + + def update(self, cached_request: "CachedRequestData") -> None: + """Update the request tracker with the cached request data. + + Args: + cached_request (CachedRequestData): the cached request data. + """ + new_block_ids = [] + for nb in cached_request.new_block_ids: + new_block_ids.extend(nb) + self.allocated_block_ids.extend(new_block_ids) + self.num_total_tokens += len(cached_request.new_token_ids) + + def update_num_saved_tokens(self, num_saved_tokens: int) -> None: + """Update the number of saved tokens. + + Args: + num_saved_tokens (int): the number of saved tokens. + """ + self.num_saved_tokens = num_saved_tokens + + +@dataclass +class PrefillReqMeta: + # Request id + req_id: str + # Blocks to save + blocks_to_save: list[int] + # The range of tokens to save + token_range: slice + # Skip first N tokens + skip_leading_tokens: int + # Skip last N tokens + skip_trailing_tokens: int + # The number of tokens in the "full request" + num_all_tokens: int + + @staticmethod + def from_request_tracker( + request_tracker: PrefillRequestTracker, + block_size: int, + ) -> "PrefillReqMeta": + """Create the request meta from the request tracker. Determine which + blocks to save and the number of leading/trailing tokens to skip for + the worker connector. + It also updates the request tracker's num_saved_tokens. + + Args: + request_tracker (PrefillRequestTracker): the request tracker. + block_size (int): the block size in vLLM. + + Returns: + PrefillReqMeta: the request meta. + """ + assert request_tracker.num_total_tokens <= \ + len(request_tracker.allocated_block_ids) * block_size, \ + f"Request {request_tracker.req_id} has more tokens " + \ + "than allocated blocks" + + token_range = slice(request_tracker.num_saved_tokens, + request_tracker.num_total_tokens) + + num_saved_full_blocks = request_tracker.num_saved_tokens // block_size + num_active_blocks = cdiv(request_tracker.num_total_tokens, block_size) + + blocks_to_save = request_tracker.allocated_block_ids[\ + num_saved_full_blocks:num_active_blocks] + skip_leading_tokens = request_tracker.num_saved_tokens % block_size + skip_trailing_tokens = num_active_blocks * block_size - \ + request_tracker.num_total_tokens + logger.debug( + "Request %s: num_saved_full_blocks=%d, num_active_blocks=%d, " + "blocks_to_save=%s, skip_leading_tokens=%d, " + "skip_trailing_tokens=%d", request_tracker.req_id, + num_saved_full_blocks, num_active_blocks, blocks_to_save, + skip_leading_tokens, skip_trailing_tokens) + + # Update the request tracker with the number of saved tokens + request_tracker.update_num_saved_tokens( + request_tracker.num_total_tokens) + return PrefillReqMeta( + req_id=request_tracker.req_id, + blocks_to_save=blocks_to_save, + token_range=token_range, + skip_leading_tokens=skip_leading_tokens, + skip_trailing_tokens=skip_trailing_tokens, + num_all_tokens=request_tracker.num_all_tokens, + ) + + +@dataclass +class DecodeReqMeta: + # Request id + req_id: str + # Prefiller-side request id + prefill_req_id: str + # Allocated block ids + block_ids: list[int] + # Skip the first N tokens + skip_leading_tokens: int + # if it's ready or not + is_ready: bool = False + + +@dataclass +class CPUConnectorMetadata(KVConnectorMetadata): + prefill_meta: list[PrefillReqMeta] + decode_meta: list[DecodeReqMeta] + + def __init__(self) -> None: + super().__init__() + self.prefill_meta = [] + self.decode_meta = [] + + def add_prefill(self, prefill_meta: PrefillReqMeta) -> None: + """Add a prefill request metadata to the metadata. + + Args: + prefill_meta (PrefillReqMeta): The prefill request metadata to be + added. + """ + self.prefill_meta.append(prefill_meta) + + def add_decode(self, decode_meta: DecodeReqMeta) -> None: + """Add a decode request metadata to the metadata. + + Args: + decode_meta (DecodeReqMeta): The decode request metadata to be + added. + """ + self.decode_meta.append(decode_meta) + + +def validate_kv_transfer_config( + kv_transfer_config: Optional["KVTransferConfig"]) -> None: + """Validate the KV transfer configuration. + It expects the host and port configuration in the kv_connector_extra_config + + Args: + kv_transfer_config (Optional[KVTransferConfig]): The KV transfer + configuration to validate. + + Raises: + AssertionError: If the configuration is invalid. + """ + assert kv_transfer_config is not None, \ + "KV transfer config is not set in the vLLM config" + + extra_config = kv_transfer_config.kv_connector_extra_config + assert "host" in extra_config, \ + "CPUConnector: must have 'host' in kv_connector_extra_config" + assert "port" in extra_config, \ + "CPUConnector: must have 'port' in kv_connector_extra_config" + assert "size" in extra_config, \ + "CPUConnector: must have 'size' in kv_connector_extra_config" + + +class CPUConnector(KVConnectorBase_V1): + """CPUKVConnector is an implementation of KVConnectorBase_V1 that + provides a CPU-based KV cache sending mechanism. + """ + + def __init__(self, vllm_config: "VllmConfig", + role: KVConnectorRole) -> None: + super().__init__(vllm_config, role) + + validate_kv_transfer_config(vllm_config.kv_transfer_config) + extra_config = vllm_config.kv_transfer_config.kv_connector_extra_config + self._host = extra_config["host"] + self._port = int(extra_config["port"]) + # Convert GB to bytes and align to 4K for storage size + kv_size_in_bytes = float(extra_config["size"]) * (1 << 30) + kv_size_in_bytes = int(kv_size_in_bytes) & (~0xFFF) # Align to 4K + self._kv_size = kv_size_in_bytes + + self.kv_role = vllm_config.kv_transfer_config.kv_role + + self._block_size = vllm_config.cache_config.block_size + + if role == KVConnectorRole.SCHEDULER: + self._should_be_ready_reqs: set[str] = set() + elif role == KVConnectorRole.WORKER: + # Prefiller side sender + if self.kv_role == "kv_producer": + self._kv_sender = NixlPrefillManager(self._kv_size) + self._kv_sender_lock = threading.Lock() + self._kv_sender_stop_event = threading.Event() + self._kv_sender_thread = threading.Thread( + target=self._kv_sender_processor, + daemon=True, + ) + self._kv_sender_thread.start() + + elif self.kv_role == "kv_consumer": + self._kv_receiver = NixlDecodeManager( + self._kv_size, + self._host, + self._port, + ) + else: + raise ValueError(f"Unknown kv_role: {self.kv_role}") + + # request_id -> prefill request trackers + self._prefill_reqs: dict[str, PrefillRequestTracker] = {} + + # gpu kv caches + self._gpu_kv_caches: dict[str, torch.Tensor] = {} + self._layer_name_to_id: dict[str, int] = {} + self._layer_id_to_name: dict[int, str] = {} + self._kv_page_shape: torch.Size = torch.Size([0]) + + # separate cuda streams + self._cuda_stream = torch.cuda.Stream() + + # prefill offload tasks + self._inflight_copy_tasks: list[NixlSendTask] = [] + + # Decode request id to prefill request id mapping + self._decode_req_id_to_prefill_req_id: dict[str, str] = {} + self._prefill_req_id_to_decode_req_id: dict[str, str] = {} + + # Decode request metadata for scheduler connector + # decode request id -> DecodeReqMeta + self._decode_req_metas: dict[str, DecodeReqMeta] = {} + + # Decode h2d cuda events + # layer id -> cuda event + self._decoder_cuda_events: dict[int, torch.cuda.Event] = {} + + # In-progress kv load requests's prefill request ids + self._inflight_h2d_requests: set[str] = set() + + def _connect_request_ids(self, p_reqid: str, d_reqid: str) -> None: + self._decode_req_id_to_prefill_req_id[d_reqid] = p_reqid + self._prefill_req_id_to_decode_req_id[p_reqid] = d_reqid + + ############################################################ + # Scheduler Side Methods + ############################################################ + def _build_prefiller_meta(self, scheduler_output: SchedulerOutput, + output_meta: CPUConnectorMetadata) -> None: + """Build the prefill request metadata from the scheduler output. + + Args: + scheduler_output (SchedulerOutput): The scheduler output. + output_meta (CPUConnectorMetadata): The output metadata. + """ + for finished_req_id in scheduler_output.finished_req_ids: + self._prefill_reqs.pop(finished_req_id, None) + + for request in scheduler_output.scheduled_new_reqs: + num_tokens_to_compute = request.num_computed_tokens + \ + scheduler_output.num_scheduled_tokens[request.req_id] + request_tracker = PrefillRequestTracker.from_new_request( + request, num_tokens_to_compute) + self._prefill_reqs[request.req_id] = request_tracker + + req_meta = PrefillReqMeta.from_request_tracker( + request_tracker, self._block_size) + output_meta.add_prefill(req_meta) + + for request in scheduler_output.scheduled_cached_reqs: + request_tracker = self._prefill_reqs[request.req_id] + request_tracker.update(request) + + req_meta = PrefillReqMeta.from_request_tracker( + request_tracker, self._block_size) + output_meta.add_prefill(req_meta) + + def build_decode_meta(self, scheduler_output: SchedulerOutput, + output_meta: CPUConnectorMetadata) -> None: + """Build the decode request metadata from the scheduler output. + + Args: + scheduler_output (SchedulerOutput): The scheduler output. + output_meta (CPUConnectorMetadata): The output metadata. + """ + updated_decode_req_metas = {} + for req_meta in self._decode_req_metas.values(): + if not req_meta.is_ready: + updated_decode_req_metas[req_meta.req_id] = req_meta + # NOTE (ApostaC): Even if the request is not ready, we still + # want the worker connector to know about it, so that it can + # connect the decode request id to the prefill request id + output_meta.add_decode(req_meta) + self._decode_req_metas = updated_decode_req_metas + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + # NOTE(ApostaC): For a single request, this function will be called + # two times if the first time we returned async_load flag as True. + # The second time will be the "real schedule" time + + if self.kv_role == "kv_producer": + return 0, False + + kv_transfer_params = request.kv_transfer_params + num_tokens = len(request.prompt_token_ids) + request_id = request.request_id + logger.info( + "For request %s, num_computed_tokens is %d, " + "total_num_tokens is %d", request_id, num_computed_tokens, + num_tokens) + + num_extra_tokens = round_down(num_tokens, + self._block_size) - num_computed_tokens + + if num_extra_tokens < self._block_size: + # If the request is smaller than the block size, we don't need + # to do anything special + logger.info( + "Request %s is smaller than block size %d, " + "no async loading", request_id, self._block_size) + return 0, False + + # Seen this request before, which means it should be ready this time, + # so we don't need to do async loading again + if request.request_id in self._should_be_ready_reqs: + self._should_be_ready_reqs.remove(request.request_id) + return 0, False + + if kv_transfer_params is None or \ + "prefill_request_id" not in kv_transfer_params: + logger.warning("Request %s does not have prefill_request_id", + request.request_id) + return 0, False + + prefill_request_id = kv_transfer_params["prefill_request_id"] + self._connect_request_ids(prefill_request_id, request_id) + self._should_be_ready_reqs.add(request_id) + + # NOTE: because the scheduler wants here to return "full blocks" if + # the async flag is true (see _update_waiting_for_remote_kv in + # scheduler.py). We need to carefully deal with it when copying + # the KV cache at worker side + return num_extra_tokens, True + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int) -> None: + """Update the state of the request after allocation. + """ + # NOTE(ApostaC): This function is called twice for the same request + # when we are using async loading. The first time is we got all the + # external "hit" blocks in `blocks`, and the second time we will have + # the remaining "last" block as a newly allocated block. + if self.kv_role == "kv_producer": + return + + if request.request_id in self._decode_req_metas: + # This is the second time we are called for the same request + # We need to mark the request as "ready" + self._decode_req_metas[request.request_id].is_ready = True + return + + if request.request_id not in self._decode_req_id_to_prefill_req_id: + # This should not happen, but just in case + logger.warning( + "Request %s does not have prefill request id, " + "skipping decode meta creation", request.request_id) + return + + p_req_id = self._decode_req_id_to_prefill_req_id[request.request_id] + block_ids = [] + for blks in blocks.get_block_ids(): + block_ids.extend(blks) + req_meta = DecodeReqMeta(req_id=request.request_id, + prefill_req_id=p_req_id, + block_ids=block_ids, + skip_leading_tokens=0, + is_ready=False) + self._decode_req_metas[request.request_id] = req_meta + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + meta = CPUConnectorMetadata() + + if self.kv_role == "kv_producer": + self._build_prefiller_meta(scheduler_output, meta) + elif self.kv_role == "kv_consumer": + self.build_decode_meta(scheduler_output, meta) + else: + raise ValueError(f"Unknown kv_role: {self.kv_role}") + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + if self.kv_role == "kv_consumer": + return False, None + # For prefiller, send back the prefiller request id + logger.info("Prefill request %s finished", request.request_id) + return False, dict(prefill_request_id=request.request_id) + + ############################################################# + # Worker Side Methods + ############################################################# + def _kv_sender_processor(self) -> None: + """Process the KV sender tasks in a separate thread.""" + while not self._kv_sender_stop_event.is_set(): + with self._kv_sender_lock: + self._kv_sender.progress() + time.sleep(0.001) # Sleep for a short time to avoid busy waiting + + def _get_layer_id(self, layer_name: str) -> int: + assert layer_name in self._layer_name_to_id, \ + f"Layer {layer_name} not found in layer name to id map" + return self._layer_name_to_id[layer_name] + + def _get_layer_name(self, layer_id: int) -> str: + assert layer_id in self._layer_id_to_name, \ + f"Layer id {layer_id} not found in layer id to name map" + return self._layer_id_to_name[layer_id] + + def _get_kv_shape(self, num_blocks: int) -> torch.Size: + return torch.Size(( + 2, + num_blocks, + ) + self._kv_page_shape) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + self._gpu_kv_caches = kv_caches + for idx, layer_name in enumerate(kv_caches): + self._layer_name_to_id[layer_name] = idx + self._layer_id_to_name[idx] = layer_name + + self._kv_page_shape = kv_caches[list(kv_caches.keys())[0]].shape[2:] + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + if self.kv_role == "kv_producer": + # encoder side + return + + meta = self._get_connector_metadata() + assert isinstance(meta, CPUConnectorMetadata), \ + "Connector metadata is not of type CPUConnectorMetadata" + + ready_decode_metas = [] + total_expected_tokens = [] + for decode_meta in meta.decode_meta: + self._connect_request_ids(decode_meta.prefill_req_id, + decode_meta.req_id) + if decode_meta.is_ready: + ready_decode_metas.append(decode_meta) + total_expected_tokens.append( + len(decode_meta.block_ids) * \ + self._block_size) + self._inflight_h2d_requests.add(decode_meta.prefill_req_id) + + # Vars needed: + # decode_meta.prefill_req_id + if len(ready_decode_metas) == 0: + return + + for layer_id in range(len(self._gpu_kv_caches)): + for decode_meta, total_expected in zip(ready_decode_metas, + total_expected_tokens): + decode_specs = self._kv_receiver.get_kv_specs( + decode_meta.prefill_req_id, layer_id) + layer_name = self._layer_id_to_name[layer_id] + dst_layer = self._gpu_kv_caches[layer_name] + for decode_spec in decode_specs: + start = decode_spec.start + stop = min(decode_spec.stop, total_expected) + if start >= total_expected: + continue + src_buffer = decode_spec.buffer + block_ids = decode_meta.block_ids + + with torch.cuda.stream(self._cuda_stream): + h2d_page_copy(src_buffer, dst_layer, block_ids, start, + stop, self._block_size) + + # Record the cuda event for this layer + event = torch.cuda.Event() + event.record(self._cuda_stream) + self._decoder_cuda_events[layer_id] = event + + # TODO (ApostaC): Potential optimizations + # 1. coalesce the h2d page copy to a single call + # 2. Don't launch all the layers, but just first 2 layers + # 2.1 launch the rest of the layers during the `wait_for_layer_load` + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + if self.kv_role == "kv_producer": + # encoder side + return + + layer_id = self._get_layer_id(layer_name) + event = self._decoder_cuda_events.pop(layer_id, None) + if event is not None: + event.synchronize() + + if layer_id == len(self._gpu_kv_caches) - 1: + # Free the memory for the whole request + for p_req_id in self._inflight_h2d_requests: + logger.info("Freeing request %s, current watermark: [%d, %d]", + p_req_id, + self._kv_receiver._allocator.low_watermark, + self._kv_receiver._allocator.high_watermark) + self._kv_receiver.free_request(p_req_id) + self._inflight_h2d_requests.clear() + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + if self.kv_role == "kv_consumer": + # decoder side + return + + meta = self._get_connector_metadata() + assert isinstance(meta, CPUConnectorMetadata), \ + "Connector metadata is not of type CPUConnectorMetadata" + assert self._kv_sender is not None + + for prefill_req in meta.prefill_meta: + # Create a source spec with serializable types + source_spec = SourceSpec( + request_id=prefill_req.req_id, + layer_id=self._get_layer_id(layer_name), + start=prefill_req.token_range.start, + stop=prefill_req.token_range.stop, + shape=tuple(self._get_kv_shape(len( + prefill_req.blocks_to_save))), + dtype_str=str(kv_layer.dtype).split('.') + [-1], # Convert torch.float32 -> "float32" + num_all_tokens=prefill_req.num_all_tokens, + ) + + # Create a destination spec + dest_spec = DestinationSpec( + rank=get_tensor_model_parallel_rank(), + host=self._host, + base_port=self._port, + ) + + # Create the send task + with self._kv_sender_lock: + task = self._kv_sender.create_send_task( + source_spec=source_spec, + destination_spec=dest_spec, + ) + assert isinstance(task, NixlSendTask), \ + "Send task is not of type NixlSendTask" + + # Start copying the data to the CPU buffer + buffer = task.tensor + self._cuda_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._cuda_stream): + # Copy the data from the GPU to the CPU buffer page by page + d2h_page_copy(src_layer=kv_layer, + dst_buffer=buffer, + block_ids=prefill_req.blocks_to_save) + + # record the cuda stream + task.cuda_event = torch.cuda.Event() + task.cuda_event.record(self._cuda_stream) + + self._inflight_copy_tasks.append(task) + + # TODO(ApostaC): Potential optimizations + # 1. coalesce the d2h page copy to a single call + # 2. use a single cuda event instead of a list of cuda events + # 3. use a cuda event pool to prevent the creation overhead + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + if self.kv_role == "kv_consumer": + # decoder side + return + + # Check the task states and send the tasks + for task in self._inflight_copy_tasks: + if task.cuda_event is not None: + task.cuda_event.synchronize() + #self._kv_sender.progress() + self._inflight_copy_tasks.clear() + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer, + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + if self.kv_role != "kv_consumer": + return None, None + + # decoder (kv_consumer) side + self._kv_receiver.progress() + p_ready_reqs = self._kv_receiver.get_finished(len(self._gpu_kv_caches)) + ret = set() + for p_req_id in p_ready_reqs: + if d_req_id := self._prefill_req_id_to_decode_req_id.get(p_req_id): + # We have seen the corresponding decode request before. + # Therefore, we can return the request id. + ret.add(d_req_id) + else: + # We haven't seen the corresponding decode request + # before. Therefore, we should make the receiver + # to return the request id again in the next + # call to get_finished. + self._kv_receiver.remove_ready_request(p_req_id) + + if ret: + logger.info("Got finished requests: %s", ret) + + return None, ret + + def close(self): + """ + Block until all the transfers are done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + if hasattr(self, "_kv_sender") and self._kv_sender is not None: + self._kv_sender_stop_event.set() + if hasattr(self, "_kv_sender_thread") and \ + self._kv_sender_thread is not None: + self._kv_sender_thread.join() + self._kv_sender.close() + + if hasattr(self, "_kv_receiver") and self._kv_receiver is not None: + self._kv_receiver.close() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/cpu_connector_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/cpu_connector_utils.py new file mode 100644 index 000000000000..5147c710ac0d --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/cpu_connector_utils.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 +import math +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import msgspec +import torch + +from vllm.logger import init_logger + +if TYPE_CHECKING: + pass + +logger = init_logger(__name__) + + +@dataclass +class DestinationSpec: + """DestinationSpec is used to specify the destination of kv sending task. + + Attributes: + rank (int): The rank of the destination. + host (str): The path of the destination. + base_port (int): The base port of the destination. + """ + rank: int + host: str + base_port: int + + def __str__(self) -> str: + return f"DestinationSpec(rank={self.rank}, " + \ + f"host={self.host}, base_port={self.base_port})" + + def get_id(self) -> str: + """Get the id of the destination spec. + + Returns: + str: The id of the destination spec. + """ + return f"{self.rank}_{self.host}_{self.base_port}" + + +class SourceSpec(msgspec.Struct): + """SourceSpec is used to specify the source of kv sending task. + """ + # The request id of the kv cache + request_id: str + + # The layer id of the kv cache + layer_id: int + + # The range of tokens to be offloaded + start: int # For token_range slice + stop: int # For token_range slice + + # The shape of the offloaded KV cache tensor as a tuple + shape: tuple[int, ...] + + # The dtype of the offloaded KV cache tensor as a string + dtype_str: str + + # The total number of tokens in the "full request" + num_all_tokens: int + + @property + def token_range(self) -> slice: + """Get the token range as a slice object.""" + return slice(self.start, self.stop) + + @property + def tensor_shape(self) -> torch.Size: + """Get the shape as a torch.Size object.""" + return torch.Size(self.shape) + + @property + def dtype(self) -> torch.dtype: + """Get the dtype as a torch.dtype object.""" + return getattr(torch, self.dtype_str) + + def get_size(self) -> int: + """Get the size in bytes of the corresponding kv cache.""" + return math.prod(self.shape) * self.dtype.itemsize + + def __str__(self) -> str: + return (f"SourceSpec(request_id={self.request_id}, " + f"layer_id={self.layer_id}, " + f"token_range={self.token_range}, shape={self.tensor_shape})") + + +@dataclass +class DecoderKVSpec: + # Start index of the KV cache (inclusive) + start: int + # Stop index of the KV cache (exclusive) + stop: int + # The shape of the KV cache + buffer: torch.Tensor + + +@dataclass +class SendTaskState: + """SendTaskState is used to track the state of a send task. + """ + sender_ready: bool = False + receiver_ready: bool = False + is_sending: bool = False + send_done: bool = False + + def __str__(self) -> str: + return (f"SendTaskState(sender_ready={self.sender_ready}, " + f"receiver_ready={self.receiver_ready}, " + f"is_sending={self.is_sending}, " + f"send_done={self.send_done})") + + def is_ready(self) -> bool: + """Check if the send task is ready to be sent. + + Returns: + bool: True if the send task is ready, False otherwise. + """ + return self.sender_ready and self.receiver_ready + + def is_done(self) -> bool: + """Check if the send task is done. + + Returns: + bool: True if the send task is done, False otherwise. + """ + return self.send_done + + +@dataclass +class SendTask: + """Wraps a KV Cache sending task + """ + + # A flat buffer holding the tensor data + buffer: torch.Tensor + source_spec: SourceSpec + destination_spec: DestinationSpec + state: SendTaskState + + @property + def tensor(self) -> torch.Tensor: + """Get the tensor of the send task. + + Returns: + torch.Tensor: The tensor of the send task. + """ + num_elements = self.source_spec.tensor_shape.numel() + return self.buffer.view(self.source_spec.dtype)[:num_elements].view( + self.source_spec.tensor_shape) + + def update_states(self) -> None: + """Update the states of the send task. This needs to be OVERWRITTEN in + subclasses to handle different types of send tasks. + + This function should be called periodically to ensure that the send + task is being processed. + """ + raise NotImplementedError + + def is_ready(self) -> bool: + """Check if the send task is ready to be sent. + + Returns: + bool: True if the send task is ready, False otherwise. + """ + return self.state.is_ready() + + def is_sending(self) -> bool: + """Check if the send task is currently sending. + + Returns: + bool: True if the send task is sending, False otherwise. + """ + return self.state.is_sending + + def is_done(self) -> bool: + """Check if the send task is done. + + Returns: + bool: True if the send task is done, False otherwise. + """ + return self.state.is_done() + + def mark_sending(self) -> None: + """Mark the send task as sending. + """ + self.state.is_sending = True + + +class KVSenderInterface(ABC): + """KVSenderInterface is an interface for sending KV cache data. + """ + + def __init__(self) -> None: + self._send_tasks: list[SendTask] = [] + + def add_send_task(self, task: SendTask) -> None: + """Add a send task to the list of send tasks. + + Args: + task (SendTask): The send task to be added. + """ + self._send_tasks.append(task) + + def get_send_tasks(self) -> list[SendTask]: + """Get the list of send tasks. + + Returns: + list[SendTask]: The list of send tasks. + """ + return self._send_tasks + + def progress(self) -> None: + """A fast, non-blocking function to check and update the states of all + send tasks. This function should be called periodically to ensure that + the send tasks are being processed. + """ + # Update before going through all send tasks + self.pre_progress_hook() + + new_task_list = [] + + num_sent = 0 + num_freed = 0 + for task in self._send_tasks: + should_add = True + + if task.is_ready() and not task.is_sending(): + self.send_task(task) + task.mark_sending() + num_sent += 1 + + if task.is_done(): + self.free_task(task) + should_add = False + num_freed += 1 + + if should_add: + new_task_list.append(task) + + self._send_tasks = new_task_list + + # Update after going through all send tasks + self.post_progress_hook() + + if num_sent > 0 or num_freed > 0: + logger.debug("KVSender progress: sent %d, freed %d", num_sent, + num_freed) + + ###################################################### + # Abstract methods (to be implemented by subclasses) # + ###################################################### + + @abstractmethod + def create_send_task( + self, + source_spec: SourceSpec, + destination_spec: DestinationSpec, + ) -> SendTask: + """Create a non-ready send task with a CPU buffer allocated. + + Args: + source_spec (SourceSpec): The source specification of the send + task. + destination_spec (DestinationSpec): The destination + specification of the send task. + """ + raise NotImplementedError("create_send_task() not implemented") + + @abstractmethod + def free_task(self, task: SendTask) -> None: + """Free the send task. + Will be called in the pre-implemented progress() method. + + Args: + task (SendTask): The send task to be freed. + """ + raise NotImplementedError("free_task() not implemented") + + @abstractmethod + def send_task(self, task: SendTask) -> None: + """Send the send task after it is ready. + Will be called in the pre-implemented progress() method. + + Args: + task (SendTask): The send task to be sent. + """ + raise NotImplementedError("send_task() not implemented") + + @abstractmethod + def pre_progress_hook(self) -> None: + """Hook to be called before processing the send task. + """ + raise NotImplementedError("pre_progress_hook() not implemented") + + @abstractmethod + def post_progress_hook(self) -> None: + """Hook to be called after processing the send task. + """ + raise NotImplementedError("post_progress_hook() not implemented") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_cpu_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_cpu_utils.py new file mode 100644 index 000000000000..dab1810fc765 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_cpu_utils.py @@ -0,0 +1,1174 @@ +# SPDX-License-Identifier: Apache-2.0 +import contextlib +import threading +import time +import uuid +from collections import OrderedDict, defaultdict +from collections.abc import Iterator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import msgspec +import torch +import zmq + +from vllm.distributed.kv_transfer.kv_connector.v1.cpu_connector_utils import ( + DecoderKVSpec, DestinationSpec, KVSenderInterface, SendTask, SendTaskState, + SourceSpec) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.logger import init_logger +from vllm.utils import make_zmq_path, make_zmq_socket + +if TYPE_CHECKING: + pass + +logger = init_logger(__name__) + +try: + from nixl._api import nixl_agent as NixlWrapper + from nixl._api import nixl_xfer_handle + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + nixl_xfer_handle = int + +################################################################### +# Helper classes and functions +################################################################### + +DEFAULT_NIXL_PAGE_SIZE = 4096 + + +def init_nixl_agent( + buffer_size: int, + buffer_ptr: int, + nixl_page_size: int = 4096, +) -> tuple[NixlWrapper, Any, Any, Any]: + """Initialize the NIXL agent. + + Args: + buffer_size (int): The size of the buffer. + buffer_ptr (int): The pointer to the buffer. + nixl_page_size (int, optional): The page size of NIXL. Defaults to 4096. + + Returns: + NixlWrapper: The NIXL agent. + reg_dlist: the registered memory descriptor list. + xfer_dlist: the local transfer descriptor list. + prepped_xfer_handler: the prepped transfer handler. + """ + if NixlWrapper is None: + raise RuntimeError("NIXL is not available") + + # Create a NIXL agent + nixl_agent = NixlWrapper(str(uuid.uuid4())) + + # Register the memory + memory_desc = [(buffer_ptr, buffer_size, 0, "")] + reg_descs = nixl_agent.get_reg_descs(memory_desc, mem_type="DRAM") + nixl_agent.register_memory(reg_descs) + + # Create xfer handlers + xfer_desc = [] + for base_addr in range(buffer_ptr, buffer_ptr + buffer_size, + nixl_page_size): + xfer_desc.append((base_addr, nixl_page_size, 0)) + + xfer_descs = nixl_agent.get_xfer_descs(xfer_desc, mem_type="DRAM") + xfer_handler = nixl_agent.prep_xfer_dlist("", xfer_descs, mem_type="DRAM") + + return nixl_agent, reg_descs, xfer_descs, xfer_handler + + +class RingBufferAllocator: + """RingBufferAllocator is a simple ring buffer allocator for managing + memory allocation and deallocation. + """ + + def __init__(self, size: int, align_to: int = 256) -> None: + """Initialize the ring buffer allocator with the given size. + + Args: + size (int): The size of the ring buffer (in bytes). + align_to (int): The alignment size (in bytes). Default is 8. + """ + self._size = size + self._buffer = torch.empty(size, dtype=torch.uint8) + self._high_watermark = 0 + self._low_watermark = 0 + self._align_to = align_to + + self._allocated: OrderedDict = OrderedDict() # Track allocated buffers + + # Register pin memory + cudart = torch.cuda.cudart() + cudart.cudaHostRegister(self._buffer.data_ptr(), size, 0) + + def _align_size(self, base: int) -> int: + """Align the given size to the nearest multiple of the alignment size. + + Args: + base (int): The size to be aligned. + + Returns: + int: The aligned size. + """ + return ((base - 1) // self._align_to + 1) * self._align_to + + def allocate(self, size: int) -> tuple[int, Optional[torch.Tensor]]: + """Allocate a buffer of the given size. + + Args: + size (int): The size of the buffer to be allocated. + + Returns: + Optional[tuple[int, torch.Tensor]]: A tuple containing the virtual + address of the allocated buffer and the buffer itself. If + allocation fails, returns None. + """ + # During allocation, we always make sure that high watermark and + # low watermark are aligned to the alignment size + aligned_size = self._align_size(size) # Align the requested size + turnaround_size = (self._high_watermark // self._size + 1) * self._size + + local_high = self._high_watermark % self._size + local_low = self._low_watermark % self._size + + if local_high >= local_low: + if local_high == local_low and \ + self._high_watermark > self._low_watermark: + # No space available + return -1, None + + # If high watermark + requested size is okay, directly allocate + if local_high + size < self._size: + address = self._high_watermark + self._allocated[address] = aligned_size + start = local_high + end = start + size + self._high_watermark += aligned_size + return address, self._buffer[start:end] + else: + # If high watermark + requested size is not okay, we need to + # wrap around and allocate again + self._high_watermark = turnaround_size + return self.allocate(size) + else: + # High watermark is below low watermark, check if we can allocate + if local_high + size < local_low: + address = self._high_watermark + self._allocated[address] = aligned_size + start = local_high + end = start + size + self._high_watermark += aligned_size + return address, self._buffer[start:end] + else: + # No space available + return -1, None + + def view_as_tensor(self, vaddr: int, dtype: torch.dtype, + shape: torch.Size) -> torch.Tensor: + """View the buffer as a tensor. + Args: + vaddr (int): The virtual address of the buffer. + dtype (torch.dtype): The data type of the tensor. + shape (torch.Size): The shape of the tensor. + Returns: + torch.Tensor: The tensor view of the buffer. + """ + assert vaddr % self._align_to == 0, \ + "Virtual address is not aligned to the alignment size" + + paddr = self.virtual_to_physical(vaddr) + size = shape.numel() * dtype.itemsize + assert paddr + size <= self._size, \ + "Physical address is out of bounds" + + # Get the tensor + return self._buffer[paddr:paddr + size].view(dtype).view(shape) + + def free(self, address: int) -> None: + """Free the buffer at the given address. + + Args: + address (int): The virtual address of the buffer to be freed, + which is returned by the allocate() method. + """ + assert address in self._allocated, \ + f"Address {address} not found in allocated buffers" + + # Pop the address from the allocated dict, and update the + # low watermark + self._allocated.pop(address) + + # If there is nothing allocated, set low_watermark to high watermark + new_low_watermark = self._high_watermark + + # Else, set the low_watermark to the first address in the allocated + # dict + for addr in self._allocated: + new_low_watermark = addr + break + self._low_watermark = new_low_watermark + + @property + def high_watermark(self) -> int: + return self._high_watermark + + @property + def low_watermark(self) -> int: + return self._low_watermark + + def virtual_to_physical(self, vaddr: int) -> int: + """Convert a virtual address to a physical address. + + Args: + vaddr (int): The virtual address to be converted. + + Returns: + torch.Tensor: The physical address of the buffer. + """ + return vaddr % self._size + + def get_size(self) -> int: + """Get the size of the ring buffer. + + Returns: + int: The size of the ring buffer. + """ + return self._size + + def get_buffer_ptr(self) -> int: + """Get the pointer to the buffer. + + Returns: + int: The pointer to the buffer. + """ + return self._buffer.data_ptr() + + +################################################################### +# NIXL Related Classes +################################################################### + + +class NixlProtocolMsg(msgspec.Struct): + msg_type: str + req_uuid: str + source_spec: Optional[SourceSpec] = None + receiver_paddr: Optional[int] = None + + +def make_send_req_msg(source_spec: SourceSpec, req_uuid: str) -> bytes: + """Make the send request message. + + Args: + source_spec (SourceSpec): The source spec. + + Returns: + bytes: The send request message. + """ + # Create the request message + msg_type = "REQMSG" + receiver_paddr = None + send_req_msg = NixlProtocolMsg(msg_type=msg_type, + req_uuid=req_uuid, + source_spec=source_spec, + receiver_paddr=receiver_paddr) + # Encode the message + send_req_msg_bytes = msgspec.msgpack.encode(send_req_msg) + return send_req_msg_bytes + + +def make_receive_ready_msg( + req_uuid: str, + receiver_paddr: int, +) -> bytes: + """Make the receive ready message. + + Args: + req_uuid (str): The request uuid. + receiver_paddr (int): The receiver's physical address. + + Returns: + bytes: The receive ready message. + """ + # Create the request message + msg_type = "READYMSG" + source_spec = None + receive_ready_msg = NixlProtocolMsg(msg_type=msg_type, + req_uuid=req_uuid, + source_spec=source_spec, + receiver_paddr=receiver_paddr) + # Encode the message + receive_ready_msg_bytes = msgspec.msgpack.encode(receive_ready_msg) + return receive_ready_msg_bytes + + +def make_send_finish_msg(req_uuid: str, ) -> bytes: + """Make the send finish message. + + Args: + req_uuid (str): The request uuid. + + Returns: + bytes: The send finish message. + """ + # Create the request message + msg_type = "FINISHMSG" + source_spec = None + receiver_paddr = None + send_finish_msg = NixlProtocolMsg(msg_type=msg_type, + req_uuid=req_uuid, + source_spec=source_spec, + receiver_paddr=receiver_paddr) + # Encode the message + send_finish_msg_bytes = msgspec.msgpack.encode(send_finish_msg) + return send_finish_msg_bytes + + +class NixlCPUSender: + + def __init__( + self, + buffer_size: int, + buffer_ptr: int, + nixl_page_size: int = 4096, + ) -> None: + self._buffer_size = buffer_size + self._buffer_ptr = buffer_ptr + self._nixl_page_size = nixl_page_size + + # Destination spec id -> peer name + self._remote_agents: dict[str, str] = {} + + self._nixl_wrapper, \ + self._reg_dlist, \ + self._local_xfer_dlist, \ + self._local_xfer_handlers = \ + init_nixl_agent(buffer_size, buffer_ptr, nixl_page_size) + + # Remote xfer dlists, peer name -> prepped xfer handlers + self._remote_xfer_handlers: dict[str, Any] = {} + + # Add ZMQ context for handshakes + self._zmq_ctx = zmq.Context() + + # Requests that are ready to send + # uuid -> (remote agent name, receiver paddr) + self._ready_requests: dict[str, tuple[str, int]] = {} + + # NOTE(ApostaC): we don't track the requests that are waiting for the + # receiver to be ready, and may want to add this in the future + + # Msg decoder + self._msg_decoder = msgspec.msgpack.Decoder(NixlProtocolMsg) + + def _get_desc_idxs(self, paddr: int, size: int) -> list[int]: + """Get the sender descriptor indexes for the given physical address + and size. + + Args: + paddr (int): The physical address. + size (int): The size of the data. + + Returns: + list[int]: The list of sender descriptor indexes. + """ + # Get the sender descriptor indexes + assert paddr % self._nixl_page_size == 0, \ + "Physical address is not aligned to the page size" + start_idx = paddr // self._nixl_page_size + end_idx = (paddr + size) // self._nixl_page_size + return [i for i in range(start_idx, end_idx)] + + def send( + self, + src_paddr: int, + dst_paddr: int, + data_size: int, + req_uuid: str, + destination_spec: DestinationSpec, + ) -> nixl_xfer_handle: + """Send data from src_addr to dst_addr using NIXL. + + Args: + src_paddr (int): Source physical address. + dst_paddr (int): Destination physical address. + data_size (int): Size of the data in bytes to be sent. + req_uuid (int): The request uuid. + destination_spec (DestinationSpec): The destination spec. + + Returns: + nixl_xfer_handle: The handle of the transfer. + """ + # Get the sender descriptor indexes + desc_idxs = self._get_desc_idxs(src_paddr, data_size) + # Get the receiver descriptor indexes + r_desc_idxs = self._get_desc_idxs(dst_paddr, data_size) + # Get the remote agent name + remote_agent_name = self._remote_agents[destination_spec.get_id()] + # Get the remote xfer dlist + remote_xfer_handlers = self._remote_xfer_handlers[remote_agent_name] + # Notif msg + notif_msg = make_send_finish_msg(req_uuid) + # Transfer + handle = self._nixl_wrapper.make_prepped_xfer( + "WRITE", self._local_xfer_handlers, desc_idxs, + remote_xfer_handlers, r_desc_idxs, notif_msg) + + self._nixl_wrapper.transfer(handle) + + return handle + + def is_send_finished(self, handle: "nixl_xfer_handle") -> bool: + """Check if the send operation is finished. + + Args: + handle (nixl_xfer_handle): The handle of the transfer. + + Returns: + bool: True if the send operation is finished, False otherwise. + """ + status = self._nixl_wrapper.check_xfer_state(handle) + if status == "ERR": + logger.error("Error in send operation") + return False + return status == "DONE" + + def prepare_send( + self, + source_spec: SourceSpec, + destination_spec: DestinationSpec, + ) -> str: + """Prepare the send operation by allocation the receive buffer + on the destination side. + + Args: + source_spec (SourceSpec): The source spec. + destination_spec (DestinationSpec): The destination spec. + + Returns: + str: The uuid of the prepared send + """ + dest_id = destination_spec.get_id() + if dest_id not in self._remote_agents: + # Perform handshake with the destination + self._nixl_handshake(destination_spec) + + remote_agent_name = self._remote_agents[dest_id] + + # Create the request message + req_uuid = str(uuid.uuid4()) + msg = make_send_req_msg(source_spec, req_uuid) + + # Send it to the remote agent + self._nixl_wrapper.send_notif(remote_agent_name, msg) + + return req_uuid + + def check_and_remove_prepared_send( + self, + send_uuid: str, + ) -> tuple[Optional[str], int]: + """Check if the prepared send is ready to be sent. + If the send is ready, remove it from the ready requests. + + Args: + send_uuid (str): The uuid of the prepared send. + + Returns: + Optional[str]: The remote agent name if the send is ready, + None otherwise. + int: The physical address of the receiver if the send is ready, + -1 otherwise. + """ + # Update the ready requests + notifs = self._nixl_wrapper.get_new_notifs() + for remote_agent_name in notifs: + for msg in notifs[remote_agent_name]: + # Decode the message + obj = self._msg_decoder.decode(msg) + + if obj.msg_type == "READYMSG": + # Add the request to the ready requests + assert obj.receiver_paddr is not None, \ + "Receiver address is None in READYMSG" + self._ready_requests[obj.req_uuid] = (remote_agent_name, + obj.receiver_paddr) + else: + logger.error("Unexpected message type: %s", obj.msg_type) + continue + + # Check if the send uuid is in the ready requests + if send_uuid in self._ready_requests: + # Remove the request from the ready requests + remote_agent_name, vaddr = self._ready_requests.pop(send_uuid) + return remote_agent_name, vaddr + else: + return None, -1 + + def _nixl_handshake(self, destination_spec: DestinationSpec) -> None: + """Perform handshake with a remote NIXL CPU instance. + + Args: + destination_spec (DestinationSpec): The destination spec. + """ + assert get_tensor_model_parallel_rank() == destination_spec.rank, \ + "Got different rank in destination spec and current rank" + + port = destination_spec.base_port + destination_spec.rank + path = make_zmq_path("tcp", destination_spec.host, port) + + local_meta = self._nixl_wrapper.get_agent_metadata() + with zmq_ctx(zmq.REQ, path) as sock: + # Send query for metadata + logger.debug("Sending handshake request to %s", destination_spec) + sock.send(local_meta) + + metadata_bytes = sock.recv() + + # Get remote agent name and register it + remote_agent_name = self._nixl_wrapper.add_remote_agent( + metadata_bytes) + + # Store remote agent info + self._remote_agents[destination_spec.get_id()] = remote_agent_name + + sock.send(b"get_xfer_descs") + # Receive the remote xfer descs + s_remote_xfer_descs = sock.recv() + remote_xfer_dlist = self._nixl_wrapper.deserialize_descs( + s_remote_xfer_descs) + + remote_xfer_handlers = self._nixl_wrapper.prep_xfer_dlist( + remote_agent_name, remote_xfer_dlist, mem_type="DRAM") + + self._remote_xfer_handlers[ + remote_agent_name] = remote_xfer_handlers + + logger.debug("Successfully completed handshake with %s", + destination_spec) + + def close(self) -> None: + if not hasattr(self, "_nixl_wrapper"): + return + + if self._reg_dlist is not None: + self._nixl_wrapper.deregister_memory(self._reg_dlist) + for agent in self._remote_agents.values(): + self._nixl_wrapper.remove_remote_agent(agent) + if self._local_xfer_handlers is not None: + self._nixl_wrapper.release_dlist_handle(self._local_xfer_handlers) + for remote_xfer_handler in self._remote_xfer_handlers.values(): + self._nixl_wrapper.release_dlist_handle(remote_xfer_handler) + del self._nixl_wrapper + + +class NixlCPUReceiver: + + def __init__( + self, + allocator: RingBufferAllocator, + nixl_page_size: int = 4096, + ) -> None: + self._buffer_size = allocator.get_size() + self._buffer_ptr = allocator.get_buffer_ptr() + self._nixl_page_size = nixl_page_size + self._allocator = allocator + + assert self._allocator is not None, "Allocator is required" + + # Requests that are pending for allocation + # uuid -> tuple[SourceSpec, peer name] + self._pending_allocation: dict[str, tuple[SourceSpec, str]] = {} + + # Already allocated requests + # uuid -> SourceSpec and uuid -> virtual address + self._inflight_requests: dict[str, SourceSpec] = {} + self._inflight_request_vaddr: dict[str, int] = {} + + # Finished requests + # uuid -> tuple[SourceSpec, virtual address] + self._finished_requests: dict[str, tuple[SourceSpec, int]] = {} + + # source zmq id -> peer name + self._remote_agents: dict[str, str] = {} + + self._nixl_wrapper, \ + self._reg_dlist, \ + self._local_xfer_dlist, \ + self._local_xfer_handlers = \ + init_nixl_agent(self._buffer_size, self._buffer_ptr, + nixl_page_size) + + # Add handshake listener thread + self._handshake_listener_t: Optional[threading.Thread] = None + self._stop_listener = threading.Event() + + # Msg decoder + self._msg_decoder = msgspec.msgpack.Decoder(NixlProtocolMsg) + + def _process_msgs(self): + """Process the received messages from the NIXL agent.""" + notifs = self._nixl_wrapper.get_new_notifs() + for remote_agent_name in notifs: + for msg in notifs[remote_agent_name]: + # Decode the message + obj = self._msg_decoder.decode(msg) + if obj.msg_type == "REQMSG": + # Add the request to the pending allocation + self._pending_allocation[obj.req_uuid] = ( + obj.source_spec, remote_agent_name) + elif obj.msg_type == "FINISHMSG": + # Add the request to the finished requests + if obj.req_uuid in self._inflight_requests: + source_spec = self._inflight_requests.pop(obj.req_uuid) + vaddr = self._inflight_request_vaddr.pop(obj.req_uuid) + self._finished_requests[obj.req_uuid] = (source_spec, + vaddr) + else: + logger.error( + "Request %s not found in inflight requests", + obj.req_uuid) + else: + logger.error("Unexpected message type: %s", obj.msg_type) + continue + + def _process_allocation_requests(self): + """Process the allocation requests and allocate the buffers.""" + allocated_requests = [] + for req_uuid, (source_spec, peer_name) in \ + self._pending_allocation.items(): + # Try to allocate the buffer + requested_size = source_spec.get_size() + if requested_size > self._buffer_size: + raise RuntimeError( + f"Requested size {requested_size} is larger than the " + f"nixl receiver buffer size {self._buffer_size}") + + vaddr, buffer = self._allocator.allocate(requested_size) + if vaddr == -1: + #logger.debug("No space available for request %s", req_uuid) + # No space available, skip all the requests + + # NOTE: an alternative is to try allocation for other requests + # and then come back to this one, but this may create + # starvation + logger.info("No space available for request %s, skipping", + req_uuid) + break + + # Add the request to the inflight requests + self._inflight_requests[req_uuid] = source_spec + self._inflight_request_vaddr[req_uuid] = vaddr + + # Send back the ready message + paddr = self._allocator.virtual_to_physical(vaddr) + ready_msg = make_receive_ready_msg(req_uuid, paddr) + self._nixl_wrapper.send_notif(peer_name, ready_msg) + + # Add the request to the allocated requests + allocated_requests.append(req_uuid) + + # Remove the allocated requests from the pending allocation + for req_uuid in allocated_requests: + del self._pending_allocation[req_uuid] + + def progress(self) -> None: + """Process the received requests and the data + """ + self._process_msgs() + self._process_allocation_requests() + + def get_finished(self, clear=False) -> list[tuple[SourceSpec, int]]: + """Get the requests that finishes receiving. + + Args: + clear (bool): Whether to clear the finished requests or not. + + Returns: + list[tuple[SourceSpec, int]]: A list of tuples containing the + source spec and the address. + """ + ret = [(source_spec, vaddr) + for source_spec, vaddr in self._finished_requests.values()] + if clear: + self._finished_requests.clear() + return ret + + def start_handshake_listener(self, host: str, base_port: int) -> None: + """Start the background thread that listens for handshake requests. + + Args: + host (str): Host address to listen on + base_port (int): Base port number to listen on + """ + ready_event = threading.Event() + self._handshake_listener_t = threading.Thread( + target=self._nixl_handshake_listener, + args=(host, base_port, ready_event), + daemon=True, + name="nixl_cpu_handshake_listener") + self._handshake_listener_t.start() + ready_event.wait() + + def _nixl_handshake_listener(self, host: str, base_port: int, + ready_event: threading.Event) -> None: + """Background thread that listens for and responds to handshake + requests. + + Args: + host (str): Host address to listen on + base_port (int): Base port number to listen on + ready_event (threading.Event): Event to signal when listener is + ready + """ + # Prepare metadata + local_meta = self._nixl_wrapper.get_agent_metadata() + + # Setup ZMQ socket + port = base_port + get_tensor_model_parallel_rank() + path = make_zmq_path("tcp", host, port) + logger.debug("Starting handshake listener on path: %s", path) + + with zmq_ctx(zmq.ROUTER, path) as sock: + ready_event.set() + + while not self._stop_listener.is_set(): + try: + identity, _, msg = sock.recv_multipart(flags=zmq.NOBLOCK) + + if msg == b"get_xfer_descs": + # Send back the local xfer descs + s_local_xfer_descs = self._nixl_wrapper.\ + get_serialized_descs(self._local_xfer_dlist) + sock.send_multipart( + [identity, b"", s_local_xfer_descs]) + logger.debug("Sent back the local xfer descs to %s", + identity) + else: + # Send the agent metadata + remote_agent_name = self._nixl_wrapper.add_remote_agent( + msg) + self._remote_agents[identity] = remote_agent_name + logger.debug("Successfully received handshake from %s", + identity) + # Send back the local metadata to the sender + sock.send_multipart([identity, b"", local_meta]) + logger.debug("Sent local metadata back to %s", + identity) + + except zmq.error.Again: + # No message available + time.sleep(0.1) + except Exception as e: + logger.error("Error in handshake listener: %s", e) + break + logger.debug("Stopping handshake listener") + + def stop_handshake_listener(self) -> None: + """Stop the handshake listener thread.""" + if self._handshake_listener_t is not None: + self._stop_listener.set() + self._handshake_listener_t.join() + self._handshake_listener_t = None + + def close(self): + logger.info( + "Watermark information before closing: (low: %d, high: %d)", + self._allocator.low_watermark, self._allocator.high_watermark) + self.stop_handshake_listener() + if hasattr(self, "_nixl_wrapper"): + self._nixl_wrapper.deregister_memory(self._reg_dlist) + del self._nixl_wrapper + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ): + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None + try: + ctx = zmq.Context() # type: ignore[attr-defined] + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) + finally: + if ctx is not None: + ctx.destroy(linger=0) + + +@dataclass +class NixlSendTask(SendTask): + """NixlSendTask is a send task that uses CPU memory for the buffer and + Nixl for sending. + """ + # Required fields + # virtual address of the src buffer + buffer_vaddr: int + # Parent nixl sender + parent_sender: NixlCPUSender + # nixl request uuid + request_uuid: str + + # Optional fields that will be updated later + # Cuda event for h2d copy + cuda_event: Optional[torch.cuda.Event] = None + # Destination physical address + receiver_paddr: Optional[int] = None + # nixl transfer handle + transfer_handle: Optional[nixl_xfer_handle] = None + + def __post_init__(self) -> None: + self.creation_time = time.time() + + def update_states(self) -> None: + """Update the states of the send task. + """ + # Check the cuda event + if not self.state.sender_ready and self.cuda_event is not None \ + and self.cuda_event.query(): + self.state.sender_ready = True + + # check if the send is ready + if not self.state.receiver_ready and self.receiver_paddr is None: + rname, rpaddr = self.parent_sender.check_and_remove_prepared_send( + self.request_uuid) + if rname is not None: + assert rpaddr != -1 + self.receiver_paddr = rpaddr + self.state.receiver_ready = True + + if not self.is_done() and self.transfer_handle is not None \ + and self.parent_sender.is_send_finished(self.transfer_handle): + self.state.send_done = True + + +class NixlPrefillManager(KVSenderInterface): + """NixlSendTask is an implementation of KVSenderInterface that provides a + ring buffer allocator for managing pin memory allocation and deallocation, + with NIXL for sending data. + """ + + def __init__(self, buffer_size: int) -> None: + super().__init__() + nixl_page_size = DEFAULT_NIXL_PAGE_SIZE + self._buffer_size = buffer_size + self._allocator = RingBufferAllocator(self._buffer_size, + nixl_page_size) + self._nixl_sender = NixlCPUSender(buffer_size, + self._allocator.get_buffer_ptr(), + nixl_page_size) + + def create_send_task( + self, + source_spec: SourceSpec, + destination_spec: DestinationSpec, + ) -> SendTask: + """Create a non-ready send task with a CPU buffer allocated. + + Args: + source_spec (SourceSpec): The source specification of the send + task. + destination_spec (DestinationSpec): The destination + specification of the send task. + """ + # Allocate a buffer for the send task + size = source_spec.get_size() + address, buffer = self._allocator.allocate(size) + while address == -1: + # If allocation fails, wait for a while to process + # and try again + time.sleep(0.001) + self.progress() + address, buffer = self._allocator.allocate(size) + assert buffer is not None, "Buffer allocation failed" + + # Prepare the send request in NixlSender + req_uuid = self._nixl_sender.prepare_send(source_spec, + destination_spec) + + # Create a send task with the allocated buffer + task = NixlSendTask(buffer=buffer, + source_spec=source_spec, + destination_spec=destination_spec, + state=SendTaskState(), + buffer_vaddr=address, + parent_sender=self._nixl_sender, + request_uuid=req_uuid) + self.add_send_task(task) + return task + + def free_task(self, task: SendTask) -> None: + """Free the send task. + Will be called in the pre-implemented progress() method. + + Args: + task (SendTask): The send task to be freed. + """ + assert isinstance(task, NixlSendTask), \ + "Task is not a NixlSendTask" + # Free the buffer in the ring buffer allocator + self._allocator.free(task.buffer_vaddr) + + def send_task(self, task: SendTask) -> None: + """Send the send task after it is ready. + Will be called in the pre-implemented progress() method. + + Args: + task (SendTask): The send task to be sent. + """ + assert isinstance(task, NixlSendTask), \ + "Task is not a NixlSendTask" + assert task.receiver_paddr is not None, \ + "Receiver physical address is not set in the task" + handle = self._nixl_sender.send( + self._allocator.virtual_to_physical(task.buffer_vaddr), + task.receiver_paddr, task.source_spec.get_size(), + task.request_uuid, task.destination_spec) + task.transfer_handle = handle + task.mark_sending() + return + + def pre_progress_hook(self) -> None: + for task in self.get_send_tasks(): + task.update_states() + + def post_progress_hook(self) -> None: + pass + + def wait_for_all_tasks(self) -> None: + """Wait for all tasks to finish. Mainly for debug, test, + and offline inferences. + """ + # Wait for all tasks to finish + tasks = self.get_send_tasks() + while tasks: + self.progress() + time.sleep(1) + tasks = self.get_send_tasks() + logger.info("Still waiting for %d tasks to finish", len(tasks)) + + def close(self): + self.wait_for_all_tasks() + self._nixl_sender.close() + + +class NixlDecodeManager: + + def __init__(self, buffer_size: int, host: str, port: int) -> None: + self.nixl_page_size = DEFAULT_NIXL_PAGE_SIZE + self._buffer_size = buffer_size + self._allocator = RingBufferAllocator(self._buffer_size, + self.nixl_page_size) + self._nixl_receiver = NixlCPUReceiver(self._allocator, + self.nixl_page_size) + self._nixl_receiver.start_handshake_listener(host, port) + + # How many tokens are received for each request, each layer + # (p_request_id, layer_id) -> num_tokens + self._received_tokens: dict[str, dict[int, int]] = {} + + # How many tokens are expected for each request + # p_request_id -> num_tokens + self._expected_tokens: dict[str, int] = {} + + # The detailed specs of the requests + # (p_request_id, layer_id) -> (SourceSpec, vaddr) + self._request_specs: dict[tuple[str, int], list[tuple[SourceSpec, + int]]] = {} + + # Metadata + self.rank = get_tensor_model_parallel_rank() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + + # Multi process receiving check + # p_request_id -> number of ready workers + self._done_receiving_count: defaultdict[str, + int] = defaultdict(lambda: 0) + + # Already 'ready' request, we don't want to check and return it + # again. + self._already_ready_requests: set[str] = set() + + def _check_receive_and_update(self): + """Checks the KV cache receiving status and update the internal + states + """ + finished_list = self._nixl_receiver.get_finished(clear=True) + for source_spec, vaddr in finished_list: + # Get the request id and layer id + p_request_id = source_spec.request_id + layer_id = source_spec.layer_id + num_received_tokens = source_spec.stop - source_spec.start + + if p_request_id not in self._expected_tokens: + self._expected_tokens[ + p_request_id] = source_spec.num_all_tokens + + # Update the received tokens + if p_request_id not in self._received_tokens: + self._received_tokens[p_request_id] = {} + if layer_id not in self._received_tokens[p_request_id]: + self._received_tokens[p_request_id][layer_id] = 0 + self._received_tokens[p_request_id][ + layer_id] += num_received_tokens + + # Update received specs + if (p_request_id, layer_id) not in self._request_specs: + self._request_specs[(p_request_id, layer_id)] = [] + self._request_specs[(p_request_id, layer_id)].append( + (source_spec, vaddr)) + + def progress(self) -> None: + """Process the received requests and the data. Updates the internal + status and respond to the allocation requests. + """ + self._nixl_receiver.progress() + + def get_finished(self, num_expected_layers: int) -> list[str]: + """Get the prefill node request_ids of the requests that finishes + receiving (which means the KV caches of all tokens and all layers + are in CPU memory). + + By default, if a request's id will only be returned once. However, + the caller can call `remove_ready_request` to force the get_finished + to return the request id again in the next call. + + Returns: + list[str]: A list of prefill-side request ids. + """ + ready_requests = [] + self._check_receive_and_update() + for p_request_id in self._expected_tokens: + if p_request_id in self._already_ready_requests: + # Already checked and ready, skip it + continue + + expected_tokens = self._expected_tokens[p_request_id] + assert p_request_id in self._received_tokens + # check if all the layers are there + if len(self._received_tokens[p_request_id]) != num_expected_layers: + continue + # check if all the tokens are there + ready = True + for layer_id in self._received_tokens[p_request_id]: + received_tokens = self._received_tokens[p_request_id][layer_id] + if received_tokens != expected_tokens: + ready = False + break + if ready: + ready_requests.append(p_request_id) + self._already_ready_requests.add(p_request_id) + + if self.world_size == 1: + return ready_requests + + # For multi-process + if self.rank == 0: + for p_request_id in ready_requests: + self._done_receiving_count[p_request_id] += 1 + + other_ranks_finished_ids: list[str] = [] + for i in range(1, self.world_size): + other_ranks_finished_ids.extend( + self.tp_group.recv_object(src=i)) + for p_request_id in other_ranks_finished_ids: + self._done_receiving_count[p_request_id] += 1 + + all_done_recving: list[str] = [] + for p_request_id in self._done_receiving_count: + if self._done_receiving_count[p_request_id] == \ + self.world_size: + all_done_recving.append(p_request_id) + + # Clear the done receiving count for the requests that are done + for p_request_id in all_done_recving: + self._done_receiving_count.pop(p_request_id) + return all_done_recving + else: + self.tp_group.send_object(ready_requests, dst=0) + return ready_requests + + def remove_ready_request(self, p_request_id: str) -> None: + """Remove the request from the 'ready' request list so that + it will be checked again in the next of get_finished. + + Args: + p_request_id (str): The prefill-side request id. + """ + self._already_ready_requests.discard(p_request_id) + + def _create_decoder_kv_spec(self, source_spec: SourceSpec, + vaddr: int) -> DecoderKVSpec: + """Create a DecoderKVSpec from the source spec and the virtual address. + """ + # Get the correct buffer + return DecoderKVSpec(start=source_spec.start, + stop=source_spec.stop, + buffer=self._allocator.view_as_tensor( + vaddr, source_spec.dtype, + source_spec.tensor_shape)) + + def get_kv_specs(self, p_request_id: str, + layer_id: int) -> list[DecoderKVSpec]: + """Get the KV specs for the given request id and layer id, which + will be used for connector to load the KV back to CPU + + Args: + p_request_id (str): The original request id from prefiller. + layer_id (int): The layer id of the request. + """ + ret: list[DecoderKVSpec] = [] + if (p_request_id, layer_id) not in self._request_specs: + logger.warning("Request %s not found in request specs", + (p_request_id, layer_id)) + return ret + + for source_spec, vaddr in self._request_specs[(p_request_id, + layer_id)]: + # Create the decoder kv spec + decoder_kv_spec = self._create_decoder_kv_spec(source_spec, vaddr) + ret.append(decoder_kv_spec) + + return ret + + def free_request(self, p_request_id): + """Free the request's memory with the given request id. + + Args: + p_request_id (str): The original request id from prefiller. + """ + # Free the memory and clear the internal states + self._expected_tokens.pop(p_request_id, None) + rcv_tokens = self._received_tokens.pop(p_request_id, None) + if rcv_tokens is not None: + for layer_id in rcv_tokens: + assert (p_request_id, layer_id) in self._request_specs, \ + "Found received tokens but no request specs" + + # Free the memory + for src_spec, vaddr in self._request_specs[(p_request_id, + layer_id)]: + self._allocator.free(vaddr) + + # Clear the request specs + self._request_specs.pop((p_request_id, layer_id), None) + + else: + logger.warning("Request %s not found in received tokens", + p_request_id) + + self.remove_ready_request(p_request_id) + + def close(self): + self._nixl_receiver.close() diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 60f1d5d8bca7..0773690964b2 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import atexit from typing import TYPE_CHECKING, Optional from vllm import envs @@ -63,6 +64,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: if envs.VLLM_USE_V1: _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( config=vllm_config, role=KVConnectorRole.WORKER) + atexit.register(_KV_CONNECTOR_AGENT.close, ) else: _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( rank=get_world_group().rank,