Skip to content

Commit 371f282

Browse files
authored
Improve trivial reduction merge support (#1931)
1 parent 1d0c267 commit 371f282

File tree

2 files changed

+99
-12
lines changed

2 files changed

+99
-12
lines changed

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,24 +1434,47 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
14341434
"Merging IterDomains with ending values that are 0 is not supported at this time.");
14351435
TORCH_CHECK(
14361436
outer->isReduction() == inner->isReduction() ||
1437-
(!outer->isReduction() && inner->extent()->isOneInt()) ||
1438-
(outer->extent()->isOneInt() && !inner->isReduction()),
1437+
(!outer->isReduction() && inner->isTrivialReduction()) ||
1438+
(outer->isTrivialReduction() && !inner->isReduction()),
14391439
"Merging IterDomains requires that their iteration types match.");
14401440
TORCH_CHECK(
14411441
(outer->isGather() && inner->isGather()) ||
14421442
(!outer->isGather() && !inner->isGather()),
14431443
"Merging gather and non-gather domains is not supported.");
14441444

1445+
TORCH_CHECK(
1446+
!outer->isStride() && !inner->isStride(),
1447+
"No support for merging stride domains");
1448+
14451449
Val* merged_id_size = mul(outer->extent(), inner->extent());
14461450

14471451
IterType itype = outer->getIterType();
14481452

14491453
if (outer->isBroadcast() && inner->isBroadcast()) {
14501454
itype = IterType::Broadcast;
1451-
} else if (outer->isBroadcast() || inner->isBroadcast()) {
1455+
}
1456+
1457+
if ((outer->isBroadcast() || inner->isBroadcast()) &&
1458+
(outer->getIterType() == IterType::Iteration ||
1459+
inner->getIterType() == IterType::Iteration)) {
1460+
itype = IterType::Iteration;
1461+
}
1462+
1463+
// Merging trivial reduction with iter domain, that's fine, just make it an
1464+
// iter domain.
1465+
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
1466+
(outer->getIterType() == IterType::Iteration ||
1467+
inner->getIterType() == IterType::Iteration)) {
14521468
itype = IterType::Iteration;
14531469
}
14541470

1471+
// Merging trivial reduction with broadcasting, that's fine, just make it a
1472+
// broadcasting.
1473+
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
1474+
(outer->isBroadcast() || inner->isBroadcast())) {
1475+
itype = IterType::Broadcast;
1476+
}
1477+
14551478
Val* expanded_extent = nullptr;
14561479
if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) {
14571480
if (outer->hasExpandedExtent() && inner->hasExpandedExtent()) {
@@ -1471,13 +1494,6 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
14711494
}
14721495
}
14731496

1474-
// Merging trivial reduction with iter domain, that's fine, just make it an
1475-
// iter domain.
1476-
if ((outer->isReduction() || inner->isReduction()) &&
1477-
(!outer->isReduction() || !inner->isReduction())) {
1478-
itype = IterType::Iteration;
1479-
}
1480-
14811497
IterDomain* merged_id =
14821498
IterDomainBuilder(
14831499
outer->container()->zeroVal(), merged_id_size->as<Int>())

torch/csrc/jit/codegen/cuda/test/test_gpu.cpp

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24872,7 +24872,7 @@ TEST_F(NVFuserTest, FusionInsertMagicZero1_CUDA) {
2487224872
tv2->reorder({{1, 2}, {2, 1}});
2487324873
tv2->merge(0);
2487424874

24875-
TransformPropagator propagator(tv2);
24875+
TransformPropagatorWithCheck propagator(tv2);
2487624876
MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);
2487724877

2487824878
tv0->computeAt(tv2, 1);
@@ -24992,7 +24992,7 @@ TEST_F(NVFuserTest, FusionExpandReduce2_CUDA) {
2499224992
// [iBIDx, iTIDx, rTIDy, rBIDy, rO]
2499324993
auto tv3 = tv2->rFactor({-1});
2499424994

24995-
TransformPropagator propagator(tv3);
24995+
TransformPropagatorWithCheck propagator(tv3);
2499624996
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
2499724997
scheduler_utils::parallelizeAllLike(tv3);
2499824998
tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined);
@@ -25693,6 +25693,77 @@ TEST_F(NVFuserTest, AsyncCompilation_CUDA) {
2569325693
executor_cache.fusion(), outputs, aten_inputs, {t6}, __LINE__, __FILE__);
2569425694
}
2569525695

25696+
TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction1_CUDA) {
25697+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
25698+
auto fusion = fusion_ptr.get();
25699+
FusionGuard fg(fusion);
25700+
25701+
TensorView* tv0 = makeConcreteTensor({1, 1});
25702+
TensorView* tv1 = makeConcreteTensor({-1});
25703+
fusion->addInput(tv0);
25704+
fusion->addInput(tv1);
25705+
auto tv2 = sum(tv0, {1});
25706+
auto tv3 = add(tv2, tv1);
25707+
fusion->addOutput(tv3);
25708+
25709+
tv0->merge(0);
25710+
25711+
MaxRootDomainInfoSpanningTree tree(tv0);
25712+
TransformPropagatorWithCheck tp(tv0);
25713+
tree.traverse(&tp);
25714+
25715+
InlinePropagator ip(tv0, -1, ComputeAtMode::MostInlined);
25716+
tree.traverse(&ip);
25717+
25718+
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
25719+
at::Tensor t0 = at::randn({1, 1}, options);
25720+
at::Tensor t1 = at::randn({10}, options);
25721+
25722+
FusionExecutor fe;
25723+
fe.compileFusion(fusion, {t0, t1});
25724+
auto cg_outputs = fe.runFusion({t0, t1});
25725+
auto out = cg_outputs[0];
25726+
25727+
testValidate(
25728+
fusion, {out}, {t0, t1}, {t1 + t0.flatten()}, __LINE__, __FILE__);
25729+
}
25730+
25731+
TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction2_CUDA) {
25732+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
25733+
auto fusion = fusion_ptr.get();
25734+
FusionGuard fg(fusion);
25735+
25736+
TensorView* tv0 = makeConcreteTensor({-1, 1, 1});
25737+
TensorView* tv1 = makeConcreteTensor({-1, -1});
25738+
fusion->addInput(tv0);
25739+
fusion->addInput(tv1);
25740+
auto tv2 = sum(tv0, {1});
25741+
auto tv3 = add(tv2, tv1);
25742+
fusion->addOutput(tv3);
25743+
25744+
tv2->merge(1);
25745+
tv2->merge(0);
25746+
25747+
MaxRootDomainInfoSpanningTree tree(tv0);
25748+
TransformPropagatorWithCheck tp(tv0);
25749+
tree.traverse(&tp);
25750+
25751+
InlinePropagator ip(tv0, -1, ComputeAtMode::MostInlined);
25752+
tree.traverse(&ip);
25753+
25754+
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
25755+
at::Tensor t0 = at::randn({10, 1, 1}, options);
25756+
at::Tensor t1 = at::randn({10, 10}, options);
25757+
25758+
FusionExecutor fe;
25759+
fe.compileFusion(fusion, {t0, t1});
25760+
auto cg_outputs = fe.runFusion({t0, t1});
25761+
auto out = cg_outputs[0];
25762+
25763+
testValidate(
25764+
fusion, {out}, {t0, t1}, {t1 + t0.squeeze(-1)}, __LINE__, __FILE__);
25765+
}
25766+
2569625767
} // namespace jit
2569725768
} // namespace torch
2569825769
#endif // #if defined(USE_CUDA)

0 commit comments

Comments
 (0)