diff --git a/src/op/builtin.cc b/src/op/builtin.cc index d453dfc4a..62ffd935e 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -26,6 +26,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer); +TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool); #define TIR_DEFINE_TL_BUILTIN(OpName) \ const Op &OpName() { \ diff --git a/src/op/builtin.h b/src/op/builtin.h index ab521643a..4234b6f4d 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -28,6 +28,8 @@ static constexpr const char *kDisableSafeMemoryLegalize = static constexpr const char *kDisableWarpSpecialized = "tl.disable_warp_specialized"; static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth"; +static constexpr const char *kEnableAggressiveSharedMemoryMerge = + "tl.enable_aggressive_shared_memory_merge"; /*! * \brief Whether to disable dynamic tail split diff --git a/src/transform/merge_shared_memory_allocations.cc b/src/transform/merge_shared_memory_allocations.cc index 56e8a119c..50665a87a 100644 --- a/src/transform/merge_shared_memory_allocations.cc +++ b/src/transform/merge_shared_memory_allocations.cc @@ -95,9 +95,11 @@ class AllocateCollector : public StmtExprVisitor { // class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { public: - explicit SharedMemLinearAccessPatternFinder(bool is_dynamic = true, - bool verbose = false) - : is_dynamic_(is_dynamic), verbose_(verbose) {} + explicit SharedMemLinearAccessPatternFinder( + bool is_dynamic = true, bool enable_aggressive_merge = false, + bool verbose = false) + : is_dynamic_(is_dynamic), + enable_aggressive_merge_(enable_aggressive_merge), verbose_(verbose) {} /*! \brief record the touch list of statement. */ struct StmtEntry { // The statement @@ -151,9 +153,15 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { ICHECK_LT(it->second.level, scope_.size()); if (IsAppropriateSharedMemory(GetRef(buf))) { // set into scope_.size() - 1 for aggressive memory reuse - scope_[it->second.level].touched.push_back(buf); + auto enable_aggressive_merge = enable_aggressive_merge_; + if (enable_aggressive_merge) { + scope_[scope_.size() - 1].touched.push_back(buf); + } else { + scope_[it->second.level].touched.push_back(buf); + } } } + StmtEntry e = scope_.back(); scope_.pop_back(); if (e.touched.size() != 0) { @@ -185,7 +193,12 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; if (IsAppropriateSharedMemory(GetRef(buf))) { - scope_[it->second.level].touched.push_back(buf); + auto enable_aggressive_merge = enable_aggressive_merge_; + if (enable_aggressive_merge) { + scope_[scope_.size() - 1].touched.push_back(buf); + } else { + scope_[it->second.level].touched.push_back(buf); + } } } } @@ -196,7 +209,12 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); if (IsAppropriateSharedMemory(GetRef(buf))) { - scope_[it->second.level].touched.push_back(buf); + auto enable_aggressive_merge = enable_aggressive_merge_; + if (enable_aggressive_merge) { + scope_[scope_.size() - 1].touched.push_back(buf); + } else { + scope_[it->second.level].touched.push_back(buf); + } } } } @@ -284,6 +302,8 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { } // Whether do dyanmic analysis. bool is_dynamic_{true}; + // Whether do aggressive merge. + bool enable_aggressive_merge_{false}; // Whether do verbose logging. bool verbose_{false}; // Whether already in thread env. @@ -317,8 +337,9 @@ class SharedMemoryRewriter : public StmtExprMutator { * \param stmt the statement */ void PlanReuse(const Stmt &stmt, bool is_dynamic = true, - bool verbose = false) { - SharedMemLinearAccessPatternFinder finder(is_dynamic, verbose); + bool enable_aggressive_merge = false, bool verbose = false) { + SharedMemLinearAccessPatternFinder finder(is_dynamic, + enable_aggressive_merge, verbose); finder(stmt); this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_); this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_); @@ -956,6 +977,7 @@ class SharedMemoryRewriter : public StmtExprMutator { } // Wheather enable dyanmic analysis. bool is_dynamic_{true}; + // Whether enable verbose logging. bool verbose_{false}; // The var for the merged buffer @@ -985,18 +1007,19 @@ class SharedMemoryRewriter : public StmtExprMutator { }; Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem, + bool enable_aggressive_merge, bool verbose = false) { AllocateCollector collector; collector(stmt); if (collector.dyn_shmem_allocs_.size() > 1) { SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_, true, verbose); - rewriter.PlanReuse(stmt); + rewriter.PlanReuse(stmt, true, enable_aggressive_merge); stmt = rewriter(std::move(stmt)); } if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) { SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false, verbose); - rewriter.PlanReuse(stmt, false); + rewriter.PlanReuse(stmt, false, enable_aggressive_merge); stmt = rewriter(std::move(stmt)); } return stmt; @@ -1006,17 +1029,18 @@ using namespace tir::transform; namespace transform { -Pass MergeSharedMemoryAllocations() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { +Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false) { + auto pass_func = [enable_aggressive_merge](PrimFunc f, IRModule m, + PassContext ctx) { bool merge_static_smem = ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); bool debug_merge_shared_memory_allocations = ctx->GetConfig(kDebugMergeSharedMemoryAllocations, Bool(false)) .value(); auto *n = f.CopyOnWrite(); - n->body = - tl::MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem, - debug_merge_shared_memory_allocations); + n->body = tl::MergeSharedMemoryAllocations( + std::move(n->body), merge_static_smem, enable_aggressive_merge, + debug_merge_shared_memory_allocations); return f; }; return CreatePrimFuncPass(pass_func, 0, "tl.MergeSharedMemoryAllocations", diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 5d2fbfd78..74ba8d00d 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -50,6 +50,20 @@ def allow_global_thread_synchronization(pass_ctx: Optional[PassContext] = None) return enable_global_thread_sync +def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None, + target: Optional[Target] = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + enable_aggressive_merge = bool( + pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False)) + if allow_warp_specialized(pass_ctx=pass_ctx, target=target): + # This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass + # when warp specialization is enabled, as different warp threads may access different + # buffers, but the liveness analysis is hard because we need to do pipeline. + enable_aggressive_merge = False + return enable_aggressive_merge + + def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Bind the target device information to the module mod = tir.transform.BindTarget(target)(mod) @@ -151,7 +165,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tir.transform.SplitHostDevice()(mod) - mod = tilelang.transform.MergeSharedMemoryAllocations()(mod) + mod = tilelang.transform.MergeSharedMemoryAllocations( + enable_aggressive_merge=should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target))( + mod) mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod) diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index a2a4b49f1..ff8a1fb22 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -335,7 +335,7 @@ def EliminateStorageSyncForMBarrier(): return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore -def MergeSharedMemoryAllocations(): +def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False): """MergeSharedMemoryAllocations Returns @@ -343,7 +343,7 @@ def MergeSharedMemoryAllocations(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.MergeSharedMemoryAllocations() # type: ignore + return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge) # type: ignore def LowerL2Persistent(): diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 1ad7dd706..cf426c6d2 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -32,6 +32,9 @@ class PassConfigKey(str, Enum): TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations" """Enable debug information for merge shared memory allocations. Default: False""" + TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE = "tl.enable_aggressive_shared_memory_merge" + """Enable aggressive merge of shared memory allocations. Default: False""" + # TIR related configs TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir" """Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""