Skip to content

Commit 2c9a6c0

Browse files
authored
Add extra configurability to parallelizeAllLike (#1831)
1 parent 3b87896 commit 2c9a6c0

File tree

11 files changed

+178
-93
lines changed

11 files changed

+178
-93
lines changed

torch/csrc/jit/codegen/cuda/ir_utils.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,19 @@ std::vector<TensorView*> allTvs(Fusion* fusion) {
652652
return uniqueEntries<TensorView>(all_tvs);
653653
}
654654

655+
std::vector<TensorView*> allTvsExcept(
656+
Fusion* fusion,
657+
const std::unordered_set<TensorView*>& except) {
658+
auto all_tvs = allTvs(fusion);
659+
std::vector<TensorView*> result;
660+
for (auto tv : all_tvs) {
661+
if (except.count(tv) == 0) {
662+
result.emplace_back(tv);
663+
}
664+
}
665+
return result;
666+
}
667+
655668
std::vector<Expr*> getReductionOps(Fusion* fusion, bool ignore_trivial) {
656669
std::vector<Expr*> red_ops;
657670

torch/csrc/jit/codegen/cuda/ir_utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,12 @@ TORCH_CUDA_CU_API std::vector<TensorView*> outputTvsOf(
282282
// returns all tensor views in fusion that are used between outputs and inputs.
283283
TORCH_CUDA_CU_API std::vector<TensorView*> allTvs(Fusion* fusion);
284284

285+
// returns all tensor views in fusion that are used between outputs and inputs
286+
// except the specified set.
287+
TORCH_CUDA_CU_API std::vector<TensorView*> allTvsExcept(
288+
Fusion* fusion,
289+
const std::unordered_set<TensorView*>& except);
290+
285291
TORCH_CUDA_CU_API std::vector<Expr*> getReductionOps(
286292
Fusion* fusion,
287293
bool ignore_trivial = true);

torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
676676
}
677677

678678
int64_t unswitch_pos;
679+
IterDomain* vectorize_id = nullptr;
679680
if (params.break_point) {
680681
// 2D parallelization scheme
681682
TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0);
@@ -692,9 +693,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
692693
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
693694
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
694695

695-
// Aggressively mark with vectorized and cleanup later. That way we
696-
// don't have to manually specify parallelization outside the reference.
697-
reference_tv->axis(4)->parallelize(ParallelType::Vectorize);
696+
vectorize_id = reference_tv->axis(4);
698697

699698
// [outer, Unswitch | i-remainder, TIDx, Vectorization]
700699
// To make consistent with unrolling:
@@ -797,7 +796,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
797796
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
798797
// Aggressively mark with vectorized and cleanup later. That way we
799798
// don't have to manually specify parallelization outside the reference.
800-
reference_tv->axis(3)->parallelize(ParallelType::Vectorize);
799+
vectorize_id = reference_tv->axis(3);
801800

802801
//[BIDx, TIDx, Unswitch, Vectorization]
803802
// To make consistent with unrolling:
@@ -822,37 +821,32 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
822821
TransformPropagator propagator(reference_tv);
823822
MaxRootDomainInfoSpanningTree spanning_tree(reference_tv);
824823
spanning_tree.traverse(&propagator);
825-
scheduler_utils::parallelizeAllLike(reference_tv, all_tvs);
824+
scheduler_utils::parallelizeAllLike(reference_tv);
826825

827826
if (params.vectorize) {
828827
// Grab all tensor views that should be vectorized
829-
auto vectorized_tvs =
828+
auto inputs_outputs =
830829
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true);
831-
// Going to move inputs to consumers of inputs, need a copy as we'll modify
832-
// the original.
833-
{
834-
auto vectorized_tvs_copy = vectorized_tvs;
835-
for (auto inp : vectorized_tvs_copy) {
836-
if (!inp->isFusionInput()) {
837-
continue;
838-
}
839-
vectorized_tvs.erase(
840-
std::find(vectorized_tvs.begin(), vectorized_tvs.end(), inp));
841-
auto consumer_tvs = ir_utils::consumerTvsOf(inp);
842-
vectorized_tvs.insert(
843-
vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
830+
std::vector<TensorView*> vectorized_tvs;
831+
bool should_vectorize_reference_tv = false;
832+
for (auto tv : inputs_outputs) {
833+
if (!tv->isFusionInput()) {
834+
vectorized_tvs.emplace_back(tv);
835+
continue;
844836
}
845-
}
846-
// Clear vectorize on tensors that shouldn't have it
847-
for (auto tv : all_tvs) {
848-
if (std::find(vectorized_tvs.begin(), vectorized_tvs.end(), tv) ==
849-
vectorized_tvs.end()) {
850-
for (auto id : tv->domain()->domain()) {
851-
if (id->getParallelType() == ParallelType::Vectorize) {
852-
id->parallelize(ParallelType::Serial);
853-
}
854-
}
837+
if (tv == reference_tv) {
838+
should_vectorize_reference_tv = true;
855839
}
840+
// move inputs to consumers of inputs
841+
auto consumer_tvs = ir_utils::consumerTvsOf(tv);
842+
vectorized_tvs.insert(
843+
vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
844+
}
845+
vectorize_id->parallelize(ParallelType::Vectorize);
846+
scheduler_utils::parallelizeAllLike(
847+
reference_tv, vectorized_tvs, {ParallelType::Vectorize});
848+
if (!should_vectorize_reference_tv) {
849+
vectorize_id->parallelize(ParallelType::Serial);
856850
}
857851
}
858852

torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ void multiReductionInliner(
251251
}
252252

253253
// Propagate parallelization
254-
scheduler_utils::parallelizeAllLike(reference_tv, ir_utils::allTvs(fusion));
254+
scheduler_utils::parallelizeAllLike(reference_tv);
255255

256256
// Find iter domains that are mapped to a trivial reduction, these should
257257
// never be inlined.

torch/csrc/jit/codegen/cuda/scheduler/utils.cpp

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,30 +188,53 @@ size_t mergeNonReduction(
188188

189189
void parallelizeAllLike(
190190
TensorView* reference_tv,
191-
const std::vector<TensorView*>& all_tvs) {
191+
int64_t pos,
192+
std::vector<TensorView*> selected_tvs,
193+
const std::unordered_set<ParallelType>& selected_parallel_types,
194+
bool propagate_padding) {
192195
FusionGuard fg(reference_tv->fusion());
193196

197+
if (pos < 0) {
198+
pos += reference_tv->nDims() + 1;
199+
}
200+
TORCH_CHECK(
201+
pos >= 0 && pos <= reference_tv->nDims(),
202+
"parallelizeAllLike called on an position outside valid range.");
203+
204+
std::unordered_map<IterDomain*, IterDomain*> concrete_to_reference_map;
205+
194206
auto ca_map = ComputeAtMap(FusionGuard::getCurFusion());
195207

196-
for (auto id : reference_tv->domain()->domain()) {
197-
ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE)
198-
->parallelize(id->getParallelType());
199-
if (id->hasPaddingToMultipleOfWarp()) {
200-
ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE)
201-
->padToMultipleOfWarp(id->getMaybeSizeAfterPadding());
202-
}
208+
const auto& reference_dom = reference_tv->domain()->domain();
209+
for (auto it = reference_dom.begin(); it != reference_dom.begin() + pos;
210+
it++) {
211+
auto ca_id = ca_map.getConcreteMappedID(*it, IdMappingMode::PERMISSIVE);
212+
concrete_to_reference_map[ca_id] = *it;
203213
}
204214

205-
for (auto tv : all_tvs) {
215+
if (selected_tvs.empty()) {
216+
selected_tvs = ir_utils::allTvs(reference_tv->fusion());
217+
}
218+
for (auto tv : selected_tvs) {
206219
if (tv->isFusionInput()) {
207220
continue;
208221
}
209222
for (const auto i : c10::irange(tv->domain()->domain().size())) {
210223
auto ca_id =
211224
ca_map.getConcreteMappedID(tv->axis(i), IdMappingMode::PERMISSIVE);
212-
tv->axis(i)->parallelize(ca_id->getParallelType());
213-
if (ca_id->hasPaddingToMultipleOfWarp()) {
214-
tv->axis(i)->padToMultipleOfWarp(ca_id->getMaybeSizeAfterPadding());
225+
if (concrete_to_reference_map.count(ca_id) > 0) {
226+
auto reference_id = concrete_to_reference_map.at(ca_id);
227+
auto reference_parallel_type = reference_id->getParallelType();
228+
if (selected_parallel_types.empty() ||
229+
selected_parallel_types.count(reference_parallel_type)) {
230+
tv->axis(i)->parallelize(reference_parallel_type);
231+
}
232+
if (propagate_padding) {
233+
if (reference_id->hasPaddingToMultipleOfWarp()) {
234+
tv->axis(i)->padToMultipleOfWarp(
235+
reference_id->getMaybeSizeAfterPadding());
236+
}
237+
}
215238
}
216239
}
217240
}

torch/csrc/jit/codegen/cuda/scheduler/utils.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,32 @@ size_t mergeNonReduction(
4949
TensorView* tv,
5050
const std::unordered_set<IterDomain*>& dont_merge = {});
5151

52+
// Propagate the parallelization from the selected dimensions of the reference
53+
// tensor to their corresponding dimensions in all selected tensors in the DAG.
54+
// Position `pos` means selecting all the dimensions [0, 1, ..., pos - 1]. pos =
55+
// -1 means selecting all dimensions. `selected_tvs` are selected tensors in the
56+
// DAG. Empty `selected_tvs` means selecting all tensors in the fusion of
57+
// `reference_tv`. `selected_parallel_types` are the selected parallel types.
58+
// Empty `selected_parallel_types` means selecting all parallel types.
5259
TORCH_CUDA_CU_API void parallelizeAllLike(
5360
TensorView* reference_tv,
54-
const std::vector<TensorView*>& all_tvs);
61+
int64_t pos = -1,
62+
std::vector<TensorView*> selected_tvs = {},
63+
const std::unordered_set<ParallelType>& selected_parallel_types = {},
64+
bool propagate_padding = true);
65+
66+
TORCH_CUDA_CU_API inline void parallelizeAllLike(
67+
TensorView* reference_tv,
68+
std::vector<TensorView*> selected_tvs,
69+
const std::unordered_set<ParallelType>& selected_parallel_types = {},
70+
bool propagate_padding = true) {
71+
parallelizeAllLike(
72+
reference_tv,
73+
-1,
74+
std::move(selected_tvs),
75+
selected_parallel_types,
76+
propagate_padding);
77+
}
5578

5679
TORCH_CUDA_CU_API void computeAtInputs(
5780
TensorView* consumer,

torch/csrc/jit/codegen/cuda/test/test_gpu.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13849,7 +13849,7 @@ TEST_F(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) {
1384913849

1385013850
TransformPropagatorWithCheck propagator(tv3);
1385113851
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
13852-
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
13852+
scheduler_utils::parallelizeAllLike(tv3);
1385313853

1385413854
tv0_cache->axis(2)->parallelize(ParallelType::Vectorize);
1385513855
tv1_cache->axis(2)->parallelize(ParallelType::Vectorize);
@@ -17152,7 +17152,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination3_CUDA) {
1715217152

1715317153
tv4->axis(0)->parallelize(ParallelType::BIDx);
1715417154
tv4->axis(1)->parallelize(ParallelType::TIDx);
17155-
scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion));
17155+
scheduler_utils::parallelizeAllLike(tv4);
1715617156

1715717157
GpuLower gpulw(&fusion);
1715817158

@@ -17203,7 +17203,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination4_CUDA) {
1720317203

1720417204
tv1->axis(0)->parallelize(ParallelType::TIDy);
1720517205
tv1->axis(1)->parallelize(ParallelType::TIDx);
17206-
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
17206+
scheduler_utils::parallelizeAllLike(tv1);
1720717207

1720817208
GpuLower gpulw(&fusion);
1720917209

@@ -17252,7 +17252,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination5_CUDA) {
1725217252
auto rtvs2 = tvs2.rFactor({1});
1725317253

1725417254
rtvs2.avg->axis(0)->parallelize(ParallelType::TIDx);
17255-
scheduler_utils::parallelizeAllLike(rtvs2.avg, ir_utils::allTvs(&fusion));
17255+
scheduler_utils::parallelizeAllLike(rtvs2.avg);
1725617256

1725717257
GpuLower gpulw(&fusion);
1725817258

@@ -20392,7 +20392,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering1_CUDA) {
2039220392

2039320393
tv3->axis(-2)->parallelize(ParallelType::BIDx);
2039420394
tv3->axis(-1)->parallelize(ParallelType::TIDx);
20395-
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
20395+
scheduler_utils::parallelizeAllLike(tv3);
2039620396

2039720397
tv1->doubleBuffer();
2039820398

@@ -20430,7 +20430,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering2_CUDA) {
2043020430

2043120431
tv3->axis(-2)->parallelize(ParallelType::BIDx);
2043220432
tv3->axis(-1)->parallelize(ParallelType::TIDx);
20433-
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
20433+
scheduler_utils::parallelizeAllLike(tv3);
2043420434

2043520435
tv1->doubleBuffer();
2043620436

@@ -20479,7 +20479,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering3_CUDA) {
2047920479
tv2->doubleBuffer();
2048020480

2048120481
tv3->axis(-1)->parallelize(ParallelType::TIDx);
20482-
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
20482+
scheduler_utils::parallelizeAllLike(tv3);
2048320483

2048420484
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2048520485
at::manual_seed(0);
@@ -20520,7 +20520,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering4_CUDA) {
2052020520

2052120521
tv3->axis(-1)->parallelize(ParallelType::TIDx);
2052220522
tv3->axis(1)->parallelize(ParallelType::Unswitch);
20523-
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
20523+
scheduler_utils::parallelizeAllLike(tv3);
2052420524

2052520525
tv2->doubleBuffer();
2052620526

@@ -20562,7 +20562,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering5_CUDA) {
2056220562

2056320563
tv2->axis(-1)->parallelize(ParallelType::TIDx);
2056420564
tv2->axis(1)->parallelize(ParallelType::Unswitch);
20565-
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
20565+
scheduler_utils::parallelizeAllLike(tv2);
2056620566

2056720567
tv1->doubleBuffer();
2056820568

@@ -20684,7 +20684,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering8_CUDA) {
2068420684
tv1->computeAt(tv4, 1);
2068520685

2068620686
tv4->axis(-1)->parallelize(ParallelType::TIDx);
20687-
scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion));
20687+
scheduler_utils::parallelizeAllLike(tv4);
2068820688

2068920689
tv2->doubleBuffer();
2069020690
tv3->doubleBuffer();
@@ -20728,7 +20728,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering9_CUDA) {
2072820728
tv3->computeAt(out, -1);
2072920729

2073020730
out->axis(-1)->parallelize(ParallelType::TIDx);
20731-
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
20731+
scheduler_utils::parallelizeAllLike(out);
2073220732

2073320733
tv2->doubleBuffer();
2073420734
tv3->doubleBuffer();
@@ -20806,7 +20806,7 @@ TEST_F(NVFuserTest, FusionSmemBlockGemmCacheDoubleBuffer_CUDA) {
2080620806
tv5->axis(-3)->parallelize(ParallelType::TIDy);
2080720807
tv5->axis(-1)->parallelize(ParallelType::TIDx);
2080820808

20809-
scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
20809+
scheduler_utils::parallelizeAllLike(tv5);
2081020810

2081120811
tv0_cache_local->doubleBuffer();
2081220812
tv1_cache_local->doubleBuffer();
@@ -21170,7 +21170,7 @@ TEST_F(NVFuserTest, FusionIssue1430_CUDA) {
2117021170

2117121171
auto rfactor = ir_utils::rfactorHelper(tv3, {1, 4});
2117221172

21173-
scheduler_utils::parallelizeAllLike(rfactor, ir_utils::allTvs(&fusion));
21173+
scheduler_utils::parallelizeAllLike(rfactor);
2117421174

2117521175
for (auto tv : ir_utils::allTvs(&fusion)) {
2117621176
if (tv != tv1 || tv != tv3) {
@@ -23202,7 +23202,7 @@ TEST_F(NVFuserTest, FusionTestReEntrantGridWelford_CUDA) {
2320223202
TransformPropagatorWithCheck propagator(reduction_tv);
2320323203
MaxRootDomainInfoSpanningTree(reduction_tv).traverse(&propagator);
2320423204
auto rfactor_tv = ir_utils::rfactorHelper(reduction_tv, {4});
23205-
scheduler_utils::parallelizeAllLike(rfactor_tv, ir_utils::allTvs(&fusion));
23205+
scheduler_utils::parallelizeAllLike(rfactor_tv);
2320623206

2320723207
tv0->computeAt(tv_avg, 2);
2320823208
tv0->computeAt(cached_input, -2);

0 commit comments

Comments
 (0)