-
Notifications
You must be signed in to change notification settings - Fork 264
[Feature]: Add test for atomicadd auto vectorize and remove useless code #1019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughReplaces string-based extern AtomicAdd calls with a new tilelang intrinsic Changes
Sequence Diagram(s)%%{init: {"themeVariables": {"actorBkg":"#f3f4f6","noteBkg":"#fff8dc"}} }%%
sequenceDiagram
autonumber
participant TIR as TIR (call site)
participant TL as tl_atomicadd_elem_op
participant Planner as VectorizePlanner
participant Rewriter as VectorizeRewriter
participant Lower as Lowering
TIR->>TL: atomicadd_elem_op(dst, src, ...)
TL-->>TIR: intrinsic opcode recognized
TIR->>Planner: analyze for BufferLoad / IfThenElse+BufferLoad
Planner-->>TIR: vector_size (2,4 or 1)
TIR->>Rewriter: rewrite(vector_size)
alt vector_size == 1
Rewriter-->>TIR: keep scalar AtomicAdd (extern)
else vector_size == 2 or 4
Rewriter-->>TIR: emit AtomicAddx2 / AtomicAddx4, reshape operands, adjust address_of
else
Rewriter-->>TIR: fallback to scalar AtomicAdd extern
end
TIR->>Lower: lower rewritten calls
Lower-->>TIR: device code emitted
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/atomicadd_vectorize.cc (1)
217-255
: Incorrect Downcast usage causes compile errorYou’re downcasting from a pointer;
Downcast
expects anObjectRef
/PrimExpr
. Fix both downcasts:- const BufferLoad dst_node = - Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); - const BufferLoad value_node = - Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()); + const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]); + const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);Optionally, remove the earlier
temp_*_node
guards and directlyas<BufferLoadNode>()
once for checks, thenDowncast
from the originalPrimExpr
.
🧹 Nitpick comments (4)
src/transform/atomicadd_vectorize.cc (3)
42-71
: Handle dtype inference when src (arg2) is guarded by IfThenElsePlan currently extracts dtype from dst (arg1) or an IfThenElse in arg1, but often masking happens on the value side (arg2). Consider symmetric handling for arg2 to avoid unnecessary fallback.
Example adjustment (sketch):
- } else if (const auto *ite = call->args[1].as<IfThenElseNode>()) { + } else if (const auto *ite = call->args[1].as<IfThenElseNode>()) { ... + } else if (const auto* ite2 = call->args[2].as<IfThenElseNode>()) { + if (const auto* then_load = ite2->then_case.as<BufferLoadNode>()) { + dtype = then_load->dtype; + vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); + } else if (const auto* else_load = ite2->else_case.as<BufferLoadNode>()) { + dtype = else_load->dtype; + vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); + } else { + vectorize_size_max = 1; + DLOG(WARNING) << "[AtomicAddVectorizePlanner] IfThenElse (arg2) has no BufferLoad; Fallback"; + } } else {
159-159
: Non-English comment
// 补齐相关test
— consider translating or removing to keep codebase comments in English.
193-215
: Docstring vs implementation mismatch for new loop var nameDoc says create
"<old_loop_var>_outer"
, but code usesVar(old_var->name_hint)
(same name). Prefer explicit"_outer"
for readability and to match docs.- auto new_var = Var(old_var->name_hint); + auto new_var = Var(old_var->name_hint + "_outer");testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (1)
25-26
: Unused bz binding in both IR modules (RUF059)Prefix the unused binding with underscore to silence lints and clarify intent:
- threads=128) as (bx, by, bz): + threads=128) as (bx, by, _bz):Repeat for the After module.
Also applies to: 36-36
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/op/atomic_add.cc
(0 hunks)src/transform/atomicadd_vectorize.cc
(5 hunks)testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py
(1 hunks)
💤 Files with no reviewable changes (1)
- src/op/atomic_add.cc
🧰 Additional context used
🧬 Code graph analysis (2)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (2)
call_extern
(172-194)address_of
(463-479)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (3)
tilelang/utils/target.py (1)
determine_target
(54-99)tilelang/language/allocate.py (1)
alloc_shared
(21-36)tilelang/language/tir/op.py (2)
call_extern
(172-194)address_of
(463-479)
🪛 Ruff (0.13.3)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py
25-25: Unpacked variable bz
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
36-36: Unpacked variable bz
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-metal
- GitHub Check: format-check
auto_target = tvm.target.Target(determine_target("auto")) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make the test assert the structural equality and avoid import-time target failures
- The result of
tvm.ir.structural_equal(mod, ref_mod)
is unused; test won’t fail on mismatch. auto_target
is computed at import time and may raiseValueError
on CI without CUDA/HIP/MPS.
Apply:
- auto_target = tvm.target.Target(determine_target("auto"))
+ # target selection is done inside the test to allow graceful skipping
@@
- with tvm.transform.PassContext():
- mod = tvm.tir.transform.BindTarget(auto_target)(Before)
+ # Select target at runtime; skip if none available.
+ try:
+ auto_target = tvm.target.Target(determine_target("auto"))
+ except ValueError:
+ pytest.skip("No CUDA/HIP/MPS available; skipping atomicadd vectorize test.")
+
+ with tvm.transform.PassContext():
+ mod = tvm.tir.transform.BindTarget(auto_target)(Before)
mod = tl.transform.LowerTileOp()(mod)
mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
- tvm.ir.structural_equal(mod, ref_mod)
+ assert tvm.ir.structural_equal(mod, ref_mod)
Also applies to: 54-61
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/atomicadd_vectorize.cc (1)
149-152
: Prevent vector_size_ from reaching zero in halving loopIf
IndiceCanVectorize
never returns true,vector_size_
can become 0, leading to UB/div-by-zero later. Add a lower bound.- while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, - inner_for_->extent, vector_size_, &analyzer_)) { - vector_size_ /= 2; - } + while (vector_size_ > 1 && + !IndiceCanVectorize(elem_offset, inner_for_->loop_var, + inner_for_->extent, vector_size_, &analyzer_)) { + vector_size_ /= 2; + } + if (vector_size_ < 1) vector_size_ = 1;
♻️ Duplicate comments (3)
src/transform/atomicadd_vectorize.cc (2)
20-31
: Tighten extern-call detection; avoid OOB and remove redundant null-checkRequire at least 3 args and drop the dead
if (!s)
; callers indexargs[2]
later. This prevents OOB when malformed calls slip through.Apply:
-inline bool IsAtomicAddExternCall(const CallNode *call) { - if (!call || !call->op.same_as(builtin::call_extern())) - return false; - if (call->args.size() < 2) - return false; - if (const auto *s = call->args[0].as<StringImmNode>()) { - if (!s) - return false; - return std::string(s->value) == AtomicAddExternName(); - } - return false; -} +inline bool IsAtomicAddExternCall(const CallNode* call) { + if (!call || !call->op.same_as(builtin::call_extern())) return false; + // call_extern("AtomicAdd", dst, src) => need at least 3 args + if (call->args.size() < 3) return false; + if (const auto* s = call->args[0].as<StringImmNode>()) { + return std::string(s->value) == AtomicAddExternName(); + } + return false; +}
88-99
: Guard access to args[2] (defense-in-depth)Add an explicit count check before indexing
args[2]
to avoid accidental OOB if detection changes.- if (IsAtomicAddExternCall(node)) { + if (IsAtomicAddExternCall(node)) { + ICHECK_GE(node->args.size(), 3); const BufferLoadNode *buffer_load_dst = node->args[1].as<BufferLoadNode>(); const BufferLoadNode *buffer_load_src = node->args[2].as<BufferLoadNode>();testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (1)
70-73
: Make the test assert structural equalityCurrently the result is unused; test won’t fail on mismatch. Assert it.
- tvm.ir.structural_equal(mod, ref_mod) - # tvm.ir.assert_structural_equal(mod, ref_mod) + assert tvm.ir.structural_equal(mod, ref_mod) + # or: tvm.ir.assert_structural_equal(mod, ref_mod)
🧹 Nitpick comments (2)
src/transform/atomicadd_vectorize.cc (1)
159-159
: Prefer English-only comments in sourceReplace the non-English comment with an English equivalent for consistency.
-// 补齐相关test +// TODO: Add/align related teststesting/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (1)
23-24
: Use the parameterizedthreads
and prefix unusedbz
Avoid ARG001/RUF059 and reflect the parameterization.
- with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, - threads=128) as (bx, by, bz): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, + threads=threads) as (bx, by, _bz): @@ - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, - threads=128) as (bx, by, bz): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, + threads=threads) as (bx, by, _bz):Also applies to: 34-35
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/op/atomic_add.cc
(0 hunks)src/transform/atomicadd_vectorize.cc
(5 hunks)testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py
(1 hunks)
💤 Files with no reviewable changes (1)
- src/op/atomic_add.cc
🧰 Additional context used
🧬 Code graph analysis (2)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (4)
tilelang/utils/target.py (1)
determine_target
(54-99)tilelang/engine/lower.py (1)
canon_target_host
(127-132)tilelang/language/allocate.py (1)
alloc_shared
(21-36)tilelang/language/parallel.py (1)
Parallel
(8-28)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (2)
call_extern
(172-194)address_of
(463-479)
🪛 Ruff (0.14.0)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py
14-14: Unused function argument: threads
(ARG001)
23-23: Unpacked variable bz
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
34-34: Unpacked variable bz
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: build-test-metal
- GitHub Check: build-test-amd
- GitHub Check: format-check
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/atomicadd_vectorize.cc (1)
149-152
: Prevent vector_size_ from becoming 0 (runtime crash risk)Division by 2 can drop to 0 if vectorization is impossible; downstream checks then divide by or mod 0.
- while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, - inner_for_->extent, vector_size_, &analyzer_)) { - vector_size_ /= 2; - } + while (vector_size_ > 1 && + !IndiceCanVectorize(elem_offset, inner_for_->loop_var, + inner_for_->extent, vector_size_, &analyzer_)) { + vector_size_ /= 2; + } + // If still not vectorizable at size 1, bail out gracefully. + if (vector_size_ < 1) vector_size_ = 1;
♻️ Duplicate comments (3)
src/transform/atomicadd_vectorize.cc (2)
20-31
: Tighten extern-call detection; remove redundant null-checkRequire at least 3 args to avoid OOB at args[2]; drop the unnecessary
if (!s)
inside the cast branch.-inline bool IsAtomicAddExternCall(const CallNode *call) { - if (!call || !call->op.same_as(builtin::call_extern())) - return false; - if (call->args.size() < 2) - return false; - if (const auto *s = call->args[0].as<StringImmNode>()) { - if (!s) - return false; - return std::string(s->value) == AtomicAddExternName(); - } - return false; -} +inline bool IsAtomicAddExternCall(const CallNode* call) { + if (!call || !call->op.same_as(builtin::call_extern())) return false; + // call_extern("AtomicAdd", dst, src) + if (call->args.size() < 3) return false; + if (const auto* s = call->args[0].as<StringImmNode>()) { + return s->value == AtomicAddExternName(); + } + return false; +}
88-99
: Guard access to args[2] (local safety even if detector regresses)Add an explicit size check before indexing
args[2]
. Safer and self‑contained.- if (IsAtomicAddExternCall(node)) { + if (IsAtomicAddExternCall(node)) { + ICHECK_GE(node->args.size(), 3); const BufferLoadNode *buffer_load_dst = node->args[1].as<BufferLoadNode>(); const BufferLoadNode *buffer_load_src = node->args[2].as<BufferLoadNode>();testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (1)
71-75
: Make the test fail on mismatchCurrently the equality result is ignored. Assert it.
- tvm.ir.structural_equal(mod, ref_mod) - # tvm.ir.assert_structural_equal(mod, ref_mod) + tvm.ir.assert_structural_equal(mod, ref_mod)
🧹 Nitpick comments (5)
src/transform/atomicadd_vectorize.cc (2)
126-126
: Remove unused variable
access_type
is never used.- const DataType &access_type = buffer->dtype;
193-215
: Preserve loop var dtype and clarify namingConstruct the new loop var with the original dtype and a distinct name (matches the docstring).
- auto old_var = fnode->loop_var; - auto new_var = Var(old_var->name_hint); + auto old_var = fnode->loop_var; + auto new_var = Var(old_var->name_hint + "_outer", old_var->dtype);testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (3)
49-55
: Robust SM parsing (handle missing digits like sm_)Guard against empty digit strings to avoid ValueError.
arch_part = arch.split("_")[1] - cc = int("".join(ch for ch in arch_part if ch.isdigit())) + digits = "".join(ch for ch in arch_part if ch.isdigit()) + if not digits: + pytest.skip(f"Unrecognized CUDA arch: {arch}") + cc = int(digits) if cc != 90: pytest.skip(f"Requires Hopper (SM90), got {arch}")
21-24
: Use the parametrized threads and silence bz unused lintWire through
threads
instead of hardcoding 128; prefix the unusedbz
.- with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, - threads=128) as (bx, by, bz): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, + threads=threads) as (bx, by, _bz): @@ - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, - threads=128) as (bx, by, bz): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, + threads=threads) as (bx, by, _bz):Also applies to: 33-35
68-70
: Optional: drop noisy printsRemove prints to keep CI logs clean.
- print(mod) - print(ref_mod)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/transform/atomicadd_vectorize.cc
(5 hunks)testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (2)
call_extern
(172-194)address_of
(463-479)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py (2)
tilelang/utils/target.py (1)
determine_target
(54-99)tilelang/engine/lower.py (1)
canon_target_host
(127-132)
🪛 Ruff (0.14.0)
testing/python/transform/test_tilelang_transform_atomicadd_auto_vectorize.py
14-14: Unused function argument: threads
(ARG001)
23-23: Unpacked variable bz
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
34-34: Unpacked variable bz
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: build-test-metal
- GitHub Check: build-test-nvidia
- GitHub Check: build-test-amd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
src/transform/atomicadd_vectorize.cc (1)
217-221
: Fix Downcast misuse (compile error): pass ObjectRef, not Node*Downcast expects an ObjectRef. Using .as() yields a raw pointer.
Apply:
- const BufferLoad dst_node = - Downcast<BufferLoad>(node->args[0].as<BufferLoadNode>()); - const BufferLoad value_node = - Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); + const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]); + const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/op/atomic_add.cc
(2 hunks)src/op/builtin.cc
(1 hunks)src/op/builtin.h
(1 hunks)src/transform/atomicadd_vectorize.cc
(4 hunks)src/transform/atomicadd_vectorize.h
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (2)
src/transform/loop_vectorize_dynamic.cc (16)
node
(80-85)node
(80-80)node
(92-96)node
(92-92)node
(98-112)node
(98-98)node
(114-120)node
(114-114)node
(122-125)node
(122-122)node
(127-135)node
(127-127)node
(263-266)node
(263-263)node
(283-286)node
(283-283)tilelang/language/tir/op.py (3)
ret
(1758-1771)address_of
(463-479)call_extern
(172-194)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: build-test-metal
- GitHub Check: build-test-amd
- GitHub Check: build-test-nvidia
🔇 Additional comments (6)
src/transform/atomicadd_vectorize.h (1)
11-11
: Include addition looks goodRequired for tl_atomicadd_elem_op usage downstream. No concerns.
src/op/builtin.h (1)
472-479
: New intrinsic declaration is consistentTwo‑arg, opaque intrinsic aligns with intended atomic add element op. LGTM.
src/op/atomic_add.cc (2)
289-291
: Switch to tl_atomicadd_elem_op is appropriateUsing the opcode instead of call_extern aligns with the new design. 👍
340-341
: InferLayout now returns empty map — verify no callers depend on prior behaviorPreviously delegated to par_op_->InferLayout; now returns {}. Please confirm no passes rely on AtomicAddNode::InferLayout producing layouts, since Lower() now computes its own loop layout.
src/op/builtin.cc (1)
278-282
: Builtin registration is correctnum_inputs(2) and KOpaque are appropriate for an atomic op. Looks good.
src/transform/atomicadd_vectorize.cc (1)
178-180
: Skip rewrite when vector_size_ == 1Good early‑exit to avoid unnecessary mutation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
src/transform/atomicadd_vectorize.cc (2)
253-261
: Fallback must passaddress_of(dst)
(and handleIfThenElse
)The non-vectorized branch now emits
AtomicAdd(dst_value, src_value)
, but the ABI still expectsAtomicAdd(address_of(dst), value)
. With the new intrinsic,node->args[0]
is aBufferLoad
(or wrapped in anIfThenElse
), so we’re currently handing the callee the value of the destination element, not its address—resulting in incorrect code generation. Recreate the extern call withbuiltin::address_of()
and unwrapIfThenElse
the same way the vectorized path does.- Array<PrimExpr> new_args; - new_args.push_back(StringImm("AtomicAdd")); - for (auto x : node->args) - new_args.push_back(x); - - Call new_call = - tvm::tir::Call(node->dtype, builtin::call_extern(), new_args); - - return new_call; + Optional<BufferLoad> dst_bl; + const PrimExpr &dst = node->args[0]; + if (const auto *bl = dst.as<BufferLoadNode>()) { + dst_bl = Downcast<BufferLoad>(dst); + } else if (const auto *ite = dst.as<IfThenElseNode>()) { + if (const auto *then_bl = ite->then_case.as<BufferLoadNode>()) { + dst_bl = Downcast<BufferLoad>(ite->then_case); + } else if (const auto *else_bl = ite->else_case.as<BufferLoadNode>()) { + dst_bl = Downcast<BufferLoad>(ite->else_case); + } + } + if (!dst_bl.defined()) { + return GetRef<PrimExpr>(node); + } + Call address_of_dst = + Call(DataType::Handle(), builtin::address_of(), {dst_bl.value()}); + Array<PrimExpr> new_args; + new_args.push_back(StringImm("AtomicAdd")); + new_args.push_back(address_of_dst); + new_args.push_back(node->args[1]); + return Call(node->dtype, builtin::call_extern(), new_args);
228-231
: FixDowncast
misuse (passPrimExpr
, not rawNode*
)
Downcast
expects anObjectRef
; passingBufferLoadNode*
won’t compile. Use the originalPrimExpr
arguments instead.- const BufferLoad dst_node = - Downcast<BufferLoad>(node->args[0].as<BufferLoadNode>()); - const BufferLoad value_node = - Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); + const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]); + const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/op/builtin.cc
(1 hunks)src/op/builtin.h
(1 hunks)src/transform/atomicadd_vectorize.cc
(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (3)
src/transform/lower_hopper_intrin.cc (2)
call
(102-132)call
(102-102)src/transform/loop_vectorize_dynamic.cc (16)
node
(80-85)node
(80-80)node
(92-96)node
(92-92)node
(98-112)node
(98-98)node
(114-120)node
(114-114)node
(122-125)node
(122-122)node
(127-135)node
(127-127)node
(263-266)node
(263-263)node
(283-286)node
(283-283)tilelang/language/tir/op.py (3)
ret
(1758-1771)address_of
(463-479)call_extern
(172-194)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (2)
src/op/builtin.h (1)
504-510
: Declaration looks correctSignature and docstring align with the new 2-operand intrinsic; no concerns here.
src/op/builtin.cc (1)
298-302
: Registration matches the new intrinsic2-arg arity and opaque call effect align with the intended atomic add builtin.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, but some comments :)
src/op/atomic_add.cc
Outdated
|
||
Call atomicadd_call = | ||
tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args); | ||
tvm::tir::Call(dst->dtype, tvm::tl::tl_atomicadd_elem_op(), new_args); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace with tl_atomicadd_elem_op()
src/op/builtin.h
Outdated
* This op is used to represent an element-wise atomic add operation in | ||
* tilelang. | ||
*/ | ||
TVM_DLL const Op &tl_atomicadd_elem_op(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove the prefix tl
as we already in tl namespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/atomic_add.cc (1)
325-341
: Remove unusedpar_op_
and restore layout delegation or explicitly handle all cases.
AtomicAddNode::InferLayout
now returns an emptyLayoutMap
instead of delegating topar_op_->InferLayout
, breaking the expected fallback. Either:
- Reintroduce delegation for non-
local.fragment
cases:
return par_op_->InferLayout(T, level);- Or remove
par_op_
from the class (and its Clone logic) and fully implement layout inference in this method.
♻️ Duplicate comments (2)
src/transform/atomicadd_vectorize.cc (2)
56-58
: Fix inconsistent error message.The warning message refers to "arg1" but the code is checking
args[0]
. The message should be corrected to match the actual argument being inspected.Apply this diff:
- DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg1 type " - << call->args[1]->GetTypeKey() + DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg0 type " + << call->args[0]->GetTypeKey() << "; Fallback to no vectorize";
252-262
: Fix fallback path to use address_of() for ABI compliance.The fallback path directly copies arguments without wrapping the destination in
address_of()
, which breaks the ABI contract for "AtomicAdd" that expectsAtomicAdd(address_of(dst), value)
as shown in the vectorized path (lines 233-236). Additionally, this doesn't handle the case wheredst
might be wrapped in anIfThenElse
.Apply this diff:
} else { + // Non-vectorized fallback: AtomicAdd(address_of(dst), value) + Optional<BufferLoad> dst_bl; + const PrimExpr& dst_arg = node->args[0]; + if (const auto* bl = dst_arg.as<BufferLoadNode>()) { + dst_bl = GetRef<BufferLoad>(bl); + } else if (const auto* ite = dst_arg.as<IfThenElseNode>()) { + if (const auto* then_bl = ite->then_case.as<BufferLoadNode>()) { + dst_bl = GetRef<BufferLoad>(then_bl); + } else if (const auto* else_bl = ite->else_case.as<BufferLoadNode>()) { + dst_bl = GetRef<BufferLoad>(else_bl); + } + } + if (!dst_bl.defined()) { + // Can't form a valid address; keep original call + return GetRef<PrimExpr>(node); + } + Call address_of_dst = + Call(DataType::Handle(), builtin::address_of(), {dst_bl.value()}); Array<PrimExpr> new_args; new_args.push_back(StringImm("AtomicAdd")); - for (auto x : node->args) - new_args.push_back(x); - + new_args.push_back(address_of_dst); + new_args.push_back(node->args[1]); Call new_call = tvm::tir::Call(node->dtype, builtin::call_extern(), new_args); - return new_call; }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/op/atomic_add.cc
(2 hunks)src/op/builtin.cc
(1 hunks)src/op/builtin.h
(1 hunks)src/transform/atomicadd_vectorize.cc
(4 hunks)src/transform/atomicadd_vectorize.h
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/transform/atomicadd_vectorize.h
- src/op/builtin.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (3)
ret
(1758-1771)address_of
(463-479)call_extern
(172-194)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.0 (on self-hosted-amd)
🔇 Additional comments (2)
src/op/builtin.h (1)
504-510
: LGTM!The new
atomicadd_elem_op()
intrinsic declaration is well-documented and follows the established pattern for other intrinsics in this file.src/op/atomic_add.cc (1)
290-290
: LGTM!The change from
builtin::call_extern()
toatomicadd_elem_op()
correctly aligns with the new intrinsic introduced insrc/op/builtin.h
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
src/transform/atomicadd_vectorize.cc (1)
56-58
: Fix inconsistent warning: references arg1 but logs arg0's type.The warning message says "Unexpected arg1 type" and logs
call->args[1]->GetTypeKey()
, but this else-branch is reached whencall->args[0]
is neither aBufferLoadNode
nor anIfThenElseNode
. The warning should reference arg0 instead.Apply this diff:
- DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg1 type " - << call->args[1]->GetTypeKey() + DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg0 type " + << call->args[0]->GetTypeKey() << "; Fallback to no vectorize";
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/atomicadd_vectorize.cc
(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (2)
src/transform/loop_vectorize_dynamic.cc (16)
node
(80-85)node
(80-80)node
(92-96)node
(92-92)node
(98-112)node
(98-98)node
(114-120)node
(114-114)node
(122-125)node
(122-122)node
(127-135)node
(127-127)node
(263-266)node
(263-263)node
(283-286)node
(283-283)tilelang/language/tir/op.py (3)
ret
(1758-1771)address_of
(463-479)call_extern
(172-194)
🔇 Additional comments (3)
src/transform/atomicadd_vectorize.cc (3)
26-34
: LGTM: Arity validation is correct.The check for
call->args.size() < 2
with an early return properly prevents out-of-bounds access in later code that indexesargs[0]
andargs[1]
.
79-94
: LGTM: Size check prevents out-of-bounds access.The explicit check
node->args.size() < 2
at lines 80-82 properly guards the subsequent access tonode->args[0]
andnode->args[1]
.
189-190
: LGTM: Optimization to skip rewriting when vectorization is disabled.Short-circuiting when
vector_size_ == 1
avoids unnecessary transformation overhead.
bool legal_vectorize = true; | ||
if (dynamic_) | ||
legal_vectorize = false; | ||
if (!(node->op == atomicadd_elem_op())) | ||
legal_vectorize = false; | ||
if (node->args.size() < 2) | ||
legal_vectorize = false; | ||
if (legal_vectorize) { | ||
const BufferLoadNode *temp_dst_node = node->args[0].as<BufferLoadNode>(); | ||
const BufferLoadNode *temp_value_node = | ||
node->args[1].as<BufferLoadNode>(); | ||
if (!temp_dst_node || !temp_value_node) | ||
legal_vectorize = false; | ||
} | ||
if (vector_size_ == 2 || vector_size_ == 4) { | ||
if (node->op == builtin::call_extern() && node->args.size() >= 2) { | ||
if (const auto *func_name = node->args[0].as<StringImmNode>()) { | ||
if (func_name->value == "AtomicAdd") { | ||
const BufferLoadNode *temp_dst_node = | ||
node->args[1].as<BufferLoadNode>(); | ||
const BufferLoadNode *temp_value_node = | ||
node->args[2].as<BufferLoadNode>(); | ||
if (!temp_dst_node || !temp_value_node) { | ||
return StmtExprMutator::VisitExpr_(node); | ||
} | ||
const BufferLoad dst_node = | ||
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); | ||
const BufferLoad value_node = | ||
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()); | ||
if (legal_vectorize) { | ||
const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]); | ||
const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]); | ||
|
||
Call address_of_dst = | ||
Call(DataType::Handle(), builtin::address_of(), {dst_node}); | ||
Call address_of_value = | ||
Call(DataType::Handle(), builtin::address_of(), {value_node}); | ||
Array<PrimExpr> new_args; | ||
if (vector_size_ == 2) { | ||
new_args.push_back(StringImm("AtomicAddx2")); | ||
} else { | ||
new_args.push_back(StringImm("AtomicAddx4")); | ||
} | ||
new_args.push_back(address_of_dst); | ||
new_args.push_back(address_of_value); | ||
Call address_of_dst = | ||
Call(DataType::Handle(), builtin::address_of(), {dst_node}); | ||
Call address_of_value = | ||
Call(DataType::Handle(), builtin::address_of(), {value_node}); | ||
Array<PrimExpr> new_args; | ||
if (vector_size_ == 4) { | ||
new_args.push_back(StringImm("AtomicAddx4")); | ||
} else if (vector_size_ == 2) { | ||
new_args.push_back(StringImm("AtomicAddx2")); | ||
} else { | ||
new_args.push_back(StringImm("AtomicAdd")); | ||
} | ||
new_args.push_back(address_of_dst); | ||
new_args.push_back(address_of_value); | ||
|
||
Call new_call = | ||
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args); | ||
Call new_call = | ||
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args); | ||
|
||
return new_call; | ||
} | ||
} | ||
} | ||
return new_call; | ||
} else { | ||
Array<PrimExpr> new_args; | ||
new_args.push_back(StringImm("AtomicAdd")); | ||
for (auto x : node->args) | ||
new_args.push_back(x); | ||
|
||
Call new_call = | ||
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args); | ||
|
||
return new_call; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Fallback path has multiple issues preventing correct lowering.
The fallback logic at lines 250-260 has two critical problems:
-
Missing early return for non-atomicadd_elem_op calls: When
node->op != atomicadd_elem_op()
(line 216), the code setslegal_vectorize = false
but still proceeds to emitcall_extern("AtomicAdd", ...)
in the fallback. Non-atomicadd_elem_op calls should remain unchanged. -
Missing address_of() wrappers: The fallback blindly pushes
node->args
intocall_extern("AtomicAdd", ...)
, but the AtomicAdd ABI requiresaddress_of(dst)
as the first argument (and possiblyaddress_of(value)
as the second). The vectorized path correctly usesaddress_of()
at lines 231-234, but the fallback doesn't.
Apply this diff to fix both issues:
PrimExpr VisitExpr_(const CallNode *node) final {
+ // Only transform atomicadd_elem_op calls
+ if (node->op != atomicadd_elem_op()) {
+ return StmtExprMutator::VisitExpr_(node);
+ }
+
bool legal_vectorize = true;
if (dynamic_)
legal_vectorize = false;
- if (!(node->op == atomicadd_elem_op()))
- legal_vectorize = false;
if (node->args.size() < 2)
legal_vectorize = false;
if (legal_vectorize) {
const BufferLoadNode *temp_dst_node = node->args[0].as<BufferLoadNode>();
const BufferLoadNode *temp_value_node =
node->args[1].as<BufferLoadNode>();
if (!temp_dst_node || !temp_value_node)
legal_vectorize = false;
}
if (legal_vectorize) {
const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]);
const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);
Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node});
Call address_of_value =
Call(DataType::Handle(), builtin::address_of(), {value_node});
Array<PrimExpr> new_args;
if (vector_size_ == 4) {
new_args.push_back(StringImm("AtomicAddx4"));
} else if (vector_size_ == 2) {
new_args.push_back(StringImm("AtomicAddx2"));
} else {
new_args.push_back(StringImm("AtomicAdd"));
}
new_args.push_back(address_of_dst);
new_args.push_back(address_of_value);
Call new_call =
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);
return new_call;
} else {
+ // Non-vectorized fallback: extract BufferLoad and emit scalar AtomicAdd
+ // with address_of() wrappers
+ Optional<BufferLoad> dst_bl;
+ const PrimExpr& dst_arg = node->args[0];
+ if (const auto* bl = dst_arg.as<BufferLoadNode>()) {
+ dst_bl = Downcast<BufferLoad>(dst_arg);
+ } else if (const auto* ite = dst_arg.as<IfThenElseNode>()) {
+ if (const auto* then_bl = ite->then_case.as<BufferLoadNode>()) {
+ dst_bl = Downcast<BufferLoad>(ite->then_case);
+ } else if (const auto* else_bl = ite->else_case.as<BufferLoadNode>()) {
+ dst_bl = Downcast<BufferLoad>(ite->else_case);
+ }
+ }
+ if (!dst_bl.defined()) {
+ // Can't extract a valid BufferLoad; keep original call
+ return GetRef<PrimExpr>(node);
+ }
+
+ // Similarly extract value BufferLoad if available
+ Optional<BufferLoad> value_bl;
+ const PrimExpr& value_arg = node->args[1];
+ if (const auto* bl = value_arg.as<BufferLoadNode>()) {
+ value_bl = Downcast<BufferLoad>(value_arg);
+ }
+
+ Call address_of_dst =
+ Call(DataType::Handle(), builtin::address_of(), {dst_bl.value()});
+
Array<PrimExpr> new_args;
new_args.push_back(StringImm("AtomicAdd"));
- for (auto x : node->args)
- new_args.push_back(x);
+ new_args.push_back(address_of_dst);
+ // For value: if it's a BufferLoad, use address_of; otherwise use as-is
+ if (value_bl.defined()) {
+ Call address_of_value =
+ Call(DataType::Handle(), builtin::address_of(), {value_bl.value()});
+ new_args.push_back(address_of_value);
+ } else {
+ new_args.push_back(node->args[1]);
+ }
Call new_call =
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);
return new_call;
}
}
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 213-260, the fallback path
must not alter calls that are not atomicadd_elem_op and must use address_of
wrappers like the vectorized path; modify the code so that if node->op !=
atomicadd_elem_op() you immediately return the original node (no
transformation), and in the fallback path construct address_of(dst) (and
address_of(value) when the second arg is a BufferLoad) instead of pushing raw
node->args: Downcast the first (and second if present) arg(s) to BufferLoad,
build Calls to builtin::address_of() for them, push the appropriate "AtomicAdd"
name, then push the address_of Call(s) followed by any remaining args, and
return the new call_extern; keep existing behavior for dynamic_ and other checks
unchanged.
Summary by CodeRabbit
New Features
Refactor
Chores