Skip to content

Commit eabe8d8

Browse files
authored
Segment self mapping fusions (#1954)
1 parent e96aacf commit eabe8d8

File tree

4 files changed

+109
-32
lines changed

4 files changed

+109
-32
lines changed

torch/csrc/jit/codegen/cuda/compute_at_map.cpp

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
77
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
88

9+
#include <tuple>
10+
911
namespace torch {
1012
namespace jit {
1113
namespace fuser {
@@ -29,8 +31,22 @@ bool idIsALeafDomain(IterDomain* id, TensorView* tv) {
2931

3032
} // namespace
3133

32-
IterDomainGraph::IterDomainGraph(Fusion* fusion) {
34+
IterDomainGraph::IterDomainGraph(Fusion* fusion, bool allow_self_mapping) {
3335
build(fusion);
36+
37+
if (!allow_self_mapping) {
38+
TORCH_INTERNAL_ASSERT(
39+
!hasSelfMapping(),
40+
"Unsupported domain mapping detected in ",
41+
std::get<0>(*self_mapping_info_)->toString(),
42+
". ",
43+
std::get<3>(*self_mapping_info_),
44+
" domains, ",
45+
std::get<1>(*self_mapping_info_)->toString(),
46+
" and ",
47+
std::get<2>(*self_mapping_info_)->toString(),
48+
", are mapped with each other.");
49+
}
3450
}
3551

3652
//! Map corresponding inputs and outputs of swizzle op together
@@ -197,7 +213,8 @@ c10::optional<std::pair<IterDomain*, IterDomain*>> detectMappablePair(
197213
// those domains should never be mapped with each other. It may be
198214
// possible to lift this assumption, but it's unclear if it could
199215
// matter in practice.
200-
void failIfSelfMappingExists(Fusion* fusion, const IterDomainGraph& id_graph) {
216+
c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
217+
findFirstSelfMapping(Fusion* fusion, const IterDomainGraph& id_graph) {
201218
for (auto tv : ir_utils::allTvs(fusion)) {
202219
// For each tensor, make sure root, rfactor and leaf domains
203220
// should not include domains that are mapped with another domain
@@ -207,44 +224,39 @@ void failIfSelfMappingExists(Fusion* fusion, const IterDomainGraph& id_graph) {
207224
// Root domains
208225
auto self_mappped_root_pair =
209226
detectMappablePair(tv->getRootDomain(), id_graph);
210-
TORCH_INTERNAL_ASSERT(
211-
!self_mappped_root_pair.has_value(),
212-
"Unsupported domain mapping detected in ",
213-
tv->toString(),
214-
". Root domains, ",
215-
self_mappped_root_pair->first->toString(),
216-
" and ",
217-
self_mappped_root_pair->second->toString(),
218-
", are mapped with each other.");
227+
if (self_mappped_root_pair.has_value()) {
228+
return std::make_tuple(
229+
tv,
230+
self_mappped_root_pair->first,
231+
self_mappped_root_pair->second,
232+
"Root");
233+
}
219234

220235
// Rfactor domains
221236
if (tv->hasRFactor()) {
222237
auto self_mappped_rf_pair =
223238
detectMappablePair(tv->getRFactorDomain(), id_graph);
224-
TORCH_INTERNAL_ASSERT(
225-
!self_mappped_rf_pair.has_value(),
226-
"Unsupported domain mapping detected in ",
227-
tv->toString(),
228-
". RFactor domains, ",
229-
self_mappped_rf_pair->first->toString(),
230-
" and ",
231-
self_mappped_rf_pair->second->toString(),
232-
", are mapped with each other.");
239+
if (self_mappped_rf_pair.has_value()) {
240+
return std::make_tuple(
241+
tv,
242+
self_mappped_rf_pair->first,
243+
self_mappped_rf_pair->second,
244+
"RFactor");
245+
}
233246
}
234247

235248
// Leaf domains
236249
auto self_mappped_leaf_pair =
237250
detectMappablePair(tv->domain()->domain(), id_graph);
238-
TORCH_INTERNAL_ASSERT(
239-
!self_mappped_leaf_pair.has_value(),
240-
"Unsupported domain mapping detected in ",
241-
tv->toString(),
242-
". Leaf domains, ",
243-
self_mappped_leaf_pair->first->toString(),
244-
" and ",
245-
self_mappped_leaf_pair->second->toString(),
246-
", are mapped with each other.");
251+
if (self_mappped_leaf_pair.has_value()) {
252+
return std::make_tuple(
253+
tv,
254+
self_mappped_leaf_pair->first,
255+
self_mappped_leaf_pair->second,
256+
"Leaf");
257+
}
247258
}
259+
return c10::nullopt;
248260
}
249261

250262
} // namespace
@@ -591,8 +603,7 @@ void IterDomainGraph::build(Fusion* fusion) {
591603
}
592604
}
593605
}
594-
595-
failIfSelfMappingExists(fusion, *this);
606+
self_mapping_info_ = findFirstSelfMapping(fusion, *this);
596607
}
597608

598609
void IterDomainGraph::initializeId(

torch/csrc/jit/codegen/cuda/compute_at_map.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ namespace cuda {
5454
// Do not forward through any broadcast IDs
5555
class TORCH_CUDA_CU_API IterDomainGraph {
5656
public:
57-
IterDomainGraph(Fusion* fusion);
57+
IterDomainGraph(Fusion* fusion, bool allow_self_mapping = false);
5858

5959
const DisjointSets<IterDomain*>& permissiveNodes() const {
6060
return permissive_nodes_;
@@ -88,6 +88,10 @@ class TORCH_CUDA_CU_API IterDomainGraph {
8888
return view_rfactor_ids_;
8989
}
9090

91+
bool hasSelfMapping() const {
92+
return self_mapping_info_.has_value();
93+
}
94+
9195
private:
9296
void build(Fusion* fusion);
9397

@@ -116,6 +120,9 @@ class TORCH_CUDA_CU_API IterDomainGraph {
116120
VectorOfUniqueEntries<IterDomain*> all_ids_;
117121

118122
std::unordered_set<IterDomain*> view_rfactor_ids_;
123+
124+
c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
125+
self_mapping_info_ = c10::nullopt;
119126
};
120127

121128
class TrivialReductionInfo;

torch/csrc/jit/codegen/cuda/scheduler/registry.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,6 +1501,9 @@ bool checkCanSchedule(
15011501
if (!isConnectedFusionGraph(fusion)) {
15021502
return false;
15031503
}
1504+
if (IterDomainGraph(fusion, /*allow_self_mapping=*/true).hasSelfMapping()) {
1505+
return false;
1506+
}
15041507
if (!SchedulerType::canScheduleCompileTime(fusion)) {
15051508
return false;
15061509
}

torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <torch/csrc/jit/codegen/cuda/executor.h>
66
#include <torch/csrc/jit/codegen/cuda/inline_propagator.h>
7+
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
78
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
89
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
910
#include <torch/csrc/jit/codegen/cuda/scheduler/transpose.h>
@@ -795,6 +796,61 @@ TEST_F(NVFuserTest, FusionViewNoTranspose_CUDA) {
795796
TORCH_CHECK(!hasAtLeastTwoValidGroups(&fusion));
796797
}
797798

799+
TEST_F(NVFuserTest, FusionTransposeSelfMapping_CUDA) {
800+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
801+
Fusion& fusion = *fusion_ptr.get();
802+
FusionGuard fg(&fusion);
803+
804+
auto tv0 = makeContigTensor(2);
805+
fusion.addInput(tv0);
806+
auto tv1 = transpose(tv0, 0, 1);
807+
auto tv2 = add(tv0, tv1);
808+
fusion.addOutput(tv2);
809+
810+
EXPECT_THAT(
811+
[&]() { IterDomainGraph(fusion_ptr.get()); },
812+
testing::ThrowsMessage<c10::Error>(
813+
testing::HasSubstr("Unsupported domain mapping detected")));
814+
815+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
816+
auto t0 = at::randn({5, 5}, options);
817+
818+
FusionExecutorCache executor_cache(std::move(fusion_ptr));
819+
auto cg_outputs = executor_cache.runFusionWithInputs({t0});
820+
821+
auto ref = t0.transpose(0, 1) + t0;
822+
823+
testValidate(
824+
executor_cache.fusion(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
825+
}
826+
827+
#if 0
828+
// silent wrong result
829+
TEST_F(NVFuserTest, FusionTransposeViewSelfMapping_CUDA) {
830+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
831+
Fusion& fusion = *fusion_ptr.get();
832+
FusionGuard fg(&fusion);
833+
834+
auto tv0 = makeContigTensor(2);
835+
fusion.addInput(tv0);
836+
auto tv1 = transpose(tv0, 0, 1);
837+
auto tv2 = view(tv0, {2, 3}, {3, 2});
838+
auto tv3 = add(tv1, tv2);
839+
fusion.addOutput(tv3);
840+
841+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
842+
auto t0 = at::randn({2, 3}, options);
843+
844+
FusionExecutorCache executor_cache(std::move(fusion_ptr));
845+
auto cg_outputs = executor_cache.runFusionWithInputs({t0});
846+
847+
auto ref = t0.transpose(0, 1) + t0.view({3, 2});
848+
849+
testValidate(
850+
executor_cache.fusion(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
851+
}
852+
#endif
853+
798854
// t0------------.
799855
// t2->broadcast->sub->mul->relu->t6
800856
// t1------------------'

0 commit comments

Comments
 (0)