File tree 3 files changed +7
-7
lines changed
torch/csrc/jit/codegen/cuda 3 files changed +7
-7
lines changed Original file line number Diff line number Diff line change @@ -163,7 +163,7 @@ std::unordered_set<TensorView*> getPropagationSubgraph(
163
163
void ComputeAt::runAt (
164
164
TensorView* producer,
165
165
TensorView* consumer,
166
- unsigned int consumer_position,
166
+ int64_t consumer_position,
167
167
ComputeAtMode mode) {
168
168
FUSER_PERF_SCOPE (" ComputeAt::runAt" );
169
169
@@ -176,7 +176,7 @@ void ComputeAt::runAt(
176
176
" are not in the same fusion." );
177
177
178
178
if (mode == ComputeAtMode::MostInlined) {
179
- consumer_position = consumer-> nDims () ;
179
+ consumer_position = - 1 ;
180
180
}
181
181
182
182
FusionGuard fg (producer->fusion ());
@@ -205,7 +205,7 @@ void ComputeAt::runAt(
205
205
void ComputeAt::runWith (
206
206
TensorView* producer,
207
207
TensorView* consumer,
208
- unsigned int producer_position,
208
+ int64_t producer_position,
209
209
ComputeAtMode mode) {
210
210
FUSER_PERF_SCOPE (" ComputeAt::runWith" );
211
211
@@ -218,7 +218,7 @@ void ComputeAt::runWith(
218
218
" are not in the same fusion." );
219
219
220
220
if (mode == ComputeAtMode::MostInlined) {
221
- producer_position = producer-> nDims () ;
221
+ producer_position = - 1 ;
222
222
}
223
223
224
224
FusionGuard fg (producer->fusion ());
Original file line number Diff line number Diff line change @@ -27,15 +27,15 @@ struct ComputeAt {
27
27
static void runAt (
28
28
TensorView* producer,
29
29
TensorView* consumer,
30
- unsigned int consumer_position,
30
+ int64_t consumer_position,
31
31
ComputeAtMode mode = ComputeAtMode::Standard);
32
32
33
33
// Runs the compute with pass making consumer look like producer, computing
34
34
// producer relative to consumer
35
35
static void runWith (
36
36
TensorView* producer,
37
37
TensorView* consumer,
38
- unsigned int producer_position,
38
+ int64_t producer_position,
39
39
ComputeAtMode mode = ComputeAtMode::Standard);
40
40
};
41
41
Original file line number Diff line number Diff line change @@ -178,7 +178,6 @@ InlinePropagator::InlinePropagator(
178
178
: max_pos_calc(mode),
179
179
selected_ (std::move(selected)),
180
180
reference_(reference),
181
- reference_pos_(reference_pos),
182
181
mode_(mode) {
183
182
if (reference_pos < 0 ) {
184
183
reference_pos += int64_t (reference->nDims ()) + 1 ;
@@ -192,6 +191,7 @@ InlinePropagator::InlinePropagator(
192
191
" and <= " ,
193
192
reference->nDims (),
194
193
" ." );
194
+ reference_pos_ = reference_pos;
195
195
}
196
196
197
197
void InlinePropagator::propagateC2P (TensorView* from, TensorView* to) {
You can’t perform that action at this time.
0 commit comments