diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index fe687da49290..33e38eab32de 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -310,39 +310,39 @@ async def benchmark( else: raise ValueError(f"Unknown backend: {backend}") - print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len = input_requests[0] - test_input = RequestFuncInput( - model=model_id, - prompt=test_prompt, - api_url=api_url, - prompt_len=test_prompt_len, - output_len=test_output_len, - best_of=best_of, - use_beam_search=use_beam_search, - ) - test_output = await request_func(request_func_input=test_input) - if not test_output.success: - raise ValueError( - "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}") - else: - print("Initial test run completed. Starting main benchmark run...") - - if profile: - print("Starting profiler...") - profile_input = RequestFuncInput( - model=model_id, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - best_of=best_of, - use_beam_search=use_beam_search, - ) - profile_output = await request_func(request_func_input=profile_input) - if profile_output.success: - print("Profiler started") + #print("Starting initial single prompt test run...") + #test_prompt, test_prompt_len, test_output_len = input_requests[0] + #test_input = RequestFuncInput( + # model=model_id, + # prompt=test_prompt, + # api_url=api_url, + # prompt_len=test_prompt_len, + # output_len=test_output_len, + # best_of=best_of, + # use_beam_search=use_beam_search, + #) + #test_output = await request_func(request_func_input=test_input) + #if not test_output.success: + # raise ValueError( + # "Initial test run failed - Please make sure benchmark arguments " + # f"are correctly specified. Error: {test_output.error}") + #else: + # print("Initial test run completed. Starting main benchmark run...") + + #if profile: + # print("Starting profiler...") + # profile_input = RequestFuncInput( + # model=model_id, + # prompt=test_prompt, + # api_url=base_url + "/start_profile", + # prompt_len=test_prompt_len, + # output_len=test_output_len, + # best_of=best_of, + # use_beam_search=use_beam_search, + # ) + # profile_output = await request_func(request_func_input=profile_input) + # if profile_output.success: + # print("Profiler started") print(f"Traffic request rate: {request_rate}") @@ -367,20 +367,20 @@ async def benchmark( pbar=pbar))) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) - if profile: - print("Stopping profiler...") - profile_input = RequestFuncInput( - model=model_id, - prompt=test_prompt, - api_url=base_url + "/stop_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - best_of=best_of, - use_beam_search=use_beam_search, - ) - profile_output = await request_func(request_func_input=profile_input) - if profile_output.success: - print("Profiler stopped") + #if profile: + # print("Stopping profiler...") + # profile_input = RequestFuncInput( + # model=model_id, + # prompt=test_prompt, + # api_url=base_url + "/stop_profile", + # prompt_len=test_prompt_len, + # output_len=test_output_len, + # best_of=best_of, + # use_beam_search=use_beam_search, + # ) + # profile_output = await request_func(request_func_input=profile_input) + # if profile_output.success: + # print("Profiler stopped") if pbar is not None: pbar.close() diff --git a/csrc/ops.h b/csrc/ops.h index 6bf0cff23252..585d0029e033 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -54,10 +54,21 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables); +void advance_step( + int64_t num_prefill_tokens, int64_t num_prefills, int64_t num_seqs, + int64_t num_queries, int64_t block_size, int64_t num_prefills_with_sampling, + torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, torch::Tensor& input_positions, + torch::Tensor& seq_lens, torch::Tensor& slot_mapping, + torch::Tensor& block_tables, torch::Tensor& seq_start_loc, + c10::optional context_lens, + c10::optional const& prefill_steps_tokens, + c10::optional const& prefill_steps_slot_mapping, + c10::optional const& prefill_input_positions_update, + c10::optional const& prefill_seq_start_loc_update, + c10::optional const& prefill_advance_query, + c10::optional const& prefill_advance_tokens, + c10::optional const& prefill_token_chunk_sizes); #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 0e537ddd6c4c..84239380bb33 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -10,51 +10,120 @@ namespace prepare_inputs { -// -template -__global__ void advance_step_kernel(int num_seqs, int num_queries, - int block_size, long* input_tokens_ptr, - long const* sampled_token_ids_ptr, - long* input_positions_ptr, - int* seq_lens_ptr, long* slot_mapping_ptr, - int const* block_tables_ptr, - int64_t const block_tables_stride) { - int num_query_blocks = div_ceil(num_queries, num_threads); - - if (blockIdx.x >= num_query_blocks) { - return; - } +__device__ void update_decode( + int const cur_query_id, int const token_idx, int const sampled_token_idx, + long* input_tokens_ptr, long* input_positions_ptr, int* seq_lens_ptr, + long* slot_mapping_ptr, int const* block_tables_ptr, + long const* sampled_token_ids_ptr, int64_t const block_tables_stride, + int const block_size, int* context_lens_ptr) { + // Update input_tokens + input_tokens_ptr[token_idx] = sampled_token_ids_ptr[sampled_token_idx]; - int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + int const seq_len = seq_lens_ptr[cur_query_id]; + int const next_seq_len = seq_len + 1; + int const next_input_pos = seq_len; - if (cur_query_id >= num_queries) { - return; + if (context_lens_ptr) { + context_lens_ptr[cur_query_id] += 1; } - // Update input_tokens - input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; - - int seq_len = seq_lens_ptr[cur_query_id]; - int next_seq_len = seq_len + 1; - int next_input_pos = next_seq_len - 1; - // Update seq_lens seq_lens_ptr[cur_query_id] = next_seq_len; // Update input_positions - input_positions_ptr[cur_query_id] = next_input_pos; + input_positions_ptr[token_idx] = next_input_pos; int const* seq_block_tables_ptr = block_tables_ptr + block_tables_stride * cur_query_id; - int block_index = next_input_pos / block_size; - int block_offset = next_input_pos % block_size; + int const block_index = next_input_pos / block_size; + int const block_offset = next_input_pos % block_size; + // TODO (varun) : CHeck if we can reuse this logic for filling prefill slot + // mapping instead of passing it as an input int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; // Update slot_mapping - slot_mapping_ptr[cur_query_id] = slot_num; + slot_mapping_ptr[token_idx] = slot_num; +} + +__device__ void update_prefill(int const cur_query_id, int* seq_lens_ptr, + int* context_lens_ptr, + int const* prefill_token_chunk_sizes) { + seq_lens_ptr[cur_query_id] += prefill_token_chunk_sizes[cur_query_id]; + context_lens_ptr[cur_query_id] += prefill_token_chunk_sizes[cur_query_id]; +} + +template +__global__ void advance_step_kernel( + int num_prefill_tokens, int num_prefills, int num_seqs, int num_queries, + int num_prefills_with_sampling, + int block_size, long* input_tokens_ptr, long const* sampled_token_ids_ptr, + long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, + int const* block_tables_ptr, int64_t const block_tables_stride, + int* seq_start_loc, int* context_lens_ptr = nullptr, + // TODO (varun) - Rename the following as ptrs + long const* prefill_steps_tokens = nullptr, + long const* prefill_steps_slot_mapping = nullptr, + long const* prefill_input_positions_update = nullptr, + int const* prefill_seq_start_loc_update = nullptr, + bool const* prefill_advance_query = nullptr, + bool const* prefill_advance_tokens = nullptr, + int const* prefill_token_chunk_sizes = nullptr) { + // TODO : USE PREFILL ADVANCE QUERY AND TOKENS + // copy prefills + if (num_prefill_tokens > 0 && blockIdx.x == 0) { + // Update prefill input tokens and slot mapping + for (int i = threadIdx.x; i < num_prefill_tokens; i += blockDim.x) { + if (prefill_advance_tokens[i]) { + input_tokens_ptr[i] = prefill_steps_tokens[i]; + slot_mapping_ptr[i] = prefill_steps_slot_mapping[i]; + input_positions_ptr[i] += prefill_input_positions_update[i]; + } + } + } + + int num_query_blocks = div_ceil(num_queries, num_threads); + if (blockIdx.x >= num_query_blocks) { + return; + } + + int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + if (cur_query_id >= num_queries) { + return; + } + + // seq stsart loc update is for all seqs + if (seq_start_loc && prefill_seq_start_loc_update) { + seq_start_loc[cur_query_id + 1] += + prefill_seq_start_loc_update[cur_query_id + 1]; + } + + bool const is_prefill_query_id = cur_query_id < num_prefills; + + if (is_prefill_query_id) { + // prefill update + // Note that + // - input tokens + // - input positions and, + // - slot mapping are already updated. + if (prefill_advance_query[cur_query_id]) { + update_prefill(cur_query_id, seq_lens_ptr, context_lens_ptr, + prefill_token_chunk_sizes); + } + } else { + // decode update + int const decode_token_idx = + (cur_query_id - num_prefills) + num_prefill_tokens; + // TODO (fix this ) + int const sampled_token_idx = num_prefills_with_sampling + (cur_query_id - num_prefills); + update_decode( + cur_query_id, decode_token_idx, sampled_token_idx, input_tokens_ptr, + input_positions_ptr, seq_lens_ptr, slot_mapping_ptr, block_tables_ptr, + sampled_token_ids_ptr, block_tables_stride, block_size, + context_lens_ptr); + } } -inline void verify_tensor(std::string const& name, torch::Tensor& t, +inline void verify_tensor(std::string const& name, torch::Tensor const& t, int64_t const size_0, int64_t const size_1, c10::ScalarType const type) { bool size_0_cond = true; @@ -79,28 +148,126 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t, } } -void advance_step(int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables) { // type: int +inline void verify_tensor_ge(std::string const& name, torch::Tensor const& t, + int64_t const size_0, int64_t const size_1, + c10::ScalarType const type) { + bool size_0_cond = true; + if (size_0 != -1) { + size_0_cond = t.size(0) >= size_0; + } + + bool size_1_cond = true; + if (size_1 != -1) { + size_1_cond = t.size(1) >= size_1; + } + + bool is_contiguous = t.is_contiguous(); + bool same_type = t.dtype() == type; + + bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; + if (!pass) { + TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), + " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), + " is not as expected: shape = [", size_0, ", ", size_1, + "], type = ", type); + } +} + +void advance_step( + int const num_prefill_tokens, int const num_prefills, int const num_seqs, + int const num_queries, int const block_size, + int const num_prefills_with_sampling, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables, // type: int + torch::Tensor& seq_start_loc, // type: int + c10::optional context_lens, // type: int + c10::optional const& prefill_steps_tokens, // type long + c10::optional const& + prefill_steps_slot_mapping, // type long + c10::optional const& + prefill_input_positions_update, // type long + c10::optional const& prefill_seq_start_loc_update, + c10::optional const& + prefill_advance_query, // type int8 + c10::optional const& + prefill_advance_tokens, // typei int8 + c10::optional const& + prefill_token_chunk_sizes) { // type int if (logging) { printf("advance_step:\n"); + printf(" num_prefill_tokens = %d\n", num_prefill_tokens); + printf(" num_prefills = %d\n", num_prefills); printf(" num_seqs = %d\n", num_seqs); printf(" num_queries = %d\n", num_queries); + printf(" num_prefills_with_sampling = %d\n", num_prefills_with_sampling); printf(" block_size = %d\n", block_size); } + + if (num_prefills > 0) { + TORCH_CHECK(num_prefill_tokens > 0); + TORCH_CHECK(context_lens.has_value()); + TORCH_CHECK(prefill_steps_tokens.has_value()); + TORCH_CHECK(prefill_steps_slot_mapping.has_value()); + TORCH_CHECK(prefill_input_positions_update.has_value()); + TORCH_CHECK(prefill_seq_start_loc_update.has_value()); + TORCH_CHECK(prefill_advance_query.has_value()) + TORCH_CHECK(prefill_advance_tokens.has_value()) + TORCH_CHECK(prefill_token_chunk_sizes.has_value()); + } else { + TORCH_CHECK(num_prefill_tokens == 0); + // TORCH_CHECK(!context_lens.has_value()); + TORCH_CHECK(!prefill_steps_tokens.has_value()); + TORCH_CHECK(!prefill_steps_slot_mapping.has_value()); + TORCH_CHECK(!prefill_input_positions_update.has_value()); + TORCH_CHECK(!prefill_seq_start_loc_update.has_value()); + TORCH_CHECK(!prefill_advance_query.has_value()) + TORCH_CHECK(!prefill_advance_tokens.has_value()) + TORCH_CHECK(!prefill_token_chunk_sizes.has_value()); + } + + int const num_decode_tokens = num_seqs - num_prefills; + int const num_decodes = num_queries - num_prefills; + int const expected_num_input_tokens = num_prefill_tokens + num_decode_tokens; + // Verify all tensors - verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); - verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, + verify_tensor("input_tokens", input_tokens, expected_num_input_tokens, -1, at::kLong); - verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("sampled_token_ids", sampled_token_ids, + num_decodes + num_prefills_with_sampling, 1, at::kLong); + verify_tensor("input_positions", input_positions, expected_num_input_tokens, + -1, at::kLong); verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); - verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); + verify_tensor("slot_mapping", slot_mapping, expected_num_input_tokens, -1, + at::kLong); verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + if (num_prefills > 0) { + verify_tensor("seq_start_loc", seq_start_loc, num_seqs + 1, -1, at::kInt); + verify_tensor("context_lens", context_lens.value(), num_seqs, -1, at::kInt); + verify_tensor_ge("prefill_steps_tokens", prefill_steps_tokens.value(), + num_prefill_tokens, -1, at::kLong); + verify_tensor_ge("prefill_steps_slot_mapping", + prefill_steps_slot_mapping.value(), num_prefill_tokens, -1, + at::kLong); + verify_tensor("prefill_input_positions_update", + prefill_input_positions_update.value(), num_prefill_tokens, + -1, at::kLong); + // TODO (varun) : This should probably be long ? + verify_tensor("prefill_seq_start_loc_update", + prefill_seq_start_loc_update.value(), num_seqs + 1, -1, + at::kInt); + verify_tensor("prefill_advance_query", + prefill_advance_query.value(), num_prefills, -1, at::kChar); + verify_tensor("prefill_advance_tokens", + prefill_advance_tokens.value(), num_prefill_tokens, -1, at::kChar); + verify_tensor("prefill_token_chunk_sizes", + prefill_token_chunk_sizes.value(), num_prefills, -1, + at::kInt); + } int dev = sampled_token_ids.get_device(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); @@ -109,23 +276,69 @@ void advance_step(int num_seqs, int num_queries, int block_size, cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); advance_step_kernel<<>>( - num_seqs, num_queries, block_size, + num_prefill_tokens, num_prefills, num_seqs, num_queries, num_prefills_with_sampling, block_size, reinterpret_cast(input_tokens.data_ptr()), reinterpret_cast(sampled_token_ids.data_ptr()), reinterpret_cast(input_positions.data_ptr()), reinterpret_cast(seq_lens.data_ptr()), reinterpret_cast(slot_mapping.data_ptr()), reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0)); + block_tables.stride(0), reinterpret_cast(seq_start_loc.data_ptr()), + context_lens.has_value() + ? reinterpret_cast(context_lens->data_ptr()) + : nullptr, + prefill_steps_tokens.has_value() + ? reinterpret_cast(prefill_steps_tokens->data_ptr()) + : nullptr, + prefill_steps_slot_mapping.has_value() + ? reinterpret_cast( + prefill_steps_slot_mapping->data_ptr()) + : nullptr, + prefill_input_positions_update.has_value() + ? reinterpret_cast( + prefill_input_positions_update->data_ptr()) + : nullptr, + prefill_seq_start_loc_update.has_value() + ? reinterpret_cast( + prefill_seq_start_loc_update->data_ptr()) + : nullptr, + prefill_advance_query.has_value() + ? reinterpret_cast( + prefill_advance_query->data_ptr()) + : nullptr, + prefill_advance_tokens.has_value() + ? reinterpret_cast( + prefill_advance_tokens->data_ptr()) + : nullptr, + prefill_token_chunk_sizes.has_value() + ? reinterpret_cast(prefill_token_chunk_sizes->data_ptr()) + : nullptr); } } // namespace prepare_inputs -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables) { - prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, - sampled_token_ids, input_positions, seq_lens, - slot_mapping, block_tables); +void advance_step( + int64_t num_prefill_tokens, int64_t num_prefills, int64_t num_seqs, + int64_t num_queries, int64_t block_size, int64_t num_prefills_with_sampling, + torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, torch::Tensor& input_positions, + torch::Tensor& seq_lens, torch::Tensor& slot_mapping, + torch::Tensor& block_tables, + // TODO varun : make this a reference + torch::Tensor& seq_start_loc, c10::optional context_lens, + c10::optional const& prefill_steps_tokens, + c10::optional const& prefill_steps_slot_mapping, + c10::optional const& prefill_input_positions_update, + c10::optional const& prefill_seq_start_loc_update, + c10::optional const& prefill_advance_query, + c10::optional const& prefill_advance_tokens, + c10::optional const& prefill_token_chunk_sizes) { + prepare_inputs::advance_step( + num_prefill_tokens, num_prefills, num_seqs, num_queries, block_size, + num_prefills_with_sampling, + input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, + block_tables, seq_start_loc, context_lens, prefill_steps_tokens, + prefill_steps_slot_mapping, prefill_input_positions_update, + prefill_seq_start_loc_update, prefill_advance_query, prefill_advance_tokens, + prefill_token_chunk_sizes); } \ No newline at end of file diff --git a/csrc/prepare_inputs/advance_step.cuh b/csrc/prepare_inputs/advance_step.cuh index f21574681b1a..81914e5a3a30 100644 --- a/csrc/prepare_inputs/advance_step.cuh +++ b/csrc/prepare_inputs/advance_step.cuh @@ -12,7 +12,7 @@ namespace prepare_inputs { static constexpr int max_threads = 256; -static constexpr bool logging = false; +static constexpr bool logging = false; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } diff --git a/tests/multi_step/test_correctness.py b/tests/multi_step/test_correctness.py index bc14311c6642..3ef927fd7326 100644 --- a/tests/multi_step/test_correctness.py +++ b/tests/multi_step/test_correctness.py @@ -1,16 +1,41 @@ # Test the AsyncLLMEngine with multi-step-decoding +import os +from collections import namedtuple +from enum import Enum from typing import List import pytest from ..utils import RemoteOpenAIServer + +class MultiStepChunkedPrefillPolicy(Enum): + # When prompt and decode sequences are scheduled together, + # the DEFAULT policy is to run the prompt and decodes sequences + # together only for the first step and run just the decode sequences + # in the rest of the steps. + DEFAULT = 1 + # In FORCE_SINGLE_STEP policy, we force the scheduled sequences to + # run a single step and then re-schedule. + FORCE_SINGLE_STEP = 2 + INVALID = 3 + + +ChunkedPrefillTestArgType = namedtuple('ChunkedPrefillTestArgType', + ['enabled', 'policy']) + MODELS = [ "JackFram/llama-160m", ] NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps NUM_PROMPTS = [10] +CHUNKED_PREFILL_ARGS = [ + ChunkedPrefillTestArgType(False, MultiStepChunkedPrefillPolicy.INVALID), + ChunkedPrefillTestArgType(True, MultiStepChunkedPrefillPolicy.DEFAULT), + ChunkedPrefillTestArgType(True, + MultiStepChunkedPrefillPolicy.FORCE_SINGLE_STEP) +] DEFAULT_SERVER_ARGS: List[str] = [ "--disable-log-requests", @@ -23,17 +48,36 @@ ] -async def completions_with_server_args(prompts: List[str], model_name: str, - server_cli_args: List[str]): +class EnvContextManager(): + + def __init__(self, env: dict): + self.os_env = dict(os.environ) + self.add_env = dict(env) + + def __enter__(self): + os.environ.update(self.add_env) + + def __exit__(self, *args, **kwargs): + os.environ.clear() + os.environ.update(self.os_env) + + +async def completions_with_server_args(prompts: List[str], + model_name: str, + server_cli_args: List[str], + with_env: dict = {}): # noqa: B006 + # env setup + os.environ.update(with_env) outputs = None - with RemoteOpenAIServer(model_name, server_cli_args) as server: - client = server.get_async_client() - outputs = await client.completions.create(model=model_name, - prompt=prompts, - temperature=0, - stream=False, - max_tokens=5) + with EnvContextManager(with_env) as _: # noqa: SIM117 + with RemoteOpenAIServer(model_name, server_cli_args) as server: + client = server.get_async_client() + outputs = await client.completions.create(model=model_name, + prompt=prompts, + temperature=0, + stream=False, + max_tokens=5) assert outputs is not None return outputs @@ -47,10 +91,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str, @pytest.mark.parametrize("eager_mode", [False, True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("chunked_prefill", CHUNKED_PREFILL_ARGS) @pytest.mark.asyncio async def test_multi_step(example_prompts, model: str, tp_size: int, - pp_size: int, eager_mode: int, - num_scheduler_steps: int, num_prompts: int): + pp_size: int, eager_mode: bool, + num_scheduler_steps: int, num_prompts: int, + chunked_prefill: ChunkedPrefillTestArgType): prompts = example_prompts if len(prompts) < num_prompts: @@ -65,6 +111,14 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, if eager_mode: ms_server_args.append("--enforce-eager") + test_env = {} + if chunked_prefill.enabled: + ms_server_args.append("--enable-chunked-prefill") + if chunked_prefill.policy == \ + MultiStepChunkedPrefillPolicy.FORCE_SINGLE_STEP: + test_env[ + 'VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY'] = '1' + distributed_args = [ "--tensor-parallel-size", str(tp_size), @@ -75,7 +129,7 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, ref_completions = await completions_with_server_args( prompts, model, server_args + distributed_args) test_completions = await completions_with_server_args( - prompts, model, ms_server_args + distributed_args) + prompts, model, ms_server_args + distributed_args, test_env) def get_text_generations(completions): return [x.text for x in completions.choices] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b89a90ef0f70..b16b6cefb328 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -160,16 +160,34 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) -def advance_step(num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, + #prefill_input_positions_update = model_input.prefill_input_positions_update, + #prefill_seq_start_loc_update = model_input.prefill_seq_start_loc_update) +def advance_step(num_prefill_tokens: int, num_prefills: int, num_seqs: int, + num_queries: int, block_size: int, + num_prefills_with_sampling: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, input_positions: torch.Tensor, seq_lens: torch.Tensor, - slot_mapping: torch.Tensor, - block_tables: torch.Tensor) -> None: + slot_mapping: torch.Tensor, block_tables: torch.Tensor, + seq_start_loc: torch.Tensor, + context_lens: Optional[torch.Tensor], + prefill_steps_tokens: Optional[torch.Tensor], + prefill_steps_slot_mapping: Optional[torch.Tensor], + prefill_input_positions_update: Optional[torch.Tensor], + prefill_seq_start_loc_update: Optional[torch.Tensor], + prefill_advance_query: Optional[torch.Tensor], + prefill_advance_tokens: Optional[torch.Tensor], + prefill_token_chunk_sizes: Optional[torch.Tensor]) -> None: """Advance a step on GPU for existing inputs for a multi-step runner""" - return torch.ops._C.advance_step(num_seqs, num_queries, block_size, - input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, - block_tables) + return torch.ops._C.advance_step( + num_prefill_tokens, num_prefills, num_seqs, num_queries, block_size, + num_prefills_with_sampling, + input_tokens, sampled_token_ids, input_positions, seq_lens, + slot_mapping, block_tables, seq_start_loc, context_lens, + prefill_steps_tokens, prefill_steps_slot_mapping, + prefill_input_positions_update, prefill_seq_start_loc_update, + prefill_advance_query, prefill_advance_tokens, + prefill_token_chunk_sizes) # quantization ops diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 30ce715d5d05..632ff964f35e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,4 +1,5 @@ """Attention layer with FlashAttention.""" +import dataclasses from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type @@ -246,8 +247,8 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self.num_prefills == 0: return None - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata + #if self._cached_prefill_metadata is not None: + # return self._cached_prefill_metadata assert self.seq_lens is not None assert self.seq_lens_tensor is not None @@ -256,6 +257,8 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: assert self.block_tables is not None assert self.seq_start_loc is not None + assert self.query_start_loc.shape[0] >= self.num_prefills + 1 + self._cached_prefill_metadata = FlashAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, @@ -300,9 +303,65 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, ) + return self._cached_decode_metadata - def advance_step(self, num_seqs: int, num_queries: int): + # TODO (varun) : Try using decode_metadata here. We hit some asserts in + # advance_step - but that seems resolvable. + @staticmethod + def without_prefills(m: "FlashAttentionMetadata") \ + -> "FlashAttentionMetadata": + """ + Extract all information related to decodes from the given attention + metadata. + """ + + num_prefills = m.num_prefills + num_prefill_tokens = m.num_prefill_tokens + if num_prefills == 0: + # Simply return a copy + return dataclasses.replace(m) + + # Slice into GPU tensors to remove prefill related information + query_start_loc = None + seq_start_loc = None + if m.query_start_loc is not None and m.seq_start_loc is not None: + query_start_loc = m.query_start_loc[num_prefills:] + seq_start_loc = m.seq_start_loc[num_prefills:] + # query_start_loc and seq_start_loc store indices for + # indexing into some other tensor. As we are removing + # all the prefill related information from all the tensors, + # the decode information would now start from 0. Therefore, + # offset the indices in query_start_loc and seq_start_loc + query_start_loc = query_start_loc - query_start_loc[0] + seq_start_loc = seq_start_loc - seq_start_loc[0] + + # All the other tensors can be sliced in-place + return FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=m.num_decode_tokens, + slot_mapping=m.slot_mapping[num_prefill_tokens:], + seq_lens=m.seq_lens[num_prefills:] + if m.seq_lens is not None else None, + seq_lens_tensor=m.seq_lens_tensor[num_prefills:] + if m.seq_lens_tensor is not None else None, + max_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=m.max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=m.context_lens_tensor[num_prefills:] + if m.context_lens_tensor is not None else None, + block_tables=m.block_tables[num_prefills:] + if m.block_tables is not None else None, + use_cuda_graph=False) + + def advance_step(self, + num_seqs: int, + num_queries: int, + prefill_token_chunk_sizes: Optional[List[int]] = None, + prefill_do_samples: Optional[List[int]] = None): """ Update metadata in-place to advance one decode step. """ @@ -317,18 +376,10 @@ def advance_step(self, num_seqs: int, num_queries: int): assert num_seqs > num_queries assert self.use_cuda_graph - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - assert self.seq_lens is not None assert len(self.seq_lens) == num_seqs assert self.seq_lens_tensor is not None assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - assert self.max_decode_seq_len == max(self.seq_lens) assert self.query_start_loc is not None assert self.query_start_loc.shape == (num_queries + 1, ) @@ -341,11 +392,50 @@ def advance_step(self, num_seqs: int, num_queries: int): assert self.block_tables is not None assert self.block_tables.shape[0] == num_seqs + has_prefills: bool = self.num_prefills > 0 + has_decodes: bool = num_seqs - self.num_prefills > 0 + + if has_prefills: + assert self.slot_mapping.shape == (num_seqs - self.num_prefills + + self.num_prefill_tokens, ) + assert prefill_token_chunk_sizes is not None + assert sum(prefill_token_chunk_sizes) == self.num_prefill_tokens + assert len(prefill_token_chunk_sizes) == self.num_prefills + assert len(prefill_do_samples) == self.num_prefills + assert self.num_decode_tokens == num_seqs - self.num_prefills + else: + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + # Update query lengths. Note that we update only queries and not seqs, # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): + if has_prefills: + assert prefill_token_chunk_sizes is not None + for idx, tcs in enumerate(prefill_token_chunk_sizes): + if not prefill_do_samples[idx]: + # do_sample prefills end in first step + self.seq_lens[idx] += tcs + for i in range(self.num_prefills, num_queries): self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) + + self.max_decode_seq_len = max( + self.seq_lens[self.num_prefills:]) if has_decodes else 0 + self.max_prefill_seq_len = max( + self.seq_lens[:self.num_prefills]) if has_prefills else 0 + + if has_prefills: + assert prefill_token_chunk_sizes is not None + self.max_query_len = max(prefill_token_chunk_sizes) + else: + self.max_query_len = 1 + + # Trigger recompute + self._cached_decode_metadata = None + self._cached_prefill_metadata = None class FlashAttentionMetadataBuilder( @@ -687,6 +777,7 @@ def forward( else: # prefix-enabled attention assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) output[: num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa diff --git a/vllm/config.py b/vllm/config.py index 4cbdde5e113a..6a1e0b437dc0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -922,6 +922,10 @@ def __init__(self, self.preemption_mode = preemption_mode self.num_scheduler_steps = num_scheduler_steps self.send_delta_data = send_delta_data + # get from env + # TODO (varun) : Is there a better way ? + self.multi_step_chunked_prefill_max_token_chunk: int = \ + envs.VLLM_MULTI_STEP_CHUNKED_PREFILL_MAX_TOKEN_CHUNK self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 3b716e32032c..b2cb087472e8 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union +import vllm.envs as envs from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger @@ -983,7 +984,8 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: [s.seq_group for s in swapped_in.prefill_seq_groups]) # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) - return SchedulerOutputs( + + scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=(prefills.seq_groups + running_scheduled.prefill_seq_groups + swapped_in.prefill_seq_groups + @@ -1005,6 +1007,95 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: len(running_scheduled.swapped_out)), ) + if self.scheduler_config.is_multi_step: + # multi step scheduler outputs + scheduler_outputs = self._make_multi_step_scheduler_outputs( + scheduler_outputs) + + return scheduler_outputs + + def _make_multi_step_scheduler_outputs( + self, scheduler_outputs: SchedulerOutputs): + """ + Given a set of scheduler outputs, determine if the prefill sequences + in the list have more chunks that could be processed in the multi-step. + """ + assert self.scheduler_config.is_multi_step + # dont support beam search + assert all([ + len(sg.seq_group.seqs) == 1 + for sg in scheduler_outputs.scheduled_seq_groups + ]) + + assert all([ sg.token_chunk_size <= self.scheduler_config.multi_step_chunked_prefill_max_token_chunk \ + for sg in scheduler_outputs.scheduled_seq_groups]) + + # If it is all decode dont do anything. multi-step is all-set !! + if scheduler_outputs.num_prefill_groups == 0: + return scheduler_outputs + + if envs.VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY: + has_prefills = len([ + None for sg in scheduler_outputs.scheduled_seq_groups + if sg.seq_group.is_prefill() + ]) + if has_prefills: + for sg in scheduler_outputs.scheduled_seq_groups: + sg.seq_group.init_multi_step(1) + return scheduler_outputs + + ## Default policy + def get_max_prefill_chunk_steps( + prefill_seq_groups: Iterable[ScheduledSequenceGroup]): + max_steps = self.scheduler_config.num_scheduler_steps + prefill_num_uncomputed = [ + psg.seq_group.get_num_uncomputed_tokens() + for psg in prefill_seq_groups + ] + token_chunk_sizes = [ + psg.token_chunk_size for psg in prefill_seq_groups + + ] + + import math + steps = [ + int(math.ceil(float(x) / float(y))) + for x, y in zip(prefill_num_uncomputed, token_chunk_sizes) + ] + # ignore the last chunk as it would require output sampling. + steps = [max(1, x - 1) for x in steps] + steps = [min(max_steps, x) for x in steps] + max_prefill_chunk_steps = min(steps) + return max_prefill_chunk_steps + + prefill_seq_groups: List[ScheduledSequenceGroup] = [ + sg for sg in scheduler_outputs.scheduled_seq_groups + if sg.seq_group.is_prefill() + ] + max_prefill_steps = get_max_prefill_chunk_steps( + prefill_seq_groups) + assert max_prefill_steps >= 1 + assert max_prefill_steps <= self.scheduler_config.num_scheduler_steps + + prefill_num_uncomputed_chunk = [ + (psg.seq_group.get_num_uncomputed_tokens(), psg.token_chunk_size) + for psg in prefill_seq_groups + ] + #prefill_num_uncomputed_chunk = sorted(prefill_num_uncomputed_chunk) + #print (f"prefill max steps {max_prefill_steps}") + #for uncomputed, chunk in prefill_num_uncomputed_chunk: + # print (f" - {uncomputed} - {chunk}") + + # update all the sequence to run only until + # max_scheduled_prefill_chunk_steps. We curtail the decodes + # intentionally so the decodes runs are not inefficient. + for sg in scheduler_outputs.scheduled_seq_groups: + #sg.seq_group.init_multi_step(1) + #sg.seq_group.init_multi_step(max_prefill_steps) + sg.seq_group.init_multi_step(self.scheduler_config.num_scheduler_steps) + + return scheduler_outputs + def _schedule(self) -> SchedulerOutputs: """Schedule queued requests.""" if self.scheduler_config.chunked_prefill_enabled: @@ -1028,7 +1119,7 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: return self.block_manager.can_append_slots( seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), + num_lookahead_slots=self._get_num_lookahead_slots(is_prefill) + 1, ) def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: @@ -1158,6 +1249,28 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: else: seq_group.metrics.scheduler_time = scheduler_time + ## debug + #from vllm.sequence import SequenceStage + #print("Scheduler :: \n") + #for idx, sg in enumerate(scheduler_outputs.scheduled_seq_groups): + # assert len(sg.seq_group.seqs_dict) == 1 + # sgml = seq_group_metadata_list[idx] + # seq_id = list(sg.seq_group.seqs_dict.keys())[0] + # num_uncomputed = sg.seq_group.get_num_uncomputed_tokens() + # num_computed = sg.seq_group.seqs[0].data._num_computed_tokens + # prompt_ids = sg.seq_group.prompt_token_ids + # num_steps = sg.seq_group.state.num_steps + # stage = sg.seq_group.seqs[0].data.stage + # stage = "prefill" if stage == SequenceStage.PREFILL else "decode" + # token_chunk_size = sg.token_chunk_size + # print((f" - id {seq_id} |" + # f" steps {num_steps} |" + # f" stage {stage} |" + # f" token-chunk {token_chunk_size} |" + # f" do-sample {sgml.do_sample} ->" + # f" #prompts {len(prompt_ids)} ->" + # f" #computed {num_computed}\n")) + return seq_group_metadata_list, scheduler_outputs def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: @@ -1202,11 +1315,13 @@ def _append_slots( the new source and destination block indices for the appended slots. """ - num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) + num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill=seq_group.is_prefill()) + seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - cows = self.block_manager.append_slots(seq, num_lookahead_slots) + cows = self.block_manager.append_slots(seq, num_lookahead_slots + 1) if len(cows) > 0: blocks_to_copy.extend(cows) @@ -1342,8 +1457,42 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, Returns 0 if the new token cannot be computed due to token budget. """ - num_new_tokens = 0 + def get_chunk_size(seqs: List[Sequence]) -> int: + """ + multistep chunkped prefill new tokens + """ + # beam search not supported + assert len(seqs) == 1 + + max_steps = self.scheduler_config.num_scheduler_steps + max_chunk_size = self.scheduler_config.multi_step_chunked_prefill_max_token_chunk + + seq = seqs[0] + new_tokens = seq.get_num_new_tokens() + if not seq.is_prefill(): + return new_tokens + if new_tokens == 1: + return new_tokens + + if new_tokens <= max_steps: + # Do the entire thing so it doesn;t curtail other sequences + return new_tokens + + import math + # we dont want to deal with the last chunk - it has sampling implications + to_chunk = new_tokens - 1 + chunk_size = to_chunk // max_steps + chunk_size = min(max_chunk_size, chunk_size) + chunk_size = max(1, chunk_size) + return chunk_size + seqs = seq_group.get_seqs(status=status) + + if self.scheduler_config.chunked_prefill_enabled and \ + self.scheduler_config.is_multi_step: + return get_chunk_size(seqs) + + num_new_tokens = 0 for seq in seqs: num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7f45c3d06375..7af249ffb9f3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -888,13 +888,9 @@ def create_engine_config(self, ) -> EngineConfig: disable_logprobs=self.disable_logprobs_during_spec_decoding, ) - if self.num_scheduler_steps > 1: - if speculative_config is not None: - raise ValueError("Speculative decoding is not supported with " - "multi-step (--num-scheduler-steps > 1)") - if self.enable_chunked_prefill: - raise ValueError("Chunked prefill is not supported with " - "multi-step (--num-scheduler-steps > 1)") + if self.num_scheduler_steps > 1 and speculative_config is not None: + raise ValueError("Speculative decoding is not supported with " + "multi-step (--num-scheduler-steps > 1)") # make sure num_lookahead_slots is set the higher value depending on # if we are using speculative decoding or multi-step diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8812b853c066..cebc1365ef53 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -35,7 +35,7 @@ from vllm.utils import print_warning_once logger = init_logger(__name__) -ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S +ENGINE_ITERATION_TIMEOUT_S = 1000000 #envs.VLLM_ENGINE_ITERATION_TIMEOUT_S class AsyncEngineDeadError(RuntimeError): @@ -300,10 +300,11 @@ async def step_async( seq_group_metadata_list, scheduler_outputs = self.scheduler[ virtual_engine].schedule() - if (self.scheduler_config.is_multi_step - and scheduler_outputs.num_lookahead_slots > 0): + remaining_steps = self._remaining_steps(seq_group_metadata_list) + if self.scheduler_config.is_multi_step and \ + remaining_steps is not None and remaining_steps > 1: # cache the scheduler outputs for the next iteration if we have - # lookahead slots + # one. self._cache_scheduler_outputs_for_multi_step( virtual_engine, seq_group_metadata_list, scheduler_outputs) @@ -346,7 +347,8 @@ async def step_async( # Finish the current step for all the sequence groups. if self.scheduler_config.is_multi_step: for seq_group in seq_group_metadata_list: - seq_group.finish_step() + if seq_group.state.remaining_steps > 0: + seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): # clear the cache if we have finished all the steps @@ -367,25 +369,32 @@ async def step_async( return request_outputs - def _has_remaining_steps( + def _remaining_steps( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] - ) -> bool: + ) -> Optional[int]: if (not self.scheduler_config.is_multi_step or not seq_group_metadata_list): - return False + return None - # TODO(will) this is a sanity check for nowto make sure that all the - # seqs are on the same steps. Eventually we will want to do some sort of - # dynamic scheduling when doing multi-step decoding. - ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps + remaining_steps = seq_group_metadata_list[0].state.remaining_steps if any([ - seq_group.state.remaining_steps != ref_remaining_steps + seq_group.state.remaining_steps != remaining_steps for seq_group in seq_group_metadata_list[1:] ]): raise AssertionError(("All running sequence groups should " "have the same remaining steps.")) - return ref_remaining_steps > 0 + return remaining_steps + + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + remaining_steps: Optional[int] = self._remaining_steps( + seq_group_metadata_list) + if remaining_steps is None: + return False + + return remaining_steps > 0 def _cache_scheduler_outputs_for_multi_step( self, virtual_engine: int, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f72902c37218..6e0422792307 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -38,7 +38,7 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, PoolerOutput, SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, - SequenceStatus) + SequenceStatus, CompletionSequenceGroupOutput) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -1199,13 +1199,46 @@ def _process_model_outputs( output_by_sequence_group = create_output_by_sequence_group( output, num_seq_groups=len(scheduled_seq_groups)) + #print (f"output by seqeunce group ...\n") + #for osg in output_by_sequence_group: + # print (f"{osg}") + + # Update + if self.scheduler_config.is_multi_step: + for scheduled_seq_group, outputs, seq_group_meta in zip( + scheduled_seq_groups, output_by_sequence_group, + seq_group_metadata_list): + if not self.scheduler_config.chunked_prefill_enabled: + # multi step prefill updates + scheduled_seq_group.seq_group.update_num_computed_tokens( + scheduled_seq_group.token_chunk_size * + seq_group_meta.state.num_steps) + else: + # chunked prefill case + if seq_group_meta.is_prompt: + if seq_group_meta.do_sample: + scheduled_seq_group.seq_group.update_num_computed_tokens( + scheduled_seq_group.token_chunk_size) + # update outputs + assert all([op == outputs[0] for op in outputs]) + for idx in range(1, len(outputs)): + outputs[idx] = CompletionSequenceGroupOutput(samples=[], prompt_logprobs=None) + else: + scheduled_seq_group.seq_group.update_num_computed_tokens( + scheduled_seq_group.token_chunk_size * + seq_group_meta.state.num_steps) + # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs, seq_group_meta in zip( scheduled_seq_groups, output_by_sequence_group, seq_group_metadata_list): seq_group = scheduled_seq_group.seq_group - seq_group.update_num_computed_tokens( - scheduled_seq_group.token_chunk_size) + + if not self.scheduler_config.is_multi_step: + # TODO (varun) : Is this required ? + seq_group.update_num_computed_tokens( + scheduled_seq_group.token_chunk_size) + if output is not None and len(output) > 0: for o in output: if (isinstance(o, SamplerOutput) diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 6c472528a7a9..3ded19c797d8 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -75,12 +75,12 @@ def process_outputs(self, sequence_group: SequenceGroup, "Beam search not supported in multi-step decoding.") seq = seqs[0] - # Since there's only one sequence per sequence group, we can take the - # first sample. - samples = [output.samples[0] for output in outputs] + # TODO (Varun) : Pass in an output_token_id of -1 instead of returning + # 0 samples. + samples = [output.samples[0] for output in outputs if output.samples] # -1 means the output token is not valid (eg. due to spec decode - # rejecting tokens). + # rejecting tokens) valid_samples = [ sample for sample in samples if sample.output_token != -1 ] @@ -124,6 +124,10 @@ def _process_seq_outputs(self, seq: Sequence, token_id=output_token_id, logprobs=output_logprob, ) + # TODO (Varun) : move this outside the loop ? + # We need it here so maybe_stop_sequence can limit on + # the sampling_params.max_token arg + seq.data.update_num_computed_tokens(1) new_char_count = 0 if sampling_params.detokenize: diff --git a/vllm/envs.py b/vllm/envs.py index 24e09ee0e055..7c55aa06b0d5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -60,6 +60,8 @@ VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY: bool = False + VLLM_MULTI_STEP_CHUNKED_PREFILL_MAX_TOKEN_CHUNK: int = 4 def get_default_cache_root(): @@ -400,6 +402,20 @@ def get_default_config_root(): "VLLM_TORCH_PROFILER_DIR": lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + + # Applicable when multi-decodes (--num-scheduler-steps) and chunked-prefill + # (--enable-chunked-prefill) are both enabled. When prefills are scheduled + # together with decode sequences, this flag forces the engine to single-step + # the model execution for all the sequences. The default behaviour is to, + # run the both the prefill and decode sequence for the first step and run + # only the decode sequences for the rest of the steps. + "VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY": + lambda: os.environ.get( + "VLLM_MULTI_STEP_CHUNKED_PREFILL_SINGLE_STEP_POLICY", "False").lower( + ) in ("true", "1"), + "VLLM_MULTI_STEP_CHUNKED_PREFILL_MAX_TOKEN_CHUNK": + lambda: int( + os.getenv("VLLM_MULTI_STEP_CHUNKED_PREFILL_MAX_TOKEN_CHUNK", "4")), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 3ba15573c217..c49b7bdf0eb4 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -40,6 +40,7 @@ def apply(self, def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: + #print (f"input range [{input_.min()} {input_.max()}]") return F.embedding(input_, layer.weight) diff --git a/vllm/sequence.py b/vllm/sequence.py index 2fe8ae9d7b27..e598f592f04b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1296,28 +1296,31 @@ class ExecuteModelRequest( @property def is_first_multi_step(self) -> bool: - # TODO(will) make this be able to handle batches with variable number of - # steps + # All sequences should start their lifetimes at step 0 assert len(self.seq_group_metadata_list) > 0 - first_seq_group = self.seq_group_metadata_list[0] - assert first_seq_group.state is not None - return first_seq_group.state.current_step == 0 + state = self.seq_group_metadata_list[0].state + assert state is not None + return state.current_step == 0 @property def is_last_step(self) -> bool: - # TODO(will) make this be able to handle batches with variable number of - # steps + # Assumptions: + # 1. All the prefills appear before any of the decodes + # in self.seq_group_metadata_list. + # 2. All the decode sequences have the same num_steps and that, + # the num_steps of the decode sequences >= num_steps of the + # prefill sequences. + assert len(self.seq_group_metadata_list) > 0 - first_seq_group = self.seq_group_metadata_list[0] - assert first_seq_group.state is not None - return first_seq_group.state.remaining_steps == 1 + state = self.seq_group_metadata_list[-1].state + assert state is not None + return state.remaining_steps == 1 @property def current_step(self) -> int: - # TODO(will) make this be able to handle batches with variable number of - # steps + # Assumptions : refer to the comment in `is_last_step` assert len(self.seq_group_metadata_list) > 0 - state = self.seq_group_metadata_list[0].state + state = self.seq_group_metadata_list[-1].state assert state is not None return state.current_step diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 2dfbacfb7b75..9b6b2aa83528 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -94,7 +94,6 @@ def sampler_output( assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] - self._append_new_tokens( model_output, expanded_request.seq_group_metadata_list) model_outputs.append(model_output) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5d930919b8ae..e48e65b69b4c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -17,6 +17,7 @@ import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState +from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.attention.backends.utils import CommonAttentionState from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, @@ -129,6 +130,50 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): # used by the driver worker. is_prompt: Optional[bool] = None + @staticmethod + def without_prefills(m: "ModelInputForGPUWithSamplingMetadata", + sampling_metadata_decodes: SamplingMetadata) \ + -> "ModelInputForGPUWithSamplingMetadata": + + assert m.attn_metadata is not None + num_prefills = m.attn_metadata.num_prefills + num_prefill_tokens = m.attn_metadata.num_prefill_tokens + if num_prefills == 0: + assert m.sampling_metadata == sampling_metadata_decodes + return dataclasses.replace(m) + + # Prefill related data in the following datastructures are not handled. + assert all([ + m.lora_mapping is None, m.lora_requests is None + or len(m.lora_requests) == 0, m.prompt_adapter_mapping is None, + m.prompt_adapter_requests is None + or len(m.prompt_adapter_requests) == 0, + m.multi_modal_kwargs is None or len(m.multi_modal_kwargs) == 0 + ]) + assert isinstance(m.attn_metadata, FlashAttentionMetadata) + assert (m.input_tokens is not None) + assert (m.input_positions is not None) + + return ModelInputForGPUWithSamplingMetadata( + input_tokens=m.input_tokens[num_prefill_tokens:], + input_positions=m.input_positions[num_prefill_tokens:], + seq_lens=m.seq_lens[num_prefills:] + if m.seq_lens is not None else None, + query_lens=m.query_lens[num_prefills:] + if m.query_lens is not None else None, + lora_mapping=None, + lora_requests=set(), + attn_metadata=FlashAttentionMetadata.without_prefills( + m.attn_metadata), + prompt_adapter_mapping=None, + prompt_adapter_requests=set(), + multi_modal_kwargs={}, + request_ids_to_seq_ids=m.request_ids_to_seq_ids, + finished_requests_ids=m.finished_requests_ids, + virtual_engine=m.virtual_engine, + sampling_metadata=sampling_metadata_decodes, + is_prompt=False) + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, @@ -1348,9 +1393,14 @@ def prepare_model_input( # Sampling metadata is only required for the final pp group generators = self.get_generators(finished_requests_ids) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, model_input.seq_lens, - model_input.query_lens, self.device, self.pin_memory, - generators, self.sampling_metadata_cache) + seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory, + generators, + # TODO(varun) : Fix sampling metadata cache impl. + None) else: sampling_metadata = None is_prompt = (seq_group_metadata_list[0].is_prompt diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 90c39407d726..a4821c3c88bf 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -49,18 +49,21 @@ def _init_attn_metadata_from_tensor_dict( def _init_sampling_metadata_from_tensor_dict( # type: ignore - tensor_dict: Dict[str, Any]) -> Dict[str, Any]: + tensor_dict: Dict[str, Any], + sampling_metadata_key: str = "sampling_metadata", + selected_token_ids_key: str = "selected_token_indices") -> Dict[str, + Any]: """ Helper method to initialize SamplingMetadata based on broadcastable SamplingMetadata fields. """ from vllm.model_executor import SamplingMetadata - selected_token_indices = tensor_dict.pop("selected_token_indices", None) + selected_token_indices = tensor_dict.pop(selected_token_ids_key, None) # An empty SamplingMetadata to signal that the worker should skip # sampling. if selected_token_indices is not None: - tensor_dict["sampling_metadata"] = SamplingMetadata( + tensor_dict[sampling_metadata_key] = SamplingMetadata( seq_groups=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, @@ -71,13 +74,14 @@ def _init_sampling_metadata_from_tensor_dict( # type: ignore def _add_sampling_metadata_broadcastable_dict( tensor_dict: Dict[str, Any], - sampling_metadata: Optional["SamplingMetadata"]) -> None: + sampling_metadata: Optional["SamplingMetadata"], + selected_token_ids_key: str = "selected_token_indices") -> None: """ Helper method to update tensor_dict with broadcastable SamplingMetadata fields. """ if sampling_metadata is not None: - tensor_dict["selected_token_indices"] = ( + tensor_dict[selected_token_ids_key] = ( sampling_metadata.selected_token_indices) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 521205eca05a..9e9a855a456d 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union try: from vllm.attention.backends.flash_attn import FlashAttentionMetadata @@ -11,20 +11,23 @@ import torch from vllm import _custom_ops as ops +from vllm.attention.backends.utils import compute_slot_mapping from vllm.distributed import get_pp_group from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SamplerOutput, SequenceGroupMetadata, SequenceOutput) +from vllm.utils import async_tensor_h2d from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( - BroadcastableModelInput, _init_attn_metadata_from_tensor_dict, + BroadcastableModelInput, _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, _init_frozen_model_input_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) -from ..model_executor.model_loader.tensorizer import TensorizerConfig - if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -52,6 +55,12 @@ class ModelOutput: sampled_token_ids: Optional[torch.Tensor] = None pythonized: bool = False + # Metadata required to run the pythonization. + # This information is passed on from the model_input during ModelOutput + # creation. Please look at the comments in StatefulModelInput. + sampling_metadata: Optional[SamplingMetadata] = None + num_empty_prefill_step_outputs: Optional[int] = None + def pythonize(self, input_metadata: "StatefulModelInput", copy_stream: torch.cuda.Stream, pinned_sampled_token_buffer: torch.Tensor) -> None: @@ -85,7 +94,11 @@ def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", if blocking: self.sampler_output_ready_event.synchronize() with torch.cuda.stream(copy_stream): - _pythonize_sampler_output(input_metadata, self.sampler_output, + assert self.sampling_metadata is not None + assert self.num_empty_prefill_step_outputs is not None + _pythonize_sampler_output(self.sampling_metadata, + self.num_empty_prefill_step_outputs, + self.sampler_output, pinned_sampled_token_buffer, self.sampled_token_ids) return True @@ -112,9 +125,51 @@ class StatefulModelInput(BroadcastableModelInput): num_seqs: int = -1 num_queries: int = -1 + num_prefills_with_sampling: int = -1 + prefill_do_samples: Optional[List[bool]] = None + + # multi-step chunked prefill tokens and slot mapping + prefill_steps_tokens: Optional[torch.Tensor] = None + prefill_steps_slot_mapping: Optional[torch.Tensor] = None + prefill_input_positions_update: Optional[torch.Tensor] = None + prefill_seq_start_loc_update: Optional[torch.Tensor] = None + prefill_token_chunk_sizes_tensor: Optional[torch.Tensor] = None + prefill_advance_query_tensor: Optional[torch.Tensor] = None + prefill_advance_tokens_tensor: Optional[torch.Tensor] = None + token_chunk_sizes: Optional[List[int]] = None + + # Multi-Step + Chunked-Prefill related args. + # When the initially scheduled sequences have both prefill and decode + # sequences, the first iteration of the multi-step processes with all + # the sequences. However, further iterations only process the decode + # sequences. + # + # For example: + # Let [S1, S2, S3, S4, S5, S6] be the scheduled set of sequences. + # let {S1, S2, S3} be prefills. Assume S2 doesn't need sampling, but S1 and + # S3 does. + # let {S4, S5, S6} be decodes. All decode sequences need sampling. + # Step 1: execute_model processes all sequences and the corresponding + # pythonize_sampler_output will produce results {R1, R3, R4, R5, R6} (Rx + # is the result for the xth sequence) + # Step 2-n: execute_model only processes sequences {S4, S5, S6} and the + # corresponding pythonize_sampler_output will produce results + # {[], [], R4, R5, R6} + + # Use sampling_metadata_decodes for decode-exclusive iterations. + sampling_metadata_decodes: Optional[SamplingMetadata] = None + # When pythonizing sampler outputs for the decode-exclusive steps, + # populate the sampler output with `num_empty_prefill_step_outputs` + # empty outputs. + num_empty_prefill_step_outputs: int = 0 + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: assert self.frozen_model_input is not None tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict() + _add_sampling_metadata_broadcastable_dict( + tensor_dict, + self.sampling_metadata_decodes, + selected_token_ids_key="selected_token_indices_decodes") new_tensor_dict = { 'last_sampled_token_ids': self.last_sampled_token_ids, 'current_step': self.current_step, @@ -123,6 +178,20 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: 'is_first_multi_step': self.is_first_multi_step, 'num_seqs': self.num_seqs, 'num_queries': self.num_queries, + 'num_prefills_with_sampling' : self.num_prefills_with_sampling, + 'prefill_do_samples': self.prefill_do_samples, + 'prefill_steps_tokens': self.prefill_steps_tokens, + 'prefill_steps_slot_mapping': self.prefill_steps_slot_mapping, + 'prefill_input_positions_update': + self.prefill_input_positions_update, + 'prefilll_seq_start_loc_update': self.prefill_seq_start_loc_update, + 'prefill_token_chunk_sizes_tensor': + self.prefill_token_chunk_sizes_tensor, + 'prefill_advance_query_tensor' : self.prefill_advance_query_tensor, + 'prefill_advance_tokens_tensor' : self.prefill_advance_tokens_tensor, + 'token_chunk_sizes': self.token_chunk_sizes, + 'num_empty_prefill_step_outputs': + self.num_empty_prefill_step_outputs, } tensor_dict.update(new_tensor_dict) return tensor_dict @@ -133,7 +202,12 @@ def from_broadcasted_tensor_dict( tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> "StatefulModelInput": + # base model runner's sampling_metadata tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + # SatefulModelInput's sampling_metadata_decodes + tensor_dict = _init_sampling_metadata_from_tensor_dict( + tensor_dict, "sampling_metadata_decodes", + "selected_token_indices_decodes") if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) @@ -165,11 +239,16 @@ def wait_previous_step(self): def add_sampler_output(self, sampler_output: SamplerOutput, sampled_token_ids: Optional[torch.Tensor] = None): + assert self.frozen_model_input is not None self.cached_outputs.append( - ModelOutput(sampler_output=sampler_output, - sampler_output_ready_event=None, - sampled_token_ids=sampled_token_ids, - pythonized=False)) + ModelOutput( + sampler_output=sampler_output, + sampler_output_ready_event=None, + sampled_token_ids=sampled_token_ids, + pythonized=False, + sampling_metadata=self.frozen_model_input.sampling_metadata, + num_empty_prefill_step_outputs=self. + num_empty_prefill_step_outputs)) # MutableModelInputForGPUWithMultiStepMetadata is not subclass of @@ -199,20 +278,213 @@ def make_model_input_from_broadcasted_tensor_dict( )) return model_input + + def make_prefill_steps_data( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + num_prefill_tokens: int) -> \ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + # for multi step chunked prefill propagation - get the, + # 1. Prefill tokens required for steps 1 to n + # 2. Prefill slot mapping required for steps 1 to n + # 3. Prefill position update required for steps 1 to n + # 4. seq_start_loc update tensor + + # filter prompt seqs + prefill_sgml = [ + sgml for sgml in seq_group_metadata_list if sgml.is_prompt + ] + if len(prefill_sgml) == 0: + # No prefills + return (None, None, None, None, None, None, None, []) + + num_multi_step = prefill_sgml[0].state.num_steps + assert (num_multi_step == self.scheduler_config.num_scheduler_steps) + assert all( + [sgml.state.num_steps == num_multi_step for sgml in prefill_sgml]) + if num_multi_step == 1: + # Single step - we dont need to advance prefills + return (None, None, None, None, None, None, None, []) + + # Populate input_tokens + input_tokens = [] + + #seq_lens = [sgml.seq_data[0].get_num_computed_tokens() + # for sgml in prefill_sgml] + # Assert that we only have one sequence in every seq-group. + # i.e. dont support beam search. + assert all([len(sgml.seq_data) == 1 for sgml in prefill_sgml]) + # TODO (varun) : find a better way to do this. + seq_ids = [list(sgml.seq_data.keys())[0] for sgml in prefill_sgml] + + seq_lens = [ + sgml.seq_data[seq_id].get_num_computed_tokens() + for seq_id, sgml in zip(seq_ids, prefill_sgml) + ] + + + token_chunk_sizes = [sgml.token_chunk_size for sgml in prefill_sgml] + + advance_prefill_query = [not sgml.do_sample for sgml in prefill_sgml] + advance_prefill_token = [] + for idx, aq in enumerate(advance_prefill_query): + advance_prefill_token.extend([aq] * token_chunk_sizes[idx]) + + # the 0th step is already computed + input_token_offsets = [ + sl + tcs for (sl, tcs) in zip(seq_lens, token_chunk_sizes) + ] + # because the 0th step data is already computed - start from 1 + for _ in range(1, num_multi_step): + for idx, sgml in enumerate(prefill_sgml): + seq_id = seq_ids[idx] + offt = input_token_offsets[idx] + tcs = token_chunk_sizes[idx] + do_advance = advance_prefill_query[idx] + + if do_advance: + assert len(sgml.seq_data[seq_id].get_token_ids()) >= offt + tcs + input_tokens.extend( + sgml.seq_data[seq_id].get_token_ids()[offt:offt + tcs]) + input_token_offsets[idx] += tcs + else: + input_tokens.extend([-1] * tcs) + assert len(input_tokens) == num_prefill_tokens * (num_multi_step - 1) + + # Populate slot mapping + input_slot_mapping = [] + for step_idx in range(1, num_multi_step): + for idx, sgml in enumerate(prefill_sgml): + # TODO (varun) : Is the calculation of sgml_context_len and + # sgml_seq_len valid ? + # Are there any nuances we are missing ? + seq_id = seq_ids[idx] + tcs = token_chunk_sizes[idx] + sgml_context_len = seq_lens[idx] + (tcs * step_idx) + sgml_seq_len = sgml_context_len + tcs + do_advance = advance_prefill_query[idx] + + part_slot_mapping: List[int] = [] + if do_advance: + compute_slot_mapping( + is_profile_run=False, + slot_mapping=part_slot_mapping, + seq_id=seq_id, + context_len=sgml_context_len, + seq_len=sgml_seq_len, + start_idx= + 0, # TODO (Varun) : Assert that sliding window is none + block_size=self.block_size, + block_tables=sgml.block_tables) + else: + part_slot_mapping = [-1] * tcs + input_slot_mapping.extend(part_slot_mapping) + + assert len( + input_slot_mapping) == num_prefill_tokens * (num_multi_step - 1) + + #for idx, x in enumerate(input_slot_mapping): + # if x == -1: + # part_idx = idx % num_prefill_tokens + # assert not advance_prefill_token[part_idx] + + #slots_min = min(input_slot_mapping) + #slots_max = max(input_slot_mapping) + #print (f"input slot mapping min {slots_min} max {slots_max}") + #if slots_max / 16 >= 27911: + # print ("Failure !!") + #assert slots_max / 16 < 27911 + + + # Populate position update tensor + input_positions_update = [] + for tcs in token_chunk_sizes: + input_positions_update.extend([tcs] * tcs) + assert len(input_positions_update) == num_prefill_tokens + + # populate seq_start_loc update tensor + num_seqs = len(seq_group_metadata_list) + num_prefills = len(prefill_sgml) + seq_start_loc_update = [0] * (num_seqs + 1) + for idx in range(1, num_seqs + 1): + prev_tcs = token_chunk_sizes[idx - + 1] if (idx - 1) < num_prefills else 1 + seq_start_loc_update[idx] = seq_start_loc_update[idx - + 1] + prev_tcs + + # Update to redo prefills + redo_prefill_offset = 0 if advance_prefill_query[0] else token_chunk_sizes[0] + for idx in range(1, num_seqs + 1): + seq_start_loc_update[idx] -= redo_prefill_offset + if idx < num_prefills and not advance_prefill_query[idx]: + redo_prefill_offset += token_chunk_sizes[idx] + + # Async transfer to GPU + prefill_input_tokens = async_tensor_h2d(input_tokens, torch.long, + self.device, self.pin_memory) + prefill_slot_mapping = async_tensor_h2d(input_slot_mapping, torch.long, + self.device, self.pin_memory) + prefill_input_positions_update = async_tensor_h2d( + input_positions_update, torch.long, self.device, self.pin_memory) + prefill_token_chunk_sizes_tensor = async_tensor_h2d( + token_chunk_sizes, torch.int, self.device, self.pin_memory) + prefill_seq_start_loc_update = async_tensor_h2d( + seq_start_loc_update, torch.int32, self.device, self.pin_memory) + + prefill_advance_query = async_tensor_h2d( + advance_prefill_query, torch.int8, self.device, self.pin_memory) + prefill_advance_tokens = async_tensor_h2d( + advance_prefill_token, torch.int8, self.device, self.pin_memory) + + return (prefill_input_tokens, prefill_slot_mapping, + prefill_input_positions_update, prefill_seq_start_loc_update, + prefill_token_chunk_sizes_tensor, + prefill_advance_query, prefill_advance_tokens, + token_chunk_sizes) + def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> StatefulModelInput: - frozen_model_input = self._base_model_runner.prepare_model_input( + frozen_model_input: ModelInputForGPUWithSamplingMetadata = \ + self._base_model_runner.prepare_model_input( seq_group_metadata_list, virtual_engine, finished_requests_ids) + assert frozen_model_input.attn_metadata is not None + (prefill_steps_tokens, prefill_slot_mapping, + prefill_input_positions_update, prefill_seq_start_loc_update, + prefill_token_chunk_sizes, prefill_advance_query, prefill_advance_tokens, token_chunk_sizes) = \ + self.make_prefill_steps_data(seq_group_metadata_list, + frozen_model_input.attn_metadata.num_prefill_tokens) + + prefills_with_sampling = 0 + for sgml in seq_group_metadata_list: + if sgml.is_prompt and sgml.do_sample: + prefills_with_sampling += 1 + + + prefill_do_samples = [sgml.do_sample for sgml in seq_group_metadata_list if sgml.is_prompt] + + assert frozen_model_input.seq_lens is not None + assert frozen_model_input.query_lens is not None model_input = StatefulModelInput( frozen_model_input=frozen_model_input, num_seqs=len(frozen_model_input.seq_lens), num_queries=len(frozen_model_input.query_lens), - ) + prefill_steps_tokens=prefill_steps_tokens, + prefill_steps_slot_mapping=prefill_slot_mapping, + prefill_input_positions_update=prefill_input_positions_update, + prefill_seq_start_loc_update=prefill_seq_start_loc_update, + prefill_token_chunk_sizes_tensor=prefill_token_chunk_sizes, + prefill_advance_query_tensor = prefill_advance_query, + prefill_advance_tokens_tensor = prefill_advance_tokens, + token_chunk_sizes=token_chunk_sizes, + num_prefills_with_sampling = prefills_with_sampling, + prefill_do_samples = prefill_do_samples, + sampling_metadata_decodes=None) return model_input @torch.inference_mode() @@ -228,13 +500,13 @@ def execute_model( metadata """ assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1" - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None + assert model_input.frozen_model_input is not None # path for warm up runs if not model_input.is_multi_step: return self._base_model_runner.execute_model( - frozen_model_input, kv_caches, intermediate_tensors, num_steps) + model_input.frozen_model_input, kv_caches, + intermediate_tensors, num_steps) # make sure we skip the sampler on the lask rank and only pythonize # if CPU is ahead. @@ -248,9 +520,9 @@ def execute_model( self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( True) - if frozen_model_input.sampling_metadata: - frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( - True) + if model_input.frozen_model_input.sampling_metadata: + model_input.frozen_model_input.sampling_metadata. \ + skip_sampler_cpu_output = (True) # some pre-execute model logic for multi-step: # - if it's the first step, we need to reset the sampling tensors @@ -258,6 +530,8 @@ def execute_model( # appended sampler output from last iteration # - also maybe pythonize if CPU is ahead of GPU + #print (f"model execute {model_input.current_step}") + current_stream = torch.cuda.current_stream() if not model_input.is_first_multi_step: # Explicitly block on the previous step's forward to make sure we @@ -268,14 +542,16 @@ def execute_model( # might clobber enqueued forwards. (prevents CPU from running too # far ahead if needed) model_input.wait_previous_step() + # Update prefill tokens and slot mappings model_input = self._advance_step( model_input, model_input.cached_outputs[-1].sampler_output) # Execute the model - output = self._base_model_runner.execute_model(frozen_model_input, - kv_caches, - intermediate_tensors, - num_steps=1) + output = self._base_model_runner.execute_model( + model_input.frozen_model_input, + kv_caches, + intermediate_tensors, + num_steps=1) # record the event for the current step so that the next step can sync model_input.record_step_event(current_stream) @@ -285,6 +561,8 @@ def execute_model( output ) == 1, "MultiStepModelRunner requires single-step base_models" + assert model_input.frozen_model_input is not None + # event for the pythonization so that we only pythonize if the # tensors are ready. May be able to be combined with the step event output_ready_event = torch.cuda.Event() @@ -293,8 +571,14 @@ def execute_model( output[0].sampled_token_ids_cpu = output[ 0].sampled_token_ids.cpu() model_input.cached_outputs.append( - ModelOutput(output[0], output_ready_event, - output[0].sampled_token_ids, False)) + ModelOutput(output[0], + output_ready_event, + output[0].sampled_token_ids, + pythonized=False, + sampling_metadata=model_input.frozen_model_input. + sampling_metadata, + num_empty_prefill_step_outputs=model_input. + num_empty_prefill_step_outputs)) # make sure we dont try to serialize any GPU tensors output[0].sampled_token_ids = None output[0].sampled_token_probs = None @@ -326,25 +610,6 @@ def execute_model( # should be [SamplerOutput] return output - def _update_sampling_metadata(self, sampling_metadata, num_seqs, - num_queries): - - assert sampling_metadata.num_prompts == 0 - assert len(sampling_metadata.seq_groups) == num_queries - assert sampling_metadata.selected_token_indices.shape == ( - num_queries, ) - # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 - - # Verify that all sequences are decodes - for i in range(num_queries): - seq_group = sampling_metadata.seq_groups[i] - - assert seq_group.is_prompt is False # No prompt - assert seq_group.prompt_logprob_indices == [] # No prompt - assert seq_group.sample_indices == [i] # Simple - assert seq_group.seq_len is None # Decode - assert seq_group.query_len is None # Decode - def _advance_step(self, model_input: StatefulModelInput, out: SamplerOutput) -> StatefulModelInput: frozen_model_input = model_input.frozen_model_input @@ -353,25 +618,89 @@ def _advance_step(self, model_input: StatefulModelInput, num_seqs = model_input.num_seqs num_queries = model_input.num_queries + assert num_seqs > 0 assert num_queries > 0 assert num_seqs >= num_queries attn_metadata = frozen_model_input.attn_metadata assert isinstance(attn_metadata, FlashAttentionMetadata) - attn_metadata.advance_step(num_seqs, num_queries) + attn_metadata.advance_step(num_seqs, num_queries, + model_input.token_chunk_sizes, + model_input.prefill_do_samples) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_prefills = attn_metadata.num_prefills + + prefill_step_offset = \ + num_prefill_tokens * (model_input.current_step - 1) \ + if self.scheduler_config.chunked_prefill_enabled else 0 + + if model_input.prefill_steps_tokens is not None: + assert num_prefill_tokens > 0 + assert attn_metadata.context_lens_tensor is not None + assert model_input.prefill_steps_slot_mapping is not None + assert model_input.prefill_steps_tokens.shape[ + 0] >= prefill_step_offset + num_prefill_tokens + assert model_input.prefill_steps_slot_mapping.shape[ + 0] >= prefill_step_offset + num_prefill_tokens + if num_prefill_tokens > 0: + assert model_input.prefill_steps_tokens is not None and \ + model_input.prefill_steps_slot_mapping is not None + #tmp = model_input.prefill_steps_slot_mapping[prefill_step_offset:prefill_step_offset + num_prefill_tokens] + #print (f"advancing .... slot idx max {torch.max(tmp)}") + # Update GPU tensors + assert model_input.cached_outputs[-1].sampled_token_ids is not None ops.advance_step( + num_prefill_tokens=num_prefill_tokens, + num_prefills=num_prefills, num_seqs=num_seqs, num_queries=num_queries, block_size=self.block_size, + num_prefills_with_sampling = model_input.num_prefills_with_sampling, input_tokens=frozen_model_input.input_tokens, sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids, input_positions=frozen_model_input.input_positions, seq_lens=attn_metadata.seq_lens_tensor, slot_mapping=attn_metadata.slot_mapping, - block_tables=attn_metadata.block_tables) + block_tables=attn_metadata.block_tables, + seq_start_loc=attn_metadata.seq_start_loc, + context_lens=attn_metadata.context_lens_tensor, + prefill_steps_tokens=model_input. + prefill_steps_tokens[prefill_step_offset:] + if model_input.prefill_steps_tokens is not None else None, + prefill_steps_slot_mapping=model_input. + prefill_steps_slot_mapping[prefill_step_offset:] + if model_input.prefill_steps_slot_mapping is not None else None, + prefill_input_positions_update=model_input. + prefill_input_positions_update, + prefill_seq_start_loc_update=model_input. + prefill_seq_start_loc_update, + prefill_advance_query=model_input.prefill_advance_query_tensor, + prefill_advance_tokens = model_input.prefill_advance_tokens_tensor, + prefill_token_chunk_sizes=model_input. + prefill_token_chunk_sizes_tensor) + + + #torch.cuda.synchronize() + #has_prefills = num_prefills > 0 + #has_decodes = num_seqs > num_prefills + #if has_prefills: + # prefill_max_slot_mapping = torch.max(model_input.frozen_model_input.attn_metadata.slot_mapping.flatten()[:num_prefill_tokens]) + # assert prefill_max_slot_mapping / 16 < 27911 + #if has_decodes: + # decode_max_slot_mapping = torch.max(model_input.frozen_model_input.attn_metadata.slot_mapping.flatten()[num_prefill_tokens:]) + # if decode_max_slot_mapping / 16 >= 27911: + # torch.set_printoptions(profile="full") + # print (f"num prefills {num_prefills} ") + # print (f'all decode slot mapping {model_input.frozen_model_input.attn_metadata.slot_mapping.flatten()[num_prefill_tokens:]}') + # print (f"decode input pos {frozen_model_input.input_positions[num_prefill_tokens:]}") + # print (f"block tables {attn_metadata.block_tables} ") + # torch.set_printoptions(profile="default") + + # assert decode_max_slot_mapping / 16 < 27911 if frozen_model_input.seq_lens is not None: for i in range(num_queries): @@ -409,22 +738,22 @@ def vocab_size(self) -> int: return self._base_model_runner.vocab_size -def _pythonize_sampler_output(model_input: StatefulModelInput, - output: SamplerOutput, - pinned_sampled_token_buffer: torch.Tensor, - sampled_token_ids: torch.Tensor) -> None: +def _pythonize_sampler_output( + sampling_metadata: SamplingMetadata, + num_empty_prefill_step_outputs: int, output: SamplerOutput, + pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: torch.Tensor) -> SamplerOutput: """ This function is only called when the output tensors are ready. See ModelOutput """ - - assert model_input.frozen_model_input is not None - - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input.sampling_metadata is not None + assert sampling_metadata is not None # samples generation should have been skipped assert not output.outputs - pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] + # dont use num-queries as some of the sequence's may not need sampling. + # Like, chunked prefill seqs. + n_sampled_token_ids = sampled_token_ids.shape[0] + pinned_buffer = pinned_sampled_token_buffer[:n_sampled_token_ids] # CPU GPU sync pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False) @@ -432,17 +761,25 @@ def _pythonize_sampler_output(model_input: StatefulModelInput, # this will not block as the tensors are already on CPU samples_list = pinned_buffer.tolist() - sampling_metadata = frozen_model_input.sampling_metadata + for _ in range(num_empty_prefill_step_outputs): + output.outputs.append(CompletionSequenceGroupOutput([], None)) + + samples_it = iter(samples_list) + for sg_idx, seq_group in enumerate(sampling_metadata.seq_groups): - for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, - samples_list): - seq_ids = seq_group.seq_ids - next_token_ids = sample_result - parent_ids = [0] - seq_outputs: List[SequenceOutput] = [] if seq_group.sampling_params.logits_processors: assert len(seq_group.sampling_params.logits_processors) == 0, ( "Logits Processors are not supported in multi-step decoding") + + skip_sequence = not seq_group.do_sample + if skip_sequence: + output.outputs.append(CompletionSequenceGroupOutput([], None)) + continue + + seq_ids = seq_group.seq_ids + next_token_ids = next(samples_it) + parent_ids = [0] + seq_outputs: List[SequenceOutput] = [] for parent_id, next_token_id in zip(parent_ids, next_token_ids): # TODO(will): support logprobs # Hard coded logprob