Skip to content

Commit 45e95fd

Browse files
authored
Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930)
1 parent a3ecb33 commit 45e95fd

File tree

3 files changed

+101
-27
lines changed

3 files changed

+101
-27
lines changed

torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ class DomainMap {
2020
}
2121
virtual ~DomainMap() = default;
2222

23-
bool areExactMapped(IterDomain* id1, IterDomain* id2) const {
24-
return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT);
25-
}
26-
2723
const ComputeAtMap& getComputeAtMap() const {
2824
return ca_map_;
2925
}

torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,49 @@ class DomainMap : public pointwise_utils::DomainMap {
5858
domain_map.findReferenceFor(grouped_inputs_outputs[1]) != nullptr;
5959
}
6060

61-
int getPosMappedTo(TensorView* tv, IterDomain* id) const {
61+
int getInnerLeafDim(TensorView* tv, IterDomain* root_dim) const {
62+
// Find the root id mapped to `root_dim`
63+
const auto& root_dom = tv->getRootDomain();
64+
IterDomain* mapped_id = nullptr;
65+
for (auto i : c10::irange(root_dom.size())) {
66+
if (ca_map_.idGraph().permissiveNodes().permissiveAreMapped(
67+
root_dom[i], root_dim)) {
68+
mapped_id = root_dom[i];
69+
break;
70+
}
71+
}
72+
TORCH_INTERNAL_ASSERT(
73+
mapped_id != nullptr,
74+
"Can not find ID mapped to ",
75+
root_dim,
76+
" in tensor ",
77+
tv);
78+
// Project the root id to leaf id
79+
while (!mapped_id->uses().empty()) {
80+
TORCH_INTERNAL_ASSERT(mapped_id->uses().size() == 1);
81+
auto expr = mapped_id->uses()[0];
82+
if (expr->isA<Split>()) {
83+
mapped_id = expr->as<Split>()->inner();
84+
} else {
85+
auto merge = expr->as<Merge>();
86+
TORCH_INTERNAL_ASSERT(
87+
mapped_id == merge->inner(),
88+
"Can not find ID mapped to ",
89+
root_dim,
90+
" in tensor ",
91+
tv);
92+
mapped_id = merge->out();
93+
}
94+
}
95+
// Find the position of the leaf id
6296
const auto& dom = tv->domain()->domain();
6397
for (auto i : c10::irange(dom.size())) {
64-
if (areExactMapped(id, tv->axis(i))) {
98+
if (dom[i] == mapped_id) {
6599
return i;
66100
}
67101
}
68102
TORCH_INTERNAL_ASSERT(
69-
false, "Can not find ID mapped to ", id, " in tensor ", tv);
103+
false, "Can not find ID mapped to ", root_dim, " in tensor ", tv);
70104
}
71105

72106
// Group inputs and outputs of a fusion by its inner most domain. For example
@@ -240,22 +274,37 @@ void maybeBuildVirtualInnerDims(
240274
// both virtual innermost dim.
241275
// 2. The satisfied one did not merge in anything. For example,
242276
// T0[I0{1024*1024}, I1{2}]
277+
// If this is the case, this means that we need to split the large
278+
// inner-most dimension to satisfy the small innermost dimension
243279
int64_t large_dim;
244280
int64_t split_factor;
281+
bool split_inner_most;
245282
if (merged_size1 < params.tile_size1) {
246283
if (params.dims_merged_with_2.empty()) {
247284
// case 2
248-
return;
285+
split_inner_most = true;
286+
large_dim = inner_most2;
287+
split_factor = params.tile_size2;
288+
} else {
289+
// case 1
290+
split_inner_most = false;
291+
large_dim = params.dims_merged_with_2.back();
292+
auto prev_merged_size2 = merged_size2 / shape_in_ref1[large_dim];
293+
split_factor = ceilDiv(params.tile_size2, prev_merged_size2);
249294
}
250-
large_dim = params.dims_merged_with_2.back();
251-
split_factor = ceilDiv(params.tile_size1, merged_size1);
252295
} else {
253296
if (params.dims_merged_with_1.empty()) {
254297
// case 2
255-
return;
298+
split_inner_most = true;
299+
large_dim = inner_most1;
300+
split_factor = params.tile_size1;
301+
} else {
302+
// case 1
303+
split_inner_most = false;
304+
large_dim = params.dims_merged_with_1.back();
305+
auto prev_merged_size1 = merged_size1 / shape_in_ref1[large_dim];
306+
split_factor = ceilDiv(params.tile_size1, prev_merged_size1);
256307
}
257-
large_dim = params.dims_merged_with_1.back();
258-
split_factor = ceilDiv(params.tile_size2, merged_size2);
259308
}
260309
params.split_before_tiling.push_back({large_dim, split_factor});
261310
// adjust all dims to after-split
@@ -271,12 +320,16 @@ void maybeBuildVirtualInnerDims(
271320
}
272321
// Give the split-out dim to the unsatisfied one, so that both are satisfied.
273322
if (merged_size1 < params.tile_size1) {
274-
params.dims_merged_with_2.pop_back();
275-
params.dims_merged_with_2.push_back(large_dim + 1);
323+
if (!split_inner_most) {
324+
params.dims_merged_with_2.pop_back();
325+
params.dims_merged_with_2.push_back(large_dim + 1);
326+
}
276327
params.dims_merged_with_1.push_back(large_dim);
277328
} else {
278-
params.dims_merged_with_1.pop_back();
279-
params.dims_merged_with_1.push_back(large_dim + 1);
329+
if (!split_inner_most) {
330+
params.dims_merged_with_1.pop_back();
331+
params.dims_merged_with_1.push_back(large_dim + 1);
332+
}
280333
params.dims_merged_with_2.push_back(large_dim);
281334
}
282335
}
@@ -369,12 +422,6 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
369422
if (n_elems < device_multiprocessor_count * kMaxTileSize * kMaxTileSize) {
370423
params->tile_size1 = 8;
371424
params->tile_size2 = 8;
372-
// TODO: I was trying the following but I got silent wrong result
373-
// params->tile_size1 = 8;
374-
// params->tile_size2 = 4;
375-
// This should not happen, because the correctness should be irrevalent to
376-
// schedulers. We don't have to use tile size (8, 4), but we need to fix our
377-
// bug in codegen.
378425
}
379426

380427
// Expand inner-most dims to virtual inner-most dims so that the inner-most
@@ -383,9 +430,9 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
383430
auto inner_most_id2 = scheduler_utils::innerMostRootDim(reference2);
384431

385432
auto inner_most_pos1_in_ref1 =
386-
domain_map.getPosMappedTo(reference1, inner_most_id1);
433+
domain_map.getInnerLeafDim(reference1, inner_most_id1);
387434
auto inner_most_pos2_in_ref1 =
388-
domain_map.getPosMappedTo(reference1, inner_most_id2);
435+
domain_map.getInnerLeafDim(reference1, inner_most_id2);
389436

390437
// See note [Supporting small transpose dimensions]
391438
maybeBuildVirtualInnerDims(
@@ -643,9 +690,9 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {
643690

644691
// merge with inner most dims to get virtual inner most dims
645692
size_t inner_most_pos1_in_ref1 =
646-
domain_map.getPosMappedTo(reference1, inner_most_id1);
693+
domain_map.getInnerLeafDim(reference1, inner_most_id1);
647694
size_t inner_most_pos2_in_ref1 =
648-
domain_map.getPosMappedTo(reference1, inner_most_id2);
695+
domain_map.getInnerLeafDim(reference1, inner_most_id2);
649696
if (merged1.has_value()) {
650697
if (inner_most_pos1_in_ref1 < *merged1) {
651698
reference1->reorder(

torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,37 @@ TEST_F(NVFuserTest, FusionScheduleTransposeSmallInnerSize3_CUDA) {
932932
testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
933933
}
934934

935+
// x->sin->transpose->cos->y
936+
TEST_F(NVFuserTest, FusionScheduleTranspose2DSmallInnerSize_CUDA) {
937+
std::array<std::vector<int64_t>, 2> shapes{
938+
std::vector<int64_t>{1024 * 1024 * 128, 2},
939+
std::vector<int64_t>{2, 1024 * 1024 * 128}};
940+
for (const auto& shape : shapes) {
941+
Fusion fusion;
942+
FusionGuard fg(&fusion);
943+
944+
auto tv0 = makeContigTensor(2);
945+
fusion.addInput(tv0);
946+
auto tv1 = sin(tv0);
947+
auto tv2 = transpose(tv1, 0, 1);
948+
auto tv3 = cos(tv2);
949+
fusion.addOutput(tv3);
950+
951+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
952+
at::Tensor input = at::randn(shape, options);
953+
954+
auto lparams = scheduleTranspose(&fusion, {input});
955+
956+
FusionExecutor fe;
957+
fe.compileFusion(&fusion, {input}, lparams);
958+
auto outputs = fe.runFusion({input}, lparams);
959+
960+
auto tv_ref = input.sin().transpose(0, 1).cos();
961+
962+
testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
963+
}
964+
}
965+
935966
} // namespace jit
936967
} // namespace torch
937968
#endif // #if defined(USE_CUDA)

0 commit comments

Comments
 (0)