Skip to content

Conversation

yyttt6
Copy link
Contributor

@yyttt6 yyttt6 commented Aug 27, 2025

Summary by CodeRabbit

  • New Features
    • Enhanced vectorized atomic-add path that leverages thread/block context and stride info to enable wider, address-based atomic operations for better performance where applicable.
  • Refactor
    • Reworked vectorization to be more context-aware of thread bounds and block/stripe strides, with a safe fallback when vectorization isn’t applicable.
  • Tests
    • Existing runtime validations preserved to ensure correctness for dynamic and edge cases.
  • Chores
    • Minor header comment typo corrected.

Copy link
Contributor

coderabbitai bot commented Aug 27, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Routes AtomicAdd lowering through a thread-aware VectorizeAtomicAdd and rewrites the AtomicAdd vectorization pass to accept thread/block context and strides, compute TX extent, replace outer-var with a new iter_var, emit address-based AtomicAddx2/x4 calls, and preserve fallback/dynamic checks and predicates.

Changes

Cohort / File(s) Summary
AtomicAdd Lowering Entry
src/op/atomic_add.cc
Replaces prior VectorizeLoop call with VectorizeAtomicAdd(thread_loop, thread_var, thread_bounds, GetArchInt(target)); removes related TODO/comments; retains predicate wrapping and fallback return.
AtomicAdd Vectorization Rewriter
src/transform/atomicadd_vectorize.cc
Expands AtomicAddVectorizeRewriter constructor signature to accept thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x; computes extent_tx_; introduces iter_var_ and rewrites inner For to iterate extent/vector_size; removes runtime tx search and uses provided tx_var_; detects bx/by multipliers and strides via AST traversal; builds dst/value indices using by/bx, iter_var_, tx_var_, and strides; emits address_of(BufferLoad(...)) and calls AtomicAddx2/AtomicAddx4; preserves dynamic/non-vectorized path and updates a truncmod condition.

Sequence Diagram(s)

sequenceDiagram
  participant Lower as AtomicAdd::Lower
  participant VPass as VectorizeAtomicAdd
  participant Rewriter as AtomicAddVectorizeRewriter
  participant IR as IR Builder

  Lower->>VPass: VectorizeAtomicAdd(thread_loop, thread_var, thread_bounds, arch_int)
  Note right of VPass #DDFFDD: Analyze loop AST to detect bx/by multipliers & strides
  VPass->>Rewriter: Init(plan, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x)
  Rewriter->>IR: Replace inner For with For(iter_var, 0, extent/vector_size)
  Rewriter->>IR: Compute dst/value indices using by/bx, iter_var, tx_var, strides
  Rewriter->>IR: Create address_of(BufferLoad(dst)) and address_of(BufferLoad(val))
  Rewriter->>IR: Emit AtomicAddx2 / AtomicAddx4 calls
  VPass-->>Lower: return vectorized_thread_loop
  alt predicate present
    Lower->>IR: Wrap vectorized loop with predicate
  end
  Lower-->>Lower: Return transformed loop
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45–75 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

I hop through loops with nimble paws,
BY and BX guide my tiny cause.
iter_var, tx — a tidy crew,
I stitch addresses, add by two. 🥕

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbit in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbit in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbit gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbit read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbit help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbit ignore or @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbit summary or @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbit or @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yyttt6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a bug in the automatic vectorization of atomic add operations, which was previously causing "negative optimization." The changes re-enable and correct the VectorizeAtomicAdd pass by providing the rewriter with necessary context about thread variables, block variables, and memory strides. This ensures that vectorized atomic operations are correctly indexed and performed, improving performance and correctness.

Highlights

  • Re-enabled AtomicAdd Vectorization: The VectorizeAtomicAdd pass is now actively used in the AtomicAdd::Lower method, replacing a generic loop vectorizer and removing a "buggy implementation" TODO comment.
  • Enhanced Rewriter Context: The AtomicAddVectorizeRewriter constructor has been updated to accept additional parameters, including thread variables, block variables (bx, by), thread bounds, and memory strides (stride_x, stride_y), providing more context for accurate vectorization.
  • Corrected Memory Indexing for Vectorization: The core logic for calculating memory indices within the AtomicAddVectorizeRewriter has been significantly revised. It now dynamically computes destination and value buffer indices based on thread extents, vector size, and strides, ensuring correct memory access patterns for vectorized atomic adds.
  • Improved Stride and Block Variable Detection: The VectorizeAtomicAdd function now includes logic to identify and extract bx, by, stride_x, and stride_y values from the loop body, which are then passed to the rewriter for precise memory address calculation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables auto-vectorization for atomic add operations, which was previously disabled. The core of the change is in src/transform/atomicadd_vectorize.cc, where a more sophisticated index remapping logic is introduced to handle vectorized accesses correctly. While the overall direction is good, I've identified a critical issue that could lead to a crash and another high-severity bug that could cause the optimization to fail silently. Please see the detailed comments for suggestions on how to address these.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line can cause a crash. node->args[2] (the value being added) can be an if_then_else node if there's a predicate on the source buffer access. In that case, as<BufferLoadNode>() will return nullptr, and dereferencing it will cause a crash.

You should add a check before this line to ensure the cast is successful, for example:

const auto* value_load_node = node->args[2].as<BufferLoadNode>();
ICHECK(value_load_node) << "The value for AtomicAdd is expected to be a BufferLoad, but got " << node->args[2]->GetTypeKey();

And then use *value_load_node here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for finding the stride and block index variable is brittle. It only checks if mul->a is a VarNode and mul->b is an IntImmNode. Due to the commutative property of multiplication, the expression could also be imm * var. The current code would fail to detect that case, causing the vectorization to be skipped silently. You should handle both cases to make the logic more robust.

    if (const MulNode *mul = obj.as<MulNode>()) {
      const VarNode* var = nullptr;
      const IntImmNode* imm = nullptr;
      PrimExpr var_expr;
      if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) {
        var_expr = mul->a;
      } else if ((var = mul->b.as<VarNode>()) && (imm = mul->a.as<IntImmNode>())) {
        var_expr = mul->b;
      }

      if (var && imm) {
        if (var->name_hint == "bx") {
          stride_x = imm->value;
          bx_var = var_expr;
        } else if (var->name_hint == "by") {
          stride_y = imm->value;
          by_var = var_expr;
        }
      }
    }

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/op/atomic_add.cc (1)

24-35: Remove duplicate GetArchInt; use the shared util to avoid silent mis-detections.

This static redefinition weakens validation (falls back to 0) and shadows the canonical GetArchInt in ../target/utils.h, risking inconsistent behavior. Prefer the shared implementation.

Apply this diff to remove the duplicate:

- static int GetArchInt(Target target) {
-   int arch_int = 0;
-   auto s = target->GetAttr<String>("arch");
-   ICHECK(s.defined());
-   const char *arch_str = s.value().c_str();
-   if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
-     arch_int = atoi(&arch_str[3]);
-   } else {
-     arch_int = 0;
-   }
-   return arch_int;
- }
+ // Use GetArchInt from ../target/utils.h
src/transform/atomicadd_vectorize.cc (1)

260-272: Verify and Complete CUDA Externs for AtomicAddx2/x4

  • We confirmed that both AtomicAddx2 and AtomicAddx4 are defined in src/tl_templates/cuda/common.h with the expected signature
    (TL_DEVICE void AtomicAddx[N](<T>* address, <T>* val)) and parameter order (dst pointer first, then value pointer), matching your call-site in atomicadd_vectorize.cc.
  • However, while AtomicAddx2 is specialized for half_t, bfloat16_t, and float, the AtomicAddx4 overload only exists for float:
    • Missing AtomicAddx4(half_t*, half_t*) and AtomicAddx4(bfloat16_t*, bfloat16_t*) specializations.
    • If vector_size_ == 4 for half- or bfloat16-typed data, you’ll get unresolved externs at link time.

Please add the missing AtomicAddx4 device functions for half_t and bfloat16_t in src/tl_templates/cuda/common.h (mirroring the pattern used for AtomicAddx2), or constrain vector_size_ to only emit AtomicAddx4 for types with existing CUDA support.

🧹 Nitpick comments (3)
src/op/atomic_add.cc (1)

4-4: Typo in file header.

“Elment-wise” → “element-wise”.

- * Define elment-wise operators.
+ * Define element-wise operators.
src/transform/atomicadd_vectorize.cc (2)

126-129: Confirm FloorMod availability; prefer truncmod for consistency.

The planner uses FloorMod(...) but elsewhere uses truncdiv/truncmod. If FloorMod isn’t defined in this context, build will fail.

Would you like me to swap to truncmod for consistency?

-      condition_ = (FloorMod(offset, vector_size_) == 0);
+      condition_ = (truncmod(offset, vector_size_) == 0);

316-319: Heuristic extraction of bx/by and strides is fragile (Mul order, non-const).

Current scan only catches patterns like bx * C and by * C with MulNode.a = Var and MulNode.b = IntImm. It misses C * bx or non-const strides.

Consider normalizing commutative Mul operands and accepting both operand orders; fall back to non-vectorized path if strides aren’t IntImm.

Also applies to: 356-358

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between e05a20a and 5e0d9550096139baf220c468da891ec1a3884675.

📒 Files selected for processing (2)
  • src/op/atomic_add.cc (1 hunks)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (2)
src/transform/atomicadd_vectorize.cc (2)
  • VectorizeAtomicAdd (313-362)
  • VectorizeAtomicAdd (313-314)
src/target/utils.cc (2)
  • GetArchInt (18-26)
  • GetArchInt (18-18)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (5)
  • ret (1656-1669)
  • call_extern (172-194)
  • truncdiv (2918-2941)
  • truncmod (2944-2967)
  • address_of (463-479)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: build-test-nvidia
🔇 Additional comments (2)
src/op/atomic_add.cc (1)

200-201: Good: vectorization is now capability-aware.

Switching to

VectorizeAtomicAdd(thread_loop, thread_var, thread_bounds, GetArchInt(target))

matches the new rewriter API.

⚠️ Please ensure that every Target reaching this pass has its "arch" attribute populated (for example, "sm_90"), otherwise GetArchInt(target) will ICHECK-fail with a zero or undefined compute capability. I did not locate any C++ code that sets a default for target->attrs["arch"], so verify its initialization in the Python binder or the C++ Target constructor.

src/transform/atomicadd_vectorize.cc (1)

182-201: Manually confirm that only AtomicAdd operations appear in the vectorized loop body

Our search for loops combining AtomicAdd calls with any other buffer loads or stores returned no hits, but absence of evidence is not evidence of absence. Since the rewrite at src/transform/atomicadd_vectorize.cc:182–201 only does

vmap.Set(fnode->loop_var, iter_var_);

(without multiplying by vector_size_), any additional uses of the original loop variable in the loop body will now refer to iter_var_ instead of iter_var_ * vector_size_, altering semantics.

Please verify that:

  • The innermost loops you plan to vectorize contain only AtomicAdd calls (no other BufferLoad, BufferStore, or arithmetic involving the original loop_var).
  • If there are other uses of loop_var, either
    • restrict vectorization to bodies with AtomicAdd-only side effects, or
    • scale the substitution to
      vmap.Set(fnode->loop_var, iter_var_ * vector_size_);
      and update all index computations accordingly.

Comment on lines 208 to 230
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Division-by-zero hazard in index math when stride_x_ < vector_size_ or not divisible.

Expressions like truncdiv(tx_var_, stride_x_ / vector_size_) and truncmod(..., stride_x_ / vector_size_) use a zero denominator when stride_x_ < vector_size_. They also assume stride_x_ % vector_size_ == 0.

Add a safety gate before invoking the rewriter so we bail out when unsafe:

@@
-    if (vectorize_hint == 1)
-      return for_node;
-    auto rewriter = AtomicAddVectorizeRewriter(
+    if (vectorize_hint == 1)
+      return for_node;
+    // Require positive stride_x and divisibility to avoid 0-denominator in index math.
+    if (stride_x <= 0 || (stride_x % vectorize_hint) != 0) {
+      return for_node;
+    }
+    auto rewriter = AtomicAddVectorizeRewriter(
         res, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x);

Optionally, also guard for undefined bx/by:

-    if (vectorize_hint == 1)
+    if (vectorize_hint == 1 || !bx_var.defined() || !by_var.defined())
       return for_node;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Matrix[by * stride_y + i / (stride_x / (tx_txtent *
// vector_size_)) + tx_var_ / (stride_x / vector_size_),
// bx * stride_x + (i % (stride_x / (tx_extent *
// vector_size_)) * (tx_extent * vector_size_) + (tx_var_ %
// (stride / vector_size_)) * vector_size_]
BufferLoadNode old_dst_node =
*(node->args[1].as<CallNode>()->args[0].as<BufferLoadNode>());
BufferLoadNode old_value_node =
*(node->args[2].as<BufferLoadNode>());
Array<PrimExpr> dst_indices, value_indices;
if ((extent_tx_ * vector_size_) > stride_x_) {
dst_indices.push_back(
by_var_ * stride_y_ +
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
truncdiv(tx_var_ * vector_size_, stride_x_));
value_indices.push_back(
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
} else {
dst_indices.push_back(
by_var_ * stride_y_ +
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
dst_indices.push_back(
bx_var_ * stride_x_ +
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
value_indices.push_back(
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
truncdiv(tx_var_, stride_x_ / vector_size_));
value_indices.push_back(
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
(extent_tx_ * vector_size_) +
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_);
}
// Bail out if no vectorization hint or missing block coords
if (vectorize_hint == 1 || !bx_var.defined() || !by_var.defined())
return for_node;
// Require a positive, exactly divisible stride_x to avoid 0-denominator in truncdiv/truncmod.
if (stride_x <= 0 || (stride_x % vectorize_hint) != 0) {
return for_node;
}
auto rewriter = AtomicAddVectorizeRewriter(
res, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x);

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/transform/atomicadd_vectorize.cc (2)

203-260: Dynamic path emits invalid IR: iter_var_ used without loop rewrite.

VisitExpr_ rewrites AtomicAdd unconditionally. When dynamic_ == true, VisitStmt_ returns the original loop, yet the rewritten calls capture iter_var_, which is undefined. Short-circuit in the call rewriter for dynamic plans:

 PrimExpr VisitExpr_(const CallNode *node) final {
+  if (dynamic_) {
+    return StmtExprMutator::VisitExpr_(node);
+  }
   if (vector_size_ == 2 || vector_size_ == 4) {

Longer-term: guard vectorized call with condition_ and fall back to scalar AtomicAdd.


303-312: Null-deref hazard when extracting dtype from AtomicAdd arg[1].

func_name may be null; args[1] may not be address_of(BufferLoad). Add checks:

-        const auto *func_name = call->args[0].as<StringImmNode>();
-        if (func_name->value == "AtomicAdd") {
-          DataType dtype =
-              call->args[1].as<CallNode>()->args[0].as<BufferLoadNode>()->dtype;
+        const auto* func_name = call->args[0].as<StringImmNode>();
+        if (func_name && func_name->value == "AtomicAdd") {
+          const CallNode* addr = call->args[1].as<CallNode>();
+          if (!addr || addr->op != builtin::address_of() || addr->args.size() != 1) return;
+          const BufferLoadNode* bl = addr->args[0].as<BufferLoadNode>();
+          if (!bl) return;
+          DataType dtype = bl->dtype;
♻️ Duplicate comments (3)
src/transform/atomicadd_vectorize.cc (3)

314-331: Good: commutative Mul handling for bx/by stride discovery.

This resolves the brittle pattern flagged earlier.


340-343: Add safety gate: require positive/divisible stride_x before vectorization.

Prevents 0-denominator and misaligned indexing. Minimal:

-    if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 || !bx_var.defined() || !by_var.defined())
+    if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 ||
+        !bx_var.defined() || !by_var.defined() ||
+        stride_x <= 0 || (stride_x % vectorize_hint) != 0)
       return for_node;

208-230: Prevent division-by-zero when vectorizing atomic adds

  • In VectorizeAtomicAdd (before invoking the rewriter), bail out if the computed stride is smaller than the planned vector size:
-    if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 || !bx_var.defined() || !by_var.defined())
+    if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 ||
+        stride_x < vectorize_hint || !bx_var.defined() || !by_var.defined())
         return for_node;
  • Optionally, add an assertion in the AtomicAddVectorizeRewriter constructor to catch any remaining invalid states:
   ICHECK(tx_ext) << "thread_bounds->extent must be a constant for vectorization.";
+  ICHECK(vector_size_ > 0 && stride_x_ >= vector_size_)
+      << "Invalid vectorization: stride_x_=" << stride_x_
+      << ", vector_size_=" << vector_size_;
🧹 Nitpick comments (3)
src/transform/atomicadd_vectorize.cc (3)

128-129: Dynamic condition is unused; either consume it or drop it.

condition_ = (truncmod(offset, vector_size_) == 0) is never referenced. If you keep dynamic plans, use it to guard the vectorized path or remove to reduce confusion.


208-210: Nit: comment typo and clarity.

tx_txtenttx_extent. Consider formatting the two-index formula legibly or moving it next to the branch that uses it.


1-1: Fix formatting to satisfy CI.

CI indicates clang-format changed the file; run clang-format and commit.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 5e0d9550096139baf220c468da891ec1a3884675 and 9f656b6.

📒 Files selected for processing (2)
  • src/op/atomic_add.cc (2 hunks)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/op/atomic_add.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (5)
  • truncmod (2944-2967)
  • ret (1656-1669)
  • call_extern (172-194)
  • address_of (463-479)
  • truncdiv (2918-2941)
🪛 GitHub Actions: CI Test on AMD
src/transform/atomicadd_vectorize.cc

[error] 1-1: tile-lang clang-format reformatted the file. Please review and stage the changes (src/transform/atomicadd_vectorize.cc).

🔇 Additional comments (2)
src/transform/atomicadd_vectorize.cc (2)

144-151: Good: guarded constant extent for tx and captured context.

The as_const_int guard and storing extent_tx_ is correct and addresses prior crash risks.


232-238: Good: safe reconstruction of BufferLoad and address_of.

Switching to validated nodes and rebuilding loads avoids previous UB.

Comment on lines +182 to 201
iter_var_ = Var(node->loop_var->name_hint + "_outer");
auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { // rewrite the innermost loop
For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
auto extent_ptr = as_const_int(fnode->extent);
ICHECK(extent_ptr) << fnode->extent;
int extent = *extent_ptr;
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
if (!dynamic_) {
Var tx_var;
PostOrderVisit(fnode->body, [&tx_var](const ObjectRef &node) {
if (const VarNode *var = node.as<VarNode>()) {
if (var->name_hint == "tx") {
tx_var = GetRef<Var>(var);
}
}
});
ICHECK(tx_var.defined()) << "Failed to find tx var";
Var outer_var = Var(old_var->name_hint + "_outer");
Map<Var, PrimExpr> vmap;
// Scale thread index (tx) and loop variable by vector_size to map each
// new iteration to a vectorized chunk
vmap.Set(tx_var, tx_var * vector_size_);
vmap.Set(fnode->loop_var, outer_var * vector_size_);
vmap.Set(fnode->loop_var, iter_var_);
Stmt body = Substitute(fnode->body, vmap);
return For(outer_var, 0, extent / vector_size_, fnode->kind, body,
return For(iter_var_, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
}
}
return ret;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Loop-splitting bug: inner loop var not scaled by vector size (drops 1 − 1/v work).

You replace the old loop var with iter_var_ but do not multiply by vector_size_. This misindexes all uses of the original loop var outside the AtomicAdd call and breaks semantics. Map the old var to iter_var_ * vector_size_ (lane 0 representative) or introduce an explicit lane var and fold it. Minimal fix:

-    Map<Var, PrimExpr> vmap;
-    vmap.Set(fnode->loop_var, iter_var_);
+    Map<Var, PrimExpr> vmap;
+    vmap.Set(fnode->loop_var, iter_var_ * make_const(DataType::Int(32), vector_size_));

Optionally also substitute tx_var_ if the body indexes with lane-local math that assumed per-lane iteration.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
iter_var_ = Var(node->loop_var->name_hint + "_outer");
auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { // rewrite the innermost loop
For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
auto extent_ptr = as_const_int(fnode->extent);
ICHECK(extent_ptr) << fnode->extent;
int extent = *extent_ptr;
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
if (!dynamic_) {
Var tx_var;
PostOrderVisit(fnode->body, [&tx_var](const ObjectRef &node) {
if (const VarNode *var = node.as<VarNode>()) {
if (var->name_hint == "tx") {
tx_var = GetRef<Var>(var);
}
}
});
ICHECK(tx_var.defined()) << "Failed to find tx var";
Var outer_var = Var(old_var->name_hint + "_outer");
Map<Var, PrimExpr> vmap;
// Scale thread index (tx) and loop variable by vector_size to map each
// new iteration to a vectorized chunk
vmap.Set(tx_var, tx_var * vector_size_);
vmap.Set(fnode->loop_var, outer_var * vector_size_);
vmap.Set(fnode->loop_var, iter_var_);
Stmt body = Substitute(fnode->body, vmap);
return For(outer_var, 0, extent / vector_size_, fnode->kind, body,
return For(iter_var_, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
}
}
return ret;
}
iter_var_ = Var(node->loop_var->name_hint + "_outer");
auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { // rewrite the innermost loop
For fnode = ret.as<For>().value();
auto extent_ptr = as_const_int(fnode->extent);
ICHECK(extent_ptr) << fnode->extent;
int extent = *extent_ptr;
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
if (!dynamic_) {
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var,
iter_var_ * make_const(DataType::Int(32), vector_size_));
Stmt body = Substitute(fnode->body, vmap);
return For(iter_var_, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
}
}
return ret;
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 182 to 201, the innermost
loop rewrite replaces the original loop variable with iter_var_ but fails to
scale it by vector_size_, causing dropped work and incorrect indexing; fix by
mapping the original loop var to iter_var_ * vector_size_ (or introduce an
explicit lane variable and fold lane offsets) when building vmap before calling
Substitute so all uses outside the AtomicAdd see the lane-0 representative; also
consider substituting tx_var_ if the body contains lane-local arithmetic that
assumed per-lane iteration.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
src/transform/atomicadd_vectorize.cc (4)

328-361: Robust bx/by detection (commutative mul) — nice.

Handles varimm and immvar; sets strides and captures var exprs. Good improvement.


197-201: Loop-splitting bug: original loop var must be scaled by vector_size_.

Mapping old loop_var → iter_var_ drops work/misindexes; multiply by vector_size_.

Apply:

-        vmap.Set(fnode->loop_var, iter_var_);
+        vmap.Set(
+            fnode->loop_var,
+            iter_var_ * make_const(fnode->loop_var.dtype(), vector_size_));

233-236: Zero-denominator hazard in index math (stride_x_ / vector_size_).

If stride_x_ < vector_size_, denominators become 0 in truncdiv/truncmod; also assumes divisibility. Add guards; otherwise vectorization must be skipped.

Suggest guarding in VectorizeAtomicAdd (see separate comment) and optionally add in-function check:

-    if (vector_size_ == 2 || vector_size_ == 4) {
+    if ((vector_size_ == 2 || vector_size_ == 4) &&
+        stride_x_ >= vector_size_ && (stride_x_ % vector_size_ == 0)) {

Also applies to: 241-242, 245-251, 256-258


370-375: Add safety gates before invoking the rewriter (avoid 0-denominator and undefined bx/by/stride).

Guard on stride_x and divisibility to prevent invalid IR; also keep existing defined checks.

Apply:

-    if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 ||
-        !bx_var.defined() || !by_var.defined())
+    if (vectorize_hint == 1 ||
+        stride_x <= 0 || (stride_x % vectorize_hint) != 0 ||
+        stride_y <= 0 ||
+        !bx_var.defined() || !by_var.defined())
       return for_node;

Optionally also guard against the else-branch ratio being 0:

+    if (const int64_t* tx_ext = as_const_int(thread_bounds->extent)) {
+      if (*tx_ext > 0 && stride_x < (*tx_ext) * vectorize_hint) {
+        return for_node;
+      }
+    }
🧹 Nitpick comments (2)
src/transform/atomicadd_vectorize.cc (2)

211-216: Nit: fix typos in comment.

“tx_txtent” → “tx_extent”, “(stride / vector_size_)” → “(stride_x_ / vector_size_)”.

Apply:

-            // Matrix[by * stride_y + i / (stride_x / (tx_txtent *
+            // Matrix[by * stride_y + i / (stride_x / (tx_extent *
@@
-            //        (stride / vector_size_)) * vector_size_]
+            //        (stride_x_ / vector_size_)) * vector_size_]

291-300: Unused/ineffective state: condition_ not applied; dynamic_ path unused.

condition_ is computed but never used to predicate the vectorized call; dynamic_ true currently yields no safe vectorization. Either remove condition_ or wire it to guard the vectorized call, else keep the scalar AtomicAdd.

Apply one of:

  • Minimal: skip rewrites when dynamic_ (see earlier diff).
  • Full: wrap new_call with Select(condition_, new_call, old_call), but that requires retaining old_call.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 9f656b6 and 9940991.

📒 Files selected for processing (1)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (5)
  • truncmod (2944-2967)
  • ret (1656-1669)
  • call_extern (172-194)
  • address_of (463-479)
  • truncdiv (2918-2941)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-amd
🔇 Additional comments (4)
src/transform/atomicadd_vectorize.cc (4)

150-154: Good: added constant-extent guard for tx.

Null-check + ICHECK on thread_bounds->extent prevents crashes for symbolic extents. Looks correct.


261-271: OK: address_of(dst/value) construction matches call_extern ABI.

New AtomicAddx2/x4 args look consistent with address-of semantics.


221-227: Safer: null-checks for dst/value loads — good.

Prevents crashes when args aren’t BufferLoad. Solid.


128-129: Dynamic condition updated to truncmod — fine, but ensure it’s used.

truncmod is correct; however, without wiring condition_ into codegen (or skipping rewrite when dynamic_), this has no effect.

Would you like me to wire condition_ as a predicate into the vectorized call or keep dynamic path scalar-only?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/transform/atomicadd_vectorize.cc (2)

334-343: Null-deref risk when extracting dtype; add shape checks.

call->args[1] may not be address_of(BufferLoad); current code can crash.

-        if (func_name->value == "AtomicAdd") {
-          DataType dtype =
-              call->args[1].as<CallNode>()->args[0].as<BufferLoadNode>()->dtype;
-          vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
-        }
+        if (func_name->value == "AtomicAdd") {
+          const CallNode* addr = call->args[1].as<CallNode>();
+          if (!addr || addr->op != builtin::address_of() || addr->args.size() != 1) {
+            return;
+          }
+          const BufferLoadNode* load = addr->args[0].as<BufferLoadNode>();
+          if (!load) {
+            return;
+          }
+          DataType dtype = load->dtype;
+          vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
+        }

120-123: Avoid potential infinite/invalid loop when shrinking vector_size_.

Stop halving at 1 to avoid 0 and undefined behavior in IndiceCanVectorize.

-      while (!IndiceCanVectorize(elem_offset, thread_var, thread_extent,
-                                 vector_size_, &analyzer_)) {
+      while (vector_size_ > 1 &&
+             !IndiceCanVectorize(elem_offset, thread_var, thread_extent,
+                                 vector_size_, &analyzer_)) {
         vector_size_ /= 2;
       }
♻️ Duplicate comments (4)
src/transform/atomicadd_vectorize.cc (4)

219-230: Good: robust operand validation before rewriting AtomicAdd.


345-364: Good: commutativity handled for bx/by stride detection.


196-201: Critical: inner loop-var must be scaled by vector_size_ (drops work, wrong indexing).

Map the original loop var to iter_var_ * vector_size_. This was previously flagged and remains unfixed.

-        vmap.Set(fnode->loop_var, iter_var_);
+        vmap.Set(fnode->loop_var,
+                 iter_var_ * make_const(DataType::Int(32), vector_size_));

373-377: Add stride/divisibility gate to avoid 0-denominator in index math.

Without this, expressions like truncdiv(tx_var_, stride_x_/vector_size_) can divide by 0 when stride_x < vector_size_ or not divisible.

-    if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 ||
-        !bx_var.defined() || !by_var.defined())
+    if (vectorize_hint == 1 || stride_x <= 0 || stride_y <= 0 ||
+        (stride_x % vectorize_hint) != 0 ||
+        !bx_var.defined() || !by_var.defined())
       return for_node;
🧹 Nitpick comments (3)
src/transform/atomicadd_vectorize.cc (3)

185-185: Preserve dtype on the new loop var.

Construct iter_var_ with the same dtype as the original loop var.

-    iter_var_ = Var(node->loop_var->name_hint + "_outer");
+    iter_var_ = Var(node->loop_var->name_hint + "_outer", node->loop_var->dtype);

214-219: Fix typos in index-derivation comment.

“tx_txtent” → “tx_extent”; “(stride / vector_size_)” → “(stride_x_ / vector_size_)”.

-            // Matrix[by * stride_y + i / (stride_x / (tx_txtent *
+            // Matrix[by * stride_y + i / (stride_x / (tx_extent *
             // vector_size_)) + tx_var_ / (stride_x / vector_size_),
             //        bx * stride_x + (i % (stride_x / (tx_extent *
             //        vector_size_)) * (tx_extent * vector_size_) + (tx_var_ %
-            //        (stride / vector_size_)) * vector_size_]
+            //        (stride_x_ / vector_size_)) * vector_size_]

144-154: Consider widening extent_tx_ type to int64_t.

Avoid narrowing from thread_bounds->extent on large configurations.

-  int extent_tx_;
+  int64_t extent_tx_;

And update the assignment accordingly.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 9940991 and f3e5a13.

📒 Files selected for processing (1)
  • src/transform/atomicadd_vectorize.cc (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (5)
  • truncmod (2944-2967)
  • ret (1656-1669)
  • call_extern (172-194)
  • address_of (463-479)
  • truncdiv (2918-2941)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: build-test-amd
  • GitHub Check: build-test-nvidia
  • GitHub Check: bot-task

Comment on lines 206 to +209
PrimExpr VisitExpr_(const CallNode *node) final {

if (dynamic_) {
return StmtExprMutator::VisitExpr_(node);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Defensive check inside rewriter to avoid 0-denominator at call rewrite sites.

Even with outer gating, keep a local guard to be safe.

   PrimExpr VisitExpr_(const CallNode *node) final {
     if (dynamic_) {
       return StmtExprMutator::VisitExpr_(node);
     }
+    if (stride_x_ <= 0 || (stride_x_ % vector_size_) != 0) {
+      return StmtExprMutator::VisitExpr_(node);
+    }
     if (vector_size_ == 2 || vector_size_ == 4) {

Also applies to: 231-262

🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 206-209 and also apply to
231-262, add a local defensive guard inside the CallNode rewrite path: detect
when the denominator (or divisor) value could be zero and short-circuit or
fallback to the original expression before performing the rewrite; concretely,
compute or extract the divisor locally, check it for zero (or
non-constant/unknown) and if zero/unsafe, return
StmtExprMutator::VisitExpr_(node) (or the unmodified node) instead of proceeding
with the transformation so no 0-division is generated.

@LeiWang1999 LeiWang1999 merged commit a7a29c0 into tile-ai:main Aug 31, 2025
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants