From 7587976e34d72411849107554f439d0a2a292c83 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Fri, 16 Sep 2022 08:26:12 -0400 Subject: [PATCH] Remove non-const functions, remove GpuLower instance on build, pass in ca_map. --- .../csrc/jit/codegen/cuda/compute_at_map.cpp | 2 +- torch/csrc/jit/codegen/cuda/contiguity.cpp | 2 +- torch/csrc/jit/codegen/cuda/index_compute.cpp | 23 ++++---- torch/csrc/jit/codegen/cuda/lower2device.cpp | 4 +- torch/csrc/jit/codegen/cuda/lower2device.h | 16 +++--- .../jit/codegen/cuda/lower_allocation.cpp | 8 +-- .../jit/codegen/cuda/lower_index_compute.cpp | 9 ++-- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 2 +- .../cuda/lower_predicate_elimination.cpp | 6 +-- torch/csrc/jit/codegen/cuda/lower_shift.cpp | 54 +++++++++---------- torch/csrc/jit/codegen/cuda/lower_shift.h | 33 ++++++------ .../codegen/cuda/lower_sync_information.cpp | 12 ++--- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 2 +- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 2 +- torch/csrc/jit/codegen/cuda/lower_utils.h | 4 +- .../jit/codegen/cuda/lower_validation.cpp | 2 +- 16 files changed, 90 insertions(+), 91 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index b1d4819fe5a8b..9c4aca73eb3bf 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -665,7 +665,7 @@ void ComputeAtMap::allocateIndexVariables() { // Halo extended parallel loops currently are handled // differently and an index variable would still // be allocated in this case. - (GpuLower::current()->haloInfo().getExtent(id) == nullptr)) { + (GpuLower::current()->haloInfo()->getExtent(id) == nullptr)) { ptype = id->getParallelType(); return true; } diff --git a/torch/csrc/jit/codegen/cuda/contiguity.cpp b/torch/csrc/jit/codegen/cuda/contiguity.cpp index 4817693bebdc3..2b0fbfb37d0c8 100644 --- a/torch/csrc/jit/codegen/cuda/contiguity.cpp +++ b/torch/csrc/jit/codegen/cuda/contiguity.cpp @@ -51,7 +51,7 @@ ContigIDs::ContigIDs( (ignore_halo_constraint || !GpuLower::current() ->haloInfo() - .getRootAxisInfo(root_domain_i) + ->getRootAxisInfo(root_domain_i) .hasHalo())) { contig_ids_.emplace(root_domain_i); is_contig_root_[root_domain_i] = true; diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index ed428dfd6319b..0d9a936dbd438 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -51,8 +51,8 @@ int getProducerHaloOffset( IterDomain* consumer_id = it->second; const auto& halo_map = GpuLower::current()->haloInfo(); - const auto p_pad = halo_map.getRootAxisInfo(producer_id).width(0); - const auto c_pad = halo_map.getRootAxisInfo(consumer_id).width(0); + const auto p_pad = halo_map->getRootAxisInfo(producer_id).width(0); + const auto c_pad = halo_map->getRootAxisInfo(consumer_id).width(0); auto offset = p_pad - c_pad; @@ -985,7 +985,7 @@ Val* getHaloExtentOfRootAxis(IterDomain* id, Val* normal_extent = nullptr) { normal_extent = id->extent(); } - const auto& halo = GpuLower::current()->haloInfo().getRootAxisInfo(id); + const auto& halo = GpuLower::current()->haloInfo()->getRootAxisInfo(id); if (halo.hasHalo()) { auto halo_extent = SimplifyingIrBuilder::addExpr( normal_extent, SimplifyingIrBuilder::create(halo.width())); @@ -2351,7 +2351,7 @@ std::vector getPredicateContigIds( std::unordered_set excluded_ids; for (auto consumer_root_id : consumer_root_domain) { - if (gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id).hasHalo()) { + if (gpu_lower->haloInfo()->getRootAxisInfo(consumer_root_id).hasHalo()) { excluded_ids.insert(consumer_root_id); continue; } @@ -2487,7 +2487,7 @@ int getUnswitchStopOffset( const auto gpu_lower = GpuLower::current(); AxisHaloInfo halo_info = - gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id); + gpu_lower->haloInfo()->getRootAxisInfo(consumer_root_id); // If the consumer root domain to predicate does not have halo, no // adjustment is required. @@ -2511,7 +2511,7 @@ int getUnswitchStopOffset( unswitch_it, consumer_tv->domain()->domain().end(), [&gpu_lower, &consumer_root_id](auto leaf_id) { - return gpu_lower->haloInfo().isHaloInherited( + return gpu_lower->haloInfo()->isHaloInherited( consumer_root_id, leaf_id); })) { return halo_info.width(); @@ -2669,7 +2669,8 @@ std::pair getStartAndStopLimitOffsets( Val* stop_limit = SimplifyingIrBuilder::negExpr(consumer_id->stopOffset()); if (!non_divisible_pred) { - AxisHaloInfo halo_info = gpu_lower->haloInfo().getRootAxisInfo(consumer_id); + AxisHaloInfo halo_info = + gpu_lower->haloInfo()->getRootAxisInfo(consumer_id); // Below, "left" and "right" halo mean halo at offset zero and // axis extent, respectively. @@ -2693,8 +2694,8 @@ std::pair getStartAndStopLimitOffsets( // that it is less than the extent of the predicated ID + // halo. Note that getRootAxisInfo doesn't work since consumer_id // isn't a root domain. - if (gpu_lower->haloInfo().hasHaloWidth(consumer_id)) { - auto halo = gpu_lower->haloInfo().getHaloWidth(consumer_id); + if (gpu_lower->haloInfo()->hasHaloWidth(consumer_id)) { + auto halo = gpu_lower->haloInfo()->getHaloWidth(consumer_id); stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo); } } @@ -2841,8 +2842,8 @@ bool canOmitStopPredicate( // to be predicated, not its merged contig id even if it exists. So, // if contig_id does not have root axis info, contig_id is // guaranteed to have no halo. - auto halo_ext = gpu_lower->haloInfo().hasRootAxisInfo(contig_id) - ? gpu_lower->haloInfo().getRootAxisInfo(contig_id).width() + auto halo_ext = gpu_lower->haloInfo()->hasRootAxisInfo(contig_id) + ? gpu_lower->haloInfo()->getRootAxisInfo(contig_id).width() : 0; if (halo_ext + stop_offset_val.value() > 0) { diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 53b9d172f203f..4b44ef9075b19 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -248,7 +248,7 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { // mappings of all iteration domains across the fusion. There are three types // of mappings Permissive, Exact, and Loop, see compute_at_map.h/cpp for more // information. - compute_at_map_ = std::make_unique(fusion_); + compute_at_map_ = std::make_shared(fusion_); if (isDebugDumpEnabled(DebugDumpOption::ComputeAtMap)) { std::cout << compute_at_map_->toString() << std::endl; @@ -281,7 +281,7 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { // Scan the whole fusion and build mappings about halo extensions of // all IterDomains - haloInfo().build(fusion_); + halo_info_ = std::make_shared(fusion_, compute_at_map_); // Want to run this after parallel map and halo info map are // created. vectorized_accesses_ and vectorized_set_info_ are filled. diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index d5600e0a25139..7420d19244400 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -76,20 +76,16 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable { return thread_pred_map_; } - const std::unique_ptr& caMap() const { - return compute_at_map_; + std::shared_ptr caMap() const { + return std::const_pointer_cast(compute_at_map_); } const TrivialReductionInfo& trivialReductionInfo() const { return trivial_reduction_info_; } - const HaloInfo& haloInfo() const { - return halo_info_; - } - - HaloInfo& haloInfo() { - return halo_info_; + std::shared_ptr haloInfo() const { + return std::const_pointer_cast(halo_info_); } const ParallelDimensionMap& parallelDimensionMap() const { @@ -201,9 +197,9 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable { ConcretizedBroadcastDomains concretized_broadcast_domains_; ThreadPredicateMap thread_pred_map_; PredicateElimination pred_elimination_; - std::unique_ptr compute_at_map_; + std::shared_ptr compute_at_map_; TrivialReductionInfo trivial_reduction_info_; - HaloInfo halo_info_; + std::shared_ptr halo_info_; LocalAllocationInfoMap local_allocation_info_map_; WarpPaddedParallelInfo warp_pad_info_; ParallelDimensionMap parallel_dimension_map_; diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index 466dc85c8abff..5df08c33f2262 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -131,7 +131,7 @@ class AllocationInserter : public kir::ExprMutator { ++init_loop_it) { auto id = *init_loop_it; kir::ForLoop* new_loop = nullptr; - auto extent_with_halo = gpu_lower->haloInfo().getExtent(id); + auto extent_with_halo = gpu_lower->haloInfo()->getExtent(id); if (extent_with_halo) { new_loop = IrBuilder::create( id, @@ -166,7 +166,7 @@ class AllocationInserter : public kir::ExprMutator { } auto extent = id->extent(); // Use halo-extended extent if found - auto halo_extent = gpu_lower->haloInfo().getRootAxisInfo(id); + auto halo_extent = gpu_lower->haloInfo()->getRootAxisInfo(id); if (halo_extent.hasHalo()) { extent = IrBuilder::addExpr( extent, IrBuilder::create(halo_extent.width())); @@ -213,7 +213,7 @@ class AllocationInserter : public kir::ExprMutator { // Get the halo extent if found auto getExtent = [this](IterDomain* id) { - auto extent = gpu_lower->haloInfo().getExtent(id); + auto extent = gpu_lower->haloInfo()->getExtent(id); if (extent == nullptr) { extent = id->extent(); } @@ -368,7 +368,7 @@ class AllocationInserter : public kir::ExprMutator { auto extent = concrete_id->extent(); - if (gpu_lower->haloInfo().getExtent(info.buffer->axis(axis_i)) != + if (gpu_lower->haloInfo()->getExtent(info.buffer->axis(axis_i)) != nullptr) { has_halo = true; } diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp index 513066a5c71c4..c28eb960b0fba 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp @@ -125,7 +125,8 @@ IndexingParameters getLinearIndexParameters( // Derive the halo extents from the loop indexing result. index_parameters.concrete_id_to_halo_extent = - GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing); + GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap( + loop_indexing); protectNonPredicateIndexWithMagicZero( loops, @@ -233,7 +234,8 @@ IndexingParameters getNonGlobalInitialIndexParameters( // Derive the halo extents from the loop indexing result. index_parameters.concrete_id_to_halo_extent = - GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing); + GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap( + loop_indexing); return index_parameters; } @@ -408,7 +410,8 @@ IndexingParameters getPredicateInitialIndexParameters( // Derive the halo extents from the loop indexing result. index_parameters.concrete_id_to_halo_extent = - GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing); + GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap( + loop_indexing); return index_parameters; } diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index 7fdb149da9359..e135627950612 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -33,7 +33,7 @@ LoopNestGenerator::LoopNestGenerator(const std::vector& exprs) { namespace { kir::ForLoop* openForHelper(kir::ForLoop* scope, IterDomain* id) { - auto extent_with_halo = GpuLower::current()->haloInfo().getExtent(id); + auto extent_with_halo = GpuLower::current()->haloInfo()->getExtent(id); kir::ForLoop* new_scope = nullptr; if (extent_with_halo) { // When an axis is extended with halo, unrolling and vectorization diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 940de32ce9567..66beebe66e7b0 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -303,12 +303,12 @@ class PredicateChcker : public IterVisitor { // Shift is not supported yet. bool predicateShift(Expr* expr) const { - auto& halo_info = GpuLower::current()->haloInfo(); + auto halo_info = GpuLower::current()->haloInfo(); auto input_tvs = ir_utils::filterByType(expr->inputs()); - return halo_info.needsShiftPredicate(expr) || + return halo_info->needsShiftPredicate(expr) || std::any_of(input_tvs.begin(), input_tvs.end(), [&](auto input_tv) { return input_tv->definition() != nullptr && - halo_info.needsShiftPredicate(input_tv->definition()); + halo_info->needsShiftPredicate(input_tv->definition()); }); } diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index fe1e0cc509c13..2bfe7f8e5ab2a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -28,7 +28,7 @@ void ShiftPredicateInserter::insert( TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output"); const bool needs_shift_predicate = - gpu_lower->haloInfo().needsShiftPredicate(out_tv->definition()); + gpu_lower->haloInfo()->needsShiftPredicate(out_tv->definition()); if (!needs_shift_predicate) { return; } @@ -145,13 +145,6 @@ const AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) const { return it->second; } -AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - return const_cast( - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(this)->getRootAxisInfo(id)); -} - void HaloInfo::setRootAxisInfo( IterDomain* id, const AxisHaloInfo& root_axis_info) { @@ -161,7 +154,9 @@ void HaloInfo::setRootAxisInfo( return; } -void HaloInfo::build(Fusion* fusion) { +HaloInfo::HaloInfo(Fusion* fusion, std::shared_ptr ca_map) + // Make a copy of the permissive map for extent comparators + : permissive_map_(ca_map->idGraph().permissiveNodes()) { const auto vals = fusion->usedMathVals(); auto tvs = ir_utils::filterByType(vals); @@ -202,7 +197,7 @@ void HaloInfo::build(Fusion* fusion) { // Note that validation requires consumer halo info for (auto tv : tvs) { - validate(tv); + validate(tv, ca_map); } } @@ -474,12 +469,13 @@ void HaloInfo::build(TensorDomain* td) { //! Other types of parallelization should be supported except for //! vectorization. Vectorization should be eventually supported but //! needs further work. -void HaloInfo::validate(TensorView* tv) const { +void HaloInfo::validate( + TensorView* tv, + std::shared_ptr ca_map) const { const auto mem_type = tv->getMemoryType(); for (auto axis : tv->domain()->domain()) { - auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( - axis, IdMappingMode::LOOP); + auto concrete_id = ca_map->getConcreteMappedID(axis, IdMappingMode::LOOP); // The extent is assumed to be the same TORCH_INTERNAL_ASSERT( @@ -526,7 +522,7 @@ void HaloInfo::validate(TensorView* tv) const { consumer->domain()->domain().begin(), consumer->domain()->domain().end(), [&](IterDomain* consumer_axis) { - return GpuLower::current()->caMap()->areMapped( + return ca_map->areMapped( axis, consumer_axis, IdMappingMode::PERMISSIVE); }); if (it == consumer->domain()->domain().end()) { @@ -626,11 +622,10 @@ bool extentCompare( const HaloInfo& halo_map, IterDomain* id1, IterDomain* id2, - Cmp cmp) { - auto gpu_lower = GpuLower::current(); + Cmp cmp, + const DisjointSets& permissive_map) { TORCH_INTERNAL_ASSERT( - gpu_lower->caMap()->areMapped(id1, id2, IdMappingMode::PERMISSIVE), - "Invalid axes to compare"); + permissive_map.strictAreMapped(id1, id2), "Invalid axes to compare"); // It's invalid to compare two axes and when only either of them has // halo. @@ -652,10 +647,10 @@ bool extentCompare( auto merge2 = dynamic_cast(id2->definition()); TORCH_INTERNAL_ASSERT( merge2 != nullptr, "Invalid comparison: ", id1, " and ", id2); - auto inner_le = - extentCompare(halo_map, merge1->inner(), merge2->inner(), cmp); - auto outer_le = - extentCompare(halo_map, merge1->outer(), merge2->outer(), cmp); + auto inner_le = extentCompare( + halo_map, merge1->inner(), merge2->inner(), cmp, permissive_map); + auto outer_le = extentCompare( + halo_map, merge1->outer(), merge2->outer(), cmp, permissive_map); return inner_le && outer_le; } else { // This is not considered. Should never reach here. @@ -667,11 +662,11 @@ bool extentCompare( } // namespace bool HaloInfo::extentLessEqual(IterDomain* id1, IterDomain* id2) const { - return extentCompare(*this, id1, id2, std::less_equal<>()); + return extentCompare(*this, id1, id2, std::less_equal<>(), permissive_map_); } bool HaloInfo::extentEqual(IterDomain* id1, IterDomain* id2) const { - return extentCompare(*this, id1, id2, std::equal_to<>()); + return extentCompare(*this, id1, id2, std::equal_to<>(), permissive_map_); } std::string HaloInfo::toString() const { @@ -722,11 +717,11 @@ bool HaloInfo::needsShiftPredicate(Expr* expr) const { } std::unordered_map HaloInfo::buildConcreteHaloExtentMap( - const LoopIndexing& loop_indexing) { + const LoopIndexing& loop_indexing) const { // Use a local workspace to avoid re-defining halo info. - HaloInfo local_halo_info; + HaloInfo local_halo_info = *GpuLower::current()->haloInfo(); - auto& global_halo_info = GpuLower::current()->haloInfo(); + auto global_halo_info = GpuLower::current()->haloInfo(); // Setup root: for (auto consumer_root_id : loop_indexing.consumerTv()->getRootDomain()) { @@ -734,7 +729,7 @@ std::unordered_map HaloInfo::buildConcreteHaloExtentMap( ir_utils::caMapExactConcreteId(consumer_root_id); local_halo_info.setRootAxisInfo( consumer_index_concrete_id, - global_halo_info.getRootAxisInfo(consumer_root_id)); + global_halo_info->getRootAxisInfo(consumer_root_id)); } // Track IDs that are generated by merging halo-extended IDs @@ -801,7 +796,8 @@ std::unordered_map HaloInfo::buildConcreteHaloExtentMap( merged_shifted_ids.insert(ir_utils::caMapExactConcreteId(merge->out())); // Note that halo_width_map_ is not updated } else { - setHaloWidth(ir_utils::caMapExactConcreteId(merge->out()), 0); + local_halo_info.setHaloWidth( + ir_utils::caMapExactConcreteId(merge->out()), 0); } } else if (auto swizzle_2d = dynamic_cast(expr)) { // Swizzle with halo not yet supported, just set the width diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.h b/torch/csrc/jit/codegen/cuda/lower_shift.h index d1500c5f9f203..0cb3c3ea44572 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.h +++ b/torch/csrc/jit/codegen/cuda/lower_shift.h @@ -61,23 +61,12 @@ class AxisHaloInfo { class TORCH_CUDA_CU_API HaloInfo { public: //! Scan a fusion and collect all information for lowering - void build(Fusion* fusion); - - //! Build mappings of extent information of a TensorDomain - void build(TensorDomain* td); + HaloInfo(Fusion* fusion, std::shared_ptr ca_map); //! Almost exact duplicate of build(TensorDomain* td), except that //! the traversal was done on loop indexing expressions. std::unordered_map buildConcreteHaloExtentMap( - const LoopIndexing& loop_indexing); - - //! Set initial AxisHaloInfo of a root axis - //! - //! The axis does not need to be a root domain in the case of - //! reference tensors. Reference tensors get halo information from - //! consumer root domains, which may correspond to rfactor domains - //! of tensors from which reference tensors are derived. - void setRootAxisInfo(IterDomain* id, const AxisHaloInfo& root_axis_info); + const LoopIndexing& loop_indexing) const; //! Returns true if id has the root halo information set by //! setRootAxisInfo. @@ -88,7 +77,6 @@ class TORCH_CUDA_CU_API HaloInfo { //! This is only for root axes. It is an error to query with //! non-root axes. const AxisHaloInfo& getRootAxisInfo(IterDomain* id) const; - AxisHaloInfo& getRootAxisInfo(IterDomain* id); //! Query if an axis has a halo width. //! @@ -139,10 +127,21 @@ class TORCH_CUDA_CU_API HaloInfo { std::string toString() const; private: + //! Build mappings of extent information of a TensorDomain + void build(TensorDomain* td); + //! Propagate root axis information from outputs to inputs of an //! expression void propagateRootAxisInfo(Expr* expr); + //! Set initial AxisHaloInfo of a root axis + //! + //! The axis does not need to be a root domain in the case of + //! reference tensors. Reference tensors get halo information from + //! consumer root domains, which may correspond to rfactor domains + //! of tensors from which reference tensors are derived. + void setRootAxisInfo(IterDomain* id, const AxisHaloInfo& root_axis_info); + //! Adds a domain to the halo inheritance map. //! //! A domain, child, is added to the same set as domain parent. Both @@ -163,11 +162,15 @@ class TORCH_CUDA_CU_API HaloInfo { void initializeFromRootAxisInfo(IterDomain* id); //! Validate shift usage - void validate(TensorView* td) const; + void validate(TensorView* td, std::shared_ptr ca_map) + const; void setHaloWidth(IterDomain* id, int halo_width); private: + // Copy the permissive map from the passed in compute at map + const DisjointSets permissive_map_; + //! Halo information of root axes std::unordered_map root_axis_map_; diff --git a/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp b/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp index 497256b5f850e..9a797692cc4be 100644 --- a/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp @@ -240,12 +240,12 @@ void SyncMap::build(Fusion* fusion) { p_id, c_id, IdMappingMode::PERMISSIVE)) { const auto halo_info = GpuLower::current()->haloInfo(); - if (halo_info.hasHaloWidth(p_id) != - halo_info.hasHaloWidth(c_id) || - (halo_info.hasHaloWidth(p_id) && - halo_info.hasHaloWidth(c_id) && - halo_info.getHaloWidth(p_id) != - halo_info.getHaloWidth(c_id))) { + if (halo_info->hasHaloWidth(p_id) != + halo_info->hasHaloWidth(c_id) || + (halo_info->hasHaloWidth(p_id) && + halo_info->hasHaloWidth(c_id) && + halo_info->getHaloWidth(p_id) != + halo_info->getHaloWidth(c_id))) { raw_dims.set(parallel_type); continue; } diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 434d1711d9c83..165aef7d6f4c7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -81,7 +81,7 @@ void UnrollPass::handle(Expr* expr) { // When a predicate needs to account for ShiftOp, it is currently // taken care by its own function. - if (GpuLower::current()->haloInfo().needsShiftPredicate(expr)) { + if (GpuLower::current()->haloInfo()->needsShiftPredicate(expr)) { ShiftPredicateInserter::insert( expr, for_loops_, thread_pred, unswitched_loop_); return; diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index aa88b0e7907db..bd25d8895b856 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -740,7 +740,7 @@ bool isTrivialIterDomain(IterDomain* id) { (id->extent()->isOneInt() && id->start()->isZeroInt()) || pt == ParallelType::Vectorize || (isParallelTypeThread(pt) && - !GpuLower::current()->haloInfo().hasHaloWidth(id)); + !GpuLower::current()->haloInfo()->hasHaloWidth(id)); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 06fbea049bde5..dc3d1c53906eb 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -243,7 +243,7 @@ struct TORCH_CUDA_CU_API IterDomainDependencySorter { IterDomainDependencySorter( const std::unordered_map>& concrete_id_dependencies, - const std::unique_ptr& compute_at_map) + std::shared_ptr compute_at_map) : concrete_id_dependencies_(concrete_id_dependencies), compute_at_map_(compute_at_map) {} @@ -269,7 +269,7 @@ struct TORCH_CUDA_CU_API IterDomainDependencySorter { const std::unordered_map>& concrete_id_dependencies_; - const std::unique_ptr& compute_at_map_; + const std::shared_ptr compute_at_map_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index de2c1135ad202..da1def37cad84 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -1183,7 +1183,7 @@ void validateAndConvertIterDomainGrouping(Fusion* fusion) { // Halo is not allowed TORCH_CHECK( - GpuLower::current()->haloInfo().getExtent(id) == nullptr, + GpuLower::current()->haloInfo()->getExtent(id) == nullptr, "Invalid use of ParallelType::Group.", " Grouping of halo-extended IterDomain, ", id->toString(),