Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);

#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \
Expand Down
2 changes: 2 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ static constexpr const char *kDisableSafeMemoryLegalize =
static constexpr const char *kDisableWarpSpecialized =
"tl.disable_warp_specialized";
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
static constexpr const char *kEnableAggressiveSharedMemoryMerge =
"tl.enable_aggressive_shared_memory_merge";

/*!
* \brief Whether to disable dynamic tail split
Expand Down
54 changes: 39 additions & 15 deletions src/transform/merge_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ class AllocateCollector : public StmtExprVisitor {
//
class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
public:
explicit SharedMemLinearAccessPatternFinder(bool is_dynamic = true,
bool verbose = false)
: is_dynamic_(is_dynamic), verbose_(verbose) {}
explicit SharedMemLinearAccessPatternFinder(
bool is_dynamic = true, bool enable_aggressive_merge = false,
bool verbose = false)
: is_dynamic_(is_dynamic),
enable_aggressive_merge_(enable_aggressive_merge), verbose_(verbose) {}
/*! \brief record the touch list of statement. */
struct StmtEntry {
// The statement
Expand Down Expand Up @@ -151,9 +153,15 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
// set into scope_.size() - 1 for aggressive memory reuse
scope_[it->second.level].touched.push_back(buf);
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
}
Comment on lines +156 to +161
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic block is duplicated in VisitExpr_(const BufferLoadNode*) (lines 196-201) and VisitExpr_(const VarNode*) (lines 212-217). To improve maintainability and readability, consider refactoring this into a private helper method.

Additionally, this block can be simplified using a ternary operator to avoid the if/else statement and the redundant local variable.

Suggested change
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
}
const size_t target_level = enable_aggressive_merge_ ? scope_.size() - 1 : it->second.level;
scope_[target_level].touched.push_back(buf);

}
}

StmtEntry e = scope_.back();
scope_.pop_back();
if (e.touched.size() != 0) {
Expand Down Expand Up @@ -185,7 +193,12 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
ICHECK_LT(it->second.level, scope_.size())
<< "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf);
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
}
}
}
}
Expand All @@ -196,7 +209,12 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[it->second.level].touched.push_back(buf);
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
}
}
}
}
Expand Down Expand Up @@ -284,6 +302,8 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
}
// Whether do dyanmic analysis.
bool is_dynamic_{true};
// Whether do aggressive merge.
bool enable_aggressive_merge_{false};
// Whether do verbose logging.
bool verbose_{false};
// Whether already in thread env.
Expand Down Expand Up @@ -317,8 +337,9 @@ class SharedMemoryRewriter : public StmtExprMutator {
* \param stmt the statement
*/
void PlanReuse(const Stmt &stmt, bool is_dynamic = true,
bool verbose = false) {
SharedMemLinearAccessPatternFinder finder(is_dynamic, verbose);
bool enable_aggressive_merge = false, bool verbose = false) {
SharedMemLinearAccessPatternFinder finder(is_dynamic,
enable_aggressive_merge, verbose);
finder(stmt);
this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
Expand Down Expand Up @@ -956,6 +977,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
}
// Wheather enable dyanmic analysis.
bool is_dynamic_{true};

// Whether enable verbose logging.
bool verbose_{false};
// The var for the merged buffer
Expand Down Expand Up @@ -985,18 +1007,19 @@ class SharedMemoryRewriter : public StmtExprMutator {
};

Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem,
bool enable_aggressive_merge,
bool verbose = false) {
AllocateCollector collector;
collector(stmt);
if (collector.dyn_shmem_allocs_.size() > 1) {
SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_, true, verbose);
rewriter.PlanReuse(stmt);
rewriter.PlanReuse(stmt, true, enable_aggressive_merge);
stmt = rewriter(std::move(stmt));
}
if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) {
SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false,
verbose);
rewriter.PlanReuse(stmt, false);
rewriter.PlanReuse(stmt, false, enable_aggressive_merge);
stmt = rewriter(std::move(stmt));
}
return stmt;
Expand All @@ -1006,17 +1029,18 @@ using namespace tir::transform;

namespace transform {

Pass MergeSharedMemoryAllocations() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false) {
auto pass_func = [enable_aggressive_merge](PrimFunc f, IRModule m,
PassContext ctx) {
bool merge_static_smem =
ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
bool debug_merge_shared_memory_allocations =
ctx->GetConfig<Bool>(kDebugMergeSharedMemoryAllocations, Bool(false))
.value();
auto *n = f.CopyOnWrite();
n->body =
tl::MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem,
debug_merge_shared_memory_allocations);
n->body = tl::MergeSharedMemoryAllocations(
std::move(n->body), merge_static_smem, enable_aggressive_merge,
debug_merge_shared_memory_allocations);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.MergeSharedMemoryAllocations",
Expand Down
18 changes: 17 additions & 1 deletion tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ def allow_global_thread_synchronization(pass_ctx: Optional[PassContext] = None)
return enable_global_thread_sync


def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
enable_aggressive_merge = bool(
pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False))
Comment on lines +57 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The bool() constructor here is redundant because pass_ctx.config.get(..., False) already returns a boolean value. Removing the explicit cast will make the code cleaner and more concise.

    enable_aggressive_merge = pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False)

if allow_warp_specialized(pass_ctx=pass_ctx, target=target):
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# when warp specialization is enabled, as different warp threads may access different
# buffers, but the liveness analysis is hard because we need to do pipeline.
enable_aggressive_merge = False
return enable_aggressive_merge


def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module
mod = tir.transform.BindTarget(target)(mod)
Expand Down Expand Up @@ -151,7 +165,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)

mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target))(
mod)

mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
Expand Down
4 changes: 2 additions & 2 deletions tilelang/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,15 @@ def EliminateStorageSyncForMBarrier():
return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore


def MergeSharedMemoryAllocations():
def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False):
"""MergeSharedMemoryAllocations

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MergeSharedMemoryAllocations() # type: ignore
return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge) # type: ignore


def LowerL2Persistent():
Expand Down
3 changes: 3 additions & 0 deletions tilelang/transform/pass_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class PassConfigKey(str, Enum):
TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations"
"""Enable debug information for merge shared memory allocations. Default: False"""

TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE = "tl.enable_aggressive_shared_memory_merge"
"""Enable aggressive merge of shared memory allocations. Default: False"""

# TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
Expand Down
Loading