Skip to content

[#4403][autodeploy] Refactor: Move more transformations to new inf optimizer, Add quantization_source to factory interface #6760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

Fridah-nv
Copy link
Collaborator

@Fridah-nv Fridah-nv commented Aug 8, 2025

Summary by CodeRabbit

  • New Features

    • Added MoE detection & fusion, TP/BMM/EP sharding, weight-loading transform, and a pluggable quantization-config reader; shared config exposes distributed/sharding settings.
  • Refactor

    • Unified transform APIs (SharedConfig/TransformInfo), split quantization into config- and graph-driven flows, and converted rope/transpose passes to standardized transform classes; optimizer-driven orchestration (InferenceOptimizer) replaces legacy multi-stage paths.
  • Tests

    • Tests migrated to export→optimize→validate using transformed GraphModules and the optimizer.
  • Chores

    • Cleaned imports/signatures, removed legacy direct transform calls, and updated configs.

Description

Test Coverage

TODO: add perf results after rebase

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@Fridah-nv Fridah-nv requested a review from a team as a code owner August 8, 2025 16:25
@Fridah-nv Fridah-nv requested a review from suyoggupta August 8, 2025 16:25
Copy link
Contributor

coderabbitai bot commented Aug 8, 2025

📝 Walkthrough

Walkthrough

Introduces a SharedConfig-driven transform pipeline, adds a QuantConfigReader registry, splits quant transforms, implements MoE/RoPE/sharding/load-weights transforms, refactors many transforms to class-based plugins with SharedConfig in signatures, updates optimizer to supply shared config, and rewrites tests to use InferenceOptimizer on GraphModules.

Changes

Cohort / File(s) Change Summary
Config
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Reordered and expanded transforms: removed initial quantize/quantize_moe, added pattern_matcher entries (match_moe_pattern, match_rope_pattern, match_rope_layout, eliminate_redundant_transposes, optimize_rope, quantize_from_config, quantize_from_graph, quantize_moe), added sharding entries (detect_column_row_shard, detect_ep_shard, detect_dp_bmm_shard, sharding_transform_executor) and load_weights.
Quant config reader & HF model
New: tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
tensorrt_llm/_torch/auto_deploy/models/hf.py
Added QuantConfigReader + registry and ModelOPT reader; AutoModel factories switch from JSON/dict parsing to reader-based flow (_quant_config_reader) and merge extra reader kwargs into model creation and cache dtype logic.
Quant transforms
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
Split Quantization into QuantizationFromConfig (quantize_from_config) and QuantizationFromGraph (quantize_from_graph); updated transforms to accept shared_config.
RoPE transforms
tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
Replaced free-function RoPE utilities with class-based transforms: MatchRopePattern, MatchRopeLayout (with config), and OptimizeRope; standardized _apply and TransformInfo reporting; removed old helper _move_node_before_first_user.
Eliminate redundant transposes
tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
New EliminateRedundantTransposes transform that finds and removes cancelling transpose pairs (preserving semantics via contiguous insertion) and runs dead-code elimination.
MoE transforms
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
New MatchMoePattern and FuseMoe transforms (plus helpers) to detect/fuse MoE subgraphs including quantized variants; registered as match_moe_pattern/fuse_moe.
Load weights transform
tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py
New LoadWeightsToDevice transform and LoadWeightsToDeviceConfig to load/init weights on specified device and move model + cached interface to that device.
Sharding transforms & utils
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
New sharding transform module (ColumnRowShard, DpBmmShard, DetectEpShard, ShardingTransformExecutor) using shared sharding config; sharding utils refactored to declarative ShardingConfig and per-transform TP/BMM/EP info classes; removed legacy detection/executor functions.
Transform interface & signatures
tensorrt_llm/_torch/auto_deploy/transform/interface.py, multiple transform/library/*.py
Added SharedConfig Pydantic model (sharding_config, local_rank, world_size). Propagated shared_config into BaseTransform.__call__ and _apply signatures; updated many transforms to accept it.
Optimizer changes & integration
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py, tensorrt_llm/_torch/auto_deploy/transformations/transform.py
InferenceOptimizer now initializes SharedConfig (local_rank/world_size, distributed-aware) and passes it to transforms; replaced legacy multi-stage inline transform/sharding flow with modular optimizer-driven invocation and minor pre-config hooks (e.g., attention/rope/load_weights propagation).
Package init cleanup
tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py
Removed re-exports/imports of eliminate_redundant_transposes, rope, and sharding.
Removed legacy module
tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py
Deleted legacy function-based redundant-transpose transformer module (moved/rewritten under transform.library).
Tests updated
tests/... (quantization, rope, redundant_transposes, moe_fusion, tp/bmm/ep sharding, helpers)
Tests refactored to export models to GraphModules via torch_export_to_gm, apply transforms via InferenceOptimizer configs, and use run_test_transformed_gm; updated imports, config keys (e.g., "quantize""quantize_from_config"), and sharding detection to read from optimizer.shared_config.sharding_config.
Minor signature updates
multiple files (attention, build_model, cleanup_*, export_to_gm, quantize_moe, etc.)
Added shared_config parameter to many transform _apply methods and corresponding imports; removed or adjusted several legacy imports/usages.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant InferenceOptimizer
    participant TransformRegistry
    participant Transform (instance)
    participant GraphModule
    Caller->>InferenceOptimizer: provide GM + optimizer config
    InferenceOptimizer->>TransformRegistry: resolve transform names
    TransformRegistry-->>InferenceOptimizer: transform classes
    InferenceOptimizer->>Transform: instantiate & call transform(gm, cm, factory, shared_config)
    Transform->>GraphModule: inspect/mutate FX graph
    GraphModule-->>Transform: return modified GM + TransformInfo
    Transform-->>InferenceOptimizer: return GM + TransformInfo
    InferenceOptimizer-->>Caller: final transformed GM
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested labels

Community want to contribute

Suggested reviewers

  • chzblych
  • pcastonguay
  • nv-guomingz
  • shaharmor98
  • Tabrizian
  • hyukn
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@Fridah-nv Fridah-nv force-pushed the user/fridah/merge-0801 branch from 6ff930b to e4df900 Compare August 8, 2025 16:27
@Fridah-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14623 [ run ] triggered by Bot

@Fridah-nv Fridah-nv changed the title [#4403][refactor] Move more transformations to new inf optimizer, Add quantization_source to factory interface [#4403][autodeploy] Refactor: Move more transformations to new inf optimizer, Add quantization_source to factory interface Aug 8, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (1)

100-131: Return-type and variable mismatch in from_file

  1. The return annotation is Tuple[ModelOPTQuantConfigReader, Optional[torch.dtype]], but extra_model_kwargs is a dict.
  2. You again insert the string "float16" instead of torch.float16.

Create a second element that matches the annotation or widen the annotation to Dict[str, Any].
Example fix:

-    ) -> Optional[Tuple["ModelOPTQuantConfigReader", Optional[torch.dtype]]]:
+    ) -> Optional[Tuple["ModelOPTQuantConfigReader", Dict[str, Any]]]:
...
-        extra_model_kwargs: Dict[str, Any] = {}
+        extra_model_kwargs: Dict[str, Any] = {}
...
-            extra_model_kwargs["torch_dtype"] = "float16"
+            extra_model_kwargs["torch_dtype"] = torch.float16
🧹 Nitpick comments (7)
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)

60-64: Attribute access assumes dataclass, not raw dict

self.ad_config.transforms["match_rope_layout"].expected_layout = … will fail if the YAML loader left the value as a plain dict.
Guard with isinstance or convert the entry to a config object before mutation, mirroring the pattern used for match_attention_layout.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py (1)

214-223: Redundant .to("cuda") after transform

InferenceOptimizer already retains the device placement from the exported GM; calling gm_transformed.to("cuda") in every branch is unnecessary noise in test code.

tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (1)

57-63: Type-hint doesn’t match actual content

nodes_to_eliminate: Set[Tuple[Node, Node]] is declared as 2-tuple set but a 3-tuple (t1, t2, bool) is inserted. Adjust the hint to Set[Tuple[Node, Node, bool]] or drop it.

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

189-204: Iterating over live gm.graph.nodes while mutating

Both quantization passes insert new nodes and mutate node.target inside a for n in gm.graph.nodes: loop. Mutating while iterating can yield nondeterministic behaviour or skip nodes. Iterate over list(gm.graph.nodes) instead.

tensorrt_llm/_torch/auto_deploy/models/hf.py (1)

190-193: Map additional kv-cache dtypes or fail clearly

Only "float8_e4m3fn" is handled. If another dtype string appears the code silently falls back to None, later causing confusing errors downstream. Either extend the mapping or raise with a clear message.

tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (1)

240-285: need_transpose never used outside loop

The flag is computed but never referenced after the first continue, making the variable redundant and the logic slightly confusing.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

44-45: Quantize MoE ordering: confirm it runs after the chosen quantization pass

Moving quantize_moe to the end makes sense; ensure it doesn’t conflict with whichever primary quantization path was taken (config or graph) and that it doesn’t re-quantize already processed nodes.

  • If the pipeline supports dependencies, consider declaring them to lock ordering:
   quantize_moe:
     stage: pattern_matcher
+    depends_on:
+      - quantize_from_config
+      - quantize_from_graph
  • Add/confirm tests for MoE models across both quantization_source modes to catch accidental double-quantization. I can help generate those if useful.

Add a brief comment explaining why quantize_moe is intentionally last to aid maintainers.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d45236b and e4df900.

📒 Files selected for processing (12)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/hf.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py (7 hunks)
💤 Files with no reviewable changes (2)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/init.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
🧠 Learnings (6)
📚 Learning: 2025-08-08T04:10:18.987Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6728
File: cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp:966-966
Timestamp: 2025-08-08T04:10:18.987Z
Learning: TensorRT plugins currently don't support padding functionality, and TensorRT is not getting new features (in maintenance mode). This means that duplicating parameters like mExpertHiddenSize in function calls, even with TODO comments, can be acceptable as pragmatic solutions within these constraints.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
📚 Learning: 2025-07-22T08:33:49.109Z
Learnt from: yiqingy0
PR: NVIDIA/TensorRT-LLM#5198
File: jenkins/mergeWaiveList.py:0-0
Timestamp: 2025-07-22T08:33:49.109Z
Learning: In the TensorRT-LLM waive list merging system, removed lines are always located at the end of the merge waive lists, which is why the mergeWaiveList.py script uses reverse traversal - it's an optimization for this specific domain constraint.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
📚 Learning: 2025-08-01T15:14:45.673Z
Learnt from: yibinl-nvidia
PR: NVIDIA/TensorRT-LLM#6506
File: examples/models/core/mixtral/requirements.txt:3-3
Timestamp: 2025-08-01T15:14:45.673Z
Learning: In TensorRT-LLM, examples directory can have different dependency versions than the root requirements.txt file. Version conflicts between root and examples dependencies are acceptable because examples are designed to be standalone and self-contained.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py
📚 Learning: 2025-07-22T09:22:14.726Z
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
🧬 Code Graph Analysis (1)
tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (4)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (3)
  • BaseTransform (129-354)
  • TransformInfo (98-123)
  • TransformRegistry (357-385)
tensorrt_llm/graph_rewriting.py (1)
  • replace_input_with (373-392)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (6)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (1)

59-60: Double-check permanent skip

pytest.skip(...) unconditionally disables the test – consider marking it xfail or gating with a flag so coverage isn’t permanently lost once the linked bug is fixed.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py (1)

10-15: Ensure helper module still ships after refactor

_is_contiguous_op / _is_transpose_op are imported from transform.library.eliminate_redundant_transposes.
That module was removed from the main pipeline; if the file itself was deleted, these imports will break the tests.
Please confirm the helpers remain, or relocate them to a stable utilities module.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (4)

30-33: RoPE matchers: confirm ordering relative to attention layout

Adding match_rope_pattern and match_rope_layout looks fine. Please confirm the intended order w.r.t. match_attention_layout (Lines 28-29). If attention layout normalization is a prerequisite for RoPE matching (or vice versa), encode that dependency or document the assumption to avoid fragile pattern misses across models.

If the pipeline supports explicit dependencies (e.g., depends_on), consider declaring them. Otherwise, verify with representative models (llama, qwen2, mistral) that both rope matchers trigger as expected with the current ordering.


34-35: No action needed: eliminate_redundant_transposes is properly registered

  • Registered in tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py at line 44 via @TransformRegistry.register("eliminate_redundant_transposes").

Pipeline lookup will succeed.


36-39: The tests spread across the repository tell a clear story of a rapidly evolving LLM-and-Torch ecosystem with several unfinished pieces and temporary workarounds:

• Core LLM tests (in tests/unittest/llmapi and test_llm.py) cover argument parsing, API apps (chat, metrics, structural tagging, speculative decoding), executor behavior (streaming outputs, error handling), and integration w/ OpenAI-style endpoints. Many tests are marked TODO or skipped pending upstream support (larger batch sizes, TP sharding v2, GPT Plugin attention, mass integration).
• Config-level gating is exercised (e.g. PeftCacheConfig, conditional flags in YAML), with tests ensuring defaults when flags are missing.
• Performance and memory-management features in the _torch subfolder (auto-tuner tactics, executor request queuing, tensor sharing, CUDA graphs, multi-modal KV cache reuse, MOE host-sharer, FP4 quantization, FP16/FP4 mixed routines) are unit-tested for correct path selection and no memory leaks.
• TRT integration tests (test_gpt.py) and quantization mode tests validate patterns and plugin behavior.
• A handful of tests manage expected failures or limits (too-small buffer errors, OOM cases) and skip lists (waives.txt) manage flaky or unsupported combinations.
• There’s a recurring pattern of “disable until X is ready” and “TODO: enable once Y lands,” indicating active development on sharding, plugins, and new backends.

Overall, the suite provides broad coverage of core functionality but carries many temporary skips and TODOs that track incomplete optimizations, upcoming feature support, and infrastructure limitations—signaling where engineering effort is still focused.


40-43: Verify gating logic for quantize_from_config vs quantize_from_graph

I couldn’t find any reference to quantization_source in the codebase, so both transforms will always run in the auto-deploy pipeline. Please ensure only one path is taken:

  • Check tensorrt_llm/_torch/auto_deploy/config/default.yaml:
    • Currently both entries are unguarded:
        quantize_from_config:
          stage: pattern_matcher
        quantize_from_graph:
          stage: pattern_matcher
      
  • If the PR introduces a quantization_source flag, confirm where it’s consumed (none found via rg -n "quantization_source").
  • To avoid double application, either:
    • Add conditional fields in default.yaml (if supported), e.g.:
         quantize_from_config:
           stage: pattern_matcher
      +    enabled_if: ${llm_args.quantization_source == "config"}
      
         quantize_from_graph:
           stage: pattern_matcher
      +    enabled_if: ${llm_args.quantization_source == "graph"}
    • Or implement gating inside the QuantizationFromConfig and QuantizationFromGraph classes in
      tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py based on the source.

Without verifying that logic exists elsewhere, please confirm and adjust so only the intended quantization transform is applied.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14623 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11047 completed with status: 'SUCCESS'

@Fridah-nv Fridah-nv force-pushed the user/fridah/merge-0801 branch from e4df900 to 50199c5 Compare August 9, 2025 01:43
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

85-93: Return success flag from _insert_quantized_bmm

To support accurate match counting, have _insert_quantized_bmm return True when it rewrites the node, False otherwise.

-def _insert_quantized_bmm(
+def _insert_quantized_bmm(
     gm: GraphModule,
     node: Node,
     quantization_impl: QuantizationImpl,
     is_quantized_graph: bool = False,
 ):
     """Replaces the bmm node with a new quantized bmm node."""
@@
-    else:
-        # If we can't determine the shape, skip quantization
-        return
+    else:
+        # If we can't determine the shape, skip quantization
+        return False
@@
-    node.args = (*node.args, *scale_values)
+    node.args = (*node.args, *scale_values)
+    return True

Also applies to: 141-166

♻️ Duplicate comments (1)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

177-183: Guard against factory=None and avoid AttributeError on get_quant_config()

Calling factory.get_quant_config() unconditionally will raise when factory is None (common in unit tests). Adjust signature and add an early return.

Apply:

-from typing import Tuple
+from typing import Tuple, Optional
@@
-    def _apply(
-        self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
-    ) -> Tuple[GraphModule, TransformInfo]:
-        quant_config = factory.get_quant_config()
-        quant_algo = quant_config.get("quant_algo")
-        excluded_patterns = quant_config.get("exclude_modules", [])
+    def _apply(
+        self, gm: GraphModule, cm: CachedSequenceInterface, factory: Optional[ModelFactory]
+    ) -> Tuple[GraphModule, TransformInfo]:
+        if factory is None:
+            return gm, TransformInfo(skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True)
+        quant_config = factory.get_quant_config()
🧹 Nitpick comments (3)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (3)

1-3: Missing NVIDIA copyright header

Per repository guidelines, all production .py files must include an NVIDIA copyright header with the current year.

Please add the standard header at the top of this file for 2025.


212-255: Quantize-from-graph pass: looks good, minor hygiene suggestion

Logic correctly short-circuits on non-quantized graphs and removes output quantizers after fusion.

Optionally call gm.graph.lint() and gm.recompile() after mutations if the pipeline doesn’t already do this centrally.


9-22: Import style deviates from project guideline: prefer module namespace imports

Guidelines specify “Always maintain the namespace when importing in Python.” Current code imports individual functions from node_utils and quantization_utils.

Prefer:

from ...utils import node_utils, quantization_utils as q_utils

# usage:
node_utils.is_bmm_op(...)
q_utils.is_quantized_graph(...)

This improves discoverability and aligns with repo conventions.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e4df900 and 50199c5.

📒 Files selected for processing (12)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/hf.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py (7 hunks)
💤 Files with no reviewable changes (2)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/init.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py
🚧 Files skipped from review as they are similar to previous changes (8)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
🧠 Learnings (5)
📚 Learning: 2025-08-08T04:10:18.987Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6728
File: cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp:966-966
Timestamp: 2025-08-08T04:10:18.987Z
Learning: TensorRT plugins currently don't support padding functionality, and TensorRT is not getting new features (in maintenance mode). This means that duplicating parameters like mExpertHiddenSize in function calls, even with TODO comments, can be acceptable as pragmatic solutions within these constraints.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
📚 Learning: 2025-07-22T08:33:49.109Z
Learnt from: yiqingy0
PR: NVIDIA/TensorRT-LLM#5198
File: jenkins/mergeWaiveList.py:0-0
Timestamp: 2025-07-22T08:33:49.109Z
Learning: In the TensorRT-LLM waive list merging system, removed lines are always located at the end of the merge waive lists, which is why the mergeWaiveList.py script uses reverse traversal - it's an optimization for this specific domain constraint.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
📚 Learning: 2025-08-01T15:14:45.673Z
Learnt from: yibinl-nvidia
PR: NVIDIA/TensorRT-LLM#6506
File: examples/models/core/mixtral/requirements.txt:3-3
Timestamp: 2025-08-01T15:14:45.673Z
Learning: In TensorRT-LLM, examples directory can have different dependency versions than the root requirements.txt file. Version conflicts between root and examples dependencies are acceptable because examples are designed to be standalone and self-contained.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
🔇 Additional comments (1)
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

30-45: Confirmed: All newly added transforms are registered

Each of the seven transforms from default.yaml has a corresponding @TransformRegistry.register(...) entry:

  • match_rope_pattern – rope.py line 124
  • match_rope_layout – rope.py line 227
  • eliminate_redundant_transposes – eliminate_redundant_transposes.py line 44
  • optimize_rope – rope.py line 343
  • quantize_from_config – quantization.py line 168
  • quantize_from_graph – quantization.py line 212
  • quantize_moe – quantize_moe.py line 134

No missing registrations were detected; the optimizer pipeline will resolve these transforms at runtime as expected.

@Fridah-nv Fridah-nv force-pushed the user/fridah/merge-0801 branch from 50199c5 to 30cd46e Compare August 9, 2025 02:30
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🔭 Outside diff range comments (19)
tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py (2)

1-10: Add NVIDIA header and switch to namespace-preserving imports per guidelines

  • Missing required NVIDIA copyright header (current year).
  • Coding guidelines require keeping module namespace in imports; avoid importing classes directly.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+
 from typing import Tuple
 
 import torch
 from torch.fx import GraphModule
 
-from ...models.factory import ModelFactory
-from ...shim.interface import CachedSequenceInterface
-from ...utils.node_utils import is_op
-from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
+from ... import models
+from ...shim import interface as shim_ifc
+from ...utils import node_utils
+from .. import interface as t_ifc

And update references below:

-@TransformRegistry.register("cleanup_noop_slice")
-class CleanupNoopSlice(BaseTransform):
+@t_ifc.TransformRegistry.register("cleanup_noop_slice")
+class CleanupNoopSlice(t_ifc.BaseTransform):
@@
-        gm: GraphModule,
-        cm: CachedSequenceInterface,
-        factory: ModelFactory,
-        shared_config: SharedConfig,
-    ) -> Tuple[GraphModule, TransformInfo]:
+        gm: GraphModule,
+        cm: shim_ifc.CachedSequenceInterface,
+        factory: models.factory.ModelFactory,
+        shared_config: t_ifc.SharedConfig,
+    ) -> Tuple[GraphModule, t_ifc.TransformInfo]:
@@
-            if not is_op(node, torch.ops.aten.slice):
+            if not node_utils.is_op(node, torch.ops.aten.slice):
@@
-        info = TransformInfo(skipped=False, num_matches=num_matches)
+        info = t_ifc.TransformInfo(skipped=False, num_matches=num_matches)

29-48: Avoid mutating the graph while iterating; recompile after edits

  • Erasing nodes while iterating gm.graph.nodes can skip nodes unpredictably.
  • After structural edits, recompile to keep codegen in sync.
-        for node in gm.graph.nodes:
+        for node in list(gm.graph.nodes):
@@
-            node.replace_all_uses_with(in_node)
-            gm.graph.erase_node(node)
+            node.replace_all_uses_with(in_node)
+            gm.graph.erase_node(node)
             num_matches += 1
@@
-        # store info object about the transform
+        # ensure graph invariants after mutation
+        gm.recompile()
+        # store info object about the transform
tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py (2)

1-10: Add NVIDIA header and switch to namespace-preserving imports per guidelines

  • Missing mandatory NVIDIA copyright header.
  • Replace direct symbol imports with module namespace imports.
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+
 from typing import Tuple
 
 import torch
 from torch.fx import GraphModule
 
-from ...models.factory import ModelFactory
-from ...shim.interface import CachedSequenceInterface
-from ...utils.node_utils import is_op
-from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
+from ... import models
+from ...shim import interface as shim_ifc
+from ...utils import node_utils
+from .. import interface as t_ifc

Update references accordingly:

-@TransformRegistry.register("cleanup_noop_add")
-class CleanupNoopAdd(BaseTransform):
+@t_ifc.TransformRegistry.register("cleanup_noop_add")
+class CleanupNoopAdd(t_ifc.BaseTransform):
@@
-        gm: GraphModule,
-        cm: CachedSequenceInterface,
-        factory: ModelFactory,
-        shared_config: SharedConfig,
-    ) -> Tuple[GraphModule, TransformInfo]:
+        gm: GraphModule,
+        cm: shim_ifc.CachedSequenceInterface,
+        factory: models.factory.ModelFactory,
+        shared_config: t_ifc.SharedConfig,
+    ) -> Tuple[GraphModule, t_ifc.TransformInfo]:
@@
-            if not is_op(node, torch.ops.aten.add):
+            if not node_utils.is_op(node, torch.ops.aten.add):
@@
-        info = TransformInfo(skipped=False, num_matches=num_matches)
+        info = t_ifc.TransformInfo(skipped=False, num_matches=num_matches)

33-51: Guard against dtype/device promotion and shape-broadcast side effects; clean dangling nodes

Removing an add can change dtype (type promotion), device, or shape due to broadcasting. Only eliminate when:

  • zero tensor dtype == true tensor dtype
  • zero tensor device == true tensor device
  • broadcasted shape equals true tensor shape

Also remove zero_node if it becomes unused. Iterate over a copy of nodes and recompile.

-        for node in gm.graph.nodes:
+        for node in list(gm.graph.nodes):
@@
-            if is_op(node.all_input_nodes[0], torch.ops.aten.zeros):
+            if node_utils.is_op(node.all_input_nodes[0], torch.ops.aten.zeros):
                 zero_node, true_node = node.all_input_nodes
-            elif is_op(node.all_input_nodes[1], torch.ops.aten.zeros):
+            elif node_utils.is_op(node.all_input_nodes[1], torch.ops.aten.zeros):
                 true_node, zero_node = node.all_input_nodes
             else:
                 continue
@@
-            # do the replacement and clean-up
-            node.replace_all_uses_with(true_node)
-            gm.graph.erase_node(node)
+            # verify dtype/device/shape safety using meta
+            true_val = (getattr(true_node, "meta", {}) or {}).get("val", None)
+            zero_val = (getattr(zero_node, "meta", {}) or {}).get("val", None)
+            if true_val is None or zero_val is None:
+                continue  # skip if we cannot prove safety
+            if getattr(true_val, "dtype", None) != getattr(zero_val, "dtype", None):
+                continue
+            if getattr(true_val, "device", None) != getattr(zero_val, "device", None):
+                continue
+            # check broadcast shape does not change result shape
+            try:
+                import torch as _t
+                out_shape = _t.broadcast_shapes(true_val.shape, zero_val.shape)
+            except Exception:
+                continue
+            if out_shape != true_val.shape:
+                continue
+
+            # safe to remove
+            node.replace_all_uses_with(true_node)
+            gm.graph.erase_node(node)
+            # optionally remove the zero producer if dead
+            if len(zero_node.users) == 0:
+                gm.graph.erase_node(zero_node)
             num_matches += 1
@@
-        # store info object about the transform
+        gm.recompile()
+        # store info object about the transform
tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py (1)

1-16: Add NVIDIA header and use namespace-preserving imports

Comply with repository import and header guidelines.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+
 """A simple wrapper transform to build a model via the model factory."""
 
 from typing import Tuple, Type
 
 from pydantic import Field
 from torch.fx import GraphModule
 
-from ...models.factory import ModelFactory
-from ...shim.interface import CachedSequenceInterface
-from ..interface import (
-    BaseTransform,
-    SharedConfig,
-    TransformConfig,
-    TransformInfo,
-    TransformRegistry,
-)
+from ... import models
+from ...shim import interface as shim_ifc
+from .. import interface as t_ifc

Update references below:

-@TransformRegistry.register("build_model")
-class BuildModel(BaseTransform):
+@t_ifc.TransformRegistry.register("build_model")
+class BuildModel(t_ifc.BaseTransform):
@@
-    def get_config_class(cls) -> Type[TransformConfig]:
-        return BuildModelConfig
+    def get_config_class(cls) -> Type[t_ifc.TransformConfig]:
+        return BuildModelConfig
@@
-        gm: GraphModule,
-        cm: CachedSequenceInterface,
-        factory: ModelFactory,
-        shared_config: SharedConfig,
-    ) -> Tuple[GraphModule, TransformInfo]:
+        gm: GraphModule,
+        cm: shim_ifc.CachedSequenceInterface,
+        factory: models.factory.ModelFactory,
+        shared_config: t_ifc.SharedConfig,
+    ) -> Tuple[GraphModule, t_ifc.TransformInfo]:
@@
-        info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
+        info = t_ifc.TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (1)

1-17: Add NVIDIA header and use namespace-preserving imports

Align with repository standards.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+
 """A simple wrapper transform to export a model to a graph module."""
 
 from typing import List, Optional, Tuple, Type
 
 from pydantic import Field
 from torch.fx import GraphModule
 
-from ...export import torch_export_to_gm
-from ...models.factory import ModelFactory
-from ...shim.interface import CachedSequenceInterface
-from ..interface import (
-    BaseTransform,
-    SharedConfig,
-    TransformConfig,
-    TransformInfo,
-    TransformRegistry,
-)
+from ... import export as trt_export
+from ... import models
+from ...shim import interface as shim_ifc
+from .. import interface as t_ifc

Update references:

-@TransformRegistry.register("export_to_gm")
-class ExportToGM(BaseTransform):
+@t_ifc.TransformRegistry.register("export_to_gm")
+class ExportToGM(t_ifc.BaseTransform):
@@
-    def get_config_class(cls) -> Type[TransformConfig]:
-        return ExportToGMConfig
+    def get_config_class(cls) -> Type[t_ifc.TransformConfig]:
+        return ExportToGMConfig
tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py (1)

1-11: Add NVIDIA header and switch to namespace-preserving imports

Comply with header and import style guidelines.

+// Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+
 import math
 from typing import List, Tuple
 
 import torch
 from torch.fx import Graph, GraphModule
 from torch.utils._sympy.value_ranges import ValueRanges
 
-from ...models.factory import ModelFactory
-from ...shim.interface import CachedSequenceInterface
-from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
+from ... import models
+from ...shim import interface as shim_ifc
+from .. import interface as t_ifc

Update references:

-@TransformRegistry.register("cleanup_input_constraints")
-class CleanupInputConstraints(BaseTransform):
+@t_ifc.TransformRegistry.register("cleanup_input_constraints")
+class CleanupInputConstraints(t_ifc.BaseTransform):
tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (5)

1-1: Add NVIDIA copyright header (2025) to comply with OSS policy

Source files must include the NVIDIA copyright header with the current year.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.

52-52: Use public load-state-dict pre-hook API

Avoid private _register_load_state_dict_pre_hook; use the public register_load_state_dict_pre_hook to reduce breakage risk across PyTorch versions.

-            gm._register_load_state_dict_pre_hook(partial(quant_impl.load_hook, weight_name=name))
+            gm.register_load_state_dict_pre_hook(partial(quant_impl.load_hook, weight_name=name))

91-98: Preserve node metadata and maintain graph hygiene during replacement

When replacing nodes, copy meta to retain dtype/shape info. After mutating the graph, lint and recompile to avoid stale code.

     with gm.graph.inserting_after(node):
         new_node = gm.graph.call_function(
             quantized_op,
             args=tuple(args),
         )
+        # Preserve metadata for downstream passes
+        try:
+            new_node.meta = dict(node.meta)
+        except Exception:
+            pass
         node.replace_all_uses_with(new_node)
         gm.graph.erase_node(node)

Additionally, recompile once at the end of the transform (see suggestion on Lines 179-183).


157-159: Guard against unknown quantization algorithms

Indexing quantized_moe_op_map[quant_algo] raises KeyError for unsupported values. Validate and raise a clear error.

-        quantized_op = quantized_moe_op_map[quant_algo]
+        if quant_algo not in quantized_moe_op_map:
+            raise ValueError(f"Unsupported quantization algorithm: {quant_algo}. "
+                             f"Supported: {list(quantized_moe_op_map.keys())}")
+        quantized_op = quantized_moe_op_map[quant_algo]

179-183: TransformInfo flags and recompile

  • Shapes are unchanged by quantization; set has_valid_shapes=True.
  • Since the graph was mutated, set is_clean=False (correct).
  • Recompile/lint after modifications to ensure executable GM.
-        info = TransformInfo(
-            skipped=False, num_matches=count, is_clean=False, has_valid_shapes=False
-        )
+        # Keep shapes valid after quantization
+        info = TransformInfo(
+            skipped=False, num_matches=count, is_clean=False, has_valid_shapes=True
+        )
+        # Ensure consistency after graph edits
+        try:
+            gm.graph.lint()
+        finally:
+            gm.recompile()
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)

1-1: Add NVIDIA copyright header (2025)

Production Python files require the NVIDIA copyright header.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 """
 High-level entrypoint to transform a model into an efficient inference model.
 """
tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py (1)

39-55: Handle None factory/cm and add docstring for _apply

  • If factory or cm can be None (as in tests), this will raise. Guard with explicit errors.
  • Add a short Google-style docstring.
     def _apply(
         self,
         gm: GraphModule,
         cm: CachedSequenceInterface,
         factory: ModelFactory,
         shared_config: SharedConfig,
     ) -> Tuple[GraphModule, TransformInfo]:
+        """
+        Load model weights and move model/CM to the target device.
+
+        Args:
+            gm: GraphModule whose modules/parameters will be materialized.
+            cm: CachedSequenceInterface to be moved to the target device.
+            factory: ModelFactory providing load-or-init functionality.
+            shared_config: Shared distributed/sharding config (unused here).
+
+        Returns:
+            Tuple of (updated GraphModule, TransformInfo).
+        """
+        if factory is None:
+            raise ValueError("LoadWeightsToDevice requires a non-None ModelFactory.")
+        if cm is None:
+            raise ValueError("LoadWeightsToDevice requires a non-None CachedSequenceInterface.")
         factory.load_or_random_init(
             gm, device=self.config.adconfig_checkpoint_device or self.config.device
         )
         move_to_device(gm, self.config.device)
         cm.to(self.config.device)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)

1-4: Missing NVIDIA copyright header

Per repository guidelines, add the NVIDIA copyright header (current year) to this source file.

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+"""The interface for all transforms.
tensorrt_llm/_torch/auto_deploy/transform/library/attention.py (1)

1-1: Missing NVIDIA copyright header

Add the standard NVIDIA header (current year) to this source file.

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+#
 """Pattern matching for detecting repeat_kv, eager, grouped attention patterns from Huggingface models."""
tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (1)

450-475: Potential dead node: original RoPE node not erased after replacing its outputs

_optimize_explicit replaces users of the tuple outputs but only erases the getitem nodes, leaving the original rope node potentially dead. Clean it up to avoid graph bloat.

     q_rope_old.replace_all_uses_with(q_rope_new)
     k_rope_old.replace_all_uses_with(k_rope_new)

     graph.erase_node(q_rope_old)
     graph.erase_node(k_rope_old)
+    # If the original rope node is now unused, erase it as well
+    if len(node.users) == 0:
+        graph.erase_node(node)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)

137-143: Erasing bias node without checking node type may crash

node_bias might be None or non-Node. Guard before erase_node.

-        gm.graph.erase_node(node_bias)
+        if isinstance(node_bias, Node):
+            gm.graph.erase_node(node_bias)

79-87: Clamp max_split_size to ≥ 1 in split_tensor to avoid zero-division

In tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py inside the split_tensor function, if

max_split_size = t.shape[d] // min_d_shape

evaluates to 0 you’ll get a DivisionByZero in

num_groups = math.ceil(ws / max_split_size)

or an invalid tensor_split(..., 0, ...) call.

Please update it as follows and keep the existing rank-to-chunk mapping (r // num_groups), which already distributes ranks correctly:

         # The local tensor shape has to be divisible by min_d_shape
-        max_split_size = t.shape[d] // min_d_shape
+        # avoid zero or negative splits
+        min_d_shape = max(1, min_d_shape)
+        max_split_size = max(1, t.shape[d] // min_d_shape)
♻️ Duplicate comments (2)
tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (1)

98-112: Fix: self-loop and global use-rewrite corrupt the graph

Replacing all uses of original_input with a newly created contiguous node makes the contiguous node self-referential and rewires unrelated users of original_input. Only rewire users of the redundant second transpose.

-        for t_node, t_comp_node, has_contiguous in nodes_to_eliminate:
-            # Replace all uses of the second transpose with the input to the first transpose
-            original_input = t_node.args[0]
-            t_comp_node.replace_all_uses_with(original_input)
-
-            # if there is a contiguous operation that we skipped, let add it after t_comp_node as new
-            # graph node that call contiguous on t_comp_node
-            if has_contiguous:
-                with graph.inserting_after(original_input):
-                    new_contiguous_node = graph.call_function(
-                        torch.ops.aten.contiguous.default, args=(original_input,)
-                    )
-                original_input.replace_all_uses_with(new_contiguous_node)
-                new_contiguous_node.replace_input_with(new_contiguous_node, original_input)
+        for t_node, t_comp_node, has_contiguous in nodes_to_eliminate:
+            # Replace uses of the second transpose with either the original input
+            # or a new contiguous(original_input) if we skipped contiguous nodes.
+            original_input = t_node.args[0]
+            replacement = original_input
+            if has_contiguous:
+                with graph.inserting_after(original_input):
+                    replacement = graph.call_function(
+                        torch.ops.aten.contiguous.default, args=(original_input,)
+                    )
+            t_comp_node.replace_all_uses_with(replacement)
tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (1)

257-260: Contract violation: _apply must return (gm, TransformInfo); early return breaks optimizer

When expected_layout is invalid, this returns None, violating the interface and crashing the pipeline. Return a skipped TransformInfo or raise a ValueError.

Apply:

-        if self.config.expected_layout.lower() not in supported:
-            return
+        if self.config.expected_layout.lower() not in supported:
+            info = TransformInfo(skipped=True, num_matches=0, is_clean=False, has_valid_shapes=False)
+            return gm, info
🧹 Nitpick comments (30)
tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py (2)

21-27: Interface alignment good; explicitly mark unused shared_config and add Google-style docstring

  • The new parameter aligns with BaseTransform._apply.
  • It’s unused; consider explicitly acknowledging to satisfy linters and future readers.
     def _apply(
         self,
         gm: GraphModule,
         cm: CachedSequenceInterface,
         factory: ModelFactory,
         shared_config: SharedConfig,
     ) -> Tuple[GraphModule, TransformInfo]:
+        """Remove no-op aten.slice nodes like t[:][:5] where the first slice is a no-op.
+
+        Args:
+            gm: Graph to mutate in-place.
+            cm: Cached sequence interface (unused).
+            factory: ModelFactory (unused).
+            shared_config: Shared transform config (unused).
+        """
+        # explicitly mark unused
+        _ = (cm, factory, shared_config)

35-42: Broaden slice no-op detection to handle step and alt end encodings

The current check is brittle to PyTorch version-specific IR. Consider:

  • Accept 4 or 5 args; if 5, require step == 1 (default no-op).
  • Accept end encodings like sys.maxsize or None (depending on exporter).
-            # 4 args will be (input, dim, start, end)
-            if len(node.args) != 4 or len(node.kwargs) != 0:
+            # aten.slice can be (input, dim, start, end[, step])
+            if len(node.args) not in (4, 5) or len(node.kwargs) != 0:
                 continue
@@
-            if node.args[2] != 0 or node.args[3] != torch.iinfo(torch.long).max:
+            start = node.args[2]
+            end = node.args[3]
+            step = 1 if len(node.args) == 4 else node.args[4]
+            # common encodings for full range ends
+            full_ends = {torch.iinfo(torch.long).max}
+            try:
+                import sys
+                full_ends.add(sys.maxsize)
+            except Exception:
+                pass
+            if start != 0 or step != 1 or end not in full_ends:
                 continue
tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py (1)

24-30: Interface addition looks good; mark unused shared_config and document args

Add a short docstring and explicitly mark unused params.

     def _apply(
         self,
         gm: GraphModule,
         cm: CachedSequenceInterface,
         factory: ModelFactory,
         shared_config: SharedConfig,
     ) -> Tuple[GraphModule, TransformInfo]:
+        """Remove add-with-zero no-ops when safe (dtype/device/shape preserved)."""
+        _ = (cm, factory, shared_config)  # unused
tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py (1)

35-41: Leverage SharedConfig.local_rank for device selection or mark as unused; avoid name collision

  • If device is GPU-backed, select rank-local device automatically.
  • Ensure submodule name is unique.
     def _apply(
         self,
         gm: GraphModule,
         cm: CachedSequenceInterface,
         factory: ModelFactory,
         shared_config: SharedConfig,
     ) -> Tuple[GraphModule, TransformInfo]:
-        # build the model
-        model = factory.build_model(self.config.device)
+        """Build the model via the factory on the configured device."""
+        device = self.config.device
+        # optional: auto-select rank device if configured for CUDA
+        if device in ("cuda", "auto"):
+            device = f"cuda:{shared_config.local_rank}"
+        # build the model
+        model = factory.build_model(device)
@@
-        gm.add_module("factory_model", model)
+        name = "factory_model"
+        if hasattr(gm, name):
+            idx = 1
+            while hasattr(gm, f"{name}_{idx}"):
+                idx += 1
+            name = f"{name}_{idx}"
+        gm.add_module(name, model)

If you prefer to keep behavior unchanged, add:

+        _ = (cm, shared_config)  # unused
tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (1)

52-58: Interface addition OK; mark unused shared_config and preserve namespace references

     def _apply(
         self,
         gm: GraphModule,
         cm: CachedSequenceInterface,
         factory: ModelFactory,
         shared_config: SharedConfig,
     ) -> Tuple[GraphModule, TransformInfo]:
+        _ = (factory, shared_config)  # unused

Also update namespace use at call site:

-        gm = torch_export_to_gm(
+        gm = trt_export.torch_export_to_gm(
tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (2)

12-12: Prefer namespace imports per project guidelines

The coding guidelines require maintaining module namespace on imports. Consider importing the module and qualifying symbols.

Example:

-from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
+from .. import interface as t_intf

And update references:

  • BaseTransform -> t_intf.BaseTransform
  • SharedConfig -> t_intf.SharedConfig
  • TransformInfo -> t_intf.TransformInfo
  • TransformRegistry -> t_intf.TransformRegistry

141-147: Unused parameter: shared_config

shared_config is unused here. Either use it or acknowledge explicitly to avoid linter warnings.

-        shared_config: SharedConfig,
+        shared_config: SharedConfig,  # unused
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)

26-31: Distributed rank/world size initialization timing

Capturing rank/world_size at construction can be stale if dist.init_process_group() occurs later. Consider refreshing at call time or lazily updating shared_config right before running transforms.

Option A (refresh in call):

  • If not dist.is_initialized(): keep (0,1), else query and update self.shared_config.

Option B (lazy property):

  • Compute on first __call__, then cache.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (1)

303-313: Passing None for factory and cm relies on transforms tolerating it

Current InferenceOptimizer.__call__ types cm as non-optional. Either update its annotation to Optional (preferred), or pass a lightweight stub CachedSequenceInterface in tests.

Would you like me to provide a minimal DummyCM for tests or open a follow-up to adjust the type annotation?

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

147-159: Order of same-stage transforms relies on insertion order

Both transforms are "sharding" stage. Sorting by stage preserves insertion order (stable sort), so detect_* will run before executor. If the config were built differently, this could invert the order. Keep them in the given dict order.

tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py (2)

1-1: Module docstring mismatches behavior

This transform loads weights to a device; update the docstring accordingly.

-"""A simple wrapper transform to build a model via the model factory."""
+"""A transform that loads model weights (from checkpoint or random init) and moves model/CM to the target device."""

11-17: Prefer namespace import per guidelines

Consider importing the interface module and referencing symbols via the namespace (consistent with project guidelines).

Example:

-from ..interface import (
-    BaseTransform,
-    SharedConfig,
-    TransformConfig,
-    TransformInfo,
-    TransformRegistry,
-)
+from .. import interface as t_intf

Update references accordingly.

tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (2)

62-63: Type hint mismatch for nodes_to_eliminate

The set stores triples (t_node, t_comp_node, has_contiguous), but the annotation is Tuple[Node, Node]. Fix the annotation for clarity.

-        nodes_to_eliminate: Set[Tuple[Node, Node]] = set()
+        nodes_to_eliminate: Set[Tuple[Node, Node, bool]] = set()

Also applies to: 95-96


114-116: Optional: recompile after DCE

After eliminate_dead_code(), consider gm.recompile() to keep code generation in sync.

-        if nodes_to_eliminate:
-            gm.graph.eliminate_dead_code()
+        if nodes_to_eliminate:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)

60-66: Avoid leaking optimizer-internal defaults into orchestrator

Plumbing checkpoint_device and device here couples this wrapper to specific transform internals. Prefer moving these defaults into the load_weights transform config or ModularInferenceOptimizer so the orchestration layer remains thin.

tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (3)

311-313: Accessing meta["val"] without guard can raise KeyError

These lines assume fake tensors exist. Add guards or ensure the transform requests shape propagation. Minimal guard keeps transform robust.

-    q_for_op_contig.meta["val"] = q_node.meta["val"].transpose(1, 2)
-    k_for_op_contig.meta["val"] = k_node.meta["val"].transpose(1, 2)
+    if "val" in q_node.meta:
+        q_for_op_contig.meta["val"] = q_node.meta["val"].transpose(1, 2)
+    if "val" in k_node.meta:
+        k_for_op_contig.meta["val"] = k_node.meta["val"].transpose(1, 2)
@@
-    q_rope_new.meta["val"] = q_rope_old.meta["val"]
-    q_rope_old.meta["val"] = q_rope_old.meta["val"].transpose(1, 2)
-    k_rope_new.meta["val"] = k_rope_old.meta["val"]
-    k_rope_old.meta["val"] = k_rope_old.meta["val"].transpose(1, 2)
+    if "val" in q_rope_old.meta:
+        q_rope_new.meta["val"] = q_rope_old.meta["val"]
+        q_rope_old.meta["val"] = q_rope_old.meta["val"].transpose(1, 2)
+    if "val" in k_rope_old.meta:
+        k_rope_new.meta["val"] = k_rope_old.meta["val"]
+        k_rope_old.meta["val"] = k_rope_old.meta["val"].transpose(1, 2)

Alternatively, ensure requires_shape_prop=True for this transform.

Also applies to: 337-340


394-396: Type hints: first arg is torch.fx.Graph, not GraphModule

Signatures for _optimize_explicit and _optimize_complex accept gm.graph. Adjust to reduce confusion.

-def _optimize_explicit(
-    graph: GraphModule, node: Node, cache: Dict[Any, Node], pos_cache: Dict[str, Node]
+def _optimize_explicit(
+    graph: torch.fx.Graph, node: Node, cache: Dict[Any, Node], pos_cache: Dict[str, Node]
 ) -> None:
@@
-def _optimize_complex(
-    graph: GraphModule, node: Node, cache: Dict[Any, Node], pos_cache: Dict[str, Node]
+def _optimize_complex(
+    graph: torch.fx.Graph, node: Node, cache: Dict[Any, Node], pos_cache: Dict[str, Node]
 ) -> None:

Also applies to: 477-479


582-609: RoPE input validation might be too strict and brittle

  • Requires head_dim % 64 == 0 and seq dim strictly torch.SymInt. Many models may not satisfy both (e.g., alternative head dims or traced constants). Consider relaxing to head_dim % 2 == 0 and allowing int | SymInt for seq.
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4)

1-1: Missing NVIDIA copyright header

Per repository guidelines, add the NVIDIA copyright header with the current year.


169-180: replace_input_with(dist_node, node) is a no-op here

dist_node does not reference itself, so replacing its own input with node has no effect. Safe to remove. Optionally, call gm.graph.eliminate_dead_code() after rewiring.


228-244: Validation for TP sharding ops is correct but could enforce dist_op presence for COLUMN split

You already enforce add_dist for dim=1 in the inserter; align validation by requiring dist_op for SplitDimension.COLUMN.


267-295: BMM validation depends on meta; guard against missing shapes

lhs_tensor.meta["val"] can be absent if shape prop wasn’t run. Add guards or require shape propagation before this transform executes.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (4)

1-1: Missing NVIDIA copyright header

Add the standard NVIDIA header at the top of this new module.


235-248: Heuristic grouping order of nodes_linear.values() is not deterministic

Using insertion order to assign split_dim (0 then 1) can be unstable across graphs. Consider ordering groups by topological position of the group’s first node to ensure consistent split assignment.


296-300: min_local_shape derived from meta["val"] without guard

If shape prop hasn’t run, meta may be missing. Add a guard or set requires_shape_prop in this transform’s config.


334-347: DP BMM sharding assumes evenly divisible but still computes remainder-specific indices

You correctly skip when remainder != 0, but the start/end computation branches include remainder handling. Simplify to avoid confusion and ensure indices are correct.

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (4)

1-1: Missing NVIDIA copyright header

Add the required NVIDIA header to comply with repository standards.


64-127: Typo: _find_lowest_common_ancessor_find_lowest_common_ancestor

Rename for clarity and correctness. Update call sites accordingly.


23-44: Consider removing or nulling original expert parameters after fusing

You register new fused parameters but keep originals, which may waste memory. If safe, remove the old per-expert params or convert to buffers to avoid double allocation.


15-61: Fused op registration on meta device

If parameters are on meta, ensure stacking/cat on meta tensors is supported in your PyTorch version and that later weight loading will replace them. Otherwise, gate this under actual device loads or ensure load_state_dict fully overwrites.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 30cd46e and 4320e20.

📒 Files selected for processing (24)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (4 hunks)
💤 Files with no reviewable changes (1)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/init.py
✅ Files skipped from review due to trivial changes (1)
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
🧠 Learnings (6)
📚 Learning: 2025-07-22T09:22:14.726Z
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
📚 Learning: 2025-08-09T20:57:04.067Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.067Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
📚 Learning: 2025-08-08T04:10:18.987Z
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6728
File: cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp:966-966
Timestamp: 2025-08-08T04:10:18.987Z
Learning: TensorRT plugins currently don't support padding functionality, and TensorRT is not getting new features (in maintenance mode). This means that duplicating parameters like mExpertHiddenSize in function calls, even with TODO comments, can be acceptable as pragmatic solutions within these constraints.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
📚 Learning: 2025-08-08T22:03:40.685Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1198-1209
Timestamp: 2025-08-08T22:03:40.685Z
Learning: In the CUTLASS MoE kernels (cpp/tensorrt_llm/cutlass_extensions), when `layout_info.fusion` is set to `TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE`, the `router_scales` parameter must be non-null by design. The fused finalize kernel epilogue does not perform nullptr checks and requires valid router scales to function correctly. This is an implicit contract that callers must satisfy when enabling the FINALIZE fusion mode.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
🧬 Code Graph Analysis (3)
tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py (1)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
  • BaseTransform (138-372)
  • SharedConfig (51-56)
  • TransformInfo (107-132)
  • TransformRegistry (375-403)
tensorrt_llm/_torch/auto_deploy/transform/library/attention.py (3)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (5)
  • BaseTransform (138-372)
  • SharedConfig (51-56)
  • TransformConfig (59-98)
  • TransformInfo (107-132)
  • TransformRegistry (375-403)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (7)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py (1)
  • cuda_memory_tracker (10-26)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (4)
  • bfs (348-365)
  • identify_regions_between_residuals (292-345)
  • is_linear_op (240-252)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
  • get_scales_and_type_from_node (505-512)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (6)
  • BaseTransform (138-372)
  • SharedConfig (51-56)
  • TransformInfo (107-132)
  • TransformRegistry (375-403)
  • register (381-388)
  • _apply (362-372)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py (3)
  • torch_moe (44-78)
  • torch_quant_fp8_moe (159-217)
  • torch_quant_fp4_moe (239-305)
tensorrt_llm/module.py (1)
  • register_parameter (186-190)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (30)
tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (1)

33-60: Consistent scale names verified across QuantizationImpl subclasses
All implementations of scale_names() (FP8QuantizationImpl, FP8BMMQuantizationImpl, FP4QuantizationImpl) return exactly the keys present in their respective default_scales() dicts—no mismatches detected.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (3)

9-11: Good shift to GM + Optimizer pipeline

Switching to torch_export_to_gm + InferenceOptimizer aligns tests with the new modular pipeline. Imports look correct.


341-356: LGTM: end-to-end validation via run_test_transformed_gm

Asserting the presence of fused ops and using strict state_dict loading is appropriate post-fusion. Tolerances are reasonable for quantized paths.


360-364: Solid parameter count assertion post-fusion

Verifying reduced parameter nodes after fusion is a good structural check.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4)

11-11: Consolidated test helpers import

Using run_test_transformed_gm and run_sharding_pattern_detection_test keeps tests consistent with the new pipeline.


15-19: Namespace imports align with guidelines

Sharding types imported via module path; consistent with namespacing preferences. Good.


271-283: Explicitly overriding rank/world_size for pattern detection is correct

Setting optimizer.shared_config before invoking detection avoids requiring distributed init for pattern-only tests.


169-175: Combined graph check is pragmatic

Validates both distributed ops presence and per-weight local size constraints. Good coverage for sharding correctness.

tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py (1)

29-38: Minimal config/type wiring looks good

Registration and config class wiring are straightforward and consistent with other transforms.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (5)

8-8: LGTM: updated helpers usage

Switch to run_test_transformed_gm/run_sharding_pattern_detection_test is consistent with the new optimizer pipeline.


13-15: LGTM: imports aligned with new transform API

EPShardingInfo and InferenceOptimizer imports reflect the refactor.


38-49: Confirm distributed context for optimizer run

InferenceOptimizer(None, {...})(None, gm) depends on shared_config being auto-populated from the process environment. Please confirm it picks up local_rank/world_size in the spawned dist job; otherwise the detection may behave as world_size=1.

You can check at runtime by logging optimizer.shared_config in the job before applying the optimizer.


53-57: LGTM: transformed-GM test path

Passing gm_transformed directly into run_test_transformed_gm matches the new flow.


96-107: LGTM: pattern detection via optimizer shared_config

Setting local_rank/world_size on optimizer.shared_config and reading ep_transforms is the expected usage.

tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)

18-18: LGTM: ShardingConfig import location

Import path matches the utils module expected by SharedConfig.


51-57: LGTM: SharedConfig introduction

SharedConfig cleanly centralizes sharding + distributed context for transforms.


248-257: LGTM: forwarding shared_config to _apply

The plumbing is correct through the error-handling path and the normal path.


363-368: LGTM: abstract signature updated

The abstract method now enforces access to SharedConfig for all transforms.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (5)

9-9: LGTM: updated helpers usage

Adopts run_test_transformed_gm and pattern detection helper as per new pipeline.


13-15: LGTM: imports aligned with refactor

BMMShardingInfo and InferenceOptimizer are the correct new sources.


63-75: Confirm distributed context for optimizer run

As with EP sharding, ensure InferenceOptimizer derives local_rank/world_size from the environment in multiprocess runs; otherwise detection defaults to world_size=1.

You can print optimizer.shared_config within the job to verify before applying the optimizer.


82-86: LGTM: transformed-GM test path

Using gm_transformed with run_test_transformed_gm is correct.


124-136: LGTM: pattern detection via optimizer shared_config

Setting shared_config then reading sharding_config.bmm_transforms is the intended flow.

tensorrt_llm/_torch/auto_deploy/transform/library/attention.py (5)

16-22: LGTM: import SharedConfig from interface

Imports reflect the updated transform interface.


334-339: LGTM: signature updated to accept shared_config

No functional changes; just interface alignment.


379-384: LGTM: signature updated to accept shared_config

Consistent with BaseTransform change.


409-414: LGTM: signature updated to accept shared_config

Consistent across transforms.


499-504: LGTM: signature updated to accept shared_config

Consistent throughout the attention transforms.

tensorrt_llm/_torch/auto_deploy/transformations/transform.py (2)

53-56: Ensure expected_layout matches supported values ('bsnd' or 'bnsd')

MatchRopeLayout only supports 'bsnd' or 'bnsd'. Verify that AttentionRegistry.get(...).get_attention_layout() returns one of these (case-insensitive). Otherwise, the transform may skip or error depending on implementation.

Would recommend normalizing and validating here to be defensive:

-            self.ad_config.transforms["match_rope_layout"].expected_layout = AttentionRegistry.get(
-                self.ad_config.attn_backend
-            ).get_attention_layout()
+            expected = AttentionRegistry.get(self.ad_config.attn_backend).get_attention_layout()
+            self.ad_config.transforms["match_rope_layout"].expected_layout = str(expected).lower()

67-67: Confirm new_optimizer(cm) handles shape prop and meta population

Downstream transforms rely on meta["val"]. Ensure the optimizer performs canonicalization and shape propagation before post-load fusions. If not guaranteed, set requires_shape_prop in transforms that dereference meta["val"].

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (2)
tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (2)

1-46: Add mandatory NVIDIA copyright header

All TensorRT-LLM source files must begin with the standard 2025 NVIDIA copyright notice.
Please prepend the header right before the module docstring.


472-477: Remove orphan RoPE tuple getter nodes

_optimize_explicit erases the two getitem nodes but not the original
RoPE op. Erase the root call as well to keep the graph clean and avoid
extra kernel dispatches during execution.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (1)

379-388: num_rope_optimizations under-counts

_optimize_explicit / _optimize_complex may early-return without changing the
graph, yet the counter is incremented unconditionally.
Increment only when the helper actually performed a replacement.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4320e20 and c2ad719.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
🧠 Learnings (2)
📓 Common learnings
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.067Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
📚 Learning: 2025-08-08T22:03:40.685Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1198-1209
Timestamp: 2025-08-08T22:03:40.685Z
Learning: In the CUTLASS MoE kernels (cpp/tensorrt_llm/cutlass_extensions), when `layout_info.fusion` is set to `TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE`, the `router_scales` parameter must be non-null by design. The fused finalize kernel epilogue does not perform nullptr checks and requires valid router scales to function correctly. This is an implicit contract that callers must satisfy when enabling the FINALIZE fusion mode.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (1)

264-268: Keep rope_ops as overload packets to match all variants
The is_op helper explicitly handles OpOverloadPacket by checking each overload (including .default, .Tensor, etc.). By passing the packet itself, you’ll match any overload function. Switching to only .default would actually narrow the matches and could skip other valid variants.

• File requiring no change:
– tensorrt_llm/_torch/auto_deploy/transform/library/rope.py lines 264–268

Likely an incorrect or invalid review comment.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🔭 Outside diff range comments (3)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)

28-33: Make state_dict loading robust for sharded vs. unsharded checkpoints.

The TODO is valid: current hooks will fail round-tripping un/sharded state_dicts. Recommend tagging state_dict with sharding meta or handling both layouts.

Proposed direction:

  • On save: register a state_dict hook to add a small metadata blob (e.g., state_dict["ad_sharding_meta"] with per-param shard spec).
  • On load: in _load_hook, check for that metadata to decide whether to split or pass-through without shape mismatch heuristics.

I can draft the hooks and metadata schema if helpful.


281-287: Guard against missing FX meta and avoid raising in validate().

validate() assumes meta["val"] is populated; if shape propagation didn't set it, this will KeyError. Also, an assert will abort instead of cleanly skipping.

Apply:

-        lhs_batch_size = lhs_tensor.meta["val"].shape[0]
-        rhs_batch_size = rhs_tensor.meta["val"].shape[0]
+        try:
+            lhs_meta = lhs_tensor.meta
+            rhs_meta = rhs_tensor.meta
+            lhs_shape = getattr(lhs_meta.get("tensor_meta", None), "shape", None) or lhs_meta["val"].shape
+            rhs_shape = getattr(rhs_meta.get("tensor_meta", None), "shape", None) or rhs_meta["val"].shape
+            lhs_batch_size = lhs_shape[0]
+            rhs_batch_size = rhs_shape[0]
+        except Exception:
+            ad_logger.warning("Missing or incomplete FX meta for BMM inputs. Skipping %s.", self)
+            return False
-
-        assert lhs_batch_size == rhs_batch_size, "Batch sizes of both tensors must match"
+        if lhs_batch_size != rhs_batch_size:
+            ad_logger.warning("BMM lhs/rhs batch sizes differ (%s vs %s). Skipping %s.",
+                              lhs_batch_size, rhs_batch_size, self)
+            return False

Also applies to: 289-299


168-174: Refactor sharding_utils to honor dist_op instead of deriving from dim

The current implementation in tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py selects the collective op solely by dim (0→all_gather, 1→all_reduce) and uses dist_op only as a boolean flag (add_dist). This makes the dist_op parameter misleading and brittle.

Proposed changes:

  • Pass the literal dist_op (“all_gather” or “all_reduce”) into _insert_sharded_matmul instead of only add_dist.
  • In _insert_sharded_matmul, replace the dist_lookup based on dim with a simple dispatch on dist_op:
    • if dist_op == "all_gather", use torch.ops.auto_deploy.torch_dist_all_gather (with dim arg −1)
    • if dist_op == "all_reduce", use torch.ops.auto_deploy.torch_dist_all_reduce
    • else (when add_dist is True but dist_op is invalid) raise an error
  • Keep TPShardingInfo.validate to enforce only valid combinations (ROW ⇒ all_gather; COLUMN ⇒ all_reduce).
  • Apply the same refactor at the other two sites in this file (lines ~228–247 and ~252–260) where dist_op is currently only gating add_dist.

Bullet-point locations needing update:

  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py:168–174 (initial dist_lookup)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py:228–247 (second matmul insertion)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py:252–260 (third matmul insertion)

This will make dist_op the single source of truth for which communication primitive is inserted.

🧹 Nitpick comments (4)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4)

78-87: Fix variable naming and debug message in split_tensor (clarity + correctness of log).

The variable max_split_size is actually the number of splits, not a size. The debug message also claims “Splitting tensor to {num_groups} chunks”, but the code splits into max_split_size chunks and then groups ranks.

Apply:

-        max_split_size = t.shape[d] // min_d_shape
-        if ws > max_split_size:
-            num_groups = math.ceil(ws / max_split_size)
+        max_num_splits = t.shape[d] // min_d_shape
+        if ws > max_num_splits:
+            num_groups = math.ceil(ws / max_num_splits)
             ad_logger.debug(
-                f"World size {ws} is greater than the max split size {max_split_size}. "
-                + f"Splitting tensor to {num_groups} chunks"
+                f"World size {ws} > max_num_splits {max_num_splits}. "
+                f"Splitting into {max_num_splits} shards; grouping ranks into {num_groups} groups."
             )
-            return torch.tensor_split(t, max_split_size, dim=d)[r // num_groups]
+            return torch.tensor_split(t, max_num_splits, dim=d)[r // num_groups]

385-405: Guard EP sharding math when num_experts % world_size != 0 and when experts_per_rank == 0.

Related to the previous comment: if experts_per_rank is 0, floordiv and comparisons will fail. The validate() fix mitigates this, but keep these paths defensive (early-return) if reached.


213-222: Include node context in logs for failed transforms.

check_and_apply() logs the transform object but not the node; adding node target/op improves debuggability.

-            ad_logger.warning(f"Skipping invalid transformation {self}.")
+            ad_logger.warning("Skipping invalid transformation %s on node %s (%s).",
+                              self, getattr(node, "name", None), getattr(node, "target", None))

477-483: Consider docstrings in Google style for public config classes.

For ShardingConfig and the Info classes, Google-style docstrings help Sphinx and downstream users.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c2ad719 and 28b87a4.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
🧠 Learnings (1)
📓 Common learnings
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.067Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)

356-363: Confirm that gather along dim=0 matches the BMM sharding scheme.

After slicing the batch dimension, the code gathers along dim=0. This is consistent with batch-split BMM, but please confirm downstream nodes expect the restored full batch.

If mismatches are observed, consider inserting NCCL/collectives that align with subsequent tensor layouts or add an explicit re-sharding step for consumers.


232-247: Confirm TP semantics: ROW vs COLUMN mapping to all_gather/all_reduce.

Common TP conventions often use all_reduce for row-split and all_gather for column-split in linear layers. Your code enforces the opposite. This may be intentional given your op definitions, but it’s worth a double-check.

If intentional, add a brief comment here documenting the convention used by auto_deploy ops to avoid confusion.


252-260: Ensure graph is recompiled after in-graph edits.

_insert_sharded_matmul mutates the graph. Please ensure the caller or pipeline invokes gm.graph.lint() and gm.recompile() after a batch of mutations.

Copy link
Collaborator

@galagam galagam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some nitpicking comments, overall LGTM.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)

1-2: Add mandatory NVIDIA copyright header

Production .py files must start with the NVIDIA copyright header for the current year. Insert it before the module docstring.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
 """Interface to initialize and load HF models."""
♻️ Duplicate comments (1)
tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (1)

1-8: Add mandatory NVIDIA copyright header

Production modules must start with the NVIDIA header for the current year. Insert it before the module docstring.

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+
 """
 Quantization Config Reader Registry.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (3)

14-15: Remove unused torch import

torch is only referenced in the incorrect type annotation below; after fixing that, the import is unused.

-import torch

81-107: Inline clarify dtype-as-string design decision

Per project learning, store torch dtypes as strings in config/kwargs and convert downstream. Add a brief comment to prevent regressions.

-        if quant_config.get("quant_algo") == "NVFP4":
-            quant_config["torch_dtype"] = "float16"
+        if quant_config.get("quant_algo") == "NVFP4":
+            # Store dtype as string for compatibility with OmegaConf merges; convert later.
+            quant_config["torch_dtype"] = "float16"
@@
-        if quant_config.get("quant_algo", None) == "NVFP4":
-            extra_model_kwargs["torch_dtype"] = "float16"
+        if quant_config.get("quant_algo", None) == "NVFP4":
+            # Keep string here as well; callers normalize to torch.dtype when needed.
+            extra_model_kwargs["torch_dtype"] = "float16"

69-76: Optional: rename parameter for clarity

has(reader_cls: str) takes a name, not a class. Consider renaming to name: str for readability.

-    def has(cls, reader_cls: str) -> bool:
-        return reader_cls in cls._registry
+    def has(cls, name: str) -> bool:
+        return name in cls._registry
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 28b87a4 and 6e0a7b2.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/auto_deploy/models/hf.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
🧠 Learnings (3)
📚 Learning: 2025-08-06T21:22:55.018Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-06T21:22:55.018Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
📚 Learning: 2025-08-09T02:04:49.580Z
Learnt from: Fridah-nv
PR: NVIDIA/TensorRT-LLM#6760
File: tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py:81-98
Timestamp: 2025-08-09T02:04:49.580Z
Learning: In TensorRT-LLM's auto_deploy module, torch.dtype values in configuration dictionaries must be stored as string representations (e.g., "float16" instead of torch.float16) because OmegaConf.merge does not support torch.dtype types. These string representations are converted to actual torch.dtype objects in downstream code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
🧬 Code Graph Analysis (1)
tensorrt_llm/_torch/auto_deploy/models/hf.py (2)
tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (6)
  • QuantConfigReader (17-54)
  • QuantConfigReaderRegistry (57-76)
  • get_config (23-25)
  • get (69-72)
  • from_file (42-54)
  • from_file (110-130)
tensorrt_llm/_torch/auto_deploy/models/factory.py (3)
  • model (43-45)
  • get_quant_config (95-97)
  • get (222-224)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/_torch/auto_deploy/models/hf.py (3)

33-33: Quant config reader imports look good

Importing the reader and registry here is appropriate for the new flow.


87-87: State is now a reader instead of a dict — LGTM

Storing a QuantConfigReader instance (or None) is a clean abstraction vs. in-model dicts.


180-184: Docstring clarifies behavior — LGTM

get_quant_config returning {} when not quantized is clear and safe.

tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (2)

122-130: File I/O and parsing logic — LGTM

from_file correctly short-circuits when file missing and delegates to read_config to produce both internal state and extra kwargs.


1-130: No callers expect a single return from from_file—all sites unpack the tuple as intended.
All usages of QuantConfigReader.from_file (in hf.py) correctly handle the (reader, extra_model_kwargs) return value. No further changes needed.

@Fridah-nv Fridah-nv force-pushed the user/fridah/merge-0801 branch 2 times, most recently from 3ef7311 to f8c7512 Compare August 11, 2025 06:40
@Fridah-nv
Copy link
Collaborator Author

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (16)
tensorrt_llm/_torch/auto_deploy/models/hf.py (4)

87-87: Use Optional[QuantConfigReader] for Python 3.8 compatibility

The type hint uses PEP 604 union syntax (|) which requires Python 3.10+. Per coding guidelines, use typing.Optional for Python 3.8+ compatibility.

-        self._quant_config_reader: QuantConfigReader | None = None
+        from typing import Optional
+        self._quant_config_reader: Optional[QuantConfigReader] = None

186-197: Replace assert with explicit mapping and error for robustness

Current assert is redundant (torch_dtype is either None or float8_e4m3fn by construction) and can be optimized out. Prefer explicit mapping with a clear error when an unsupported kv_cache_dtype appears.

-        kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype")
-        torch_dtype = torch.float8_e4m3fn if kv_cache_dtype == "float8_e4m3fn" else None
-        assert torch_dtype in (torch.float8_e4m3fn, None), (
-            f"Unsupported dtype: {torch_dtype}. Only torch.float8_e4m3fn is supported."
-        )
-
-        return CacheConfig(dtype=torch_dtype)
+        kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype")
+        dtype_map = {
+            "float8_e4m3fn": torch.float8_e4m3fn,
+        }
+        if kv_cache_dtype is None:
+            return CacheConfig(dtype=None)
+        if kv_cache_dtype not in dtype_map:
+            raise ValueError(f"Unsupported kv_cache_dtype: {kv_cache_dtype}")
+        return CacheConfig(dtype=dtype_map[kv_cache_dtype])

329-332: Honor quantization_source instead of hard-coding "modelopt"

The PR summary mentions adding a quantization_source to the factory interface. Use it (with a sane default) rather than hard-coding "modelopt".

Proposed change:

-        # TODO: specified by user or auto-detect
-        reader_cls = QuantConfigReaderRegistry.get("modelopt")
+        # Prefer user-specified source; fall back to "modelopt"
+        source = getattr(self, "quantization_source", "modelopt")
+        reader_cls = QuantConfigReaderRegistry.get(source)

336-339: Normalize torch_dtype after merging extra_model_kwargs

extra_model_kwargs may contain "torch_dtype" as a string (by design). Since init normalization ran before this merge, redo the conversion here to avoid propagating a str into model config.

-            self._quant_config_reader = reader
-            self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs)
+            self._quant_config_reader = reader
+            self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs)
+            # Normalize torch_dtype if provided as string (e.g., "float16")
+            if isinstance(self.model_kwargs.get("torch_dtype"), str):
+                dt = getattr(torch, self.model_kwargs["torch_dtype"], None)
+                if not isinstance(dt, torch.dtype):
+                    raise ValueError(f"Invalid torch_dtype: {self.model_kwargs['torch_dtype']}")
+                self.model_kwargs["torch_dtype"] = dt
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (5)

1-1: Missing NVIDIA copyright header (required by repo guidelines)

Per coding guidelines, all OSS source files must include an NVIDIA copyright header with the current year.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
 """Sharding config definitions for the inference optimizer."""

266-268: Remove duplicate fields (rank/world_size) — already defined in base class

BMMShardingInfo redeclares rank and world_size, which are already in ShardingTransformInfo. This is redundant and risks confusion.

 class BMMShardingInfo(ShardingTransformInfo):
     """Configuration for BMM sharding transformations."""
 
-    rank: int
-    world_size: int
     start_idx: int
     end_idx: int

327-329: Do not re-register an existing parameter — use setattr like TP path

register_parameter can raise if a parameter with the same name already exists. The TP path uses setattr, which is safe and consistent.

-                param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True)
-                gm.get_submodule(modname).register_parameter(param_name, param_new)
+                param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True)
+                setattr(gm.get_submodule(modname), param_name, param_new)

346-347: torch.fx Node has no update_arg(); replace with args tuple update

Node.update_arg is not part of the FX API and will fail at runtime.

Minimal fix:

-                bmm_node.update_arg(arg_idx, tensor_slice)
+                new_args = list(bmm_node.args)
+                new_args[arg_idx] = tensor_slice
+                bmm_node.args = tuple(new_args)

458-471: Validate num_experts/world_size to prevent divide-by-zero and invalid masking

experts_per_rank = num_experts // world_size later can be zero if world_size > num_experts, leading to division by zero in floordiv and incorrect masks.

Apply:

     def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
         """Validate the transformation configuration."""
         if not is_op(
             node,
             (
                 torch.ops.auto_deploy.torch_moe,
                 torch.ops.auto_deploy.torch_quant_fp8_moe,
                 torch.ops.auto_deploy.torch_quant_fp4_moe,
             ),
         ):
             ad_logger.warning(f"EP sharding is only supported for MOE nodes. Skipping {self}.")
             return False
+        try:
+            num_experts = len(node.args[3])
+        except Exception:
+            ad_logger.warning("Unable to determine num_experts for %s. Skipping.", self)
+            return False
+        if num_experts == 0:
+            ad_logger.warning("num_experts == 0 for %s. Skipping.", self)
+            return False
+        if self.world_size > 1 and (num_experts // self.world_size) == 0:
+            ad_logger.warning(
+                "world_size (%s) > num_experts (%s); experts_per_rank would be 0. Skipping %s.",
+                self.world_size, num_experts, self
+            )
+            return False
         return True
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (2)

64-66: Python 3.8 compatibility: use typing.List/Tuple instead of PEP 585 generics

The project targets Python 3.8+. Built-in generics like list[Node] and tuple[...] require 3.9+. Replace with List[...]/Tuple[...].

-from typing import Optional, Tuple
+from typing import Optional, Tuple, List
@@
-def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
+def _find_lowest_common_ancessor(nodes: List[Node]) -> Optional[Node]:
@@
-def _extract_linear_parameters(linear_node: Node) -> tuple[Node, torch.Tensor, Optional[dict], str]:
+def _extract_linear_parameters(linear_node: Node) -> Tuple[Node, Node, Optional[dict], str]:
@@
-def _find_final_hidden_state_node(
-    pattern_output_nodes: list[Node], end_boundary: Node
+def _find_final_hidden_state_node(
+    pattern_output_nodes: List[Node], end_boundary: Node
 ) -> Optional[Node]:
@@
-def _extract_index_branches_from_expert_outputs(
-    pattern_output_nodes: list[Node],
-) -> tuple[list[Node], list[Node]]:
+def _extract_index_branches_from_expert_outputs(
+    pattern_output_nodes: List[Node],
+) -> Tuple[List[Node], List[Node]]:

Also applies to: 129-131, 268-271, 320-323


146-153: Logic bug: elif {is_op(...), is_op(...)}: always truthy

This set literal is always truthy, so the branch executes even when neither op matches. Use boolean OR. Also align return types to Node for weights and handle scales possibly being None.

-    elif {
-        is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear)
-        or is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear),
-    }:
+    elif (
+        is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear)
+        or is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear)
+    ):
         weight = linear_node.args[1]
-        scales, quant_type = get_scales_and_type_from_node(linear_node)
-        return input_node, weight, scales, quant_type
+        scales, quant_type = get_scales_and_type_from_node(linear_node)
+        return input_node, weight, scales or {}, quant_type
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (2)

81-84: Recompile the FX graph once after all transforms

Calling gm.graph.lint() and gm.recompile() once avoids stale code objects after in-place graph edits.

         ############################################################################################
         # RETURN OPTIMIZED GRAPH
         ############################################################################################
-        return gm
+        try:
+            gm.graph.lint()
+        finally:
+            gm.recompile()
+        return gm

79-79: Accept Optional cm and factory to match test usage

Tests construct InferenceOptimizer(None, ...)(None, gm). Consider loosening types to Optional for both constructor factory and call’s cm.

Suggested diffs:

  • Constructor:
-class InferenceOptimizer:
-    def __init__(self, factory: ModelFactory, config: InferenceOptimizerConfig):
+class InferenceOptimizer:
+    def __init__(self, factory: Optional[ModelFactory], config: InferenceOptimizerConfig):
         self.factory = factory
  • call signature:
-    def __call__(
-        self, cm: CachedSequenceInterface, gm: Optional[GraphModule] = None
-    ) -> GraphModule:
+    def __call__(
+        self, cm: Optional[CachedSequenceInterface], gm: Optional[GraphModule] = None
+    ) -> GraphModule:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py (1)

221-229: Optimizer called with None factory/cm — ensure API reflects this

The test drives InferenceOptimizer(None, {...})(None, gm). Please ensure InferenceOptimizer’s constructor and call accept Optional[ModelFactory] and Optional[CachedSequenceInterface], respectively, to match usage.

If not yet updated, apply the diffs suggested in the optimizer review. To confirm current usages across tests:

#!/bin/bash
# Find optimizer invocations passing None
rg -n "InferenceOptimizer\\(None,|\\)\\(None,\\s*gm\\)" -A 2
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

147-159: Optimizer called with None factory/cm — align types

Same as the single-GPU test: consider updating InferenceOptimizer to accept Optional types for factory and cm to match usage.

See optimizer review for proposed diffs.

Also applies to: 270-283

tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (1)

263-356: MatchRopeLayout implementation is comprehensive but needs node cleanup.

The layout transformation logic is thorough and handles the transpose operations correctly. However, there's still a potential issue with stale nodes that was mentioned in past reviews.

After updating node arguments and replacing uses (lines 344-347), consider checking if the original RoPE node becomes unused and clean it up:

            q_rope_old.replace_all_uses_with(q_rope_new)
            k_rope_old.replace_all_uses_with(k_rope_new)
            q_rope_new.args = (q_rope_old, 1, 2)
            k_rope_new.args = (k_rope_old, 1, 2)
+           
+           # Clean up original node if it has no users
+           if not node.users:
+               graph.erase_node(node)

Also, consider calling gm.recompile() before returning to ensure the graph is properly updated.

🧹 Nitpick comments (5)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)

64-64: Fix typo in function name: "ancessor" → "ancestor"

The function name has a typo. It should be _find_lowest_common_ancestor not _find_lowest_common_ancessor.

-def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
+def _find_lowest_common_ancestor(nodes: list[Node]) -> Optional[Node]:

Also update the function calls on lines 403 and 407:

-            normalized_routing_weights = _find_lowest_common_ancessor(arg1_list)
+            normalized_routing_weights = _find_lowest_common_ancestor(arg1_list)
...
-            common_ancessor2 = _find_lowest_common_ancessor(arg2_list)
+            common_ancestor2 = _find_lowest_common_ancestor(arg2_list)
tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (1)

11-17: Follow namespaced import guideline

Org guideline asks to maintain module namespaces. Consider importing the interface module and referencing types via the module.

Example:

-from ..interface import (
-    BaseTransform,
-    SharedConfig,
-    TransformConfig,
-    TransformInfo,
-    TransformRegistry,
-)
+from .. import interface as t_iface

And update references in this file:

  • BaseTransform -> t_iface.BaseTransform
  • TransformConfig -> t_iface.TransformConfig
  • TransformInfo -> t_iface.TransformInfo
  • TransformRegistry -> t_iface.TransformRegistry
  • SharedConfig -> t_iface.SharedConfig
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)

26-31: Guard for non-available distributed backends

Be defensive and check dist.is_available() before is_initialized(). Avoids edge cases in builds without distributed.

-        if not dist.is_initialized():
-            local_rank, world_size = 0, 1
-        else:
-            local_rank, world_size = dist_ad.get_rank_world_size()
+        if dist.is_available() and dist.is_initialized():
+            local_rank, world_size = dist_ad.get_rank_world_size()
+        else:
+            local_rank, world_size = 0, 1
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py (1)

6-6: Prefer namespaced imports in tests as well

Maintain module namespaces per guideline. Low priority for tests, but consider:

Example:

  • from _graph_test_helpers import run_test_transformed_gm -> import _graph_test_helpers as gth
  • from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -> import tensorrt_llm._torch.auto_deploy.export as ad_export
  • Then use gth.run_test_transformed_gm and ad_export.torch_export_to_gm

Also applies to: 9-11, 14-14

tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)

18-18: Namespaced import for ShardingConfig

Per guideline, favor module import: utils.sharding_utils as sharding_utils, then refer to sharding_utils.ShardingConfig.

-from ..utils.sharding_utils import ShardingConfig
+from ..utils import sharding_utils
+# and: sharding_utils.ShardingConfig in type annotations
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6e0a7b2 and f8c7512.

📒 Files selected for processing (30)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/hf.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py (7 hunks)
🔥 Files not summarized due to errors (1)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py: Error: Server error: no LLM provider could handle the message
💤 Files with no reviewable changes (2)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/init.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py
🚧 Files skipped from review as they are similar to previous changes (16)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T21:22:55.018Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-06T21:22:55.018Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
📚 Learning: 2025-08-09T02:04:49.580Z
Learnt from: Fridah-nv
PR: NVIDIA/TensorRT-LLM#6760
File: tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py:81-98
Timestamp: 2025-08-09T02:04:49.580Z
Learning: In TensorRT-LLM's auto_deploy module, torch.dtype values in configuration dictionaries must be stored as string representations (e.g., "float16" instead of torch.float16) because OmegaConf.merge does not support torch.dtype types. These string representations are converted to actual torch.dtype objects in downstream code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/hf.py
🧬 Code Graph Analysis (11)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (8)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (6)
  • SharedConfig (51-56)
  • TransformInfo (107-132)
  • TransformRegistry (375-403)
  • register (381-388)
  • _apply (362-372)
  • get (391-393)
tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (2)
  • register (61-66)
  • get (69-72)
tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (1)
  • _apply (141-183)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
  • get_quant_config (179-183)
tensorrt_llm/models/modeling_utils.py (1)
  • quant_algo (547-548)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (5)
  • should_skip_quantization (470-483)
  • QuantizationImpl (68-161)
  • is_quantized_graph (363-371)
  • get_quantization_from_linear_node (394-407)
  • remove_output_quantizers (384-391)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
  • is_linear_op (240-252)
  • is_bmm_op (255-262)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (7)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py (1)
  • cuda_memory_tracker (10-26)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (4)
  • bfs (348-365)
  • identify_regions_between_residuals (292-345)
  • is_linear_op (240-252)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
  • get_scales_and_type_from_node (505-512)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (6)
  • BaseTransform (138-372)
  • SharedConfig (51-56)
  • TransformInfo (107-132)
  • TransformRegistry (375-403)
  • register (381-388)
  • _apply (362-372)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py (3)
  • torch_moe (44-78)
  • torch_quant_fp8_moe (159-217)
  • torch_quant_fp4_moe (239-305)
tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (3)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (5)
  • BaseTransform (138-372)
  • SharedConfig (51-56)
  • TransformConfig (59-98)
  • TransformInfo (107-132)
  • TransformRegistry (375-403)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (5)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (2)
  • run_sharding_pattern_detection_test (228-242)
  • run_test_transformed_gm (68-138)
tensorrt_llm/_torch/auto_deploy/export/export.py (1)
  • torch_export_to_gm (198-284)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
  • SplitDimension (182-186)
  • TPShardingInfo (225-260)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)
  • InferenceOptimizer (22-84)
tensorrt_llm/mapping.py (1)
  • local_rank (372-373)
tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py (3)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (5)
  • BaseTransform (138-372)
  • SharedConfig (51-56)
  • TransformConfig (59-98)
  • TransformInfo (107-132)
  • TransformRegistry (375-403)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
  • ShardingConfig (477-482)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (4)
  • _apply (57-91)
  • _apply (149-321)
  • _apply (335-408)
  • _apply (413-452)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)
  • run_test_transformed_gm (68-138)
tensorrt_llm/_torch/auto_deploy/export/export.py (1)
  • torch_export_to_gm (198-284)
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)
  • InferenceOptimizer (25-142)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (3)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)
  • SharedConfig (51-56)
tensorrt_llm/_torch/auto_deploy/distributed/common.py (1)
  • get_rank_world_size (86-87)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (7)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (3)
  • extract_param_names_from_lin_node (149-170)
  • is_op (183-206)
  • num_users_of_weight_node (143-146)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
  • QuantizationImpl (68-161)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)
  • check_and_apply (67-75)
tensorrt_llm/logger.py (1)
  • warning (131-132)
tensorrt_llm/_torch/modules/linear.py (1)
  • split_dim (49-50)
tensorrt_llm/functional.py (2)
  • Tensor (107-602)
  • slice (1222-1324)
tensorrt_llm/graph_rewriting.py (1)
  • replace_input_with (373-392)
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (8)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (3)
  • AttentionRegistry (730-753)
  • get (747-749)
  • get_attention_layout (616-617)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)
  • get (391-393)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (1)
  • get_attention_layout (400-402)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py (1)
  • get_attention_layout (29-30)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py (1)
  • get_attention_layout (1025-1026)
tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py (1)
  • get_attention_layout (329-331)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (1)
  • get_attention_layout (346-348)
tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py (1)
  • get_attention_layout (217-219)
tensorrt_llm/_torch/auto_deploy/models/hf.py (4)
tensorrt_llm/_torch/auto_deploy/models/factory.py (3)
  • model (43-45)
  • get_quant_config (95-97)
  • get (222-224)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)
  • get_quant_config (37-38)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (1)
  • get_quant_config (35-36)
tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (1)
  • get (69-72)
🔇 Additional comments (22)
tensorrt_llm/_torch/auto_deploy/models/hf.py (2)

179-183: LGTM! Clean migration to reader-based config retrieval

The method correctly returns the config from the reader when available, with a sensible empty dict fallback for non-quantized models.


352-359: LGTM! Proper support for nested text_config in vision-language models

The addition of nested text_config.max_position_embeddings correctly supports multi-modal models that require position embeddings configuration in their text component.

tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py (1)

35-41: LGTM! Correctly implements the new SharedConfig interface

The addition of shared_config: SharedConfig parameter aligns with the updated transform interface used throughout the PR. The parameter is appropriately unused in this simple wrapper transform.

tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (1)

52-58: Signature updated to accept SharedConfig — LGTM

The added shared_config parameter aligns with the new BaseTransform interface. No functional changes, safe for downstream use.

tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)

79-79: Propagating SharedConfig to transforms — LGTM

This aligns the optimizer with the new transform interface.

tensorrt_llm/_torch/auto_deploy/transform/interface.py (2)

51-57: SharedConfig introduction — LGTM

The SharedConfig model is concise and sufficient for propagating sharding and distributed context.


200-207: SharedConfig threaded through BaseTransform — LGTM

call now passes shared_config to _apply; consistent with new interface and optimizer updates.

Also applies to: 249-249, 256-256

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

169-175: Test harness usage — LGTM

run_test_transformed_gm flow looks correct; parameter count guard and combined graph checks are appropriate.

tensorrt_llm/_torch/auto_deploy/transformations/transform.py (2)

48-66: Configuration setup logic is clean and well-structured.

The pre-apply configuration for load_weights appropriately sets the checkpoint device and device based on the cached sequence interface. The conditional checks ensure the configuration is only applied when the transforms exist.


67-67: Manual Verification: Ensure All TransformRegistry Entries Are Applied

The refactor delegates transformation logic to ModularInferenceOptimizer (alias InferenceOptimizer), but we haven’t confirmed it actually picks up every transform previously registered under TransformRegistry in tensorrt_llm/_torch/auto_deploy/transform/library. Please verify:

• That the list of transform names in self.ad_config.transforms covers all keys registered via
TransformRegistry.register("…") (e.g.
"sharding_transform_executor", "cleanup_noop_slice", …, "quantize_from_graph").
• That InferenceOptimizer’s initialization (in
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py) imports or otherwise loads all library modules so the registry is fully populated before use.

Key locations to check:

  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py (around line 58: egm = new_optimizer(cm))
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (constructor of InferenceOptimizer)
tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (6)

130-225: MatchRopePattern implementation looks correct.

The conversion from standalone function to transform class is well-executed. The _apply method properly returns the expected Tuple[GraphModule, TransformInfo] and includes appropriate metadata in the TransformInfo object.


228-261: Early return properly handles unsupported layouts.

The early return issue from previous reviews has been correctly addressed. When expected_layout is not supported, the method properly returns a TransformInfo with skipped=True instead of None, maintaining the expected return contract.


359-393: OptimizeRope implementation is well-structured.

The conversion to transform class maintains the optimization logic while properly returning TransformInfo with the number of optimizations performed. The helper functions _optimize_explicit and _optimize_complex are correctly called.


475-476: Node cleanup is properly handled in _optimize_explicit.

The explicit optimization correctly erases the old rope nodes after replacing their uses, which addresses the stale node concern from past reviews.


538-540: Node cleanup is properly handled in _optimize_complex.

The complex optimization correctly erases the old node after replacement, maintaining graph cleanliness.


584-611: Input validation logic is robust.

The _validate_rope_inputs function properly checks dtype, head dimension constraints, and shape requirements. The validation ensures only compatible inputs proceed to optimization.

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (6)

2-2: LGTM! Import additions support the new transforms.

The new imports (Tuple, SharedConfig, TransformInfo, TransformRegistry) are necessary for the transform implementations and follow proper namespace conventions.

Also applies to: 23-23


168-175: Well-structured transform class with clear documentation.

The QuantizationFromConfig class follows the established transform pattern with proper registration, inheritance, and comprehensive docstring describing its functionality.


177-217: Implementation looks solid with proper error handling.

The _apply method correctly handles the quantization flow:

  • Validates quantization config and algorithm availability
  • Properly skips excluded patterns using should_skip_quantization
  • Creates appropriate QuantizationImpl instances for linear and BMM operations
  • Returns accurate TransformInfo metadata

The TODO comment on Line 207 appropriately tracks the known issue with BMM match counting accuracy.


189-189: Good defensive coding practices.

Using .get("quant_algo", None) with a default value and explicitly initializing num_matches = 0 improves code robustness and readability.

Also applies to: 196-196


220-227: Clear separation of concerns with proper documentation.

The QuantizationFromGraph class is well-defined with a docstring that clearly distinguishes its purpose from the config-based approach - handling pre-quantized ModelOpt checkpoints.


229-266: Effective graph-based quantization detection and fusion.

The _apply method properly handles ModelOpt-quantized graphs:

  • Uses is_quantized_graph for early detection
  • Processes each linear node with get_quantization_from_linear_node for per-node quantization type
  • Appropriately calls remove_output_quantizers for cleanup
  • The comment on Line 258 helpfully explains why BMM operations aren't processed in this transform

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14768 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (16)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)

201-207: Verify all subclasses have updated _apply signatures.

The base class correctly includes shared_config: SharedConfig in both __call__ and _apply. Based on the provided snippets, most subclasses appear updated, but ensure all transforms consistently use the new signature.

#!/bin/bash
# Verify all _apply methods include shared_config parameter
echo "Checking for _apply methods missing shared_config parameter:"
rg -n "def _apply\(.*\)" --glob "tensorrt_llm/_torch/auto_deploy/transform/**/*.py" | grep -v "shared_config" | head -20

Also applies to: 362-367

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (2)

207-211: TODO acknowledged for BMM quantization match counting.

The TODO correctly identifies that _insert_quantized_bmm may not actually modify the graph but the match is still counted. Consider prioritizing this fix to ensure accurate transformation metrics.

Would you like me to create a GitHub issue to track this enhancement?


177-183: Add null check for factory before accessing methods.

In unit tests, the optimizer is often invoked with factory=None, which will cause an AttributeError at line 184.

 def _apply(
     self,
     gm: GraphModule,
     cm: CachedSequenceInterface,
     factory: ModelFactory,
     shared_config: SharedConfig,
 ) -> Tuple[GraphModule, TransformInfo]:
+    if factory is None:
+        return gm, TransformInfo(
+            skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
+        )
     quant_config = factory.get_quant_config()
tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (2)

19-19: Fix type annotation for _quant_config.

The field is always a dict (never None), so the Optional annotation is misleading.

-        self._quant_config: Optional[Dict] = {}
+        self._quant_config: Dict[str, Any] = {}

1-8: Add mandatory NVIDIA copyright header.

All production Python files must include the NVIDIA copyright header for the current year (2025).

Add the standard NVIDIA copyright notice before the module docstring:

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tensorrt_llm/_torch/auto_deploy/models/hf.py (3)

185-196: Replace assert with explicit validation.

The assert is redundant and could be optimized out in production. Use explicit validation for better error handling.

 def get_cache_config(self):
     """Return kv cache dtype configuration."""
     if not self._quant_config_reader:
         return CacheConfig(dtype=None)

     kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype")
-    torch_dtype = torch.float8_e4m3fn if kv_cache_dtype == "float8_e4m3fn" else None
-    assert torch_dtype in (torch.float8_e4m3fn, None), (
-        f"Unsupported dtype: {torch_dtype}. Only torch.float8_e4m3fn is supported."
-    )
-
-    return CacheConfig(dtype=torch_dtype)
+    if kv_cache_dtype is None:
+        return CacheConfig(dtype=None)
+    if kv_cache_dtype != "float8_e4m3fn":
+        raise ValueError(f"Unsupported kv_cache_dtype: {kv_cache_dtype}. Only 'float8_e4m3fn' is supported.")
+    return CacheConfig(dtype=torch.float8_e4m3fn)

329-331: Use configurable quantization_source instead of hard-coding "modelopt".

The PR objectives mention adding quantization_source to the factory interface, but it's still hard-coded here.

The factory should accept a quantization_source parameter (defaulting to "modelopt") and use it here:

-        # TODO: specified by user or auto-detect
-        reader_cls = QuantConfigReaderRegistry.get("modelopt")
+        source = getattr(self, "quantization_source", "modelopt")
+        reader_cls = QuantConfigReaderRegistry.get(source)

Also update the ModelFactory.__init__ to accept and store this parameter.


336-338: Normalize torch_dtype after merging extra_model_kwargs.

The extra_model_kwargs may contain "torch_dtype" as a string. Since this merge happens after init, the dtype normalization needs to be repeated.

 if reader is not None:
     self._quant_config_reader = reader
     self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs)
+    # Normalize torch_dtype if it was added/updated by extra_model_kwargs
+    if "torch_dtype" in self.model_kwargs and isinstance(self.model_kwargs["torch_dtype"], str):
+        dtype = getattr(torch, self.model_kwargs["torch_dtype"], None)
+        if not isinstance(dtype, torch.dtype):
+            raise ValueError(f"Invalid torch_dtype: {self.model_kwargs['torch_dtype']}")
+        self.model_kwargs["torch_dtype"] = dtype
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (5)

1-1: Missing NVIDIA copyright header (required by repo guidelines).

Per coding guidelines, all OSS source files must include an NVIDIA copyright header with the current year.


266-268: Remove duplicate fields (rank/world_size) — already defined in base class.

BMMShardingInfo redeclares rank and world_size, which are already in ShardingTransformInfo. This is redundant and risks confusion.


327-329: Do not re-register an existing parameter — use setattr like TP path.

register_parameter can raise if a parameter with the same name already exists. The TP path uses setattr, which is safe and consistent.


346-347: torch.fx Node has no update_arg(); replace with replace_input_with or args tuple update.

Node.update_arg is not part of the FX API and will fail at runtime.


458-471: Validate num_experts/world_size to prevent divide-by-zero and invalid masking.

experts_per_rank = num_experts // world_size later can be zero if world_size > num_experts, leading to division by zero in floordiv and incorrect masks.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

42-47: Avoid double-quantization: gate selection between config vs graph sources

Both quantize_from_config and quantize_from_graph are enabled in the same stage. If a graph is ModelOpt-quantized, running the config-based pass first might be wasteful or conflicting depending on detection heuristics.

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (2)

64-66: Python 3.8 compatibility: use typing.List/Tuple instead of PEP 585 generics

The project targets Python 3.8+. Built-in generics like list[Node] and tuple[...] require 3.9+. Replace with List[...]/Tuple[...].

Also applies to: 129-131, 268-271, 320-323


146-153: Logic bug: elif {is_op(...), is_op(...)}: always truthy

This set literal is always truthy, so the branch executes even when neither op matches. Use boolean OR.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6e0a7b2 and 08169cd.

📒 Files selected for processing (30)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/hf.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py (7 hunks)
💤 Files with no reviewable changes (2)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/init.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py
🚧 Files skipped from review as they are similar to previous changes (15)
  • tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
🧠 Learnings (4)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
📚 Learning: 2025-08-06T21:22:55.018Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-06T21:22:55.018Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
📚 Learning: 2025-08-09T02:04:49.580Z
Learnt from: Fridah-nv
PR: NVIDIA/TensorRT-LLM#6760
File: tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py:81-98
Timestamp: 2025-08-09T02:04:49.580Z
Learning: In TensorRT-LLM's auto_deploy module, torch.dtype values in configuration dictionaries must be stored as string representations (e.g., "float16" instead of torch.float16) because OmegaConf.merge does not support torch.dtype types. These string representations are converted to actual torch.dtype objects in downstream code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
🧬 Code Graph Analysis (9)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (4)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1)
  • MoEOpModel (170-206)
tensorrt_llm/_torch/auto_deploy/export/export.py (1)
  • torch_export_to_gm (198-284)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)
  • InferenceOptimizer (22-84)
tensorrt_llm/module.py (1)
  • named_parameters (166-171)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (4)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (2)
  • run_sharding_pattern_detection_test (228-242)
  • run_test_transformed_gm (68-138)
tensorrt_llm/_torch/auto_deploy/export/export.py (1)
  • torch_export_to_gm (198-284)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
  • BMMShardingInfo (263-362)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)
  • InferenceOptimizer (22-84)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (4)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (4)
  • bfs (348-365)
  • identify_regions_between_residuals (292-345)
  • is_linear_op (240-252)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
  • get_scales_and_type_from_node (505-512)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py (3)
  • torch_moe (44-78)
  • torch_quant_fp8_moe (159-217)
  • torch_quant_fp4_moe (239-305)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (6)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (7)
  • BaseTransform (138-372)
  • SharedConfig (51-56)
  • TransformInfo (107-132)
  • TransformRegistry (375-403)
  • register (381-388)
  • _apply (362-372)
  • get (391-393)
tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (2)
  • register (61-66)
  • get (69-72)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/models/hf.py (1)
  • get_quant_config (179-183)
tensorrt_llm/models/modeling_utils.py (1)
  • quant_algo (547-548)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (5)
  • should_skip_quantization (470-483)
  • QuantizationImpl (68-161)
  • is_quantized_graph (363-371)
  • get_quantization_from_linear_node (394-407)
  • remove_output_quantizers (384-391)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (6)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
  • ShardingConfig (477-482)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)
  • _apply (57-91)
  • _apply (149-321)
tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py (1)
  • _apply (24-53)
tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py (1)
  • _apply (24-56)
tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (1)
  • _apply (52-81)
tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (1)
  • _apply (141-183)
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (5)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)
  • get (391-393)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (1)
  • get_attention_layout (400-402)
tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py (1)
  • get_attention_layout (329-331)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py (1)
  • get_attention_layout (346-348)
tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py (1)
  • get_attention_layout (217-219)
tensorrt_llm/_torch/auto_deploy/models/hf.py (4)
tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (6)
  • QuantConfigReader (15-54)
  • QuantConfigReaderRegistry (57-76)
  • get_config (21-23)
  • get (69-72)
  • from_file (42-54)
  • from_file (110-130)
tensorrt_llm/_torch/auto_deploy/models/factory.py (4)
  • model (43-45)
  • get_quant_config (95-97)
  • get_cache_config (99-105)
  • get (222-224)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (2)
  • CacheConfig (25-28)
  • get (747-749)
tensorrt_llm/_torch/auto_deploy/utils/_config.py (1)
  • deep_merge_dicts (13-27)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (9)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (3)
  • extract_param_names_from_lin_node (149-170)
  • is_op (183-206)
  • num_users_of_weight_node (143-146)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
  • QuantizationImpl (68-161)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)
  • check_and_apply (67-75)
tensorrt_llm/logger.py (1)
  • warning (131-132)
tensorrt_llm/_torch/modules/linear.py (1)
  • split_dim (49-50)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)
  • target (361-362)
tensorrt_llm/parameter.py (1)
  • Parameter (34-274)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (1)
  • call_function (247-274)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py (1)
  • torch_moe (44-78)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (5)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (2)
  • run_sharding_pattern_detection_test (228-242)
  • run_test_transformed_gm (68-138)
tensorrt_llm/_torch/auto_deploy/export/export.py (1)
  • torch_export_to_gm (198-284)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
  • SplitDimension (182-186)
  • TPShardingInfo (225-260)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)
  • InferenceOptimizer (22-84)
tensorrt_llm/mapping.py (1)
  • local_rank (372-373)
🔇 Additional comments (25)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)

51-56: LGTM! Well-structured SharedConfig implementation.

The Pydantic model properly encapsulates shared configuration with sensible defaults for single-GPU scenarios. Good use of default_factory for the mutable ShardingConfig.

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

220-266: LGTM! Clean implementation of graph-based quantization.

The transform correctly identifies and processes pre-quantized graphs from ModelOpt checkpoints. Good separation of concerns between config-based and graph-based quantization.

tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (1)

79-107: LGTM! Proper ModelOPT config handling.

Good validation of producer name and appropriate handling of NVFP4 dtype requirements. The string representation for torch_dtype is correct for OmegaConf compatibility.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (2)

304-312: LGTM! Proper use of new transform pipeline.

The test correctly exports the model and applies the transform through InferenceOptimizer. Good use of the registry-based approach with appropriate stage configuration.


327-364: LGTM! Comprehensive MOE fusion testing.

Good test coverage including:

  • Proper transform application through InferenceOptimizer
  • Verification of fused operations in the graph
  • Parameter count validation to ensure fusion actually occurred
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (3)

9-9: LGTM! Import updates align with the new optimizer-driven architecture.

The imports have been correctly updated to use the new test helpers and the centralized sharding module structure.

Also applies to: 13-14


63-74: LGTM! Correct implementation of the new optimizer-driven workflow.

The test has been properly updated to use:

  1. torch_export_to_gm for model export
  2. InferenceOptimizer with appropriate stage configuration for BMM sharding
  3. run_test_transformed_gm for validation of the transformed GraphModule

This aligns with the architectural changes across the PR.

Also applies to: 82-89


124-135: LGTM! Pattern detection correctly updated for the new architecture.

The detection logic properly:

  • Uses InferenceOptimizer with only the detection stage enabled
  • Sets rank/world_size on the shared_config for proper context
  • Retrieves detected transformations from optimizer.shared_config.sharding_config.bmm_transforms

This is consistent with the new SharedConfig-based transform communication pattern.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (3)

11-11: LGTM! Import updates align with the centralized architecture.

The imports have been correctly updated to use the new optimizer and centralized sharding module structure.

Also applies to: 15-19


147-158: LGTM! Correct implementation of optimizer-driven TP sharding workflow.

The test properly uses the new pattern:

  1. Export to GraphModule with torch_export_to_gm
  2. Configure InferenceOptimizer with TP sharding stages
  3. Validate with run_test_transformed_gm

This is consistent with the architectural refactoring.

Also applies to: 169-175


271-282: LGTM! Pattern detection correctly updated for SharedConfig architecture.

The detection logic properly:

  • Uses InferenceOptimizer with detection-only configuration
  • Sets context via shared_config
  • Retrieves results from optimizer.shared_config.sharding_config.tp_transforms
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (2)

48-66: LGTM! Configuration setup properly prepares optimizer parameters.

The configuration adjustments correctly:

  • Set the expected RoPE layout from the attention backend
  • Configure load_weights with appropriate device settings

This ensures the modular optimizer has the correct context for execution.


58-67: LGTM! Clean transition to modular optimizer-driven transformation.

The code correctly:

  • Creates ModularInferenceOptimizer with the factory and configuration
  • Applies it to get the transformed GraphModule

This replaces the complex legacy multi-stage approach with a cleaner, centralized solution.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (1)

8-8: LGTM! EP sharding test correctly updated for optimizer-driven architecture.

All changes properly implement the new workflow:

  • Updated imports for centralized sharding module and optimizer
  • Transform execution via InferenceOptimizer with appropriate stage configuration
  • Pattern detection using SharedConfig-based approach
  • Validation via run_test_transformed_gm

This is consistent with the architectural refactoring across all sharding tests.

Also applies to: 13-14, 38-49, 53-61, 96-107

tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (5)

50-50: LGTM! Import updates support the new class-based transform architecture.

The imports correctly add the necessary components for:

  • Type annotations
  • Pydantic configuration classes
  • Transform interface components (BaseTransform, TransformRegistry, etc.)

Also applies to: 53-53, 60-66


130-225: LGTM! MatchRopePattern correctly implements the new transform interface.

The class properly:

  • Registers with TransformRegistry
  • Implements the required _apply method signature
  • Handles multiple RoPE pattern variants (explicit, interleaved, complex)
  • Returns appropriate TransformInfo with match count

The pattern matching logic is comprehensive and the implementation follows the new architecture correctly.


228-235: LGTM! Well-defined configuration class for RoPE layout matching.

The configuration class properly defines the expected_layout parameter with appropriate default value and description.


237-356: LGTM! MatchRopeLayout correctly implements layout transformation with proper error handling.

The class properly:

  • Implements the BaseTransform interface with configuration support
  • Handles unsupported layouts by returning skipped TransformInfo (addresses previous review concern)
  • Performs layout transformations by inserting transpose/contiguous operations
  • Updates node arguments and metadata appropriately

The implementation is robust and follows the new architecture correctly.


359-393: LGTM! OptimizeRope correctly implements RoPE optimization transform.

The class properly:

  • Registers with TransformRegistry
  • Handles both explicit and complex RoPE optimization patterns
  • Tracks optimization count in TransformInfo
  • Follows the new transform architecture
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)

182-186: LGTM! Well-defined enum improves code clarity.

The SplitDimension enum provides clear, self-documenting constants for tensor sharding dimensions.


189-223: LGTM! Well-designed abstract base class for sharding transformations.

The base class properly:

  • Uses Pydantic for configuration management
  • Defines common fields (target_node, rank, world_size)
  • Requires concrete implementations to define validation and application logic
  • Provides a convenient check_and_apply method
tensorrt_llm/_torch/auto_deploy/config/default.yaml (3)

30-47: LGTM! Pattern matcher transforms correctly configured for new architecture.

The new transforms properly support the class-based transform system:

  • MoE pattern matching and RoPE transforms
  • Quantization transforms (both config and graph-based)
  • Redundant transpose elimination

All are appropriately assigned to the pattern_matcher stage.


49-59: LGTM! Sharding transforms properly configured for the new architecture.

The sharding stage correctly includes:

  • Detection transforms for different sharding types
  • Transform executor with shape propagation enabled
  • Proper stage assignment for the sharding workflow

60-61: LGTM! Weight loading transform properly configured.

The load_weights transform is correctly assigned to the weight_load stage, supporting the new architecture.

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)

372-500: LGTM! Well-structured MoE transforms with comprehensive pattern matching.

The classes properly:

  • Implement the BaseTransform interface
  • Register with TransformRegistry
  • Handle complex MoE pattern detection and fusion
  • Support quantized variants (FP4/FP8)
  • Return appropriate TransformInfo

The architectural approach is sound and integrates well with the new transform system.

Also applies to: 503-523

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14768 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #11148 completed with status: 'FAILURE'

h-guo18 and others added 4 commits August 11, 2025 09:49
…ransforms to new inf optimizer (#113)

* refactor: move eliminate_transpose and rope transform to modular inference optimizer

Signed-off-by: h-guo18 <[email protected]>

* refactor: move eliminate_transpose and rope transform to modular inference optimizer

Signed-off-by: h-guo18 <[email protected]>

* merge: updates of rope.py

Signed-off-by: h-guo18 <[email protected]>

---------

Signed-off-by: h-guo18 <[email protected]>
* Move quant_config handling to _load_quantization_config

Signed-off-by: Fridah-nv <[email protected]>

move kv_cache_dtype into _quant_config in hf factory

Signed-off-by: Fridah-nv <[email protected]>

remove quant_source getter

Signed-off-by: Fridah-nv <[email protected]>

* add QuantConfigReader class

Signed-off-by: Fridah-nv <[email protected]>

minor

Signed-off-by: Fridah-nv <[email protected]>

tmp:Llama4 FP8 for BMM testing

Signed-off-by: Fridah-nv <[email protected]>

revert Llama4 FP8 patch

Signed-off-by: Fridah-nv <[email protected]>

move _quant_config into QuantConfigReader

Signed-off-by: Fridah-nv <[email protected]>

* move quantize and quantize_moe to the end of pattern matcher

Signed-off-by: Fridah-nv <[email protected]>

* delegate quant_config processing to QuantConfigReader and pass read to the transformation, spilit transformation into config and graph based

Signed-off-by: Fridah-nv <[email protected]>

have quantConfigReader return the dtype for NVFP4

Signed-off-by: Fridah-nv <[email protected]>

move quantization target collection as a transform

Signed-off-by: Fridah-nv <[email protected]>

minor

Signed-off-by: Fridah-nv <[email protected]>

tmp:hacky fix modelopt graph based path quantizer loading

Signed-off-by: Fridah-nv <[email protected]>

fix rebase quantization to BaseTransform

Signed-off-by: Fridah-nv <[email protected]>

minor: remove QuantizationTarget

Signed-off-by: Fridah-nv <[email protected]>

rm tmp fix, rebase, minor update on interface

Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: Fridah-nv <[email protected]>

update transform docstring

Signed-off-by: Fridah-nv <[email protected]>

fix unit test

Signed-off-by: Fridah-nv <[email protected]>

* move quantization to end of transformations

Signed-off-by: Fridah-nv <[email protected]>

---------

Signed-off-by: Fridah-nv <[email protected]>

minor

Signed-off-by: Fridah-nv <[email protected]>

address feedback

Signed-off-by: Frida Hou <[email protected]>
…ular inference optimizer (#124)

* refactor: move match_moe_pattern, sharding, load_weights to modular inf optimizer

Signed-off-by: h-guo18 <[email protected]>

* refactor: move match_moe, sharding, load_weights to modular inference optimizer

Signed-off-by: h-guo18 <[email protected]>

* style: fix typo; remove redundant comments

Signed-off-by: h-guo18 <[email protected]>

* polish: keep moe skip test; move sharding configs to avoid circular imports;

Signed-off-by: h-guo18 <[email protected]>

* remove redundant comments

Signed-off-by: h-guo18 <[email protected]>

* style: rename class name

Signed-off-by: h-guo18 <[email protected]>

* fix quant_graph inputs; add type hints for shared_config

Signed-off-by: h-guo18 <[email protected]>

---------

Signed-off-by: h-guo18 <[email protected]>
h-guo18 and others added 3 commits August 11, 2025 09:49
@Fridah-nv Fridah-nv force-pushed the user/fridah/merge-0801 branch from 08169cd to c1e0da8 Compare August 11, 2025 16:49
@Fridah-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14833 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

52-56: Fix potential IndexError when accessing node.users and scope the user lookup after the length check.

You access the first user before verifying there is exactly one.

Apply:

-        # redirect output to skip output quantizer if any
-        user = list(node.users.keys())[0]
-        if len(node.users) == 1 and is_quantized_op(user):
-            user.replace_all_uses_with(node)
+        # redirect output to skip output quantizer if any
+        if len(node.users) == 1:
+            user = next(iter(node.users))
+            if is_quantized_op(user):
+                user.replace_all_uses_with(node)
♻️ Duplicate comments (21)
tensorrt_llm/_torch/auto_deploy/models/hf.py (4)

87-87: Use Optional[QuantConfigReader] for Python 3.8 compatibility

The | union syntax requires Python 3.10+, but the project targets Python 3.8+. Use typing.Optional instead.

-        self._quant_config_reader: QuantConfigReader | None = None
+        from typing import Optional
+        self._quant_config_reader: Optional[QuantConfigReader] = None

329-332: Honor quantization_source instead of hard-coding "modelopt"

The PR objectives mention adding quantization_source to the factory interface. The code should use this parameter instead of hard-coding "modelopt".

First, add quantization_source parameter to the base ModelFactory.__init__:

# In tensorrt_llm/_torch/auto_deploy/models/factory.py
class ModelFactory(ABC):
    def __init__(
        self,
        ...
        max_seq_len: int = 512,
        quantization_source: str = "modelopt",
        **kwargs,
    ):
        self.quantization_source = quantization_source
        ...

Then use it here:

-        # TODO: specified by user or auto-detect
-        reader_cls = QuantConfigReaderRegistry.get("modelopt")
+        reader_cls = QuantConfigReaderRegistry.get(self.quantization_source)

336-339: Normalize torch_dtype after merging extra_model_kwargs

The merged extra_model_kwargs may contain "torch_dtype" as a string. Re-normalize it to avoid propagating strings into the model config.

         if reader is not None:
             self._quant_config_reader = reader
             self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs)
+            # Re-normalize torch_dtype if it was provided as a string
+            if "torch_dtype" in self.model_kwargs and isinstance(self.model_kwargs["torch_dtype"], str):
+                dtype = getattr(torch, self.model_kwargs["torch_dtype"], None)
+                if not isinstance(dtype, torch.dtype):
+                    raise ValueError(f"Invalid torch_dtype: {self.model_kwargs['torch_dtype']}")
+                self.model_kwargs["torch_dtype"] = dtype

186-196: Simplify kv_cache_dtype handling with explicit mapping

The assert is redundant since torch_dtype can only be None or float8_e4m3fn by construction. Use an explicit mapping for better extensibility.

-        kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype")
-        torch_dtype = torch.float8_e4m3fn if kv_cache_dtype == "float8_e4m3fn" else None
-        assert torch_dtype in (torch.float8_e4m3fn, None), (
-            f"Unsupported dtype: {torch_dtype}. Only torch.float8_e4m3fn is supported."
-        )
-
-        return CacheConfig(dtype=torch_dtype)
+        kv_cache_dtype = self._quant_config_reader.get_config().get("kv_cache_dtype")
+        dtype_map = {
+            "float8_e4m3fn": torch.float8_e4m3fn,
+        }
+        if kv_cache_dtype is None:
+            return CacheConfig(dtype=None)
+        if kv_cache_dtype not in dtype_map:
+            raise ValueError(f"Unsupported kv_cache_dtype: {kv_cache_dtype}. Supported: {list(dtype_map.keys())}")
+        return CacheConfig(dtype=dtype_map[kv_cache_dtype])
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

42-45: Gate quantization methods to avoid double-quantization

Both quantize_from_config and quantize_from_graph are enabled in the same stage, which could cause conflicts if both attempt to quantize the model.

Consider one of these approaches:

  1. Add a configuration switch to enable only one method:
  quantize_from_config:
    stage: pattern_matcher
    enabled: ${quantization_source == "config"}
  quantize_from_graph:
    stage: pattern_matcher
    enabled: ${quantization_source == "graph"}
  1. Or implement detection logic in the transforms to skip if already quantized.
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (2)

64-66: Use typing.List and typing.Tuple for Python 3.8 compatibility

Built-in generics like list[Node] and tuple[...] require Python 3.9+. The project targets Python 3.8+.

-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
 
-def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
+def _find_lowest_common_ancessor(nodes: List[Node]) -> Optional[Node]:
 
-def _extract_linear_parameters(linear_node: Node) -> tuple[Node, torch.Tensor, Optional[dict], str]:
+def _extract_linear_parameters(linear_node: Node) -> Tuple[Node, Node, Optional[dict], str]:
 
-def _find_final_hidden_state_node(
-    pattern_output_nodes: list[Node], end_boundary: Node
-) -> Optional[Node]:
+def _find_final_hidden_state_node(
+    pattern_output_nodes: List[Node], end_boundary: Node
+) -> Optional[Node]:
 
-def _extract_index_branches_from_expert_outputs(
-    pattern_output_nodes: list[Node],
-) -> tuple[list[Node], list[Node]]:
+def _extract_index_branches_from_expert_outputs(
+    pattern_output_nodes: List[Node],
+) -> Tuple[List[Node], List[Node]]:

Note: The return type for _extract_linear_parameters should be Node not torch.Tensor for the weight parameter, as it returns a node from the graph.

Also applies to: 129-131, 268-271, 320-323


146-153: Fix logic bug: use boolean OR instead of set literal

The set literal {is_op(...), is_op(...)} is always truthy (even when both conditions are False), causing the elif branch to always execute. Use boolean OR instead.

-    elif {
-        is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear)
-        or is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear),
-    }:
+    elif (
+        is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear)
+        or is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear)
+    ):
         weight = linear_node.args[1]
         scales, quant_type = get_scales_and_type_from_node(linear_node)
         return input_node, weight, scales or {}, quant_type
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)

79-84: Recompile the FX graph after all transforms complete

Transforms mutate gm.graph. Call lint() and recompile() after all transforms to ensure the graph is valid and avoid stale code objects.

         # iterate over all transforms sorted by stage in the config
         for t_name, t_config in self.config.items():
             # instantiate transform
             transform = TransformRegistry.get(t_name)(t_config)
             # run transform
             gm = transform(gm, cm, self.factory, self.shared_config)
 
         ############################################################################################
         # RETURN OPTIMIZED GRAPH
         ############################################################################################
+        # Ensure the graph is valid and recompile
+        try:
+            gm.graph.lint()
+        finally:
+            gm.recompile()
         return gm
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)

361-368: Verify all subclasses updated to new _apply(..., shared_config) signature

BaseTransform now calls _apply(gm, cm, factory, shared_config). Ensure every subclass matches this signature to avoid runtime errors. The prior verification agent flagged several offenders in library modules.

Run this quick scan to list any _apply definitions missing shared_config:

#!/bin/bash
# 1) Find BaseTransform subclasses whose _apply lacks shared_config.
ast-grep --pattern $'class $_(BaseTransform) {\n  $$$\n  def _apply(self, $_, $_, $_):\n    $$$\n}' || true

# 2) Broad scan of all _apply definitions without 'shared_config' token.
rg -n '^\s*def\s+_apply\s*\(' --glob '*.py' | grep -v 'shared_config' || true
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

15-18: Import TPShardingInfo/SplitDimension from canonical utils module

These types are defined under utils.sharding_utils and may not be re-exported by transform.library.sharding. Import from the canonical location to avoid fragile coupling.

-from tensorrt_llm._torch.auto_deploy.transform.library.sharding import (
-    SplitDimension,
-    TPShardingInfo,
-)
+from tensorrt_llm._torch.auto_deploy.utils.sharding_utils import (
+    SplitDimension,
+    TPShardingInfo,
+)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (3)

85-167: Return a boolean from _insert_quantized_bmm to indicate success/failure.

This enables accurate num_matches counting and graceful skips when weight shape is unknown.

Apply:

-def _insert_quantized_bmm(
+def _insert_quantized_bmm(
     gm: GraphModule,
     node: Node,
     quantization_impl: QuantizationImpl,
     is_quantized_graph: bool = False,
-):
-    """Replaces the bmm node with a new quantized bmm node."""
+)-> bool:
+    """Replaces the bmm node with a new quantized bmm node. Returns True on success."""
@@
-    else:
-        # If we can't determine the shape, skip quantization
-        return
+    else:
+        # If we can't determine the shape, skip quantization
+        return False
@@
-    node.args = (*node.args, *scale_values)
+    node.args = (*node.args, *scale_values)
+    return True

181-186: Guard against factory being None before calling get_quant_config().

Unit tests often pass factory=None; calling factory.get_quant_config() will raise.

Apply:

     def _apply(
         self,
         gm: GraphModule,
         cm: CachedSequenceInterface,
         factory: ModelFactory,
         shared_config: SharedConfig,
     ) -> Tuple[GraphModule, TransformInfo]:
-        quant_config = factory.get_quant_config()
+        if factory is None:
+            return gm, TransformInfo(
+                skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
+            )
+        quant_config = factory.get_quant_config()

207-211: Increment num_matches only when _insert_quantized_bmm succeeds.

Currently it increments even when quantization is skipped.

Apply:

-            elif is_bmm_op(n):
-                impl = QuantizationImpl.create(quant_algo, is_bmm=True)
-                _insert_quantized_bmm(gm, n, impl, False)
-                num_matches += 1
+            elif is_bmm_op(n):
+                impl = QuantizationImpl.create(quant_algo, is_bmm=True)
+                success = _insert_quantized_bmm(gm, n, impl, False)
+                if success:
+                    num_matches += 1
tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (3)

1-1: Missing NVIDIA copyright header.

Add the standard header before the module docstring.

Apply:

+ # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ #
 """Graph transformation to eliminate redundant transpose operations in the model graph.

20-27: Use specific aten overloads for robust FX matching.

Overloadpacks can miss matches in FX graphs; prefer transpose.int and contiguous.default.

Apply:

 def _is_transpose_op(node: Node) -> bool:
     """Check if the node is a transpose operation."""
-    return is_op(node, torch.ops.aten.transpose)
+    return is_op(node, torch.ops.aten.transpose.int)
@@
 def _is_contiguous_op(node: Node) -> bool:
     """Check if the node is a contiguous operation."""
-    return is_op(node, torch.ops.aten.contiguous)
+    return is_op(node, torch.ops.aten.contiguous.default)

101-112: Fix graph rewiring to avoid creating a self-loop and corrupting the graph.

Do not replace all uses of original_input. Replace uses of the second transpose node. Insert contiguous only for that path.

Apply:

-            # Replace all uses of the second transpose with the input to the first transpose
-            original_input = t_node.args[0]
-            t_comp_node.replace_all_uses_with(original_input)
-
-            # if there is a contiguous operation that we skipped, let add it after t_comp_node as new
-            # graph node that call contiguous on t_comp_node
-            if has_contiguous:
-                with graph.inserting_after(original_input):
-                    new_contiguous_node = graph.call_function(
-                        torch.ops.aten.contiguous.default, args=(original_input,)
-                    )
-                original_input.replace_all_uses_with(new_contiguous_node)
-                new_contiguous_node.replace_input_with(new_contiguous_node, original_input)
+            # Replace all uses of the second transpose appropriately
+            original_input = t_node.args[0]
+            if has_contiguous:
+                # Insert a contiguous op on the original input and rewire only the matched path
+                with graph.inserting_after(original_input):
+                    new_contiguous_node = graph.call_function(
+                        torch.ops.aten.contiguous.default, args=(original_input,)
+                    )
+                t_comp_node.replace_all_uses_with(new_contiguous_node)
+            else:
+                # Directly bypass the second transpose
+                t_comp_node.replace_all_uses_with(original_input)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (5)

1-1: Missing NVIDIA copyright header (required).

Add the standard header at the top.

Apply:

+ # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ #
 """Sharding config definitions for the inference optimizer."""

266-268: Remove duplicate rank/world_size fields; they already exist in the base class.

Redefinition is redundant and confusing.

Apply:

 class BMMShardingInfo(ShardingTransformInfo):
     """Configuration for BMM sharding transformations."""
 
-    rank: int
-    world_size: int
     start_idx: int
     end_idx: int

327-329: Avoid register_parameter; use setattr to replace existing parameter safely.

register_parameter can raise if the name already exists.

Apply:

-                param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True)
-                gm.get_submodule(modname).register_parameter(param_name, param_new)
+                param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True)
+                setattr(gm.get_submodule(modname), param_name, param_new)

346-347: Node.update_arg is not part of FX API; use replace_input_with or rebuild args.

This will otherwise fail at runtime.

Apply:

-                # Update BMM node to use the sliced tensor
-                bmm_node.update_arg(arg_idx, tensor_slice)
+                # Update BMM node to use the sliced tensor
+                bmm_node.replace_input_with(tensor_node, tensor_slice)
+                # Alternatively:
+                # new_args = list(bmm_node.args)
+                # new_args[arg_idx] = tensor_slice
+                # bmm_node.args = tuple(new_args)

458-471: Validate num_experts and world_size to prevent divide-by-zero and invalid masking.

experts_per_rank == 0 leads to division by zero later; skip invalid configs.

Apply:

     def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
         """Validate the transformation configuration."""
         if not is_op(
             node,
             (
                 torch.ops.auto_deploy.torch_moe,
                 torch.ops.auto_deploy.torch_quant_fp8_moe,
                 torch.ops.auto_deploy.torch_quant_fp4_moe,
             ),
         ):
             ad_logger.warning(f"EP sharding is only supported for MOE nodes. Skipping {self}.")
             return False
+        # Additional safety checks
+        try:
+            num_experts = len(node.args[3])
+        except Exception:
+            ad_logger.warning("Unable to determine num_experts for %s. Skipping.", self)
+            return False
+        if num_experts <= 0:
+            ad_logger.warning("num_experts <= 0 for %s. Skipping.", self)
+            return False
+        if self.world_size > 1 and (num_experts // self.world_size) == 0:
+            ad_logger.warning(
+                "world_size (%s) > num_experts (%s); experts_per_rank would be 0. Skipping %s.",
+                self.world_size, num_experts, self
+            )
+            return False
         return True
🧹 Nitpick comments (4)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (1)

304-313: Optional: instantiate the optimizer to allow future shared_config overrides

You directly call the optimizer constructor and immediately invoke it. To make future tweaks (e.g., shared_config edits, debug etc.) easier and mirror the pattern used in other tests, consider assigning it first.

- gm_transformed = InferenceOptimizer(
-     None,
-     {
-         "match_moe_pattern": {
-             "stage": "pattern_matcher",
-         },
-     },
- )(None, gm)
+ optimizer = InferenceOptimizer(
+     None,
+     {
+         "match_moe_pattern": {"stage": "pattern_matcher"},
+     },
+ )
+ gm_transformed = optimizer(None, gm)

And similarly for the fusion path.

Also applies to: 331-339

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

1-1: Add required NVIDIA copyright header at top of file.

Per repository guidelines, include the NVIDIA header with current year.

Apply:

+ # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ #
 from functools import partial
tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (1)

62-63: Fix typing of nodes_to_eliminate to match actual stored tuples.

You store (t_node, t_comp_node, has_contiguous) triples; annotate accordingly.

Apply:

-        nodes_to_eliminate: Set[Tuple[Node, Node]] = set()
+        nodes_to_eliminate: Set[Tuple[Node, Node, bool]] = set()
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)

171-180: Prefer specific aten overload for contiguous in pointwise_ops.

This avoids missed matches in FX graphs.

Apply:

         pointwise_ops = {
             torch.ops.aten.gelu,
             torch.ops.aten.leaky_relu,
             torch.ops.aten.mul,
             torch.ops.aten.relu,
             torch.ops.aten.sigmoid,
             torch.ops.aten.silu,
             torch.ops.aten.tanh,
-            torch.ops.aten.contiguous,
+            torch.ops.aten.contiguous.default,
         }
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 08169cd and c1e0da8.

📒 Files selected for processing (30)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/hf.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py (7 hunks)
💤 Files with no reviewable changes (2)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/init.py
🚧 Files skipped from review as they are similar to previous changes (14)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
  • tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rope.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/eliminate_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py
🧠 Learnings (3)
📚 Learning: 2025-08-06T21:22:55.018Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-06T21:22:55.018Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py
📚 Learning: 2025-08-09T02:04:49.580Z
Learnt from: Fridah-nv
PR: NVIDIA/TensorRT-LLM#6760
File: tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py:81-98
Timestamp: 2025-08-09T02:04:49.580Z
Learning: In TensorRT-LLM's auto_deploy module, torch.dtype values in configuration dictionaries must be stored as string representations (e.g., "float16" instead of torch.float16) because OmegaConf.merge does not support torch.dtype types. These string representations are converted to actual torch.dtype objects in downstream code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/hf.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (9)
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)

44-67: Clean refactoring to ModularInferenceOptimizer!

The migration from the legacy multi-stage transformation path to the streamlined ModularInferenceOptimizer flow is well-executed. The pre-optimizer configuration hooks for attention and weight loading are properly handled.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (2)

9-10: Migration to export→optimizer flow looks correct

torch_export_to_gm + InferenceOptimizer import updates align with the new transform pipeline. Stages used later in the file match the Stages enum semantics.


341-356: Parameter-count assertion after fusion is sensible

The fused graph check and param-count reduction assertion are consistent with expected MoE fusion behavior.

Also applies to: 360-360

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_redundant_transposes.py (2)

6-6: Import and test-harness refactor LGTM

Switch to run_test_transformed_gm + torch_export_to_gm + InferenceOptimizer is consistent with the new pipeline. Importing helper predicates (_is_transpose_op/_is_contiguous_op) for graph checks is acceptable in tests.

Also applies to: 9-15


221-229: Optimizer stage and validation wiring look correct

Using "eliminate_redundant_transposes" at stage "pattern_matcher" and validating via the helper is aligned with the intended flow.

Also applies to: 236-245

tensorrt_llm/_torch/auto_deploy/transform/interface.py (2)

51-58: SharedConfig introduction is a good abstraction

Centralizing global knobs (sharding_config, local_rank, world_size) enables consistent behavior across transforms. Defaults are sensible.


200-207: Propagating shared_config through BaseTransform.call is correct

Signature and call-site updates are consistent. Error handling preserves history and logging; cleanup happens post-apply as configured.

Also applies to: 249-257

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

169-175: Test harness wiring looks good

Combined graph check and parameter expectations are reasonable for validating sharding and local weight sizes.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (1)

82-89: Validation logic looks correct

Dist-op presence tied to world_size and param count scaling check are appropriate.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14833 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11203 completed with status: 'FAILURE'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Backlog
Development

Successfully merging this pull request may close these issues.

4 participants