Skip to content

Cleanup of lower_utils.cpp stage 1: Isolate out GpuLower usage #1989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ Val* getConcreteProducerOffsetWithGather(
Val* window_idx = nullptr;

if (use_concrete_map) {
window_idx = index_map.at(ir_utils::caMapExactConcreteId(window_id));
window_idx = index_map.at(GpuLower::current()->caMap()->getConcreteMappedID(
window_id, IdMappingMode::EXACT));
} else {
window_idx = index_map.at(window_id);
}
Expand Down Expand Up @@ -703,7 +704,9 @@ void IndexCompute::collectIndexIntoPermissiveMap(
auto id_outputs = ir_utils::filterByType<IterDomain>(expr->outputs());
if (std::all_of(
id_outputs.begin(), id_outputs.end(), [this](IterDomain* id) {
return index_map_.count(ir_utils::caMapExactConcreteId(id));
return index_map_.count(
GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT));
})) {
// Visit this expression:
// LoopIndexingAnalysis::traverseFromDomainVals made sure that each
Expand All @@ -715,7 +718,9 @@ void IndexCompute::collectIndexIntoPermissiveMap(
for (auto id : id_inputs) {
// Collect backward pass results from this expression if they are
// made available in by this expression.
auto idx_it = index_map_.find(ir_utils::caMapExactConcreteId(id));
auto idx_it =
index_map_.find(GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT));

if (idx_it != index_map_.end()) {
permissive_index_map_
Expand All @@ -730,7 +735,8 @@ void IndexCompute::collectIndexIntoPermissiveMap(
void IndexCompute::updateIndexMapFromPermissiveMap(const Expr* id_expr) {
auto id_outputs = ir_utils::filterByType<IterDomain>(id_expr->outputs());
for (auto id : id_outputs) {
auto concrete_id = ir_utils::caMapExactConcreteId(id);
auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT);
// Only try to copy index val from permissive map when
// the index is missing.
if (!index_map_.count(concrete_id)) {
Expand Down Expand Up @@ -1506,7 +1512,8 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
// effort which means some domains may be producer's original domains.
std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
for (auto entry : c2p_map) {
auto ref_id = ir_utils::caMapExactConcreteId(entry.first);
auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID(
entry.first, IdMappingMode::EXACT);
auto p_id = entry.second;
if (ref_id->getParallelType() == ParallelType::Vectorize) {
p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
Expand Down Expand Up @@ -1745,7 +1752,8 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
// effort which means some domains may be the originals.
std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
for (auto entry : c2p_index_map) {
auto ref_id = ir_utils::caMapExactConcreteId(entry.first);
auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID(
entry.first, IdMappingMode::EXACT);
auto p_id = entry.second;
if (ref_id->getParallelType() == ParallelType::Vectorize) {
p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
Expand Down
10 changes: 8 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ namespace fuser {
namespace cuda {

namespace {
// Alias used for std::transform
IterDomain* exactConcreteId(IterDomain* id) {
return GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT);
}

//! Checks that the current loop nest is not realizing a serial
//! broadcast so that each index of producer buffer will only
Expand Down Expand Up @@ -83,7 +88,7 @@ bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) {
std::inserter(
producer_exact_concrete_root_ids,
producer_exact_concrete_root_ids.begin()),
ir_utils::caMapExactConcreteId);
exactConcreteId);

// Check if serial loop roots indexes any exact root id's that
// is not within the set of producer's root exact id's. These
Expand All @@ -92,7 +97,8 @@ bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) {
for (auto serial_loop_root :
ir_utils::filterByType<IterDomain>(serial_loop_roots)) {
if (!producer_exact_concrete_root_ids.count(
ir_utils::caMapExactConcreteId(serial_loop_root))) {
GpuLower::current()->caMap()->getConcreteMappedID(
serial_loop_root, IdMappingMode::EXACT))) {
return true;
}
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class AllocationInserter : public kir::ExprMutator {
// info.init_place_before, info.alloc_for_loop, info.alloc_place_before
void fillAllocationInformation(AllocationInformation& info, Expr* expr) {
auto loop_alloc_info =
loop_utils::getAllocInformation(info.buffer, for_loops_);
lower_loop_utils::getAllocInformation(info.buffer, for_loops_);

info.init_for_loop = loop_alloc_info.init_for_loop;
info.alloc_for_loop = loop_alloc_info.alloc_for_loop;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ std::vector<IterDomain*> getLocalDomainOrdering(
std::sort(
merged_domain.begin(),
merged_domain.end(),
IterDomainDependencySorter(
ir_utils::IterDomainDependencySorter(
concrete_id_dependencies, GpuLower::current()->caMap()));
return merged_domain;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ kir::Allocate* IndexLowering::allocateUniqueBuffer(

// No existing allocation found. Create a new one
auto new_buffer =
ir_utils::allocGlobalBufferForGridComm(buffer_size, dtype, zero_init);
lower_utils::allocGlobalBufferForGridComm(buffer_size, dtype, zero_init);

// Keep track of the allocation
alloc_map.emplace(out_tv, new_buffer);
Expand Down
78 changes: 54 additions & 24 deletions torch/csrc/jit/codegen/cuda/lower_index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ IndexingParameters getLinearIndexParameters(

for (auto loop_idx : c10::irange(loops.size())) {
auto loop = loops[loop_idx];
auto index_domain = ir_utils::caMapExactConcreteId(loop_domain[loop_idx]);
auto index_domain = GpuLower::current()->caMap()->getConcreteMappedID(
loop_domain[loop_idx], IdMappingMode::EXACT);
if (loop->isTrivial()) {
// This is useful information in the case of
// MisalignedVectorize and double buffer epilog, etc.
Expand Down Expand Up @@ -149,7 +150,9 @@ IndexingParameters getLinearIndexParameters(

auto loop_id = loop_indexing.loopDomains()[loop_idx];

auto concrete_loop_id = ir_utils::caMapExactConcreteId(loop_id);
auto concrete_loop_id =
GpuLower::current()->caMap()->getConcreteMappedID(
loop_id, IdMappingMode::EXACT);

auto stage_depth =
GpuLower::current()->doubleBufferInfo().getStageDepthFor(
Expand Down Expand Up @@ -186,7 +189,7 @@ IndexingParameters getNonGlobalInitialIndexParameters(
}

auto alloc_tv = index_producer ? producer_tv : consumer_tv;
auto alloc_info = loop_utils::getAllocInformation(
auto alloc_info = lower_utils::getAllocInformation(
alloc_tv, loops, alloc_id_map, index_producer);

std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
Expand Down Expand Up @@ -217,7 +220,9 @@ IndexingParameters getNonGlobalInitialIndexParameters(
auto loop = loops[loop_idx];
auto loop_domain = loop_domains[loop_idx];

auto concrete_loop_domain = ir_utils::caMapExactConcreteId(loop_domain);
auto concrete_loop_domain =
GpuLower::current()->caMap()->getConcreteMappedID(
loop_domain, IdMappingMode::EXACT);

index_parameters.initial_concrete_id_index[concrete_loop_domain] =
loop_to_ind_map.at(loop);
Expand Down Expand Up @@ -399,7 +404,8 @@ IndexingParameters getPredicateInitialIndexParameters(
for (int loop_idx : c10::irange(loops.size())) {
auto loop = loops.at(loop_idx);
auto concrete_loop_domain =
ir_utils::caMapExactConcreteId(loop_domains.at(loop_idx));
GpuLower::current()->caMap()->getConcreteMappedID(
loop_domains.at(loop_idx), IdMappingMode::EXACT);
index_parameters.initial_concrete_id_index[concrete_loop_domain] =
loop_to_ind_map.at(loop);
}
Expand Down Expand Up @@ -566,7 +572,10 @@ LoopIndexingAnalysis::LoopIndexingAnalysis(
// consume each concrete id once so this map is well defined.
for (auto expr : replayed_exprs_) {
for (auto input_id : ir_utils::filterByType<IterDomain>(expr->inputs())) {
concrete_id_to_consumer_[ir_utils::caMapExactConcreteId(input_id)] = expr;
auto concrete_input_id =
GpuLower::current()->caMap()->getConcreteMappedID(
input_id, IdMappingMode::EXACT);
concrete_id_to_consumer_[concrete_input_id] = expr;
}
}

Expand Down Expand Up @@ -598,7 +607,8 @@ void LoopIndexingAnalysis::validateLoopStructure(
for (auto it_i = loops.begin(); it_i != loops.end(); ++it_i) {
// Largely duplicating original logic
auto loop_id = (*it_i)->iter_domain();
auto concrete_loop_id = ir_utils::caMapExactConcreteId(loop_id);
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
loop_id, IdMappingMode::EXACT);

TORCH_INTERNAL_ASSERT(
!concrete_to_loop.count(concrete_loop_id),
Expand Down Expand Up @@ -662,13 +672,22 @@ void LoopIndexingAnalysis::traverseFromDomainVals() {
}

IterDomain* LoopIndexingAnalysis::concretizeAndVisitId(IterDomain* id) {
auto concrete_id = ir_utils::caMapExactConcreteId(id);
auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT);
if (replayed_concrete_ids_.pushBack(concrete_id)) {
concrete_to_original_id_[concrete_id] = id;
}
return concrete_id;
}

namespace {
// Alias used for std::transform
IterDomain* exactConcreteId(IterDomain* id) {
return GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT);
}
} // namespace

void LoopIndexingAnalysis::visitExpr(Expr* expr) {
if (auto swizzle2d = dynamic_cast<Swizzle2D*>(expr)) {
// Swizzle outputs are already forwarded through
Expand Down Expand Up @@ -703,14 +722,14 @@ void LoopIndexingAnalysis::visitExpr(Expr* expr) {
consumed_ids.begin(),
consumed_ids.end(),
std::inserter(consumed_concrete_, consumed_concrete_.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);

auto produced_ids = ir_utils::filterByType<IterDomain>(expr->outputs());
std::transform(
produced_ids.begin(),
produced_ids.end(),
std::inserter(produced_concrete_, produced_concrete_.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);
}

bool LoopIndexingAnalysis::visitIdsAndCheckDuplication(
Expand Down Expand Up @@ -800,7 +819,8 @@ void LoopIndexingAnalysis::constructLoopDomains() {
// will complain for not having all outputs of the traversal.
for (auto id : ir_utils::filterByType<IterDomain>(all_ids_from_root)) {
if (id->uses().empty()) {
loop_domains_.pushBack(ir_utils::caMapExactConcreteId(id));
loop_domains_.pushBack(GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT));
}
}
}
Expand Down Expand Up @@ -880,7 +900,8 @@ IndexFromIdGraph getTensorIndexFromIdGraph(

// Exact id will have to be pulled from consumer side as the
// producer side are replayed ids.
auto exact_concrete_id = ir_utils::caMapExactConcreteId(consumer_id);
auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
consumer_id, IdMappingMode::EXACT);

index_update_map[exact_concrete_id] = target_id;

Expand Down Expand Up @@ -961,7 +982,8 @@ IndexFromIdGraph getPredicateIndexingFromIdGraph(
ir_utils::filterByType<IterDomain>(all_consumer_vals)) {
// Track the non-concrete id we were trying to bind index
// to, whether from producer or consumer.
auto exact_concrete_id = ir_utils::caMapExactConcreteId(consumer_id);
auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
consumer_id, IdMappingMode::EXACT);
index_update_map[exact_concrete_id] = consumer_id;
}

Expand Down Expand Up @@ -1040,7 +1062,8 @@ LoopIndexingTraversal::LoopIndexingTraversal(
auto next_ids =
ir_utils::filterByType<IterDomain>(nextValsInTraversalOrder(expr));
for (auto id : next_ids) {
auto concrete_id = ir_utils::caMapExactConcreteId(id);
auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT);
TORCH_INTERNAL_ASSERT(
concrete_id_to_dependency_.insert(std::make_pair(concrete_id, expr))
.second,
Expand Down Expand Up @@ -1108,7 +1131,8 @@ std::vector<Expr*> LoopIndexingTraversal::getExprList() {
for (auto prev_id :
ir_utils::filterByType<IterDomain>(prevValsInTraversalOrder(top))) {
auto prev_expr_it = concrete_id_to_dependency_.find(
ir_utils::caMapExactConcreteId(prev_id));
GpuLower::current()->caMap()->getConcreteMappedID(
prev_id, IdMappingMode::EXACT));
if (prev_expr_it != concrete_id_to_dependency_.end()) {
auto prev_expr = prev_expr_it->second;
if (!visited.count(prev_expr)) {
Expand Down Expand Up @@ -1145,7 +1169,7 @@ void LoopIndexingAnalysis::collectOutOfLineExprs() {
consumer_tv_->getComputeAtPosition(),
consumer_tv_->domain()->domain().end(),
std::inserter(out_of_line_ids, out_of_line_ids.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);

// Get the original selected list of index expressions
// in reverse topological order.
Expand All @@ -1160,7 +1184,9 @@ void LoopIndexingAnalysis::collectOutOfLineExprs() {
id_outputs.begin(),
id_outputs.end(),
[&out_of_line_ids](IterDomain* id) {
return out_of_line_ids.count(ir_utils::caMapExactConcreteId(id));
return out_of_line_ids.count(
GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT));
})) {
// Record out of line expression
out_of_line_exprs_.push_back(expr);
Expand All @@ -1171,7 +1197,7 @@ void LoopIndexingAnalysis::collectOutOfLineExprs() {
id_inputs.begin(),
id_inputs.end(),
std::inserter(out_of_line_ids, out_of_line_ids.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);
}
}
}
Expand All @@ -1192,14 +1218,14 @@ std::unordered_set<IterDomain*> LoopIndexing::getAllExactConcreteIdSet() const {
out_ids.begin(),
out_ids.end(),
std::inserter(all_id_set, all_id_set.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);

auto in_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
std::transform(
in_ids.begin(),
in_ids.end(),
std::inserter(all_id_set, all_id_set.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);
}
return all_id_set;
}
Expand Down Expand Up @@ -1244,7 +1270,9 @@ class LoopIndexingPreferredPathCompute : public IterVisitor {
}
mapped_id = c_id_it->second;
}
auto concrete_original_id = ir_utils::caMapExactConcreteId(mapped_id);
auto concrete_original_id =
GpuLower::current()->caMap()->getConcreteMappedID(
mapped_id, IdMappingMode::EXACT);
if (all_concrete_ids.count(concrete_original_id)) {
if (original_id->isBroadcast() || original_id->isReduction() ||
original_id->isStride()) {
Expand All @@ -1270,16 +1298,18 @@ class LoopIndexingPreferredPathCompute : public IterVisitor {
all_iter_inputs.begin(),
all_iter_inputs.end(),
[&](IterDomain* inp_id) {
return this->preferred_path_.find(ir_utils::caMapExactConcreteId(
inp_id)) != this->preferred_path_.end();
return this->preferred_path_.find(
GpuLower::current()->caMap()->getConcreteMappedID(
inp_id, IdMappingMode::EXACT)) !=
this->preferred_path_.end();
})) {
auto all_iter_outputs = ir_utils::filterByType<IterDomain>(e->outputs());

std::transform(
all_iter_outputs.begin(),
all_iter_outputs.end(),
std::inserter(preferred_path_, preferred_path_.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);
}
}

Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class WarSyncInserter : private kir::ExprMutator {
auto maybe_aliased_tv = alloc_map_.getRealBuffer(tv);
auto alloc_it = smem_allocations_.find(maybe_aliased_tv);
auto ca_loop =
loop_utils::getAllocInformation(tv, for_loops_).init_for_loop;
lower_utils::getAllocInformation(tv, for_loops_).init_for_loop;
if (alloc_it == smem_allocations_.end()) {
WarMemoryInfo mem_info;
mem_info.ca_loop = ca_loop;
Expand Down Expand Up @@ -486,7 +486,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
Expr* sync_expr = nullptr;
kir::Allocate* maybe_alloc = nullptr;
if (sync_bitmap.hasBID()) {
maybe_alloc = ir_utils::allocGlobalBufferForGridComm(
maybe_alloc = lower_utils::allocGlobalBufferForGridComm(
getGridSyncBufferSize(sync_bitmap), DataType::Int, true);
sync_expr = IrBuilder::create<kir::GridSync>(
sync_bitmap, maybe_alloc->buffer());
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ void LoopNestGenerator::generate(const std::vector<Expr*>& exprs) {
std::sort(
loop_structure.rbegin(),
loop_structure.rend(),
IterDomainDependencySorter(
ir_utils::IterDomainDependencySorter(
concrete_id_dependencies, GpuLower::current()->caMap()));
loop_structures_[tv] = loop_structure;
}
Expand Down
Loading