-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Feat/mtp opt 2 lamport allgather #6500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feat/mtp opt 2 lamport allgather #6500
Conversation
Signed-off-by: Amey Naik <[email protected]>
Signed-off-by: Amey naik <[email protected]>
📝 WalkthroughWalkthroughAdds an ALLGATHER fusion path: new enum/member values, pattern traits and params (allgather_out), kernel branching for AllGather vs AllReduce, C++/Python bindings and operator plumbing, workspace selection, tests, and model integration to support concatenated per-rank AllGather outputs. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant PyFrontEnd
participant PyAllReduceOp
participant CppRuntime
participant Kernel
participant NCCL
User->>PyFrontEnd: call AllReduce(fusion_op=ALLGATHER)
PyFrontEnd->>PyAllReduceOp: select workspace (allgather_workspace if available)
PyAllReduceOp->>CppRuntime: invoke allreduce_fusion_op (kAllGather)
CppRuntime->>Kernel: launch allreduce_fusion_kernel (AllGather path)
Kernel->>Kernel: gather per-rank slices and write to params.allgather_out (concatenated)
Kernel-->>CppRuntime: return allgather_out buffer
CppRuntime-->>PyAllReduceOp: wrap buffer as tensor
PyAllReduceOp-->>User: concatenated per-rank tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (8)
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py (1)
29-34
: LGTM! Correct implementation for ALLGATHER fake operator.The logic correctly expands the first tensor dimension by the group size, which matches the expected behavior of an AllGather operation. However, please address the FIXME comment on line 34.
Consider removing or explaining the FIXME comment - the implementation appears correct for a fake AllGather operation.
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu (1)
693-698
: Remove commented code or convert to documentation.The commented-out code should either be removed if it's no longer needed, or converted to a proper documentation comment explaining why special block size handling for AllGather was considered but not implemented.
- // // Override block size to 1024 for AllGather operations - // if (params.pattern == AllReduceFusionPattern::kAllGather) { - // block_size = 1024; - // } - + // Note: Special block size handling for AllGather was considered but determined + // unnecessary as the existing calculation works well for all patterns.tensorrt_llm/_torch/distributed/ops.py (1)
35-51
: Consider workspace sharing between allreduce and allgather operations.The new
get_allgather_workspace
function duplicates the entire logic ofget_allreduce_workspace
. Since both operations use the same workspace allocation method (allocate_allreduce_fusion_workspace
) with identical parameters, consider sharing the workspace to reduce memory usage.-def get_allgather_workspace(mapping: Mapping) -> torch.LongTensor: - if not hasattr(_thread_local, f'allgather_workspaces_{mapping.pp_rank}'): - setattr(_thread_local, f'allgather_workspaces_{mapping.pp_rank}', {}) - - allgather_workspaces = getattr(_thread_local, - f'allgather_workspaces_{mapping.pp_rank}') - if mapping not in allgather_workspaces: - # Allocate dedicated workspace for ALLGATHER operations - # Use a potentially larger workspace size for allgather operations - ipc_buffers, workspace = CustomAllReduceHelper.allocate_allreduce_fusion_workspace( - mapping, - CustomAllReduceHelper.max_workspace_size_auto( - mapping.tp_size, support_deterministic=False), - ) - allgather_workspaces[mapping] = (ipc_buffers, workspace) - return allgather_workspaces[mapping][1] +# Consider using the existing allreduce workspace for allgather operations +# since they use the same allocation method and size +get_allgather_workspace = get_allreduce_workspacecpp/tests/unit_tests/kernels/allReduce/allReduceFusionTest.cu (1)
556-576
: Consider removing the half-specific method to reduce duplication.The templated
set_input_data_typed
method can handle all data types includinghalf
. The non-templated version appears redundant.Remove the half-specific method and update callers to use the templated version:
- // Public method to set input data for testing - void set_input_data(int message_size, std::function<half(int)> data_generator) - { - auto input_data = m_allreduce_in.host_data<half>(); - for (int i = 0; i < message_size; ++i) - { - input_data[i] = data_generator(i); - } - m_allreduce_in.h2d(); - } - // Template method to set input data for any data type template<typename T> void set_input_data_typed(int message_size, std::function<T(int)> data_generator)tests/unittest/_torch/multi_gpu/test_allreduce.py (3)
109-109
: Remove debug print statement before merging.Debug print statements should be removed from production test code.
- print(f"DBG AMEY: run_allreduce_op: i am here")
276-312
: Remove debug prints from reference implementation.The reference implementation contains debug print statements that should be removed.
def ref_allgather(x, res): - print(f"[Rank {tensor_parallel_rank}] Starting ref_allgather") # Reference implementation: manually create expected AllGather result all_rank_data = [] for rank in range(tensor_parallel_size): # Create with default dtype first, then convert rank_data_orig = torch.arange( rank * 10, rank * 10 + 8, device="cuda" ).reshape(1, 8) - print(f"[Rank {tensor_parallel_rank}] Original ref data for rank {rank}: {rank_data_orig} (dtype: {rank_data_orig.dtype})") - rank_data = rank_data_orig.to(dtype) - print(f"[Rank {tensor_parallel_rank}] Converted ref data for rank {rank}: {rank_data}") - print(f"[Rank {tensor_parallel_rank}] ref rank_data dtype: {rank_data.dtype}") - # Pad or reshape to match expected input dimensions if rank_data.shape != x.shape: rank_data = rank_data.expand(x.shape[0], -1) if rank_data.shape[1] < x.shape[1]: # Repeat pattern to fill the hidden dimension repeat_factor = x.shape[1] // 8 remainder = x.shape[1] % 8 rank_data = rank_data.repeat(1, repeat_factor) if remainder > 0: rank_data = torch.cat([rank_data, rank_data[:, :remainder]], dim=1) - print(f"[Rank {tensor_parallel_rank}] Final ref data for rank {rank}: shape {rank_data.shape}, values {rank_data}") all_rank_data.append(rank_data) # Concatenate along the first dimension (batch dimension) gathered_result = torch.cat(all_rank_data, dim=0) - print(f"[Rank {tensor_parallel_rank}] Reference AllGather result:") - print(f"[Rank {tensor_parallel_rank}] Shape: {gathered_result.shape}") - print(f"[Rank {tensor_parallel_rank}] Values: {gathered_result} dtype: {gathered_result.dtype}") return [gathered_result]
393-393
: Remove debug print statement.- print(f"DBG AMEY: test_allreduce_fusion_patterns: seq_len={seq_len}, hidden_size={hidden_size}, fusion_op={fusion_op}")
tensorrt_llm/_torch/speculative/mtp.py (1)
1101-1132
: Consider refactoring to reduce code duplication between padded and simple versions.The
_simple
variants of these methods contain nearly identical logic to their padded counterparts, with only the padding/slicing logic removed. This duplication could lead to maintenance issues.Consider refactoring to share common logic:
def _get_local_max_and_indices(self, logits): """Common logic for getting local max values and indices.""" local_max_values, local_argmax = torch.max(logits, dim=-1, keepdim=True) vocab_per_rank = logits.shape[-1] max_index_per_rank = local_argmax.type( torch.int32) + (self.model_config.mapping.tp_rank * vocab_per_rank) max_index_per_rank_float = max_index_per_rank.float() local_max_values_float32 = local_max_values.float() combined = torch.stack( [max_index_per_rank_float, local_max_values_float32], dim=-1).flatten(-2) return combined def get_local_max_and_combined(self, logits): combined = self._get_local_max_and_indices(logits) original_last_dim = combined.shape[-1] # Add padding if needed for ALLGATHER kernel requirements if combined.numel() < 4: # ... padding logic ... return original_last_dim, combined def get_local_max_and_combined_simple(self, logits): """Simple version without padding for fallback allgather""" return self._get_local_max_and_indices(logits)
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
(3 hunks)cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h
(4 hunks)cpp/tensorrt_llm/kernels/customAllReduceKernels.h
(1 hunks)cpp/tensorrt_llm/pybind/runtime/bindings.cpp
(1 hunks)cpp/tensorrt_llm/thop/allreduceOp.cpp
(5 hunks)cpp/tests/unit_tests/kernels/allReduce/allReduceFusionTest.cu
(8 hunks)tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
(1 hunks)tensorrt_llm/_torch/distributed/ops.py
(5 hunks)tensorrt_llm/_torch/models/modeling_deepseekv3.py
(1 hunks)tensorrt_llm/_torch/speculative/mtp.py
(4 hunks)tensorrt_llm/functional.py
(2 hunks)tests/unittest/_torch/multi_gpu/test_allreduce.py
(4 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without reflection.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/models/modeling_deepseekv3.py
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
tensorrt_llm/_torch/distributed/ops.py
tensorrt_llm/functional.py
tests/unittest/_torch/multi_gpu/test_allreduce.py
tensorrt_llm/_torch/speculative/mtp.py
**/*.{cpp,h,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
tensorrt_llm/_torch/models/modeling_deepseekv3.py
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
cpp/tensorrt_llm/pybind/runtime/bindings.cpp
cpp/tensorrt_llm/kernels/customAllReduceKernels.h
tensorrt_llm/_torch/distributed/ops.py
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h
cpp/tensorrt_llm/thop/allreduceOp.cpp
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
tensorrt_llm/functional.py
tests/unittest/_torch/multi_gpu/test_allreduce.py
cpp/tests/unit_tests/kernels/allReduce/allReduceFusionTest.cu
tensorrt_llm/_torch/speculative/mtp.py
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}
: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo).
Prefer const or constexpr variables over #defines whenever possible, as the latter are not visible to the compiler.
A variable that is not modified after its initialization should be declared as const.
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in a compilation target must have filenames that are case-insensitive unique.
All types (including class names) are camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace use camel case prefixed by a lower case 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace use camel case prefixed by a lower case 's' (e.g., sMutableStaticGlobal).
Locally visible static variable uses camel case with lowercase prefix 's' as the first letter of the name (e.g., static std::once_flag sFlag;).
Class member variables use camelcase prefixed with an 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope ...
Files:
cpp/tensorrt_llm/pybind/runtime/bindings.cpp
cpp/tensorrt_llm/kernels/customAllReduceKernels.h
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h
cpp/tensorrt_llm/thop/allreduceOp.cpp
**/*.{h,hpp}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.
Files:
cpp/tensorrt_llm/kernels/customAllReduceKernels.h
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h
🧠 Learnings (3)
📓 Common learnings
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Learnt from: yiqingy0
PR: NVIDIA/TensorRT-LLM#5198
File: jenkins/mergeWaiveList.py:0-0
Timestamp: 2025-07-22T08:33:49.109Z
Learning: In the TensorRT-LLM waive list merging system, removed lines are always located at the end of the merge waive lists, which is why the mergeWaiveList.py script uses reverse traversal - it's an optimization for this specific domain constraint.
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
Learnt from: yiqingy0
PR: #5198
File: jenkins/mergeWaiveList.py:0-0
Timestamp: 2025-07-22T08:33:49.109Z
Learning: In the TensorRT-LLM waive list merging system, removed lines are always located at the end of the merge waive lists, which is why the mergeWaiveList.py script uses reverse traversal - it's an optimization for this specific domain constraint.
tensorrt_llm/_torch/speculative/mtp.py (1)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor()
is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation()
to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
🪛 Ruff (0.12.2)
tests/unittest/_torch/multi_gpu/test_allreduce.py
287-287: Line too long (135 > 120)
(E501)
304-304: Line too long (127 > 120)
(E501)
393-393: Line too long (123 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (29)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
986-986
: LGTM! Cosmetic improvement.Removing extraneous blank lines improves code readability and maintains consistent formatting.
cpp/tensorrt_llm/pybind/runtime/bindings.cpp (1)
435-435
: LGTM! Correctly adds Python binding for the new ALLGATHER enum value.The addition aligns with the C++ enum definition and follows the established pattern for exposing enum values to Python.
cpp/tensorrt_llm/kernels/customAllReduceKernels.h (1)
86-86
: LGTM! Correctly updated operator<< overload for ALLGATHER.The string output case for
ALLGATHER
is properly added to maintain consistency with the enum definition.tensorrt_llm/functional.py (2)
3889-3897
: LGTM: Enum extension correctly implements ALLGATHER support.The addition of
ALLGATHER = 1
and subsequent value shifts are properly implemented. This change aligns with the broader AllGather fusion operation implementation described in the AI summary.
3924-3925
: LGTM: Assertion correctly allows ALLGATHER without residual requirement.The updated assertion logic properly exempts
ALLGATHER
fusion operations from requiring a residual tensor, which makes sense as AllGather operations don't inherently need residual data unlike other fusion operations.cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu (2)
508-526
: LGTM: Well-structured AllGather support in kernel.The kernel changes correctly implement AllGather functionality by:
- Writing gathered data from all ranks to
allgather_out
buffer with proper indexing- Conditionally executing AllReduce sum and fusion operations only when needed
- Using template traits to determine which operations to perform
The separation of AllGather output (lines 510-518) and AllReduce processing (lines 521-525) is clean and maintains the existing fusion logic.
765-768
: LGTM: Correct pattern dispatch implementation.The dispatch logic correctly adds support for the
kAllGather
pattern, following the established pattern and integrating cleanly with the existing dispatch system.cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h (4)
78-106
: LGTM: Well-implemented trait system extension.The trait system changes are correctly implemented:
- Macro updated to accept the new
hasAllGatherOut
parameter- All pattern specializations consistently updated with appropriate trait values
- New
HasAllGatherOut
template variable provides convenient accesskAllGather
pattern correctly setskHasAllGatherOut = true
The implementation maintains consistency with the existing trait system design.
115-116
: LGTM: Consistent template variable addition.The new
HasAllGatherOut
template variable follows the established pattern and provides convenient access to the trait value, maintaining consistency with other trait accessors.
134-134
: LGTM: Well-documented parameter addition.The new
allgather_out
parameter is appropriately added with a clear comment explaining its purpose. The positioning in the struct is logical and maintains readability.
58-66
: Ensure binary compatibility after inserting new enum entryAdding
kAllGather = 1
shifts all subsequentAllReduceFusionPattern
values, which can break any existing compiled consumers that serialize or switch on the old numeric values. Although our own source references the symbols (and will recompile correctly), external modules or persisted data may now be invalid.Key locations impacted:
- cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h (enum definition)
- cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu (dispatch on
kAllGather
,kARResidualRMSNorm*
)- cpp/tensorrt_llm/thop/allreduceOp.cpp (pattern assignments)
- cpp/tests/unit_tests/kernels/allReduce/allReduceFusionTest.cu (static asserts and TEST_AR_FUSION)
Next steps:
- Confirm that all downstream consumers (including any plugins or serialized checkpoints) will be recompiled together.
- If backward compatibility is required, consider appending
kAllGather
at the end of the enum or explicitly assigning old numeric values to preserve existing ordering.cpp/tensorrt_llm/thop/allreduceOp.cpp (5)
409-409
: LGTM: Correct parameter initialization.The
allgather_out
parameter is properly initialized tonullptr
following the established pattern for other output parameters.
440-449
: LGTM: Correct AllGather operation implementation.The AllGather handling is well-implemented:
- Output tensor correctly expands first dimension by
tp_size
(group size)- Proper assignment of
allgather_out
pointer- Correct fusion pattern selection
This matches the expected semantics where AllGather concatenates data from all ranks along the first dimension.
534-534
: LGTM: Correct exclusion logic for AllGather.The condition correctly excludes
ALLGATHER
operations from residual and norm weight processing, as AllGather operations don't require these additional inputs, similar toNONE
operations.
557-557
: LGTM: Appropriate output packing for AllGather.The output packing correctly returns a single tensor for AllGather operations, which is appropriate as AllGather produces only the gathered output tensor.
891-891
: LGTM: Appropriate heuristic selection for AllGather.Adding
ALLGATHER
to the operations that always returnMIN_LATENCY
strategy is correct, as AllGather operations are specialized and benefit from the optimized fused kernel implementation rather than fallback strategies.tensorrt_llm/_torch/distributed/ops.py (2)
428-439
: LGTM! Workspace allocation follows the established pattern.The initialization and allocation of
allgather_workspace
correctly follows the same pattern as the existingallreduce_workspace
, ensuring consistency in the codebase.
493-509
: Workspace selection logic is well-implemented.The dynamic workspace selection based on fusion operation type is correctly implemented with proper fallback handling. This ensures the ALLGATHER fusion operation gets its dedicated workspace when available.
cpp/tests/unit_tests/kernels/allReduce/allReduceFusionTest.cu (10)
1-15
: Copyright header looks good.The copyright year range includes 2025 which aligns with the current requirements.
283-284
: Buffer allocation for AllGather is correct.The size calculation
m_message_size * m_world_size * sizeof(DType)
correctly accounts for the concatenated output from all ranks.
387-405
: Verification logic correctly handles AllGather pattern.The conditional compilation using
ar_fusion::HasAllGatherOut<Pattern>
properly differentiates between AllReduce and AllGather patterns, with appropriate reference output generation and comparison tolerances.
464-468
: NCCL AllGather wrapper is correctly implemented.The method properly calls the NCCL API with correct parameters for benchmarking purposes.
495-536
: Debug output method is well-structured.The method provides clear visualization of AllGather results with appropriate rank-0 restrictions to avoid cluttered output.
678-758
: AllGather test with half precision is comprehensive and well-structured.The test effectively validates both correctness and performance with clear debugging output and deterministic data generation.
785-785
: Clarify the allocation strategy with max(token_num, 128).The comment mentions "Use max token_num for allocation" but doesn't explain why 128 is used as the minimum. This appears to be a workaround for an issue mentioned in line 778 about avoiding hanging.
Could you clarify why a minimum of 128 is needed for allocation? Is this related to workspace requirements or a known issue with smaller sizes?
842-844
: Performance comparison output effectively demonstrates the speedup.The format clearly shows the performance benefit of the fusion kernel over NCCL AllGather, aligning with the PR's objective of ~2-3x improvement for small message sizes.
862-906
: Float AllReduce test provides good basic validation.The test correctly validates the AllReduce pattern with float precision using deterministic data.
652-654
: Test coverage has been significantly reduced.The test parameters have been reduced to only test with
hidden_dim=1024
andtoken_num=1
. While this speeds up test execution, it reduces coverage for edge cases and different problem sizes.Is this reduction intentional for faster CI runs, or should we maintain broader test coverage? Consider adding a comment explaining why the coverage was reduced.
tests/unittest/_torch/multi_gpu/test_allreduce.py (1)
386-388
: Remove debug print and clarify dtype change.Debug print should be removed, and the dtype change from
bfloat16
tofloat32
needs clarification.- print("DBG AMEY: test_allreduce_fusion_patterns: i am here") - # dtype = torch.bfloat16 - dtype = torch.float32 + dtype = torch.bfloat16Why was the dtype changed to
float32
? If this is necessary for the ALLGATHER operation, please add a comment explaining the rationale. Otherwise, consider reverting tobfloat16
to maintain consistency with the original test design.
ALLGATHER = 1, | ||
RESIDUAL_RMS_NORM = 2, | ||
LAST_PROCESS_FOR_UB = 3, | ||
RESIDUAL_RMS_PREPOST_NORM = 4, | ||
RESIDUAL_RMS_NORM_QUANT_FP8 = 5, | ||
RESIDUAL_RMS_NORM_QUANT_NVFP4 = 6, | ||
RESIDUAL_RMS_NORM_OUT_QUANT_FP8 = 7, | ||
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 = 8, | ||
MOE_FINALIZE_ALLREDUCE_RESIDUAL_RMS_NORM = 9, | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
LGTM! Correctly adds ALLGATHER enum value with proper renumbering.
The addition of ALLGATHER = 1
is well-placed and all subsequent enum values are appropriately renumbered. This maintains enum consistency while introducing the new fusion operation.
Please verify that all references to the renumbered enum values across the codebase have been updated consistently, as this is a breaking change.
🏁 Script executed:
#!/bin/bash
# Description: Verify that enum values are used consistently across the codebase
# Expected: All enum references should use symbolic names, not hardcoded integers
# Search for potential hardcoded enum values that might need updating
rg -A 2 -B 2 "AllReduceFusionOp::" --type cpp --type cu
Length of output: 82
🏁 Script executed:
#!/bin/bash
# Description: Re-run enum reference check without unsupported --type flag
# Search C++, header, and CUDA files for AllReduceFusionOp usage
rg -A2 -B2 "AllReduceFusionOp::" --glob '*.cpp' --glob '*.h' --glob '*.cu'
Length of output: 42196
All enum usages are symbolic and safe; nanobind layer missing ALLGATHER
- C++/CUDA code and the
pybind
binding referenceAllReduceFusionOp::ALLGATHER
correctly—no hardcoded integers remain. - In
cpp/tensorrt_llm/nanobind/runtime/bindings.cpp
, thenb::enum_<AllReduceFusionOp>
list currently omits the newALLGATHER
entry.
Please add:
.value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER)
to complete the nanobind mapping.
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/nanobind/runtime/bindings.cpp, locate the
nb::enum_<AllReduceFusionOp> definition and add the line .value("ALLGATHER",
tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER) to include the ALLGATHER
enum in the nanobind bindings, ensuring the Python layer correctly maps this new
enum value.
PR_Github #13594 [ run ] triggered by Bot |
PR_Github #13594 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h (1)
2-2
: Update copyright year to 2025 for consistency.The copyright year should be updated to 2022-2025 to match the .cu file and comply with the coding guideline that requires the current year in copyright headers.
- * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
(3 hunks)cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h
(4 hunks)cpp/tensorrt_llm/kernels/customAllReduceKernels.h
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- cpp/tensorrt_llm/kernels/customAllReduceKernels.h
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}
: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo).
Prefer const or constexpr variables over #defines whenever possible.
A variable that is not modified after its initialization should be declared as const.
Except 0 (used for checking signness/existence/emptiness), nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do..while, or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with the first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in a compilation target must have case-insensitive unique filenames.
All types (including class names) should use camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces should use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace should use camel case prefixed by 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace should use camel case prefixed by 's' (e.g., sMutableStaticGlobal).
Locally visible static variables should use camel case with lowercase prefix 's' as the first letter (e.g., static std::once_flag sFlag;).
Class member variables should use camel case prefixed with 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope magic-number/literal constants should be uppercase snake case with prefix...
Files:
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h
**/*.{h,hpp}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.
Files:
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu (2)
508-525
: LGTM! AllGather implementation correctly handles per-rank data concatenation.The conditional branching between AllGather and AllReduce paths is well-structured. The AllGather output correctly concatenates data from all ranks in sequential order before any reduction operations.
765-768
: LGTM! AllGather dispatch case properly integrated.The new AllGather pattern dispatch follows the established pattern and is correctly positioned in the dispatch chain.
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h (3)
57-65
: LGTM! AllGather pattern correctly integrated into the enum.The new kAllGather pattern is logically positioned and the renumbering of subsequent patterns is consistent.
78-105
: LGTM! Pattern traits properly extended for AllGather and NormOut variants.The trait system correctly captures the distinct characteristics of each fusion pattern, with AllGather having only kHasAllGatherOut enabled and the new NormOut variants properly configured.
134-134
: LGTM! AllGather output field properly added to params struct.The new
allgather_out
field is correctly positioned and follows the existing pattern for output fields.
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
Outdated
Show resolved
Hide resolved
Signed-off-by: Amey naik <[email protected]>
/bot run --disable-fail-fast |
PR_Github #14861 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (3)
tests/unittest/_torch/multi_gpu/test_allreduce.py (2)
217-274
: Remove debug prints and make iteration count configurable (prior feedback).This block contains extensive debug prints and a hardcoded 100-iteration loop that slows CI and clutters logs. Make the iteration count configurable (env var) and drop prints to avoid noise and potential timeouts. This mirrors the earlier review guidance.
Apply this diff:
def calc_allgather(x, res): - print(f"[Rank {tensor_parallel_rank}] Starting calc_allgather") - print(f"[Rank {tensor_parallel_rank}] Input x shape: {x.shape}, dtype: {x.dtype}") - print(f"[Rank {tensor_parallel_rank}] Input x values: {x}") - # Create rank-specific input data for AllGather testing # Rank 0: [0,1,2,3,4,5,6,7], Rank 1: [10,11,12,13,14,15,16,17], etc. - print(f"[Rank {tensor_parallel_rank}] Using dtype: {dtype}") - # First create with default dtype, then convert rank_data_orig = torch.arange( tensor_parallel_rank * 10, tensor_parallel_rank * 10 + 8, device="cuda" ).reshape(1, 8) - print(f"[Rank {tensor_parallel_rank}] Original rank_data: {rank_data_orig} (dtype: {rank_data_orig.dtype})") - rank_data = rank_data_orig.to(dtype) # Convert to target dtype - - print(f"[Rank {tensor_parallel_rank}] Converted rank_data: {rank_data}") - print(f"[Rank {tensor_parallel_rank}] rank_data dtype: {rank_data.dtype}") - # Pad or reshape to match expected input dimensions if needed if rank_data.shape != x.shape: rank_data = rank_data.expand(x.shape[0], -1) if rank_data.shape[1] < x.shape[1]: # Repeat pattern to fill the hidden dimension repeat_factor = x.shape[1] // 8 remainder = x.shape[1] % 8 rank_data = rank_data.repeat(1, repeat_factor) if remainder > 0: rank_data = torch.cat([rank_data, rank_data[:, :remainder]], dim=1) - - print(f"[Rank {tensor_parallel_rank}] Final rank_data shape: {rank_data.shape}") - print(f"[Rank {tensor_parallel_rank}] Final rank_data values: {rank_data}") - - # Use AllReduce with ALLGATHER fusion to gather data from all ranks - print(f"[Rank {tensor_parallel_rank}] About to call allreduce with ALLGATHER") try: - # Run allreduce 100 times in a loop + # Run allreduce N times for performance benchmarking (default 100). + # Override with env var TLLM_ALLGATHER_BENCH_ITERS to adjust runtime. + iterations = int(os.getenv("TLLM_ALLGATHER_BENCH_ITERS", "100")) output = None - for i in range(100): + for i in range(iterations): output = allreduce( rank_data, all_reduce_params=AllReduceParams( fusion_op=AllReduceFusionOp.ALLGATHER, enable_allreduce=True, ), ) - print(f"[Rank {tensor_parallel_rank}] AllGather completed!") - print(f"[Rank {tensor_parallel_rank}] Output shape: {output.shape}") - print(f"[Rank {tensor_parallel_rank}] Output values: {output}") return [output] except Exception as e: - print(f"[Rank {tensor_parallel_rank}] AllGather failed: {e}") - print(f"[Rank {tensor_parallel_rank}] Exception traceback: {traceback.format_exc()}") raise
358-361
: Restore broader test parameter coverage for seq_len/hidden_size (prior feedback).Limiting to a single value reduces coverage and may mask regressions for larger messages. Restore the broader sets or document the restriction clearly.
Apply this diff:
-@pytest.mark.parametrize("seq_len", [1], #256, 8192], +@pytest.mark.parametrize("seq_len", [1, 256, 8192], ids=lambda x: f"seqlen:{x}") -@pytest.mark.parametrize("hidden_size", [8], #7168], +@pytest.mark.parametrize("hidden_size", [8, 7168], ids=lambda x: f"hidden:{x}")cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu (1)
693-693
: Remove or document the commented-out block size override.The commented block size override for AllGather operations should either be removed if it's no longer needed, or documented with a clear explanation of why it's disabled and under what conditions it might be re-enabled.
This is the same issue flagged in the previous review comment. The commented-out code still exists and needs to be addressed either by removal or proper documentation.
🧹 Nitpick comments (2)
tests/unittest/_torch/multi_gpu/test_allreduce.py (2)
381-381
: Reduce mpi_pool_executor workers to avoid CI resource exhaustion/timeouts.Spawning 8 workers combined with 100-iteration loops can significantly increase runtime and memory pressure, risking CI failures. Prefer 2 (or make it configurable).
Apply this diff (unless there’s a specific reason to force 8):
-@pytest.mark.parametrize("mpi_pool_executor", [8], indirect=True) +@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)If you need to test at higher parallelism, consider gating via an environment variable or marking as a heavier test.
385-387
: Revisit dtype change to float32; consider bfloat16 or parameterizing.Most fusion tests use bfloat16; switching to float32 here diverges and reduces coverage consistency. Either restore bfloat16 or parametrize dtype to run both where feasible.
Two options:
- Restore bfloat16 (simplest):
- # dtype = torch.bfloat16 - dtype = torch.float32 + dtype = torch.bfloat16
- Or parametrize dtype (adds coverage but increases runtime):
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_allreduce_fusion_patterns(seq_len, hidden_size, fusion_op, mpi_pool_executor, dtype): ...
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
(3 hunks)cpp/tensorrt_llm/nanobind/runtime/bindings.cpp
(1 hunks)tensorrt_llm/_torch/speculative/mtp.py
(4 hunks)tests/unittest/_torch/multi_gpu/test_allreduce.py
(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tensorrt_llm/_torch/speculative/mtp.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,h,hpp,cc,cxx}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.{cpp,h,hpp,cc,cxx}
: Closing braces of namespaces should have a comment saying the namespace it closes (e.g., } // namespace foo).
Prefer const or constexpr variables over #defines whenever possible.
A variable that is not modified after its initialization should be declared as const.
Except 0 (used for checking signness/existence/emptiness), nullptr, true, false, all other literals should only be used for variable initialization.
Use the Allman indentation style for braces in C++ code.
Put the semicolon for an empty for or while loop in a new line.
The statement forming the body of a switch, while, do..while, or for statement shall be a compound statement (use brace-delimited statements).
If and else should always be followed by brace-delimited statements, even if empty or a single statement.
C++ filenames should use camel case with the first letter lowercase (e.g., thisIsAFilename.cpp), and all files involved in a compilation target must have case-insensitive unique filenames.
All types (including class names) should use camel case with uppercase first letter (e.g., FooBarClass).
Local variables, methods, and namespaces should use camel case with first letter lowercase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not defined in anonymous namespace should use camel case prefixed by 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number global variables that are static or defined in an anonymous namespace should use camel case prefixed by 's' (e.g., sMutableStaticGlobal).
Locally visible static variables should use camel case with lowercase prefix 's' as the first letter (e.g., static std::once_flag sFlag;).
Class member variables should use camel case prefixed with 'm' (e.g., mNbFooValues). Public member variables do not require the 'm' prefix but it is encouraged for clarity.
Enumerations, global constants, static constants at class-scope, and function-scope magic-number/literal constants should be uppercase snake case with prefix...
Files:
cpp/tensorrt_llm/nanobind/runtime/bindings.cpp
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
cpp/tensorrt_llm/nanobind/runtime/bindings.cpp
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
tests/unittest/_torch/multi_gpu/test_allreduce.py
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tests/unittest/_torch/multi_gpu/test_allreduce.py
🪛 Ruff (0.12.2)
tests/unittest/_torch/multi_gpu/test_allreduce.py
286-286: Line too long (135 > 120)
(E501)
303-303: Line too long (127 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
cpp/tensorrt_llm/nanobind/runtime/bindings.cpp (1)
347-358
: Exposing ALLGATHER in the nanobind enum: LGTM.The binding cleanly exposes the new enumerator and aligns with the expanded C++ enum.
tests/unittest/_torch/multi_gpu/test_allreduce.py (1)
315-316
: ALLGATHER mapping added to fusion_op_to_func: LGTM.The new mapping correctly wires the calc/ref paths for the ALLGATHER fusion op.
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu (3)
1-2
: LGTM! Copyright header is current and compliant.The NVIDIA copyright header includes the current year (2025) as required by the coding guidelines.
508-526
: Well-implemented AllGather vs AllReduce branching logic.The pattern-based conditional compilation correctly separates the AllGather output generation from the AllReduce computation and fusion operations. The AllGather output layout (concatenated per-rank data) aligns with the expected format for downstream consumption.
760-763
: Verify downstream support for AllGather patternAutomated searches confirmed that
- the dispatch branch for
AllReduceFusionPattern::kAllGather
is added in
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
(lines 760–763)- the pattern traits are defined in
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h
(line 95)However, no explicit
kAllGather
–specific handling was found in workspace‐allocation or parameter‐validation code. Please manually review and ensure that:
- workspace‐size calculations (e.g., via
DEFINE_FUSION_PATTERN_TRAITS
andcalculateTotalWorkspaceSize
) correctly include the AllGather case- any
params.checkParams()
or similar validation routines acceptkAllGather
- downstream kernels or helper functions that dispatch based on pattern (e.g., switch statements over
AllReduceFusionPattern
) include the new case
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer); | ||
|
||
nb::enum_<tensorrt_llm::kernels::AllReduceFusionOp>(m, "AllReduceFusionOp") | ||
.value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace tab with spaces to comply with code style (indentation rule).
C++ code must use spaces only for indentation. This line starts with a tab; clang-format/lint will likely fail.
Apply this diff:
- .value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER)
+ .value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
.value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER) | |
.value("ALLGATHER", tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER) |
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/nanobind/runtime/bindings.cpp around line 348, the line
begins with a tab which violates the project's C++ indentation rule; replace the
leading tab with spaces (use the repository's preferred indentation, e.g., 4
spaces) so the line starts with spaces only: .value("ALLGATHER",
tensorrt_llm::kernels::AllReduceFusionOp::ALLGATHER). Ensure there are no other
leading tabs on the same or surrounding lines.
def ref_allgather(x, res): | ||
print(f"[Rank {tensor_parallel_rank}] Starting ref_allgather") | ||
# Reference implementation: manually create expected AllGather result | ||
all_rank_data = [] | ||
for rank in range(tensor_parallel_size): | ||
# Create with default dtype first, then convert | ||
rank_data_orig = torch.arange( | ||
rank * 10, | ||
rank * 10 + 8, | ||
device="cuda" | ||
).reshape(1, 8) | ||
print(f"[Rank {tensor_parallel_rank}] Original ref data for rank {rank}: {rank_data_orig} (dtype: {rank_data_orig.dtype})") | ||
|
||
rank_data = rank_data_orig.to(dtype) | ||
print(f"[Rank {tensor_parallel_rank}] Converted ref data for rank {rank}: {rank_data}") | ||
print(f"[Rank {tensor_parallel_rank}] ref rank_data dtype: {rank_data.dtype}") | ||
|
||
# Pad or reshape to match expected input dimensions | ||
if rank_data.shape != x.shape: | ||
rank_data = rank_data.expand(x.shape[0], -1) | ||
if rank_data.shape[1] < x.shape[1]: | ||
# Repeat pattern to fill the hidden dimension | ||
repeat_factor = x.shape[1] // 8 | ||
remainder = x.shape[1] % 8 | ||
rank_data = rank_data.repeat(1, repeat_factor) | ||
if remainder > 0: | ||
rank_data = torch.cat([rank_data, rank_data[:, :remainder]], dim=1) | ||
|
||
print(f"[Rank {tensor_parallel_rank}] Final ref data for rank {rank}: shape {rank_data.shape}, values {rank_data}") | ||
all_rank_data.append(rank_data) | ||
|
||
# Concatenate along the first dimension (batch dimension) | ||
gathered_result = torch.cat(all_rank_data, dim=0) | ||
print(f"[Rank {tensor_parallel_rank}] Reference AllGather result:") | ||
print(f"[Rank {tensor_parallel_rank}] Shape: {gathered_result.shape}") | ||
print(f"[Rank {tensor_parallel_rank}] Values: {gathered_result} dtype: {gathered_result.dtype}") | ||
return [gathered_result] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove debug prints in ref_allgather; fix linter E501 and reduce log noise.
The repeated prints here trigger Ruff E501 (line too long) and generate excessive output. They’re not needed for test assertions.
Apply this diff:
def ref_allgather(x, res):
- print(f"[Rank {tensor_parallel_rank}] Starting ref_allgather")
# Reference implementation: manually create expected AllGather result
all_rank_data = []
for rank in range(tensor_parallel_size):
# Create with default dtype first, then convert
rank_data_orig = torch.arange(
rank * 10,
rank * 10 + 8,
device="cuda"
).reshape(1, 8)
- print(f"[Rank {tensor_parallel_rank}] Original ref data for rank {rank}: {rank_data_orig} (dtype: {rank_data_orig.dtype})")
-
rank_data = rank_data_orig.to(dtype)
- print(f"[Rank {tensor_parallel_rank}] Converted ref data for rank {rank}: {rank_data}")
- print(f"[Rank {tensor_parallel_rank}] ref rank_data dtype: {rank_data.dtype}")
-
# Pad or reshape to match expected input dimensions
if rank_data.shape != x.shape:
rank_data = rank_data.expand(x.shape[0], -1)
if rank_data.shape[1] < x.shape[1]:
# Repeat pattern to fill the hidden dimension
repeat_factor = x.shape[1] // 8
remainder = x.shape[1] % 8
rank_data = rank_data.repeat(1, repeat_factor)
if remainder > 0:
rank_data = torch.cat([rank_data, rank_data[:, :remainder]], dim=1)
-
- print(f"[Rank {tensor_parallel_rank}] Final ref data for rank {rank}: shape {rank_data.shape}, values {rank_data}")
all_rank_data.append(rank_data)
# Concatenate along the first dimension (batch dimension)
gathered_result = torch.cat(all_rank_data, dim=0)
- print(f"[Rank {tensor_parallel_rank}] Reference AllGather result:")
- print(f"[Rank {tensor_parallel_rank}] Shape: {gathered_result.shape}")
- print(f"[Rank {tensor_parallel_rank}] Values: {gathered_result} dtype: {gathered_result.dtype}")
return [gathered_result]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def ref_allgather(x, res): | |
print(f"[Rank {tensor_parallel_rank}] Starting ref_allgather") | |
# Reference implementation: manually create expected AllGather result | |
all_rank_data = [] | |
for rank in range(tensor_parallel_size): | |
# Create with default dtype first, then convert | |
rank_data_orig = torch.arange( | |
rank * 10, | |
rank * 10 + 8, | |
device="cuda" | |
).reshape(1, 8) | |
print(f"[Rank {tensor_parallel_rank}] Original ref data for rank {rank}: {rank_data_orig} (dtype: {rank_data_orig.dtype})") | |
rank_data = rank_data_orig.to(dtype) | |
print(f"[Rank {tensor_parallel_rank}] Converted ref data for rank {rank}: {rank_data}") | |
print(f"[Rank {tensor_parallel_rank}] ref rank_data dtype: {rank_data.dtype}") | |
# Pad or reshape to match expected input dimensions | |
if rank_data.shape != x.shape: | |
rank_data = rank_data.expand(x.shape[0], -1) | |
if rank_data.shape[1] < x.shape[1]: | |
# Repeat pattern to fill the hidden dimension | |
repeat_factor = x.shape[1] // 8 | |
remainder = x.shape[1] % 8 | |
rank_data = rank_data.repeat(1, repeat_factor) | |
if remainder > 0: | |
rank_data = torch.cat([rank_data, rank_data[:, :remainder]], dim=1) | |
print(f"[Rank {tensor_parallel_rank}] Final ref data for rank {rank}: shape {rank_data.shape}, values {rank_data}") | |
all_rank_data.append(rank_data) | |
# Concatenate along the first dimension (batch dimension) | |
gathered_result = torch.cat(all_rank_data, dim=0) | |
print(f"[Rank {tensor_parallel_rank}] Reference AllGather result:") | |
print(f"[Rank {tensor_parallel_rank}] Shape: {gathered_result.shape}") | |
print(f"[Rank {tensor_parallel_rank}] Values: {gathered_result} dtype: {gathered_result.dtype}") | |
return [gathered_result] | |
def ref_allgather(x, res): | |
# Reference implementation: manually create expected AllGather result | |
all_rank_data = [] | |
for rank in range(tensor_parallel_size): | |
# Create with default dtype first, then convert | |
rank_data_orig = torch.arange( | |
rank * 10, | |
rank * 10 + 8, | |
device="cuda" | |
).reshape(1, 8) | |
rank_data = rank_data_orig.to(dtype) | |
# Pad or reshape to match expected input dimensions | |
if rank_data.shape != x.shape: | |
rank_data = rank_data.expand(x.shape[0], -1) | |
if rank_data.shape[1] < x.shape[1]: | |
# Repeat pattern to fill the hidden dimension | |
repeat_factor = x.shape[1] // 8 | |
remainder = x.shape[1] % 8 | |
rank_data = rank_data.repeat(1, repeat_factor) | |
if remainder > 0: | |
rank_data = torch.cat([rank_data, rank_data[:, :remainder]], dim=1) | |
all_rank_data.append(rank_data) | |
# Concatenate along the first dimension (batch dimension) | |
gathered_result = torch.cat(all_rank_data, dim=0) | |
return [gathered_result] |
🧰 Tools
🪛 Ruff (0.12.2)
286-286: Line too long (135 > 120)
(E501)
303-303: Line too long (127 > 120)
(E501)
🤖 Prompt for AI Agents
In tests/unittest/_torch/multi_gpu/test_allreduce.py around lines 275 to 312,
remove all debug print statements inside ref_allgather to eliminate excessive
test output and Ruff E501 violations; instead keep the function silent and only
build and return the expected gathered_result. Ensure any long lines are wrapped
or simplified to respect line-length limits, and do not introduce new prints or
logging in this helper.
PR_Github #14861 [ run ] completed with state |
allreduce_fusion_params.pattern = tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kAllReduce; | ||
} | ||
// Handle AllGather operation | ||
else if (mOp == AllReduceFusionOp::ALLGATHER) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is a correct way to implement ALLGATHER inside the allreduce op.
fused_op(sum_val, tidx); | ||
|
||
// Handle AllGather pattern - output gathered data before reduction | ||
if constexpr (HasAllGatherOut<Pattern>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The correct way is making the common part a common function and erase "allreduce" related string from the common function. And use the common function to implement allreduce and allgather.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the code, I see you are reusing so many allreduce code. But adding allgather into allreduce breaks semantic of the allreduce Op. This is even worse than a copy paste.
If you really want to share the code, you need to make fundamental refactor to make the common code agnostic to allreduce and allgather.
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Chores
Utilize Lamport-based allgather. A new option has been introduced in Lamport-based allreduce to provide the output after only the gather operation. When employed in MTP module after lm_head and local_argmax ops, the performance of Deepseek R1 BS1 improves by approximately 2 TPS on B200x8. Considering the message size is small, the previously used nccl-based allgather is approximately 2-3 times slower than the newly introduced Lamport-based allgather.