Skip to content

Commit 19e5af7

Browse files
authored
Add short cut for recomputed tv (#2134)
1 parent f725cfb commit 19e5af7

File tree

6 files changed

+81
-5
lines changed

6 files changed

+81
-5
lines changed

torch/csrc/jit/codegen/cuda/ir_iostream.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ void IrPrinter::handle(const TensorView* tv) {
161161
case MemoryType::Local:
162162
os_ << "_l";
163163
break;
164+
default:
165+
TORCH_INTERNAL_ASSERT(false, "Unknown tensor memory type.");
164166
}
165167
handle(tv->domain());
166168

@@ -704,6 +706,8 @@ void IrPrinter::handle(const kir::TensorIndex* ti) {
704706
case MemoryType::Local:
705707
os_ << "_l";
706708
break;
709+
default:
710+
TORCH_INTERNAL_ASSERT(false, "Unknown tensor memory type.");
707711
}
708712
os_ << "[";
709713
for (auto index : ti->indices()) {

torch/csrc/jit/codegen/cuda/kernel.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class KernelIrScanner : private IrVisitor {
7777
summary_.dynamic_lmem_allocations.emplace_back(allocate);
7878
}
7979
break;
80+
default:
81+
TORCH_INTERNAL_ASSERT(false, "Unknown memory type to allocate.");
8082
}
8183
}
8284

torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,11 +974,14 @@ TORCH_CUDA_CU_API void schedulePersistentKernel(
974974

975975
// Project the persistent buffers to the inputs. Inputs will be cached in a
976976
// later step, this will move them to be in a register buffer as expected.
977+
// dummy outputs are helper tensors to make sure persistent buffer projection
978+
// does not create trouble for transform propagation.
977979
// TODO: Fix projected persistent buffers with view
978980
// https://github.com/csarofeen/pytorch/issues/2054
981+
std::vector<TensorView*> dummy_outputs;
979982
if (rparams.project_persistent_buffers &&
980983
ir_utils::getViewOps(fusion).empty()) {
981-
reduction_scheduler_utils::projectPersistentBuffers(fusion);
984+
dummy_outputs = reduction_scheduler_utils::projectPersistentBuffers(fusion);
982985
}
983986

984987
// Cache tensors before grabbing any references to reductions as cache_before
@@ -1043,6 +1046,9 @@ TORCH_CUDA_CU_API void schedulePersistentKernel(
10431046
reference_tv != nullptr && reduction_tv != nullptr,
10441047
"Need these two tensor views to finish the scheduling.");
10451048

1049+
for (auto output : dummy_outputs) {
1050+
fusion->addOutput(output);
1051+
}
10461052
reduction_scheduler_utils::multiReductionInliner(
10471053
fusion,
10481054
rparams,
@@ -1051,6 +1057,9 @@ TORCH_CUDA_CU_API void schedulePersistentKernel(
10511057
reduction_tvs,
10521058
cached_inputs,
10531059
cached_outputs);
1060+
for (auto output : dummy_outputs) {
1061+
fusion->removeOutput(output);
1062+
}
10541063
}
10551064

10561065
} // namespace cuda

torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.h>
22

3+
#include <torch/csrc/jit/codegen/cuda/arith.h>
34
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
45
#include <torch/csrc/jit/codegen/cuda/inlining.h>
56
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
@@ -493,8 +494,9 @@ TensorView* sortAndRFactor(TensorView* reference_tv) {
493494
return ir_utils::rfactorHelper(reference_tv, rfactor_axes);
494495
}
495496

496-
void projectPersistentBuffers(Fusion* fusion) {
497+
std::vector<TensorView*> projectPersistentBuffers(Fusion* fusion) {
497498
auto persistent_info = scheduler_utils::persistentBuffers(fusion);
499+
std::vector<TensorView*> dummy_outputs;
498500

499501
// Convenience accessors
500502
const auto& persistent_buffers = persistent_info.persistent_buffers;
@@ -562,10 +564,39 @@ void projectPersistentBuffers(Fusion* fusion) {
562564
for (auto use : persistent_use_of_buffer) {
563565
TORCH_INTERNAL_ASSERT(use->definition() != nullptr);
564566
auto buffer_replicate = RecomputeTv::recompute(buffer);
567+
// Create a shortcut buffer <--> buffer_replicate for propagation.
568+
// Why is this needed?
569+
// Consider that we have a fusion
570+
//
571+
// T0[I]
572+
// T1[b b I] = broadcast(T0)
573+
// T2[b b r] = reduction(T1)
574+
// T3[b b b] = broadcast(T2)
575+
// T4[b, b, I] = T1 + T3
576+
// T5[b, b, r] = reduction(T4)
577+
//
578+
// After projection, it becomes
579+
//
580+
// T0[I]
581+
// T1[b b I] = broadcast(T0)
582+
// T2[b b r] = reduction(T1)
583+
// T3[b b b] = broadcast(T2)
584+
// T6[b b I] = broadcast(T0)
585+
// T4[b, b, I] = T6 + T3
586+
// T5[b, b, r] = reduction(T4)
587+
//
588+
// During schedule, we need to propagate from T2 to T5. However, in the
589+
// resulting DAG, neither the propagation path T2->T3->T4->T5 nor
590+
// T2->T1->T0->T6->T4->T5 works because they both have missing root
591+
// domain. But adding `T7 = T1 + T6` creates a new propagation path
592+
// `T2->T1->T7->T6->T4->T5` which has all root domain information.
593+
// See FusionBroadcastPersistentReduction_CUDA for an example
594+
dummy_outputs.emplace_back(add(buffer_replicate, buffer));
565595
ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate);
566596
}
567597
}
568598
}
599+
return dummy_outputs;
569600
}
570601

571602
} // namespace reduction_scheduler_utils

torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,11 @@ TORCH_CUDA_CU_API void multiReductionInliner(
4343
// Reduction inliner expects an rfactored domain.
4444
TORCH_CUDA_CU_API TensorView* sortAndRFactor(TensorView* reference_tv);
4545

46-
// Take all projectable persistent buffers, and move them to the inputs.
47-
TORCH_CUDA_CU_API void projectPersistentBuffers(Fusion* fusion);
46+
// Take all projectable persistent buffers, and move them to the inputs. This
47+
// function create dummy outputs which should be used in later stages of the
48+
// scheduling.
49+
TORCH_CUDA_CU_API std::vector<TensorView*> projectPersistentBuffers(
50+
Fusion* fusion);
4851

4952
} // namespace reduction_scheduler_utils
5053
} // namespace cuda

torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6408,10 +6408,37 @@ TEST_F(NVFuserTest, FusionVectorizeRepro1843_CUDA) {
64086408
testValidate(fusion, cg_outputs, {t1, t0}, {ref}, __LINE__, __FILE__);
64096409
}
64106410

6411+
TEST_F(NVFuserTest, FusionBroadcastPersistentReduction_CUDA) {
6412+
// Simplified repro for
6413+
// https://github.com/csarofeen/pytorch/issues/2094
6414+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
6415+
auto fusion = fusion_ptr.get();
6416+
FusionGuard fg(fusion);
6417+
6418+
auto tv0 = makeContigTensor(2, DataType::Half);
6419+
auto tv1 = castOp(DataType::Float, tv0);
6420+
auto tv2 = broadcast(tv1, {true, true, false, false});
6421+
auto tv3 = sum(tv2, {-1}, true);
6422+
auto tv4 = add(tv2, tv3); // TODO: changing this to tv1 there is still errors
6423+
auto tv5 = sum(tv4, {-1});
6424+
fusion->addInput(tv0);
6425+
fusion->addOutput(tv5);
6426+
6427+
auto options = at::TensorOptions().dtype(kHalf).device(at::kCUDA, 0);
6428+
auto t0 = at::randn({1024, 768}, options);
6429+
auto t1 = t0.view({1, 1, 1024, 768}).to(kFloat);
6430+
auto t3 = t1.sum({-1}, true);
6431+
auto t4 = t1 + t3;
6432+
auto t5 = t4.sum({-1});
6433+
6434+
FusionExecutorCache fec(std::move(fusion_ptr));
6435+
auto cg_outputs = fec.runFusionWithInputs({t0});
6436+
testValidate(fusion, cg_outputs, {t0}, {t5}, __LINE__, __FILE__);
6437+
}
6438+
64116439
// Repro for
64126440
// https://github.com/csarofeen/pytorch/issues/2094
64136441
TEST_F(NVFuserTest, FusionRepro2094_CUDA) {
6414-
return;
64156442
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
64166443
auto fusion = fusion_ptr.get();
64176444
FusionGuard fg(fusion);

0 commit comments

Comments
 (0)