-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Feat/mtp opt 2 lamport allgather #6500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
69d9410
a1f5081
0be2ac8
ab97951
e16b5d9
c475ea0
04c7465
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,21 +69,23 @@ enum class AllReduceStrategyConfig : int8_t | |
enum class AllReduceFusionOp : int8_t | ||
{ | ||
NONE = 0, | ||
RESIDUAL_RMS_NORM = 1, | ||
LAST_PROCESS_FOR_UB = 2, | ||
RESIDUAL_RMS_PREPOST_NORM = 3, | ||
RESIDUAL_RMS_NORM_QUANT_FP8 = 4, | ||
RESIDUAL_RMS_NORM_QUANT_NVFP4 = 5, | ||
RESIDUAL_RMS_NORM_OUT_QUANT_FP8 = 6, | ||
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 = 7, | ||
MOE_FINALIZE_ALLREDUCE_RESIDUAL_RMS_NORM = 8, | ||
ALLGATHER = 1, | ||
RESIDUAL_RMS_NORM = 2, | ||
LAST_PROCESS_FOR_UB = 3, | ||
RESIDUAL_RMS_PREPOST_NORM = 4, | ||
RESIDUAL_RMS_NORM_QUANT_FP8 = 5, | ||
RESIDUAL_RMS_NORM_QUANT_NVFP4 = 6, | ||
RESIDUAL_RMS_NORM_OUT_QUANT_FP8 = 7, | ||
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 = 8, | ||
MOE_FINALIZE_ALLREDUCE_RESIDUAL_RMS_NORM = 9, | ||
}; | ||
Comment on lines
+72
to
81
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainLGTM! Correctly adds ALLGATHER enum value with proper renumbering. The addition of Please verify that all references to the renumbered enum values across the codebase have been updated consistently, as this is a breaking change. 🏁 Script executed: #!/bin/bash
# Description: Verify that enum values are used consistently across the codebase
# Expected: All enum references should use symbolic names, not hardcoded integers
# Search for potential hardcoded enum values that might need updating
rg -A 2 -B 2 "AllReduceFusionOp::" --type cpp --type cu Length of output: 82 🏁 Script executed: #!/bin/bash
# Description: Re-run enum reference check without unsupported --type flag
# Search C++, header, and CUDA files for AllReduceFusionOp usage
rg -A2 -B2 "AllReduceFusionOp::" --glob '*.cpp' --glob '*.h' --glob '*.cu' Length of output: 42196 All enum usages are symbolic and safe; nanobind layer missing ALLGATHER
Please add: .value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER) to complete the nanobind mapping. 🤖 Prompt for AI Agents
|
||
|
||
inline std::ostream& operator<<(std::ostream& os, AllReduceFusionOp op) | ||
{ | ||
switch (op) | ||
{ | ||
case AllReduceFusionOp::NONE: os << "NONE"; break; | ||
case AllReduceFusionOp::ALLGATHER: os << "ALLGATHER"; break; | ||
case AllReduceFusionOp::RESIDUAL_RMS_NORM: os << "RESIDUAL_RMS_NORM"; break; | ||
case AllReduceFusionOp::LAST_PROCESS_FOR_UB: os << "LAST_PROCESS_FOR_UB"; break; | ||
case AllReduceFusionOp::RESIDUAL_RMS_PREPOST_NORM: os << "RESIDUAL_RMS_PREPOST_NORM"; break; | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -345,6 +345,7 @@ void initBindings(nb::module_& m) | |||||
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer); | ||||||
|
||||||
nb::enum_<tensorrt_llm::kernels::AllReduceFusionOp>(m, "AllReduceFusionOp") | ||||||
.value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace tab with spaces to comply with code style (indentation rule). C++ code must use spaces only for indentation. This line starts with a tab; clang-format/lint will likely fail. Apply this diff: - .value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER)
+ .value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||
.value("NONE", tensorrt_llm::kernels::AllReduceFusionOp::NONE) | ||||||
.value("RESIDUAL_RMS_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM) | ||||||
.value("LAST_PROCESS_FOR_UB", tensorrt_llm::kernels::AllReduceFusionOp::LAST_PROCESS_FOR_UB) | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -441,6 +441,7 @@ class AllreduceOp | |
allreduce_fusion_params.rms_gamma = nullptr; | ||
|
||
allreduce_fusion_params.allreduce_out = nullptr; | ||
allreduce_fusion_params.allgather_out = nullptr; // Initialize AllGather output pointer | ||
allreduce_fusion_params.quant_out = nullptr; | ||
allreduce_fusion_params.scale_out = nullptr; | ||
allreduce_fusion_params.residual_out = nullptr; | ||
|
@@ -471,6 +472,16 @@ class AllreduceOp | |
allreduce_fusion_params.allreduce_out = reduce_out.mutable_data_ptr(); | ||
allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kAllReduce; | ||
} | ||
// Handle AllGather operation | ||
else if (mOp == AllReduceFusionOp::ALLGATHER) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it is a correct way to implement ALLGATHER inside the allreduce op. |
||
{ | ||
// For AllGather, create output tensor with expanded first dimension | ||
auto output_shape = input.sizes().vec(); | ||
output_shape[0] = output_shape[0] * tp_size; // Expand by group size | ||
reduce_out = torch::empty(output_shape, input.options()); | ||
allreduce_fusion_params.allgather_out = reduce_out.mutable_data_ptr(); | ||
allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kAllGather; | ||
} | ||
// Handle allreduce fusion here | ||
// Prepare required output tensors for each fusion pattern | ||
else if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM) | ||
|
@@ -555,7 +566,7 @@ class AllreduceOp | |
allreduce_fusion_params.workspace = reinterpret_cast<void**>(workspace.value().mutable_data_ptr()); | ||
allreduce_fusion_params.allreduce_in = input.data_ptr(); | ||
|
||
if (mOp != AllReduceFusionOp::NONE) | ||
if (mOp != AllReduceFusionOp::NONE && mOp != AllReduceFusionOp::ALLGATHER) | ||
{ | ||
allreduce_fusion_params.residual_in = residual.value().data_ptr(); | ||
allreduce_fusion_params.rms_gamma = norm_weight.value().data_ptr(); | ||
|
@@ -578,6 +589,7 @@ class AllreduceOp | |
switch (mOp) | ||
{ | ||
case AllReduceFusionOp::NONE: return {reduce_out}; | ||
case AllReduceFusionOp::ALLGATHER: return {reduce_out}; | ||
case AllReduceFusionOp::RESIDUAL_RMS_NORM: return {norm_out, residual_out}; | ||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8: return {quant_out, residual_out}; | ||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8: return {norm_out, quant_out, residual_out}; | ||
|
@@ -920,6 +932,7 @@ class AllreduceOp | |
{ | ||
case AllReduceFusionOp::NONE: | ||
case AllReduceFusionOp::RESIDUAL_RMS_NORM: break; | ||
case AllReduceFusionOp::ALLGATHER: | ||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8: | ||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8: | ||
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The correct way is making the common part a common function and erase "allreduce" related string from the common function. And use the common function to implement allreduce and allgather.