Skip to content

Commit fe8a34a

Browse files
committed
[TRTLLM-6445] feat: Enable AllReduce associated fusion patterns in Llama3/4.
* Added support for controlling fusion optimizations via environment variables. * Applied AR+Residual + RMS_NORM + Quant fp4/fp8 fusion. This is also compatible with the speculative decoding capturing in these models. * Some improvements for the two-shot allreduce kernel. * Disable fusion for small models with a hidden size no greater than 4096 to avoid accuracy drop issues. Signed-off-by: Yukun He <[email protected]>
1 parent 428e340 commit fe8a34a

File tree

2 files changed

+203
-32
lines changed

2 files changed

+203
-32
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ __global__ void __launch_bounds__(1024) allreduce_fusion_kernel_oneshot_lamport(
520520
}
521521

522522
template <AllReduceFusionPattern Pattern, typename DType, int NRanks, bool Fp32Acc>
523-
__global__ void allreduce_fusion_kernel_twoshot_sync(
523+
__global__ void __launch_bounds__(1024) allreduce_fusion_kernel_twoshot_sync(
524524
AllReduceFusionParams params, std::array<int, NRanks> begin_tokens, std::array<int, NRanks> token_num_per_ranks)
525525
{
526526
IndexHelper<DType> index_helper(params);

0 commit comments

Comments
 (0)