Skip to content

Commit 20cf109

Browse files
authored
Grouped grid welford (#1921)
Enables grouping of grid welford ops across iterations. Same functionality as the iteration grouping for GridReduction. This ins intended to improve the outer-norm grid persistence in batchnorm-like fusions.
1 parent 6cf7eb0 commit 20cf109

29 files changed

+2650
-338
lines changed

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,12 +1417,12 @@ WelfordResult Welford(
14171417
out_avg,
14181418
out_var,
14191419
out_N, /*out var/avg/count */
1420+
tv, /*in var/avg/count */
1421+
FusionGuard::getCurFusion()->zeroVal(),
1422+
FusionGuard::getCurFusion()->oneVal(),
14201423
init_avg_val,
14211424
init_var_val,
1422-
init_N, /*init var/avg/count */
1423-
tv,
1424-
FusionGuard::getCurFusion()->zeroVal(),
1425-
FusionGuard::getCurFusion()->oneVal()); /*in var/avg/count */
1425+
init_N); /*init var/avg/count */
14261426

14271427
return WelfordResult(out_avg, out_var, out_N);
14281428
}

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1671,6 +1671,16 @@ class CudaKernelGenerator : private OptOutConstDispatch {
16711671
indent() << kTab << func_args << ");\n";
16721672
}
16731673

1674+
void handle(const kir::GroupedGridWelford* grouped_gwop) final {
1675+
if (grouped_gwop->isAllreduce()) {
1676+
generateGroupedGridAllreduceWelford(grouped_gwop);
1677+
return;
1678+
} else {
1679+
TORCH_INTERNAL_ASSERT(
1680+
false, "Non-allreduce grouped grid welford is not yet supported");
1681+
}
1682+
}
1683+
16741684
// Enumerates all combinations of index values of grouped
16751685
// loops. Each combination is a vector of loop index values. The
16761686
// length of the vector is the number of grouped loops.
@@ -1872,6 +1882,154 @@ class CudaKernelGenerator : private OptOutConstDispatch {
18721882
indent() << kTab << func_args << ");\n";
18731883
}
18741884

1885+
// Mostly the same as the grouped grid redution version
1886+
void generateGroupedGridAllreduceWelford(
1887+
const kir::GroupedGridWelford* grouped_gwop) {
1888+
TORCH_INTERNAL_ASSERT(grouped_gwop->isAllreduce());
1889+
1890+
const auto index_replacement_maps = getLoopIndexReplacementMaps();
1891+
const auto num_grouped_iterations = index_replacement_maps.size();
1892+
1893+
// This is also checked at the lowering validaiton time, so it
1894+
// isn't strictly necessary.
1895+
TORCH_INTERNAL_ASSERT(
1896+
num_grouped_iterations * grouped_gwop->numExprs() <=
1897+
kMaxNumGroupedReductions,
1898+
"Too many grouped reductions: ",
1899+
grouped_gwop->toString(),
1900+
". Up to ",
1901+
kMaxNumGroupedReductions,
1902+
" reductions are allowed.");
1903+
1904+
ArgumentBuilder data_types;
1905+
ArgumentBuilder index_types;
1906+
1907+
// Note that the data type of var and avg and that of N are the
1908+
// same with all the welford ops since we only support
1909+
// grouping of iterations.
1910+
const auto data_type = grouped_gwop->outputVals().at(0).avg()->dtype();
1911+
const auto index_type = grouped_gwop->outputVals().at(0).N()->dtype();
1912+
1913+
std::array<ArgumentBuilder, 3> out_args;
1914+
std::array<ArgumentBuilder, 3> in_args;
1915+
std::array<ArgumentBuilder, 3> init_args;
1916+
std::array<ArgumentBuilder, 3> work_bufs;
1917+
1918+
ArgumentBuilder bool_types;
1919+
ArgumentBuilder read_preds;
1920+
ArgumentBuilder write_preds;
1921+
1922+
for (const auto expr_index : c10::irange(grouped_gwop->numExprs())) {
1923+
const auto& output = grouped_gwop->outputVals().at(expr_index);
1924+
const auto& input = grouped_gwop->inputVals().at(expr_index);
1925+
const auto& init = grouped_gwop->initVals().at(expr_index);
1926+
1927+
for (const auto& group_index :
1928+
c10::irange(index_replacement_maps.size())) {
1929+
// Set the index replacement map with the concrete values of
1930+
// indices of grouped loops.
1931+
index_replacement_map_ = index_replacement_maps.at(group_index);
1932+
1933+
data_types.arg(data_type);
1934+
index_types.arg(index_type);
1935+
1936+
auto work_buffer_offset = group_index == 0
1937+
? "0"
1938+
: (genInline(grouped_gwop->buffer_stride()) + " * " +
1939+
std::to_string(group_index));
1940+
1941+
// Setup arguments for avg, var, and N
1942+
for (const auto i : c10::irange(3)) {
1943+
out_args[i].arg(gen(output.get(i)));
1944+
in_args[i].arg(gen(input.get(i)));
1945+
init_args[i].arg(gen(init.get(i)));
1946+
const auto work_buffer = grouped_gwop->reduction_buffers()[i]
1947+
.at(expr_index)
1948+
->buffer()
1949+
->as<TensorView>();
1950+
work_bufs[i]
1951+
.arg("&")
1952+
.append(varName(work_buffer))
1953+
.append("[")
1954+
.append(work_buffer_offset)
1955+
.append("]");
1956+
}
1957+
1958+
// read and write predicates
1959+
bool_types.arg("bool");
1960+
// Same argument for all inputs. Different predicates would be
1961+
// used when grouping is done across iterations
1962+
TORCH_INTERNAL_ASSERT(grouped_gwop->predicate() != nullptr);
1963+
TORCH_INTERNAL_ASSERT(
1964+
grouped_gwop->predicate() != nullptr &&
1965+
grouped_gwop->predicate()->hasValue());
1966+
const auto read_pred = genInline(grouped_gwop->predicate());
1967+
read_preds.arg(read_pred);
1968+
if (grouped_gwop->writePredicate() != nullptr) {
1969+
TORCH_INTERNAL_ASSERT(grouped_gwop->writePredicate()->hasValue());
1970+
write_preds.arg(genInline(grouped_gwop->writePredicate()));
1971+
} else {
1972+
write_preds.arg(read_pred);
1973+
}
1974+
1975+
index_replacement_map_.clear();
1976+
}
1977+
}
1978+
1979+
ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
1980+
// output
1981+
func_args.arg(genCall("RefTuple", data_types, out_args[0]));
1982+
func_args.arg(genCall("RefTuple", data_types, out_args[1]));
1983+
func_args.arg(genCall("RefTuple", index_types, out_args[2]));
1984+
// input
1985+
func_args.arg(genCall("ConstRefTuple", data_types, in_args[0]));
1986+
func_args.arg(genCall("ConstRefTuple", data_types, in_args[1]));
1987+
func_args.arg(genCall("ConstRefTuple", index_types, in_args[2]));
1988+
// init
1989+
func_args.arg(genCall("LocalTuple", data_types, init_args[0]));
1990+
func_args.arg(genCall("LocalTuple", data_types, init_args[1]));
1991+
func_args.arg(genCall("LocalTuple", index_types, init_args[2]));
1992+
// work buffer
1993+
func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[0]));
1994+
func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[1]));
1995+
func_args.arg(genCall("VolatilePtrTuple", index_types, work_bufs[2]));
1996+
// global_sync_buffer
1997+
const auto sync_buffer =
1998+
grouped_gwop->sync_buffer()->buffer()->as<TensorView>();
1999+
func_args.arg("&").append(varName(sync_buffer)).append("[0]");
2000+
2001+
// shared_buf
2002+
ArgumentBuilder smem_buffer_args;
2003+
smem_buffer_args.arg(
2004+
genCall("reinterpret_cast", ptrType(data_type), "shared_mem_avg"));
2005+
smem_buffer_args.arg(
2006+
genCall("reinterpret_cast", ptrType(data_type), "shared_mem_var"));
2007+
smem_buffer_args.arg(
2008+
genCall("reinterpret_cast", ptrType(index_type), "shared_mem_n"));
2009+
func_args.arg(genCall(
2010+
"PtrTuple",
2011+
ArgumentBuilder().arg(data_type).arg(data_type).arg(index_type),
2012+
smem_buffer_args));
2013+
2014+
func_args.arg(genCall("LocalTuple", bool_types, read_preds));
2015+
func_args.arg(genCall("LocalTuple", bool_types, write_preds));
2016+
2017+
addProfileArguments(func_args, grouped_gwop);
2018+
2019+
ArgumentBuilder func_template_args;
2020+
func_template_args.arg(
2021+
grouped_gwop->numExprs() * index_replacement_maps.size());
2022+
func_template_args.arg(data_type);
2023+
func_template_args.arg(index_type);
2024+
2025+
indent() << genCall(
2026+
genFusedReductionName(ir_utils::getTvOutput(grouped_gwop)) +
2027+
".welfordGroup",
2028+
func_template_args,
2029+
func_args)
2030+
<< ";\n";
2031+
}
2032+
18752033
void handle(const kir::GridBroadcast* grop) final {
18762034
const auto bop = grop->broadcast_op();
18772035
TORCH_INTERNAL_ASSERT(bop->out()->isA<kir::TensorIndex>());
@@ -2208,6 +2366,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {
22082366
}
22092367
}
22102368

2369+
void handle(const GroupedWelfordOp* grouped_wop) final {
2370+
TORCH_INTERNAL_ASSERT(
2371+
false,
2372+
"Should not reach here as grouped welford is only enabled for grid welford,",
2373+
" which is handled by its own handler");
2374+
}
2375+
22112376
//! True if loop is grouped. The IterDomain of the loop must have
22122377
//! ParallelType::Group, but it isn't sufficient as the loop may be
22132378
//! for an initialization expression, for which the loop shold not
@@ -2216,7 +2381,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
22162381
if (loop->iter_domain()->getParallelType() != ParallelType::Group) {
22172382
return false;
22182383
}
2219-
return ExprFinder::exists(loop, {ExprType::GroupedGridReduction});
2384+
return ExprFinder::exists(
2385+
loop, {ExprType::GroupedGridReduction, ExprType::GroupedGridWelford});
22202386
}
22212387

22222388
void handle(const kir::ForLoop* loop) final {

torch/csrc/jit/codegen/cuda/dispatch.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ void Expr::dispatch(T handler, Expr* expr) {
113113
case ExprType::WelfordOp:
114114
ptr(handler)->handle(expr->as<WelfordOp>());
115115
return;
116+
case ExprType::GroupedWelfordOp:
117+
ptr(handler)->handle(expr->as<GroupedWelfordOp>());
118+
return;
116119
case ExprType::LoadStoreOp:
117120
ptr(handler)->handle(expr->as<LoadStoreOp>());
118121
return;
@@ -190,6 +193,9 @@ void Expr::dispatch(T handler, Expr* expr) {
190193
case ExprType::GridWelford:
191194
ptr(handler)->handle(expr->as<kir::GridWelford>());
192195
return;
196+
case ExprType::GroupedGridWelford:
197+
ptr(handler)->handle(expr->as<kir::GroupedGridWelford>());
198+
return;
193199
case ExprType::AllocateFusedReduction:
194200
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
195201
return;
@@ -287,6 +293,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
287293
case ExprType::WelfordOp:
288294
ptr(handler)->handle(expr->as<WelfordOp>());
289295
return;
296+
case ExprType::GroupedWelfordOp:
297+
ptr(handler)->handle(expr->as<GroupedWelfordOp>());
298+
return;
290299
case ExprType::LoadStoreOp:
291300
ptr(handler)->handle(expr->as<LoadStoreOp>());
292301
return;
@@ -364,6 +373,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
364373
case ExprType::GridWelford:
365374
ptr(handler)->handle(expr->as<kir::GridWelford>());
366375
return;
376+
case ExprType::GroupedGridWelford:
377+
ptr(handler)->handle(expr->as<kir::GroupedGridWelford>());
378+
return;
367379
case ExprType::AllocateFusedReduction:
368380
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
369381
return;
@@ -469,6 +481,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
469481
case ExprType::WelfordOp:
470482
ptr(mutator)->mutate(expr->as<WelfordOp>());
471483
return;
484+
case ExprType::GroupedWelfordOp:
485+
ptr(mutator)->mutate(expr->as<GroupedWelfordOp>());
486+
return;
472487
case ExprType::LoadStoreOp:
473488
ptr(mutator)->mutate(expr->as<LoadStoreOp>());
474489
return;
@@ -546,6 +561,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
546561
case ExprType::GridWelford:
547562
ptr(mutator)->mutate(expr->as<kir::GridWelford>());
548563
return;
564+
case ExprType::GroupedGridWelford:
565+
ptr(mutator)->mutate(expr->as<kir::GroupedGridWelford>());
566+
return;
549567
case ExprType::AllocateFusedReduction:
550568
ptr(mutator)->mutate(expr->as<kir::AllocateFusedReduction>());
551569
return;
@@ -716,6 +734,9 @@ void OptOutConstDispatch::handle(const GroupedReductionOp* stmt) {
716734
void OptOutConstDispatch::handle(const WelfordOp* stmt) {
717735
unhandled(stmt);
718736
}
737+
void OptOutConstDispatch::handle(const GroupedWelfordOp* stmt) {
738+
unhandled(stmt);
739+
}
719740
void OptOutConstDispatch::handle(const LoadStoreOp* stmt) {
720741
unhandled(stmt);
721742
}
@@ -793,6 +814,9 @@ void OptOutConstDispatch::handle(const kir::GridBroadcast* stmt) {
793814
void OptOutConstDispatch::handle(const kir::GridWelford* stmt) {
794815
unhandled(stmt);
795816
}
817+
void OptOutConstDispatch::handle(const kir::GroupedGridWelford* stmt) {
818+
unhandled(stmt);
819+
}
796820
void OptOutConstDispatch::handle(const kir::AllocateFusedReduction* stmt) {
797821
unhandled(stmt);
798822
}
@@ -860,6 +884,9 @@ void OptOutDispatch::handle(GroupedReductionOp* stmt) {
860884
void OptOutDispatch::handle(WelfordOp* stmt) {
861885
unhandled(stmt);
862886
}
887+
void OptOutDispatch::handle(GroupedWelfordOp* stmt) {
888+
unhandled(stmt);
889+
}
863890
void OptOutDispatch::handle(LoadStoreOp* stmt) {
864891
unhandled(stmt);
865892
}
@@ -937,6 +964,9 @@ void OptOutDispatch::handle(kir::GridBroadcast* stmt) {
937964
void OptOutDispatch::handle(kir::GridWelford* stmt) {
938965
unhandled(stmt);
939966
}
967+
void OptOutDispatch::handle(kir::GroupedGridWelford* stmt) {
968+
unhandled(stmt);
969+
}
940970
void OptOutDispatch::handle(kir::AllocateFusedReduction* stmt) {
941971
unhandled(stmt);
942972
}

torch/csrc/jit/codegen/cuda/dispatch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class TernaryOp;
7474
class ReductionOp;
7575
class GroupedReductionOp;
7676
class WelfordOp;
77+
class GroupedWelfordOp;
7778
class LoadStoreOp;
7879
class MmaOp;
7980
class BroadcastOp;
@@ -105,6 +106,7 @@ class GridReduction;
105106
class GroupedGridReduction;
106107
class GridBroadcast;
107108
class GridWelford;
109+
class GroupedGridWelford;
108110
class AllocateFusedReduction;
109111
class InitMagicZero;
110112
class UpdateMagicZero;
@@ -146,6 +148,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
146148
virtual void handle(const ReductionOp* stmt);
147149
virtual void handle(const GroupedReductionOp* stmt);
148150
virtual void handle(const WelfordOp* stmt);
151+
virtual void handle(const GroupedWelfordOp* stmt);
149152
virtual void handle(const LoadStoreOp* stmt);
150153
virtual void handle(const MmaOp* stmt);
151154
virtual void handle(const BroadcastOp* stmt);
@@ -173,6 +176,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
173176
virtual void handle(const kir::GroupedGridReduction*);
174177
virtual void handle(const kir::GridBroadcast*);
175178
virtual void handle(const kir::GridWelford*);
179+
virtual void handle(const kir::GroupedGridWelford*);
176180
virtual void handle(const kir::AllocateFusedReduction*);
177181
virtual void handle(const kir::Swizzle2DInt*);
178182
virtual void handle(const kir::PairSelect*);
@@ -209,6 +213,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
209213
virtual void handle(ReductionOp* stmt);
210214
virtual void handle(GroupedReductionOp* stmt);
211215
virtual void handle(WelfordOp* stmt);
216+
virtual void handle(GroupedWelfordOp* stmt);
212217
virtual void handle(LoadStoreOp* stmt);
213218
virtual void handle(MmaOp* stmt);
214219
virtual void handle(BroadcastOp* stmt);
@@ -236,6 +241,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
236241
virtual void handle(kir::GroupedGridReduction* stmt);
237242
virtual void handle(kir::GridBroadcast* stmt);
238243
virtual void handle(kir::GridWelford* stmt);
244+
virtual void handle(kir::GroupedGridWelford* stmt);
239245
virtual void handle(kir::AllocateFusedReduction* stmt);
240246
virtual void handle(kir::Swizzle2DInt* stmt);
241247
virtual void handle(kir::PairSelect* stmt);
@@ -313,6 +319,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
313319
virtual void mutate(ReductionOp*);
314320
virtual void mutate(GroupedReductionOp*);
315321
virtual void mutate(WelfordOp*);
322+
virtual void mutate(GroupedWelfordOp*);
316323
virtual void mutate(LoadStoreOp*);
317324
virtual void mutate(MmaOp*);
318325
virtual void mutate(BroadcastOp*);
@@ -340,6 +347,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
340347
virtual void mutate(kir::GroupedGridReduction*);
341348
virtual void mutate(kir::GridBroadcast*);
342349
virtual void mutate(kir::GridWelford*);
350+
virtual void mutate(kir::GroupedGridWelford*);
343351
virtual void mutate(kir::AllocateFusedReduction*);
344352
virtual void mutate(kir::Swizzle2DInt*);
345353
virtual void mutate(kir::PairSelect*);

torch/csrc/jit/codegen/cuda/executor.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,13 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
860860
"what can be resident on the GPU at once. Need: ",
861861
launch_params_.gdimx() * launch_params_.gdimy() *
862862
launch_params_.gdimz(),
863-
" but limited to ",
863+
" (",
864+
launch_params_.gdimx(),
865+
" * ",
866+
launch_params_.gdimy(),
867+
" * ",
868+
launch_params_.gdimz(),
869+
") but limited to ",
864870
num_blocks_per_SM,
865871
" * ",
866872
at::cuda::getDeviceProperties(options_.device.index())

0 commit comments

Comments
 (0)