Skip to content

Commit d0d0908

Browse files
authored
Some further cleanup for the new computeAt interface (#1793)
Revert MaxProducerPosUpdater to old algo.
1 parent 45f5203 commit d0d0908

File tree

2 files changed

+86
-18
lines changed

2 files changed

+86
-18
lines changed

torch/csrc/jit/codegen/cuda/compute_at.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ void ComputeAt::runAt(
165165
TensorView* consumer,
166166
unsigned int consumer_position,
167167
ComputeAtMode mode) {
168-
FUSER_PERF_SCOPE("ComputeAt::run");
168+
FUSER_PERF_SCOPE("ComputeAt::runAt");
169169

170170
// Make sure the correct fusion is setup between this and consumer.
171171
TORCH_CHECK(
@@ -175,6 +175,10 @@ void ComputeAt::runAt(
175175
consumer,
176176
" are not in the same fusion.");
177177

178+
if (mode == ComputeAtMode::MostInlined) {
179+
consumer_position = consumer->nDims();
180+
}
181+
178182
FusionGuard fg(producer->fusion());
179183

180184
auto selected = getPropagationSubgraph(producer, consumer);
@@ -206,6 +210,10 @@ void ComputeAt::runWith(
206210
consumer,
207211
" are not in the same fusion.");
208212

213+
if (mode == ComputeAtMode::MostInlined) {
214+
producer_position = producer->nDims();
215+
}
216+
209217
FusionGuard fg(producer->fusion());
210218

211219
auto selected = getPropagationSubgraph(producer, consumer);

torch/csrc/jit/codegen/cuda/inline_propagator.cpp

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,15 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) {
309309
int pos = getReplayPosPasC(to, from);
310310
auto to_pos =
311311
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
312-
// TODO: Can we make TransformPropagator do the transformation, and
313-
// InlinePropagator only set the CA positions?
314-
// TORCH_CHECK(to_pos >= 0);
312+
if (mode_ != ComputeAtMode::MostInlined) {
313+
TORCH_CHECK(
314+
to_pos >= 0,
315+
"Unable to propagate CA position from consumer ",
316+
from,
317+
" to producer ",
318+
to,
319+
" because this would require replay.");
320+
}
315321
if (to_pos < 0) {
316322
auto replay = TransformReplay::replayPasC(to, from, pos);
317323
TORCH_INTERNAL_ASSERT(
@@ -335,9 +341,15 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) {
335341
int pos = getReplayPosCasP(to, from);
336342
auto to_pos =
337343
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
338-
// TODO: Can we make TransformPropagator do the transformation, and
339-
// InlinePropagator only set the CA positions?
340-
// TORCH_CHECK(to_pos >= 0);
344+
if (mode_ != ComputeAtMode::MostInlined) {
345+
TORCH_CHECK(
346+
to_pos >= 0,
347+
"Unable to propagate CA position from producer ",
348+
from,
349+
" to consumer ",
350+
to,
351+
" because this would require replay.");
352+
}
341353
if (to_pos < 0) {
342354
auto replay = TransformReplay::replayCasP(to, from, pos);
343355
TORCH_INTERNAL_ASSERT(
@@ -373,23 +385,71 @@ void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) {
373385
recordReplayedPos(to, from_pos);
374386
}
375387

388+
namespace {
389+
376390
// Try to find the aligned position on consumer's domain corresponding to the
377-
// compute at position of producer domain.
378-
void MaxProducerPosUpdater::handle(TensorView* consumer) {
391+
// compute at position of producer domain. Used in computeAt pass only. No
392+
// checking on actual producer-consumer relationship.
393+
unsigned int getConsumerPosAlignedToProducerCA(
394+
TensorView* consumer,
395+
TensorView* producer) {
396+
// Locate consumer's position that aligns with
397+
// the producer's new compute at axis. We need broadcast axes forwarded so we
398+
// need to replay PasC as CasP will not forward braodcast dims. For example
399+
// if we have:
400+
// T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
401+
// produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
402+
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
403+
// NVFuserTest.FusionComplexBCast1_CUDA
404+
405+
auto c2p_map =
406+
BestEffortReplay::replayPasC(
407+
producer,
408+
consumer,
409+
-1,
410+
// Compute at root domain may not be valid here, as all
411+
// producers don't have to be able to map into consumer at
412+
// max producer position. Since computeAt should be valid
413+
// and this mechanism is only intended to lower produce
414+
// position of consumer, we can simply use the pairwise map.
415+
PairwiseRootDomainMap(producer, consumer))
416+
.getReplay();
417+
418+
// Find the innermost position of consumer that has
419+
// been mapped within the producer ca axis.
379420
unsigned int consumer_pos = consumer->nDims();
380421
while (consumer_pos > 0) {
381-
for (auto producer : ir_utils::producerTvsOf(consumer)) {
382-
auto producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(
383-
producer, consumer, consumer_pos);
384-
if (producer_pos >= 0 &&
385-
producer_pos <= producer->getComputeAtPosition()) {
386-
goto finished;
387-
}
422+
auto consumer_id = consumer->axis((int)consumer_pos - 1);
423+
auto p_dom = producer->domain()->domain();
424+
if (std::any_of(
425+
p_dom.begin(),
426+
p_dom.begin() + producer->getComputeAtPosition(),
427+
[&consumer_id, &c2p_map](IterDomain* p_id) {
428+
auto c_id_it = c2p_map.find(consumer_id);
429+
if (c_id_it != c2p_map.end()) {
430+
return c_id_it->second == p_id;
431+
}
432+
return false;
433+
})) {
434+
break;
388435
}
389436
consumer_pos--;
390437
}
391-
finished:
392-
consumer->setMaxProducer(consumer_pos, true);
438+
439+
return consumer_pos;
440+
}
441+
442+
} // namespace
443+
444+
// Try to find the aligned position on consumer's domain corresponding to the
445+
// compute at position of producer domain.
446+
void MaxProducerPosUpdater::handle(TensorView* consumer) {
447+
unsigned int consumer_pos = 0;
448+
for (auto producer : ir_utils::producerTvsOf(consumer)) {
449+
consumer_pos = std::max(
450+
consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer));
451+
}
452+
consumer->setMaxProducer(consumer_pos);
393453
}
394454

395455
void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) {

0 commit comments

Comments
 (0)