Skip to content

Commit a43cb20

Browse files
authored
Make inlining even more modular (#2004)
I don't like `TensorView::setComputeAt` and `TensorView::setMaxProducer`, they are private, and I can not use them conveniently. It would be better if there is some public method of `TensorView` that allows directly setting the CA position of a TV with necessary validation. So I added two public methods `TensorView::inlineAt` and `TensorView::updateMaxProducerPosition` and removed `TensorView::setComputeAt` and `TensorView::setMaxProducer`. The `inlineAt` can be safely used publicly. It will not inline into disallowed dimensions, and the max producer position will be kept consistent. There are two ways of using `inlineAt`: If you only want to set the CA position of a single tensor, then simply do ```C++ tv->inlineAt(pos, /*best_effort=*/true); ``` If you want to set the CA position of multiple tensors, then you can do ```C++ MaxPosCalculator calc; for (auto tv : tensors) { tv->inlineAt(pos, /*best_effort=*/true, &calc); } ``` In both case, the max producer position will be updated at the end of the `inlineAt` call. Manually constructing the object of `MaxPosCalculator` is mainly for performance reasons: we don't want to build unmappable dimensions every time we call `inlineAt`. If we want to inline multiple tensors, we should build at the beginning and use it in all `inlineAt` calls. Even though `inlineAt` always updates the max producer position automatically, there are still cases where we want to manually trigger an update of the max producer position, and the `updateMaxProducerPosition` is designed for such a purpose. It is mainly used for grouped reductions. **With `inlineAt`, I can refactor inlining to make it even more modular:** There is no longer an `InlinePropagator`. Innermost inlining is now just a dumb for loop: ```C++ MaxPosCalculator calc; for (auto tv : all_tvs) { tv->inlineAt(-1, /*best_effort=*/true, &calc); } ``` For standard and best effort inlining, we need first to do a propagation to find the positions in each tensor mapped to the given reference tensor's given position. With the positions calculated, inlining is again a dumb for loop.
1 parent dc45835 commit a43cb20

19 files changed

+597
-625
lines changed

build_variables.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ libtorch_cuda_core_sources = [
649649
"torch/csrc/autograd/functions/comm.cpp",
650650
"torch/csrc/jit/codegen/cuda/arith.cpp",
651651
"torch/csrc/jit/codegen/cuda/compute_at.cpp",
652-
"torch/csrc/jit/codegen/cuda/inline_propagator.cpp",
652+
"torch/csrc/jit/codegen/cuda/inlining.cpp",
653653
"torch/csrc/jit/codegen/cuda/compute_at_map.cpp",
654654
"torch/csrc/jit/codegen/cuda/codegen.cpp",
655655
"torch/csrc/jit/codegen/cuda/contiguity.cpp",

torch/csrc/jit/codegen/cuda/compute_at.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,20 +213,21 @@ void ComputeAt::runAt(
213213
auto selected = getPropagationSubgraph(producer, consumer);
214214
ComputeAtSelector selector(selected);
215215

216-
InlinePropagator inline_propagator(
217-
consumer, consumer_position, mode, selector.selected());
218-
219216
MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector);
220217

221218
if (mode == ComputeAtMode::MostInlined) {
222219
MostInlinedTransformPropagator propagator;
223220
path.traverse(&propagator);
221+
inlineMost(selected);
224222
} else {
225223
TransformPropagator propagator(consumer, consumer_position);
226224
path.traverse(&propagator);
225+
inlineSelectedAt(
226+
selected,
227+
consumer,
228+
consumer_position,
229+
mode == ComputeAtMode::BestEffort);
227230
}
228-
229-
path.traverse(&inline_propagator);
230231
}
231232

232233
void ComputeAt::runWith(
@@ -253,19 +254,21 @@ void ComputeAt::runWith(
253254
auto selected = getPropagationSubgraph(producer, consumer);
254255
ComputeAtSelector selector(selected);
255256

256-
InlinePropagator inline_propagator(
257-
producer, producer_position, mode, selector.selected());
258-
259257
MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector);
260258

261259
if (mode == ComputeAtMode::MostInlined) {
262260
MostInlinedTransformPropagator propagator;
263261
path.traverse(&propagator);
262+
inlineMost(selected);
264263
} else {
265264
TransformPropagator propagator(producer, producer_position);
266265
path.traverse(&propagator);
266+
inlineSelectedAt(
267+
selected,
268+
producer,
269+
producer_position,
270+
mode == ComputeAtMode::BestEffort);
267271
}
268-
path.traverse(&inline_propagator);
269272
}
270273

271274
} // namespace cuda

torch/csrc/jit/codegen/cuda/compute_at.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <torch/csrc/jit/codegen/cuda/inline_propagator.h>
3+
#include <torch/csrc/jit/codegen/cuda/inlining.h>
44
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
55
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
66

torch/csrc/jit/codegen/cuda/grouped_reduction.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ bool hasMatchingTransformations(TensorView* ref, TensorView* other) {
3838
}
3939

4040
// Validate grouping of reductions and return a new max producer position
41-
unsigned int validateReductionGrouping(
41+
void validateReductionGrouping(
4242
const std::vector<Val*>& inputs,
4343
const std::vector<Val*>& outputs) {
4444
TORCH_INTERNAL_ASSERT(inputs.size() == outputs.size());
@@ -57,7 +57,6 @@ unsigned int validateReductionGrouping(
5757
const auto num_root_dims = ref_domain.size();
5858
const auto num_dims = ref_tv->nDims();
5959
const auto ref_ca_pos = ref_tv->getComputeAtPosition();
60-
auto max_producer_pos = ref_tv->getMaxProducerPosition();
6160
for (const auto i : c10::irange(inputs.size())) {
6261
auto output_tv = outputs.at(i)->as<TensorView>();
6362
const auto& output_domain = output_tv->getRootDomain();
@@ -136,9 +135,6 @@ unsigned int validateReductionGrouping(
136135
ref_tv->toString(),
137136
". Mismatched tensor: ",
138137
output_tv->toString());
139-
140-
max_producer_pos =
141-
std::max(max_producer_pos, output_tv->getMaxProducerPosition());
142138
}
143139

144140
// Must not have any data dependency from outputs to inputs
@@ -152,8 +148,6 @@ unsigned int validateReductionGrouping(
152148
}
153149
TORCH_INTERNAL_ASSERT(all_dep_vals.empty(), ss.str());
154150
}
155-
156-
return max_producer_pos;
157151
}
158152

159153
} // namespace
@@ -194,14 +188,14 @@ void groupReductions(const std::vector<TensorView*>& reduction_outputs) {
194188
inputs.at(i) = rop->in();
195189
}
196190

197-
auto max_producer_pos = validateReductionGrouping(inputs, outputs);
198-
199-
for (auto output : ir_utils::filterByType<TensorView>(outputs)) {
200-
output->setMaxProducer(max_producer_pos);
201-
}
191+
validateReductionGrouping(inputs, outputs);
202192

203193
IrBuilder::create<GroupedReductionOp>(
204194
container, op_types, init_vals, outputs, inputs);
195+
196+
for (auto output : ir_utils::filterByType<TensorView>(outputs)) {
197+
output->updateMaxProducerPosition();
198+
}
205199
}
206200

207201
} // namespace cuda

0 commit comments

Comments
 (0)