diff --git a/torch/csrc/jit/codegen/cuda/inlining.cpp b/torch/csrc/jit/codegen/cuda/inlining.cpp index da6d229c68f8b5..eb2c4b3fb5db59 100644 --- a/torch/csrc/jit/codegen/cuda/inlining.cpp +++ b/torch/csrc/jit/codegen/cuda/inlining.cpp @@ -153,29 +153,25 @@ size_t MaxPosCalculator::getMaxPosAll( return max_pos; } -void inlineMost(const std::unordered_set& uninlinable_ids) { - inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids); +void inlineMost() { + inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion())); } -void inlineMost( - const std::vector& tvs, - const std::unordered_set& uninlinable_ids) { +void inlineMost(const std::vector& tvs) { if (tvs.empty()) { return; } - MaxPosCalculator calc(uninlinable_ids); + MaxPosCalculator calc; for (auto tv : tvs) { tv->inlineAt(-1, true, &calc); } } -void inlineMost( - const std::unordered_set& tvs, - const std::unordered_set& uninlinable_ids) { +void inlineMost(const std::unordered_set& tvs) { if (tvs.empty()) { return; } - MaxPosCalculator calc(uninlinable_ids); + MaxPosCalculator calc; for (auto tv : tvs) { tv->inlineAt(-1, true, &calc); } @@ -276,10 +272,9 @@ std::unordered_map getPositionsMappedTo( void inlineAllAt( TensorView* reference_tv, int64_t reference_pos, - bool best_effort, - const std::unordered_set& uninlinable_ids) { + bool best_effort) { auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos); - MaxPosCalculator calc(uninlinable_ids); + MaxPosCalculator calc; for (auto pair : mapped_positions) { pair.first->inlineAt(pair.second, best_effort, &calc); } @@ -289,10 +284,9 @@ void inlineSelectedAt( const std::unordered_set& selected, TensorView* reference_tv, int64_t reference_pos, - bool best_effort, - const std::unordered_set& uninlinable_ids) { + bool best_effort) { auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos); - MaxPosCalculator calc(uninlinable_ids); + MaxPosCalculator calc; for (auto pair : mapped_positions) { if (selected.count(pair.first) > 0) { pair.first->inlineAt(pair.second, best_effort, &calc); diff --git a/torch/csrc/jit/codegen/cuda/inlining.h b/torch/csrc/jit/codegen/cuda/inlining.h index 3b15eb23f98777..0a3ce1e8012a17 100644 --- a/torch/csrc/jit/codegen/cuda/inlining.h +++ b/torch/csrc/jit/codegen/cuda/inlining.h @@ -64,26 +64,20 @@ class MaxPosCalculator { // Inline to the right most allowed position for all tensors in the current // fusion. -TORCH_CUDA_CU_API void inlineMost( - const std::unordered_set& uninlinable_ids = {}); +TORCH_CUDA_CU_API void inlineMost(); // Inline to the right most allowed position for the selected tensors in the // current fusion. -TORCH_CUDA_CU_API void inlineMost( - const std::vector& tvs, - const std::unordered_set& uninlinable_ids = {}); +TORCH_CUDA_CU_API void inlineMost(const std::vector& tvs); // Inline to the right most allowed position for the selected tensors in the // current fusion. -TORCH_CUDA_CU_API void inlineMost( - const std::unordered_set& tvs, - const std::unordered_set& uninlinable_ids = {}); +TORCH_CUDA_CU_API void inlineMost(const std::unordered_set& tvs); // Inline to the position corresponding to the reference position in the // reference tensor for all tensors in the current fusion. TORCH_CUDA_CU_API void inlineAllAt( TensorView* reference_tv, int64_t reference_pos, - bool best_effort = false, - const std::unordered_set& uninlinable_ids = {}); + bool best_effort = false); // Inline to the position corresponding to the reference position in the // reference tensor for selected tensors in the current fusion. @@ -91,8 +85,7 @@ TORCH_CUDA_CU_API void inlineSelectedAt( const std::unordered_set& selected, TensorView* reference_tv, int64_t reference_pos, - bool best_effort = false, - const std::unordered_set& uninlinable_ids = {}); + bool best_effort = false); } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 3319bf28a18a9d..410f008d59cb20 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -1597,6 +1597,43 @@ std::vector IterDomain::clone( return cloned_domains; } +IterType inferIterType(IterDomain* i1, IterDomain* i2) { + // The itertype inference is a pattern matching of the rules below: + // + // X + X = X + // trivial reduction + X = X + // X + trivial reduction = X + // broadcasting + X = X + // X + broadcasting = X + // fail + // + // The rules are proceeded one by one in order. For each rule, we test if the + // given (outer, inner) matches the pattern. If it does, then we stop + // procceeding and get a result. If we have reached the end without finding + // any matched pattern, then it is a mistake and should be reported. + // + // Note that based on the above rule: + // broadcasting + (non-trivial) reduction = reduction + // broadcasting + trivial reduction = broadcasting + if (i1->getIterType() == i2->getIterType()) { + return i1->getIterType(); + } + if (i1->isTrivialReduction()) { + return i2->getIterType(); + } + if (i2->isTrivialReduction()) { + return i1->getIterType(); + } + if (i1->isBroadcast()) { + return i2->getIterType(); + } + if (i2->isBroadcast()) { + return i1->getIterType(); + } + TORCH_CHECK( + false, "Merging IterDomains requires that their iteration types match."); +} + // Merging does not propagate the start and stop values of the input // domains to the merged output domain. The actual range of the // domains is enforced by predicates. Note that since only root @@ -1606,48 +1643,10 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { TORCH_CHECK( !outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(), "Merging IterDomains with ending values that are 0 is not supported at this time."); - TORCH_CHECK( - outer->isReduction() == inner->isReduction() || - (!outer->isReduction() && inner->isTrivialReduction()) || - (outer->isTrivialReduction() && !inner->isReduction()), - "Merging IterDomains requires that their iteration types match."); - TORCH_CHECK( - (outer->isGather() && inner->isGather()) || - (!outer->isGather() && !inner->isGather()), - "Merging gather and non-gather domains is not supported."); - - TORCH_CHECK( - !outer->isStride() && !inner->isStride(), - "No support for merging stride domains"); Val* merged_id_size = mul(outer->extent(), inner->extent()); - IterType itype = outer->getIterType(); - - if (outer->isBroadcast() && inner->isBroadcast()) { - itype = IterType::Broadcast; - } - - if ((outer->isBroadcast() || inner->isBroadcast()) && - (outer->getIterType() == IterType::Iteration || - inner->getIterType() == IterType::Iteration)) { - itype = IterType::Iteration; - } - - // Merging trivial reduction with iter domain, that's fine, just make it an - // iter domain. - if ((outer->isTrivialReduction() || inner->isTrivialReduction()) && - (outer->getIterType() == IterType::Iteration || - inner->getIterType() == IterType::Iteration)) { - itype = IterType::Iteration; - } - - // Merging trivial reduction with broadcasting, that's fine, just make it a - // broadcasting. - if ((outer->isTrivialReduction() || inner->isTrivialReduction()) && - (outer->isBroadcast() || inner->isBroadcast())) { - itype = IterType::Broadcast; - } + IterType itype = inferIterType(outer, inner); Val* expanded_extent = nullptr; if (outer->hasExpandedExtent() || inner->hasExpandedExtent()) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp index ae9ecd88bbdc3c..f88c34eb3f59a5 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp @@ -330,13 +330,8 @@ void multiReductionInliner( } } - // Find iter domains that are mapped to a trivial reduction, these should - // never be inlined. - std::unordered_set mapped_to_trivial_reduction = - scheduler_utils::getTrivialReductionMap(fusion); - // Inline the schedule - inlineMost(mapped_to_trivial_reduction); + inlineMost(); } namespace { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index d985da926354b0..036e9c920824a3 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -21,26 +21,20 @@ namespace scheduler_utils { // Returns number of "valid" dimensions. e.g. if tv has // [I1, R2, I3, I4, R3{1}] -// where R3{1} is in dont_merge, resulting domain should be: -// [I1, I3*I4, R2, R3{1}] with return value 3 +// resulting domain should be: +// [I1, I3*I4, R2*R3{1}] with return value 3 // // if tv has // [R1, I2, R3, I4, R4, R5{1}, R6{1}] -// where R5{1} and R6{1} are in dont_merge, resulting domain should be: -// [I2*I4, R1*R3, R4, R5{1}, R6{1}] +// resulting domain should be: +// [I2*I4, R1*R3, R4*R5{1}*R6{1}] // with return value 3 -size_t merge_3d( - TensorView* tv, - const std::unordered_set& dont_merge) { +size_t merge_3d(TensorView* tv) { bool active_is_reduction = false; bool first_dim = true; int prev_i = -1; for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { - if (dont_merge.count(tv->axis(i))) { - continue; - } - if (first_dim) { active_is_reduction = tv->axis(i)->isReduction(); prev_i = i; @@ -67,10 +61,6 @@ size_t merge_3d( for (int i = static_cast(tv->nDims()) - 2; i >= 0; i--) { auto id = tv->axis(i); - if (dont_merge.count(id)) { - continue; - } - if (first_dim) { active_is_reduction = id->isReduction(); prev_i = i; @@ -96,10 +86,6 @@ size_t merge_3d( prev_i = -1; for (int i = static_cast(tv->nDims()) - 3; i >= 0; i--) { - if (dont_merge.count(tv->axis(i))) { - continue; - } - if (first_dim) { active_is_reduction = tv->axis(i)->isReduction(); prev_i = i; @@ -114,7 +100,7 @@ size_t merge_3d( if (prev_i == -1) { // Two dimensional, put merged dimensions first tv->reorder({{-1, 0}, {-2, 1}}); - // [outer, inner, dont_merge...] + // [outer, inner] if (tv->axis(0)->isReduction()) { // put reductions as second axis tv->reorder({{0, 1}, {1, 0}}); @@ -195,13 +181,11 @@ c10::optional mergeDims( return left; } -size_t mergeReduction( - TensorView* tv, - const std::unordered_set& dont_merge) { +size_t mergeReduction(TensorView* tv) { int prev_i = -1; size_t num_merged = 0; for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { - if (!tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) { + if (!tv->axis(i)->isReduction()) { continue; } if (prev_i == -1) { @@ -219,16 +203,14 @@ size_t mergeReduction( return prev_i == -1 ? 0 : num_merged + 1; } -size_t mergeNonReduction( - TensorView* tv, - const std::unordered_set& dont_merge) { +size_t mergeNonReduction(TensorView* tv) { int prev_i = -1; size_t num_merged = 0; if (tv->nDims() == 0) { return 0; } for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { - if (tv->axis(i)->isReduction() || dont_merge.count(tv->axis(i))) { + if (tv->axis(i)->isReduction()) { continue; } if (prev_i == -1) { @@ -905,63 +887,21 @@ PersistentBufferSizeReturn persistentBufferSize( return persistent_buffer_size; } -std::unordered_set getTrivialReductionMap(Fusion* fusion) { - auto all_tvs = ir_utils::allTvs(fusion); - std::unordered_set mapped_to_trivial_reduction; - for (auto tv : all_tvs) { - // root domain vs domain shouldn't matter as at this point we shouldn't have - // any transformations. - for (auto id : tv->getRootDomain()) { - if (id->isTrivialReduction()) { - mapped_to_trivial_reduction.emplace(id); - } - } - } - - if (!mapped_to_trivial_reduction.empty()) { - // Use the loop map as that is the most permissive - auto ca_map = ComputeAtMap(fusion); - // Make a copy we need to check mappings of all - auto trivial_ids = mapped_to_trivial_reduction; - for (auto tv : all_tvs) { - for (auto id : tv->getRootDomain()) { - if (!id->extent()->isOneInt()) { - continue; - } - if (std::any_of( - trivial_ids.begin(), - trivial_ids.end(), - [&ca_map, &id](IterDomain* trivial_id) { - return ca_map.areMapped( - id, trivial_id, IdMappingMode::PERMISSIVE); - })) { - mapped_to_trivial_reduction.emplace(id); - } - } - } - } - return mapped_to_trivial_reduction; -} - std::pair canonicalDimReduction( Fusion* fusion, TensorView* tv, bool schedule_3D) { - std::unordered_set mapped_to_trivial_reduction = - getTrivialReductionMap(fusion); - TORCH_INTERNAL_ASSERT(tv != nullptr); if (!schedule_3D) { // We coalesce all reduction axes to the right; - bool has_red_axis = mergeReduction(tv, mapped_to_trivial_reduction) > 0; + bool has_red_axis = mergeReduction(tv) > 0; - bool has_iter_axis = mergeNonReduction(tv, mapped_to_trivial_reduction) > 0; + bool has_iter_axis = mergeNonReduction(tv) > 0; return {has_iter_axis, has_red_axis}; } else { TORCH_INTERNAL_ASSERT( - merge_3d(tv, mapped_to_trivial_reduction) == 3, - "Tried 3D merge, but result is not 3D."); + merge_3d(tv) == 3, "Tried 3D merge, but result is not 3D."); return {true, true}; } } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.h b/torch/csrc/jit/codegen/cuda/scheduler/utils.h index 373a879f740d50..b5dbe162f0e938 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.h @@ -78,16 +78,12 @@ TORCH_CUDA_CU_API inline c10::optional mergeDims( } // Merge all reduction to the right side and returns total number of -// reduction axes. Don't merge is typically used for trivial reductions. -size_t mergeReduction( - TensorView* tv, - const std::unordered_set& dont_merge = {}); +// reduction axes. +size_t mergeReduction(TensorView* tv); // merge all non-reduction axes to the left side and returns total number of -// iteration axes. Don't merge is typically used for trivial reductions. -size_t mergeNonReduction( - TensorView* tv, - const std::unordered_set& dont_merge = {}); +// iteration axes. +size_t mergeNonReduction(TensorView* tv); // Propagate the parallelization from the selected dimensions of the reference // tensor to their corresponding dimensions in all selected tensors in the DAG. diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index ee5e55bd592e17..8711154c9e7324 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -7853,6 +7853,74 @@ TEST_F(NVFuserTest, FusionReductionScheduler_CUDA) { lparams); } +// This test checks if our system could correctly handles the case where both +// reduction and trivial reduction exist in the fusion. Trivial reduction +// deserve testing because trivial reduction is handled more like a broadcasting +// rather than a reduction. +TEST_F(NVFuserTest, FusionReductionWithTrivial_CUDA) { + constexpr int bid_x = 80; + constexpr int tid_x = 4096; + + std::vector> shapes = { + {-1, -1, 1}, {-1, 1, -1}, {1, -1, -1}}; + + for (auto shape : shapes) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + std::vector> reduction_dims = { + {0}, + {1}, + {2}, + {0, 1}, + {0, 2}, + {1, 2}, + {0, 1, 2}, + }; + + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + for (auto rdims : reduction_dims) { + std::vector rdims_(rdims.begin(), rdims.end()); + auto tv = sum(tv0, rdims_); + fusion.addOutput(tv); + } + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto concrete_shape = shape; + std::deque concrete_values = {bid_x, tid_x}; + for (auto& s : concrete_shape) { + if (s == -1) { + s = concrete_values.front(); + concrete_values.pop_front(); + } + } + + at::Tensor aten_input = at::randn(concrete_shape, options); + std::vector aten_outputs; + for (auto rdims : reduction_dims) { + aten_outputs.push_back(aten_input.sum(rdims)); + } + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs({aten_input}); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + aten_outputs, + __LINE__, + __FILE__, + ""); + } +} + // Simple reduction parallelized on a symbolic size. TEST_F(NVFuserTest, FusionSymbolicReduction_CUDA) { Fusion fusion;