@@ -58,15 +58,49 @@ class DomainMap : public pointwise_utils::DomainMap {
58
58
domain_map.findReferenceFor (grouped_inputs_outputs[1 ]) != nullptr ;
59
59
}
60
60
61
- int getPosMappedTo (TensorView* tv, IterDomain* id) const {
61
+ int getInnerLeafDim (TensorView* tv, IterDomain* root_dim) const {
62
+ // Find the root id mapped to `root_dim`
63
+ const auto & root_dom = tv->getRootDomain ();
64
+ IterDomain* mapped_id = nullptr ;
65
+ for (auto i : c10::irange (root_dom.size ())) {
66
+ if (ca_map_.idGraph ().permissiveNodes ().permissiveAreMapped (
67
+ root_dom[i], root_dim)) {
68
+ mapped_id = root_dom[i];
69
+ break ;
70
+ }
71
+ }
72
+ TORCH_INTERNAL_ASSERT (
73
+ mapped_id != nullptr ,
74
+ " Can not find ID mapped to " ,
75
+ root_dim,
76
+ " in tensor " ,
77
+ tv);
78
+ // Project the root id to leaf id
79
+ while (!mapped_id->uses ().empty ()) {
80
+ TORCH_INTERNAL_ASSERT (mapped_id->uses ().size () == 1 );
81
+ auto expr = mapped_id->uses ()[0 ];
82
+ if (expr->isA <Split>()) {
83
+ mapped_id = expr->as <Split>()->inner ();
84
+ } else {
85
+ auto merge = expr->as <Merge>();
86
+ TORCH_INTERNAL_ASSERT (
87
+ mapped_id == merge->inner (),
88
+ " Can not find ID mapped to " ,
89
+ root_dim,
90
+ " in tensor " ,
91
+ tv);
92
+ mapped_id = merge->out ();
93
+ }
94
+ }
95
+ // Find the position of the leaf id
62
96
const auto & dom = tv->domain ()->domain ();
63
97
for (auto i : c10::irange (dom.size ())) {
64
- if (areExactMapped (id, tv-> axis (i)) ) {
98
+ if (dom[i] == mapped_id ) {
65
99
return i;
66
100
}
67
101
}
68
102
TORCH_INTERNAL_ASSERT (
69
- false , " Can not find ID mapped to " , id , " in tensor " , tv);
103
+ false , " Can not find ID mapped to " , root_dim , " in tensor " , tv);
70
104
}
71
105
72
106
// Group inputs and outputs of a fusion by its inner most domain. For example
@@ -240,22 +274,37 @@ void maybeBuildVirtualInnerDims(
240
274
// both virtual innermost dim.
241
275
// 2. The satisfied one did not merge in anything. For example,
242
276
// T0[I0{1024*1024}, I1{2}]
277
+ // If this is the case, this means that we need to split the large
278
+ // inner-most dimension to satisfy the small innermost dimension
243
279
int64_t large_dim;
244
280
int64_t split_factor;
281
+ bool split_inner_most;
245
282
if (merged_size1 < params.tile_size1 ) {
246
283
if (params.dims_merged_with_2 .empty ()) {
247
284
// case 2
248
- return ;
285
+ split_inner_most = true ;
286
+ large_dim = inner_most2;
287
+ split_factor = params.tile_size2 ;
288
+ } else {
289
+ // case 1
290
+ split_inner_most = false ;
291
+ large_dim = params.dims_merged_with_2 .back ();
292
+ auto prev_merged_size2 = merged_size2 / shape_in_ref1[large_dim];
293
+ split_factor = ceilDiv (params.tile_size2 , prev_merged_size2);
249
294
}
250
- large_dim = params.dims_merged_with_2 .back ();
251
- split_factor = ceilDiv (params.tile_size1 , merged_size1);
252
295
} else {
253
296
if (params.dims_merged_with_1 .empty ()) {
254
297
// case 2
255
- return ;
298
+ split_inner_most = true ;
299
+ large_dim = inner_most1;
300
+ split_factor = params.tile_size1 ;
301
+ } else {
302
+ // case 1
303
+ split_inner_most = false ;
304
+ large_dim = params.dims_merged_with_1 .back ();
305
+ auto prev_merged_size1 = merged_size1 / shape_in_ref1[large_dim];
306
+ split_factor = ceilDiv (params.tile_size1 , prev_merged_size1);
256
307
}
257
- large_dim = params.dims_merged_with_1 .back ();
258
- split_factor = ceilDiv (params.tile_size2 , merged_size2);
259
308
}
260
309
params.split_before_tiling .push_back ({large_dim, split_factor});
261
310
// adjust all dims to after-split
@@ -271,12 +320,16 @@ void maybeBuildVirtualInnerDims(
271
320
}
272
321
// Give the split-out dim to the unsatisfied one, so that both are satisfied.
273
322
if (merged_size1 < params.tile_size1 ) {
274
- params.dims_merged_with_2 .pop_back ();
275
- params.dims_merged_with_2 .push_back (large_dim + 1 );
323
+ if (!split_inner_most) {
324
+ params.dims_merged_with_2 .pop_back ();
325
+ params.dims_merged_with_2 .push_back (large_dim + 1 );
326
+ }
276
327
params.dims_merged_with_1 .push_back (large_dim);
277
328
} else {
278
- params.dims_merged_with_1 .pop_back ();
279
- params.dims_merged_with_1 .push_back (large_dim + 1 );
329
+ if (!split_inner_most) {
330
+ params.dims_merged_with_1 .pop_back ();
331
+ params.dims_merged_with_1 .push_back (large_dim + 1 );
332
+ }
280
333
params.dims_merged_with_2 .push_back (large_dim);
281
334
}
282
335
}
@@ -369,12 +422,6 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
369
422
if (n_elems < device_multiprocessor_count * kMaxTileSize * kMaxTileSize ) {
370
423
params->tile_size1 = 8 ;
371
424
params->tile_size2 = 8 ;
372
- // TODO: I was trying the following but I got silent wrong result
373
- // params->tile_size1 = 8;
374
- // params->tile_size2 = 4;
375
- // This should not happen, because the correctness should be irrevalent to
376
- // schedulers. We don't have to use tile size (8, 4), but we need to fix our
377
- // bug in codegen.
378
425
}
379
426
380
427
// Expand inner-most dims to virtual inner-most dims so that the inner-most
@@ -383,9 +430,9 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
383
430
auto inner_most_id2 = scheduler_utils::innerMostRootDim (reference2);
384
431
385
432
auto inner_most_pos1_in_ref1 =
386
- domain_map.getPosMappedTo (reference1, inner_most_id1);
433
+ domain_map.getInnerLeafDim (reference1, inner_most_id1);
387
434
auto inner_most_pos2_in_ref1 =
388
- domain_map.getPosMappedTo (reference1, inner_most_id2);
435
+ domain_map.getInnerLeafDim (reference1, inner_most_id2);
389
436
390
437
// See note [Supporting small transpose dimensions]
391
438
maybeBuildVirtualInnerDims (
@@ -643,9 +690,9 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {
643
690
644
691
// merge with inner most dims to get virtual inner most dims
645
692
size_t inner_most_pos1_in_ref1 =
646
- domain_map.getPosMappedTo (reference1, inner_most_id1);
693
+ domain_map.getInnerLeafDim (reference1, inner_most_id1);
647
694
size_t inner_most_pos2_in_ref1 =
648
- domain_map.getPosMappedTo (reference1, inner_most_id2);
695
+ domain_map.getInnerLeafDim (reference1, inner_most_id2);
649
696
if (merged1.has_value ()) {
650
697
if (inner_most_pos1_in_ref1 < *merged1) {
651
698
reference1->reorder (
0 commit comments