1
1
import copy
2
+ import os
2
3
from typing import Dict , List , Optional , Tuple , Union
3
4
4
5
import torch
@@ -337,7 +338,7 @@ def forward(
337
338
assert shared_output .size () == routed_output .size (
338
339
), f'unmatched tensor shape'
339
340
final_hidden_states = shared_output + routed_output
340
- if not self .enable_attention_dp and self .mapping .tp_size > 1 :
341
+ if not self .enable_attention_dp and self .mapping .has_tp () :
341
342
final_hidden_states = self .all_reduce (
342
343
final_hidden_states , all_reduce_params = final_all_reduce_params )
343
344
@@ -367,9 +368,6 @@ def __init__(
367
368
self .fusion_config = EagerFusionConfig ()
368
369
# self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp(
369
370
# )
370
- # TODO: re-enable these fusions
371
- self .fusion_config .PRE_MOE_FUSION = False
372
- self .fusion_config .POST_MLP_FUSION = False
373
371
374
372
nope_layer = config .no_rope_layers [layer_idx ] == 0
375
373
attention_chunk_size = getattr (config , "attention_chunk_size" ,
@@ -387,6 +385,20 @@ def __init__(
387
385
self .is_mlp_layer = (layer_idx +
388
386
1 ) % config .interleave_moe_layer_step != 0
389
387
388
+ self .enable_fusion = os .environ .get (
389
+ "TRTLLM_LLAMA_EAGER_FUSION_DISABLED" , "0" ) == "0"
390
+
391
+ if self .is_nvfp4 :
392
+ self .pre_feed_forward_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4
393
+ self .post_feed_forward_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4
394
+ # TODO: enable fp8 quant fusion later
395
+ # elif self.is_fp8_quant:
396
+ # self.pre_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
397
+ # self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
398
+ else :
399
+ self .pre_feed_forward_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
400
+ self .post_feed_forward_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
401
+
390
402
if self .is_mlp_layer :
391
403
self .feed_forward = GatedMLP (
392
404
hidden_size = config .hidden_size ,
@@ -399,8 +411,11 @@ def __init__(
399
411
layer_idx = layer_idx ,
400
412
)
401
413
402
- # self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp(
403
- # )
414
+ self .fusion_config .PRE_MLP_FUSION = model_config .mapping .has_tp (
415
+ ) and not self .enable_attention_dp and self .enable_fusion
416
+ self .fusion_config .POST_MLP_FUSION = model_config .mapping .has_tp (
417
+ ) and not self .enable_attention_dp and self .enable_fusion
418
+
404
419
else :
405
420
self .feed_forward = Llama4MoE (
406
421
num_experts = config .num_local_experts ,
@@ -413,8 +428,10 @@ def __init__(
413
428
dtype = config .torch_dtype ,
414
429
layer_idx = layer_idx )
415
430
416
- # self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp(
417
- # )
431
+ self .fusion_config .PRE_MOE_FUSION = model_config .mapping .has_tp (
432
+ ) and not self .enable_attention_dp and self .enable_fusion
433
+ self .fusion_config .POST_MOE_FUSION = model_config .mapping .has_tp (
434
+ ) and not self .enable_attention_dp and self .enable_fusion
418
435
419
436
self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
420
437
eps = config .rms_norm_eps ,
@@ -432,6 +449,17 @@ def __init__(
432
449
433
450
self .moe_allreduce = MoEAllReduce (self .mapping )
434
451
452
+ self .disable_attn_allreduce = (self .fusion_config .PRE_MOE_FUSION
453
+ or self .fusion_config .PRE_MLP_FUSION
454
+ or self .mapping .tp_size == 1
455
+ or self .enable_attention_dp )
456
+ self .disable_feed_forward_allreduce = (
457
+ self .fusion_config .POST_MOE_FUSION
458
+ or self .fusion_config .POST_MLP_FUSION or self .mapping .tp_size == 1
459
+ or self .enable_attention_dp )
460
+
461
+ print (f"init Llama4DecoderLayer" )
462
+
435
463
def forward (
436
464
self ,
437
465
position_ids : torch .IntTensor ,
@@ -461,34 +489,43 @@ def forward(
461
489
position_ids = position_ids ,
462
490
hidden_states = hidden_states ,
463
491
attn_metadata = attn_metadata ,
464
- all_reduce_params = AllReduceParams (enable_allreduce = not (
465
- self .fusion_config .PRE_MOE_FUSION or self .mapping .tp_size == 1
466
- or self .enable_attention_dp )),
492
+ all_reduce_params = AllReduceParams (
493
+ enable_allreduce = not self .disable_attn_allreduce ),
467
494
** kwargs ,
468
495
)
469
496
470
- if self .fusion_config .PRE_MOE_FUSION :
471
- hidden_states , residual = self .all_reduce (
497
+ if self .is_nvfp4 or self .is_fp8_quant :
498
+ scale = self .self_attn .qkv_proj .input_scale
499
+ else :
500
+ scale = None
501
+
502
+ if self .fusion_config .PRE_MLP_FUSION or self .fusion_config .PRE_MOE_FUSION :
503
+ allreduce_output = self .all_reduce (
472
504
hidden_states ,
473
505
all_reduce_params = AllReduceParams (
474
- fusion_op = AllReduceFusionOp . RESIDUAL_RMS_NORM ,
506
+ fusion_op = self . pre_feed_forward_fusion_op ,
475
507
residual = residual ,
476
508
norm_weight = self .post_attention_layernorm .weight ,
509
+ scale = scale ,
477
510
eps = self .post_attention_layernorm .variance_epsilon ,
478
511
))
479
512
else :
480
513
# Fully Connected
481
- hidden_states , residual = self .post_attention_layernorm (
514
+ allreduce_output = self .post_attention_layernorm (
482
515
hidden_states , residual )
483
516
517
+ if self .is_nvfp4 :
518
+ act_fp4 , act_sf , residual = allreduce_output
519
+ hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
520
+ else :
521
+ hidden_states , residual = allreduce_output
522
+
484
523
hidden_states = self .feed_forward (
485
524
hidden_states ,
486
525
all_rank_num_tokens = attn_metadata .all_rank_num_tokens ,
487
526
all_rank_max_num_tokens = attn_metadata .all_rank_max_num_tokens ,
488
- final_all_reduce_params = AllReduceParams (enable_allreduce = not (
489
- self .fusion_config .POST_MOE_FUSION
490
- or self .fusion_config .POST_MLP_FUSION
491
- or self .mapping .tp_size == 1 or self .enable_attention_dp )),
527
+ final_all_reduce_params = AllReduceParams (
528
+ enable_allreduce = not self .disable_feed_forward_allreduce ),
492
529
cutlass_min_latency_mode = cutlass_min_latency_mode ,
493
530
)
494
531
@@ -500,16 +537,23 @@ def forward(
500
537
spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
501
538
hidden_states , residual )
502
539
503
- if (self .fusion_config .POST_MOE_FUSION
540
+ if (
541
+ self .fusion_config .POST_MOE_FUSION
504
542
or self .fusion_config .POST_MLP_FUSION
505
- ) and self .next_layer_layernorm is not None :
543
+ ) and self .next_layer_layernorm is not None and self .next_attn is not None :
544
+ # Get the scale for the next allreduce fusion op
545
+ if self .is_nvfp4 or self .is_fp8_quant :
546
+ scale = self .next_attn .qkv_proj .input_scale
547
+ else :
548
+ scale = None
549
+
506
550
if cutlass_min_latency_mode :
507
551
shared_output = hidden_states [0 ]
508
552
hidden_states_activated_experts = hidden_states [1 ]
509
553
num_activated_experts_per_node = hidden_states [2 ]
510
554
experts_to_token_score = hidden_states [3 ]
511
555
512
- hidden_states , residual = self .moe_allreduce (
556
+ allreduce_output = self .moe_allreduce (
513
557
residual ,
514
558
self .next_layer_layernorm .weight ,
515
559
device_num_experts = num_activated_experts_per_node ,
@@ -519,18 +563,30 @@ def forward(
519
563
eps = self .next_layer_layernorm .variance_epsilon ,
520
564
)
521
565
else :
522
- hidden_states , residual = self .all_reduce (
523
- hidden_states ,
524
- all_reduce_params = AllReduceParams (
525
- fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM ,
526
- residual = residual ,
527
- norm_weight = self .next_layer_layernorm .weight ,
528
- eps = self .next_layer_layernorm .variance_epsilon ,
529
- ))
566
+ if self .fusion_config .POST_MLP_FUSION or self .fusion_config .POST_MOE_FUSION :
567
+ allreduce_output = self .all_reduce (
568
+ hidden_states ,
569
+ all_reduce_params = AllReduceParams (
570
+ fusion_op = self .post_feed_forward_fusion_op ,
571
+ residual = residual ,
572
+ norm_weight = self .next_layer_layernorm .weight ,
573
+ scale = scale ,
574
+ eps = self .next_layer_layernorm .variance_epsilon ,
575
+ ))
576
+ else :
577
+ raise ValueError ("Unknown fusion config" )
530
578
elif self .next_layer_layernorm :
531
- hidden_states , residual = self .next_layer_layernorm (
579
+ print (f"{ self .layer_idx } , { self .next_layer_layernorm } " )
580
+ allreduce_output = self .next_layer_layernorm (
532
581
hidden_states , residual )
533
582
583
+ print (f"in forward" )
584
+ if self .is_nvfp4 :
585
+ act_fp4 , act_sf , residual = allreduce_output
586
+ hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
587
+ else :
588
+ hidden_states , residual = allreduce_output
589
+
534
590
return hidden_states , residual
535
591
536
592
@@ -544,6 +600,14 @@ def __init__(
544
600
super ().__init__ ()
545
601
config = model_config .pretrained_config
546
602
self .layer_idx = layer_idx
603
+ self .mapping = model_config .mapping
604
+ self .enable_attention_dp = model_config .mapping .enable_attention_dp
605
+ self .is_quanted = model_config .quant_config and model_config .quant_config .quant_mode .has_any_quant (
606
+ )
607
+ self .is_fp8_quant = self .is_quanted and model_config .quant_config .quant_mode .has_fp8_qdq (
608
+ )
609
+ self .is_nvfp4 = self .is_quanted and model_config .quant_config .quant_mode .has_nvfp4 (
610
+ )
547
611
548
612
self .self_attn = LlamaAttention (
549
613
model_config ,
@@ -566,11 +630,43 @@ def __init__(
566
630
eps = config .rms_norm_eps ,
567
631
dtype = config .torch_dtype )
568
632
633
+ self .all_reduce = AllReduce (mapping = model_config .mapping )
634
+
635
+ self .next_layer_layernorm : RMSNorm = None
636
+ self .next_attn : LlamaAttention = None
637
+
569
638
self .attention_mask = PredefinedAttentionMask .CAUSAL
570
639
# If the model is being used as an encoder model (prefill only) we use a full attention mask
571
640
if not model_config .is_generation :
572
641
self .attention_mask = PredefinedAttentionMask .FULL
573
642
643
+ self .enable_fusion = os .environ .get (
644
+ "TRTLLM_LLAMA_EAGER_FUSION_DISABLED" , "0" ) == "0"
645
+ self .PRE_MLP_FUSION = self .mapping .has_tp (
646
+ ) and not self .enable_attention_dp and self .enable_fusion
647
+ self .POST_MLP_FUSION = self .mapping .has_tp () and self .enable_fusion
648
+
649
+ if self .is_nvfp4 :
650
+ self .pre_mlp_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4
651
+ self .post_mlp_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4
652
+ # TODO: enable fp8 quant fusion later
653
+ # elif self.is_fp8_quant:
654
+ # self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
655
+ # self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
656
+ else :
657
+ self .pre_mlp_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
658
+ self .post_mlp_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
659
+
660
+ # TODO: Disable this to avoid large accuracy drop
661
+ self .POST_MLP_FUSION = False
662
+
663
+ self .disable_attn_allreduce = (self .PRE_MLP_FUSION
664
+ or self .mapping .tp_size == 1
665
+ or self .enable_attention_dp )
666
+ self .disable_mlp_allreduce = (self .POST_MLP_FUSION
667
+ or self .mapping .tp_size == 1
668
+ or self .enable_attention_dp )
669
+
574
670
def forward (
575
671
self ,
576
672
position_ids : torch .IntTensor ,
@@ -583,30 +679,78 @@ def forward(
583
679
if residual is None :
584
680
residual = hidden_states
585
681
hidden_states = self .input_layernorm (hidden_states )
586
- else :
587
- hidden_states , residual = self .input_layernorm (
588
- hidden_states , residual )
589
682
590
683
# Self Attention
591
684
hidden_states = self .self_attn (
592
685
position_ids = position_ids ,
593
686
hidden_states = hidden_states ,
594
687
attn_metadata = attn_metadata ,
595
688
attention_mask = self .attention_mask ,
689
+ all_reduce_params = AllReduceParams (
690
+ enable_allreduce = not self .disable_attn_allreduce ),
596
691
** kwargs ,
597
692
)
598
-
599
693
# Fully Connected
600
- hidden_states , residual = self .post_attention_layernorm (
601
- hidden_states , residual )
602
- hidden_states = self .mlp (hidden_states , ** kwargs )
694
+ if self .PRE_MLP_FUSION :
695
+ if self .is_nvfp4 :
696
+ scale = self .mlp .gate_up_proj .input_scale
697
+ else :
698
+ scale = None
699
+ all_reduce_output = self .all_reduce (
700
+ hidden_states ,
701
+ all_reduce_params = AllReduceParams (
702
+ fusion_op = self .pre_mlp_fusion_op ,
703
+ residual = residual ,
704
+ norm_weight = self .post_attention_layernorm .weight ,
705
+ scale = scale ,
706
+ eps = self .post_attention_layernorm .variance_epsilon ,
707
+ ))
708
+ if self .is_nvfp4 :
709
+ act_fp4 , act_sf , residual = all_reduce_output
710
+ hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
711
+ else :
712
+ hidden_states , residual = all_reduce_output
713
+ else :
714
+ hidden_states , residual = self .post_attention_layernorm (
715
+ hidden_states , residual )
716
+
717
+ hidden_states = self .mlp (
718
+ hidden_states ,
719
+ final_all_reduce_params = AllReduceParams (
720
+ enable_allreduce = not self .disable_mlp_allreduce ),
721
+ ** kwargs ,
722
+ )
723
+
603
724
if spec_metadata is not None :
604
725
# We save the hidden states in the spec metadata here. In _prepare_draft_tokens,
605
726
# PyExecutor will extract these from the model engine's spec metadata.
606
727
# They will be passed to the draft model engine on the first draft iteration.
607
728
# TODO: can we support multiple model outputs instead?
608
729
spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
609
730
hidden_states , residual )
731
+ if self .POST_MLP_FUSION and self .next_attn is not None :
732
+ if self .is_nvfp4 :
733
+ scale = self .next_attn .qkv_proj .input_scale
734
+ else :
735
+ scale = None
736
+ all_reduce_output = self .all_reduce (
737
+ hidden_states ,
738
+ all_reduce_params = AllReduceParams (
739
+ fusion_op = self .post_mlp_fusion_op ,
740
+ residual = residual ,
741
+ norm_weight = self .next_layer_layernorm .weight ,
742
+ scale = scale ,
743
+ eps = self .next_layer_layernorm .variance_epsilon ,
744
+ ))
745
+ if self .is_nvfp4 :
746
+ act_fp4 , act_sf , residual = all_reduce_output
747
+ hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
748
+ else :
749
+ hidden_states , residual = all_reduce_output
750
+ elif self .next_layer_layernorm :
751
+ hidden_states , residual = self .next_layer_layernorm (
752
+ hidden_states , residual )
753
+
610
754
return hidden_states , residual
611
755
612
756
@@ -727,7 +871,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
727
871
728
872
if self .has_custom_embed_tokens :
729
873
with torch .no_grad ():
730
- if model_config .mapping .tp_size > 1 :
874
+ if model_config .mapping .has_tp () :
731
875
weight = split_matrix_tp (
732
876
weight ,
733
877
model_config .mapping .tp_size ,
@@ -775,7 +919,6 @@ def forward(
775
919
lora_params = lora_params ,
776
920
)
777
921
778
- hidden_states , _ = self .norm (hidden_states , residual )
779
922
return hidden_states
780
923
781
924
@@ -788,6 +931,18 @@ def __init__(
788
931
):
789
932
super ().__init__ (LlamaModel (model_config ), model_config )
790
933
934
+ def load_weights (self , weights : Dict ):
935
+ super ().load_weights (weights )
936
+
937
+ for idx , layer in enumerate (
938
+ self .model .layers [:self .config .num_hidden_layers ]):
939
+ if idx == self .config .num_hidden_layers - 1 :
940
+ layer .next_layer_layernorm = self .model .norm
941
+ else :
942
+ layer .next_layer_layernorm = self .model .layers [
943
+ idx + 1 ].input_layernorm
944
+ layer .next_attn = self .model .layers [idx + 1 ].self_attn
945
+
791
946
792
947
class Llama4InputProcessor (InputProcessor ):
793
948
0 commit comments