Skip to content

Commit 33a824d

Browse files
authored
Adding sibling path for MaxInfoSpanningTree (#1776)
The sibling path is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector. For example, when the producer of a Welford is excluded from the propagation section. See test `FusionTransformPropagateSelectorSibling_CUDA` for a detailed example. Besides, since we know that siblings should be transformed exactly the same, the sibling path is a perfect next hop for preserving information. If you want a spanning tree without a sibling path, you can override `allowSibling` as `return false` in your selector;
1 parent 86f46aa commit 33a824d

File tree

7 files changed

+221
-54
lines changed

7 files changed

+221
-54
lines changed

torch/csrc/jit/codegen/cuda/ir_utils.cpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,22 @@ TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(Val* val) {
529529
return uniqueEntries<Val>(consumer_vals);
530530
}
531531

532+
// Return immediate siblings of val
533+
TORCH_CUDA_CU_API std::vector<Val*> siblingValsOf(Val* val) {
534+
std::vector<Val*> sibling_vals;
535+
auto def = val->definition();
536+
if (def != nullptr) {
537+
auto outs = def->outputs();
538+
for (auto sibling_val : outs) {
539+
if (sibling_val == val) {
540+
continue;
541+
}
542+
sibling_vals.emplace_back(sibling_val);
543+
}
544+
}
545+
return sibling_vals;
546+
}
547+
532548
// Return immediate producers of val
533549
TORCH_CUDA_CU_API std::vector<Val*> producerValsOf(
534550
const std::vector<Val*>& vals) {
@@ -556,22 +572,21 @@ TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(
556572
}
557573

558574
std::vector<TensorView*> producerTvsOf(TensorView* tv) {
559-
if (tv->definition() == nullptr) {
560-
return {};
561-
}
562-
auto producer_vals =
563-
ir_utils::filterByType<TensorView>(tv->definition()->inputs());
564-
return uniqueEntries<TensorView>(
565-
{producer_vals.begin(), producer_vals.end()});
575+
auto producer_vals = producerValsOf(tv);
576+
auto producer_tvs = ir_utils::filterByType<TensorView>(producer_vals);
577+
return {producer_tvs.begin(), producer_tvs.end()};
566578
}
567579

568580
std::vector<TensorView*> consumerTvsOf(TensorView* tv) {
569-
std::vector<TensorView*> consumer_tvs;
570-
for (auto use_expr : tv->uses()) {
571-
auto outputs = ir_utils::filterByType<TensorView>(use_expr->outputs());
572-
consumer_tvs.insert(consumer_tvs.end(), outputs.begin(), outputs.end());
573-
}
574-
return uniqueEntries<TensorView>(consumer_tvs);
581+
auto consumer_vals = consumerValsOf(tv);
582+
auto consumer_tvs = ir_utils::filterByType<TensorView>(consumer_vals);
583+
return {consumer_tvs.begin(), consumer_tvs.end()};
584+
}
585+
586+
TORCH_CUDA_CU_API std::vector<TensorView*> siblingTvsOf(TensorView* tv) {
587+
auto sibling_vals = siblingValsOf(tv);
588+
auto sibling_tvs = ir_utils::filterByType<TensorView>(sibling_vals);
589+
return {sibling_tvs.begin(), sibling_tvs.end()};
575590
}
576591

577592
std::vector<TensorView*> producerTvsOf(const std::vector<TensorView*>& tvs) {

torch/csrc/jit/codegen/cuda/ir_utils.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,16 @@ TORCH_CUDA_CU_API std::vector<Val*> producerValsOf(Val* val);
181181
// code.
182182
TORCH_CUDA_CU_API std::vector<Val*> consumerValsOf(Val* val);
183183

184+
// Return immediate siblings of val, this function can be used on any Val and
185+
// will return siblings through Exprs.
186+
//
187+
// Warning: returned val's are not guaranteed to be between fusion inputs and
188+
// outputs. This function simply uses val->definition() or val->uses() which is
189+
// limited to not go through fusion inputs/outputs, but if on a path that isn't
190+
// strictly between fusion inputs/outputs, it could effectively return dead
191+
// code.
192+
TORCH_CUDA_CU_API std::vector<Val*> siblingValsOf(Val* val);
193+
184194
// Return immediate producers of vals, this function can be used on any vals and
185195
// will return producers through Exprs.
186196
//
@@ -223,6 +233,16 @@ TORCH_CUDA_CU_API std::vector<TensorView*> producerTvsOf(TensorView* tv);
223233
// code.
224234
TORCH_CUDA_CU_API std::vector<TensorView*> consumerTvsOf(TensorView* tv);
225235

236+
// Return immediate siblings of tv, this function will return all immediate
237+
// siblings of tv through Exprs.
238+
//
239+
// Warning: returned tv's are not guaranteed to be between fusion inputs and
240+
// outputs. This function simply uses tv->definition() or tv->uses() which is
241+
// limited to not go through fusion inputs/outputs, but if on a path that isn't
242+
// strictly between fusion inputs/outputs, it could effectively return dead
243+
// code.
244+
TORCH_CUDA_CU_API std::vector<TensorView*> siblingTvsOf(TensorView* tv);
245+
226246
// Return immediate producers of tvs, this function will return all immediate
227247
// producers of tvs through Exprs.
228248
//

torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ void MaxInfoSpanningTree::compute_spanning_tree() {
8080
return selector_->allowCasP(from, to);
8181
};
8282

83+
auto allowSibling = [this](TensorView* from, TensorView* to) {
84+
if (selector_ == nullptr) {
85+
return true;
86+
}
87+
return selector_->allowSibling(from, to);
88+
};
89+
8390
while (!candidates.empty()) {
8491
const auto next_hop_info = candidates.back();
8592
const auto& next_hop = next_hop_info.next_hop;
@@ -91,6 +98,21 @@ void MaxInfoSpanningTree::compute_spanning_tree() {
9198
}
9299
replayed.emplace(next_hop.to);
93100

101+
for (auto sibling_tv : ir_utils::siblingTvsOf(next_hop.to)) {
102+
if (replayed.count(sibling_tv) ||
103+
!allowSibling(next_hop.to, sibling_tv)) {
104+
continue;
105+
}
106+
insertNextHop(
107+
{.next_hop =
108+
{.type = NextHopType::SIBLING,
109+
.from = next_hop.to,
110+
.to = sibling_tv},
111+
.info_from = next_hop_info.info_to,
112+
.info_to = computeInfoSibling(
113+
next_hop.to, sibling_tv, next_hop_info.info_to)});
114+
}
115+
94116
for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) {
95117
if (replayed.count(consumer_tv) || !allowCasP(next_hop.to, consumer_tv)) {
96118
continue;
@@ -127,6 +149,9 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) {
127149
}
128150
for (const auto& next_hop : path_) {
129151
switch (next_hop.type) {
152+
case NextHopType::SIBLING:
153+
propagator->propagateTvSibling(next_hop.from, next_hop.to);
154+
break;
130155
case NextHopType::C_AS_P:
131156
propagator->propagateTvCasP(next_hop.from, next_hop.to);
132157
break;
@@ -380,6 +405,17 @@ MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo(
380405
return std::make_shared<RootDomainInfo>(std::move(result));
381406
}
382407

408+
// Given the preserved reference root ID info of a tensor, compute
409+
// the corresponding info in its sibling. Since info has nothing to do with
410+
// replay state, so sibling info is always identical by definition.
411+
std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree::
412+
computeInfoSibling(
413+
TensorView* from,
414+
TensorView* to,
415+
std::shared_ptr<Information> from_info) const {
416+
return from_info;
417+
}
418+
383419
} // namespace cuda
384420
} // namespace fuser
385421
} // namespace jit

torch/csrc/jit/codegen/cuda/maxinfo_propagator.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ namespace cuda {
2929
* MaxInfoSpanningTree::Information and implement `operator<` which is used to
3030
* tell which path contains more information, and `operator bool` which is used
3131
* to tell if there is any information stored. You also need to implement
32-
* computeInfoPasC and computeInfoCasP, which are the functions that compute
33-
* information of the `to` tensor from the information of the `from` tensor.
32+
* computeInfoPasC, computeInfoCasP, and computeInfoSibling, which are the
33+
* functions that compute information of the `to` tensor from the information of
34+
* the `from` tensor.
3435
*/
3536
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
3637
class TORCH_CUDA_CU_API MaxInfoSpanningTree {
@@ -40,12 +41,14 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree {
4041
struct Selector {
4142
virtual bool allowPasC(TensorView* from, TensorView* to) = 0;
4243
virtual bool allowCasP(TensorView* from, TensorView* to) = 0;
44+
virtual bool allowSibling(TensorView* from, TensorView* to) = 0;
4345
};
4446

4547
// This is the interface to implement the actual propagation
4648
struct Propagator {
4749
virtual void propagateTvPasC(TensorView* from, TensorView* to) = 0;
4850
virtual void propagateTvCasP(TensorView* from, TensorView* to) = 0;
51+
virtual void propagateTvSibling(TensorView* from, TensorView* to) = 0;
4952
};
5053

5154
// This is the interface that specifies the structure of information used to
@@ -71,6 +74,7 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree {
7174

7275
private:
7376
enum class NextHopType {
77+
SIBLING,
7478
C_AS_P,
7579
P_AS_C,
7680
};
@@ -109,6 +113,10 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree {
109113
TensorView* from,
110114
TensorView* to,
111115
std::shared_ptr<Information> from_info) const = 0;
116+
virtual std::shared_ptr<Information> computeInfoSibling(
117+
TensorView* from,
118+
TensorView* to,
119+
std::shared_ptr<Information> from_info) const = 0;
112120

113121
public:
114122
MaxInfoSpanningTree(
@@ -190,6 +198,10 @@ class TORCH_CUDA_CU_API MaxRootDomainInfoSpanningTree
190198
TensorView* from,
191199
TensorView* to,
192200
std::shared_ptr<Information> from_info) const override;
201+
virtual std::shared_ptr<Information> computeInfoSibling(
202+
TensorView* from,
203+
TensorView* to,
204+
std::shared_ptr<Information> from_info) const override;
193205

194206
private:
195207
static std::shared_ptr<RootDomainInfo> getReferenceRootIDInfo(TensorView* tv);

0 commit comments

Comments
 (0)