@@ -120,6 +120,38 @@ std::string genCall(
120
120
return ss.str ();
121
121
}
122
122
123
+ // ! A utility class to check if an expression of a particular type exists
124
+ class ExprFinder : kir::ConstIrVisitor {
125
+ public:
126
+ // ! True if expr or any of its nested expressions is included in
127
+ // ! expr_types
128
+ static bool exists (
129
+ const Expr* expr,
130
+ const std::unordered_set<ExprType>& expr_types) {
131
+ ExprFinder finder (expr_types);
132
+ finder.handle (std::vector<const Expr*>{expr});
133
+ return finder.is_found_ ;
134
+ }
135
+
136
+ private:
137
+ ExprFinder (const std::unordered_set<ExprType>& expr_types)
138
+ : expr_types_(expr_types) {}
139
+
140
+ using kir::ConstIrVisitor::handle;
141
+
142
+ void handle (const Expr* expr) final {
143
+ if (expr_types_.find (expr->etype ()) != expr_types_.end ()) {
144
+ is_found_ = true ;
145
+ return ;
146
+ }
147
+ kir::ConstIrVisitor::handle (expr);
148
+ }
149
+
150
+ private:
151
+ const std::unordered_set<ExprType>& expr_types_;
152
+ bool is_found_ = false ;
153
+ };
154
+
123
155
class CudaKernelGenerator : private OptOutConstDispatch {
124
156
static constexpr const char * kTab = " " ;
125
157
@@ -397,6 +429,14 @@ class CudaKernelGenerator : private OptOutConstDispatch {
397
429
}
398
430
399
431
void handle (const Int* i) final {
432
+ // Check the replacement map first. If there's an entry for i, use
433
+ // the corresponding replacement.
434
+ auto replace_it = index_replacement_map_.find (i);
435
+ if (replace_it != index_replacement_map_.end ()) {
436
+ code_ << replace_it->second ;
437
+ return ;
438
+ }
439
+
400
440
const auto def = i->definition ();
401
441
const bool has_alloc = alloc_map_.find (i) != alloc_map_.end ();
402
442
if (def != nullptr && !has_alloc) {
@@ -1535,7 +1575,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
1535
1575
}
1536
1576
1537
1577
TORCH_INTERNAL_ASSERT (
1538
- grouped_grop->numReductions () == 2 ,
1578
+ grouped_grop->numExprs () == 2 ,
1539
1579
" Only grouping of 2 reductions is supported. " ,
1540
1580
grouped_grop->toString ());
1541
1581
@@ -1554,7 +1594,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
1554
1594
ArgumentBuilder func_args (block_nest_level_ + 1 , kTab );
1555
1595
1556
1596
// Append arguments for each reduction
1557
- for (const auto i : c10::irange (grouped_grop->numReductions ())) {
1597
+ for (const auto i : c10::irange (grouped_grop->numExprs ())) {
1558
1598
TORCH_INTERNAL_ASSERT (
1559
1599
grouped_grop->reduction_buffers ().at (i)->buffer ()->isA <TensorView>());
1560
1600
const auto work_buffer =
@@ -1596,17 +1636,106 @@ class CudaKernelGenerator : private OptOutConstDispatch {
1596
1636
indent () << kTab << func_args << " );\n " ;
1597
1637
}
1598
1638
1639
+ // Enumerates all combinations of index values of grouped
1640
+ // loops. Each combination is a vector of loop index values. The
1641
+ // length of the vector is the number of grouped loops.
1642
+ //
1643
+ // Example 1: only one domain of extent 2 is grouped: {{0}, {1}}.
1644
+ // Example 2: two domains of extents 2 and 3 are grouped: {{0, 0},
1645
+ // {0, 1}, {0, 2}, {1, 0}, {1, 1}, {1, 2}}
1646
+ std::vector<std::vector<int64_t >> getGroupedLoopIndexConcreteIntSets () {
1647
+ std::vector<std::vector<int64_t >> index_combinationsatoins;
1648
+
1649
+ // Initialize with an empty vector
1650
+ index_combinationsatoins.push_back (std::vector<int64_t >());
1651
+
1652
+ // Incrementally build a combinatorial set
1653
+ for (const auto loop : grouped_loops_) {
1654
+ const auto iter_count = loop->stop ()->evaluateInt ();
1655
+ std::vector<std::vector<int64_t >> new_combinations;
1656
+ // Append integers from 0 to iter_count to all the vectors built
1657
+ // so far
1658
+ for (const auto & index_vec : index_combinationsatoins) {
1659
+ for (int64_t i = 0 ; i < iter_count; ++i) {
1660
+ auto index_vec_appended = index_vec;
1661
+ index_vec_appended.push_back (i);
1662
+ new_combinations.push_back (index_vec_appended);
1663
+ }
1664
+ }
1665
+ index_combinationsatoins = std::move (new_combinations);
1666
+ }
1667
+
1668
+ return index_combinationsatoins;
1669
+ }
1670
+
1671
+ // ! Returns all combinations of maps from index Vals of grouped loops to their
1672
+ // ! conrete integers.
1673
+ std::vector<std::unordered_map<const Int*, int64_t >>
1674
+ getLoopIndexReplacementMaps () {
1675
+ std::vector<std::unordered_map<const Int*, int64_t >> maps;
1676
+
1677
+ if (grouped_loops_.empty ()) {
1678
+ std::unordered_map<const Int*, int64_t > empty_map;
1679
+ return {empty_map};
1680
+ }
1681
+
1682
+ // Vector of indices of grouped loops
1683
+ std::vector<Int*> loop_indices;
1684
+ std::transform (
1685
+ grouped_loops_.begin (),
1686
+ grouped_loops_.end (),
1687
+ std::back_inserter (loop_indices),
1688
+ [](const kir::ForLoop* loop) { return loop->index ()->as <Int>(); });
1689
+
1690
+ // All combinations of loop index integer values
1691
+ const auto index_val_sets = getGroupedLoopIndexConcreteIntSets ();
1692
+
1693
+ // Create maps from loop index Vals to integers
1694
+ for (const auto & index_values : index_val_sets) {
1695
+ TORCH_INTERNAL_ASSERT (loop_indices.size () == index_values.size ());
1696
+ std::unordered_map<const Int*, int64_t > index_val_map;
1697
+ for (const auto i : c10::irange (loop_indices.size ())) {
1698
+ auto loop_index = loop_indices.at (i);
1699
+ auto index_val = index_values.at (i);
1700
+ index_val_map.emplace (loop_index, index_val);
1701
+ }
1702
+ maps.emplace_back (std::move (index_val_map));
1703
+ }
1704
+
1705
+ return maps;
1706
+ }
1707
+
1599
1708
void generateGroupedGridAllreduce (
1600
1709
const kir::GroupedGridReduction* grouped_grop) {
1601
1710
TORCH_INTERNAL_ASSERT (grouped_grop->isAllreduce ());
1602
1711
1603
- constexpr int max_num_reductions = 8 ;
1712
+ // There are two dimensions of grouping: horizontal grouping and
1713
+ // iteration grouping. The total number of individual reductions
1714
+ // is the number of horizontal reductions * the extent of grouped
1715
+ // iterations. All of them are packed into a single grid reduction
1716
+ // call. The number of reductions is limited, and currently it is
1717
+ // simply an error if exceeded. This could be avoided by
1718
+ // decomposing grouped_grop into smaller groups within the
1719
+ // limit. TODO: Support a larger number of reductions.
1720
+
1721
+ // First, enumerate all combinations of loop index values of
1722
+ // grouped IterDomains. If only a single domain is grouped, this
1723
+ // is simply just a 1D vector of integer from 0 to extent-1. If
1724
+ // two domains are grouped, combinations of two integer vectors
1725
+ // are returned. These loop index value vectors are returned as a
1726
+ // map from loop index Vals to concrete int values.
1727
+ const auto index_replacement_maps = getLoopIndexReplacementMaps ();
1728
+ const auto num_grouped_iterations = index_replacement_maps.size ();
1729
+
1730
+ // This is also checked at the lowering validaiton time, so it
1731
+ // isn't strictly necessary.
1604
1732
TORCH_INTERNAL_ASSERT (
1605
- grouped_grop->numReductions () <= max_num_reductions,
1733
+ num_grouped_iterations * grouped_grop->numExprs () <=
1734
+ kMaxNumGroupedReductions ,
1606
1735
" Too many grouped reductions: " ,
1607
1736
grouped_grop->toString (),
1608
1737
" . Up to " ,
1609
- max_num_reductions ,
1738
+ kMaxNumGroupedReductions ,
1610
1739
" reductions are allowed." );
1611
1740
1612
1741
ArgumentBuilder types;
@@ -1620,44 +1749,65 @@ class CudaKernelGenerator : private OptOutConstDispatch {
1620
1749
ArgumentBuilder read_preds;
1621
1750
ArgumentBuilder write_preds;
1622
1751
1623
- for (const auto i : c10::irange (grouped_grop->numReductions ())) {
1624
- const auto data_type = grouped_grop->outputs ().at (i )->dtype ();
1625
- TORCH_INTERNAL_ASSERT (
1626
- grouped_grop-> reduction_buffers () .at (i)-> buffer ()-> isA <TensorView>());
1627
-
1628
- types. arg (data_type );
1752
+ for (const auto expr_index : c10::irange (grouped_grop->numExprs ())) {
1753
+ const auto data_type = grouped_grop->outputs ().at (expr_index )->dtype ();
1754
+ TORCH_INTERNAL_ASSERT (grouped_grop-> reduction_buffers ()
1755
+ .at (expr_index)
1756
+ -> buffer ()
1757
+ -> isA <TensorView>() );
1629
1758
1630
- // out
1631
- outputs.arg (gen (grouped_grop->outputs ().at (i)));
1759
+ for (const auto & group_index :
1760
+ c10::irange (index_replacement_maps.size ())) {
1761
+ // Set the index replacement map with the concrete values of
1762
+ // indices of grouped loops.
1763
+ index_replacement_map_ = index_replacement_maps.at (group_index);
1632
1764
1633
- // inp
1634
- inputs.arg (gen (grouped_grop->inputs ().at (i)));
1765
+ types.arg (data_type);
1635
1766
1636
- // global_work_buffer
1637
- const auto work_buffer =
1638
- grouped_grop->reduction_buffers ().at (i)->buffer ()->as <TensorView>();
1639
- work_bufs.arg (" &" ).append (varName (work_buffer)).append (" [0]" );
1640
-
1641
- init_vals.arg (genInline (grouped_grop->initVal (i)));
1642
-
1643
- reduction_ops.arg (genReductionOp (
1644
- grouped_grop->getReductionOpType (i),
1645
- grouped_grop->output (i)->dtype ()));
1767
+ // out
1768
+ outputs.arg (gen (grouped_grop->outputs ().at (expr_index)));
1769
+
1770
+ // inp
1771
+ inputs.arg (gen (grouped_grop->inputs ().at (expr_index)));
1772
+
1773
+ // global_work_buffer
1774
+ const auto work_buffer = grouped_grop->reduction_buffers ()
1775
+ .at (expr_index)
1776
+ ->buffer ()
1777
+ ->as <TensorView>();
1778
+ // Separate Work buffer is used for each reduction.
1779
+ auto work_buffer_offset = group_index == 0
1780
+ ? " 0"
1781
+ : (genInline (grouped_grop->buffer_stride ()) + " * " +
1782
+ std::to_string (group_index));
1783
+ work_bufs.arg (" &" )
1784
+ .append (varName (work_buffer))
1785
+ .append (" [" )
1786
+ .append (work_buffer_offset)
1787
+ .append (" ]" );
1788
+ init_vals.arg (genInline (grouped_grop->initVal (expr_index)));
1789
+
1790
+ reduction_ops.arg (genReductionOp (
1791
+ grouped_grop->getReductionOpType (expr_index),
1792
+ grouped_grop->output (expr_index)->dtype ()));
1793
+
1794
+ // read and write predicates
1795
+ bool_types.arg (" bool" );
1796
+ // Same argument for all inputs. Different predicates would be
1797
+ // used when grouping is done across iterations
1798
+ TORCH_INTERNAL_ASSERT (
1799
+ grouped_grop->predicate () != nullptr &&
1800
+ grouped_grop->predicate ()->hasValue ());
1801
+ const auto read_pred = genInline (grouped_grop->predicate ());
1802
+ read_preds.arg (read_pred);
1803
+ if (grouped_grop->writePredicate () != nullptr ) {
1804
+ TORCH_INTERNAL_ASSERT (grouped_grop->writePredicate ()->hasValue ());
1805
+ write_preds.arg (genInline (grouped_grop->writePredicate ()));
1806
+ } else {
1807
+ write_preds.arg (read_pred);
1808
+ }
1646
1809
1647
- // read and write predicates
1648
- bool_types.arg (" bool" );
1649
- // Same argument for all inputs. Different predicates would be
1650
- // used when grouping is done across iterations
1651
- TORCH_INTERNAL_ASSERT (
1652
- grouped_grop->predicate () != nullptr &&
1653
- grouped_grop->predicate ()->hasValue ());
1654
- const auto read_pred = genInline (grouped_grop->predicate ());
1655
- read_preds.arg (read_pred);
1656
- if (grouped_grop->writePredicate () != nullptr ) {
1657
- TORCH_INTERNAL_ASSERT (grouped_grop->writePredicate ()->hasValue ());
1658
- write_preds.arg (genInline (grouped_grop->writePredicate ()));
1659
- } else {
1660
- write_preds.arg (read_pred);
1810
+ index_replacement_map_.clear ();
1661
1811
}
1662
1812
}
1663
1813
@@ -1975,7 +2125,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
1975
2125
1976
2126
void handleTrivialLoop (const kir::ForLoop* loop) {
1977
2127
if (loop->vectorize ()) {
1978
- vectorize_scope_ = loop-> vectorize () ;
2128
+ vectorize_scope_ = true ;
1979
2129
}
1980
2130
handleScope (loop->body ());
1981
2131
if (loop->vectorize ()) {
@@ -1984,7 +2134,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
1984
2134
}
1985
2135
1986
2136
void handle (const GroupedReductionOp* grouped_rop) final {
1987
- for (const auto i : c10::irange (grouped_rop->numReductions ())) {
2137
+ for (const auto i : c10::irange (grouped_rop->numExprs ())) {
1988
2138
TORCH_INTERNAL_ASSERT (grouped_rop->output (i)->isA <kir::TensorIndex>());
1989
2139
1990
2140
const auto output = grouped_rop->output (i)->as <kir::TensorIndex>();
@@ -1997,7 +2147,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
1997
2147
1998
2148
TORCH_INTERNAL_ASSERT (
1999
2149
!has_grid_reduce,
2000
- " GroupedReductionOp does not support block parallelization. GroupedGridReductionOp must be used. " ,
2150
+ " GroupedReductionOp does not support block parallelization. GroupedGridReduction must be used. " ,
2001
2151
grouped_rop->toString ());
2002
2152
2003
2153
if (!has_block_reduce) {
@@ -2023,12 +2173,32 @@ class CudaKernelGenerator : private OptOutConstDispatch {
2023
2173
}
2024
2174
}
2025
2175
2176
+ // ! True if loop is grouped. The IterDomain of the loop must have
2177
+ // ! ParallelType::Group, but it isn't sufficient as the loop may be
2178
+ // ! for an initialization expression, for which the loop shold not
2179
+ // ! be grouped. Make sure a GroupedGridReduction is found.
2180
+ bool isGroupedLoop (const kir::ForLoop* loop) {
2181
+ if (loop->iter_domain ()->getParallelType () != ParallelType::Group) {
2182
+ return false ;
2183
+ }
2184
+ return ExprFinder::exists (loop, {ExprType::GroupedGridReduction});
2185
+ }
2186
+
2026
2187
void handle (const kir::ForLoop* loop) final {
2027
2188
if (loop->isTrivial ()) {
2028
2189
handleTrivialLoop (loop);
2029
2190
return ;
2030
2191
}
2031
2192
2193
+ // If a loop is grouped, no loop is created, but it isn't
2194
+ // considered trivial as the loop trip count is not one.
2195
+ if (isGroupedLoop (loop)) {
2196
+ grouped_loops_.push_back (loop);
2197
+ handleScope (loop->body ());
2198
+ grouped_loops_.pop_back ();
2199
+ return ;
2200
+ }
2201
+
2032
2202
const auto gen_index = gen (loop->index ());
2033
2203
const auto gen_start = genInline (loop->start ());
2034
2204
const auto gen_stop = genInline (loop->stop ());
@@ -2213,10 +2383,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {
2213
2383
2214
2384
// Mark when we are inside of a vectorized for-loop
2215
2385
bool vectorize_scope_ = false ;
2216
-
2217
2386
// ! Keep track of Allocate node for Val. Used to determine if Val
2218
2387
// ! should be inlined.
2219
2388
std::unordered_map<const Val*, const kir::Allocate*> alloc_map_;
2389
+ // ! Keep track of grouped loops
2390
+ std::deque<const kir::ForLoop*> grouped_loops_;
2391
+ // ! Used to replace symbolic indices with concrete values
2392
+ std::unordered_map<const Int*, int64_t > index_replacement_map_;
2220
2393
};
2221
2394
2222
2395
} // namespace
0 commit comments