Skip to content

Conversation

yyttt6
Copy link
Contributor

@yyttt6 yyttt6 commented Sep 26, 2025

Summary by CodeRabbit

  • New Features

    • Planner-driven vectorization for atomic-add loops that adapts to device capability and reports chosen vector size.
    • Improved lowering with fused parallel loops and dynamic layout inference to enable safer, more effective thread-level vectorization.
  • Refactor

    • Simplified vectorization API and decoupled planning from rewriting for clearer, more maintainable behavior.
    • Consolidated analysis and transformation into a streamlined pipeline for more reliable vectorization decisions.

Copy link
Contributor

coderabbitai bot commented Sep 26, 2025

Walkthrough

Refactors AtomicAdd vectorization into a planner-driven pipeline and rewrites AtomicAdd lowering to use SIMT fusion, layout inference, thread-loop partitioning, and planner-based vectorization; public VectorizeAtomicAdd signature is simplified and new planner/plan-result types are introduced.

Changes

Cohort / File(s) Summary of Changes
AtomicAdd vectorization planner & API
src/transform/atomicadd_vectorize.*
Replaces ad-hoc vectorization with a planner-based flow. Adds AtomicAddVectorizePlanner and AtomicAddVectorizePlanResult. Updates VectorizeAtomicAdd signature to (const For&, int). Moves vector-size inference to a PostOrderVisit over AtomicAdd calls and dtypes; planner returns vector_size, dynamic, and condition. Rewriter now constructed from plan result; removes thread-bound/stride parameters and emits a LOG(INFO) for chosen vector size.
AtomicAdd lowering pipeline rewrite
src/op/atomic_add.cc
Reworks Lower to build/fuse SIMT loops, collect loop nests, infer loop/buffer layout, plan vectorization via AtomicAddVectorizePlanner, adjust vector width for divisibility/coalescing, and invoke VectorizeAtomicAdd. Adds AtomicAddNode::InferLayout(...) and helper visitors/utilities (e.g., AtomicLoopNestCollector, ComputeLoopLayoutFromBuffer). Replaces previous inline For-node mutation path with planner+rewriter pipeline and new includes.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Caller
    participant Lower as AtomicAddNode::Lower
    participant SIMT as SIMT Builder/Fuser
    participant Collector as AtomicLoopNestCollector
    participant Layout as InferLayout
    participant Planner as AtomicAddVectorizePlanner
    participant Vectorize as VectorizeAtomicAdd
    participant Rewriter
    participant IR as Resulting IR

    Caller->>Lower: Lower(args, analyzer)
    Lower->>SIMT: Build + fuse SIMT loops
    SIMT-->>Lower: Fused loop
    Lower->>Collector: Collect loop nest & indices
    Collector-->>Lower: Loop metadata
    Lower->>Layout: ComputeLoopLayoutFromBuffer(...)
    Layout-->>Lower: Layout + guard predicate (optional)
    Lower->>Planner: Plan(For, compute_capability)
    Note right of Planner: Analyze AtomicAdd calls & dtypes → vector_size, dynamic, condition
    Planner-->>Vectorize: PlanResult
    Vectorize->>Rewriter: Rewrite using PlanResult
    Rewriter-->>Vectorize: Vectorized For
    Vectorize-->>Lower: Vectorized loop (maybe guarded)
    Lower-->>Caller: Lowered IR
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

Little paws plan vectors four and neat,
I hop through loops with tiny feet.
I fuse, infer layouts, pick a size,
Rewrite the hops with joyful eyes.
Atomic carrots, aligned and bright—🥕🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.53% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title indicates a fix for the atomicadd auto vectorize variable identification error, which aligns with a key bug addressed by the changeset, but it does not reflect the broader refactoring into a planner-based pipeline. It remains clear about the bug being fixed yet omits the significant restructuring of the vectorization logic.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

Summary of Changes

Hello @yyttt6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a bug in the atomicadd auto-vectorization pass, specifically addressing an error in how variables are identified within expressions. The core change involves refactoring the variable parsing logic to be more robust and accurate, ensuring correct vectorization. As a result of this fix, a previously failing and experimental test case related to tile_atomic_add has been re-enabled, confirming the resolution of the underlying issue.

Highlights

  • Refactored Variable Identification: The logic for identifying variables and their strides within AtomicAdd expressions has been significantly improved for robustness and accuracy during auto-vectorization.
  • Introduced ParseIndex Helper Function: A new lambda function, ParseIndex, was added to centralize and simplify the extraction of variables and their associated strides from complex expressions, making the parsing more reliable.
  • Re-enabled Test Case: A previously commented-out and failing test for tile_atomic_add functionality (test_tile_atomic_add) has been re-enabled, indicating the successful resolution of the underlying bug.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug in identifying variables for atomicadd auto-vectorization by introducing a more robust ParseIndex function. The changes are a definite improvement over the previous, more brittle implementation. I've identified a potential issue in how multiple AtomicAdd calls within a loop are handled, which could lead to incorrect behavior. My review includes a suggestion to make this logic more robust. Additionally, it's good to see that a previously failing test case has been re-enabled as part of this fix.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
testing/python/language/test_tilelang_language_atomic_add.py (1)

375-376: Enable test: good; consider making it hardware-agnostic (float16) to avoid cc>=90 dependency

AtomicAddx4 for float32 is only selected when compute capability >= 90. On CI GPUs < 90 (e.g., A100 cc=80), this path may not vectorize and could cause flakiness for the tile-atomic path. Two options:

  • Portable: call with float16 so vectorization is available broadly.
  • Alternatively, gate/skip on device capability.

Apply this minimal change for portability:

-def test_tile_atomic_add():
-    run_tile_atomic_add(8, 128, 128, 32, 32)
+def test_tile_atomic_add():
+    run_tile_atomic_add(8, 128, 128, 32, 32, dtype="float16")

Also, consider removing or gating the debug prints in run_tile_atomic_add to keep test output clean (prints at Lines 58, 72, 73).

src/transform/atomicadd_vectorize.cc (4)

322-347: ParseIndex is too strict; accept const-expr strides and avoid false negatives

Requiring exactly one MulNode with a Var and an IntImm will miss common canonical forms:

  • Stride may be a foldable const expr or come via casts (not a bare IntImm).
  • Extra harmless multiplies like x*1 can appear pre-simplification.
  • You only need a unique var*const match; other non-relevant muls shouldn’t invalidate the parse.

Refine by simplifying first, using as_const_int, and relaxing the check to “exactly one legal var*const mul” regardless of other muls:

-  auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out,
-                       int &stride_out) -> bool {
+  auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out,
+                       int &stride_out) -> bool {
     int mul_count = 0, legal_mul_count = 0;
     stride_out = -1;
     var_out = PrimExpr();
-    PostOrderVisit(idx, [&](const ObjectRef &obj) {
+    // Simplify to eliminate x*1 and fold-able constants.
+    arith::Analyzer az;
+    PrimExpr sidx = az.Simplify(idx);
+    PostOrderVisit(sidx, [&](const ObjectRef &obj) {
       if (const MulNode *mul = obj.as<MulNode>()) {
         mul_count++;
-        const VarNode *var = nullptr;
-        const IntImmNode *imm = nullptr;
-        if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) {
-          var_out = mul->a;
-          stride_out = imm->value;
-          legal_mul_count++;
-        } else if ((var = mul->b.as<VarNode>()) &&
-                   (imm = mul->a.as<IntImmNode>())) {
-          var_out = mul->b;
-          stride_out = imm->value;
-          legal_mul_count++;
-        }
+        const VarNode *var = nullptr;
+        const int64_t *c = nullptr;
+        if ((var = mul->a.as<VarNode>()) && (c = as_const_int(mul->b))) {
+          var_out = mul->a;
+          stride_out = static_cast<int>(*c);
+          legal_mul_count++;
+        } else if ((var = mul->b.as<VarNode>()) && (c = as_const_int(mul->a))) {
+          var_out = mul->b;
+          stride_out = static_cast<int>(*c);
+          legal_mul_count++;
+        }
       }
     });
-    if (mul_count == 1 && legal_mul_count == 1)
-      return true;
-    return false;
+    return legal_mul_count == 1;
   };

Note: this uses as_const_int and simplification. If not already available, include tvm/arith/analyzer.h (already included).


362-368: Accumulate vectorize_size_max across multiple AtomicAdd sites

If the loop body contains multiple AtomicAdd calls, you currently overwrite vectorize_size_max. Prefer taking the max to avoid under-vectorizing later calls.

-          DataType dtype = bufload->dtype;
-          vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
+          DataType dtype = bufload->dtype;
+          vectorize_size_max = std::max(
+              vectorize_size_max, GetVectorizeSizeMax(compute_capability, dtype));

You’ll need:

  • Add at top: #include

362-368: Guard against mis-identifying non-block vars as bx/by

ParseIndex will happily return any varconst (e.g., loop i1). Before accepting, assert the extracted vars are actual block indices (thread/block bindings) for safety, otherwise bail out. For example:

  • Verify var_out.as() is bound in thread_binding as blockIdx.{x,y} (or matches expected bx/by symbols in this pass’ context).
  • If that metadata isn’t available here, at least ensure both extracted vars differ and are not the loop var inside inner_for_.

This avoids rewriting with incorrect axes on more complex index expressions.


380-385: Extra sanity checks before rewriting

Before constructing the rewriter:

  • Ensure bx_var and by_var are Vars: if (!bx_var.as() || !by_var.as()) return for_node;
  • Optionally ensure stride_x > 0 && stride_y > 0.

This prevents emitting malformed truncdiv/truncmod expressions.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ec24561 and 5a02ea8.

📒 Files selected for processing (2)
  • src/transform/atomicadd_vectorize.cc (2 hunks)
  • testing/python/language/test_tilelang_language_atomic_add.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: format-check
  • GitHub Check: bot-task

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atomic_add.py (1)

375-377: Remove noisy debug prints before re-enabling this test

Reactivating test_tile_atomic_add now exercises run_tile_atomic_add, which still contains print(kernel.get_kernel_source()) plus dumps of both 128×128 tensors. That’s hundreds of thousands of characters on every run and will swamp CI logs without adding assertion value. Please drop or gate those prints before merging.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5a02ea8 and 1255ccb.

📒 Files selected for processing (1)
  • testing/python/language/test_tilelang_language_atomic_add.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: bot-task

For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
const Range &thread_bounds, int compute_capability) {

auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add some comments for this function.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/atomic_add.cc (1)

35-46: Code duplication: GetArchInt is duplicated from src/target/utils.cc.

The GetArchInt function is duplicated from src/target/utils.cc (lines 17-25) with a different implementation. The local version includes a fallback to return 0 for non-sm_ architectures, while the version in src/target/utils.cc uses ICHECK to enforce the sm_ prefix.

Consider one of the following approaches:

  1. Preferred: Import and use the existing GetArchInt from src/target/utils.cc if the stricter validation is acceptable, or
  2. Update the version in src/target/utils.cc to include the fallback behavior and use it consistently across the codebase.

Apply this diff to use the existing function:

-static int GetArchInt(Target target) {
-  int arch_int = 0;
-  auto s = target->GetAttr<String>("arch");
-  ICHECK(s.defined());
-  std::string arch = s.value();
-  if (arch.rfind("sm_", 0) == 0) {
-    arch_int = std::stoi(arch.substr(3));
-  } else {
-    arch_int = 0;
-  }
-  return arch_int;
-}

And update the include at the top of the file if not already present:

 #include "../target/utils.h"
🧹 Nitpick comments (3)
src/op/atomic_add.cc (1)

372-372: Consider removing or adjusting the log level.

The LOG(INFO) statement prints the vectorized loop IR to the console. This may be useful during development but could be noisy in production.

Consider one of the following:

  1. Remove the log statement if it was added for debugging purposes only.
  2. Change to VLOG(1) or a higher verbosity level to reduce noise in production logs.
  3. If this is intentional diagnostic output, add a comment explaining why it's logged at INFO level.

Apply this diff to change to verbose logging:

-  LOG(INFO) << vectorized_thread_loop;
+  VLOG(1) << "Vectorized thread loop: " << vectorized_thread_loop;
src/transform/atomicadd_vectorize.cc (2)

33-55: Consider adding documentation for BufferIndiceSimplify.

The BufferIndiceSimplify class lacks documentation. Adding a brief comment explaining its purpose would improve maintainability.

Apply this diff to add documentation:

+/// \brief Simplifies buffer load and store indices using an analyzer.
+///
+/// This mutator visits BufferLoad and BufferStore nodes and simplifies
+/// their indices by applying the analyzer's Simplify method to each index.
 class BufferIndiceSimplify : public StmtExprMutator {

174-231: Consider adding documentation for the run() method.

The run() method implements complex loop transformation logic but lacks documentation explaining the transformation steps and the role of loop_layout and analyzer.

Apply this diff to add documentation:

+  /// \brief Transform and vectorize the for loop using the provided layout.
+  ///
+  /// \param for_node The original For loop to transform
+  /// \param loop_layout Fragment describing the loop layout transformation
+  /// \param analyzer Analyzer for simplifying indices and binding loop variables
+  /// \return Transformed and vectorized For loop
   For run(For for_node, const Fragment &loop_layout,
           arith::Analyzer *analyzer) {
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 202add4 and 6ceb5e0.

📒 Files selected for processing (3)
  • src/op/atomic_add.cc (1 hunks)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
  • src/transform/atomicadd_vectorize.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (2)
src/transform/atomicadd_vectorize.cc (2)
  • VectorizeAtomicAdd (308-343)
  • VectorizeAtomicAdd (308-310)
src/target/utils.cc (2)
  • GetArchInt (18-26)
  • GetArchInt (18-18)
src/transform/atomicadd_vectorize.cc (1)
src/transform/loop_partition.cc (5)
  • BufferIndiceSimplify (38-38)
  • LoopPragmaUnroll (201-205)
  • LoopPragmaUnroll (201-201)
  • PartitionLoop (61-105)
  • PartitionLoop (61-62)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: bot-task
  • GitHub Check: format-check
  • GitHub Check: format-check
🔇 Additional comments (5)
src/op/atomic_add.cc (1)

364-366: Call site correctly updated.

The call to VectorizeAtomicAdd now includes the required analyzer and loop_layout parameters, matching the updated API signature.

src/transform/atomicadd_vectorize.cc (3)

33-55: BufferIndiceSimplify implementation looks correct.

The mutator correctly simplifies indices for both BufferLoad and BufferStore nodes using the provided analyzer. The implementation properly uses CopyOnWrite() and maps indices through the analyzer's Simplify method.


308-343: VectorizeAtomicAdd implementation looks correct.

The function properly:

  1. Determines the maximum vectorization size based on compute capability and data type.
  2. Uses PartitionLoop to simplify the loop before planning.
  3. Plans the vectorization with the planner.
  4. Returns the original loop if vectorization is not beneficial (vectorize_hint == 1).
  5. Applies the vectorization rewrite with the provided analyzer and loop layout.

174-231: Verify vectorization across all loop levels and add tests. The run() method divides every transformed loop extent by vector_size_, not just the innermost; no existing tests cover this behavior—please confirm this is intentional and add multi-level loop vectorization test cases.

src/transform/atomicadd_vectorize.h (1)

17-19: Approve VectorizeAtomicAdd signature update
All call sites, including src/op/atomic_add.cc, have been updated to match the new parameters.

@yyttt6
Copy link
Contributor Author

yyttt6 commented Oct 3, 2025

I refactored part of the AtomicAdd auto vectorization code.
This change fixes the previous error of incorrect variable matching, and makes the AtomicAdd auto vectorization more robust and adaptable to a wider range of cases.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4526725 and 2dab926.

📒 Files selected for processing (4)
  • src/op/atomic_add.cc (1 hunks)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
  • src/transform/atomicadd_vectorize.h (1 hunks)
  • tilelang/language/atomic.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/transform/atomicadd_vectorize.h
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/atomicadd_vectorize.cc (1)
src/transform/loop_partition.cc (5)
  • BufferIndiceSimplify (38-38)
  • LoopPragmaUnroll (201-205)
  • LoopPragmaUnroll (201-201)
  • PartitionLoop (61-105)
  • PartitionLoop (61-62)
src/op/atomic_add.cc (1)
src/transform/atomicadd_vectorize.cc (2)
  • VectorizeAtomicAdd (308-343)
  • VectorizeAtomicAdd (308-310)
tilelang/language/atomic.py (1)
tilelang/language/customize.py (1)
  • buffer_load_to_tile_region (46-68)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: format-check

Comment on lines +249 to 205
const BufferLoad dst_node =
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
const BufferLoad value_node =
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Downcast expects an ObjectRef, not a raw node pointer

Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) doesn’t compile—the overload requires an ObjectRef, but as<...>() returns const BufferLoadNode*. Use the original PrimExpr instead (we already checked it’s a BufferLoad), e.g.:

-            const BufferLoad dst_node =
-                Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
-            const BufferLoad value_node =
-                Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
+            const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]);
+            const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);

Without this change the file fails to build.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
const BufferLoad dst_node =
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
const BufferLoad value_node =
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]);
const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 249 to 253, the code calls
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) and similarly for
args[2], but Downcast expects an ObjectRef not a raw node pointer; replace the
.as<BufferLoadNode>() calls and pass the original PrimExprs (node->args[1] and
node->args[2]) directly to Downcast<BufferLoad>(), relying on the existing type
checks that confirmed these are BufferLoad instances so the Downcast will be
valid.

Comment on lines 193 to 211
src_extent = list(get_extent(value))
dst_extent = list(get_extent(dst))
legal = True

if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent):
legal = False
elif (dst_extent and src_extent):
if len(dst_extent) > len(src_extent):
dst_extent_dims = [x for x in dst_extent if x != 1]
if dst_extent_dims != src_extent:
legal = False
else:
if dst_extent != src_extent:
legal = False
else:
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(dst_extent, src_extent)
dst_extent = src_extent = extent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard get_extent() results before wrapping in list(...)

get_extent() still returns None for scalar PrimExpr inputs (e.g., atomic_add(dst, 1)), so list(get_extent(...)) raises a TypeError before we can fall back to the extern path. This regresses the scalar code path.

Please keep the raw result, check for None, and only convert to list when defined before the length/shape logic.

-    src_extent = list(get_extent(value))
-    dst_extent = list(get_extent(dst))
+    src_extent_raw = get_extent(value)
+    dst_extent_raw = get_extent(dst)
+    src_extent = list(src_extent_raw) if src_extent_raw is not None else None
+    dst_extent = list(dst_extent_raw) if dst_extent_raw is not None else None
     legal = True
 
-    if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent):
+    if dst_extent is None and src_extent is None:
+        legal = False
+    elif dst_extent is None:
+        dst_extent = [1] * len(src_extent)
+    elif src_extent is None:
+        src_extent = [1] * len(dst_extent)
+    elif len(dst_extent) < len(src_extent):
         legal = False

Make sure the remaining branches avoid len(None) as well.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
src_extent = list(get_extent(value))
dst_extent = list(get_extent(dst))
legal = True
if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent):
legal = False
elif (dst_extent and src_extent):
if len(dst_extent) > len(src_extent):
dst_extent_dims = [x for x in dst_extent if x != 1]
if dst_extent_dims != src_extent:
legal = False
else:
if dst_extent != src_extent:
legal = False
else:
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(dst_extent, src_extent)
dst_extent = src_extent = extent
src_extent_raw = get_extent(value)
dst_extent_raw = get_extent(dst)
src_extent = list(src_extent_raw) if src_extent_raw is not None else None
dst_extent = list(dst_extent_raw) if dst_extent_raw is not None else None
legal = True
if dst_extent is None and src_extent is None:
legal = False
elif dst_extent is None:
dst_extent = [1] * len(src_extent)
elif src_extent is None:
src_extent = [1] * len(dst_extent)
elif len(dst_extent) < len(src_extent):
legal = False
elif dst_extent and src_extent:
if len(dst_extent) > len(src_extent):
dst_extent_dims = [x for x in dst_extent if x != 1]
if dst_extent_dims != src_extent:
legal = False
else:
if dst_extent != src_extent:
legal = False
else:
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(dst_extent, src_extent)
dst_extent = src_extent = extent
🤖 Prompt for AI Agents
In tilelang/language/atomic.py around lines 193 to 211, guard the raw results of
get_extent() before calling list(...) to avoid TypeError on scalar PrimExprs:
assign src_extent_raw = get_extent(value) and dst_extent_raw = get_extent(dst),
check if each is None before converting to list; only call list(...) when the
raw extent is not None, and ensure every branch uses None-checks rather than
len(None) (e.g., when one side is None treat it as a scalar extent
[1]*len(other) or follow the extern-path fallback), then perform the same
dimension comparisons and the elementwise max logic operating on actual lists so
no branch ever calls len(...) or iterates over a None.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/transform/atomicadd_vectorize.cc (1)

52-205: Fix AtomicAdd rewrites: size guard and correct Downcast usage.

Two problems here break compilation and correctness:

  1. We only check args.size() >= 2, yet we always read args[2]; AtomicAdd calls must therefore guard for >= 3 before touching the value operand.
  2. Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) is invalid—the overload expects an ObjectRef, not a raw node pointer. This reintroduces the build failure from the previous review.

Please tighten the guards and pass the original PrimExpr to Downcast:

-  if (node->op == builtin::call_extern() && node->args.size() >= 2) {
+  if (node->op == builtin::call_extern() && node->args.size() >= 3) {
     if (const auto *func_name = node->args[0].as<StringImmNode>()) {
       if (func_name->value == "AtomicAdd") {
         const BufferLoadNode *temp_dst_node =
             node->args[1].as<BufferLoadNode>();
         const BufferLoadNode *temp_value_node =
             node->args[2].as<BufferLoadNode>();
         if (!temp_dst_node || !temp_value_node) {
           return StmtExprMutator::VisitExpr_(node);
         }
-        const BufferLoad dst_node =
-            Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
-        const BufferLoad value_node =
-            Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
+        const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]);
+        const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between badf9c1 and c317f26.

📒 Files selected for processing (3)
  • src/op/atomic_add.cc (2 hunks)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
  • src/transform/atomicadd_vectorize.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/atomicadd_vectorize.h (2)
src/transform/loop_vectorize.h (2)
  • tvm (31-49)
  • tl (32-48)
src/transform/atomicadd_vectorize.cc (17)
  • VectorizeAtomicAdd (236-246)
  • VectorizeAtomicAdd (236-236)
  • AtomicAddVectorizePlanner (15-15)
  • Plan (17-44)
  • Plan (18-18)
  • node (162-184)
  • node (162-162)
  • node (186-228)
  • node (186-186)
  • VisitStmt_ (46-49)
  • VisitStmt_ (46-46)
  • VisitExpr_ (51-71)
  • VisitExpr_ (51-51)
  • GetVectorizeSizeMax (73-85)
  • GetVectorizeSizeMax (73-74)
  • UpdateVectorSize (87-127)
  • UpdateVectorSize (87-88)
src/transform/atomicadd_vectorize.cc (2)
src/transform/atomicadd_vectorize.h (1)
  • AtomicAddVectorizePlanner (36-55)
src/transform/loop_vectorize.cc (4)
  • indices (157-189)
  • indices (157-157)
  • IndiceCanVectorize (257-298)
  • IndiceCanVectorize (257-259)
src/op/atomic_add.cc (6)
src/op/parallel.cc (8)
  • Lower (184-187)
  • Lower (184-185)
  • VisitStmt_ (130-146)
  • VisitStmt_ (130-130)
  • VisitStmt_ (148-160)
  • VisitStmt_ (148-148)
  • VisitExpr_ (162-173)
  • VisitExpr_ (162-162)
src/op/copy.cc (6)
  • Lower (791-823)
  • Lower (791-791)
  • Lower (1776-1898)
  • Lower (1776-1777)
  • MakeSIMTLoop (299-344)
  • MakeSIMTLoop (299-299)
src/op/fill.cc (4)
  • Lower (171-206)
  • Lower (171-171)
  • MakeSIMTLoop (136-151)
  • MakeSIMTLoop (136-136)
src/op/reduce.cc (4)
  • Lower (152-318)
  • Lower (152-152)
  • Lower (413-437)
  • Lower (413-413)
src/target/utils.cc (2)
  • GetArchInt (18-26)
  • GetArchInt (18-18)
src/transform/atomicadd_vectorize.cc (6)
  • VisitStmt_ (46-49)
  • VisitStmt_ (46-46)
  • VisitExpr_ (51-71)
  • VisitExpr_ (51-51)
  • VectorizeAtomicAdd (236-246)
  • VectorizeAtomicAdd (236-236)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check

Comment on lines +27 to +33
const auto *func_name = call->args[0].as<StringImmNode>();
if (func_name && func_name->value == "AtomicAdd") {
DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard the dtype probe before dereferencing.

call->args[1] is not guaranteed to be a BufferLoad (it can be wrapped in an IfThenElse by the lowering pipeline), so as<BufferLoadNode>() may return nullptr. Dereferencing it here will crash during planning. Please bail out unless the cast succeeds, e.g.:

-          DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
-          vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
+          if (const auto *dst = call->args[1].as<BufferLoadNode>()) {
+            vectorize_size_max =
+                GetVectorizeSizeMax(compute_capability, dst->dtype);
+          }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
const auto *func_name = call->args[0].as<StringImmNode>();
if (func_name && func_name->value == "AtomicAdd") {
DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
}
}
}
const auto *func_name = call->args[0].as<StringImmNode>();
if (func_name && func_name->value == "AtomicAdd") {
if (const auto *dst = call->args[1].as<BufferLoadNode>()) {
vectorize_size_max =
GetVectorizeSizeMax(compute_capability, dst->dtype);
}
}
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 27 to 33, the code
dereferences call->args[1].as<BufferLoadNode>() without checking it may be
nullptr (it can be an IfThenElse or other node), which can crash; store the
result of as<BufferLoadNode>() in a local pointer, check it is non-null before
accessing dtype, and if the cast fails simply skip setting vectorize_size_max
(bail out of this branch) so you don't dereference a null pointer.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c317f26 and 12e437c.

📒 Files selected for processing (2)
  • src/op/atomic_add.cc (3 hunks)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (5)
src/op/parallel.cc (8)
  • InferLayout (212-527)
  • InferLayout (212-213)
  • VisitStmt_ (130-146)
  • VisitStmt_ (130-130)
  • VisitStmt_ (148-160)
  • VisitStmt_ (148-148)
  • VisitExpr_ (162-173)
  • VisitExpr_ (162-162)
src/op/copy.cc (6)
  • InferLayout (399-498)
  • InferLayout (399-400)
  • InferLayout (1955-1958)
  • InferLayout (1955-1956)
  • MakeSIMTLoop (299-344)
  • MakeSIMTLoop (299-299)
src/op/parallel.h (1)
  • ParallelOp (151-158)
src/target/utils.cc (2)
  • GetArchInt (18-26)
  • GetArchInt (18-18)
src/transform/atomicadd_vectorize.cc (6)
  • VisitStmt_ (45-48)
  • VisitStmt_ (45-45)
  • VisitExpr_ (50-70)
  • VisitExpr_ (50-50)
  • VectorizeAtomicAdd (235-244)
  • VectorizeAtomicAdd (235-235)
src/transform/atomicadd_vectorize.cc (2)
src/transform/atomicadd_vectorize.h (1)
  • AtomicAddVectorizePlanner (36-55)
src/transform/loop_vectorize.cc (4)
  • indices (157-189)
  • indices (157-157)
  • IndiceCanVectorize (257-298)
  • IndiceCanVectorize (257-259)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-metal

Comment on lines +401 to +409
auto GetArchInt = [&](const Target &tgt) -> int {
int arch_int = 0;
if (auto s = tgt->GetAttr<String>("arch")) {
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0)
arch_int = std::stoi(arch.substr(3));
}
return arch_int;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Drop the shadowing GetArchInt lambda.

We already have a file-scope GetArchInt(Target) (Lines 37-48). Redefining an identical lambda here is redundant, risks divergence, and silently bypasses the shared helper (e.g. any future validation fixes). Call the existing function instead of shadowing it locally.

🤖 Prompt for AI Agents
In src/op/atomic_add.cc around lines 401 to 409, there is a locally defined
GetArchInt lambda that shadows an existing file-scope GetArchInt(Target) (lines
37-48); remove this redundant lambda and replace any uses in this scope with a
direct call to the file-scope GetArchInt(tgt) helper so the shared
implementation (and any future validations) are preserved; ensure the lambda
definition is deleted and all call sites use GetArchInt(tgt) without adding new
duplicates.

Comment on lines +517 to +533
if (plan.dynamic && plan.condition.defined()) {
pred = plan.condition;
}
DLOG(INFO) << "[AtomicAddInferLayout] vec=" << vec
<< " loop_layout=" << loop_layout->DebugOutput();
return {loop_layout, pred};
};

auto ret = AtomicAddInferLayout(transformed_loop,
{T.target, T.thread_bounds, T.layout_map,
analyzer, false, T.buffer_remap});
Fragment loop_layout = ret.loop_layout;
auto thread_loop =
PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout);
auto vectorized_thread_loop =
VectorizeAtomicAdd(thread_loop, GetArchInt(target));
return vectorized_thread_loop;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Propagate planner predicate for dynamic vectorization.

AtomicAddVectorizePlanner::Plan can return dynamic=true with condition set. We drop that information: VectorizeAtomicAdd may leave the loop unvectorized, but when dynamic is true the planner expects us to wrap the vectorized body with the guard before executing. Please reintroduce the predicate (e.g. wrap vectorized_thread_loop with IfThenElse(pred, body) when present) so dynamic plans stay correct.

🤖 Prompt for AI Agents
In src/op/atomic_add.cc around lines 517 to 533, the planner predicate produced
for dynamic vectorization is being dropped after AtomicAddInferLayout; capture
the returned predicate from ret and, if it is defined, wrap the final
vectorized_thread_loop in an If/IfThenElse guard that uses that predicate so the
dynamic plan executes the guarded vectorized body only when the condition holds;
specifically, obtain pred from the planner return, and replace the direct return
of vectorized_thread_loop with a conditional wrapper (IfThenElse(pred,
vectorized_thread_loop)) when pred is present, otherwise return the loop
unchanged.

Comment on lines +29 to +30
DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard the BufferLoad cast before dereferencing.

call->args[1].as<BufferLoadNode>() can return nullptr (planner already saw AtomicAdd nodes wrapped by predicates elsewhere). Dereferencing without a check crashes the planner on valid inputs. Bail out unless the cast succeeds before touching dtype.

🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 29-30, the code
dereferences call->args[1].as<BufferLoadNode>() without checking for null which
can crash; modify the code to first store the result of
call->args[1].as<BufferLoadNode>() into a local pointer, check if it is nullptr,
and if so bail out (e.g., return false / skip vectorization) before accessing
dtype; only call GetVectorizeSizeMax when the cast succeeded.

Comment on lines 194 to +203
node->args[1].as<BufferLoadNode>();
const BufferLoadNode *old_value_node =
const BufferLoadNode *temp_value_node =
node->args[2].as<BufferLoadNode>();
if (!old_dst_node || !old_value_node) {
if (!temp_dst_node || !temp_value_node) {
return StmtExprMutator::VisitExpr_(node);
}
Array<PrimExpr> dst_indices, value_indices;
if ((extent_tx_ * vector_size_) > stride_x_) {
dst_indices.push_back(
by_var_ * stride_y_ +
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_ * vector_size_, stride_x_));
value_indices.push_back(
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
} else {
dst_indices.push_back(
by_var_ * stride_y_ +
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
value_indices.push_back(
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
}
const BufferLoad dst_node =
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
const BufferLoad value_node =
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Pass the original PrimExprs to Downcast.

Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) calls the overload with a raw pointer, which doesn’t compile. We already confirmed the operands are BufferLoad, so pass node->args[1] / node->args[2] directly to Downcast. This matches the required ObjectRef signature and fixes the build.

🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 194 to 203, the code calls
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) and
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()), which uses the
raw-pointer overload and fails to compile; replace these with
Downcast<BufferLoad>(node->args[1]) and Downcast<BufferLoad>(node->args[2])
respectively so the ObjectRef (PrimExpr) overload is used (keep the preceding
as<BufferLoadNode>() checks to validate the type).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants