diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index 8e12c96db..f3dc0d78d 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -165,8 +165,8 @@ class BufferRegionCollector : public StmtExprVisitor { class PipelinePlanner : public StmtExprMutator { public: - static Stmt Substitute(const PrimFunc &f) { - PipelinePlanner substituter; + static Stmt Substitute(const PrimFunc &f, bool use_async_copy = true) { + PipelinePlanner substituter(use_async_copy); for (const auto &[_, buffer] : f->buffer_map) { substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); } @@ -179,6 +179,7 @@ class PipelinePlanner : public StmtExprMutator { private: PipelinePlanner() = default; + PipelinePlanner(bool use_async_copy) : use_async_copy_(use_async_copy) {} /*! \brief Information about a pipeline stage * @@ -262,7 +263,7 @@ class PipelinePlanner : public StmtExprMutator { } } annotations.Set(tir::attr::software_pipeline_stage, stage_anno); - if (TargetHasAsyncCopy(target_)) + if (TargetHasAsyncCopy(target_) && use_async_copy_) annotations.Set(tir::attr::software_pipeline_async_stages, Array{0}); auto for_node = GetRef(loop); @@ -459,7 +460,7 @@ class PipelinePlanner : public StmtExprMutator { annotations.Set(tir::attr::software_pipeline_stage, Array(stages)); annotations.Set(tir::attr::software_pipeline_order, Array(orders)); - if (TargetHasAsyncCopy(target_)) + if (TargetHasAsyncCopy(target_) && use_async_copy_) annotations.Set(tir::attr::software_pipeline_async_stages, Array{0}); @@ -480,13 +481,16 @@ class PipelinePlanner : public StmtExprMutator { Map buffer_data_to_buffer_; Target target_; + bool use_async_copy_; }; tvm::transform::Pass PipelinePlanning() { using namespace tir::transform; auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + bool use_async_copy = + ctx->GetConfig("tir.use_async_copy", Bool(true)).value(); PrimFuncNode *fptr = f.CopyOnWrite(); - fptr->body = PipelinePlanner::Substitute(f); + fptr->body = PipelinePlanner::Substitute(f, use_async_copy); return f; }; return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});