Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 35 additions & 35 deletions third_party/nvfuser/csrc/scheduler/reduction_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -721,42 +721,42 @@ std::vector<TensorView*> projectPersistentBuffers(Fusion* fusion) {
}
persistent_use_of_buffer.emplace_back(use);
}
}

// For all uses that do not go towards the reduction operations in the
// persistent section of the graph, recompute the persistent buffer.
for (auto use : persistent_use_of_buffer) {
TORCH_INTERNAL_ASSERT(use->definition() != nullptr);
auto buffer_replicate = RecomputeTv::recompute(buffer);
// Create a shortcut buffer <--> buffer_replicate for propagation.
// Why is this needed?
// Consider that we have a fusion
//
// T0[I]
// T1[b b I] = broadcast(T0)
// T2[b b r] = reduction(T1)
// T3[b b b] = broadcast(T2)
// T4[b, b, I] = T1 + T3
// T5[b, b, r] = reduction(T4)
//
// After projection, it becomes
//
// T0[I]
// T1[b b I] = broadcast(T0)
// T2[b b r] = reduction(T1)
// T3[b b b] = broadcast(T2)
// T6[b b I] = broadcast(T0)
// T4[b, b, I] = T6 + T3
// T5[b, b, r] = reduction(T4)
//
// During schedule, we need to propagate from T2 to T5. However, in the
// resulting DAG, neither the propagation path T2->T3->T4->T5 nor
// T2->T1->T0->T6->T4->T5 works because they both have missing root
// domain. But adding `T7 = T1 + T6` creates a new propagation path
// `T2->T1->T7->T6->T4->T5` which has all root domain information.
// See FusionBroadcastPersistentReduction_CUDA for an example
dummy_outputs.emplace_back(add(buffer_replicate, buffer));
ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate);
}
// For all uses that do not go towards the reduction operations in the
// persistent section of the graph, recompute the persistent buffer.
for (auto use : persistent_use_of_buffer) {
TORCH_INTERNAL_ASSERT(use->definition() != nullptr);
auto buffer_replicate = RecomputeTv::recompute(buffer);
// Create a shortcut buffer <--> buffer_replicate for propagation.
// Why is this needed?
// Consider that we have a fusion
//
// T0[I]
// T1[b b I] = broadcast(T0)
// T2[b b r] = reduction(T1)
// T3[b b b] = broadcast(T2)
// T4[b, b, I] = T1 + T3
// T5[b, b, r] = reduction(T4)
//
// After projection, it becomes
//
// T0[I]
// T1[b b I] = broadcast(T0)
// T2[b b r] = reduction(T1)
// T3[b b b] = broadcast(T2)
// T6[b b I] = broadcast(T0)
// T4[b, b, I] = T6 + T3
// T5[b, b, r] = reduction(T4)
//
// During schedule, we need to propagate from T2 to T5. However, in the
// resulting DAG, neither the propagation path T2->T3->T4->T5 nor
// T2->T1->T0->T6->T4->T5 works because they both have missing root
// domain. But adding `T7 = T1 + T6` creates a new propagation path
// `T2->T1->T7->T6->T4->T5` which has all root domain information.
// See FusionBroadcastPersistentReduction_CUDA for an example
dummy_outputs.emplace_back(add(buffer_replicate, buffer));
ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate);
}
}
return dummy_outputs;
Expand Down