Skip to content

Commit d0d106a

Browse files
authored
Improve view support on pointwise and transpose scheduler (#1906)
1 parent e71e1ec commit d0d106a

12 files changed

+415
-173
lines changed

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,18 +1933,39 @@ TensorDomain* TensorDomain::view(const AnalyzeViewResult& view_analysis) {
19331933
}
19341934

19351935
TensorDomain* TensorDomain::flatten(int64_t start_dim, int64_t end_dim) {
1936+
auto inp_domain = noReductions(getMaybeRFactorDomain());
1937+
19361938
if (start_dim < 0) {
1937-
start_dim += nDims();
1939+
start_dim += inp_domain.size();
19381940
}
19391941
if (end_dim < 0) {
1940-
end_dim += nDims();
1942+
end_dim += inp_domain.size();
19411943
}
1944+
TORCH_CHECK(
1945+
start_dim >= 0 && start_dim < inp_domain.size(),
1946+
"Invalid start_dim ",
1947+
start_dim);
1948+
TORCH_CHECK(
1949+
end_dim >= 0 && end_dim < inp_domain.size(), "Invalid end_dim ", end_dim);
1950+
TORCH_CHECK(start_dim <= end_dim, "start_dim must be <= end_dim");
19421951

19431952
std::vector<IterDomain*> new_root_domain;
1944-
auto inp_domain = noReductions(getMaybeRFactorDomain());
19451953
new_root_domain.reserve(inp_domain.size());
1946-
for (auto id : inp_domain) {
1947-
new_root_domain.push_back(id->cloneWithoutRFactor());
1954+
for (auto i : c10::irange(inp_domain.size())) {
1955+
bool is_rfactor_dim = i >= start_dim && i <= end_dim;
1956+
auto inp_id = inp_domain[i];
1957+
auto out_id = IterDomainBuilder(inp_id)
1958+
.is_rfactor_domain(is_rfactor_dim)
1959+
.extent(
1960+
(is_rfactor_dim && inp_id->hasExpandedExtent())
1961+
? inp_id->expandedExtent()
1962+
: inp_id->extent())
1963+
.iter_type(
1964+
(is_rfactor_dim && inp_id->isBroadcast())
1965+
? IterType::Iteration
1966+
: inp_id->getIterType())
1967+
.build();
1968+
new_root_domain.push_back(out_id);
19481969
}
19491970

19501971
std::vector<IterDomain*> rfactor_domain;
@@ -1966,7 +1987,7 @@ TensorDomain* TensorDomain::flatten(int64_t start_dim, int64_t end_dim) {
19661987
}
19671988
rfactor_domain.push_back(merged_id);
19681989

1969-
for (auto i : c10::irange(end_dim + 1, nDims())) {
1990+
for (auto i : c10::irange(end_dim + 1, inp_domain.size())) {
19701991
rfactor_domain.push_back(new_root_domain[i]);
19711992
}
19721993

torch/csrc/jit/codegen/cuda/ops/alias.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,19 @@ TensorView* view(
111111
}
112112

113113
TensorView* flatten(TensorView* x, int64_t start_dim, int64_t end_dim) {
114+
auto inp_domain = TensorDomain::noReductions(x->getMaybeRFactorDomain());
114115
if (start_dim < 0) {
115-
start_dim += x->nDims();
116+
start_dim += inp_domain.size();
116117
}
117118
if (end_dim < 0) {
118-
end_dim += x->nDims();
119+
end_dim += inp_domain.size();
119120
}
120121
TORCH_CHECK(
121-
start_dim >= 0 && start_dim < x->nDims(),
122+
start_dim >= 0 && start_dim < inp_domain.size(),
122123
"Invalid start_dim ",
123124
start_dim);
124125
TORCH_CHECK(
125-
end_dim >= 0 && end_dim < x->nDims(), "Invalid end_dim ", end_dim);
126+
end_dim >= 0 && end_dim < inp_domain.size(), "Invalid end_dim ", end_dim);
126127
TORCH_CHECK(start_dim <= end_dim, "start_dim must be <= end_dim");
127128

128129
if (start_dim == end_dim) {

torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
878878
data_cache, [&first_red_tv]() {
879879
return std::make_unique<std::vector<TensorView*>>(
880880
scheduler_utils::getInputsOutputsWithInnerDim(
881-
first_red_tv, true));
881+
first_red_tv, true, true));
882882
});
883883

884884
auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_entry.get();
@@ -888,7 +888,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
888888
data_cache, [&first_red_tv]() {
889889
return std::make_unique<std::vector<TensorView*>>(
890890
scheduler_utils::getInputsOutputsWithInnerDim(
891-
first_red_tv, false));
891+
first_red_tv, false, false));
892892
});
893893

894894
auto& unrollable_inputs_outputs = unrollable_inputs_outputs_entry.get();

torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
157157
data_cache, [&largest_out]() {
158158
return std::make_unique<std::vector<TensorView*>>(
159159
scheduler_utils::getInputsOutputsWithInnerDim(
160-
largest_out, true));
160+
largest_out, true, true));
161161
});
162162

163163
constexpr int64_t kSixteen = 16; // clang tidy
@@ -691,7 +691,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
691691
if (params.vectorize) {
692692
// Grab all tensor views that should be vectorized
693693
auto inputs_outputs =
694-
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true);
694+
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true, true);
695695
std::vector<TensorView*> vectorized_tvs;
696696
bool should_vectorize_reference_tv = false;
697697
for (auto tv : inputs_outputs) {

torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
924924
data_cache, [&reduction_tv]() {
925925
return std::make_unique<std::vector<TensorView*>>(
926926
scheduler_utils::getInputsOutputsWithInnerDim(
927-
reduction_tv, true));
927+
reduction_tv, true, true));
928928
});
929929

930930
auto& vectorizable_inputs_outputs = vectorizable_inputs_outputs_entry.get();
@@ -934,7 +934,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
934934
data_cache, [&reduction_tv]() {
935935
return std::make_unique<std::vector<TensorView*>>(
936936
scheduler_utils::getInputsOutputsWithInnerDim(
937-
reduction_tv, false));
937+
reduction_tv, false, false));
938938
});
939939

940940
auto& unrollable_inputs_outputs = unrollable_inputs_outputs_entry.get();

torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ void multiReductionInliner(
266266

267267
// Grab all tensor views that should be vectorized
268268
auto vectorizable_inputs_outputs =
269-
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true);
269+
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true, true);
270270

271271
auto vectorizable_expr = [](Expr* e) {
272272
return e->isA<UnaryOp>() &&

torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,28 +86,53 @@ class DomainMap : public pointwise_utils::DomainMap {
8686
// The order here must be deterministic, because in transpose heuristics, we
8787
// have `vectorize_factor1` and `vectorize_factor2` and we need to be sure
8888
// that `1` and `2` are assigned to the same group across runs.
89+
//
90+
// In the case where view is present in the graph, there are two cases: if the
91+
// view doesn't touch any inner dimension of any group, then the support of it
92+
// is trivial. In the case where view actually touches an inner-most dim, we
93+
// keep track of the inner-most dimension of view's split and merges.
94+
//
95+
// For example, if you have:
96+
// T0 [2, 3, 5] <-- input
97+
// T1 [2, 5, 3] <-- input
98+
// T2 [2, 5, 3] = transpose(T0) + T1
99+
// T3 [2, 15] = view(T2)
100+
// output <-- T3
101+
//
102+
// Then T3 should be in the same group with T1, and T0 should have
103+
// different group with T1 and T3.
89104
std::vector<std::vector<TensorView*>> groupInputsOutputsByInnerDim() const {
90105
std::vector<std::vector<TensorView*>> groups;
91106
auto output_tvs = ir_utils::filterByType<TensorView>(fusion_->outputs());
92107
auto input_tvs = ir_utils::filterByType<TensorView>(fusion_->inputs());
93-
std::unordered_map<size_t, IterDomain*> group_to_inner_dim_map;
94-
decltype(input_tvs)* tv_filtered_group[2] = {&output_tvs, &input_tvs};
95-
for (auto view : tv_filtered_group) {
96-
for (auto tv : *view) {
97-
auto inner_most_id = scheduler_utils::innerMostRootDim(tv);
98-
bool found = false;
99-
for (auto gi : c10::irange(groups.size())) {
100-
auto& g = groups[gi];
101-
auto group_inner_dim = group_to_inner_dim_map.at(gi);
102-
if (areExactMapped(inner_most_id, group_inner_dim)) {
103-
g.emplace_back(tv);
104-
found = true;
105-
break;
106-
}
108+
std::unordered_set<TensorView*> grouped;
109+
decltype(input_tvs)* tv_filtered_groups[2] = {&output_tvs, &input_tvs};
110+
for (auto tv_filtered_group : tv_filtered_groups) {
111+
for (auto tv : *tv_filtered_group) {
112+
if (grouped.count(tv) > 0) {
113+
continue;
107114
}
108-
if (!found) {
109-
group_to_inner_dim_map[groups.size()] = inner_most_id;
110-
groups.push_back({tv});
115+
groups.emplace_back(std::vector<TensorView*>{tv});
116+
grouped.emplace(tv);
117+
// We only want to grab the inner-most dimension, because we don't want
118+
// tensors with different inner-most dimension to be put in the same
119+
// group. For example, if we have:
120+
// T2[i1, i3*i2] = relu(view(transpose(T1[i1, i2, i3])))
121+
// then we don't want T1 and T2 to be in the same group.
122+
//
123+
// But we don't want to check contiguity. For example, if we have:
124+
// T1[i1, i2, i3] (contiguous) + T2[i1, i2, i3] (discontiguous)
125+
// Then we still want to T1 and T2 to be grouped together.
126+
auto group =
127+
scheduler_utils::getInputsOutputsWithInnerDim(tv, true, false);
128+
for (auto member_tv : group) {
129+
TORCH_INTERNAL_ASSERT(
130+
grouped.count(member_tv) == 0 || member_tv == tv,
131+
"The group of ",
132+
member_tv->toString(),
133+
" is ambiguous. This is likely a bug.");
134+
grouped.emplace(member_tv);
135+
groups.back().emplace_back(member_tv);
111136
}
112137
}
113138
}
@@ -263,6 +288,10 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
263288
runtime_info.getInnerDimVectorizableWidth(tv);
264289
vectorize_factor1 = std::min(vectorize_factor1, tv_vectorize_factor);
265290
}
291+
// TODO: Since group2 only has global->shared and shared->global set op, we
292+
// can have fine-grained control of unroll/vectorization at per tensor level.
293+
// We should not be using a single global vectorize factor for the entire
294+
// group 2
266295
for (auto tv : grouped_inputs_outputs[1]) {
267296
const auto tv_vectorize_factor =
268297
runtime_info.getInnerDimVectorizableWidth(tv);

torch/csrc/jit/codegen/cuda/scheduler/transpose.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ TORCH_CUDA_CU_API LaunchParams scheduleTranspose(
9393

9494
//! Utility for canSchedule interface to check if this fusion has at least two
9595
//! groups, each with a fully broadcasted reference tensor.
96-
bool hasAtLeastTwoValidGroups(Fusion* fusion);
96+
TORCH_CUDA_CU_API bool hasAtLeastTwoValidGroups(Fusion* fusion);
9797

9898
} // namespace cuda
9999
} // namespace fuser

0 commit comments

Comments
 (0)