Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,24 @@ __global__ void __launch_bounds__(1024) allreduce_fusion_kernel_oneshot_lamport(
done &= !is_neg_zero(vals[r]);
}
}
float4 sum_val = allreduce_sum<DType, NRanks, Fp32Acc>(vals);
fused_op(sum_val, tidx);

// Handle AllGather pattern - output gathered data before reduction
if constexpr (HasAllGatherOut<Pattern>)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct way is making the common part a common function and erase "allreduce" related string from the common function. And use the common function to implement allreduce and allgather.

{
// Output gathered data: [rank0_data][rank1_data][rank2_data]...[rankN_data]
for (int r = 0; r < NRanks; ++r)
{
int output_idx = r * tot_access + idx;
reinterpret_cast<float4*>(params.allgather_out)[output_idx] = vals[r];
}
}

// Handle AllReduce pattern - compute sum and apply fusion operations
if constexpr (HasAllReduceOut<Pattern> || HasResidual<Pattern> || HasRMSNorm<Pattern>)
{
float4 sum_val = allreduce_sum<DType, NRanks, Fp32Acc>(vals);
fused_op(sum_val, tidx);
}
}

comm.update(params.size * NRanks);
Expand Down Expand Up @@ -674,6 +690,7 @@ void allreduce_fusion_kernel_launcher(AllReduceFusionParams const& params)
}
TLLM_CHECK(oneshot || threads_per_block >= params.nranks);
int block_size = threads_per_block;

TLLM_CHECK(block_size <= 1024 && cluster_size > 0);

int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size;
Expand Down Expand Up @@ -740,6 +757,10 @@ void allreduce_fusion_op(AllReduceFusionParams const& params)
{ \
DISPATCH_ACC_TYPE(DType, AllReduceFusionPattern::kAllReduce, NRanks); \
} \
else if (params.pattern == AllReduceFusionPattern::kAllGather) \
{ \
DISPATCH_ACC_TYPE(DType, AllReduceFusionPattern::kAllGather, NRanks); \
} \
else if (params.pattern == AllReduceFusionPattern::kARResidualRMSNorm) \
{ \
DISPATCH_ACC_TYPE(DType, AllReduceFusionPattern::kARResidualRMSNorm, NRanks); \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ static constexpr int kBarrierFlagCount = 256;
enum class AllReduceFusionPattern : int
{
kAllReduce = 0,
kARResidualRMSNorm = 1,
kARResidualRMSNormFP8Quant = 2,
kARResidualRMSNormFP4Quant = 3,
kAllGather = 1,
kARResidualRMSNorm = 2,
kARResidualRMSNormFP8Quant = 3,
kARResidualRMSNormFP4Quant = 4,
// The difference between these two and the standard version is that the NormOut version outputs the result of the
// norm.
kARResidualRMSNormOutFP8Quant = 4,
kARResidualRMSNormOutFP4Quant = 5
kARResidualRMSNormOutFP8Quant = 5,
kARResidualRMSNormOutFP4Quant = 6,
};

enum class QuantType : int
Expand All @@ -75,29 +76,33 @@ template <AllReduceFusionPattern Pattern>
struct FusionPatternTraits;

#define DEFINE_FUSION_PATTERN_TRAITS( \
pattern, hasAllReduceOut, hasResidual, hasResidualOut, hasRMSNorm, hasNormOut, quantType) \
pattern, hasAllReduceOut, hasAllGatherOut, hasResidual, hasResidualOut, hasRMSNorm, hasNormOut, quantType) \
template <> \
struct FusionPatternTraits<pattern> \
{ \
static constexpr bool kHasAllReduceOut = hasAllReduceOut; \
static constexpr bool kHasAllGatherOut = hasAllGatherOut; \
static constexpr bool kHasResidual = hasResidual; \
static constexpr bool kHasResidualOut = hasResidualOut; \
static constexpr bool kHasRMSNorm = hasRMSNorm; \
static constexpr bool kHasNormOut = hasNormOut; \
static constexpr QuantType kQuantType = quantType; \
};

DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kAllReduce, true, false, false, false, false, QuantType::kNone);
DEFINE_FUSION_PATTERN_TRAITS(
AllReduceFusionPattern::kARResidualRMSNorm, false, true, true, true, true, QuantType::kNone);
AllReduceFusionPattern::kAllReduce, true, false, false, false, false, false, QuantType::kNone);
DEFINE_FUSION_PATTERN_TRAITS(
AllReduceFusionPattern::kARResidualRMSNormFP8Quant, false, true, true, true, false, QuantType::kFP8);
AllReduceFusionPattern::kAllGather, false, true, false, false, false, false, QuantType::kNone);
DEFINE_FUSION_PATTERN_TRAITS(
AllReduceFusionPattern::kARResidualRMSNormFP4Quant, false, true, true, true, false, QuantType::kFP4);
AllReduceFusionPattern::kARResidualRMSNorm, false, false, true, true, true, true, QuantType::kNone);
DEFINE_FUSION_PATTERN_TRAITS(
AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant, false, true, true, true, true, QuantType::kFP8);
AllReduceFusionPattern::kARResidualRMSNormFP8Quant, false, false, true, true, true, false, QuantType::kFP8);
DEFINE_FUSION_PATTERN_TRAITS(
AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant, false, true, true, true, true, QuantType::kFP4);
AllReduceFusionPattern::kARResidualRMSNormFP4Quant, false, false, true, true, true, false, QuantType::kFP4);
DEFINE_FUSION_PATTERN_TRAITS(
AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant, false, false, true, true, true, true, QuantType::kFP8);
DEFINE_FUSION_PATTERN_TRAITS(
AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant, false, false, true, true, true, true, QuantType::kFP4);
#undef DEFINE_FUSION_PATTERN_TRAITS

template <AllReduceFusionPattern Pattern>
Expand All @@ -107,6 +112,8 @@ constexpr bool HasRMSNorm = FusionPatternTraits<Pattern>::kHasRMSNorm;
template <AllReduceFusionPattern Pattern>
constexpr bool HasAllReduceOut = FusionPatternTraits<Pattern>::kHasAllReduceOut;
template <AllReduceFusionPattern Pattern>
constexpr bool HasAllGatherOut = FusionPatternTraits<Pattern>::kHasAllGatherOut;
template <AllReduceFusionPattern Pattern>
constexpr bool HasResidualOut = FusionPatternTraits<Pattern>::kHasResidualOut;
template <AllReduceFusionPattern Pattern>
constexpr bool HasNormOut = FusionPatternTraits<Pattern>::kHasNormOut;
Expand All @@ -124,6 +131,7 @@ struct AllReduceFusionParams
void* allreduce_in;
void* residual_in;
void* allreduce_out;
void* allgather_out; // New field for AllGather output
void* residual_out;
void* norm_out;
void* quant_out;
Expand Down
18 changes: 10 additions & 8 deletions cpp/tensorrt_llm/kernels/customAllReduceKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,23 @@ enum class AllReduceStrategyConfig : int8_t
enum class AllReduceFusionOp : int8_t
{
NONE = 0,
RESIDUAL_RMS_NORM = 1,
LAST_PROCESS_FOR_UB = 2,
RESIDUAL_RMS_PREPOST_NORM = 3,
RESIDUAL_RMS_NORM_QUANT_FP8 = 4,
RESIDUAL_RMS_NORM_QUANT_NVFP4 = 5,
RESIDUAL_RMS_NORM_OUT_QUANT_FP8 = 6,
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 = 7,
MOE_FINALIZE_ALLREDUCE_RESIDUAL_RMS_NORM = 8,
ALLGATHER = 1,
RESIDUAL_RMS_NORM = 2,
LAST_PROCESS_FOR_UB = 3,
RESIDUAL_RMS_PREPOST_NORM = 4,
RESIDUAL_RMS_NORM_QUANT_FP8 = 5,
RESIDUAL_RMS_NORM_QUANT_NVFP4 = 6,
RESIDUAL_RMS_NORM_OUT_QUANT_FP8 = 7,
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 = 8,
MOE_FINALIZE_ALLREDUCE_RESIDUAL_RMS_NORM = 9,
};
Comment on lines +72 to 81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

LGTM! Correctly adds ALLGATHER enum value with proper renumbering.

The addition of ALLGATHER = 1 is well-placed and all subsequent enum values are appropriately renumbered. This maintains enum consistency while introducing the new fusion operation.

Please verify that all references to the renumbered enum values across the codebase have been updated consistently, as this is a breaking change.


🏁 Script executed:

#!/bin/bash
# Description: Verify that enum values are used consistently across the codebase
# Expected: All enum references should use symbolic names, not hardcoded integers

# Search for potential hardcoded enum values that might need updating
rg -A 2 -B 2 "AllReduceFusionOp::" --type cpp --type cu

Length of output: 82


🏁 Script executed:

#!/bin/bash
# Description: Re-run enum reference check without unsupported --type flag
# Search C++, header, and CUDA files for AllReduceFusionOp usage
rg -A2 -B2 "AllReduceFusionOp::" --glob '*.cpp' --glob '*.h' --glob '*.cu'

Length of output: 42196


All enum usages are symbolic and safe; nanobind layer missing ALLGATHER

  • C++/CUDA code and the pybind binding reference AllReduceFusionOp::ALLGATHER correctly—no hardcoded integers remain.
  • In cpp/tensorrt_llm/nanobind/runtime/bindings.cpp, the nb::enum_<AllReduceFusionOp> list currently omits the new ALLGATHER entry.

Please add:

.value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER)

to complete the nanobind mapping.

🤖 Prompt for AI Agents
In cpp/tensorrt_llm/nanobind/runtime/bindings.cpp, locate the
nb::enum_<AllReduceFusionOp> definition and add the line .value("ALLGATHER",
tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER) to include the ALLGATHER
enum in the nanobind bindings, ensuring the Python layer correctly maps this new
enum value.


inline std::ostream& operator<<(std::ostream& os, AllReduceFusionOp op)
{
switch (op)
{
case AllReduceFusionOp::NONE: os << "NONE"; break;
case AllReduceFusionOp::ALLGATHER: os << "ALLGATHER"; break;
case AllReduceFusionOp::RESIDUAL_RMS_NORM: os << "RESIDUAL_RMS_NORM"; break;
case AllReduceFusionOp::LAST_PROCESS_FOR_UB: os << "LAST_PROCESS_FOR_UB"; break;
case AllReduceFusionOp::RESIDUAL_RMS_PREPOST_NORM: os << "RESIDUAL_RMS_PREPOST_NORM"; break;
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/nanobind/runtime/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ void initBindings(nb::module_& m)
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer);

nb::enum_<tensorrt_llm::kernels::AllReduceFusionOp>(m, "AllReduceFusionOp")
.value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Replace tab with spaces to comply with code style (indentation rule).

C++ code must use spaces only for indentation. This line starts with a tab; clang-format/lint will likely fail.

Apply this diff:

-	.value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER)
+        .value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
.value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER)
.value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER)
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/nanobind/runtime/bindings.cpp around line 348, the line
begins with a tab which violates the project's C++ indentation rule; replace the
leading tab with spaces (use the repository's preferred indentation, e.g., 4
spaces) so the line starts with spaces only: .value("ALLGATHER",
tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER). Ensure there are no other
leading tabs on the same or surrounding lines.

.value("NONE", tensorrt_llm::kernels::AllReduceFusionOp::NONE)
.value("RESIDUAL_RMS_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM)
.value("LAST_PROCESS_FOR_UB", tensorrt_llm::kernels::AllReduceFusionOp::LAST_PROCESS_FOR_UB)
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/pybind/runtime/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ void initBindings(pybind11::module_& m)

py::enum_<tensorrt_llm::kernels::AllReduceFusionOp>(m, "AllReduceFusionOp")
.value("NONE", tensorrt_llm::kernels::AllReduceFusionOp::NONE)
.value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER)
.value("RESIDUAL_RMS_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM)
.value("LAST_PROCESS_FOR_UB", tensorrt_llm::kernels::AllReduceFusionOp::LAST_PROCESS_FOR_UB)
.value("RESIDUAL_RMS_PREPOST_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_PREPOST_NORM)
Expand Down
15 changes: 14 additions & 1 deletion cpp/tensorrt_llm/thop/allreduceOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ class AllreduceOp
allreduce_fusion_params.rms_gamma = nullptr;

allreduce_fusion_params.allreduce_out = nullptr;
allreduce_fusion_params.allgather_out = nullptr; // Initialize AllGather output pointer
allreduce_fusion_params.quant_out = nullptr;
allreduce_fusion_params.scale_out = nullptr;
allreduce_fusion_params.residual_out = nullptr;
Expand Down Expand Up @@ -471,6 +472,16 @@ class AllreduceOp
allreduce_fusion_params.allreduce_out = reduce_out.mutable_data_ptr();
allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kAllReduce;
}
// Handle AllGather operation
else if (mOp == AllReduceFusionOp::ALLGATHER)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is a correct way to implement ALLGATHER inside the allreduce op.

{
// For AllGather, create output tensor with expanded first dimension
auto output_shape = input.sizes().vec();
output_shape[0] = output_shape[0] * tp_size; // Expand by group size
reduce_out = torch::empty(output_shape, input.options());
allreduce_fusion_params.allgather_out = reduce_out.mutable_data_ptr();
allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kAllGather;
}
// Handle allreduce fusion here
// Prepare required output tensors for each fusion pattern
else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM)
Expand Down Expand Up @@ -555,7 +566,7 @@ class AllreduceOp
allreduce_fusion_params.workspace = reinterpret_cast<void**>(workspace.value().mutable_data_ptr());
allreduce_fusion_params.allreduce_in = input.data_ptr();

if (mOp != AllReduceFusionOp::NONE)
if (mOp != AllReduceFusionOp::NONE && mOp != AllReduceFusionOp::ALLGATHER)
{
allreduce_fusion_params.residual_in = residual.value().data_ptr();
allreduce_fusion_params.rms_gamma = norm_weight.value().data_ptr();
Expand All @@ -578,6 +589,7 @@ class AllreduceOp
switch (mOp)
{
case AllReduceFusionOp::NONE: return {reduce_out};
case AllReduceFusionOp::ALLGATHER: return {reduce_out};
case AllReduceFusionOp::RESIDUAL_RMS_NORM: return {norm_out, residual_out};
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8: return {quant_out, residual_out};
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8: return {norm_out, quant_out, residual_out};
Expand Down Expand Up @@ -920,6 +932,7 @@ class AllreduceOp
{
case AllReduceFusionOp::NONE:
case AllReduceFusionOp::RESIDUAL_RMS_NORM: break;
case AllReduceFusionOp::ALLGATHER:
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8:
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8:
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4:
Expand Down
Loading
Loading