Skip to content

Commit de6b7ca

Browse files
authored
Fix negative position in InlinePropagator (csarofeen#1813)
1 parent 10a996c commit de6b7ca

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

torch/csrc/jit/codegen/cuda/compute_at.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ std::unordered_set<TensorView*> getPropagationSubgraph(
163163
void ComputeAt::runAt(
164164
TensorView* producer,
165165
TensorView* consumer,
166-
unsigned int consumer_position,
166+
int64_t consumer_position,
167167
ComputeAtMode mode) {
168168
FUSER_PERF_SCOPE("ComputeAt::runAt");
169169

@@ -176,7 +176,7 @@ void ComputeAt::runAt(
176176
" are not in the same fusion.");
177177

178178
if (mode == ComputeAtMode::MostInlined) {
179-
consumer_position = consumer->nDims();
179+
consumer_position = -1;
180180
}
181181

182182
FusionGuard fg(producer->fusion());
@@ -205,7 +205,7 @@ void ComputeAt::runAt(
205205
void ComputeAt::runWith(
206206
TensorView* producer,
207207
TensorView* consumer,
208-
unsigned int producer_position,
208+
int64_t producer_position,
209209
ComputeAtMode mode) {
210210
FUSER_PERF_SCOPE("ComputeAt::runWith");
211211

@@ -218,7 +218,7 @@ void ComputeAt::runWith(
218218
" are not in the same fusion.");
219219

220220
if (mode == ComputeAtMode::MostInlined) {
221-
producer_position = producer->nDims();
221+
producer_position = -1;
222222
}
223223

224224
FusionGuard fg(producer->fusion());

torch/csrc/jit/codegen/cuda/compute_at.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ struct ComputeAt {
2727
static void runAt(
2828
TensorView* producer,
2929
TensorView* consumer,
30-
unsigned int consumer_position,
30+
int64_t consumer_position,
3131
ComputeAtMode mode = ComputeAtMode::Standard);
3232

3333
// Runs the compute with pass making consumer look like producer, computing
3434
// producer relative to consumer
3535
static void runWith(
3636
TensorView* producer,
3737
TensorView* consumer,
38-
unsigned int producer_position,
38+
int64_t producer_position,
3939
ComputeAtMode mode = ComputeAtMode::Standard);
4040
};
4141

torch/csrc/jit/codegen/cuda/inline_propagator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ InlinePropagator::InlinePropagator(
178178
: max_pos_calc(mode),
179179
selected_(std::move(selected)),
180180
reference_(reference),
181-
reference_pos_(reference_pos),
182181
mode_(mode) {
183182
if (reference_pos < 0) {
184183
reference_pos += int64_t(reference->nDims()) + 1;
@@ -192,6 +191,7 @@ InlinePropagator::InlinePropagator(
192191
" and <= ",
193192
reference->nDims(),
194193
".");
194+
reference_pos_ = reference_pos;
195195
}
196196

197197
void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {

0 commit comments

Comments
 (0)