Skip to content

Commit 025c840

Browse files
authored
Grouping grid allreduces across iterations (csarofeen#1755)
* Extend the grouped grid reduction kernel The kernel itself should work with an arbitrary number of inputs, but the underlying data structure, Tuple, still explicitly needs to be specialized for the number of values, which is currently limited to 8.
1 parent 37c579e commit 025c840

20 files changed

+918
-151
lines changed

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 216 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,38 @@ std::string genCall(
120120
return ss.str();
121121
}
122122

123+
//! A utility class to check if an expression of a particular type exists
124+
class ExprFinder : kir::ConstIrVisitor {
125+
public:
126+
//! True if expr or any of its nested expressions is included in
127+
//! expr_types
128+
static bool exists(
129+
const Expr* expr,
130+
const std::unordered_set<ExprType>& expr_types) {
131+
ExprFinder finder(expr_types);
132+
finder.handle(std::vector<const Expr*>{expr});
133+
return finder.is_found_;
134+
}
135+
136+
private:
137+
ExprFinder(const std::unordered_set<ExprType>& expr_types)
138+
: expr_types_(expr_types) {}
139+
140+
using kir::ConstIrVisitor::handle;
141+
142+
void handle(const Expr* expr) final {
143+
if (expr_types_.find(expr->etype()) != expr_types_.end()) {
144+
is_found_ = true;
145+
return;
146+
}
147+
kir::ConstIrVisitor::handle(expr);
148+
}
149+
150+
private:
151+
const std::unordered_set<ExprType>& expr_types_;
152+
bool is_found_ = false;
153+
};
154+
123155
class CudaKernelGenerator : private OptOutConstDispatch {
124156
static constexpr const char* kTab = " ";
125157

@@ -397,6 +429,14 @@ class CudaKernelGenerator : private OptOutConstDispatch {
397429
}
398430

399431
void handle(const Int* i) final {
432+
// Check the replacement map first. If there's an entry for i, use
433+
// the corresponding replacement.
434+
auto replace_it = index_replacement_map_.find(i);
435+
if (replace_it != index_replacement_map_.end()) {
436+
code_ << replace_it->second;
437+
return;
438+
}
439+
400440
const auto def = i->definition();
401441
const bool has_alloc = alloc_map_.find(i) != alloc_map_.end();
402442
if (def != nullptr && !has_alloc) {
@@ -1535,7 +1575,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
15351575
}
15361576

15371577
TORCH_INTERNAL_ASSERT(
1538-
grouped_grop->numReductions() == 2,
1578+
grouped_grop->numExprs() == 2,
15391579
"Only grouping of 2 reductions is supported. ",
15401580
grouped_grop->toString());
15411581

@@ -1554,7 +1594,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
15541594
ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
15551595

15561596
// Append arguments for each reduction
1557-
for (const auto i : c10::irange(grouped_grop->numReductions())) {
1597+
for (const auto i : c10::irange(grouped_grop->numExprs())) {
15581598
TORCH_INTERNAL_ASSERT(
15591599
grouped_grop->reduction_buffers().at(i)->buffer()->isA<TensorView>());
15601600
const auto work_buffer =
@@ -1596,17 +1636,106 @@ class CudaKernelGenerator : private OptOutConstDispatch {
15961636
indent() << kTab << func_args << ");\n";
15971637
}
15981638

1639+
// Enumerates all combinations of index values of grouped
1640+
// loops. Each combination is a vector of loop index values. The
1641+
// length of the vector is the number of grouped loops.
1642+
//
1643+
// Example 1: only one domain of extent 2 is grouped: {{0}, {1}}.
1644+
// Example 2: two domains of extents 2 and 3 are grouped: {{0, 0},
1645+
// {0, 1}, {0, 2}, {1, 0}, {1, 1}, {1, 2}}
1646+
std::vector<std::vector<int64_t>> getGroupedLoopIndexConcreteIntSets() {
1647+
std::vector<std::vector<int64_t>> index_combinationsatoins;
1648+
1649+
// Initialize with an empty vector
1650+
index_combinationsatoins.push_back(std::vector<int64_t>());
1651+
1652+
// Incrementally build a combinatorial set
1653+
for (const auto loop : grouped_loops_) {
1654+
const auto iter_count = loop->stop()->evaluateInt();
1655+
std::vector<std::vector<int64_t>> new_combinations;
1656+
// Append integers from 0 to iter_count to all the vectors built
1657+
// so far
1658+
for (const auto& index_vec : index_combinationsatoins) {
1659+
for (int64_t i = 0; i < iter_count; ++i) {
1660+
auto index_vec_appended = index_vec;
1661+
index_vec_appended.push_back(i);
1662+
new_combinations.push_back(index_vec_appended);
1663+
}
1664+
}
1665+
index_combinationsatoins = std::move(new_combinations);
1666+
}
1667+
1668+
return index_combinationsatoins;
1669+
}
1670+
1671+
//! Returns all combinations of maps from index Vals of grouped loops to their
1672+
//! conrete integers.
1673+
std::vector<std::unordered_map<const Int*, int64_t>>
1674+
getLoopIndexReplacementMaps() {
1675+
std::vector<std::unordered_map<const Int*, int64_t>> maps;
1676+
1677+
if (grouped_loops_.empty()) {
1678+
std::unordered_map<const Int*, int64_t> empty_map;
1679+
return {empty_map};
1680+
}
1681+
1682+
// Vector of indices of grouped loops
1683+
std::vector<Int*> loop_indices;
1684+
std::transform(
1685+
grouped_loops_.begin(),
1686+
grouped_loops_.end(),
1687+
std::back_inserter(loop_indices),
1688+
[](const kir::ForLoop* loop) { return loop->index()->as<Int>(); });
1689+
1690+
// All combinations of loop index integer values
1691+
const auto index_val_sets = getGroupedLoopIndexConcreteIntSets();
1692+
1693+
// Create maps from loop index Vals to integers
1694+
for (const auto& index_values : index_val_sets) {
1695+
TORCH_INTERNAL_ASSERT(loop_indices.size() == index_values.size());
1696+
std::unordered_map<const Int*, int64_t> index_val_map;
1697+
for (const auto i : c10::irange(loop_indices.size())) {
1698+
auto loop_index = loop_indices.at(i);
1699+
auto index_val = index_values.at(i);
1700+
index_val_map.emplace(loop_index, index_val);
1701+
}
1702+
maps.emplace_back(std::move(index_val_map));
1703+
}
1704+
1705+
return maps;
1706+
}
1707+
15991708
void generateGroupedGridAllreduce(
16001709
const kir::GroupedGridReduction* grouped_grop) {
16011710
TORCH_INTERNAL_ASSERT(grouped_grop->isAllreduce());
16021711

1603-
constexpr int max_num_reductions = 8;
1712+
// There are two dimensions of grouping: horizontal grouping and
1713+
// iteration grouping. The total number of individual reductions
1714+
// is the number of horizontal reductions * the extent of grouped
1715+
// iterations. All of them are packed into a single grid reduction
1716+
// call. The number of reductions is limited, and currently it is
1717+
// simply an error if exceeded. This could be avoided by
1718+
// decomposing grouped_grop into smaller groups within the
1719+
// limit. TODO: Support a larger number of reductions.
1720+
1721+
// First, enumerate all combinations of loop index values of
1722+
// grouped IterDomains. If only a single domain is grouped, this
1723+
// is simply just a 1D vector of integer from 0 to extent-1. If
1724+
// two domains are grouped, combinations of two integer vectors
1725+
// are returned. These loop index value vectors are returned as a
1726+
// map from loop index Vals to concrete int values.
1727+
const auto index_replacement_maps = getLoopIndexReplacementMaps();
1728+
const auto num_grouped_iterations = index_replacement_maps.size();
1729+
1730+
// This is also checked at the lowering validaiton time, so it
1731+
// isn't strictly necessary.
16041732
TORCH_INTERNAL_ASSERT(
1605-
grouped_grop->numReductions() <= max_num_reductions,
1733+
num_grouped_iterations * grouped_grop->numExprs() <=
1734+
kMaxNumGroupedReductions,
16061735
"Too many grouped reductions: ",
16071736
grouped_grop->toString(),
16081737
". Up to ",
1609-
max_num_reductions,
1738+
kMaxNumGroupedReductions,
16101739
" reductions are allowed.");
16111740

16121741
ArgumentBuilder types;
@@ -1620,44 +1749,65 @@ class CudaKernelGenerator : private OptOutConstDispatch {
16201749
ArgumentBuilder read_preds;
16211750
ArgumentBuilder write_preds;
16221751

1623-
for (const auto i : c10::irange(grouped_grop->numReductions())) {
1624-
const auto data_type = grouped_grop->outputs().at(i)->dtype();
1625-
TORCH_INTERNAL_ASSERT(
1626-
grouped_grop->reduction_buffers().at(i)->buffer()->isA<TensorView>());
1627-
1628-
types.arg(data_type);
1752+
for (const auto expr_index : c10::irange(grouped_grop->numExprs())) {
1753+
const auto data_type = grouped_grop->outputs().at(expr_index)->dtype();
1754+
TORCH_INTERNAL_ASSERT(grouped_grop->reduction_buffers()
1755+
.at(expr_index)
1756+
->buffer()
1757+
->isA<TensorView>());
16291758

1630-
// out
1631-
outputs.arg(gen(grouped_grop->outputs().at(i)));
1759+
for (const auto& group_index :
1760+
c10::irange(index_replacement_maps.size())) {
1761+
// Set the index replacement map with the concrete values of
1762+
// indices of grouped loops.
1763+
index_replacement_map_ = index_replacement_maps.at(group_index);
16321764

1633-
// inp
1634-
inputs.arg(gen(grouped_grop->inputs().at(i)));
1765+
types.arg(data_type);
16351766

1636-
// global_work_buffer
1637-
const auto work_buffer =
1638-
grouped_grop->reduction_buffers().at(i)->buffer()->as<TensorView>();
1639-
work_bufs.arg("&").append(varName(work_buffer)).append("[0]");
1640-
1641-
init_vals.arg(genInline(grouped_grop->initVal(i)));
1642-
1643-
reduction_ops.arg(genReductionOp(
1644-
grouped_grop->getReductionOpType(i),
1645-
grouped_grop->output(i)->dtype()));
1767+
// out
1768+
outputs.arg(gen(grouped_grop->outputs().at(expr_index)));
1769+
1770+
// inp
1771+
inputs.arg(gen(grouped_grop->inputs().at(expr_index)));
1772+
1773+
// global_work_buffer
1774+
const auto work_buffer = grouped_grop->reduction_buffers()
1775+
.at(expr_index)
1776+
->buffer()
1777+
->as<TensorView>();
1778+
// Separate Work buffer is used for each reduction.
1779+
auto work_buffer_offset = group_index == 0
1780+
? "0"
1781+
: (genInline(grouped_grop->buffer_stride()) + " * " +
1782+
std::to_string(group_index));
1783+
work_bufs.arg("&")
1784+
.append(varName(work_buffer))
1785+
.append("[")
1786+
.append(work_buffer_offset)
1787+
.append("]");
1788+
init_vals.arg(genInline(grouped_grop->initVal(expr_index)));
1789+
1790+
reduction_ops.arg(genReductionOp(
1791+
grouped_grop->getReductionOpType(expr_index),
1792+
grouped_grop->output(expr_index)->dtype()));
1793+
1794+
// read and write predicates
1795+
bool_types.arg("bool");
1796+
// Same argument for all inputs. Different predicates would be
1797+
// used when grouping is done across iterations
1798+
TORCH_INTERNAL_ASSERT(
1799+
grouped_grop->predicate() != nullptr &&
1800+
grouped_grop->predicate()->hasValue());
1801+
const auto read_pred = genInline(grouped_grop->predicate());
1802+
read_preds.arg(read_pred);
1803+
if (grouped_grop->writePredicate() != nullptr) {
1804+
TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue());
1805+
write_preds.arg(genInline(grouped_grop->writePredicate()));
1806+
} else {
1807+
write_preds.arg(read_pred);
1808+
}
16461809

1647-
// read and write predicates
1648-
bool_types.arg("bool");
1649-
// Same argument for all inputs. Different predicates would be
1650-
// used when grouping is done across iterations
1651-
TORCH_INTERNAL_ASSERT(
1652-
grouped_grop->predicate() != nullptr &&
1653-
grouped_grop->predicate()->hasValue());
1654-
const auto read_pred = genInline(grouped_grop->predicate());
1655-
read_preds.arg(read_pred);
1656-
if (grouped_grop->writePredicate() != nullptr) {
1657-
TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue());
1658-
write_preds.arg(genInline(grouped_grop->writePredicate()));
1659-
} else {
1660-
write_preds.arg(read_pred);
1810+
index_replacement_map_.clear();
16611811
}
16621812
}
16631813

@@ -1975,7 +2125,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
19752125

19762126
void handleTrivialLoop(const kir::ForLoop* loop) {
19772127
if (loop->vectorize()) {
1978-
vectorize_scope_ = loop->vectorize();
2128+
vectorize_scope_ = true;
19792129
}
19802130
handleScope(loop->body());
19812131
if (loop->vectorize()) {
@@ -1984,7 +2134,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
19842134
}
19852135

19862136
void handle(const GroupedReductionOp* grouped_rop) final {
1987-
for (const auto i : c10::irange(grouped_rop->numReductions())) {
2137+
for (const auto i : c10::irange(grouped_rop->numExprs())) {
19882138
TORCH_INTERNAL_ASSERT(grouped_rop->output(i)->isA<kir::TensorIndex>());
19892139

19902140
const auto output = grouped_rop->output(i)->as<kir::TensorIndex>();
@@ -1997,7 +2147,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
19972147

19982148
TORCH_INTERNAL_ASSERT(
19992149
!has_grid_reduce,
2000-
"GroupedReductionOp does not support block parallelization. GroupedGridReductionOp must be used. ",
2150+
"GroupedReductionOp does not support block parallelization. GroupedGridReduction must be used. ",
20012151
grouped_rop->toString());
20022152

20032153
if (!has_block_reduce) {
@@ -2023,12 +2173,32 @@ class CudaKernelGenerator : private OptOutConstDispatch {
20232173
}
20242174
}
20252175

2176+
//! True if loop is grouped. The IterDomain of the loop must have
2177+
//! ParallelType::Group, but it isn't sufficient as the loop may be
2178+
//! for an initialization expression, for which the loop shold not
2179+
//! be grouped. Make sure a GroupedGridReduction is found.
2180+
bool isGroupedLoop(const kir::ForLoop* loop) {
2181+
if (loop->iter_domain()->getParallelType() != ParallelType::Group) {
2182+
return false;
2183+
}
2184+
return ExprFinder::exists(loop, {ExprType::GroupedGridReduction});
2185+
}
2186+
20262187
void handle(const kir::ForLoop* loop) final {
20272188
if (loop->isTrivial()) {
20282189
handleTrivialLoop(loop);
20292190
return;
20302191
}
20312192

2193+
// If a loop is grouped, no loop is created, but it isn't
2194+
// considered trivial as the loop trip count is not one.
2195+
if (isGroupedLoop(loop)) {
2196+
grouped_loops_.push_back(loop);
2197+
handleScope(loop->body());
2198+
grouped_loops_.pop_back();
2199+
return;
2200+
}
2201+
20322202
const auto gen_index = gen(loop->index());
20332203
const auto gen_start = genInline(loop->start());
20342204
const auto gen_stop = genInline(loop->stop());
@@ -2213,10 +2383,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {
22132383

22142384
// Mark when we are inside of a vectorized for-loop
22152385
bool vectorize_scope_ = false;
2216-
22172386
//! Keep track of Allocate node for Val. Used to determine if Val
22182387
//! should be inlined.
22192388
std::unordered_map<const Val*, const kir::Allocate*> alloc_map_;
2389+
//! Keep track of grouped loops
2390+
std::deque<const kir::ForLoop*> grouped_loops_;
2391+
//! Used to replace symbolic indices with concrete values
2392+
std::unordered_map<const Int*, int64_t> index_replacement_map_;
22202393
};
22212394

22222395
} // namespace

torch/csrc/jit/codegen/cuda/inline_propagator.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ bool MaxPosCalculator::isAllowedID(
6666
}
6767

6868
if (!allow_vectorize) {
69+
// Avoid inlining if marked as Vectorize or Group. In the case of
70+
// BestEffort and MostInlined modes, avoid Unroll as well.
6971
bool is_vectorize = isParallelTypeVectorize(id->getParallelType()) ||
72+
id->getParallelType() == ParallelType::Group ||
7073
((mode_ == ComputeAtMode::BestEffort ||
7174
mode_ == ComputeAtMode::MostInlined) &&
7275
id->getParallelType() == ParallelType::Unroll);

torch/csrc/jit/codegen/cuda/ir_internal_nodes.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr {
204204

205205
GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner);
206206

207-
size_t numReductions() const {
207+
//! Number of expressions grouped horizontally. It does not reflect
208+
//! iteration grouping.
209+
size_t numExprs() const {
208210
return reduction_op_types_.size();
209211
}
210212

@@ -231,7 +233,9 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr {
231233
bool sameAs(const Statement* other) const override;
232234

233235
private:
236+
//! Reduction ops of grouped reductions
234237
const std::vector<BinaryOpType> reduction_op_types_;
238+
//! Initial values of grouped reductions
235239
const std::vector<Val*> init_vals_;
236240
//! True if using the fused reduction kernel
237241
bool is_allreduce_ = false;

0 commit comments

Comments
 (0)