-
Notifications
You must be signed in to change notification settings - Fork 264
[Bugfix]:Fix atomic add auto vectorize negative optimization #765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughRoutes AtomicAdd lowering through a thread-aware VectorizeAtomicAdd and rewrites the AtomicAdd vectorization pass to accept thread/block context and strides, compute TX extent, replace outer-var with a new iter_var, emit address-based AtomicAddx2/x4 calls, and preserve fallback/dynamic checks and predicates. Changes
Sequence Diagram(s)sequenceDiagram
participant Lower as AtomicAdd::Lower
participant VPass as VectorizeAtomicAdd
participant Rewriter as AtomicAddVectorizeRewriter
participant IR as IR Builder
Lower->>VPass: VectorizeAtomicAdd(thread_loop, thread_var, thread_bounds, arch_int)
Note right of VPass #DDFFDD: Analyze loop AST to detect bx/by multipliers & strides
VPass->>Rewriter: Init(plan, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x)
Rewriter->>IR: Replace inner For with For(iter_var, 0, extent/vector_size)
Rewriter->>IR: Compute dst/value indices using by/bx, iter_var, tx_var, strides
Rewriter->>IR: Create address_of(BufferLoad(dst)) and address_of(BufferLoad(val))
Rewriter->>IR: Emit AtomicAddx2 / AtomicAddx4 calls
VPass-->>Lower: return vectorized_thread_loop
alt predicate present
Lower->>IR: Wrap vectorized loop with predicate
end
Lower-->>Lower: Return transformed loop
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45–75 minutes Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches
🧪 Generate unit tests
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @yyttt6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses a bug in the automatic vectorization of atomic add operations, which was previously causing "negative optimization." The changes re-enable and correct the VectorizeAtomicAdd
pass by providing the rewriter with necessary context about thread variables, block variables, and memory strides. This ensures that vectorized atomic operations are correctly indexed and performed, improving performance and correctness.
Highlights
- Re-enabled AtomicAdd Vectorization: The
VectorizeAtomicAdd
pass is now actively used in theAtomicAdd::Lower
method, replacing a generic loop vectorizer and removing a "buggy implementation" TODO comment. - Enhanced Rewriter Context: The
AtomicAddVectorizeRewriter
constructor has been updated to accept additional parameters, including thread variables, block variables (bx
,by
), thread bounds, and memory strides (stride_x
,stride_y
), providing more context for accurate vectorization. - Corrected Memory Indexing for Vectorization: The core logic for calculating memory indices within the
AtomicAddVectorizeRewriter
has been significantly revised. It now dynamically computes destination and value buffer indices based on thread extents, vector size, and strides, ensuring correct memory access patterns for vectorized atomic adds. - Improved Stride and Block Variable Detection: The
VectorizeAtomicAdd
function now includes logic to identify and extractbx
,by
,stride_x
, andstride_y
values from the loop body, which are then passed to the rewriter for precise memory address calculation.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables auto-vectorization for atomic add operations, which was previously disabled. The core of the change is in src/transform/atomicadd_vectorize.cc
, where a more sophisticated index remapping logic is introduced to handle vectorized accesses correctly. While the overall direction is good, I've identified a critical issue that could lead to a crash and another high-severity bug that could cause the optimization to fail silently. Please see the detailed comments for suggestions on how to address these.
src/transform/atomicadd_vectorize.cc
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line can cause a crash. node->args[2]
(the value being added) can be an if_then_else
node if there's a predicate on the source buffer access. In that case, as<BufferLoadNode>()
will return nullptr
, and dereferencing it will cause a crash.
You should add a check before this line to ensure the cast is successful, for example:
const auto* value_load_node = node->args[2].as<BufferLoadNode>();
ICHECK(value_load_node) << "The value for AtomicAdd is expected to be a BufferLoad, but got " << node->args[2]->GetTypeKey();
And then use *value_load_node
here.
src/transform/atomicadd_vectorize.cc
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic for finding the stride and block index variable is brittle. It only checks if mul->a
is a VarNode
and mul->b
is an IntImmNode
. Due to the commutative property of multiplication, the expression could also be imm * var
. The current code would fail to detect that case, causing the vectorization to be skipped silently. You should handle both cases to make the logic more robust.
if (const MulNode *mul = obj.as<MulNode>()) {
const VarNode* var = nullptr;
const IntImmNode* imm = nullptr;
PrimExpr var_expr;
if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) {
var_expr = mul->a;
} else if ((var = mul->b.as<VarNode>()) && (imm = mul->a.as<IntImmNode>())) {
var_expr = mul->b;
}
if (var && imm) {
if (var->name_hint == "bx") {
stride_x = imm->value;
bx_var = var_expr;
} else if (var->name_hint == "by") {
stride_y = imm->value;
by_var = var_expr;
}
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/op/atomic_add.cc (1)
24-35
: Remove duplicate GetArchInt; use the shared util to avoid silent mis-detections.This static redefinition weakens validation (falls back to 0) and shadows the canonical GetArchInt in ../target/utils.h, risking inconsistent behavior. Prefer the shared implementation.
Apply this diff to remove the duplicate:
- static int GetArchInt(Target target) { - int arch_int = 0; - auto s = target->GetAttr<String>("arch"); - ICHECK(s.defined()); - const char *arch_str = s.value().c_str(); - if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') { - arch_int = atoi(&arch_str[3]); - } else { - arch_int = 0; - } - return arch_int; - } + // Use GetArchInt from ../target/utils.hsrc/transform/atomicadd_vectorize.cc (1)
260-272
: Verify and Complete CUDA Externs for AtomicAddx2/x4
- We confirmed that both
AtomicAddx2
andAtomicAddx4
are defined insrc/tl_templates/cuda/common.h
with the expected signature
(TL_DEVICE void AtomicAddx[N](<T>* address, <T>* val)
) and parameter order (dst pointer first, then value pointer), matching your call-site inatomicadd_vectorize.cc
.- However, while
AtomicAddx2
is specialized forhalf_t
,bfloat16_t
, andfloat
, theAtomicAddx4
overload only exists forfloat
:
- Missing
AtomicAddx4(half_t*, half_t*)
andAtomicAddx4(bfloat16_t*, bfloat16_t*)
specializations.- If
vector_size_ == 4
for half- or bfloat16-typed data, you’ll get unresolved externs at link time.Please add the missing
AtomicAddx4
device functions forhalf_t
andbfloat16_t
insrc/tl_templates/cuda/common.h
(mirroring the pattern used forAtomicAddx2
), or constrainvector_size_
to only emitAtomicAddx4
for types with existing CUDA support.
🧹 Nitpick comments (3)
src/op/atomic_add.cc (1)
4-4
: Typo in file header.“Elment-wise” → “element-wise”.
- * Define elment-wise operators. + * Define element-wise operators.src/transform/atomicadd_vectorize.cc (2)
126-129
: Confirm FloorMod availability; prefer truncmod for consistency.The planner uses FloorMod(...) but elsewhere uses truncdiv/truncmod. If FloorMod isn’t defined in this context, build will fail.
Would you like me to swap to truncmod for consistency?
- condition_ = (FloorMod(offset, vector_size_) == 0); + condition_ = (truncmod(offset, vector_size_) == 0);
316-319
: Heuristic extraction of bx/by and strides is fragile (Mul order, non-const).Current scan only catches patterns like bx * C and by * C with MulNode.a = Var and MulNode.b = IntImm. It misses C * bx or non-const strides.
Consider normalizing commutative Mul operands and accepting both operand orders; fall back to non-vectorized path if strides aren’t IntImm.
Also applies to: 356-358
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📥 Commits
Reviewing files that changed from the base of the PR and between e05a20a and 5e0d9550096139baf220c468da891ec1a3884675.
📒 Files selected for processing (2)
src/op/atomic_add.cc
(1 hunks)src/transform/atomicadd_vectorize.cc
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (2)
src/transform/atomicadd_vectorize.cc (2)
VectorizeAtomicAdd
(313-362)VectorizeAtomicAdd
(313-314)src/target/utils.cc (2)
GetArchInt
(18-26)GetArchInt
(18-18)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (5)
ret
(1656-1669)call_extern
(172-194)truncdiv
(2918-2941)truncmod
(2944-2967)address_of
(463-479)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: build-test-nvidia
🔇 Additional comments (2)
src/op/atomic_add.cc (1)
200-201
: Good: vectorization is now capability-aware.Switching to
VectorizeAtomicAdd(thread_loop, thread_var, thread_bounds, GetArchInt(target))
matches the new rewriter API.
⚠️ Please ensure that everyTarget
reaching this pass has its"arch"
attribute populated (for example,"sm_90"
), otherwiseGetArchInt(target)
will ICHECK-fail with a zero or undefined compute capability. I did not locate any C++ code that sets a default fortarget->attrs["arch"]
, so verify its initialization in the Python binder or the C++ Target constructor.src/transform/atomicadd_vectorize.cc (1)
182-201
: Manually confirm that only AtomicAdd operations appear in the vectorized loop bodyOur search for loops combining
AtomicAdd
calls with any other buffer loads or stores returned no hits, but absence of evidence is not evidence of absence. Since the rewrite atsrc/transform/atomicadd_vectorize.cc:182–201
only doesvmap.Set(fnode->loop_var, iter_var_);(without multiplying by
vector_size_
), any additional uses of the original loop variable in the loop body will now refer toiter_var_
instead ofiter_var_ * vector_size_
, altering semantics.Please verify that:
- The innermost loops you plan to vectorize contain only
AtomicAdd
calls (no otherBufferLoad
,BufferStore
, or arithmetic involving the originalloop_var
).- If there are other uses of
loop_var
, either
- restrict vectorization to bodies with AtomicAdd-only side effects, or
- scale the substitution to
and update all index computations accordingly.vmap.Set(fnode->loop_var, iter_var_ * vector_size_);
src/transform/atomicadd_vectorize.cc
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Division-by-zero hazard in index math when stride_x_ < vector_size_ or not divisible.
Expressions like truncdiv(tx_var_, stride_x_ / vector_size_) and truncmod(..., stride_x_ / vector_size_) use a zero denominator when stride_x_ < vector_size_. They also assume stride_x_ % vector_size_ == 0.
Add a safety gate before invoking the rewriter so we bail out when unsafe:
@@
- if (vectorize_hint == 1)
- return for_node;
- auto rewriter = AtomicAddVectorizeRewriter(
+ if (vectorize_hint == 1)
+ return for_node;
+ // Require positive stride_x and divisibility to avoid 0-denominator in index math.
+ if (stride_x <= 0 || (stride_x % vectorize_hint) != 0) {
+ return for_node;
+ }
+ auto rewriter = AtomicAddVectorizeRewriter(
res, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x);
Optionally, also guard for undefined bx/by:
- if (vectorize_hint == 1)
+ if (vectorize_hint == 1 || !bx_var.defined() || !by_var.defined())
return for_node;
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
// Matrix[by * stride_y + i / (stride_x / (tx_txtent * | |
// vector_size_)) + tx_var_ / (stride_x / vector_size_), | |
// bx * stride_x + (i % (stride_x / (tx_extent * | |
// vector_size_)) * (tx_extent * vector_size_) + (tx_var_ % | |
// (stride / vector_size_)) * vector_size_] | |
BufferLoadNode old_dst_node = | |
*(node->args[1].as<CallNode>()->args[0].as<BufferLoadNode>()); | |
BufferLoadNode old_value_node = | |
*(node->args[2].as<BufferLoadNode>()); | |
Array<PrimExpr> dst_indices, value_indices; | |
if ((extent_tx_ * vector_size_) > stride_x_) { | |
dst_indices.push_back( | |
by_var_ * stride_y_ + | |
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) + | |
truncdiv(tx_var_, stride_x_ / vector_size_)); | |
dst_indices.push_back( | |
bx_var_ * stride_x_ + | |
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_); | |
value_indices.push_back( | |
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) + | |
truncdiv(tx_var_ * vector_size_, stride_x_)); | |
value_indices.push_back( | |
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_); | |
} else { | |
dst_indices.push_back( | |
by_var_ * stride_y_ + | |
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) + | |
truncdiv(tx_var_, stride_x_ / vector_size_)); | |
dst_indices.push_back( | |
bx_var_ * stride_x_ + | |
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) * | |
(extent_tx_ * vector_size_) + | |
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_); | |
value_indices.push_back( | |
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) + | |
truncdiv(tx_var_, stride_x_ / vector_size_)); | |
value_indices.push_back( | |
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) * | |
(extent_tx_ * vector_size_) + | |
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_); | |
} | |
// Bail out if no vectorization hint or missing block coords | |
if (vectorize_hint == 1 || !bx_var.defined() || !by_var.defined()) | |
return for_node; | |
// Require a positive, exactly divisible stride_x to avoid 0-denominator in truncdiv/truncmod. | |
if (stride_x <= 0 || (stride_x % vectorize_hint) != 0) { | |
return for_node; | |
} | |
auto rewriter = AtomicAddVectorizeRewriter( | |
res, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/transform/atomicadd_vectorize.cc (2)
203-260
: Dynamic path emits invalid IR:iter_var_
used without loop rewrite.
VisitExpr_
rewrites AtomicAdd unconditionally. Whendynamic_ == true
,VisitStmt_
returns the original loop, yet the rewritten calls captureiter_var_
, which is undefined. Short-circuit in the call rewriter for dynamic plans:PrimExpr VisitExpr_(const CallNode *node) final { + if (dynamic_) { + return StmtExprMutator::VisitExpr_(node); + } if (vector_size_ == 2 || vector_size_ == 4) {Longer-term: guard vectorized call with
condition_
and fall back to scalar AtomicAdd.
303-312
: Null-deref hazard when extracting dtype from AtomicAdd arg[1].
func_name
may be null;args[1]
may not beaddress_of(BufferLoad)
. Add checks:- const auto *func_name = call->args[0].as<StringImmNode>(); - if (func_name->value == "AtomicAdd") { - DataType dtype = - call->args[1].as<CallNode>()->args[0].as<BufferLoadNode>()->dtype; + const auto* func_name = call->args[0].as<StringImmNode>(); + if (func_name && func_name->value == "AtomicAdd") { + const CallNode* addr = call->args[1].as<CallNode>(); + if (!addr || addr->op != builtin::address_of() || addr->args.size() != 1) return; + const BufferLoadNode* bl = addr->args[0].as<BufferLoadNode>(); + if (!bl) return; + DataType dtype = bl->dtype;
♻️ Duplicate comments (3)
src/transform/atomicadd_vectorize.cc (3)
314-331
: Good: commutative Mul handling for bx/by stride discovery.This resolves the brittle pattern flagged earlier.
340-343
: Add safety gate: require positive/divisible stride_x before vectorization.Prevents 0-denominator and misaligned indexing. Minimal:
- if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 || !bx_var.defined() || !by_var.defined()) + if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 || + !bx_var.defined() || !by_var.defined() || + stride_x <= 0 || (stride_x % vectorize_hint) != 0) return for_node;
208-230
: Prevent division-by-zero when vectorizing atomic adds
- In
VectorizeAtomicAdd
(before invoking the rewriter), bail out if the computed stride is smaller than the planned vector size:- if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 || !bx_var.defined() || !by_var.defined()) + if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 || + stride_x < vectorize_hint || !bx_var.defined() || !by_var.defined()) return for_node;
- Optionally, add an assertion in the
AtomicAddVectorizeRewriter
constructor to catch any remaining invalid states:ICHECK(tx_ext) << "thread_bounds->extent must be a constant for vectorization."; + ICHECK(vector_size_ > 0 && stride_x_ >= vector_size_) + << "Invalid vectorization: stride_x_=" << stride_x_ + << ", vector_size_=" << vector_size_;
🧹 Nitpick comments (3)
src/transform/atomicadd_vectorize.cc (3)
128-129
: Dynamic condition is unused; either consume it or drop it.
condition_ = (truncmod(offset, vector_size_) == 0)
is never referenced. If you keep dynamic plans, use it to guard the vectorized path or remove to reduce confusion.
208-210
: Nit: comment typo and clarity.
tx_txtent
→tx_extent
. Consider formatting the two-index formula legibly or moving it next to the branch that uses it.
1-1
: Fix formatting to satisfy CI.CI indicates clang-format changed the file; run clang-format and commit.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📥 Commits
Reviewing files that changed from the base of the PR and between 5e0d9550096139baf220c468da891ec1a3884675 and 9f656b6.
📒 Files selected for processing (2)
src/op/atomic_add.cc
(2 hunks)src/transform/atomicadd_vectorize.cc
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/op/atomic_add.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (5)
truncmod
(2944-2967)ret
(1656-1669)call_extern
(172-194)address_of
(463-479)truncdiv
(2918-2941)
🪛 GitHub Actions: CI Test on AMD
src/transform/atomicadd_vectorize.cc
[error] 1-1: tile-lang clang-format reformatted the file. Please review and stage the changes (src/transform/atomicadd_vectorize.cc).
🔇 Additional comments (2)
src/transform/atomicadd_vectorize.cc (2)
144-151
: Good: guarded constant extent fortx
and captured context.The
as_const_int
guard and storingextent_tx_
is correct and addresses prior crash risks.
232-238
: Good: safe reconstruction of BufferLoad and address_of.Switching to validated nodes and rebuilding loads avoids previous UB.
iter_var_ = Var(node->loop_var->name_hint + "_outer"); | ||
auto ret = StmtExprMutator::VisitStmt_(node); | ||
if (inner_for_ == node) { // rewrite the innermost loop | ||
For fnode = ret.as<For>().value(); | ||
auto old_var = fnode->loop_var; | ||
auto extent_ptr = as_const_int(fnode->extent); | ||
ICHECK(extent_ptr) << fnode->extent; | ||
int extent = *extent_ptr; | ||
ICHECK(extent % vector_size_ == 0) | ||
<< "extent: " << extent << " vector_size_: " << vector_size_; | ||
ICHECK(is_zero(fnode->min)); | ||
if (!dynamic_) { | ||
Var tx_var; | ||
PostOrderVisit(fnode->body, [&tx_var](const ObjectRef &node) { | ||
if (const VarNode *var = node.as<VarNode>()) { | ||
if (var->name_hint == "tx") { | ||
tx_var = GetRef<Var>(var); | ||
} | ||
} | ||
}); | ||
ICHECK(tx_var.defined()) << "Failed to find tx var"; | ||
Var outer_var = Var(old_var->name_hint + "_outer"); | ||
Map<Var, PrimExpr> vmap; | ||
// Scale thread index (tx) and loop variable by vector_size to map each | ||
// new iteration to a vectorized chunk | ||
vmap.Set(tx_var, tx_var * vector_size_); | ||
vmap.Set(fnode->loop_var, outer_var * vector_size_); | ||
vmap.Set(fnode->loop_var, iter_var_); | ||
Stmt body = Substitute(fnode->body, vmap); | ||
return For(outer_var, 0, extent / vector_size_, fnode->kind, body, | ||
return For(iter_var_, 0, extent / vector_size_, fnode->kind, body, | ||
fnode->thread_binding, fnode->annotations, fnode->span); | ||
} | ||
} | ||
return ret; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Loop-splitting bug: inner loop var not scaled by vector size (drops 1 − 1/v work).
You replace the old loop var with iter_var_
but do not multiply by vector_size_
. This misindexes all uses of the original loop var outside the AtomicAdd call and breaks semantics. Map the old var to iter_var_ * vector_size_
(lane 0 representative) or introduce an explicit lane var and fold it. Minimal fix:
- Map<Var, PrimExpr> vmap;
- vmap.Set(fnode->loop_var, iter_var_);
+ Map<Var, PrimExpr> vmap;
+ vmap.Set(fnode->loop_var, iter_var_ * make_const(DataType::Int(32), vector_size_));
Optionally also substitute tx_var_
if the body indexes with lane-local math that assumed per-lane iteration.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
iter_var_ = Var(node->loop_var->name_hint + "_outer"); | |
auto ret = StmtExprMutator::VisitStmt_(node); | |
if (inner_for_ == node) { // rewrite the innermost loop | |
For fnode = ret.as<For>().value(); | |
auto old_var = fnode->loop_var; | |
auto extent_ptr = as_const_int(fnode->extent); | |
ICHECK(extent_ptr) << fnode->extent; | |
int extent = *extent_ptr; | |
ICHECK(extent % vector_size_ == 0) | |
<< "extent: " << extent << " vector_size_: " << vector_size_; | |
ICHECK(is_zero(fnode->min)); | |
if (!dynamic_) { | |
Var tx_var; | |
PostOrderVisit(fnode->body, [&tx_var](const ObjectRef &node) { | |
if (const VarNode *var = node.as<VarNode>()) { | |
if (var->name_hint == "tx") { | |
tx_var = GetRef<Var>(var); | |
} | |
} | |
}); | |
ICHECK(tx_var.defined()) << "Failed to find tx var"; | |
Var outer_var = Var(old_var->name_hint + "_outer"); | |
Map<Var, PrimExpr> vmap; | |
// Scale thread index (tx) and loop variable by vector_size to map each | |
// new iteration to a vectorized chunk | |
vmap.Set(tx_var, tx_var * vector_size_); | |
vmap.Set(fnode->loop_var, outer_var * vector_size_); | |
vmap.Set(fnode->loop_var, iter_var_); | |
Stmt body = Substitute(fnode->body, vmap); | |
return For(outer_var, 0, extent / vector_size_, fnode->kind, body, | |
return For(iter_var_, 0, extent / vector_size_, fnode->kind, body, | |
fnode->thread_binding, fnode->annotations, fnode->span); | |
} | |
} | |
return ret; | |
} | |
iter_var_ = Var(node->loop_var->name_hint + "_outer"); | |
auto ret = StmtExprMutator::VisitStmt_(node); | |
if (inner_for_ == node) { // rewrite the innermost loop | |
For fnode = ret.as<For>().value(); | |
auto extent_ptr = as_const_int(fnode->extent); | |
ICHECK(extent_ptr) << fnode->extent; | |
int extent = *extent_ptr; | |
ICHECK(extent % vector_size_ == 0) | |
<< "extent: " << extent << " vector_size_: " << vector_size_; | |
ICHECK(is_zero(fnode->min)); | |
if (!dynamic_) { | |
Map<Var, PrimExpr> vmap; | |
vmap.Set(fnode->loop_var, | |
iter_var_ * make_const(DataType::Int(32), vector_size_)); | |
Stmt body = Substitute(fnode->body, vmap); | |
return For(iter_var_, 0, extent / vector_size_, fnode->kind, body, | |
fnode->thread_binding, fnode->annotations, fnode->span); | |
} | |
} | |
return ret; |
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 182 to 201, the innermost
loop rewrite replaces the original loop variable with iter_var_ but fails to
scale it by vector_size_, causing dropped work and incorrect indexing; fix by
mapping the original loop var to iter_var_ * vector_size_ (or introduce an
explicit lane variable and fold lane offsets) when building vmap before calling
Substitute so all uses outside the AtomicAdd see the lane-0 representative; also
consider substituting tx_var_ if the body contains lane-local arithmetic that
assumed per-lane iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
src/transform/atomicadd_vectorize.cc (4)
328-361
: Robust bx/by detection (commutative mul) — nice.Handles varimm and immvar; sets strides and captures var exprs. Good improvement.
197-201
: Loop-splitting bug: original loop var must be scaled by vector_size_.Mapping old loop_var → iter_var_ drops work/misindexes; multiply by vector_size_.
Apply:
- vmap.Set(fnode->loop_var, iter_var_); + vmap.Set( + fnode->loop_var, + iter_var_ * make_const(fnode->loop_var.dtype(), vector_size_));
233-236
: Zero-denominator hazard in index math (stride_x_ / vector_size_).If stride_x_ < vector_size_, denominators become 0 in truncdiv/truncmod; also assumes divisibility. Add guards; otherwise vectorization must be skipped.
Suggest guarding in VectorizeAtomicAdd (see separate comment) and optionally add in-function check:
- if (vector_size_ == 2 || vector_size_ == 4) { + if ((vector_size_ == 2 || vector_size_ == 4) && + stride_x_ >= vector_size_ && (stride_x_ % vector_size_ == 0)) {Also applies to: 241-242, 245-251, 256-258
370-375
: Add safety gates before invoking the rewriter (avoid 0-denominator and undefined bx/by/stride).Guard on stride_x and divisibility to prevent invalid IR; also keep existing defined checks.
Apply:
- if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 || - !bx_var.defined() || !by_var.defined()) + if (vectorize_hint == 1 || + stride_x <= 0 || (stride_x % vectorize_hint) != 0 || + stride_y <= 0 || + !bx_var.defined() || !by_var.defined()) return for_node;Optionally also guard against the else-branch ratio being 0:
+ if (const int64_t* tx_ext = as_const_int(thread_bounds->extent)) { + if (*tx_ext > 0 && stride_x < (*tx_ext) * vectorize_hint) { + return for_node; + } + }
🧹 Nitpick comments (2)
src/transform/atomicadd_vectorize.cc (2)
211-216
: Nit: fix typos in comment.“tx_txtent” → “tx_extent”, “(stride / vector_size_)” → “(stride_x_ / vector_size_)”.
Apply:
- // Matrix[by * stride_y + i / (stride_x / (tx_txtent * + // Matrix[by * stride_y + i / (stride_x / (tx_extent * @@ - // (stride / vector_size_)) * vector_size_] + // (stride_x_ / vector_size_)) * vector_size_]
291-300
: Unused/ineffective state: condition_ not applied; dynamic_ path unused.condition_ is computed but never used to predicate the vectorized call; dynamic_ true currently yields no safe vectorization. Either remove condition_ or wire it to guard the vectorized call, else keep the scalar AtomicAdd.
Apply one of:
- Minimal: skip rewrites when dynamic_ (see earlier diff).
- Full: wrap new_call with Select(condition_, new_call, old_call), but that requires retaining old_call.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/transform/atomicadd_vectorize.cc
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (5)
truncmod
(2944-2967)ret
(1656-1669)call_extern
(172-194)address_of
(463-479)truncdiv
(2918-2941)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-amd
🔇 Additional comments (4)
src/transform/atomicadd_vectorize.cc (4)
150-154
: Good: added constant-extent guard for tx.Null-check + ICHECK on thread_bounds->extent prevents crashes for symbolic extents. Looks correct.
261-271
: OK: address_of(dst/value) construction matches call_extern ABI.New AtomicAddx2/x4 args look consistent with address-of semantics.
221-227
: Safer: null-checks for dst/value loads — good.Prevents crashes when args aren’t BufferLoad. Solid.
128-129
: Dynamic condition updated to truncmod — fine, but ensure it’s used.truncmod is correct; however, without wiring condition_ into codegen (or skipping rewrite when dynamic_), this has no effect.
Would you like me to wire condition_ as a predicate into the vectorized call or keep dynamic path scalar-only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/transform/atomicadd_vectorize.cc (2)
334-343
: Null-deref risk when extracting dtype; add shape checks.call->args[1] may not be address_of(BufferLoad); current code can crash.
- if (func_name->value == "AtomicAdd") { - DataType dtype = - call->args[1].as<CallNode>()->args[0].as<BufferLoadNode>()->dtype; - vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); - } + if (func_name->value == "AtomicAdd") { + const CallNode* addr = call->args[1].as<CallNode>(); + if (!addr || addr->op != builtin::address_of() || addr->args.size() != 1) { + return; + } + const BufferLoadNode* load = addr->args[0].as<BufferLoadNode>(); + if (!load) { + return; + } + DataType dtype = load->dtype; + vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); + }
120-123
: Avoid potential infinite/invalid loop when shrinking vector_size_.Stop halving at 1 to avoid 0 and undefined behavior in IndiceCanVectorize.
- while (!IndiceCanVectorize(elem_offset, thread_var, thread_extent, - vector_size_, &analyzer_)) { + while (vector_size_ > 1 && + !IndiceCanVectorize(elem_offset, thread_var, thread_extent, + vector_size_, &analyzer_)) { vector_size_ /= 2; }
♻️ Duplicate comments (4)
src/transform/atomicadd_vectorize.cc (4)
219-230
: Good: robust operand validation before rewriting AtomicAdd.
345-364
: Good: commutativity handled for bx/by stride detection.
196-201
: Critical: inner loop-var must be scaled by vector_size_ (drops work, wrong indexing).Map the original loop var to iter_var_ * vector_size_. This was previously flagged and remains unfixed.
- vmap.Set(fnode->loop_var, iter_var_); + vmap.Set(fnode->loop_var, + iter_var_ * make_const(DataType::Int(32), vector_size_));
373-377
: Add stride/divisibility gate to avoid 0-denominator in index math.Without this, expressions like truncdiv(tx_var_, stride_x_/vector_size_) can divide by 0 when stride_x < vector_size_ or not divisible.
- if (vectorize_hint == 1 || stride_x == -1 || stride_y == -1 || - !bx_var.defined() || !by_var.defined()) + if (vectorize_hint == 1 || stride_x <= 0 || stride_y <= 0 || + (stride_x % vectorize_hint) != 0 || + !bx_var.defined() || !by_var.defined()) return for_node;
🧹 Nitpick comments (3)
src/transform/atomicadd_vectorize.cc (3)
185-185
: Preserve dtype on the new loop var.Construct iter_var_ with the same dtype as the original loop var.
- iter_var_ = Var(node->loop_var->name_hint + "_outer"); + iter_var_ = Var(node->loop_var->name_hint + "_outer", node->loop_var->dtype);
214-219
: Fix typos in index-derivation comment.“tx_txtent” → “tx_extent”; “(stride / vector_size_)” → “(stride_x_ / vector_size_)”.
- // Matrix[by * stride_y + i / (stride_x / (tx_txtent * + // Matrix[by * stride_y + i / (stride_x / (tx_extent * // vector_size_)) + tx_var_ / (stride_x / vector_size_), // bx * stride_x + (i % (stride_x / (tx_extent * // vector_size_)) * (tx_extent * vector_size_) + (tx_var_ % - // (stride / vector_size_)) * vector_size_] + // (stride_x_ / vector_size_)) * vector_size_]
144-154
: Consider widening extent_tx_ type to int64_t.Avoid narrowing from thread_bounds->extent on large configurations.
- int extent_tx_; + int64_t extent_tx_;And update the assignment accordingly.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/transform/atomicadd_vectorize.cc
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/atomicadd_vectorize.cc (1)
tilelang/language/tir/op.py (5)
truncmod
(2944-2967)ret
(1656-1669)call_extern
(172-194)address_of
(463-479)truncdiv
(2918-2941)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: build-test-amd
- GitHub Check: build-test-nvidia
- GitHub Check: bot-task
PrimExpr VisitExpr_(const CallNode *node) final { | ||
|
||
if (dynamic_) { | ||
return StmtExprMutator::VisitExpr_(node); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Defensive check inside rewriter to avoid 0-denominator at call rewrite sites.
Even with outer gating, keep a local guard to be safe.
PrimExpr VisitExpr_(const CallNode *node) final {
if (dynamic_) {
return StmtExprMutator::VisitExpr_(node);
}
+ if (stride_x_ <= 0 || (stride_x_ % vector_size_) != 0) {
+ return StmtExprMutator::VisitExpr_(node);
+ }
if (vector_size_ == 2 || vector_size_ == 4) {
Also applies to: 231-262
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 206-209 and also apply to
231-262, add a local defensive guard inside the CallNode rewrite path: detect
when the denominator (or divisor) value could be zero and short-circuit or
fallback to the original expression before performing the rewrite; concretely,
compute or extract the divisor locally, check it for zero (or
non-constant/unknown) and if zero/unsafe, return
StmtExprMutator::VisitExpr_(node) (or the unmodified node) instead of proceeding
with the transformation so no 0-division is generated.
Summary by CodeRabbit