Skip to content

Conversation

yyttt6
Copy link
Contributor

@yyttt6 yyttt6 commented Oct 13, 2025

Summary by CodeRabbit

  • New Features

    • Added an element-wise tile-level atomic-add intrinsic.
  • Refactor

    • Switched atomic-add vectorization to intrinsic-based detection, enabling x2/x4 vectorized variants with safe fallbacks to scalar.
    • Simplified layout handling for atomic-add paths (layout inference no longer delegated).
  • Chores

    • Updated transformation flow and messages to recognize the new intrinsic and emit clearer warnings for unsupported vectorization patterns.

Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

Copy link
Contributor

coderabbitai bot commented Oct 13, 2025

Walkthrough

Replaces string-based extern AtomicAdd calls with a new tilelang intrinsic tl_atomicadd_elem_op (access via atomicadd_elem_op()), updates the atomic-add vectorization pass to target that intrinsic and rewrite to vector variants, and removes ParallelOp-based layout inference from AtomicAdd lowering.

Changes

Cohort / File(s) Summary
AtomicAdd lowering
src/op/atomic_add.cc
Stop emitting StringImm("AtomicAdd"); lower element-wise atomic-add using tvm::tl::tl_atomicadd_elem_op()/atomicadd_elem_op(); remove lazy par_op_ creation and return an empty layout map from InferLayout (no delegation to par_op_).
Vectorize pass
src/transform/atomicadd_vectorize.cc, src/transform/atomicadd_vectorize.h
Detect atomic-add by atomicadd_elem_op() opcode instead of extern-name string; derive vectorize size from BufferLoad or IfThenElse+BufferLoad (fallback with warnings); short-circuit when vector_size == 1; rewrite calls to AtomicAddx2/AtomicAddx4 (or fallback to scalar AtomicAdd); adjust operand shapes and address_of() usage.
Builtins registry & header
src/op/builtin.cc, src/op/builtin.h
Add new TL builtin registration tl_atomicadd_elem_op (2 inputs, TCallEffectKind::kOpaque) and expose atomicadd_elem_op() declaration in header.
Header include only
src/transform/atomicadd_vectorize.h
Add #include "../op/builtin.h" to enable builtin opcode usage.

Sequence Diagram(s)

%%{init: {"themeVariables": {"actorBkg":"#f3f4f6","noteBkg":"#fff8dc"}} }%%
sequenceDiagram
    autonumber
    participant TIR as TIR (call site)
    participant TL as tl_atomicadd_elem_op
    participant Planner as VectorizePlanner
    participant Rewriter as VectorizeRewriter
    participant Lower as Lowering

    TIR->>TL: atomicadd_elem_op(dst, src, ...)
    TL-->>TIR: intrinsic opcode recognized

    TIR->>Planner: analyze for BufferLoad / IfThenElse+BufferLoad
    Planner-->>TIR: vector_size (2,4 or 1)

    TIR->>Rewriter: rewrite(vector_size)
    alt vector_size == 1
        Rewriter-->>TIR: keep scalar AtomicAdd (extern)
    else vector_size == 2 or 4
        Rewriter-->>TIR: emit AtomicAddx2 / AtomicAddx4, reshape operands, adjust address_of
    else
        Rewriter-->>TIR: fallback to scalar AtomicAdd extern
    end

    TIR->>Lower: lower rewritten calls
    Lower-->>TIR: device code emitted
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

I rabbit-hop through opcode fields with a small, clever tap,
Replaced the brittle string with one true intrinsic map.
Two or four I sniff and choose, or keep the single lane,
I nudge the tiles, reshape the loads — the kernels hum again. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Title Check ⚠️ Warning The title claims to add a test for atomicadd auto vectorization, but the diff contains no test files and instead focuses on removing unused code and introducing a new intrinsic and vectorization logic, making the title misleading and overly verbose. Rename the title to accurately reflect the core changes, for example: “Simplify AtomicAdd implementation and integrate atomicadd_elem_op vectorization support.”
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/transform/atomicadd_vectorize.cc (1)

217-255: Incorrect Downcast usage causes compile error

You’re downcasting from a pointer; Downcast expects an ObjectRef/PrimExpr. Fix both downcasts:

-        const BufferLoad dst_node =
-            Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
-        const BufferLoad value_node =
-            Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
+        const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]);
+        const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);

Optionally, remove the earlier temp_*_node guards and directly as<BufferLoadNode>() once for checks, then Downcast from the original PrimExpr.

🧹 Nitpick comments (4)
src/transform/atomicadd_vectorize.cc (3)

42-71: Handle dtype inference when src (arg2) is guarded by IfThenElse

Plan currently extracts dtype from dst (arg1) or an IfThenElse in arg1, but often masking happens on the value side (arg2). Consider symmetric handling for arg2 to avoid unnecessary fallback.

Example adjustment (sketch):

-        } else if (const auto *ite = call->args[1].as<IfThenElseNode>()) {
+        } else if (const auto *ite = call->args[1].as<IfThenElseNode>()) {
           ...
+        } else if (const auto* ite2 = call->args[2].as<IfThenElseNode>()) {
+          if (const auto* then_load = ite2->then_case.as<BufferLoadNode>()) {
+            dtype = then_load->dtype;
+            vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
+          } else if (const auto* else_load = ite2->else_case.as<BufferLoadNode>()) {
+            dtype = else_load->dtype;
+            vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
+          } else {
+            vectorize_size_max = 1;
+            DLOG(WARNING) << "[AtomicAddVectorizePlanner] IfThenElse (arg2) has no BufferLoad; Fallback";
+          }
         } else {

159-159: Non-English comment

// 补齐相关test — consider translating or removing to keep codebase comments in English.


193-215: Docstring vs implementation mismatch for new loop var name

Doc says create "<old_loop_var>_outer", but code uses Var(old_var->name_hint) (same name). Prefer explicit "_outer" for readability and to match docs.

-      auto new_var = Var(old_var->name_hint);
+      auto new_var = Var(old_var->name_hint + "_outer");
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (1)

25-26: Unused bz binding in both IR modules (RUF059)

Prefix the unused binding with underscore to silence lints and clarify intent:

-            threads=128) as (bx, by, bz):
+            threads=128) as (bx, by, _bz):

Repeat for the After module.

Also applies to: 36-36

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d89ba5b and 0d5b4e3.

📒 Files selected for processing (3)
  • src/op/atomic_add.cc (0 hunks)
  • src/transform/atomicadd_vectorize.cc (5 hunks)
  • testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (1 hunks)
💤 Files with no reviewable changes (1)
  • src/op/atomic_add.cc
🧰 Additional context used
🧬 Code graph analysis (2)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (2)
  • call_extern (172-194)
  • address_of (463-479)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (3)
tilelang/utils/target.py (1)
  • determine_target (54-99)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/tir/op.py (2)
  • call_extern (172-194)
  • address_of (463-479)
🪛 Ruff (0.13.3)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py

25-25: Unpacked variable bz is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


36-36: Unpacked variable bz is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-metal
  • GitHub Check: format-check

Comment on lines 8 to 9
auto_target = tvm.target.Target(determine_target("auto"))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Make the test assert the structural equality and avoid import-time target failures

  • The result of tvm.ir.structural_equal(mod, ref_mod) is unused; test won’t fail on mismatch.
  • auto_target is computed at import time and may raise ValueError on CI without CUDA/HIP/MPS.

Apply:

- auto_target = tvm.target.Target(determine_target("auto"))
+ # target selection is done inside the test to allow graceful skipping

@@
-    with tvm.transform.PassContext():
-        mod = tvm.tir.transform.BindTarget(auto_target)(Before)
+    # Select target at runtime; skip if none available.
+    try:
+        auto_target = tvm.target.Target(determine_target("auto"))
+    except ValueError:
+        pytest.skip("No CUDA/HIP/MPS available; skipping atomicadd vectorize test.")
+
+    with tvm.transform.PassContext():
+        mod = tvm.tir.transform.BindTarget(auto_target)(Before)
         mod = tl.transform.LowerTileOp()(mod)
         mod = tvm.tir.transform.Simplify()(mod)
     ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
     ref_mod = tvm.tir.transform.Simplify()(ref_mod)
-    tvm.ir.structural_equal(mod, ref_mod)
+    assert tvm.ir.structural_equal(mod, ref_mod)

Also applies to: 54-61

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/transform/atomicadd_vectorize.cc (1)

149-152: Prevent vector_size_ from reaching zero in halving loop

If IndiceCanVectorize never returns true, vector_size_ can become 0, leading to UB/div-by-zero later. Add a lower bound.

-    while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
-                               inner_for_->extent, vector_size_, &analyzer_)) {
-      vector_size_ /= 2;
-    }
+    while (vector_size_ > 1 &&
+           !IndiceCanVectorize(elem_offset, inner_for_->loop_var,
+                               inner_for_->extent, vector_size_, &analyzer_)) {
+      vector_size_ /= 2;
+    }
+    if (vector_size_ < 1) vector_size_ = 1;
♻️ Duplicate comments (3)
src/transform/atomicadd_vectorize.cc (2)

20-31: Tighten extern-call detection; avoid OOB and remove redundant null-check

Require at least 3 args and drop the dead if (!s); callers index args[2] later. This prevents OOB when malformed calls slip through.

Apply:

-inline bool IsAtomicAddExternCall(const CallNode *call) {
-  if (!call || !call->op.same_as(builtin::call_extern()))
-    return false;
-  if (call->args.size() < 2)
-    return false;
-  if (const auto *s = call->args[0].as<StringImmNode>()) {
-    if (!s)
-      return false;
-    return std::string(s->value) == AtomicAddExternName();
-  }
-  return false;
-}
+inline bool IsAtomicAddExternCall(const CallNode* call) {
+  if (!call || !call->op.same_as(builtin::call_extern())) return false;
+  // call_extern("AtomicAdd", dst, src) => need at least 3 args
+  if (call->args.size() < 3) return false;
+  if (const auto* s = call->args[0].as<StringImmNode>()) {
+    return std::string(s->value) == AtomicAddExternName();
+  }
+  return false;
+}

88-99: Guard access to args[2] (defense-in-depth)

Add an explicit count check before indexing args[2] to avoid accidental OOB if detection changes.

-  if (IsAtomicAddExternCall(node)) {
+  if (IsAtomicAddExternCall(node)) {
+    ICHECK_GE(node->args.size(), 3);
     const BufferLoadNode *buffer_load_dst = node->args[1].as<BufferLoadNode>();
     const BufferLoadNode *buffer_load_src = node->args[2].as<BufferLoadNode>();
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (1)

70-73: Make the test assert structural equality

Currently the result is unused; test won’t fail on mismatch. Assert it.

-    tvm.ir.structural_equal(mod, ref_mod)
-    # tvm.ir.assert_structural_equal(mod, ref_mod)
+    assert tvm.ir.structural_equal(mod, ref_mod)
+    # or: tvm.ir.assert_structural_equal(mod, ref_mod)
🧹 Nitpick comments (2)
src/transform/atomicadd_vectorize.cc (1)

159-159: Prefer English-only comments in source

Replace the non-English comment with an English equivalent for consistency.

-// 补齐相关test
+// TODO: Add/align related tests
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (1)

23-24: Use the parameterized threads and prefix unused bz

Avoid ARG001/RUF059 and reflect the parameterization.

-            with T.Kernel(
-                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k,
-                    threads=128) as (bx, by, bz):
+            with T.Kernel(
+                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k,
+                    threads=threads) as (bx, by, _bz):
@@
-            with T.Kernel(
-                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k,
-                    threads=128) as (bx, by, bz):
+            with T.Kernel(
+                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k,
+                    threads=threads) as (bx, by, _bz):

Also applies to: 34-35

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0d5b4e3 and 31c4189.

📒 Files selected for processing (3)
  • src/op/atomic_add.cc (0 hunks)
  • src/transform/atomicadd_vectorize.cc (5 hunks)
  • testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (1 hunks)
💤 Files with no reviewable changes (1)
  • src/op/atomic_add.cc
🧰 Additional context used
🧬 Code graph analysis (2)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (4)
tilelang/utils/target.py (1)
  • determine_target (54-99)
tilelang/engine/lower.py (1)
  • canon_target_host (127-132)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/parallel.py (1)
  • Parallel (8-28)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (2)
  • call_extern (172-194)
  • address_of (463-479)
🪛 Ruff (0.14.0)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py

14-14: Unused function argument: threads

(ARG001)


23-23: Unpacked variable bz is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


34-34: Unpacked variable bz is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: build-test-metal
  • GitHub Check: build-test-amd
  • GitHub Check: format-check

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/transform/atomicadd_vectorize.cc (1)

149-152: Prevent vector_size_ from becoming 0 (runtime crash risk)

Division by 2 can drop to 0 if vectorization is impossible; downstream checks then divide by or mod 0.

-    while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
-                               inner_for_->extent, vector_size_, &analyzer_)) {
-      vector_size_ /= 2;
-    }
+    while (vector_size_ > 1 &&
+           !IndiceCanVectorize(elem_offset, inner_for_->loop_var,
+                               inner_for_->extent, vector_size_, &analyzer_)) {
+      vector_size_ /= 2;
+    }
+    // If still not vectorizable at size 1, bail out gracefully.
+    if (vector_size_ < 1) vector_size_ = 1;
♻️ Duplicate comments (3)
src/transform/atomicadd_vectorize.cc (2)

20-31: Tighten extern-call detection; remove redundant null-check

Require at least 3 args to avoid OOB at args[2]; drop the unnecessary if (!s) inside the cast branch.

-inline bool IsAtomicAddExternCall(const CallNode *call) {
-  if (!call || !call->op.same_as(builtin::call_extern()))
-    return false;
-  if (call->args.size() < 2)
-    return false;
-  if (const auto *s = call->args[0].as<StringImmNode>()) {
-    if (!s)
-      return false;
-    return std::string(s->value) == AtomicAddExternName();
-  }
-  return false;
-}
+inline bool IsAtomicAddExternCall(const CallNode* call) {
+  if (!call || !call->op.same_as(builtin::call_extern())) return false;
+  // call_extern("AtomicAdd", dst, src)
+  if (call->args.size() < 3) return false;
+  if (const auto* s = call->args[0].as<StringImmNode>()) {
+    return s->value == AtomicAddExternName();
+  }
+  return false;
+}

88-99: Guard access to args[2] (local safety even if detector regresses)

Add an explicit size check before indexing args[2]. Safer and self‑contained.

-  if (IsAtomicAddExternCall(node)) {
+  if (IsAtomicAddExternCall(node)) {
+    ICHECK_GE(node->args.size(), 3);
     const BufferLoadNode *buffer_load_dst = node->args[1].as<BufferLoadNode>();
     const BufferLoadNode *buffer_load_src = node->args[2].as<BufferLoadNode>();
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (1)

71-75: Make the test fail on mismatch

Currently the equality result is ignored. Assert it.

-    tvm.ir.structural_equal(mod, ref_mod)
-    # tvm.ir.assert_structural_equal(mod, ref_mod)
+    tvm.ir.assert_structural_equal(mod, ref_mod)
🧹 Nitpick comments (5)
src/transform/atomicadd_vectorize.cc (2)

126-126: Remove unused variable

access_type is never used.

-  const DataType &access_type = buffer->dtype;

193-215: Preserve loop var dtype and clarify naming

Construct the new loop var with the original dtype and a distinct name (matches the docstring).

-      auto old_var = fnode->loop_var;
-      auto new_var = Var(old_var->name_hint);
+      auto old_var = fnode->loop_var;
+      auto new_var = Var(old_var->name_hint + "_outer", old_var->dtype);
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (3)

49-55: Robust SM parsing (handle missing digits like sm_)

Guard against empty digit strings to avoid ValueError.

     arch_part = arch.split("_")[1]
-    cc = int("".join(ch for ch in arch_part if ch.isdigit()))
+    digits = "".join(ch for ch in arch_part if ch.isdigit())
+    if not digits:
+        pytest.skip(f"Unrecognized CUDA arch: {arch}")
+    cc = int(digits)
     if cc != 90:
         pytest.skip(f"Requires Hopper (SM90), got {arch}")

21-24: Use the parametrized threads and silence bz unused lint

Wire through threads instead of hardcoding 128; prefix the unused bz.

-            with T.Kernel(
-                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k,
-                    threads=128) as (bx, by, bz):
+            with T.Kernel(
+                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k,
+                    threads=threads) as (bx, by, _bz):
@@
-            with T.Kernel(
-                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k,
-                    threads=128) as (bx, by, bz):
+            with T.Kernel(
+                    T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k,
+                    threads=threads) as (bx, by, _bz):

Also applies to: 33-35


68-70: Optional: drop noisy prints

Remove prints to keep CI logs clean.

-    print(mod)
-    print(ref_mod)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 31c4189 and 51eb185.

📒 Files selected for processing (2)
  • src/transform/atomicadd_vectorize.cc (5 hunks)
  • testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (2)
  • call_extern (172-194)
  • address_of (463-479)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (2)
tilelang/utils/target.py (1)
  • determine_target (54-99)
tilelang/engine/lower.py (1)
  • canon_target_host (127-132)
🪛 Ruff (0.14.0)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py

14-14: Unused function argument: threads

(ARG001)


23-23: Unpacked variable bz is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


34-34: Unpacked variable bz is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: build-test-metal
  • GitHub Check: build-test-nvidia
  • GitHub Check: build-test-amd

@LeiWang1999 LeiWang1999 changed the title [Feature]: Add test for atomicadd auto vectorize and remote useless code [Feature]: Add test for atomicadd auto vectorize and remove useless code Oct 14, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
src/transform/atomicadd_vectorize.cc (1)

217-221: Fix Downcast misuse (compile error): pass ObjectRef, not Node*

Downcast expects an ObjectRef. Using .as() yields a raw pointer.

Apply:

-      const BufferLoad dst_node =
-          Downcast<BufferLoad>(node->args[0].as<BufferLoadNode>());
-      const BufferLoad value_node =
-          Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
+      const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]);
+      const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 51eb185 and 052aac7.

📒 Files selected for processing (5)
  • src/op/atomic_add.cc (2 hunks)
  • src/op/builtin.cc (1 hunks)
  • src/op/builtin.h (1 hunks)
  • src/transform/atomicadd_vectorize.cc (4 hunks)
  • src/transform/atomicadd_vectorize.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (2)
src/transform/loop_vectorize_dynamic.cc (16)
  • node (80-85)
  • node (80-80)
  • node (92-96)
  • node (92-92)
  • node (98-112)
  • node (98-98)
  • node (114-120)
  • node (114-114)
  • node (122-125)
  • node (122-122)
  • node (127-135)
  • node (127-127)
  • node (263-266)
  • node (263-263)
  • node (283-286)
  • node (283-283)
tilelang/language/tir/op.py (3)
  • ret (1758-1771)
  • address_of (463-479)
  • call_extern (172-194)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: build-test-metal
  • GitHub Check: build-test-amd
  • GitHub Check: build-test-nvidia
🔇 Additional comments (6)
src/transform/atomicadd_vectorize.h (1)

11-11: Include addition looks good

Required for tl_atomicadd_elem_op usage downstream. No concerns.

src/op/builtin.h (1)

472-479: New intrinsic declaration is consistent

Two‑arg, opaque intrinsic aligns with intended atomic add element op. LGTM.

src/op/atomic_add.cc (2)

289-291: Switch to tl_atomicadd_elem_op is appropriate

Using the opcode instead of call_extern aligns with the new design. 👍


340-341: InferLayout now returns empty map — verify no callers depend on prior behavior

Previously delegated to par_op_->InferLayout; now returns {}. Please confirm no passes rely on AtomicAddNode::InferLayout producing layouts, since Lower() now computes its own loop layout.

src/op/builtin.cc (1)

278-282: Builtin registration is correct

num_inputs(2) and KOpaque are appropriate for an atomic op. Looks good.

src/transform/atomicadd_vectorize.cc (1)

178-180: Skip rewrite when vector_size_ == 1

Good early‑exit to avoid unnecessary mutation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
src/transform/atomicadd_vectorize.cc (2)

253-261: Fallback must pass address_of(dst) (and handle IfThenElse)

The non-vectorized branch now emits AtomicAdd(dst_value, src_value), but the ABI still expects AtomicAdd(address_of(dst), value). With the new intrinsic, node->args[0] is a BufferLoad (or wrapped in an IfThenElse), so we’re currently handing the callee the value of the destination element, not its address—resulting in incorrect code generation. Recreate the extern call with builtin::address_of() and unwrap IfThenElse the same way the vectorized path does.

-      Array<PrimExpr> new_args;
-      new_args.push_back(StringImm("AtomicAdd"));
-      for (auto x : node->args)
-        new_args.push_back(x);
-
-      Call new_call =
-          tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);
-
-      return new_call;
+      Optional<BufferLoad> dst_bl;
+      const PrimExpr &dst = node->args[0];
+      if (const auto *bl = dst.as<BufferLoadNode>()) {
+        dst_bl = Downcast<BufferLoad>(dst);
+      } else if (const auto *ite = dst.as<IfThenElseNode>()) {
+        if (const auto *then_bl = ite->then_case.as<BufferLoadNode>()) {
+          dst_bl = Downcast<BufferLoad>(ite->then_case);
+        } else if (const auto *else_bl = ite->else_case.as<BufferLoadNode>()) {
+          dst_bl = Downcast<BufferLoad>(ite->else_case);
+        }
+      }
+      if (!dst_bl.defined()) {
+        return GetRef<PrimExpr>(node);
+      }
+      Call address_of_dst =
+          Call(DataType::Handle(), builtin::address_of(), {dst_bl.value()});
+      Array<PrimExpr> new_args;
+      new_args.push_back(StringImm("AtomicAdd"));
+      new_args.push_back(address_of_dst);
+      new_args.push_back(node->args[1]);
+      return Call(node->dtype, builtin::call_extern(), new_args);

228-231: Fix Downcast misuse (pass PrimExpr, not raw Node*)

Downcast expects an ObjectRef; passing BufferLoadNode* won’t compile. Use the original PrimExpr arguments instead.

-      const BufferLoad dst_node =
-          Downcast<BufferLoad>(node->args[0].as<BufferLoadNode>());
-      const BufferLoad value_node =
-          Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
+      const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]);
+      const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 052aac7 and 60880bc.

📒 Files selected for processing (3)
  • src/op/builtin.cc (1 hunks)
  • src/op/builtin.h (1 hunks)
  • src/transform/atomicadd_vectorize.cc (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (3)
src/transform/lower_hopper_intrin.cc (2)
  • call (102-132)
  • call (102-102)
src/transform/loop_vectorize_dynamic.cc (16)
  • node (80-85)
  • node (80-80)
  • node (92-96)
  • node (92-92)
  • node (98-112)
  • node (98-98)
  • node (114-120)
  • node (114-114)
  • node (122-125)
  • node (122-122)
  • node (127-135)
  • node (127-127)
  • node (263-266)
  • node (263-263)
  • node (283-286)
  • node (283-283)
tilelang/language/tir/op.py (3)
  • ret (1758-1771)
  • address_of (463-479)
  • call_extern (172-194)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (2)
src/op/builtin.h (1)

504-510: Declaration looks correct

Signature and docstring align with the new 2-operand intrinsic; no concerns here.

src/op/builtin.cc (1)

298-302: Registration matches the new intrinsic

2-arg arity and opaque call effect align with the intended atomic add builtin.

Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, but some comments :)


Call atomicadd_call =
tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args);
tvm::tir::Call(dst->dtype, tvm::tl::tl_atomicadd_elem_op(), new_args);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace with tl_atomicadd_elem_op()

src/op/builtin.h Outdated
* This op is used to represent an element-wise atomic add operation in
* tilelang.
*/
TVM_DLL const Op &tl_atomicadd_elem_op();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the prefix tl as we already in tl namespace

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/atomic_add.cc (1)

325-341: Remove unused par_op_ and restore layout delegation or explicitly handle all cases.
AtomicAddNode::InferLayout now returns an empty LayoutMap instead of delegating to par_op_->InferLayout, breaking the expected fallback. Either:

  • Reintroduce delegation for non-local.fragment cases:
    return par_op_->InferLayout(T, level);
  • Or remove par_op_ from the class (and its Clone logic) and fully implement layout inference in this method.
♻️ Duplicate comments (2)
src/transform/atomicadd_vectorize.cc (2)

56-58: Fix inconsistent error message.

The warning message refers to "arg1" but the code is checking args[0]. The message should be corrected to match the actual argument being inspected.

Apply this diff:

-          DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg1 type "
-                        << call->args[1]->GetTypeKey()
+          DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg0 type "
+                        << call->args[0]->GetTypeKey()
                         << "; Fallback to no vectorize";

252-262: Fix fallback path to use address_of() for ABI compliance.

The fallback path directly copies arguments without wrapping the destination in address_of(), which breaks the ABI contract for "AtomicAdd" that expects AtomicAdd(address_of(dst), value) as shown in the vectorized path (lines 233-236). Additionally, this doesn't handle the case where dst might be wrapped in an IfThenElse.

Apply this diff:

     } else {
+      // Non-vectorized fallback: AtomicAdd(address_of(dst), value)
+      Optional<BufferLoad> dst_bl;
+      const PrimExpr& dst_arg = node->args[0];
+      if (const auto* bl = dst_arg.as<BufferLoadNode>()) {
+        dst_bl = GetRef<BufferLoad>(bl);
+      } else if (const auto* ite = dst_arg.as<IfThenElseNode>()) {
+        if (const auto* then_bl = ite->then_case.as<BufferLoadNode>()) {
+          dst_bl = GetRef<BufferLoad>(then_bl);
+        } else if (const auto* else_bl = ite->else_case.as<BufferLoadNode>()) {
+          dst_bl = GetRef<BufferLoad>(else_bl);
+        }
+      }
+      if (!dst_bl.defined()) {
+        // Can't form a valid address; keep original call
+        return GetRef<PrimExpr>(node);
+      }
+      Call address_of_dst =
+          Call(DataType::Handle(), builtin::address_of(), {dst_bl.value()});
       Array<PrimExpr> new_args;
       new_args.push_back(StringImm("AtomicAdd"));
-      for (auto x : node->args)
-        new_args.push_back(x);
-
+      new_args.push_back(address_of_dst);
+      new_args.push_back(node->args[1]);
       Call new_call =
           tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);
-
       return new_call;
     }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c5a3050 and 64590d8.

📒 Files selected for processing (5)
  • src/op/atomic_add.cc (2 hunks)
  • src/op/builtin.cc (1 hunks)
  • src/op/builtin.h (1 hunks)
  • src/transform/atomicadd_vectorize.cc (4 hunks)
  • src/transform/atomicadd_vectorize.h (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/transform/atomicadd_vectorize.h
  • src/op/builtin.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (3)
  • ret (1758-1771)
  • address_of (463-479)
  • call_extern (172-194)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.0 (on self-hosted-amd)
🔇 Additional comments (2)
src/op/builtin.h (1)

504-510: LGTM!

The new atomicadd_elem_op() intrinsic declaration is well-documented and follows the established pattern for other intrinsics in this file.

src/op/atomic_add.cc (1)

290-290: LGTM!

The change from builtin::call_extern() to atomicadd_elem_op() correctly aligns with the new intrinsic introduced in src/op/builtin.h.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
src/transform/atomicadd_vectorize.cc (1)

56-58: Fix inconsistent warning: references arg1 but logs arg0's type.

The warning message says "Unexpected arg1 type" and logs call->args[1]->GetTypeKey(), but this else-branch is reached when call->args[0] is neither a BufferLoadNode nor an IfThenElseNode. The warning should reference arg0 instead.

Apply this diff:

-          DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg1 type "
-                        << call->args[1]->GetTypeKey()
+          DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg0 type "
+                        << call->args[0]->GetTypeKey()
                         << "; Fallback to no vectorize";
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 64590d8 and cc56ff9.

📒 Files selected for processing (1)
  • src/transform/atomicadd_vectorize.cc (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (2)
src/transform/loop_vectorize_dynamic.cc (16)
  • node (80-85)
  • node (80-80)
  • node (92-96)
  • node (92-92)
  • node (98-112)
  • node (98-98)
  • node (114-120)
  • node (114-114)
  • node (122-125)
  • node (122-122)
  • node (127-135)
  • node (127-127)
  • node (263-266)
  • node (263-263)
  • node (283-286)
  • node (283-283)
tilelang/language/tir/op.py (3)
  • ret (1758-1771)
  • address_of (463-479)
  • call_extern (172-194)
🔇 Additional comments (3)
src/transform/atomicadd_vectorize.cc (3)

26-34: LGTM: Arity validation is correct.

The check for call->args.size() < 2 with an early return properly prevents out-of-bounds access in later code that indexes args[0] and args[1].


79-94: LGTM: Size check prevents out-of-bounds access.

The explicit check node->args.size() < 2 at lines 80-82 properly guards the subsequent access to node->args[0] and node->args[1].


189-190: LGTM: Optimization to skip rewriting when vectorization is disabled.

Short-circuiting when vector_size_ == 1 avoids unnecessary transformation overhead.

Comment on lines +213 to 260
bool legal_vectorize = true;
if (dynamic_)
legal_vectorize = false;
if (!(node->op == atomicadd_elem_op()))
legal_vectorize = false;
if (node->args.size() < 2)
legal_vectorize = false;
if (legal_vectorize) {
const BufferLoadNode *temp_dst_node = node->args[0].as<BufferLoadNode>();
const BufferLoadNode *temp_value_node =
node->args[1].as<BufferLoadNode>();
if (!temp_dst_node || !temp_value_node)
legal_vectorize = false;
}
if (vector_size_ == 2 || vector_size_ == 4) {
if (node->op == builtin::call_extern() && node->args.size() >= 2) {
if (const auto *func_name = node->args[0].as<StringImmNode>()) {
if (func_name->value == "AtomicAdd") {
const BufferLoadNode *temp_dst_node =
node->args[1].as<BufferLoadNode>();
const BufferLoadNode *temp_value_node =
node->args[2].as<BufferLoadNode>();
if (!temp_dst_node || !temp_value_node) {
return StmtExprMutator::VisitExpr_(node);
}
const BufferLoad dst_node =
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
const BufferLoad value_node =
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
if (legal_vectorize) {
const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]);
const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);

Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node});
Call address_of_value =
Call(DataType::Handle(), builtin::address_of(), {value_node});
Array<PrimExpr> new_args;
if (vector_size_ == 2) {
new_args.push_back(StringImm("AtomicAddx2"));
} else {
new_args.push_back(StringImm("AtomicAddx4"));
}
new_args.push_back(address_of_dst);
new_args.push_back(address_of_value);
Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node});
Call address_of_value =
Call(DataType::Handle(), builtin::address_of(), {value_node});
Array<PrimExpr> new_args;
if (vector_size_ == 4) {
new_args.push_back(StringImm("AtomicAddx4"));
} else if (vector_size_ == 2) {
new_args.push_back(StringImm("AtomicAddx2"));
} else {
new_args.push_back(StringImm("AtomicAdd"));
}
new_args.push_back(address_of_dst);
new_args.push_back(address_of_value);

Call new_call =
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);
Call new_call =
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);

return new_call;
}
}
}
return new_call;
} else {
Array<PrimExpr> new_args;
new_args.push_back(StringImm("AtomicAdd"));
for (auto x : node->args)
new_args.push_back(x);

Call new_call =
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);

return new_call;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Fallback path has multiple issues preventing correct lowering.

The fallback logic at lines 250-260 has two critical problems:

  1. Missing early return for non-atomicadd_elem_op calls: When node->op != atomicadd_elem_op() (line 216), the code sets legal_vectorize = false but still proceeds to emit call_extern("AtomicAdd", ...) in the fallback. Non-atomicadd_elem_op calls should remain unchanged.

  2. Missing address_of() wrappers: The fallback blindly pushes node->args into call_extern("AtomicAdd", ...), but the AtomicAdd ABI requires address_of(dst) as the first argument (and possibly address_of(value) as the second). The vectorized path correctly uses address_of() at lines 231-234, but the fallback doesn't.

Apply this diff to fix both issues:

 PrimExpr VisitExpr_(const CallNode *node) final {
+  // Only transform atomicadd_elem_op calls
+  if (node->op != atomicadd_elem_op()) {
+    return StmtExprMutator::VisitExpr_(node);
+  }
+
   bool legal_vectorize = true;
   if (dynamic_)
     legal_vectorize = false;
-  if (!(node->op == atomicadd_elem_op()))
-    legal_vectorize = false;
   if (node->args.size() < 2)
     legal_vectorize = false;
   if (legal_vectorize) {
     const BufferLoadNode *temp_dst_node = node->args[0].as<BufferLoadNode>();
     const BufferLoadNode *temp_value_node =
         node->args[1].as<BufferLoadNode>();
     if (!temp_dst_node || !temp_value_node)
       legal_vectorize = false;
   }
   if (legal_vectorize) {
     const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]);
     const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);

     Call address_of_dst =
         Call(DataType::Handle(), builtin::address_of(), {dst_node});
     Call address_of_value =
         Call(DataType::Handle(), builtin::address_of(), {value_node});
     Array<PrimExpr> new_args;
     if (vector_size_ == 4) {
       new_args.push_back(StringImm("AtomicAddx4"));
     } else if (vector_size_ == 2) {
       new_args.push_back(StringImm("AtomicAddx2"));
     } else {
       new_args.push_back(StringImm("AtomicAdd"));
     }
     new_args.push_back(address_of_dst);
     new_args.push_back(address_of_value);

     Call new_call =
         tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);

     return new_call;
   } else {
+    // Non-vectorized fallback: extract BufferLoad and emit scalar AtomicAdd
+    // with address_of() wrappers
+    Optional<BufferLoad> dst_bl;
+    const PrimExpr& dst_arg = node->args[0];
+    if (const auto* bl = dst_arg.as<BufferLoadNode>()) {
+      dst_bl = Downcast<BufferLoad>(dst_arg);
+    } else if (const auto* ite = dst_arg.as<IfThenElseNode>()) {
+      if (const auto* then_bl = ite->then_case.as<BufferLoadNode>()) {
+        dst_bl = Downcast<BufferLoad>(ite->then_case);
+      } else if (const auto* else_bl = ite->else_case.as<BufferLoadNode>()) {
+        dst_bl = Downcast<BufferLoad>(ite->else_case);
+      }
+    }
+    if (!dst_bl.defined()) {
+      // Can't extract a valid BufferLoad; keep original call
+      return GetRef<PrimExpr>(node);
+    }
+
+    // Similarly extract value BufferLoad if available
+    Optional<BufferLoad> value_bl;
+    const PrimExpr& value_arg = node->args[1];
+    if (const auto* bl = value_arg.as<BufferLoadNode>()) {
+      value_bl = Downcast<BufferLoad>(value_arg);
+    }
+    
+    Call address_of_dst =
+        Call(DataType::Handle(), builtin::address_of(), {dst_bl.value()});
+    
     Array<PrimExpr> new_args;
     new_args.push_back(StringImm("AtomicAdd"));
-    for (auto x : node->args)
-      new_args.push_back(x);
+    new_args.push_back(address_of_dst);
+    // For value: if it's a BufferLoad, use address_of; otherwise use as-is
+    if (value_bl.defined()) {
+      Call address_of_value =
+          Call(DataType::Handle(), builtin::address_of(), {value_bl.value()});
+      new_args.push_back(address_of_value);
+    } else {
+      new_args.push_back(node->args[1]);
+    }

     Call new_call =
         tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);

     return new_call;
   }
 }
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 213-260, the fallback path
must not alter calls that are not atomicadd_elem_op and must use address_of
wrappers like the vectorized path; modify the code so that if node->op !=
atomicadd_elem_op() you immediately return the original node (no
transformation), and in the fallback path construct address_of(dst) (and
address_of(value) when the second arg is a BufferLoad) instead of pushing raw
node->args: Downcast the first (and second if present) arg(s) to BufferLoad,
build Calls to builtin::address_of() for them, push the appropriate "AtomicAdd"
name, then push the address_of Call(s) followed by any remaining args, and
return the new call_extern; keep existing behavior for dynamic_ and other checks
unchanged.

@LeiWang1999 LeiWang1999 merged commit 0ff4f42 into tile-ai:main Oct 16, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants