Skip to content

Commit ef04f6c

Browse files
authored
Coding style cleanups (csarofeen#1798)
Per offline discussion with @csarofeen, this PR does many renaming for better coding style: For all propagation-related things, I am now using the names `P2C` and `C2P` instead of `CasP` and `PasC`. Because "A as B" somewhat implies we want to replay A the same as B, but "B to A" sounds more general and is a better word for this case. Also, I modified the order of function arguments to match the order in its name. For example `PasC` should have `(producer, consumer)` or `(to, from)`, but not `(consumer, producer)` or `(from, to)`, and `C2P` should have `(consumer, producer)` or `(from, to)`, but not `(producer, consumer)` or `(to, from)`.
1 parent 38c7f3c commit ef04f6c

File tree

7 files changed

+97
-103
lines changed

7 files changed

+97
-103
lines changed

torch/csrc/jit/codegen/cuda/inline_propagator.cpp

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ namespace jit {
1010
namespace fuser {
1111
namespace cuda {
1212

13-
bool InlinePropagatorSelector::allowPasC(TensorView* from, TensorView* to) {
13+
bool InlinePropagatorSelector::allowC2P(TensorView* from, TensorView* to) {
1414
return selected_.count(to) > 0;
1515
}
1616

17-
bool InlinePropagatorSelector::allowCasP(TensorView* from, TensorView* to) {
17+
bool InlinePropagatorSelector::allowP2C(TensorView* from, TensorView* to) {
1818
// If the producer is in the selected set, then the consumer must also be
1919
// replayed to obtain a compatible loop structure so that this producer
2020
// can be consumed in this loop.
@@ -112,9 +112,9 @@ size_t MaxPosCalculator::getMaxPosSelf(
112112
// Unrolled dimensions in producer or consumer
113113
// Dimensions derived from root dimensions that exist in both but are
114114
// unmappable
115-
size_t MaxPosCalculator::getMaxPosPasC(
116-
TensorView* producer,
117-
TensorView* consumer) const {
115+
size_t MaxPosCalculator::getMaxPosC2P(
116+
TensorView* consumer,
117+
TensorView* producer) const {
118118
// Limit max position based on vectorized dims in consumer.
119119
auto max_consumer_pos = getMaxPosSelf(consumer, true, false, true);
120120

@@ -144,9 +144,9 @@ size_t MaxPosCalculator::getMaxPosPasC(
144144
// Unrolled dimensions in producer or consumer
145145
// Dimensions derived from root dimensions that exist in both but are
146146
// unmappable
147-
size_t MaxPosCalculator::getMaxPosCasP(
148-
TensorView* consumer,
149-
TensorView* producer) const {
147+
size_t MaxPosCalculator::getMaxPosP2C(
148+
TensorView* producer,
149+
TensorView* consumer) const {
150150
auto max_producer_pos = getMaxPosSelf(producer, false, false, false);
151151

152152
auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer);
@@ -173,16 +173,14 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv) {
173173
for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
174174
// consumers are always replayed consistently
175175
max_pos =
176-
std::min<size_t>(max_pos, max_pos_calc.getMaxPosCasP(consumer_tv, tv));
176+
std::min<size_t>(max_pos, max_pos_calc.getMaxPosP2C(tv, consumer_tv));
177177
}
178178
return max_pos;
179179
}
180180

181-
size_t InlinePropagator::getFromPosPasC(
182-
TensorView* producer,
183-
TensorView* consumer) {
184-
size_t max_pos = max_pos_calc.getMaxPosPasC(producer, consumer);
185-
size_t pos = mapped_reference_pos_.at(consumer);
181+
size_t InlinePropagator::getFromPosC2P(TensorView* from, TensorView* to) {
182+
size_t max_pos = max_pos_calc.getMaxPosC2P(from, to);
183+
size_t pos = mapped_reference_pos_.at(from);
186184

187185
if (mode_ == ComputeAtMode::BestEffort) {
188186
return std::min(pos, max_pos);
@@ -193,21 +191,19 @@ size_t InlinePropagator::getFromPosPasC(
193191
TORCH_INTERNAL_ASSERT(
194192
pos <= max_pos,
195193
"Invalid compute at position detected in compute at when trying to propagate the CA position from consumer: ",
196-
consumer,
194+
from,
197195
" to producer: ",
198-
producer,
196+
to,
199197
" tried to do this at position: ",
200198
pos,
201199
" but max position that's allowed is ",
202200
max_pos);
203201
return pos;
204202
}
205203

206-
size_t InlinePropagator::getFromPosCasP(
207-
TensorView* consumer,
208-
TensorView* producer) {
209-
size_t max_pos = max_pos_calc.getMaxPosCasP(consumer, producer);
210-
size_t pos = mapped_reference_pos_.at(producer);
204+
size_t InlinePropagator::getFromPosP2C(TensorView* from, TensorView* to) {
205+
size_t max_pos = max_pos_calc.getMaxPosP2C(from, to);
206+
size_t pos = mapped_reference_pos_.at(from);
211207

212208
if (mode_ == ComputeAtMode::BestEffort) {
213209
return std::min(pos, max_pos);
@@ -218,9 +214,9 @@ size_t InlinePropagator::getFromPosCasP(
218214
TORCH_INTERNAL_ASSERT(
219215
pos <= max_pos,
220216
"Invalid compute at position detected in compute at when trying to propagate the CA position from producer: ",
221-
producer,
217+
from,
222218
" to consumer: ",
223-
consumer,
219+
to,
224220
" tried to do this at position: ",
225221
pos,
226222
" but max position that's allowed is ",
@@ -263,13 +259,13 @@ InlinePropagator::InlinePropagator(
263259
".");
264260
}
265261

266-
void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) {
262+
void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
267263
if (is_first_) {
268264
is_first_ = false;
269265
setCAPos(reference_, reference_pos_);
270266
mapped_reference_pos_[reference_] = reference_pos_;
271267
}
272-
int from_pos = getFromPosPasC(to, from);
268+
int from_pos = getFromPosC2P(from, to);
273269
auto to_pos =
274270
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
275271
TORCH_CHECK(
@@ -283,13 +279,13 @@ void InlinePropagator::propagateTvPasC(TensorView* from, TensorView* to) {
283279
mapped_reference_pos_[to] = to_pos;
284280
}
285281

286-
void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) {
282+
void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
287283
if (is_first_) {
288284
is_first_ = false;
289285
setCAPos(reference_, reference_pos_);
290286
mapped_reference_pos_[reference_] = reference_pos_;
291287
}
292-
int from_pos = getFromPosCasP(to, from);
288+
int from_pos = getFromPosP2C(from, to);
293289
auto to_pos =
294290
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
295291
TORCH_CHECK(
@@ -303,7 +299,7 @@ void InlinePropagator::propagateTvCasP(TensorView* from, TensorView* to) {
303299
mapped_reference_pos_[to] = to_pos;
304300
}
305301

306-
void InlinePropagator::propagateTvSibling(TensorView* from, TensorView* to) {
302+
void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
307303
if (is_first_) {
308304
is_first_ = false;
309305
setCAPos(reference_, reference_pos_);
@@ -388,11 +384,11 @@ void MaxProducerPosUpdater::handle(TensorView* consumer) {
388384
consumer->setMaxProducer(consumer_pos);
389385
}
390386

391-
void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) {
387+
void MaxProducerPosUpdater::propagateC2P(TensorView* from, TensorView* to) {
392388
if (updated_.empty()) {
393389
// handle the reference tensor
394390
updated_.insert(nullptr);
395-
propagateTvPasC(nullptr, from);
391+
propagateC2P(nullptr, from);
396392
}
397393
for (auto consumer_tv : ir_utils::consumerTvsOf(to)) {
398394
if (updated_.count(consumer_tv) > 0) {
@@ -403,14 +399,12 @@ void MaxProducerPosUpdater::propagateTvPasC(TensorView* from, TensorView* to) {
403399
}
404400
}
405401

406-
void MaxProducerPosUpdater::propagateTvCasP(TensorView* from, TensorView* to) {
407-
propagateTvPasC(from, to);
402+
void MaxProducerPosUpdater::propagateP2C(TensorView* from, TensorView* to) {
403+
propagateC2P(from, to);
408404
}
409405

410-
void MaxProducerPosUpdater::propagateTvSibling(
411-
TensorView* from,
412-
TensorView* to) {
413-
propagateTvPasC(from, to);
406+
void MaxProducerPosUpdater::propagateSibling(TensorView* from, TensorView* to) {
407+
propagateC2P(from, to);
414408
}
415409

416410
} // namespace cuda

torch/csrc/jit/codegen/cuda/inline_propagator.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector {
1818
std::unordered_set<TensorView*> selected_;
1919

2020
public:
21-
virtual bool allowPasC(TensorView* from, TensorView* to) override;
22-
virtual bool allowCasP(TensorView* from, TensorView* to) override;
21+
virtual bool allowC2P(TensorView* from, TensorView* to) override;
22+
virtual bool allowP2C(TensorView* from, TensorView* to) override;
2323
virtual bool allowSibling(TensorView* from, TensorView* to) override;
2424

2525
InlinePropagatorSelector(std::unordered_set<TensorView*> selected)
@@ -60,11 +60,11 @@ class MaxPosCalculator {
6060

6161
// Returns the maximum position producer can be inlined based on consumer
6262
// given the set ComputeAtMode
63-
size_t getMaxPosPasC(TensorView* producer, TensorView* consumer) const;
63+
size_t getMaxPosC2P(TensorView* from, TensorView* to) const;
6464

6565
// Returns the maximum position consumer can be inlined based on producer
6666
// given the set ComputeAtMode
67-
size_t getMaxPosCasP(TensorView* consumer, TensorView* producer) const;
67+
size_t getMaxPosP2C(TensorView* from, TensorView* to) const;
6868

6969
MaxPosCalculator(ComputeAtMode mode);
7070
};
@@ -76,13 +76,13 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {
7676

7777
// Returns the inline position in consumer that producer should be inlined as
7878
// based on consumer, taking into consideration the max possible returned by
79-
// getMaxPos{PasC, CasP}, the compute at mode type.
80-
size_t getFromPosPasC(TensorView* producer, TensorView* consumer);
79+
// getMaxPos{P2C, C2P}, the compute at mode type.
80+
size_t getFromPosC2P(TensorView* from, TensorView* to);
8181

8282
// Returns the inline position in producer that consumer should be inlined as
8383
// based on producer, taking into consideration the max possible returned by
84-
// getMaxPos{PasC, CasP}, the compute at mode type.
85-
size_t getFromPosCasP(TensorView* consumer, TensorView* producer);
84+
// getMaxPos{P2C, C2P}, the compute at mode type.
85+
size_t getFromPosP2C(TensorView* from, TensorView* to);
8686

8787
// We use mapped_reference_pos_ to keep track of the outer axes information of
8888
// the reference tensor. That is, mapped_reference_pos_[tv] answers the
@@ -115,9 +115,9 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {
115115

116116
// Actually propagate the transformations for the inlining pass. Uses the
117117
// functions above to figure out what position to do the propagation at.
118-
virtual void propagateTvPasC(TensorView* from, TensorView* to) override;
119-
virtual void propagateTvCasP(TensorView* from, TensorView* to) override;
120-
virtual void propagateTvSibling(TensorView* from, TensorView* to) override;
118+
virtual void propagateC2P(TensorView* from, TensorView* to) override;
119+
virtual void propagateP2C(TensorView* from, TensorView* to) override;
120+
virtual void propagateSibling(TensorView* from, TensorView* to) override;
121121
};
122122

123123
// This is actually not a propagation, it only sets the max producer position of
@@ -129,9 +129,9 @@ class MaxProducerPosUpdater : public MaxInfoSpanningTree::Propagator {
129129
void handle(TensorView* tv);
130130

131131
public:
132-
virtual void propagateTvPasC(TensorView* from, TensorView* to) override;
133-
virtual void propagateTvCasP(TensorView* from, TensorView* to) override;
134-
virtual void propagateTvSibling(TensorView* from, TensorView* to) override;
132+
virtual void propagateC2P(TensorView* from, TensorView* to) override;
133+
virtual void propagateP2C(TensorView* from, TensorView* to) override;
134+
virtual void propagateSibling(TensorView* from, TensorView* to) override;
135135
};
136136

137137
} // namespace cuda

torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,18 @@ void MaxInfoSpanningTree::compute_spanning_tree() {
6666
}
6767
};
6868

69-
auto allowPasC = [this](TensorView* from, TensorView* to) {
69+
auto allowC2P = [this](TensorView* from, TensorView* to) {
7070
if (selector_ == nullptr) {
7171
return true;
7272
}
73-
return selector_->allowPasC(from, to);
73+
return selector_->allowC2P(from, to);
7474
};
7575

76-
auto allowCasP = [this](TensorView* from, TensorView* to) {
76+
auto allowP2C = [this](TensorView* from, TensorView* to) {
7777
if (selector_ == nullptr) {
7878
return true;
7979
}
80-
return selector_->allowCasP(from, to);
80+
return selector_->allowP2C(from, to);
8181
};
8282

8383
auto allowSibling = [this](TensorView* from, TensorView* to) {
@@ -114,7 +114,7 @@ void MaxInfoSpanningTree::compute_spanning_tree() {
114114
}
115115

116116
for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) {
117-
if (replayed.count(consumer_tv) || !allowCasP(next_hop.to, consumer_tv)) {
117+
if (replayed.count(consumer_tv) || !allowP2C(next_hop.to, consumer_tv)) {
118118
continue;
119119
}
120120
insertNextHop(
@@ -128,7 +128,7 @@ void MaxInfoSpanningTree::compute_spanning_tree() {
128128
}
129129

130130
for (auto producer_tv : ir_utils::producerTvsOf(next_hop.to)) {
131-
if (replayed.count(producer_tv) || !allowPasC(next_hop.to, producer_tv)) {
131+
if (replayed.count(producer_tv) || !allowC2P(next_hop.to, producer_tv)) {
132132
continue;
133133
}
134134
insertNextHop(
@@ -150,13 +150,13 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) {
150150
for (const auto& next_hop : path_) {
151151
switch (next_hop.type) {
152152
case NextHopType::SIBLING:
153-
propagator->propagateTvSibling(next_hop.from, next_hop.to);
153+
propagator->propagateSibling(next_hop.from, next_hop.to);
154154
break;
155155
case NextHopType::C_AS_P:
156-
propagator->propagateTvCasP(next_hop.from, next_hop.to);
156+
propagator->propagateP2C(next_hop.from, next_hop.to);
157157
break;
158158
case NextHopType::P_AS_C:
159-
propagator->propagateTvPasC(next_hop.from, next_hop.to);
159+
propagator->propagateC2P(next_hop.from, next_hop.to);
160160
break;
161161
}
162162
}
@@ -416,20 +416,20 @@ std::shared_ptr<MaxInfoSpanningTree::Information> MaxRootDomainInfoSpanningTree:
416416
return from_info;
417417
}
418418

419-
void SpanningTreePrinter::propagateTvPasC(TensorView* from, TensorView* to) {
420-
stream_ << "propagateTvPasC" << std::endl;
419+
void SpanningTreePrinter::propagateC2P(TensorView* from, TensorView* to) {
420+
stream_ << "propagateC2P" << std::endl;
421421
stream_ << " from: " << from->toString() << std::endl;
422422
stream_ << " to: " << to->toString() << std::endl;
423423
}
424424

425-
void SpanningTreePrinter::propagateTvCasP(TensorView* from, TensorView* to) {
426-
stream_ << "propagateTvCasP" << std::endl;
425+
void SpanningTreePrinter::propagateP2C(TensorView* from, TensorView* to) {
426+
stream_ << "propagateP2C" << std::endl;
427427
stream_ << " from: " << from->toString() << std::endl;
428428
stream_ << " to: " << to->toString() << std::endl;
429429
}
430430

431-
void SpanningTreePrinter::propagateTvSibling(TensorView* from, TensorView* to) {
432-
stream_ << "propagateTvSibling" << std::endl;
431+
void SpanningTreePrinter::propagateSibling(TensorView* from, TensorView* to) {
432+
stream_ << "propagateSibling" << std::endl;
433433
stream_ << " from: " << from->toString() << std::endl;
434434
stream_ << " to: " << to->toString() << std::endl;
435435
}

torch/csrc/jit/codegen/cuda/maxinfo_propagator.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,16 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree {
3939
// Class to subclass in order to stop traversal, by which limits the nodes in
4040
// the spanning tree.
4141
struct Selector {
42-
virtual bool allowPasC(TensorView* from, TensorView* to) = 0;
43-
virtual bool allowCasP(TensorView* from, TensorView* to) = 0;
42+
virtual bool allowC2P(TensorView* from, TensorView* to) = 0;
43+
virtual bool allowP2C(TensorView* from, TensorView* to) = 0;
4444
virtual bool allowSibling(TensorView* from, TensorView* to) = 0;
4545
};
4646

4747
// This is the interface to implement the actual propagation
4848
struct Propagator {
49-
virtual void propagateTvPasC(TensorView* from, TensorView* to) = 0;
50-
virtual void propagateTvCasP(TensorView* from, TensorView* to) = 0;
51-
virtual void propagateTvSibling(TensorView* from, TensorView* to) = 0;
49+
virtual void propagateC2P(TensorView* from, TensorView* to) = 0;
50+
virtual void propagateP2C(TensorView* from, TensorView* to) = 0;
51+
virtual void propagateSibling(TensorView* from, TensorView* to) = 0;
5252
};
5353

5454
// This is the interface that specifies the structure of information used to
@@ -237,9 +237,9 @@ class TORCH_CUDA_CU_API SpanningTreePrinter
237237
std::ostream& stream_;
238238

239239
public:
240-
virtual void propagateTvPasC(TensorView* from, TensorView* to) override;
241-
virtual void propagateTvCasP(TensorView* from, TensorView* to) override;
242-
virtual void propagateTvSibling(TensorView* from, TensorView* to) override;
240+
virtual void propagateC2P(TensorView* from, TensorView* to) override;
241+
virtual void propagateP2C(TensorView* from, TensorView* to) override;
242+
virtual void propagateSibling(TensorView* from, TensorView* to) override;
243243
SpanningTreePrinter(std::ostream& stream = std::cout) : stream_(stream) {}
244244
};
245245

0 commit comments

Comments
 (0)