|
1 | 1 | #include <torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.h>
|
2 | 2 |
|
| 3 | +#include <torch/csrc/jit/codegen/cuda/arith.h> |
3 | 4 | #include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
4 | 5 | #include <torch/csrc/jit/codegen/cuda/inlining.h>
|
5 | 6 | #include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
|
@@ -493,8 +494,9 @@ TensorView* sortAndRFactor(TensorView* reference_tv) {
|
493 | 494 | return ir_utils::rfactorHelper(reference_tv, rfactor_axes);
|
494 | 495 | }
|
495 | 496 |
|
496 |
| -void projectPersistentBuffers(Fusion* fusion) { |
| 497 | +std::vector<TensorView*> projectPersistentBuffers(Fusion* fusion) { |
497 | 498 | auto persistent_info = scheduler_utils::persistentBuffers(fusion);
|
| 499 | + std::vector<TensorView*> dummy_outputs; |
498 | 500 |
|
499 | 501 | // Convenience accessors
|
500 | 502 | const auto& persistent_buffers = persistent_info.persistent_buffers;
|
@@ -562,10 +564,39 @@ void projectPersistentBuffers(Fusion* fusion) {
|
562 | 564 | for (auto use : persistent_use_of_buffer) {
|
563 | 565 | TORCH_INTERNAL_ASSERT(use->definition() != nullptr);
|
564 | 566 | auto buffer_replicate = RecomputeTv::recompute(buffer);
|
| 567 | + // Create a shortcut buffer <--> buffer_replicate for propagation. |
| 568 | + // Why is this needed? |
| 569 | + // Consider that we have a fusion |
| 570 | + // |
| 571 | + // T0[I] |
| 572 | + // T1[b b I] = broadcast(T0) |
| 573 | + // T2[b b r] = reduction(T1) |
| 574 | + // T3[b b b] = broadcast(T2) |
| 575 | + // T4[b, b, I] = T1 + T3 |
| 576 | + // T5[b, b, r] = reduction(T4) |
| 577 | + // |
| 578 | + // After projection, it becomes |
| 579 | + // |
| 580 | + // T0[I] |
| 581 | + // T1[b b I] = broadcast(T0) |
| 582 | + // T2[b b r] = reduction(T1) |
| 583 | + // T3[b b b] = broadcast(T2) |
| 584 | + // T6[b b I] = broadcast(T0) |
| 585 | + // T4[b, b, I] = T6 + T3 |
| 586 | + // T5[b, b, r] = reduction(T4) |
| 587 | + // |
| 588 | + // During schedule, we need to propagate from T2 to T5. However, in the |
| 589 | + // resulting DAG, neither the propagation path T2->T3->T4->T5 nor |
| 590 | + // T2->T1->T0->T6->T4->T5 works because they both have missing root |
| 591 | + // domain. But adding `T7 = T1 + T6` creates a new propagation path |
| 592 | + // `T2->T1->T7->T6->T4->T5` which has all root domain information. |
| 593 | + // See FusionBroadcastPersistentReduction_CUDA for an example |
| 594 | + dummy_outputs.emplace_back(add(buffer_replicate, buffer)); |
565 | 595 | ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate);
|
566 | 596 | }
|
567 | 597 | }
|
568 | 598 | }
|
| 599 | + return dummy_outputs; |
569 | 600 | }
|
570 | 601 |
|
571 | 602 | } // namespace reduction_scheduler_utils
|
|
0 commit comments