Skip to content

Fix allocation of work buffers and fused_reduction::ParallelReduce with unswitch #1818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 121 additions & 58 deletions torch/csrc/jit/codegen/cuda/lower_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,17 @@ void IndexLowering::handleGridReduction(
const auto buffer_size_info =
getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent);

const auto reduce_buffer = ir_utils::allocGlobalBufferForGridComm(
buffer_size_info.size_of_privatized_buffer, out->dtype(), false);

const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm(
getGridSyncBufferSize(out_domain, for_loops_, is_persistent),
DataType::Int,
true);
auto work_buffer = allocateUniqueBuffer(
buffer_size_info.size_of_privatized_buffer,
out_tv->dtype(),
false,
out_tv,
work_buffer_map_);

auto sync_buffer_size =
getGridSyncBufferSize(out_domain, for_loops_, is_persistent);
auto sync_buffer = allocateUniqueBuffer(
sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);

const auto entrance_ind = !is_persistent
? getEntranceLinIndGridReduce(for_loops_)
Expand All @@ -408,7 +412,7 @@ void IndexLowering::handleGridReduction(
rop->init(),
out,
in,
reduce_buffer,
work_buffer,
sync_buffer,
entrance_ind,
n_entrances,
Expand All @@ -423,17 +427,11 @@ void IndexLowering::handleGridReduction(
grid_reduction->setWritePredicate(rop->writePredicate());
}

pushBack(reduce_buffer);
pushBack(sync_buffer);
pushBack(grid_reduction);
GpuLower::current()->propagateExprInfo(rop, back());

if (rop->isAllreduce()) {
// When using the fused reduction, allocate the reduction object at
// the outer-most scope
auto fused_reduction_alloc_reduction =
IrBuilder::create<kir::AllocateFusedReduction>(grid_reduction);
insertAtTopLevel(fused_reduction_alloc_reduction);
allocateUniqueFusedReduction(grid_reduction, out_tv);
}
}

Expand Down Expand Up @@ -521,22 +519,24 @@ void IndexLowering::handleGridReduction(
auto work_buf_size_info =
getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent);

std::vector<kir::Allocate*> reduce_buffers;
std::vector<kir::Allocate*> work_buffers;
std::transform(
outputs.begin(),
outputs.end(),
std::back_inserter(reduce_buffers),
std::back_inserter(work_buffers),
[&](Val* output) {
return ir_utils::allocGlobalBufferForGridComm(
return allocateUniqueBuffer(
work_buf_size_info.size_of_privatized_buffer,
output->dtype(),
false);
false,
output->as<kir::TensorIndex>()->view(),
work_buffer_map_);
});

const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm(
getGridSyncBufferSize(out_domain, for_loops_, is_persistent),
DataType::Int,
true);
auto sync_buffer_size =
getGridSyncBufferSize(out_domain, for_loops_, is_persistent);
auto sync_buffer = allocateUniqueBuffer(
sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);

const auto entrance_ind = !is_persistent
? getEntranceLinIndGridReduce(for_loops_)
Expand All @@ -556,7 +556,7 @@ void IndexLowering::handleGridReduction(
grouped_rop->initVals(),
outputs,
inputs,
reduce_buffers,
work_buffers,
sync_buffer,
entrance_ind,
n_entrances,
Expand All @@ -572,17 +572,11 @@ void IndexLowering::handleGridReduction(
grid_reduction->setWritePredicate(grouped_rop->writePredicate());
}

for (auto reduce_buffer : reduce_buffers) {
pushBack(reduce_buffer);
}
pushBack(sync_buffer);
pushBack(grid_reduction);
GpuLower::current()->propagateExprInfo(grouped_rop, back());

if (grouped_rop->isAllreduce()) {
auto fused_reduction_alloc_reduction =
IrBuilder::create<kir::AllocateFusedReduction>(grid_reduction);
insertAtTopLevel(fused_reduction_alloc_reduction);
allocateUniqueFusedReduction(grid_reduction, out_tv);
}
}

Expand Down Expand Up @@ -672,17 +666,29 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) {
getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent);

const auto work_buffer_size = buffer_size_info.size_of_privatized_buffer;
const auto out_var_buffer = ir_utils::allocGlobalBufferForGridComm(
work_buffer_size, indexed_wop->outVar()->dtype(), false);
const auto out_avg_buffer = ir_utils::allocGlobalBufferForGridComm(
work_buffer_size, indexed_wop->outAvg()->dtype(), false);
const auto out_N_buffer = ir_utils::allocGlobalBufferForGridComm(
work_buffer_size, indexed_wop->outN()->dtype(), false);

const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm(
getGridSyncBufferSize(out_domain, for_loops_, is_persistent),
DataType::Int,
true);
auto out_var_buffer = allocateUniqueBuffer(
work_buffer_size,
indexed_wop->outVar()->dtype(),
false,
indexed_wop->outVar()->as<kir::TensorIndex>()->view(),
work_buffer_map_);
auto out_avg_buffer = allocateUniqueBuffer(
work_buffer_size,
indexed_wop->outAvg()->dtype(),
false,
indexed_wop->outAvg()->as<kir::TensorIndex>()->view(),
work_buffer_map_);
auto out_N_buffer = allocateUniqueBuffer(
work_buffer_size,
indexed_wop->outN()->dtype(),
false,
indexed_wop->outN()->as<kir::TensorIndex>()->view(),
work_buffer_map_);

auto sync_buffer_size =
getGridSyncBufferSize(out_domain, for_loops_, is_persistent);
auto sync_buffer = allocateUniqueBuffer(
sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);

const auto entrance_ind = !is_persistent
? getEntranceLinIndGridReduce(for_loops_)
Expand Down Expand Up @@ -729,19 +735,13 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) {
GpuLower::current()->propagateExprInfo(indexed_wop, back());
}

pushBack(out_var_buffer);
pushBack(out_avg_buffer);
pushBack(out_N_buffer);
pushBack(sync_buffer);
pushBack(grid_welford);
GpuLower::current()->propagateExprInfo(indexed_wop, back());

if (indexed_wop->isAllreduce()) {
// When using the fused reduction, allocate the reduction object at
// the outer-most scope
auto fused_reduction_alloc_reduction =
IrBuilder::create<kir::AllocateFusedReduction>(grid_welford);
insertAtTopLevel(fused_reduction_alloc_reduction);
allocateUniqueFusedReduction(grid_welford, out_tv);
}
}

Expand Down Expand Up @@ -792,24 +792,24 @@ void IndexLowering::handle(const BroadcastOp* bop) {

// Grid broadcast
const auto out_domain = out_tv->domain();
const auto broadcast_buffer = ir_utils::allocGlobalBufferForGridComm(
const auto work_buffer_size =
getGridCommWorkBufferSize(out_domain, for_loops_, true)
.size_of_privatized_buffer,
out->dtype(),
false);
.size_of_privatized_buffer;

auto work_buffer = allocateUniqueBuffer(
work_buffer_size, out->dtype(), false, out_tv, work_buffer_map_);

const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm(
getGridSyncBufferSize(out_domain, for_loops_, true), DataType::Int, true);
auto sync_buffer_size = getGridSyncBufferSize(out_domain, for_loops_, true);
auto sync_buffer = allocateUniqueBuffer(
sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_);

auto grid_broadcast = IrBuilder::create<kir::GridBroadcast>(
indexed_expr, broadcast_buffer, sync_buffer);
indexed_expr, work_buffer, sync_buffer);

if (bop->predicate()) {
grid_broadcast->setPredicate(bop->predicate());
}

pushBack(broadcast_buffer);
pushBack(sync_buffer);
pushBack(grid_broadcast);
GpuLower::current()->propagateExprInfo(bop, back());
}
Expand Down Expand Up @@ -840,6 +840,69 @@ void IndexLowering::generate(const std::vector<Expr*>& exprs) {
}
}

kir::Allocate* IndexLowering::allocateUniqueBuffer(
Val* buffer_size,
DataType dtype,
bool zero_init,
TensorView* out_tv,
std::unordered_map<TensorView*, kir::Allocate*>& alloc_map) {
// Return an existing allocation if exists
auto it = alloc_map.find(out_tv);
if (it != alloc_map.end()) {
return it->second;
}

// No existing allocation found. Create a new one
auto new_buffer =
ir_utils::allocGlobalBufferForGridComm(buffer_size, dtype, zero_init);

// Keep track of the allocation
alloc_map.emplace(out_tv, new_buffer);

// A buffer may be used in both the unswitched paths, so it must be
// placed outside of the current scope. Simplying placing it at the
// top-level scope should work.
insertAtTopLevel(new_buffer);

return new_buffer;
}

void IndexLowering::allocateUniqueFusedReduction(
Expr* expr,
TensorView* out_tv) {
auto it = fused_reduction_map_.find(out_tv);
if (it != fused_reduction_map_.end()) {
return;
}

kir::AllocateFusedReduction* fused_reduction_alloc_reduction = nullptr;
switch (expr->getExprType().value()) {
case ExprType::GridReduction:
fused_reduction_alloc_reduction =
IrBuilder::create<kir::AllocateFusedReduction>(
expr->as<kir::GridReduction>());
break;
case ExprType::GridWelford:
fused_reduction_alloc_reduction =
IrBuilder::create<kir::AllocateFusedReduction>(
expr->as<kir::GridWelford>());
break;
case ExprType::GroupedGridReduction:
fused_reduction_alloc_reduction =
IrBuilder::create<kir::AllocateFusedReduction>(
expr->as<kir::GroupedGridReduction>());
break;
default:
TORCH_INTERNAL_ASSERT(false, "Invalid expr: ", expr->toString());
}

fused_reduction_map_.emplace(out_tv, fused_reduction_alloc_reduction);

// When using the fused reduction, allocate the reduction object at
// the outer-most scope
insertAtTopLevel(fused_reduction_alloc_reduction);
}

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
22 changes: 22 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch {

void handleGridWelford(WelfordOp* new_wop);

// Allocate a unique buffer for grid reductions and broadcast. A
// buffer is uniquely allocated for each output tensor of an
// expression.
kir::Allocate* allocateUniqueBuffer(
Val* buffer_size,
DataType dtype,
bool zero_init,
TensorView* out_tv,
std::unordered_map<TensorView*, kir::Allocate*>& alloc_map);

// Allocate a fused reduction object uniquely for a given
// TensorView. Parameter expr is the expression corresponding to the
// fused reduction.
void allocateUniqueFusedReduction(Expr* expr, TensorView* out_tv);

private:
std::vector<Expr*> lowered_exprs_;

Expand All @@ -90,6 +105,13 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch {
// Track for loops to send to indexing. Similar to what's done in
// kir::IrVisitor
std::vector<kir::ForLoop*> for_loops_;

// Maps to keep track of allocated buffers and objects that must be
// allocated only once
std::unordered_map<TensorView*, kir::Allocate*> sync_buffer_map_;
std::unordered_map<TensorView*, kir::Allocate*> work_buffer_map_;
std::unordered_map<TensorView*, kir::AllocateFusedReduction*>
fused_reduction_map_;
};

} // namespace cuda
Expand Down
12 changes: 10 additions & 2 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ void validateNoParallelBroadcastExist(kir::Kernel* kernel) {
TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) {
const int nx = 999;
const int tidx = 128;
const int bidx = 4;

if (ceilDiv(nx, tidx) > deviceSMCount()) {
GTEST_SKIP() << "Not enough SMs to run this test";
Expand All @@ -135,13 +136,20 @@ TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) {
fusion.addOutput(tv3);

tv3->split(0, tidx);
tv3->split(0, bidx);
tv3->split(0, 1); // unswitch
TransformPropagator propagator(tv3);
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);

tv3->axis(0)->parallelize(ParallelType::BIDx);
tv3->axis(1)->parallelize(ParallelType::TIDx);
tv3->axis(0)->parallelize(ParallelType::BIDy);
tv3->axis(2)->parallelize(ParallelType::BIDx);
tv3->axis(3)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));

// Just to make sure fused_reduction and work buffers are allocated
// uniquely
tv1->axis(1)->parallelize(ParallelType::Unswitch);

GpuLower gpulw(&fusion);
validateNoParallelBroadcastExist(gpulw.kernel());

Expand Down