Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 8 additions & 55 deletions fastdeploy/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,6 @@ def xpu_pre_process(
xpu_forward_meta.total_enc_len,
) = get_infer_param(seq_lens_encoder, seq_lens_decoder)

# Adjust batch
# print(f"=========================adjust_batch 更新前=========================")
# print(f"ids_remove_padding : {ids_remove_padding}")
# print(f"cum_offsets : {cum_offsets}")
# print(f"xpu_forward_meta.encoder_seq_lod : {xpu_forward_meta.encoder_seq_lod}")
# print(f"xpu_forward_meta.encoder_batch_idx: {xpu_forward_meta.encoder_batch_idx}")
# print(f"xpu_forward_meta.decoder_batch_idx : {xpu_forward_meta.decoder_batch_idx}")
# print(f"xpu_forward_meta.encoder_seq_lod_cpu : {xpu_forward_meta.encoder_seq_lod_cpu}")
# print(f"xpu_forward_meta.encoder_batch_idx_cpu : {xpu_forward_meta.encoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.decoder_batch_idx_cpu : {xpu_forward_meta.decoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.enc_batch : {xpu_forward_meta.encoder_batch_map}")
# print(f"xpu_forward_meta.dec_batch : {xpu_forward_meta.decoder_batch_map}")

adjusted_input = adjust_batch(
ids_remove_padding.reshape([-1, 1]),
cum_offsets,
Expand All @@ -144,16 +131,6 @@ def xpu_pre_process(
None, # output_padding_offset
-1, # max_input_length
)
# print(f"=========================adjust_batch 更新后=========================")
# print(f"ids_remove_padding : {ids_remove_padding}")
# print(f"cum_offsets : {cum_offsets}")
# print(f"xpu_forward_meta.encoder_seq_lod : {xpu_forward_meta.encoder_seq_lod}")
# print(f"xpu_forward_meta.encoder_batch_idx: {xpu_forward_meta.encoder_batch_idx}")
# print(f"xpu_forward_meta.decoder_batch_idx : {xpu_forward_meta.decoder_batch_idx}")
# print(f"xpu_forward_meta.encoder_seq_lod_cpu : {xpu_forward_meta.encoder_seq_lod_cpu}")
# print(f"xpu_forward_meta.encoder_batch_idx_cpu : {xpu_forward_meta.encoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.decoder_batch_idx_cpu : {xpu_forward_meta.decoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.enc_batch : {xpu_forward_meta.encoder_batch_map}")

adjusted_input = adjusted_input.squeeze(1)

Expand Down Expand Up @@ -228,21 +205,6 @@ def xpu_post_process(
with paddle.framework._no_check_dy2st_diff():
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output:

# print(f"============================================update_inputs_v1 更新前=========================================")
# print(f"model_output.stop_flags : {model_output.stop_flags}")
# print(f"model_output.not_need_stop : {model_output.not_need_stop}")
# print(f"model_output.seq_lens_this_time : {model_output.seq_lens_this_time}")
# print(f"model_output.seq_lens_encoder : {model_output.seq_lens_encoder}")
# print(f"model_output.seq_lens_decoder : {model_output.seq_lens_decoder}")
# print(f"share_inputs['step_seq_lens_decoder'] : {share_inputs['step_seq_lens_decoder']}")
# print(f"share_inputs['prompt_lens'] : {share_inputs['prompt_lens']}")
# print(f"sampled_token_ids : {sampled_token_ids}")
# print(f"model_output.input_ids : {model_output.input_ids}")
# print(f"model_output.stop_nums : {model_output.stop_nums}")
# print(f"model_output.next_tokens : {model_output.next_tokens}")
# print(f"model_output.is_block_step : {model_output.is_block_step}")
# print(f"share_inputs['block_tables'] : {share_inputs['block_tables']}")
# print(f"block_size : {block_size}")
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop,
Expand All @@ -259,21 +221,7 @@ def xpu_post_process(
model_output.is_block_step,
block_size,
)
# print(f"============================================update_inputs_v1 更新后=========================================")
# print(f"model_output.stop_flags : {model_output.stop_flags}")
# print(f"model_output.not_need_stop : {model_output.not_need_stop}")
# print(f"model_output.seq_lens_this_time : {model_output.seq_lens_this_time}")
# print(f"model_output.seq_lens_encoder : {model_output.seq_lens_encoder}")
# print(f"model_output.seq_lens_decoder : {model_output.seq_lens_decoder}")
# print(f"share_inputs['step_seq_lens_decoder'] : {share_inputs['step_seq_lens_decoder']}")
# print(f"share_inputs['prompt_lens'] : {share_inputs['prompt_lens']}")
# print(f"sampled_token_ids : {sampled_token_ids}")
# print(f"model_output.input_ids : {model_output.input_ids}")
# print(f"model_output.stop_nums : {model_output.stop_nums}")
# print(f"model_output.next_tokens : {model_output.next_tokens}")
# print(f"model_output.is_block_step : {model_output.is_block_step}")
# print(f"share_inputs['block_tables'] : {share_inputs['block_tables']}")
# print(f"block_size : {block_size}")

else:
update_inputs(
model_output.stop_flags,
Expand Down Expand Up @@ -383,15 +331,18 @@ def insert_tasks_v1(self, req_dicts: List[Request]):

req_len = len(req_dicts)
has_prefill_task = False
has_decode_task = False
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task
logger.debug(f"Handle prefill request {request} at idx {idx}")
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
input_ids = request.prompt_token_ids + request.output_token_ids
logger.debug(
f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}"
)
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
input_ids[prefill_start_index:prefill_end_index]
)
Expand All @@ -401,6 +352,8 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode
has_decode_task = True
self.share_inputs["stop_flags"][idx : idx + 1] = False
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
Expand Down Expand Up @@ -460,7 +413,7 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
request.get("stop_token_ids"), dtype="int64"
)
if has_prefill_task:
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True

def process_prefill_inputs(self, req_dicts: List[Request]):
Expand Down
Loading