Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,8 @@ IndexCompute::IndexCompute(
}

void IndexCompute::run(const LoopIndexing& loop_indexing) {
TORCH_INTERNAL_ASSERT(
concrete_id_pass_, "concrete pass only for this option");
// Apply loop swizzles if there are any that outputs to
// the loop domains.
// Currently only support loop swizzles that directly output
Expand All @@ -669,13 +671,80 @@ void IndexCompute::run(const LoopIndexing& loop_indexing) {
}
}

// Resolve the index vals that could be resolved with only
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be that could only be resolved with the loops...?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you're saying is you're resolving anything you can with the loops that consumer_tv doesn't share
What I'm saying is you're resolving loops that cannot be resolved without the loops that consumer_tv doesn't share with its consumers

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Both only with and with only are true here. Will think about formalizing in follow ups.

// the loops that consumer_tv doesn't share with any of its
// consumers, i.e. the not-inlined loops that define consumer_tv
// values.
collectIndexIntoPermissiveMap(loop_indexing);

// Run through the loop indexing expressions and generate
// the indexing integer math for the concrete ids.
for (auto expr : loop_indexing.getBackwardExprList()) {
// Resolve missing values from permissive map.
updateIndexMapFromPermissiveMap(expr);

handle(expr);
}
}

void IndexCompute::collectIndexIntoPermissiveMap(
const LoopIndexing& loop_indexing) {
// Visit the expressions that only produces un-inlined iterdomains,
// in reverse topological order.
for (auto expr : loop_indexing.getBackwardOutOfLineExprList()) {
// Compute indexing vals for the expression inputs.
//
// This stage should run before any indexing computation so it could be
// made sure that all index values computed at this stage are
// the ones that can be resolved only with the not-inlined
// iterdomains.
//
auto id_outputs = ir_utils::filterByType<IterDomain>(expr->outputs());
if (std::all_of(
id_outputs.begin(), id_outputs.end(), [this](IterDomain* id) {
return index_map_.count(ir_utils::caMapExactConcreteId(id));
})) {
// Visit this expression:
// LoopIndexingAnalysis::traverseFromDomainVals made sure that each
// concrete index is bound exactly once so computing these expressions
// early should still be consistent.
handle(expr);

auto id_inputs = ir_utils::filterByType<IterDomain>(expr->inputs());
for (auto id : id_inputs) {
// Collect backward pass results from this expression if they are
// made available in by this expression.
auto idx_it = index_map_.find(ir_utils::caMapExactConcreteId(id));

if (idx_it != index_map_.end()) {
permissive_index_map_
[GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::PERMISSIVE)] = idx_it->second;
}
}
}
}
}

void IndexCompute::updateIndexMapFromPermissiveMap(const Expr* id_expr) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scary, but seemingly better than what exists today.

auto id_outputs = ir_utils::filterByType<IterDomain>(id_expr->outputs());
for (auto id : id_outputs) {
auto concrete_id = ir_utils::caMapExactConcreteId(id);
// Only try to copy index val from permissive map when
// the index is missing.
if (!index_map_.count(concrete_id)) {
auto permissive_id = GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::PERMISSIVE);
// Write the permissive index val into index_map_ if the
// missing value is found here.
auto permissive_it = permissive_index_map_.find(permissive_id);
if (permissive_it != permissive_index_map_.end()) {
index_map_[concrete_id] = permissive_it->second;
}
}
}
}

void IndexCompute::run() {
const std::vector<Val*> domain_vals(
td_->domain().begin(), td_->domain().end());
Expand Down
22 changes: 22 additions & 0 deletions torch/csrc/jit/codegen/cuda/index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ class IndexCompute : public BackwardVisitor {
//! based traversal.
IterDomain* maybeGetExactMapConcreteID(IterDomain* id);

//! (Concrete indexing pass only)
//! Collect permissive index binding from the given expression.
//! See also permissive_map_ and LoopIndexing::getBackwardOutOfLineExprList.
void collectIndexIntoPermissiveMap(const LoopIndexing& loop_indexing);

//! (Concrete indexing pass only)
//! Iterate through id_expr's input and pull index vals from permissive
//! map, when both of the following are true:
//! 1. the output id is missing in index_map_.
//! 2. the output id is found in permissive map.
void updateIndexMapFromPermissiveMap(const Expr* id_expr);

// Tensor domain we're mapping back to root
const TensorDomain* td_; // NOLINT

Expand Down Expand Up @@ -137,6 +149,16 @@ class IndexCompute : public BackwardVisitor {
// pass. See also [Note on swizzle mode]
SwizzleMode swizzle_mode_ = SwizzleMode::NoSwizzle;

// (Concrete id pass only)
// Contains the indexing math that could be resolved with only the
// iterdomains on the right of the consumer_tv's ca axis, i.e. the
// ones that corresponding to the loops that consumer_tv would not
// share with any of its consumers.
// These indexing vals should be kept separate from index_map_ and
// should only be used when the indexing traversal follows the
// order defined in LoopIndexingAnalysis::traverseFromDomainVals.
std::unordered_map<IterDomain*, Val*> permissive_index_map_;

public:
const std::unordered_map<IterDomain*, Val*>& indexMap() const {
return index_map_;
Expand Down
57 changes: 57 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ class LoopIndexingAnalysis {
indexing.loop_root_ = loop_root_domains_;
indexing.loop_domains_ = loop_domains_.vector();
indexing.index_exprs_ = replayed_exprs_;
indexing.out_of_line_exprs_ = out_of_line_exprs_;
return indexing;
}

Expand Down Expand Up @@ -481,6 +482,12 @@ class LoopIndexingAnalysis {
//! loop_domains_ with all of these iter domains.
void constructLoopDomains();

//! Fills out_of_line_exprs_ by traversing the selected list of
//! expressions in reverse topological order and collect iterdomains
//! on the indexing paths that only involves leaf id's on the right
//! of consumer's ca axis.
void collectOutOfLineExprs();

private:
//! Original loop nest input to derive info from.
const std::vector<kir::ForLoop*>& loops_;
Expand Down Expand Up @@ -521,6 +528,10 @@ class LoopIndexingAnalysis {
//! Selected list of exprs that will produce and consume each
//! of the exact concrete ids from the loop nest exactly once.
std::vector<Expr*> replayed_exprs_;

//! Set of expressions from the selected list that can be
//! resolved from axes on the right of ca axes.
std::vector<Expr*> out_of_line_exprs_;
};

LoopIndexingAnalysis::LoopIndexingAnalysis(
Expand Down Expand Up @@ -559,6 +570,10 @@ LoopIndexingAnalysis::LoopIndexingAnalysis(
// Reconstruct the iterdomain view of the original loopnest after resolving
// the exact definition of each index.
constructLoopDomains();

//! Collect the set of indexing expressions that can be
//! resolved out of line.
collectOutOfLineExprs();
}

void LoopIndexingAnalysis::validateLoopStructure(
Expand Down Expand Up @@ -1088,6 +1103,48 @@ std::vector<Expr*> LoopIndexingTraversal::getExprList() {

} // namespace

void LoopIndexingAnalysis::collectOutOfLineExprs() {
// Keep track of all the id's that can be resolved without
// iterdomains on the left of ca axes.
std::unordered_set<IterDomain*> out_of_line_ids;

// Start the set with all the leaf ids.
std::transform(
consumer_tv_->domain()->domain().begin() +
consumer_tv_->getComputeAtPosition(),
consumer_tv_->domain()->domain().end(),
std::inserter(out_of_line_ids, out_of_line_ids.end()),
ir_utils::caMapExactConcreteId);

// Get the original selected list of index expressions
// in reverse topological order.
auto backward_expr_list =
LoopIndexingTraversal::backwardTopologicalOrder(replayed_exprs_);

for (auto expr : backward_expr_list) {
auto id_outputs = ir_utils::filterByType<IterDomain>(expr->outputs());
if (
// Check that all of the outputs are out of line
std::all_of(
id_outputs.begin(),
id_outputs.end(),
[&out_of_line_ids](IterDomain* id) {
return out_of_line_ids.count(ir_utils::caMapExactConcreteId(id));
})) {
// Record out of line expression
out_of_line_exprs_.push_back(expr);

// Add all of the expression inputs as out of line id's.
auto id_inputs = ir_utils::filterByType<IterDomain>(expr->inputs());
std::transform(
id_inputs.begin(),
id_inputs.end(),
std::inserter(out_of_line_ids, out_of_line_ids.end()),
ir_utils::caMapExactConcreteId);
}
}
}

std::vector<Expr*> LoopIndexing::getForwardExprList() const {
return LoopIndexingTraversal::forwardTopologicalOrder(index_exprs_);
}
Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ class LoopIndexing {
//! topological order.
std::vector<Expr*> getBackwardExprList() const;

//! Returns the set of out of line expressions in
//! reverse topological order.
const std::vector<Expr*>& getBackwardOutOfLineExprList() const {
return out_of_line_exprs_;
}

//! Returns all exact concrete id's that were produced
//! or consumed in the selected indexing expressions
std::unordered_set<IterDomain*> getAllExactConcreteIdSet() const;
Expand All @@ -152,6 +158,12 @@ class LoopIndexing {
//! The selected sequence of expressions that should represent
//! the correct indexing math from the given loop nest.
std::vector<Expr*> index_exprs_;

//! The subset of sequence of expressions that can be resolved
//! with only the iterdomains on the right of consumer tv's ca
//! axis.
//! Expressions are ordered in reverse topological order.
std::vector<Expr*> out_of_line_exprs_;
};

// When indexing there are sometimes an option to propagate an index down
Expand Down
63 changes: 63 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25528,6 +25528,69 @@ TEST_F(NVFuserTest, FusionSizeDependentData_CUDA) {
executor_cache.fusion(), cg_outputs, {a}, {a + 123}, __LINE__, __FILE__);
}

// Repro for issue #1925
TEST_F(NVFuserTest, FusionScheduleTransposeRepro1_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(4);
auto tv1 = makeConcreteTensor({-1, -1, -1, 1});
fusion.addInput(tv0);
fusion.addInput(tv1);
auto tv2 = add(tv0, tv1);
fusion.addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input0 = at::randn({1, 1, 333, 1}, options);
at::Tensor input1 = at::randn({1, 1, 333, 1}, options);

auto lparams = scheduleTranspose(&fusion, {input0, input1});

FusionExecutor fe;
fe.compileFusion(&fusion, {input0, input1}, lparams);
auto outputs = fe.runFusion({input0, input1}, lparams);

auto tv_ref = input0 + input1;

testValidate(
&fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__);
}

// Repro for issue #1873
TEST_F(NVFuserTest, FusionInlineBroadcastIndexing0_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(1);
auto tv1 = makeContigTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
auto tv2 = set(tv0);
auto tv3 = broadcast(tv2, {true, false});
auto tv4 = add(tv3, tv1);
fusion.addOutput(tv4);

tv4->merge(0);
tv4->split(0, 32);

tv0->computeAt(tv4, 1);

tv2->split(-1, 8);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({123}, options);
at::Tensor t1 = at::randn({3, 123}, options);

FusionExecutor fe;
fe.compileFusion(&fusion, {t0, t1});

auto outputs = fe.runFusion({t0, t1});

auto tv_ref = t0 + t1;

testValidate(&fusion, outputs, {t0, t1}, {tv_ref}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) {
// https://github.com/csarofeen/pytorch/issues/1926
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <torch/torch.h>

#include <unordered_map>
Expand Down Expand Up @@ -36,6 +37,10 @@ class NVFuserTest : public ::testing::Test {
GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs";
}
}

void TearDown() override {
c10::cuda::CUDACachingAllocator::emptyCache();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do this every time a test is done? Does it involve cudaFree? If so, wouldn't this running the tests slower?

}
};

struct ValidationConstants {
Expand Down