diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index f9f13d916b53b5..8e5831cd099463 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -1417,12 +1417,12 @@ WelfordResult Welford( out_avg, out_var, out_N, /*out var/avg/count */ + tv, /*in var/avg/count */ + FusionGuard::getCurFusion()->zeroVal(), + FusionGuard::getCurFusion()->oneVal(), init_avg_val, init_var_val, - init_N, /*init var/avg/count */ - tv, - FusionGuard::getCurFusion()->zeroVal(), - FusionGuard::getCurFusion()->oneVal()); /*in var/avg/count */ + init_N); /*init var/avg/count */ return WelfordResult(out_avg, out_var, out_N); } diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 1fa8425c465dec..07e6564f274359 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -1671,6 +1671,16 @@ class CudaKernelGenerator : private OptOutConstDispatch { indent() << kTab << func_args << ");\n"; } + void handle(const kir::GroupedGridWelford* grouped_gwop) final { + if (grouped_gwop->isAllreduce()) { + generateGroupedGridAllreduceWelford(grouped_gwop); + return; + } else { + TORCH_INTERNAL_ASSERT( + false, "Non-allreduce grouped grid welford is not yet supported"); + } + } + // Enumerates all combinations of index values of grouped // loops. Each combination is a vector of loop index values. The // length of the vector is the number of grouped loops. @@ -1872,6 +1882,154 @@ class CudaKernelGenerator : private OptOutConstDispatch { indent() << kTab << func_args << ");\n"; } + // Mostly the same as the grouped grid redution version + void generateGroupedGridAllreduceWelford( + const kir::GroupedGridWelford* grouped_gwop) { + TORCH_INTERNAL_ASSERT(grouped_gwop->isAllreduce()); + + const auto index_replacement_maps = getLoopIndexReplacementMaps(); + const auto num_grouped_iterations = index_replacement_maps.size(); + + // This is also checked at the lowering validaiton time, so it + // isn't strictly necessary. + TORCH_INTERNAL_ASSERT( + num_grouped_iterations * grouped_gwop->numExprs() <= + kMaxNumGroupedReductions, + "Too many grouped reductions: ", + grouped_gwop->toString(), + ". Up to ", + kMaxNumGroupedReductions, + " reductions are allowed."); + + ArgumentBuilder data_types; + ArgumentBuilder index_types; + + // Note that the data type of var and avg and that of N are the + // same with all the welford ops since we only support + // grouping of iterations. + const auto data_type = grouped_gwop->outputVals().at(0).avg()->dtype(); + const auto index_type = grouped_gwop->outputVals().at(0).N()->dtype(); + + std::array out_args; + std::array in_args; + std::array init_args; + std::array work_bufs; + + ArgumentBuilder bool_types; + ArgumentBuilder read_preds; + ArgumentBuilder write_preds; + + for (const auto expr_index : c10::irange(grouped_gwop->numExprs())) { + const auto& output = grouped_gwop->outputVals().at(expr_index); + const auto& input = grouped_gwop->inputVals().at(expr_index); + const auto& init = grouped_gwop->initVals().at(expr_index); + + for (const auto& group_index : + c10::irange(index_replacement_maps.size())) { + // Set the index replacement map with the concrete values of + // indices of grouped loops. + index_replacement_map_ = index_replacement_maps.at(group_index); + + data_types.arg(data_type); + index_types.arg(index_type); + + auto work_buffer_offset = group_index == 0 + ? "0" + : (genInline(grouped_gwop->buffer_stride()) + " * " + + std::to_string(group_index)); + + // Setup arguments for avg, var, and N + for (const auto i : c10::irange(3)) { + out_args[i].arg(gen(output.get(i))); + in_args[i].arg(gen(input.get(i))); + init_args[i].arg(gen(init.get(i))); + const auto work_buffer = grouped_gwop->reduction_buffers()[i] + .at(expr_index) + ->buffer() + ->as(); + work_bufs[i] + .arg("&") + .append(varName(work_buffer)) + .append("[") + .append(work_buffer_offset) + .append("]"); + } + + // read and write predicates + bool_types.arg("bool"); + // Same argument for all inputs. Different predicates would be + // used when grouping is done across iterations + TORCH_INTERNAL_ASSERT(grouped_gwop->predicate() != nullptr); + TORCH_INTERNAL_ASSERT( + grouped_gwop->predicate() != nullptr && + grouped_gwop->predicate()->hasValue()); + const auto read_pred = genInline(grouped_gwop->predicate()); + read_preds.arg(read_pred); + if (grouped_gwop->writePredicate() != nullptr) { + TORCH_INTERNAL_ASSERT(grouped_gwop->writePredicate()->hasValue()); + write_preds.arg(genInline(grouped_gwop->writePredicate())); + } else { + write_preds.arg(read_pred); + } + + index_replacement_map_.clear(); + } + } + + ArgumentBuilder func_args(block_nest_level_ + 1, kTab); + // output + func_args.arg(genCall("RefTuple", data_types, out_args[0])); + func_args.arg(genCall("RefTuple", data_types, out_args[1])); + func_args.arg(genCall("RefTuple", index_types, out_args[2])); + // input + func_args.arg(genCall("ConstRefTuple", data_types, in_args[0])); + func_args.arg(genCall("ConstRefTuple", data_types, in_args[1])); + func_args.arg(genCall("ConstRefTuple", index_types, in_args[2])); + // init + func_args.arg(genCall("LocalTuple", data_types, init_args[0])); + func_args.arg(genCall("LocalTuple", data_types, init_args[1])); + func_args.arg(genCall("LocalTuple", index_types, init_args[2])); + // work buffer + func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[0])); + func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[1])); + func_args.arg(genCall("VolatilePtrTuple", index_types, work_bufs[2])); + // global_sync_buffer + const auto sync_buffer = + grouped_gwop->sync_buffer()->buffer()->as(); + func_args.arg("&").append(varName(sync_buffer)).append("[0]"); + + // shared_buf + ArgumentBuilder smem_buffer_args; + smem_buffer_args.arg( + genCall("reinterpret_cast", ptrType(data_type), "shared_mem_avg")); + smem_buffer_args.arg( + genCall("reinterpret_cast", ptrType(data_type), "shared_mem_var")); + smem_buffer_args.arg( + genCall("reinterpret_cast", ptrType(index_type), "shared_mem_n")); + func_args.arg(genCall( + "PtrTuple", + ArgumentBuilder().arg(data_type).arg(data_type).arg(index_type), + smem_buffer_args)); + + func_args.arg(genCall("LocalTuple", bool_types, read_preds)); + func_args.arg(genCall("LocalTuple", bool_types, write_preds)); + + addProfileArguments(func_args, grouped_gwop); + + ArgumentBuilder func_template_args; + func_template_args.arg( + grouped_gwop->numExprs() * index_replacement_maps.size()); + func_template_args.arg(data_type); + func_template_args.arg(index_type); + + indent() << genCall( + genFusedReductionName(ir_utils::getTvOutput(grouped_gwop)) + + ".welfordGroup", + func_template_args, + func_args) + << ";\n"; + } + void handle(const kir::GridBroadcast* grop) final { const auto bop = grop->broadcast_op(); TORCH_INTERNAL_ASSERT(bop->out()->isA()); @@ -2208,6 +2366,13 @@ class CudaKernelGenerator : private OptOutConstDispatch { } } + void handle(const GroupedWelfordOp* grouped_wop) final { + TORCH_INTERNAL_ASSERT( + false, + "Should not reach here as grouped welford is only enabled for grid welford,", + " which is handled by its own handler"); + } + //! True if loop is grouped. The IterDomain of the loop must have //! ParallelType::Group, but it isn't sufficient as the loop may be //! for an initialization expression, for which the loop shold not @@ -2216,7 +2381,8 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (loop->iter_domain()->getParallelType() != ParallelType::Group) { return false; } - return ExprFinder::exists(loop, {ExprType::GroupedGridReduction}); + return ExprFinder::exists( + loop, {ExprType::GroupedGridReduction, ExprType::GroupedGridWelford}); } void handle(const kir::ForLoop* loop) final { diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 8a514d4c2e16fc..676cb80866ea58 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -113,6 +113,9 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::WelfordOp: ptr(handler)->handle(expr->as()); return; + case ExprType::GroupedWelfordOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::LoadStoreOp: ptr(handler)->handle(expr->as()); return; @@ -190,6 +193,9 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::GridWelford: ptr(handler)->handle(expr->as()); return; + case ExprType::GroupedGridWelford: + ptr(handler)->handle(expr->as()); + return; case ExprType::AllocateFusedReduction: ptr(handler)->handle(expr->as()); return; @@ -287,6 +293,9 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::WelfordOp: ptr(handler)->handle(expr->as()); return; + case ExprType::GroupedWelfordOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::LoadStoreOp: ptr(handler)->handle(expr->as()); return; @@ -364,6 +373,9 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::GridWelford: ptr(handler)->handle(expr->as()); return; + case ExprType::GroupedGridWelford: + ptr(handler)->handle(expr->as()); + return; case ExprType::AllocateFusedReduction: ptr(handler)->handle(expr->as()); return; @@ -469,6 +481,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) { case ExprType::WelfordOp: ptr(mutator)->mutate(expr->as()); return; + case ExprType::GroupedWelfordOp: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::LoadStoreOp: ptr(mutator)->mutate(expr->as()); return; @@ -546,6 +561,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) { case ExprType::GridWelford: ptr(mutator)->mutate(expr->as()); return; + case ExprType::GroupedGridWelford: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::AllocateFusedReduction: ptr(mutator)->mutate(expr->as()); return; @@ -716,6 +734,9 @@ void OptOutConstDispatch::handle(const GroupedReductionOp* stmt) { void OptOutConstDispatch::handle(const WelfordOp* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const GroupedWelfordOp* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const LoadStoreOp* stmt) { unhandled(stmt); } @@ -793,6 +814,9 @@ void OptOutConstDispatch::handle(const kir::GridBroadcast* stmt) { void OptOutConstDispatch::handle(const kir::GridWelford* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const kir::GroupedGridWelford* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const kir::AllocateFusedReduction* stmt) { unhandled(stmt); } @@ -860,6 +884,9 @@ void OptOutDispatch::handle(GroupedReductionOp* stmt) { void OptOutDispatch::handle(WelfordOp* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(GroupedWelfordOp* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(LoadStoreOp* stmt) { unhandled(stmt); } @@ -937,6 +964,9 @@ void OptOutDispatch::handle(kir::GridBroadcast* stmt) { void OptOutDispatch::handle(kir::GridWelford* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(kir::GroupedGridWelford* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(kir::AllocateFusedReduction* stmt) { unhandled(stmt); } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index f680871b6460cb..5f84ecca406965 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -74,6 +74,7 @@ class TernaryOp; class ReductionOp; class GroupedReductionOp; class WelfordOp; +class GroupedWelfordOp; class LoadStoreOp; class MmaOp; class BroadcastOp; @@ -105,6 +106,7 @@ class GridReduction; class GroupedGridReduction; class GridBroadcast; class GridWelford; +class GroupedGridWelford; class AllocateFusedReduction; class InitMagicZero; class UpdateMagicZero; @@ -146,6 +148,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const ReductionOp* stmt); virtual void handle(const GroupedReductionOp* stmt); virtual void handle(const WelfordOp* stmt); + virtual void handle(const GroupedWelfordOp* stmt); virtual void handle(const LoadStoreOp* stmt); virtual void handle(const MmaOp* stmt); virtual void handle(const BroadcastOp* stmt); @@ -173,6 +176,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const kir::GroupedGridReduction*); virtual void handle(const kir::GridBroadcast*); virtual void handle(const kir::GridWelford*); + virtual void handle(const kir::GroupedGridWelford*); virtual void handle(const kir::AllocateFusedReduction*); virtual void handle(const kir::Swizzle2DInt*); virtual void handle(const kir::PairSelect*); @@ -209,6 +213,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(ReductionOp* stmt); virtual void handle(GroupedReductionOp* stmt); virtual void handle(WelfordOp* stmt); + virtual void handle(GroupedWelfordOp* stmt); virtual void handle(LoadStoreOp* stmt); virtual void handle(MmaOp* stmt); virtual void handle(BroadcastOp* stmt); @@ -236,6 +241,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(kir::GroupedGridReduction* stmt); virtual void handle(kir::GridBroadcast* stmt); virtual void handle(kir::GridWelford* stmt); + virtual void handle(kir::GroupedGridWelford* stmt); virtual void handle(kir::AllocateFusedReduction* stmt); virtual void handle(kir::Swizzle2DInt* stmt); virtual void handle(kir::PairSelect* stmt); @@ -313,6 +319,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(ReductionOp*); virtual void mutate(GroupedReductionOp*); virtual void mutate(WelfordOp*); + virtual void mutate(GroupedWelfordOp*); virtual void mutate(LoadStoreOp*); virtual void mutate(MmaOp*); virtual void mutate(BroadcastOp*); @@ -340,6 +347,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(kir::GroupedGridReduction*); virtual void mutate(kir::GridBroadcast*); virtual void mutate(kir::GridWelford*); + virtual void mutate(kir::GroupedGridWelford*); virtual void mutate(kir::AllocateFusedReduction*); virtual void mutate(kir::Swizzle2DInt*); virtual void mutate(kir::PairSelect*); diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index d2299a0ce54974..ecd006f1461d6c 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -860,7 +860,13 @@ std::vector FusionExecutor::runFusion( "what can be resident on the GPU at once. Need: ", launch_params_.gdimx() * launch_params_.gdimy() * launch_params_.gdimz(), - " but limited to ", + " (", + launch_params_.gdimx(), + " * ", + launch_params_.gdimy(), + " * ", + launch_params_.gdimz(), + ") but limited to ", num_blocks_per_SM, " * ", at::cuda::getDeviceProperties(options_.device.index()) diff --git a/torch/csrc/jit/codegen/cuda/grouped_reduction.h b/torch/csrc/jit/codegen/cuda/grouped_reduction.h index 39e6e0850e67a2..330a6018446bd9 100644 --- a/torch/csrc/jit/codegen/cuda/grouped_reduction.h +++ b/torch/csrc/jit/codegen/cuda/grouped_reduction.h @@ -27,6 +27,10 @@ namespace cuda { //! dimensions, the same transformations and the same axes to //! reduce. //! +//! Note that Welford is not allowed yet, though it should be +//! technically straightforward to support horizontal fusions of +//! welford ops. Unclear how common it would be in practice, though. +//! //! \param reduction_outputs Tensors produced by ReductionOp TORCH_CUDA_CU_API void groupReductions( const std::vector& reduction_outputs); diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 51fc2354431bbe..129a0a0e79e429 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -247,6 +247,10 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr { return is_allreduce_; } + //! Return the index of the corresponding reduction expression for + //! a given output val. + int getExprIndexOfOutput(Val* output_val) const; + bool sameAs(const Statement* other) const override; private: @@ -258,78 +262,215 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr { bool is_allreduce_ = false; }; +//! Average, variance and N (count) vals for Welford +class TORCH_CUDA_CU_API WelfordTriplet { + public: + //! Names of the Welford triplet vals + enum class ValName { Avg, Var, N }; + + WelfordTriplet() = default; + + WelfordTriplet(Val* avg, Val* var, Val* N) : vals_({avg, var, N}) {} + + Val* const& avg() const { + return get(ValName::Avg); + } + + Val*& avg() { + return get(ValName::Avg); + } + + TensorView* avgTv() const { + TORCH_INTERNAL_ASSERT(avg()->isA()); + return avg()->as(); + } + + Val* const& var() const { + return get(ValName::Var); + } + + Val*& var() { + return get(ValName::Var); + } + + TensorView* varTv() const { + TORCH_INTERNAL_ASSERT(var()->isA()); + return var()->as(); + } + + Val* const& N() const { + return get(ValName::N); + } + + Val*& N() { + return get(ValName::N); + } + + TensorView* NTv() const { + TORCH_INTERNAL_ASSERT(N()->isA()); + return N()->as(); + } + + //! Get the i-th val. Ordering is defined by ValName. + Val* const& get(int i) const { + return vals_.at(i); + } + + //! Get the i-th val. Ordering is defined by ValName. + Val*& get(int i) { + return vals_.at(i); + } + + Val* const& get(ValName name) const { + return get(valNameToIndex(name)); + } + + Val*& get(ValName name) { + return get(valNameToIndex(name)); + } + + //! Get the name of a given val in this triplet. None is returned if + //! not found. + c10::optional getNameOf(Val* val) const; + + //! Return a new triplet with outputs produced by a function applied + //! to each of this triplet + template + WelfordTriplet transform(Func func) const { + return WelfordTriplet(func(avg()), func(var()), func(N())); + } + + bool sameAs(const WelfordTriplet& other) const; + + WelfordTriplet clone(IrCloner* ir_cloner) const; + + //! Clone a vector of triplets + static std::vector clone( + const std::vector& src, + IrCloner* ir_cloner); + + auto begin() { + return vals_.begin(); + } + + auto begin() const { + return vals_.begin(); + } + + auto end() { + return vals_.end(); + } + + auto end() const { + return vals_.end(); + } + + private: + //! Convert a given val name to an index + static int valNameToIndex(ValName name) { + return static_cast(name); + } + + //! Convert a given index to a name + static ValName indexToValName(int index) { + TORCH_INTERNAL_ASSERT(index >= 0 && index < 3, "Invalid index: ", index); + return static_cast(index); + } + + private: + //! Holds avg, var and N in this order + std::array vals_ = {nullptr, nullptr, nullptr}; +}; + //! Welford Scan operation. class TORCH_CUDA_CU_API WelfordOp : public Expr { public: + WelfordOp( + IrBuilderPasskey, + const WelfordTriplet& output, + const WelfordTriplet& input, + const WelfordTriplet& init, + bool is_fused = false); + WelfordOp( IrBuilderPasskey, Val* out_avg, Val* out_var, Val* out_N, - Val* init_avg, - Val* init_var, - Val* init_N, Val* in_avg, Val* in_var, Val* in_N, + Val* init_avg, + Val* init_var, + Val* init_N, bool is_fused = false); WelfordOp(const WelfordOp* src, IrCloner* ir_cloner); Val* out() const { - return out_avg_; + return output().avg(); } Val* in() const { - return in_avg_; + return input().avg(); } bool sameAs(const Statement* const other) const override; - // Welford Accessors - // TODO clean up + const WelfordTriplet& output() const { + return output_; + } + Val* outAvg() const { - return out_avg_; + return output().avg(); } Val* outVar() const { - return out_var_; + return output().var(); } Val* outN() const { - return out_N_; + return output().N(); + } + + const WelfordTriplet& input() const { + return input_; } Val* inAvg() const { - return in_avg_; + return input().avg(); } Val* inVar() const { - return in_var_; + return input().var(); } Val* inN() const { - return in_N_; + return input().N(); + } + + const WelfordTriplet& init() const { + return init_; } Val* initAvg() const { - return init_avg_; + return init().avg(); } Val* initVar() const { - return init_var_; + return init().var(); } Val* initN() const { - return init_N_; + return init().N(); } bool singleValue() const { - return in_N_->isOneInt(); + return inN()->isOneInt(); } bool hasInit() const { - return !init_N_->isZeroInt(); + return !initN()->isZeroInt(); } bool isAllreduce() const { @@ -338,20 +479,121 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { std::vector getInitVals() const; + //! Return the init val for an output val + Val* getInitValOfOutput(Val* output_val) const; + private: - Val* const out_avg_; - Val* const out_var_; - Val* const out_N_; - Val* const init_avg_; - Val* const init_var_; - Val* const init_N_; - Val* const in_avg_; - Val* const in_var_; - Val* const in_N_; + const WelfordTriplet output_; + const WelfordTriplet input_; + const WelfordTriplet init_; //! True if using the fused reduction kernel (not implemented yet) bool is_allreduce_ = false; }; +class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { + public: + GroupedWelfordOp( + IrBuilderPasskey, + std::vector output_vals, + std::vector input_vals, + std::vector init_vals, + bool is_allreduce = false, + ExprType expr_type = ExprType::GroupedWelfordOp); + + GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner); + + //! Number of expressions grouped horizontally. It does not reflect + //! iteration grouping. As horizontal grouping is not supported, + //! this always returns 1. + size_t numExprs() const { + return 1; + } + + Val* out(size_t index) const { + return outAvg(index); + } + + Val* in(size_t index) const { + return inAvg(index); + } + + bool sameAs(const Statement* const other) const override; + + const std::vector& outputVals() const { + return output_vals_; + } + + const std::vector& inputVals() const { + return input_vals_; + } + + const std::vector& initVals() const { + return init_vals_; + } + + Val* outAvg(size_t index) const { + return outputVals().at(index).avg(); + } + + Val* outVar(size_t index) const { + return outputVals().at(index).var(); + } + + Val* outN(size_t index) const { + return outputVals().at(index).N(); + } + + Val* inAvg(size_t index) const { + return inputVals().at(index).avg(); + } + + Val* inVar(size_t index) const { + return inputVals().at(index).var(); + } + + Val* inN(size_t index) const { + return inputVals().at(index).N(); + } + + Val* initAvg(size_t index) const { + return initVals().at(index).avg(); + } + + Val* initVar(size_t index) const { + return initVals().at(index).var(); + } + + Val* initN(size_t index) const { + return initVals().at(index).N(); + } + + //! Return the index of the corresponding welford expression for + //! a given output val + int getExprIndexOfOutput(Val* output_val) const; + + //! Return the init val for an output val + Val* getInitValOfOutput(Val* output_val) const; + + bool singleValue(size_t index) const { + return inN(index)->isOneInt(); + } + + bool hasInit(size_t index) const { + return !initN(index)->isZeroInt(); + } + + bool isAllreduce() const { + return is_allreduce_; + } + + private: + const std::vector output_vals_; + const std::vector input_vals_; + const std::vector init_vals_; + //! True if using the fused reduction kernel + bool is_allreduce_ = false; +}; + //! Fused Matmul operation class TORCH_CUDA_CU_API MmaOp : public Expr { public: diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index c435c2c557162f..005eeea8ae21c3 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -398,11 +398,12 @@ void IrPrinter::handle(const ReductionOp* rop) { indent() << " = reduction( " << rop->in() << ", op = " << rop->getReductionOpType() << ", initial value = " << rop->init() - << ", allreduce = " << rop->isAllreduce() << " )\n"; + << ", allreduce = " << (rop->isAllreduce() ? "true" : "false") + << " )\n"; } void IrPrinter::handle(const GroupedReductionOp* grouped_rop) { - indent() << "Grouped reduction(\n"; + indent() << "GroupedReductionOp(\n"; ++indent_size_; for (const auto i : c10::irange(grouped_rop->numExprs())) { indent() << grouped_rop->output(i) << " = reduction( " @@ -430,10 +431,34 @@ void IrPrinter::handle(const WelfordOp* wop) { os_ << "\n initial value = " << wop->initAvg() << "(Avg)\n " << wop->initVar() << "(Var)\n " << wop->initN() << "(N)"; } - os_ << "\n allreduce = " << wop->isAllreduce(); + os_ << "\n allreduce = " << (wop->isAllreduce() ? "true" : "false"); os_ << " )\n"; } +void IrPrinter::handle(const GroupedWelfordOp* grouped_wop) { + indent() << "GroupedWelford(\n"; + ++indent_size_; + for (const auto i : c10::irange(grouped_wop->numExprs())) { + indent() << grouped_wop->outAvg(i) << " (Avg),\n"; + indent() << grouped_wop->outVar(i) << " (Var),\n"; + indent() << grouped_wop->outN(i) << " (Count)\n"; + indent() << " = Welford ( "; + ++indent_size_; + indent() << grouped_wop->inAvg(i) << " (Avg),\n"; + indent() << grouped_wop->inVar(i) << " (Var),\n"; + indent() << grouped_wop->inN(i) << " (Count)\n"; + indent() << "initial value =\n"; + ++indent_size_; + indent() << grouped_wop->initAvg(i) << " (Avg),\n"; + indent() << grouped_wop->initVar(i) << " (Var),\n"; + indent() << grouped_wop->initN(i) << " (Count) )\n"; + indent_size_ -= 2; + } + indent() << "allreduce = " << (grouped_wop->isAllreduce() ? "true" : "false") + << " )\n"; + --indent_size_; +} + void IrPrinter::handle(const LoadStoreOp* ldst) { indent() << ldst->out() << " = " << ldst->opType() << "( " << ldst->in() << " )\n"; @@ -649,138 +674,171 @@ void IrPrinter::handle(const kir::GridBroadcast* node) { } void IrPrinter::handle(const kir::GridReduction* node) { - indent(); - handle(node->out()); - os_ << " = " - << "GRID_REDUCTION(op='" << node->getReductionOpType() << "'" - << ", in="; - handle(node->in()); - os_ << ", init="; - handle(node->init()); - os_ << ", read_pred="; + indent() << node->out() << " = reduction( " << node->in() + << ", op = " << node->getReductionOpType() + << ", initial value = " << node->init() << ",\n"; + ++indent_size_; + indent() << "reduction buffer = " << node->reduction_buffer()->buffer() + << ",\n"; + indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n"; + indent() << "read predicate = "; if (node->predicate() != nullptr) { - handle(node->predicate()); + os_ << node->predicate(); } else { os_ << "nullptr"; } - os_ << ")\n"; - os_ << ", write_pred="; + os_ << ",\n"; + indent() << "write predicate = "; if (node->writePredicate() != nullptr) { - handle(node->writePredicate()); + os_ << node->writePredicate(); } else { os_ << "nullptr"; } - os_ << ")\n"; - indent() << kTab << ".reduction_buffer="; - handle(node->reduction_buffer()->buffer()); - os_ << "\n"; - indent() << kTab << ".sync_buffer="; - handle(node->sync_buffer()->buffer()); - os_ << "\n"; + os_ << ",\n"; + indent() << "thread predicate = " << node->threadPredicate().toString() + << ",\n"; + indent() << "allreduce = " << (node->isAllreduce() ? "true" : "false") + << " )\n"; + --indent_size_; } void IrPrinter::handle(const kir::GroupedGridReduction* node) { - indent() << "Grouped grid reduction(\n"; + indent() << "GroupedGridReduction(\n"; ++indent_size_; for (const auto i : c10::irange(node->numExprs())) { - indent(); - handle(node->output(i)); - os_ << " = " - << "reduction(op='" << node->getReductionOpType(i) << "'" - << ", in="; - handle(node->input(i)); - os_ << ", init="; - handle(node->initVal(i)); - os_ << "\n"; - } - indent() << kTab << ".read_pred="; + indent() << node->output(i) << " = reduction( " << node->input(i) + << ", op = " << node->getReductionOpType(i) + << ", initial value = " << node->initVal(i) + << ", reduction buffer = " + << node->reduction_buffers().at(i)->buffer() << " )\n"; + } + indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n"; + indent() << "read predicate = "; if (node->predicate() != nullptr) { - handle(node->predicate()); + os_ << node->predicate(); } else { os_ << "nullptr"; } - os_ << "\n"; - indent() << kTab << ".write_pred="; + os_ << ",\n"; + indent() << "write predicate = "; if (node->writePredicate() != nullptr) { - handle(node->writePredicate()); + os_ << node->writePredicate(); } else { os_ << "nullptr"; } - os_ << "\n"; - for (const auto i : c10::irange(node->numExprs())) { - indent() << kTab << ".reduction_buffer="; - handle(node->reduction_buffers().at(i)->buffer()); - os_ << "\n"; - } - indent() << kTab << ".sync_buffer="; - handle(node->sync_buffer()->buffer()); - os_ << "\n"; + os_ << ",\n"; + indent() << "thread predicate = " << node->threadPredicate().toString() + << ",\n"; + indent() << "allreduce = " << (node->isAllreduce() ? "true" : "false") + << " )\n"; + --indent_size_; } void IrPrinter::handle(const kir::GridWelford* node) { + std::cerr << "current indent size: " << indent_size_ << std::endl; const auto* welford_op = node->welford_op(); - indent(); - handle(welford_op->outVar()); - os_ << ","; - handle(welford_op->outAvg()); - os_ << ","; - handle(welford_op->outN()); - os_ << " = " - << "GRID_WELFORD(" - << "inAvg="; - handle(welford_op->inAvg()); - if (!welford_op->inN()->isOneInt()) { - indent() << ", inVar="; - handle(welford_op->inVar()); - } - indent() << ", inN="; - handle(welford_op->inN()); - if (!welford_op->initN()->isZeroInt()) { - indent() << ", initVar="; - handle(welford_op->initVar()); - os_ << " initAvg="; - handle(welford_op->initAvg()); - os_ << " initN="; - handle(welford_op->initN()); - } - indent() << ", read_pred="; + indent() << welford_op->outAvg() << " (Avg),\n"; + indent() << welford_op->outVar() << " (Var),\n"; + indent() << welford_op->outN() << " (Count)\n"; + indent() << " = Welford (\n"; + ++indent_size_; + indent() << welford_op->inAvg() << " (Avg),\n"; + indent() << welford_op->inVar() << " (Var),\n"; + indent() << welford_op->inN() << " (Count)\n"; + indent() << "initial value =\n"; + ++indent_size_; + indent() << welford_op->initAvg() << " (Avg),\n"; + indent() << welford_op->initVar() << " (Var),\n"; + indent() << welford_op->initN() << " (Count),\n"; + --indent_size_; + indent() << "reduction buffer =\n"; + ++indent_size_; + indent() << node->avg_buffer()->buffer() << " (Avg),\n"; + indent() << node->var_buffer()->buffer() << " (Var),\n"; + indent() << node->N_buffer()->buffer() << " (Count),\n"; + --indent_size_; + indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n"; + indent() << "read predicate = "; if (welford_op->predicate() != nullptr) { - handle(welford_op->predicate()); + os_ << welford_op->predicate(); } else { os_ << "nullptr"; } - os_ << ")\n"; - indent() << ", write_pred="; + os_ << ",\n"; + indent() << "write predicate = "; if (welford_op->writePredicate() != nullptr) { - handle(welford_op->writePredicate()); + os_ << welford_op->writePredicate(); } else { os_ << "nullptr"; } - os_ << ")\n"; - indent() << kTab << ".var_buffer="; - handle(node->var_buffer()->buffer()); - os_ << ".avg_buffer="; - handle(node->avg_buffer()->buffer()); - os_ << ".n_buffer="; - handle(node->N_buffer()->buffer()); - os_ << "\n"; - indent() << kTab << ".sync_buffer="; - handle(node->sync_buffer()->buffer()); - os_ << "\n"; - indent() << kTab << ".grid_read_pred="; + os_ << ",\n"; + indent() << "grid read predicate = "; if (node->predicate() != nullptr) { - handle(node->predicate()); + os_ << node->predicate(); } else { os_ << "nullptr"; } - os_ << "\n"; - indent() << kTab << ".grid_write_pred="; + os_ << ",\n"; + indent() << "grid write predicate = "; if (node->writePredicate() != nullptr) { - handle(node->writePredicate()); + os_ << node->writePredicate(); } else { os_ << "nullptr"; } - os_ << "\n"; + os_ << ",\n"; + indent() << "thread predicate = " << node->threadPredicate().toString() + << ",\n"; + indent() << "allreduce = " << (welford_op->isAllreduce() ? "true" : "false") + << " )\n"; + --indent_size_; + std::cerr << "Ending indent size: " << indent_size_ << std::endl; +} + +void IrPrinter::handle(const kir::GroupedGridWelford* node) { + indent() << "GroupedGridWelford(\n"; + ++indent_size_; + for (const auto i : c10::irange(node->numExprs())) { + indent() << node->outAvg(i) << " (Avg),\n"; + indent() << node->outVar(i) << " (Var),\n"; + indent() << node->outN(i) << " (Count)\n"; + indent() << " = Welford (\n"; + ++indent_size_; + indent() << node->inAvg(i) << " (Avg),\n"; + indent() << node->inVar(i) << " (Var),\n"; + indent() << node->inN(i) << " (Count)\n"; + indent() << "initial value =\n"; + ++indent_size_; + indent() << node->initAvg(i) << " (Avg),\n"; + indent() << node->initVar(i) << " (Var),\n"; + indent() << node->initN(i) << " (Count),\n"; + --indent_size_; + indent() << "reduction buffer =\n"; + ++indent_size_; + indent() << node->reduction_buffers()[0].at(i)->buffer() << " (Avg),\n"; + indent() << node->reduction_buffers()[1].at(i)->buffer() << " (Var),\n"; + indent() << node->reduction_buffers()[2].at(i)->buffer() << " (Count) )\n"; + indent_size_ -= 2; + } + indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n"; + indent() << "read predicate = "; + if (node->predicate() != nullptr) { + os_ << node->predicate(); + } else { + os_ << "nullptr"; + } + os_ << ",\n"; + indent() << "write predicate = "; + if (node->writePredicate() != nullptr) { + os_ << node->writePredicate(); + } else { + os_ << "nullptr"; + } + os_ << ",\n"; + indent() << "thread predicate = " << node->threadPredicate().toString() + << ",\n"; + indent() << "allreduce = " << (node->isAllreduce() ? "true" : "false") + << " )\n"; + --indent_size_; } void IrPrinter::handle(const kir::InitMagicZero* node) { diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index 5266c259c1f5be..2df1ec2ec230aa 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -88,6 +88,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const ReductionOp*) final; void handle(const GroupedReductionOp*) final; void handle(const WelfordOp*) final; + void handle(const GroupedWelfordOp*) final; void handle(const LoadStoreOp*) final; void handle(const MmaOp*) final; void handle(const BroadcastOp*) final; @@ -106,6 +107,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const kir::GridReduction*) final; void handle(const kir::GroupedGridReduction*) final; void handle(const kir::GridWelford*) final; + void handle(const kir::GroupedGridWelford*) final; void handle(const kir::ForLoop*) final; void handle(const kir::IfThenElse*) final; void handle(const kir::Allocate*) final; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 47e6a39ab632d5..19999a92ca3ac0 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -452,6 +452,16 @@ GroupedReductionOp::GroupedReductionOp( init_vals_(ir_cloner->clone(src->init_vals_)), is_allreduce_(src->is_allreduce_) {} +int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { + auto it = std::find(outputs().begin(), outputs().end(), output_val); + if (it != outputs().end()) { + return std::distance(outputs().begin(), it); + } + + TORCH_INTERNAL_ASSERT( + false, "Not an output, ", output_val->toString(), ", of ", toString()); +} + bool GroupedReductionOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -478,128 +488,337 @@ bool GroupedReductionOp::sameAs(const Statement* other) const { WelfordOp::WelfordOp( IrBuilderPasskey passkey, - Val* out_avg, - Val* out_var, - Val* out_N, - Val* init_avg, - Val* init_var, - Val* init_N, - Val* in_avg, - Val* in_var, - Val* in_N, + const WelfordTriplet& output, + const WelfordTriplet& input, + const WelfordTriplet& init, bool is_fused) : Expr(passkey, ExprType::WelfordOp), - out_avg_(out_avg), - out_var_(out_var), - out_N_(out_N), - init_avg_(init_avg), - init_var_(init_var), - init_N_(init_N), - in_avg_(in_avg), - in_var_(in_var == nullptr ? in_avg->container()->zeroVal() : in_var), - in_N_(in_N), + output_(output), + input_(input), + init_(init), is_allreduce_(is_fused) { + // Previously, nullptr was accepted and implicitly replaced by + // default values. Looks like we always pass some non-null values, + // so removed the implicit default behavior for code simplicity. + TORCH_INTERNAL_ASSERT(output.avg() != nullptr); + TORCH_INTERNAL_ASSERT(output.var() != nullptr); + TORCH_INTERNAL_ASSERT(output.N() != nullptr); + TORCH_INTERNAL_ASSERT(init.avg() != nullptr); + TORCH_INTERNAL_ASSERT(init.var() != nullptr); + TORCH_INTERNAL_ASSERT(init.N() != nullptr); + TORCH_INTERNAL_ASSERT(input.avg() != nullptr); + TORCH_INTERNAL_ASSERT(input.var() != nullptr); + TORCH_INTERNAL_ASSERT(input.N() != nullptr); + // Check output type TORCH_INTERNAL_ASSERT( - out_avg->getValType().value() == ValType::TensorView || - out_avg->getValType().value() == ValType::TensorIndex); + output.avg()->getValType().value() == ValType::TensorView || + output.avg()->getValType().value() == ValType::TensorIndex); TORCH_INTERNAL_ASSERT( - out_var->getValType().value() == ValType::TensorView || - out_var->getValType().value() == ValType::TensorIndex); + output.var()->getValType().value() == ValType::TensorView || + output.var()->getValType().value() == ValType::TensorIndex); TORCH_INTERNAL_ASSERT( - out_N->getValType().value() == ValType::TensorView || - out_N->getValType().value() == ValType::TensorIndex); + output.N()->getValType().value() == ValType::TensorView || + output.N()->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT(isIntegralType(output.N()->dtype())); // check initial value - TORCH_INTERNAL_ASSERT(init_N->getValType().value() == ValType::Scalar); - if (!init_N->isZeroInt()) { + TORCH_INTERNAL_ASSERT(init.N()->getValType().value() == ValType::Scalar); + TORCH_INTERNAL_ASSERT(isIntegralType(init.N()->dtype())); + if (!init.N()->isZeroInt()) { // when initial count is zero, no initial variance or average is needed // initial value with a count of 1 is un-common enough that I'll push // the responsibility of creating all-zero var tensors to the user TORCH_INTERNAL_ASSERT( - init_avg && - (init_avg->getValType().value() == ValType::TensorView || - init_avg->getValType().value() == ValType::TensorIndex)); + init_.avg()->getValType().value() == ValType::TensorView || + init_.avg()->getValType().value() == ValType::TensorIndex); TORCH_INTERNAL_ASSERT( - init_var && - (init_var->getValType().value() == ValType::TensorView || - init_var->getValType().value() == ValType::TensorIndex)); + init_.var()->getValType().value() == ValType::TensorView || + init_.var()->getValType().value() == ValType::TensorIndex, + "Invalid initial var: ", + init_.var()->toString()); } - TORCH_INTERNAL_ASSERT( - in_avg && - (in_avg->getValType().value() == ValType::TensorView || - in_avg->getValType().value() == ValType::TensorIndex), - in_avg->getValType().value()); // check input TORCH_INTERNAL_ASSERT( - in_N->getValType().value() == ValType::Scalar || - in_N->getValType().value() == ValType::TensorView || - in_N->getValType().value() == ValType::TensorIndex); - if (!in_N->isOneInt()) { + input_.avg()->getValType().value() == ValType::TensorView || + input_.avg()->getValType().value() == ValType::TensorIndex, + input_.avg()->getValType().value()); + TORCH_INTERNAL_ASSERT( + input_.N()->getValType().value() == ValType::Scalar || + input_.N()->getValType().value() == ValType::TensorView || + input_.N()->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT(isIntegralType(input_.N()->dtype())); + if (!input_.N()->isOneInt()) { // when input is only one value, only the value is required through avg // input the var part is implicitly 0 and codegen will handle that. TORCH_INTERNAL_ASSERT( - in_var && - (in_var->getValType().value() == ValType::TensorView || - in_var->getValType().value() == ValType::TensorIndex)); + input_.var()->getValType().value() == ValType::TensorView || + input_.var()->getValType().value() == ValType::TensorIndex); } else { TORCH_INTERNAL_ASSERT( - in_var == nullptr || in_var->isZeroInt(), + input_.var() == nullptr || input_.var()->isZeroInt(), "Invalid var input, which must be either nullptr or scalar zero when the N input is one."); } - addOutput(out_avg_); - addOutput(out_var_); - addOutput(out_N_); + addOutput(output_.avg()); + addOutput(output_.var()); + addOutput(output_.N()); - addInput(in_avg_); - // Previously in_var_ was allowed to be null - TORCH_INTERNAL_ASSERT( - in_var_ != nullptr, "Welford var input nullptr not allowed"); - addInput(in_var_); - addInput(in_N_); + addInput(input_.avg()); + addInput(input_.var()); + addInput(input_.N()); +} + +c10::optional WelfordTriplet::getNameOf( + Val* val) const { + auto it = std::find(begin(), end(), val); + if (it != end()) { + return indexToValName(std::distance(begin(), it)); + } + + return c10::optional(); +} + +bool WelfordTriplet::sameAs(const WelfordTriplet& other) const { + return this == &other || + (avg()->sameAs(other.avg()) && var()->sameAs(other.var()) && + N()->sameAs(other.N())); +} + +WelfordTriplet WelfordTriplet::clone(IrCloner* ir_cloner) const { + return transform([&](const Val* val) { return ir_cloner->clone(val); }); } +std::vector WelfordTriplet::clone( + const std::vector& src, + IrCloner* ir_cloner) { + std::vector cloned; + for (const auto& triplet : src) { + cloned.emplace_back(triplet.clone(ir_cloner)); + } + return cloned; +} + +WelfordOp::WelfordOp( + IrBuilderPasskey passkey, + Val* out_avg, + Val* out_var, + Val* out_N, + Val* in_avg, + Val* in_var, + Val* in_N, + Val* init_avg, + Val* init_var, + Val* init_N, + bool is_fused) + : WelfordOp( + passkey, + WelfordTriplet(out_avg, out_var, out_N), + WelfordTriplet(in_avg, in_var, in_N), + WelfordTriplet(init_avg, init_var, init_N), + is_fused) {} + WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner) : Expr(src, ir_cloner), - out_avg_(ir_cloner->clone(src->out_avg_)), - out_var_(ir_cloner->clone(src->out_var_)), - out_N_(ir_cloner->clone(src->out_N_)), - init_avg_(src->init_avg_ ? ir_cloner->clone(src->init_avg_) : nullptr), - init_var_(src->init_var_ ? ir_cloner->clone(src->init_var_) : nullptr), - init_N_(ir_cloner->clone(src->init_N_)), - in_avg_(ir_cloner->clone(src->in_avg_)), - in_var_(src->in_var_ ? ir_cloner->clone(src->in_var_) : nullptr), - in_N_(ir_cloner->clone(src->in_N_)), + output_(src->output_.clone(ir_cloner)), + input_(src->input_.clone(ir_cloner)), + init_(src->init_.clone(ir_cloner)), is_allreduce_(src->is_allreduce_) {} -namespace { -inline bool sameOptionalVal(Val* a, Val* b) { - return ((a == nullptr && b == nullptr)) || ((a && b) && (a->sameAs(b))); +Val* WelfordOp::getInitValOfOutput(Val* output_val) const { + auto val_name = output().getNameOf(output_val); + + TORCH_INTERNAL_ASSERT( + val_name.has_value(), + "Not an output val ", + output_val->toString(), + " of ", + toString()); + + return init().get(*val_name); } -} // namespace bool WelfordOp::sameAs(const Statement* other) const { if (this == other) { return true; } if (auto other_wop = dynamic_cast(other)) { - return in_avg_->sameAs(other_wop->in_avg_) && - sameOptionalVal(in_var_, other_wop->in_var_) && - in_N_->sameAs(other_wop->in_N_) && - sameOptionalVal(init_avg_, other_wop->init_avg_) && - sameOptionalVal(init_var_, other_wop->init_var_) && - init_N_->sameAs(other_wop->init_N_); + return input_.sameAs(other_wop->input_) && init_.sameAs(other_wop->init_); } return false; } std::vector WelfordOp::getInitVals() const { - std::vector init_vals({init_avg_, init_var_, init_N_}); + std::vector init_vals({init_.avg(), init_.var(), init_.N()}); return init_vals; } +GroupedWelfordOp::GroupedWelfordOp( + IrBuilderPasskey passkey, + std::vector output_vals, + std::vector input_vals, + std::vector init_vals, + bool is_allreduce, + ExprType expr_type) + : Expr(passkey, expr_type), + output_vals_(std::move(output_vals)), + input_vals_(std::move(input_vals)), + init_vals_(std::move(init_vals)), + is_allreduce_(is_allreduce) { + const auto num_grouped_ops = output_vals_.size(); + + TORCH_INTERNAL_ASSERT( + input_vals_.size() == num_grouped_ops, + "Invalid number of input arguments. Expected: ", + num_grouped_ops, + ", Given: ", + input_vals_.size()); + TORCH_INTERNAL_ASSERT( + init_vals_.size() == num_grouped_ops, + "Invalid number of N arguments. Expected: ", + num_grouped_ops, + ", Given: ", + init_vals_.size()); + + for (const auto i : c10::irange(num_grouped_ops)) { + // Check output type + TORCH_INTERNAL_ASSERT( + output_vals_[i].avg()->getValType().value() == ValType::TensorView || + output_vals_[i].avg()->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT( + output_vals_[i].var()->getValType().value() == ValType::TensorView || + output_vals_[i].var()->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT( + output_vals_[i].N()->getValType().value() == ValType::TensorView || + output_vals_[i].N()->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT(isIntegralType(output_vals_[i].N()->dtype())); + + // check initial value + auto init_avg = init_vals_[i].avg(); + auto init_var = init_vals_[i].var(); + auto init_N = init_vals_[i].N(); + TORCH_INTERNAL_ASSERT( + init_avg != nullptr && init_var != nullptr && init_N != nullptr, + "nullptr init vals are not allowed"); + TORCH_INTERNAL_ASSERT(init_N->getValType().value() == ValType::Scalar); + TORCH_INTERNAL_ASSERT(isIntegralType(init_N->dtype())); + TORCH_INTERNAL_ASSERT( + init_avg->getValType().value() == ValType::TensorView || + init_avg->getValType().value() == ValType::TensorIndex || + (init_N->isZeroInt() && + init_avg->getValType().value() == ValType::Scalar), + "Initial avg must be a tensor or, can be a scalar if initial N is zero.", + " Initial avg: ", + init_avg->toString(), + ". Initial N: ", + init_N->toString()); + TORCH_INTERNAL_ASSERT( + init_var->getValType().value() == ValType::TensorView || + init_var->getValType().value() == ValType::TensorIndex || + (init_N->isZeroInt() && + init_var->getValType().value() == ValType::Scalar), + "Initial var must be a tensor or, can be a scalar if initial N is zero: ", + init_var->toString()); + + // check input + auto in_avg = input_vals_[i].avg(); + auto in_var = input_vals_[i].var(); + auto in_N = input_vals_[i].N(); + TORCH_INTERNAL_ASSERT( + in_avg != nullptr && in_var != nullptr && in_N != nullptr, + "nullptr input vals are not allowed"); + TORCH_INTERNAL_ASSERT( + in_N->getValType().value() == ValType::Scalar || + in_N->getValType().value() == ValType::TensorView || + in_N->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT(isIntegralType(in_N->dtype())); + TORCH_INTERNAL_ASSERT( + in_avg->getValType().value() == ValType::TensorView || + in_avg->getValType().value() == ValType::TensorIndex, + "Invalid input avg argument type: ", + in_avg->getValType().value()); + + if (in_N->isOneInt()) { + // when input is only one value, only the value is required through avg + // input the var part must be implicitly 0 + TORCH_INTERNAL_ASSERT( + in_var->isZeroInt(), + "Invalid var input, which must be scalar zero when the N input is one: ", + in_var->toString()); + } else { + TORCH_INTERNAL_ASSERT( + in_var->getValType().value() == ValType::TensorView || + in_var->getValType().value() == ValType::TensorIndex, + in_var->getValType().value(), + ", ", + in_N->toString()); + } + } + + for (const auto i : c10::irange(num_grouped_ops)) { + addOutput(output_vals_[i].avg()); + addOutput(output_vals_[i].var()); + addOutput(output_vals_[i].N()); + addInput(input_vals_[i].avg()); + addInput(input_vals_[i].var()); + addInput(input_vals_[i].N()); + } +} + +GroupedWelfordOp::GroupedWelfordOp( + const GroupedWelfordOp* src, + IrCloner* ir_cloner) + : Expr(src, ir_cloner), + output_vals_(WelfordTriplet::clone(src->output_vals_, ir_cloner)), + input_vals_(WelfordTriplet::clone(src->input_vals_, ir_cloner)), + init_vals_(WelfordTriplet::clone(src->init_vals_, ir_cloner)), + is_allreduce_(src->is_allreduce_) {} + +bool GroupedWelfordOp::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + + auto grouped_op = dynamic_cast(other); + if (grouped_op == nullptr) { + return false; + } + + if (!Expr::sameAs(other)) { + return false; + } + + for (const auto i : c10::irange(numExprs())) { + if (!initAvg(i)->sameAs(grouped_op->initAvg(i)) || + !initVar(i)->sameAs(grouped_op->initVar(i)) || + !initN(i)->sameAs(grouped_op->initN(i))) { + return false; + } + } + + return true; +} + +int GroupedWelfordOp::getExprIndexOfOutput(Val* output_val) const { + for (const auto expr_idx : c10::irange(numExprs())) { + if (outputVals().at(expr_idx).getNameOf(output_val).has_value()) { + return expr_idx; + } + } + + TORCH_INTERNAL_ASSERT( + false, "Not an output, ", output_val->toString(), ", of ", toString()); +} + +Val* GroupedWelfordOp::getInitValOfOutput(Val* output_val) const { + auto expr_index = getExprIndexOfOutput(output_val); + + auto val_name = outputVals().at(expr_index).getNameOf(output_val).value(); + + return initVals().at(expr_index).get(val_name); +} + MmaOp::MmaOp( IrBuilderPasskey passkey, Val* out, diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 68616db917e941..d8520bf047f41e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -415,12 +415,12 @@ struct SubstituteInExpr : public OptInDispatch { out_avg, out_var, out_N, - init_avg, - init_var, - init_N, in_avg, in_var, in_N, + init_avg, + init_var, + init_N, welford_expr->isAllreduce()); } @@ -782,29 +782,12 @@ Val* getReductionInitValOf(TensorView* tv) { if (auto rop = dynamic_cast(def)) { init = rop->init(); } else if (auto grop = dynamic_cast(def)) { - int output_idx = -1; - for (const auto i : c10::irange(grop->numExprs())) { - if (tv == grop->output(i)) { - output_idx = static_cast(i); - break; - } - } - TORCH_INTERNAL_ASSERT( - output_idx >= 0, - "Matching output not found for GroupedReductionOp: ", - tv->toString(), - ". Defined by: ", - def->toString()); + int output_idx = grop->getExprIndexOfOutput(tv); init = grop->initVal(output_idx); } else if (auto wop = dynamic_cast(def)) { - if (tv == wop->outAvg()) { - init = wop->initAvg(); - } else if (tv == wop->outVar()) { - init = wop->initVar(); - } else { - TORCH_INTERNAL_ASSERT(tv == wop->outN()); - init = wop->initN(); - } + return wop->getInitValOfOutput(tv); + } else if (auto gwop = dynamic_cast(def)) { + init = gwop->getInitValOfOutput(tv); } else if (auto mma = dynamic_cast(def)) { init = mma->init(); } @@ -817,7 +800,8 @@ Val* getReductionInitValOf(TensorView* tv) { bool isReductionOp(const Expr* expr) { // Note that GridReduction inherits ReductionOp return expr->isA() || expr->isA() || - expr->isA() || expr->isA(); + expr->isA() || expr->isA() || + expr->isA() || expr->isA(); } bool isReductionTvOp(const Expr* expr) { diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index 3f5efe02d8ed69..404e61e7dc5277 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -137,6 +137,15 @@ class KernelIrScanner : private IrVisitor { } } + void handle(GroupedGridWelford* grid_welford) final { + summary_.has_welford = true; + summary_.has_grid_welford = true; + summary_.has_grid_reductions = true; + if (grid_welford->isAllreduce()) { + summary_.has_cooperative_grid_reduction = true; + } + } + void handle(GridBroadcast* grid_broadcast) final { summary_.has_cooperative_grid_reduction = true; handle(grid_broadcast->broadcast_op()); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 11bf1b18543787..132b99b31c34bb 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -589,6 +589,34 @@ GridWelford::GridWelford( "IR type only valid for Kernel container."); } +GroupedGridWelford::GroupedGridWelford( + IrBuilderPasskey passkey, + std::vector output_vals, + std::vector input_vals, + std::vector init_vals, + std::array, 3> reduction_buffers, + Allocate* sync_buffer, + Val* entrance_index, + Val* entrances, + Val* buffer_stride, + bool is_allreduce) + : GroupedWelfordOp( + passkey, + std::move(output_vals), + std::move(input_vals), + std::move(init_vals), + is_allreduce, + ExprType::GroupedGridWelford), + reduction_buffers_(std::move(reduction_buffers)), + sync_buffer_(sync_buffer), + entrance_index_(entrance_index), + entrances_(entrances), + buffer_stride_(buffer_stride) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + AllocateFusedReduction::AllocateFusedReduction( IrBuilderPasskey passkey, GridReduction* grid_reduction) @@ -619,6 +647,16 @@ AllocateFusedReduction::AllocateFusedReduction( "IR type only valid for Kernel container."); } +AllocateFusedReduction::AllocateFusedReduction( + IrBuilderPasskey passkey, + GroupedGridWelford* grouped_grid_welford) + : Expr(passkey, ExprType::AllocateFusedReduction), + grid_expr_(grouped_grid_welford) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + TensorIndex* AllocateFusedReduction::out() const { TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr); if (grid_expr_->isA() || @@ -626,6 +664,10 @@ TensorIndex* AllocateFusedReduction::out() const { return grid_expr_->outputs().at(0)->as(); } else if (auto grid_welford = dynamic_cast(grid_expr_)) { return grid_welford->welford_op()->out()->as(); + } else if ( + auto grouped_grid_welford = + dynamic_cast(grid_expr_)) { + return grouped_grid_welford->out(0)->as(); } else { TORCH_INTERNAL_ASSERT( false, "Invalid grid expression: ", grid_expr_->toString()); @@ -642,6 +684,10 @@ const ParallelTypeBitmap& AllocateFusedReduction::threadPredicate() const { auto grouped_grid_reduction = dynamic_cast(grid_expr_)) { return grouped_grid_reduction->threadPredicate(); + } else if ( + auto grouped_grid_welford = + dynamic_cast(grid_expr_)) { + return grouped_grid_welford->threadPredicate(); } else { TORCH_INTERNAL_ASSERT( false, "Invalid grid expression: ", grid_expr_->toString()); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index b629f687e2c014..cc41376435c6b6 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -64,6 +64,7 @@ class GridReduction; class GroupedGridReduction; class GridBroadcast; class GridWelford; +class GroupedGridWelford; class AllocateFusedReduction; // Expr container @@ -694,6 +695,8 @@ class TORCH_CUDA_CU_API GridBroadcast final : public Expr { //! //! This node provides FusionExecutor the information it needs to allocate the //! reduction and sync buffers. +//! +//! TODO: Make this a subclass of WelfordOp class TORCH_CUDA_CU_API GridWelford final : public Expr { public: GridWelford( @@ -758,6 +761,64 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { ParallelTypeBitmap thread_predicate_; }; +class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { + public: + // input, output and init vals are vectors of triplets + GroupedGridWelford( + IrBuilderPasskey passkey, + std::vector output_vals, + std::vector input_vals, + std::vector init_vals, + std::array, 3> reduction_buffers, + Allocate* sync_buffer, + Val* entrance_index, + Val* entrances, + Val* buffer_stride, + bool is_allreduce = false); + + const std::array, 3>& reduction_buffers() const { + return reduction_buffers_; + } + + Allocate* sync_buffer() const { + return sync_buffer_; + } + + // Which instance of entering this grid reduction is this iteration? + Val* entrance_index() const { + return entrance_index_; + } + + // How many times will this grid reduction be entered + Val* entrances() const { + return entrances_; + } + + Val* buffer_stride() const { + return buffer_stride_; + } + + const ParallelTypeBitmap& threadPredicate() const { + return thread_predicate_; + } + + void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { + thread_predicate_ = thread_predicate; + } + + private: + std::array, 3> reduction_buffers_; + Allocate* sync_buffer_ = nullptr; + // gridReduce has template flags for thread predicates. In order to + // use them, the thread predicate is held here separately from + // Expr::predicate_. + ParallelTypeBitmap thread_predicate_; + Val* entrance_index_ = nullptr; + Val* entrances_ = nullptr; + // Stride of reduction buffers + Val* buffer_stride_ = nullptr; +}; + // Allocate an instance of the fused reduction class. class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { public: @@ -773,6 +834,10 @@ class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { IrBuilderPasskey passkey, GroupedGridReduction* grouped_grid_reduction); + explicit AllocateFusedReduction( + IrBuilderPasskey passkey, + GroupedGridWelford* grouped_grid_welford); + Expr* gridExpr() const { return grid_expr_; } @@ -782,7 +847,7 @@ class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { const ParallelTypeBitmap& threadPredicate() const; private: - //! GridReduction, GridWelford or GroupedGridReduction + //! GridReduction, GridWelford, GroupedGridReduction or GroupedGridWelford Expr* grid_expr_ = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index f0c7b26c1802ff..466dc85c8abff3 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -478,6 +478,11 @@ class AllocationInserter : public kir::ExprMutator { out->name() == welford->outN()->name(), "Unreachable"); init = welford->initN(); } + } else if (expr->isA()) { + TORCH_INTERNAL_ASSERT( + default_val == nullptr, + "Welford should not have a default initialization value for predicate elimination."); + init = expr->as()->getInitValOfOutput(out); } else if (default_val != nullptr) { init = default_val; } diff --git a/torch/csrc/jit/codegen/cuda/lower_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/lower_fused_reduction.cpp index 213abda029a6c7..744feab598b3a6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_fused_reduction.cpp @@ -266,15 +266,9 @@ class FusionTransformer { fusion_->removeExpr(welford); fused_expr = IrBuilder::create( - out_avg, - out_var, - out_n, - init_avg, - init_var, - init_n, - in_avg, - in_var, - in_n, + WelfordTriplet{out_avg, out_var, out_n}, + WelfordTriplet{in_avg, in_var, in_n}, + WelfordTriplet{init_avg, init_var, init_n}, true); } else if (auto grouped_rop = dynamic_cast(expr)) { TORCH_INTERNAL_ASSERT(!grouped_rop->isAllreduce()); diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 3d9b5b17d8b68f..ab5eef6b21cffa 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -615,8 +615,6 @@ void IndexLowering::handle(const WelfordOp* wop) { const bool has_block_reduce = out_domain->hasBlockReduction(); const bool has_grid_reduce = out_domain->hasGridReduction(); - // If we do a grid reduction we can't have a reduction axis that is not bound - // to a grid or block dim () if (has_grid_reduce) { TORCH_INTERNAL_ASSERT( std::none_of( @@ -650,12 +648,12 @@ void IndexLowering::handle(const WelfordOp* wop) { out_avg, out_var, out_N, - wop->initAvg(), - wop->initVar(), - wop->initN(), in_avg, in_var, in_N, + wop->initAvg(), + wop->initVar(), + wop->initN(), wop->isAllreduce()); if (wop->predicate()) { @@ -692,18 +690,18 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent); const auto work_buffer_size = buffer_size_info.size_of_privatized_buffer; - auto out_var_buffer = allocateUniqueBuffer( - work_buffer_size, - indexed_wop->outVar()->dtype(), - false, - indexed_wop->outVar()->as()->view(), - work_buffer_map_); auto out_avg_buffer = allocateUniqueBuffer( work_buffer_size, indexed_wop->outAvg()->dtype(), false, indexed_wop->outAvg()->as()->view(), work_buffer_map_); + auto out_var_buffer = allocateUniqueBuffer( + work_buffer_size, + indexed_wop->outVar()->dtype(), + false, + indexed_wop->outVar()->as()->view(), + work_buffer_map_); auto out_N_buffer = allocateUniqueBuffer( work_buffer_size, indexed_wop->outN()->dtype(), @@ -771,6 +769,150 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { } } +void IndexLowering::handle(const GroupedWelfordOp* grouped_wop) { + TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(grouped_wop)); + + const auto out_tv = ir_utils::getTvOutput(grouped_wop); + const auto out_domain = out_tv->domain(); + + const bool has_grid_reduce = out_domain->hasGridReduction(); + + std::vector indexed_outputs(grouped_wop->numExprs()); + std::vector indexed_inputs(grouped_wop->numExprs()); + + for (const auto i : c10::irange(grouped_wop->numExprs())) { + const auto& output = grouped_wop->outputVals().at(i); + const auto& input = grouped_wop->inputVals().at(i); + WelfordTriplet indexed_output; + WelfordTriplet indexed_input; + for (const auto j : c10::irange(3)) { + indexed_output.get(j) = lowerDstIndex(output.get(j)); + indexed_input.get(j) = lowerSrcIndex(input.get(j), output.get(j)); + } + indexed_outputs[i] = indexed_output; + indexed_inputs[i] = indexed_input; + } + + if (has_grid_reduce) { + handleGroupedGridWelford( + grouped_wop, indexed_outputs, indexed_inputs, grouped_wop->initVals()); + } else { + TORCH_INTERNAL_ASSERT( + false, + "Only grid welford is supported. Validation should have caught non-grid welford grouping."); + } +} + +std::vector IndexLowering::allocateWelfordWorkBuffer( + const std::vector& triplets, + WelfordTriplet::ValName name, + Val* buffer_size) { + std::vector work_buffers; + + std::transform( + triplets.begin(), + triplets.end(), + std::back_inserter(work_buffers), + [&](const WelfordTriplet& output) { + return allocateUniqueBuffer( + buffer_size, + output.get(name)->dtype(), + false, + output.get(name)->as(), + work_buffer_map_); + }); + + return work_buffers; +} + +void IndexLowering::handleGroupedGridWelford( + const GroupedWelfordOp* op, + const std::vector& output_vals, + const std::vector& input_vals, + const std::vector& init_vals) { + const auto out_tv = ir_utils::getTvOutput(op); + const auto out_domain = out_tv->domain(); + + TORCH_INTERNAL_ASSERT(out_domain->hasGridReduction()); + + // If we do a grid reduction we can't have a reduction axis that is not bound + // to a grid or block dim. + TORCH_INTERNAL_ASSERT( + std::none_of( + out_domain->domain().begin(), + out_domain->domain().end(), + [](IterDomain* id) { + return !id->isThread() && id->isReduction() && + !id->extent()->isOneInt(); + }), + "Found a reduction stage that has both a non-parallelized ", + "reduction and a grid reduction. This is not supported, ", + "please use rfactor to do the serialized reduction first, ", + "then the grid reduction."); + + const bool is_persistent = op->isAllreduce(); + auto work_buf_size_info = + getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent); + + const auto work_buffers_avg = allocateWelfordWorkBuffer( + op->outputVals(), + WelfordTriplet::ValName::Avg, + work_buf_size_info.size_of_privatized_buffer); + const auto work_buffers_var = allocateWelfordWorkBuffer( + op->outputVals(), + WelfordTriplet::ValName::Var, + work_buf_size_info.size_of_privatized_buffer); + const auto work_buffers_N = allocateWelfordWorkBuffer( + op->outputVals(), + WelfordTriplet::ValName::N, + work_buf_size_info.size_of_privatized_buffer); + + auto sync_buffer_size = + getGridSyncBufferSize(out_domain, for_loops_, is_persistent); + auto sync_buffer = allocateUniqueBuffer( + sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_); + + const auto entrance_ind = !is_persistent + ? getEntranceLinIndGridReduce(for_loops_) + : GpuLower::current()->kernel()->zeroVal(); + const auto n_entrances = !is_persistent + ? getEntranceCountGridReduce(for_loops_) + : GpuLower::current()->kernel()->oneVal(); + + // The thread predicate needs to be set separately from the main + // predicate. Do not combine them like other expressions. + const auto& thread_pred = + GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); + + auto indexed_op = IrBuilder::create( + output_vals, + input_vals, + init_vals, + std::array, 3>{ + work_buffers_avg, work_buffers_var, work_buffers_N}, + sync_buffer, + entrance_ind, + n_entrances, + work_buf_size_info.buffer_stride, + op->isAllreduce()); + + indexed_op->setThreadPredicate(thread_pred); + + if (op->predicate()) { + indexed_op->setPredicate(op->predicate()); + } + if (op->writePredicate()) { + indexed_op->setWritePredicate(op->writePredicate()); + } + + pushBack(indexed_op); + GpuLower::current()->propagateExprInfo(op, back()); + + if (op->isAllreduce()) { + allocateUniqueFusedReduction(indexed_op, out_tv); + } +} + void IndexLowering::handle(const LoadStoreOp* ldst) { const auto in = lowerSrcIndex(ldst->in(), ldst->out()); const auto out = lowerDstIndex(ldst->out()); @@ -923,6 +1065,11 @@ void IndexLowering::allocateUniqueFusedReduction( IrBuilder::create( expr->as()); break; + case ExprType::GroupedGridWelford: + fused_reduction_alloc_reduction = + IrBuilder::create( + expr->as()); + break; default: TORCH_INTERNAL_ASSERT(false, "Invalid expr: ", expr->toString()); } diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index b42cd288d04591..539c40f0fb6ce4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -48,6 +48,7 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { void handle(const ReductionOp*) final; void handle(const GroupedReductionOp*) final; void handle(const WelfordOp*) final; + void handle(const GroupedWelfordOp*) final; void handle(const LoadStoreOp*) final; void handle(const MmaOp*) final; void handle(const BroadcastOp*) final; @@ -80,6 +81,17 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { void handleGridWelford(WelfordOp* new_wop); + void handleGroupedBlockWelford( + const GroupedWelfordOp* wop, + const std::vector& output_vals, + const std::vector& input_vals, + const std::vector& init_vals); + void handleGroupedGridWelford( + const GroupedWelfordOp* wop, + const std::vector& output_vals, + const std::vector& input_vals, + const std::vector& init_vals); + // Allocate a unique buffer for grid reductions and broadcast. A // buffer is uniquely allocated for each output tensor of an // expression. @@ -90,6 +102,11 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { TensorView* out_tv, std::unordered_map& alloc_map); + std::vector allocateWelfordWorkBuffer( + const std::vector& triplets, + WelfordTriplet::ValName name, + Val* buffer_size); + // Allocate a fused reduction object uniquely for a given // TensorView. Parameter expr is the expression corresponding to the // fused reduction. diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 453872d4f4e97d..940de32ce9567c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -640,10 +640,11 @@ class PredicateChcker : public IterVisitor { // If input is not predicated, out-of-bound value may be // overwritten by a garbage value. However, it doesn't matter if // the input is also produced by another welford. - if (!input_def->isA() && + if (!input_def->isA() && !input_def->isA() && non_predicated_exprs_.find(input_def) != non_predicated_exprs_.end()) { needs_predicate_ = true; + return; } } } @@ -691,12 +692,8 @@ class PredicateChcker : public IterVisitor { } else if ( auto input_def_grouped_rop = dynamic_cast(input_def)) { - auto input_index_as_output = std::distance( - input_def_grouped_rop->outputs().begin(), - std::find( - input_def_grouped_rop->outputs().begin(), - input_def_grouped_rop->outputs().end(), - input)); + auto input_index_as_output = + input_def_grouped_rop->getExprIndexOfOutput(input); if (grouped_rop->getReductionOpType(i) != input_def_grouped_rop->getReductionOpType( input_index_as_output) && @@ -714,6 +711,62 @@ class PredicateChcker : public IterVisitor { } } + void handle(GroupedWelfordOp* grouped_wop) final { + for (const auto expr_idx : c10::irange(grouped_wop->numExprs())) { + for (const auto val_idx : c10::irange(3)) { + auto init = grouped_wop->initVals().at(expr_idx).get(val_idx); + + // Welford input can be a scalar. Predicate is required unless + // the scalar value is equal to the init value. + auto input = grouped_wop->inputVals().at(expr_idx).get(val_idx); + if (input->isScalar()) { + if (!input->sameAs(init)) { + needs_predicate_ = true; + return; + } + continue; + } + + auto input_tv = dynamic_cast(input); + TORCH_INTERNAL_ASSERT(input_tv != nullptr); + + auto input_def = input->definition(); + + // When input_def is null, input must be an input to the fusion, + // so that must be allocated on global memory. Since we don't omit + // predication for expressions involving global memory, this + // should never occur. + TORCH_INTERNAL_ASSERT( + input_def != nullptr, + "Inconsistent input found: ", + input->toString()); + + // The input needs to be initialized to the init value to omit + // the predicate, so if the input has its own init value, i.e., + // produced by another reduction, they must use the same init + // value. + Val* input_init = ir_utils::getReductionInitValOf(input_tv); + if (input_init != nullptr && !init->sameAs(input_init)) { + needs_predicate_ = true; + return; + } + + // If input is not predicated, out-of-bound value may be + // overwritten by a garbage value. However, it doesn't matter if + // the input is also produced by another reduction op as it + // must be initialized and its initialized value is already + // found to be equal to the initil value of this op. + if (!input_def->isA() && + !input_def->isA() && + non_predicated_exprs_.find(input_def) != + non_predicated_exprs_.end()) { + needs_predicate_ = true; + return; + } + } + } + } + // Similar to the above reduction constraint but for MMA void handle(MmaOp* mma) final { for (auto input : ir_utils::filterByType(mma->inputs())) { diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 3c5908ee14facf..28da8774daa289 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -94,6 +94,7 @@ bool isTvOp(const Expr* expr) { expr->getExprType().value() == ExprType::ReductionOp || expr->getExprType().value() == ExprType::GroupedReductionOp || expr->getExprType().value() == ExprType::WelfordOp || + expr->getExprType().value() == ExprType::GroupedWelfordOp || expr->getExprType().value() == ExprType::LoadStoreOp || expr->getExprType().value() == ExprType::MmaOp || expr->getExprType().value() == ExprType::BroadcastOp || @@ -104,8 +105,10 @@ bool isTvOp(const Expr* expr) { expr->getExprType().value() == ExprType::ViewAsScalar || expr->getExprType().value() == ExprType::ViewOp || expr->getExprType().value() == ExprType::GridReduction || + expr->getExprType().value() == ExprType::GroupedGridReduction || expr->getExprType().value() == ExprType::GridBroadcast || - expr->getExprType().value() == ExprType::GridWelford)) { + expr->getExprType().value() == ExprType::GridWelford || + expr->getExprType().value() == ExprType::GroupedGridWelford)) { return true; } return false; diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 9c522f6ae9646a..631e1b9be7724b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -1204,8 +1204,10 @@ void validateAndConvertIterDomainGrouping(Fusion* fusion) { TORCH_CHECK( tv->definition()->isA() || - tv->definition()->isA(), - "Invalid use of ParallelType::Group. Only ReductionOp and GroupedReductionOp are allowed. ", + tv->definition()->isA() || + tv->definition()->isA() || + tv->definition()->isA(), + "Invalid use of ParallelType::Group. Only ReductionOp, GroupedReductionOp, WelfordOp and GroupedWelfordOp are allowed. ", tv->definition()->toString()); // Convert ReductionOp to GroupedReductionOp @@ -1238,6 +1240,36 @@ void validateAndConvertIterDomainGrouping(Fusion* fusion) { outputs, inputs, is_allreduce); + } else if (tv->definition()->isA()) { + // Convert WelfordOp to GroupedWelfordOp + auto wop = def->as(); + auto is_allreduce = wop->isAllreduce(); + + TORCH_CHECK( + is_allreduce, + "Invalid use of ParallelType::Group.", + " Only enabled for allreduce reductions: ", + wop->toString()); + + TORCH_CHECK( + tv->domain()->hasGridReduction(), + "Invalid use of ParallelType::Group.", + " Only enabled for grid reductions: ", + wop->toString()); + + std::vector output_vals( + {{wop->outAvg(), wop->outVar(), wop->outN()}}); + std::vector input_vals( + {{wop->inAvg(), wop->inVar(), wop->inN()}}); + std::vector init_vals( + {{wop->initAvg(), wop->initVar(), wop->initN()}}); + fusion->removeExpr(wop); + IrBuilder::create( + static_cast(fusion), + output_vals, + input_vals, + init_vals, + is_allreduce); } } } diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 50dc7a57908eb4..bfb9d6a2534ef8 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -263,15 +263,52 @@ void OptOutMutator::mutate(WelfordOp* wop) { out_avg, out_var, out_N, - init_avg, - init_var, - init_N, in_avg, in_var, in_N, + init_avg, + init_var, + init_N, wop->isAllreduce()); } +void OptOutMutator::mutate(GroupedWelfordOp* wop) { + bool is_same = true; + + std::vector output_vals; + for (const auto& out : wop->outputVals()) { + auto maybe_mutated = + out.transform([&](Val* val) { return maybeMutated(val); }); + is_same = is_same && maybe_mutated.sameAs(out); + output_vals.push_back(maybe_mutated); + } + + std::vector input_vals; + for (const auto& inp : wop->inputVals()) { + auto maybe_mutated = + inp.transform([&](Val* val) { return maybeMutated(val); }); + is_same = is_same && maybe_mutated.sameAs(inp); + input_vals.push_back(maybe_mutated); + } + + std::vector init_vals; + for (const auto& init : wop->initVals()) { + auto maybe_mutated = + init.transform([&](Val* val) { return maybeMutated(val); }); + is_same = is_same && maybe_mutated.sameAs(init); + init_vals.push_back(maybe_mutated); + } + + if (is_same) { + return; + } + + auto container = wop->container(); + container->removeExpr(wop); + IrBuilder::create( + container, output_vals, input_vals, init_vals, wop->isAllreduce()); +} + void OptOutMutator::mutate(MmaOp* mma) { Val* out = maybeMutated(mma->out()); Val* in_a = maybeMutated(mma->inA()); @@ -511,6 +548,9 @@ void OptOutMutator::mutate(kir::GridBroadcast*) { void OptOutMutator::mutate(kir::GridWelford*) { TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } +void OptOutMutator::mutate(kir::GroupedGridWelford*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} void OptOutMutator::mutate(kir::AllocateFusedReduction*) { TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } diff --git a/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.cpp b/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.cpp index 43961dbda4754c..9e3ff2046c0f6e 100644 --- a/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.cpp +++ b/torch/csrc/jit/codegen/cuda/parallel_type_bitmap.cpp @@ -12,6 +12,7 @@ constexpr std::bitset std::string ParallelTypeBitmap::toString() const { std::stringstream ss; + ss << "("; bool is_first = true; for (ParallelType pt : *this) { if (!is_first) { @@ -20,6 +21,7 @@ std::string ParallelTypeBitmap::toString() const { ss << pt; is_first = false; } + ss << ")"; return ss.str(); } diff --git a/torch/csrc/jit/codegen/cuda/runtime/fused_reduction.cu b/torch/csrc/jit/codegen/cuda/runtime/fused_reduction.cu index a1cbcb1b398e82..de111b6782d6b9 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/fused_reduction.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/fused_reduction.cu @@ -1,4 +1,95 @@ namespace fused_reduction { + +// Tuple of Welford avg, var and N parameters. +// +// Template parameters: +// - DataTypeT: Type of avg and var +// - IndexTypeT: Type of N +// - MakeTuple: Template template parameter to define Tuple types +// (e.g., MakeLocalTuple> +template < + int NumVals, + typename DataTypeT, + typename IndexTypeT, + template + typename MakeTuple> +struct WelfordTripletTuple { + static constexpr int num_vals = NumVals; + using DataType = DataTypeT; + using IndexType = IndexTypeT; + using DataTuple = typename MakeTuple::type; + using IndexTuple = typename MakeTuple::type; + + DataTuple avg; + DataTuple var; + IndexTuple N; + + WelfordTripletTuple( + const DataTuple& avg, + const DataTuple& var, + const IndexTuple& N) + : avg(avg), var(var), N(N) {} +}; + +template +using LocalWelfordTripletTuple = + WelfordTripletTuple; + +template +using RefWelfordTripletTuple = + WelfordTripletTuple; + +template +using ConstRefWelfordTripletTuple = + WelfordTripletTuple; + +template +using VolatilePtrWelfordTripletTuple = + WelfordTripletTuple; + +// Advance pointer offsets of WelfordTripleTuple. Only valid when the +// values are pointer values. +template +__inline__ __device__ static void operator+=( + WelfordTripletTupleType& triplet, + nvfuser_index_t offset) { + triplet.avg += offset; + triplet.var += offset; + triplet.N += offset; +} + +// Copy each of the triplet tuples +template +__inline__ __device__ static void copyWelfordTripletTuple( + DstType& dst, + nvfuser_index_t dst_offset, + const SrcType& src, + nvfuser_index_t src_offset = 0) { + copyTuple(dst.avg, dst_offset, src.avg, src_offset); + copyTuple(dst.var, dst_offset, src.var, src_offset); + copyTuple(dst.N, dst_offset, src.N, src_offset); +} + +// Copy each of the triplet tuples +template +__inline__ __device__ static void copyWelfordTripletTuple( + DstType& dst, + const SrcType& src, + nvfuser_index_t src_offset = 0) { + copyWelfordTripletTuple(dst, 0, src, src_offset); +} + +// Copy each of the triplet tuples +template +__inline__ __device__ static void copyWelfordTripletTupleIf( + DstType& dst, + const SrcType& src, + const PredType& pred) { + copyTupleIf(dst.avg, src.avg, pred); + copyTupleIf(dst.var, src.var, pred); + copyTupleIf(dst.N, src.N, pred); +} + namespace impl { //! Suppose f_i be the i-th function of the binary function @@ -110,6 +201,140 @@ __inline__ __device__ static void reduceEach( val0, offset0, val1, offset1, reduction_ops...); } +//! Implementation helper for welfordEach. +template +struct WelfordForEach { + static __inline__ __device__ void call( + Triplet0& triplet0, + nvfuser_index_t offset0, + const Triplet1& triplet1, + nvfuser_index_t offset1) { + static_assert( + Triplet0::num_vals == Triplet1::num_vals, "Invalid Triplet types"); + static_assert( + IsSameType:: + value, + "Invalid Triplet types"); + static_assert( + IsSameType:: + value, + "Invalid Triplet types"); + + using DataType = typename Triplet0::DataType; + using IndexType = typename Triplet0::IndexType; + + WelfordForEach::call( + triplet0, offset0, triplet1, offset1); + welfordCombine( + triplet0.avg.val(offset0), + triplet0.var.val(offset0), + triplet0.N.val(offset0), + triplet1.avg.val(offset1), + triplet1.var.val(offset1), + triplet1.N.val(offset1)); + } +}; + +template +struct WelfordForEach<-1, Triplet0, Triplet1> { + __inline__ __device__ static void call( + Triplet0& triplet0, + nvfuser_index_t offset0, + const Triplet1& triplet1, + nvfuser_index_t offset1) {} +}; + +//! Call welfordCombine with each of the triplet tuples. This is a +//! welford version of reduceEach. +template +__inline__ __device__ static void welfordEach( + Triplet0& triplet0, + nvfuser_index_t offset0, + const Triplet1& triplet1, + nvfuser_index_t offset1) { + WelfordForEach::call( + triplet0, offset0, triplet1, offset1); +} + +template +struct TupleReduce {}; + +template +struct TupleReduce { + __inline__ __device__ static void reduce( + TupleType0& val0, + nvfuser_index_t offset0, + const TupleType1& val1, + nvfuser_index_t offset1, + Func reduction_op) { + static_assert( + IsSameType< + typename TupleType0::ValTypes, + typename TupleType1::ValTypes>::value, + "Invalid value types"); + reduction_op(val0.val<0>(offset0), val1.val<0>(offset1)); + } +}; + +template +struct TupleReduce { + __inline__ __device__ static void reduce( + TupleType0& val0, + nvfuser_index_t offset0, + const TupleType1& val1, + nvfuser_index_t offset1, + Func reduction_op) { + static_assert( + IsSameType< + typename TupleType0::ValTypes, + typename TupleType1::ValTypes>::value, + "Invalid value types"); + reduction_op( + val0.val<0>(offset0), + val0.val<1>(offset0), + val1.val<0>(offset1), + val1.val<1>(offset1)); + } +}; + +template +struct TupleReduce { + __inline__ __device__ static void reduce( + TupleType0& val0, + nvfuser_index_t offset0, + const TupleType1& val1, + nvfuser_index_t offset1, + Func reduction_op) { + static_assert( + IsSameType< + typename TupleType0::ValTypes, + typename TupleType1::ValTypes>::value, + "Invalid value types"); + reduction_op( + val0.val<0>(offset0), + val0.val<1>(offset0), + val0.val<2>(offset0), + val1.val<0>(offset1), + val1.val<1>(offset1), + val1.val<2>(offset1)); + } +}; + +//! Reduce all values of a tuple together. The reduction function must +//! have the same number of inputs as the number of values of each tuple. +template +__inline__ __device__ void reduceTuple( + TupleType0& val0, + nvfuser_index_t offset0, + const TupleType1& val1, + nvfuser_index_t offset1, + Func reduction_op) { + static_assert( + TupleType0::num_vals == TupleType1::num_vals, "Invalid number of values"); + TupleReduce::reduce( + val0, offset0, val1, offset1, reduction_op); +} + // Reduces all of the first (idx+1) values by a thread block template < int idx, @@ -297,83 +522,191 @@ __inline__ __device__ void blockReduceEach( reduction_ops...); } -template -struct TupleReduce {}; - -template -struct TupleReduce { +// Welford version of BlockReduceEach +template < + int idx, + bool BROADCAST, + bool FORWARD_PROTECT_SMEM, + typename LocalWelfordTripletTupleT> +struct BlockWelfordEach { __inline__ __device__ static void reduce( - TupleType0& val0, - nvfuser_index_t offset0, - const TupleType1& val1, - nvfuser_index_t offset1, - Func reduction_op) { - static_assert( - IsSameType< - typename TupleType0::ValTypes, - typename TupleType1::ValTypes>::value, - "Invalid value types"); - reduction_op(val0.val<0>(offset0), val1.val<0>(offset1)); - } -}; + LocalWelfordTripletTupleT& block_result, + const LocalWelfordTripletTupleT& partial_result, + PtrTuple< + typename LocalWelfordTripletTupleT::DataType, + typename LocalWelfordTripletTupleT::DataType, + typename LocalWelfordTripletTupleT::IndexType> shared_buf, + bool has_block_result, + int tid_in_reduction, + int num_threads_per_reduction, + int num_elements_per_reduction, + int reduction_idx) { + // Finish the reduction of each tuple value with a smaller offset + BlockWelfordEach:: + reduce( + block_result, + partial_result, + shared_buf, + has_block_result, + tid_in_reduction, + num_threads_per_reduction, + num_elements_per_reduction, + reduction_idx); -template -struct TupleReduce { - __inline__ __device__ static void reduce( - TupleType0& val0, - nvfuser_index_t offset0, - const TupleType1& val1, - nvfuser_index_t offset1, - Func reduction_op) { - static_assert( - IsSameType< - typename TupleType0::ValTypes, - typename TupleType1::ValTypes>::value, - "Invalid value types"); - reduction_op( - val0.val<0>(offset0), - val0.val<1>(offset0), - val1.val<0>(offset1), - val1.val<1>(offset1)); + if (num_elements_per_reduction == 1) { + if (has_block_result) { + copyWelfordTripletTuple(block_result, partial_result); + } + return; + } + + using DataType = typename LocalWelfordTripletTupleT::DataType; + using IndexType = typename LocalWelfordTripletTupleT::IndexType; + + LocalTuple block_result_i( + partial_result.avg.val(0), + partial_result.var.val(0), + partial_result.N.val(0)); + + const auto smem_offset = + reduction_idx * num_threads_per_reduction + tid_in_reduction; + + const int np2 = 1 << (31 - __clz(num_elements_per_reduction)); + + // Threads values are initialized, so all can participate here + if (tid_in_reduction >= np2) { + copyTuple(shared_buf, smem_offset, block_result_i); + } + + block_sync::sync(); + if (tid_in_reduction < np2 && + tid_in_reduction + np2 < num_elements_per_reduction) { + impl::reduceTuple( + block_result_i, + 0, + shared_buf, + smem_offset + np2, + welfordCombine); + } + + if (tid_in_reduction < np2) { + copyTuple(shared_buf, smem_offset, block_result_i); + } + + // Always sync when communicating across smem + block_sync::sync(); + + // Reduce down to 2 values, last thread will do the final reduction and + // can save a syncthreads this way + for (int factor = np2 / 2; factor > 1; factor >>= 1) { + if (tid_in_reduction < factor) { + impl::reduceTuple( + shared_buf, + smem_offset, + shared_buf, + smem_offset + factor, + welfordCombine); + } + block_sync::sync(); + } + + copyTuple(block_result_i, shared_buf, smem_offset); + + // Do the last reduction + if (has_block_result) { + impl::reduceTuple( + block_result_i, + 0, + shared_buf, + smem_offset + 1, + welfordCombine); + } + + if (BROADCAST) { + if (has_block_result) { + // Put result back in shared memory, put in the first entry of the + // reduction segment's buffer + copyTuple( + shared_buf, + reduction_idx * num_threads_per_reduction, + block_result_i); + } + + // Sync threads to make sure result is in smem + block_sync::sync(); + + copyTuple( + block_result_i, + shared_buf, + reduction_idx * num_threads_per_reduction); + } + + block_result.avg.val(0) = block_result_i.val<0>(0); + block_result.var.val(0) = block_result_i.val<1>(0); + block_result.N.val(0) = block_result_i.val<2>(0); + + if (FORWARD_PROTECT_SMEM) { + block_sync::sync(); + } } }; -template -struct TupleReduce { +// Specialization for idx == -1, i.e., no value to reduce. +template < + bool BROADCAST, + bool FORWARD_PROTECT_SMEM, + typename LocalWelfordTripletTupleT> +struct BlockWelfordEach< + -1, + BROADCAST, + FORWARD_PROTECT_SMEM, + LocalWelfordTripletTupleT> { __inline__ __device__ static void reduce( - TupleType0& val0, - nvfuser_index_t offset0, - const TupleType1& val1, - nvfuser_index_t offset1, - Func reduction_op) { - static_assert( - IsSameType< - typename TupleType0::ValTypes, - typename TupleType1::ValTypes>::value, - "Invalid value types"); - reduction_op( - val0.val<0>(offset0), - val0.val<1>(offset0), - val0.val<2>(offset0), - val1.val<0>(offset1), - val1.val<1>(offset1), - val1.val<2>(offset1)); - } + LocalWelfordTripletTupleT& block_result, + const LocalWelfordTripletTupleT& partial_result, + PtrTuple< + typename LocalWelfordTripletTupleT::DataType, + typename LocalWelfordTripletTupleT::DataType, + typename LocalWelfordTripletTupleT::IndexType> shared_buf, + bool has_block_result, + int tid_in_reduction, + int num_threads_per_reduction, + int num_elements_per_reduction, + int reduction_idx) {} }; -//! Reduce all values of a tuple together. The reduction function must -//! have the same number of inputs as the number of values of each tuple. -template -__inline__ __device__ void reduceTuple( - TupleType0& val0, - nvfuser_index_t offset0, - const TupleType1& val1, - nvfuser_index_t offset1, - Func reduction_op) { - static_assert( - TupleType0::num_vals == TupleType1::num_vals, "Invalid number of values"); - TupleReduce::reduce( - val0, offset0, val1, offset1, reduction_op); +//! Welford version of blockReduceEach. Perform block-parallel Welford +//! reduction of each Welford triplet. +template < + bool BROADCAST, + bool FORWARD_PROTECT_SMEM, + typename LocalWelfordTripletTupleT> +__inline__ __device__ void blockWelfordEach( + LocalWelfordTripletTupleT& block_result, + const LocalWelfordTripletTupleT& partial_result, + PtrTuple< + typename LocalWelfordTripletTupleT::DataType, + typename LocalWelfordTripletTupleT::DataType, + typename LocalWelfordTripletTupleT::IndexType> shared_buf, + bool has_block_result, + int tid_in_reduction, + int num_threads_per_reduction, + int num_elements_per_reduction, + int reduction_idx) { + BlockWelfordEach< + LocalWelfordTripletTupleT::num_vals - 1, + BROADCAST, + FORWARD_PROTECT_SMEM, + LocalWelfordTripletTupleT>:: + reduce( + block_result, + partial_result, + shared_buf, + has_block_result, + tid_in_reduction, + num_threads_per_reduction, + num_elements_per_reduction, + reduction_idx); } } // namespace impl @@ -1109,6 +1442,203 @@ class ParallelReduce { } } + // User-visible entry point of grouped grid welford + + // broadcast. Mostly the same as reduceGroup, and it would be + // possible to combine this to reduceGroup, but it might make the + // templated data structures even more complicated and difficult to + // understand. For now, keep it as a separate function. + // + // Unlike reduceGroup, though, the data types of welford ops must be + // the same. For example, reduceGroup can be used to reduce half and + // float values by passing a tuple of, e.g., LocalTuple, but that's not supported here for implementation + // simplicity. In practice, it should be really uncommon to group + // welford ops with different data types, so this restriction + // shouldn't be an issue. + template + __device__ __inline__ void welfordGroup( + typename MakeRefTuple::type out_avg, + typename MakeRefTuple::type out_var, + typename MakeRefTuple::type out_N, + const typename MakeConstRefTuple::type& inp_avg, + const typename MakeConstRefTuple::type& inp_var, + const typename MakeConstRefTuple::type& inp_N, + const typename MakeLocalTuple::type& init_avg, + const typename MakeLocalTuple::type& init_var, + const typename MakeLocalTuple::type& init_N, + typename MakeVolatilePtrTuple::type + global_work_buffer_avg, + typename MakeVolatilePtrTuple::type + global_work_buffer_var, + typename MakeVolatilePtrTuple::type + global_work_buffer_N, + int64_t* global_sync_buffer, + PtrTuple shared_buf, + const typename MakeLocalTuple::type& read_preds, + const typename MakeLocalTuple::type& write_preds) { + const ConstRefWelfordTripletTuple inp( + inp_avg, inp_var, inp_N); + RefWelfordTripletTuple out( + out_avg, out_var, out_N); + + // If no reduction needed, just return input + if (!BLOCK_REDUCE && !GRID_REDUCE) { + copyWelfordTripletTupleIf(out, inp, read_preds && write_preds); + return; + } + + // Don't read/write in temporary buffers if in a predicated dimension + const bool block_reduce_participate = index_utils:: + maskedIsZero( + threadIdx); + + // Only threads that with id == 0 in the dimensions being reduced will + // have a valid result + const bool has_block_result = index_utils::maskedIsZero< + isReduce(X_THREAD), + isReduce(Y_THREAD), + isReduce(Z_THREAD)>(threadIdx); + + LocalWelfordTripletTuple block_result( + init_avg, init_var, init_N); + + // Initial per-block reduction. Result is broadcast if specified + // and this call is block reduction only. + welfordGroupBlock( + block_result, inp, shared_buf, read_preds, block_reduce_participate); + + // If block reduction only, save to out and exit + if (!GRID_REDUCE) { + copyWelfordTripletTupleIf( + out, + block_result, + write_preds && + (block_reduce_participate && (BROADCAST || has_block_result))); + + // Need a block sync here as reduceGroupBlock does not + // forward-protect the smem buffer. This block sync is not + // necessary when a grid reduction follows since a block sync is + // done just before the grid sync. + block_sync::sync(); + return; + } + + // -- START GRID REDUCTION -- // + // Grid reductions are more challenging for two reasons, (1) the reduction + // itself is 3D instead of 2D because we now have an iter domain space in + // the grid dimension. (2) a tree reduction isn't performed, instead all + // blocks will populate GMEM and one block will finish the grid reduction. + + // What is the grid reduction size, block reduction already performed so + // that doesn't have to be taken into consideration + const auto grid_red_size = index_utils:: + maskedSize( + gridDim); + + // Which ID in the reduction is this block. Threads can participate in + // multiple grid reductions, but the block will have the same relative index + // in those reductions + const auto idx_in_grid_red = index_utils:: + maskedOffset( + blockIdx, gridDim); + + // How many grid reductions have to be performed, in the grid dimension + const auto num_block_iters = index_utils:: + maskedSize(gridDim); + + // Which grid reduction does this block participate in, in the grid + // dimension + const auto block_red_idx_offset = index_utils:: + maskedOffset( + blockIdx, gridDim); + + // How many grid reductions have to be performed, in the block dimension + const auto num_thread_iters = index_utils:: + maskedSize( + blockDim); + + // Which grid reduction does this thread participate in, in the block + // dimension + const auto thread_red_idx_offset = index_utils:: + maskedOffset( + threadIdx, blockDim); + + // 3D buffer of reductions: + // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] + // Offset into the work buffer + auto work_buf_offset = + (idx_in_grid_red * num_block_iters + block_red_idx_offset) * + num_thread_iters + + thread_red_idx_offset; + + // Don't read/write in temporary buffers if in a predicated dimension + bool grid_reduce_participate = index_utils:: + maskedIsZero( + blockIdx); + + VolatilePtrWelfordTripletTuple + global_work_buffer( + global_work_buffer_avg, + global_work_buffer_var, + global_work_buffer_N); + + if (PERSISTENT_REDUCTION && flip) { + auto global_buffer_size = + index_utils:: + maskedSize( + gridDim) * + index_utils:: + maskedSize( + blockDim) * + grid_red_size; + global_work_buffer += global_buffer_size; + } + flip = !flip; + + // Per-block partial reduction to global work buffer + if (grid_reduce_participate && block_reduce_participate && + has_block_result) { + copyWelfordTripletTuple( + global_work_buffer, work_buf_offset, block_result); + } + + // -- GLOBAL BUFFER FILLED -- // + + bool last_block = index_utils:: + maskedIsLast( + blockIdx, gridDim); + + if (grid_reduce_participate) { + // Don't need to sync up blocks that are not participating in this + // reduction + grid_sync::sync< + isReduce(X_BLOCK), + isReduce(Y_BLOCK), + isReduce(Z_BLOCK), + PERSISTENT_REDUCTION>( + global_sync_buffer[block_red_idx_offset], grid_red_size, last_block); + } + + // -- START BLOCK CLEANUP -- // + welfordGroupLastBlock( + out, + global_work_buffer, + LocalWelfordTripletTuple( + init_avg, init_var, init_N), + shared_buf, + block_red_idx_offset, + num_thread_iters, + num_block_iters, + thread_red_idx_offset, + grid_red_size, + write_preds, + block_reduce_participate, + grid_reduce_participate); + + // Forward protect the smem buffer + block_sync::sync(); + } + private: __device__ bool isLastBlockInGrid() { return index_utils::maskedIsLast< @@ -1287,6 +1817,153 @@ class ParallelReduce { } } + //! Welford version of reduceGroupBlock + template < + bool BLOCK_BROADCAST, + int NumVals, + typename DataType, + typename IndexType> + __device__ __inline__ static void welfordGroupBlock( + LocalWelfordTripletTuple& block_result, + const ConstRefWelfordTripletTuple& inp, + PtrTuple shared_buf, + const typename MakeLocalTuple::type& read_preds, + bool block_reduce_participate) { + const bool has_block_result = index_utils::maskedIsZero< + isReduce(X_THREAD), + isReduce(Y_THREAD), + isReduce(Z_THREAD)>(threadIdx); + + copyWelfordTripletTupleIf( + block_result, inp, block_reduce_participate && read_preds); + + // Size of the block reduction segment, can be an int since it's limited + // to number of threads + const int block_reduction_size = index_utils:: + maskedSize( + blockDim); + + // Index in the reduction segment, can be an int since it's limited to + // number of threads + const int tid_in_block_reduction = index_utils::maskedOffset< + isReduce(X_THREAD), + isReduce(Y_THREAD), + isReduce(Z_THREAD)>(threadIdx, blockDim); + + // ID of the block reduction this thread is participating in + // + // If any of the parallel dimensions are predicated out, that means + // they've already been reduced, so we only care about the first thread in + // that dimension. Therefore don't expand the reduction_idx by that + // dimension + const int block_reduction_idx = index_utils:: + maskedOffset( + threadIdx, blockDim); + + // Do not protect the smem buffer as it's not always necessary. + impl::blockWelfordEach< + BLOCK_BROADCAST, + false, + LocalWelfordTripletTuple>( + block_result, + block_result, + shared_buf, + has_block_result, + tid_in_block_reduction, + block_reduction_size, + block_reduction_size, + block_reduction_idx); + } + + //! Welford version of reduceGrouplLastBlock + template + __device__ __inline__ void welfordGroupLastBlock( + RefWelfordTripletTuple& out, + const VolatilePtrWelfordTripletTuple& + global_work_buffer, + const LocalWelfordTripletTuple& init_val, + PtrTuple shared_buf, + nvfuser_index_t block_red_idx_offset, + nvfuser_index_t num_thread_iters, + nvfuser_index_t num_block_iters, + nvfuser_index_t thread_red_idx_offset, + nvfuser_index_t grid_red_size, + const typename MakeLocalTuple::type& write_preds, + bool block_reduce_participate, + bool grid_reduce_participate) { + // Initialize block result + auto last_block_result = init_val; + + const bool last_block = index_utils:: + maskedIsLast( + blockIdx, gridDim); + + if ((PERSISTENT_REDUCTION || last_block) && grid_reduce_participate) { + // Can use the last block to reduce all the values the blocks filled in. + // Can use any thread that has been predicated, or has been reduced to do + // this reduction, cannot use any block that's associated with an + // iteration domain + + // Start with non-block reduction + + // Index in the reduction segment + int tid_in_block_reduction = index_utils::maskedOffset< + activeNotIter(X_THREAD), + activeNotIter(Y_THREAD), + activeNotIter(Z_THREAD)>(threadIdx, blockDim); + + int block_reduction_size = index_utils::maskedSize< + activeNotIter(X_THREAD), + activeNotIter(Y_THREAD), + activeNotIter(Z_THREAD)>(blockDim); + + bool has_block_result = index_utils::maskedIsZero< + activeNotIter(X_THREAD), + activeNotIter(Y_THREAD), + activeNotIter(Z_THREAD)>(threadIdx); + + // 3D buffer of reductions: + // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] + // Change the offset, we want to keep the last two dimensions, but the + // first dimension is what we will reduce over + const auto work_buf_offset = + block_red_idx_offset * num_thread_iters + thread_red_idx_offset; + for (auto reduction_i = tid_in_block_reduction; + reduction_i < grid_red_size; + reduction_i += block_reduction_size) { + impl::welfordEach( + last_block_result, + 0, + global_work_buffer, + work_buf_offset + reduction_i * num_block_iters * num_thread_iters); + } + + // Which block reduction this thread is participating in + int block_reduction_idx = index_utils:: + maskedOffset( + threadIdx, blockDim); + + impl::blockWelfordEach< + BROADCAST, + false, + LocalWelfordTripletTuple>( + last_block_result, + last_block_result, + shared_buf, + has_block_result, + tid_in_block_reduction, + block_reduction_size, + min(grid_red_size, block_reduction_size), + block_reduction_idx); + + copyWelfordTripletTupleIf( + out, + last_block_result, + write_preds && + (block_reduce_participate && (BROADCAST || has_block_result))); + } + } + // End Parallel reduce class }; diff --git a/torch/csrc/jit/codegen/cuda/runtime/tuple.cu b/torch/csrc/jit/codegen/cuda/runtime/tuple.cu index 7f2e7ab94b7d4b..6daac6b99b758c 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/tuple.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/tuple.cu @@ -606,6 +606,179 @@ using PtrTuple = PtrTupleBase; template using VolatilePtrTuple = PtrTupleBase; +// Define a LocalTuple of NumVals values of type Type +template +struct MakeLocalTuple; + +template +struct MakeLocalTuple<1, Type> { + using type = LocalTuple; +}; + +template +struct MakeLocalTuple<2, Type> { + using type = LocalTuple; +}; + +template +struct MakeLocalTuple<3, Type> { + using type = LocalTuple; +}; + +template +struct MakeLocalTuple<4, Type> { + using type = LocalTuple; +}; + +template +struct MakeLocalTuple<5, Type> { + using type = LocalTuple; +}; + +template +struct MakeLocalTuple<6, Type> { + using type = LocalTuple; +}; + +template +struct MakeLocalTuple<7, Type> { + using type = LocalTuple; +}; + +template +struct MakeLocalTuple<8, Type> { + using type = LocalTuple; +}; + +template +struct MakeRefTuple; + +template +struct MakeRefTuple<1, Type> { + using type = RefTuple; +}; + +template +struct MakeRefTuple<2, Type> { + using type = RefTuple; +}; + +template +struct MakeRefTuple<3, Type> { + using type = RefTuple; +}; + +template +struct MakeRefTuple<4, Type> { + using type = RefTuple; +}; + +template +struct MakeRefTuple<5, Type> { + using type = RefTuple; +}; + +template +struct MakeRefTuple<6, Type> { + using type = RefTuple; +}; + +template +struct MakeRefTuple<7, Type> { + using type = RefTuple; +}; + +template +struct MakeRefTuple<8, Type> { + using type = RefTuple; +}; + +template +struct MakeConstRefTuple; + +template +struct MakeConstRefTuple<1, Type> { + using type = ConstRefTuple; +}; + +template +struct MakeConstRefTuple<2, Type> { + using type = ConstRefTuple; +}; + +template +struct MakeConstRefTuple<3, Type> { + using type = ConstRefTuple; +}; + +template +struct MakeConstRefTuple<4, Type> { + using type = ConstRefTuple; +}; + +template +struct MakeConstRefTuple<5, Type> { + using type = ConstRefTuple; +}; + +template +struct MakeConstRefTuple<6, Type> { + using type = ConstRefTuple; +}; + +template +struct MakeConstRefTuple<7, Type> { + using type = ConstRefTuple; +}; + +template +struct MakeConstRefTuple<8, Type> { + using type = ConstRefTuple; +}; + +template +struct MakeVolatilePtrTuple; + +template +struct MakeVolatilePtrTuple<1, Type> { + using type = VolatilePtrTuple; +}; + +template +struct MakeVolatilePtrTuple<2, Type> { + using type = VolatilePtrTuple; +}; + +template +struct MakeVolatilePtrTuple<3, Type> { + using type = VolatilePtrTuple; +}; + +template +struct MakeVolatilePtrTuple<4, Type> { + using type = VolatilePtrTuple; +}; + +template +struct MakeVolatilePtrTuple<5, Type> { + using type = VolatilePtrTuple; +}; + +template +struct MakeVolatilePtrTuple<6, Type> { + using type = VolatilePtrTuple; +}; + +template +struct MakeVolatilePtrTuple<7, Type> { + using type = VolatilePtrTuple; +}; + +template +struct MakeVolatilePtrTuple<8, Type> { + using type = VolatilePtrTuple; +}; + // Utility definitions. Currently only used with LocalTuple template diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 406524c5f78b6e..186605ebfadfac 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -818,6 +818,15 @@ std::vector TensorView::rFactor( "Rfactor of a multi-output reduction not used correctly"); } + // Currently grouping of welford is only supported through + // ParallelType::Group, so GroupedWelfordOp is only created during + // the lowering time. As rFactor is done before lowering, there + // should be no GroupedWelfordOp at this point. + TORCH_INTERNAL_ASSERT( + !definition()->isA(), + "GroupedWelfordOp found: ", + definition()->toString()); + std::vector rf_tvs(tvs.size()); // Make sure this gets rfactored last so everybody gets @@ -844,25 +853,25 @@ std::vector TensorView::rFactor( IrBuilder::create( producer_avg, producer_var, - producer_n, /*out var/avg/count */ - wop->initAvg(), - wop->initVar(), - wop->initN(), /*init var/avg/count */ + producer_n, wop->inAvg(), wop->inVar(), - wop->inN()); + wop->inN(), + wop->initAvg(), + wop->initVar(), + wop->initN()); // Expr* consumer_definition = IrBuilder::create( wop->outAvg(), wop->outVar(), wop->outN(), - wop->initAvg(), - wop->initVar(), - wop->initN(), producer_avg, producer_var, - producer_n); + producer_n, + wop->initAvg(), + wop->initVar(), + wop->initN()); } else if ( auto grouped_rop = dynamic_cast(definition())) { IrBuilder::create( diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp index 332fec9b652368..3b9e7cbd962c65 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -24,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -2148,6 +2150,319 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce4_CUDA) { testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelford1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = Welford(tv1, {0}).avg; + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv0, tv3); + fusion.addOutput(tv4); + + const int vec = 2; + const int tidx = 32; + const int tidy = 8; + + tv1->split(1, vec); + tv1->split(1, tidx); + tv1->split(0, tidy); + TransformPropagator propagator(tv1); + MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + + tv1->axis(0)->parallelize(ParallelType::BIDy); + tv1->axis(1)->parallelize(ParallelType::TIDy); + tv1->axis(2)->parallelize(ParallelType::BIDx); + tv1->axis(3)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv1); + + tv2->axis(4)->parallelize(ParallelType::Group); + + // Make sure the reduction expr is converted to GroupedGridReduciton + // and the non-reduction domains of the output TV are either + // grouped or parallelized + GpuLower gpulw(&fusion); + bool validated = false; + for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { + auto grouped_grid_reduction = dynamic_cast(expr); + if (grouped_grid_reduction == nullptr) { + continue; + } + validated = true; + } + TORCH_CHECK( + validated, "Invalid lowered kernel. No GroupedGridWelford found."); + + std::vector shape({99, 101}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn(shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + auto t0_double = t0.to(at::kDouble); + auto ref = t0_double + t0_double.mean({0}).unsqueeze(0); + + testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Test grouping of two domains +TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelford2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = Welford(tv1, {0}).avg; + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv0, tv3); + fusion.addOutput(tv4); + + const int vec1 = 2; + const int vec2 = 3; + const int tidx = 16; + const int tidy = 8; + + tv1->split(1, vec1); + tv1->split(1, vec2); + tv1->split(1, tidx); + tv1->split(0, tidy); + TransformPropagator propagator(tv1); + MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + + tv1->axis(0)->parallelize(ParallelType::BIDy); + tv1->axis(1)->parallelize(ParallelType::TIDy); + tv1->axis(2)->parallelize(ParallelType::BIDx); + tv1->axis(3)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv1); + + tv2->axis(4)->parallelize(ParallelType::Group); + tv2->axis(5)->parallelize(ParallelType::Group); + + std::vector shape({99, 129}); + + // Make sure the reduction expr is converted to GroupedGridReduciton + // and the non-reduction domains of the output TV are either + // grouped or parallelized + GpuLower gpulw(&fusion); + bool validated = false; + for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { + auto grouped_grid_reduction = dynamic_cast(expr); + if (grouped_grid_reduction == nullptr) { + continue; + } + validated = true; + } + TORCH_CHECK( + validated, "Invalid lowered kernel. No GroupedGridWelford found."); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn(shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + auto t0_double = t0.to(at::kDouble); + auto ref = t0_double + t0_double.mean({0}).unsqueeze(0); + + testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Follows the pattern of persistent outer grid welford in batchnorm +TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelfordShmoo_CUDA) { + struct Params { + int N; + int H; + int W; + int C; + int tidx; + int tidy; + int vect; + int persistent_buffer; + int bidx; + }; + + auto test = [](const Params& params) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector bcast_pattern{true, true, true, false}; + std::vector reduction_dims{2, 1, 0}; + + auto tv0 = makeSymbolicTensor(4); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tvs = Welford(tv1, reduction_dims); + auto tv2 = tvs.avg; + auto tv3 = tvs.var_sum; + auto tv4 = tvs.n; + auto tv5 = broadcast(tv2, bcast_pattern); + auto tv6 = broadcast(tv3, bcast_pattern); + auto tv7 = broadcast(tv4, bcast_pattern); + auto tv8 = sub(tv1, tv5); + auto tv9 = add(tv8, tv6); + // auto tv10 = div(tv9, tv7); + // fusion.addOutput(tv10); + fusion.addOutput(tv9); + + // Schedule the fusion as it will be done by the persistent + // scheduler + + auto input_cache = tv1; + auto output_cache = tv9->cacheBefore(); + + auto transform_ref = tv2; + + transform_ref->merge(0)->merge(0); + + int reduction_pos = 1; + + transform_ref->split(0, params.tidy); + ++reduction_pos; + transform_ref->axis(1)->parallelize(ParallelType::TIDy); + + // Persistent buffer + transform_ref->split(0, params.persistent_buffer); + ++reduction_pos; + + // Unswitch + transform_ref->split(0, 1); + ++reduction_pos; + transform_ref->axis(1)->parallelize(ParallelType::Unswitch); + + transform_ref->axis(0)->parallelize(ParallelType::BIDy); + + transform_ref->split(reduction_pos, params.vect); + transform_ref->axis(reduction_pos + 1) + ->parallelize(ParallelType::Vectorize); + + transform_ref->split(reduction_pos, params.tidx); + transform_ref->axis(reduction_pos + 1)->parallelize(ParallelType::TIDx); + transform_ref->split(reduction_pos, params.bidx); + transform_ref->axis(reduction_pos + 1)->parallelize(ParallelType::BIDx); + + auto transform_ref_rf = + reduction_scheduler_utils::sortAndRFactor(transform_ref); + + TransformPropagator propagator(transform_ref_rf); + MaxRootDomainInfoSpanningTree(transform_ref_rf).traverse(&propagator); + + int vec_id = std::distance( + transform_ref_rf->domain()->domain().begin(), + std::find_if( + transform_ref_rf->domain()->domain().begin(), + transform_ref_rf->domain()->domain().end(), + [](auto id) { + return id->getParallelType() == ParallelType::Vectorize; + })); + transform_ref_rf->axis(vec_id)->parallelize(ParallelType::Serial); + + int unswitch_id = std::distance( + transform_ref_rf->domain()->domain().begin(), + std::find_if( + transform_ref_rf->domain()->domain().begin(), + transform_ref_rf->domain()->domain().end(), + [](auto id) { + return id->getParallelType() == ParallelType::Unswitch; + })); + transform_ref_rf->axis(unswitch_id)->parallelize(ParallelType::Serial); + + scheduler_utils::parallelizeAllLike( + transform_ref_rf, ir_utils::allTvs(&fusion)); + + ParallelType vec_pt = ParallelType::Vectorize; + tv1->axis(vec_id)->parallelize(vec_pt); + tv9->axis(vec_id)->parallelize(vec_pt); + + transform_ref->axis(vec_id)->parallelize(ParallelType::Group); + + transform_ref_rf->axis(unswitch_id)->parallelize(ParallelType::Unswitch); + + InlinePropagator inline_propagator( + transform_ref_rf, -1, ComputeAtMode::MostInlined); + MaxRootDomainInfoSpanningTree(transform_ref_rf) + .traverse(&inline_propagator); + + // Make sure the reduction expr is converted to GroupedGridReduciton + // and the non-reduction domains of the output TV are either + // grouped or parallelized + GpuLower gpulw(&fusion); + bool validated = false; + for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { + auto grouped_grid_reduction = + dynamic_cast(expr); + validated = true; + } + TORCH_CHECK( + validated, "Invalid lowered kernel. No GroupedGridWelford found."); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + const std::vector input_shape{ + params.N, params.H, params.W, params.C}; + auto t0 = at::randn(input_shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + + // Skip the rest of this test size if the required number of SMs + // exceeds the available SM count + const auto num_required_sms = params.bidx * + ceilDiv(ceilDiv(params.N * params.H * params.W, params.tidy), + params.persistent_buffer); + if (num_required_sms > deviceSMCount()) { + return; + } + + auto cg_outputs = fe.runFusion({t0}); + + auto t1 = t0.to(at::kDouble); + auto t2 = t1.mean({0, 1, 2}).unsqueeze(0).unsqueeze(0).unsqueeze(0); + auto t3 = + at::var(t1, {0, 1, 2}, false).unsqueeze(0).unsqueeze(0).unsqueeze(0); + auto t4 = params.N * params.H * params.W; + auto ref = (t1 - t2 + (t3 * t4)); + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__, ""); + }; + + std::vector base_params; + base_params.push_back({256, 7, 7, 1, 8, 32, 2, 32, 4}); + base_params.push_back({256, 7, 7, 1, 16, 16, 4, 50, 4}); + base_params.push_back({128, 7, 7, 1, 16, 16, 4, 32, 4}); + base_params.push_back({128, 14, 14, 1, 16, 16, 4, 32, 1}); + base_params.push_back({128, 14, 14, 1, 16, 16, 2, 64, 2}); + base_params.push_back({128, 14, 14, 1, 8, 32, 4, 50, 4}); + base_params.push_back({128, 14, 14, 1, 8, 32, 2, 50, 4}); + + std::vector param_vec; + for (const auto base_p : base_params) { + for (const auto c_dim : {512, 1024, 2048}) { + auto tmp = base_p; + tmp.C = c_dim; + param_vec.push_back(tmp); + } + } + + for (const auto& params : param_vec) { + test(params); + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 35f6310f2820fd..ef8136c631d400 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -304,6 +304,8 @@ static const char* expr_type2string(ExprType t) { return "BroadcastOp"; case ExprType::WelfordOp: return "WelfordOp"; + case ExprType::GroupedWelfordOp: + return "GroupedWelfordOp"; case ExprType::LoadStoreOp: return "LoadStoreOp"; case ExprType::MmaOp: @@ -350,6 +352,8 @@ static const char* expr_type2string(ExprType t) { return "GridBroadcast"; case ExprType::GridWelford: return "GridWelford"; + case ExprType::GroupedGridWelford: + return "GroupedGridWelford"; case ExprType::Swizzle2D: return "Swizzle2D"; case ExprType::Swizzle2DInt: diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index fce051d432fff9..455a995568349d 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -111,6 +111,7 @@ enum class ExprType { GroupedReductionOp, BroadcastOp, WelfordOp, + GroupedWelfordOp, MmaOp, TransposeOp, ExpandOp, @@ -137,6 +138,7 @@ enum class ExprType { GroupedGridReduction, GridBroadcast, GridWelford, + GroupedGridWelford, AllocateFusedReduction };