Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions torch/csrc/jit/codegen/cuda/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ void IterDomainGraph::build(Fusion* fusion) {
BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map);

const auto& permissive_c2p_map = permissive_replay_PasC.getReplay();
const auto permissive_disjoint_sets =
permissive_replay_PasC.getDisjointSets();

// For exact mapings do not map any broadcast dimensions to
// non-broadcast dimensions. Prevent any broadcasted axes being mapped
Expand Down Expand Up @@ -213,6 +215,17 @@ void IterDomainGraph::build(Fusion* fusion) {
auto p_id = entry.second;
if (idIsAComputeAtLeafDomain(p_id, p_tv)) {
loop_nodes_.mapEntries(c_id, p_id);
} else {
// When there are trivial reductions merged with other dims, `p_id`
// might not be a compute at leaf domain of `p_tv`, but it actually
// has an equivalent compute at leaf domain. For that case, we map
// the equivalent compute at leaf domain.
for (int i = 0; i < p_tv->getComputeAtPosition(); i++) {
auto id = p_tv->axis(i);
if (permissive_disjoint_sets.permissiveAreMapped(p_id, id)) {
loop_nodes_.mapEntries(c_id, id);
}
}
Comment on lines +218 to +228
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was trying to do

        for (const auto& set : permissive_disjoint_sets.disjointSets()) {
          auto id1 = set->front();
          for (auto id2 : *set) {
            auto is_leaf1 = idIsAComputeAtLeafDomain(id1, p_tv);
            auto is_leaf2 = idIsAComputeAtLeafDomain(id2, p_tv);
            if (is_leaf1 || is_leaf2) {
              loop_nodes_.mapEntries(id1, id2);
            }
            permissive_nodes_.mapEntries(id1, id2);

            // Add the swizzle inputs to the same
            //  disjoint set as well if either c_id
            //  or p_id is swizzle output.
            mapMaybeSwizzleOp(permissive_nodes_, id1);
            mapMaybeSwizzleOp(permissive_nodes_, id2);
          }
        }

But it didn't work. Don't know why.

}
permissive_nodes_.mapEntries(c_id, p_id);
consumers_.at(p_id).pushBack(c_id);
Expand All @@ -225,8 +238,8 @@ void IterDomainGraph::build(Fusion* fusion) {
mapMaybeSwizzleOp(permissive_nodes_, c_id);
}

// Make sure we always get root mapping for the permissive map. Because
// of forwarding we could otherwise miss some root mappings.
// Make sure we always get root mapping for the permissive map.
// Because of forwarding we could otherwise miss some root mappings.
for (auto entry : permissive_c2p_root_map) {
auto c_id = entry.first;
auto p_id = entry.second;
Expand Down
108 changes: 108 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24886,6 +24886,76 @@ TEST_F(NVFuserTest, FusionInlinePropagatorBroadcast_CUDA) {
testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionInlinePropagatorBroadcastTrivialReduction_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 3, 4});
fusion.addInput(tv0);
auto tv1 = sin(tv0);
// broadcasting
auto tv2 = broadcast(tv1, {false, true, false, true, false, true});
auto tv3 = tan(tv2);
// trivial reduction
auto tv4 = sum(tv3, {1, 3, 5});
auto tv5 = cos(tv4);
auto tv6 = exp(tv5);
fusion.addOutput(tv6);

for (auto tv : {tv2, tv3, tv4}) {
tv->merge(0);
tv->merge(1);
tv->merge(2);
}

InlinePropagator inline_propagator(tv6, -1, ComputeAtMode::MostInlined);
MaxRootDomainInfoSpanningTree(tv6).traverse(&inline_propagator);

TORCH_CHECK(tv6->getComputeAtPosition() == 3);
TORCH_CHECK(tv5->getComputeAtPosition() == 3);
TORCH_CHECK(tv4->getComputeAtPosition() == 3);
TORCH_CHECK(tv3->getComputeAtPosition() == 3);
TORCH_CHECK(tv2->getComputeAtPosition() == 3);
TORCH_CHECK(tv1->getComputeAtPosition() == 3);

const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::randn({2, 3, 4}, options);
auto output = input.sin().tan().cos().exp();

FusionExecutor fe;
fe.compileFusion(&fusion, {input});
auto cg_outputs = fe.runFusion({input});

testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayTrivialReduction_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 1, 3, 1, 4, 1});
fusion.addInput(tv0);
auto tv1 = sum(tv0, {1, 3, 5});
auto tv2 = sin(tv1);
fusion.addOutput(tv1);

for (auto tv : {tv0, tv1}) {
tv->merge(0);
tv->merge(1);
tv->merge(2);
}

TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv0, tv1, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv1, tv0, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv1, tv2, 3) == 3);
TORCH_CHECK(
TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv1, 3) == 3);
}

TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayBroadcast_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand All @@ -24912,6 +24982,44 @@ TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayBroadcast_CUDA) {
TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv1, 3) == 3);
}

TEST_F(NVFuserTest, FusionIdGraphTrivialReduction_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 3, 4});
fusion.addInput(tv0);
auto tv1 = broadcast(tv0, {false, true, false, true, false, true});
auto tv2 = sum(tv1, {1, 3, 5});
auto tv3 = sin(tv2);
fusion.addOutput(tv3);

for (auto tv : {tv1, tv2}) {
tv->merge(0);
tv->merge(1);
tv->merge(2);
}

InlinePropagator inline_propagator(tv3, -1, ComputeAtMode::MostInlined);
MaxRootDomainInfoSpanningTree(tv3).traverse(&inline_propagator);

ComputeAtMap ca_map(&fusion);

auto all_tvs = ir_utils::allTvs(&fusion);
for (auto tv1 : all_tvs) {
for (auto tv2 : all_tvs) {
if (tv1->isFusionInput() || tv2->isFusionInput()) {
continue;
}
for (int i : c10::irange(3)) {
auto id1 = tv1->axis(i);
auto id2 = tv2->axis(i);
TORCH_CHECK(ca_map.areMapped(id1, id2, IdMappingMode::LOOP));
TORCH_CHECK(ca_map.areMapped(id1, id2, IdMappingMode::PERMISSIVE));
}
}
}
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)