Skip to content

Commit 93505bc

Browse files
authored
WAR on index mapping when exact and permissive maps differ (#1960)
1 parent 45e95fd commit 93505bc

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

torch/csrc/jit/codegen/cuda/lower_index_compute.cpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -732,8 +732,36 @@ void LoopIndexingAnalysis::constructLoopDomains() {
732732
!concrete_id_to_consumer_.count(concrete_id) &&
733733
// Use permissive map so the selected ID indeed represents the
734734
// loop.
735-
GpuLower::current()->caMap()->areMapped(
736-
concrete_id, loop_id, IdMappingMode::PERMISSIVE);
735+
// Note: see PR https://github.com/csarofeen/pytorch/pull/1960
736+
// and issue https://github.com/csarofeen/pytorch/issues/1873
737+
// This mapping look up is part of a staged indexing scheme.
738+
// When we find a replayed exact id that exactly map to the loop
739+
// id, this means that we can resolve indexing involved in this
740+
// loop "locally", i.e. only with and with only the iterdomains
741+
// on the
742+
//
743+
// given consumer tv.
744+
// When we cannot find an exact mapping, the permissive mapping
745+
// would
746+
// help defering the indexing resolution for this loop nest
747+
// level to other iterdomain expressions from tv's that are
748+
// further concretized and usually they are further down the
749+
// consumer chain of the given consumer tv.
750+
//
751+
// Intuitively exact mapping of two iterdomains should imply
752+
// permissive mapping
753+
// of them as well and if that was the case, only looking up
754+
// permissive mapping would be enough to address both of the
755+
// cases above.
756+
// FIXME: But currently exact mapping does not imply permissive
757+
// mapping (See issue:
758+
// https://github.com/csarofeen/pytorch/issues/1963)
759+
// Which means we should check both exact and permissive mapping
760+
// here.
761+
(GpuLower::current()->caMap()->areMapped(
762+
concrete_id, loop_id, IdMappingMode::EXACT) ||
763+
GpuLower::current()->caMap()->areMapped(
764+
concrete_id, loop_id, IdMappingMode::PERMISSIVE));
737765
});
738766

739767
TORCH_INTERNAL_ASSERT(

torch/csrc/jit/codegen/cuda/test/test_gpu.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25747,6 +25747,51 @@ TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction2_CUDA) {
2574725747
fusion, {out}, {t0, t1}, {t1 + t0.squeeze(-1)}, __LINE__, __FILE__);
2574825748
}
2574925749

25750+
TEST_F(NVFuserTest, FusionMappingRelation_CUDA) {
25751+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
25752+
auto fusion = fusion_ptr.get();
25753+
FusionGuard fg(fusion);
25754+
25755+
TensorView* tv0 = makeConcreteTensor({1, 1});
25756+
TensorView* tv1 = makeConcreteTensor({-1, 1, 1});
25757+
fusion->addInput(tv0);
25758+
fusion->addInput(tv1);
25759+
auto tv2 = set(tv0);
25760+
auto tv3 = broadcast(tv2, {true, false, false});
25761+
auto tv4 = add(tv3, tv1);
25762+
25763+
fusion->addOutput(tv4);
25764+
25765+
tv4->merge(-2);
25766+
tv4->merge(-1);
25767+
25768+
tv0->computeAt(tv4, -1);
25769+
tv1->computeAt(tv4, -1);
25770+
25771+
ComputeAtMap ca_map(fusion);
25772+
25773+
// FIXME: This is the concerning part that would motivate some
25774+
// more formalization on concrete/permissive mapping:
25775+
// exact mapping should ideally imply permissive mapping.
25776+
auto tv4_inner_node = tv4->axis(0)->definition()->input(1)->as<IterDomain>();
25777+
TORCH_CHECK(
25778+
ca_map.areMapped(tv2->axis(0), tv4_inner_node, IdMappingMode::EXACT));
25779+
TORCH_CHECK(!ca_map.areMapped(
25780+
tv2->axis(0), tv4_inner_node, IdMappingMode::PERMISSIVE));
25781+
25782+
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
25783+
at::Tensor t0 = at::randn({1, 1}, options);
25784+
at::Tensor t1 = at::randn({2, 1, 1}, options);
25785+
25786+
FusionExecutor fe;
25787+
fe.compileFusion(fusion, {t0, t1});
25788+
auto cg_outputs = fe.runFusion({t0, t1});
25789+
auto out = cg_outputs[0];
25790+
25791+
testValidate(
25792+
fusion, {out}, {t0, t1}, {t1 + t0.squeeze(0)}, __LINE__, __FILE__);
25793+
}
25794+
2575025795
} // namespace jit
2575125796
} // namespace torch
2575225797
#endif // #if defined(USE_CUDA)

0 commit comments

Comments
 (0)