@@ -835,6 +835,19 @@ bool TransformReplay::fullSelfMatching(
835
835
return true ;
836
836
}
837
837
838
+ namespace {
839
+
840
+ // Make sure if tv is set to new_td it doesn't violate set compute at and max
841
+ // produce at positions.
842
+ bool validateDomain (TensorView* tv, TensorDomain* new_td) {
843
+ auto first_mismatch =
844
+ BestEffortReplay::findFirstMismatchedID (tv->domain (), new_td);
845
+ return first_mismatch >= (int )tv->getMaxProducerPosition () &&
846
+ first_mismatch >= (int )tv->getComputeAtPosition ();
847
+ }
848
+
849
+ } // namespace
850
+
838
851
void TransformPropagator::propagateTvPasC (TensorView* from, TensorView* to) {
839
852
int pos = replayed_pos_.at (from);
840
853
// Note: [Using multiple TransformPropagators]
@@ -849,6 +862,13 @@ void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) {
849
862
TransformReplay::getMatchedLeafPosWithoutReplayPasC (to, from, pos);
850
863
if (new_pos < 0 ) {
851
864
auto replay = TransformReplay::replayPasC (to, from, pos);
865
+ TORCH_INTERNAL_ASSERT (
866
+ validateDomain (to, replay.first ),
867
+ " Tried to set the domain of " ,
868
+ to,
869
+ " to " ,
870
+ replay.first ,
871
+ " but that would invalidate previously compute at position or max producer position." );
852
872
to->setDomain (replay.first );
853
873
new_pos = replay.second ;
854
874
}
@@ -862,6 +882,13 @@ void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) {
862
882
TransformReplay::getMatchedLeafPosWithoutReplayCasP (to, from, pos);
863
883
if (new_pos < 0 ) {
864
884
auto replay = TransformReplay::replayCasP (to, from, pos);
885
+ TORCH_INTERNAL_ASSERT (
886
+ validateDomain (to, replay.first ),
887
+ " Tried to set the domain of " ,
888
+ to,
889
+ " to " ,
890
+ replay.first ,
891
+ " but that would invalidate previously compute at position or max producer position." );
865
892
to->setDomain (replay.first );
866
893
new_pos = replay.second ;
867
894
}
@@ -873,6 +900,13 @@ void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) {
873
900
// See note [Using multiple TransformPropagators]
874
901
if (!TransformReplay::fullSelfMatching (to, from)) {
875
902
auto replay = TransformReplay::fullSelfReplay (to->domain (), from->domain ());
903
+ TORCH_INTERNAL_ASSERT (
904
+ validateDomain (to, replay),
905
+ " Tried to set the domain of " ,
906
+ to,
907
+ " to " ,
908
+ replay,
909
+ " but that would invalidate previously compute at position or max producer position." );
876
910
to->setDomain (replay);
877
911
}
878
912
replayed_pos_[to] = pos;
0 commit comments