Skip to content

Commit 457fdab

Browse files
committed
Matmul scheduler - apply changes from #2488
- apply improvement in matmul instruction scheduling with loop rotation
1 parent 8d19dcd commit 457fdab

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

third_party/nvfuser/csrc/scheduler/matmul.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -604,12 +604,12 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
604604
acr->axis(-1)->parallelize(ParallelType::Vectorize);
605605
bcr->axis(-1)->parallelize(ParallelType::Vectorize);
606606

607-
// 0 1 2 3 4 5 6 7 8 9 10
608-
// [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
607+
// 0 1 2 3 4 5 6 7 8 9 10
608+
// [Mo No Ko Kwo Mwo Nwo Mw Nw (Mi Ni Ki)]
609609
cc->axis(0)->parallelize(ParallelType::BIDx);
610610
cc->axis(1)->parallelize(ParallelType::BIDy);
611-
cc->axis(3)->parallelize(ParallelType::TIDz);
612-
cc->axis(4)->parallelize(ParallelType::TIDy);
611+
cc->axis(4)->parallelize(ParallelType::TIDz);
612+
cc->axis(5)->parallelize(ParallelType::TIDy);
613613

614614
// Propagate mma output swizzle and parallelization down the DAG
615615
if (params.double_buffer_options.double_buffer_smem_write) {
@@ -640,6 +640,11 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
640640
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
641641
.propagateParallelType()
642642
.propagateToBoundary());
643+
644+
if (params.double_buffer_options.double_buffer_smem_read &&
645+
params.double_buffer_options.double_buffer_smem_write) {
646+
scheduler_utils::rotateLoop(cc, 2, {acr, bcr});
647+
}
643648
}
644649

645650
} // namespace nvfuser

third_party/nvfuser/csrc/scheduler/matmul_heuristic.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ class MatmulParams : public HeuristicParams {
1717
int smem_double_buffer_stage = 2;
1818
};
1919

20+
//! Whether to rotate the ldmatrix out of the main loop
21+
bool rotate_ldmatrix_out_of_main_loop = true;
22+
2023
//! (Ampere+) Use cp.async to load operands.
2124
bool async_gmem_load_operands = false;
2225

0 commit comments

Comments
 (0)