Skip to content

Commit bca20c1

Browse files
authored
Cleanup trivial reduction workarounds (#2006)
1 parent e4b6585 commit bca20c1

File tree

7 files changed

+139
-154
lines changed

7 files changed

+139
-154
lines changed

torch/csrc/jit/codegen/cuda/inlining.cpp

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,29 +153,25 @@ size_t MaxPosCalculator::getMaxPosAll(
153153
return max_pos;
154154
}
155155

156-
void inlineMost(const std::unordered_set<IterDomain*>& uninlinable_ids) {
157-
inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids);
156+
void inlineMost() {
157+
inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()));
158158
}
159159

160-
void inlineMost(
161-
const std::vector<TensorView*>& tvs,
162-
const std::unordered_set<IterDomain*>& uninlinable_ids) {
160+
void inlineMost(const std::vector<TensorView*>& tvs) {
163161
if (tvs.empty()) {
164162
return;
165163
}
166-
MaxPosCalculator calc(uninlinable_ids);
164+
MaxPosCalculator calc;
167165
for (auto tv : tvs) {
168166
tv->inlineAt(-1, true, &calc);
169167
}
170168
}
171169

172-
void inlineMost(
173-
const std::unordered_set<TensorView*>& tvs,
174-
const std::unordered_set<IterDomain*>& uninlinable_ids) {
170+
void inlineMost(const std::unordered_set<TensorView*>& tvs) {
175171
if (tvs.empty()) {
176172
return;
177173
}
178-
MaxPosCalculator calc(uninlinable_ids);
174+
MaxPosCalculator calc;
179175
for (auto tv : tvs) {
180176
tv->inlineAt(-1, true, &calc);
181177
}
@@ -276,10 +272,9 @@ std::unordered_map<TensorView*, size_t> getPositionsMappedTo(
276272
void inlineAllAt(
277273
TensorView* reference_tv,
278274
int64_t reference_pos,
279-
bool best_effort,
280-
const std::unordered_set<IterDomain*>& uninlinable_ids) {
275+
bool best_effort) {
281276
auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
282-
MaxPosCalculator calc(uninlinable_ids);
277+
MaxPosCalculator calc;
283278
for (auto pair : mapped_positions) {
284279
pair.first->inlineAt(pair.second, best_effort, &calc);
285280
}
@@ -289,10 +284,9 @@ void inlineSelectedAt(
289284
const std::unordered_set<TensorView*>& selected,
290285
TensorView* reference_tv,
291286
int64_t reference_pos,
292-
bool best_effort,
293-
const std::unordered_set<IterDomain*>& uninlinable_ids) {
287+
bool best_effort) {
294288
auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos);
295-
MaxPosCalculator calc(uninlinable_ids);
289+
MaxPosCalculator calc;
296290
for (auto pair : mapped_positions) {
297291
if (selected.count(pair.first) > 0) {
298292
pair.first->inlineAt(pair.second, best_effort, &calc);

torch/csrc/jit/codegen/cuda/inlining.h

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,35 +64,28 @@ class MaxPosCalculator {
6464

6565
// Inline to the right most allowed position for all tensors in the current
6666
// fusion.
67-
TORCH_CUDA_CU_API void inlineMost(
68-
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
67+
TORCH_CUDA_CU_API void inlineMost();
6968
// Inline to the right most allowed position for the selected tensors in the
7069
// current fusion.
71-
TORCH_CUDA_CU_API void inlineMost(
72-
const std::vector<TensorView*>& tvs,
73-
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
70+
TORCH_CUDA_CU_API void inlineMost(const std::vector<TensorView*>& tvs);
7471
// Inline to the right most allowed position for the selected tensors in the
7572
// current fusion.
76-
TORCH_CUDA_CU_API void inlineMost(
77-
const std::unordered_set<TensorView*>& tvs,
78-
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
73+
TORCH_CUDA_CU_API void inlineMost(const std::unordered_set<TensorView*>& tvs);
7974

8075
// Inline to the position corresponding to the reference position in the
8176
// reference tensor for all tensors in the current fusion.
8277
TORCH_CUDA_CU_API void inlineAllAt(
8378
TensorView* reference_tv,
8479
int64_t reference_pos,
85-
bool best_effort = false,
86-
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
80+
bool best_effort = false);
8781

8882
// Inline to the position corresponding to the reference position in the
8983
// reference tensor for selected tensors in the current fusion.
9084
TORCH_CUDA_CU_API void inlineSelectedAt(
9185
const std::unordered_set<TensorView*>& selected,
9286
TensorView* reference_tv,
9387
int64_t reference_pos,
94-
bool best_effort = false,
95-
const std::unordered_set<IterDomain*>& uninlinable_ids = {});
88+
bool best_effort = false);
9689

9790
} // namespace cuda
9891
} // namespace fuser

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,43 @@ std::vector<IterDomain*> IterDomain::clone(
15971597
return cloned_domains;
15981598
}
15991599

1600+
IterType inferIterType(IterDomain* i1, IterDomain* i2) {
1601+
// The itertype inference is a pattern matching of the rules below:
1602+
//
1603+
// X + X = X
1604+
// trivial reduction + X = X
1605+
// X + trivial reduction = X
1606+
// broadcasting + X = X
1607+
// X + broadcasting = X
1608+
// fail
1609+
//
1610+
// The rules are proceeded one by one in order. For each rule, we test if the
1611+
// given (outer, inner) matches the pattern. If it does, then we stop
1612+
// procceeding and get a result. If we have reached the end without finding
1613+
// any matched pattern, then it is a mistake and should be reported.
1614+
//
1615+
// Note that based on the above rule:
1616+
// broadcasting + (non-trivial) reduction = reduction
1617+
// broadcasting + trivial reduction = broadcasting
1618+
if (i1->getIterType() == i2->getIterType()) {
1619+
return i1->getIterType();
1620+
}
1621+
if (i1->isTrivialReduction()) {
1622+
return i2->getIterType();
1623+
}
1624+
if (i2->isTrivialReduction()) {
1625+
return i1->getIterType();
1626+
}
1627+
if (i1->isBroadcast()) {
1628+
return i2->getIterType();
1629+
}
1630+
if (i2->isBroadcast()) {
1631+
return i1->getIterType();
1632+
}
1633+
TORCH_CHECK(
1634+
false, "Merging IterDomains requires that their iteration types match.");
1635+
}
1636+
16001637
// Merging does not propagate the start and stop values of the input
16011638
// domains to the merged output domain. The actual range of the
16021639
// domains is enforced by predicates. Note that since only root
@@ -1606,48 +1643,10 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
16061643
TORCH_CHECK(
16071644
!outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(),
16081645
"Merging IterDomains with ending values that are 0 is not supported at this time.");
1609-
TORCH_CHECK(
1610-
outer->isReduction() == inner->isReduction() ||
1611-
(!outer->isReduction() && inner->isTrivialReduction()) ||
1612-
(outer->isTrivialReduction() && !inner->isReduction()),
1613-
"Merging IterDomains requires that their iteration types match.");
1614-
TORCH_CHECK(
1615-
(outer->isGather() && inner->isGather()) ||
1616-
(!outer->isGather() && !inner->isGather()),
1617-
"Merging gather and non-gather domains is not supported.");
1618-
1619-
TORCH_CHECK(
1620-
!outer->isStride() && !inner->isStride(),
1621-
"No support for merging stride domains");
16221646

16231647
Val* merged_id_size = mul(outer->extent(), inner->extent());
16241648

1625-
IterType itype = outer->getIterType();
1626-
1627-
if (outer->isBroadcast() && inner->isBroadcast()) {
1628-
itype = IterType::Broadcast;
1629-
}
1630-
1631-
if ((outer->isBroadcast() || inner->isBroadcast()) &&
1632-
(outer->getIterType() == IterType::Iteration ||
1633-
inner->getIterType() == IterType::Iteration)) {
1634-
itype = IterType::Iteration;
1635-
}
1636-
1637-
// Merging trivial reduction with iter domain, that's fine, just make it an
1638-
// iter domain.
1639-
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
1640-
(outer->getIterType() == IterType::Iteration ||
1641-
inner->getIterType() == IterType::Iteration)) {
1642-
itype = IterType::Iteration;
1643-
}
1644-
1645-
// Merging trivial reduction with broadcasting, that's fine, just make it a
1646-
// broadcasting.
1647-
if ((outer->isTrivialReduction() || inner->isTrivialReduction()) &&
1648-
(outer->isBroadcast() || inner->isBroadcast())) {
1649-
itype = IterType::Broadcast;
1650-
}
1649+
IterType itype = inferIterType(outer, inner);
16511650

16521651
Val* expanded_extent = nullptr;
16531652
if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) {

torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,13 +330,8 @@ void multiReductionInliner(
330330
}
331331
}
332332

333-
// Find iter domains that are mapped to a trivial reduction, these should
334-
// never be inlined.
335-
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
336-
scheduler_utils::getTrivialReductionMap(fusion);
337-
338333
// Inline the schedule
339-
inlineMost(mapped_to_trivial_reduction);
334+
inlineMost();
340335
}
341336

342337
namespace {

torch/csrc/jit/codegen/cuda/scheduler/utils.cpp

Lines changed: 13 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,20 @@ namespace scheduler_utils {
2121

2222
// Returns number of "valid" dimensions. e.g. if tv has
2323
// [I1, R2, I3, I4, R3{1}]
24-
// where R3{1} is in dont_merge, resulting domain should be:
25-
// [I1, I3*I4, R2, R3{1}] with return value 3
24+
// resulting domain should be:
25+
// [I1, I3*I4, R2*R3{1}] with return value 3
2626
//
2727
// if tv has
2828
// [R1, I2, R3, I4, R4, R5{1}, R6{1}]
29-
// where R5{1} and R6{1} are in dont_merge, resulting domain should be:
30-
// [I2*I4, R1*R3, R4, R5{1}, R6{1}]
29+
// resulting domain should be:
30+
// [I2*I4, R1*R3, R4*R5{1}*R6{1}]
3131
// with return value 3
32-
size_t merge_3d(
33-
TensorView* tv,
34-
const std::unordered_set<IterDomain*>& dont_merge) {
32+
size_t merge_3d(TensorView* tv) {
3533
bool active_is_reduction = false;
3634
bool first_dim = true;
3735
int prev_i = -1;
3836

3937
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
40-
if (dont_merge.count(tv->axis(i))) {
41-
continue;
42-
}
43-
4438
if (first_dim) {
4539
active_is_reduction = tv->axis(i)->isReduction();
4640
prev_i = i;
@@ -67,10 +61,6 @@ size_t merge_3d(
6761

6862
for (int i = static_cast<int>(tv->nDims()) - 2; i >= 0; i--) {
6963
auto id = tv->axis(i);
70-
if (dont_merge.count(id)) {
71-
continue;
72-
}
73-
7464
if (first_dim) {
7565
active_is_reduction = id->isReduction();
7666
prev_i = i;
@@ -96,10 +86,6 @@ size_t merge_3d(
9686
prev_i = -1;
9787

9888
for (int i = static_cast<int>(tv->nDims()) - 3; i >= 0; i--) {
99-
if (dont_merge.count(tv->axis(i))) {
100-
continue;
101-
}
102-
10389
if (first_dim) {
10490
active_is_reduction = tv->axis(i)->isReduction();
10591
prev_i = i;
@@ -114,7 +100,7 @@ size_t merge_3d(
114100
if (prev_i == -1) {
115101
// Two dimensional, put merged dimensions first
116102
tv->reorder({{-1, 0}, {-2, 1}});
117-
// [outer, inner, dont_merge...]
103+
// [outer, inner]
118104
if (tv->axis(0)->isReduction()) {
119105
// put reductions as second axis
120106
tv->reorder({{0, 1}, {1, 0}});
@@ -195,13 +181,11 @@ c10::optional<size_t> mergeDims(
195181
return left;
196182
}
197183

198-
size_t mergeReduction(
199-
TensorView* tv,
200-
const std::unordered_set<IterDomain*>& dont_merge) {
184+
size_t mergeReduction(TensorView* tv) {
201185
int prev_i = -1;
202186
size_t num_merged = 0;
203187
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
204-
if (!tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
188+
if (!tv->axis(i)->isReduction()) {
205189
continue;
206190
}
207191
if (prev_i == -1) {
@@ -219,16 +203,14 @@ size_t mergeReduction(
219203
return prev_i == -1 ? 0 : num_merged + 1;
220204
}
221205

222-
size_t mergeNonReduction(
223-
TensorView* tv,
224-
const std::unordered_set<IterDomain*>& dont_merge) {
206+
size_t mergeNonReduction(TensorView* tv) {
225207
int prev_i = -1;
226208
size_t num_merged = 0;
227209
if (tv->nDims() == 0) {
228210
return 0;
229211
}
230212
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
231-
if (tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) {
213+
if (tv->axis(i)->isReduction()) {
232214
continue;
233215
}
234216
if (prev_i == -1) {
@@ -905,63 +887,21 @@ PersistentBufferSizeReturn persistentBufferSize(
905887
return persistent_buffer_size;
906888
}
907889

908-
std::unordered_set<IterDomain*> getTrivialReductionMap(Fusion* fusion) {
909-
auto all_tvs = ir_utils::allTvs(fusion);
910-
std::unordered_set<IterDomain*> mapped_to_trivial_reduction;
911-
for (auto tv : all_tvs) {
912-
// root domain vs domain shouldn't matter as at this point we shouldn't have
913-
// any transformations.
914-
for (auto id : tv->getRootDomain()) {
915-
if (id->isTrivialReduction()) {
916-
mapped_to_trivial_reduction.emplace(id);
917-
}
918-
}
919-
}
920-
921-
if (!mapped_to_trivial_reduction.empty()) {
922-
// Use the loop map as that is the most permissive
923-
auto ca_map = ComputeAtMap(fusion);
924-
// Make a copy we need to check mappings of all
925-
auto trivial_ids = mapped_to_trivial_reduction;
926-
for (auto tv : all_tvs) {
927-
for (auto id : tv->getRootDomain()) {
928-
if (!id->extent()->isOneInt()) {
929-
continue;
930-
}
931-
if (std::any_of(
932-
trivial_ids.begin(),
933-
trivial_ids.end(),
934-
[&ca_map, &id](IterDomain* trivial_id) {
935-
return ca_map.areMapped(
936-
id, trivial_id, IdMappingMode::PERMISSIVE);
937-
})) {
938-
mapped_to_trivial_reduction.emplace(id);
939-
}
940-
}
941-
}
942-
}
943-
return mapped_to_trivial_reduction;
944-
}
945-
946890
std::pair<bool, bool> canonicalDimReduction(
947891
Fusion* fusion,
948892
TensorView* tv,
949893
bool schedule_3D) {
950-
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
951-
getTrivialReductionMap(fusion);
952-
953894
TORCH_INTERNAL_ASSERT(tv != nullptr);
954895

955896
if (!schedule_3D) {
956897
// We coalesce all reduction axes to the right;
957-
bool has_red_axis = mergeReduction(tv, mapped_to_trivial_reduction) > 0;
898+
bool has_red_axis = mergeReduction(tv) > 0;
958899

959-
bool has_iter_axis = mergeNonReduction(tv, mapped_to_trivial_reduction) > 0;
900+
bool has_iter_axis = mergeNonReduction(tv) > 0;
960901
return {has_iter_axis, has_red_axis};
961902
} else {
962903
TORCH_INTERNAL_ASSERT(
963-
merge_3d(tv, mapped_to_trivial_reduction) == 3,
964-
"Tried 3D merge, but result is not 3D.");
904+
merge_3d(tv) == 3, "Tried 3D merge, but result is not 3D.");
965905
return {true, true};
966906
}
967907
}

torch/csrc/jit/codegen/cuda/scheduler/utils.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,12 @@ TORCH_CUDA_CU_API inline c10::optional<size_t> mergeDims(
7878
}
7979

8080
// Merge all reduction to the right side and returns total number of
81-
// reduction axes. Don't merge is typically used for trivial reductions.
82-
size_t mergeReduction(
83-
TensorView* tv,
84-
const std::unordered_set<IterDomain*>& dont_merge = {});
81+
// reduction axes.
82+
size_t mergeReduction(TensorView* tv);
8583

8684
// merge all non-reduction axes to the left side and returns total number of
87-
// iteration axes. Don't merge is typically used for trivial reductions.
88-
size_t mergeNonReduction(
89-
TensorView* tv,
90-
const std::unordered_set<IterDomain*>& dont_merge = {});
85+
// iteration axes.
86+
size_t mergeNonReduction(TensorView* tv);
9187

9288
// Propagate the parallelization from the selected dimensions of the reference
9389
// tensor to their corresponding dimensions in all selected tensors in the DAG.

0 commit comments

Comments
 (0)