Skip to content
Merged
15 changes: 8 additions & 7 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,12 @@ def __init__(
overridden_tp_size=1 if self.enable_attention_dp else None,
layer_idx=layer_idx,
)

# TODO(TRTLLM-7809): Fix fusion with PP>1
self.fusion_config.PRE_MLP_FUSION = model_config.mapping.has_tp(
) and not self.enable_attention_dp and self.enable_fusion
self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp(
) and not self.enable_attention_dp and self.enable_fusion
) and not self.enable_attention_dp and self.enable_fusion and not model_config.mapping.has_pp(
)
self.fusion_config.POST_MLP_FUSION = self.fusion_config.PRE_MLP_FUSION

else:
self.feed_forward = Llama4MoE(
num_experts=config.num_local_experts,
Expand All @@ -437,9 +438,9 @@ def __init__(
layer_idx=layer_idx)

self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp(
) and not self.enable_attention_dp and self.enable_fusion
self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp(
) and not self.enable_attention_dp and self.enable_fusion
) and not self.enable_attention_dp and self.enable_fusion and not model_config.mapping.has_pp(
)
self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION

self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,8 @@ def test_chunked_prefill(self, attn_backend):
@parametrize_with_ids("cuda_graph", [False, True])
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size", [(8, 1, 1), (8, 1, 4), (8, 1, 8), (4, 1, 1),
(4, 1, 2), (4, 1, 4)],
ids=["tp8", "tp8ep4", "tp8ep8", "tp4", "tp4ep2", "tp4ep4"])
(4, 1, 2), (4, 1, 4), (4, 2, 1)],
ids=["tp8", "tp8ep4", "tp8ep8", "tp4", "tp4ep2", "tp4ep4", "tp4pp2"])
def test_fp8(self, cuda_graph, tp_size, pp_size, ep_size):
if get_device_memory() < 140000 and get_device_count() < 8:
pytest.skip("Not enough memory for this test")
Expand Down