-
Notifications
You must be signed in to change notification settings - Fork 7
Propagate permissive mapping information into indexing pass #1929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bb8cef1
66765a7
7992309
e9d09fe
f406e23
283bc84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -650,6 +650,8 @@ IndexCompute::IndexCompute( | |
} | ||
|
||
void IndexCompute::run(const LoopIndexing& loop_indexing) { | ||
TORCH_INTERNAL_ASSERT( | ||
concrete_id_pass_, "concrete pass only for this option"); | ||
// Apply loop swizzles if there are any that outputs to | ||
// the loop domains. | ||
// Currently only support loop swizzles that directly output | ||
|
@@ -669,13 +671,80 @@ void IndexCompute::run(const LoopIndexing& loop_indexing) { | |
} | ||
} | ||
|
||
// Resolve the index vals that could be resolved with only | ||
// the loops that consumer_tv doesn't share with any of its | ||
// consumers, i.e. the not-inlined loops that define consumer_tv | ||
// values. | ||
collectIndexIntoPermissiveMap(loop_indexing); | ||
|
||
// Run through the loop indexing expressions and generate | ||
// the indexing integer math for the concrete ids. | ||
for (auto expr : loop_indexing.getBackwardExprList()) { | ||
// Resolve missing values from permissive map. | ||
updateIndexMapFromPermissiveMap(expr); | ||
|
||
handle(expr); | ||
} | ||
} | ||
|
||
void IndexCompute::collectIndexIntoPermissiveMap( | ||
const LoopIndexing& loop_indexing) { | ||
// Visit the expressions that only produces un-inlined iterdomains, | ||
// in reverse topological order. | ||
for (auto expr : loop_indexing.getBackwardOutOfLineExprList()) { | ||
// Compute indexing vals for the expression inputs. | ||
// | ||
// This stage should run before any indexing computation so it could be | ||
// made sure that all index values computed at this stage are | ||
// the ones that can be resolved only with the not-inlined | ||
// iterdomains. | ||
// | ||
auto id_outputs = ir_utils::filterByType<IterDomain>(expr->outputs()); | ||
if (std::all_of( | ||
id_outputs.begin(), id_outputs.end(), [this](IterDomain* id) { | ||
return index_map_.count(ir_utils::caMapExactConcreteId(id)); | ||
})) { | ||
// Visit this expression: | ||
// LoopIndexingAnalysis::traverseFromDomainVals made sure that each | ||
// concrete index is bound exactly once so computing these expressions | ||
// early should still be consistent. | ||
handle(expr); | ||
|
||
auto id_inputs = ir_utils::filterByType<IterDomain>(expr->inputs()); | ||
for (auto id : id_inputs) { | ||
// Collect backward pass results from this expression if they are | ||
// made available in by this expression. | ||
auto idx_it = index_map_.find(ir_utils::caMapExactConcreteId(id)); | ||
|
||
if (idx_it != index_map_.end()) { | ||
permissive_index_map_ | ||
[GpuLower::current()->caMap()->getConcreteMappedID( | ||
id, IdMappingMode::PERMISSIVE)] = idx_it->second; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
void IndexCompute::updateIndexMapFromPermissiveMap(const Expr* id_expr) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Scary, but seemingly better than what exists today. |
||
auto id_outputs = ir_utils::filterByType<IterDomain>(id_expr->outputs()); | ||
for (auto id : id_outputs) { | ||
auto concrete_id = ir_utils::caMapExactConcreteId(id); | ||
// Only try to copy index val from permissive map when | ||
// the index is missing. | ||
if (!index_map_.count(concrete_id)) { | ||
auto permissive_id = GpuLower::current()->caMap()->getConcreteMappedID( | ||
id, IdMappingMode::PERMISSIVE); | ||
// Write the permissive index val into index_map_ if the | ||
// missing value is found here. | ||
auto permissive_it = permissive_index_map_.find(permissive_id); | ||
if (permissive_it != permissive_index_map_.end()) { | ||
index_map_[concrete_id] = permissive_it->second; | ||
} | ||
} | ||
} | ||
} | ||
|
||
void IndexCompute::run() { | ||
const std::vector<Val*> domain_vals( | ||
td_->domain().begin(), td_->domain().end()); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
#include <torch/csrc/jit/codegen/cuda/lower_utils.h> | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDACachingAllocator.h> | ||
#include <torch/torch.h> | ||
|
||
#include <unordered_map> | ||
|
@@ -36,6 +37,10 @@ class NVFuserTest : public ::testing::Test { | |
GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs"; | ||
} | ||
} | ||
|
||
void TearDown() override { | ||
c10::cuda::CUDACachingAllocator::emptyCache(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to do this every time a test is done? Does it involve cudaFree? If so, wouldn't this running the tests slower? |
||
} | ||
}; | ||
|
||
struct ValidationConstants { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supposed to be
that could only be resolved with the loops...
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What you're saying is you're resolving anything you can with the loops that consumer_tv doesn't share
What I'm saying is you're resolving loops that cannot be resolved without the loops that consumer_tv doesn't share with its consumers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Both
only with
andwith only
are true here. Will think about formalizing in follow ups.