Skip to content

Commit cecf45f

Browse files
lfr-0531litaotju
authored andcommitted
[https://nvbugs/5481385][fix] Fix max_seq_len in cuda graph warmup and intermediate_size in fused_moe_deepgemm (NVIDIA#7345)
Signed-off-by: Fanrong Li <[email protected]> Co-authored-by: Tao Li @ NVIDIA <[email protected]>
1 parent 6aa6dcb commit cecf45f

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def __init__(
411411

412412
def get_workspace(self, m_max: int, group_size: int):
413413
hidden_size = self.hidden_size
414-
intermediate_size = self.intermediate_size
414+
intermediate_size = self.intermediate_size_per_partition
415415
num_experts = self.expert_size_per_partition
416416

417417
# create workspace
@@ -564,7 +564,7 @@ def forward_chunk(
564564
# grouped gemm 1
565565
h1 = set_strides(workspace["workspace_1"],
566566
self.expert_size_per_partition, m_max,
567-
self.intermediate_size * 2)
567+
self.intermediate_size_per_partition * 2)
568568

569569
deepgemm_fp8_group_blockwise_gemm(
570570
d=h1,
@@ -579,9 +579,9 @@ def forward_chunk(
579579
# activation and quantization
580580
act_input_fp8 = set_strides(workspace["workspace_0"],
581581
self.expert_size_per_partition, m_max,
582-
self.intermediate_size)
582+
self.intermediate_size_per_partition)
583583

584-
scale_k = fp8_utils.ceil_div(self.intermediate_size, 128)
584+
scale_k = fp8_utils.ceil_div(self.intermediate_size_per_partition, 128)
585585
scale_k_padded = fp8_utils.align(scale_k, 4)
586586
act_input_sf = set_strides(workspace["workspace_sf"],
587587
self.expert_size_per_partition,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,15 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
583583

584584
# Add one dummy request with the maximum possible sequence length.
585585
# The sequence length is limited by both the max_seq_len and the number of available blocks.
586+
# Also, the sequence length is limited by the max_position_embeddings.
586587
token_num = max(1, min(available_tokens, self.max_seq_len - 1))
588+
model_config = self.model.model_config.pretrained_config
589+
max_position_embeddings = getattr(model_config,
590+
'max_position_embeddings',
591+
None)
592+
if max_position_embeddings is not None:
593+
token_num = min(token_num,
594+
max_position_embeddings - draft_len)
587595
max_seq_len_request = kv_cache_manager.add_dummy_requests(
588596
request_ids=[batch_size - 1],
589597
token_nums=[token_num],

0 commit comments

Comments
 (0)