-
Notifications
You must be signed in to change notification settings - Fork 1.7k
feat: RayExecutor #7240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: RayExecutor #7240
Conversation
Signed-off-by: Jonas yang <[email protected]>
Signed-off-by: Yuan Tong <[email protected]>
Signed-off-by: Jonas yang <[email protected]>
Signed-off-by: Jonas yang <[email protected]>
Signed-off-by: Jonas yang <[email protected]>
Signed-off-by: Yuan Tong <[email protected]>
…mmit for cacheTrans
Signed-off-by: Yuan Tong <[email protected]>
Signed-off-by: Yuan Tong <[email protected]>
Signed-off-by: Yuan Tong <[email protected]>
Signed-off-by: Jonas yang <[email protected]>
Signed-off-by: Jonas yang <[email protected]>
Signed-off-by: Erin Ho <[email protected]>
📝 WalkthroughWalkthroughIntroduces a ProcessGroup (PG) backend alongside MPI across C++ and Python: adds pg_utils module, broker bindings, and a unified CacheTransceiverComm. Extends thop ops to PG variants. Adds Ray-based executor/worker, Ray-backed result queues, and device-mesh enhancements. Broad CMake/packaging updates, new examples/tests, environment gating via TLLM_DISABLE_MPI, and a ModelConfig.clone(). Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Py as Python LLMAPI
participant Exec as GenerationExecutor (Ray)
participant Ray as Ray Cluster
participant Wrk0 as RayWorkerWrapper[rank 0]
participant WrkN as RayWorkerWrapper[rank N-1]
User->>Py: LLM(..., executor_type="ray").generate_async()
Py->>Exec: create(..., tp_size, executor_type="ray")
Exec->>Ray: ensure cluster, placement group
Exec->>Ray: spawn RayWorkerWrapper actors
Ray-->>Exec: actor handles ready
Py->>Exec: submit(request)
Exec->>Wrk0: enqueue_request(request, result_queue)
par Fan-out (PP/TP)
Wrk0->>WrkN: dist init/barrier (PG path when TLLM_DISABLE_MPI=1)
end
Wrk0-->>Exec: responses via RayAsyncQueue.put_response
Exec-->>Py: GenerationResult (queue-backed)
Py-->>User: streamed/final tokens
sequenceDiagram
autonumber
participant Cpp as CacheTransceiver
participant Comm as CacheTransceiverComm
participant PG as ProcessGroup (Python)
participant MPI as MpiComm
Cpp->>Comm: split(color,key)
alt MPI enabled
Comm->>MPI: MPI_Comm_split
MPI-->>Comm: sub-comm
else PG path
Comm->>PG: py import pg_utils.split
PG-->>Comm: boxed ProcessGroup subgroup
end
Cpp->>Comm: allgatherv(input, sizes)
alt MPI enabled
Comm->>MPI: MPI_Allgatherv
else PG path
Comm->>PG: PgHelper.allgatherv
end
sequenceDiagram
autonumber
participant Py as Python
participant Torch as Torch ext (pgBroker)
participant Cpp as pg_utils (C++)
Py->>Torch: init_pg(world_pg, local_pg)
Torch->>Cpp: init_pg(...)
Cpp-->>Torch: set globals
Torch-->>Py: ok
Py->>Torch: init_store(default_store)
Torch->>Cpp: init_store(store)
Cpp-->>Torch: ok
Estimated code review effort🎯 5 (Critical) | ⏱️ ~150–240 minutes Possibly related PRs
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 52
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (16)
tests/integration/defs/trt_test_alternative.py (1)
1-1
: Add NVIDIA copyright header (2025) and SPDXRepo guideline requires the header on all .py files. Please prepend:
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # SPDX-License-Identifier: Apache-2.0tensorrt_llm/_utils.py (1)
701-712
: Make is_trace_enabled backend-agnostic and robust; avoid global_mpi_rank() dependency.Using global_mpi_rank() forces rank 0 in PG mode (per current implementation), preventing tracing for nonzero PG ranks. Compare the env int against mpi_rank() instead. Also make "ALL" case-insensitive and broaden exception handling minimally.
Apply this diff:
@@ def is_trace_enabled(env_var: str): value = os.environ.get(env_var, "-1") - if value == "ALL": + if value.upper() == "ALL": return True if value == "-1": return False try: - # (TODO: joyang) Need to check if this is needed. - return int(value) == global_mpi_rank() - except ValueError: + # Compare against backend-agnostic rank. + return int(value) == mpi_rank() + except ValueError: return FalseIf helpful, I can add a small unit test matrix covering:
- MPI enabled vs TLLM_DISABLE_MPI=1
- PG initialized vs not initialized
- env values: "-1", "ALL", "0", "1".
tensorrt_llm/_torch/models/modeling_utils.py (2)
886-914
: remove_duplicate=False can double-load shared submodules; deduplicate by module identity to avoid races and redundant workIterating with
remove_duplicate=False
yields the same module object under multiple names when submodules are shared. In serial mode this re-loads the same object multiple times; in concurrent mode it can schedule the same module for simultaneous writes, which is a race. Deduplicate the iteration byid(module)
while still honoring the chosen names.Apply this diff:
- for name, module in tqdm(list( - model.named_modules(remove_duplicate=False)), - desc="Loading weights"): - load_single_module(name, module) + all_named = list(model.named_modules(remove_duplicate=False)) + seen_ids = set() + unique_named = [] + for nm, mod in all_named: + mid = id(mod) + if mid in seen_ids: + continue + seen_ids.add(mid) + unique_named.append((nm, mod)) + for name, module in tqdm(unique_named, desc="Loading weights"): + load_single_module(name, module) @@ - all_modules = dict(model.named_modules(remove_duplicate=False)) + all_named = list(model.named_modules(remove_duplicate=False)) + seen_ids = set() + unique_named = [] + for nm, mod in all_named: + mid = id(mod) + if mid in seen_ids: + continue + seen_ids.add(mid) + unique_named.append((nm, mod)) + all_modules = dict(unique_named) @@ - pbar = tqdm(list(model.named_modules(remove_duplicate=False)), - desc="Loading weights concurrently") - args_list = [ - (name, module) - for name, module in model.named_modules(remove_duplicate=False) - if name not in serial_load_modules - ] + pbar = tqdm([nm for nm, _ in unique_named], + desc="Loading weights concurrently") + args_list = [ + (name, module) + for name, module in unique_named + if name not in serial_load_modules + ] run_concurrently(load_single_module, args_list, pbar=pbar)
535-539
: Avoid mutable default arguments (skip_modules=[]
)Using a list as a default argument can leak state across calls. Prefer
None
with an internal normalization.Apply this diff:
- def load_weights(self, - weights: Dict, - weight_mapper: Optional["BaseWeightMapper"] = None, - skip_modules: List[str] = []): + def load_weights(self, + weights: Dict, + weight_mapper: Optional["BaseWeightMapper"] = None, + skip_modules: Optional[List[str]] = None): + if skip_modules is None: + skip_modules = []And similarly update helper signatures:
-def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], - weights: Dict, - skip_modules: List[str] = [], +def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], + weights: Dict, + skip_modules: Optional[List[str]] = None, params_map: Optional[Dict[str, str]] = None, preload_weight_modules: Optional[List[str]] = None): + skip_modules = skip_modules or []-def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM], - weights: Dict, - weight_mapper: "BaseWeightMapper", - skip_modules: List[str] = [], +def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM], + weights: Dict, + weight_mapper: "BaseWeightMapper", + skip_modules: Optional[List[str]] = None, params_map: Optional[Dict[str, str]] = None, preload_weight_modules: Optional[List[str]] = None): + skip_modules = skip_modules or []tensorrt_llm/_torch/models/modeling_hyperclovax.py (1)
1-1
: Add/ensure NVIDIA copyright header (2025)This file lacks the required header.
Insert the standard 2025 NVIDIA Apache-2.0 header at the top, before imports.
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
189-223
: Critical: Modifyingstate_dict()
tensors won’t update model weights.
for name, param in model.state_dict().items():
iterates over a copy of tensors; in-place ops on these do not mutate the module’s parameters/buffers. As a result, dummy weights are not actually initialized. Iterate over named parameters and (if needed) buffers instead.- for name, param in model.state_dict().items(): - logger.info(f"Initializing {name} with shape {param.data.shape}") - import hashlib - - hashobj = hashlib.sha256(name.encode('ascii')) - name_hash = int.from_bytes(hashobj.digest(), 'big') % 10**8 - - generator = torch.Generator(device=param.data.device) - generator.manual_seed(name_hash) - dtype = param.data.dtype + from itertools import chain + for name, param in chain(model.named_parameters(recurse=True), + model.named_buffers(recurse=True)): + logger.info(f"Initializing {name} with shape {param.data.shape}") + # Seed per tensor by name for determinism across ranks/devices + import hashlib + hashobj = hashlib.sha256(name.encode('utf-8')) + name_hash = int.from_bytes(hashobj.digest(), 'big') % 10**8 + + generator = torch.Generator(device=param.data.device) + generator.manual_seed(name_hash) + dtype = param.data.dtype @@ - elif torch.is_floating_point(param): - param.uniform_(low, high, generator=generator) + elif torch.is_floating_point(param.data): + param.data.uniform_(low, high, generator=generator)Optional: change info-level logs to debug to avoid significant overhead on large models.
- logger.info(f"Initializing {name} with shape {param.data.shape}") + logger.debug(f"Initializing {name} with shape {param.data.shape}") @@ - logger.info( + logger.debug( f"After initialization {name}: {tmp_param.mean()}, {tmp_param.std()}" ) @@ - logger.info( + logger.debug( f"After initialization {name}: {param.mean()}, {param.std()}")Also prefer UTF-8 encoding over ASCII for tensor names.
tensorrt_llm/_torch/models/modeling_phi4mm.py (1)
1-3
: Missing NVIDIA copyright header (2025)Per repository guidelines, prepend the NVIDIA header to Python sources.
Apply this diff at the top of the file:
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # SPDX-License-Identifier: Apache-2.0tensorrt_llm/_torch/models/modeling_qwen2vl.py (1)
1-3
: Missing NVIDIA copyright header (2025)Please prepend the standard header.
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # SPDX-License-Identifier: Apache-2.0tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
1-3
: Add NVIDIA header alongside third‑party noticeThis file contains third‑party license text (DeepSeek MIT). Per guidelines, prepend the NVIDIA header while preserving the existing attribution and license block.
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # SPDX-License-Identifier: Apache-2.0 # -------------------------------------------------- # Portions of this code were derived from DeepSeek‑V3:cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.h (1)
20-41
: Missing explicit includes for used STL types.This header declares members of types that require explicit includes:
- std::unordered_map (mAddressToConnectionId)
- std::atomic (mConnectionIdCounter)
- std::thread (mZmqRepThread)
Relying on transitive includes is brittle for headers.
Apply this diff:
#include "ucxx/utils/ucx.h" #include <cstdint> #include <future> +#include <unordered_map> +#include <atomic> +#include <thread>cpp/tensorrt_llm/batch_manager/CMakeLists.txt (1)
112-116
: Guard PyTorch/Python linkage under the existing BUILD_PYT option and scope it privatelyThe batch manager’s CMakeLists currently unconditionally pulls in
torch_python
,Python3::Python
, andpg_utils
as PUBLIC dependencies, even when PyTorch features aren’t enabled. This:
- Forces PyTorch and Python linkage on all consumers and breaks builds that don’t use Torch.
- Exposes unnecessary usage requirements downstream via PUBLIC on a static library.
- Relies on an undefined
TORCH_INSTALL_PREFIX
variable forfind_library
.We already have a global
BUILD_PYT
option (defined incpp/CMakeLists.txt
at line 30) that gates all TorchScript–related logic. We should:
- Wrap all Torch/Python linkage in
if(BUILD_PYT)…endif()
.- Prefer
find_package(Torch COMPONENTS Python)
(with fallback to rawfind_library
) rather than assumingTORCH_INSTALL_PREFIX
.- Use
PRIVATE
linkage for a static library to avoid propagating usage requirements.- Only link
pg_utils
if that target exists.Affected file:
cpp/tensorrt_llm/batch_manager/CMakeLists.txt
around lines 112–116.Example refactor:
if(ENABLE_UCX) … endif() -find_library(TORCH_PYTHON_LIB torch_python REQUIRED - HINTS ${TORCH_INSTALL_PREFIX}/lib) -target_link_libraries(${BATCH_MANAGER_STATIC_TARGET} - PUBLIC ${TORCH_PYTHON_LIB} Python3::Python pg_utils) +if(BUILD_PYT) + # Prefer modern CMake target for Torch/Python, fallback to raw library + find_package(Torch QUIET COMPONENTS Python) + if(TARGET Torch::Python) + target_link_libraries(${BATCH_MANAGER_STATIC_TARGET} + PRIVATE Torch::Python Python3::Python) + else() + find_library(TORCH_PYTHON_LIB torch_python HINTS ${TORCH_INSTALL_PREFIX}/lib) + if(TORCH_PYTHON_LIB) + target_link_libraries(${BATCH_MANAGER_STATIC_TARGET} + PRIVATE ${TORCH_PYTHON_LIB} Python3::Python) + else() + message(STATUS + "Torch::Python not found; skipping PyTorch linkage for ${BATCH_MANAGER_STATIC_TARGET}") + endif() + endif() + + # Only link pg_utils if that target is defined + if(TARGET pg_utils) + target_link_libraries(${BATCH_MANAGER_STATIC_TARGET} PRIVATE pg_utils) + endif() +endif()Use the existing
BUILD_PYT
option to gate all PyTorch–related logic.cpp/tensorrt_llm/runtime/utils/mpiUtils.cpp (2)
17-19
: Missing standard library include for std::unordered_map (and chrono used later).This file uses std::unordered_map (in getMpiDtype/getMpiOp) and std::chrono::milliseconds (in recvPoll) but doesn't include the corresponding headers. Depending on transitive includes, this can fail to build on stricter compilers.
Apply:
#include <numeric> -#include <unordered_set> +#include <unordered_set> +#include <unordered_map> @@ #include <mutex> #include <thread> #include <type_traits> +#include <chrono>
307-315
: Use correct format specifier for size_t in logs.size is size_t but is logged with %d. On LP64, this truncates and is UB. Use %zu (or cast to unsigned long long and use %llu).
- TLLM_LOG_DEBUG("start MPI_Isend with dest %d, tag %d, size %d", dest, static_cast<int>(tag), size); + TLLM_LOG_DEBUG("start MPI_Isend with dest %d, tag %d, size %zu", dest, static_cast<int>(tag), size); @@ - TLLM_LOG_DEBUG("end MPI_Isend with dest %d, tag %d, size %d", dest, static_cast<int>(tag), size); + TLLM_LOG_DEBUG("end MPI_Isend with dest %d, tag %d, size %zu", dest, static_cast<int>(tag), size); @@ - TLLM_LOG_DEBUG("start MPI_Send with dest %d, tag %d, size %d", dest, tag, size); + TLLM_LOG_DEBUG("start MPI_Send with dest %d, tag %d, size %zu", dest, tag, size); @@ - TLLM_LOG_DEBUG("end MPI_Send with dest %d, tag %d, size %d", dest, tag, size); + TLLM_LOG_DEBUG("end MPI_Send with dest %d, tag %d, size %zu", dest, tag, size); @@ - TLLM_LOG_DEBUG("start MPI_Recv with source %d, tag %d, size %d", source, tag, size); + TLLM_LOG_DEBUG("start MPI_Recv with source %d, tag %d, size %zu", source, tag, size); @@ - TLLM_LOG_DEBUG("end MPI_Recv with source %d, tag %d, size %d", source, tag, size); + TLLM_LOG_DEBUG("end MPI_Recv with source %d, tag %d, size %zu", source, tag, size);Also applies to: 327-334, 351-359
cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.cpp (1)
213-219
: Inconsistent rank usage in loggingWhile most logging now uses
mRank
, these lines still usempi::MpiComm::world().getRank()
, which may be uninitialized or inconsistent when using ProcessGroup backend.- TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager mZmqRepEndpoint: %s", + TLLM_LOG_INFO(mRank, "UcxConnectionManager::UcxConnectionManager mZmqRepEndpoint: %s", mZmqRepEndpoint.c_str()); auto parse_result = parse_zmq_endpoint(mZmqRepEndpoint); TLLM_CHECK_WITH_INFO(parse_result.has_value(), "Failed to parse ZMQ endpoint"); auto [ip, port] = parse_result.value(); - TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "UcxConnectionManager::UcxConnectionManager ip: %s, port: %d", + TLLM_LOG_INFO(mRank, "UcxConnectionManager::UcxConnectionManager ip: %s, port: %d", ip.c_str(), port);tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
510-519
: Fix logging: self.global_rank no longer exists (removed earlier).This will raise AttributeError the first time the profiler logs. Use
self.dist.rank
(already present) or remove the “global_rank” field.Apply this diff:
- logger.info( - f"iter = {self.model_engine.iter_counter}, " - f"global_rank = {self.global_rank}, " - f"rank = {self.dist.rank}, " + logger.info( + f"iter = {self.model_engine.iter_counter}, " + f"rank = {self.dist.rank}, "tensorrt_llm/llmapi/llm.py (1)
919-934
: TRT create() call not propagating Ray/TP parameters; likely needed for RayExecutor.PyTorch path passes executor_type and tp_size into GenerationExecutor.create(), but the TRT path doesn’t. If Ray/TRT is intended, this is a functional gap.
Apply this diff (keeping argument order consistent with the PyTorch path if supported by the TRT executor):
self._executor = self._executor_cls.create( self._engine_dir, executor_config=self._executor_config, batched_logits_processor=self.args.batched_logits_processor, model_world_size=self.args.parallel_config.world_size, mpi_session=self.mpi_session, reuse_mpi_comm=external_mpi_comm_available( self.args.parallel_config.world_size), return_logits=return_logits, postproc_worker_config=PostprocWorkerConfig( num_postprocess_workers=self.args.num_postprocess_workers, postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir, ), is_llm_executor=True, - lora_config=lora_config) + lora_config=lora_config, + executor_type=self._executor_type, + tp_size=self.args.tensor_parallel_size, + )If TRT executor doesn’t accept these kwargs, please confirm and gate them by backend.
bool couldUseMPI() const | ||
{ | ||
if (!mDisableMPI.has_value()) | ||
{ | ||
char* val = std::getenv("TLLM_DISABLE_MPI"); | ||
; | ||
bool disable_mpi = false; | ||
if (val != NULL && std::string(val) == "1") | ||
{ | ||
throw std::runtime_error("MPI is disabled, DON\'T USE MPI"); | ||
} | ||
mDisableMPI = disable_mpi; | ||
} | ||
|
||
return mDisableMPI.value(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couldUseMPI() returns the wrong value and throws where a pure query is expected
- The function name implies a pure check, but it throws when disabled; this makes if (couldUseMPI()) branches unusable.
- It returns mDisableMPI.value() (i.e., “disabled?”) instead of “could use?” causing it to return false when MPI is enabled.
- It also relies on std::string without including and uses std::runtime_error without including .
This will break the intended backend gating and can cause spurious exceptions.
Apply this fix to make the API boolean and non-throwing; let callers decide behavior:
- bool couldUseMPI() const
+ [[nodiscard]] bool couldUseMPI() const
{
if (!mDisableMPI.has_value())
{
- char* val = std::getenv("TLLM_DISABLE_MPI");
- ;
- bool disable_mpi = false;
- if (val != NULL && std::string(val) == "1")
- {
- throw std::runtime_error("MPI is disabled, DON\'T USE MPI");
- }
- mDisableMPI = disable_mpi;
+ char const* val = std::getenv("TLLM_DISABLE_MPI");
+ bool const disableMpi = (val != nullptr) && (std::strcmp(val, "1") == 0);
+ mDisableMPI = disableMpi;
}
-
- return mDisableMPI.value();
+ // couldUse == !disabled
+ return !mDisableMPI.value();
}
And add missing headers:
#include <memory>
#include <mutex>
+#include <optional>
+#include <stdexcept>
+#include <string>
#include <thread>
If you prefer to raise on misuse, add a separate ensureMPIEnabled() that throws, while couldUseMPI() remains a pure query.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
bool couldUseMPI() const | |
{ | |
if (!mDisableMPI.has_value()) | |
{ | |
char* val = std::getenv("TLLM_DISABLE_MPI"); | |
; | |
bool disable_mpi = false; | |
if (val != NULL && std::string(val) == "1") | |
{ | |
throw std::runtime_error("MPI is disabled, DON\'T USE MPI"); | |
} | |
mDisableMPI = disable_mpi; | |
} | |
return mDisableMPI.value(); | |
} | |
++ b/cpp/include/tensorrt_llm/runtime/utils/mpiUtils.h | |
@@ | |
#include <memory> | |
#include <mutex> | |
#include <optional> | |
#include <stdexcept> | |
#include <string> | |
#include <thread> |
bool couldUseMPI() const | |
{ | |
if (!mDisableMPI.has_value()) | |
{ | |
char* val = std::getenv("TLLM_DISABLE_MPI"); | |
; | |
bool disable_mpi = false; | |
if (val != NULL && std::string(val) == "1") | |
{ | |
throw std::runtime_error("MPI is disabled, DON\'T USE MPI"); | |
} | |
mDisableMPI = disable_mpi; | |
} | |
return mDisableMPI.value(); | |
} | |
++ b/cpp/include/tensorrt_llm/runtime/utils/mpiUtils.h | |
@@ lines 429-444 | |
[[nodiscard]] bool couldUseMPI() const | |
{ | |
if (!mDisableMPI.has_value()) | |
{ | |
- char* val = std::getenv("TLLM_DISABLE_MPI"); | |
- ; | |
- bool disable_mpi = false; | |
- if (val != NULL && std::string(val) == "1") | |
- { | |
- throw std::runtime_error("MPI is disabled, DON\'T USE MPI"); | |
- } | |
char const* val = std::getenv("TLLM_DISABLE_MPI"); | |
bool const disableMpi = (val != nullptr) && (std::strcmp(val, "1") == 0); | |
mDisableMPI = disableMpi; | |
} | |
- | |
// couldUse == not disabled | |
return !mDisableMPI.value(); | |
} |
🤖 Prompt for AI Agents
In cpp/include/tensorrt_llm/runtime/utils/mpiUtils.h around lines 429-444,
couldUseMPI() currently throws, always sets mDisableMPI=false, and returns the
disabled flag instead of whether MPI can be used; also missing headers. Change
the function to be a pure, non-throwing query: read getenv("TLLM_DISABLE_MPI"),
set mDisableMPI = (val != nullptr && std::string(val) == "1"); do NOT throw; and
return !mDisableMPI.value() (i.e. true when MPI is usable). Also add the missing
#include <string> and #include <stdexcept> at the top of the header (or remove
std::runtime_error usage if you keep no throwing). If you still want a throwing
variant, add a separate ensureMPIEnabled() that checks couldUseMPI() and throws.
int mRank; | ||
int mWorldSize; | ||
std::atomic<UcxConnection::ConnectionIdType> mConnectionIdCounter{1}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initialize new rank/world-size members to safe defaults to avoid UB before ctor assignment.
Uninitialized mRank
/mWorldSize
risk undefined behavior if accessed (e.g., via getRank()
) before constructor sets them. Initialize in-class.
Apply this diff:
- int mRank;
- int mWorldSize;
+ int mRank{-1};
+ int mWorldSize{1};
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
int mRank; | |
int mWorldSize; | |
std::atomic<UcxConnection::ConnectionIdType> mConnectionIdCounter{1}; | |
int mRank{-1}; | |
int mWorldSize{1}; | |
std::atomic<UcxConnection::ConnectionIdType> mConnectionIdCounter{1}; |
🤖 Prompt for AI Agents
cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.h
around lines 58-60: mRank and mWorldSize are declared without in-class
initializers which can lead to undefined behavior if accessed before the
constructor assigns them; add safe default initializers (e.g., int mRank{0}; int
mWorldSize{0};) beside the existing mConnectionIdCounter initializer so they
have defined values prior to ctor assignment.
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents) | ||
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
New binding looks correct; double-check trampoline coverage for resetReuseState.
Binding reset_reuse_state is consistent with the new C++ API.
- If BaseKVCacheManager::resetReuseState is pure virtual, the nanobind trampoline (class PyKvCacheManager) should also expose an override using NB_OVERRIDE_PURE to keep Python subclassing viable. If it’s not pure, no change needed.
- If needed, add:
class PyKvCacheManager : public tbk::BaseKVCacheManager
{
public:
- NB_TRAMPOLINE(tbk::BaseKVCacheManager, 28);
+ NB_TRAMPOLINE(tbk::BaseKVCacheManager, 29); // increment if you add a new override below
@@
void flushIterationEvents() override
{
NB_OVERRIDE_PURE(flushIterationEvents);
}
+
+ void resetReuseState() override
+ {
+ NB_OVERRIDE_PURE(resetReuseState);
+ }
};
You can verify purity with:
🏁 Script executed:
#!/bin/bash
set -euo pipefail
rg -n "class BaseKVCacheManager"
rg -n "virtual\s+void\s+resetReuseState\s*\(" cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Length of output: 864
Add pure-virtual override for resetReuseState
in the Python trampoline
We’ve confirmed that BaseKVCacheManager::resetReuseState()
is declared as a pure virtual method (= 0
) in cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h:1307
, so the Nanobind trampoline must provide an override to keep Python subclassing viable.
Please update the PyKvCacheManager
class in cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
as follows:
- Increment the trampoline counter to account for the new override.
- Add a
resetReuseState()
override that callsNB_OVERRIDE_PURE
.
class PyKvCacheManager : public tbk::BaseKVCacheManager
{
public:
- NB_TRAMPOLINE(tbk::BaseKVCacheManager, 28);
+ NB_TRAMPOLINE(tbk::BaseKVCacheManager, 29); // incremented for new override
void flushIterationEvents() override
{
NB_OVERRIDE_PURE(flushIterationEvents);
}
+ void resetReuseState() override
+ {
+ NB_OVERRIDE_PURE(resetReuseState);
+ }
};
With this change, Python subclasses can correctly override reset_reuse_state
.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents) | |
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState); | |
// File: cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp | |
class PyKvCacheManager : public tbk::BaseKVCacheManager | |
{ | |
public: | |
// Increment trampoline count to account for new pure-virtual override | |
NB_TRAMPOLINE(tbk::BaseKVCacheManager, 29); | |
void flushIterationEvents() override | |
{ | |
NB_OVERRIDE_PURE(flushIterationEvents); | |
} | |
void resetReuseState() override | |
{ | |
NB_OVERRIDE_PURE(resetReuseState); | |
} | |
}; |
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp around lines 430
to 432, the Python trampoline PyKvCacheManager is missing a pure-virtual
override for BaseKVCacheManager::resetReuseState; update the trampoline by
incrementing its counter and adding a resetReuseState() override that calls
NB_OVERRIDE_PURE so Python subclasses can override reset_reuse_state; ensure the
new method signature matches the C++ pure-virtual (no args, void) and place it
with the other NB_OVERRIDE/NB_OVERRIDE_PURE methods in the PyKvCacheManager
class.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & | ||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix SPDX header formatting
The SPDX header is split across two lines incorrectly. It should be on a single line.
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
-# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & | |
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 | |
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 |
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/runtime/utils/CMakeLists.txt around lines 1 to 2, the SPDX
header is incorrectly split across two lines; replace the two-line header with a
single-line SPDX header containing the full copyright and license tag (e.g.,
combine the existing two lines into one properly formatted SPDX header) so the
file has one continuous SPDX header line at the top.
m.def("init_store", | ||
[](const py::object store_obj, std::string pybind11_abi) | ||
{ | ||
auto* pStore = pybind11_conduit_v1::get_type_pointer_ephemeral<c10d::Store>(store_obj.ptr(), pybind11_abi); | ||
if (pStore == nullptr) | ||
{ | ||
throw py::error_already_set(); | ||
} | ||
init_store(c10::intrusive_ptr<c10d::Store>::reclaim_copy(pStore)); | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid throwing py::error_already_set when no Python error is set.
If pointer extraction returns nullptr without a pending Python error, throwing py::error_already_set() raises an internal exception. Prefer surfacing the existing error if present, else raise a typed error.
- if (pStore == nullptr)
- {
- throw py::error_already_set();
- }
+ if (pStore == nullptr)
+ {
+ if (PyErr_Occurred())
+ {
+ throw py::error_already_set();
+ }
+ throw py::type_error("Failed to resolve c10d::Store pointer from the provided object.");
+ }
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
m.def("init_store", | |
[](const py::object store_obj, std::string pybind11_abi) | |
{ | |
auto* pStore = pybind11_conduit_v1::get_type_pointer_ephemeral<c10d::Store>(store_obj.ptr(), pybind11_abi); | |
if (pStore == nullptr) | |
{ | |
throw py::error_already_set(); | |
} | |
init_store(c10::intrusive_ptr<c10d::Store>::reclaim_copy(pStore)); | |
}); | |
m.def("init_store", | |
[](const py::object store_obj, std::string pybind11_abi) | |
{ | |
auto* pStore = pybind11_conduit_v1::get_type_pointer_ephemeral<c10d::Store>( | |
store_obj.ptr(), pybind11_abi); | |
if (pStore == nullptr) | |
{ | |
if (PyErr_Occurred()) | |
{ | |
throw py::error_already_set(); | |
} | |
throw py::type_error( | |
"Failed to resolve c10d::Store pointer from the provided object."); | |
} | |
init_store( | |
c10::intrusive_ptr<c10d::Store>::reclaim_copy(pStore)); | |
}); |
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/runtime/utils/pgUtilsBindings.cpp around lines 79-88, change
the nullptr handling so you check whether a Python error is pending before
throwing py::error_already_set; if PyErr_Occurred() (or
py::detail::get_internals()->python_error) is true then throw
py::error_already_set(), otherwise throw a typed C++/pybind11 exception (for
example py::type_error or std::runtime_error) with a clear message like "failed
to extract c10d::Store pointer from object" so you never raise
py::error_already_set when no Python error is set.
if executor_type == "ray": | ||
return GenerationExecutor._create_ray_executor( | ||
worker_kwargs, | ||
model_world_size, | ||
postproc_worker_config, | ||
is_llm_executor=is_llm_executor, | ||
tp_size=args.get("tp_size", 1)) | ||
elif executor_type is not None: | ||
assert False, "Invalid executor type" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Replace assert with a proper exception for invalid executor_type.
Asserts can be stripped with -O and raise AssertionError (not ideal for API). Use ValueError with the invalid value.
- elif executor_type is not None:
- assert False, "Invalid executor type"
+ elif executor_type is not None:
+ raise ValueError(f"Invalid executor type: {executor_type!r}. Only 'ray' is supported at the moment.")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if executor_type == "ray": | |
return GenerationExecutor._create_ray_executor( | |
worker_kwargs, | |
model_world_size, | |
postproc_worker_config, | |
is_llm_executor=is_llm_executor, | |
tp_size=args.get("tp_size", 1)) | |
elif executor_type is not None: | |
assert False, "Invalid executor type" | |
if executor_type == "ray": | |
return GenerationExecutor._create_ray_executor( | |
worker_kwargs, | |
model_world_size, | |
postproc_worker_config, | |
is_llm_executor=is_llm_executor, | |
tp_size=args.get("tp_size", 1)) | |
elif executor_type is not None: | |
raise ValueError(f"Invalid executor type: {executor_type!r}. Only 'ray' is supported at the moment.") |
🤖 Prompt for AI Agents
In tensorrt_llm/executor/executor.py around lines 406 to 415, the code uses an
assert to signal an invalid executor_type which can be stripped in optimized
runs and raises AssertionError; replace the assert with a proper exception:
raise a ValueError that includes the invalid executor_type value in the message
(e.g., f"Invalid executor_type: {executor_type}") so callers receive a clear,
non-stripable error.
if self._done and isinstance(self.queue, ray.actor.ActorHandle): | ||
self.queue.unregister.remote(self.id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Potential race condition in cleanup logic.
The cleanup check and unregister operation at lines 485-486 could race with other operations. The _done
flag is set before this check, but if another thread is still processing, it could lead to premature unregistration.
🏁 Script executed:
#!/bin/bash
# Check for other places where _done is accessed or modified
rg -n "_done.*=" --type py | grep -E "(result\.py|executor\.py)" | head -20
# Check for synchronization mechanisms around _done
rg -n -B3 -A3 "self\._done" --type py | grep -E "(Lock|Event|Condition|synchronized)" | head -20
Length of output: 1341
Ensure thread-safe cleanup around self._done
before calling queue.unregister
The current pattern at tensorrt_llm/executor/result.py (around lines 485–486) is vulnerable to races because _done
is toggled by one context while another may simultaneously evaluate it and invoke unregister
. To prevent premature or duplicated unregister calls, serialize access with a lock or use an atomic guard:
• In the class initializer, add a lock and a flag to track cleanup:
import threading
class Result:
def __init__(…):
…
self._done = False
self._cleanup_lock = threading.Lock()
self._unregistered = False
• Around the unregister call, acquire the lock and check both guards:
- if self._done and isinstance(self.queue, ray.actor.ActorHandle):
- self.queue.unregister.remote(self.id)
+ if isinstance(self.queue, ray.actor.ActorHandle):
+ with self._cleanup_lock:
+ if self._done and not self._unregistered:
+ self.queue.unregister.remote(self.id)
+ self._unregistered = True
This ensures that no two threads can enter the unregister block simultaneously and that the remote call is only issued once.
🤖 Prompt for AI Agents
In tensorrt_llm/executor/result.py around lines 485–486, the unregister call
guarded only by self._done can race; add thread-safe cleanup by importing
threading, creating self._cleanup_lock = threading.Lock() and self._unregistered
= False in the class __init__, then replace the direct if self._done ...
unregister check with code that acquires self._cleanup_lock, verifies self._done
and not self._unregistered, sets self._unregistered = True, releases the lock,
and only then (after confirming self.queue is a ray.actor.ActorHandle) calls
self.queue.unregister.remote(self.id) so unregister is issued exactly once and
is serialized across threads.
response = ray.get(self.queue.get.remote(self.request_id)) | ||
response = self._handle_ray_response(response) | ||
else: | ||
response = self.queue.get() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for Ray remote operations.
The Ray remote operations could fail with network or serialization errors, but there's no error handling around them.
def _result_step(self, timeout: Optional[float] = None):
if isinstance(self.queue, ray.actor.ActorHandle):
- response = ray.get(self.queue.get.remote(self.request_id))
- response = self._handle_ray_response(response)
+ try:
+ response = ray.get(self.queue.get.remote(self.request_id))
+ response = self._handle_ray_response(response)
+ except ray.exceptions.RayError as e:
+ raise RuntimeError(f"Failed to get response from Ray queue: {e}")
else:
Apply similar error handling to the async version on lines 681-682.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
response = ray.get(self.queue.get.remote(self.request_id)) | |
response = self._handle_ray_response(response) | |
else: | |
response = self.queue.get() | |
def _result_step(self, timeout: Optional[float] = None): | |
if isinstance(self.queue, ray.actor.ActorHandle): | |
try: | |
response = ray.get(self.queue.get.remote(self.request_id)) | |
response = self._handle_ray_response(response) | |
except ray.exceptions.RayError as e: | |
raise RuntimeError(f"Failed to get response from Ray queue: {e}") | |
else: | |
response = self.queue.get() |
# 8 GPUs, no cp | ||
(4, 2, 1, -1, -1), | ||
(2, 4, 1, -1, -1), | ||
# 8 GPPUs with cp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typo in comment
There's a typo in the comment: "GPPUs" should be "GPUs".
- # 8 GPPUs with cp
+ # 8 GPUs with cp
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# 8 GPPUs with cp | |
# 8 GPUs with cp |
🤖 Prompt for AI Agents
In tests/unittest/_torch/ray/test_mapping.py around line 39, the inline comment
contains a typo "GPPUs" which should be corrected to "GPUs"; update the comment
text accordingly to read "8 GPUs with cp" (or similar intended wording) to fix
the spelling.
@@ -0,0 +1,63 @@ | |||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add NVIDIA copyright header.
Per repo guidelines, prepend the NVIDIA copyright header.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
import os
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
import os | |
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
import os |
🤖 Prompt for AI Agents
In tests/unittest/_torch/ray/test_placement.py around lines 1 to 1, the file is
missing the required NVIDIA copyright header; prepend the repository's standard
NVIDIA copyright header at the very top of the file (above any imports) to
comply with repo guidelines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review continued from previous batch...
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm; | ||
updateKVCacheTransferBW(syncComm, it->first); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verify syncComm is not null before use
The code dereferences syncComm
without checking if it's null, though there's a conditional check syncComm
used elsewhere. This could lead to a null pointer dereference.
if (!common::getEnvKVCacheTransferOutputPath().empty())
{
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
+ if (syncComm)
+ {
updateKVCacheTransferBW(syncComm, it->first);
+ }
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm; | |
updateKVCacheTransferBW(syncComm, it->first); | |
} | |
if (!common::getEnvKVCacheTransferOutputPath().empty()) | |
{ | |
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm; | |
if (syncComm) | |
{ | |
updateKVCacheTransferBW(syncComm, it->first); | |
} | |
} |
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp around lines 593 to 595,
the pointer syncComm is used without a null-check; ensure you validate syncComm
before passing it to updateKVCacheTransferBW. Modify the code to check if
syncComm is non-null and only call updateKVCacheTransferBW(syncComm, it->first)
when it is valid; if null, either skip the call or use a safe fallback
communicator (e.g., mGroupComm or mGroupDataComm as appropriate) and document
which branch you chose.
std::cout << "mRank: " << mRank << " recv socketStates[" << r << "]: " << socketStates[r].toString() | ||
<< std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove debug std::cout statement before production
Debug output should not be committed to production code.
- // TODO: remove this
- std::cout << "mRank: " << mRank << " recv socketStates[" << r << "]: " << socketStates[r].toString()
- << std::endl;
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
std::cout << "mRank: " << mRank << " recv socketStates[" << r << "]: " << socketStates[r].toString() | |
<< std::endl; |
🤖 Prompt for AI Agents
In
cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.cpp
around lines 287-288, remove the debug std::cout statement that prints mRank and
socketStates[r] so no debug output is left in production; replace it with either
nothing or, if logging is required, use the project's logging framework at an
appropriate log level (e.g., debug) and ensure it is guarded or compiled out in
production builds.
void init_store(c10::intrusive_ptr<c10d::Store> const& default_store) | ||
{ | ||
(void) default_store; // TODO | ||
TLLM_LOG_DEBUG(pg_world->getRank(), "Init store on rank %d", pg_world->getRank()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential null pointer dereference in init_store
The function uses pg_world->getRank()
without checking if pg_world
is null. If init_store
is called before init_pg
, this will cause a segmentation fault.
void init_store(c10::intrusive_ptr<c10d::Store> const& default_store)
{
(void) default_store; // TODO
+ if (!pg_world)
+ {
+ throw std::runtime_error("init_store called before init_pg. Process groups must be initialized first.");
+ }
TLLM_LOG_DEBUG(pg_world->getRank(), "Init store on rank %d", pg_world->getRank());
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
void init_store(c10::intrusive_ptr<c10d::Store> const& default_store) | |
{ | |
(void) default_store; // TODO | |
TLLM_LOG_DEBUG(pg_world->getRank(), "Init store on rank %d", pg_world->getRank()); | |
void init_store(c10::intrusive_ptr<c10d::Store> const& default_store) | |
{ | |
(void) default_store; // TODO | |
if (!pg_world) | |
{ | |
throw std::runtime_error("init_store called before init_pg. Process groups must be initialized first."); | |
} | |
TLLM_LOG_DEBUG(pg_world->getRank(), "Init store on rank %d", pg_world->getRank()); | |
} |
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/runtime/utils/pgUtils.cpp around lines 51 to 54, the code
calls pg_world->getRank() without verifying pg_world is non-null which can cause
a crash if init_store is invoked before init_pg; add a null check at the start
of init_store (if pg_world == nullptr) and handle it by logging a warning/error
and returning early (or use a safe rank value) so no dereference happens,
ensuring subsequent code only runs when pg_world is valid.
std::set<int> getLocalGroupTorch(std::set<int> const& group) | ||
{ | ||
auto const worldPg = get_world_pg(); | ||
auto const myRank = worldPg->getRank(); | ||
auto const localPg = get_local_pg(); | ||
auto const myLocalRank = localPg->getRank(); | ||
auto const localSize = static_cast<uint32_t>(localPg->getSize()); | ||
|
||
PgHelper pgh_local{localPg}; | ||
PgHelper pgh_world{worldPg}; // for p2p | ||
|
||
std::vector<int32_t> ranks(localSize, -1); | ||
std::vector<int32_t> localRanks(localSize, -1); | ||
|
||
if (group.size() >= localSize) | ||
{ | ||
PGCHECK_THROW(pgh_local.allgather(&myRank, ref(ranks), {})); | ||
PGCHECK_THROW(pgh_local.allgather(&myLocalRank, ref(localRanks), {})); | ||
} | ||
else | ||
{ | ||
int tag = static_cast<int>(MpiTag::kDefault); | ||
|
||
if (myRank == *group.begin()) | ||
{ | ||
// Leader: gather from peers (world ranks), then broadcast full localSize arrays. | ||
size_t cnt = 0; | ||
ranks[cnt++] = myRank; | ||
int tmp; | ||
for (auto it = std::next(group.begin()); it != group.end(); ++it) | ||
{ | ||
PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag)); | ||
ranks[cnt++] = tmp; | ||
} | ||
for (auto it = std::next(group.begin()); it != group.end(); ++it) | ||
{ | ||
PGCHECK_THROW(pgh_world.send(ref(ranks), *it, tag)); | ||
} | ||
|
||
cnt = 0; | ||
localRanks[cnt++] = myLocalRank; | ||
for (auto it = std::next(group.begin()); it != group.end(); ++it) | ||
{ | ||
PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag)); | ||
localRanks[cnt++] = tmp; | ||
} | ||
for (auto it = std::next(group.begin()); it != group.end(); ++it) | ||
{ | ||
PGCHECK_THROW(pgh_world.send(ref(localRanks), *it, tag)); | ||
} | ||
} | ||
else | ||
{ | ||
int leader = *group.begin(); | ||
|
||
PGCHECK_THROW(pgh_world.send(&myRank, leader, tag)); | ||
PGCHECK_THROW(pgh_world.recv(ref(ranks), leader, tag)); | ||
|
||
PGCHECK_THROW(pgh_world.send(&myLocalRank, leader, tag)); | ||
PGCHECK_THROW(pgh_world.recv(ref(localRanks), leader, tag)); | ||
} | ||
} | ||
|
||
std::set<int> localGroup; | ||
for (size_t i = 0; i < ranks.size(); ++i) | ||
{ | ||
int world_r = ranks[i]; | ||
if (group.find(world_r) != group.end()) | ||
localGroup.insert(localRanks[i]); | ||
} | ||
return localGroup; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ref symbol is unqualified; missing include and std::ref usage cause compile errors.
getLocalGroupTorch() uses ref(...) without namespace qualification. You must include and use std::ref.
Apply this diff:
@@
-#include <unordered_set>
+#include <unordered_set>
+#include <functional>
@@
- PGCHECK_THROW(pgh_local.allgather(&myRank, ref(ranks), {}));
- PGCHECK_THROW(pgh_local.allgather(&myLocalRank, ref(localRanks), {}));
+ PGCHECK_THROW(pgh_local.allgather(&myRank, std::ref(ranks), {}));
+ PGCHECK_THROW(pgh_local.allgather(&myLocalRank, std::ref(localRanks), {}));
@@
- PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag));
+ PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag));
ranks[cnt++] = tmp;
@@
- PGCHECK_THROW(pgh_world.send(ref(ranks), *it, tag));
+ PGCHECK_THROW(pgh_world.send(std::ref(ranks), *it, tag));
@@
- PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag));
+ PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag));
localRanks[cnt++] = tmp;
@@
- PGCHECK_THROW(pgh_world.send(ref(localRanks), *it, tag));
+ PGCHECK_THROW(pgh_world.send(std::ref(localRanks), *it, tag));
@@
- PGCHECK_THROW(pgh_world.send(&myRank, leader, tag));
- PGCHECK_THROW(pgh_world.recv(ref(ranks), leader, tag));
+ PGCHECK_THROW(pgh_world.send(&myRank, leader, tag));
+ PGCHECK_THROW(pgh_world.recv(std::ref(ranks), leader, tag));
@@
- PGCHECK_THROW(pgh_world.send(&myLocalRank, leader, tag));
- PGCHECK_THROW(pgh_world.recv(ref(localRanks), leader, tag));
+ PGCHECK_THROW(pgh_world.send(&myLocalRank, leader, tag));
+ PGCHECK_THROW(pgh_world.recv(std::ref(localRanks), leader, tag));
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
std::set<int> getLocalGroupTorch(std::set<int> const& group) | |
{ | |
auto const worldPg = get_world_pg(); | |
auto const myRank = worldPg->getRank(); | |
auto const localPg = get_local_pg(); | |
auto const myLocalRank = localPg->getRank(); | |
auto const localSize = static_cast<uint32_t>(localPg->getSize()); | |
PgHelper pgh_local{localPg}; | |
PgHelper pgh_world{worldPg}; // for p2p | |
std::vector<int32_t> ranks(localSize, -1); | |
std::vector<int32_t> localRanks(localSize, -1); | |
if (group.size() >= localSize) | |
{ | |
PGCHECK_THROW(pgh_local.allgather(&myRank, ref(ranks), {})); | |
PGCHECK_THROW(pgh_local.allgather(&myLocalRank, ref(localRanks), {})); | |
} | |
else | |
{ | |
int tag = static_cast<int>(MpiTag::kDefault); | |
if (myRank == *group.begin()) | |
{ | |
// Leader: gather from peers (world ranks), then broadcast full localSize arrays. | |
size_t cnt = 0; | |
ranks[cnt++] = myRank; | |
int tmp; | |
for (auto it = std::next(group.begin()); it != group.end(); ++it) | |
{ | |
PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag)); | |
ranks[cnt++] = tmp; | |
} | |
for (auto it = std::next(group.begin()); it != group.end(); ++it) | |
{ | |
PGCHECK_THROW(pgh_world.send(ref(ranks), *it, tag)); | |
} | |
cnt = 0; | |
localRanks[cnt++] = myLocalRank; | |
for (auto it = std::next(group.begin()); it != group.end(); ++it) | |
{ | |
PGCHECK_THROW(pgh_world.recv(&tmp, *it, tag)); | |
localRanks[cnt++] = tmp; | |
} | |
for (auto it = std::next(group.begin()); it != group.end(); ++it) | |
{ | |
PGCHECK_THROW(pgh_world.send(ref(localRanks), *it, tag)); | |
} | |
} | |
else | |
{ | |
int leader = *group.begin(); | |
PGCHECK_THROW(pgh_world.send(&myRank, leader, tag)); | |
PGCHECK_THROW(pgh_world.recv(ref(ranks), leader, tag)); | |
PGCHECK_THROW(pgh_world.send(&myLocalRank, leader, tag)); | |
PGCHECK_THROW(pgh_world.recv(ref(localRanks), leader, tag)); | |
} | |
} | |
std::set<int> localGroup; | |
for (size_t i = 0; i < ranks.size(); ++i) | |
{ | |
int world_r = ranks[i]; | |
if (group.find(world_r) != group.end()) | |
localGroup.insert(localRanks[i]); | |
} | |
return localGroup; | |
} | |
++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp | |
@@ -1,5 +1,6 @@ | |
#include <unordered_set> | |
#include <functional> | |
// ... other includes ... | |
@@ -168,9 +169,9 @@ std::set<int> getLocalGroupTorch(std::set<int> const& group) | |
if (group.size() >= localSize) | |
{ | |
- PGCHECK_THROW(pgh_local.allgather(&myRank, ref(ranks), {})); | |
PGCHECK_THROW(pgh_local.allgather(&myRank, std::ref(ranks), {})); | |
PGCHECK_THROW(pgh_local.allgather(&myLocalRank, std::ref(localRanks), {})); | |
} | |
else | |
{ | |
@@ -189,7 +190,7 @@ std::set<int> getLocalGroupTorch(std::set<int> const& group) | |
for (auto it = std::next(group.begin()); it != group.end(); ++it) | |
{ | |
PGCHECK_THROW(pgh_world.send(std::ref(ranks), *it, tag)); | |
} | |
cnt = 0; | |
@@ -200,7 +201,7 @@ std::set<int> getLocalGroupTorch(std::set<int> const& group) | |
for (auto it = std::next(group.begin()); it != group.end(); ++it) | |
{ | |
PGCHECK_THROW(pgh_world.send(std::ref(localRanks), *it, tag)); | |
} | |
} | |
else | |
@@ -217,8 +218,8 @@ std::set<int> getLocalGroupTorch(std::set<int> const& group) | |
PGCHECK_THROW(pgh_world.send(&myRank, leader, tag)); | |
PGCHECK_THROW(pgh_world.recv(std::ref(ranks), leader, tag)); | |
PGCHECK_THROW(pgh_world.send(&myLocalRank, leader, tag)); | |
PGCHECK_THROW(pgh_world.recv(std::ref(localRanks), leader, tag)); | |
} | |
} |
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/thop/allreduceOp.cpp around lines 162 to 233, the code calls
ref(...) unqualified which causes compile errors because <functional> is not
included and ref is in std::; fix by adding #include <functional> to the file
and replace unqualified ref(...) calls with std::ref(...) (or qualify them)
wherever used so the compiler finds std::ref.
run_python_file() { | ||
local file_path="$1" | ||
shift # Remove first argument (file_path) | ||
local args="$@" # Get remaining arguments | ||
local file_name=$(basename "$file_path" .py) | ||
|
||
local timestamp=$(date +"%Y%m%d_%H%M%S") | ||
local log_file="$LOG_DIR/${file_name}_${timestamp}.log" | ||
|
||
# Create a unique test name with arguments | ||
local test_name="${file_name}" | ||
if [ $# -gt 0 ]; then | ||
test_name="${file_name}(${args// /_})" | ||
fi | ||
|
||
echo "Running: $file_path $args" | ||
echo "Log file: $log_file" | ||
echo "------------------------------------------" | ||
|
||
# Run Python file with arguments and record output | ||
python3 "$file_path" $args 2>&1 | tee "$log_file" | ||
|
||
# Check run result and store it | ||
if [ ${PIPESTATUS[0]} -eq 0 ]; then | ||
echo "✅ $file_name ran successfully" | ||
TEST_RESULTS["$test_name"]="PASS" | ||
else | ||
echo "❌ $file_name failed to run" | ||
TEST_RESULTS["$test_name"]="FAIL" | ||
fi | ||
TEST_LOG_FILES["$test_name"]="$log_file" | ||
echo "" | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix shellcheck warnings for more robust script execution
Several shellcheck warnings should be addressed to make the script more robust:
- Line 39: Array assignment issue
- Lines 40, 42: Potential masking of return values from command substitution
- Line 76: Potential masking of return value
run_python_file() {
local file_path="$1"
shift # Remove first argument (file_path)
- local args="$@" # Get remaining arguments
- local file_name=$(basename "$file_path" .py)
+ local args=("$@") # Store as array for proper handling
+ local file_name
+ file_name=$(basename "$file_path" .py)
- local timestamp=$(date +"%Y%m%d_%H%M%S")
+ local timestamp
+ timestamp=$(date +"%Y%m%d_%H%M%S")
local log_file="$LOG_DIR/${file_name}_${timestamp}.log"
# Create a unique test name with arguments
local test_name="${file_name}"
if [ $# -gt 0 ]; then
- test_name="${file_name}(${args// /_})"
+ test_name="${file_name}(${args[*]// /_})"
fi
- echo "Running: $file_path $args"
+ echo "Running: $file_path ${args[*]}"
echo "Log file: $log_file"
echo "------------------------------------------"
# Run Python file with arguments and record output
- python3 "$file_path" $args 2>&1 | tee "$log_file"
+ python3 "$file_path" "${args[@]}" 2>&1 | tee "$log_file"
Also fix line 76:
run_disagg_serving_test() {
local disagg_dir="disagg_serving_local"
local test_name="disagg_serving_test"
# Store the original directory path for log files
- local original_dir=$(pwd)
+ local original_dir
+ original_dir=$(pwd)
And line 176:
- local python_output=$(python3 current_node_gpu_type.py 2>/dev/null)
+ local python_output
+ python_output=$(python3 current_node_gpu_type.py 2>/dev/null)
And line 192:
-read sm_major sm_minor gpu_count <<< $(get_sm_version)
+read sm_major sm_minor gpu_count <<< "$(get_sm_version)"
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
run_python_file() { | |
local file_path="$1" | |
shift # Remove first argument (file_path) | |
local args="$@" # Get remaining arguments | |
local file_name=$(basename "$file_path" .py) | |
local timestamp=$(date +"%Y%m%d_%H%M%S") | |
local log_file="$LOG_DIR/${file_name}_${timestamp}.log" | |
# Create a unique test name with arguments | |
local test_name="${file_name}" | |
if [ $# -gt 0 ]; then | |
test_name="${file_name}(${args// /_})" | |
fi | |
echo "Running: $file_path $args" | |
echo "Log file: $log_file" | |
echo "------------------------------------------" | |
# Run Python file with arguments and record output | |
python3 "$file_path" $args 2>&1 | tee "$log_file" | |
# Check run result and store it | |
if [ ${PIPESTATUS[0]} -eq 0 ]; then | |
echo "✅ $file_name ran successfully" | |
TEST_RESULTS["$test_name"]="PASS" | |
else | |
echo "❌ $file_name failed to run" | |
TEST_RESULTS["$test_name"]="FAIL" | |
fi | |
TEST_LOG_FILES["$test_name"]="$log_file" | |
echo "" | |
} | |
# Updated run_python_file (lines 36–68) | |
run_python_file() { | |
local file_path="$1" | |
shift # Remove first argument (file_path) | |
# Store remaining args as an array | |
local args=("$@") | |
# Separate declaration and assignment to avoid masking return values | |
local file_name | |
file_name=$(basename "$file_path" .py) | |
local timestamp | |
timestamp=$(date +"%Y%m%d_%H%M%S") | |
local log_file="$LOG_DIR/${file_name}_${timestamp}.log" | |
# Create a unique test name with arguments | |
local test_name="${file_name}" | |
if [ $# -gt 0 ]; then | |
test_name="${file_name}(${args[*]// /_})" | |
fi | |
echo "Running: $file_path ${args[*]}" | |
echo "Log file: $log_file" | |
echo "------------------------------------------" | |
# Run Python file with arguments and record output | |
python3 "$file_path" "${args[@]}" 2>&1 | tee "$log_file" | |
# Check run result and store it | |
if [ "${PIPESTATUS[0]}" -eq 0 ]; then | |
echo "✅ $file_name ran successfully" | |
TEST_RESULTS["$test_name"]="PASS" | |
else | |
echo "❌ $file_name failed to run" | |
TEST_RESULTS["$test_name"]="FAIL" | |
fi | |
TEST_LOG_FILES["$test_name"]="$log_file" | |
echo "" | |
} |
run_python_file() { | |
local file_path="$1" | |
shift # Remove first argument (file_path) | |
local args="$@" # Get remaining arguments | |
local file_name=$(basename "$file_path" .py) | |
local timestamp=$(date +"%Y%m%d_%H%M%S") | |
local log_file="$LOG_DIR/${file_name}_${timestamp}.log" | |
# Create a unique test name with arguments | |
local test_name="${file_name}" | |
if [ $# -gt 0 ]; then | |
test_name="${file_name}(${args// /_})" | |
fi | |
echo "Running: $file_path $args" | |
echo "Log file: $log_file" | |
echo "------------------------------------------" | |
# Run Python file with arguments and record output | |
python3 "$file_path" $args 2>&1 | tee "$log_file" | |
# Check run result and store it | |
if [ ${PIPESTATUS[0]} -eq 0 ]; then | |
echo "✅ $file_name ran successfully" | |
TEST_RESULTS["$test_name"]="PASS" | |
else | |
echo "❌ $file_name failed to run" | |
TEST_RESULTS["$test_name"]="FAIL" | |
fi | |
TEST_LOG_FILES["$test_name"]="$log_file" | |
echo "" | |
} | |
# Fix in run_disagg_serving_test (around line 76) | |
run_disagg_serving_test() { | |
local disagg_dir="disagg_serving_local" | |
local test_name="disagg_serving_test" | |
# Store the original directory path for log files | |
local original_dir | |
original_dir=$(pwd) | |
# …rest of function… | |
} |
run_python_file() { | |
local file_path="$1" | |
shift # Remove first argument (file_path) | |
local args="$@" # Get remaining arguments | |
local file_name=$(basename "$file_path" .py) | |
local timestamp=$(date +"%Y%m%d_%H%M%S") | |
local log_file="$LOG_DIR/${file_name}_${timestamp}.log" | |
# Create a unique test name with arguments | |
local test_name="${file_name}" | |
if [ $# -gt 0 ]; then | |
test_name="${file_name}(${args// /_})" | |
fi | |
echo "Running: $file_path $args" | |
echo "Log file: $log_file" | |
echo "------------------------------------------" | |
# Run Python file with arguments and record output | |
python3 "$file_path" $args 2>&1 | tee "$log_file" | |
# Check run result and store it | |
if [ ${PIPESTATUS[0]} -eq 0 ]; then | |
echo "✅ $file_name ran successfully" | |
TEST_RESULTS["$test_name"]="PASS" | |
else | |
echo "❌ $file_name failed to run" | |
TEST_RESULTS["$test_name"]="FAIL" | |
fi | |
TEST_LOG_FILES["$test_name"]="$log_file" | |
echo "" | |
} | |
# Fix masking return value in python_output assignment (around line 176) | |
local python_output | |
python_output=$(python3 current_node_gpu_type.py 2>/dev/null) |
run_python_file() { | |
local file_path="$1" | |
shift # Remove first argument (file_path) | |
local args="$@" # Get remaining arguments | |
local file_name=$(basename "$file_path" .py) | |
local timestamp=$(date +"%Y%m%d_%H%M%S") | |
local log_file="$LOG_DIR/${file_name}_${timestamp}.log" | |
# Create a unique test name with arguments | |
local test_name="${file_name}" | |
if [ $# -gt 0 ]; then | |
test_name="${file_name}(${args// /_})" | |
fi | |
echo "Running: $file_path $args" | |
echo "Log file: $log_file" | |
echo "------------------------------------------" | |
# Run Python file with arguments and record output | |
python3 "$file_path" $args 2>&1 | tee "$log_file" | |
# Check run result and store it | |
if [ ${PIPESTATUS[0]} -eq 0 ]; then | |
echo "✅ $file_name ran successfully" | |
TEST_RESULTS["$test_name"]="PASS" | |
else | |
echo "❌ $file_name failed to run" | |
TEST_RESULTS["$test_name"]="FAIL" | |
fi | |
TEST_LOG_FILES["$test_name"]="$log_file" | |
echo "" | |
} | |
# Quote command substitution to avoid unwanted word-splitting (around line 192) | |
read sm_major sm_minor gpu_count <<< "$(get_sm_version)" |
🧰 Tools
🪛 Shellcheck (0.10.0)
[warning] 39-39: Assigning an array to a string! Assign as array, or use * instead of @ to concatenate.
(SC2124)
[warning] 40-40: Declare and assign separately to avoid masking return values.
(SC2155)
[warning] 42-42: Declare and assign separately to avoid masking return values.
(SC2155)
super().__init__(model_world_size, postproc_worker_config, | ||
is_llm_executor) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect superclass initialization; breaks postprocessing config.
GenerationExecutor.__init__
expects (num_postprocess_workers, postprocess_tokenizer_dir, is_llm_executor)
. Passing (model_world_size, PostprocWorkerConfig, is_llm_executor)
is wrong and will corrupt configuration.
Apply this diff:
- super().__init__(model_world_size, postproc_worker_config,
- is_llm_executor)
+ super().__init__(
+ num_postprocess_workers=postproc_worker_config.num_postprocess_workers,
+ postprocess_tokenizer_dir=postproc_worker_config.postprocess_tokenizer_dir,
+ is_llm_executor=is_llm_executor,
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
super().__init__(model_world_size, postproc_worker_config, | |
is_llm_executor) | |
super().__init__( | |
num_postprocess_workers=postproc_worker_config.num_postprocess_workers, | |
postprocess_tokenizer_dir=postproc_worker_config.postprocess_tokenizer_dir, | |
is_llm_executor=is_llm_executor, | |
) |
🤖 Prompt for AI Agents
In tensorrt_llm/executor/ray_executor.py around lines 45 to 47, the superclass
is being initialized with (model_world_size, postproc_worker_config,
is_llm_executor) which is incorrect; GenerationExecutor.__init__ expects
(num_postprocess_workers, postprocess_tokenizer_dir, is_llm_executor). Replace
the current super().__init__ call so it passes the postprocessing worker count
and tokenizer dir extracted from postproc_worker_config (e.g.,
postproc_worker_config.num_postprocess_workers and
postproc_worker_config.postprocess_tokenizer_dir) as the first two arguments,
keeping is_llm_executor as the third, so the postprocessing configuration is
correctly propagated.
runtime_env = ray.runtime_env.RuntimeEnv() | ||
runtime_env["env_vars"] = os.environ.copy() | ||
runtime_env["env_vars"].update({ | ||
"TLLM_DISABLE_MPI": "1", | ||
"MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo | ||
"MASTER_PORT": str(self.master_port) | ||
}) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix invalid RuntimeEnv mutation; use dict or construct RuntimeEnv with kwargs.
ray.runtime_env.RuntimeEnv()
is not a dict and does not support __setitem__
. This will throw at runtime. Use a plain dict for runtime_env
or instantiate RuntimeEnv(env_vars=...)
.
Apply this diff:
- runtime_env = ray.runtime_env.RuntimeEnv()
- runtime_env["env_vars"] = os.environ.copy()
- runtime_env["env_vars"].update({
- "TLLM_DISABLE_MPI": "1",
- "MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo
- "MASTER_PORT": str(self.master_port)
- })
+ runtime_env = {
+ "env_vars": {
+ **os.environ,
+ "TLLM_DISABLE_MPI": "1",
+ "MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo
+ "MASTER_PORT": str(self.master_port),
+ }
+ }
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
runtime_env = ray.runtime_env.RuntimeEnv() | |
runtime_env["env_vars"] = os.environ.copy() | |
runtime_env["env_vars"].update({ | |
"TLLM_DISABLE_MPI": "1", | |
"MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo | |
"MASTER_PORT": str(self.master_port) | |
}) | |
runtime_env = { | |
"env_vars": { | |
**os.environ, | |
"TLLM_DISABLE_MPI": "1", | |
"MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo | |
"MASTER_PORT": str(self.master_port), | |
} | |
} |
🤖 Prompt for AI Agents
In tensorrt_llm/executor/ray_executor.py around lines 102 to 109, the code
treats ray.runtime_env.RuntimeEnv() like a dict and tries to set items on it
which raises at runtime; replace this by either building a plain dict for
runtime_env (e.g., runtime_env = {"env_vars": os.environ.copy(), ...}) or
construct the RuntimeEnv with kwargs (e.g., RuntimeEnv(env_vars=combined_env)),
ensuring you merge os.environ.copy() with the additional keys
("TLLM_DISABLE_MPI","MASTER_ADDR","MASTER_PORT") into a single dict before
passing it to RuntimeEnv.
if async_call: | ||
return [(getattr(self.workers, func).remote(*args, **kwargs)) | ||
for worker in workers] | ||
else: | ||
return ray.get([ | ||
getattr(worker, func).remote(*args, **kwargs) | ||
for worker in workers | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix call_all_ray_workers async branch; currently calls getattr on a list.
getattr(self.workers, func)
is wrong (self.workers is a list). This raises AttributeError and prevents async calls from executing.
Apply this diff:
- if async_call:
- return [(getattr(self.workers, func).remote(*args, **kwargs))
- for worker in workers]
+ if async_call:
+ return [getattr(worker, func).remote(*args, **kwargs)
+ for worker in workers]
else:
return ray.get([
getattr(worker, func).remote(*args, **kwargs)
for worker in workers
])
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if async_call: | |
return [(getattr(self.workers, func).remote(*args, **kwargs)) | |
for worker in workers] | |
else: | |
return ray.get([ | |
getattr(worker, func).remote(*args, **kwargs) | |
for worker in workers | |
]) | |
if async_call: | |
return [getattr(worker, func).remote(*args, **kwargs) | |
for worker in workers] | |
else: | |
return ray.get([ | |
getattr(worker, func).remote(*args, **kwargs) | |
for worker in workers | |
]) |
🤖 Prompt for AI Agents
In tensorrt_llm/executor/ray_executor.py around lines 130 to 137, the async
branch incorrectly calls getattr(self.workers, func) (self.workers is a list)
which raises AttributeError; change the async branch to iterate over the workers
and call getattr(worker, func).remote(*args, **kwargs) for each worker (same
pattern used in the sync branch), returning the list of remote object refs.
Ensure the list comprehension mirrors the sync branch but without ray.get, i.e.,
[getattr(worker, func).remote(*args, **kwargs) for worker in workers].
if executor_config is None: | ||
executor_config = tllm.ExecutorConfig(1) | ||
|
||
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( | ||
processor_batched=batched_logits_processor, replicate=False) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: self._executor_config may remain None despite local default assignment.
If executor_config is None on entry, you create a local ExecutorConfig but never write it back to self._executor_config. Later uses (e.g., disable_overlap_scheduler) will crash.
Apply this diff:
- self._executor_config = executor_config
+ self._executor_config = executor_config
@@
- if executor_config is None:
- executor_config = tllm.ExecutorConfig(1)
+ if executor_config is None:
+ executor_config = tllm.ExecutorConfig(1)
+ # Keep the instance field in sync
+ self._executor_config = executor_config
Also applies to: 140-143
if self.engine.can_enqueue_requests(): | ||
|
||
if self.await_response_thread.is_alive(): | ||
self.await_response_thread.stop() | ||
self.await_response_thread.join() | ||
if self.dispatch_stats_thread.is_alive(): | ||
self.dispatch_stats_thread.stop() | ||
self.dispatch_stats_thread.join() | ||
if self.dispatch_kv_cache_events_thread.is_alive(): | ||
self.dispatch_kv_cache_events_thread.stop() | ||
self.dispatch_kv_cache_events_thread.join() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
References to non-existent threads: await_response_thread/dispatch_ are undefined here.*
RayGPUWorker doesn’t start these threads; accessing them will raise AttributeError. Let engine own its own background resources.
Apply this diff to remove the invalid block:
- if self.engine.can_enqueue_requests():
-
- if self.await_response_thread.is_alive():
- self.await_response_thread.stop()
- self.await_response_thread.join()
- if self.dispatch_stats_thread.is_alive():
- self.dispatch_stats_thread.stop()
- self.dispatch_stats_thread.join()
- if self.dispatch_kv_cache_events_thread.is_alive():
- self.dispatch_kv_cache_events_thread.stop()
- self.dispatch_kv_cache_events_thread.join()
+ # Engine-specific background threads are handled inside engine.shutdown()
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if self.engine.can_enqueue_requests(): | |
if self.await_response_thread.is_alive(): | |
self.await_response_thread.stop() | |
self.await_response_thread.join() | |
if self.dispatch_stats_thread.is_alive(): | |
self.dispatch_stats_thread.stop() | |
self.dispatch_stats_thread.join() | |
if self.dispatch_kv_cache_events_thread.is_alive(): | |
self.dispatch_kv_cache_events_thread.stop() | |
self.dispatch_kv_cache_events_thread.join() | |
# Engine-specific background threads are handled inside engine.shutdown() |
🤖 Prompt for AI Agents
In tensorrt_llm/executor/ray_gpu_worker.py around lines 420 to 431, there is a
block that attempts to stop/join await_response_thread, dispatch_stats_thread,
and dispatch_kv_cache_events_thread which are not defined in RayGPUWorker and
will raise AttributeError; remove this entire block so the worker no longer
references non-existent threads and instead rely on the engine to own and manage
its background resources (ensure no other code paths expect these attributes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 13
♻️ Duplicate comments (7)
tensorrt_llm/_torch/models/modeling_utils.py (1)
869-877
: Fix KeyError on partial weight updates in manual copy pathThe else-branch still indexes module_weights[n] directly. Partial updates that omit a param will raise KeyError. Guard with .get(...) and skip missing entries. This mirrors an earlier review; repeating here because it remains unfixed.
Apply this diff:
- else: + else: module_weights = filter_weights(name, weights) if hasattr(module, 'load_weights'): module.load_weights(weights=[module_weights]) else: - for n, p in module._parameters.items(): - if p is not None: - p.data.copy_(module_weights[n][:]) + for n, p in module._parameters.items(): + if p is None: + continue + w = module_weights.get(n) + if w is None: + # Missing weight (partial update) — skip + continue + p.data.copy_(w[:])examples/llm-api/ray/run_ray_tests.sh (2)
36-68
: Fix shellcheck issues in run_python_file: use arrays for args, avoid masked return values, quote expansionsAdopt array args, split declare/assign for command substitutions, and ensure proper quoting to avoid word-splitting and incorrect arg passing. This also fixes SC2124/SC2155 and improves robustness.
run_python_file() { local file_path="$1" shift # Remove first argument (file_path) - local args="$@" # Get remaining arguments - local file_name=$(basename "$file_path" .py) + local args=("$@") # Store remaining arguments as an array + local file_name + file_name=$(basename "$file_path" .py) - local timestamp=$(date +"%Y%m%d_%H%M%S") + local timestamp + timestamp=$(date +"%Y%m%d_%H%M%S") local log_file="$LOG_DIR/${file_name}_${timestamp}.log" # Create a unique test name with arguments local test_name="${file_name}" if [ $# -gt 0 ]; then - test_name="${file_name}(${args// /_})" + test_name="${file_name}(${args[*]// /_})" fi - echo "Running: $file_path $args" + echo "Running: $file_path ${args[*]}" echo "Log file: $log_file" echo "------------------------------------------" # Run Python file with arguments and record output - python3 "$file_path" $args 2>&1 | tee "$log_file" + python3 "$file_path" "${args[@]}" 2>&1 | tee "$log_file" # Check run result and store it if [ ${PIPESTATUS[0]} -eq 0 ]; then echo "✅ $file_name ran successfully"
192-196
: Quote command substitution in read to avoid word splittingPrevents SC2046 and odd parsing when output contains newlines/extra spaces.
-read sm_major sm_minor gpu_count <<< $(get_sm_version) +read sm_major sm_minor gpu_count <<< "$(get_sm_version)"tensorrt_llm/executor/ray_executor.py (3)
45-46
: Fix incorrect superclass initialization; breaks postprocessing config.GenerationExecutor.init expects (num_postprocess_workers, postprocess_tokenizer_dir, is_llm_executor). Passing (model_world_size, PostprocWorkerConfig, is_llm_executor) corrupts configuration. This mirrors a previously raised issue.
- super().__init__(model_world_size, postproc_worker_config, - is_llm_executor) + super().__init__( + num_postprocess_workers=postproc_worker_config.num_postprocess_workers, + postprocess_tokenizer_dir=postproc_worker_config.postprocess_tokenizer_dir, + is_llm_executor=is_llm_executor, + )
102-108
: Fix invalid RuntimeEnv mutation; use a dict or construct RuntimeEnv with kwargs.ray.runtime_env.RuntimeEnv isn’t a dict; item assignment raises at runtime. This mirrors a previously raised issue.
- runtime_env = ray.runtime_env.RuntimeEnv() - runtime_env["env_vars"] = os.environ.copy() - runtime_env["env_vars"].update({ - "TLLM_DISABLE_MPI": "1", - "MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo - "MASTER_PORT": str(self.master_port) - }) + runtime_env = { + "env_vars": { + **dict(os.environ), + "TLLM_DISABLE_MPI": "1", + "MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo + "MASTER_PORT": str(self.master_port), + } + }
130-132
: Fix async broadcast call; getattr on list raises AttributeError.The async branch calls getattr(self.workers, ...) but self.workers is a list. Iterate workers as in the sync branch. This mirrors a previously raised issue.
- if async_call: - return [(getattr(self.workers, func).remote(*args, **kwargs)) - for worker in workers] + if async_call: + return [ + getattr(worker, func).remote(*args, **kwargs) + for worker in workers + ]tensorrt_llm/executor/ray_gpu_worker.py (1)
386-397
: Remove references to undefined threads; let engine own background resources.await_response_thread/dispatch_* aren’t defined here and will raise AttributeError. A similar issue was flagged earlier.
- if self.await_response_thread.is_alive(): - self.await_response_thread.stop() - self.await_response_thread.join() - if self.dispatch_stats_thread.is_alive(): - self.dispatch_stats_thread.stop() - self.dispatch_stats_thread.join() - if self.dispatch_kv_cache_events_thread.is_alive(): - self.dispatch_kv_cache_events_thread.stop() - self.dispatch_kv_cache_events_thread.join() + # Engine-specific background threads are handled inside engine.shutdown()
🧹 Nitpick comments (20)
tensorrt_llm/_torch/models/modeling_utils.py (2)
1-1
: Add NVIDIA copyright header (2025) at file topRepository guideline requires the NVIDIA copyright header on all source files. Prepend the standard header.
Example:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import contextlib import math import os
535-539
: Avoid mutable default argument for skip_modulesBoth loaders use skip_modules: List[str] = [], which is mutable at function definition time. Switch to None and normalize inside the function.
Apply this diff:
-def load_weights(self, - weights: Dict, - weight_mapper: Optional["BaseWeightMapper"] = None, - skip_modules: List[str] = []): +def load_weights(self, + weights: Dict, + weight_mapper: Optional["BaseWeightMapper"] = None, + skip_modules: Optional[List[str]] = None):-def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], - weights: Dict, - skip_modules: List[str] = [], +def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], + weights: Dict, + skip_modules: Optional[List[str]] = None, params_map: Optional[Dict[str, str]] = None, preload_weight_modules: Optional[List[str]] = None): @@ - if not hasattr(model, 'model_config') or not isinstance( + if skip_modules is None: + skip_modules = [] + if not hasattr(model, 'model_config') or not isinstance( model.model_config, ModelConfig):-def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM], - weights: Dict, - weight_mapper: "BaseWeightMapper", - skip_modules: List[str] = [], +def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM], + weights: Dict, + weight_mapper: "BaseWeightMapper", + skip_modules: Optional[List[str]] = None, params_map: Optional[Dict[str, str]] = None, preload_weight_modules: Optional[List[str]] = None): @@ - weight_mapper.add_skip_modules(skip_modules) + if skip_modules is None: + skip_modules = [] + weight_mapper.add_skip_modules(skip_modules)Also applies to: 908-914
examples/llm-api/ray/run_ray_tests.sh (7)
1-2
: Add pipefail (and optionally nounset) for safer error handling in pipelinesset -e alone won’t catch failures in the left side of pipelines. Add pipefail; consider nounset if you want to catch typos/unset vars.
#!/bin/bash -set -e +set -e +set -o pipefail +# Optional: +# set -u
10-18
: Make LLM_MODELS_ROOT optional; skip only tests that need itRequiring LLM_MODELS_ROOT for the entire script is restrictive. Only the DeepSeek-V3-Lite path depends on it. Degrade gracefully: run what you can and SKIP DeepSeek when the var is unset.
-# Check for LLM_MODELS_ROOT environment variable -if [ -z "$LLM_MODELS_ROOT" ]; then - echo "❌ Error: LLM_MODELS_ROOT environment variable is not set!" - echo "Please set the LLM_MODELS_ROOT environment variable before running this script." - echo "For example, export LLM_MODELS_ROOT=/home/scratch.trt_llm_data/llm-models" - exit 1 -else - echo "✅ LLM_MODELS_ROOT is set to: $LLM_MODELS_ROOT" -fi +# Check for LLM_MODELS_ROOT environment variable (optional) +if [ -z "${LLM_MODELS_ROOT:-}" ]; then + echo "ℹ️ LLM_MODELS_ROOT is not set; tests requiring local model assets will be skipped." +else + echo "✅ LLM_MODELS_ROOT is set to: $LLM_MODELS_ROOT" +fi
167-186
: Harden get_sm_version parsing and avoid masked return valuesSplit declare/assign (SC2155), quote command substitutions, and validate parsed values. Fall back to 0 0 0 if parsing fails.
get_sm_version() { local sm_major=0 local sm_minor=0 local gpu_count=0 # Check if the Python script exists if [ -f "current_node_gpu_type.py" ]; then # Run the Python script and capture output - local python_output=$(python3 current_node_gpu_type.py 2>/dev/null) + local python_output + python_output=$(python3 current_node_gpu_type.py 2>/dev/null) - if [ $? -eq 0 ]; then + if [ $? -eq 0 ] && [ -n "$python_output" ]; then # Parse the output which is in format "sm_major sm_minor gpu_count" # The Python script outputs space-separated values like "9 0 2" - read sm_major sm_minor gpu_count <<< "$python_output" + if read -r sm_major sm_minor gpu_count <<< "$python_output"; then + : + else + sm_major=0; sm_minor=0; gpu_count=0 + fi fi fi echo "$sm_major $sm_minor $gpu_count" }
201-213
: Simplify SM gating and guard DeepSeek path by LLM_MODELS_ROOT presence
- Minor nit: [ "$sm_major" -ge 9 ] is sufficient; the minor check is redundant when major already ≥ 9.
- Also ensure LLM_MODELS_ROOT is set before constructing the DeepSeek path; otherwise SKIP gracefully.
- # Check SM version for DeepSeek-V3-Lite (only run on Hopper and Blackwell) - if [ "$sm_major" -ge 9 ] && [ "$sm_minor" -ge 0 ]; then + # Check SM version for DeepSeek-V3-Lite (only run on Hopper/Blackwell) + if [ "$sm_major" -ge 9 ] && [ -n "${LLM_MODELS_ROOT:-}" ]; then echo "✅ SM version $sm_major.$sm_minor is compatible with DeepSeek-V3-Lite (Hopper/Blackwell)" run_python_file "simple_ray_single_node.py" "--model_dir=$LLM_MODELS_ROOT/DeepSeek-V3-Lite/bf16" else - echo "⚠️ Skipping DeepSeek-V3-Lite test - requires Hopper or Blackwell GPU (SM 9.0+), but detected: $sm_major.$sm_minor" - echo " DeepSeek-V3-Lite test was not run due to incompatible SM version" + echo "⚠️ Skipping DeepSeek-V3-Lite test - requires SM 9.0+ and LLM_MODELS_ROOT; detected SM: $sm_major.$sm_minor; LLM_MODELS_ROOT set: ${LLM_MODELS_ROOT:+yes}${LLM_MODELS_ROOT:-no}" echo "" # Record this as a skipped test TEST_RESULTS["simple_ray_single_node(DeepSeek-V3-Lite)"]="SKIP" TEST_LOG_FILES["simple_ray_single_node(DeepSeek-V3-Lite)"]="N/A" fi
220-223
: Guard the MPI-guarding test path existenceIf ../llm_inference.py is missing, this will fail the run unnecessarily. Consider checking before invoking.
-# Run MPI guarding tests -run_python_file "../llm_inference.py" +# Run MPI guarding tests (if present) +if [ -f "../llm_inference.py" ]; then + run_python_file "../llm_inference.py" +else + TEST_RESULTS["llm_inference.py"]="SKIP" + TEST_LOG_FILES["llm_inference.py"]="N/A" +fi
241-254
: Stable ordering for summary outputAssociative arrays iterate in arbitrary order. Sorting improves readability for CI logs.
-# Print each test result -for test_name in "${!TEST_RESULTS[@]}"; do +# Print each test result (sorted) +mapfile -t __sorted_tests < <(printf '%s\n' "${!TEST_RESULTS[@]}" | sort) +for test_name in "${__sorted_tests[@]}"; do result="${TEST_RESULTS[$test_name]}" log_file="${TEST_LOG_FILES[$test_name]}" if [ "$result" = "PASS" ]; then echo "✅ $test_name - PASS" elif [ "$result" = "SKIP" ]; then echo "⏭️ $test_name - SKIP" else echo "❌ $test_name - FAIL" fi done
256-259
: Prefer quoting the sample path; also consider printing actual per-test log pathsTiny UX nit: quote the sample path to avoid issues with spaces; optionally list TEST_LOG_FILES map entries for direct copy/paste.
-echo "cat $SCRIPT_DIR/$LOG_DIR/[filename].log" +echo "cat \"$SCRIPT_DIR/$LOG_DIR/[filename].log\"" +# Or list logs: +# for k in "${!TEST_LOG_FILES[@]}"; do echo "$k -> ${TEST_LOG_FILES[$k]}"; donetensorrt_llm/executor/ray_executor.py (4)
1-1
: Add NVIDIA copyright header (2025).All source files must start with the NVIDIA copyright header per repo guidelines.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import os import socket from typing import Dict, List, Tuple
26-31
: Minimize port race for MASTER_PORT.Binding to an ephemeral port and immediately closing it can race with other processes before Ray workers call init_process_group. Consider retry-on-failure at PG init or reserving via a small retry loop around get_free_port()/PG init.
I can propose a retry helper that attempts several ports and falls back to random selection if collisions persist. Want me to add it?
87-92
: Avoid private Ray serialization helpers for weak refs.Using _serialization_helper/_deserialization_helper is private API and brittle across Ray versions. Prefer:
- Keep a strong ActorHandle and manage lifetime explicitly, or
- Name the actor and later retrieve via ray.get_actor(name=...).
If weak refs are required, guard with version checks and provide a fallback.
140-145
: Docstring style: fix D205/D415 (summary line, punctuation, blank line).Ruff hints flag these sections. Tighten docstrings to Google style with a one-line summary ending in period and a blank line before details.
- """ - Low-level API to the executor. Return a "future" GenerationResult - which can be waited. - Forwards the request to the workers through the request queue. - """ + """Submit a request and return a GenerationResult future. + + Forwards the request to the workers through the request queue. + """- """ - Either use the existing placement group from driver script (e.g., in the case of RL FW integration), - or create a default PACK placement group where each bundle has tp_size GPUs. - - When tp_size ≤ GPUs per node, keep one TP group per node. - - When tp_size > GPUs per node, allow a TP group span nodes. - - rank 0 must be put on the driver node - """ + """Resolve or create a placement group for Ray actors. + + Either use an existing placement group (e.g., for RL framework + integration) or create a default PACK placement group where each + bundle has tp_size GPUs. + - When tp_size ≤ GPUs per node, keep one TP group per node. + - When tp_size > GPUs per node, allow a TP group to span nodes. + - Rank 0 is pinned to the driver node. + """Also applies to: 210-217
tensorrt_llm/executor/ray_gpu_worker.py (4)
1-1
: Add NVIDIA copyright header (2025).All source files must start with the NVIDIA copyright header per repo guidelines.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import copy import json import os
48-52
: Small nit: assert CUDA availability or log a clear error.torch.cuda.is_available() return value is ignored. Either assert or log a failure early.
- torch.cuda.is_available() + assert torch.cuda.is_available(), "CUDA is not available but RayWorkerWrapper expects a GPU."
174-186
: Validate LoRA/prompt adapter manager availability on all ranks.Managers are instantiated only on rank 0 for tllm.Executor; _load_lora_adapter/_load_prompt_adapter rely on a local manager and local model_config. Confirm the runtime replicates/propagates loaded weights to other ranks, or initialize managers consistently across ranks when needed.
I can add rank-aware broadcast/load helpers if replication is required.
263-276
: Wrap long lines flagged by Ruff (E501).Lines are longer than 120 chars; break them for readability and lint compliance.
- assert ( - not self._is_pytorch_backend - or self.engine.kv_cache_transceiver is not None - ), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:<backend_type>` in config file for disaggregated serving" + assert ( + (not self._is_pytorch_backend) + or (self.engine.kv_cache_transceiver is not None) + ), ( + "kv_cache_transceiver is disabled; set " + "'cache_transceiver_config: backend:<backend_type>' in the config for disaggregated serving" + )Also applies to: 269-276
tensorrt_llm/_torch/pyexecutor/py_executor.py (3)
831-839
: Complex conditional logic for MPI togglingThe MPI toggle logic with
recv_object_from_isend
vsrecv_object
is complex and scattered. Consider consolidating this logic.Extract to a helper method:
def _get_recv_function(self): """Get the appropriate receive function based on MPI configuration.""" if self._disable_mpi: return self.dist.recv_object_from_isend return self.dist.recv_object # Then use it: recv_object_funct = self._get_recv_function() sample_state.host = recv_object_funct(src=self.dist.prev_pp_rank, tag=prev_microbatch_id)Also applies to: 848-854
848-854
: TODO comment needs attentionThe TODO comment "need clean up" is vague and doesn't provide actionable information about what needs to be cleaned up.
The conditional wait logic for send handles looks complex. Would you like me to help refactor this into a cleaner abstraction or create an issue to track this cleanup task?
381-385
: Memory pool cleanup logging at DEBUG levelThe virtual memory pool cleanup is logged at DEBUG level, which might not be visible in production. Consider using INFO level for resource cleanup operations.
-logger.debug("Freeing pool %s", key) +logger.info("Freeing virtual memory pool: %s", key)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (6)
examples/llm-api/ray/run_ray_tests.sh
(1 hunks)tensorrt_llm/_torch/models/modeling_utils.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/py_executor.py
(13 hunks)tensorrt_llm/_torch/pyexecutor/resource_manager.py
(2 hunks)tensorrt_llm/executor/ray_executor.py
(1 hunks)tensorrt_llm/executor/ray_gpu_worker.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tensorrt_llm/_torch/pyexecutor/resource_manager.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures
Files:
tensorrt_llm/_torch/models/modeling_utils.py
tensorrt_llm/executor/ray_gpu_worker.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
tensorrt_llm/executor/ray_executor.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)
Files:
tensorrt_llm/_torch/models/modeling_utils.py
tensorrt_llm/executor/ray_gpu_worker.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
tensorrt_llm/executor/ray_executor.py
🧬 Code graph analysis (3)
tensorrt_llm/executor/ray_gpu_worker.py (11)
tensorrt_llm/_torch/virtual_memory.py (2)
materialize_with_tag
(80-88)release_with_tag
(69-77)tensorrt_llm/_utils.py (1)
mpi_rank
(487-494)tensorrt_llm/llmapi/llm_args.py (2)
PybindMirror
(577-721)maybe_to_pybind
(587-592)tensorrt_llm/lora_manager.py (2)
cpp_lora_weights
(1163-1164)cpp_lora_config
(1167-1168)tensorrt_llm/prompt_adapter_manager.py (2)
PromptAdapterManager
(12-49)uid_to_weights
(41-42)tensorrt_llm/executor/executor.py (3)
GenerationExecutor
(76-489)shutdown
(282-283)_handle_background_error
(247-276)tensorrt_llm/executor/request.py (4)
GenerationRequest
(84-129)LoRARequest
(24-53)PromptAdapterRequest
(57-81)local_path
(80-81)tensorrt_llm/executor/result.py (5)
result
(688-699)GenerationResult
(592-759)request_id
(634-635)get
(232-238)prompt_token_ids
(638-639)tensorrt_llm/executor/utils.py (1)
RequestError
(76-77)tensorrt_llm/_torch/pyexecutor/py_executor.py (5)
enqueue_request
(419-429)can_enqueue_requests
(387-391)cancel_request
(359-365)shutdown
(367-385)wait_shutdown
(416-417)tensorrt_llm/_torch/distributed/pg_utils.py (1)
split
(5-41)
tensorrt_llm/_torch/pyexecutor/py_executor.py (5)
tensorrt_llm/executor/result.py (3)
get
(232-238)put_response
(182-185)put_response
(228-230)tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (2)
enqueue_requests
(236-240)enqueue_request
(242-248)tensorrt_llm/executor/ray_gpu_worker.py (2)
enqueue_request
(71-74)enqueue_request
(221-358)tensorrt_llm/mapping.py (2)
dist
(290-291)dist
(294-303)tensorrt_llm/_torch/distributed/communicator.py (3)
recv_object_from_isend
(359-366)recv_object
(140-141)recv_object
(323-336)
tensorrt_llm/executor/ray_executor.py (5)
tensorrt_llm/_utils.py (1)
nvtx_range_debug
(882-906)tensorrt_llm/executor/postproc_worker.py (1)
PostprocWorkerConfig
(43-50)tensorrt_llm/executor/ray_gpu_worker.py (7)
RayGPUWorker
(101-425)RayWorkerWrapper
(34-98)submit
(68-69)submit
(372-373)abort_request
(76-77)abort_request
(194-204)shutdown
(375-404)tensorrt_llm/executor/request.py (2)
GenerationRequest
(84-129)set_id
(126-129)tensorrt_llm/executor/result.py (8)
result
(688-699)GenerationResult
(592-759)RayAsyncQueue
(157-193)RaySyncQueue
(201-238)warmup
(176-180)warmup
(222-226)get
(232-238)request_id
(634-635)
🪛 Shellcheck (0.10.0)
examples/llm-api/ray/run_ray_tests.sh
[warning] 39-39: Assigning an array to a string! Assign as array, or use * instead of @ to concatenate.
(SC2124)
[warning] 40-40: Declare and assign separately to avoid masking return values.
(SC2155)
[warning] 42-42: Declare and assign separately to avoid masking return values.
(SC2155)
[warning] 76-76: Declare and assign separately to avoid masking return values.
(SC2155)
[warning] 176-176: Declare and assign separately to avoid masking return values.
(SC2155)
[warning] 192-192: Quote this to prevent word splitting.
(SC2046)
🪛 Ruff (0.12.2)
tensorrt_llm/executor/ray_gpu_worker.py
263-263: Line too long (153 > 120)
(E501)
269-269: Line too long (132 > 120)
(E501)
tensorrt_llm/_torch/pyexecutor/py_executor.py
1685-1687: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
tensorrt_llm/executor/ray_executor.py
140-143: 1 blank line required between summary line and description
(D205)
209-214: 1 blank line required between summary line and description
(D205)
209-214: First line should end with a period, question mark, or exclamation point
Add closing punctuation
(D415)
🔇 Additional comments (1)
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
1685-1691
: Handle other Ray response typesCurrently, the code only enqueues responses of type
LlmResponse
. Any other response objects sent through Ray (for example, the generictllm.Response
, various preprocessing or generation responses, or inference error responses) will be silently dropped. Please verify which response classes can be returned by this executor and extend the queue‐handling logic to cover them, or document why they can be safely ignored.Potential response types to audit (examples found elsewhere in the repo):
tllm.Response
(used in tests: tests/unittest/llmapi/test_executor.py:277)PreprocResponse
(triton_backend/all_models/tests/test_decode.py:50)GenerationResponse
(triton_backend/all_models/tests/test_decode.py:67)pb_utils.InferenceResponse
(in model code under triton_backend)– After identifying all valid response subclasses, update the
if type(resp) == LlmResponse
check to a more extensible form (e.g.,isinstance(resp, (LlmResponse, OtherResponseType, …))
) or a dispatch table to ensure no responses are lost.
# Function: run disagg serving test | ||
run_disagg_serving_test() { | ||
local disagg_dir="disagg_serving_local" | ||
local test_name="disagg_serving_test" | ||
|
||
# Store the original directory path for log files | ||
local original_dir=$(pwd) | ||
local log_file="$original_dir/$LOG_DIR/disagg_serving_test.log" | ||
local test_log_file="$original_dir/$LOG_DIR/disagg_serving_test_output.log" | ||
|
||
echo "Running disagg serving test in $disagg_dir" | ||
echo "Log file: $log_file" | ||
echo "------------------------------------------" | ||
|
||
# Check if disagg_serving_local directory exists | ||
if [ ! -d "$disagg_dir" ]; then | ||
echo "❌ $disagg_dir directory not found" | ||
TEST_RESULTS["$test_name"]="FAIL" | ||
TEST_LOG_FILES["$test_name"]="$log_file" | ||
return 1 | ||
fi | ||
|
||
# Change to disagg_serving_local directory | ||
cd "$disagg_dir" | ||
|
||
# Start the serving in background | ||
echo "Starting disagg serving..." | ||
bash -e disagg_serving_local.sh --executor ray > "$log_file" 2>&1 & | ||
local serving_pid=$! | ||
|
||
# Wait for the serving to start (look for startup message) | ||
echo "Waiting for serving to start..." | ||
local max_wait=120 | ||
local wait_count=0 | ||
while [ $wait_count -lt $max_wait ]; do | ||
if grep -q "INFO: Application startup complete." "$log_file" 2>/dev/null; then | ||
echo "✅ Serving started successfully" | ||
break | ||
fi | ||
sleep 2 | ||
wait_count=$((wait_count + 2)) | ||
done | ||
|
||
if [ $wait_count -ge $max_wait ]; then | ||
echo "❌ Timeout waiting for serving to start" | ||
kill $serving_pid 2>/dev/null | ||
cd "$original_dir" | ||
TEST_RESULTS["$test_name"]="FAIL" | ||
TEST_LOG_FILES["$test_name"]="$log_file" | ||
return 1 | ||
fi | ||
|
||
# Wait a bit more for Uvicorn to be fully ready | ||
sleep 5 | ||
|
||
# Run the test directly in the current script process | ||
echo "" | ||
echo "==========================================" | ||
echo "Serving is now running on http://localhost:8000" | ||
echo "Running disagg_serving_test.py in current process..." | ||
echo "==========================================" | ||
|
||
# Run the test directly and capture output | ||
echo "Running disagg_serving_test.py..." | ||
echo "Current directory: $(pwd)" | ||
echo "----------------------------------------" | ||
|
||
# Run the test and capture output | ||
python disagg_serving_test.py 2>&1 | tee "$test_log_file" | ||
local test_result=$? | ||
|
||
echo "----------------------------------------" | ||
echo "Test completed with exit code: $test_result" | ||
|
||
# Check test result and store it | ||
if [ $test_result -eq 0 ]; then | ||
echo "✅ disagg_serving_test ran successfully" | ||
TEST_RESULTS["$test_name"]="PASS" | ||
else | ||
echo "❌ disagg_serving_test failed with exit code $test_result" | ||
TEST_RESULTS["$test_name"]="FAIL" | ||
fi | ||
TEST_LOG_FILES["$test_name"]="$test_log_file" | ||
|
||
# Stop the serving | ||
echo "Stopping serving..." | ||
ray stop | ||
kill $serving_pid 2>/dev/null | ||
wait $serving_pid 2>/dev/null | ||
|
||
# Go back to original directory | ||
cd "$original_dir" | ||
|
||
echo "✅ disagg_serving_test completed" | ||
echo "" | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Harden disagg_serving test: avoid masked return values, ensure cleanup on exit, and tolerate missing Ray
- Split declare/assign (SC2155).
- Add a trap to always clean up the background serving.
- Make ray stop tolerant of absence/failures.
- Don’t let set -e abort on wait/kill; add || true.
run_disagg_serving_test() {
local disagg_dir="disagg_serving_local"
local test_name="disagg_serving_test"
# Store the original directory path for log files
- local original_dir=$(pwd)
+ local original_dir
+ original_dir=$(pwd)
local log_file="$original_dir/$LOG_DIR/disagg_serving_test.log"
local test_log_file="$original_dir/$LOG_DIR/disagg_serving_test_output.log"
+ local serving_pid=""
+
+ # Ensure cleanup even if the function exits early
+ cleanup_disagg() {
+ if [ -n "$serving_pid" ] && kill -0 "$serving_pid" 2>/dev/null; then
+ ray stop >/dev/null 2>&1 || true
+ kill "$serving_pid" 2>/dev/null || true
+ wait "$serving_pid" 2>/dev/null || true
+ fi
+ cd "$original_dir" 2>/dev/null || true
+ }
+ trap cleanup_disagg RETURN
echo "Running disagg serving test in $disagg_dir"
echo "Log file: $log_file"
echo "------------------------------------------"
# Check if disagg_serving_local directory exists
if [ ! -d "$disagg_dir" ]; then
echo "❌ $disagg_dir directory not found"
- TEST_RESULTS["$test_name"]="FAIL"
+ TEST_RESULTS["$test_name"]="SKIP"
TEST_LOG_FILES["$test_name"]="$log_file"
return 1
fi
# Change to disagg_serving_local directory
cd "$disagg_dir"
# Start the serving in background
echo "Starting disagg serving..."
bash -e disagg_serving_local.sh --executor ray > "$log_file" 2>&1 &
- local serving_pid=$!
+ serving_pid=$!
# Wait for the serving to start (look for startup message)
echo "Waiting for serving to start..."
local max_wait=120
local wait_count=0
while [ $wait_count -lt $max_wait ]; do
+ # If the process died, break early
+ if ! kill -0 "$serving_pid" 2>/dev/null; then
+ echo "❌ Serving process exited prematurely. See log: $log_file"
+ break
+ fi
if grep -q "INFO: Application startup complete." "$log_file" 2>/dev/null; then
echo "✅ Serving started successfully"
break
fi
sleep 2
wait_count=$((wait_count + 2))
done
if [ $wait_count -ge $max_wait ]; then
echo "❌ Timeout waiting for serving to start"
- kill $serving_pid 2>/dev/null
- cd "$original_dir"
+ kill "$serving_pid" 2>/dev/null || true
TEST_RESULTS["$test_name"]="FAIL"
TEST_LOG_FILES["$test_name"]="$log_file"
return 1
fi
# Wait a bit more for Uvicorn to be fully ready
sleep 5
# Run the test directly in the current script process
echo ""
echo "=========================================="
echo "Serving is now running on http://localhost:8000"
echo "Running disagg_serving_test.py in current process..."
echo "=========================================="
# Run the test directly and capture output
echo "Running disagg_serving_test.py..."
echo "Current directory: $(pwd)"
echo "----------------------------------------"
# Run the test and capture output
- python disagg_serving_test.py 2>&1 | tee "$test_log_file"
+ python3 disagg_serving_test.py 2>&1 | tee "$test_log_file"
local test_result=$?
echo "----------------------------------------"
echo "Test completed with exit code: $test_result"
# Check test result and store it
if [ $test_result -eq 0 ]; then
echo "✅ disagg_serving_test ran successfully"
TEST_RESULTS["$test_name"]="PASS"
else
echo "❌ disagg_serving_test failed with exit code $test_result"
TEST_RESULTS["$test_name"]="FAIL"
fi
TEST_LOG_FILES["$test_name"]="$test_log_file"
# Stop the serving
echo "Stopping serving..."
- ray stop
- kill $serving_pid 2>/dev/null
- wait $serving_pid 2>/dev/null
+ ray stop >/dev/null 2>&1 || true
+ kill "$serving_pid" 2>/dev/null || true
+ wait "$serving_pid" 2>/dev/null || true
# Go back to original directory
cd "$original_dir"
echo "✅ disagg_serving_test completed"
echo ""
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# Function: run disagg serving test | |
run_disagg_serving_test() { | |
local disagg_dir="disagg_serving_local" | |
local test_name="disagg_serving_test" | |
# Store the original directory path for log files | |
local original_dir=$(pwd) | |
local log_file="$original_dir/$LOG_DIR/disagg_serving_test.log" | |
local test_log_file="$original_dir/$LOG_DIR/disagg_serving_test_output.log" | |
echo "Running disagg serving test in $disagg_dir" | |
echo "Log file: $log_file" | |
echo "------------------------------------------" | |
# Check if disagg_serving_local directory exists | |
if [ ! -d "$disagg_dir" ]; then | |
echo "❌ $disagg_dir directory not found" | |
TEST_RESULTS["$test_name"]="FAIL" | |
TEST_LOG_FILES["$test_name"]="$log_file" | |
return 1 | |
fi | |
# Change to disagg_serving_local directory | |
cd "$disagg_dir" | |
# Start the serving in background | |
echo "Starting disagg serving..." | |
bash -e disagg_serving_local.sh --executor ray > "$log_file" 2>&1 & | |
local serving_pid=$! | |
# Wait for the serving to start (look for startup message) | |
echo "Waiting for serving to start..." | |
local max_wait=120 | |
local wait_count=0 | |
while [ $wait_count -lt $max_wait ]; do | |
if grep -q "INFO: Application startup complete." "$log_file" 2>/dev/null; then | |
echo "✅ Serving started successfully" | |
break | |
fi | |
sleep 2 | |
wait_count=$((wait_count + 2)) | |
done | |
if [ $wait_count -ge $max_wait ]; then | |
echo "❌ Timeout waiting for serving to start" | |
kill $serving_pid 2>/dev/null | |
cd "$original_dir" | |
TEST_RESULTS["$test_name"]="FAIL" | |
TEST_LOG_FILES["$test_name"]="$log_file" | |
return 1 | |
fi | |
# Wait a bit more for Uvicorn to be fully ready | |
sleep 5 | |
# Run the test directly in the current script process | |
echo "" | |
echo "==========================================" | |
echo "Serving is now running on http://localhost:8000" | |
echo "Running disagg_serving_test.py in current process..." | |
echo "==========================================" | |
# Run the test directly and capture output | |
echo "Running disagg_serving_test.py..." | |
echo "Current directory: $(pwd)" | |
echo "----------------------------------------" | |
# Run the test and capture output | |
python disagg_serving_test.py 2>&1 | tee "$test_log_file" | |
local test_result=$? | |
echo "----------------------------------------" | |
echo "Test completed with exit code: $test_result" | |
# Check test result and store it | |
if [ $test_result -eq 0 ]; then | |
echo "✅ disagg_serving_test ran successfully" | |
TEST_RESULTS["$test_name"]="PASS" | |
else | |
echo "❌ disagg_serving_test failed with exit code $test_result" | |
TEST_RESULTS["$test_name"]="FAIL" | |
fi | |
TEST_LOG_FILES["$test_name"]="$test_log_file" | |
# Stop the serving | |
echo "Stopping serving..." | |
ray stop | |
kill $serving_pid 2>/dev/null | |
wait $serving_pid 2>/dev/null | |
# Go back to original directory | |
cd "$original_dir" | |
echo "✅ disagg_serving_test completed" | |
echo "" | |
} | |
run_disagg_serving_test() { | |
local disagg_dir="disagg_serving_local" | |
local test_name="disagg_serving_test" | |
# Store the original directory path for log files | |
local original_dir | |
original_dir=$(pwd) | |
local log_file="$original_dir/$LOG_DIR/disagg_serving_test.log" | |
local test_log_file="$original_dir/$LOG_DIR/disagg_serving_test_output.log" | |
local serving_pid="" | |
# Ensure cleanup even if the function exits early | |
cleanup_disagg() { | |
if [ -n "$serving_pid" ] && kill -0 "$serving_pid" 2>/dev/null; then | |
ray stop >/dev/null 2>&1 || true | |
kill "$serving_pid" 2>/dev/null || true | |
wait "$serving_pid" 2>/dev/null || true | |
fi | |
cd "$original_dir" 2>/dev/null || true | |
} | |
trap cleanup_disagg RETURN | |
echo "Running disagg serving test in $disagg_dir" | |
echo "Log file: $log_file" | |
echo "------------------------------------------" | |
# Check if disagg_serving_local directory exists | |
if [ ! -d "$disagg_dir" ]; then | |
echo "❌ $disagg_dir directory not found" | |
TEST_RESULTS["$test_name"]="SKIP" | |
TEST_LOG_FILES["$test_name"]="$log_file" | |
return 1 | |
fi | |
# Change to disagg_serving_local directory | |
cd "$disagg_dir" | |
# Start the serving in background | |
echo "Starting disagg serving..." | |
bash -e disagg_serving_local.sh --executor ray > "$log_file" 2>&1 & | |
serving_pid=$! | |
# Wait for the serving to start (look for startup message) | |
echo "Waiting for serving to start..." | |
local max_wait=120 | |
local wait_count=0 | |
while [ $wait_count -lt $max_wait ]; do | |
# If the process died, break early | |
if ! kill -0 "$serving_pid" 2>/dev/null; then | |
echo "❌ Serving process exited prematurely. See log: $log_file" | |
break | |
fi | |
if grep -q "INFO: Application startup complete." "$log_file" 2>/dev/null; then | |
echo "✅ Serving started successfully" | |
break | |
fi | |
sleep 2 | |
wait_count=$((wait_count + 2)) | |
done | |
if [ $wait_count -ge $max_wait ]; then | |
echo "❌ Timeout waiting for serving to start" | |
kill "$serving_pid" 2>/dev/null || true | |
TEST_RESULTS["$test_name"]="FAIL" | |
TEST_LOG_FILES["$test_name"]="$log_file" | |
return 1 | |
fi | |
# Wait a bit more for Uvicorn to be fully ready | |
sleep 5 | |
# Run the test directly in the current script process | |
echo "" | |
echo "==========================================" | |
echo "Serving is now running on http://localhost:8000" | |
echo "Running disagg_serving_test.py in current process..." | |
echo "==========================================" | |
# Run the test and capture output | |
echo "Running disagg_serving_test.py..." | |
echo "Current directory: $(pwd)" | |
echo "----------------------------------------" | |
python3 disagg_serving_test.py 2>&1 | tee "$test_log_file" | |
local test_result=$? | |
echo "----------------------------------------" | |
echo "Test completed with exit code: $test_result" | |
# Check test result and store it | |
if [ $test_result -eq 0 ]; then | |
echo "✅ disagg_serving_test ran successfully" | |
TEST_RESULTS["$test_name"]="PASS" | |
else | |
echo "❌ disagg_serving_test failed with exit code $test_result" | |
TEST_RESULTS["$test_name"]="FAIL" | |
fi | |
TEST_LOG_FILES["$test_name"]="$test_log_file" | |
# Stop the serving | |
echo "Stopping serving..." | |
ray stop >/dev/null 2>&1 || true | |
kill "$serving_pid" 2>/dev/null || true | |
wait "$serving_pid" 2>/dev/null || true | |
# Go back to original directory | |
cd "$original_dir" | |
echo "✅ disagg_serving_test completed" | |
echo "" | |
} |
🧰 Tools
🪛 Shellcheck (0.10.0)
[warning] 76-76: Declare and assign separately to avoid masking return values.
(SC2155)
🤖 Prompt for AI Agents
In examples/llm-api/ray/run_ray_tests.sh around lines 70-165, the
disagg_serving_test function needs hardening: split combined declare/assign
(SC2155) by separating local original_dir assignment from declaration; add a
trap at function start that ensures the background serving is killed and we cd
back to original_dir on any exit (cleanup handler should check if serving_pid is
set and kill/wait it safely); make ray stop tolerant by prefixing it with "ray
stop || true" (or checking if ray exists) and ensure kill/wait commands do not
abort the script under set -e by appending "|| true" to those calls; also ensure
TEST_RESULTS/TEST_LOG_FILES are set on all failure paths before returning so no
return value gets masked. Implement these changes so cleanup always runs, errors
in stopping ray/kill don't cause the function to fail unexpectedly, and SC2155
is avoided.
@@ -3,6 +3,7 @@ | |||
import functools | |||
import gc | |||
import os | |||
import pickle # nosec B403 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Security: Document the pickle usage rationale
The pickle
import with # nosec B403
suppresses security warnings. While this may be necessary for Ray's distributed communication, it's important to document why pickle is required here and what data is being serialized/deserialized to ensure security implications are understood.
Consider adding a comment explaining:
-import pickle # nosec B403
+import pickle # nosec B403 - Required for Ray actor communication; only trusted data from Ray actors is deserialized
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
import pickle # nosec B403 | |
import pickle # nosec B403 - Required for Ray actor communication; only trusted data from Ray actors is deserialized |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/py_executor.py around line 6, the bare `import
pickle # nosec B403` suppresses security warnings without explanation; add a
concise comment documenting why pickle is required (e.g., for Ray actor/remote
object serialization), what exact types/structures are serialized/deserialized
here, and any mitigations (such as validating inputs, using safe alternatives
where possible, or restricting deserialization to trusted sources). Update the
comment to mention the threat model and note if future refactors should consider
safer serializers or explicit whitelist-based deserialization.
@@ -17,11 +18,13 @@ | |||
except ImportError: | |||
from cuda import cudart | |||
|
|||
import ray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing Ray import protection
Ray is imported unconditionally, but it may not be installed in all environments. This could cause import failures for users who don't need Ray functionality.
Add import protection:
-import ray
+try:
+ import ray
+except ImportError:
+ ray = None
Then check if Ray is available when using Ray-specific features.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
import ray | |
try: | |
import ray | |
except ImportError: | |
ray = None |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/py_executor.py around line 21, avoid importing
ray unconditionally because it may not be installed; wrap the import in a
try/except ImportError and set a module-level flag (e.g., _has_ray = True/False)
or store the imported module in a variable (e.g., ray = None) so callers can
check availability; alternatively perform a lazy import inside Ray-specific
functions and raise a clear error if Ray is required but missing; update any
code paths that assume ray is present to guard on the flag or import result
before using Ray APIs.
@@ -153,10 +156,10 @@ def __init__(self, | |||
kv_cache_transceiver: Optional[KvCacheTransceiver] = None, | |||
guided_decoder: Optional[GuidedDecoder] = None, | |||
garbage_collection_gen0_threshold: Optional[int] = None, | |||
start_worker: bool = True): | |||
start_worker: bool = True, | |||
virtual_memory_pools: Optional[dict] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Document the virtual_memory_pools parameter
The new virtual_memory_pools
parameter lacks documentation about its purpose, expected structure, and lifecycle management.
Add documentation to the __init__
method:
"""
Args:
...
virtual_memory_pools: Optional dictionary of memory pools for virtual memory management.
Keys represent pool identifiers and values are pool objects that will be freed
on shutdown. Used for Ray-based distributed execution.
"""
Also applies to: 181-181, 320-324
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/py_executor.py around lines 160, 181 and
320-324, the new virtual_memory_pools parameter is undocumented; add a concise
docstring entry to the __init__ (and any other affected method docs) describing
that virtual_memory_pools is an optional dict mapping pool identifiers to pool
objects used for virtual memory management (e.g., Ray/distributed execution),
that values are pool objects which will be freed on shutdown, the expected
types/structure (keys as strings/ids, values as pool-like objects), and
lifecycle behavior (who owns and is responsible for freeing them). Ensure the
docstring follows existing style of Args: and includes the example note about
usage and shutdown responsibility.
@@ -196,6 +200,7 @@ def __init__(self, | |||
self.response_lock = threading.Lock() | |||
self.response_cv = threading.Condition(self.response_lock) | |||
self.responses = {} | |||
self.result_wait_queues = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Clean up result_wait_queues on shutdown
The result_wait_queues
dictionary is not cleaned up during shutdown, which could lead to memory leaks if Ray actor handles are retained.
Add cleanup in the shutdown method:
def shutdown(self):
"""
Signals the server to shutdown.
"""
self.executor_request_queue.enqueue_shutdown_request()
self.shutdown_event.wait()
self.worker_thread.join()
self.worker_started = False
+ # Clean up result wait queues
+ self.result_wait_queues.clear()
for manager in self.resource_manager.resource_managers.values():
Also applies to: 329-331, 428-428, 1689-1691
# (TODO: joyang) There are other types of responses, we need to sort out. | ||
if type( | ||
resp | ||
) == LlmResponse and req_id in self.result_wait_queues and self.result_wait_queues[ | ||
req_id] is not None: | ||
self.result_wait_queues[req_id].put_response.remote( | ||
resp.client_id, resp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use isinstance() for type checking
Using type(resp) == LlmResponse
is not idiomatic Python and won't work with subclasses.
-if type(
- resp
-) == LlmResponse and req_id in self.result_wait_queues and self.result_wait_queues[
- req_id] is not None:
+if isinstance(resp, LlmResponse) and req_id in self.result_wait_queues and self.result_wait_queues[
+ req_id] is not None:
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# (TODO: joyang) There are other types of responses, we need to sort out. | |
if type( | |
resp | |
) == LlmResponse and req_id in self.result_wait_queues and self.result_wait_queues[ | |
req_id] is not None: | |
self.result_wait_queues[req_id].put_response.remote( | |
resp.client_id, resp) | |
# (TODO: joyang) There are other types of responses, we need to sort out. | |
if isinstance(resp, LlmResponse) and req_id in self.result_wait_queues and self.result_wait_queues[ | |
req_id] is not None: | |
self.result_wait_queues[req_id].put_response.remote( | |
resp.client_id, resp) |
🧰 Tools
🪛 Ruff (0.12.2)
1685-1687: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/py_executor.py around lines 1685 to 1691,
replace the exact-type check using type(resp) == LlmResponse with an isinstance
check so subclass instances are accepted; keep the existing guard for req_id in
self.result_wait_queues and non-None queue, and then call
put_response.remote(resp.client_id, resp) as before.
self.result_wait_queues[req_id].put_response.remote( | ||
resp.client_id, resp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential AttributeError when calling Ray actor method
The code calls put_response.remote()
without error handling. If the Ray actor is dead or unreachable, this could raise an exception.
Add error handling:
-self.result_wait_queues[req_id].put_response.remote(
- resp.client_id, resp)
+try:
+ self.result_wait_queues[req_id].put_response.remote(
+ resp.client_id, resp)
+except Exception as e:
+ logger.error(f"Failed to send response to Ray queue for request {req_id}: {e}")
+ # Consider whether to re-raise or handle gracefully
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
self.result_wait_queues[req_id].put_response.remote( | |
resp.client_id, resp) | |
try: | |
self.result_wait_queues[req_id].put_response.remote( | |
resp.client_id, resp) | |
except Exception as e: | |
logger.error(f"Failed to send response to Ray queue for request {req_id}: {e}") | |
# Consider whether to re-raise or handle gracefully |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/py_executor.py around lines 1690-1691, the
call to self.result_wait_queues[req_id].put_response.remote(resp.client_id,
resp) is unprotected and can raise if the Ray actor is dead/unreachable; wrap
this remote invocation in a try/except that catches Exception, log the exception
with contextual data (req_id, resp.client_id, traceback) and the actor
reference, and handle the failure path (e.g., mark the request failed, remove or
recreate the queue/actor, or drop the response) so the process doesn’t crash
when the Ray actor call fails.
self.has_start_local_cluser = False | ||
runtime_env = { | ||
"env_vars": { | ||
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1" | ||
} | ||
} | ||
|
||
ray_init_args = { | ||
"include_dashboard": False, | ||
"namespace": "trtllm", | ||
"ignore_reinit_error": True, | ||
"runtime_env": runtime_env | ||
} | ||
try: | ||
ray.init(address="auto", **ray_init_args) | ||
logger.info(f"Attached to an existing Ray cluster.") | ||
except ConnectionError: | ||
logger.info(f"Ray cluster not found, starting a new one.") | ||
|
||
if not ray.is_initialized(): | ||
ray.init(**ray_init_args) | ||
self.has_start_local_cluser = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Typo in flag and fragile Ray attach/start logic.
- Typo: has_start_local_cluser → has_started_local_cluster.
- Catching only ConnectionError may miss Ray’s RaySystemError/RuntimeError. Make the fallback robust.
- self.has_start_local_cluser = False
+ self.has_started_local_cluster = False
@@
- try:
- ray.init(address="auto", **ray_init_args)
+ try:
+ ray.init(address="auto", **ray_init_args)
logger.info(f"Attached to an existing Ray cluster.")
- except ConnectionError:
+ except (ConnectionError, RuntimeError, ray.exceptions.RaySystemError):
logger.info(f"Ray cluster not found, starting a new one.")
@@
if not ray.is_initialized():
ray.init(**ray_init_args)
- self.has_start_local_cluser = True
+ self.has_started_local_cluster = True
@@
- if self.has_start_local_cluser:
+ if self.has_started_local_cluster:
logger.debug("Shutting down Ray cluster")
ray.shutdown()
Also applies to: 185-187
torch.distributed.init_process_group( | ||
backend="cuda:nccl,cpu:gloo", | ||
init_method=f"tcp://{self.master_address}:{self.master_port}", | ||
world_size=world_size, | ||
rank=rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid torch.distributed backend string; will raise at init_process_group.
backend="cuda:nccl,cpu:gloo" is not a valid value for torch.distributed.init_process_group. Use "nccl" for CUDA; optionally fall back to "gloo" when CUDA is unavailable.
- torch.distributed.init_process_group(
- backend="cuda:nccl,cpu:gloo",
- init_method=f"tcp://{self.master_address}:{self.master_port}",
- world_size=world_size,
- rank=rank)
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
+ torch.distributed.init_process_group(
+ backend=backend,
+ init_method=f"tcp://{self.master_address}:{self.master_port}",
+ world_size=world_size,
+ rank=rank,
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
torch.distributed.init_process_group( | |
backend="cuda:nccl,cpu:gloo", | |
init_method=f"tcp://{self.master_address}:{self.master_port}", | |
world_size=world_size, | |
rank=rank) | |
# Select the proper backend for distributed transport | |
backend = "nccl" if torch.cuda.is_available() else "gloo" | |
torch.distributed.init_process_group( | |
backend=backend, | |
init_method=f"tcp://{self.master_address}:{self.master_port}", | |
world_size=world_size, | |
rank=rank, | |
) |
🤖 Prompt for AI Agents
In tensorrt_llm/executor/ray_gpu_worker.py around lines 54 to 58, the backend
passed to torch.distributed.init_process_group is invalid
("cuda:nccl,cpu:gloo"); change it to use a single valid backend string by
selecting "nccl" when torch.cuda.is_available() is True and otherwise "gloo",
then pass that backend to init_process_group; ensure the rest of the init call
(init_method, world_size, rank) remains unchanged.
self._executor_config = executor_config | ||
self._is_pytorch_backend = getattr(self._executor_config, "backend", | ||
None) == "pytorch" | ||
assert self._is_pytorch_backend | ||
|
||
if isinstance(engine, list): | ||
engine = engine[self.rank] | ||
|
||
if executor_config is None: | ||
executor_config = tllm.ExecutorConfig(1) | ||
|
||
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Executor config handling and backend assertion break non-PyTorch paths.
- If executor_config is None, you create a local instance but never assign it to self._executor_config.
- Asserting PyTorch backend contradicts the code paths that support tllm.Executor (C++ backend).
- Accessing self._executor_config.pytorch_backend_config without guards will fail when missing.
- self._executor_config = executor_config
- self._is_pytorch_backend = getattr(self._executor_config, "backend",
- None) == "pytorch"
- assert self._is_pytorch_backend
+ # Normalize executor config and detect backend
+ self._executor_config = executor_config or tllm.ExecutorConfig(1)
+ self._is_pytorch_backend = getattr(self._executor_config, "backend", None) == "pytorch"
@@
- if executor_config is None:
- executor_config = tllm.ExecutorConfig(1)
+ if executor_config is None:
+ executor_config = self._executor_config
@@
- is_overlap_enabled = self._is_pytorch_backend and not self._executor_config.pytorch_backend_config.disable_overlap_scheduler
+ pytorch_cfg = getattr(self._executor_config, "pytorch_backend_config", None)
+ is_overlap_enabled = self._is_pytorch_backend and not getattr(
+ pytorch_cfg, "disable_overlap_scheduler", False
+ )
Also applies to: 169-193, 269-276
🤖 Prompt for AI Agents
In tensorrt_llm/executor/ray_gpu_worker.py around lines 126-137 (and also apply
same changes at 169-193 and 269-276), the code incorrectly asserts a PyTorch
backend, creates a local executor_config when None without assigning it to
self._executor_config, and unguardedly accesses pytorch_backend_config; change
to: if executor_config is None create a new tllm.ExecutorConfig and assign it to
self._executor_config (and update self._is_pytorch_backend after that), remove
the unconditional assert, compute self._is_pytorch_backend from
self._executor_config.backend, and wrap any accesses to
self._executor_config.pytorch_backend_config in conditional logic that only runs
when self._is_pytorch_backend is true (provide sensible fallbacks or raise clear
errors otherwise); apply identical fixes to the other mentioned line ranges.
WIP on final cleanups & refactoring.
Summary by CodeRabbit
New Features
Documentation
Chores
Description
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.