Skip to content

Commit fbd97e5

Browse files
authored
Revert "Cleanup trivial reduction workarounds (#2006)" (#2024)
This reverts commit bca20c1.
1 parent bca20c1 commit fbd97e5

File tree

7 files changed

+154
-139
lines changed

7 files changed

+154
-139
lines changed

torch/csrc/jit/codegen/cuda/inlining.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,25 +153,29 @@ size_t MaxPosCalculator::getMaxPosAll(
153153
return max_pos;
154154
}
155155

156-
void inlineMost() {
157-
inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()));
156+
void inlineMost(const std::unordered_set<IterDomain*>& uninlinable_ids) {
157+
inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids);
158158
}
159159

160-
void inlineMost(const std::vector<TensorView*>& tvs) {
160+
void inlineMost(
161+
const std::vector<TensorView*>& tvs,
162+
const std::unordered_set<IterDomain*>& uninlinable_ids) {
161163
if (tvs.empty()) {
162164
return;
163165
}
164-
MaxPosCalculator calc;
166+
MaxPosCalculator calc(uninlinable_ids);
165167
for (auto tv : tvs) {
166168
tv->inlineAt(-1, true, &calc);
167169
}
168170
}
169171

170-
void inlineMost(const std::unordered_set<TensorView*>& tvs) {
172+
void inlineMost(
173+
const std::unordered_set<TensorView*>& tvs,
174+
const std::unordered_set<IterDomain*>& uninlinable_ids) {
171175
if (tvs.empty()) {
172176
return;
173177
}
174-
MaxPosCalculator calc;
178+
MaxPosCalculator calc(uninlinable_ids);
175179
for (auto tv : tvs) {
176180
tv->inlineAt(-1, true, &calc);
177181
}
@@ -272,9 +276,10 @@ std::unordered_map<TensorView*, size_t> getPositionsMappedTo(
272276
void inlineAllAt(
273277
TensorView* reference_tv,
274278
int64_t reference_pos,
275-
bool best_effort) {
279+
bool best_effort,
280+
const std::unordered_set<IterDomain*>& uninlinable_ids) {
276281
auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
277-
MaxPosCalculator calc;
282+
MaxPosCalculator calc(uninlinable_ids);
278283
for (auto pair : mapped_positions) {
279284
pair.first->inlineAt(pair.second, best_effort, &calc);
280285
}
@@ -284,9 +289,10 @@ void inlineSelectedAt(
284289
const std::unordered_set<TensorView*>& selected,
285290
TensorView* reference_tv,
286291
int64_t reference_pos,
287-
bool best_effort) {
292+
bool best_effort,
293+
const std::unordered_set<IterDomain*>& uninlinable_ids) {
288294
auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
289-
MaxPosCalculator calc;
295+
MaxPosCalculator calc(uninlinable_ids);
290296
for (auto pair : mapped_positions) {
291297
if (selected.count(pair.first) > 0) {
292298
pair.first->inlineAt(pair.second, best_effort, &calc);

torch/csrc/jit/codegen/cuda/inlining.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,35 @@ class MaxPosCalculator {
6464

6565
// Inline to the right most allowed position for all tensors in the current
6666
// fusion.
67-
TORCH_CUDA_CU_API void inlineMost();
67+
TORCH_CUDA_CU_API void inlineMost(
68+
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
6869
// Inline to the right most allowed position for the selected tensors in the
6970
// current fusion.
70-
TORCH_CUDA_CU_API void inlineMost(const std::vector<TensorView*>& tvs);
71+
TORCH_CUDA_CU_API void inlineMost(
72+
const std::vector<TensorView*>& tvs,
73+
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
7174
// Inline to the right most allowed position for the selected tensors in the
7275
// current fusion.
73-
TORCH_CUDA_CU_API void inlineMost(const std::unordered_set<TensorView*>& tvs);
76+
TORCH_CUDA_CU_API void inlineMost(
77+
const std::unordered_set<TensorView*>& tvs,
78+
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
7479

7580
// Inline to the position corresponding to the reference position in the
7681
// reference tensor for all tensors in the current fusion.
7782
TORCH_CUDA_CU_API void inlineAllAt(
7883
TensorView* reference_tv,
7984
int64_t reference_pos,
80-
bool best_effort = false);
85+
bool best_effort = false,
86+
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
8187

8288
// Inline to the position corresponding to the reference position in the
8389
// reference tensor for selected tensors in the current fusion.
8490
TORCH_CUDA_CU_API void inlineSelectedAt(
8591
const std::unordered_set<TensorView*>& selected,
8692
TensorView* reference_tv,
8793
int64_t reference_pos,
88-
bool best_effort = false);
94+
bool best_effort = false,
95+
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
8996

9097
} // namespace cuda
9198
} // namespace fuser

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,43 +1597,6 @@ std::vector<IterDomain*> IterDomain::clone(
15971597
return cloned_domains;
15981598
}
15991599

1600-
IterType inferIterType(IterDomain* i1, IterDomain* i2) {
1601-
// The itertype inference is a pattern matching of the rules below:
1602-
//
1603-
// X + X = X
1604-
// trivial reduction + X = X
1605-
// X + trivial reduction = X
1606-
// broadcasting + X = X
1607-
// X + broadcasting = X
1608-
// fail
1609-
//
1610-
// The rules are proceeded one by one in order. For each rule, we test if the
1611-
// given (outer, inner) matches the pattern. If it does, then we stop
1612-
// procceeding and get a result. If we have reached the end without finding
1613-
// any matched pattern, then it is a mistake and should be reported.
1614-
//
1615-
// Note that based on the above rule:
1616-
// broadcasting + (non-trivial) reduction = reduction
1617-
// broadcasting + trivial reduction = broadcasting
1618-
if (i1->getIterType() == i2->getIterType()) {
1619-
return i1->getIterType();
1620-
}
1621-
if (i1->isTrivialReduction()) {
1622-
return i2->getIterType();
1623-
}
1624-
if (i2->isTrivialReduction()) {
1625-
return i1->getIterType();
1626-
}
1627-
if (i1->isBroadcast()) {
1628-
return i2->getIterType();
1629-
}
1630-
if (i2->isBroadcast()) {
1631-
return i1->getIterType();
1632-
}
1633-
TORCH_CHECK(
1634-
false, "Merging IterDomains requires that their iteration types match.");
1635-
}
1636-
16371600
// Merging does not propagate the start and stop values of the input
16381601
// domains to the merged output domain. The actual range of the
16391602
// domains is enforced by predicates. Note that since only root
@@ -1643,10 +1606,48 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
16431606
TORCH_CHECK(
16441607
!outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(),
16451608
"Merging IterDomains with ending values that are 0 is not supported at this time.");
1609+
TORCH_CHECK(
1610+
outer->isReduction() == inner->isReduction() ||
1611+
(!outer->isReduction() && inner->isTrivialReduction()) ||
1612+
(outer->isTrivialReduction() && !inner->isReduction()),
1613+
"Merging IterDomains requires that their iteration types match.");
1614+
TORCH_CHECK(
1615+
(outer->isGather() && inner->isGather()) ||
1616+
(!outer->isGather() && !inner->isGather()),
1617+
"Merging gather and non-gather domains is not supported.");
1618+
1619+
TORCH_CHECK(
1620+
!outer->isStride() && !inner->isStride(),
1621+
"No support for merging stride domains");
16461622

16471623
Val* merged_id_size = mul(outer->extent(), inner->extent());
16481624

1649-
IterType itype = inferIterType(outer, inner);
1625+
IterType itype = outer->getIterType();
1626+
1627+
if (outer->isBroadcast() && inner->isBroadcast()) {
1628+
itype = IterType::Broadcast;
1629+
}
1630+
1631+
if ((outer->isBroadcast() || inner->isBroadcast()) &&
1632+
(outer->getIterType() == IterType::Iteration ||
1633+
inner->getIterType() == IterType::Iteration)) {
1634+
itype = IterType::Iteration;
1635+
}
1636+
1637+
// Merging trivial reduction with iter domain, that's fine, just make it an
1638+
// iter domain.
1639+
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
1640+
(outer->getIterType() == IterType::Iteration ||
1641+
inner->getIterType() == IterType::Iteration)) {
1642+
itype = IterType::Iteration;
1643+
}
1644+
1645+
// Merging trivial reduction with broadcasting, that's fine, just make it a
1646+
// broadcasting.
1647+
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
1648+
(outer->isBroadcast() || inner->isBroadcast())) {
1649+
itype = IterType::Broadcast;
1650+
}
16501651

16511652
Val* expanded_extent = nullptr;
16521653
if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) {

torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,13 @@ void multiReductionInliner(
330330
}
331331
}
332332

333+
// Find iter domains that are mapped to a trivial reduction, these should
334+
// never be inlined.
335+
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
336+
scheduler_utils::getTrivialReductionMap(fusion);
337+
333338
// Inline the schedule
334-
inlineMost();
339+
inlineMost(mapped_to_trivial_reduction);
335340
}
336341

337342
namespace {

torch/csrc/jit/codegen/cuda/scheduler/utils.cpp

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,26 @@ namespace scheduler_utils {
2121

2222
// Returns number of "valid" dimensions. e.g. if tv has
2323
// [I1, R2, I3, I4, R3{1}]
24-
// resulting domain should be:
25-
// [I1, I3*I4, R2*R3{1}] with return value 3
24+
// where R3{1} is in dont_merge, resulting domain should be:
25+
// [I1, I3*I4, R2, R3{1}] with return value 3
2626
//
2727
// if tv has
2828
// [R1, I2, R3, I4, R4, R5{1}, R6{1}]
29-
// resulting domain should be:
30-
// [I2*I4, R1*R3, R4*R5{1}*R6{1}]
29+
// where R5{1} and R6{1} are in dont_merge, resulting domain should be:
30+
// [I2*I4, R1*R3, R4, R5{1}, R6{1}]
3131
// with return value 3
32-
size_t merge_3d(TensorView* tv) {
32+
size_t merge_3d(
33+
TensorView* tv,
34+
const std::unordered_set<IterDomain*>& dont_merge) {
3335
bool active_is_reduction = false;
3436
bool first_dim = true;
3537
int prev_i = -1;
3638

3739
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
40+
if (dont_merge.count(tv->axis(i))) {
41+
continue;
42+
}
43+
3844
if (first_dim) {
3945
active_is_reduction = tv->axis(i)->isReduction();
4046
prev_i = i;
@@ -61,6 +67,10 @@ size_t merge_3d(TensorView* tv) {
6167

6268
for (int i = static_cast<int>(tv->nDims()) - 2; i >= 0; i--) {
6369
auto id = tv->axis(i);
70+
if (dont_merge.count(id)) {
71+
continue;
72+
}
73+
6474
if (first_dim) {
6575
active_is_reduction = id->isReduction();
6676
prev_i = i;
@@ -86,6 +96,10 @@ size_t merge_3d(TensorView* tv) {
8696
prev_i = -1;
8797

8898
for (int i = static_cast<int>(tv->nDims()) - 3; i >= 0; i--) {
99+
if (dont_merge.count(tv->axis(i))) {
100+
continue;
101+
}
102+
89103
if (first_dim) {
90104
active_is_reduction = tv->axis(i)->isReduction();
91105
prev_i = i;
@@ -100,7 +114,7 @@ size_t merge_3d(TensorView* tv) {
100114
if (prev_i == -1) {
101115
// Two dimensional, put merged dimensions first
102116
tv->reorder({{-1, 0}, {-2, 1}});
103-
// [outer, inner]
117+
// [outer, inner, dont_merge...]
104118
if (tv->axis(0)->isReduction()) {
105119
// put reductions as second axis
106120
tv->reorder({{0, 1}, {1, 0}});
@@ -181,11 +195,13 @@ c10::optional<size_t> mergeDims(
181195
return left;
182196
}
183197

184-
size_t mergeReduction(TensorView* tv) {
198+
size_t mergeReduction(
199+
TensorView* tv,
200+
const std::unordered_set<IterDomain*>& dont_merge) {
185201
int prev_i = -1;
186202
size_t num_merged = 0;
187203
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
188-
if (!tv->axis(i)->isReduction()) {
204+
if (!tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
189205
continue;
190206
}
191207
if (prev_i == -1) {
@@ -203,14 +219,16 @@ size_t mergeReduction(TensorView* tv) {
203219
return prev_i == -1 ? 0 : num_merged + 1;
204220
}
205221

206-
size_t mergeNonReduction(TensorView* tv) {
222+
size_t mergeNonReduction(
223+
TensorView* tv,
224+
const std::unordered_set<IterDomain*>& dont_merge) {
207225
int prev_i = -1;
208226
size_t num_merged = 0;
209227
if (tv->nDims() == 0) {
210228
return 0;
211229
}
212230
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
213-
if (tv->axis(i)->isReduction()) {
231+
if (tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
214232
continue;
215233
}
216234
if (prev_i == -1) {
@@ -887,21 +905,63 @@ PersistentBufferSizeReturn persistentBufferSize(
887905
return persistent_buffer_size;
888906
}
889907

908+
std::unordered_set<IterDomain*> getTrivialReductionMap(Fusion* fusion) {
909+
auto all_tvs = ir_utils::allTvs(fusion);
910+
std::unordered_set<IterDomain*> mapped_to_trivial_reduction;
911+
for (auto tv : all_tvs) {
912+
// root domain vs domain shouldn't matter as at this point we shouldn't have
913+
// any transformations.
914+
for (auto id : tv->getRootDomain()) {
915+
if (id->isTrivialReduction()) {
916+
mapped_to_trivial_reduction.emplace(id);
917+
}
918+
}
919+
}
920+
921+
if (!mapped_to_trivial_reduction.empty()) {
922+
// Use the loop map as that is the most permissive
923+
auto ca_map = ComputeAtMap(fusion);
924+
// Make a copy we need to check mappings of all
925+
auto trivial_ids = mapped_to_trivial_reduction;
926+
for (auto tv : all_tvs) {
927+
for (auto id : tv->getRootDomain()) {
928+
if (!id->extent()->isOneInt()) {
929+
continue;
930+
}
931+
if (std::any_of(
932+
trivial_ids.begin(),
933+
trivial_ids.end(),
934+
[&ca_map, &id](IterDomain* trivial_id) {
935+
return ca_map.areMapped(
936+
id, trivial_id, IdMappingMode::PERMISSIVE);
937+
})) {
938+
mapped_to_trivial_reduction.emplace(id);
939+
}
940+
}
941+
}
942+
}
943+
return mapped_to_trivial_reduction;
944+
}
945+
890946
std::pair<bool, bool> canonicalDimReduction(
891947
Fusion* fusion,
892948
TensorView* tv,
893949
bool schedule_3D) {
950+
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
951+
getTrivialReductionMap(fusion);
952+
894953
TORCH_INTERNAL_ASSERT(tv != nullptr);
895954

896955
if (!schedule_3D) {
897956
// We coalesce all reduction axes to the right;
898-
bool has_red_axis = mergeReduction(tv) > 0;
957+
bool has_red_axis = mergeReduction(tv, mapped_to_trivial_reduction) > 0;
899958

900-
bool has_iter_axis = mergeNonReduction(tv) > 0;
959+
bool has_iter_axis = mergeNonReduction(tv, mapped_to_trivial_reduction) > 0;
901960
return {has_iter_axis, has_red_axis};
902961
} else {
903962
TORCH_INTERNAL_ASSERT(
904-
merge_3d(tv) == 3, "Tried 3D merge, but result is not 3D.");
963+
merge_3d(tv, mapped_to_trivial_reduction) == 3,
964+
"Tried 3D merge, but result is not 3D.");
905965
return {true, true};
906966
}
907967
}

torch/csrc/jit/codegen/cuda/scheduler/utils.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,16 @@ TORCH_CUDA_CU_API inline c10::optional<size_t> mergeDims(
7878
}
7979

8080
// Merge all reduction to the right side and returns total number of
81-
// reduction axes.
82-
size_t mergeReduction(TensorView* tv);
81+
// reduction axes. Don't merge is typically used for trivial reductions.
82+
size_t mergeReduction(
83+
TensorView* tv,
84+
const std::unordered_set<IterDomain*>& dont_merge = {});
8385

8486
// merge all non-reduction axes to the left side and returns total number of
85-
// iteration axes.
86-
size_t mergeNonReduction(TensorView* tv);
87+
// iteration axes. Don't merge is typically used for trivial reductions.
88+
size_t mergeNonReduction(
89+
TensorView* tv,
90+
const std::unordered_set<IterDomain*>& dont_merge = {});
8791

8892
// Propagate the parallelization from the selected dimensions of the reference
8993
// tensor to their corresponding dimensions in all selected tensors in the DAG.

0 commit comments

Comments
 (0)