diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 8c06481de4..0dfe01f358 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -117,19 +117,6 @@ def xpu_pre_process( xpu_forward_meta.total_enc_len, ) = get_infer_param(seq_lens_encoder, seq_lens_decoder) - # Adjust batch - # print(f"=========================adjust_batch 更新前=========================") - # print(f"ids_remove_padding : {ids_remove_padding}") - # print(f"cum_offsets : {cum_offsets}") - # print(f"xpu_forward_meta.encoder_seq_lod : {xpu_forward_meta.encoder_seq_lod}") - # print(f"xpu_forward_meta.encoder_batch_idx: {xpu_forward_meta.encoder_batch_idx}") - # print(f"xpu_forward_meta.decoder_batch_idx : {xpu_forward_meta.decoder_batch_idx}") - # print(f"xpu_forward_meta.encoder_seq_lod_cpu : {xpu_forward_meta.encoder_seq_lod_cpu}") - # print(f"xpu_forward_meta.encoder_batch_idx_cpu : {xpu_forward_meta.encoder_batch_idx_cpu}") - # print(f"xpu_forward_meta.decoder_batch_idx_cpu : {xpu_forward_meta.decoder_batch_idx_cpu}") - # print(f"xpu_forward_meta.enc_batch : {xpu_forward_meta.encoder_batch_map}") - # print(f"xpu_forward_meta.dec_batch : {xpu_forward_meta.decoder_batch_map}") - adjusted_input = adjust_batch( ids_remove_padding.reshape([-1, 1]), cum_offsets, @@ -144,16 +131,6 @@ def xpu_pre_process( None, # output_padding_offset -1, # max_input_length ) - # print(f"=========================adjust_batch 更新后=========================") - # print(f"ids_remove_padding : {ids_remove_padding}") - # print(f"cum_offsets : {cum_offsets}") - # print(f"xpu_forward_meta.encoder_seq_lod : {xpu_forward_meta.encoder_seq_lod}") - # print(f"xpu_forward_meta.encoder_batch_idx: {xpu_forward_meta.encoder_batch_idx}") - # print(f"xpu_forward_meta.decoder_batch_idx : {xpu_forward_meta.decoder_batch_idx}") - # print(f"xpu_forward_meta.encoder_seq_lod_cpu : {xpu_forward_meta.encoder_seq_lod_cpu}") - # print(f"xpu_forward_meta.encoder_batch_idx_cpu : {xpu_forward_meta.encoder_batch_idx_cpu}") - # print(f"xpu_forward_meta.decoder_batch_idx_cpu : {xpu_forward_meta.decoder_batch_idx_cpu}") - # print(f"xpu_forward_meta.enc_batch : {xpu_forward_meta.encoder_batch_map}") adjusted_input = adjusted_input.squeeze(1) @@ -228,21 +205,6 @@ def xpu_post_process( with paddle.framework._no_check_dy2st_diff(): if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output: - # print(f"============================================update_inputs_v1 更新前=========================================") - # print(f"model_output.stop_flags : {model_output.stop_flags}") - # print(f"model_output.not_need_stop : {model_output.not_need_stop}") - # print(f"model_output.seq_lens_this_time : {model_output.seq_lens_this_time}") - # print(f"model_output.seq_lens_encoder : {model_output.seq_lens_encoder}") - # print(f"model_output.seq_lens_decoder : {model_output.seq_lens_decoder}") - # print(f"share_inputs['step_seq_lens_decoder'] : {share_inputs['step_seq_lens_decoder']}") - # print(f"share_inputs['prompt_lens'] : {share_inputs['prompt_lens']}") - # print(f"sampled_token_ids : {sampled_token_ids}") - # print(f"model_output.input_ids : {model_output.input_ids}") - # print(f"model_output.stop_nums : {model_output.stop_nums}") - # print(f"model_output.next_tokens : {model_output.next_tokens}") - # print(f"model_output.is_block_step : {model_output.is_block_step}") - # print(f"share_inputs['block_tables'] : {share_inputs['block_tables']}") - # print(f"block_size : {block_size}") update_inputs_v1( model_output.stop_flags, model_output.not_need_stop, @@ -259,21 +221,7 @@ def xpu_post_process( model_output.is_block_step, block_size, ) - # print(f"============================================update_inputs_v1 更新后=========================================") - # print(f"model_output.stop_flags : {model_output.stop_flags}") - # print(f"model_output.not_need_stop : {model_output.not_need_stop}") - # print(f"model_output.seq_lens_this_time : {model_output.seq_lens_this_time}") - # print(f"model_output.seq_lens_encoder : {model_output.seq_lens_encoder}") - # print(f"model_output.seq_lens_decoder : {model_output.seq_lens_decoder}") - # print(f"share_inputs['step_seq_lens_decoder'] : {share_inputs['step_seq_lens_decoder']}") - # print(f"share_inputs['prompt_lens'] : {share_inputs['prompt_lens']}") - # print(f"sampled_token_ids : {sampled_token_ids}") - # print(f"model_output.input_ids : {model_output.input_ids}") - # print(f"model_output.stop_nums : {model_output.stop_nums}") - # print(f"model_output.next_tokens : {model_output.next_tokens}") - # print(f"model_output.is_block_step : {model_output.is_block_step}") - # print(f"share_inputs['block_tables'] : {share_inputs['block_tables']}") - # print(f"block_size : {block_size}") + else: update_inputs( model_output.stop_flags, @@ -383,15 +331,18 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = req_len = len(req_dicts) has_prefill_task = False + has_decode_task = False for i in range(req_len): request = req_dicts[i] idx = request.idx if request.task_type.value == RequestType.PREFILL.value: # prefill task - logger.debug(f"Handle prefill request {request} at idx {idx}") prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index input_ids = request.prompt_token_ids + request.output_token_ids + logger.debug( + f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}" + ) self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( input_ids[prefill_start_index:prefill_end_index] ) @@ -401,6 +352,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) + if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode + has_decode_task = True self.share_inputs["stop_flags"][idx : idx + 1] = False self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index self.seq_lens_this_time_buffer[idx : idx + 1] = length @@ -459,7 +412,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( request.get("stop_token_ids"), dtype="int64" ) - if has_prefill_task: + if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]