Skip to content

Commit 4cae122

Browse files
authored
schedulePointwise cleanup: - computeAt + InlinePropagator (csarofeen#1815)
1 parent 3df9742 commit 4cae122

File tree

5 files changed

+78
-97
lines changed

5 files changed

+78
-97
lines changed

torch/csrc/jit/codegen/cuda/compute_at.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ void ComputeAt::runAt(
185185
InlinePropagatorSelector selector(selected);
186186

187187
InlinePropagator inline_propagator(
188-
selector.selected(), consumer, consumer_position, mode);
188+
consumer, consumer_position, mode, selector.selected());
189189
MaxProducerPosUpdater updater;
190190

191191
MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector);
@@ -227,7 +227,7 @@ void ComputeAt::runWith(
227227
InlinePropagatorSelector selector(selected);
228228

229229
InlinePropagator inline_propagator(
230-
selector.selected(), producer, producer_position, mode);
230+
producer, producer_position, mode, selector.selected());
231231
MaxProducerPosUpdater updater;
232232

233233
MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector);

torch/csrc/jit/codegen/cuda/inline_propagator.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) {
148148

149149
void InlinePropagator::setCAPos(TensorView* tv) {
150150
size_t pos = mapped_reference_pos_.at(tv);
151-
if (selected_.count(tv) && !tv->isFusionInput()) {
151+
if ((selected_.empty() || selected_.count(tv)) && !tv->isFusionInput()) {
152152
auto max_pos = getMaxPosAll(tv);
153153
if (mode_ == ComputeAtMode::Standard) {
154154
TORCH_INTERNAL_ASSERT(
@@ -171,10 +171,10 @@ void InlinePropagator::setCAPos(TensorView* tv) {
171171
}
172172

173173
InlinePropagator::InlinePropagator(
174-
std::unordered_set<TensorView*> selected,
175174
TensorView* reference,
176175
int64_t reference_pos,
177-
ComputeAtMode mode)
176+
ComputeAtMode mode,
177+
std::unordered_set<TensorView*> selected)
178178
: max_pos_calc(mode),
179179
selected_(std::move(selected)),
180180
reference_(reference),
@@ -213,6 +213,8 @@ void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
213213
to_pos >= 0,
214214
"Unable to propagate CA position from consumer ",
215215
from,
216+
" at ",
217+
from_pos,
216218
" to producer ",
217219
to,
218220
" because this would require replay.");
@@ -240,6 +242,8 @@ void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
240242
to_pos >= 0,
241243
"Unable to propagate CA position from producer ",
242244
from,
245+
" at ",
246+
from_pos,
243247
" to consumer ",
244248
to,
245249
" because this would require replay.");

torch/csrc/jit/codegen/cuda/inline_propagator.h

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ namespace cuda {
1414
// Simple selector that only propagates across tensor views in the provided
1515
// unordered_set. Will also propagate to all consumers of those tensors, and the
1616
// siblings of those tensors.
17-
class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector {
17+
class TORCH_CUDA_CU_API InlinePropagatorSelector
18+
: public MaxInfoSpanningTree::Selector {
1819
std::unordered_set<TensorView*> selected_;
1920

2021
public:
@@ -29,7 +30,7 @@ class InlinePropagatorSelector : public MaxInfoSpanningTree::Selector {
2930
}
3031
};
3132

32-
class MaxPosCalculator {
33+
class TORCH_CUDA_CU_API MaxPosCalculator {
3334
ComputeAtMode mode_ = ComputeAtMode::Standard;
3435

3536
// Root domains in producer that's unmappable to any of its consumers
@@ -67,7 +68,10 @@ class MaxPosCalculator {
6768
MaxPosCalculator(ComputeAtMode mode);
6869
};
6970

70-
class InlinePropagator : public MaxInfoSpanningTree::Propagator {
71+
// Propagate inline position to the `selected` tensors in the DAG. If `selected`
72+
// is not specified or empty, then propagate to the entire DAG.
73+
class TORCH_CUDA_CU_API InlinePropagator
74+
: public MaxInfoSpanningTree::Propagator {
7175
// Checks producers and consumers to see what the maximum position in tv is
7276
// that can be shared across both directions.
7377
size_t getMaxPosAll(TensorView* tv, bool check_siblings = true);
@@ -94,10 +98,20 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {
9498

9599
public:
96100
InlinePropagator(
97-
std::unordered_set<TensorView*> selected,
98101
TensorView* reference,
99102
int64_t reference_pos,
100-
ComputeAtMode mode);
103+
ComputeAtMode mode = ComputeAtMode::Standard,
104+
std::unordered_set<TensorView*> selected = {});
105+
106+
InlinePropagator(
107+
TensorView* reference,
108+
int64_t reference_pos,
109+
std::unordered_set<TensorView*> selected)
110+
: InlinePropagator(
111+
reference,
112+
reference_pos,
113+
ComputeAtMode::Standard,
114+
selected) {}
101115

102116
~InlinePropagator() = default;
103117

@@ -112,7 +126,8 @@ class InlinePropagator : public MaxInfoSpanningTree::Propagator {
112126
// the tensors, and it is not needed to compute the max producer position in a
113127
// specific order. But MaxInfoSpanningTree provides a very convenient API to
114128
// visit the tensors, so I just use it for cleaner code.
115-
class MaxProducerPosUpdater : public MaxInfoSpanningTree::Propagator {
129+
class TORCH_CUDA_CU_API MaxProducerPosUpdater
130+
: public MaxInfoSpanningTree::Propagator {
116131
std::unordered_set<TensorView*> updated_;
117132
void handle(TensorView* tv);
118133

torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp

Lines changed: 35 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise.h>
22

33
#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
4+
#include <torch/csrc/jit/codegen/cuda/inline_propagator.h>
45
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
56
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
67
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
@@ -13,6 +14,9 @@
1314

1415
#include <ATen/cuda/CUDAContext.h>
1516

17+
#include <algorithm>
18+
#include <unordered_map>
19+
1620
namespace torch {
1721
namespace jit {
1822
namespace fuser {
@@ -671,6 +675,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
671675
}
672676
}
673677

678+
int64_t unswitch_pos;
674679
if (params.break_point) {
675680
// 2D parallelization scheme
676681
TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0);
@@ -723,8 +728,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
723728
// [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx]
724729
reference_tv->split(0, 65535);
725730
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
731+
unswitch_pos = 5;
726732
} else {
727733
reference_tv->axis(0)->parallelize(ParallelType::BIDy);
734+
unswitch_pos = 4;
728735
}
729736
} else {
730737
// [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx]
@@ -734,8 +741,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
734741
// [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx]
735742
reference_tv->split(1, 65535);
736743
reference_tv->axis(2)->parallelize(ParallelType::BIDy);
744+
unswitch_pos = 5;
737745
} else {
738746
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
747+
unswitch_pos = 4;
739748
}
740749
}
741750
} else {
@@ -747,8 +756,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
747756
// [i-remainder, BIDy{65535} | BIDx | Unswitch, Unroll, TIDx]
748757
reference_tv->split(0, 65535);
749758
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
759+
unswitch_pos = 4;
750760
} else {
751761
reference_tv->axis(0)->parallelize(ParallelType::BIDy);
762+
unswitch_pos = 3;
752763
}
753764
} else {
754765
// [BIDx | BIDy | Unswitch, Unroll, TIDx]
@@ -757,8 +768,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
757768
// [BIDx | i-remainder, BIDy{65535} | Unswitch, Unroll, TIDx]
758769
reference_tv->split(1, 65535);
759770
reference_tv->axis(2)->parallelize(ParallelType::BIDy);
771+
unswitch_pos = 4;
760772
} else {
761773
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
774+
unswitch_pos = 3;
762775
}
763776
}
764777
}
@@ -803,10 +816,12 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
803816
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
804817
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
805818
}
819+
unswitch_pos = 2;
806820
}
807821

808822
TransformPropagator propagator(reference_tv);
809-
MaxRootDomainInfoSpanningTree(reference_tv).traverse(&propagator);
823+
MaxRootDomainInfoSpanningTree spanning_tree(reference_tv);
824+
spanning_tree.traverse(&propagator);
810825
scheduler_utils::parallelizeAllLike(reference_tv, all_tvs);
811826

812827
if (params.vectorize) {
@@ -841,84 +856,31 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
841856
}
842857
}
843858

844-
// Compute at into cached inputs
845-
std::vector<TensorView*> consumers_of_cached_inputs;
846-
// Cache of input, and one of its consumers
847-
std::vector<std::pair<TensorView*, TensorView*>> input_cache_and_consumer;
848-
{
849-
// Avoid duplicate additions, so track what we add
850-
std::unordered_set<TensorView*> added;
851-
for (auto cached_input : cached_inputs) {
852-
auto consumer_tvs = ir_utils::consumerTvsOf(cached_input);
853-
TORCH_INTERNAL_ASSERT(
854-
consumer_tvs.size(),
855-
"Input was not succesfully filtered out for scheduling but wasn't used.");
856-
857-
// Grab a consumer which will be used for computeAt structure of cached
858-
// input into a consumer
859-
input_cache_and_consumer.emplace_back(
860-
std::make_pair(cached_input, consumer_tvs[0]));
861-
862-
// Grab all consumers which will be used for inlining computeAt for the
863-
// body of the computation (excluding caching inputs/outputs)
864-
for (auto consumer_tv : consumer_tvs) {
865-
// Don't duplicate
866-
if (added.insert(consumer_tv).second) {
867-
consumers_of_cached_inputs.emplace_back(consumer_tv);
868-
}
869-
}
870-
}
871-
}
872-
873-
for (auto entry : input_cache_and_consumer) {
874-
// Compute at inside unswitch position:
875-
auto input_cache = entry.first;
876-
auto input_cache_consumer = entry.second;
877-
878-
auto unswitch_it = std::find_if(
879-
input_cache_consumer->domain()->domain().begin(),
880-
input_cache_consumer->domain()->domain().end(),
881-
[](IterDomain* id) {
882-
return id->getParallelType() == ParallelType::Unswitch;
883-
});
884-
auto unswitch_pos =
885-
unswitch_it == input_cache_consumer->domain()->domain().end()
886-
? -1
887-
: std::distance(
888-
input_cache_consumer->domain()->domain().begin(), unswitch_it) +
889-
1;
859+
// Begin by inlining at the unswitch position for the entire DAG. The cached
860+
// inputs, and outputs will keep this inline position, but other tensors will
861+
// get a higher position in later inline propagation.
862+
InlinePropagator inline_unswitch(
863+
reference_tv, unswitch_pos, ComputeAtMode::BestEffort);
864+
spanning_tree.traverse(&inline_unswitch);
890865

891-
input_cache->computeAt(
892-
input_cache_consumer, unswitch_pos, ComputeAtMode::BestEffort);
866+
// Inline at the inner most position. The CA position of all tensors except
867+
// inputs, cached inputs and outputs will be updated.
868+
std::unordered_set<TensorView*> inner_most_tensors(
869+
all_tvs.begin(), all_tvs.end());
870+
for (auto cached_input : cached_inputs) {
871+
inner_most_tensors.erase(cached_input);
893872
}
894-
895-
// Producers for inlined computeAt
896-
std::vector<TensorView*> compute_from = consumers_of_cached_inputs;
897-
898-
// Consumers for inlined computeAt
899-
std::vector<TensorView*> compute_to;
900-
// Compute at cached outputs
901-
//[BIDx, Unswitch, Vectorization, TIDx]
902873
for (auto entry : cached_outputs) {
903-
auto cached_output = entry.first;
904874
auto output = entry.second;
905-
906-
auto unswitch_it = std::find_if(
907-
output->domain()->domain().begin(),
908-
output->domain()->domain().end(),
909-
[](IterDomain* id) {
910-
return id->getParallelType() == ParallelType::Unswitch;
911-
});
912-
auto unswitch_pos = unswitch_it == output->domain()->domain().end()
913-
? -1
914-
: std::distance(output->domain()->domain().begin(), unswitch_it) + 1;
915-
916-
cached_output->computeAt(output, unswitch_pos, ComputeAtMode::BestEffort);
917-
compute_to.push_back(cached_output);
875+
inner_most_tensors.erase(output);
918876
}
877+
InlinePropagator inline_inner_most(
878+
reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors);
879+
spanning_tree.traverse(&inline_inner_most);
919880

920-
scheduler_utils::computeAtBetween(
921-
compute_from, compute_to, -1, ComputeAtMode::BestEffort);
881+
// Fix max producer position
882+
MaxProducerPosUpdater updater;
883+
spanning_tree.traverse(&updater);
922884
}
923885

924886
} // namespace cuda

torch/csrc/jit/codegen/cuda/test/test_gpu.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,26 +1362,26 @@ TEST_F(NVFuserTest, FusionParser_CUDA) {
13621362
// 2. use a fuzzy compare (ignore non-significant whitespaces for example)
13631363
const std::string expected_kernel = R"(
13641364
__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3) {
1365-
int64_t i51;
1366-
i51 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
1367-
if ((i51 < T0.size[0])) {
1365+
int64_t i50;
1366+
i50 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
1367+
if ((i50 < T0.size[0])) {
13681368
float T5[1];
13691369
T5[0] = 0;
13701370
T5[0]
1371-
= T1[i51];
1371+
= T1[i50];
13721372
float T4[1];
13731373
T4[0] = 0;
13741374
T4[0]
1375-
= T0[i51];
1376-
float T6[1];
1375+
= T0[i50];
13771376
float T2[1];
13781377
T2[0]
13791378
= T4[0]
13801379
* T5[0];
1380+
float T6[1];
13811381
T6[0]
13821382
= T2[0]
13831383
* T4[0];
1384-
T3[i51]
1384+
T3[i50]
13851385
= T6[0];
13861386
}
13871387
}
@@ -19086,18 +19086,17 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) {
1908619086
// 2. use a fuzzy compare (ignore non-significant whitespaces for example)
1908719087
const std::string expected_kernel = R"(
1908819088
__global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) {
19089-
int64_t i172;
19090-
i172 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
19091-
if ((i172 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) {
19089+
int64_t i171;
19090+
i171 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
19091+
if ((i171 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) {
1909219092
__half T9[1];
1909319093
T9[0] = 0;
1909419094
T9[0]
1909519095
= T2[((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) / (T0.size[1] * (T0.size[2] * T0.size[3]))) * ((T0.size[2] * T0.size[1]) * T0.size[3])) + ((((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) % T0.size[3]) * (T0.size[2] * T0.size[1])) + (((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) / (T0.size[2] * T0.size[3])) * T0.size[2]) + (((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) / T0.size[3])];
1909619096
__half T8[1];
1909719097
T8[0] = 0;
1909819098
T8[0]
19099-
= T0[i172];
19100-
__half T10[1];
19099+
= T0[i171];
1910119100
float T3[1];
1910219101
T3[0]
1910319102
= __half2float(T9[0]);
@@ -19114,9 +19113,10 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2,
1911419113
float T6[1];
1911519114
T6[0]
1911619115
= relu(T5[0]);
19116+
__half T10[1];
1911719117
T10[0]
1911819118
= __float2half(T6[0]);
19119-
T7[i172]
19119+
T7[i171]
1912019120
= T10[0];
1912119121
}
1912219122
}

0 commit comments

Comments
 (0)