@@ -1671,6 +1671,16 @@ class CudaKernelGenerator : private OptOutConstDispatch {
1671
1671
indent () << kTab << func_args << " );\n " ;
1672
1672
}
1673
1673
1674
+ void handle (const kir::GroupedGridWelford* grouped_gwop) final {
1675
+ if (grouped_gwop->isAllreduce ()) {
1676
+ generateGroupedGridAllreduceWelford (grouped_gwop);
1677
+ return ;
1678
+ } else {
1679
+ TORCH_INTERNAL_ASSERT (
1680
+ false , " Non-allreduce grouped grid welford is not yet supported" );
1681
+ }
1682
+ }
1683
+
1674
1684
// Enumerates all combinations of index values of grouped
1675
1685
// loops. Each combination is a vector of loop index values. The
1676
1686
// length of the vector is the number of grouped loops.
@@ -1872,6 +1882,154 @@ class CudaKernelGenerator : private OptOutConstDispatch {
1872
1882
indent () << kTab << func_args << " );\n " ;
1873
1883
}
1874
1884
1885
+ // Mostly the same as the grouped grid redution version
1886
+ void generateGroupedGridAllreduceWelford (
1887
+ const kir::GroupedGridWelford* grouped_gwop) {
1888
+ TORCH_INTERNAL_ASSERT (grouped_gwop->isAllreduce ());
1889
+
1890
+ const auto index_replacement_maps = getLoopIndexReplacementMaps ();
1891
+ const auto num_grouped_iterations = index_replacement_maps.size ();
1892
+
1893
+ // This is also checked at the lowering validaiton time, so it
1894
+ // isn't strictly necessary.
1895
+ TORCH_INTERNAL_ASSERT (
1896
+ num_grouped_iterations * grouped_gwop->numExprs () <=
1897
+ kMaxNumGroupedReductions ,
1898
+ " Too many grouped reductions: " ,
1899
+ grouped_gwop->toString (),
1900
+ " . Up to " ,
1901
+ kMaxNumGroupedReductions ,
1902
+ " reductions are allowed." );
1903
+
1904
+ ArgumentBuilder data_types;
1905
+ ArgumentBuilder index_types;
1906
+
1907
+ // Note that the data type of var and avg and that of N are the
1908
+ // same with all the welford ops since we only support
1909
+ // grouping of iterations.
1910
+ const auto data_type = grouped_gwop->outputVals ().at (0 ).avg ()->dtype ();
1911
+ const auto index_type = grouped_gwop->outputVals ().at (0 ).N ()->dtype ();
1912
+
1913
+ std::array<ArgumentBuilder, 3 > out_args;
1914
+ std::array<ArgumentBuilder, 3 > in_args;
1915
+ std::array<ArgumentBuilder, 3 > init_args;
1916
+ std::array<ArgumentBuilder, 3 > work_bufs;
1917
+
1918
+ ArgumentBuilder bool_types;
1919
+ ArgumentBuilder read_preds;
1920
+ ArgumentBuilder write_preds;
1921
+
1922
+ for (const auto expr_index : c10::irange (grouped_gwop->numExprs ())) {
1923
+ const auto & output = grouped_gwop->outputVals ().at (expr_index);
1924
+ const auto & input = grouped_gwop->inputVals ().at (expr_index);
1925
+ const auto & init = grouped_gwop->initVals ().at (expr_index);
1926
+
1927
+ for (const auto & group_index :
1928
+ c10::irange (index_replacement_maps.size ())) {
1929
+ // Set the index replacement map with the concrete values of
1930
+ // indices of grouped loops.
1931
+ index_replacement_map_ = index_replacement_maps.at (group_index);
1932
+
1933
+ data_types.arg (data_type);
1934
+ index_types.arg (index_type);
1935
+
1936
+ auto work_buffer_offset = group_index == 0
1937
+ ? " 0"
1938
+ : (genInline (grouped_gwop->buffer_stride ()) + " * " +
1939
+ std::to_string (group_index));
1940
+
1941
+ // Setup arguments for avg, var, and N
1942
+ for (const auto i : c10::irange (3 )) {
1943
+ out_args[i].arg (gen (output.get (i)));
1944
+ in_args[i].arg (gen (input.get (i)));
1945
+ init_args[i].arg (gen (init.get (i)));
1946
+ const auto work_buffer = grouped_gwop->reduction_buffers ()[i]
1947
+ .at (expr_index)
1948
+ ->buffer ()
1949
+ ->as <TensorView>();
1950
+ work_bufs[i]
1951
+ .arg (" &" )
1952
+ .append (varName (work_buffer))
1953
+ .append (" [" )
1954
+ .append (work_buffer_offset)
1955
+ .append (" ]" );
1956
+ }
1957
+
1958
+ // read and write predicates
1959
+ bool_types.arg (" bool" );
1960
+ // Same argument for all inputs. Different predicates would be
1961
+ // used when grouping is done across iterations
1962
+ TORCH_INTERNAL_ASSERT (grouped_gwop->predicate () != nullptr );
1963
+ TORCH_INTERNAL_ASSERT (
1964
+ grouped_gwop->predicate () != nullptr &&
1965
+ grouped_gwop->predicate ()->hasValue ());
1966
+ const auto read_pred = genInline (grouped_gwop->predicate ());
1967
+ read_preds.arg (read_pred);
1968
+ if (grouped_gwop->writePredicate () != nullptr ) {
1969
+ TORCH_INTERNAL_ASSERT (grouped_gwop->writePredicate ()->hasValue ());
1970
+ write_preds.arg (genInline (grouped_gwop->writePredicate ()));
1971
+ } else {
1972
+ write_preds.arg (read_pred);
1973
+ }
1974
+
1975
+ index_replacement_map_.clear ();
1976
+ }
1977
+ }
1978
+
1979
+ ArgumentBuilder func_args (block_nest_level_ + 1 , kTab );
1980
+ // output
1981
+ func_args.arg (genCall (" RefTuple" , data_types, out_args[0 ]));
1982
+ func_args.arg (genCall (" RefTuple" , data_types, out_args[1 ]));
1983
+ func_args.arg (genCall (" RefTuple" , index_types, out_args[2 ]));
1984
+ // input
1985
+ func_args.arg (genCall (" ConstRefTuple" , data_types, in_args[0 ]));
1986
+ func_args.arg (genCall (" ConstRefTuple" , data_types, in_args[1 ]));
1987
+ func_args.arg (genCall (" ConstRefTuple" , index_types, in_args[2 ]));
1988
+ // init
1989
+ func_args.arg (genCall (" LocalTuple" , data_types, init_args[0 ]));
1990
+ func_args.arg (genCall (" LocalTuple" , data_types, init_args[1 ]));
1991
+ func_args.arg (genCall (" LocalTuple" , index_types, init_args[2 ]));
1992
+ // work buffer
1993
+ func_args.arg (genCall (" VolatilePtrTuple" , data_types, work_bufs[0 ]));
1994
+ func_args.arg (genCall (" VolatilePtrTuple" , data_types, work_bufs[1 ]));
1995
+ func_args.arg (genCall (" VolatilePtrTuple" , index_types, work_bufs[2 ]));
1996
+ // global_sync_buffer
1997
+ const auto sync_buffer =
1998
+ grouped_gwop->sync_buffer ()->buffer ()->as <TensorView>();
1999
+ func_args.arg (" &" ).append (varName (sync_buffer)).append (" [0]" );
2000
+
2001
+ // shared_buf
2002
+ ArgumentBuilder smem_buffer_args;
2003
+ smem_buffer_args.arg (
2004
+ genCall (" reinterpret_cast" , ptrType (data_type), " shared_mem_avg" ));
2005
+ smem_buffer_args.arg (
2006
+ genCall (" reinterpret_cast" , ptrType (data_type), " shared_mem_var" ));
2007
+ smem_buffer_args.arg (
2008
+ genCall (" reinterpret_cast" , ptrType (index_type), " shared_mem_n" ));
2009
+ func_args.arg (genCall (
2010
+ " PtrTuple" ,
2011
+ ArgumentBuilder ().arg (data_type).arg (data_type).arg (index_type),
2012
+ smem_buffer_args));
2013
+
2014
+ func_args.arg (genCall (" LocalTuple" , bool_types, read_preds));
2015
+ func_args.arg (genCall (" LocalTuple" , bool_types, write_preds));
2016
+
2017
+ addProfileArguments (func_args, grouped_gwop);
2018
+
2019
+ ArgumentBuilder func_template_args;
2020
+ func_template_args.arg (
2021
+ grouped_gwop->numExprs () * index_replacement_maps.size ());
2022
+ func_template_args.arg (data_type);
2023
+ func_template_args.arg (index_type);
2024
+
2025
+ indent () << genCall (
2026
+ genFusedReductionName (ir_utils::getTvOutput (grouped_gwop)) +
2027
+ " .welfordGroup" ,
2028
+ func_template_args,
2029
+ func_args)
2030
+ << " ;\n " ;
2031
+ }
2032
+
1875
2033
void handle (const kir::GridBroadcast* grop) final {
1876
2034
const auto bop = grop->broadcast_op ();
1877
2035
TORCH_INTERNAL_ASSERT (bop->out ()->isA <kir::TensorIndex>());
@@ -2208,6 +2366,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {
2208
2366
}
2209
2367
}
2210
2368
2369
+ void handle (const GroupedWelfordOp* grouped_wop) final {
2370
+ TORCH_INTERNAL_ASSERT (
2371
+ false ,
2372
+ " Should not reach here as grouped welford is only enabled for grid welford," ,
2373
+ " which is handled by its own handler" );
2374
+ }
2375
+
2211
2376
// ! True if loop is grouped. The IterDomain of the loop must have
2212
2377
// ! ParallelType::Group, but it isn't sufficient as the loop may be
2213
2378
// ! for an initialization expression, for which the loop shold not
@@ -2216,7 +2381,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
2216
2381
if (loop->iter_domain ()->getParallelType () != ParallelType::Group) {
2217
2382
return false ;
2218
2383
}
2219
- return ExprFinder::exists (loop, {ExprType::GroupedGridReduction});
2384
+ return ExprFinder::exists (
2385
+ loop, {ExprType::GroupedGridReduction, ExprType::GroupedGridWelford});
2220
2386
}
2221
2387
2222
2388
void handle (const kir::ForLoop* loop) final {
0 commit comments