@@ -178,22 +178,11 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv) {
178
178
return max_pos;
179
179
}
180
180
181
- size_t InlinePropagator::adjustComputeAtPos (TensorView* tv, size_t pos) {
182
- pos = std::min<size_t >(pos, getMaxPosAll (tv));
183
-
184
- // hoist inner most broadcast
185
- while (pos > 0 && tv->axis (pos - 1 )->isBroadcast ()) {
186
- pos--;
187
- }
188
-
189
- return pos;
190
- }
191
-
192
- size_t InlinePropagator::getReplayPosPasC (
181
+ size_t InlinePropagator::getFromPosPasC (
193
182
TensorView* producer,
194
183
TensorView* consumer) {
195
184
size_t max_pos = max_pos_calc.getMaxPosPasC (producer, consumer);
196
- size_t pos = retrieveReplayedPos (consumer);
185
+ size_t pos = mapped_reference_pos_. at (consumer);
197
186
198
187
if (mode_ == ComputeAtMode::BestEffort) {
199
188
return std::min (pos, max_pos);
@@ -203,22 +192,22 @@ size_t InlinePropagator::getReplayPosPasC(
203
192
204
193
TORCH_INTERNAL_ASSERT (
205
194
pos <= max_pos,
206
- " Invalid compute at position detected in compute at when trying to replay producer: " ,
207
- producer,
208
- " as consumer: " ,
195
+ " Invalid compute at position detected in compute at when trying to propagate the CA position from consumer: " ,
209
196
consumer,
197
+ " to producer: " ,
198
+ producer,
210
199
" tried to do this at position: " ,
211
200
pos,
212
201
" but max position that's allowed is " ,
213
202
max_pos);
214
203
return pos;
215
204
}
216
205
217
- size_t InlinePropagator::getReplayPosCasP (
206
+ size_t InlinePropagator::getFromPosCasP (
218
207
TensorView* consumer,
219
208
TensorView* producer) {
220
209
size_t max_pos = max_pos_calc.getMaxPosCasP (consumer, producer);
221
- size_t pos = retrieveReplayedPos (producer);
210
+ size_t pos = mapped_reference_pos_. at (producer);
222
211
223
212
if (mode_ == ComputeAtMode::BestEffort) {
224
213
return std::min (pos, max_pos);
@@ -228,42 +217,28 @@ size_t InlinePropagator::getReplayPosCasP(
228
217
229
218
TORCH_INTERNAL_ASSERT (
230
219
pos <= max_pos,
231
- " Invalid compute at position detected in compute at when trying to replay consumer: " ,
232
- consumer,
233
- " as producer: " ,
220
+ " Invalid compute at position detected in compute at when trying to propagate the CA position from producer: " ,
234
221
producer,
222
+ " to consumer: " ,
223
+ consumer,
235
224
" tried to do this at position: " ,
236
225
pos,
237
226
" but max position that's allowed is " ,
238
227
max_pos);
239
228
return pos;
240
229
}
241
230
242
- void InlinePropagator::recordReplayedPos (TensorView* tv, size_t pos) {
243
- if (selected_.count (tv)) {
244
- auto new_pos = adjustComputeAtPos (tv, pos );
245
- if (pos != new_pos) {
246
- replayed_pos_[tv] = pos;
247
- pos = new_pos ;
231
+ void InlinePropagator::setCAPos (TensorView* tv, size_t pos) {
232
+ if (selected_.count (tv) && !tv-> isFusionInput () ) {
233
+ pos = std::min< size_t >(pos, getMaxPosAll (tv) );
234
+ // hoist inner most broadcast
235
+ while (pos > 0 && tv-> axis ( pos - 1 )-> isBroadcast ()) {
236
+ pos-- ;
248
237
}
249
- if (!tv->isFusionInput ()) {
250
- tv->setComputeAt (pos);
251
- } else {
252
- replayed_pos_[tv] = pos;
253
- }
254
- } else {
255
- replayed_pos_[tv] = pos;
238
+ tv->setComputeAt (pos);
256
239
}
257
240
}
258
241
259
- size_t InlinePropagator::retrieveReplayedPos (TensorView* tv) {
260
- auto it = replayed_pos_.find (tv);
261
- if (it != replayed_pos_.end ()) {
262
- return it->second ;
263
- }
264
- return tv->getComputeAtPosition ();
265
- }
266
-
267
242
InlinePropagator::InlinePropagator (
268
243
std::unordered_set<TensorView*> selected,
269
244
TensorView* reference,
@@ -288,101 +263,62 @@ InlinePropagator::InlinePropagator(
288
263
" ." );
289
264
}
290
265
291
- namespace {
292
-
293
- // Make sure if tv is set to new_td it doesn't violate set compute at and max
294
- // produce at positions.
295
- bool validateDomain (TensorView* tv, TensorDomain* new_td) {
296
- auto first_mismatch =
297
- BestEffortReplay::findFirstMismatchedID (tv->domain (), new_td);
298
- return first_mismatch >= (int )tv->getMaxProducerPosition () &&
299
- first_mismatch >= (int )tv->getComputeAtPosition ();
300
- }
301
-
302
- } // namespace
303
-
304
266
void InlinePropagator::propagateTvPasC (TensorView* from, TensorView* to) {
305
267
if (is_first_) {
306
268
is_first_ = false ;
307
- recordReplayedPos (reference_, reference_pos_);
269
+ setCAPos (reference_, reference_pos_);
270
+ mapped_reference_pos_[reference_] = reference_pos_;
308
271
}
309
- int pos = getReplayPosPasC (to, from);
272
+ int from_pos = getFromPosPasC (to, from);
310
273
auto to_pos =
311
- TransformReplay::getMatchedLeafPosWithoutReplayPasC (to, from, pos);
312
- if (mode_ != ComputeAtMode::MostInlined) {
313
- TORCH_CHECK (
314
- to_pos >= 0 ,
315
- " Unable to propagate CA position from consumer " ,
316
- from,
317
- " to producer " ,
318
- to,
319
- " because this would require replay." );
320
- }
321
- if (to_pos < 0 ) {
322
- auto replay = TransformReplay::replayPasC (to, from, pos);
323
- TORCH_INTERNAL_ASSERT (
324
- validateDomain (to, replay.first ),
325
- " Tried to set the domain of " ,
326
- to,
327
- " to " ,
328
- replay.first ,
329
- " but that would invalidate previously compute at position or max producer position." );
330
- to->setDomain (replay.first );
331
- to_pos = replay.second ;
332
- }
333
- recordReplayedPos (to, to_pos);
274
+ TransformReplay::getMatchedLeafPosWithoutReplayPasC (to, from, from_pos);
275
+ TORCH_CHECK (
276
+ to_pos >= 0 ,
277
+ " Unable to propagate CA position from consumer " ,
278
+ from,
279
+ " to producer " ,
280
+ to,
281
+ " because this would require replay." );
282
+ setCAPos (to, to_pos);
283
+ mapped_reference_pos_[to] = to_pos;
334
284
}
335
285
336
286
void InlinePropagator::propagateTvCasP (TensorView* from, TensorView* to) {
337
287
if (is_first_) {
338
288
is_first_ = false ;
339
- recordReplayedPos (reference_, reference_pos_);
289
+ setCAPos (reference_, reference_pos_);
290
+ mapped_reference_pos_[reference_] = reference_pos_;
340
291
}
341
- int pos = getReplayPosCasP (to, from);
292
+ int from_pos = getFromPosCasP (to, from);
342
293
auto to_pos =
343
- TransformReplay::getMatchedLeafPosWithoutReplayCasP (to, from, pos);
344
- if (mode_ != ComputeAtMode::MostInlined) {
345
- TORCH_CHECK (
346
- to_pos >= 0 ,
347
- " Unable to propagate CA position from producer " ,
348
- from,
349
- " to consumer " ,
350
- to,
351
- " because this would require replay." );
352
- }
353
- if (to_pos < 0 ) {
354
- auto replay = TransformReplay::replayCasP (to, from, pos);
355
- TORCH_INTERNAL_ASSERT (
356
- validateDomain (to, replay.first ),
357
- " Tried to set the domain of " ,
358
- to,
359
- " to " ,
360
- replay.first ,
361
- " but that would invalidate previously compute at position or max producer position." );
362
- to->setDomain (replay.first );
363
- to_pos = replay.second ;
364
- }
365
- recordReplayedPos (to, to_pos);
294
+ TransformReplay::getMatchedLeafPosWithoutReplayCasP (to, from, from_pos);
295
+ TORCH_CHECK (
296
+ to_pos >= 0 ,
297
+ " Unable to propagate CA position from producer " ,
298
+ from,
299
+ " to consumer " ,
300
+ to,
301
+ " because this would require replay." );
302
+ setCAPos (to, to_pos);
303
+ mapped_reference_pos_[to] = to_pos;
366
304
}
367
305
368
306
void InlinePropagator::propagateTvSibling (TensorView* from, TensorView* to) {
369
307
if (is_first_) {
370
308
is_first_ = false ;
371
- recordReplayedPos (reference_, reference_pos_);
372
- }
373
- auto from_pos = retrieveReplayedPos (from);
374
- if (!TransformReplay::fullSelfMatching (to, from)) {
375
- auto replay = TransformReplay::fullSelfReplay (to->domain (), from->domain ());
376
- TORCH_INTERNAL_ASSERT (
377
- validateDomain (to, replay),
378
- " Tried to set the domain of " ,
379
- to,
380
- " to " ,
381
- replay,
382
- " but that would invalidate previously compute at position or max producer position." );
383
- to->setDomain (replay);
309
+ setCAPos (reference_, reference_pos_);
310
+ mapped_reference_pos_[reference_] = reference_pos_;
384
311
}
385
- recordReplayedPos (to, from_pos);
312
+ auto from_pos = mapped_reference_pos_.at (from);
313
+ TORCH_CHECK (
314
+ TransformReplay::fullSelfMatching (to, from),
315
+ " Unable to propagate CA position from " ,
316
+ from,
317
+ " to sibling " ,
318
+ to,
319
+ " because this would require replay." );
320
+ setCAPos (to, from_pos);
321
+ mapped_reference_pos_[to] = from_pos;
386
322
}
387
323
388
324
namespace {
0 commit comments