Skip to content

Commit 6ac74d1

Browse files
authored
Fix sync map (#2047)
1 parent f5bca33 commit 6ac74d1

File tree

5 files changed

+52
-7
lines changed

5 files changed

+52
-7
lines changed

torch/csrc/jit/codegen/cuda/lower2device.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,9 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {
303303
// Depends on thread_pred_map_, validates parallelization collects which
304304
// tensor views need WAR or RAW syncs
305305
sync_map_.build(fusion_);
306+
if (isDebugDumpEnabled(DebugDumpOption::SyncMap)) {
307+
std::cout << sync_map_.toString() << std::endl;
308+
}
306309

307310
partialSplitMap().build(fusion_);
308311

torch/csrc/jit/codegen/cuda/lower_sync_information.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ void SyncMap::build(Fusion* fusion) {
483483
} // end for consumers
484484

485485
if (raw_dims.any()) {
486-
needs_raw_sync_[producer] = raw_dims;
486+
needs_raw_sync_[producer] |= raw_dims;
487487
}
488488

489489
} // end producer
@@ -492,10 +492,14 @@ void SyncMap::build(Fusion* fusion) {
492492

493493
std::string SyncMap::toString() const {
494494
std::stringstream ss;
495-
ss << "TVs requiring RAW:" << std::endl;
495+
ss << "SyncMap:";
496+
bool is_first = true;
496497
for (auto entry : needs_raw_sync_) {
497-
ss << " " << entry.first->toString() << " :: " << entry.second.toString()
498-
<< std::endl;
498+
if (!is_first) {
499+
ss << ",";
500+
}
501+
ss << " " << entry.first->toString() << " -> " << entry.second.toString();
502+
is_first = false;
499503
}
500504
return ss.str();
501505
}

torch/csrc/jit/codegen/cuda/test/test_gpu2.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5040,6 +5040,40 @@ TEST_F(NVFuserTest, FusionValidateParallelize6_CUDA) {
50405040
ASSERT_ANY_THROW(fusion.printKernel());
50415041
}
50425042

5043+
// Repro of #2046
5044+
TEST_F(NVFuserTest, FusionValidateParallelize7_CUDA) {
5045+
Fusion fusion;
5046+
FusionGuard fg(&fusion);
5047+
5048+
auto tv0 = makeSymbolicTensor(2);
5049+
fusion.addInput(tv0);
5050+
5051+
auto tv1 = set(tv0);
5052+
auto tv2 = set(tv1);
5053+
auto tv3 = set(tv1);
5054+
fusion.addOutput(tv2);
5055+
fusion.addOutput(tv3);
5056+
5057+
tv1->setMemoryType(MemoryType::Global);
5058+
5059+
tv1->axis(0)->parallelize(ParallelType::BIDx);
5060+
tv1->axis(1)->parallelize(ParallelType::TIDx);
5061+
5062+
tv2->axis(1)->parallelize(ParallelType::TIDy);
5063+
tv3->axis(0)->parallelize(ParallelType::BIDx);
5064+
5065+
// tv2 uses tv1 but is not parallelized with BIDx, so a grid sync is
5066+
// required. It should be placed as a top-level expression.
5067+
5068+
GpuLower gpulw(&fusion);
5069+
TORCH_CHECK(
5070+
std::any_of(
5071+
gpulw.kernel()->topLevelExprs().begin(),
5072+
gpulw.kernel()->topLevelExprs().end(),
5073+
[](Expr* expr) { return expr->isA<kir::GridSync>(); }),
5074+
"Grid sync not found");
5075+
}
5076+
50435077
TEST_F(NVFuserTest, FusionDAGMerging_CUDA) {
50445078
Fusion fusion;
50455079
FusionGuard fg(&fusion);

torch/csrc/jit/codegen/cuda/utils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ auto parseDebugDumpOptions() {
4343
{DebugDumpOption::TransformPropagator, false},
4444
{DebugDumpOption::Cubin, false},
4545
{DebugDumpOption::Ptx, false},
46-
{DebugDumpOption::BankConflictInfo, false}};
46+
{DebugDumpOption::BankConflictInfo, false},
47+
{DebugDumpOption::SyncMap, false}};
4748

4849
if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_DUMP")) {
4950
c10::string_view options_view(dump_options);
@@ -106,6 +107,8 @@ auto parseDebugDumpOptions() {
106107
options_map[DebugDumpOption::Ptx] = true;
107108
} else if (token == "bank_conflict") {
108109
options_map[DebugDumpOption::BankConflictInfo] = true;
110+
} else if (token == "sync_map") {
111+
options_map[DebugDumpOption::SyncMap] = true;
109112
} else {
110113
TORCH_CHECK(
111114
false,
@@ -118,7 +121,7 @@ auto parseDebugDumpOptions() {
118121
"\tdraw_segmented_fusion, scheduler_params, parallel_dimensions,\n",
119122
"\tbuffer_reuse_verbose, ptxas_verbose, halo, segmenter_logging,\n",
120123
"\tperf_debug_verbose, python_definition, python_frontend_debug,\n",
121-
"\ttransform_propagator, cubin, ptx, bank_conflict\n");
124+
"\ttransform_propagator, cubin, ptx, bank_conflict, sync_map\n");
122125
}
123126
options_view = (end_pos != c10::string_view::npos)
124127
? options_view.substr(end_pos + 1)

torch/csrc/jit/codegen/cuda/utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ enum class DebugDumpOption {
5959
//! path and replay result
6060
Cubin, //! Dump compiled CUBIN
6161
Ptx, //! Dump compiled PTX
62-
BankConflictInfo //! Dump bank confliction info
62+
BankConflictInfo, //! Dump bank confliction info
63+
SyncMap //! RAW dependency info
6364
};
6465

6566
TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);

0 commit comments

Comments
 (0)