Skip to content

Commit 38c7f3c

Browse files
authored
InlinePropagator please don't replay (csarofeen#1797)
This PR makes `InlinePropagator` just set compute-at positions. It will not replay any tensor. If you want to replay, please use `TransformPropagator` and friends to do so. Currently, `InlinePropagator` is already asserting no replay for standard and best effort compute at. So this PR is mostly about making most inlined compute at works as well. This PR also does a lot of cleanups to remove the word "replay" from comments and variable and function names from `InlinePropagator`. I also cleaned up `recordReplayedPos` and `retrieveReplayedPos`, now the logic is much easier to understand.
1 parent 3f2c263 commit 38c7f3c

File tree

6 files changed

+156
-141
lines changed

6 files changed

+156
-141
lines changed

torch/csrc/jit/codegen/cuda/compute_at.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,20 @@ void ComputeAt::runAt(
184184
auto selected = getPropagationSubgraph(producer, consumer);
185185
InlinePropagatorSelector selector(selected);
186186

187-
TransformPropagator propagator(consumer, consumer_position);
188187
InlinePropagator inline_propagator(
189188
selector.selected(), consumer, consumer_position, mode);
190189
MaxProducerPosUpdater updater;
191190

192191
MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector);
193-
path.traverse(&propagator);
192+
193+
if (mode == ComputeAtMode::MostInlined) {
194+
MostInlinedTransformPropagator propagator;
195+
path.traverse(&propagator);
196+
} else {
197+
TransformPropagator propagator(consumer, consumer_position);
198+
path.traverse(&propagator);
199+
}
200+
194201
path.traverse(&inline_propagator);
195202
path.traverse(&updater);
196203
}
@@ -219,13 +226,19 @@ void ComputeAt::runWith(
219226
auto selected = getPropagationSubgraph(producer, consumer);
220227
InlinePropagatorSelector selector(selected);
221228

222-
TransformPropagator propagator(producer, producer_position);
223229
InlinePropagator inline_propagator(
224230
selector.selected(), producer, producer_position, mode);
225231
MaxProducerPosUpdater updater;
226232

227233
MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector);
228-
path.traverse(&propagator);
234+
235+
if (mode == ComputeAtMode::MostInlined) {
236+
MostInlinedTransformPropagator propagator;
237+
path.traverse(&propagator);
238+
} else {
239+
TransformPropagator propagator(producer, producer_position);
240+
path.traverse(&propagator);
241+
}
229242
path.traverse(&inline_propagator);
230243
path.traverse(&updater);
231244
}

torch/csrc/jit/codegen/cuda/inline_propagator.cpp

Lines changed: 55 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -178,22 +178,11 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv) {
178178
return max_pos;
179179
}
180180

181-
size_t InlinePropagator::adjustComputeAtPos(TensorView* tv, size_t pos) {
182-
pos = std::min<size_t>(pos, getMaxPosAll(tv));
183-
184-
// hoist inner most broadcast
185-
while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) {
186-
pos--;
187-
}
188-
189-
return pos;
190-
}
191-
192-
size_t InlinePropagator::getReplayPosPasC(
181+
size_t InlinePropagator::getFromPosPasC(
193182
TensorView* producer,
194183
TensorView* consumer) {
195184
size_t max_pos = max_pos_calc.getMaxPosPasC(producer, consumer);
196-
size_t pos = retrieveReplayedPos(consumer);
185+
size_t pos = mapped_reference_pos_.at(consumer);
197186

198187
if (mode_ == ComputeAtMode::BestEffort) {
199188
return std::min(pos, max_pos);
@@ -203,22 +192,22 @@ size_t InlinePropagator::getReplayPosPasC(
203192

204193
TORCH_INTERNAL_ASSERT(
205194
pos <= max_pos,
206-
"Invalid compute at position detected in compute at when trying to replay producer: ",
207-
producer,
208-
" as consumer: ",
195+
"Invalid compute at position detected in compute at when trying to propagate the CA position from consumer: ",
209196
consumer,
197+
" to producer: ",
198+
producer,
210199
" tried to do this at position: ",
211200
pos,
212201
" but max position that's allowed is ",
213202
max_pos);
214203
return pos;
215204
}
216205

217-
size_t InlinePropagator::getReplayPosCasP(
206+
size_t InlinePropagator::getFromPosCasP(
218207
TensorView* consumer,
219208
TensorView* producer) {
220209
size_t max_pos = max_pos_calc.getMaxPosCasP(consumer, producer);
221-
size_t pos = retrieveReplayedPos(producer);
210+
size_t pos = mapped_reference_pos_.at(producer);
222211

223212
if (mode_ == ComputeAtMode::BestEffort) {
224213
return std::min(pos, max_pos);
@@ -228,42 +217,28 @@ size_t InlinePropagator::getReplayPosCasP(
228217

229218
TORCH_INTERNAL_ASSERT(
230219
pos <= max_pos,
231-
"Invalid compute at position detected in compute at when trying to replay consumer: ",
232-
consumer,
233-
" as producer: ",
220+
"Invalid compute at position detected in compute at when trying to propagate the CA position from producer: ",
234221
producer,
222+
" to consumer: ",
223+
consumer,
235224
" tried to do this at position: ",
236225
pos,
237226
" but max position that's allowed is ",
238227
max_pos);
239228
return pos;
240229
}
241230

242-
void InlinePropagator::recordReplayedPos(TensorView* tv, size_t pos) {
243-
if (selected_.count(tv)) {
244-
auto new_pos = adjustComputeAtPos(tv, pos);
245-
if (pos != new_pos) {
246-
replayed_pos_[tv] = pos;
247-
pos = new_pos;
231+
void InlinePropagator::setCAPos(TensorView* tv, size_t pos) {
232+
if (selected_.count(tv) && !tv->isFusionInput()) {
233+
pos = std::min<size_t>(pos, getMaxPosAll(tv));
234+
// hoist inner most broadcast
235+
while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) {
236+
pos--;
248237
}
249-
if (!tv->isFusionInput()) {
250-
tv->setComputeAt(pos);
251-
} else {
252-
replayed_pos_[tv] = pos;
253-
}
254-
} else {
255-
replayed_pos_[tv] = pos;
238+
tv->setComputeAt(pos);
256239
}
257240
}
258241

259-
size_t InlinePropagator::retrieveReplayedPos(TensorView* tv) {
260-
auto it = replayed_pos_.find(tv);
261-
if (it != replayed_pos_.end()) {
262-
return it->second;
263-
}
264-
return tv->getComputeAtPosition();
265-
}
266-
267242
InlinePropagator::InlinePropagator(
268243
std::unordered_set<TensorView*> selected,
269244
TensorView* reference,
@@ -288,101 +263,62 @@ InlinePropagator::InlinePropagator(
288263
".");
289264
}
290265

291-
namespace {
292-
293-
// Make sure if tv is set to new_td it doesn't violate set compute at and max
294-
// produce at positions.
295-
bool validateDomain(TensorView* tv, TensorDomain* new_td) {
296-
auto first_mismatch =
297-
BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td);
298-
return first_mismatch >= (int)tv->getMaxProducerPosition() &&
299-
first_mismatch >= (int)tv->getComputeAtPosition();
300-
}
301-
302-
} // namespace
303-
304266
void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) {
305267
if (is_first_) {
306268
is_first_ = false;
307-
recordReplayedPos(reference_, reference_pos_);
269+
setCAPos(reference_, reference_pos_);
270+
mapped_reference_pos_[reference_] = reference_pos_;
308271
}
309-
int pos = getReplayPosPasC(to, from);
272+
int from_pos = getFromPosPasC(to, from);
310273
auto to_pos =
311-
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
312-
if (mode_ != ComputeAtMode::MostInlined) {
313-
TORCH_CHECK(
314-
to_pos >= 0,
315-
"Unable to propagate CA position from consumer ",
316-
from,
317-
" to producer ",
318-
to,
319-
" because this would require replay.");
320-
}
321-
if (to_pos < 0) {
322-
auto replay = TransformReplay::replayPasC(to, from, pos);
323-
TORCH_INTERNAL_ASSERT(
324-
validateDomain(to, replay.first),
325-
"Tried to set the domain of ",
326-
to,
327-
" to ",
328-
replay.first,
329-
" but that would invalidate previously compute at position or max producer position.");
330-
to->setDomain(replay.first);
331-
to_pos = replay.second;
332-
}
333-
recordReplayedPos(to, to_pos);
274+
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
275+
TORCH_CHECK(
276+
to_pos >= 0,
277+
"Unable to propagate CA position from consumer ",
278+
from,
279+
" to producer ",
280+
to,
281+
" because this would require replay.");
282+
setCAPos(to, to_pos);
283+
mapped_reference_pos_[to] = to_pos;
334284
}
335285

336286
void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) {
337287
if (is_first_) {
338288
is_first_ = false;
339-
recordReplayedPos(reference_, reference_pos_);
289+
setCAPos(reference_, reference_pos_);
290+
mapped_reference_pos_[reference_] = reference_pos_;
340291
}
341-
int pos = getReplayPosCasP(to, from);
292+
int from_pos = getFromPosCasP(to, from);
342293
auto to_pos =
343-
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
344-
if (mode_ != ComputeAtMode::MostInlined) {
345-
TORCH_CHECK(
346-
to_pos >= 0,
347-
"Unable to propagate CA position from producer ",
348-
from,
349-
" to consumer ",
350-
to,
351-
" because this would require replay.");
352-
}
353-
if (to_pos < 0) {
354-
auto replay = TransformReplay::replayCasP(to, from, pos);
355-
TORCH_INTERNAL_ASSERT(
356-
validateDomain(to, replay.first),
357-
"Tried to set the domain of ",
358-
to,
359-
" to ",
360-
replay.first,
361-
" but that would invalidate previously compute at position or max producer position.");
362-
to->setDomain(replay.first);
363-
to_pos = replay.second;
364-
}
365-
recordReplayedPos(to, to_pos);
294+
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
295+
TORCH_CHECK(
296+
to_pos >= 0,
297+
"Unable to propagate CA position from producer ",
298+
from,
299+
" to consumer ",
300+
to,
301+
" because this would require replay.");
302+
setCAPos(to, to_pos);
303+
mapped_reference_pos_[to] = to_pos;
366304
}
367305

368306
void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) {
369307
if (is_first_) {
370308
is_first_ = false;
371-
recordReplayedPos(reference_, reference_pos_);
372-
}
373-
auto from_pos = retrieveReplayedPos(from);
374-
if (!TransformReplay::fullSelfMatching(to, from)) {
375-
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
376-
TORCH_INTERNAL_ASSERT(
377-
validateDomain(to, replay),
378-
"Tried to set the domain of ",
379-
to,
380-
" to ",
381-
replay,
382-
" but that would invalidate previously compute at position or max producer position.");
383-
to->setDomain(replay);
309+
setCAPos(reference_, reference_pos_);
310+
mapped_reference_pos_[reference_] = reference_pos_;
384311
}
385-
recordReplayedPos(to, from_pos);
312+
auto from_pos = mapped_reference_pos_.at(from);
313+
TORCH_CHECK(
314+
TransformReplay::fullSelfMatching(to, from),
315+
"Unable to propagate CA position from ",
316+
from,
317+
" to sibling ",
318+
to,
319+
" because this would require replay.");
320+
setCAPos(to, from_pos);
321+
mapped_reference_pos_[to] = from_pos;
386322
}
387323

388324
namespace {

torch/csrc/jit/codegen/cuda/inline_propagator.h

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,18 @@ class MaxPosCalculator {
5151
bool allow_unmappable) const;
5252

5353
public:
54-
// Returns the position at which tv can be relayed within.
54+
// Returns the position at which tv can be inlined within.
5555
size_t getMaxPosSelf(
5656
TensorView* tv,
5757
bool allow_reduction,
5858
bool allow_vectorize,
5959
bool allow_unmappable) const;
6060

61-
// Returns the maximum position producer can be replayed based on consumer
61+
// Returns the maximum position producer can be inlined based on consumer
6262
// given the set ComputeAtMode
6363
size_t getMaxPosPasC(TensorView* producer, TensorView* consumer) const;
6464

65-
// Returns the maximum position consumer can be replayed based on producer
65+
// Returns the maximum position consumer can be inlined based on producer
6666
// given the set ComputeAtMode
6767
size_t getMaxPosCasP(TensorView* consumer, TensorView* producer) const;
6868

@@ -74,34 +74,34 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {
7474
// that can be shared across both directions.
7575
size_t getMaxPosAll(TensorView* tv);
7676

77-
// Returns position of getMaxPosAll while also hoisting outside broadcast
78-
// dimensions.
79-
size_t adjustComputeAtPos(TensorView* tv, size_t pos);
80-
81-
// Returns the replay position in consumer that producer should be replayed as
77+
// Returns the inline position in consumer that producer should be inlined as
8278
// based on consumer, taking into consideration the max possible returned by
8379
// getMaxPos{PasC, CasP}, the compute at mode type.
84-
size_t getReplayPosPasC(TensorView* producer, TensorView* consumer);
80+
size_t getFromPosPasC(TensorView* producer, TensorView* consumer);
8581

86-
// Returns the replay position in producer that consumer should be replayed as
82+
// Returns the inline position in producer that consumer should be inlined as
8783
// based on producer, taking into consideration the max possible returned by
8884
// getMaxPos{PasC, CasP}, the compute at mode type.
89-
size_t getReplayPosCasP(TensorView* consumer, TensorView* producer);
85+
size_t getFromPosCasP(TensorView* consumer, TensorView* producer);
9086

91-
// Sets the compute at position of tv and records the position in
92-
// replayed_pos_
93-
void recordReplayedPos(TensorView* tv, size_t pos);
87+
// We use mapped_reference_pos_ to keep track of the outer axes information of
88+
// the reference tensor. That is, mapped_reference_pos_[tv] answers the
89+
// question "What outer axes in tv are shared with the specified reference
90+
// tensor's outer axes?". However, when we actually set the CA position of tv,
91+
// we might not want to set it as mapped_reference_pos_[tv] because because we
92+
// don't want to inline certain things (such as vectorized dimensions, inner
93+
// most broadcasting, etc.).
94+
std::unordered_map<TensorView*, size_t> mapped_reference_pos_;
9495

95-
// Returns the entry for tv in replayed_pos_ if it exists, else returns the
96-
// compute at position of tv.
97-
size_t retrieveReplayedPos(TensorView* tv);
96+
// Actually set the computeAt position. This does not necessarily equal to
97+
// mapped_reference_pos_[tv] because we don't want to inline certain things.
98+
void setCAPos(TensorView* tv, size_t pos);
9899

99100
const MaxPosCalculator max_pos_calc;
100101
std::unordered_set<TensorView*> selected_;
101102
TensorView* reference_;
102103
size_t reference_pos_;
103104
ComputeAtMode mode_ = ComputeAtMode::Standard;
104-
std::unordered_map<TensorView*, size_t> replayed_pos_;
105105
bool is_first_ = true;
106106

107107
public:

torch/csrc/jit/codegen/cuda/ir_interface_nodes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ enum class ComputeAtMode { Standard, BestEffort, MostInlined };
157157
class InlinePropagator;
158158
class MaxProducerPosUpdater;
159159
class TransformPropagator;
160+
struct MostInlinedTransformPropagator;
160161
class TransformIter;
161162
class TransformReplay;
162163
class OptOutMutator;
@@ -457,6 +458,7 @@ class TORCH_CUDA_CU_API TensorView : public Val {
457458
void applyMmaSwizzle(MmaOptions options);
458459

459460
friend TORCH_CUDA_CU_API TransformPropagator;
461+
friend TORCH_CUDA_CU_API MostInlinedTransformPropagator;
460462
friend TORCH_CUDA_CU_API TransformReplay;
461463
friend TORCH_CUDA_CU_API OptOutMutator;
462464
friend TORCH_CUDA_CU_API InlinePropagator;

0 commit comments

Comments
 (0)