Skip to content

Move scheduler vectorize utilities into their own file #1959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 47 commits into from
Sep 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
11cadb9
Expand+Reduction, Expand+View support, rework View analysis and guard…
csarofeen Aug 1, 2022
276369b
Minor fixes
csarofeen Aug 1, 2022
973e4f5
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 3, 2022
c230d23
Minor rework of view cache system, add test cases for view cache.
csarofeen Aug 3, 2022
c423ecf
Start adding view to scheduling and fusions.
csarofeen Aug 4, 2022
d295fe9
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 4, 2022
f0e4887
Merge branch 'view_fixes' into view_schedule
csarofeen Aug 4, 2022
8ce26ff
Missed an isEnabled switch.
csarofeen Aug 5, 2022
34ae3f9
Merge branch 'view_fixes' into view_schedule
csarofeen Aug 5, 2022
c443502
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 10, 2022
6d142c9
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 15, 2022
c5dfc54
Map through view operations.
csarofeen Aug 18, 2022
98f5138
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 18, 2022
8596126
Merge branch 'view_mapping' into view_schedule
csarofeen Aug 18, 2022
0424dc6
Minor warning fix.
csarofeen Aug 18, 2022
5e2ffde
View pointwise and 2D pointwise scheduling drafted. Still needs to be…
csarofeen Aug 24, 2022
4584ce1
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 24, 2022
cd85707
minor fix.
csarofeen Aug 25, 2022
cd776f2
View compute at mapping fix.
csarofeen Aug 26, 2022
0bc2a92
Pointwise schedule fix.
csarofeen Aug 26, 2022
1739702
Lint.
csarofeen Aug 26, 2022
fdbd14b
Minor refactor.
csarofeen Aug 26, 2022
3906da0
View scheduling tests.
csarofeen Aug 26, 2022
226acfb
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 26, 2022
01a73e9
Mapping fix for view.
csarofeen Aug 26, 2022
f03ac36
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 26, 2022
5728d87
Merge branch 'view_mapping' into view_schedule
csarofeen Aug 26, 2022
33b28d3
Minor text fix.
csarofeen Aug 26, 2022
20a6b23
Comment cleanup.
csarofeen Aug 26, 2022
941818f
Merge branch 'view_mapping' into view_schedule
csarofeen Aug 26, 2022
50d8a38
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 26, 2022
afab725
Update python tests.
csarofeen Aug 26, 2022
3542591
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 26, 2022
fc3dd9d
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 26, 2022
cd115ae
Disable compute at root mapping check on mulitple mapped domains.
csarofeen Aug 28, 2022
77ec1bc
Update torch/csrc/jit/codegen/cuda/scheduler/pointwise.h
csarofeen Aug 30, 2022
a8104f5
Update torch/csrc/jit/codegen/cuda/scheduler/pointwise.h
csarofeen Aug 30, 2022
930b22d
Minor cleanup.
csarofeen Aug 31, 2022
975a93c
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Aug 31, 2022
6f39979
Merge branch 'view_schedule' of https://www.github.com/csarofeen/pyto…
csarofeen Aug 31, 2022
111dcef
Comments.
csarofeen Sep 3, 2022
be28867
Merge branch 'devel' of https://www.github.com/csarofeen/pytorch into…
csarofeen Sep 3, 2022
0df1a73
Move schedule vectorize helper to separate file.
csarofeen Sep 5, 2022
ca65845
Merge branch 'devel' into move_vec_help
zasdfgbnm Sep 10, 2022
6e43941
no enablement
zasdfgbnm Sep 10, 2022
22c8702
cleanup
zasdfgbnm Sep 10, 2022
dc7d506
no registry.cpp
zasdfgbnm Sep 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/registry.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/utils.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp",
"torch/csrc/jit/codegen/cuda/type_inference.cpp",
"torch/csrc/jit/codegen/cuda/type_promotion.cpp",
"torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp",
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
}

// Try expanding vectorization to contig merged domains
vectorize_factor = scheduler_utils::expandVectorizationToContigMergedDomains(
vectorize_factor = vectorize_helper::expandVectorizationToContigMergedDomains(
fusion,
runtime_info,
vectorizable_inputs_outputs,
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
// TODO: This is an expensive function that shouldn't be in heuristics without
// caching.
auto expanded_vector_word_size =
scheduler_utils::expandVectorizationToContigMergedDomains(
vectorize_helper::expandVectorizationToContigMergedDomains(
fusion,
runtime_info,
vectorizable_inputs_outputs,
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
}

// Try expanding vectorization to contig merged domains
vectorize_factor = scheduler_utils::expandVectorizationToContigMergedDomains(
vectorize_factor = vectorize_helper::expandVectorizationToContigMergedDomains(
fusion,
runtime_info,
vectorizable_inputs_outputs,
Expand Down
260 changes: 0 additions & 260 deletions torch/csrc/jit/codegen/cuda/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1620,89 +1620,6 @@ BroadcastMultipleInformation getBroadcastMultiples(
return bcast_info;
}

size_t collectMaxVectorizeSizeWithContigMerge(
TensorView* tv,
IterDomain* leaf_merged_domain,
size_t max_vector_size_in_byte,
ExpressionEvaluator& expression_evaluator,
DataType index_type) {
// Maybe too conservative, but only handles fully contiguous tensors
// TODO: Relax the contiguity constraint to be similar to that in index
// computing. Just looking for all merged root domains in the right order,
// all merged root dimensions are contiguous, all merged root dimensions are
// next to eachother (exlcuding broadcast).
if (std::any_of(
tv->domain()->contiguity().begin(),
tv->domain()->contiguity().end(),
[](const auto contig) { return !contig; })) {
return 1;
}

auto dtype_size = dataTypeSize(tv->dtype(), index_type);
const size_t max_vector_size = max_vector_size_in_byte / dtype_size;

// Assume no halo-related expression appears in the fusion. No
// broadcast is merged, so indexability can be assumed to be true.
ContigIDs contigIds(
{leaf_merged_domain},
tv->getMaybeRFactorDomain(),
tv->domain()->contiguity(),
{},
{},
true,
true);

auto innermost_root_id = tv->getMaybeRFactorDomain().back();
auto indexed_id = contigIds.rootToIndexedID().at(innermost_root_id);

size_t merged_size = 1;
// If the indexed ID is a contig merged domain, i.e., it is
// different from innermost_root_id, we accumulate the extents of
// all the root domains covered by the contig indexed ID. Otherwise,
// just look at the extent of the innermost root ID.
if (indexed_id != innermost_root_id) {
const auto& within_root = contigIds.withinContigIDs().at(indexed_id);
for (auto root_id : tv->getMaybeRFactorDomain()) {
if (within_root.find(root_id) == within_root.end()) {
continue;
}
auto maybe_dimension_size =
expression_evaluator.evaluate(root_id->extent());
TORCH_INTERNAL_ASSERT(
maybe_dimension_size.has_value(),
"Unknown extent of tv: ",
tv->toString(),
", id: ",
root_id->toString());
merged_size *= maybe_dimension_size->as<int64_t>();
}
} else {
auto maybe_dimension_size =
expression_evaluator.evaluate(innermost_root_id->extent());
TORCH_INTERNAL_ASSERT(
maybe_dimension_size.has_value(),
"Unknown extent of tv: ",
tv->toString(),
", id: ",
innermost_root_id->toString());
merged_size = maybe_dimension_size->as<int64_t>();
}

size_t vector_size = 1;
size_t next_vector_size = vector_size * 2;

// Try until vector size exceeds the max allowed size
while (next_vector_size <= max_vector_size) {
if (merged_size % next_vector_size != 0) {
break;
}
vector_size = next_vector_size;
next_vector_size *= 2;
}

return vector_size;
}

namespace matmul_utils {

void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) {
Expand Down Expand Up @@ -2260,183 +2177,6 @@ void BoundedDirectionalTransformPropagator::bothWays(
propagate(from, pos, included_tvs, *options);
}

// Grab all values and expressions used to make the merged_domain and remove
// them from the fusion
void cleanUpInnermostMergedDomains(
const std::vector<IterDomain*>& root_domain,
IterDomain* merged_domain) {
TORCH_INTERNAL_ASSERT(merged_domain != nullptr);
TORCH_INTERNAL_ASSERT(!root_domain.empty());

std::unordered_set<Val*> root_set({root_domain.begin(), root_domain.end()});

auto vals = DependencyCheck::getAllValsBetween(root_set, {merged_domain});

for (auto it = vals.rbegin(); it != vals.rend(); ++it) {
TORCH_INTERNAL_ASSERT((*it)->isA<IterDomain>());
auto id = (*it)->as<IterDomain>();
if (root_set.find(id) != root_set.end()) {
continue;
}
Fusion* fusion = id->container()->as<Fusion>();
auto id_def = id->definition();
TORCH_INTERNAL_ASSERT(
id_def->isA<Merge>(),
"Invalid ID: ",
id->toString(),
". Expected definition of a Merge expression: ",
(id_def != nullptr ? id_def->toString() : "nullptr"));
fusion->removeExpr(id_def);
fusion->removeVal(id);
}
}

// Merge innermost domains for finding the widest vectorizable
// size. Return the merged domain or nullptr if no merge is done.
IterDomain* mergeInnermostDomains(
const std::vector<IterDomain*>& domain,
int num_merged_domains) {
const auto ndims = domain.size();
IterDomain* merged_id = nullptr;
bool is_merge_done = false;
for (const auto i : c10::irange(num_merged_domains)) {
auto id = domain.at(ndims - 1 - i);
// broadcast and trivial reductions are ignored
if (id->isBroadcast() || id->isTrivialReduction()) {
continue;
}
if (merged_id == nullptr) {
merged_id = id;
} else {
auto id_inner = merged_id;
auto id_outer = id;
merged_id = IterDomain::merge(id_outer, id_inner);
is_merge_done = true;
}
}
return is_merge_done ? merged_id : nullptr;
}

//! Attempt to expand vectorized domains to contig merged domains. Break point
//! identifies the point in which you can't propagate contiguous merges. For
//! example in pointwise this is the point where we want to split the
//! parallelization to take advantage of broadcast, and for reduction
//! schedulers it's the point where we switch from a reduction domain to an
//! iter domain (or vice versa).
size_t expandVectorizationToContigMergedDomains(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
const std::vector<TensorView*> vectorizable_inputs_outputs,
TensorView* reference_tv,
int break_point,
size_t default_word_size) {
size_t max_expand_size = SchedulerRuntimeInfo::max_alignment_size_in_byte;
size_t common_alignment_size =
SchedulerRuntimeInfo::max_alignment_size_in_byte;

for (auto inp_out : vectorizable_inputs_outputs) {
auto dtype_size = dataTypeSize(
inp_out->dtype(), indexModeToDtype(runtime_info.getIndexMode()));

max_expand_size = std::min(
max_expand_size,
SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_size);
max_expand_size = std::min(
max_expand_size, runtime_info.getMaxVectorizableWidth(inp_out));
common_alignment_size =
std::min(common_alignment_size, runtime_info.getAlignmentSize(inp_out));
}

// If there's no possibility to increase vector size of provided tensors,
// then don't bother doing a more complex analysis to try and do so, just
// return early.
if (max_expand_size == default_word_size) {
return default_word_size;
}

auto ca_map = ComputeAtMap(fusion);

// Merge the domains right of the break point
const auto& ref_root = reference_tv->getMaybeRFactorDomain();
const int num_merged_domains =
static_cast<int>(ref_root.size()) - static_cast<int>(break_point);

// No expansion with no merged domain
if (num_merged_domains == 0) {
return default_word_size;
}

// Merge the domains but don't modify TensorDomain
auto merged_domain = mergeInnermostDomains(ref_root, num_merged_domains);

// No expansion is done if no merge is done.
if (merged_domain == nullptr) {
return default_word_size;
}

// Find the vectorizable word size with the merged domains
size_t word_size = scheduler_utils::collectMaxVectorizeSizeWithContigMerge(
reference_tv,
merged_domain,
common_alignment_size,
runtime_info.expressionEvaluator(),
indexModeToDtype(runtime_info.getIndexMode()));

cleanUpInnermostMergedDomains(ref_root, merged_domain);

// Stop if the reference doesn't get a larger word size.
if (word_size <= default_word_size) {
return default_word_size;
}

// Check the other TVs and take the minimum of the valid word sizes
for (const auto tv : vectorizable_inputs_outputs) {
if (tv == reference_tv) {
continue;
}

const auto& tv_root = tv->getMaybeRFactorDomain();

int tv_num_merged_domains = 0;
for (const auto i : c10::irange(num_merged_domains)) {
if (i == tv_root.size()) {
break;
}
auto ref_id = ref_root.at(ref_root.size() - 1 - i);
IterDomain* tv_id = tv_root.at(tv_root.size() - 1 - i);
// If not mapped, stop expanding.
if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT)) {
break;
} else {
++tv_num_merged_domains;
}
}

size_t tv_word_size = 1;
if (tv_num_merged_domains > 1) {
auto tv_merged_domain =
mergeInnermostDomains(tv_root, tv_num_merged_domains);
if (tv_merged_domain == nullptr) {
tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv);
} else {
tv_word_size = scheduler_utils::collectMaxVectorizeSizeWithContigMerge(
tv,
tv_merged_domain,
common_alignment_size,
runtime_info.expressionEvaluator(),
indexModeToDtype(runtime_info.getIndexMode()));
cleanUpInnermostMergedDomains(tv_root, tv_merged_domain);
}
} else {
tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv);
}

word_size = std::min(word_size, tv_word_size);
}

return word_size;
}

DisjointSets<IterDomain*> disjointViewSets(Fusion* fusion) {
// Start from the exact iter domain graph of the fusion
IterDomainGraph id_graph(fusion);
Expand Down
Loading