From 3db0417a1ac77d7e8560d45b1796a72dd394133f Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 27 Jul 2020 14:43:22 -0400 Subject: [PATCH 1/6] Add lower validation pass to make sure root broadcast dims aren't split on tensors. --- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 12 ++++ torch/csrc/jit/codegen/cuda/lower_utils.h | 8 ++- .../jit/codegen/cuda/lower_validation.cpp | 63 +++++++++++++++---- 3 files changed, 70 insertions(+), 13 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 5c59f0248a41..c6eacbb5f7d3 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -389,6 +389,18 @@ Expr* firstInnerMostScope(Expr* scope) { namespace ir_utils { +std::vector iterDomainInputsOf( + const std::vector& input_ids) { + auto inputs = IterVisitor::getInputsTo({input_ids.begin(), input_ids.end()}); + std::vector id_inputs; + for (auto inp : inputs) { + if (inp->getValType() == ValType::IterDomain) { + id_inputs.push_back(inp->as()); + } + } + return id_inputs; +} + std::vector indices(std::vector loops) { std::vector inds(loops.size()); std::transform( diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index cc07303712b8..e65ba039ffe2 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -55,6 +55,9 @@ Expr* firstInnerMostScope(Expr* scope); namespace ir_utils { +// Return inputs of provided IterDomains that are IterDomains +std::vector iterDomainInputsOf(const std::vector&); + std::vector indices(std::vector); std::vector iterDomains(std::vector); @@ -73,11 +76,14 @@ bool isScope(const Expr*); Expr* asExpr(Statement*); +// TODO: Remove in favor of ->as() TensorView* asTV(Val*); +// TODO: Remove in favor of ->as() kir::ForLoop* asForLoop(Statement*); -const TensorView* asConstTV(const Val* const); +// TODO: Remove in favor of ->as() +const TensorView* asConstTV(const Val*); bool isUnrolledFor(const Expr*); diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 8b18b21d9ad5..de65e1e61c25 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -9,19 +9,56 @@ namespace fuser { // Some pre-compilation checks static void IrValidate(Fusion* fusion) { - fusion->validateInputs(); - for (Val* val : fusion->vals()) { + FusionGuard fg(fusion); + auto used_vals = DependencyCheck::getAllValsBetween( + {fusion->outputs().begin(), fusion->outputs().end()}, fusion->inputs()); + + std::unordered_set used_tvs; + + for (auto val : used_vals) { if (ir_utils::isTV(val)) { - TensorView* tv = ir_utils::asTV(val); - for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) { - IterDomain* id = tv->getComputeAtAxis(i).first; - - if (id->isBlockDim()) { - TORCH_CHECK( - !id->isBroadcast(), - "Parallelization across blocks on broadcast axes is not supported, but found on, ", - tv, - "."); + used_tvs.emplace(val->as()); + } + } + + fusion->validateInputs(); + + for (auto tv : used_tvs) { + for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) { + IterDomain* id = tv->getComputeAtAxis(i).first; + + if (id->isBlockDim()) { + TORCH_CHECK( + !id->isBroadcast(), + "Parallelization across blocks on broadcast axes is not supported, but found on, ", + tv, + "."); + } + if (tv->hasBroadcast() && tv->getMemoryType() != MemoryType::Global) { + auto td = tv->domain()->domain(); + auto ca_inputs = ir_utils::iterDomainInputsOf( + {td.begin(), td.begin() + tv->getThisComputeAtAxis()}); + auto non_ca_inputs = ir_utils::iterDomainInputsOf( + {td.begin() + tv->getThisComputeAtAxis(), td.end()}); + + std::unordered_set ca_inputs_set( + ca_inputs.begin(), ca_inputs.end()); + std::unordered_set non_ca_inputs_set( + non_ca_inputs.begin(), non_ca_inputs.end()); + + for (auto id : tv->getRootDomain()) { + if (id->isBroadcast()) { + // If a broadcast dimension is an input to both an axis within the + // computeAt point and outside the compute at point we would have to + // look at consumers to figure out what that axis will be + // broadcasted to, because we would have to generate everything the + // consumer could need on that axis. This could be supported but is + // not at this point. + TORCH_INTERNAL_ASSERT( + !(ca_inputs_set.find(id) != ca_inputs_set.end() && + non_ca_inputs_set.find(id) != non_ca_inputs_set.end()), + "Cannot generate a kernel where a root broadcast dimension is input to both IterDomains outside and within the computeAt point."); + } } } } @@ -33,6 +70,8 @@ void IrBuildSizesMap(Fusion* fusion) { std::unordered_map size_map; // Grab inputs and outputs + // TODO: Only run through inputs for the size map, outputs don't actually set + // any sizes of the problem. std::vector inputs_and_outputs; for (auto val : fusion->inputs()) { if (ir_utils::isTV(val)) { From cfa659693ef24927c56b8b287ac1f0b6d6355209 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 28 Jul 2020 11:37:18 -0400 Subject: [PATCH 2/6] Change global producer indexing so it's not based on consumer. Fixes issue mentioned in comments of: getBCastMergedIndices in index_compute.cpp --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 265 +++++++++++++----- 1 file changed, 197 insertions(+), 68 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 12833380be53..422231ed8546 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -192,6 +192,7 @@ IndexCompute::IndexCompute( return; } + // We may or may not have indices associated with reductions. const bool exclude_reduction = td_->nDims() > indices.size(); TORCH_INTERNAL_ASSERT( @@ -220,13 +221,17 @@ IndexCompute::IndexCompute( traverseFrom(indices[0]->fusion(), domain_vals, false); for (auto id : td_->rootDomain()) { - if (exclude_reduction && id->isReduction()) + if (exclude_reduction && id->isReduction()) { continue; - auto it = index_map_.find(id); - TORCH_INTERNAL_ASSERT( - it != index_map_.end(), - "Error during index compute, missed computing a value."); - indices_.push_back(it->second); + } else if (id->getIterType() == IterType::BroadcastWithStride) { + indices_.push_back(new Int(0)); + } else { + auto it = index_map_.find(id); + TORCH_INTERNAL_ASSERT( + it != index_map_.end(), + "Error during index compute, missed computing a value."); + indices_.push_back(it->second); + } } } @@ -255,6 +260,7 @@ std::vector IndexCompute::contiguityAnd( return contig_result; } +// TODO: use new mapping functions std::vector IndexCompute::contiguityPasC( TensorDomain* producer, TensorDomain* consumer) { @@ -289,84 +295,207 @@ std::vector IndexCompute::contiguityPasC( return as_consumer_contiguity; } +namespace { +// Note returned vector is relative to the producer root (rfactor domain is +// taken here for producer if there is one as this is the root domain relative +// to consumer) +// +// Consider: T3[ iS{( 1 * i5 )}, iS{i7} ] compute_at( 4, 1 ) +// = broadcast( T1[ iS{i5}, iS{i7} ] ) +// which could generate the loop nest: +// for(size_t i18 = 0; i18 < ( T4.size[0] * T4.size[1] ); ++i18 ) { +// float T3[T1.size[2]]; +// for(size_t i19 = 0; i19 < T3.size[2]; ++i19 ) { +// T2[ 0 ] +// = T1[...]; +// +// Here the first dimension to index T1 must be i18 % T4.size[1], because T3 has +// a dimension being broadcasted to the extent of T4. This function is looking +// for these types of cases: where there's a dimension merged into an entry into +// consumer->domain(), but this dimension does not exist in producer_root. +// Then we need these dimensions mapped to producer_root, so we know which ones +// we need to access with modulo. We could go consumer->domain() => +// consumer->rootDomain() => producer_root however producer_root could = +// producer->rfactorDomain() then we still might have to map to +// producer->rootDomain(). Therefore we might as well go consumer->domain() => +// producer->domain() => producer->rootDomain(). +std::vector getBCastMergedIndices( + const TensorDomain* producer, + const TensorDomain* consumer) { + auto c_root = consumer->rootDomain(); + auto p_root = producer->hasRFactor() ? producer->rfactorDomain() + : producer->rootDomain(); + + auto root_c2p_idmap = TensorDomain::mapRootCtoP(consumer, producer); + + std::unordered_set bcast_not_in_P; + for (auto c_id : c_root) { + if (c_id->isBroadcast() && + root_c2p_idmap.find(c_id) == root_c2p_idmap.end()) { + bcast_not_in_P.emplace(c_id); + } + } + + // If there are no broadcasts in consumer_root that are not in producer_root, + // we have nothing to track here. + if (bcast_not_in_P.empty()) { + return std::vector(p_root.size(), false); + } + + // We want to know what domains in consumer have a merged root broadcast + // domain not present in producer root. We then want to map that to the + // consumer_root axes impacted by this (the non-bcast axes merged with these + // bcast axes). Then we want to map this to producer_root. + + std::vector c_bcast_merged(consumer->nDims(), false); + + for (size_t c_i = 0; c_i < consumer->nDims(); c_i++) { + auto c_id = consumer->axis(c_i); + bool missing_bcast = false; + bool non_missing_bcast = false; + + auto c_id_inps = ir_utils::iterDomainInputsOf({c_id}); + + for (auto inp : c_id_inps) { + if (bcast_not_in_P.find(inp) != bcast_not_in_P.end()) { + missing_bcast = true; + } else { + non_missing_bcast = true; + } + } + + // this domain c_i is guilty. + c_bcast_merged[c_i] = missing_bcast && non_missing_bcast; + } + + // If these missing axes aren't merged with non-missing axes, we have nothing + // to track here. + if (std::none_of(c_bcast_merged.begin(), c_bcast_merged.end(), [](bool b) { + return b; + })) { + return std::vector(p_root.size(), false); + } + + // map c_bcast_merged to producer + std::vector p_bcast_merged(producer->nDims(), false); + auto c2p = + TensorDomain::mapDomainCtoP(consumer->domain(), producer->domain()); + + for (size_t c_i = 0; c_i < c2p.size(); c_i++) { + auto p_i = c2p[c_i]; + if (p_i != -1) { + p_bcast_merged[p_i] = c_bcast_merged[c_i]; + } + } + + // map p_bcast_merged to producer->root + std::vector p_root_bcast_merged(producer->nDims(), false); + + // map producer root IterDomain to it's position in producer->rootDomain() + std::unordered_map p_root_id_to_index; + for (size_t p_i = 0; p_i < producer->rootDomain().size(); p_i++) { + p_root_id_to_index[producer->rootDomain()[p_i]] = p_i; + } + + for (size_t p_i = 0; p_i < p_bcast_merged.size(); p_i++) { + if (!p_bcast_merged[p_i]) + continue; + IterDomain* id = producer->axis((int)p_i); + auto id_inps = ir_utils::iterDomainInputsOf({id}); + for (auto inp : id_inps) { + p_root_bcast_merged[p_root_id_to_index.at(id)] = true; + } + } + + return p_root_bcast_merged; +} +} // namespace kir::TensorIndex* Index::getGlobalProducerIndex( - const TensorView* producer, - const TensorView* consumer, + const TensorView* producer_tv, + const TensorView* consumer_tv, const std::vector& loops) { - // Grab indices from the loops - std::vector indices(loops.size()); - std::transform( - loops.begin(), loops.end(), indices.begin(), [](kir::ForLoop* fl) { - return fl->index(); - }); + // producer_tv->domain() is not replayed as the loop strucutre we were + // provided, so replay it to match consumer_tv which is. + auto producer = TransformReplay::replayPasC( + producer_tv->domain(), consumer_tv->domain(), -1) + .first; + + auto p2c = TensorDomain::mapDomainPtoC( + producer->domain(), consumer_tv->domain()->domain()); + std::vector indices; + for (size_t i = 0; i < producer->domain().size(); i++) { + indices.push_back(loops[p2c[i]]->index()); + } + + std::vector computed_inds = + IndexCompute::get(producer, indices, producer_tv->domain()->contiguity()); + + auto p_root = TensorDomain::noReductions(producer->rootDomain()); - // What would the consumer indices be if it was global, keeping in mind - // reduction axes. We have to do the indexing based on consumer because we - // could hit instances where we have a loop nest generated based on: - // consumer[b{1}, i1, i2] with consumer->merge(0) => consumer[b{1}*i1, i2], - // but producer would just be producer[i1, i2]. It would be very hard to - // generate indices directly on producer, but if we do it on consumer, and - // grab the root axes we need (i1 and i2), it's easy to do. - const std::vector c_inds = IndexCompute::get( - consumer->domain(), - indices, - IndexCompute::contiguityPasC(producer->domain(), consumer->domain())); - - // Computed consumer indices should have everything we need for the producer - std::vector p_inds; - auto p_root = TensorDomain::noReductions(producer->getRootDomain()); - // Number of root dims that are broadcasted - size_t implicit_bcast_dims = 0; { - auto c_root = consumer->getRootDomain(); - size_t it_c = 0, it_p = 0; - while (it_c < c_root.size() && it_p < p_root.size()) { - const bool is_bcast = p_root[it_p]->isBroadcast(); - if (c_root[it_c]->isBroadcast() && !p_root[it_p]->isBroadcast()) { - it_c++; - } else { - if (!p_root[it_p]->isBroadcast()) { - p_inds.push_back(c_inds[it_c]); - } else { - if (p_root[it_p]->getIterType() == IterType::BroadcastWithStride) { - p_inds.push_back(new Int(0)); - } else { - implicit_bcast_dims++; - } - } - it_c++; - it_p++; + // remove implicit bcast dims from root + std::vector without_implicit_bcast; + + size_t implicit_bcast_dims = 0; + for (auto id : p_root) { + if (id->getIterType() != IterType::BroadcastWithoutStride) { + without_implicit_bcast.push_back(id); } } + p_root = without_implicit_bcast; } + TORCH_INTERNAL_ASSERT( - p_inds.size() == p_root.size() - implicit_bcast_dims, + computed_inds.size() == p_root.size(), "Dimensionality error in code generator while computing tensor indices."); bool inner_most_dim_contig = - producer->getRootDomain()[producer->getRootDomain().size() - 1] - ->getIterType() == IterType::Iteration && - producer->domain()->contiguity()[producer->getRootDomain().size() - 1]; + p_root[p_root.size() - 1]->getIterType() == IterType::Iteration && + producer->contiguity()[p_root.size() - 1]; + + // This function is projected as consumer->domain() => consumer->rootDomain() + // => producer->rootDomain() + auto p_root_bcast_merged = + getBCastMergedIndices(producer, consumer_tv->domain()); std::vector strided_inds; - for (size_t i = 0; i < p_inds.size(); i++) { - if (p_inds[i]->isZeroInt()) { - strided_inds.push_back(p_inds[i]); - } else if (i == p_inds.size() - 1 && inner_most_dim_contig) { - strided_inds.push_back(p_inds[i]); + for (size_t p_i = 0; p_i < p_root.size(); p_i++) { + Val* extent = nullptr; + if (computed_inds[p_i]->isZeroInt()) { + // If collapsing a dim, but need to module, we need extents multiplied + // together + if (p_root[p_i]->getIterType() == IterType::Iteration) { + if (extent == nullptr) { + extent = p_root[p_i]->extent(); + } else { + extent = mul(extent, p_root[p_i]->extent()); + } + } + continue; + } + + auto maybe_modulo = computed_inds[p_i]; + if (p_root_bcast_merged[p_i]) { + maybe_modulo = + mod(computed_inds[p_i], + extent == nullptr ? p_root[p_i]->extent() + : mul(extent, p_root[p_i]->extent())); + } + + if (p_i == computed_inds.size() - 1 && inner_most_dim_contig) { + strided_inds.push_back(maybe_modulo); } else { std::stringstream ss; - ss << "T" << producer->name() << ".stride[" << i << "]"; + ss << "T" << producer_tv->name() << ".stride[" << p_i << "]"; strided_inds.push_back( - mul(p_inds[i], new NamedScalar(ss.str(), DataType::Int))); + mul(maybe_modulo, new NamedScalar(ss.str(), DataType::Int))); } } - // Probably shouldn't ever hit this if (strided_inds.size() == 0) strided_inds.push_back(new Int(0)); - return new kir::TensorIndex(producer, strided_inds); + return new kir::TensorIndex(producer_tv, strided_inds); } // Producer index for either shared or local memory @@ -462,12 +591,11 @@ kir::TensorIndex* Index::getGlobalConsumerIndex( consumer->domain(), indices, consumer->domain()->contiguity()); auto root_dom = consumer->getRootDomain(); + TORCH_INTERNAL_ASSERT( - computed_inds.size() == TensorDomain::noReductions(root_dom).size() || - computed_inds.size() == root_dom.size(), + computed_inds.size() == root_dom.size() || + computed_inds.size() == TensorDomain::noReductions(root_dom).size(), "Dimensionality error in code generator while computing indexing."); - - // Remove indices associated with reductions. if (computed_inds.size() == root_dom.size()) { for (size_t i = 0; i < root_dom.size(); i++) { // Do this backwards so erase offset will be right @@ -515,10 +643,12 @@ kir::TensorIndex* Index::getGlobalConsumerIndex( ->getIterType() == IterType::Iteration && consumer->domain()->contiguity()[consumer->getRootDomain().size() - 1]; + inner_most_dim_contig = false; + std::vector strided_inds; for (size_t i = 0; i < computed_inds.size(); i++) { if (computed_inds[i]->isZeroInt()) { - strided_inds.push_back(computed_inds[i]); + continue; } else if (i == computed_inds.size() - 1 && inner_most_dim_contig) { strided_inds.push_back(computed_inds[i]); } else { @@ -529,7 +659,6 @@ kir::TensorIndex* Index::getGlobalConsumerIndex( } } - // Probably shouldn't ever hit this if (strided_inds.size() == 0) strided_inds.push_back(new Int(0)); From 30f1cc3ecc4c267dbb651c274cff1cdc8443cce5 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 29 Jul 2020 08:08:07 -0400 Subject: [PATCH 3/6] Post cherry-pick cleanup. --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 422231ed8546..579c7600ea93 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -378,19 +378,17 @@ std::vector getBCastMergedIndices( // map c_bcast_merged to producer std::vector p_bcast_merged(producer->nDims(), false); - auto c2p = - TensorDomain::mapDomainCtoP(consumer->domain(), producer->domain()); + std::vector> pc_map = + TensorDomain::mapDomainPandC(consumer->domain(), producer->domain()); - for (size_t c_i = 0; c_i < c2p.size(); c_i++) { - auto p_i = c2p[c_i]; - if (p_i != -1) { - p_bcast_merged[p_i] = c_bcast_merged[c_i]; - } + for (std::pair entry : pc_map) { + int p_i = entry.first; + int c_i = entry.second; + p_bcast_merged[p_i] = c_bcast_merged[c_i]; } // map p_bcast_merged to producer->root std::vector p_root_bcast_merged(producer->nDims(), false); - // map producer root IterDomain to it's position in producer->rootDomain() std::unordered_map p_root_id_to_index; for (size_t p_i = 0; p_i < producer->rootDomain().size(); p_i++) { @@ -403,7 +401,7 @@ std::vector getBCastMergedIndices( IterDomain* id = producer->axis((int)p_i); auto id_inps = ir_utils::iterDomainInputsOf({id}); for (auto inp : id_inps) { - p_root_bcast_merged[p_root_id_to_index.at(id)] = true; + p_root_bcast_merged[p_root_id_to_index.at(inp)] = true; } } @@ -420,8 +418,15 @@ kir::TensorIndex* Index::getGlobalProducerIndex( producer_tv->domain(), consumer_tv->domain(), -1) .first; - auto p2c = TensorDomain::mapDomainPtoC( + std::vector p2c(producer->nDims(), false); + auto pc_map = TensorDomain::mapDomainPandC( producer->domain(), consumer_tv->domain()->domain()); + for (auto entry : pc_map) { + int p_i = entry.first; + int c_i = entry.second; + p2c[p_i] = c_i; + } + std::vector indices; for (size_t i = 0; i < producer->domain().size(); i++) { indices.push_back(loops[p2c[i]]->index()); @@ -643,8 +648,6 @@ kir::TensorIndex* Index::getGlobalConsumerIndex( ->getIterType() == IterType::Iteration && consumer->domain()->contiguity()[consumer->getRootDomain().size() - 1]; - inner_most_dim_contig = false; - std::vector strided_inds; for (size_t i = 0; i < computed_inds.size(); i++) { if (computed_inds[i]->isZeroInt()) { From ef9cc9fc541736859eec8186c590759ab148f5e7 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 30 Jul 2020 20:39:04 -0400 Subject: [PATCH 4/6] Fix loop nest structure for broadcast ops. --- test/cpp/jit/test_gpu.cpp | 58 +++++ test/cpp/jit/tests.h | 1 + torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 2 + torch/csrc/jit/codegen/cuda/kernel_ir.h | 8 +- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 265 ++++++++++++-------- torch/csrc/jit/codegen/cuda/lower_loops.h | 13 +- 6 files changed, 241 insertions(+), 106 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index a57dea7083c7..f6c59eece797 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -51,6 +51,17 @@ TensorView* makeDummyTensor(int nDims, DataType dtype = DataType::Float) { return new TensorView(new TensorDomain(dom), dtype); } +TensorView* makeConcreteTensor( + std::vector sizes, + DataType dtype = DataType::Float) { + // We can uncomment the below statement to test all tests with contiguous + // tensors. return makeContigTensor(nDims, dtype); + std::vector dom; + for (int i = 0; i < sizes.size(); i++) + dom.push_back(new IterDomain(new Int(0), new Int(sizes[i]))); + return new TensorView(new TensorDomain(dom), dtype); +} + TensorView* makeTensorWithContig( int nDims, std::vector contig_info, @@ -3038,6 +3049,53 @@ void testGPU_FusionSimpleBCast() { #endif } +void testGPU_FusionComplexBCast() { + Fusion fusion; + FusionGuard fg(&fusion); + + int x = 2, y = 3, z = 4; + + auto tv0 = makeConcreteTensor({y}); + auto tv1 = broadcast(tv0, {false, true}); + auto tv2 = makeConcreteTensor({y, z}); + auto tv3 = mul(tv1, tv2); + auto tv4 = broadcast(tv3, {true, false, false}); + auto tv5 = makeConcreteTensor({x, y, z}); + auto tv6 = add(tv4, tv5); + + // tv0[ i1 ] + // tv1[ i1, b2] + // tv2[ i1, i2] + // tv3[ i1, i2] + // tv4[b0, i1, i2] + // tv5[i0, i1, i2] + // tv6[i0, i1, i2] + + // tv3 = bcast(tv0) * tv2 + // tv6 = bcast(tv3) + tv5 + + fusion.addInput(tv0); + fusion.addInput(tv2); + fusion.addInput(tv5); + + fusion.addOutput(tv6); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({y}, options); + at::Tensor t2 = at::randn({y, z}, options); + at::Tensor t5 = at::randn({x, y, z}, options); + + auto t3 = t0.unsqueeze(-1).expand({y, z}) * t2; + auto t6 = t3.unsqueeze(0).expand({x, y, z}) + t5; + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({t0, t2, t5}); + + TORCH_CHECK(t6.allclose(outputs[0])); +} + // Test a simple Gemm but also play around with fusion executor features void testGPU_FusionSimpleGemm() { Fusion fusion; diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 92e4d639e640..89e362f3505f 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -147,6 +147,7 @@ namespace jit { _(GPU_FusionReduction5) \ _(GPU_FusionReductionTFT) \ _(GPU_FusionSimpleBCast) \ + _(GPU_FusionComplexBCast) \ _(GPU_FusionSimpleGemm) \ _(GPU_FusionSoftmax1D) \ _(GPU_FusionSoftmax1DNormalized) \ diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index fab5e612ee7b..0af5c1b656e9 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -610,6 +610,7 @@ void IRPrinter::handle(const kir::TernaryOp* top) { void IRPrinter::handle(const ReductionOp* rop) { TORCH_CHECK(rop->out()->getValType() != ValType::TensorIndex); + indent(); os << rop->out() << " = reduction( " << rop->in() << ", op = " << rop->getReductionOpType() << ", initial value = " << rop->init() << " )\n"; @@ -721,6 +722,7 @@ void IRPrinter::handle(const kir::GridReduction* gr) { void IRPrinter::handle(const BroadcastOp* bop) { TORCH_CHECK(bop->out()->getValType() != ValType::TensorIndex); + indent(); os << bop->out() << " = broadcast( " << bop->in() << " )\n"; } diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 2b6bdc16edab..f9b403e02a32 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -365,12 +365,12 @@ class TORCH_CUDA_API Scope { exprs_.push_back(e); } - void insert(std::vector::iterator it, Expr* expr) { - exprs_.insert(it, expr); + void insert(size_t pos, Expr* expr) { + exprs_.insert(exprs_.begin() + pos, expr); } - void erase(std::vector::iterator it) { - exprs_.erase(it); + void erase(size_t pos) { + exprs_.erase(exprs_.begin() + pos); } bool empty() const { diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index adffc7a0452e..1095a421ce16 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -11,31 +11,27 @@ namespace jit { namespace fuser { // Create, place, and return the allocation for tv -Expr* LoopNestGenerator::pushAlloc(TensorView* tv) { +Expr* LoopNestGenerator::pushAlloc(TensorView* tv, IterDomain* alloc_id) { TORCH_INTERNAL_ASSERT( !(FusionGuard::getCurFusion()->hasInput(tv) || FusionGuard::getCurFusion()->hasOutput(tv)), "Tried to allocate an input or output tensor."); - // First figure out which loop nest this allocation needs to be placed in - // Do we need to place the allocation at the root? + // Alloc id is the iteration domain associated with the allocation point of + // tv, figure out which local axis it is so we can determine the allocation + // size. size_t alloc_pos = 0; - // If there's no computeAt, then we want to be allocated at the root - while (alloc_pos <= tv->nDims() && tv->hasComputeAt()) { - // If we have a computeAt and we reached computeAt pos that's where it goes - if (tv->hasComputeAt() && alloc_pos == tv->getThisComputeAtAxis()) { - break; - } - // If we found an unroll, we want to place the allocation outside the unroll - if (alloc_pos < tv->nDims() && - tv->getComputeAtAxis(alloc_pos).first->getParallelType() == - ParallelType::Unroll) { - break; + + if (alloc_id != nullptr) { + for (auto tv_i = alloc_pos; tv_i < tv->nDims(); tv_i++) { + if (alloc_id == tv->getComputeAtAxis(tv_i).first) { + alloc_pos = tv_i + 1; + break; + } } - alloc_pos++; } - // Grab the dimensions the allocation will be based on + // Grab the dimensions the allocation will be based on to compute a size std::vector alloc_dims; for (auto i = alloc_pos; i < tv->nDims(); i++) { IterDomain* compute_at_dim = tv->getComputeAtAxis(i).first; @@ -72,19 +68,18 @@ Expr* LoopNestGenerator::pushAlloc(TensorView* tv) { // Create the allocation node kir::Allocate* alloc = new kir::Allocate(tv, MemoryType::Local, size); - // Place the allocation - if (alloc_pos == 0) { - // If we allocate at the root, insert at the begining of the lowered - // expressions + // Find which for loop we need to place this allocation + bool inserted = false; + for (auto loop : for_loops) { + if (loop->iter_domain() == alloc_id) { + loop->body().insert(0, alloc); + inserted = true; + break; + } + } + + if (!inserted) { lowered_exprs.insert(lowered_exprs.begin(), alloc); - } else if (alloc_pos == for_loops.size()) { - // If we allocate inline, push to the back of the last for loop - scope_utils::pushBack(for_loops[for_loops.size() - 1], alloc); - } else { - // Otherwise we allocate in some loop nest that is not inline, or root, so - // insert right before the loop we're just outside of - scope_utils::insertBefore( - for_loops[alloc_pos - 1], for_loops[alloc_pos], alloc); } return alloc; @@ -123,27 +118,21 @@ void LoopNestGenerator::pushBack(Expr* expr) { void LoopNestGenerator::initReduction( TensorView* tv, Val* init_val, + IterDomain* alloc_id, Expr* alloc_expr) { // This logic was taken from pushAlloc, as the initialization loop nest will // go at the same place. - - // First figure out which loop nest this allocation needs to be placed in - // Do we need to place the allocation at the root? size_t alloc_pos = 0; - // If there's no computeAt, then we want to be allocated at the root - while (alloc_pos <= tv->nDims() && tv->hasComputeAt()) { - // If we have a computeAt and we reached computeAt pos that's where it goes - if (tv->hasComputeAt() && alloc_pos == tv->getThisComputeAtAxis()) { - break; - } - // If we found an unroll, we want to place the allocation outside the unroll - if (alloc_pos < tv->nDims() && - tv->getComputeAtAxis(alloc_pos).first->getParallelType() == - ParallelType::Unroll) { - break; + // Allocation id is the iteration domain associated with the for loop which we + // will place this initialization loop nest in + if (alloc_id != nullptr) { + for (auto tv_i = alloc_pos; tv_i < tv->nDims(); tv_i++) { + if (alloc_id == tv->getComputeAtAxis(tv_i).first) { + alloc_pos = tv_i + 1; + break; + } } - alloc_pos++; } // Grab the IDs that will be involved in the initialization, ignore reduction @@ -213,51 +202,42 @@ void LoopNestGenerator::initReduction( inner_fl->body().push_back(init_stmt); } - // Place the allocation - if (alloc_pos == 0) { - // If we allocate at the root, look for the provided allocatoin if it - // exists, and place after it. + // Figure out which loop we need to place this initialization loop in. + kir::ForLoop* insert_loop = nullptr; + for (auto loop : for_loops) { + if (loop->iter_domain() == alloc_id) { + insert_loop = loop; + break; + } + } + + // If we don't find an insertion loop it means it needs to go in lowered_exprs + if (insert_loop == nullptr) { if (alloc_expr != nullptr) { - bool found = false; - for (auto it = lowered_exprs.begin(); it != lowered_exprs.end(); it++) { - if ((*it) == alloc_expr) { - lowered_exprs.insert(it + 1, init_loop_nest); - found = true; - break; - } - } + auto it = + std::find(lowered_exprs.begin(), lowered_exprs.end(), alloc_expr); TORCH_INTERNAL_ASSERT( - found, + it != lowered_exprs.end(), "Could not figure out where to initialize the buffer for ", tv); + lowered_exprs.insert(it + 1, init_loop_nest); } else { lowered_exprs.insert(lowered_exprs.begin(), init_loop_nest); } - } else if (alloc_pos == for_loops.size()) { - // If we allocate inline, push to the back of the last for loop - scope_utils::pushBack(for_loops[for_loops.size() - 1], init_loop_nest); } else { - // Otherwise we allocate in some loop nest that is not inline, or root, so - // insert right before the loop we're just outside of - scope_utils::insertBefore( - for_loops[alloc_pos - 1], for_loops[alloc_pos], init_loop_nest); + if (alloc_expr != nullptr) { + // If there is an allocation for this tensor view place this loop nest after it + insert_loop->body().insert_after(alloc_expr, init_loop_nest); + } else { + // Otherwise we're allocating a global value + insert_loop->body().insert(0, init_loop_nest); + } } } -/* - * This is one of the most complex parts of the code lowering logic. what we - * need to do is: - * 1) Reduce loop structure if needed - * 2) Open to compute At - * - If there is a computeAt set for this TV - * 3) Allocate the output. - * 4) If this is a reduction, initialize the output (open for loops to inner - * most, predicate, initialize, close predicate, close to computeAt) - * 5) Open to inner most loop - * 6) Run operation - * 7) Close to computeAt - */ void LoopNestGenerator::handle(Expr* expr) { + + // Check if it's a tensor view expression we need to place in the loop nest structure if (!ir_utils::isTVOp(expr)) { for (auto out : expr->outputs()) { TORCH_INTERNAL_ASSERT( @@ -274,42 +254,131 @@ void LoopNestGenerator::handle(Expr* expr) { } TensorView* out = expr->output(0)->as(); - // 1) Reduce loop structure - while (compute_at_scope.size() > out->getThisComputeAtAxis() && - compute_at_scope.back().second != out && - compute_at_scope.back() != - out->getComputeAtAxis((int)compute_at_scope.size() - 1)) { - popFor(); + + // Figure out what the entire loop structure should look like. + std::deque< + std::pair> + loop_structure; + + // As we go through iteration domains track the previous view + TensorView* last_ca_view = nullptr; + // Check where in the previous view our last axis was in that view + int64_t last_ca_view_ind = 0; + + // Look at each axis individually in out's domain + for (int64_t out_i = 0; out_i < (int64_t)out->getThisComputeAtAxis(); + out_i++) { + + // Grab the axis information + auto ca_point = out->getComputeAtAxis(out_i); + auto ca_view = ca_point.second; + auto ca_id = ca_point.first; + + // Figure out if there are axes in the compute at tensor view that aren't + // in out, make sure to also open them. Check where to start looking for + // them in the compute at view. + size_t start = 0; + if (last_ca_view == nullptr) { + // Start at the begining, we haven't processed any axes yet. + start = 0; + } else if (last_ca_view == ca_view) { + // This view is the same as the last axis, so start where we left off. + start = last_ca_view_ind + 1; + } else { + // This is a new view, figure out where we are in it, and start from there + for (start = 0; start < ca_view->nDims(); start++) { + if (loop_structure.back() == ca_view->getComputeAtAxis(start)) { + break; + } + } + start++; + } + + // Go from start, and open all loops in the computeAt view until we hit the + // one associated with out->getComputeAtAxis(out_i) + for (size_t ca_i = start; ca_i < ca_view->nDims(); ca_i++) { + loop_structure.push_back(ca_view->getComputeAtAxis(ca_i)); + + // Update the last view processed + last_ca_view_ind = ca_i; + last_ca_view = ca_view; + if (ca_view->getComputeAtAxis(ca_i).first == ca_id) { + break; + } + } + + // Shouldn't ever hit this, but make sure we hit the break above, meaning we + // added all necessary axes from the compute at view. + TORCH_INTERNAL_ASSERT( + ca_view->getComputeAtAxis(last_ca_view_ind).first == ca_id); + } + + // We're up to the compute at point in loop_structure, grab the remaining + // axes. + for (int64_t out_i = (int64_t)out->getThisComputeAtAxis(); + out_i < out->nDims(); + out_i++) { + // It's actually local, but getComputeAtAxis returns a std::pair, axis + // doesn't + loop_structure.push_back(out->getComputeAtAxis(out_i)); + } + + // At this point loop_structure contains our overal target loop nest structure + // Lets get a copy of the loop structure, and figure out which loops we need + // to open. + decltype(loop_structure) loops_to_open(loop_structure); + std::deque for_copy(for_loops.begin(), for_loops.end()); + while (!loops_to_open.empty() && !for_copy.empty()) { + if (loops_to_open.front().first == for_copy.front()->iter_domain()) { + loops_to_open.pop_front(); + } + for_copy.pop_front(); } - // 2) Open back up to computeAt - while (compute_at_scope.size() < out->getThisComputeAtAxis()) { - openFor(out->getComputeAtAxis((int)compute_at_scope.size())); + // At this point for_loops + loops_to_open contains our overal target loop + // nest structure. Open loops in "loops_to_open". + while (!loops_to_open.empty()) { + openFor(loops_to_open.front()); + loops_to_open.pop_front(); } - Expr* alloc_stmt = nullptr; - // 3) Allocate the output. + // Figure out where we want to place alloc/reduction initialization + IterDomain* alloc_id = nullptr; + for (size_t out_i = 0; out_i < out->getThisComputeAtAxis(); out_i++) { + auto ca_id = out->getComputeAtAxis(out_i).first; + if (ca_id->getParallelType() == ParallelType::Unroll) + break; + alloc_id = ca_id; + } + + Expr* alloc_expr = nullptr; + // Place the allocation for out if (!FusionGuard::getCurFusion()->hasInput(out) && !FusionGuard::getCurFusion()->hasOutput(out)) { - alloc_stmt = pushAlloc(out); + alloc_expr = pushAlloc(out, alloc_id); } - // 4) If this is a reduction, initialize the output (open for loops to inner + // If this is a reduction, initialize the output (open for loops to inner // most, predicate, initialize, place next after allocation if exists, close // to computeAt) if (out->hasReduction()) - initReduction(out, expr->as()->init(), alloc_stmt); + initReduction(out, expr->as()->init(), alloc_id, alloc_expr); - // 5) Open to inner most loop - for (decltype(out->nDims()) i = for_loops.size(); i < out->nDims(); i++) - openFor(out->getComputeAtAxis(i)); - // 6) Run expression + // Place the expression pushBack(expr); - // 7) Reduce loop structure back to computeAt - while (!compute_at_scope.empty() && - compute_at_scope.size() > out->getThisComputeAtAxis()) - popFor(); + // Reduce the loop nest structure back to computeAt + if (out->getThisComputeAtAxis() == 0) { + while (!for_loops.empty()) + popFor(); + } else { + auto ca_axis = out->getThisComputeAtAxis() - 1; + while (for_loops.size() > 0 && + for_loops.back()->iter_domain() != + out->getComputeAtAxis(ca_axis).first) { + popFor(); + } + } } namespace { diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h index 8decb622636b..dca05430bc1d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.h +++ b/torch/csrc/jit/codegen/cuda/lower_loops.h @@ -44,8 +44,9 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { // initialization ThreadPredicateMap& thread_predicates_; - // Create, place, and return the allocation for tv - Expr* pushAlloc(TensorView*); + // Create the allocation for tv, place it inside the loop associated with + // alloc_id, return the node + Expr* pushAlloc(TensorView*, IterDomain* alloc_id); // Open a new inner most for loop, track which TV it was constructed from // according to the computeAt chain. @@ -64,8 +65,12 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { // Initialize a buffer to init_val. If this buffer is in smem or registers, // pass in its allocation statement so we can make sure that we insert this - // initialization comes after the allocation. - void initReduction(TensorView* tv, Val* init_val, Expr* alloc_expr = nullptr); + // initialization after the allocation. + void initReduction( + TensorView* tv, + Val* init_val, + IterDomain* alloc_id, + Expr* alloc_expr = nullptr); // Check if expr is a TV op and handle accordingly. void handle(Expr*) final; From 9966cf2bbe113abd21948832f76c57b834e4470d Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 30 Jul 2020 20:48:29 -0400 Subject: [PATCH 5/6] clang format --- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 8 ++++---- torch/csrc/jit/codegen/cuda/lower_loops.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index cb8f6cefb3fe..e8c217e6f312 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -230,7 +230,8 @@ void LoopNestGenerator::initReduction( } } else { if (alloc_expr != nullptr) { - // If there is an allocation for this tensor view place this loop nest after it + // If there is an allocation for this tensor view place this loop nest + // after it insert_loop->body().insert_after(alloc_expr, init_loop_nest); } else { // Otherwise we're allocating a global value @@ -240,8 +241,8 @@ void LoopNestGenerator::initReduction( } void LoopNestGenerator::handle(Expr* expr) { - - // Check if it's a tensor view expression we need to place in the loop nest structure + // Check if it's a tensor view expression we need to place in the loop nest + // structure if (!ir_utils::isTVOp(expr)) { for (auto out : expr->outputs()) { TORCH_INTERNAL_ASSERT( @@ -283,7 +284,6 @@ void LoopNestGenerator::handle(Expr* expr) { // Look at each axis individually in out's domain for (int64_t out_i = 0; out_i < (int64_t)out->getThisComputeAtAxis(); out_i++) { - // Grab the axis information auto ca_point = out->getComputeAtAxis(out_i); auto ca_view = ca_point.second; diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h index a1aebc804124..f83f10fde6e1 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.h +++ b/torch/csrc/jit/codegen/cuda/lower_loops.h @@ -65,7 +65,7 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { // Return the status of the shared memory buffer // False if TensorView is not shared memory buffer bool isModifiedSharedMemory(Val* key) const; - + // Open a new inner most for loop, track which TV it was constructed from // according to the computeAt chain. void openFor(std::pair); From a6e8ee7d6ae6fea0192a3f9504da07a29b140505 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 30 Jul 2020 21:07:49 -0400 Subject: [PATCH 6/6] Clang tidy. --- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index e8c217e6f312..69eceacd0592 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -331,7 +331,7 @@ void LoopNestGenerator::handle(Expr* expr) { // We're up to the compute at point in loop_structure, grab the remaining // axes. for (int64_t out_i = (int64_t)out->getThisComputeAtAxis(); - out_i < out->nDims(); + out_i < (int64_t)out->nDims(); out_i++) { // It's actually local, but getComputeAtAxis returns a std::pair, axis // doesn't