@@ -104,129 +104,56 @@ size_t MaxPosCalculator::getMaxPosSelf(
104
104
return std::distance (dom.begin (), iter);
105
105
}
106
106
107
- // Return the max position in consumer that producer can be inlined to
108
- // Cannot inline:
109
- // Reduction dimensions in producer
110
- // Block broadcast dimensions in producer
111
- // Vectorized dimensions in producer or consumer
112
- // Unrolled dimensions in producer or consumer
113
- // Dimensions derived from root dimensions that exist in both but are
114
- // unmappable
115
- size_t MaxPosCalculator::getMaxPosC2P (
116
- TensorView* consumer,
117
- TensorView* producer) const {
118
- // Limit max position based on vectorized dims in consumer.
119
- auto max_consumer_pos = getMaxPosSelf (consumer, true , false , true );
120
-
121
- auto pairwise_root_map = PairwiseRootDomainMap (producer, consumer);
122
- auto replay_PasC =
123
- BestEffortReplay::replayPasC (producer, consumer, -1 , pairwise_root_map);
124
- auto c2p_replay_map = replay_PasC.getReplay ();
125
-
126
- for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0 ;
127
- consumer_pos--) {
128
- auto map_it = c2p_replay_map.find (consumer->axis ((int )consumer_pos - 1 ));
129
- if (map_it != c2p_replay_map.end ()) {
130
- auto p_id = map_it->second ;
131
- if (!isAllowedID (p_id, producer, true , false , false )) {
132
- max_consumer_pos = consumer_pos - 1 ;
133
- }
134
- }
135
- }
136
-
137
- return max_consumer_pos;
138
- }
139
-
140
107
// Return the max position in producer that can be inlined to consumer
141
108
// Cannot inline:
142
- // Reduction dimensions in producer
143
- // Vectorized dimensions in producer or consumer
144
- // Unrolled dimensions in producer or consumer
145
- // Dimensions derived from root dimensions that exist in both but are
146
- // unmappable
147
- size_t MaxPosCalculator::getMaxPosP2C (
109
+ // Vectorized dimensions in consumer
110
+ // Unrolled dimensions in consumer
111
+ size_t MaxPosCalculator::getMaxProducerPosFromConsumer (
148
112
TensorView* producer,
149
113
TensorView* consumer) const {
150
- auto max_producer_pos = getMaxPosSelf (producer, false , false , false );
151
-
152
114
auto pairwise_root_map = PairwiseRootDomainMap (producer, consumer);
153
115
auto replay_CasP =
154
116
BestEffortReplay::replayCasP (consumer, producer, -1 , pairwise_root_map);
155
117
auto p2c_replay_map = replay_CasP.getReplay ();
156
118
157
- for (size_t producer_pos = max_producer_pos ; producer_pos > 0 ;
158
- producer_pos-- ) {
159
- auto map_it = p2c_replay_map.find (producer->axis (( int ) producer_pos - 1 ));
119
+ for (size_t producer_pos = 0 ; producer_pos < producer-> nDims () ;
120
+ producer_pos++ ) {
121
+ auto map_it = p2c_replay_map.find (producer->axis (producer_pos));
160
122
if (map_it != p2c_replay_map.end ()) {
161
123
auto c_id = map_it->second ;
162
124
if (!isAllowedID (c_id, consumer, true , false , true )) {
163
- max_producer_pos = producer_pos - 1 ;
125
+ return producer_pos;
164
126
}
165
127
}
166
128
}
167
-
168
- return max_producer_pos;
129
+ return producer->nDims ();
169
130
}
170
131
171
132
size_t InlinePropagator::getMaxPosAll (TensorView* tv) {
172
133
auto max_pos = max_pos_calc.getMaxPosSelf (tv, false , false , false );
173
134
for (auto consumer_tv : ir_utils::consumerTvsOf (tv)) {
174
- // consumers are always replayed consistently
175
- max_pos =
176
- std::min<size_t >(max_pos, max_pos_calc.getMaxPosP2C (tv, consumer_tv));
135
+ max_pos = std::min<size_t >(
136
+ max_pos, max_pos_calc.getMaxProducerPosFromConsumer (tv, consumer_tv));
177
137
}
178
138
return max_pos;
179
139
}
180
140
181
- size_t InlinePropagator::getFromPosC2P (TensorView* from, TensorView* to) {
182
- size_t max_pos = max_pos_calc.getMaxPosC2P (from, to);
183
- size_t pos = mapped_reference_pos_.at (from);
184
-
185
- if (mode_ == ComputeAtMode::BestEffort) {
186
- return std::min (pos, max_pos);
187
- } else if (mode_ == ComputeAtMode::MostInlined) {
188
- return max_pos;
189
- }
190
-
191
- TORCH_INTERNAL_ASSERT (
192
- pos <= max_pos,
193
- " Invalid compute at position detected in compute at when trying to propagate the CA position from consumer: " ,
194
- from,
195
- " to producer: " ,
196
- to,
197
- " tried to do this at position: " ,
198
- pos,
199
- " but max position that's allowed is " ,
200
- max_pos);
201
- return pos;
202
- }
203
-
204
- size_t InlinePropagator::getFromPosP2C (TensorView* from, TensorView* to) {
205
- size_t max_pos = max_pos_calc.getMaxPosP2C (from, to);
206
- size_t pos = mapped_reference_pos_.at (from);
207
-
208
- if (mode_ == ComputeAtMode::BestEffort) {
209
- return std::min (pos, max_pos);
210
- } else if (mode_ == ComputeAtMode::MostInlined) {
211
- return max_pos;
212
- }
213
-
214
- TORCH_INTERNAL_ASSERT (
215
- pos <= max_pos,
216
- " Invalid compute at position detected in compute at when trying to propagate the CA position from producer: " ,
217
- from,
218
- " to consumer: " ,
219
- to,
220
- " tried to do this at position: " ,
221
- pos,
222
- " but max position that's allowed is " ,
223
- max_pos);
224
- return pos;
225
- }
226
-
227
- void InlinePropagator::setCAPos (TensorView* tv, size_t pos) {
141
+ void InlinePropagator::setCAPos (TensorView* tv) {
142
+ size_t pos = mapped_reference_pos_.at (tv);
228
143
if (selected_.count (tv) && !tv->isFusionInput ()) {
229
- pos = std::min<size_t >(pos, getMaxPosAll (tv));
144
+ auto max_pos = getMaxPosAll (tv);
145
+ if (mode_ == ComputeAtMode::Standard) {
146
+ TORCH_INTERNAL_ASSERT (
147
+ pos <= max_pos,
148
+ " Invalid compute at position detected in InlinePropagator when trying to set the CA position of: " ,
149
+ tv,
150
+ " to " ,
151
+ pos,
152
+ " , max position that's allowed is " ,
153
+ max_pos);
154
+ } else {
155
+ pos = std::min<size_t >(pos, max_pos);
156
+ }
230
157
// hoist inner most broadcast
231
158
while (pos > 0 && tv->axis (pos - 1 )->isBroadcast ()) {
232
159
pos--;
@@ -262,10 +189,16 @@ InlinePropagator::InlinePropagator(
262
189
void InlinePropagator::propagateC2P (TensorView* from, TensorView* to) {
263
190
if (is_first_) {
264
191
is_first_ = false ;
265
- setCAPos (reference_, reference_pos_);
266
192
mapped_reference_pos_[reference_] = reference_pos_;
193
+ setCAPos (reference_);
194
+ }
195
+ // Step 1: find mapped_reference_pos_[to]
196
+ int from_pos;
197
+ if (mode_ != ComputeAtMode::MostInlined) {
198
+ from_pos = mapped_reference_pos_.at (from);
199
+ } else {
200
+ from_pos = from->nDims ();
267
201
}
268
- int from_pos = getFromPosC2P (from, to);
269
202
auto to_pos =
270
203
TransformReplay::getMatchedLeafPosWithoutReplayPasC (to, from, from_pos);
271
204
TORCH_CHECK (
@@ -275,17 +208,24 @@ void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
275
208
" to producer " ,
276
209
to,
277
210
" because this would require replay." );
278
- setCAPos (to, to_pos);
279
211
mapped_reference_pos_[to] = to_pos;
212
+ // Step 2: set CA position of `to`
213
+ setCAPos (to);
280
214
}
281
215
282
216
void InlinePropagator::propagateP2C (TensorView* from, TensorView* to) {
283
217
if (is_first_) {
284
218
is_first_ = false ;
285
- setCAPos (reference_, reference_pos_);
286
219
mapped_reference_pos_[reference_] = reference_pos_;
220
+ setCAPos (reference_);
221
+ }
222
+ // Step 1: find mapped_reference_pos_[to]
223
+ int from_pos;
224
+ if (mode_ != ComputeAtMode::MostInlined) {
225
+ from_pos = mapped_reference_pos_.at (from);
226
+ } else {
227
+ from_pos = from->nDims ();
287
228
}
288
- int from_pos = getFromPosP2C (from, to);
289
229
auto to_pos =
290
230
TransformReplay::getMatchedLeafPosWithoutReplayCasP (to, from, from_pos);
291
231
TORCH_CHECK (
@@ -295,16 +235,18 @@ void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
295
235
" to consumer " ,
296
236
to,
297
237
" because this would require replay." );
298
- setCAPos (to, to_pos);
299
238
mapped_reference_pos_[to] = to_pos;
239
+ // Step 2: set CA position of `to`
240
+ setCAPos (to);
300
241
}
301
242
302
243
void InlinePropagator::propagateSibling (TensorView* from, TensorView* to) {
303
244
if (is_first_) {
304
245
is_first_ = false ;
305
- setCAPos (reference_, reference_pos_);
306
246
mapped_reference_pos_[reference_] = reference_pos_;
247
+ setCAPos (reference_);
307
248
}
249
+ // Step 1: find mapped_reference_pos_[to]
308
250
auto from_pos = mapped_reference_pos_.at (from);
309
251
TORCH_CHECK (
310
252
TransformReplay::fullSelfMatching (to, from),
@@ -313,8 +255,9 @@ void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
313
255
" to sibling " ,
314
256
to,
315
257
" because this would require replay." );
316
- setCAPos (to, from_pos);
317
258
mapped_reference_pos_[to] = from_pos;
259
+ // Step 2: set CA position of `to`
260
+ setCAPos (to);
318
261
}
319
262
320
263
namespace {
0 commit comments