Skip to content

Commit cd899f6

Browse files
jjsjann123naoyam
andauthored
persistent_use_of_buffer is accumulated over all the resolution points. (#4)
Cherry-picking from: csarofeen/pytorch#2576 Author: Naoya Maruyama [email protected] Date: Mon Mar 13 17:50:01 2023 -0700 persistent_use_of_buffer is accumulated over all the resolution points. (#2576) Recomputation for each persistent use should be done after the accumulation is done. Currently, recomputation and replaceVal can be done redundantly. For example, on A100, that happens with NvFuserScheduler_BatchNorm_fp32/64/32/256. Co-authored-by: Naoya Maruyama <[email protected]>
1 parent 48b0cb4 commit cd899f6

File tree

1 file changed

+35
-35
lines changed

1 file changed

+35
-35
lines changed

csrc/scheduler/reduction_utils.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -728,42 +728,42 @@ std::vector<TensorView*> projectPersistentBuffers(Fusion* fusion) {
728728
}
729729
persistent_use_of_buffer.emplace_back(use);
730730
}
731+
}
731732

732-
// For all uses that do not go towards the reduction operations in the
733-
// persistent section of the graph, recompute the persistent buffer.
734-
for (auto use : persistent_use_of_buffer) {
735-
TORCH_INTERNAL_ASSERT(use->definition() != nullptr);
736-
auto buffer_replicate = RecomputeTv::recompute(buffer);
737-
// Create a shortcut buffer <--> buffer_replicate for propagation.
738-
// Why is this needed?
739-
// Consider that we have a fusion
740-
//
741-
// T0[I]
742-
// T1[b b I] = broadcast(T0)
743-
// T2[b b r] = reduction(T1)
744-
// T3[b b b] = broadcast(T2)
745-
// T4[b, b, I] = T1 + T3
746-
// T5[b, b, r] = reduction(T4)
747-
//
748-
// After projection, it becomes
749-
//
750-
// T0[I]
751-
// T1[b b I] = broadcast(T0)
752-
// T2[b b r] = reduction(T1)
753-
// T3[b b b] = broadcast(T2)
754-
// T6[b b I] = broadcast(T0)
755-
// T4[b, b, I] = T6 + T3
756-
// T5[b, b, r] = reduction(T4)
757-
//
758-
// During schedule, we need to propagate from T2 to T5. However, in the
759-
// resulting DAG, neither the propagation path T2->T3->T4->T5 nor
760-
// T2->T1->T0->T6->T4->T5 works because they both have missing root
761-
// domain. But adding `T7 = T1 + T6` creates a new propagation path
762-
// `T2->T1->T7->T6->T4->T5` which has all root domain information.
763-
// See FusionBroadcastPersistentReduction_CUDA for an example
764-
dummy_outputs.emplace_back(add(buffer_replicate, buffer));
765-
ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate);
766-
}
733+
// For all uses that do not go towards the reduction operations in the
734+
// persistent section of the graph, recompute the persistent buffer.
735+
for (auto use : persistent_use_of_buffer) {
736+
TORCH_INTERNAL_ASSERT(use->definition() != nullptr);
737+
auto buffer_replicate = RecomputeTv::recompute(buffer);
738+
// Create a shortcut buffer <--> buffer_replicate for propagation.
739+
// Why is this needed?
740+
// Consider that we have a fusion
741+
//
742+
// T0[I]
743+
// T1[b b I] = broadcast(T0)
744+
// T2[b b r] = reduction(T1)
745+
// T3[b b b] = broadcast(T2)
746+
// T4[b, b, I] = T1 + T3
747+
// T5[b, b, r] = reduction(T4)
748+
//
749+
// After projection, it becomes
750+
//
751+
// T0[I]
752+
// T1[b b I] = broadcast(T0)
753+
// T2[b b r] = reduction(T1)
754+
// T3[b b b] = broadcast(T2)
755+
// T6[b b I] = broadcast(T0)
756+
// T4[b, b, I] = T6 + T3
757+
// T5[b, b, r] = reduction(T4)
758+
//
759+
// During schedule, we need to propagate from T2 to T5. However, in the
760+
// resulting DAG, neither the propagation path T2->T3->T4->T5 nor
761+
// T2->T1->T0->T6->T4->T5 works because they both have missing root
762+
// domain. But adding `T7 = T1 + T6` creates a new propagation path
763+
// `T2->T1->T7->T6->T4->T5` which has all root domain information.
764+
// See FusionBroadcastPersistentReduction_CUDA for an example
765+
dummy_outputs.emplace_back(add(buffer_replicate, buffer));
766+
ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate);
767767
}
768768
}
769769
return dummy_outputs;

0 commit comments

Comments
 (0)