Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/transform/pipeline_planning.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ class BufferRegionCollector : public StmtExprVisitor {

class PipelinePlanner : public StmtExprMutator {
public:
static Stmt Substitute(const PrimFunc &f) {
PipelinePlanner substituter;
static Stmt Substitute(const PrimFunc &f, bool use_async_copy = true) {
PipelinePlanner substituter(use_async_copy);
for (const auto &[_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
Expand All @@ -179,6 +179,7 @@ class PipelinePlanner : public StmtExprMutator {

private:
PipelinePlanner() = default;
PipelinePlanner(bool use_async_copy) : use_async_copy_(use_async_copy) {}

/*! \brief Information about a pipeline stage
*
Expand Down Expand Up @@ -262,7 +263,7 @@ class PipelinePlanner : public StmtExprMutator {
}
}
annotations.Set(tir::attr::software_pipeline_stage, stage_anno);
if (TargetHasAsyncCopy(target_))
if (TargetHasAsyncCopy(target_) && use_async_copy_)
annotations.Set(tir::attr::software_pipeline_async_stages,
Array<Integer>{0});
auto for_node = GetRef<For>(loop);
Expand Down Expand Up @@ -459,7 +460,7 @@ class PipelinePlanner : public StmtExprMutator {

annotations.Set(tir::attr::software_pipeline_stage, Array<Integer>(stages));
annotations.Set(tir::attr::software_pipeline_order, Array<Integer>(orders));
if (TargetHasAsyncCopy(target_))
if (TargetHasAsyncCopy(target_) && use_async_copy_)
annotations.Set(tir::attr::software_pipeline_async_stages,
Array<Integer>{0});

Expand All @@ -480,13 +481,16 @@ class PipelinePlanner : public StmtExprMutator {

Map<Var, Buffer> buffer_data_to_buffer_;
Target target_;
bool use_async_copy_;
};

tvm::transform::Pass PipelinePlanning() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
bool use_async_copy =
ctx->GetConfig<Bool>("tir.use_async_copy", Bool(true)).value();
Comment on lines +490 to +491
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The configuration key "tir.use_async_copy" is a bit generic. Since this pass is part of tilelang, it would be better to namespace the configuration key to avoid potential conflicts with other TIR passes and to make its scope clearer. Suggest using a more specific key, like "tl.pipeline.use_async_copy".

Suggested change
bool use_async_copy =
ctx->GetConfig<Bool>("tir.use_async_copy", Bool(true)).value();
bool use_async_copy =
ctx->GetConfig<Bool>("tl.pipeline.use_async_copy", Bool(true)).value();

PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = PipelinePlanner::Substitute(f);
fptr->body = PipelinePlanner::Substitute(f, use_async_copy);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
Comment on lines 487 to 496
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new functionality to disable asynchronous copy is not covered by any tests. Add a new test case to verify the behavior when use_async_copy is False to prevent future regressions.

Expand Down
Loading