-
Notifications
You must be signed in to change notification settings - Fork 254
[Bugfix]:Fix atomicadd auto vectorize identify var error #883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughRefactors AtomicAdd vectorization into a planner-driven pipeline and rewrites AtomicAdd lowering to use SIMT fusion, layout inference, thread-loop partitioning, and planner-based vectorization; public VectorizeAtomicAdd signature is simplified and new planner/plan-result types are introduced. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant Lower as AtomicAddNode::Lower
participant SIMT as SIMT Builder/Fuser
participant Collector as AtomicLoopNestCollector
participant Layout as InferLayout
participant Planner as AtomicAddVectorizePlanner
participant Vectorize as VectorizeAtomicAdd
participant Rewriter
participant IR as Resulting IR
Caller->>Lower: Lower(args, analyzer)
Lower->>SIMT: Build + fuse SIMT loops
SIMT-->>Lower: Fused loop
Lower->>Collector: Collect loop nest & indices
Collector-->>Lower: Loop metadata
Lower->>Layout: ComputeLoopLayoutFromBuffer(...)
Layout-->>Lower: Layout + guard predicate (optional)
Lower->>Planner: Plan(For, compute_capability)
Note right of Planner: Analyze AtomicAdd calls & dtypes → vector_size, dynamic, condition
Planner-->>Vectorize: PlanResult
Vectorize->>Rewriter: Rewrite using PlanResult
Rewriter-->>Vectorize: Vectorized For
Vectorize-->>Lower: Vectorized loop (maybe guarded)
Lower-->>Caller: Lowered IR
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Summary of ChangesHello @yyttt6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a bug in the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes a bug in identifying variables for atomicadd
auto-vectorization by introducing a more robust ParseIndex
function. The changes are a definite improvement over the previous, more brittle implementation. I've identified a potential issue in how multiple AtomicAdd
calls within a loop are handled, which could lead to incorrect behavior. My review includes a suggestion to make this logic more robust. Additionally, it's good to see that a previously failing test case has been re-enabled as part of this fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
testing/python/language/test_tilelang_language_atomic_add.py (1)
375-376
: Enable test: good; consider making it hardware-agnostic (float16) to avoid cc>=90 dependencyAtomicAddx4 for float32 is only selected when compute capability >= 90. On CI GPUs < 90 (e.g., A100 cc=80), this path may not vectorize and could cause flakiness for the tile-atomic path. Two options:
- Portable: call with float16 so vectorization is available broadly.
- Alternatively, gate/skip on device capability.
Apply this minimal change for portability:
-def test_tile_atomic_add(): - run_tile_atomic_add(8, 128, 128, 32, 32) +def test_tile_atomic_add(): + run_tile_atomic_add(8, 128, 128, 32, 32, dtype="float16")Also, consider removing or gating the debug prints in run_tile_atomic_add to keep test output clean (prints at Lines 58, 72, 73).
src/transform/atomicadd_vectorize.cc (4)
322-347
: ParseIndex is too strict; accept const-expr strides and avoid false negativesRequiring exactly one MulNode with a Var and an IntImm will miss common canonical forms:
- Stride may be a foldable const expr or come via casts (not a bare IntImm).
- Extra harmless multiplies like x*1 can appear pre-simplification.
- You only need a unique var*const match; other non-relevant muls shouldn’t invalidate the parse.
Refine by simplifying first, using as_const_int, and relaxing the check to “exactly one legal var*const mul” regardless of other muls:
- auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out, - int &stride_out) -> bool { + auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out, + int &stride_out) -> bool { int mul_count = 0, legal_mul_count = 0; stride_out = -1; var_out = PrimExpr(); - PostOrderVisit(idx, [&](const ObjectRef &obj) { + // Simplify to eliminate x*1 and fold-able constants. + arith::Analyzer az; + PrimExpr sidx = az.Simplify(idx); + PostOrderVisit(sidx, [&](const ObjectRef &obj) { if (const MulNode *mul = obj.as<MulNode>()) { mul_count++; - const VarNode *var = nullptr; - const IntImmNode *imm = nullptr; - if ((var = mul->a.as<VarNode>()) && (imm = mul->b.as<IntImmNode>())) { - var_out = mul->a; - stride_out = imm->value; - legal_mul_count++; - } else if ((var = mul->b.as<VarNode>()) && - (imm = mul->a.as<IntImmNode>())) { - var_out = mul->b; - stride_out = imm->value; - legal_mul_count++; - } + const VarNode *var = nullptr; + const int64_t *c = nullptr; + if ((var = mul->a.as<VarNode>()) && (c = as_const_int(mul->b))) { + var_out = mul->a; + stride_out = static_cast<int>(*c); + legal_mul_count++; + } else if ((var = mul->b.as<VarNode>()) && (c = as_const_int(mul->a))) { + var_out = mul->b; + stride_out = static_cast<int>(*c); + legal_mul_count++; + } } }); - if (mul_count == 1 && legal_mul_count == 1) - return true; - return false; + return legal_mul_count == 1; };Note: this uses as_const_int and simplification. If not already available, include tvm/arith/analyzer.h (already included).
362-368
: Accumulate vectorize_size_max across multiple AtomicAdd sitesIf the loop body contains multiple AtomicAdd calls, you currently overwrite vectorize_size_max. Prefer taking the max to avoid under-vectorizing later calls.
- DataType dtype = bufload->dtype; - vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); + DataType dtype = bufload->dtype; + vectorize_size_max = std::max( + vectorize_size_max, GetVectorizeSizeMax(compute_capability, dtype));You’ll need:
- Add at top: #include
362-368
: Guard against mis-identifying non-block vars as bx/byParseIndex will happily return any varconst (e.g., loop i1). Before accepting, assert the extracted vars are actual block indices (thread/block bindings) for safety, otherwise bail out. For example:
- Verify var_out.as() is bound in thread_binding as blockIdx.{x,y} (or matches expected bx/by symbols in this pass’ context).
- If that metadata isn’t available here, at least ensure both extracted vars differ and are not the loop var inside inner_for_.
This avoids rewriting with incorrect axes on more complex index expressions.
380-385
: Extra sanity checks before rewritingBefore constructing the rewriter:
- Ensure bx_var and by_var are Vars: if (!bx_var.as() || !by_var.as()) return for_node;
- Optionally ensure stride_x > 0 && stride_y > 0.
This prevents emitting malformed truncdiv/truncmod expressions.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/transform/atomicadd_vectorize.cc
(2 hunks)testing/python/language/test_tilelang_language_atomic_add.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: format-check
- GitHub Check: bot-task
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_atomic_add.py (1)
375-377
: Remove noisy debug prints before re-enabling this testReactivating
test_tile_atomic_add
now exercisesrun_tile_atomic_add
, which still containsprint(kernel.get_kernel_source())
plus dumps of both 128×128 tensors. That’s hundreds of thousands of characters on every run and will swamp CI logs without adding assertion value. Please drop or gate those prints before merging.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/language/test_tilelang_language_atomic_add.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
src/transform/atomicadd_vectorize.cc
Outdated
For VectorizeAtomicAdd(const For &for_node, const Var &thread_var, | ||
const Range &thread_bounds, int compute_capability) { | ||
|
||
auto ParseIndex = [](const PrimExpr &idx, PrimExpr &var_out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add some comments for this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/atomic_add.cc (1)
35-46
: Code duplication:GetArchInt
is duplicated fromsrc/target/utils.cc
.The
GetArchInt
function is duplicated fromsrc/target/utils.cc
(lines 17-25) with a different implementation. The local version includes a fallback to return 0 for non-sm_
architectures, while the version insrc/target/utils.cc
usesICHECK
to enforce thesm_
prefix.Consider one of the following approaches:
- Preferred: Import and use the existing
GetArchInt
fromsrc/target/utils.cc
if the stricter validation is acceptable, or- Update the version in
src/target/utils.cc
to include the fallback behavior and use it consistently across the codebase.Apply this diff to use the existing function:
-static int GetArchInt(Target target) { - int arch_int = 0; - auto s = target->GetAttr<String>("arch"); - ICHECK(s.defined()); - std::string arch = s.value(); - if (arch.rfind("sm_", 0) == 0) { - arch_int = std::stoi(arch.substr(3)); - } else { - arch_int = 0; - } - return arch_int; -}And update the include at the top of the file if not already present:
#include "../target/utils.h"
🧹 Nitpick comments (3)
src/op/atomic_add.cc (1)
372-372
: Consider removing or adjusting the log level.The
LOG(INFO)
statement prints the vectorized loop IR to the console. This may be useful during development but could be noisy in production.Consider one of the following:
- Remove the log statement if it was added for debugging purposes only.
- Change to
VLOG(1)
or a higher verbosity level to reduce noise in production logs.- If this is intentional diagnostic output, add a comment explaining why it's logged at INFO level.
Apply this diff to change to verbose logging:
- LOG(INFO) << vectorized_thread_loop; + VLOG(1) << "Vectorized thread loop: " << vectorized_thread_loop;src/transform/atomicadd_vectorize.cc (2)
33-55
: Consider adding documentation forBufferIndiceSimplify
.The
BufferIndiceSimplify
class lacks documentation. Adding a brief comment explaining its purpose would improve maintainability.Apply this diff to add documentation:
+/// \brief Simplifies buffer load and store indices using an analyzer. +/// +/// This mutator visits BufferLoad and BufferStore nodes and simplifies +/// their indices by applying the analyzer's Simplify method to each index. class BufferIndiceSimplify : public StmtExprMutator {
174-231
: Consider adding documentation for therun()
method.The
run()
method implements complex loop transformation logic but lacks documentation explaining the transformation steps and the role ofloop_layout
andanalyzer
.Apply this diff to add documentation:
+ /// \brief Transform and vectorize the for loop using the provided layout. + /// + /// \param for_node The original For loop to transform + /// \param loop_layout Fragment describing the loop layout transformation + /// \param analyzer Analyzer for simplifying indices and binding loop variables + /// \return Transformed and vectorized For loop For run(For for_node, const Fragment &loop_layout, arith::Analyzer *analyzer) {
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/op/atomic_add.cc
(1 hunks)src/transform/atomicadd_vectorize.cc
(6 hunks)src/transform/atomicadd_vectorize.h
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (2)
src/transform/atomicadd_vectorize.cc (2)
VectorizeAtomicAdd
(308-343)VectorizeAtomicAdd
(308-310)src/target/utils.cc (2)
GetArchInt
(18-26)GetArchInt
(18-18)
src/transform/atomicadd_vectorize.cc (1)
src/transform/loop_partition.cc (5)
BufferIndiceSimplify
(38-38)LoopPragmaUnroll
(201-205)LoopPragmaUnroll
(201-201)PartitionLoop
(61-105)PartitionLoop
(61-62)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: bot-task
- GitHub Check: format-check
- GitHub Check: format-check
🔇 Additional comments (5)
src/op/atomic_add.cc (1)
364-366
: Call site correctly updated.The call to
VectorizeAtomicAdd
now includes the requiredanalyzer
andloop_layout
parameters, matching the updated API signature.src/transform/atomicadd_vectorize.cc (3)
33-55
:BufferIndiceSimplify
implementation looks correct.The mutator correctly simplifies indices for both
BufferLoad
andBufferStore
nodes using the provided analyzer. The implementation properly usesCopyOnWrite()
and maps indices through the analyzer'sSimplify
method.
308-343
:VectorizeAtomicAdd
implementation looks correct.The function properly:
- Determines the maximum vectorization size based on compute capability and data type.
- Uses
PartitionLoop
to simplify the loop before planning.- Plans the vectorization with the planner.
- Returns the original loop if vectorization is not beneficial (
vectorize_hint == 1
).- Applies the vectorization rewrite with the provided analyzer and loop layout.
174-231
: Verify vectorization across all loop levels and add tests. Therun()
method divides every transformed loop extent byvector_size_
, not just the innermost; no existing tests cover this behavior—please confirm this is intentional and add multi-level loop vectorization test cases.src/transform/atomicadd_vectorize.h (1)
17-19
: Approve VectorizeAtomicAdd signature update
All call sites, includingsrc/op/atomic_add.cc
, have been updated to match the new parameters.
I refactored part of the AtomicAdd auto vectorization code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/op/atomic_add.cc
(1 hunks)src/transform/atomicadd_vectorize.cc
(6 hunks)src/transform/atomicadd_vectorize.h
(1 hunks)tilelang/language/atomic.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/transform/atomicadd_vectorize.h
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/atomicadd_vectorize.cc (1)
src/transform/loop_partition.cc (5)
BufferIndiceSimplify
(38-38)LoopPragmaUnroll
(201-205)LoopPragmaUnroll
(201-201)PartitionLoop
(61-105)PartitionLoop
(61-62)
src/op/atomic_add.cc (1)
src/transform/atomicadd_vectorize.cc (2)
VectorizeAtomicAdd
(308-343)VectorizeAtomicAdd
(308-310)
tilelang/language/atomic.py (1)
tilelang/language/customize.py (1)
buffer_load_to_tile_region
(46-68)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: format-check
const BufferLoad dst_node = | ||
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); | ||
const BufferLoad value_node = | ||
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Downcast
expects an ObjectRef
, not a raw node pointer
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>())
doesn’t compile—the overload requires an ObjectRef
, but as<...>()
returns const BufferLoadNode*
. Use the original PrimExpr
instead (we already checked it’s a BufferLoad
), e.g.:
- const BufferLoad dst_node =
- Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
- const BufferLoad value_node =
- Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
+ const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]);
+ const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);
Without this change the file fails to build.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
const BufferLoad dst_node = | |
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); | |
const BufferLoad value_node = | |
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()); | |
const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]); | |
const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]); |
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 249 to 253, the code calls
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) and similarly for
args[2], but Downcast expects an ObjectRef not a raw node pointer; replace the
.as<BufferLoadNode>() calls and pass the original PrimExprs (node->args[1] and
node->args[2]) directly to Downcast<BufferLoad>(), relying on the existing type
checks that confirmed these are BufferLoad instances so the Downcast will be
valid.
tilelang/language/atomic.py
Outdated
src_extent = list(get_extent(value)) | ||
dst_extent = list(get_extent(dst)) | ||
legal = True | ||
|
||
if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent): | ||
legal = False | ||
elif (dst_extent and src_extent): | ||
if len(dst_extent) > len(src_extent): | ||
dst_extent_dims = [x for x in dst_extent if x != 1] | ||
if dst_extent_dims != src_extent: | ||
legal = False | ||
else: | ||
if dst_extent != src_extent: | ||
legal = False | ||
else: | ||
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) | ||
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) | ||
extent = max(dst_extent, src_extent) | ||
dst_extent = src_extent = extent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard get_extent()
results before wrapping in list(...)
get_extent()
still returns None
for scalar PrimExpr inputs (e.g., atomic_add(dst, 1)
), so list(get_extent(...))
raises a TypeError
before we can fall back to the extern path. This regresses the scalar code path.
Please keep the raw result, check for None
, and only convert to list
when defined before the length/shape logic.
- src_extent = list(get_extent(value))
- dst_extent = list(get_extent(dst))
+ src_extent_raw = get_extent(value)
+ dst_extent_raw = get_extent(dst)
+ src_extent = list(src_extent_raw) if src_extent_raw is not None else None
+ dst_extent = list(dst_extent_raw) if dst_extent_raw is not None else None
legal = True
- if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent):
+ if dst_extent is None and src_extent is None:
+ legal = False
+ elif dst_extent is None:
+ dst_extent = [1] * len(src_extent)
+ elif src_extent is None:
+ src_extent = [1] * len(dst_extent)
+ elif len(dst_extent) < len(src_extent):
legal = False
Make sure the remaining branches avoid len(None)
as well.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
src_extent = list(get_extent(value)) | |
dst_extent = list(get_extent(dst)) | |
legal = True | |
if (dst_extent is None and src_extent is None) or len(dst_extent) < len(src_extent): | |
legal = False | |
elif (dst_extent and src_extent): | |
if len(dst_extent) > len(src_extent): | |
dst_extent_dims = [x for x in dst_extent if x != 1] | |
if dst_extent_dims != src_extent: | |
legal = False | |
else: | |
if dst_extent != src_extent: | |
legal = False | |
else: | |
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) | |
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) | |
extent = max(dst_extent, src_extent) | |
dst_extent = src_extent = extent | |
src_extent_raw = get_extent(value) | |
dst_extent_raw = get_extent(dst) | |
src_extent = list(src_extent_raw) if src_extent_raw is not None else None | |
dst_extent = list(dst_extent_raw) if dst_extent_raw is not None else None | |
legal = True | |
if dst_extent is None and src_extent is None: | |
legal = False | |
elif dst_extent is None: | |
dst_extent = [1] * len(src_extent) | |
elif src_extent is None: | |
src_extent = [1] * len(dst_extent) | |
elif len(dst_extent) < len(src_extent): | |
legal = False | |
elif dst_extent and src_extent: | |
if len(dst_extent) > len(src_extent): | |
dst_extent_dims = [x for x in dst_extent if x != 1] | |
if dst_extent_dims != src_extent: | |
legal = False | |
else: | |
if dst_extent != src_extent: | |
legal = False | |
else: | |
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) | |
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) | |
extent = max(dst_extent, src_extent) | |
dst_extent = src_extent = extent |
🤖 Prompt for AI Agents
In tilelang/language/atomic.py around lines 193 to 211, guard the raw results of
get_extent() before calling list(...) to avoid TypeError on scalar PrimExprs:
assign src_extent_raw = get_extent(value) and dst_extent_raw = get_extent(dst),
check if each is None before converting to list; only call list(...) when the
raw extent is not None, and ensure every branch uses None-checks rather than
len(None) (e.g., when one side is None treat it as a scalar extent
[1]*len(other) or follow the extern-path fallback), then perform the same
dimension comparisons and the elementwise max logic operating on actual lists so
no branch ever calls len(...) or iterates over a None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/atomicadd_vectorize.cc (1)
52-205
: Fix AtomicAdd rewrites: size guard and correctDowncast
usage.Two problems here break compilation and correctness:
- We only check
args.size() >= 2
, yet we always readargs[2]
; AtomicAdd calls must therefore guard for>= 3
before touching the value operand.Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>())
is invalid—the overload expects anObjectRef
, not a raw node pointer. This reintroduces the build failure from the previous review.Please tighten the guards and pass the original
PrimExpr
toDowncast
:- if (node->op == builtin::call_extern() && node->args.size() >= 2) { + if (node->op == builtin::call_extern() && node->args.size() >= 3) { if (const auto *func_name = node->args[0].as<StringImmNode>()) { if (func_name->value == "AtomicAdd") { const BufferLoadNode *temp_dst_node = node->args[1].as<BufferLoadNode>(); const BufferLoadNode *temp_value_node = node->args[2].as<BufferLoadNode>(); if (!temp_dst_node || !temp_value_node) { return StmtExprMutator::VisitExpr_(node); } - const BufferLoad dst_node = - Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); - const BufferLoad value_node = - Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()); + const BufferLoad dst_node = Downcast<BufferLoad>(node->args[1]); + const BufferLoad value_node = Downcast<BufferLoad>(node->args[2]);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/op/atomic_add.cc
(2 hunks)src/transform/atomicadd_vectorize.cc
(6 hunks)src/transform/atomicadd_vectorize.h
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/transform/atomicadd_vectorize.h (2)
src/transform/loop_vectorize.h (2)
tvm
(31-49)tl
(32-48)src/transform/atomicadd_vectorize.cc (17)
VectorizeAtomicAdd
(236-246)VectorizeAtomicAdd
(236-236)AtomicAddVectorizePlanner
(15-15)Plan
(17-44)Plan
(18-18)node
(162-184)node
(162-162)node
(186-228)node
(186-186)VisitStmt_
(46-49)VisitStmt_
(46-46)VisitExpr_
(51-71)VisitExpr_
(51-51)GetVectorizeSizeMax
(73-85)GetVectorizeSizeMax
(73-74)UpdateVectorSize
(87-127)UpdateVectorSize
(87-88)
src/transform/atomicadd_vectorize.cc (2)
src/transform/atomicadd_vectorize.h (1)
AtomicAddVectorizePlanner
(36-55)src/transform/loop_vectorize.cc (4)
indices
(157-189)indices
(157-157)IndiceCanVectorize
(257-298)IndiceCanVectorize
(257-259)
src/op/atomic_add.cc (6)
src/op/parallel.cc (8)
Lower
(184-187)Lower
(184-185)VisitStmt_
(130-146)VisitStmt_
(130-130)VisitStmt_
(148-160)VisitStmt_
(148-148)VisitExpr_
(162-173)VisitExpr_
(162-162)src/op/copy.cc (6)
Lower
(791-823)Lower
(791-791)Lower
(1776-1898)Lower
(1776-1777)MakeSIMTLoop
(299-344)MakeSIMTLoop
(299-299)src/op/fill.cc (4)
Lower
(171-206)Lower
(171-171)MakeSIMTLoop
(136-151)MakeSIMTLoop
(136-136)src/op/reduce.cc (4)
Lower
(152-318)Lower
(152-152)Lower
(413-437)Lower
(413-413)src/target/utils.cc (2)
GetArchInt
(18-26)GetArchInt
(18-18)src/transform/atomicadd_vectorize.cc (6)
VisitStmt_
(46-49)VisitStmt_
(46-46)VisitExpr_
(51-71)VisitExpr_
(51-51)VectorizeAtomicAdd
(236-246)VectorizeAtomicAdd
(236-236)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: format-check
const auto *func_name = call->args[0].as<StringImmNode>(); | ||
if (func_name && func_name->value == "AtomicAdd") { | ||
DataType dtype = call->args[1].as<BufferLoadNode>()->dtype; | ||
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard the dtype probe before dereferencing.
call->args[1]
is not guaranteed to be a BufferLoad
(it can be wrapped in an IfThenElse
by the lowering pipeline), so as<BufferLoadNode>()
may return nullptr
. Dereferencing it here will crash during planning. Please bail out unless the cast succeeds, e.g.:
- DataType dtype = call->args[1].as<BufferLoadNode>()->dtype;
- vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
+ if (const auto *dst = call->args[1].as<BufferLoadNode>()) {
+ vectorize_size_max =
+ GetVectorizeSizeMax(compute_capability, dst->dtype);
+ }
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
const auto *func_name = call->args[0].as<StringImmNode>(); | |
if (func_name && func_name->value == "AtomicAdd") { | |
DataType dtype = call->args[1].as<BufferLoadNode>()->dtype; | |
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); | |
} | |
} | |
} | |
const auto *func_name = call->args[0].as<StringImmNode>(); | |
if (func_name && func_name->value == "AtomicAdd") { | |
if (const auto *dst = call->args[1].as<BufferLoadNode>()) { | |
vectorize_size_max = | |
GetVectorizeSizeMax(compute_capability, dst->dtype); | |
} | |
} |
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 27 to 33, the code
dereferences call->args[1].as<BufferLoadNode>() without checking it may be
nullptr (it can be an IfThenElse or other node), which can crash; store the
result of as<BufferLoadNode>() in a local pointer, check it is non-null before
accessing dtype, and if the cast fails simply skip setting vectorize_size_max
(bail out of this branch) so you don't dereference a null pointer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/op/atomic_add.cc
(3 hunks)src/transform/atomicadd_vectorize.cc
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (5)
src/op/parallel.cc (8)
InferLayout
(212-527)InferLayout
(212-213)VisitStmt_
(130-146)VisitStmt_
(130-130)VisitStmt_
(148-160)VisitStmt_
(148-148)VisitExpr_
(162-173)VisitExpr_
(162-162)src/op/copy.cc (6)
InferLayout
(399-498)InferLayout
(399-400)InferLayout
(1955-1958)InferLayout
(1955-1956)MakeSIMTLoop
(299-344)MakeSIMTLoop
(299-299)src/op/parallel.h (1)
ParallelOp
(151-158)src/target/utils.cc (2)
GetArchInt
(18-26)GetArchInt
(18-18)src/transform/atomicadd_vectorize.cc (6)
VisitStmt_
(45-48)VisitStmt_
(45-45)VisitExpr_
(50-70)VisitExpr_
(50-50)VectorizeAtomicAdd
(235-244)VectorizeAtomicAdd
(235-235)
src/transform/atomicadd_vectorize.cc (2)
src/transform/atomicadd_vectorize.h (1)
AtomicAddVectorizePlanner
(36-55)src/transform/loop_vectorize.cc (4)
indices
(157-189)indices
(157-157)IndiceCanVectorize
(257-298)IndiceCanVectorize
(257-259)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-metal
auto GetArchInt = [&](const Target &tgt) -> int { | ||
int arch_int = 0; | ||
if (auto s = tgt->GetAttr<String>("arch")) { | ||
std::string arch = s.value(); | ||
if (arch.rfind("sm_", 0) == 0) | ||
arch_int = std::stoi(arch.substr(3)); | ||
} | ||
return arch_int; | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Drop the shadowing GetArchInt
lambda.
We already have a file-scope GetArchInt(Target)
(Lines 37-48). Redefining an identical lambda here is redundant, risks divergence, and silently bypasses the shared helper (e.g. any future validation fixes). Call the existing function instead of shadowing it locally.
🤖 Prompt for AI Agents
In src/op/atomic_add.cc around lines 401 to 409, there is a locally defined
GetArchInt lambda that shadows an existing file-scope GetArchInt(Target) (lines
37-48); remove this redundant lambda and replace any uses in this scope with a
direct call to the file-scope GetArchInt(tgt) helper so the shared
implementation (and any future validations) are preserved; ensure the lambda
definition is deleted and all call sites use GetArchInt(tgt) without adding new
duplicates.
if (plan.dynamic && plan.condition.defined()) { | ||
pred = plan.condition; | ||
} | ||
DLOG(INFO) << "[AtomicAddInferLayout] vec=" << vec | ||
<< " loop_layout=" << loop_layout->DebugOutput(); | ||
return {loop_layout, pred}; | ||
}; | ||
|
||
auto ret = AtomicAddInferLayout(transformed_loop, | ||
{T.target, T.thread_bounds, T.layout_map, | ||
analyzer, false, T.buffer_remap}); | ||
Fragment loop_layout = ret.loop_layout; | ||
auto thread_loop = | ||
PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout); | ||
auto vectorized_thread_loop = | ||
VectorizeAtomicAdd(thread_loop, GetArchInt(target)); | ||
return vectorized_thread_loop; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Propagate planner predicate for dynamic vectorization.
AtomicAddVectorizePlanner::Plan
can return dynamic=true
with condition
set. We drop that information: VectorizeAtomicAdd
may leave the loop unvectorized, but when dynamic
is true the planner expects us to wrap the vectorized body with the guard before executing. Please reintroduce the predicate (e.g. wrap vectorized_thread_loop
with IfThenElse(pred, body)
when present) so dynamic plans stay correct.
🤖 Prompt for AI Agents
In src/op/atomic_add.cc around lines 517 to 533, the planner predicate produced
for dynamic vectorization is being dropped after AtomicAddInferLayout; capture
the returned predicate from ret and, if it is defined, wrap the final
vectorized_thread_loop in an If/IfThenElse guard that uses that predicate so the
dynamic plan executes the guarded vectorized body only when the condition holds;
specifically, obtain pred from the planner return, and replace the direct return
of vectorized_thread_loop with a conditional wrapper (IfThenElse(pred,
vectorized_thread_loop)) when pred is present, otherwise return the loop
unchanged.
DataType dtype = call->args[1].as<BufferLoadNode>()->dtype; | ||
vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard the BufferLoad cast before dereferencing.
call->args[1].as<BufferLoadNode>()
can return nullptr
(planner already saw AtomicAdd
nodes wrapped by predicates elsewhere). Dereferencing without a check crashes the planner on valid inputs. Bail out unless the cast succeeds before touching dtype
.
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 29-30, the code
dereferences call->args[1].as<BufferLoadNode>() without checking for null which
can crash; modify the code to first store the result of
call->args[1].as<BufferLoadNode>() into a local pointer, check if it is nullptr,
and if so bail out (e.g., return false / skip vectorization) before accessing
dtype; only call GetVectorizeSizeMax when the cast succeeded.
node->args[1].as<BufferLoadNode>(); | ||
const BufferLoadNode *old_value_node = | ||
const BufferLoadNode *temp_value_node = | ||
node->args[2].as<BufferLoadNode>(); | ||
if (!old_dst_node || !old_value_node) { | ||
if (!temp_dst_node || !temp_value_node) { | ||
return StmtExprMutator::VisitExpr_(node); | ||
} | ||
Array<PrimExpr> dst_indices, value_indices; | ||
if ((extent_tx_ * vector_size_) > stride_x_) { | ||
dst_indices.push_back( | ||
by_var_ * stride_y_ + | ||
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) + | ||
truncdiv(tx_var_, stride_x_ / vector_size_)); | ||
dst_indices.push_back( | ||
bx_var_ * stride_x_ + | ||
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_); | ||
value_indices.push_back( | ||
iter_var_ * (extent_tx_ * vector_size_ / stride_x_) + | ||
truncdiv(tx_var_ * vector_size_, stride_x_)); | ||
value_indices.push_back( | ||
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_); | ||
} else { | ||
dst_indices.push_back( | ||
by_var_ * stride_y_ + | ||
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) + | ||
truncdiv(tx_var_, stride_x_ / vector_size_)); | ||
dst_indices.push_back( | ||
bx_var_ * stride_x_ + | ||
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) * | ||
(extent_tx_ * vector_size_) + | ||
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_); | ||
value_indices.push_back( | ||
truncdiv(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) + | ||
truncdiv(tx_var_, stride_x_ / vector_size_)); | ||
value_indices.push_back( | ||
truncmod(iter_var_, stride_x_ / (extent_tx_ * vector_size_)) * | ||
(extent_tx_ * vector_size_) + | ||
truncmod(tx_var_, stride_x_ / vector_size_) * vector_size_); | ||
} | ||
const BufferLoad dst_node = | ||
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()); | ||
const BufferLoad value_node = | ||
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass the original PrimExprs to Downcast
.
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>())
calls the overload with a raw pointer, which doesn’t compile. We already confirmed the operands are BufferLoad
, so pass node->args[1]
/ node->args[2]
directly to Downcast
. This matches the required ObjectRef
signature and fixes the build.
🤖 Prompt for AI Agents
In src/transform/atomicadd_vectorize.cc around lines 194 to 203, the code calls
Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>()) and
Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>()), which uses the
raw-pointer overload and fails to compile; replace these with
Downcast<BufferLoad>(node->args[1]) and Downcast<BufferLoad>(node->args[2])
respectively so the ObjectRef (PrimExpr) overload is used (keep the preceding
as<BufferLoadNode>() checks to validate the type).
Summary by CodeRabbit
New Features
Refactor