Skip to content

Commit fa4e6a4

Browse files
authored
Check siblings in getMaxPosAll (csarofeen#1805)
1 parent 025c840 commit fa4e6a4

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

torch/csrc/jit/codegen/cuda/inline_propagator.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,17 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
132132
return producer->nDims();
133133
}
134134

135-
size_t InlinePropagator::getMaxPosAll(TensorView* tv) {
135+
size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) {
136136
auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false);
137137
for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
138138
max_pos = std::min<size_t>(
139139
max_pos, max_pos_calc.getMaxProducerPosFromConsumer(tv, consumer_tv));
140140
}
141+
if (check_siblings) {
142+
for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) {
143+
max_pos = std::min<size_t>(max_pos, getMaxPosAll(sibling_tv, false));
144+
}
145+
}
141146
return max_pos;
142147
}
143148

torch/csrc/jit/codegen/cuda/inline_propagator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class MaxPosCalculator {
7070
class InlinePropagator : public MaxInfoSpanningTree::Propagator {
7171
// Checks producers and consumers to see what the maximum position in tv is
7272
// that can be shared across both directions.
73-
size_t getMaxPosAll(TensorView* tv);
73+
size_t getMaxPosAll(TensorView* tv, bool check_siblings = true);
7474

7575
// We use mapped_reference_pos_ to keep track of the outer axes information of
7676
// the reference tensor. That is, mapped_reference_pos_[tv] answers the

torch/csrc/jit/codegen/cuda/test/test_gpu.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24167,6 +24167,30 @@ TEST_F(NVFuserTest, FusionSkipReplay_CUDA) {
2416724167
}
2416824168
}
2416924169

24170+
TEST_F(NVFuserTest, FusionInlineRepro1803_CUDA) {
24171+
Fusion fusion;
24172+
FusionGuard fg(&fusion);
24173+
24174+
TensorView* tv0 = makeContigTensor(2);
24175+
24176+
fusion.addInput(tv0);
24177+
auto tv1 = set(tv0);
24178+
auto tvs = Welford(tv1, {1});
24179+
auto tvo = set(tvs.var_sum);
24180+
fusion.addOutput(tvo);
24181+
24182+
tvo->split(0, 16);
24183+
tvo->axis(1)->parallelize(ParallelType::Unroll);
24184+
24185+
tv0->computeAt(tvo, -1, ComputeAtMode::BestEffort);
24186+
24187+
TORCH_CHECK(
24188+
tvs.var_sum->getComputeAtPosition() == tvs.avg->getComputeAtPosition());
24189+
TORCH_CHECK(
24190+
tvs.var_sum->getComputeAtPosition() == tvs.n->getComputeAtPosition());
24191+
TORCH_CHECK(tvs.var_sum->getComputeAtPosition() == 1);
24192+
}
24193+
2417024194
} // namespace jit
2417124195
} // namespace torch
2417224196
#endif // #if defined(USE_CUDA)

0 commit comments

Comments
 (0)