Skip to content

Commit a4d3136

Browse files
committed
[TRTLLM-6445] feat: Enable AllReduce associated fusion patterns in Llama3/4.
Signed-off-by: Yukun He <[email protected]>
1 parent 48ddc3d commit a4d3136

File tree

1 file changed

+195
-40
lines changed

1 file changed

+195
-40
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 195 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import os
23
from typing import Dict, List, Optional, Tuple, Union
34

45
import torch
@@ -337,7 +338,7 @@ def forward(
337338
assert shared_output.size() == routed_output.size(
338339
), f'unmatched tensor shape'
339340
final_hidden_states = shared_output + routed_output
340-
if not self.enable_attention_dp and self.mapping.tp_size > 1:
341+
if not self.enable_attention_dp and self.mapping.has_tp():
341342
final_hidden_states = self.all_reduce(
342343
final_hidden_states, all_reduce_params=final_all_reduce_params)
343344

@@ -367,9 +368,6 @@ def __init__(
367368
self.fusion_config = EagerFusionConfig()
368369
# self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp(
369370
# )
370-
# TODO: re-enable these fusions
371-
self.fusion_config.PRE_MOE_FUSION = False
372-
self.fusion_config.POST_MLP_FUSION = False
373371

374372
nope_layer = config.no_rope_layers[layer_idx] == 0
375373
attention_chunk_size = getattr(config, "attention_chunk_size",
@@ -387,6 +385,20 @@ def __init__(
387385
self.is_mlp_layer = (layer_idx +
388386
1) % config.interleave_moe_layer_step != 0
389387

388+
self.enable_fusion = os.environ.get(
389+
"TRTLLM_LLAMA_EAGER_FUSION_DISABLED", "0") == "0"
390+
391+
if self.is_nvfp4:
392+
self.pre_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
393+
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
394+
# TODO: enable fp8 quant fusion later
395+
# elif self.is_fp8_quant:
396+
# self.pre_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
397+
# self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
398+
else:
399+
self.pre_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
400+
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
401+
390402
if self.is_mlp_layer:
391403
self.feed_forward = GatedMLP(
392404
hidden_size=config.hidden_size,
@@ -399,8 +411,11 @@ def __init__(
399411
layer_idx=layer_idx,
400412
)
401413

402-
# self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp(
403-
# )
414+
self.fusion_config.PRE_MLP_FUSION = model_config.mapping.has_tp(
415+
) and not self.enable_attention_dp and self.enable_fusion
416+
self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp(
417+
) and not self.enable_attention_dp and self.enable_fusion
418+
404419
else:
405420
self.feed_forward = Llama4MoE(
406421
num_experts=config.num_local_experts,
@@ -413,8 +428,10 @@ def __init__(
413428
dtype=config.torch_dtype,
414429
layer_idx=layer_idx)
415430

416-
# self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp(
417-
# )
431+
self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp(
432+
) and not self.enable_attention_dp and self.enable_fusion
433+
self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp(
434+
) and not self.enable_attention_dp and self.enable_fusion
418435

419436
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
420437
eps=config.rms_norm_eps,
@@ -432,6 +449,17 @@ def __init__(
432449

433450
self.moe_allreduce = MoEAllReduce(self.mapping)
434451

452+
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
453+
or self.fusion_config.PRE_MLP_FUSION
454+
or self.mapping.tp_size == 1
455+
or self.enable_attention_dp)
456+
self.disable_feed_forward_allreduce = (
457+
self.fusion_config.POST_MOE_FUSION
458+
or self.fusion_config.POST_MLP_FUSION or self.mapping.tp_size == 1
459+
or self.enable_attention_dp)
460+
461+
print(f"init Llama4DecoderLayer")
462+
435463
def forward(
436464
self,
437465
position_ids: torch.IntTensor,
@@ -461,34 +489,43 @@ def forward(
461489
position_ids=position_ids,
462490
hidden_states=hidden_states,
463491
attn_metadata=attn_metadata,
464-
all_reduce_params=AllReduceParams(enable_allreduce=not (
465-
self.fusion_config.PRE_MOE_FUSION or self.mapping.tp_size == 1
466-
or self.enable_attention_dp)),
492+
all_reduce_params=AllReduceParams(
493+
enable_allreduce=not self.disable_attn_allreduce),
467494
**kwargs,
468495
)
469496

470-
if self.fusion_config.PRE_MOE_FUSION:
471-
hidden_states, residual = self.all_reduce(
497+
if self.is_nvfp4 or self.is_fp8_quant:
498+
scale = self.self_attn.qkv_proj.input_scale
499+
else:
500+
scale = None
501+
502+
if self.fusion_config.PRE_MLP_FUSION or self.fusion_config.PRE_MOE_FUSION:
503+
allreduce_output = self.all_reduce(
472504
hidden_states,
473505
all_reduce_params=AllReduceParams(
474-
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
506+
fusion_op=self.pre_feed_forward_fusion_op,
475507
residual=residual,
476508
norm_weight=self.post_attention_layernorm.weight,
509+
scale=scale,
477510
eps=self.post_attention_layernorm.variance_epsilon,
478511
))
479512
else:
480513
# Fully Connected
481-
hidden_states, residual = self.post_attention_layernorm(
514+
allreduce_output = self.post_attention_layernorm(
482515
hidden_states, residual)
483516

517+
if self.is_nvfp4:
518+
act_fp4, act_sf, residual = allreduce_output
519+
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
520+
else:
521+
hidden_states, residual = allreduce_output
522+
484523
hidden_states = self.feed_forward(
485524
hidden_states,
486525
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
487526
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
488-
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
489-
self.fusion_config.POST_MOE_FUSION
490-
or self.fusion_config.POST_MLP_FUSION
491-
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
527+
final_all_reduce_params=AllReduceParams(
528+
enable_allreduce=not self.disable_feed_forward_allreduce),
492529
cutlass_min_latency_mode=cutlass_min_latency_mode,
493530
)
494531

@@ -500,16 +537,23 @@ def forward(
500537
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
501538
hidden_states, residual)
502539

503-
if (self.fusion_config.POST_MOE_FUSION
540+
if (
541+
self.fusion_config.POST_MOE_FUSION
504542
or self.fusion_config.POST_MLP_FUSION
505-
) and self.next_layer_layernorm is not None:
543+
) and self.next_layer_layernorm is not None and self.next_attn is not None:
544+
# Get the scale for the next allreduce fusion op
545+
if self.is_nvfp4 or self.is_fp8_quant:
546+
scale = self.next_attn.qkv_proj.input_scale
547+
else:
548+
scale = None
549+
506550
if cutlass_min_latency_mode:
507551
shared_output = hidden_states[0]
508552
hidden_states_activated_experts = hidden_states[1]
509553
num_activated_experts_per_node = hidden_states[2]
510554
experts_to_token_score = hidden_states[3]
511555

512-
hidden_states, residual = self.moe_allreduce(
556+
allreduce_output = self.moe_allreduce(
513557
residual,
514558
self.next_layer_layernorm.weight,
515559
device_num_experts=num_activated_experts_per_node,
@@ -519,18 +563,30 @@ def forward(
519563
eps=self.next_layer_layernorm.variance_epsilon,
520564
)
521565
else:
522-
hidden_states, residual = self.all_reduce(
523-
hidden_states,
524-
all_reduce_params=AllReduceParams(
525-
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
526-
residual=residual,
527-
norm_weight=self.next_layer_layernorm.weight,
528-
eps=self.next_layer_layernorm.variance_epsilon,
529-
))
566+
if self.fusion_config.POST_MLP_FUSION or self.fusion_config.POST_MOE_FUSION:
567+
allreduce_output = self.all_reduce(
568+
hidden_states,
569+
all_reduce_params=AllReduceParams(
570+
fusion_op=self.post_feed_forward_fusion_op,
571+
residual=residual,
572+
norm_weight=self.next_layer_layernorm.weight,
573+
scale=scale,
574+
eps=self.next_layer_layernorm.variance_epsilon,
575+
))
576+
else:
577+
raise ValueError("Unknown fusion config")
530578
elif self.next_layer_layernorm:
531-
hidden_states, residual = self.next_layer_layernorm(
579+
print(f"{self.layer_idx}, {self.next_layer_layernorm}")
580+
allreduce_output = self.next_layer_layernorm(
532581
hidden_states, residual)
533582

583+
print(f"in forward")
584+
if self.is_nvfp4:
585+
act_fp4, act_sf, residual = allreduce_output
586+
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
587+
else:
588+
hidden_states, residual = allreduce_output
589+
534590
return hidden_states, residual
535591

536592

@@ -544,6 +600,14 @@ def __init__(
544600
super().__init__()
545601
config = model_config.pretrained_config
546602
self.layer_idx = layer_idx
603+
self.mapping = model_config.mapping
604+
self.enable_attention_dp = model_config.mapping.enable_attention_dp
605+
self.is_quanted = model_config.quant_config and model_config.quant_config.quant_mode.has_any_quant(
606+
)
607+
self.is_fp8_quant = self.is_quanted and model_config.quant_config.quant_mode.has_fp8_qdq(
608+
)
609+
self.is_nvfp4 = self.is_quanted and model_config.quant_config.quant_mode.has_nvfp4(
610+
)
547611

548612
self.self_attn = LlamaAttention(
549613
model_config,
@@ -566,11 +630,43 @@ def __init__(
566630
eps=config.rms_norm_eps,
567631
dtype=config.torch_dtype)
568632

633+
self.all_reduce = AllReduce(mapping=model_config.mapping)
634+
635+
self.next_layer_layernorm: RMSNorm = None
636+
self.next_attn: LlamaAttention = None
637+
569638
self.attention_mask = PredefinedAttentionMask.CAUSAL
570639
# If the model is being used as an encoder model (prefill only) we use a full attention mask
571640
if not model_config.is_generation:
572641
self.attention_mask = PredefinedAttentionMask.FULL
573642

643+
self.enable_fusion = os.environ.get(
644+
"TRTLLM_LLAMA_EAGER_FUSION_DISABLED", "0") == "0"
645+
self.PRE_MLP_FUSION = self.mapping.has_tp(
646+
) and not self.enable_attention_dp and self.enable_fusion
647+
self.POST_MLP_FUSION = self.mapping.has_tp() and self.enable_fusion
648+
649+
if self.is_nvfp4:
650+
self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
651+
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
652+
# TODO: enable fp8 quant fusion later
653+
# elif self.is_fp8_quant:
654+
# self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
655+
# self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
656+
else:
657+
self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
658+
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
659+
660+
# TODO: Disable this to avoid large accuracy drop
661+
self.POST_MLP_FUSION = False
662+
663+
self.disable_attn_allreduce = (self.PRE_MLP_FUSION
664+
or self.mapping.tp_size == 1
665+
or self.enable_attention_dp)
666+
self.disable_mlp_allreduce = (self.POST_MLP_FUSION
667+
or self.mapping.tp_size == 1
668+
or self.enable_attention_dp)
669+
574670
def forward(
575671
self,
576672
position_ids: torch.IntTensor,
@@ -583,30 +679,78 @@ def forward(
583679
if residual is None:
584680
residual = hidden_states
585681
hidden_states = self.input_layernorm(hidden_states)
586-
else:
587-
hidden_states, residual = self.input_layernorm(
588-
hidden_states, residual)
589682

590683
# Self Attention
591684
hidden_states = self.self_attn(
592685
position_ids=position_ids,
593686
hidden_states=hidden_states,
594687
attn_metadata=attn_metadata,
595688
attention_mask=self.attention_mask,
689+
all_reduce_params=AllReduceParams(
690+
enable_allreduce=not self.disable_attn_allreduce),
596691
**kwargs,
597692
)
598-
599693
# Fully Connected
600-
hidden_states, residual = self.post_attention_layernorm(
601-
hidden_states, residual)
602-
hidden_states = self.mlp(hidden_states, **kwargs)
694+
if self.PRE_MLP_FUSION:
695+
if self.is_nvfp4:
696+
scale = self.mlp.gate_up_proj.input_scale
697+
else:
698+
scale = None
699+
all_reduce_output = self.all_reduce(
700+
hidden_states,
701+
all_reduce_params=AllReduceParams(
702+
fusion_op=self.pre_mlp_fusion_op,
703+
residual=residual,
704+
norm_weight=self.post_attention_layernorm.weight,
705+
scale=scale,
706+
eps=self.post_attention_layernorm.variance_epsilon,
707+
))
708+
if self.is_nvfp4:
709+
act_fp4, act_sf, residual = all_reduce_output
710+
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
711+
else:
712+
hidden_states, residual = all_reduce_output
713+
else:
714+
hidden_states, residual = self.post_attention_layernorm(
715+
hidden_states, residual)
716+
717+
hidden_states = self.mlp(
718+
hidden_states,
719+
final_all_reduce_params=AllReduceParams(
720+
enable_allreduce=not self.disable_mlp_allreduce),
721+
**kwargs,
722+
)
723+
603724
if spec_metadata is not None:
604725
# We save the hidden states in the spec metadata here. In _prepare_draft_tokens,
605726
# PyExecutor will extract these from the model engine's spec metadata.
606727
# They will be passed to the draft model engine on the first draft iteration.
607728
# TODO: can we support multiple model outputs instead?
608729
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
609730
hidden_states, residual)
731+
if self.POST_MLP_FUSION and self.next_attn is not None:
732+
if self.is_nvfp4:
733+
scale = self.next_attn.qkv_proj.input_scale
734+
else:
735+
scale = None
736+
all_reduce_output = self.all_reduce(
737+
hidden_states,
738+
all_reduce_params=AllReduceParams(
739+
fusion_op=self.post_mlp_fusion_op,
740+
residual=residual,
741+
norm_weight=self.next_layer_layernorm.weight,
742+
scale=scale,
743+
eps=self.next_layer_layernorm.variance_epsilon,
744+
))
745+
if self.is_nvfp4:
746+
act_fp4, act_sf, residual = all_reduce_output
747+
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
748+
else:
749+
hidden_states, residual = all_reduce_output
750+
elif self.next_layer_layernorm:
751+
hidden_states, residual = self.next_layer_layernorm(
752+
hidden_states, residual)
753+
610754
return hidden_states, residual
611755

612756

@@ -727,7 +871,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
727871

728872
if self.has_custom_embed_tokens:
729873
with torch.no_grad():
730-
if model_config.mapping.tp_size > 1:
874+
if model_config.mapping.has_tp():
731875
weight = split_matrix_tp(
732876
weight,
733877
model_config.mapping.tp_size,
@@ -775,7 +919,6 @@ def forward(
775919
lora_params=lora_params,
776920
)
777921

778-
hidden_states, _ = self.norm(hidden_states, residual)
779922
return hidden_states
780923

781924

@@ -788,6 +931,18 @@ def __init__(
788931
):
789932
super().__init__(LlamaModel(model_config), model_config)
790933

934+
def load_weights(self, weights: Dict):
935+
super().load_weights(weights)
936+
937+
for idx, layer in enumerate(
938+
self.model.layers[:self.config.num_hidden_layers]):
939+
if idx == self.config.num_hidden_layers - 1:
940+
layer.next_layer_layernorm = self.model.norm
941+
else:
942+
layer.next_layer_layernorm = self.model.layers[
943+
idx + 1].input_layernorm
944+
layer.next_attn = self.model.layers[idx + 1].self_attn
945+
791946

792947
class Llama4InputProcessor(InputProcessor):
793948

0 commit comments

Comments
 (0)