@@ -10,11 +10,11 @@ namespace jit {
10
10
namespace fuser {
11
11
namespace cuda {
12
12
13
- bool InlinePropagatorSelector::allowPasC (TensorView* from, TensorView* to) {
13
+ bool InlinePropagatorSelector::allowC2P (TensorView* from, TensorView* to) {
14
14
return selected_.count (to) > 0 ;
15
15
}
16
16
17
- bool InlinePropagatorSelector::allowCasP (TensorView* from, TensorView* to) {
17
+ bool InlinePropagatorSelector::allowP2C (TensorView* from, TensorView* to) {
18
18
// If the producer is in the selected set, then the consumer must also be
19
19
// replayed to obtain a compatible loop structure so that this producer
20
20
// can be consumed in this loop.
@@ -112,9 +112,9 @@ size_t MaxPosCalculator::getMaxPosSelf(
112
112
// Unrolled dimensions in producer or consumer
113
113
// Dimensions derived from root dimensions that exist in both but are
114
114
// unmappable
115
- size_t MaxPosCalculator::getMaxPosPasC (
116
- TensorView* producer ,
117
- TensorView* consumer ) const {
115
+ size_t MaxPosCalculator::getMaxPosC2P (
116
+ TensorView* consumer ,
117
+ TensorView* producer ) const {
118
118
// Limit max position based on vectorized dims in consumer.
119
119
auto max_consumer_pos = getMaxPosSelf (consumer, true , false , true );
120
120
@@ -144,9 +144,9 @@ size_t MaxPosCalculator::getMaxPosPasC(
144
144
// Unrolled dimensions in producer or consumer
145
145
// Dimensions derived from root dimensions that exist in both but are
146
146
// unmappable
147
- size_t MaxPosCalculator::getMaxPosCasP (
148
- TensorView* consumer ,
149
- TensorView* producer ) const {
147
+ size_t MaxPosCalculator::getMaxPosP2C (
148
+ TensorView* producer ,
149
+ TensorView* consumer ) const {
150
150
auto max_producer_pos = getMaxPosSelf (producer, false , false , false );
151
151
152
152
auto pairwise_root_map = PairwiseRootDomainMap (producer, consumer);
@@ -173,16 +173,14 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv) {
173
173
for (auto consumer_tv : ir_utils::consumerTvsOf (tv)) {
174
174
// consumers are always replayed consistently
175
175
max_pos =
176
- std::min<size_t >(max_pos, max_pos_calc.getMaxPosCasP (consumer_tv, tv ));
176
+ std::min<size_t >(max_pos, max_pos_calc.getMaxPosP2C (tv, consumer_tv ));
177
177
}
178
178
return max_pos;
179
179
}
180
180
181
- size_t InlinePropagator::getFromPosPasC (
182
- TensorView* producer,
183
- TensorView* consumer) {
184
- size_t max_pos = max_pos_calc.getMaxPosPasC (producer, consumer);
185
- size_t pos = mapped_reference_pos_.at (consumer);
181
+ size_t InlinePropagator::getFromPosC2P (TensorView* from, TensorView* to) {
182
+ size_t max_pos = max_pos_calc.getMaxPosC2P (from, to);
183
+ size_t pos = mapped_reference_pos_.at (from);
186
184
187
185
if (mode_ == ComputeAtMode::BestEffort) {
188
186
return std::min (pos, max_pos);
@@ -193,21 +191,19 @@ size_t InlinePropagator::getFromPosPasC(
193
191
TORCH_INTERNAL_ASSERT (
194
192
pos <= max_pos,
195
193
" Invalid compute at position detected in compute at when trying to propagate the CA position from consumer: " ,
196
- consumer ,
194
+ from ,
197
195
" to producer: " ,
198
- producer ,
196
+ to ,
199
197
" tried to do this at position: " ,
200
198
pos,
201
199
" but max position that's allowed is " ,
202
200
max_pos);
203
201
return pos;
204
202
}
205
203
206
- size_t InlinePropagator::getFromPosCasP (
207
- TensorView* consumer,
208
- TensorView* producer) {
209
- size_t max_pos = max_pos_calc.getMaxPosCasP (consumer, producer);
210
- size_t pos = mapped_reference_pos_.at (producer);
204
+ size_t InlinePropagator::getFromPosP2C (TensorView* from, TensorView* to) {
205
+ size_t max_pos = max_pos_calc.getMaxPosP2C (from, to);
206
+ size_t pos = mapped_reference_pos_.at (from);
211
207
212
208
if (mode_ == ComputeAtMode::BestEffort) {
213
209
return std::min (pos, max_pos);
@@ -218,9 +214,9 @@ size_t InlinePropagator::getFromPosCasP(
218
214
TORCH_INTERNAL_ASSERT (
219
215
pos <= max_pos,
220
216
" Invalid compute at position detected in compute at when trying to propagate the CA position from producer: " ,
221
- producer ,
217
+ from ,
222
218
" to consumer: " ,
223
- consumer ,
219
+ to ,
224
220
" tried to do this at position: " ,
225
221
pos,
226
222
" but max position that's allowed is " ,
@@ -263,13 +259,13 @@ InlinePropagator::InlinePropagator(
263
259
" ." );
264
260
}
265
261
266
- void InlinePropagator::propagateTvPasC (TensorView* from, TensorView* to) {
262
+ void InlinePropagator::propagateC2P (TensorView* from, TensorView* to) {
267
263
if (is_first_) {
268
264
is_first_ = false ;
269
265
setCAPos (reference_, reference_pos_);
270
266
mapped_reference_pos_[reference_] = reference_pos_;
271
267
}
272
- int from_pos = getFromPosPasC (to, from );
268
+ int from_pos = getFromPosC2P (from, to );
273
269
auto to_pos =
274
270
TransformReplay::getMatchedLeafPosWithoutReplayPasC (to, from, from_pos);
275
271
TORCH_CHECK (
@@ -283,13 +279,13 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) {
283
279
mapped_reference_pos_[to] = to_pos;
284
280
}
285
281
286
- void InlinePropagator::propagateTvCasP (TensorView* from, TensorView* to) {
282
+ void InlinePropagator::propagateP2C (TensorView* from, TensorView* to) {
287
283
if (is_first_) {
288
284
is_first_ = false ;
289
285
setCAPos (reference_, reference_pos_);
290
286
mapped_reference_pos_[reference_] = reference_pos_;
291
287
}
292
- int from_pos = getFromPosCasP (to, from );
288
+ int from_pos = getFromPosP2C (from, to );
293
289
auto to_pos =
294
290
TransformReplay::getMatchedLeafPosWithoutReplayCasP (to, from, from_pos);
295
291
TORCH_CHECK (
@@ -303,7 +299,7 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) {
303
299
mapped_reference_pos_[to] = to_pos;
304
300
}
305
301
306
- void InlinePropagator::propagateTvSibling (TensorView* from, TensorView* to) {
302
+ void InlinePropagator::propagateSibling (TensorView* from, TensorView* to) {
307
303
if (is_first_) {
308
304
is_first_ = false ;
309
305
setCAPos (reference_, reference_pos_);
@@ -388,11 +384,11 @@ void MaxProducerPosUpdater::handle(TensorView* consumer) {
388
384
consumer->setMaxProducer (consumer_pos);
389
385
}
390
386
391
- void MaxProducerPosUpdater::propagateTvPasC (TensorView* from, TensorView* to) {
387
+ void MaxProducerPosUpdater::propagateC2P (TensorView* from, TensorView* to) {
392
388
if (updated_.empty ()) {
393
389
// handle the reference tensor
394
390
updated_.insert (nullptr );
395
- propagateTvPasC (nullptr , from);
391
+ propagateC2P (nullptr , from);
396
392
}
397
393
for (auto consumer_tv : ir_utils::consumerTvsOf (to)) {
398
394
if (updated_.count (consumer_tv) > 0 ) {
@@ -403,14 +399,12 @@ void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) {
403
399
}
404
400
}
405
401
406
- void MaxProducerPosUpdater::propagateTvCasP (TensorView* from, TensorView* to) {
407
- propagateTvPasC (from, to);
402
+ void MaxProducerPosUpdater::propagateP2C (TensorView* from, TensorView* to) {
403
+ propagateC2P (from, to);
408
404
}
409
405
410
- void MaxProducerPosUpdater::propagateTvSibling (
411
- TensorView* from,
412
- TensorView* to) {
413
- propagateTvPasC (from, to);
406
+ void MaxProducerPosUpdater::propagateSibling (TensorView* from, TensorView* to) {
407
+ propagateC2P (from, to);
414
408
}
415
409
416
410
} // namespace cuda
0 commit comments