@@ -6,16 +6,24 @@ namespace jit {
6
6
namespace fuser {
7
7
namespace cuda {
8
8
9
- bool MaxInfoPropagator ::Information::operator >(const Information& r) const {
9
+ bool MaxInfoSpanningTree ::Information::operator >(const Information& r) const {
10
10
return r < *this ;
11
11
}
12
12
13
- bool MaxInfoPropagator ::Information::operator ==(const Information& r) const {
13
+ bool MaxInfoSpanningTree ::Information::operator ==(const Information& r) const {
14
14
return !(r < *this ) && !(*this < r);
15
15
}
16
16
17
- // Dijkstra
18
- void MaxInfoPropagator::run () {
17
+ // Prim's algorithm
18
+ MaxInfoSpanningTree::MaxInfoSpanningTree (
19
+ TensorView* reference,
20
+ std::shared_ptr<Information> reference_info,
21
+ Selector* selector)
22
+ : reference_(reference),
23
+ reference_info_ (reference_info),
24
+ selector_(selector) {}
25
+
26
+ void MaxInfoSpanningTree::compute_spanning_tree () {
19
27
// A set that allows us to quickly tell if a tensor has been replayed. If yes,
20
28
// then we will not bother computing if a new path to this tensor is worth
21
29
// taking (because the answer is always not worth)
@@ -28,88 +36,114 @@ void MaxInfoPropagator::run() {
28
36
// std::list instead of std::priority_queue because C++'s
29
37
// std::priority_queue does not support increase-key, and might not be
30
38
// deterministic either.
31
- std::list<NextHopInfo> propagation (1 );
32
- propagation .back ().from = nullptr ;
33
- propagation .back ().to = reference ;
34
- propagation .back ().info_to = reference_info ;
39
+ std::list<NextHopWithInfo> candidates (1 );
40
+ candidates .back (). next_hop .from = nullptr ;
41
+ candidates .back ().next_hop . to = reference_ ;
42
+ candidates .back ().info_to = reference_info_ ;
35
43
36
- // Insert the given next hop the correct position in `propagation `. If there
44
+ // Insert the given next hop the correct position in `candidates `. If there
37
45
// is an existing next hop that preserves more information, then we will just
38
46
// discard `info`.
39
- auto insertNextHopInfo = [&](const NextHopInfo & info) {
47
+ auto insertNextHop = [&](const NextHopWithInfo & info) {
40
48
if (!*(info.info_from )) {
41
49
// When there is no more information about the starting tensor,
42
- // we are not interested in continuing the propagation .
50
+ // we are not interested in continuing the path-finding .
43
51
return ;
44
52
}
45
53
// Find if there is already a path to the dest tensor
46
54
auto existing = std::find_if (
47
- propagation .begin (), propagation .end (), [&](const NextHopInfo & i) {
48
- return i.to == info.to ;
55
+ candidates .begin (), candidates .end (), [&](const NextHopWithInfo & i) {
56
+ return i.next_hop . to == info. next_hop .to ;
49
57
});
50
58
// Only insert if there is no existing path to the dest tensor, or the new
51
59
// path preserves more information about the starting tensor.
52
- if (existing == propagation .end () || *existing < info) {
53
- if (existing != propagation .end ()) {
54
- propagation .erase (existing);
60
+ if (existing == candidates .end () || *existing < info) {
61
+ if (existing != candidates .end ()) {
62
+ candidates .erase (existing);
55
63
}
56
- auto pos = std::upper_bound (propagation.begin (), propagation.end (), info);
57
- propagation.insert (pos, info);
64
+ auto pos = std::upper_bound (candidates.begin (), candidates.end (), info);
65
+ candidates.insert (pos, info);
66
+ }
67
+ };
68
+
69
+ auto allowPasC = [this ](TensorView* from, TensorView* to) {
70
+ if (selector_ == nullptr ) {
71
+ return true ;
58
72
}
73
+ return selector_->allowPasC (from, to);
59
74
};
60
75
61
- while (!propagation.empty ()) {
62
- auto next_hop = propagation.back ();
63
- propagation.pop_back ();
76
+ auto allowCasP = [this ](TensorView* from, TensorView* to) {
77
+ if (selector_ == nullptr ) {
78
+ return true ;
79
+ }
80
+ return selector_->allowCasP (from, to);
81
+ };
82
+
83
+ while (!candidates.empty ()) {
84
+ const auto next_hop_info = candidates.back ();
85
+ const auto & next_hop = next_hop_info.next_hop ;
86
+ candidates.pop_back ();
64
87
65
88
if (next_hop.from != nullptr ) {
66
89
// nullptr used to start from reference
67
- switch (next_hop.type ) {
68
- case NextHopType::C_AS_P:
69
- propagateTvCasP (next_hop.from , next_hop.to );
70
- break ;
71
- case NextHopType::P_AS_C:
72
- propagateTvPasC (next_hop.from , next_hop.to );
73
- break ;
74
- }
90
+ path_.push_back (next_hop);
75
91
}
76
92
replayed.emplace (next_hop.to );
77
93
78
94
for (auto consumer_tv : ir_utils::consumerTvsOf (next_hop.to )) {
79
- if (replayed.count (consumer_tv)) {
95
+ if (replayed.count (consumer_tv) || ! allowCasP (next_hop. to , consumer_tv) ) {
80
96
continue ;
81
97
}
82
- insertNextHopInfo (
83
- {.type = NextHopType::C_AS_P,
84
- .from = next_hop.to ,
85
- .to = consumer_tv,
86
- .info_from = next_hop.info_to ,
87
- .info_to =
88
- computeInfoCasP (next_hop.to , consumer_tv, next_hop.info_to )});
98
+ insertNextHop (
99
+ {.next_hop =
100
+ {.type = NextHopType::C_AS_P,
101
+ .from = next_hop.to ,
102
+ .to = consumer_tv},
103
+ .info_from = next_hop_info.info_to ,
104
+ .info_to = computeInfoCasP (
105
+ next_hop.to , consumer_tv, next_hop_info.info_to )});
89
106
}
90
107
91
108
for (auto producer_tv : ir_utils::producerTvsOf (next_hop.to )) {
92
- if (replayed.count (producer_tv)) {
109
+ if (replayed.count (producer_tv) || ! allowPasC (next_hop. to , producer_tv) ) {
93
110
continue ;
94
111
}
95
- insertNextHopInfo (
96
- {.type = NextHopType::P_AS_C,
97
- .from = next_hop.to ,
98
- .to = producer_tv,
99
- .info_from = next_hop.info_to ,
100
- .info_to =
101
- computeInfoPasC (next_hop.to , producer_tv, next_hop.info_to )});
112
+ insertNextHop (
113
+ {.next_hop =
114
+ {.type = NextHopType::P_AS_C,
115
+ .from = next_hop.to ,
116
+ .to = producer_tv},
117
+ .info_from = next_hop_info.info_to ,
118
+ .info_to = computeInfoPasC (
119
+ next_hop.to , producer_tv, next_hop_info.info_to )});
102
120
}
103
121
}
104
122
}
105
123
106
- MaxRootDomainInfoPropagator::RootDomainInfo::operator bool () const {
124
+ void MaxInfoSpanningTree::traverse (Propagator* propagator) {
125
+ if (path_.empty ()) {
126
+ compute_spanning_tree ();
127
+ }
128
+ for (const auto & next_hop : path_) {
129
+ switch (next_hop.type ) {
130
+ case NextHopType::C_AS_P:
131
+ propagator->propagateTvCasP (next_hop.from , next_hop.to );
132
+ break ;
133
+ case NextHopType::P_AS_C:
134
+ propagator->propagateTvPasC (next_hop.from , next_hop.to );
135
+ break ;
136
+ }
137
+ }
138
+ }
139
+
140
+ MaxRootDomainInfoSpanningTree::RootDomainInfo::operator bool () const {
107
141
return !info.empty ();
108
142
}
109
143
110
- bool MaxRootDomainInfoPropagator ::RootDomainInfo::operator <(
111
- const MaxInfoPropagator:: Information& r) const {
112
- auto rr = dynamic_cast <const MaxRootDomainInfoPropagator:: RootDomainInfo&>(r);
144
+ bool MaxRootDomainInfoSpanningTree ::RootDomainInfo::operator <(
145
+ const Information& r) const {
146
+ auto rr = dynamic_cast <const RootDomainInfo&>(r);
113
147
if (info.size () != rr.info .size ()) {
114
148
return info.size () < rr.info .size ();
115
149
}
@@ -174,17 +208,17 @@ std::unordered_set<IterDomain*> mapRFactorToRoot(
174
208
// Given the preserved reference root ID info of a producer, compute
175
209
// the corresponding info in consumer. The given info may be represented by
176
210
// producer's root domain, or rfactor domain, depending on how we reached the
177
- // producer during propagation . If the given info is already represented with
211
+ // producer during path-finding . If the given info is already represented with
178
212
// producer's rfactor domain, then we directly map it to the consumer's root
179
213
// domain. If the given info is represented with producer's root domain, we need
180
214
// to first map it to the rfactor domain of the producer, then we can map it to
181
215
// the consumer's root domain. The computed info will be represented by root
182
216
// domain as root domain contains the raw information.
183
- std::shared_ptr<MaxInfoPropagator ::Information> MaxRootDomainInfoPropagator ::
217
+ std::shared_ptr<MaxInfoSpanningTree ::Information> MaxRootDomainInfoSpanningTree ::
184
218
computeInfoCasP (
185
219
TensorView* from,
186
220
TensorView* to,
187
- std::shared_ptr<Information> from_info) {
221
+ std::shared_ptr<Information> from_info) const {
188
222
RootDomainInfo result;
189
223
190
224
TensorView* producer = from;
@@ -231,17 +265,17 @@ std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
231
265
// Given the preserved reference root ID info of a consumer, compute
232
266
// the corresponding info in producer. The given info may be represented by
233
267
// consumer's root domain, or rfactor domain, depending on how we reached the
234
- // consumer during propagation . If the given info is already represented with
268
+ // consumer during path-finding . If the given info is already represented with
235
269
// consumer's root domain, then we directly map it to the producer's rfactor
236
270
// domain. If the given info is represented with consumer's rfactor domain, we
237
271
// need to first map it to the root domain of the consumer, then we can map it
238
272
// to the producer's rfactor domain. The computed info will be represented by
239
273
// rfactor domain as rfactor domain contains the raw information.
240
- std::shared_ptr<MaxInfoPropagator ::Information> MaxRootDomainInfoPropagator ::
274
+ std::shared_ptr<MaxInfoSpanningTree ::Information> MaxRootDomainInfoSpanningTree ::
241
275
computeInfoPasC (
242
276
TensorView* from,
243
277
TensorView* to,
244
- std::shared_ptr<Information> from_info) {
278
+ std::shared_ptr<Information> from_info) const {
245
279
RootDomainInfo result;
246
280
247
281
TensorView* producer = to;
@@ -279,9 +313,9 @@ std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
279
313
// We will stop at the rfactor ids in producer, and will not further map
280
314
// them into root ids in producer. This means, we only keep the unprocessed
281
315
// raw information of a tensor. This behavior is important to make sure that
282
- // info is as accurate as possible throughout the propagation .
316
+ // info is as accurate as possible throughout the path-finding .
283
317
//
284
- // For example, if we do a C->P->C' propagation , we want to do
318
+ // For example, in a C->P->C' path , we want to do
285
319
// C(root) -> P(rfactor) -> C'(root)
286
320
// instead of
287
321
// C(root) -> P(rfactor) -> P(root) -> P(rfactor) -> C'(root)
@@ -305,6 +339,47 @@ std::shared_ptr<MaxInfoPropagator::Information> MaxRootDomainInfoPropagator::
305
339
return std::make_shared<RootDomainInfo>(std::move (result));
306
340
}
307
341
342
+ std::shared_ptr<MaxRootDomainInfoSpanningTree::RootDomainInfo>
343
+ MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo (TensorView* tv) {
344
+ RootDomainInfo result;
345
+ const auto & root_domain = tv->getRootDomain ();
346
+ result.info .reserve (root_domain.size ());
347
+ for (auto id : root_domain) {
348
+ result.info .emplace_back (RootIDInfo{{id}, true , false });
349
+ }
350
+ return std::make_shared<RootDomainInfo>(std::move (result));
351
+ }
352
+
353
+ std::shared_ptr<MaxRootDomainInfoSpanningTree::RootDomainInfo>
354
+ MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo (
355
+ TensorView* tv,
356
+ int64_t leaf_pos) {
357
+ if (leaf_pos < 0 ) {
358
+ leaf_pos += int64_t (tv->nDims ()) + 1 ;
359
+ }
360
+ TORCH_CHECK (
361
+ leaf_pos >= 0 && leaf_pos <= tv->nDims (),
362
+ " MaxRootDomainInfoSpanningTree called on an leaf_pos outside valid range." );
363
+ RootDomainInfo result;
364
+ const auto & root_domain = tv->getMaybeRFactorDomain ();
365
+ const auto & leaf_domain = tv->domain ()->domain ();
366
+ std::unordered_set<IterDomain*> selected_leaves (
367
+ leaf_domain.begin (), leaf_domain.begin () + leaf_pos);
368
+ for (auto id : root_domain) {
369
+ if (selected_leaves.count (id) > 0 ) {
370
+ result.info .emplace_back (RootIDInfo{{id}, true , tv->hasRFactor ()});
371
+ continue ;
372
+ }
373
+ for (auto selected_leaf_id : selected_leaves) {
374
+ if (DependencyCheck::isDependencyOf (id, selected_leaf_id)) {
375
+ result.info .emplace_back (RootIDInfo{{id}, true , tv->hasRFactor ()});
376
+ break ;
377
+ }
378
+ }
379
+ }
380
+ return std::make_shared<RootDomainInfo>(std::move (result));
381
+ }
382
+
308
383
} // namespace cuda
309
384
} // namespace fuser
310
385
} // namespace jit
0 commit comments