Skip to content

Commit a054b3e

Browse files
authored
Refactor TransormPropagator to allow specifying a position and propagating to part of the DAG (#1775)
`MaxInfoPropagator` is renamed to `MaxInfoSpanningTree`, it now only does path-finding, and the propagation is in a separate class `MaxInfoSpanningTree::Propagator`. Same for `MaxRootDomainInfoPropagator`. `MaxInfoSpanningTree` and `MaxRootDomainInfoSpanningTree` now allow specifying a selector, which controls which subgraph should be included in path-finding. `MaxRootDomainInfoSpanningTree` also gets a few new constructors for convenience to use. `TransormPropagator` is now a subclass of `MaxInfoSpanningTree::Propagator`, so the way to use it has changed. Now `MaxInfoSpanningTree` and `MaxRootDomainInfoSpanningTree` will store the path after generation so that the same path can be traversed multiple times. This will be useful to support use cases like new `computeAt`. Pseudo-code: ```C++ void TensorView::computeAt(TensorView tv, int pos) { auto ComputeAtSubgraphSelector selector(this, tv); MaxRootDomainInfoSpanningTree path(tv, pos, &selector); TransformPropagator propagator(tv, pos); path.traverse(&propagator); ComputeAtPosPropagator ca_propagator(tv, pos); path.traverse(&ca_propagator); } ```
1 parent d67e1cd commit a054b3e

9 files changed

+508
-205
lines changed

torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp

Lines changed: 132 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,24 @@ namespace jit {
66
namespace fuser {
77
namespace cuda {
88

9-
bool MaxInfoPropagator::Information::operator>(const Information& r) const {
9+
bool MaxInfoSpanningTree::Information::operator>(const Information& r) const {
1010
return r < *this;
1111
}
1212

13-
bool MaxInfoPropagator::Information::operator==(const Information& r) const {
13+
bool MaxInfoSpanningTree::Information::operator==(const Information& r) const {
1414
return !(r < *this) && !(*this < r);
1515
}
1616

17-
// Dijkstra
18-
void MaxInfoPropagator::run() {
17+
// Prim's algorithm
18+
MaxInfoSpanningTree::MaxInfoSpanningTree(
19+
TensorView* reference,
20+
std::shared_ptr<Information> reference_info,
21+
Selector* selector)
22+
: reference_(reference),
23+
reference_info_(reference_info),
24+
selector_(selector) {}
25+
26+
void MaxInfoSpanningTree::compute_spanning_tree() {
1927
// A set that allows us to quickly tell if a tensor has been replayed. If yes,
2028
// then we will not bother computing if a new path to this tensor is worth
2129
// taking (because the answer is always not worth)
@@ -28,88 +36,114 @@ void MaxInfoPropagator::run() {
2836
// std::list instead of std::priority_queue because C++'s
2937
// std::priority_queue does not support increase-key, and might not be
3038
// deterministic either.
31-
std::list<NextHopInfo> propagation(1);
32-
propagation.back().from = nullptr;
33-
propagation.back().to = reference;
34-
propagation.back().info_to = reference_info;
39+
std::list<NextHopWithInfo> candidates(1);
40+
candidates.back().next_hop.from = nullptr;
41+
candidates.back().next_hop.to = reference_;
42+
candidates.back().info_to = reference_info_;
3543

36-
// Insert the given next hop the correct position in `propagation`. If there
44+
// Insert the given next hop the correct position in `candidates`. If there
3745
// is an existing next hop that preserves more information, then we will just
3846
// discard `info`.
39-
auto insertNextHopInfo = [&](const NextHopInfo& info) {
47+
auto insertNextHop = [&](const NextHopWithInfo& info) {
4048
if (!*(info.info_from)) {
4149
// When there is no more information about the starting tensor,
42-
// we are not interested in continuing the propagation.
50+
// we are not interested in continuing the path-finding.
4351
return;
4452
}
4553
// Find if there is already a path to the dest tensor
4654
auto existing = std::find_if(
47-
propagation.begin(), propagation.end(), [&](const NextHopInfo& i) {
48-
return i.to == info.to;
55+
candidates.begin(), candidates.end(), [&](const NextHopWithInfo& i) {
56+
return i.next_hop.to == info.next_hop.to;
4957
});
5058
// Only insert if there is no existing path to the dest tensor, or the new
5159
// path preserves more information about the starting tensor.
52-
if (existing == propagation.end() || *existing < info) {
53-
if (existing != propagation.end()) {
54-
propagation.erase(existing);
60+
if (existing == candidates.end() || *existing < info) {
61+
if (existing != candidates.end()) {
62+
candidates.erase(existing);
5563
}
56-
auto pos = std::upper_bound(propagation.begin(), propagation.end(), info);
57-
propagation.insert(pos, info);
64+
auto pos = std::upper_bound(candidates.begin(), candidates.end(), info);
65+
candidates.insert(pos, info);
66+
}
67+
};
68+
69+
auto allowPasC = [this](TensorView* from, TensorView* to) {
70+
if (selector_ == nullptr) {
71+
return true;
5872
}
73+
return selector_->allowPasC(from, to);
5974
};
6075

61-
while (!propagation.empty()) {
62-
auto next_hop = propagation.back();
63-
propagation.pop_back();
76+
auto allowCasP = [this](TensorView* from, TensorView* to) {
77+
if (selector_ == nullptr) {
78+
return true;
79+
}
80+
return selector_->allowCasP(from, to);
81+
};
82+
83+
while (!candidates.empty()) {
84+
const auto next_hop_info = candidates.back();
85+
const auto& next_hop = next_hop_info.next_hop;
86+
candidates.pop_back();
6487

6588
if (next_hop.from != nullptr) {
6689
// nullptr used to start from reference
67-
switch (next_hop.type) {
68-
case NextHopType::C_AS_P:
69-
propagateTvCasP(next_hop.from, next_hop.to);
70-
break;
71-
case NextHopType::P_AS_C:
72-
propagateTvPasC(next_hop.from, next_hop.to);
73-
break;
74-
}
90+
path_.push_back(next_hop);
7591
}
7692
replayed.emplace(next_hop.to);
7793

7894
for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) {
79-
if (replayed.count(consumer_tv)) {
95+
if (replayed.count(consumer_tv) || !allowCasP(next_hop.to, consumer_tv)) {
8096
continue;
8197
}
82-
insertNextHopInfo(
83-
{.type = NextHopType::C_AS_P,
84-
.from = next_hop.to,
85-
.to = consumer_tv,
86-
.info_from = next_hop.info_to,
87-
.info_to =
88-
computeInfoCasP(next_hop.to, consumer_tv, next_hop.info_to)});
98+
insertNextHop(
99+
{.next_hop =
100+
{.type = NextHopType::C_AS_P,
101+
.from = next_hop.to,
102+
.to = consumer_tv},
103+
.info_from = next_hop_info.info_to,
104+
.info_to = computeInfoCasP(
105+
next_hop.to, consumer_tv, next_hop_info.info_to)});
89106
}
90107

91108
for (auto producer_tv : ir_utils::producerTvsOf(next_hop.to)) {
92-
if (replayed.count(producer_tv)) {
109+
if (replayed.count(producer_tv) || !allowPasC(next_hop.to, producer_tv)) {
93110
continue;
94111
}
95-
insertNextHopInfo(
96-
{.type = NextHopType::P_AS_C,
97-
.from = next_hop.to,
98-
.to = producer_tv,
99-
.info_from = next_hop.info_to,
100-
.info_to =
101-
computeInfoPasC(next_hop.to, producer_tv, next_hop.info_to)});
112+
insertNextHop(
113+
{.next_hop =
114+
{.type = NextHopType::P_AS_C,
115+
.from = next_hop.to,
116+
.to = producer_tv},
117+
.info_from = next_hop_info.info_to,
118+
.info_to = computeInfoPasC(
119+
next_hop.to, producer_tv, next_hop_info.info_to)});
102120
}
103121
}
104122
}
105123

106-
MaxRootDomainInfoPropagator::RootDomainInfo::operator bool() const {
124+
void MaxInfoSpanningTree::traverse(Propagator* propagator) {
125+
if (path_.empty()) {
126+
compute_spanning_tree();
127+
}
128+
for (const auto& next_hop : path_) {
129+
switch (next_hop.type) {
130+
case NextHopType::C_AS_P:
131+
propagator->propagateTvCasP(next_hop.from, next_hop.to);
132+
break;
133+
case NextHopType::P_AS_C:
134+
propagator->propagateTvPasC(next_hop.from, next_hop.to);
135+
break;
136+
}
137+
}
138+
}
139+
140+
MaxRootDomainInfoSpanningTree::RootDomainInfo::operator bool() const {
107141
return !info.empty();
108142
}
109143

110-
bool MaxRootDomainInfoPropagator::RootDomainInfo::operator<(
111-
const MaxInfoPropagator::Information& r) const {
112-
auto rr = dynamic_cast<const MaxRootDomainInfoPropagator::RootDomainInfo&>(r);
144+
bool MaxRootDomainInfoSpanningTree::RootDomainInfo::operator<(
145+
const Information& r) const {
146+
auto rr = dynamic_cast<const RootDomainInfo&>(r);
113147
if (info.size() != rr.info.size()) {
114148
return info.size() < rr.info.size();
115149
}
@@ -174,17 +208,17 @@ std::unordered_set<IterDomain*> mapRFactorToRoot(
174208
// Given the preserved reference root ID info of a producer, compute
175209
// the corresponding info in consumer. The given info may be represented by
176210
// producer's root domain, or rfactor domain, depending on how we reached the
177-
// producer during propagation. If the given info is already represented with
211+
// producer during path-finding. If the given info is already represented with
178212
// producer's rfactor domain, then we directly map it to the consumer's root
179213
// domain. If the given info is represented with producer's root domain, we need
180214
// to first map it to the rfactor domain of the producer, then we can map it to
181215
// the consumer's root domain. The computed info will be represented by root
182216
// domain as root domain contains the raw information.
183-
std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
217+
std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree::
184218
computeInfoCasP(
185219
TensorView* from,
186220
TensorView* to,
187-
std::shared_ptr<Information> from_info) {
221+
std::shared_ptr<Information> from_info) const {
188222
RootDomainInfo result;
189223

190224
TensorView* producer = from;
@@ -231,17 +265,17 @@ std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
231265
// Given the preserved reference root ID info of a consumer, compute
232266
// the corresponding info in producer. The given info may be represented by
233267
// consumer's root domain, or rfactor domain, depending on how we reached the
234-
// consumer during propagation. If the given info is already represented with
268+
// consumer during path-finding. If the given info is already represented with
235269
// consumer's root domain, then we directly map it to the producer's rfactor
236270
// domain. If the given info is represented with consumer's rfactor domain, we
237271
// need to first map it to the root domain of the consumer, then we can map it
238272
// to the producer's rfactor domain. The computed info will be represented by
239273
// rfactor domain as rfactor domain contains the raw information.
240-
std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
274+
std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree::
241275
computeInfoPasC(
242276
TensorView* from,
243277
TensorView* to,
244-
std::shared_ptr<Information> from_info) {
278+
std::shared_ptr<Information> from_info) const {
245279
RootDomainInfo result;
246280

247281
TensorView* producer = to;
@@ -279,9 +313,9 @@ std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
279313
// We will stop at the rfactor ids in producer, and will not further map
280314
// them into root ids in producer. This means, we only keep the unprocessed
281315
// raw information of a tensor. This behavior is important to make sure that
282-
// info is as accurate as possible throughout the propagation.
316+
// info is as accurate as possible throughout the path-finding.
283317
//
284-
// For example, if we do a C->P->C' propagation, we want to do
318+
// For example, in a C->P->C' path, we want to do
285319
// C(root) -> P(rfactor) -> C'(root)
286320
// instead of
287321
// C(root) -> P(rfactor) -> P(root) -> P(rfactor) -> C'(root)
@@ -305,6 +339,47 @@ std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
305339
return std::make_shared<RootDomainInfo>(std::move(result));
306340
}
307341

342+
std::shared_ptr<MaxRootDomainInfoSpanningTree::RootDomainInfo>
343+
MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo(TensorView* tv) {
344+
RootDomainInfo result;
345+
const auto& root_domain = tv->getRootDomain();
346+
result.info.reserve(root_domain.size());
347+
for (auto id : root_domain) {
348+
result.info.emplace_back(RootIDInfo{{id}, true, false});
349+
}
350+
return std::make_shared<RootDomainInfo>(std::move(result));
351+
}
352+
353+
std::shared_ptr<MaxRootDomainInfoSpanningTree::RootDomainInfo>
354+
MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo(
355+
TensorView* tv,
356+
int64_t leaf_pos) {
357+
if (leaf_pos < 0) {
358+
leaf_pos += int64_t(tv->nDims()) + 1;
359+
}
360+
TORCH_CHECK(
361+
leaf_pos >= 0 && leaf_pos <= tv->nDims(),
362+
"MaxRootDomainInfoSpanningTree called on an leaf_pos outside valid range.");
363+
RootDomainInfo result;
364+
const auto& root_domain = tv->getMaybeRFactorDomain();
365+
const auto& leaf_domain = tv->domain()->domain();
366+
std::unordered_set<IterDomain*> selected_leaves(
367+
leaf_domain.begin(), leaf_domain.begin() + leaf_pos);
368+
for (auto id : root_domain) {
369+
if (selected_leaves.count(id) > 0) {
370+
result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()});
371+
continue;
372+
}
373+
for (auto selected_leaf_id : selected_leaves) {
374+
if (DependencyCheck::isDependencyOf(id, selected_leaf_id)) {
375+
result.info.emplace_back(RootIDInfo{{id}, true, tv->hasRFactor()});
376+
break;
377+
}
378+
}
379+
}
380+
return std::make_shared<RootDomainInfo>(std::move(result));
381+
}
382+
308383
} // namespace cuda
309384
} // namespace fuser
310385
} // namespace jit

0 commit comments

Comments
 (0)