Skip to content

Commit 5f375d0

Browse files
authored
More cleanup on InlinePropagator (csarofeen#1800)
I just realized that `InlinePropagator` can be further simplified because it no longer replays. Since `InlinePropagator` is no longer doing replay, it is more like a "for each" problem rather than a propagation problem: For each tensor `tv`, if we already know what is the max position of `tv` that is mapped to the reference tensor's selected outer dimensions(stored in `mapped_reference_pos_` in the code), setting the CA position is a very local operation, and is as simple as checking `tv` itself and all its consumers to determine the inline position. `InlinePropagator` is not completely a "for each" problem only because the computation of `mapped_reference_pos_` is a propagation problem. This cleanup reorganizes the code of `InlinePropagator` so it is clear that `InlinePropagator` is nothing but a two-step process: Step 1: Do a propagation to find the `mapped_reference_pos_` for all tensors. Step 2: For each tensor, check itself and its consumers to determine the CA position. Conceptually, I would like to split step 1 with step 2. Because this split makes these concepts decoupled. Especially, this PR makes `mapped_reference_pos_` only contain info about the reference tensor, and is independent of the CA position (Currently, this is not true for best effort and most inlined computeAt without this PR). Now, in my view, `InlinePropagator` is conceptually very simple and easy to understand. In terms of implementation, step 1 and step 2 can be interleaved, because when we don't need to know the `mapped_reference_pos_` for `tv`'s consumer in order to compute the CA position of `tv`. So a one-pass traverse could do both step 1 and step 2 altogether.
1 parent 8d384da commit 5f375d0

File tree

2 files changed

+53
-122
lines changed

2 files changed

+53
-122
lines changed

torch/csrc/jit/codegen/cuda/inline_propagator.cpp

Lines changed: 49 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -104,129 +104,56 @@ size_t MaxPosCalculator::getMaxPosSelf(
104104
return std::distance(dom.begin(), iter);
105105
}
106106

107-
// Return the max position in consumer that producer can be inlined to
108-
// Cannot inline:
109-
// Reduction dimensions in producer
110-
// Block broadcast dimensions in producer
111-
// Vectorized dimensions in producer or consumer
112-
// Unrolled dimensions in producer or consumer
113-
// Dimensions derived from root dimensions that exist in both but are
114-
// unmappable
115-
size_t MaxPosCalculator::getMaxPosC2P(
116-
TensorView* consumer,
117-
TensorView* producer) const {
118-
// Limit max position based on vectorized dims in consumer.
119-
auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true);
120-
121-
auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer);
122-
auto replay_PasC =
123-
BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map);
124-
auto c2p_replay_map = replay_PasC.getReplay();
125-
126-
for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0;
127-
consumer_pos--) {
128-
auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1));
129-
if (map_it != c2p_replay_map.end()) {
130-
auto p_id = map_it->second;
131-
if (!isAllowedID(p_id, producer, true, false, false)) {
132-
max_consumer_pos = consumer_pos - 1;
133-
}
134-
}
135-
}
136-
137-
return max_consumer_pos;
138-
}
139-
140107
// Return the max position in producer that can be inlined to consumer
141108
// Cannot inline:
142-
// Reduction dimensions in producer
143-
// Vectorized dimensions in producer or consumer
144-
// Unrolled dimensions in producer or consumer
145-
// Dimensions derived from root dimensions that exist in both but are
146-
// unmappable
147-
size_t MaxPosCalculator::getMaxPosP2C(
109+
// Vectorized dimensions in consumer
110+
// Unrolled dimensions in consumer
111+
size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
148112
TensorView* producer,
149113
TensorView* consumer) const {
150-
auto max_producer_pos = getMaxPosSelf(producer, false, false, false);
151-
152114
auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer);
153115
auto replay_CasP =
154116
BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map);
155117
auto p2c_replay_map = replay_CasP.getReplay();
156118

157-
for (size_t producer_pos = max_producer_pos; producer_pos > 0;
158-
producer_pos--) {
159-
auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1));
119+
for (size_t producer_pos = 0; producer_pos < producer->nDims();
120+
producer_pos++) {
121+
auto map_it = p2c_replay_map.find(producer->axis(producer_pos));
160122
if (map_it != p2c_replay_map.end()) {
161123
auto c_id = map_it->second;
162124
if (!isAllowedID(c_id, consumer, true, false, true)) {
163-
max_producer_pos = producer_pos - 1;
125+
return producer_pos;
164126
}
165127
}
166128
}
167-
168-
return max_producer_pos;
129+
return producer->nDims();
169130
}
170131

171132
size_t InlinePropagator::getMaxPosAll(TensorView* tv) {
172133
auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false);
173134
for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
174-
// consumers are always replayed consistently
175-
max_pos =
176-
std::min<size_t>(max_pos, max_pos_calc.getMaxPosP2C(tv, consumer_tv));
135+
max_pos = std::min<size_t>(
136+
max_pos, max_pos_calc.getMaxProducerPosFromConsumer(tv, consumer_tv));
177137
}
178138
return max_pos;
179139
}
180140

181-
size_t InlinePropagator::getFromPosC2P(TensorView* from, TensorView* to) {
182-
size_t max_pos = max_pos_calc.getMaxPosC2P(from, to);
183-
size_t pos = mapped_reference_pos_.at(from);
184-
185-
if (mode_ == ComputeAtMode::BestEffort) {
186-
return std::min(pos, max_pos);
187-
} else if (mode_ == ComputeAtMode::MostInlined) {
188-
return max_pos;
189-
}
190-
191-
TORCH_INTERNAL_ASSERT(
192-
pos <= max_pos,
193-
"Invalid compute at position detected in compute at when trying to propagate the CA position from consumer: ",
194-
from,
195-
" to producer: ",
196-
to,
197-
" tried to do this at position: ",
198-
pos,
199-
" but max position that's allowed is ",
200-
max_pos);
201-
return pos;
202-
}
203-
204-
size_t InlinePropagator::getFromPosP2C(TensorView* from, TensorView* to) {
205-
size_t max_pos = max_pos_calc.getMaxPosP2C(from, to);
206-
size_t pos = mapped_reference_pos_.at(from);
207-
208-
if (mode_ == ComputeAtMode::BestEffort) {
209-
return std::min(pos, max_pos);
210-
} else if (mode_ == ComputeAtMode::MostInlined) {
211-
return max_pos;
212-
}
213-
214-
TORCH_INTERNAL_ASSERT(
215-
pos <= max_pos,
216-
"Invalid compute at position detected in compute at when trying to propagate the CA position from producer: ",
217-
from,
218-
" to consumer: ",
219-
to,
220-
" tried to do this at position: ",
221-
pos,
222-
" but max position that's allowed is ",
223-
max_pos);
224-
return pos;
225-
}
226-
227-
void InlinePropagator::setCAPos(TensorView* tv, size_t pos) {
141+
void InlinePropagator::setCAPos(TensorView* tv) {
142+
size_t pos = mapped_reference_pos_.at(tv);
228143
if (selected_.count(tv) && !tv->isFusionInput()) {
229-
pos = std::min<size_t>(pos, getMaxPosAll(tv));
144+
auto max_pos = getMaxPosAll(tv);
145+
if (mode_ == ComputeAtMode::Standard) {
146+
TORCH_INTERNAL_ASSERT(
147+
pos <= max_pos,
148+
"Invalid compute at position detected in InlinePropagator when trying to set the CA position of: ",
149+
tv,
150+
" to ",
151+
pos,
152+
", max position that's allowed is ",
153+
max_pos);
154+
} else {
155+
pos = std::min<size_t>(pos, max_pos);
156+
}
230157
// hoist inner most broadcast
231158
while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) {
232159
pos--;
@@ -262,10 +189,16 @@ InlinePropagator::InlinePropagator(
262189
void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
263190
if (is_first_) {
264191
is_first_ = false;
265-
setCAPos(reference_, reference_pos_);
266192
mapped_reference_pos_[reference_] = reference_pos_;
193+
setCAPos(reference_);
194+
}
195+
// Step 1: find mapped_reference_pos_[to]
196+
int from_pos;
197+
if (mode_ != ComputeAtMode::MostInlined) {
198+
from_pos = mapped_reference_pos_.at(from);
199+
} else {
200+
from_pos = from->nDims();
267201
}
268-
int from_pos = getFromPosC2P(from, to);
269202
auto to_pos =
270203
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
271204
TORCH_CHECK(
@@ -275,17 +208,24 @@ void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
275208
" to producer ",
276209
to,
277210
" because this would require replay.");
278-
setCAPos(to, to_pos);
279211
mapped_reference_pos_[to] = to_pos;
212+
// Step 2: set CA position of `to`
213+
setCAPos(to);
280214
}
281215

282216
void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
283217
if (is_first_) {
284218
is_first_ = false;
285-
setCAPos(reference_, reference_pos_);
286219
mapped_reference_pos_[reference_] = reference_pos_;
220+
setCAPos(reference_);
221+
}
222+
// Step 1: find mapped_reference_pos_[to]
223+
int from_pos;
224+
if (mode_ != ComputeAtMode::MostInlined) {
225+
from_pos = mapped_reference_pos_.at(from);
226+
} else {
227+
from_pos = from->nDims();
287228
}
288-
int from_pos = getFromPosP2C(from, to);
289229
auto to_pos =
290230
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
291231
TORCH_CHECK(
@@ -295,16 +235,18 @@ void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
295235
" to consumer ",
296236
to,
297237
" because this would require replay.");
298-
setCAPos(to, to_pos);
299238
mapped_reference_pos_[to] = to_pos;
239+
// Step 2: set CA position of `to`
240+
setCAPos(to);
300241
}
301242

302243
void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
303244
if (is_first_) {
304245
is_first_ = false;
305-
setCAPos(reference_, reference_pos_);
306246
mapped_reference_pos_[reference_] = reference_pos_;
247+
setCAPos(reference_);
307248
}
249+
// Step 1: find mapped_reference_pos_[to]
308250
auto from_pos = mapped_reference_pos_.at(from);
309251
TORCH_CHECK(
310252
TransformReplay::fullSelfMatching(to, from),
@@ -313,8 +255,9 @@ void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
313255
" to sibling ",
314256
to,
315257
" because this would require replay.");
316-
setCAPos(to, from_pos);
317258
mapped_reference_pos_[to] = from_pos;
259+
// Step 2: set CA position of `to`
260+
setCAPos(to);
318261
}
319262

320263
namespace {

torch/csrc/jit/codegen/cuda/inline_propagator.h

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,9 @@ class MaxPosCalculator {
6060

6161
// Returns the maximum position producer can be inlined based on consumer
6262
// given the set ComputeAtMode
63-
size_t getMaxPosC2P(TensorView* from, TensorView* to) const;
64-
65-
// Returns the maximum position consumer can be inlined based on producer
66-
// given the set ComputeAtMode
67-
size_t getMaxPosP2C(TensorView* from, TensorView* to) const;
63+
size_t getMaxProducerPosFromConsumer(
64+
TensorView* producer,
65+
TensorView* consumer) const;
6866

6967
MaxPosCalculator(ComputeAtMode mode);
7068
};
@@ -74,16 +72,6 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {
7472
// that can be shared across both directions.
7573
size_t getMaxPosAll(TensorView* tv);
7674

77-
// Returns the inline position in consumer that producer should be inlined as
78-
// based on consumer, taking into consideration the max possible returned by
79-
// getMaxPos{P2C, C2P}, the compute at mode type.
80-
size_t getFromPosC2P(TensorView* from, TensorView* to);
81-
82-
// Returns the inline position in producer that consumer should be inlined as
83-
// based on producer, taking into consideration the max possible returned by
84-
// getMaxPos{P2C, C2P}, the compute at mode type.
85-
size_t getFromPosP2C(TensorView* from, TensorView* to);
86-
8775
// We use mapped_reference_pos_ to keep track of the outer axes information of
8876
// the reference tensor. That is, mapped_reference_pos_[tv] answers the
8977
// question "What outer axes in tv are shared with the specified reference
@@ -95,7 +83,7 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {
9583

9684
// Actually set the computeAt position. This does not necessarily equal to
9785
// mapped_reference_pos_[tv] because we don't want to inline certain things.
98-
void setCAPos(TensorView* tv, size_t pos);
86+
void setCAPos(TensorView* tv);
9987

10088
const MaxPosCalculator max_pos_calc;
10189
std::unordered_set<TensorView*> selected_;

0 commit comments

Comments
 (0)