Skip to content

Commit 3f2c263

Browse files
authored
validateDomain in TransformPropagator (#1796)
1 parent c077085 commit 3f2c263

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

torch/csrc/jit/codegen/cuda/transform_replay.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,19 @@ bool TransformReplay::fullSelfMatching(
835835
return true;
836836
}
837837

838+
namespace {
839+
840+
// Make sure if tv is set to new_td it doesn't violate set compute at and max
841+
// produce at positions.
842+
bool validateDomain(TensorView* tv, TensorDomain* new_td) {
843+
auto first_mismatch =
844+
BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td);
845+
return first_mismatch >= (int)tv->getMaxProducerPosition() &&
846+
first_mismatch >= (int)tv->getComputeAtPosition();
847+
}
848+
849+
} // namespace
850+
838851
void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) {
839852
int pos = replayed_pos_.at(from);
840853
// Note: [Using multiple TransformPropagators]
@@ -849,6 +862,13 @@ void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) {
849862
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
850863
if (new_pos < 0) {
851864
auto replay = TransformReplay::replayPasC(to, from, pos);
865+
TORCH_INTERNAL_ASSERT(
866+
validateDomain(to, replay.first),
867+
"Tried to set the domain of ",
868+
to,
869+
" to ",
870+
replay.first,
871+
" but that would invalidate previously compute at position or max producer position.");
852872
to->setDomain(replay.first);
853873
new_pos = replay.second;
854874
}
@@ -862,6 +882,13 @@ void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) {
862882
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
863883
if (new_pos < 0) {
864884
auto replay = TransformReplay::replayCasP(to, from, pos);
885+
TORCH_INTERNAL_ASSERT(
886+
validateDomain(to, replay.first),
887+
"Tried to set the domain of ",
888+
to,
889+
" to ",
890+
replay.first,
891+
" but that would invalidate previously compute at position or max producer position.");
865892
to->setDomain(replay.first);
866893
new_pos = replay.second;
867894
}
@@ -873,6 +900,13 @@ void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) {
873900
// See note [Using multiple TransformPropagators]
874901
if (!TransformReplay::fullSelfMatching(to, from)) {
875902
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
903+
TORCH_INTERNAL_ASSERT(
904+
validateDomain(to, replay),
905+
"Tried to set the domain of ",
906+
to,
907+
" to ",
908+
replay,
909+
" but that would invalidate previously compute at position or max producer position.");
876910
to->setDomain(replay);
877911
}
878912
replayed_pos_[to] = pos;

0 commit comments

Comments
 (0)