diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index d56d57120c03..1da77963f6c2 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -13,7 +13,8 @@ ComputeAtData::ComputeAtData(TensorView* tv) : tv_ref_(tv), original_has_compute_at_(tv->hasComputeAt()), original_compute_at_position(tv->getThisComputeAtAxis()), - original_domain_(tv->domain()) {} + original_domain_(tv->domain()), + new_compute_at_domain_(tv->domain()) {} // Clear pass based data void ComputeAtData::clearPass() { @@ -187,16 +188,17 @@ unsigned int ComputeAt::backwardComputeAt_impl( TensorView* producer, TensorView* consumer, unsigned int consumer_compute_at_axis) { - auto& entry = tv_data.at(producer); + auto& producer_entry = tv_data.at(producer); // Use TensorDomain interface so it doesn't set computeAt automatically auto replay = TransformReplay::replayPasC( producer, consumer, (int)consumer_compute_at_axis); - entry.setPassPosition(replay.second); + producer_entry.setPassPosition(replay.second); - if (entry.shouldSetComputeAt(replay.second)) { + if (producer_entry.shouldSetComputeAt(replay.second)) { producer->setComputeAt(consumer, (int)consumer_compute_at_axis); + producer_entry.setComputeAtDomain(producer->domain()); } return replay.second; @@ -213,11 +215,17 @@ unsigned int ComputeAt::forwardComputeAt_impl( auto replay = TransformReplay::replayCasP( consumer, producer, (int)producer_compute_at_axis); - consumer_entry.setPassPosition(replay.second); if (producer_entry.shouldSetComputeAt(producer_compute_at_axis)) { producer->setComputeAt(consumer, replay.second); } + consumer_entry.setPassPosition(replay.second); + if ((consumer_entry.shouldSetComputeAt(replay.second) && + consumer != consumer_) || + (consumer == consumer_ && replay.second >= consumer_position_)) { + consumer_entry.setComputeAtDomain(consumer->domain()); + } + return replay.second; } @@ -359,9 +367,19 @@ void ComputeAt::runPass() { setupOutputs(); - for (const auto entry : tv_data) { + for (const auto& entry : tv_data) { + entry.first->setDomain(entry.second.getComputeAtDomain()); entry.second.validateNewComputeAt(); } + + TORCH_INTERNAL_ASSERT( + BestEffortReplay::findFirstMismatchedID( + consumer_->domain(), tv_data.at(consumer_).getOriginalDomain()) == + consumer_->domain()->nDims(), + "ComputeAt logic changed the consumer domain which should not happen. Domain was ", + tv_data.at(consumer_).getOriginalDomain(), + " but is now: ", + consumer_->domain()); } void ComputeAt::setupOutputs() { diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index 2cefd04edffd..2071788ddd7c 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -29,7 +29,7 @@ class ComputeAtData { // an invalid compute_at that would require tensor replication. void setPassPosition(unsigned int pos); - // Returns if new postion is greater or equal to previous seen + // Returns if new postion is greater or equal to previous seen, if bool shouldSetComputeAt(unsigned int pos) const { return pos > original_compute_at_position && pos > new_compute_at_position && pos >= current_traversal_position; @@ -48,14 +48,23 @@ class ComputeAtData { return touched_; } - // Traversal domain, public as it can freely be set without impacting any - // other data. Just a convenience to have it included here. - TensorDomain* traversal_domain = nullptr; + TensorDomain* getOriginalDomain() const { + return original_domain_; + } - private: - // Position to update after a traversal - unsigned int new_compute_at_position = 0; + // If we set computeAt, save the domain so we can reset it after traversal. + // Traversal state can deviate from the domain we will want to save after the + // entire computeAt pass. + void setComputeAtDomain(TensorDomain* td) { + new_compute_at_domain_ = td; + } + + // Return domain set in setComputeAtDomain + TensorDomain* getComputeAtDomain() const { + return new_compute_at_domain_; + } + private: // Was the position ever modified? bool touched_ = false; @@ -76,6 +85,13 @@ class ComputeAtData { // Did this traversal set a position or not yet bool current_traversal_position_set = false; + + // Position to update after a traversal + unsigned int new_compute_at_position = 0; + + // Domain when we actually set computeAt, will set back to this after the + // pass. + TensorDomain* new_compute_at_domain_; }; class ComputeAt { diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 854e5437cf91..893078c21866 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -355,6 +355,9 @@ IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner) is_broadcast_domain_(src->is_broadcast_domain_) {} bool IterDomain::sameAs(const IterDomain* const other) const { + if (other == this) + return true; + bool is_same = isReduction() == other->isReduction() && parallel_method() == other->parallel_method(); is_same = is_same && ScalarCheck::sameAs(extent(), other->extent()); diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp index 2713bd8ebd87..1c6996ba93b8 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp @@ -369,16 +369,17 @@ int BestEffortReplay::findFirstMismatchedID( const TensorDomain* td2) { std::unordered_map id_map; auto rd1 = td1->rootDomain(); + auto rd2 = td2->rootDomain(); std::unordered_set rd2_set( td2->rootDomain().begin(), td2->rootDomain().end()); // Find matching root IterDomains, we could make this O(nlog(n)) if we could // sort IterDomains. for (auto rd1i : rd1) { - for (IterDomain* rd2_id : rd2_set) { - if (rd1i->sameAs(rd2_id)) { - id_map[rd1i] = rd2_id; - rd2_set.erase(rd2_id); + for (auto rd2i : rd2) { + if (rd1i->sameAs(rd2i) && rd2_set.find(rd2i) != rd2_set.end()) { + id_map[rd1i] = rd2i; + rd2_set.erase(rd2i); break; } } @@ -387,8 +388,14 @@ int BestEffortReplay::findFirstMismatchedID( BestEffortReplay ber(td2->domain(), td1->domain(), id_map); for (size_t i = 0; i < td1->domain().size(); i++) { - if (ber.getReplay().find(td1->axis(i)) == ber.getReplay().end()) + if (ber.getReplay().find(td1->axis(i)) == ber.getReplay().end()) { return i; + } + // Order is important. + auto td2_axis = ber.getReplay().at(td1->axis(i)); + if (td2->axis(i) != td2_axis) { + return i; + } } return td1->nDims(); }