diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index d2d4cad6d4fe..c5755c97aa7a 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -178,6 +178,8 @@ void IterDomainGraph::build(Fusion* fusion) { BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map); const auto& permissive_c2p_map = permissive_replay_PasC.getReplay(); + const auto permissive_disjoint_sets = + permissive_replay_PasC.getDisjointSets(); // For exact mapings do not map any broadcast dimensions to // non-broadcast dimensions. Prevent any broadcasted axes being mapped @@ -213,6 +215,17 @@ void IterDomainGraph::build(Fusion* fusion) { auto p_id = entry.second; if (idIsAComputeAtLeafDomain(p_id, p_tv)) { loop_nodes_.mapEntries(c_id, p_id); + } else { + // When there are trivial reductions merged with other dims, `p_id` + // might not be a compute at leaf domain of `p_tv`, but it actually + // has an equivalent compute at leaf domain. For that case, we map + // the equivalent compute at leaf domain. + for (int i = 0; i < p_tv->getComputeAtPosition(); i++) { + auto id = p_tv->axis(i); + if (permissive_disjoint_sets.permissiveAreMapped(p_id, id)) { + loop_nodes_.mapEntries(c_id, id); + } + } } permissive_nodes_.mapEntries(c_id, p_id); consumers_.at(p_id).pushBack(c_id); @@ -225,8 +238,8 @@ void IterDomainGraph::build(Fusion* fusion) { mapMaybeSwizzleOp(permissive_nodes_, c_id); } - // Make sure we always get root mapping for the permissive map. Because - // of forwarding we could otherwise miss some root mappings. + // Make sure we always get root mapping for the permissive map. + // Because of forwarding we could otherwise miss some root mappings. for (auto entry : permissive_c2p_root_map) { auto c_id = entry.first; auto p_id = entry.second; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index a3d7b310902a..91358f0ec54e 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -24886,6 +24886,76 @@ TEST_F(NVFuserTest, FusionInlinePropagatorBroadcast_CUDA) { testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionInlinePropagatorBroadcastTrivialReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3, 4}); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + // broadcasting + auto tv2 = broadcast(tv1, {false, true, false, true, false, true}); + auto tv3 = tan(tv2); + // trivial reduction + auto tv4 = sum(tv3, {1, 3, 5}); + auto tv5 = cos(tv4); + auto tv6 = exp(tv5); + fusion.addOutput(tv6); + + for (auto tv : {tv2, tv3, tv4}) { + tv->merge(0); + tv->merge(1); + tv->merge(2); + } + + InlinePropagator inline_propagator(tv6, -1, ComputeAtMode::MostInlined); + MaxRootDomainInfoSpanningTree(tv6).traverse(&inline_propagator); + + TORCH_CHECK(tv6->getComputeAtPosition() == 3); + TORCH_CHECK(tv5->getComputeAtPosition() == 3); + TORCH_CHECK(tv4->getComputeAtPosition() == 3); + TORCH_CHECK(tv3->getComputeAtPosition() == 3); + TORCH_CHECK(tv2->getComputeAtPosition() == 3); + TORCH_CHECK(tv1->getComputeAtPosition() == 3); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({2, 3, 4}, options); + auto output = input.sin().tan().cos().exp(); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayTrivialReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 1, 3, 1, 4, 1}); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1, 3, 5}); + auto tv2 = sin(tv1); + fusion.addOutput(tv1); + + for (auto tv : {tv0, tv1}) { + tv->merge(0); + tv->merge(1); + tv->merge(2); + } + + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv0, tv1, 3) == 3); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv1, tv0, 3) == 3); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv1, tv2, 3) == 3); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv1, 3) == 3); +} + TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -24912,6 +24982,44 @@ TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayBroadcast_CUDA) { TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv1, 3) == 3); } +TEST_F(NVFuserTest, FusionIdGraphTrivialReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3, 4}); + fusion.addInput(tv0); + auto tv1 = broadcast(tv0, {false, true, false, true, false, true}); + auto tv2 = sum(tv1, {1, 3, 5}); + auto tv3 = sin(tv2); + fusion.addOutput(tv3); + + for (auto tv : {tv1, tv2}) { + tv->merge(0); + tv->merge(1); + tv->merge(2); + } + + InlinePropagator inline_propagator(tv3, -1, ComputeAtMode::MostInlined); + MaxRootDomainInfoSpanningTree(tv3).traverse(&inline_propagator); + + ComputeAtMap ca_map(&fusion); + + auto all_tvs = ir_utils::allTvs(&fusion); + for (auto tv1 : all_tvs) { + for (auto tv2 : all_tvs) { + if (tv1->isFusionInput() || tv2->isFusionInput()) { + continue; + } + for (int i : c10::irange(3)) { + auto id1 = tv1->axis(i); + auto id2 = tv2->axis(i); + TORCH_CHECK(ca_map.areMapped(id1, id2, IdMappingMode::LOOP)); + TORCH_CHECK(ca_map.areMapped(id1, id2, IdMappingMode::PERMISSIVE)); + } + } + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA)