@@ -309,9 +309,15 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) {
309
309
int pos = getReplayPosPasC (to, from);
310
310
auto to_pos =
311
311
TransformReplay::getMatchedLeafPosWithoutReplayPasC (to, from, pos);
312
- // TODO: Can we make TransformPropagator do the transformation, and
313
- // InlinePropagator only set the CA positions?
314
- // TORCH_CHECK(to_pos >= 0);
312
+ if (mode_ != ComputeAtMode::MostInlined) {
313
+ TORCH_CHECK (
314
+ to_pos >= 0 ,
315
+ " Unable to propagate CA position from consumer " ,
316
+ from,
317
+ " to producer " ,
318
+ to,
319
+ " because this would require replay." );
320
+ }
315
321
if (to_pos < 0 ) {
316
322
auto replay = TransformReplay::replayPasC (to, from, pos);
317
323
TORCH_INTERNAL_ASSERT (
@@ -335,9 +341,15 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) {
335
341
int pos = getReplayPosCasP (to, from);
336
342
auto to_pos =
337
343
TransformReplay::getMatchedLeafPosWithoutReplayCasP (to, from, pos);
338
- // TODO: Can we make TransformPropagator do the transformation, and
339
- // InlinePropagator only set the CA positions?
340
- // TORCH_CHECK(to_pos >= 0);
344
+ if (mode_ != ComputeAtMode::MostInlined) {
345
+ TORCH_CHECK (
346
+ to_pos >= 0 ,
347
+ " Unable to propagate CA position from producer " ,
348
+ from,
349
+ " to consumer " ,
350
+ to,
351
+ " because this would require replay." );
352
+ }
341
353
if (to_pos < 0 ) {
342
354
auto replay = TransformReplay::replayCasP (to, from, pos);
343
355
TORCH_INTERNAL_ASSERT (
@@ -373,23 +385,71 @@ void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) {
373
385
recordReplayedPos (to, from_pos);
374
386
}
375
387
388
+ namespace {
389
+
376
390
// Try to find the aligned position on consumer's domain corresponding to the
377
- // compute at position of producer domain.
378
- void MaxProducerPosUpdater::handle (TensorView* consumer) {
391
+ // compute at position of producer domain. Used in computeAt pass only. No
392
+ // checking on actual producer-consumer relationship.
393
+ unsigned int getConsumerPosAlignedToProducerCA (
394
+ TensorView* consumer,
395
+ TensorView* producer) {
396
+ // Locate consumer's position that aligns with
397
+ // the producer's new compute at axis. We need broadcast axes forwarded so we
398
+ // need to replay PasC as CasP will not forward braodcast dims. For example
399
+ // if we have:
400
+ // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
401
+ // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
402
+ // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
403
+ // NVFuserTest.FusionComplexBCast1_CUDA
404
+
405
+ auto c2p_map =
406
+ BestEffortReplay::replayPasC (
407
+ producer,
408
+ consumer,
409
+ -1 ,
410
+ // Compute at root domain may not be valid here, as all
411
+ // producers don't have to be able to map into consumer at
412
+ // max producer position. Since computeAt should be valid
413
+ // and this mechanism is only intended to lower produce
414
+ // position of consumer, we can simply use the pairwise map.
415
+ PairwiseRootDomainMap (producer, consumer))
416
+ .getReplay ();
417
+
418
+ // Find the innermost position of consumer that has
419
+ // been mapped within the producer ca axis.
379
420
unsigned int consumer_pos = consumer->nDims ();
380
421
while (consumer_pos > 0 ) {
381
- for (auto producer : ir_utils::producerTvsOf (consumer)) {
382
- auto producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC (
383
- producer, consumer, consumer_pos);
384
- if (producer_pos >= 0 &&
385
- producer_pos <= producer->getComputeAtPosition ()) {
386
- goto finished;
387
- }
422
+ auto consumer_id = consumer->axis ((int )consumer_pos - 1 );
423
+ auto p_dom = producer->domain ()->domain ();
424
+ if (std::any_of (
425
+ p_dom.begin (),
426
+ p_dom.begin () + producer->getComputeAtPosition (),
427
+ [&consumer_id, &c2p_map](IterDomain* p_id) {
428
+ auto c_id_it = c2p_map.find (consumer_id);
429
+ if (c_id_it != c2p_map.end ()) {
430
+ return c_id_it->second == p_id;
431
+ }
432
+ return false ;
433
+ })) {
434
+ break ;
388
435
}
389
436
consumer_pos--;
390
437
}
391
- finished:
392
- consumer->setMaxProducer (consumer_pos, true );
438
+
439
+ return consumer_pos;
440
+ }
441
+
442
+ } // namespace
443
+
444
+ // Try to find the aligned position on consumer's domain corresponding to the
445
+ // compute at position of producer domain.
446
+ void MaxProducerPosUpdater::handle (TensorView* consumer) {
447
+ unsigned int consumer_pos = 0 ;
448
+ for (auto producer : ir_utils::producerTvsOf (consumer)) {
449
+ consumer_pos = std::max (
450
+ consumer_pos, getConsumerPosAlignedToProducerCA (consumer, producer));
451
+ }
452
+ consumer->setMaxProducer (consumer_pos);
393
453
}
394
454
395
455
void MaxProducerPosUpdater::propagateTvPasC (TensorView* from, TensorView* to) {
0 commit comments