diff --git a/build_variables.bzl b/build_variables.bzl index e4b4b82df5f6..c4177968ad27 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -718,6 +718,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/root_domain_map.cpp", "torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp", "torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp", + "torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp", "torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp", "torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp", "torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp", diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index d8078bbd9e5c..4245fa6e0a01 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -101,6 +101,7 @@ if(USE_CUDA) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_view.cpp) + list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_rng.cu) endif() diff --git a/torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h b/torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h index 56460ec92695..90e64a284086 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h @@ -2,6 +2,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -12,7 +13,8 @@ enum class TORCH_CUDA_CU_API ScheduleHeuristic { None, PointWise, Reduction, - Persistent + Persistent, + Transpose }; } } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h b/torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h index f2c9f161619a..c43ef64eac0a 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h @@ -28,6 +28,7 @@ enum class CompileTimeEntryType { DOMAIN_MAP, REFERENCE_TENSORS, VECTORIZABLE_INPUTS_AND_OUTPUTS, + INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS, UNROLLABLE_INPUTS_AND_OUTPUTS, REDUCTION_TVS, PERSISTENT_BUFFER_INFO, @@ -62,6 +63,15 @@ class VectorizableInputsAndOutputs { CompileTimeEntryType::VECTORIZABLE_INPUTS_AND_OUTPUTS; }; +//! Entry type definition class for `INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS`, +//! stores the fusion's inputs and outputs grouped by inner most dimension. +class InputsOutputsInnerDimGroups { + public: + using DataType = std::vector>; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS; +}; + //! Entry type definition class for `UNROLLABLE_INPUTS_AND_OUTPUTS`, //! stores the unrollable TensorViews on a fusion's inputs and outputs. class UnrollableInputsAndOutputs { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index ad04f437389d..56d0f2e43e62 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -57,29 +58,6 @@ class DomainMap : public pointwise_utils::DomainMap { return domain_map.findReferenceTensorView() != nullptr; } - // Determine if output TensorView is a valid reference tensor for this fusion. - // The reference tensor must map to all the iterDomains in each input. - bool isValidReference(TensorView* output_tv) const { - if (output_tv->isFusionInput()) { - return false; - } - for (auto input_tv : - ir_utils::filterByType(fusion_->inputs())) { - if (input_tv->uses().empty()) { - continue; - } - - if (fusion_->getOutputAlias(output_tv) == input_tv) { - continue; - } - - if (!areAllInputIdsMappedToOutput(input_tv, output_tv)) { - return false; - } - } - return true; - } - private: bool hasMinimumSize(TensorView* tv, int num_axes) const { TORCH_INTERNAL_ASSERT(tv != nullptr); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp index ff6bfd07dd12..cf823322078f 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp @@ -6,19 +6,9 @@ namespace fuser { namespace cuda { namespace pointwise_utils { -DomainMap::DomainMap(Fusion* fusion) - : fusion_(fusion), ca_map_(ComputeAtMap(fusion)) { - view_tvs_ = scheduler_utils::getViewTVs(fusion); -} - -bool DomainMap::areExactMapped(IterDomain* id1, IterDomain* id2) { - return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT); -} - -// Determine if all IterDomains in input are mapped to output -bool DomainMap::areAllInputIdsMappedToOutput( - TensorView* input_tv, - TensorView* output_tv) const { +// Determine if all IterDomains in input are mapped to the given tensor +bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) + const { // Get concrete IDs for input root or rfactor domain std::unordered_set in_concrete_ids; for (auto in_id : input_tv->getMaybeRFactorDomain()) { @@ -30,11 +20,9 @@ bool DomainMap::areAllInputIdsMappedToOutput( // Erase all input concrete IDs mapped to the output domain // Ignore unresolved broadcast dimensions - for (auto out_id : output_tv->getMaybeRFactorDomain()) { - if (!out_id->isBroadcast()) { - if (!eraseIfMapped(in_concrete_ids, out_id)) { - eraseIfInputMappedThroughViewToOutput(in_concrete_ids, out_id); - } + for (auto id : tv->getMaybeRFactorDomain()) { + if (!eraseIfMapped(in_concrete_ids, id)) { + eraseIfInputMappedThroughViewTo(in_concrete_ids, id); } } return in_concrete_ids.empty(); @@ -45,7 +33,7 @@ bool DomainMap::eraseIfMapped( std::unordered_set& in_concrete_ids, IterDomain* out_id) const { auto out_concrete_id = - ca_map_.getConcreteMappedID(out_id, IdMappingMode::EXACT); + ca_map_.getConcreteMappedID(out_id, IdMappingMode::PERMISSIVE); auto in_concrete_id_iter = in_concrete_ids.find(out_concrete_id); bool found_match = in_concrete_id_iter != in_concrete_ids.end(); if (found_match) { @@ -58,12 +46,12 @@ bool DomainMap::eraseIfMapped( // Currently this function only allow having one view on the path from input to // output. If there are multiple views, then likely the pointwise scheduler will // reject the fusion because we can not correctly find a reference tensor. -void DomainMap::eraseIfInputMappedThroughViewToOutput( +void DomainMap::eraseIfInputMappedThroughViewTo( std::unordered_set& in_concrete_ids, - IterDomain* out_id) const { + IterDomain* id) const { for (auto view : view_tvs_) { // Find any ID in view rfactor domain that is mapped to output ID - auto view_rfactor_id = anyMapped(view->getRFactorDomain(), out_id); + auto view_rfactor_id = anyMapped(view->getRFactorDomain(), id); if (view_rfactor_id == nullptr) { continue; } @@ -94,6 +82,20 @@ IterDomain* DomainMap::anyMapped( return nullptr; } +// Determine if output TensorView is a valid reference tensor for this fusion. +// The reference tensor must map to all the iterDomains in each input. +bool DomainMap::isValidReference(TensorView* tv) const { + for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { + if (input_tv->uses().empty()) { + continue; + } + if (!areAllInputIdsMappedTo(input_tv, tv)) { + return false; + } + } + return true; +} + } // namespace pointwise_utils } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h index 99d29a452511..7947a27f4836 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h @@ -15,18 +15,26 @@ namespace pointwise_utils { // that maps to all IterDomains in the fusion. class DomainMap { public: - DomainMap(Fusion* fusion); + DomainMap(Fusion* fusion) : fusion_(fusion), ca_map_(fusion) { + view_tvs_ = scheduler_utils::getViewTVs(fusion); + } virtual ~DomainMap() = default; - bool areExactMapped(IterDomain* id1, IterDomain* id2); + bool areExactMapped(IterDomain* id1, IterDomain* id2) const { + return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT); + } const ComputeAtMap& getComputeAtMap() const { return ca_map_; } + // Determine if a TensorView is a valid reference tensor for this fusion. + // The reference tensor must map to all the iterDomains in each input. + bool isValidReference(TensorView* tv) const; + protected: - // Determine if all iterDomains are mapped between input and output tvs - bool areAllInputIdsMappedToOutput(TensorView* input_tv, TensorView* output_tv) + // Determine if all IterDomains are mapped between input and the given tvs + bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) const; // Erase input concrete ID if it is mapped to output ID @@ -34,10 +42,10 @@ class DomainMap { std::unordered_set& in_concrete_ids, IterDomain* out_id) const; - // Check if in_id is mapped to out_id through any view rfactor domain - void eraseIfInputMappedThroughViewToOutput( + // Check if in_id is mapped to id through any view rfactor domain + void eraseIfInputMappedThroughViewTo( std::unordered_set& in_concrete_ids, - IterDomain* out_id) const; + IterDomain* id) const; // Find any id in domain that maps with target id IterDomain* anyMapped( diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index 76aeafeb002f..17c72c77da47 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -1244,10 +1245,75 @@ class PersistentKernelScheduler : public SchedulerEntry { } }; +class TransposeScheduler : public SchedulerEntry { + public: + explicit TransposeScheduler( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) + : SchedulerEntry(ScheduleHeuristic::Transpose) { + computeHeuristics(fusion, runtime_info, data_cache); + } + + static bool canScheduleCompileTime(Fusion* fusion) { + // Not enabling this yet. Needs more validation. + return false; +#if 0 + if (!hasAtLeastTwoValidGroups(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Transpose, + "cannot find two mismatching inner most dimensions"); + return false; + } + + // TODO: add support for trivial reduction + auto reduction_ops = + ir_utils::getReductionOps(fusion, false /* ignore_trivial */); + + if (!reduction_ops.empty()) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Transpose, "no support for reduction ops"); + return false; + } + + if (hasNonUniqueBcast(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Transpose, + "Broadcasting dimension might be broadcasting to multiple sizes."); + return false; + } + + return true; +#endif + } + + static bool canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) { + return true; + } + + void schedule(Fusion* fusion) override { + FUSER_PERF_SCOPE("Schedule Transpose Fusion"); + scheduleTranspose(fusion, transposeParams()); + } + + private: + void computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) { + params_ = getTransposeHeuristics(fusion, runtime_info, data_cache); + TORCH_INTERNAL_ASSERT(params_ != nullptr); + } +}; + // Schedule Table const std::vector& all_heuristics() { static const std::vector hlist = { ScheduleHeuristic::Reduction, + ScheduleHeuristic::Transpose, ScheduleHeuristic::PointWise, ScheduleHeuristic::Persistent}; return hlist; @@ -1294,6 +1360,9 @@ bool SchedulerEntry::canSchedule( case ScheduleHeuristic::Persistent: return checkCanSchedule( fusion, runtime_info, data_cache); + case ScheduleHeuristic::Transpose: + return checkCanSchedule( + fusion, runtime_info, data_cache); default: TORCH_INTERNAL_ASSERT(false, "unreachable"); return false; @@ -1320,6 +1389,10 @@ std::unique_ptr SchedulerEntry::makeEntry( scheduler_entry = std::make_unique( fusion, runtime_info, data_cache); break; + case ScheduleHeuristic::Transpose: + scheduler_entry = std::make_unique( + fusion, runtime_info, data_cache); + break; default: TORCH_INTERNAL_ASSERT(false, "unreachable"); } @@ -1353,6 +1426,8 @@ std::string toString(ScheduleHeuristic sh) { return "reduction"; case ScheduleHeuristic::Persistent: return "persistent"; + case ScheduleHeuristic::Transpose: + return "transpose"; default: TORCH_INTERNAL_ASSERT(false, "undefined schedule"); } @@ -1405,6 +1480,10 @@ HeuristicSummary::HeuristicSummary( getPersistentHeuristics(fusion, runtime_info, this); PersistentKernelScheduler::canScheduleRunTime(fusion, runtime_info, this); break; + case ScheduleHeuristic::Transpose: + getTransposeHeuristics(fusion, runtime_info, this); + TransposeScheduler::canScheduleRunTime(fusion, runtime_info, this); + break; default: TORCH_INTERNAL_ASSERT(false, "unknown heuristic"); } @@ -1451,6 +1530,11 @@ void HeuristicSummary::validate() const { entry_type_map_.count(EntryType::SCOPE_PERSISTENT_FACTOR_INFO)); break; } + case ScheduleHeuristic::Transpose: { + TORCH_INTERNAL_ASSERT(entry_type_map_.count( + EntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS)); + break; + } default: TORCH_INTERNAL_ASSERT(false, "unknown heuristic"); } @@ -1490,6 +1574,8 @@ template class HeuristicSummaryEntry; template class HeuristicSummaryEntry; template class HeuristicSummaryEntry< HeuristicCompileTime::VectorizableInputsAndOutputs>; +template class HeuristicSummaryEntry< + HeuristicCompileTime::InputsOutputsInnerDimGroups>; template class HeuristicSummaryEntry< HeuristicCompileTime::UnrollableInputsAndOutputs>; template class HeuristicSummaryEntry; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.h b/torch/csrc/jit/codegen/cuda/scheduler/registry.h index 7d2af85bfad0..dd8caf63ccda 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.h @@ -187,6 +187,13 @@ class TORCH_CUDA_CU_API SchedulerEntry { return *pparams; } + const TransposeParams& transposeParams() const { + auto tparams = std::dynamic_pointer_cast(params_); + TORCH_INTERNAL_ASSERT( + tparams != nullptr, "Heuristic parameter is not a transpose parameter"); + return *tparams; + } + void updateLaunchConstraint(const LaunchParams& launch_params) { params_->lparams = launch_params; } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp new file mode 100644 index 000000000000..1c430b6a0dfa --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp @@ -0,0 +1,616 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace { + +constexpr int64_t kThreadsPerBlock = 128; + +// DomainMap uses the ComputeAtMap to find a reference TensorView +// that maps to all iterDomains in the fusion. +class DomainMap : public pointwise_utils::DomainMap { + public: + using pointwise_utils::DomainMap::DomainMap; + + TensorView* findReferenceFor(const std::vector& group) const { + TensorView* result = nullptr; + int max_dims = -1; + for (auto tv : group) { + if (isValidReference(tv)) { + int dims = pointwise_utils::nRootDims(tv); + if (dims > max_dims) { + result = tv; + max_dims = dims; + } + } + } + return result; + } + + static bool hasAtLeastTwoValidGroups(Fusion* fusion) { + FusionGuard fg(fusion); + DomainMap domain_map(fusion); + auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); + if (grouped_inputs_outputs.size() < 2) { + return false; + } + return domain_map.findReferenceFor(grouped_inputs_outputs[0]) != nullptr && + domain_map.findReferenceFor(grouped_inputs_outputs[1]) != nullptr; + } + + int getPosMappedTo(TensorView* tv, IterDomain* id) { + const auto& dom = tv->domain()->domain(); + for (auto i : c10::irange(dom.size())) { + if (areExactMapped(tv->axis(i), id)) { + return i; + } + } + TORCH_INTERNAL_ASSERT( + false, "Can not find ID mapped to ", id, " in tensor ", tv); + } + + // Group inputs and outputs of a fusion by its inner most domain. For example + // inputs: t0, t1 + // t2 = transpose(t1) + // t3 = t0 + t2 + // t4 = sin(t0) + // t5 = cos(t1) + // outputs: t3, t4, t5 + // + // Then we should have group {t0, t3, t4} and {t1, t5} + // + // The returned groups are sorted in descending size. If the sizes of two + // group are equal, then we sort them by their members in the following order: + // output[0], output[1], ..., input[0], input[1], ... + // That is, {ouput[0], output[2]} will be in front of {ouput[1], output[3]} + // The order here must be deterministic, because in transpose heuristics, we + // have `vectorize_factor1` and `vectorize_factor2` and we need to be sure + // that `1` and `2` are assigned to the same group across runs. + std::vector> groupInputsOutputsByInnerDim() const { + std::vector> groups; + auto output_tvs = ir_utils::filterByType(fusion_->outputs()); + auto input_tvs = ir_utils::filterByType(fusion_->inputs()); + std::unordered_map group_to_inner_dim_map; + decltype(input_tvs)* tv_filtered_group[2] = {&output_tvs, &input_tvs}; + for (auto view : tv_filtered_group) { + for (auto tv : *view) { + auto inner_most_id = scheduler_utils::innerMostRootDim(tv); + bool found = false; + for (auto gi : c10::irange(groups.size())) { + auto& g = groups[gi]; + auto group_inner_dim = group_to_inner_dim_map.at(gi); + if (areExactMapped(inner_most_id, group_inner_dim)) { + g.emplace_back(tv); + found = true; + break; + } + } + if (!found) { + group_to_inner_dim_map[groups.size()] = inner_most_id; + groups.push_back({tv}); + } + } + } + std::stable_sort( + groups.begin(), + groups.end(), + [](const std::vector& v1, + const std::vector& v2) { + return v1.size() > v2.size(); + }); + return groups; + } +}; + +} // namespace + +bool hasAtLeastTwoValidGroups(Fusion* fusion) { + return DomainMap::hasAtLeastTwoValidGroups(fusion); +} + +std::shared_ptr getTransposeHeuristics( + Fusion* fusion, + const at::ArrayRef& runtime_inputs, + HeuristicSummary* data_cache) { + SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true); + return getTransposeHeuristics(fusion, runtime_info, data_cache); +} + +std::shared_ptr getTransposeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache) { + FUSER_PERF_SCOPE("getTransposeHeuristics"); + + FusionGuard fg(fusion); + + // Incase any buffer is of type DataType::Index + DataType index_type = indexModeToDtype(runtime_info.getIndexMode()); + + auto domain_map_entry = + HeuristicSummaryEntry( + data_cache, + [fusion]() { return std::make_unique(fusion); }); + const auto& domain_map = dynamic_cast(domain_map_entry.get()); + + auto grouped_inputs_outputs_entry = + HeuristicSummaryEntry( + data_cache, [&domain_map]() { + return std::make_unique>>( + domain_map.groupInputsOutputsByInnerDim()); + }); + auto grouped_inputs_outputs = grouped_inputs_outputs_entry.get(); + + TORCH_INTERNAL_ASSERT( + grouped_inputs_outputs.size() >= 2, + "Can not find mismatched inner most dim, should use pointwise scheduler."); + + auto largest_entry = + HeuristicSummaryEntry( + data_cache, [&domain_map, &grouped_inputs_outputs]() { + std::vector data{ + domain_map.findReferenceFor(grouped_inputs_outputs[0]), + domain_map.findReferenceFor(grouped_inputs_outputs[1])}; + return std::make_unique>(std::move(data)); + }); + auto& largest = largest_entry.get(); + TORCH_INTERNAL_ASSERT(largest.size() == 2); + TensorView* largest1 = largest[0]; + TensorView* largest2 = largest[1]; + TORCH_INTERNAL_ASSERT( + largest1 != nullptr, "Unable to find reference tensor for group 1"); + TORCH_INTERNAL_ASSERT( + largest2 != nullptr, "Unable to find reference tensor for group 2"); + + const int64_t device_multiprocessor_count = + (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + + int64_t max_input_dtype_size = 1; + + size_t n_input_tensors = 0; + for (auto inp : ir_utils::filterByType(fusion->inputs())) { + max_input_dtype_size = std::max( + max_input_dtype_size, + (int64_t)dataTypeSize(inp->getDataType().value(), index_type)); + n_input_tensors++; + } + + auto ref_root = largest1->getMaybeRFactorDomain(); + int64_t n_elems = 1; + for (size_t ref_i = 0; ref_i < ref_root.size(); ref_i++) { + auto inferred_val = + runtime_info.expressionEvaluator().evaluate(ref_root[ref_i]->extent()); + TORCH_INTERNAL_ASSERT( + inferred_val.has_value(), + "Error inferring size for pointwise scheduler: ", + ref_root[ref_i]->extent()->toInlineString()); + n_elems *= inferred_val.value().as(); + } + + auto params = std::make_shared("Transpose heuristics"); + + // Note [vectorization and unroll of input and output] + // + // The choice of vectorization size, block size and tile sizes needs to be + // consistent with each other. Consider the following: + // + // The number of threads in one block is + // num_threads = blockDim.x * blockDim.y + // and the number of elements per each tile is + // num_elems_per_tile = params->tile_size1 * params->tile_size2 + // So each thread needs to process + // num_elems_per_thread = num_elems_per_tile / num_threads + // elements. That is, once the tile sizes and block size are determined, the + // `num_elems_per_thread` is determined, regardless of vectorizability of + // input/output tensors. + // + // To make the selection of tile sizes othogonal to vectorizability, we + // support having both vectorization and unrolling in the same tensor. For + // example, if we have num_elems_per_tile == 1024 and num_threads = 256, then + // we have num_elems_per_thread being 4. And if we have vector size 2, then we + // will do unroll 2 * vectorize 2 at the same tensor. + // + // Also, since the inner most dim of different groups are not the same, it is + // natural to consider their vectorizability separately and allow them to have + // different vectorize/unroll sizes. + + constexpr int64_t kSixteen = 16; // clang tidy + + auto max_unroll_factor = ceilDiv( + // Available unrolling based on size of data type + (int64_t)kSixteen / max_input_dtype_size, + // Reduce max unrolling factor if we have many inputs/outputs to unroll + // as it could start consuming a lot of registers. + std::max( + (scheduler_utils::lastPow2( + (int64_t)grouped_inputs_outputs[0].size() + + (int64_t)grouped_inputs_outputs[1].size()) >> + 2), + (int64_t)1)); + + // Don't unroll at the cost of getting a full wave on the GPU + auto max_unroll_factor_occupancy = ceilDiv( + n_elems, + device_multiprocessor_count * params->tile_size1 * params->tile_size2); + max_unroll_factor = std::min(max_unroll_factor, max_unroll_factor_occupancy); + + // Compute maximum vectorize factor that can be used + size_t vectorize_factor1 = max_unroll_factor; + size_t vectorize_factor2 = max_unroll_factor; + + for (auto tv : grouped_inputs_outputs[0]) { + const auto tv_vectorize_factor = + runtime_info.getInnerDimVectorizableWidth(tv); + vectorize_factor1 = std::min(vectorize_factor1, tv_vectorize_factor); + } + for (auto tv : grouped_inputs_outputs[1]) { + const auto tv_vectorize_factor = + runtime_info.getInnerDimVectorizableWidth(tv); + vectorize_factor2 = std::min(vectorize_factor2, tv_vectorize_factor); + } + + // Try expanding vectorization to contig merged domains + auto expanded_vector_word_size1 = + scheduler_utils::expandVectorizationToContigMergedDomains( + fusion, + runtime_info, + grouped_inputs_outputs[0], + largest1, + 0, + vectorize_factor1); + auto expanded_vector_word_size2 = + scheduler_utils::expandVectorizationToContigMergedDomains( + fusion, + runtime_info, + grouped_inputs_outputs[1], + largest2, + 0, + vectorize_factor2); + + expanded_vector_word_size1 = std::min( + static_cast(max_unroll_factor), expanded_vector_word_size1); + expanded_vector_word_size2 = std::min( + static_cast(max_unroll_factor), expanded_vector_word_size2); + + vectorize_factor1 = std::max(vectorize_factor1, expanded_vector_word_size1); + vectorize_factor2 = std::max(vectorize_factor2, expanded_vector_word_size2); + + params->vectorize_factor1 = vectorize_factor1; + params->vectorize_factor2 = vectorize_factor2; + + // TODO: should we adjust tile size according to max_unroll_factor? + + params->lparams.bind(kThreadsPerBlock, ParallelType::TIDx); + + if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { + std::cerr << "\n===== Transpose Stats ========\n" + << "num_elems: " << n_elems << "\n" + << "n_input_tensors: " << n_input_tensors << "\n" + << "max_input_dtype_size: " << max_input_dtype_size << "\n" + << "group 1: " << ir_utils::toString(grouped_inputs_outputs[0]) + << "\n" + << "group 2: " << ir_utils::toString(grouped_inputs_outputs[1]) + << std::endl; + std::cerr << std::endl; + std::cerr << params->toString() << std::endl; + } + + return params; +} + +// TODO: remove or return launch parameters +LaunchParams scheduleTranspose( + Fusion* fusion, + const at::ArrayRef& runtime_inputs) { + FUSER_PERF_SCOPE("scheduleFusion"); + auto params = getTransposeHeuristics(fusion, runtime_inputs); + TORCH_INTERNAL_ASSERT( + params != nullptr, "Could not schedule pointwise operation."); + scheduleTranspose(fusion, *params); + return params->lparams; +} + +void scheduleTranspose(Fusion* fusion, const TransposeParams& params) { + FusionGuard fg(fusion); + + // Make sure we don't have global memory set on intermediate tensors from + // fusion segmentation + scheduler_utils::clearMemorySpace(fusion); + + // maybe has_reduction for scheduling should be done on a per output tensor + // basis. + // TODO: add support for trivial reduction + TORCH_INTERNAL_ASSERT( + ir_utils::getReductionOps(fusion, /*ignore_trivial=*/false).empty(), + "This scheduler only handles pointwise ops."); + + // Cache inputs + auto cached_inputs = scheduler_utils::cacheInputs(fusion, true); + + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); + + std::vector input_tvs; + { + auto filtered_tvs = ir_utils::filterByType(fusion->inputs()); + // Remove hanging tensor views + for (auto tv : filtered_tvs) { + if (tv->uses().empty()) { + continue; + } + input_tvs.push_back(tv); + } + } + auto output_tvs = ir_utils::filterByType(fusion->outputs()); + + size_t max_dims = 0; + for (auto inp : input_tvs) { + max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims); + } + + for (auto out : output_tvs) { + max_dims = std::max(pointwise_utils::nRootDims(out), max_dims); + } + + // If everything is zero dim tensors, just return. + if (max_dims == 0) { + return; + } + + DomainMap domain_map(fusion); + auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); + TORCH_INTERNAL_ASSERT(grouped_inputs_outputs.size() >= 2); + + // We need something similar to `cacheFork` for input tensors in group 2. We + // need this because we will want to propagate to the entire DAG except group + // 2 and its cached inputs, so we need to make sure the DAG is still connected + // if we remove group and its cached inputs. For example + // t0 + // | + // cache + // / \ + // t1 t2 + // if groups = {{t1, t2}, {t0}}, then removing {t0, cache} from the DAG will + // make it disconnected. + std::unordered_set group2_and_cached_inputs( + grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end()); + for (auto tv : grouped_inputs_outputs[1]) { + if (tv->isFusionInput()) { + auto existing_cache = ir_utils::consumerTvsOf(tv)[0]; + if (ir_utils::consumerTvsOf(existing_cache).size() > 1) { + auto new_cache = tv->cacheAfter(); + new_cache->setMemoryType(MemoryType::Shared); + group2_and_cached_inputs.emplace(new_cache); + } else { + existing_cache->setMemoryType(MemoryType::Shared); + group2_and_cached_inputs.emplace(existing_cache); + } + } + } + // set cached outputs of group 2 to shared memory + for (auto pair : cached_outputs) { + auto cached_output = pair.first; + auto output = pair.second; + if (group2_and_cached_inputs.count(output) > 0) { + cached_output->setMemoryType(MemoryType::Shared); + } + } + + TensorView* reference1 = + domain_map.findReferenceFor(grouped_inputs_outputs[0]); + TensorView* reference2 = + domain_map.findReferenceFor(grouped_inputs_outputs[1]); + + TORCH_INTERNAL_ASSERT( + reference1 != nullptr, + "Could not find a fully broadcasted tensor to reference schedule on the first group."); + + TORCH_INTERNAL_ASSERT( + reference2 != nullptr, + "Could not find a fully broadcasted tensor to reference schedule on the second group."); + + auto inner_most_id1 = scheduler_utils::innerMostRootDim(reference1); + auto inner_most_id2 = scheduler_utils::innerMostRootDim(reference2); + + auto inner_most_pos1_in_ref1 = + domain_map.getPosMappedTo(reference1, inner_most_id1); + auto inner_most_pos2_in_ref1 = + domain_map.getPosMappedTo(reference1, inner_most_id2); + + // make tile + // [..., I1, .., I2, ...] + reference1->split(inner_most_pos1_in_ref1, params.tile_size1); + reference1->reorder({{inner_most_pos1_in_ref1 + 1, -1}}); + reference1->split(inner_most_pos2_in_ref1, params.tile_size2); + reference1->reorder({{inner_most_pos2_in_ref1 + 1, -1}}); + // [..., I1/tile1, .., I2/tile2, ..., tile1, tile2] + + // Merge remaining dimensions + int lhs_i = -1; + for (int i = (int)reference1->nDims() - 2; i > 0; i--) { + auto axis_i = i - 1; + if (lhs_i == -1) { + lhs_i = axis_i; + } else { + reference1->merge(axis_i, lhs_i); + lhs_i = axis_i; + } + } + reference1->split(0, 1); + // [merged_dim, 1, tile1, tile2] + + // parallelize non-tile dimensions + reference1->axis(1)->parallelize(ParallelType::Unswitch); + reference1->axis(0)->parallelize(ParallelType::BIDx); + // [BIDx, Unswitch, tile1, tile2] + + // Propagate transformations so far to the entire DAG + TransformPropagator propagator(reference1); + MaxRootDomainInfoSpanningTree entire_dag(reference1); + entire_dag.traverse(&propagator); + scheduler_utils::parallelizeAllLike(reference1); + + // For a transpose scheduling, all we need is to bind threadIdx.x differently + // for inputs and outputs. This swap of binding could happen at any tensor on + // the path from input to output, especially, it does not have to be in the + // transpose tensor. Here, we naively do the binding swap at cached + // input/output for simplicity. We might need to find a better set of swap + // tensors in the future to reduce shared memory usage. + + // transform tile for vectorization/unroll + // See note [vectorization and unroll of input and output] + + // schedule group 2 + int pos = reference2->nDims() - 2; + // [..., tile1, tile2] + reference2->merge(pos); + reference2->split(pos, params.vectorize_factor2); + reference2->split(pos, kThreadsPerBlock); + // [..., Unroll, TIDx, Vectorize] + + // Propagate transformations of reference2 to the entire DAG except + // group 1. We actually only want to propagate to the fusion outputs, but + // inputs and outputs themselves are disconnected, so we have to borrow the + // entire DAG and use its spanning tree. + { + auto all_tvs_except1 = ir_utils::allTvsExcept( + fusion, + {grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()}); + SetSelector selector({all_tvs_except1.begin(), all_tvs_except1.end()}); + MaxRootDomainInfoSpanningTree entire_dag_except1(reference2, &selector); + TransformPropagator propagator(reference2); + entire_dag_except1.traverse(&propagator); + } + + // parallelize group2 and its cached inputs + { + reference2->axis(-1)->parallelize(ParallelType::Vectorize); + reference2->axis(-2)->parallelize(ParallelType::TIDx); + reference2->axis(-3)->parallelize(ParallelType::Unroll); + + ComputeAtMap ca_map(fusion); + + scheduler_utils::parallelizeAllLike( + reference2, + {group2_and_cached_inputs.begin(), group2_and_cached_inputs.end()}, + {ParallelType::Vectorize, ParallelType::TIDx}); + + // Only unrolled the axes that exactly maps to the unrolled axes + // on reference as support for permissively mapped axes are not + // yet clearly defined. + std::vector unrolled_group2_cached_inputs; + for (auto gin : group2_and_cached_inputs) { + if (std::any_of( + gin->domain()->domain().begin(), + gin->domain()->domain().end(), + [&ca_map, reference2](IterDomain* id) { + return ca_map.areMapped( + id, reference2->axis(-3), IdMappingMode::EXACT); + })) { + unrolled_group2_cached_inputs.push_back(gin); + } + } + + scheduler_utils::parallelizeAllLike( + reference2, unrolled_group2_cached_inputs, {ParallelType::Unroll}); + } + + // schedule group 1 + reference1->reorder({{-2, -1}}); + // [..., tile2, tile1] + pos = reference1->nDims() - 2; + reference1->merge(pos); + reference1->split(pos, params.vectorize_factor1); + reference1->split(pos, kThreadsPerBlock); + reference1->axis(-1)->parallelize(ParallelType::Vectorize); + reference1->axis(-2)->parallelize(ParallelType::TIDx); + reference1->axis(-3)->parallelize(ParallelType::Unroll); + // [..., Unroll, TIDx, Vectorize] + + // Propagate transformations, parallelization of the reference1 to the entire + // DAG except group 2 and its corresponding cached outputs. + { + auto all_tvs_except2 = + ir_utils::allTvsExcept(fusion, group2_and_cached_inputs); + SetSelector selector({all_tvs_except2.begin(), all_tvs_except2.end()}); + MaxRootDomainInfoSpanningTree entire_dag_except_outputs( + reference1, &selector); + TransformPropagator propagator(reference1); + entire_dag_except_outputs.traverse(&propagator); + scheduler_utils::parallelizeAllLike( + reference1, all_tvs_except2, {ParallelType::TIDx}); + } + + // vectorize and unroll group 1's output and cached input + { + ComputeAtMap ca_map(fusion); + std::vector group1_and_cached_inputs( + grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()); + for (auto tv : grouped_inputs_outputs[0]) { + if (tv->isFusionInput()) { + group1_and_cached_inputs.emplace_back(ir_utils::consumerTvsOf(tv)[0]); + } + } + scheduler_utils::parallelizeAllLike( + reference1, group1_and_cached_inputs, {ParallelType::Vectorize}); + + // Only unrolled the axes that exactly maps to the unrolled axes + // on reference as support for permissively mapped axes are not + // yet clearly defined. + std::vector unrolled_group1_cached_inputs; + for (auto gin : group1_and_cached_inputs) { + if (std::any_of( + gin->domain()->domain().begin(), + gin->domain()->domain().end(), + [&ca_map, reference1](IterDomain* id) { + return ca_map.areMapped( + id, reference1->axis(-3), IdMappingMode::EXACT); + })) { + unrolled_group1_cached_inputs.push_back(gin); + } + } + + scheduler_utils::parallelizeAllLike( + reference1, unrolled_group1_cached_inputs, {ParallelType::Unroll}); + } + + // cleanup parallelization from reference1 and reference2 if they are fusion + // inputs + for (auto tv : {reference1, reference2}) { + if (tv->isFusionInput()) { + for (auto id : tv->domain()->domain()) { + id->parallelize(ParallelType::Serial); + } + } + } + + // Inline + InlinePropagator inline_propagator( + reference1, -1, ComputeAtMode::MostInlined); + entire_dag.traverse(&inline_propagator); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/scheduler/transpose.h b/torch/csrc/jit/codegen/cuda/scheduler/transpose.h new file mode 100644 index 000000000000..374840846a61 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/scheduler/transpose.h @@ -0,0 +1,101 @@ +#pragma once + +#include + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// Note [Transpose scheduling] +// +// The target of transpose scheduling is to get coalesced global memory access +// to as much input and output tensors as possible. For a DAG with only pure +// pointwise operators, the scheduling is very simple because the inner most +// dimension of all input and output tensors are all mapped together in the +// ComputeAtMap, i.e., there is essentially only one inner most dimension. In +// such case, we just vectorize that inner most dimension and bind it to +// threadIdx.x identically for all input and output tensors. In the case where +// transposes are present in the DAG, the inner most dimensions of different +// inputs and outputs might not match. And there is no fixed pattern on which +// input/output tensors should share the same inner most dimension with which. +// Consider the following example DAGs ([T] represents transpose, all tensors +// are 2D): +// +// t0 t1 t0 t1 t0 t1 t0 t1 t0 +// \ | \ / \ | \ | | +// \ [T] [T] [T] \ [T] t2 [T] [T] +// \ / \ / \ / \ / \ / \ | +// t2 t2 t2 t3 t3 t4 t5 [T] +// | +// t1 +// +// In order to support all these cases in a general way, the following +// perspective is very important: What we are looking for is to bind threadIdx.x +// differently for different inputs and outputs, so there has to be some tensor +// somewhere in the DAG that we write and read with different threadIdx.x +// bindings. The tensor of binding swap can be any tensor on the path that +// connects inputs/outputs with different inner most dimension, especially, it +// does not necessarily have to be the tensor of the transpose operator. In +// other words, thanks to our indexing system who is already taking care of the +// correctness of transpose, the scheduler can freely choose where to realize +// these transposes as different threadIdx.x bindings. This observation greatly +// simplifies our scheduling. +// +// Our scheduling strategy is as follows: We first split the inputs and outputs +// of the fusion into two groups according to their inner most dimension. The +// inner most dimensions of tensors in the same group are mapped to each other, +// and they are not mapped to the inner most dimesion of tensors in a different +// group. Depending on the transpose pattern, there can be more than two groups, +// if this is the case, we only consider the two largest groups, and the tensors +// in the remaining groups will just be accessed unvectorized and uncoalesced. +// We call the largest group as `group1` and the second largest group as +// `group2`. When we have the split, we will make a 2D tiling [I1, I2] -> +// [I1/tile1, tile1, I2/tile2, tile2] on the inner most dimensions of group1 and +// group2. Each tile [tile1, tile2] will be handled by a block, and the tensors +// that have mismatched threadIdx.x bindings will use shared memory. The outer +// IDs of the tiling split will be merged with non-tiled IDs and then binded to +// blockIdx.x for the entire DAG, regardless of which group a tensor belongs to. +// For the inner tile IDs [tile1, tile2], we need to transform and parallelize +// group 1 and group 2 differently. The intermediate tensors can be transformed +// and parallelized consistently either with group 1 or group 2. Here, since +// group 1 is larger than group 2, we decide to only transform and parallelize +// the cached inputs of group 2 together with group 2, and keep the rest of the +// DAG consistent with group 1. +// +// If you would like to see an example of how to manually schedule a complicated +// DAG using this idea, refer to: +// FusionManualScheduleTransposeComplexDAG1_CUDA + +class SchedulerRuntimeInfo; +class HeuristicSummary; + +TORCH_CUDA_CU_API std::shared_ptr getTransposeHeuristics( + Fusion* fusion, + const at::ArrayRef& runtime_inputs, + HeuristicSummary* data_cache = nullptr); + +TORCH_CUDA_CU_API std::shared_ptr getTransposeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr); + +TORCH_CUDA_CU_API void scheduleTranspose( + Fusion* fusion, + const TransposeParams& params); + +TORCH_CUDA_CU_API LaunchParams scheduleTranspose( + Fusion* fusion, + const at::ArrayRef& runtime_inputs); + +//! Utility for canSchedule interface to check if this fusion has at least two +//! groups, each with a fully broadcasted reference tensor. +bool hasAtLeastTwoValidGroups(Fusion* fusion); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/scheduler/transpose_heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/transpose_heuristic.h new file mode 100644 index 000000000000..2755afeccaa9 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/scheduler/transpose_heuristic.h @@ -0,0 +1,94 @@ +#pragma once + +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// Parameters of the transpose heuristic to describe the optimial schedule. +// Warning: equal operator is intended for use in caching the kernel associated +// with these reduction parameters. It does not check if the launch parameters +// are equivelent! +class TransposeParams : public HeuristicParams { + public: + // Vectorization factor for tensors in the first group + size_t vectorize_factor1 = 1; + + // Vectorization factor for tensors in the second group + size_t vectorize_factor2 = 1; + + // TODO: support symbolic tile size + // https://github.com/csarofeen/pytorch/pull/1854#discussion_r928143729 + + // Tile size for the inner most dim of tensors in the first group + size_t tile_size1 = 32; + + // Tile size for the inner most dim of tensors in the second group + size_t tile_size2 = 32; + + using HeuristicParams::HeuristicParams; + + // Warning: Does not check launch parameters! + bool sameAs( + const std::shared_ptr& other_base) const override { + auto other_casted = std::dynamic_pointer_cast(other_base); + if (other_casted == nullptr) { + return false; + } + const TransposeParams& other = *other_casted; + bool attr_equal = other.vectorize_factor1 == vectorize_factor1 && + other.vectorize_factor2 == vectorize_factor2 && + other.tile_size1 == tile_size1 && other.tile_size2 == tile_size2; + return attr_equal; + } + + std::string toString() const override { + std::stringstream ss; + ss << "\n===== Transpose Parameters ========\n" + << (tag == "" ? "" : "Tag: ") << tag << " Transpose Characteristics:\n" + << " Gridx: " << lparams.gdimx() << " BlckY: " << lparams.bdimy() + << " BlckX: " << lparams.bdimx() << "\n"; + ss << " input tile size: " << tile_size1 << "\n"; + ss << " output tile size: " << tile_size2 << "\n"; + int elements_per_tile = tile_size1 * tile_size2; + ss << " elements per tile: " << elements_per_tile << "\n"; + int elements_per_thread = + elements_per_tile / (lparams.bdimy() * lparams.bdimx()); + ss << " elements per thread: " << elements_per_thread << "\n"; + if (vectorize_factor1 > 1) { + ss << "Vectorize set 1, Factor: " << vectorize_factor1 << "\n"; + } + int unroll_factor1 = elements_per_thread / vectorize_factor1; + if (unroll_factor1 > 1) { + ss << "Unroll set 1, Factor: " << unroll_factor1 << "\n"; + } + if (vectorize_factor2 > 1) { + ss << "Vectorize set 2, Factor: " << vectorize_factor2 << "\n"; + } + int unroll_factor2 = elements_per_thread / vectorize_factor2; + if (unroll_factor2 > 1) { + ss << "Unroll set 2, Factor: " << unroll_factor2 << "\n"; + } + ss << "====================================\n"; + return ss.str(); + } + + size_t hash() const override { + size_t attr_hash = vectorize_factor1 ^ (vectorize_factor2 << 16) ^ + (tile_size1 << 32) ^ (tile_size2 << 48); + return attr_hash; + } + + std::shared_ptr clone() const override { + return std::make_shared(*this); + } +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index cce482e8ccc7..f75dd800d25d 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -13150,69 +13150,6 @@ TEST_F(NVFuserTest, FusionWelfordShmoo_CUDA) { } } -TEST_F(NVFuserTest, FusionTranspose1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - constexpr int M = 10; - constexpr int N = 20; - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = transpose(tv0); - fusion.addInput(tv0); - fusion.addOutput(tv1); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M, N}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - at::Tensor aten_output = t0.t(); - - testValidate( - &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTranspose2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - constexpr int M = 10; - constexpr int N = 20; - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = transpose(tv0); - fusion.addInput(tv0); - fusion.addOutput(tv1); - - tv1->merge(0); - tv1->split(0, 32); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M, N}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - at::Tensor aten_output = t0.t(); - - testValidate( - &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - TEST_F(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -14094,134 +14031,6 @@ TEST_F(NVFuserTest, FusionSwizzle2_CUDA) { &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionTransposeWithSwizzle_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = transpose(tv0); - fusion.addOutput(tv1); - - // tv0: [I0, I1] - // tv1: [I1, I0] - - const int BS = 32; - - // CTA tiling by BS*BS - tv1->split(1, BS); - tv1->split(0, BS); - tv1->reorder({{1, 2}}); - // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)] - - // Create a smem buffer to cache each tile - auto tv0_cache = tv0->cacheAfter(); - tv0_cache->setMemoryType(MemoryType::Shared); - - tv0->computeAt(tv1, 2); - // tv0: [I0, I1] - // tv0_cache: [I1/BS, I0/BS, BS(I1), BS(I0)] - // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)] - - // Assign each thread block to a tile - tv1->axis(0)->parallelize(ParallelType::BIDy); - tv1->axis(1)->parallelize(ParallelType::BIDx); - - // Thread mapping for each tile. For both of the input and output - // tiles, map TIDx to the fastest-changing dimension to facilitate - // coalesced gmem accesses. - tv1->axis(2)->parallelize(ParallelType::TIDy); - tv1->axis(3)->parallelize(ParallelType::TIDx); - // Note that the fastest-changing axis is next to the inner-most - // axis since computeAt reorders the axes as the output tensor. - tv0_cache->axis(2)->parallelize(ParallelType::TIDx); - tv0_cache->axis(3)->parallelize(ParallelType::TIDy); - - // Swizzles the smem cache to avoid bank conflicts - tv0_cache->swizzle(SwizzleType::Transpose, {3, 2}); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int bx = 100; - const int by = 200; - at::Tensor t0 = at::randn({bx, by}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = t0.t(); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = transpose(tv0); - fusion.addOutput(tv1); - - // tv0: [I0, I1] - // tv1: [I1, I0] - - const int BS = 32; - const int BDIM = 256; - - // CTA tiling by BS*BS - tv1->split(1, BS); - tv1->split(0, BS); - tv1->reorder({{1, 2}}); - // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)] - - // Create a smem buffer to cache each tile - auto tv0_cache = tv0->cacheAfter(); - tv0_cache->setMemoryType(MemoryType::Shared); - - tv0->computeAt(tv1, 2); - // tv0: [I0, I1] - // tv0_cache: [I1/BS, I0/BS, BS*BS/BDIM, BDIM] - // tv1: [I1/BS, I0/BS, BS*BS/BDIM, BDIM] - - // Tranform the tile axes for 1D thread mapping - tv1->merge(-2, -1); - tv1->split(-1, BDIM); - // tv1: [I1/BS, I0/BS, BS*BS/BDIM, BDIM] - - // Transform the cache similarly but apply swizzle to the 2D tile axes. - tv0_cache->reorder({{-2, -1}}); - tv0_cache->swizzle(SwizzleType::Transpose, {2, 3}); - tv0_cache->merge(-2, -1); - tv0_cache->split(-1, BDIM); - // tv0: [I1/BS, I0/BS, BS*BS/BDIM, BDIM] - - // Assign each thread block to a tile - tv1->axis(0)->parallelize(ParallelType::BIDy); - tv1->axis(1)->parallelize(ParallelType::BIDx); - - // Thread mapping for each tile. - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int bx = 100; - const int by = 200; - at::Tensor t0 = at::randn({bx, by}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = t0.t(); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - TEST_F(NVFuserTest, FusionGridPersistence_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -25593,176 +25402,6 @@ TEST_F(NVFuserTest, FusionPrint_CUDA) { } } -TEST_F(NVFuserTest, FusionManualScheduleTransposeComplexDAG1_CUDA) { - // achieved: 833.526 GB/s on RTX 3090 (theoretical bandwidth: 936 GB/s) - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(3); - auto tv1 = makeContigTensor(3); - auto tv2 = makeContigTensor(3); - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - auto tv3 = transpose(tv0, 1, 2); - auto tv4 = transpose(tv1, 0, 1); - auto tv5 = sigmoid(tv1); - auto tv6 = add(tv2, tv3); - auto tv7 = transpose(tv5, 0, 2); - auto tv8 = add(tv4, tv0); - auto tv9 = relu(tv8); - fusion.addOutput(tv9); - auto tv10 = sin(tv6); - fusion.addOutput(tv10); - auto tv11 = transpose(tv6, 0, 1); - auto tv12 = add(tv7, tv11); - fusion.addOutput(tv12); - - // group 1: tv0, tv1, *tv9, innermost dim K - // group 2: tv2, *tv10, tv12, innermost dim N - - // cache inputs and outputs - auto tv0_cache = tv0->cacheAfter(); - auto tv1_cache = tv1->cacheAfter(); - auto tv2_cache = tv2->cacheAfter(); - auto tv9_cache = tv9->cacheBefore(); - auto tv10_cache = tv10->cacheBefore(); - auto tv12_cache = tv12->cacheBefore(); - - // Step 1: Make 32x32 tiles, schedule outer dimensions - { - // Pick an arbitrary tensor as a reference tensor for this step. There is no - // requirement on which group this reference tensor should belong to. Here - // we pick tv9, which belongs to group 1. - - // Make 32x32 tile: - // [M, N, K] - tv9->split(1, 32); - tv9->reorder({{2, -1}}); - tv9->split(2, 32); - tv9->reorder({{3, -1}}); - // [M, N/32, K/32, 32(N), 32(K)] - - // merge outer dims, parallelize on BIDx, and unswitch - tv9->merge(0); - tv9->merge(0); - tv9->split(0, 1); - // [M * N/32 * K/32, 1, 32(N), 32(K)] - tv9->axis(0)->parallelize(ParallelType::BIDx); - tv9->axis(1)->parallelize(ParallelType::Unswitch); - // [BIDx, Unswitch, 32(N), 32(K)] - - // propagate to the entire DAG - MaxRootDomainInfoSpanningTree entire_dag(tv9); - TransformPropagator tp(tv9); - entire_dag.traverse(&tp); - scheduler_utils::parallelizeAllLike(tv9); - } - - constexpr int threads_per_block = 128; - - // Step 2, schedule group 2 - { - // group 2: tv2, *tv10, tv12, innermost dim N - - tv2_cache->setMemoryType(MemoryType::Shared); - tv10_cache->setMemoryType(MemoryType::Shared); - tv12_cache->setMemoryType(MemoryType::Shared); - - // pick tv10 as reference tensor for group 2 - // [BIDx, Unswitch, 32(N), 32(K)] - tv10->reorder({{-1, -2}}); - // [BIDx, Unswitch, 32(K), 32(N)] - tv10->merge(2); - tv10->split(2, 4); - tv10->split(2, threads_per_block); - tv10->axis(-1)->parallelize(ParallelType::Vectorize); - tv10->axis(-2)->parallelize(ParallelType::TIDx); - tv10->axis(-3)->parallelize(ParallelType::Unroll); - // [BIDx, Unswitch, Unroll, TIDx, Vectorize] - - // Propagate to group 2 and its cache. Note that group 2 and its cache are - // not connected, so we need to borrow other tensors of the DAG to be able - // to propagate. The transformations on borrowed tensors will be overwritten - // in the next step. We can not borrow the reference tensor of group 1. - auto all_tvs_except_ref1 = ir_utils::allTvsExcept(&fusion, {tv9}); - auto all_tvs_except_ref1_set = std::unordered_set( - all_tvs_except_ref1.begin(), all_tvs_except_ref1.end()); - SetSelector selector(all_tvs_except_ref1_set); - MaxRootDomainInfoSpanningTree tree(tv10, &selector); - TransformPropagator tp(tv10); - tree.traverse(&tp); - scheduler_utils::parallelizeAllLike( - tv10, {tv2_cache, tv10, tv12}, {ParallelType::TIDx}); - scheduler_utils::parallelizeAllLike( - tv10, - {tv2_cache, tv10, tv12}, - {ParallelType::Vectorize, ParallelType::Unroll}); - } - - // Step 3, schedule group 1 - { - // group 1: tv0, tv1, *tv9, innermost dim K - // [BIDx, Unswitch, 32(N), 32(K)] - tv9->merge(2); - tv9->split(2, 4); - tv9->split(2, threads_per_block); - tv9->axis(-1)->parallelize(ParallelType::Vectorize); - tv9->axis(-2)->parallelize(ParallelType::TIDx); - tv9->axis(-3)->parallelize(ParallelType::Unroll); - // [BIDx, Unswitch, Unroll, TIDx, Vectorize] - - // Propagate to the entire DAG except for group 2 and its cached inputs - auto all_tvs_except2 = - ir_utils::allTvsExcept(&fusion, {tv2, tv2_cache, tv10, tv12}); - auto all_tvs_except2_set = std::unordered_set( - all_tvs_except2.begin(), all_tvs_except2.end()); - SetSelector selector(all_tvs_except2_set); - MaxRootDomainInfoSpanningTree tree(tv9, &selector); - TransformPropagator tp(tv9); - tree.traverse(&tp); - scheduler_utils::parallelizeAllLike( - tv9, all_tvs_except2, {ParallelType::TIDx}); - scheduler_utils::parallelizeAllLike( - tv9, - {tv0_cache, tv1_cache, tv9}, - {ParallelType::Vectorize, ParallelType::Unroll}); - } - - // inline - MaxRootDomainInfoSpanningTree entire_dag(tv9); - InlinePropagator inline_propagator(tv9, -1, ComputeAtMode::MostInlined); - entire_dag.traverse(&inline_propagator); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input0 = at::randn({512, 1024, 256}, options); - at::Tensor input1 = at::randn({1024, 512, 256}, options); - at::Tensor input2 = at::randn({512, 256, 1024}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input0, input1, input2}); - auto outputs = fe.runFusion({input0, input1, input2}); - - auto t3 = input0.transpose(1, 2); - auto t4 = input1.transpose(0, 1); - auto t5 = input1.sigmoid(); - auto t6 = input2 + t3; - auto t7 = t5.transpose(0, 2); - auto t8 = t4 + input0; - auto t9 = t8.relu(); - auto t10 = t6.sin(); - auto t11 = t6.transpose(0, 1); - auto t12 = t7 + t11; - - testValidate( - &fusion, - outputs, - {input0, input1, input2}, - {t9, t10, t12}, - __LINE__, - __FILE__); -} - TEST_F(NVFuserTest, FusionCheckedSymbolicShape_CUDA) { const auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp new file mode 100644 index 000000000000..192865769f16 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp @@ -0,0 +1,715 @@ +#if defined(USE_CUDA) +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +using namespace torch::jit::fuser::cuda; + +TEST_F(NVFuserTest, FusionTranspose1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int M = 10; + constexpr int N = 20; + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = transpose(tv0); + fusion.addInput(tv0); + fusion.addOutput(tv1); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M, N}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + at::Tensor aten_output = t0.t(); + + testValidate( + &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTranspose2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int M = 10; + constexpr int N = 20; + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = transpose(tv0); + fusion.addInput(tv0); + fusion.addOutput(tv1); + + tv1->merge(0); + tv1->split(0, 32); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M, N}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + at::Tensor aten_output = t0.t(); + + testValidate( + &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTransposeWithSwizzle_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = transpose(tv0); + fusion.addOutput(tv1); + + // tv0: [I0, I1] + // tv1: [I1, I0] + + const int BS = 32; + + // CTA tiling by BS*BS + tv1->split(1, BS); + tv1->split(0, BS); + tv1->reorder({{1, 2}}); + // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)] + + // Create a smem buffer to cache each tile + auto tv0_cache = tv0->cacheAfter(); + tv0_cache->setMemoryType(MemoryType::Shared); + + tv0->computeAt(tv1, 2); + // tv0: [I0, I1] + // tv0_cache: [I1/BS, I0/BS, BS(I1), BS(I0)] + // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)] + + // Assign each thread block to a tile + tv1->axis(0)->parallelize(ParallelType::BIDy); + tv1->axis(1)->parallelize(ParallelType::BIDx); + + // Thread mapping for each tile. For both of the input and output + // tiles, map TIDx to the fastest-changing dimension to facilitate + // coalesced gmem accesses. + tv1->axis(2)->parallelize(ParallelType::TIDy); + tv1->axis(3)->parallelize(ParallelType::TIDx); + // Note that the fastest-changing axis is next to the inner-most + // axis since computeAt reorders the axes as the output tensor. + tv0_cache->axis(2)->parallelize(ParallelType::TIDx); + tv0_cache->axis(3)->parallelize(ParallelType::TIDy); + + // Swizzles the smem cache to avoid bank conflicts + tv0_cache->swizzle(SwizzleType::Transpose, {3, 2}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int bx = 100; + const int by = 200; + at::Tensor t0 = at::randn({bx, by}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = t0.t(); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTransposeWithSwizzle1DThreadBlock_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = transpose(tv0); + fusion.addOutput(tv1); + + // tv0: [I0, I1] + // tv1: [I1, I0] + + const int BS = 32; + const int BDIM = 256; + + // CTA tiling by BS*BS + tv1->split(1, BS); + tv1->split(0, BS); + tv1->reorder({{1, 2}}); + // tv1: [I1/BS, I0/BS, BS(I1), BS(I0)] + + // Create a smem buffer to cache each tile + auto tv0_cache = tv0->cacheAfter(); + tv0_cache->setMemoryType(MemoryType::Shared); + + tv0->computeAt(tv1, 2); + // tv0: [I0, I1] + // tv0_cache: [I1/BS, I0/BS, BS*BS/BDIM, BDIM] + // tv1: [I1/BS, I0/BS, BS*BS/BDIM, BDIM] + + // Tranform the tile axes for 1D thread mapping + tv1->merge(-2, -1); + tv1->split(-1, BDIM); + // tv1: [I1/BS, I0/BS, BS*BS/BDIM, BDIM] + + // Transform the cache similarly but apply swizzle to the 2D tile axes. + tv0_cache->reorder({{-2, -1}}); + tv0_cache->swizzle(SwizzleType::Transpose, {2, 3}); + tv0_cache->merge(-2, -1); + tv0_cache->split(-1, BDIM); + // tv0: [I1/BS, I0/BS, BS*BS/BDIM, BDIM] + + // Assign each thread block to a tile + tv1->axis(0)->parallelize(ParallelType::BIDy); + tv1->axis(1)->parallelize(ParallelType::BIDx); + + // Thread mapping for each tile. + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int bx = 100; + const int by = 200; + at::Tensor t0 = at::randn({bx, by}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = t0.t(); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +// x->sin->transpose->cos->y +TEST_F(NVFuserTest, FusionScheduleTransposeSimple_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(3); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = cos(tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({256, 1024, 1024}, options); + + auto lparams = scheduleTranspose(&fusion, {input}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}, lparams); + auto outputs = fe.runFusion({input}, lparams); + + auto tv_ref = input.sin().transpose(1, 2).cos(); + + testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__); +} + +// x->tanspose->sin->transpose->cos->y +TEST_F(NVFuserTest, FusionScheduleTransposeSinTransposeCos_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(3); + fusion.addInput(tv0); + auto tv1 = transpose(tv0, 0, 2); + auto tv2 = sin(tv1); + auto tv3 = transpose(tv2, 1, 2); + auto tv4 = cos(tv3); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({256, 1024, 1024}, options); + + auto lparams = scheduleTranspose(&fusion, {input}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}, lparams); + auto outputs = fe.runFusion({input}, lparams); + + auto tv_ref = input.transpose(0, 2).sin().transpose(1, 2).cos(); + + testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__); +} + +// t0->transpose--. +// \ +// t1->transpose---add-->sin->t5 +TEST_F(NVFuserTest, FusionScheduleTransposeMultipleInput_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(3); + auto tv1 = makeContigTensor(3); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = transpose(tv0, 0, 2); + auto tv3 = transpose(tv1, 0, 2); + auto tv4 = add(tv2, tv3); + auto tv5 = sin(tv4); + fusion.addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({256, 1024, 1024}, options); + at::Tensor input1 = at::randn({256, 1024, 1024}, options); + + auto lparams = scheduleTranspose(&fusion, {input0, input1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input0, input1}, lparams); + auto outputs = fe.runFusion({input0, input1}, lparams); + + auto tv_ref = (input0.transpose(0, 2) + input1.transpose(0, 2)).sin(); + + testValidate( + &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__); +} + +// t0->sin->transpose->t5 +// `->cos->transpose->t6 +TEST_F(NVFuserTest, FusionScheduleTransposeMultipleOutput_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(3); + fusion.addInput(tv0); + auto tv2 = sin(tv0); + auto tv3 = cos(tv0); + auto tv5 = transpose(tv2, 0, 2); + auto tv6 = transpose(tv3, 0, 2); + fusion.addOutput(tv5); + fusion.addOutput(tv6); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({256, 1024, 1024}, options); + + auto lparams = scheduleTranspose(&fusion, {input}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}, lparams); + auto outputs = fe.runFusion({input}, lparams); + + auto tv_ref1 = input.sin().transpose(0, 2); + auto tv_ref2 = input.cos().transpose(0, 2); + + testValidate( + &fusion, outputs, {input}, {tv_ref1, tv_ref2}, __LINE__, __FILE__); +} + +// t0->transpose->sin->t3 +// \_.-->cos->t5 +// / +// t1 +TEST_F(NVFuserTest, FusionScheduleTransposeMultipleInputOutput_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(3); + auto tv1 = makeContigTensor(3); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = transpose(tv0, 0, 2); + auto tv3 = sin(tv2); + fusion.addOutput(tv3); + auto tv4 = add(tv0, tv1); + auto tv5 = cos(tv4); + fusion.addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({256, 1024, 1024}, options); + at::Tensor input1 = at::randn({256, 1024, 1024}, options); + + auto lparams = scheduleTranspose(&fusion, {input0, input1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input0, input1}, lparams); + auto outputs = fe.runFusion({input0, input1}, lparams); + + auto tv_ref1 = input0.transpose(0, 2).sin(); + auto tv_ref2 = (input0 + input1).cos(); + + testValidate( + &fusion, + outputs, + {input0, input1}, + {tv_ref1, tv_ref2}, + __LINE__, + __FILE__); +} + +// .------>sin------>z +// x->transpose->transpose->add->y +// \_______________________/ +TEST_F(NVFuserTest, FusionScheduleTransposeMatchingSkipConnection_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(3); + fusion.addInput(tv0); + auto tv1 = transpose(tv0, 0, 2); + auto tv2 = transpose(tv1, 0, 2); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + auto tv4 = sin(tv1); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({256, 1024, 1024}, options); + + auto lparams = scheduleTranspose(&fusion, {input}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}, lparams); + auto outputs = fe.runFusion({input}, lparams); + + auto tv_ref1 = input.transpose(0, 2).transpose(0, 2) + input; + auto tv_ref2 = input.transpose(0, 2).sin(); + + testValidate( + &fusion, outputs, {input}, {tv_ref1, tv_ref2}, __LINE__, __FILE__); +} + +// x->transpose--add->z +// y->broadcast-/ +TEST_F(NVFuserTest, FusionScheduleTransposeBroadcast_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(3); + auto tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = transpose(tv0, 1, 2); + auto tv3 = broadcast(tv1, {false, false, true}); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({1024, 256, 1024}, options); + at::Tensor input1 = at::randn({1024, 1024}, options); + + auto lparams = scheduleTranspose(&fusion, {input0, input1}); + // auto lparams = schedulePointwise(&fusion, {input0, input1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input0, input1}, lparams); + auto outputs = fe.runFusion({input0, input1}, lparams); + + auto tv_ref = input0.transpose(1, 2) + input1.unsqueeze(2); + + testValidate( + &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__); +} + +// x->broadcast--add->z +// y->broadcast-/ +TEST_F(NVFuserTest, FusionScheduleTransposeNoReference_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + auto tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = broadcast(tv0, {false, true, false}); + auto tv3 = broadcast(tv1, {false, false, true}); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({1024, 256}, options); + at::Tensor input1 = at::randn({1024, 1024}, options); + + EXPECT_THAT( + [&]() { + scheduleTranspose(&fusion, {input0, input1}); + }, + testing::ThrowsMessage( + testing::HasSubstr("reference tensor"))); +} + +// x->broadcast--add->z +// y->broadcast-/ +TEST_F(NVFuserTest, FusionScheduleBroadcastOnly_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 1, 256}); + auto tv1 = makeConcreteTensor({1024, 1024, 1}); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({1024, 1, 256}, options); + at::Tensor input1 = at::randn({1024, 1024, 1}, options); + + auto lparams = scheduleTranspose(&fusion, {input0, input1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input0, input1}, lparams); + auto outputs = fe.runFusion({input0, input1}, lparams); + + auto tv_ref = input0 + input1; + + testValidate( + &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionScheduleTransposeComplexDAG1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(3); + auto tv1 = makeContigTensor(3); + auto tv2 = makeContigTensor(3); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + auto tv3 = transpose(tv0, 1, 2); + auto tv4 = transpose(tv1, 0, 1); + auto tv5 = sigmoid(tv1); + auto tv6 = add(tv2, tv3); + auto tv7 = transpose(tv5, 0, 2); + auto tv8 = add(tv4, tv0); + auto tv9 = relu(tv8); + fusion.addOutput(tv9); + auto tv10 = sin(tv6); + fusion.addOutput(tv10); + auto tv11 = transpose(tv6, 0, 1); + auto tv12 = add(tv7, tv11); + fusion.addOutput(tv12); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({512, 1024, 256}, options); + at::Tensor input1 = at::randn({1024, 512, 256}, options); + at::Tensor input2 = at::randn({512, 256, 1024}, options); + + auto lparams = scheduleTranspose(&fusion, {input0, input1, input2}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input0, input1, input2}, lparams); + auto outputs = fe.runFusion({input0, input1, input2}, lparams); + + auto t3 = input0.transpose(1, 2); + auto t4 = input1.transpose(0, 1); + auto t5 = input1.sigmoid(); + auto t6 = input2 + t3; + auto t7 = t5.transpose(0, 2); + auto t8 = t4 + input0; + auto t9 = t8.relu(); + auto t10 = t6.sin(); + auto t11 = t6.transpose(0, 1); + auto t12 = t7 + t11; + + testValidate( + &fusion, + outputs, + {input0, input1, input2}, + {t9, t10, t12}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionManualScheduleTransposeComplexDAG1_CUDA) { + // achieved: 833.526 GB/s on RTX 3090 (theoretical bandwidth: 936 GB/s) + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(3); + auto tv1 = makeContigTensor(3); + auto tv2 = makeContigTensor(3); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + auto tv3 = transpose(tv0, 1, 2); + auto tv4 = transpose(tv1, 0, 1); + auto tv5 = sigmoid(tv1); + auto tv6 = add(tv2, tv3); + auto tv7 = transpose(tv5, 0, 2); + auto tv8 = add(tv4, tv0); + auto tv9 = relu(tv8); + fusion.addOutput(tv9); + auto tv10 = sin(tv6); + fusion.addOutput(tv10); + auto tv11 = transpose(tv6, 0, 1); + auto tv12 = add(tv7, tv11); + fusion.addOutput(tv12); + + // group 1: tv0, tv1, *tv9, innermost dim K + // group 2: tv2, *tv10, tv12, innermost dim N + + // cache inputs and outputs + auto tv0_cache = tv0->cacheAfter(); + auto tv1_cache = tv1->cacheAfter(); + auto tv2_cache = tv2->cacheAfter(); + auto tv9_cache = tv9->cacheBefore(); + auto tv10_cache = tv10->cacheBefore(); + auto tv12_cache = tv12->cacheBefore(); + + // Step 1: Make 32x32 tiles, schedule outer dimensions + { + // Pick an arbitrary tensor as a reference tensor for this step. There is no + // requirement on which group this reference tensor should belong to. Here + // we pick tv9, which belongs to group 1. + + // Make 32x32 tile: + // [M, N, K] + tv9->split(1, 32); + tv9->reorder({{2, -1}}); + tv9->split(2, 32); + tv9->reorder({{3, -1}}); + // [M, N/32, K/32, 32(N), 32(K)] + + // merge outer dims, parallelize on BIDx, and unswitch + tv9->merge(0); + tv9->merge(0); + tv9->split(0, 1); + // [M * N/32 * K/32, 1, 32(N), 32(K)] + tv9->axis(0)->parallelize(ParallelType::BIDx); + tv9->axis(1)->parallelize(ParallelType::Unswitch); + // [BIDx, Unswitch, 32(N), 32(K)] + + // propagate to the entire DAG + MaxRootDomainInfoSpanningTree entire_dag(tv9); + TransformPropagator tp(tv9); + entire_dag.traverse(&tp); + scheduler_utils::parallelizeAllLike(tv9); + } + + constexpr int threads_per_block = 128; + + // Step 2, schedule group 2 + { + // group 2: tv2, *tv10, tv12, innermost dim N + + tv2_cache->setMemoryType(MemoryType::Shared); + tv10_cache->setMemoryType(MemoryType::Shared); + tv12_cache->setMemoryType(MemoryType::Shared); + + // pick tv10 as reference tensor for group 2 + // [BIDx, Unswitch, 32(N), 32(K)] + tv10->reorder({{-1, -2}}); + // [BIDx, Unswitch, 32(K), 32(N)] + tv10->merge(2); + tv10->split(2, 4); + tv10->split(2, threads_per_block); + tv10->axis(-1)->parallelize(ParallelType::Vectorize); + tv10->axis(-2)->parallelize(ParallelType::TIDx); + tv10->axis(-3)->parallelize(ParallelType::Unroll); + // [BIDx, Unswitch, Unroll, TIDx, Vectorize] + + // Propagate to group 2 and its cache. Note that group 2 and its cache are + // not connected, so we need to borrow other tensors of the DAG to be able + // to propagate. The transformations on borrowed tensors will be overwritten + // in the next step. We can not borrow the reference tensor of group 1. + auto all_tvs_except_ref1 = ir_utils::allTvsExcept(&fusion, {tv9}); + auto all_tvs_except_ref1_set = std::unordered_set( + all_tvs_except_ref1.begin(), all_tvs_except_ref1.end()); + SetSelector selector(all_tvs_except_ref1_set); + MaxRootDomainInfoSpanningTree tree(tv10, &selector); + TransformPropagator tp(tv10); + tree.traverse(&tp); + scheduler_utils::parallelizeAllLike( + tv10, {tv2_cache, tv10, tv12}, {ParallelType::TIDx}); + scheduler_utils::parallelizeAllLike( + tv10, + {tv2_cache, tv10, tv12}, + {ParallelType::Vectorize, ParallelType::Unroll}); + } + + // Step 3, schedule group 1 + { + // group 1: tv0, tv1, *tv9, innermost dim K + // [BIDx, Unswitch, 32(N), 32(K)] + tv9->merge(2); + tv9->split(2, 4); + tv9->split(2, threads_per_block); + tv9->axis(-1)->parallelize(ParallelType::Vectorize); + tv9->axis(-2)->parallelize(ParallelType::TIDx); + tv9->axis(-3)->parallelize(ParallelType::Unroll); + // [BIDx, Unswitch, Unroll, TIDx, Vectorize] + + // Propagate to the entire DAG except for group 2 and its cached inputs + auto all_tvs_except2 = + ir_utils::allTvsExcept(&fusion, {tv2, tv2_cache, tv10, tv12}); + auto all_tvs_except2_set = std::unordered_set( + all_tvs_except2.begin(), all_tvs_except2.end()); + SetSelector selector(all_tvs_except2_set); + MaxRootDomainInfoSpanningTree tree(tv9, &selector); + TransformPropagator tp(tv9); + tree.traverse(&tp); + scheduler_utils::parallelizeAllLike( + tv9, all_tvs_except2, {ParallelType::TIDx}); + scheduler_utils::parallelizeAllLike( + tv9, + {tv0_cache, tv1_cache, tv9}, + {ParallelType::Vectorize, ParallelType::Unroll}); + } + + // inline + MaxRootDomainInfoSpanningTree entire_dag(tv9); + InlinePropagator inline_propagator(tv9, -1, ComputeAtMode::MostInlined); + entire_dag.traverse(&inline_propagator); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({512, 1024, 256}, options); + at::Tensor input1 = at::randn({1024, 512, 256}, options); + at::Tensor input2 = at::randn({512, 256, 1024}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input0, input1, input2}); + auto outputs = fe.runFusion({input0, input1, input2}); + + auto t3 = input0.transpose(1, 2); + auto t4 = input1.transpose(0, 1); + auto t5 = input1.sigmoid(); + auto t6 = input2 + t3; + auto t7 = t5.transpose(0, 2); + auto t8 = t4 + input0; + auto t9 = t8.relu(); + auto t10 = t6.sin(); + auto t11 = t6.transpose(0, 1); + auto t12 = t7 + t11; + + testValidate( + &fusion, + outputs, + {input0, input1, input2}, + {t9, t10, t12}, + __LINE__, + __FILE__); +} + +} // namespace jit +} // namespace torch +#endif // #if defined(USE_CUDA)