Skip to content

Commit b34e3b9

Browse files
authored
Fix ir_utils::hasBlockSync + misc fixes in transpose scheduler (#1924)
1 parent 14a53e6 commit b34e3b9

File tree

5 files changed

+92
-5
lines changed

5 files changed

+92
-5
lines changed

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -754,10 +754,12 @@ class CudaKernelGenerator : private OptOutConstDispatch {
754754
auto out_tv = rop->output(0)->as<kir::TensorIndex>()->view();
755755
auto index = genTensorIndex(rop->getPhiloxIndex()->as<kir::TensorIndex>());
756756
int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4;
757-
indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = (" << index
758-
<< ") / " << multiple << ";\n";
759-
indent() << "nvfuser_index_t rng_component" << rop->name() << " = ("
760-
<< index << ") % " << multiple << ";\n";
757+
indent() << "nvfuser_index_t linear_index" << rop->name() << " = " << index
758+
<< ";\n";
759+
indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = linear_index"
760+
<< rop->name() << " / " << multiple << ";\n";
761+
indent() << "nvfuser_index_t rng_component" << rop->name()
762+
<< " = linear_index" << rop->name() << " % " << multiple << ";\n";
761763
indent() << "nvfuser_index_t rng_offset" << rop->name() << " = "
762764
<< rop->getRNGOffset() << ";\n";
763765
indent() << "if (rng_subseq != rng_subseq" << rop->name()

torch/csrc/jit/codegen/cuda/lower_utils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ bool isScalarOp(const Expr* expr) {
204204
}
205205

206206
bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) {
207+
if (expr->isA<kir::BlockSync>()) {
208+
return true;
209+
}
210+
207211
if (!isTvOp(expr)) {
208212
return false;
209213
}

torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ class DomainMap : public pointwise_utils::DomainMap {
109109
decltype(input_tvs)* tv_filtered_groups[2] = {&output_tvs, &input_tvs};
110110
for (auto tv_filtered_group : tv_filtered_groups) {
111111
for (auto tv : *tv_filtered_group) {
112+
if (tv->isFusionInput() && tv->uses().empty()) {
113+
continue;
114+
}
112115
if (grouped.count(tv) > 0) {
113116
continue;
114117
}
@@ -653,7 +656,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {
653656
if (inner_most_pos2_in_ref1 > inner_most_pos1_in_ref1) {
654657
inner_most_pos2_in_ref1--;
655658
}
656-
if (!merged2.has_value() && *merged2 > inner_most_pos1_in_ref1) {
659+
if (merged2.has_value() && *merged2 > inner_most_pos1_in_ref1) {
657660
(*merged2)--;
658661
}
659662
reference1->merge(*merged1, inner_most_pos1_in_ref1);

torch/csrc/jit/codegen/cuda/test/test_gpu.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25512,6 +25512,49 @@ TEST_F(NVFuserTest, FusionSizeDependentData_CUDA) {
2551225512
executor_cache.fusion(), cg_outputs, {a}, {a + 123}, __LINE__, __FILE__);
2551325513
}
2551425514

25515+
TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) {
25516+
// https://github.com/csarofeen/pytorch/issues/1926
25517+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
25518+
auto fusion = fusion_ptr.get();
25519+
FusionGuard fg(fusion);
25520+
25521+
TensorView* tv0 = makeSymbolicTensor(2);
25522+
fusion->addInput(tv0);
25523+
auto tv1 = set(tv0);
25524+
auto tv2 = set(tv1);
25525+
fusion->addOutput(tv2);
25526+
25527+
tv1->setMemoryType(MemoryType::Shared);
25528+
for (auto tv : {tv1, tv2}) {
25529+
tv->split(0, 4);
25530+
tv->reorder({{1, -1}});
25531+
tv->split(1, 8);
25532+
tv->merge(0);
25533+
tv->split(0, 1);
25534+
tv->axis(0)->parallelize(ParallelType::BIDx);
25535+
tv->axis(1)->parallelize(ParallelType::Unswitch);
25536+
}
25537+
tv1->merge(2);
25538+
tv2->reorder({{2, 3}});
25539+
tv2->merge(2);
25540+
for (auto tv : {tv1, tv2}) {
25541+
tv->axis(-1)->parallelize(ParallelType::TIDx);
25542+
}
25543+
25544+
InlinePropagator propagator(tv2, -1, ComputeAtMode::MostInlined);
25545+
MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);
25546+
25547+
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
25548+
at::Tensor t0 = at::randn({5, 5}, options);
25549+
25550+
FusionExecutor fe;
25551+
fe.compileFusion(fusion, {t0});
25552+
auto cg_outputs = fe.runFusion({t0});
25553+
auto out = cg_outputs[0];
25554+
25555+
testValidate(fusion, {out}, {t0}, {t0}, __LINE__, __FILE__);
25556+
}
25557+
2551525558
} // namespace jit
2551625559
} // namespace torch
2551725560
#endif // #if defined(USE_CUDA)

torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,5 +264,40 @@ TEST_F(NVFuserTest, FusionBroadcastingRNGSmem_CUDA) {
264264
}
265265
}
266266
267+
TEST_F(NVFuserTest, FusionBroadcastingRNGSmemNonSquareTile_CUDA) {
268+
// https://github.com/csarofeen/pytorch/issues/1926
269+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
270+
auto fusion = fusion_ptr.get();
271+
FusionGuard fg(fusion);
272+
273+
TensorView* tv0 = makeConcreteTensor({5, 1});
274+
TensorView* tv1 = makeConcreteTensor({5, 5});
275+
fusion->addInput(tv0);
276+
fusion->addInput(tv1);
277+
auto tv2 = randlike(tv0);
278+
auto tv3 = add(tv1, tv2);
279+
auto tv4 = add(tv0, tv3);
280+
fusion->addOutput(tv4);
281+
282+
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
283+
at::Tensor t0 = at::zeros({5, 1}, options);
284+
at::Tensor t1 = at::zeros({5, 5}, options);
285+
286+
TransposeParams heuristics;
287+
heuristics.tile_size1 = 8;
288+
heuristics.tile_size2 = 4;
289+
scheduleTranspose(fusion, heuristics);
290+
291+
FusionExecutor fe;
292+
fe.compileFusion(fusion, {t0, t1});
293+
auto cg_outputs = fe.runFusion({t0, t1});
294+
auto out = cg_outputs[0];
295+
296+
TORCH_CHECK((out.select(1, 0) == out.select(1, 1)).all().item<bool>());
297+
TORCH_CHECK((out.select(1, 0) == out.select(1, 2)).all().item<bool>());
298+
TORCH_CHECK((out.select(1, 0) == out.select(1, 3)).all().item<bool>());
299+
TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item<bool>());
300+
}
301+
267302
} // namespace jit
268303
} // namespace torch

0 commit comments

Comments
 (0)