6
6
#include < torch/csrc/jit/codegen/cuda/root_domain_map.h>
7
7
#include < torch/csrc/jit/codegen/cuda/transform_iter.h>
8
8
9
+ #include < tuple>
10
+
9
11
namespace torch {
10
12
namespace jit {
11
13
namespace fuser {
@@ -29,8 +31,22 @@ bool idIsALeafDomain(IterDomain* id, TensorView* tv) {
29
31
30
32
} // namespace
31
33
32
- IterDomainGraph::IterDomainGraph (Fusion* fusion) {
34
+ IterDomainGraph::IterDomainGraph (Fusion* fusion, bool allow_self_mapping ) {
33
35
build (fusion);
36
+
37
+ if (!allow_self_mapping) {
38
+ TORCH_INTERNAL_ASSERT (
39
+ !hasSelfMapping (),
40
+ " Unsupported domain mapping detected in " ,
41
+ std::get<0 >(*self_mapping_info_)->toString (),
42
+ " . " ,
43
+ std::get<3 >(*self_mapping_info_),
44
+ " domains, " ,
45
+ std::get<1 >(*self_mapping_info_)->toString (),
46
+ " and " ,
47
+ std::get<2 >(*self_mapping_info_)->toString (),
48
+ " , are mapped with each other." );
49
+ }
34
50
}
35
51
36
52
// ! Map corresponding inputs and outputs of swizzle op together
@@ -197,7 +213,8 @@ c10::optional<std::pair<IterDomain*, IterDomain*>> detectMappablePair(
197
213
// those domains should never be mapped with each other. It may be
198
214
// possible to lift this assumption, but it's unclear if it could
199
215
// matter in practice.
200
- void failIfSelfMappingExists (Fusion* fusion, const IterDomainGraph& id_graph) {
216
+ c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
217
+ findFirstSelfMapping (Fusion* fusion, const IterDomainGraph& id_graph) {
201
218
for (auto tv : ir_utils::allTvs (fusion)) {
202
219
// For each tensor, make sure root, rfactor and leaf domains
203
220
// should not include domains that are mapped with another domain
@@ -207,44 +224,39 @@ void failIfSelfMappingExists(Fusion* fusion, const IterDomainGraph& id_graph) {
207
224
// Root domains
208
225
auto self_mappped_root_pair =
209
226
detectMappablePair (tv->getRootDomain (), id_graph);
210
- TORCH_INTERNAL_ASSERT (
211
- !self_mappped_root_pair.has_value (),
212
- " Unsupported domain mapping detected in " ,
213
- tv->toString (),
214
- " . Root domains, " ,
215
- self_mappped_root_pair->first ->toString (),
216
- " and " ,
217
- self_mappped_root_pair->second ->toString (),
218
- " , are mapped with each other." );
227
+ if (self_mappped_root_pair.has_value ()) {
228
+ return std::make_tuple (
229
+ tv,
230
+ self_mappped_root_pair->first ,
231
+ self_mappped_root_pair->second ,
232
+ " Root" );
233
+ }
219
234
220
235
// Rfactor domains
221
236
if (tv->hasRFactor ()) {
222
237
auto self_mappped_rf_pair =
223
238
detectMappablePair (tv->getRFactorDomain (), id_graph);
224
- TORCH_INTERNAL_ASSERT (
225
- !self_mappped_rf_pair.has_value (),
226
- " Unsupported domain mapping detected in " ,
227
- tv->toString (),
228
- " . RFactor domains, " ,
229
- self_mappped_rf_pair->first ->toString (),
230
- " and " ,
231
- self_mappped_rf_pair->second ->toString (),
232
- " , are mapped with each other." );
239
+ if (self_mappped_rf_pair.has_value ()) {
240
+ return std::make_tuple (
241
+ tv,
242
+ self_mappped_rf_pair->first ,
243
+ self_mappped_rf_pair->second ,
244
+ " RFactor" );
245
+ }
233
246
}
234
247
235
248
// Leaf domains
236
249
auto self_mappped_leaf_pair =
237
250
detectMappablePair (tv->domain ()->domain (), id_graph);
238
- TORCH_INTERNAL_ASSERT (
239
- !self_mappped_leaf_pair.has_value (),
240
- " Unsupported domain mapping detected in " ,
241
- tv->toString (),
242
- " . Leaf domains, " ,
243
- self_mappped_leaf_pair->first ->toString (),
244
- " and " ,
245
- self_mappped_leaf_pair->second ->toString (),
246
- " , are mapped with each other." );
251
+ if (self_mappped_leaf_pair.has_value ()) {
252
+ return std::make_tuple (
253
+ tv,
254
+ self_mappped_leaf_pair->first ,
255
+ self_mappped_leaf_pair->second ,
256
+ " Leaf" );
257
+ }
247
258
}
259
+ return c10::nullopt;
248
260
}
249
261
250
262
} // namespace
@@ -591,8 +603,7 @@ void IterDomainGraph::build(Fusion* fusion) {
591
603
}
592
604
}
593
605
}
594
-
595
- failIfSelfMappingExists (fusion, *this );
606
+ self_mapping_info_ = findFirstSelfMapping (fusion, *this );
596
607
}
597
608
598
609
void IterDomainGraph::initializeId (
0 commit comments