Skip to content

Commit 3b87896

Browse files
authored
Fix allocation of work buffers and fused_reduction::ParallelReduce with unswitch (csarofeen#1818)
* Make sure unique work buffers are used even if expressions are placed in two paths due to unswitch * Avoid duplicated allocations of fused_reduction
1 parent 4cae122 commit 3b87896

File tree

3 files changed

+153
-60
lines changed

3 files changed

+153
-60
lines changed

torch/csrc/jit/codegen/cuda/lower_index.cpp

Lines changed: 121 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,17 @@ void IndexLowering::handleGridReduction(
382382
const auto buffer_size_info =
383383
getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent);
384384

385-
const auto reduce_buffer = ir_utils::allocGlobalBufferForGridComm(
386-
buffer_size_info.size_of_privatized_buffer, out->dtype(), false);
387-
388-
const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm(
389-
getGridSyncBufferSize(out_domain, for_loops_, is_persistent),
390-
DataType::Int,
391-
true);
385+
auto work_buffer = allocateUniqueBuffer(
386+
buffer_size_info.size_of_privatized_buffer,
387+
out_tv->dtype(),
388+
false,
389+
out_tv,
390+
work_buffer_map_);
391+
392+
auto sync_buffer_size =
393+
getGridSyncBufferSize(out_domain, for_loops_, is_persistent);
394+
auto sync_buffer = allocateUniqueBuffer(
395+
sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);
392396

393397
const auto entrance_ind = !is_persistent
394398
? getEntranceLinIndGridReduce(for_loops_)
@@ -408,7 +412,7 @@ void IndexLowering::handleGridReduction(
408412
rop->init(),
409413
out,
410414
in,
411-
reduce_buffer,
415+
work_buffer,
412416
sync_buffer,
413417
entrance_ind,
414418
n_entrances,
@@ -423,17 +427,11 @@ void IndexLowering::handleGridReduction(
423427
grid_reduction->setWritePredicate(rop->writePredicate());
424428
}
425429

426-
pushBack(reduce_buffer);
427-
pushBack(sync_buffer);
428430
pushBack(grid_reduction);
429431
GpuLower::current()->propagateExprInfo(rop, back());
430432

431433
if (rop->isAllreduce()) {
432-
// When using the fused reduction, allocate the reduction object at
433-
// the outer-most scope
434-
auto fused_reduction_alloc_reduction =
435-
IrBuilder::create<kir::AllocateFusedReduction>(grid_reduction);
436-
insertAtTopLevel(fused_reduction_alloc_reduction);
434+
allocateUniqueFusedReduction(grid_reduction, out_tv);
437435
}
438436
}
439437

@@ -521,22 +519,24 @@ void IndexLowering::handleGridReduction(
521519
auto work_buf_size_info =
522520
getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent);
523521

524-
std::vector<kir::Allocate*> reduce_buffers;
522+
std::vector<kir::Allocate*> work_buffers;
525523
std::transform(
526524
outputs.begin(),
527525
outputs.end(),
528-
std::back_inserter(reduce_buffers),
526+
std::back_inserter(work_buffers),
529527
[&](Val* output) {
530-
return ir_utils::allocGlobalBufferForGridComm(
528+
return allocateUniqueBuffer(
531529
work_buf_size_info.size_of_privatized_buffer,
532530
output->dtype(),
533-
false);
531+
false,
532+
output->as<kir::TensorIndex>()->view(),
533+
work_buffer_map_);
534534
});
535535

536-
const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm(
537-
getGridSyncBufferSize(out_domain, for_loops_, is_persistent),
538-
DataType::Int,
539-
true);
536+
auto sync_buffer_size =
537+
getGridSyncBufferSize(out_domain, for_loops_, is_persistent);
538+
auto sync_buffer = allocateUniqueBuffer(
539+
sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);
540540

541541
const auto entrance_ind = !is_persistent
542542
? getEntranceLinIndGridReduce(for_loops_)
@@ -556,7 +556,7 @@ void IndexLowering::handleGridReduction(
556556
grouped_rop->initVals(),
557557
outputs,
558558
inputs,
559-
reduce_buffers,
559+
work_buffers,
560560
sync_buffer,
561561
entrance_ind,
562562
n_entrances,
@@ -572,17 +572,11 @@ void IndexLowering::handleGridReduction(
572572
grid_reduction->setWritePredicate(grouped_rop->writePredicate());
573573
}
574574

575-
for (auto reduce_buffer : reduce_buffers) {
576-
pushBack(reduce_buffer);
577-
}
578-
pushBack(sync_buffer);
579575
pushBack(grid_reduction);
580576
GpuLower::current()->propagateExprInfo(grouped_rop, back());
581577

582578
if (grouped_rop->isAllreduce()) {
583-
auto fused_reduction_alloc_reduction =
584-
IrBuilder::create<kir::AllocateFusedReduction>(grid_reduction);
585-
insertAtTopLevel(fused_reduction_alloc_reduction);
579+
allocateUniqueFusedReduction(grid_reduction, out_tv);
586580
}
587581
}
588582

@@ -672,17 +666,29 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) {
672666
getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent);
673667

674668
const auto work_buffer_size = buffer_size_info.size_of_privatized_buffer;
675-
const auto out_var_buffer = ir_utils::allocGlobalBufferForGridComm(
676-
work_buffer_size, indexed_wop->outVar()->dtype(), false);
677-
const auto out_avg_buffer = ir_utils::allocGlobalBufferForGridComm(
678-
work_buffer_size, indexed_wop->outAvg()->dtype(), false);
679-
const auto out_N_buffer = ir_utils::allocGlobalBufferForGridComm(
680-
work_buffer_size, indexed_wop->outN()->dtype(), false);
681-
682-
const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm(
683-
getGridSyncBufferSize(out_domain, for_loops_, is_persistent),
684-
DataType::Int,
685-
true);
669+
auto out_var_buffer = allocateUniqueBuffer(
670+
work_buffer_size,
671+
indexed_wop->outVar()->dtype(),
672+
false,
673+
indexed_wop->outVar()->as<kir::TensorIndex>()->view(),
674+
work_buffer_map_);
675+
auto out_avg_buffer = allocateUniqueBuffer(
676+
work_buffer_size,
677+
indexed_wop->outAvg()->dtype(),
678+
false,
679+
indexed_wop->outAvg()->as<kir::TensorIndex>()->view(),
680+
work_buffer_map_);
681+
auto out_N_buffer = allocateUniqueBuffer(
682+
work_buffer_size,
683+
indexed_wop->outN()->dtype(),
684+
false,
685+
indexed_wop->outN()->as<kir::TensorIndex>()->view(),
686+
work_buffer_map_);
687+
688+
auto sync_buffer_size =
689+
getGridSyncBufferSize(out_domain, for_loops_, is_persistent);
690+
auto sync_buffer = allocateUniqueBuffer(
691+
sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);
686692

687693
const auto entrance_ind = !is_persistent
688694
? getEntranceLinIndGridReduce(for_loops_)
@@ -729,19 +735,13 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) {
729735
GpuLower::current()->propagateExprInfo(indexed_wop, back());
730736
}
731737

732-
pushBack(out_var_buffer);
733-
pushBack(out_avg_buffer);
734-
pushBack(out_N_buffer);
735-
pushBack(sync_buffer);
736738
pushBack(grid_welford);
737739
GpuLower::current()->propagateExprInfo(indexed_wop, back());
738740

739741
if (indexed_wop->isAllreduce()) {
740742
// When using the fused reduction, allocate the reduction object at
741743
// the outer-most scope
742-
auto fused_reduction_alloc_reduction =
743-
IrBuilder::create<kir::AllocateFusedReduction>(grid_welford);
744-
insertAtTopLevel(fused_reduction_alloc_reduction);
744+
allocateUniqueFusedReduction(grid_welford, out_tv);
745745
}
746746
}
747747

@@ -792,24 +792,24 @@ void IndexLowering::handle(const BroadcastOp* bop) {
792792

793793
// Grid broadcast
794794
const auto out_domain = out_tv->domain();
795-
const auto broadcast_buffer = ir_utils::allocGlobalBufferForGridComm(
795+
const auto work_buffer_size =
796796
getGridCommWorkBufferSize(out_domain, for_loops_, true)
797-
.size_of_privatized_buffer,
798-
out->dtype(),
799-
false);
797+
.size_of_privatized_buffer;
798+
799+
auto work_buffer = allocateUniqueBuffer(
800+
work_buffer_size, out->dtype(), false, out_tv, work_buffer_map_);
800801

801-
const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm(
802-
getGridSyncBufferSize(out_domain, for_loops_, true), DataType::Int, true);
802+
auto sync_buffer_size = getGridSyncBufferSize(out_domain, for_loops_, true);
803+
auto sync_buffer = allocateUniqueBuffer(
804+
sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);
803805

804806
auto grid_broadcast = IrBuilder::create<kir::GridBroadcast>(
805-
indexed_expr, broadcast_buffer, sync_buffer);
807+
indexed_expr, work_buffer, sync_buffer);
806808

807809
if (bop->predicate()) {
808810
grid_broadcast->setPredicate(bop->predicate());
809811
}
810812

811-
pushBack(broadcast_buffer);
812-
pushBack(sync_buffer);
813813
pushBack(grid_broadcast);
814814
GpuLower::current()->propagateExprInfo(bop, back());
815815
}
@@ -840,6 +840,69 @@ void IndexLowering::generate(const std::vector<Expr*>& exprs) {
840840
}
841841
}
842842

843+
kir::Allocate* IndexLowering::allocateUniqueBuffer(
844+
Val* buffer_size,
845+
DataType dtype,
846+
bool zero_init,
847+
TensorView* out_tv,
848+
std::unordered_map<TensorView*, kir::Allocate*>& alloc_map) {
849+
// Return an existing allocation if exists
850+
auto it = alloc_map.find(out_tv);
851+
if (it != alloc_map.end()) {
852+
return it->second;
853+
}
854+
855+
// No existing allocation found. Create a new one
856+
auto new_buffer =
857+
ir_utils::allocGlobalBufferForGridComm(buffer_size, dtype, zero_init);
858+
859+
// Keep track of the allocation
860+
alloc_map.emplace(out_tv, new_buffer);
861+
862+
// A buffer may be used in both the unswitched paths, so it must be
863+
// placed outside of the current scope. Simplying placing it at the
864+
// top-level scope should work.
865+
insertAtTopLevel(new_buffer);
866+
867+
return new_buffer;
868+
}
869+
870+
void IndexLowering::allocateUniqueFusedReduction(
871+
Expr* expr,
872+
TensorView* out_tv) {
873+
auto it = fused_reduction_map_.find(out_tv);
874+
if (it != fused_reduction_map_.end()) {
875+
return;
876+
}
877+
878+
kir::AllocateFusedReduction* fused_reduction_alloc_reduction = nullptr;
879+
switch (expr->getExprType().value()) {
880+
case ExprType::GridReduction:
881+
fused_reduction_alloc_reduction =
882+
IrBuilder::create<kir::AllocateFusedReduction>(
883+
expr->as<kir::GridReduction>());
884+
break;
885+
case ExprType::GridWelford:
886+
fused_reduction_alloc_reduction =
887+
IrBuilder::create<kir::AllocateFusedReduction>(
888+
expr->as<kir::GridWelford>());
889+
break;
890+
case ExprType::GroupedGridReduction:
891+
fused_reduction_alloc_reduction =
892+
IrBuilder::create<kir::AllocateFusedReduction>(
893+
expr->as<kir::GroupedGridReduction>());
894+
break;
895+
default:
896+
TORCH_INTERNAL_ASSERT(false, "Invalid expr: ", expr->toString());
897+
}
898+
899+
fused_reduction_map_.emplace(out_tv, fused_reduction_alloc_reduction);
900+
901+
// When using the fused reduction, allocate the reduction object at
902+
// the outer-most scope
903+
insertAtTopLevel(fused_reduction_alloc_reduction);
904+
}
905+
843906
} // namespace cuda
844907
} // namespace fuser
845908
} // namespace jit

torch/csrc/jit/codegen/cuda/lower_index.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,21 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch {
7676

7777
void handleGridWelford(WelfordOp* new_wop);
7878

79+
// Allocate a unique buffer for grid reductions and broadcast. A
80+
// buffer is uniquely allocated for each output tensor of an
81+
// expression.
82+
kir::Allocate* allocateUniqueBuffer(
83+
Val* buffer_size,
84+
DataType dtype,
85+
bool zero_init,
86+
TensorView* out_tv,
87+
std::unordered_map<TensorView*, kir::Allocate*>& alloc_map);
88+
89+
// Allocate a fused reduction object uniquely for a given
90+
// TensorView. Parameter expr is the expression corresponding to the
91+
// fused reduction.
92+
void allocateUniqueFusedReduction(Expr* expr, TensorView* out_tv);
93+
7994
private:
8095
std::vector<Expr*> lowered_exprs_;
8196

@@ -90,6 +105,13 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch {
90105
// Track for loops to send to indexing. Similar to what's done in
91106
// kir::IrVisitor
92107
std::vector<kir::ForLoop*> for_loops_;
108+
109+
// Maps to keep track of allocated buffers and objects that must be
110+
// allocated only once
111+
std::unordered_map<TensorView*, kir::Allocate*> sync_buffer_map_;
112+
std::unordered_map<TensorView*, kir::Allocate*> work_buffer_map_;
113+
std::unordered_map<TensorView*, kir::AllocateFusedReduction*>
114+
fused_reduction_map_;
93115
};
94116

95117
} // namespace cuda

torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ void validateNoParallelBroadcastExist(kir::Kernel* kernel) {
117117
TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) {
118118
const int nx = 999;
119119
const int tidx = 128;
120+
const int bidx = 4;
120121

121122
if (ceilDiv(nx, tidx) > deviceSMCount()) {
122123
GTEST_SKIP() << "Not enough SMs to run this test";
@@ -135,13 +136,20 @@ TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) {
135136
fusion.addOutput(tv3);
136137

137138
tv3->split(0, tidx);
139+
tv3->split(0, bidx);
140+
tv3->split(0, 1); // unswitch
138141
TransformPropagator propagator(tv3);
139142
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
140143

141-
tv3->axis(0)->parallelize(ParallelType::BIDx);
142-
tv3->axis(1)->parallelize(ParallelType::TIDx);
144+
tv3->axis(0)->parallelize(ParallelType::BIDy);
145+
tv3->axis(2)->parallelize(ParallelType::BIDx);
146+
tv3->axis(3)->parallelize(ParallelType::TIDx);
143147
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
144148

149+
// Just to make sure fused_reduction and work buffers are allocated
150+
// uniquely
151+
tv1->axis(1)->parallelize(ParallelType::Unswitch);
152+
145153
GpuLower gpulw(&fusion);
146154
validateNoParallelBroadcastExist(gpulw.kernel());
147155

0 commit comments

Comments
 (0)