Skip to content

Grouped grid welford #1921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1417,12 +1417,12 @@ WelfordResult Welford(
out_avg,
out_var,
out_N, /*out var/avg/count */
tv, /*in var/avg/count */
FusionGuard::getCurFusion()->zeroVal(),
FusionGuard::getCurFusion()->oneVal(),
init_avg_val,
init_var_val,
init_N, /*init var/avg/count */
tv,
FusionGuard::getCurFusion()->zeroVal(),
FusionGuard::getCurFusion()->oneVal()); /*in var/avg/count */
init_N); /*init var/avg/count */

return WelfordResult(out_avg, out_var, out_N);
}
Expand Down
168 changes: 167 additions & 1 deletion torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,16 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << kTab << func_args << ");\n";
}

void handle(const kir::GroupedGridWelford* grouped_gwop) final {
if (grouped_gwop->isAllreduce()) {
generateGroupedGridAllreduceWelford(grouped_gwop);
return;
} else {
TORCH_INTERNAL_ASSERT(
false, "Non-allreduce grouped grid welford is not yet supported");
}
}

// Enumerates all combinations of index values of grouped
// loops. Each combination is a vector of loop index values. The
// length of the vector is the number of grouped loops.
Expand Down Expand Up @@ -1872,6 +1882,154 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << kTab << func_args << ");\n";
}

// Mostly the same as the grouped grid redution version
void generateGroupedGridAllreduceWelford(
const kir::GroupedGridWelford* grouped_gwop) {
TORCH_INTERNAL_ASSERT(grouped_gwop->isAllreduce());

const auto index_replacement_maps = getLoopIndexReplacementMaps();
const auto num_grouped_iterations = index_replacement_maps.size();

// This is also checked at the lowering validaiton time, so it
// isn't strictly necessary.
TORCH_INTERNAL_ASSERT(
num_grouped_iterations * grouped_gwop->numExprs() <=
kMaxNumGroupedReductions,
"Too many grouped reductions: ",
grouped_gwop->toString(),
". Up to ",
kMaxNumGroupedReductions,
" reductions are allowed.");

ArgumentBuilder data_types;
ArgumentBuilder index_types;

// Note that the data type of var and avg and that of N are the
// same with all the welford ops since we only support
// grouping of iterations.
const auto data_type = grouped_gwop->outputVals().at(0).avg()->dtype();
const auto index_type = grouped_gwop->outputVals().at(0).N()->dtype();

std::array<ArgumentBuilder, 3> out_args;
std::array<ArgumentBuilder, 3> in_args;
std::array<ArgumentBuilder, 3> init_args;
std::array<ArgumentBuilder, 3> work_bufs;

ArgumentBuilder bool_types;
ArgumentBuilder read_preds;
ArgumentBuilder write_preds;

for (const auto expr_index : c10::irange(grouped_gwop->numExprs())) {
const auto& output = grouped_gwop->outputVals().at(expr_index);
const auto& input = grouped_gwop->inputVals().at(expr_index);
const auto& init = grouped_gwop->initVals().at(expr_index);

for (const auto& group_index :
c10::irange(index_replacement_maps.size())) {
// Set the index replacement map with the concrete values of
// indices of grouped loops.
index_replacement_map_ = index_replacement_maps.at(group_index);

data_types.arg(data_type);
index_types.arg(index_type);

auto work_buffer_offset = group_index == 0
? "0"
: (genInline(grouped_gwop->buffer_stride()) + " * " +
std::to_string(group_index));

// Setup arguments for avg, var, and N
for (const auto i : c10::irange(3)) {
out_args[i].arg(gen(output.get(i)));
in_args[i].arg(gen(input.get(i)));
init_args[i].arg(gen(init.get(i)));
const auto work_buffer = grouped_gwop->reduction_buffers()[i]
.at(expr_index)
->buffer()
->as<TensorView>();
work_bufs[i]
.arg("&")
.append(varName(work_buffer))
.append("[")
.append(work_buffer_offset)
.append("]");
}

// read and write predicates
bool_types.arg("bool");
// Same argument for all inputs. Different predicates would be
// used when grouping is done across iterations
TORCH_INTERNAL_ASSERT(grouped_gwop->predicate() != nullptr);
TORCH_INTERNAL_ASSERT(
grouped_gwop->predicate() != nullptr &&
grouped_gwop->predicate()->hasValue());
const auto read_pred = genInline(grouped_gwop->predicate());
read_preds.arg(read_pred);
if (grouped_gwop->writePredicate() != nullptr) {
TORCH_INTERNAL_ASSERT(grouped_gwop->writePredicate()->hasValue());
write_preds.arg(genInline(grouped_gwop->writePredicate()));
} else {
write_preds.arg(read_pred);
}

index_replacement_map_.clear();
}
}

ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
// output
func_args.arg(genCall("RefTuple", data_types, out_args[0]));
func_args.arg(genCall("RefTuple", data_types, out_args[1]));
func_args.arg(genCall("RefTuple", index_types, out_args[2]));
// input
func_args.arg(genCall("ConstRefTuple", data_types, in_args[0]));
func_args.arg(genCall("ConstRefTuple", data_types, in_args[1]));
func_args.arg(genCall("ConstRefTuple", index_types, in_args[2]));
// init
func_args.arg(genCall("LocalTuple", data_types, init_args[0]));
func_args.arg(genCall("LocalTuple", data_types, init_args[1]));
func_args.arg(genCall("LocalTuple", index_types, init_args[2]));
// work buffer
func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[0]));
func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[1]));
func_args.arg(genCall("VolatilePtrTuple", index_types, work_bufs[2]));
// global_sync_buffer
const auto sync_buffer =
grouped_gwop->sync_buffer()->buffer()->as<TensorView>();
func_args.arg("&").append(varName(sync_buffer)).append("[0]");

// shared_buf
ArgumentBuilder smem_buffer_args;
smem_buffer_args.arg(
genCall("reinterpret_cast", ptrType(data_type), "shared_mem_avg"));
smem_buffer_args.arg(
genCall("reinterpret_cast", ptrType(data_type), "shared_mem_var"));
smem_buffer_args.arg(
genCall("reinterpret_cast", ptrType(index_type), "shared_mem_n"));
func_args.arg(genCall(
"PtrTuple",
ArgumentBuilder().arg(data_type).arg(data_type).arg(index_type),
smem_buffer_args));

func_args.arg(genCall("LocalTuple", bool_types, read_preds));
func_args.arg(genCall("LocalTuple", bool_types, write_preds));

addProfileArguments(func_args, grouped_gwop);

ArgumentBuilder func_template_args;
func_template_args.arg(
grouped_gwop->numExprs() * index_replacement_maps.size());
func_template_args.arg(data_type);
func_template_args.arg(index_type);

indent() << genCall(
genFusedReductionName(ir_utils::getTvOutput(grouped_gwop)) +
".welfordGroup",
func_template_args,
func_args)
<< ";\n";
}

void handle(const kir::GridBroadcast* grop) final {
const auto bop = grop->broadcast_op();
TORCH_INTERNAL_ASSERT(bop->out()->isA<kir::TensorIndex>());
Expand Down Expand Up @@ -2208,6 +2366,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

void handle(const GroupedWelfordOp* grouped_wop) final {
TORCH_INTERNAL_ASSERT(
false,
"Should not reach here as grouped welford is only enabled for grid welford,",
" which is handled by its own handler");
}

//! True if loop is grouped. The IterDomain of the loop must have
//! ParallelType::Group, but it isn't sufficient as the loop may be
//! for an initialization expression, for which the loop shold not
Expand All @@ -2216,7 +2381,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
if (loop->iter_domain()->getParallelType() != ParallelType::Group) {
return false;
}
return ExprFinder::exists(loop, {ExprType::GroupedGridReduction});
return ExprFinder::exists(
loop, {ExprType::GroupedGridReduction, ExprType::GroupedGridWelford});
}

void handle(const kir::ForLoop* loop) final {
Expand Down
30 changes: 30 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::WelfordOp:
ptr(handler)->handle(expr->as<WelfordOp>());
return;
case ExprType::GroupedWelfordOp:
ptr(handler)->handle(expr->as<GroupedWelfordOp>());
return;
case ExprType::LoadStoreOp:
ptr(handler)->handle(expr->as<LoadStoreOp>());
return;
Expand Down Expand Up @@ -190,6 +193,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::GridWelford:
ptr(handler)->handle(expr->as<kir::GridWelford>());
return;
case ExprType::GroupedGridWelford:
ptr(handler)->handle(expr->as<kir::GroupedGridWelford>());
return;
case ExprType::AllocateFusedReduction:
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
return;
Expand Down Expand Up @@ -287,6 +293,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::WelfordOp:
ptr(handler)->handle(expr->as<WelfordOp>());
return;
case ExprType::GroupedWelfordOp:
ptr(handler)->handle(expr->as<GroupedWelfordOp>());
return;
case ExprType::LoadStoreOp:
ptr(handler)->handle(expr->as<LoadStoreOp>());
return;
Expand Down Expand Up @@ -364,6 +373,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::GridWelford:
ptr(handler)->handle(expr->as<kir::GridWelford>());
return;
case ExprType::GroupedGridWelford:
ptr(handler)->handle(expr->as<kir::GroupedGridWelford>());
return;
case ExprType::AllocateFusedReduction:
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
return;
Expand Down Expand Up @@ -469,6 +481,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::WelfordOp:
ptr(mutator)->mutate(expr->as<WelfordOp>());
return;
case ExprType::GroupedWelfordOp:
ptr(mutator)->mutate(expr->as<GroupedWelfordOp>());
return;
case ExprType::LoadStoreOp:
ptr(mutator)->mutate(expr->as<LoadStoreOp>());
return;
Expand Down Expand Up @@ -546,6 +561,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::GridWelford:
ptr(mutator)->mutate(expr->as<kir::GridWelford>());
return;
case ExprType::GroupedGridWelford:
ptr(mutator)->mutate(expr->as<kir::GroupedGridWelford>());
return;
case ExprType::AllocateFusedReduction:
ptr(mutator)->mutate(expr->as<kir::AllocateFusedReduction>());
return;
Expand Down Expand Up @@ -716,6 +734,9 @@ void OptOutConstDispatch::handle(const GroupedReductionOp* stmt) {
void OptOutConstDispatch::handle(const WelfordOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const GroupedWelfordOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const LoadStoreOp* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -793,6 +814,9 @@ void OptOutConstDispatch::handle(const kir::GridBroadcast* stmt) {
void OptOutConstDispatch::handle(const kir::GridWelford* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::GroupedGridWelford* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::AllocateFusedReduction* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -860,6 +884,9 @@ void OptOutDispatch::handle(GroupedReductionOp* stmt) {
void OptOutDispatch::handle(WelfordOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(GroupedWelfordOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(LoadStoreOp* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -937,6 +964,9 @@ void OptOutDispatch::handle(kir::GridBroadcast* stmt) {
void OptOutDispatch::handle(kir::GridWelford* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::GroupedGridWelford* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::AllocateFusedReduction* stmt) {
unhandled(stmt);
}
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class TernaryOp;
class ReductionOp;
class GroupedReductionOp;
class WelfordOp;
class GroupedWelfordOp;
class LoadStoreOp;
class MmaOp;
class BroadcastOp;
Expand Down Expand Up @@ -105,6 +106,7 @@ class GridReduction;
class GroupedGridReduction;
class GridBroadcast;
class GridWelford;
class GroupedGridWelford;
class AllocateFusedReduction;
class InitMagicZero;
class UpdateMagicZero;
Expand Down Expand Up @@ -146,6 +148,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const ReductionOp* stmt);
virtual void handle(const GroupedReductionOp* stmt);
virtual void handle(const WelfordOp* stmt);
virtual void handle(const GroupedWelfordOp* stmt);
virtual void handle(const LoadStoreOp* stmt);
virtual void handle(const MmaOp* stmt);
virtual void handle(const BroadcastOp* stmt);
Expand Down Expand Up @@ -173,6 +176,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const kir::GroupedGridReduction*);
virtual void handle(const kir::GridBroadcast*);
virtual void handle(const kir::GridWelford*);
virtual void handle(const kir::GroupedGridWelford*);
virtual void handle(const kir::AllocateFusedReduction*);
virtual void handle(const kir::Swizzle2DInt*);
virtual void handle(const kir::PairSelect*);
Expand Down Expand Up @@ -209,6 +213,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(ReductionOp* stmt);
virtual void handle(GroupedReductionOp* stmt);
virtual void handle(WelfordOp* stmt);
virtual void handle(GroupedWelfordOp* stmt);
virtual void handle(LoadStoreOp* stmt);
virtual void handle(MmaOp* stmt);
virtual void handle(BroadcastOp* stmt);
Expand Down Expand Up @@ -236,6 +241,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(kir::GroupedGridReduction* stmt);
virtual void handle(kir::GridBroadcast* stmt);
virtual void handle(kir::GridWelford* stmt);
virtual void handle(kir::GroupedGridWelford* stmt);
virtual void handle(kir::AllocateFusedReduction* stmt);
virtual void handle(kir::Swizzle2DInt* stmt);
virtual void handle(kir::PairSelect* stmt);
Expand Down Expand Up @@ -313,6 +319,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(ReductionOp*);
virtual void mutate(GroupedReductionOp*);
virtual void mutate(WelfordOp*);
virtual void mutate(GroupedWelfordOp*);
virtual void mutate(LoadStoreOp*);
virtual void mutate(MmaOp*);
virtual void mutate(BroadcastOp*);
Expand Down Expand Up @@ -340,6 +347,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(kir::GroupedGridReduction*);
virtual void mutate(kir::GridBroadcast*);
virtual void mutate(kir::GridWelford*);
virtual void mutate(kir::GroupedGridWelford*);
virtual void mutate(kir::AllocateFusedReduction*);
virtual void mutate(kir::Swizzle2DInt*);
virtual void mutate(kir::PairSelect*);
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/codegen/cuda/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,13 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
"what can be resident on the GPU at once. Need: ",
launch_params_.gdimx() * launch_params_.gdimy() *
launch_params_.gdimz(),
" but limited to ",
" (",
launch_params_.gdimx(),
" * ",
launch_params_.gdimy(),
" * ",
launch_params_.gdimz(),
") but limited to ",
num_blocks_per_SM,
" * ",
at::cuda::getDeviceProperties(options_.device.index())
Expand Down
Loading