From d3a295a0ec65105c74fd18d413ac03b90b3fa0ab Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 5 Apr 2020 10:56:51 -0400 Subject: [PATCH 1/9] Continue lowering refactor, split out loop nest generator, create scope/loop utils. --- caffe2/CMakeLists.txt | 2 + test/cpp/jit/test_gpu.cpp | 6 +- tools/build_variables.bzl | 2 + torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 4 + torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 2 + torch/csrc/jit/codegen/cuda/lower2device.cpp | 486 ++++++------------ torch/csrc/jit/codegen/cuda/lower2device.h | 61 +-- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 196 +++++++ torch/csrc/jit/codegen/cuda/lower_loops.h | 60 +++ torch/csrc/jit/codegen/cuda/lower_utils.cpp | 249 +++++++++ torch/csrc/jit/codegen/cuda/lower_utils.h | 45 ++ torch/csrc/jit/codegen/cuda/type.cpp | 13 +- 12 files changed, 751 insertions(+), 375 deletions(-) create mode 100644 torch/csrc/jit/codegen/cuda/lower_loops.cpp create mode 100644 torch/csrc/jit/codegen/cuda/lower_loops.h create mode 100644 torch/csrc/jit/codegen/cuda/lower_utils.cpp create mode 100644 torch/csrc/jit/codegen/cuda/lower_utils.h diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 8035e2e898d1..124c5b452b9e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -579,6 +579,8 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_loops.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_utils.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower2device.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/parser.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/partition.cpp diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 3f89530aa6a2..1adced3aa5bc 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -1,4 +1,4 @@ -#if defined(USE_CUDA) +//#if defined(USE_CUDA) #include #include @@ -1001,8 +1001,6 @@ void testGPU_FusionForLoop() { } } -void testGPU_Fusion() {} - } // namespace jit } // namespace torch -#endif // #if defined(USE_CUDA) +//#endif // #if defined(USE_CUDA) diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 823bf732e846..f7c03c458269 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -238,6 +238,8 @@ libtorch_cuda_sources = [ "torch/csrc/jit/codegen/cuda/ir_iostream.cpp", "torch/csrc/jit/codegen/cuda/iter_visitor.cpp", "torch/csrc/jit/codegen/cuda/kernel.cpp", + "torch/csrc/jit/codegen/cuda/lower_loops.cpp", + "torch/csrc/jit/codegen/cuda/lower_utils.cpp", "torch/csrc/jit/codegen/cuda/lower2device.cpp", "torch/csrc/jit/codegen/cuda/manager.cpp", "torch/csrc/jit/codegen/cuda/mutator.cpp", diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 36de14cacab7..c5a1cdb275df 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -147,6 +147,10 @@ bool Scope::sameAs(const Scope& other) const { return true; } +void Scope::clear() { + this->exprs_ = std::vector(); +} + bool IRInputOutput::hasInput(const Val* const input) const { for (auto val : inputs_) if (val == input) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index ee3c92abfbc5..1a4d905cc240 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -251,6 +251,8 @@ struct TORCH_CUDA_API Scope { bool sameAs(const Scope& other) const; + void clear(); + private: std::vector exprs_; }; diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 98c56dcdf28f..99e741367416 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -1,7 +1,8 @@ +#include #include -#include +#include +#include #include -#include #include #include @@ -37,165 +38,9 @@ const TensorView* asConstTV(const Val* const val) { return static_cast(val); } -struct parentScope_ : private OptInDispatch { - private: - Expr* parent_ = nullptr; - - void handle(ForLoop* fl) final { - parent_ = fl->parentScope(); - } - - void handle(IfThenElse* ite) final { - parent_ = ite->parentScope(); - } - - void handle(Expr* expr) final { - OptInDispatch::handle(expr); - } - - public: - static Expr* parent(Expr* scope) { - parentScope_ sp; - sp.handle(scope); - return sp.parent_; - } -}; - -struct forLoopCount : private OptInDispatch { - private: - unsigned int count_ = 0; - - void handle(ForLoop* fl) final { - count_++; - } - - void handle(IfThenElse* ite) final {} - - void handle(Expr* expr) final { - OptInDispatch::handle(expr); - } - - public: - static unsigned int count(Expr* scope) { - forLoopCount flc; - Expr* it = scope; - while (it != nullptr) { - flc.handle(it); - it = parentScope_::parent(it); - } - return flc.count_; - } -}; - -struct scopePushBack : private OptInDispatch { - private: - Expr* _expr = nullptr; - void handle(ForLoop* fl) final { - fl->body().push_back(_expr); - } - - void handle(IfThenElse* ite) final { - ite->body().push_back(_expr); - } - - void handle(Expr* expr) final { - OptInDispatch::handle(expr); - } - - public: - static void pushBack(Expr* scope, Expr* expr) { - scopePushBack pb; - TORCH_INTERNAL_ASSERT( - expr != nullptr && scope != nullptr, - "Cannot push back, scope or expr is a nullptr."); - pb._expr = expr; - pb.handle(scope); - } -}; - -struct forLoopIndices : private OptInDispatch { - private: - std::vector inds_; - void handle(ForLoop* fl) final { - inds_.insert(inds_.begin(), fl->index()); - } - - void handle(IfThenElse* ite) final {} - - void handle(Expr* expr) final { - OptInDispatch::handle(expr); - } - - public: - static std::vector get(Expr* scope) { - forLoopIndices fli; - Expr* it = scope; - while (it != nullptr) { - fli.handle(it); - it = parentScope_::parent(it); - } - return fli.inds_; - } -}; - -struct forLoopIDs : private OptInDispatch { - private: - std::vector IDs_; - void handle(ForLoop* fl) final { - IDs_.insert(IDs_.begin(), fl->range()); - } - - void handle(IfThenElse* ite) final {} - - void handle(Expr* expr) final { - OptInDispatch::handle(expr); - } - - public: - static std::vector get(Expr* scope) { - forLoopIDs fli; - Expr* it = scope; - while (it != nullptr) { - fli.handle(it); - it = parentScope_::parent(it); - } - return fli.IDs_; - } -}; - } // namespace // END HELPER FUNCTIONS -// Open a new inner most for loop -void GPULower::openFor(IterDomain* id) { - ForLoop* new_scope = nullptr; - if (id->isThread()) { - new_scope = new ForLoop( - new NamedScalar(stringify(id->parallel_method()), DataType::Int), - id, - {}, - active_scope); - } else { - new_scope = new ForLoop(new Int(), id, {}, active_scope); - } - pushBack(new_scope); - active_scope = new_scope; -} - -// Close the inner most scope -void GPULower::closeScope() { - TORCH_INTERNAL_ASSERT( - active_scope != nullptr, - "Tried to close the active scope, but there isn't one set."); - Expr* parent = parentScope_::parent(active_scope); - active_scope = parent; -} - -// Close all scopes -void GPULower::resetScope() { - active_scope = nullptr; -} - // Clear out the last recorded computeAtView void GPULower::clearActiveView() { active_view_axis = 0; @@ -208,14 +53,6 @@ void GPULower::setActiveView(const TensorView* const tv) { active_view = tv->getComputeAtView(); } -std::vector GPULower::getLoopIndices() { - return forLoopIndices::get(active_scope); -} - -std::vector GPULower::getLoopIterDomains() { - return forLoopIDs::get(active_scope); -} - TensorIndex* GPULower::getGlobalProducerIndex( TensorView* producer, TensorView* consumer) { @@ -224,11 +61,11 @@ TensorIndex* GPULower::getGlobalProducerIndex( // This replay will ignore reduction dimensions on the producer TransformReplay::fullReplay(consumer, cloned_tv); TORCH_INTERNAL_ASSERT( - getLoopIndices().size() == cloned_tv->nDims(), + scope_utils::getLoopIndices(active_scope).size() == cloned_tv->nDims(), "Dimensionality error in code generator while computing indexing."); - const std::vector computed_inds = - IndexCompute::computeIndices(cloned_tv, getLoopIndices()); + const std::vector computed_inds = IndexCompute::computeIndices( + cloned_tv, scope_utils::getLoopIndices(active_scope)); TORCH_INTERNAL_ASSERT( computed_inds.size() == producer->getRootDomain()->size(), @@ -253,14 +90,15 @@ TensorIndex* GPULower::getLocalProducerIndex( TensorView* producer, TensorView* consumer) { TORCH_INTERNAL_ASSERT( - computeForDepth() == producer->nDims(), + scope_utils::computeForDepth(active_scope) == producer->nDims(), "Expected a tensor with ", - computeForDepth(), + scope_utils::computeForDepth(active_scope), " dimensions but got one with ", producer->nDims()); - std::vector loopInds = getLoopIndices(); - std::vector ranges = getLoopIterDomains(); + std::vector loopInds = scope_utils::getLoopIndices(active_scope); + std::vector ranges = + scope_utils::getLoopIterDomains(active_scope); std::vector computed_inds; std::vector used_ranges; for (decltype(loopInds.size()) i{0}; i < loopInds.size(); i++) { @@ -295,11 +133,11 @@ TensorIndex* GPULower::getProducerIndex( TensorIndex* GPULower::getGlobalConsumerIndex(TensorView* consumer) { TORCH_INTERNAL_ASSERT( - getLoopIndices().size() == consumer->nDims(), + scope_utils::getLoopIndices(active_scope).size() == consumer->nDims(), "Dimensionality error in code generator while computing indexing."); - const std::vector computed_inds = - IndexCompute::computeIndices(consumer, getLoopIndices()); + const std::vector computed_inds = IndexCompute::computeIndices( + consumer, scope_utils::getLoopIndices(active_scope)); TORCH_INTERNAL_ASSERT( computed_inds.size() == consumer->getRootDomain()->size(), @@ -322,14 +160,15 @@ TensorIndex* GPULower::getGlobalConsumerIndex(TensorView* consumer) { TensorIndex* GPULower::getLocalConsumerIndex(TensorView* consumer) { TORCH_INTERNAL_ASSERT( - computeForDepth() == consumer->nDims(), + scope_utils::computeForDepth(active_scope) == consumer->nDims(), "Expected a tensor with ", - computeForDepth(), + scope_utils::computeForDepth(active_scope), " dimensions but got one with ", consumer->nDims()); - std::vector loopInds = getLoopIndices(); - std::vector ranges = getLoopIterDomains(); + std::vector loopInds = scope_utils::getLoopIndices(active_scope); + std::vector ranges = + scope_utils::getLoopIterDomains(active_scope); std::vector computed_inds; std::vector used_ranges; @@ -364,57 +203,11 @@ TensorIndex* GPULower::getConsumerIndex(TensorView* consumer) { return getLocalConsumerIndex(consumer); } -// Track how far our for loop scope is -unsigned int GPULower::computeForDepth() { - return forLoopCount::count(active_scope); -} - -// Push an expr to the active scope -void GPULower::pushBack(Expr* expr) { - if (active_scope == nullptr) { - lowered_exprs.push_back(expr); - return; - } - scopePushBack::pushBack(active_scope, expr); -} - -// Return the parent of the active scope -Expr* GPULower::parentScope() { - if (active_scope == nullptr) - return nullptr; - return parentScope_::parent(active_scope); -} - -Allocate* GPULower::getAlloc(TensorView* tv) { - TORCH_INTERNAL_ASSERT( - !(FusionGuard::getCurFusion()->hasInput(tv) || - FusionGuard::getCurFusion()->hasOutput(tv)), - "Tried to allocate an input or output tensor."); - - std::vector alloc_dims; - - for (decltype(tv->nDims()) i = tv->getComputeAtAxis(); i < tv->nDims(); i++) { - IterDomain* dim = tv->getComputeAtAxis(i); - if (dim->isThreadDim()) - continue; - alloc_dims.push_back(dim->size()); - } - - Val* size; - if (alloc_dims.size() == 0) { - size = new Int(1); - } else { - size = alloc_dims[0]; - for (decltype(alloc_dims.size()) i{1}; i < alloc_dims.size(); i++) { - size = mul(size, alloc_dims[i]); - } - } - return new Allocate(tv, size); -} - IfThenElse* GPULower::getPredicate(const TensorView* const pred_tv) { TensorIndex* ti = new TensorIndex( - pred_tv, IndexCompute::computeIndices(pred_tv, getLoopIndices())); + pred_tv, + IndexCompute::computeIndices( + pred_tv, scope_utils::getLoopIndices(active_scope))); std::vector all_preds = PredicateCompute::computePredicates(ti); @@ -438,41 +231,92 @@ IfThenElse* GPULower::getPredicate(const TensorView* const pred_tv) { return new IfThenElse(cond, {}, {}, active_scope); } -// Custom dispatch for Expr, want to find out of it's a TV op -void GPULower::handle(Expr* expr) { - if (!isTVOp(expr)) - return; +void GPULower::pushBack(Expr* expr) { + if (active_scope == nullptr) + lowered_exprs.push_back(expr); + else + scope_utils::pushBack(active_scope, expr); +} + +Statement* GPULower::mutate(Expr* expr) { + Statement* mutated_stmt = OptOutMutator::mutate(expr); + TORCH_INTERNAL_ASSERT( + mutated_stmt->isExpr(), + "Tried to generate a kernel but hit a non expression during lowering: ", + mutated_stmt); + return mutated_stmt; +} + +Statement* GPULower::mutate(ForLoop* fl) { + Expr* prev_scope = active_scope; + active_scope = fl; + std::vector mutated_exprs; + bool is_mutated = false; + for (auto expr : fl->body().exprs()) { + Statement* mutated_stmt = mutate(expr); + + TORCH_INTERNAL_ASSERT( + mutated_stmt->isExpr(), + "Tried to generate a kernel but hit a non expression during lowering: ", + mutated_stmt); + + mutated_exprs.push_back(static_cast(mutated_stmt)); + if (!(mutated_exprs.back()->sameAs(expr))) + is_mutated = true; + } - TensorView* out = static_cast(expr->output(0)); + if (is_mutated) { + scope_utils::clearScope(active_scope); + for (auto expr : mutated_exprs) + pushBack(expr); + } - updateView(out); + active_scope = prev_scope; - // 8) Run operation - OptOutDispatch::handle(expr); + if (is_mutated) + return new ForLoop( + fl->index(), fl->range(), mutated_exprs, fl->parentScope()); - // 9) Close predicate - if (active_scope != nullptr && - active_scope->getExprType() == ExprType::IfThenElse) - closeScope(); + return fl; } -void GPULower::handle(UnaryOp* uop) { - TORCH_INTERNAL_ASSERT( - isTV(uop->out()), - "Expected a tensor view but got ", - uop->out()->getValType().value()); +Statement* GPULower::mutate(UnaryOp* uop) { + if (!isTVOp(uop)) + return OptOutMutator::mutate(uop); + + IfThenElse* pred = getPredicate(asTV(uop->out())); + bool predicated = !pred->cond()->sameAs(new Int(1)); + if (predicated) { + pushBack(pred); + active_scope = pred; + } + TensorIndex* out = getConsumerIndex(asTV(uop->out())); Val* in = uop->in(); if (isTV(in)) in = getProducerIndex(asTV(in), asTV(uop->out())); - pushBack(new UnaryOp(uop->getUnaryOpType(), out, in)); + Expr* new_op = new UnaryOp(uop->getUnaryOpType(), out, in); + + if (predicated) { + active_scope = scope_utils::getParent(active_scope); + pushBack(new_op); + return pred; + } + + return new_op; } -void GPULower::handle(BinaryOp* bop) { - TORCH_INTERNAL_ASSERT( - isTV(bop->out()), - "Expected a tensor view but got ", - bop->out()->getValType().value()); +Statement* GPULower::mutate(BinaryOp* bop) { + if (!isTVOp(bop)) + return OptOutMutator::mutate(bop); + + IfThenElse* pred = getPredicate(asTV(bop->out())); + bool predicated = !pred->cond()->sameAs(new Int(1)); + if (predicated) { + pushBack(pred); + active_scope = pred; + } + TensorIndex* out = getConsumerIndex(asTV(bop->out())); Val* lhs = bop->lhs(); Val* rhs = bop->rhs(); @@ -483,83 +327,15 @@ void GPULower::handle(BinaryOp* bop) { if (isTV(rhs)) rhs = getProducerIndex(asTV(rhs), asTV(bop->out())); - pushBack(new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs)); -} - -/* - * This is one of the most complex parts of the code lowering logic. what we - * need to do is: 1) Reduce loop structure - * - Reset all loops if active_view == nullptr (I'm not the last in a series - * of computeAts) - * - Else reduce to active_view_axis if loop_depth > active_view_axis - * 2) Set active_view(_axis) - * - If there is a computeAt set for this TV - * 3) Open to compute At - * - If there is a computeAt set for this TV - * 4) Allocate the output. - * 5) If this is a reduction, initialize the output (open for loops to inner - * most, predicate, initialize, close predicate, close to computeAt) 6) Open to - * inner most loop 7) Open predicate 8) Run operation 9) Close predicate - */ - -// Update fors based on tv. -void GPULower::updateView(TensorView* tv) { - // 1) Reduce loop structure - if (active_view == nullptr) { - // - Reset all loops if active_view == nullptr (I'm not the last in a series - // of computeAts) - resetScope(); - } else { - // - Else reduce to active_view_axis if loop_depth > active_view_axis - auto depth = computeForDepth(); - for (auto i = depth; i > active_view_axis; i--) { - closeScope(); - } - } - if (tv->hasComputeAt()) { - // 2) Set active_view(_axis) - // - If there is a computeAt set for this TV - setActiveView(tv); - - // 3) Open to compute At - // - If there is a computeAt set for this TV - auto depth = computeForDepth(); - for (auto i = depth; i < tv->getComputeAtAxis(); i++) - openFor(tv->getComputeAtAxis(i)); - } else { - if (active_view != nullptr) - // If we're the last computeAt of a block, active view should match this - // tv - TORCH_INTERNAL_ASSERT( - tv->sameAs(active_view), - "Error detected in code lowering. Expected ", - active_view, - " but recieved ", - tv); - clearActiveView(); - } - - // 4) Allocate the output. + Expr* new_op = new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs); - if (!FusionGuard::getCurFusion()->hasInput(tv) && - !FusionGuard::getCurFusion()->hasOutput(tv)) { - pushBack(getAlloc(tv)); + if (predicated) { + pushBack(new_op); + active_scope = scope_utils::getParent(active_scope); + return pred; } - // TODO: - // 5) If this is a reduction, initialize the output (open for loops to inner - // most, predicate, initialize, close predicate, close to computeAt) - - // 6) Open to inner most loop - for (decltype(tv->nDims()) i = computeForDepth(); i < tv->nDims(); i++) - openFor(tv->getComputeAtAxis(i)); - - // 7) Open predicate - IfThenElse* pred = getPredicate(tv); - if (!pred->cond()->sameAs(new Int(1))) { - pushBack(pred); - active_scope = pred; - } + return new_op; } // TensorViews are all based on symbolic sizes. When we first initialize them we @@ -677,27 +453,75 @@ void GPULower::replaceSizes() { ReplaceAll::instancesOf(tv_map); } +namespace { + +// Some pre-compilation checks +void validate(Fusion* fusion) { + for (Val* val : fusion->vals()) { + if (isTV(val)) { + TensorView* tv = asTV(val); + for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) { + IterDomain* id = tv->getComputeAtAxis(i); + + if (id->isThread()) + TORCH_CHECK( + !id->isReduction(), + "Parallelization on reduction axes not support at the moment found on, ", + tv, + "."); + + if (tv->hasComputeAt()) + if (i < tv->getComputeAtAxis()) + TORCH_CHECK( + id->parallel_method() != ParallelType::Unroll, + "Unroll dimension cannot be outside computeAt, found on: ", + tv, + " compute at ", + tv->getComputeAtView(), + " axis = ", + tv->getComputeAtAxis(), + "."); + } + } // if isTV + } // for(Val* val : fusion->vals()) + +} // validate +} // namespace + // Traverse through the fusion and print CUDA code associated with it std::vector GPULower::getLoweredExprs() { FusionGuard fg(fusion_); + // Likely we lowered this fusion, we can simply return the lowered expressions + // Not the safest approach but good enough for now. + if (fusion_->lowered && lowered_exprs.size() != 0) + return lowered_exprs; + TORCH_CHECK( !fusion_->lowered, "Fusions can only be lowered once as of now. You could reuse the lowering using", " std::vector GPULower::getLoweredExprs() the result can be printed as", " a kernel with IRPrinter irp(os); irp.printKernel(lowered_exprs, kernel_name);"); + validate(fusion_); + // Initialize members of the class - lowered_exprs = std::vector(); active_view = nullptr; active_view_axis = 0; replaceSizes(); - // Run through and lower the expressions - std::vector exprs = fusion_->exprs(true); - for (auto* expr : exprs) - handle(expr); + auto loop_nests = LoopNestGenerator::getLoopNest(fusion_); + + // Run through loop nests and further lower the expressions + for (auto* expr : loop_nests) { + Statement* mutated_stmt = mutate(expr); + TORCH_INTERNAL_ASSERT( + mutated_stmt->isExpr(), + "Tried to generate a kernel but hit a non expression during lowering: ", + mutated_stmt); + lowered_exprs.push_back(static_cast(mutated_stmt)); + } fusion_->lowered = true; return lowered_exprs; diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index 00018addf79c..7c29b6a88b82 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -2,8 +2,8 @@ #include -#include #include +#include #include #include #include @@ -21,71 +21,58 @@ namespace fuser { // keep user references intact so they can lower it as they describe the kernel. // Right now we can only lower once. -struct TORCH_CUDA_API GPULower : public OptOutDispatch { +struct TORCH_CUDA_API GPULower : public OptOutMutator { private: bool lowered = false; - Fusion* fusion_; + Fusion* const fusion_; std::vector lowered_exprs; Expr* active_scope = nullptr; + // Track the last computeAt TensorView and axis const TensorView* active_view; unsigned int active_view_axis; - // Open a new inner most for loop - void openFor(IterDomain*); - // Close the inner most for loop - void closeScope(); - // Close all for loops - void resetScope(); // Clear out the last recorded computeAtView void clearActiveView(); // Set active views from computeAtView void setActiveView(const TensorView* const); - // Grab the index variables of the active loop nest - std::vector getLoopIndices(); - // Grab the iterDomains of the active loops - std::vector getLoopIterDomains(); - // Gets the indexing of a TensorView producer. These are values consumed in a - // TensorView Expr. We use the consumer (left hand side of the =) to compute - // the indexing into the consumer. + + // Indexing functions + // Consumer = Producer + // i.e. T0 = T1... -> T0 is the consumer, T1 is the producer + // Producer indexing dispatch TensorIndex* getProducerIndex(TensorView* producer, TensorView* consumer); + // Producer if it's in global memory TensorIndex* getGlobalProducerIndex( TensorView* producer, TensorView* consumer); + // Producer indexing if it's in registers TensorIndex* getLocalProducerIndex( TensorView* producer, TensorView* consumer); + // Consumer index dispatch TensorIndex* getConsumerIndex(TensorView* consumer); + // Consumer indexing if it's in global memory TensorIndex* getGlobalConsumerIndex(TensorView* consumer); + // Consumer indexing if it's in local memory TensorIndex* getLocalConsumerIndex(TensorView* consumer); - // Track how far our for loop scope is - unsigned int computeForDepth(); - // Push an expr to the active scope - void pushBack(Expr* expr); - // Return the parent of the active scope - Expr* parentScope(); - - // Get Register allocation statement for tensorview - Allocate* getAlloc(TensorView*); // Get a predicate based on a particular tensorview IfThenElse* getPredicate(const TensorView* const); - // Custom dispatch for Expr, want to find out of it's a TV op - void handle(Expr*) final; + // Wrap pushBack in lower_utils if active_scope is null we want it to go + // straight to lower_exprs + void pushBack(Expr*); - // Remake operations with TensorIndex - void handle(UnaryOp*) final; - void handle(BinaryOp*) final; + // Custom dispatch for Expr, want to find out of it's a TV op + Statement* mutate(Expr*) final; - // Ignore split/merge/reorder operations, - // we don't want to print them. - void handle(Split*) final {} - void handle(Merge*) final {} - void handle(Reorder*) final {} + // Open the for loop. + Statement* mutate(ForLoop*) final; - // Update for loop structure based on producing provided TensorView - void updateView(TensorView*); + // Remake operations with TensorIndex + Statement* mutate(UnaryOp*) final; + Statement* mutate(BinaryOp*) final; // TensorViews are all based on symbolic sizes. When we first initialize them // we don't know if they're inputs or outputs which would mean that they have diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp new file mode 100644 index 000000000000..1cc67f5b1991 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -0,0 +1,196 @@ +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { + +// HELPER NAMESPACE +namespace { + +bool isTV(const Val* const val) { + return val->getValType().value() == ValType::TensorView; +} + +// Check if we're a TensorView op that we can generate code for. +bool isTVOp(const Expr* expr) { + if (expr->nOutputs() == 1 && isTV(expr->output(0)) && + (expr->getExprType().value() == ExprType::BinaryOp || + expr->getExprType().value() == ExprType::UnaryOp)) + return true; + return false; +} + +TensorView* asTV(Val* val) { + TORCH_INTERNAL_ASSERT(isTV(val)); + return static_cast(val); +} + +const TensorView* asConstTV(const Val* const val) { + TORCH_INTERNAL_ASSERT(isTV(val)); + return static_cast(val); +} + +} // namespace + +Allocate* LoopNestGenerator::getAlloc(TensorView* tv) { + TORCH_INTERNAL_ASSERT( + !(FusionGuard::getCurFusion()->hasInput(tv) || + FusionGuard::getCurFusion()->hasOutput(tv)), + "Tried to allocate an input or output tensor."); + + std::vector alloc_dims; + + for (decltype(tv->nDims()) i = tv->getComputeAtAxis(); i < tv->nDims(); i++) { + IterDomain* dim = tv->getComputeAtAxis(i); + if (dim->isThreadDim()) + continue; + alloc_dims.push_back(dim->size()); + } + + Val* size; + if (alloc_dims.size() == 0) { + size = new Int(1); + } else { + size = alloc_dims[0]; + for (decltype(alloc_dims.size()) i{1}; i < alloc_dims.size(); i++) { + size = mul(size, alloc_dims[i]); + } + } + return new Allocate(tv, size); +} + +// Clear out the last recorded computeAtView +void LoopNestGenerator::clearActiveView() { + active_view_axis = 0; + active_view = nullptr; +} + +// Set active views from computeAtView +void LoopNestGenerator::setActiveView(const TensorView* const tv) { + active_view_axis = tv->getComputeAtAxis(); + active_view = tv->getComputeAtView(); +} + +void LoopNestGenerator::openFor(IterDomain* id) { + Expr* new_scope = scope_utils::openFor(active_scope, id); + if (active_scope == nullptr) { + pushBack(new_scope); + } + active_scope = new_scope; +} + +void LoopNestGenerator::pushBack(Expr* expr) { + if (active_scope == nullptr) + lowered_exprs.push_back(expr); + else + scope_utils::pushBack(active_scope, expr); +} + +/* + * This is one of the most complex parts of the code lowering logic. what we + * need to do is: 1) Reduce loop structure + * - Reset all loops if active_view == nullptr (I'm not the last in a series + * of computeAts) + * - Else reduce to active_view_axis if loop_depth > active_view_axis + * 2) Set active_view(_axis) + * - If there is a computeAt set for this TV + * 3) Open to compute At + * - If there is a computeAt set for this TV + * 4) Allocate the output. + * 5) If this is a reduction, initialize the output (open for loops to inner + * most, predicate, initialize, close predicate, close to computeAt) 6) Open to + * inner most loop 7) Open predicate 8) Run operation 9) Close predicate + */ + +// Update fors based on tv. +void LoopNestGenerator::updateLoopNest(TensorView* tv) { + // 1) Reduce loop structure + if (active_view != nullptr) { + // - Else reduce to active_view_axis if loop_depth > active_view_axis + auto depth = scope_utils::computeForDepth(active_scope); + for (auto i = depth; i > active_view_axis; i--) { + active_scope = scope_utils::closeScope(active_scope); + } + } + + if (tv->hasComputeAt()) { + // 2) Set active_view(_axis) + // - If there is a computeAt set for this TV + setActiveView(tv); + + // 3) Open to compute At + // - If there is a computeAt set for this TV + auto depth = scope_utils::computeForDepth(active_scope); + + for (auto i = depth; i < tv->getComputeAtAxis(); i++) + openFor(tv->getComputeAtAxis(i)); + } else { + if (active_view != nullptr) + // If we're the last computeAt of a block, active view should match this + // tv + TORCH_INTERNAL_ASSERT( + tv->sameAs(active_view), + "Error detected in code lowering. Expected ", + active_view, + " but recieved ", + tv); + + clearActiveView(); + } + // 4) Allocate the output. + if (!FusionGuard::getCurFusion()->hasInput(tv) && + !FusionGuard::getCurFusion()->hasOutput(tv)) { + pushBack(getAlloc(tv)); + } + // TODO: + // 5) If this is a reduction, initialize the output (open for loops to inner + // most, predicate, initialize, close predicate, close to computeAt) + + // 6) Open to inner most loop + for (decltype(tv->nDims()) i = scope_utils::computeForDepth(active_scope); + i < tv->nDims(); + i++) + openFor(tv->getComputeAtAxis(i)); +} + +// Custom dispatch for Expr, want to find out of it's a TV op +void LoopNestGenerator::handle(Expr* expr) { + if (!isTVOp(expr)) + return; + + TensorView* out = static_cast(expr->output(0)); + updateLoopNest(out); + + pushBack(expr); +} + +// Generate the loop nest structure and place it in lowered_exprs +void LoopNestGenerator::generate() { + FusionGuard fg(fusion_); + + // Likely we lowered this fusion, we can simply return the lowered expressions + // Not the safest approach but good enough for now. + if (fusion_->lowered && lowered_exprs.size() != 0) + return; + + TORCH_CHECK( + !fusion_->lowered, + "Fusions can only be lowered once as of now. You could reuse the lowering using", + " std::vector GPULower::getLoweredExprs() the result can be printed as", + " a kernel with IRPrinter irp(os); irp.printKernel(lowered_exprs, kernel_name);"); + + // Initialize members of the class + lowered_exprs = std::vector(); + active_view = nullptr; + active_view_axis = 0; + + std::vector exprs = fusion_->exprs(true); + for (auto* expr : exprs) + handle(expr); +} + +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h new file mode 100644 index 000000000000..7bfbc8dc40db --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_loops.h @@ -0,0 +1,60 @@ +#pragma once +#include + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { + +struct TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { + private: + std::vector lowered_exprs; + Fusion* fusion_; + + // Track the last computeAt TensorView and axis + const TensorView* active_view; + unsigned int active_view_axis; + + // Active IfThenElse or ForLoop + Expr* active_scope = nullptr; + + // Get Register allocation statement for tensorview + Allocate* getAlloc(TensorView*); + + // Clear out the last recorded computeAtView + void clearActiveView(); + // Set active views from computeAtView + void setActiveView(const TensorView* const); + + // Open a new inner most for loop + void openFor(IterDomain*); + + // Wrap pushBack in lower_utils if active_scope is null we want it to go + // straight to lower_exprs + void pushBack(Expr*); + + // Update for loop structure based on this TensorView + void updateLoopNest(TensorView*); + + // Check if a TV op, generate for loop nest around it + void handle(Expr*) final; + + // Generate the loop nest structure and place it in lowered_exprs + void generate(); + + LoopNestGenerator(Fusion* _fusion) : fusion_(_fusion) {} + + public: + static std::vector getLoopNest(Fusion* fusion) { + FusionGuard fg(fusion); + LoopNestGenerator lng(fusion); + lng.generate(); + return lng.lowered_exprs; + } +}; + +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp new file mode 100644 index 000000000000..79cfd74a0b05 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -0,0 +1,249 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace fuser { + +namespace scope_utils { + +namespace { + +struct forLoopIndices : private OptInDispatch { + private: + std::vector inds_; + void handle(ForLoop* fl) final { + inds_.insert(inds_.begin(), fl->index()); + } + + void handle(IfThenElse* ite) final {} + + void handle(Expr* expr) final { + OptInDispatch::handle(expr); + } + + public: + static std::vector get(Expr* scope) { + forLoopIndices fli; + Expr* it = scope; + while (it != nullptr) { + fli.handle(it); + it = getParent(it); + } + return fli.inds_; + } +}; + +struct parentScope : private OptInDispatch { + private: + Expr* parent_ = nullptr; + + void handle(ForLoop* fl) final { + parent_ = fl->parentScope(); + } + + void handle(IfThenElse* ite) final { + parent_ = ite->parentScope(); + } + + void handle(Expr* expr) final { + OptInDispatch::handle(expr); + } + + public: + static Expr* get(Expr* scope) { + parentScope sp; + sp.handle(scope); + return sp.parent_; + } +}; + +struct forLoopCount : private OptInDispatch { + private: + unsigned int count_ = 0; + + void handle(ForLoop* fl) final { + count_++; + } + + void handle(IfThenElse* ite) final {} + + void handle(Expr* expr) final { + OptInDispatch::handle(expr); + } + + public: + static unsigned int get(Expr* scope) { + forLoopCount flc; + Expr* it = scope; + while (it != nullptr) { + flc.handle(it); + it = getParent(it); + } + return flc.count_; + } +}; + +struct forLoopIDs : private OptInDispatch { + private: + std::vector IDs_; + void handle(ForLoop* fl) final { + IDs_.insert(IDs_.begin(), fl->range()); + } + + void handle(IfThenElse* ite) final {} + + void handle(Expr* expr) final { + OptInDispatch::handle(expr); + } + + public: + static std::vector get(Expr* scope) { + forLoopIDs fli; + Expr* it = scope; + while (it != nullptr) { + fli.handle(it); + it = getParent(it); + } + return fli.IDs_; + } +}; + +struct scopePushBack : private OptInDispatch { + private: + Expr* _expr = nullptr; + void handle(ForLoop* fl) final { + fl->body().push_back(_expr); + } + + void handle(IfThenElse* ite) final { + ite->body().push_back(_expr); + } + + void handle(Expr* expr) final { + OptInDispatch::handle(expr); + } + + public: + static void push(Expr* scope, Expr* expr) { + scopePushBack pb; + TORCH_INTERNAL_ASSERT( + expr != nullptr && scope != nullptr, + "Cannot push back, scope or expr is a nullptr."); + pb._expr = expr; + pb.handle(scope); + } +}; + +struct scopeClearExprs : private OptInDispatch { + private: + Expr* _expr = nullptr; + void handle(ForLoop* fl) final { + fl->body().clear(); + } + + void handle(IfThenElse* ite) final { + ite->body().clear(); + } + + void handle(Expr* expr) final { + OptInDispatch::handle(expr); + } + + public: + static void clear(Expr* scope) { + scopeClearExprs sce; + TORCH_INTERNAL_ASSERT( + scope != nullptr, "Cannot clear scope, scope is a nullptr."); + sce.handle(scope); + } +}; + +void assertScope(Expr* expr) { + TORCH_INTERNAL_ASSERT( + expr->getExprType() == ExprType::ForLoop || + expr->getExprType() == ExprType::IfThenElse, + "Assert Scope failed when calling a scope_util function."); +} + +} // namespace + +// Grab the index variables of the active loop nest +std::vector getLoopIndices(Expr* scope) { + if (scope == nullptr) + return std::vector(); + assertScope(scope); + return forLoopIndices::get(scope); +} + +// Grab the iterDomains of the active loops +std::vector getLoopIterDomains(Expr* scope) { + if (scope == nullptr) + return std::vector(); + assertScope(scope); + return forLoopIDs::get(scope); +} + +// Track how far our for loop scope is +unsigned int computeForDepth(Expr* scope) { + if (scope == nullptr) + return 0; + assertScope(scope); + return forLoopCount::get(scope); +} + +// Push back an expr to scope +void pushBack(Expr* scope, Expr* expr) { + TORCH_INTERNAL_ASSERT( + scope != nullptr, "Scope is a nullptr, cannot push an expr to it."); + assertScope(scope); + scopePushBack::push(scope, expr); +} + +// Return the parent of the active scope +Expr* getParent(Expr* scope) { + TORCH_INTERNAL_ASSERT( + scope != nullptr, + "Tried to close the active scope, but there isn't one set."); + assertScope(scope); + return parentScope::get(scope); +} + +// Open a new inner most for loop +Expr* openFor(Expr* scope, IterDomain* id) { + ForLoop* new_scope = nullptr; + if (id->isThread()) { + new_scope = new ForLoop( + new NamedScalar(stringify(id->parallel_method()), DataType::Int), + id, + {}, + scope); + } else { + new_scope = new ForLoop(new Int(), id, {}, scope); + } + if (scope != nullptr) + pushBack(scope, new_scope); + return new_scope; +} + +// Close the inner most for loop +Expr* closeScope(Expr* scope) { + TORCH_INTERNAL_ASSERT( + scope != nullptr, "Tried to close a scope but got a nullptr."); + return getParent(scope); +} + +// Clear all expressions from the scope +Expr* clearScope(Expr* scope) { + TORCH_INTERNAL_ASSERT( + scope != nullptr, "Tried to clear a scope but got a nullptr."); + assertScope(scope); + scopeClearExprs::clear(scope); + return scope; +} + +} // namespace scope_utils +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h new file mode 100644 index 000000000000..f61826c0dd94 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +#include + +// Provides utilities for dealing with nested ForLoop and IfThenElse scopes + +namespace torch { +namespace jit { +namespace fuser { + +namespace scope_utils { + +// Grab the index variables of the active loop nest +std::vector getLoopIndices(Expr* scope); + +// Grab the iterDomains of the active loops +std::vector getLoopIterDomains(Expr* scope); + +// Track how far our for loop scope is +unsigned int computeForDepth(Expr* scope); + +// Push back an expr to scope +void pushBack(Expr* scope, Expr* expr); + +// Return the parent of the active scope +Expr* getParent(Expr* scope); + +// Open a new inner most for loop +Expr* openFor(Expr* scope, IterDomain*); + +// Close the inner most for loop +Expr* closeScope(Expr* scope); + +// Clear all expressions from the scope +Expr* clearScope(Expr* scope); + +// Track how far our for loop scope is +unsigned int computeForDepth(Expr* scope); + +} // namespace scope_utils +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 928eb4158c2c..1e396fee5df8 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -9,8 +9,12 @@ namespace fuser { // Return highest on list (smallest enum val) DataType promote_type(const DataType& t1, const DataType& t2) { - TORCH_CHECK(DataType::Null != t1 && DataType::Null != t2, - "Expected promotable DataTypes but got: ", t1, " and ", t2); + TORCH_CHECK( + DataType::Null != t1 && DataType::Null != t2, + "Expected promotable DataTypes but got: ", + t1, + " and ", + t2); return t1 < t2 ? t1 : t2; } @@ -18,7 +22,10 @@ DataType promote_type(const DataType& t1, const DataType& t2) { ValType promote_type(const ValType& t1, const ValType& t2) { TORCH_CHECK( t1 >= ValType::TensorView && t2 >= ValType::TensorView, - "Expected promotable ValTypes but got: ", t1, " and ", t2); + "Expected promotable ValTypes but got: ", + t1, + " and ", + t2); // Check that it's a promotable type (with dtype) // static_assert?? return t1 < t2 ? t1 : t2; From 1e2540a10c26fb3aa9789fbd6976bde2ccffa3f8 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 5 Apr 2020 18:22:06 -0400 Subject: [PATCH 2/9] ForLoop::range renamed to ForLoop::iter_domain --- torch/csrc/jit/codegen/cuda/ir_internal_nodes.h | 12 ++++++------ torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 4 ++-- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 8 ++++---- torch/csrc/jit/codegen/cuda/lower2device.cpp | 2 +- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 2 +- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 4352d0258117..16ec7e105bc6 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -99,7 +99,7 @@ struct TORCH_CUDA_API BinaryOp : public Expr { }; /* - * Simply a representation of an iterable from 0 to size. TensorDomains which + * Simply a representation of an iterable from start to extent. TensorDomains which * represent how to iterate over a tensor is made up of IterDomains. We directly * set parallization strategies on IterDomains. */ @@ -317,7 +317,7 @@ struct TORCH_CUDA_API Reorder : public Expr { }; /* - * ForLoop provides scoping around an int iterator from 0 to range. Exprs placed + * ForLoop provides scoping around an index through an IterDomain. Exprs placed * in its body are considered inside the scope of the for loop. In the future * the implementation should look quite different so that we can do proper * dependency annalysis like in Fusion. @@ -329,7 +329,7 @@ struct TORCH_API ForLoop : public Expr { ~ForLoop() = default; ForLoop( Val* _index, - IterDomain* _range, + IterDomain* _iter_domain, const std::vector& _body = {}, Expr* parent_scope = nullptr); @@ -343,8 +343,8 @@ struct TORCH_API ForLoop : public Expr { return index_; } - IterDomain* range() const noexcept { - return range_; + IterDomain* iter_domain() const noexcept { + return iter_domain_; } Scope& body() noexcept { @@ -365,7 +365,7 @@ struct TORCH_API ForLoop : public Expr { private: Val* const index_; - IterDomain* const range_; + IterDomain* const iter_domain_; Scope body_; Expr* parent_scope_; }; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 821c739d0518..3f1dcf2a948a 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -256,7 +256,7 @@ void IRPrinter::handle(const BinaryOp* const bop) { } void IRPrinter::handle(const ForLoop* const fl) { - if (fl->range()->isThread()) { + if (fl->iter_domain()->isThread()) { for (auto& expr : fl->constBody().exprs()) handle(expr); return; @@ -268,7 +268,7 @@ void IRPrinter::handle(const ForLoop* const fl) { os << "{0}; "; handle(fl->index()); os << " < "; - print_inline(fl->range()->size()); + print_inline(fl->iter_domain()->size()); os << "; ++"; handle(fl->index()); os << " ) {\n"; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index daa79e7de29c..97e576f5d0ff 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -174,25 +174,25 @@ bool Reorder::sameAs(const Reorder* const other) const { ForLoop::ForLoop( Val* _index, - IterDomain* _range, + IterDomain* _iter_domain, const std::vector& _body, Expr* _parent_scope) : Expr(ExprType::ForLoop), index_{_index}, - range_{_range}, + iter_domain_{_iter_domain}, parent_scope_{_parent_scope} { TORCH_INTERNAL_ASSERT( _index->isAnInt(), "Cannot create a for loop with an index that is not an int."); addInput(_index); - addInput(_range); + addInput(_iter_domain); this->name_ = FusionGuard::getCurFusion()->registerExpr(this); for (Expr* expr : _body) body().push_back(expr); } bool ForLoop::sameAs(const ForLoop* other) const { - if (this->range() != other->range()) + if (this->iter_domain() != other->iter_domain()) return false; if (!(constBody().sameAs(other->constBody()))) return false; diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 99e741367416..1fe6b727bd1b 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -275,7 +275,7 @@ Statement* GPULower::mutate(ForLoop* fl) { if (is_mutated) return new ForLoop( - fl->index(), fl->range(), mutated_exprs, fl->parentScope()); + fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope()); return fl; } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 79cfd74a0b05..d74b4fe8fb10 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -89,7 +89,7 @@ struct forLoopIDs : private OptInDispatch { private: std::vector IDs_; void handle(ForLoop* fl) final { - IDs_.insert(IDs_.begin(), fl->range()); + IDs_.insert(IDs_.begin(), fl->iter_domain()); } void handle(IfThenElse* ite) final {} From dc5d14c977dd1401afc2160a8c8e03ff037aac01 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 5 Apr 2020 18:29:30 -0400 Subject: [PATCH 3/9] Rename IterDomain::size -> IterDomain::extent. --- test/cpp/jit/test_gpu.cpp | 16 +++++------ torch/csrc/jit/codegen/cuda/index_compute.cpp | 2 +- .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 6 ++-- torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 4 +-- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 28 +++++++++---------- torch/csrc/jit/codegen/cuda/lower2device.cpp | 8 +++--- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 2 +- torch/csrc/jit/codegen/cuda/mutator.cpp | 4 +-- .../jit/codegen/cuda/predicate_compute.cpp | 2 +- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 6 ++-- 10 files changed, 39 insertions(+), 39 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 1adced3aa5bc..064d22e57b84 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -385,22 +385,22 @@ void testGPU_FusionTVSplit() { tv = tv->split(2, 2); TORCH_CHECK(tv->nDims() == 4); - Expr* outer = tv->axis(2)->size()->getOrigin(); + Expr* outer = tv->axis(2)->extent()->getOrigin(); TORCH_CHECK( outer->getExprType().value() == ExprType::BinaryOp && static_cast(outer)->getBinaryOpType() == BinaryOpType::CeilDiv && static_cast(outer)->lhs()->sameAs( - tv->getRootDomain()->axis(2)->size()) && + tv->getRootDomain()->axis(2)->extent()) && static_cast(static_cast(outer)->rhs()) ->sameAs(new Int(2))); IterDomain* inner = static_cast(tv->axis(3)); TORCH_CHECK( - inner->size()->isScalar() && - static_cast(inner->size())->isConst() && - static_cast(inner->size())->value().value() == 2); + inner->extent()->isScalar() && + static_cast(inner->extent())->isConst() && + static_cast(inner->extent())->value().value() == 2); } void testGPU_FusionTVMerge() { @@ -410,15 +410,15 @@ void testGPU_FusionTVMerge() { TensorView* tv = makeDummyTensor(3); tv = tv->merge(1); - Expr* axisOp = tv->axis(1)->size()->getOrigin(); + Expr* axisOp = tv->axis(1)->extent()->getOrigin(); TORCH_CHECK( tv->nDims() == 2 && axisOp->getExprType() == ExprType::BinaryOp && static_cast(axisOp)->getBinaryOpType() == BinaryOpType::Mul && static_cast(axisOp)->lhs() == - tv->getRootDomain()->axis(1)->size() && + tv->getRootDomain()->axis(1)->extent() && static_cast(axisOp)->rhs() == - tv->getRootDomain()->axis(2)->size()); + tv->getRootDomain()->axis(2)->extent()); } void testGPU_FusionTVReorder() { diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index ddc3e2ee1c62..deef3aa6f89f 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -21,7 +21,7 @@ void IndexCompute::replayBackward(Merge* expr) { ax >= 0 && ax < indices.size(), "Hit an invalid MERGE transformation during IndexCompute, axis is not within bounds."); - Val* I = expr->in()->axis(ax + 1)->size(); + Val* I = expr->in()->axis(ax + 1)->extent(); Val* ind = indices[ax]; indices[ax] = div(ind, I); indices.insert(indices.begin() + ax + 1, mod(ind, I)); diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 16ec7e105bc6..a7338a87a13e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -109,7 +109,7 @@ struct TORCH_CUDA_API IterDomain : public Val { IterDomain() = delete; IterDomain( - Val* int_size, + Val* _extent, ParallelType _parallel_method = ParallelType::Serial, bool _reduction_domain = false); @@ -165,7 +165,7 @@ struct TORCH_CUDA_API IterDomain : public Val { return parallel_method_; } - Val* size() const; + Val* extent() const; IterDomain(const IterDomain& other) = delete; IterDomain& operator=(const IterDomain& other) = delete; @@ -174,7 +174,7 @@ struct TORCH_CUDA_API IterDomain : public Val { IterDomain& operator=(IterDomain&& other) = delete; private: - Val* const size_; + Val* const extent_; ParallelType parallel_method_ = ParallelType::Serial; bool is_reduction_domain_; }; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 3f1dcf2a948a..2b6fc648734d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -108,7 +108,7 @@ void IRPrinter::handle(const IterDomain* const id) { os << id->parallel_method(); } os << "{"; - print_inline(id->size()); + print_inline(id->extent()); os << "}"; } @@ -268,7 +268,7 @@ void IRPrinter::handle(const ForLoop* const fl) { os << "{0}; "; handle(fl->index()); os << " < "; - print_inline(fl->iter_domain()->size()); + print_inline(fl->iter_domain()->extent()); os << "; ++"; handle(fl->index()); os << " ) {\n"; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 97e576f5d0ff..b6b363837149 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -56,43 +56,43 @@ bool BinaryOp::sameAs(const BinaryOp* other) const { } IterDomain::IterDomain( - Val* _size, + Val* _extent, ParallelType _parallel_method, bool _reduction_domain) : Val(ValType::IterDomain, DataType::Int), - size_(_size), + extent_(_extent), parallel_method_(_parallel_method), is_reduction_domain_(_reduction_domain) { TORCH_INTERNAL_ASSERT( - _size->isAnInt(), - "Cannot create an iter domain over a size that is not an int."); + _extent->isAnInt(), + "Cannot create an iter domain over an extent that is not an int."); } bool IterDomain::sameAs(const IterDomain* const other) const { bool is_same = isReduction() == other->isReduction() && parallel_method() == other->parallel_method(); - if (size()->getValType() == ValType::NamedScalar && - other->size()->getValType() == ValType::NamedScalar) { + if (extent()->getValType() == ValType::NamedScalar && + other->extent()->getValType() == ValType::NamedScalar) { is_same = is_same && - (static_cast(size())->name().compare( - static_cast(other->size())->name()) == 0); + (static_cast(extent())->name().compare( + static_cast(other->extent())->name()) == 0); } else { - is_same = is_same && size()->sameAs(other->size()); + is_same = is_same && extent()->sameAs(other->extent()); } return is_same; } -Val* IterDomain::size() const { +Val* IterDomain::extent() const { if (isThread()) { - if (size_->getValType() == ValType::Scalar) - if (static_cast(size_)->isConst()) - return size_; + if (extent_->getValType() == ValType::Scalar) + if (static_cast(extent_)->isConst()) + return extent_; std::string parallel_dim = stringifyThreadSize(parallel_method_); return new NamedScalar(parallel_dim, DataType::Int); } - return size_; + return extent_; } bool TensorDomain::sameAs(const TensorDomain* const other) const { diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 1fe6b727bd1b..a458a067decb 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -113,7 +113,7 @@ TensorIndex* GPULower::getLocalProducerIndex( for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { Val* ind = computed_inds[i]; for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++) - ind = mul(ind, used_ranges[i]->size()); + ind = mul(ind, used_ranges[i]->extent()); computed_inds[i] = ind; } if (computed_inds.size() == 0) @@ -184,7 +184,7 @@ TensorIndex* GPULower::getLocalConsumerIndex(TensorView* consumer) { for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { Val* ind = computed_inds[i]; for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++) - ind = mul(ind, used_ranges[i]->size()); + ind = mul(ind, used_ranges[i]->extent()); computed_inds[i] = ind; } @@ -388,7 +388,7 @@ void GPULower::replaceSizes() { std::vector new_domain; TensorDomain* root_td = tv->getRootDomain(); for (decltype(root_td->size()) i{0}; i < root_td->size(); i++) { - Val* orig_size = root_td->axis(i)->size(); + Val* orig_size = root_td->axis(i)->extent(); std::stringstream ss; ss << "T" << new_tv->name() << ".size[" << i << "]"; Val* new_size = @@ -412,7 +412,7 @@ void GPULower::replaceSizes() { TensorDomain* root_td = tv->getRootDomain(); for (decltype(root_td->size()) i{0}; i < root_td->size(); i++) { - Val* new_size = root_td->axis(i)->size(); + Val* new_size = root_td->axis(i)->extent(); if (size_map.find(new_size) != size_map.end()) new_size = size_map[new_size]; new_domain.push_back(new IterDomain( diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index 1cc67f5b1991..a6f97de6a449 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -46,7 +46,7 @@ Allocate* LoopNestGenerator::getAlloc(TensorView* tv) { IterDomain* dim = tv->getComputeAtAxis(i); if (dim->isThreadDim()) continue; - alloc_dims.push_back(dim->size()); + alloc_dims.push_back(dim->extent()); } Val* size; diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 5086e32c895c..00ba69a1a2bd 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -29,8 +29,8 @@ void OptOutMutator::mutate(Fusion* fusion) { // MUTATE FUNCTIONS FOR VALS Statement* OptOutMutator::mutate(IterDomain* id) { - Val* s = mutateAsVal(id->size())->asVal(); - if (!s->sameAs(id->size())) { + Val* s = mutateAsVal(id->extent())->asVal(); + if (!s->sameAs(id->extent())) { Val* mutated_val = new IterDomain(s, id->parallel_method(), id->isReduction()); registerMutation(id, mutated_val); diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp index dd55f876bc75..2068a398ff6f 100644 --- a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp @@ -27,7 +27,7 @@ std::vector PredicateCompute::computePredicates(const TensorIndex* ti) { for (decltype(ti->size()) i{0}; i < ti->size(); i++) if (FusionGuard::getCurFusion()->origin(ti->index(i)) != nullptr) { - Val* pred = lt(ti->index(i), root->axis(i)->size()); + Val* pred = lt(ti->index(i), root->axis(i)->extent()); TORCH_CHECK( pred->getValType().value() == ValType::Scalar && pred->getDataType().value() == DataType::Int); diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 680643a7ca26..440cd2b415b5 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -58,7 +58,7 @@ TensorView* split_(TensorView* tv, int axis, int factor) { new_domain.push_back(td->axis(i)); else { // outer loop size - Val* vo = ceilDiv(id->size(), fact); + Val* vo = ceilDiv(id->extent(), fact); Int* so = static_cast(vo); // outer loop IterDomain @@ -96,7 +96,7 @@ TensorView* merge_(TensorView* tv, int axis) { assert(first->isReduction() == second->isReduction()); assert(first->parallel_method() == second->parallel_method()); - Val* merged_id_size = mul(first->size(), second->size()); + Val* merged_id_size = mul(first->extent(), second->extent()); IterDomain* merged_id = new IterDomain( static_cast(merged_id_size), first->parallel_method(), @@ -221,7 +221,7 @@ TensorView* TensorView::newForOutput(DataType dtype) const { // consumers and we're copying over a producer. if (this->axis(i)->isReduction()) continue; - domain_copy.push_back(new IterDomain(this->axis(i)->size())); + domain_copy.push_back(new IterDomain(this->axis(i)->extent())); } TensorDomain* td = new TensorDomain(domain_copy); return new TensorView(td, dtype); From c512f7ad4538144fe07d723f79fa8260a1a58d3a Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 6 Apr 2020 10:25:58 -0400 Subject: [PATCH 4/9] Last working test before unrolling. Add incrementally better scalar checking (still not recursive). Add start index to For Loops. --- test/cpp/jit/test_gpu.cpp | 17 +++--- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 11 ++++ torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 2 + .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 20 ++++-- torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 9 ++- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 61 ++++++++++++++++--- torch/csrc/jit/codegen/cuda/lower2device.cpp | 3 +- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 1 + torch/csrc/jit/codegen/cuda/mutator.cpp | 17 +++--- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 22 +++++-- .../jit/codegen/cuda/transform_replay.cpp | 9 +++ 11 files changed, 136 insertions(+), 36 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 064d22e57b84..9b3ab22de536 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -1,4 +1,4 @@ -//#if defined(USE_CUDA) +#if defined(USE_CUDA) #include #include @@ -28,7 +28,7 @@ using namespace torch::jit::fuser; TensorView* makeDummyTensor(int nDims) { std::vector dom; for (int i = 0; i < nDims; i++) - dom.push_back(new IterDomain(new Int())); + dom.push_back(new IterDomain(new Int(0), new Int())); return new TensorView(new TensorDomain(dom), DataType::Float); } @@ -857,7 +857,7 @@ void testGPU_FusionSimplePWise() { // Set up symbolic sizes for the axes should be dimensionality of the problem std::vector dom; for (int i = 0; i < nDims; i++) - dom.push_back(new IterDomain(new Int())); + dom.push_back(new IterDomain(new Int(0), new Int())); // Set up your input tensor views TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float); @@ -973,14 +973,16 @@ void testGPU_FusionForLoop() { FusionGuard fg(&fusion); const auto TV0 = new TensorView( - new TensorDomain({new IterDomain(new Int(16))}), DataType::Float); + new TensorDomain({new IterDomain(new Int(0), new Int(16))}), + DataType::Float); const auto TV1 = new TensorView( - new TensorDomain({new IterDomain(new Int(16))}), DataType::Float); + new TensorDomain({new IterDomain(new Int(0), new Int(16))}), + DataType::Float); fusion.addInput(TV0); fusion.addInput(TV1); - auto ID0 = new IterDomain(new Int(8)); + auto ID0 = new IterDomain(new Int(0), new Int(8)); TensorView* TV2 = static_cast(add(TV0, TV1)); BinaryOp* op = static_cast(TV2->getOrigin()); @@ -1001,6 +1003,7 @@ void testGPU_FusionForLoop() { } } + } // namespace jit } // namespace torch -//#endif // #if defined(USE_CUDA) +#endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index c5a1cdb275df..6a93e7d70f52 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -41,6 +41,9 @@ Val::Val(ValType _vtype, DataType _dtype) : vtype_{_vtype}, dtype_{_dtype} { } } +// Traverse origin of all values involved in constructing the provided val. +// Check if all values involved are constant values, meaning the provided +// val is also a constant value. namespace { struct ConstCheck : OptOutConstDispatch { @@ -88,6 +91,14 @@ bool Val::isConstScalar() const { return ConstCheck::isConst(this); } +bool Val::isZeroInt() const { + if (isConstScalar() && getValType().value() == ValType::Scalar && + getDataType().value() == DataType::Int && + static_cast(this)->value().value() == 0) + return true; + return false; +} + c10::optional Val::getDataType() const { TORCH_INTERNAL_ASSERT( dtype_ != DataType::Null, "Value does not have a data type."); diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 1a4d905cc240..5a7d2b5fa9a1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -178,6 +178,8 @@ struct TORCH_CUDA_API Val : public Statement { return isScalar() && dtype_ == DataType::Int; } + bool isZeroInt() const; + // Returns the Expr that this value is an output of, returns nullptr if none // was found Expr* getOrigin(); diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index a7338a87a13e..9040e6a58d2f 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -99,9 +99,9 @@ struct TORCH_CUDA_API BinaryOp : public Expr { }; /* - * Simply a representation of an iterable from start to extent. TensorDomains which - * represent how to iterate over a tensor is made up of IterDomains. We directly - * set parallization strategies on IterDomains. + * Simply a representation of an iterable from start to extent. TensorDomains + * which represent how to iterate over a tensor is made up of IterDomains. We + * directly set parallization strategies on IterDomains. */ struct TORCH_CUDA_API IterDomain : public Val { ~IterDomain() = default; @@ -109,6 +109,7 @@ struct TORCH_CUDA_API IterDomain : public Val { IterDomain() = delete; IterDomain( + Val* _start, Val* _extent, ParallelType _parallel_method = ParallelType::Serial, bool _reduction_domain = false); @@ -157,7 +158,14 @@ struct TORCH_CUDA_API IterDomain : public Val { TORCH_CHECK( t != ParallelType::Vectorize, "Vectorization not yet supported."); if (t == ParallelType::Unroll) - TORCH_CHECK(false, "Unrolling not yet supported."); + TORCH_CHECK( + start()->isZeroInt() && extent()->isConstScalar(), + "Unrolling only supported with start = 0 and extent as a const int, but got ", + "a start of ", + start(), + " and extent ", + extent(), + " ."); } } @@ -165,6 +173,9 @@ struct TORCH_CUDA_API IterDomain : public Val { return parallel_method_; } + Val* start() const noexcept { + return start_; + } Val* extent() const; IterDomain(const IterDomain& other) = delete; @@ -174,6 +185,7 @@ struct TORCH_CUDA_API IterDomain : public Val { IterDomain& operator=(IterDomain&& other) = delete; private: + Val* const start_; Val* const extent_; ParallelType parallel_method_ = ParallelType::Serial; bool is_reduction_domain_; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 2b6fc648734d..c26c47d567b2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -107,7 +107,12 @@ void IRPrinter::handle(const IterDomain* const id) { default: os << id->parallel_method(); } + os << "{"; + if (!id->start()->isZeroInt()) { + print_inline(id->start()); + os << " : "; + } print_inline(id->extent()); os << "}"; } @@ -265,7 +270,9 @@ void IRPrinter::handle(const ForLoop* const fl) { indent(); os << "for(size_t "; handle(fl->index()); - os << "{0}; "; + os << " = "; + print_inline(fl->iter_domain()->start()); + os << "; "; handle(fl->index()); os << " < "; print_inline(fl->iter_domain()->extent()); diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index b6b363837149..331ffdc8c2ce 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -10,6 +10,46 @@ namespace torch { namespace jit { namespace fuser { +namespace { +struct ScalarCheck : OptInDispatch { + Val* v1_; + Val* v2_; + bool same = false; + + void handle(Float* f) { + same = static_cast(v1_)->sameAs(static_cast(v2_)); + } + + void handle(Int* i) { + same = static_cast(v1_)->sameAs(static_cast(v2_)); + } + + void handle(NamedScalar* ns) { + same = + static_cast(v1_)->sameAs(static_cast(v2_)); + } + + ScalarCheck(Val* _v1, Val* _v2) : v1_(_v1), v2_(_v2) { + OptInDispatch::handle(v1_); + } + + public: + static bool sameAs(Val* v1, Val* v2) { + if (v1 == v2) + return true; + + if (v1->getValType() != v2->getValType()) + return false; + + if (v1->getDataType() != v2->getDataType()) + return false; + + ScalarCheck sc(v1, v2); + return sc.same; + } +}; +} // namespace + bool Float::sameAs(const Float* const other) const { if (isConst() && other->isConst()) return *value() == *(other->value()); @@ -56,30 +96,33 @@ bool BinaryOp::sameAs(const BinaryOp* other) const { } IterDomain::IterDomain( + Val* _start, Val* _extent, ParallelType _parallel_method, bool _reduction_domain) : Val(ValType::IterDomain, DataType::Int), + start_(_start), extent_(_extent), parallel_method_(_parallel_method), is_reduction_domain_(_reduction_domain) { TORCH_INTERNAL_ASSERT( _extent->isAnInt(), - "Cannot create an iter domain over an extent that is not an int."); + "Cannot create an iter domain over an extent that is not an int but recieved ", + _extent, + " ."); + TORCH_INTERNAL_ASSERT( + _start->isAnInt(), + "Cannot create an iter domain with a start that is not an int but recieved ", + _extent, + " ."); } bool IterDomain::sameAs(const IterDomain* const other) const { bool is_same = isReduction() == other->isReduction() && parallel_method() == other->parallel_method(); + is_same = is_same && ScalarCheck::sameAs(extent(), other->extent()); + is_same = is_same && ScalarCheck::sameAs(start(), other->start()); - if (extent()->getValType() == ValType::NamedScalar && - other->extent()->getValType() == ValType::NamedScalar) { - is_same = is_same && - (static_cast(extent())->name().compare( - static_cast(other->extent())->name()) == 0); - } else { - is_same = is_same && extent()->sameAs(other->extent()); - } return is_same; } diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index a458a067decb..10f0712c1025 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -396,6 +396,7 @@ void GPULower::replaceSizes() { size_map[orig_size] = new_size; new_domain.push_back(new IterDomain( + root_td->axis(i)->start(), new_size, root_td->axis(i)->parallel_method(), root_td->axis(i)->isReduction())); @@ -416,6 +417,7 @@ void GPULower::replaceSizes() { if (size_map.find(new_size) != size_map.end()) new_size = size_map[new_size]; new_domain.push_back(new IterDomain( + root_td->axis(i)->start(), new_size, root_td->axis(i)->parallel_method(), root_td->axis(i)->isReduction())); @@ -531,7 +533,6 @@ std::ostream& GPULower::printKernel( std::ostream& os, const std::string& kernel_name) { FusionGuard fg(fusion_); - getLoweredExprs(); IRPrinter irp(os); diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index a6f97de6a449..1b438e7580cb 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -46,6 +46,7 @@ Allocate* LoopNestGenerator::getAlloc(TensorView* tv) { IterDomain* dim = tv->getComputeAtAxis(i); if (dim->isThreadDim()) continue; + //TORCH_INTERNAL_ASSERT() alloc_dims.push_back(dim->extent()); } diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 00ba69a1a2bd..d91c57023a6e 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -29,14 +29,15 @@ void OptOutMutator::mutate(Fusion* fusion) { // MUTATE FUNCTIONS FOR VALS Statement* OptOutMutator::mutate(IterDomain* id) { - Val* s = mutateAsVal(id->extent())->asVal(); - if (!s->sameAs(id->extent())) { - Val* mutated_val = - new IterDomain(s, id->parallel_method(), id->isReduction()); - registerMutation(id, mutated_val); - return mutated_val; - } - return id; + Val* s = mutateAsVal(id->start())->asVal(); + Val* e = mutateAsVal(id->extent())->asVal(); + if (s->sameAs(id->start()) && e->sameAs(id->extent())) + return id; + + Val* mutated_val = + new IterDomain(s, e, id->parallel_method(), id->isReduction()); + registerMutation(id, mutated_val); + return mutated_val; } Statement* OptOutMutator::mutate(TensorDomain* td) { diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 440cd2b415b5..9545c93fea18 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -38,6 +38,10 @@ TensorView* split_(TensorView* tv, int axis, int factor) { IterDomain* id = td->axis(axis); + TORCH_CHECK( + id->start()->isZeroInt(), + "Splitting IterDomains with starting values that aren't 0, is not supported at this time."); + if (id->parallel_method() != ParallelType::Serial) TORCH_CHECK( false, @@ -62,13 +66,13 @@ TensorView* split_(TensorView* tv, int axis, int factor) { Int* so = static_cast(vo); // outer loop IterDomain - IterDomain* ido = - new IterDomain(so, id->parallel_method(), id->isReduction()); + IterDomain* ido = new IterDomain( + new Int(0), so, id->parallel_method(), id->isReduction()); new_domain.push_back(ido); // inner loop IterDomain - IterDomain* idi = - new IterDomain(fact, id->parallel_method(), id->isReduction()); + IterDomain* idi = new IterDomain( + new Int(0), fact, id->parallel_method(), id->isReduction()); new_domain.push_back(idi); } } @@ -93,11 +97,16 @@ TensorView* merge_(TensorView* tv, int axis) { IterDomain* first = td->axis(axis); IterDomain* second = td->axis(axis + 1); + TORCH_CHECK( + first->start()->isZeroInt() && second->start()->isZeroInt(), + "Merging IterDomains with starting values that aren't 0, is not supported at this time."); + assert(first->isReduction() == second->isReduction()); assert(first->parallel_method() == second->parallel_method()); Val* merged_id_size = mul(first->extent(), second->extent()); IterDomain* merged_id = new IterDomain( + new Int(0), static_cast(merged_id_size), first->parallel_method(), first->isReduction()); @@ -202,7 +211,7 @@ TensorView::TensorView(const std::shared_ptr& tensor_type) TORCH_CHECK( tensor_type->dim().has_value(), "Requires static rank for Tensor"); for (int i = 0; i < tensor_type->dim().value(); i++) { - sizes.push_back(new IterDomain(new Int())); + sizes.push_back(new IterDomain(new Int(0), new Int())); } domain_ = new TensorDomain(sizes); } @@ -221,7 +230,8 @@ TensorView* TensorView::newForOutput(DataType dtype) const { // consumers and we're copying over a producer. if (this->axis(i)->isReduction()) continue; - domain_copy.push_back(new IterDomain(this->axis(i)->extent())); + domain_copy.push_back( + new IterDomain(this->axis(i)->start(), this->axis(i)->extent())); } TensorDomain* td = new TensorDomain(domain_copy); return new TensorView(td, dtype); diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 80ff45ed8751..717fd7f7d558 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -67,6 +67,10 @@ TensorView* TransformReplay::replay(Split* expr, TensorView* tv) { TORCH_INTERNAL_ASSERT( real_axis != -1, "During transformation replay attempted to split an imaginary axis."); + TORCH_INTERNAL_ASSERT( + tv->axis(real_axis)->start()->isZeroInt(), + "Transform Replay tried to split an IterDomain with a start value that is not 0,", + " this is not currently supported."); // Replay split tv->split(real_axis, *(expr->factor()->value())); // Inserted a real axis, push everything in axis_map over to the right @@ -97,6 +101,11 @@ TensorView* TransformReplay::replay(Merge* expr, TensorView* tv) { axis_map[axis] != -1 && axis_map[axis + 1] != -1, "During transformation replay attempted to merge an imaginary axis."); // Replay merge + TORCH_INTERNAL_ASSERT( + tv->axis(axis)->start()->isZeroInt() && + tv->axis(axis + 1)->start()->isZeroInt(), + "Transform Replay tried to Merge IterDomains with a start value that is not 0,", + " this is not currently supported."); tv->merge(axis_map[axis]); } else { // If we aren't applying the merge, we won't change any following axis From 95755cab7d9dfd5625c2f89913e71e9aac1abe6c Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 6 Apr 2020 11:04:41 -0400 Subject: [PATCH 5/9] Add basic infrastructure for unrolling pass. --- torch/csrc/jit/codegen/cuda/lower2device.cpp | 3 +- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 93 +++++++++++++++++++- torch/csrc/jit/codegen/cuda/lower_loops.h | 39 ++++++++ 3 files changed, 132 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 10f0712c1025..55e6db31e227 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -514,9 +514,10 @@ std::vector GPULower::getLoweredExprs() { replaceSizes(); auto loop_nests = LoopNestGenerator::getLoopNest(fusion_); + auto unrolled_loops = UnrollPass::runPass(fusion_, loop_nests); // Run through loop nests and further lower the expressions - for (auto* expr : loop_nests) { + for (auto* expr : unrolled_loops) { Statement* mutated_stmt = mutate(expr); TORCH_INTERNAL_ASSERT( mutated_stmt->isExpr(), diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index 1b438e7580cb..f70b9df22a2b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include namespace torch { @@ -34,6 +34,95 @@ const TensorView* asConstTV(const Val* const val) { } // namespace +void UnrollPass::pushBack(Expr* expr) { + if (active_scope == nullptr) + lowered_exprs.push_back(expr); + else + scope_utils::pushBack(active_scope, expr); +} + +// Custom dispatch for Expr, want to find out of it's a TV op +Statement* UnrollPass::mutate(Expr* expr) { + Statement* mutated_stmt = OptOutMutator::mutate(expr); + TORCH_INTERNAL_ASSERT( + mutated_stmt->isExpr(), + "Tried to generate a kernel but hit a non expression during lowering: ", + mutated_stmt); + return mutated_stmt; +} + +// Open the for loop. +Statement* UnrollPass::mutate(ForLoop* fl) { + Expr* prev_scope = active_scope; + active_scope = fl; + std::vector mutated_exprs; + bool is_mutated = false; + for (auto expr : fl->body().exprs()) { + Statement* mutated_stmt = mutate(expr); + + TORCH_INTERNAL_ASSERT( + mutated_stmt->isExpr(), + "Tried to generate a kernel but hit a non expression during lowering: ", + mutated_stmt); + + mutated_exprs.push_back(static_cast(mutated_stmt)); + if (!(mutated_exprs.back()->sameAs(expr))) + is_mutated = true; + } + + if (is_mutated) { + scope_utils::clearScope(active_scope); + for (auto expr : mutated_exprs) + pushBack(expr); + } + + active_scope = prev_scope; + + if (is_mutated) + return new ForLoop( + fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope()); + + return fl; +} + +// Remake operations with TensorIndex +Statement* UnrollPass::mutate(UnaryOp* uop) { + return uop; +} +Statement* UnrollPass::mutate(BinaryOp* bop) { + return bop; +} + +// Generate the loop nest structure and place it in lowered_exprs +void UnrollPass::runPass() { + FusionGuard fg(fusion_); + + // Likely we lowered this fusion, we can simply return the lowered expressions + // Not the safest approach but good enough for now. + if (fusion_->lowered && incoming_exprs_.size() != 0) + return; + + TORCH_CHECK( + !fusion_->lowered, + "Fusions can only be lowered once as of now. You could reuse the lowering using", + " std::vector GPULower::getLoweredExprs() the result can be printed as", + " a kernel with IRPrinter irp(os); irp.printKernel(lowered_exprs, kernel_name);"); + + // Initialize members of the class + active_view = nullptr; + active_view_axis = 0; + + // Run through loop nests and further lower the expressions + for (auto* expr : incoming_exprs_) { + Statement* mutated_stmt = mutate(expr); + TORCH_INTERNAL_ASSERT( + mutated_stmt->isExpr(), + "Tried to generate a kernel but hit a non expression during lowering: ", + mutated_stmt); + lowered_exprs.push_back(static_cast(mutated_stmt)); + } +} + Allocate* LoopNestGenerator::getAlloc(TensorView* tv) { TORCH_INTERNAL_ASSERT( !(FusionGuard::getCurFusion()->hasInput(tv) || @@ -46,7 +135,7 @@ Allocate* LoopNestGenerator::getAlloc(TensorView* tv) { IterDomain* dim = tv->getComputeAtAxis(i); if (dim->isThreadDim()) continue; - //TORCH_INTERNAL_ASSERT() + // TORCH_INTERNAL_ASSERT() alloc_dims.push_back(dim->extent()); } diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h index 7bfbc8dc40db..f82ff9b74955 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.h +++ b/torch/csrc/jit/codegen/cuda/lower_loops.h @@ -8,6 +8,45 @@ namespace torch { namespace jit { namespace fuser { +struct UnrollPass : public OptOutMutator { + private: + Fusion* fusion_; + std::vector lowered_exprs; + const std::vector& incoming_exprs_; + Expr* active_scope = nullptr; + + // Track the last computeAt TensorView and axis + const TensorView* active_view; + unsigned int active_view_axis; + + // Wrap pushBack in lower_utils if active_scope is null we want it to go + // straight to lower_exprs + void pushBack(Expr*); + + // Custom dispatch for Expr, want to find out of it's a TV op + Statement* mutate(Expr*) final; + + // Open the for loop. + Statement* mutate(ForLoop*) final; + + // Remake operations with TensorIndex + Statement* mutate(UnaryOp*) final; + Statement* mutate(BinaryOp*) final; + + UnrollPass(Fusion* _fusion, const std::vector& _incoming_exprs) + : fusion_(_fusion), incoming_exprs_(_incoming_exprs) {} + + void runPass(); + + public: + static std::vector runPass(Fusion* fusion, std::vector exprs) { + FusionGuard fg(fusion); + UnrollPass up(fusion, exprs); + up.runPass(); + return up.lowered_exprs; + } +}; + struct TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { private: std::vector lowered_exprs; From 6741cbe5d6af6cd6f548c5836c6af331c92e1420 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 6 Apr 2020 12:03:28 -0400 Subject: [PATCH 6/9] Factor out ir utilities that can be reused during lowering. --- torch/csrc/jit/codegen/cuda/lower2device.cpp | 72 ++++++-------------- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 57 +++------------- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 48 +++++++++++++ torch/csrc/jit/codegen/cuda/lower_utils.h | 18 +++++ 4 files changed, 97 insertions(+), 98 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 55e6db31e227..a11d37c90475 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -12,35 +12,6 @@ namespace torch { namespace jit { namespace fuser { -// START HELPER FUNCTIONS -namespace { - -bool isTV(const Val* const val) { - return val->getValType().value() == ValType::TensorView; -} - -// Check if we're a TensorView op that we can generate code for. -bool isTVOp(const Expr* expr) { - if (expr->nOutputs() == 1 && isTV(expr->output(0)) && - (expr->getExprType().value() == ExprType::BinaryOp || - expr->getExprType().value() == ExprType::UnaryOp)) - return true; - return false; -} - -TensorView* asTV(Val* val) { - TORCH_INTERNAL_ASSERT(isTV(val)); - return static_cast(val); -} - -const TensorView* asConstTV(const Val* const val) { - TORCH_INTERNAL_ASSERT(isTV(val)); - return static_cast(val); -} - -} // namespace -// END HELPER FUNCTIONS - // Clear out the last recorded computeAtView void GPULower::clearActiveView() { active_view_axis = 0; @@ -281,20 +252,20 @@ Statement* GPULower::mutate(ForLoop* fl) { } Statement* GPULower::mutate(UnaryOp* uop) { - if (!isTVOp(uop)) + if (!ir_utils::isTVOp(uop)) return OptOutMutator::mutate(uop); - IfThenElse* pred = getPredicate(asTV(uop->out())); + IfThenElse* pred = getPredicate(ir_utils::asTV(uop->out())); bool predicated = !pred->cond()->sameAs(new Int(1)); if (predicated) { pushBack(pred); active_scope = pred; } - TensorIndex* out = getConsumerIndex(asTV(uop->out())); + TensorIndex* out = getConsumerIndex(ir_utils::asTV(uop->out())); Val* in = uop->in(); - if (isTV(in)) - in = getProducerIndex(asTV(in), asTV(uop->out())); + if (ir_utils::isTV(in)) + in = getProducerIndex(ir_utils::asTV(in), ir_utils::asTV(uop->out())); Expr* new_op = new UnaryOp(uop->getUnaryOpType(), out, in); if (predicated) { @@ -307,25 +278,25 @@ Statement* GPULower::mutate(UnaryOp* uop) { } Statement* GPULower::mutate(BinaryOp* bop) { - if (!isTVOp(bop)) + if (!ir_utils::isTVOp(bop)) return OptOutMutator::mutate(bop); - IfThenElse* pred = getPredicate(asTV(bop->out())); + IfThenElse* pred = getPredicate(ir_utils::asTV(bop->out())); bool predicated = !pred->cond()->sameAs(new Int(1)); if (predicated) { pushBack(pred); active_scope = pred; } - TensorIndex* out = getConsumerIndex(asTV(bop->out())); + TensorIndex* out = getConsumerIndex(ir_utils::asTV(bop->out())); Val* lhs = bop->lhs(); Val* rhs = bop->rhs(); - if (isTV(lhs)) - lhs = getProducerIndex(asTV(lhs), asTV(bop->out())); + if (ir_utils::isTV(lhs)) + lhs = getProducerIndex(ir_utils::asTV(lhs), ir_utils::asTV(bop->out())); - if (isTV(rhs)) - rhs = getProducerIndex(asTV(rhs), asTV(bop->out())); + if (ir_utils::isTV(rhs)) + rhs = getProducerIndex(ir_utils::asTV(rhs), ir_utils::asTV(bop->out())); Expr* new_op = new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs); @@ -356,11 +327,11 @@ void GPULower::replaceSizes() { std::vector orig_intermediates; for (auto* val : fusion->deterministic_vals()) { - if (isTV(val)) { + if (ir_utils::isTV(val)) { if (fusion->hasInput(val) || fusion->hasOutput(val)) { - orig_inp_out.push_back(asTV(val)); + orig_inp_out.push_back(ir_utils::asTV(val)); } else { - orig_intermediates.push_back(asTV(val)); + orig_intermediates.push_back(ir_utils::asTV(val)); } } } @@ -427,8 +398,8 @@ void GPULower::replaceSizes() { // Now that we have the base tensor views. Lets fix its members. for (auto entry : tv_map) { - TensorView* orig_tv = asTV(entry.first); - TensorView* new_tv = asTV(entry.second); + TensorView* orig_tv = ir_utils::asTV(entry.first); + TensorView* new_tv = ir_utils::asTV(entry.second); // Domain in the new TV is the root domain, replay it like the original // domain. @@ -448,7 +419,8 @@ void GPULower::replaceSizes() { computeAtTV, " but one wasn't found."); new_tv->setComputeAt( - asTV(tv_map[computeAtTV]), (int)(orig_tv->getComputeAtAxis())); + ir_utils::asTV(tv_map[computeAtTV]), + (int)(orig_tv->getComputeAtAxis())); } } @@ -460,8 +432,8 @@ namespace { // Some pre-compilation checks void validate(Fusion* fusion) { for (Val* val : fusion->vals()) { - if (isTV(val)) { - TensorView* tv = asTV(val); + if (ir_utils::isTV(val)) { + TensorView* tv = ir_utils::asTV(val); for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) { IterDomain* id = tv->getComputeAtAxis(i); @@ -484,7 +456,7 @@ void validate(Fusion* fusion) { tv->getComputeAtAxis(), "."); } - } // if isTV + } // if ir_utils::isTV } // for(Val* val : fusion->vals()) } // validate diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index f70b9df22a2b..2b14cefd50d0 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -1,39 +1,11 @@ -#include #include +#include #include namespace torch { namespace jit { namespace fuser { -// HELPER NAMESPACE -namespace { - -bool isTV(const Val* const val) { - return val->getValType().value() == ValType::TensorView; -} - -// Check if we're a TensorView op that we can generate code for. -bool isTVOp(const Expr* expr) { - if (expr->nOutputs() == 1 && isTV(expr->output(0)) && - (expr->getExprType().value() == ExprType::BinaryOp || - expr->getExprType().value() == ExprType::UnaryOp)) - return true; - return false; -} - -TensorView* asTV(Val* val) { - TORCH_INTERNAL_ASSERT(isTV(val)); - return static_cast(val); -} - -const TensorView* asConstTV(const Val* const val) { - TORCH_INTERNAL_ASSERT(isTV(val)); - return static_cast(val); -} - -} // namespace - void UnrollPass::pushBack(Expr* expr) { if (active_scope == nullptr) lowered_exprs.push_back(expr); @@ -44,10 +16,7 @@ void UnrollPass::pushBack(Expr* expr) { // Custom dispatch for Expr, want to find out of it's a TV op Statement* UnrollPass::mutate(Expr* expr) { Statement* mutated_stmt = OptOutMutator::mutate(expr); - TORCH_INTERNAL_ASSERT( - mutated_stmt->isExpr(), - "Tried to generate a kernel but hit a non expression during lowering: ", - mutated_stmt); + ir_utils::ASSERT_EXPR(mutated_stmt); return mutated_stmt; } @@ -58,16 +27,12 @@ Statement* UnrollPass::mutate(ForLoop* fl) { std::vector mutated_exprs; bool is_mutated = false; for (auto expr : fl->body().exprs()) { - Statement* mutated_stmt = mutate(expr); - - TORCH_INTERNAL_ASSERT( - mutated_stmt->isExpr(), - "Tried to generate a kernel but hit a non expression during lowering: ", - mutated_stmt); - - mutated_exprs.push_back(static_cast(mutated_stmt)); - if (!(mutated_exprs.back()->sameAs(expr))) + if (ir_utils::isUnrolledFor(expr)) { is_mutated = true; + mutated_exprs.push_back(expr); + } else { + mutated_exprs.push_back(expr); + } } if (is_mutated) { @@ -115,11 +80,7 @@ void UnrollPass::runPass() { // Run through loop nests and further lower the expressions for (auto* expr : incoming_exprs_) { Statement* mutated_stmt = mutate(expr); - TORCH_INTERNAL_ASSERT( - mutated_stmt->isExpr(), - "Tried to generate a kernel but hit a non expression during lowering: ", - mutated_stmt); - lowered_exprs.push_back(static_cast(mutated_stmt)); + lowered_exprs.push_back(ir_utils::asExpr(mutated_stmt)); } } @@ -247,7 +208,7 @@ void LoopNestGenerator::updateLoopNest(TensorView* tv) { // Custom dispatch for Expr, want to find out of it's a TV op void LoopNestGenerator::handle(Expr* expr) { - if (!isTVOp(expr)) + if (!ir_utils::isTVOp(expr)) return; TensorView* out = static_cast(expr->output(0)); diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index d74b4fe8fb10..30abd998f176 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -244,6 +244,54 @@ Expr* clearScope(Expr* scope) { } } // namespace scope_utils + +namespace ir_utils { + +bool isTV(const Val* const val) { + return val->getValType().value() == ValType::TensorView; +} + +// Check if we're a TensorView op that we can generate code for. +bool isTVOp(const Expr* expr) { + if (expr->nOutputs() == 1 && isTV(expr->output(0)) && + (expr->getExprType().value() == ExprType::BinaryOp || + expr->getExprType().value() == ExprType::UnaryOp)) + return true; + return false; +} + +void ASSERT_EXPR(Statement* stmt) { + TORCH_INTERNAL_ASSERT( + stmt->isExpr(), + "Tried to generate a kernel but hit a non expression during lowering: ", + stmt); +} + +Expr* asExpr(Statement* stmt) { + ASSERT_EXPR(stmt); + return static_cast(stmt); +} + +TensorView* asTV(Val* val) { + TORCH_INTERNAL_ASSERT(isTV(val)); + return static_cast(val); +} + +const TensorView* asConstTV(const Val* const val) { + TORCH_INTERNAL_ASSERT(isTV(val)); + return static_cast(val); +} + +bool isUnrolledFor(const Expr* expr) { + if (expr->getExprType() != ExprType::ForLoop) { + return false; + } + return static_cast(expr)->iter_domain()->parallel_method() == + ParallelType::Unroll; +} + +} // namespace ir_utils + } // namespace fuser } // namespace jit } // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index f61826c0dd94..f632eeb02bba 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -40,6 +40,24 @@ Expr* clearScope(Expr* scope); unsigned int computeForDepth(Expr* scope); } // namespace scope_utils + +namespace ir_utils { + +bool isTV(const Val* const); + +bool isTVOp(const Expr*); + +void ASSERT_EXPR(Statement*); + +Expr* asExpr(Statement*); + +TensorView* asTV(Val*); + +const TensorView* asConstTV(const Val* const); + +bool isUnrolledFor(const Expr*); + +} // namespace ir_utils } // namespace fuser } // namespace jit } // namespace torch \ No newline at end of file From e71d8b2d8129fea1e0523f35bb5423deb3a39a00 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 7 Apr 2020 21:49:18 -0400 Subject: [PATCH 7/9] Unrolling loops seemingly working. --- test/cpp/jit/test_gpu.cpp | 66 ++++- test/cpp/jit/tests.h | 3 +- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 8 + torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 1 + torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 1 + torch/csrc/jit/codegen/cuda/ir_iostream.h | 2 +- torch/csrc/jit/codegen/cuda/lower2device.cpp | 118 ++++----- torch/csrc/jit/codegen/cuda/lower2device.h | 8 +- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 191 ++++++++++---- torch/csrc/jit/codegen/cuda/lower_loops.h | 34 ++- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 233 +++++++++++++++--- torch/csrc/jit/codegen/cuda/lower_utils.h | 16 +- torch/csrc/jit/codegen/cuda/mutator.cpp | 75 +++++- 13 files changed, 569 insertions(+), 187 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 9b3ab22de536..8a6db3a68f5a 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -937,13 +937,18 @@ void testGPU_FusionExecKernel() { // Register your outputs fusion.addOutput(tv3); + tv3->split(0, 4); + // For all inputs, computeAt the output inline, temporaries should be squeezed // between them - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); // Parallelize TV3 tv3->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); torch::jit::fuser::cuda::CudaKernel prog; @@ -1003,6 +1008,63 @@ void testGPU_FusionForLoop() { } } +void testGPU_FusionLoopUnroll() { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeDummyTensor(1); + TensorView* tv1 = makeDummyTensor(1); + + // Register your inputs + fusion.addInput(tv0); + fusion.addInput(tv1); + + // Do math with it, it returns a `Val*` but can be static_casted back to + // TensorView + TensorView* tv2 = static_cast(add(tv1, new Float(2.0))); + TensorView* tv3 = static_cast(add(tv0, tv2)); + + // Register your outputs + fusion.addOutput(tv3); + + tv3->split(0, 16); + tv3->split(0, 4); + + // For all inputs, computeAt the output inline, temporaries should be squeezed + // between them + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + // Parallelize + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + torch::jit::fuser::cuda::CudaKernel prog; + prog.device_ = 0; + prog.grid(1); // 1 CTA + prog.block(128); // 128 Threads + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input1 = at::ones({1, 128}, options); + at::Tensor input2 = at::ones_like(input1); + ; + at::Tensor output = at::empty_like(input1); + std::vector inputs{{input1, input2}}; + std::vector outputs{{output}}; + + torch::jit::fuser::cuda::compileKernel(fusion, prog); + torch::jit::fuser::cuda::runTestKernel(prog, inputs, outputs); + + at::Tensor check = at::full({1, 128}, 4, options); + ; + TORCH_CHECK(output.equal(check)); + +} } // namespace jit } // namespace torch diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 7aece9e352df..18dcba708152 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -116,7 +116,8 @@ namespace jit { _(GPU_FusionCodeGen2) \ _(GPU_FusionSimplePWise) \ _(GPU_FusionExecKernel) \ - _(GPU_FusionForLoop) + _(GPU_FusionForLoop) \ + _(GPU_FusionLoopUnroll) #else #define TH_FORALL_TESTS_CUDA(_) \ _(ArgumentSpec) \ diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 6a93e7d70f52..f1437efc1333 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -99,6 +99,14 @@ bool Val::isZeroInt() const { return false; } +bool Val::isOneInt() const { + if (isConstScalar() && getValType().value() == ValType::Scalar && + getDataType().value() == DataType::Int && + static_cast(this)->value().value() == 1) + return true; + return false; +} + c10::optional Val::getDataType() const { TORCH_INTERNAL_ASSERT( dtype_ != DataType::Null, "Value does not have a data type."); diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 5a7d2b5fa9a1..4ff3e84aa852 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -179,6 +179,7 @@ struct TORCH_CUDA_API Val : public Statement { } bool isZeroInt() const; + bool isOneInt() const; // Returns the Expr that this value is an output of, returns nullptr if none // was found diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index c26c47d567b2..7ca98384da50 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -291,6 +291,7 @@ void IRPrinter::handle(const ForLoop* const fl) { void IRPrinter::handle(const IfThenElse* const ite) { indent(); + // IF os << "if ( "; print_inline(ite->cond()); diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index 774eddc073c8..3ca1074da1e7 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -47,6 +47,7 @@ struct Add; */ struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch { +public: std::ostream& os; bool print_inline_ = false; @@ -65,7 +66,6 @@ struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch { void printHeader(Fusion* fusion, const std::string& kernel_name_); - public: IRPrinter(std::ostream& _os) : os(_os) {} virtual void handle(Fusion* const f); diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index a11d37c90475..af02554cea86 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -1,8 +1,12 @@ #include #include +#include +#include #include #include #include +#include +#include #include #include @@ -174,34 +178,6 @@ TensorIndex* GPULower::getConsumerIndex(TensorView* consumer) { return getLocalConsumerIndex(consumer); } -IfThenElse* GPULower::getPredicate(const TensorView* const pred_tv) { - TensorIndex* ti = new TensorIndex( - pred_tv, - IndexCompute::computeIndices( - pred_tv, scope_utils::getLoopIndices(active_scope))); - - std::vector all_preds = PredicateCompute::computePredicates(ti); - - std::vector preds; - - Int* one = new Int(1); - - for (Int* pred : all_preds) - if (!pred->sameAs(one)) - preds.push_back(pred); - - if (preds.size() == 0) { - return new IfThenElse(one, {}, {}, active_scope); - } - - Int* cond = preds[0]; - - for (decltype(preds.size()) i{1}; i < preds.size(); i++) - cond = static_cast(andOp(cond, preds[i])); - - return new IfThenElse(cond, {}, {}, active_scope); -} - void GPULower::pushBack(Expr* expr) { if (active_scope == nullptr) lowered_exprs.push_back(expr); @@ -218,35 +194,65 @@ Statement* GPULower::mutate(Expr* expr) { return mutated_stmt; } -Statement* GPULower::mutate(ForLoop* fl) { +Statement* GPULower::mutate(IfThenElse* ite) { Expr* prev_scope = active_scope; - active_scope = fl; + active_scope = ite; std::vector mutated_exprs; bool is_mutated = false; - for (auto expr : fl->body().exprs()) { + for (auto expr : ite->body().exprs()) { Statement* mutated_stmt = mutate(expr); + Expr* mutated_expr = ir_utils::asExpr(mutated_stmt); + mutated_exprs.push_back(mutated_expr); + is_mutated = is_mutated | (mutated_expr != expr); + } - TORCH_INTERNAL_ASSERT( - mutated_stmt->isExpr(), - "Tried to generate a kernel but hit a non expression during lowering: ", - mutated_stmt); - - mutated_exprs.push_back(static_cast(mutated_stmt)); - if (!(mutated_exprs.back()->sameAs(expr))) - is_mutated = true; + std::vector mutated_else_exprs; + for (auto expr : ite->elseBody().exprs()) { + Statement* mutated_stmt = mutate(expr); + Expr* mutated_expr = ir_utils::asExpr(mutated_stmt); + mutated_else_exprs.push_back(mutated_expr); + is_mutated = is_mutated | (mutated_expr != expr); } if (is_mutated) { - scope_utils::clearScope(active_scope); + ite->body().clear(); for (auto expr : mutated_exprs) - pushBack(expr); + ite->body().push_back(expr); + ite->elseBody().clear(); + for (auto expr : mutated_else_exprs) + ite->elseBody().push_back(expr); } active_scope = prev_scope; - if (is_mutated) - return new ForLoop( + if (is_mutated){ + auto new_ite = new IfThenElse( + ite->cond(), mutated_exprs, mutated_else_exprs, ite->parentScope()); + return new_ite; + } + + return ite; +} + +Statement* GPULower::mutate(ForLoop* fl) { + Expr* prev_scope = active_scope; + active_scope = fl; + std::vector mutated_exprs; + bool is_mutated = false; + for (auto expr : fl->body().exprs()) { + Statement* mutated_stmt = mutate(expr); + Expr* mutated_expr = ir_utils::asExpr(mutated_stmt); + mutated_exprs.push_back(mutated_expr); + is_mutated = is_mutated | (mutated_expr != expr); + } + + active_scope = prev_scope; + + if (is_mutated){ + auto newFL = new ForLoop( fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope()); + return newFL; + } return fl; } @@ -255,25 +261,12 @@ Statement* GPULower::mutate(UnaryOp* uop) { if (!ir_utils::isTVOp(uop)) return OptOutMutator::mutate(uop); - IfThenElse* pred = getPredicate(ir_utils::asTV(uop->out())); - bool predicated = !pred->cond()->sameAs(new Int(1)); - if (predicated) { - pushBack(pred); - active_scope = pred; - } - TensorIndex* out = getConsumerIndex(ir_utils::asTV(uop->out())); Val* in = uop->in(); if (ir_utils::isTV(in)) in = getProducerIndex(ir_utils::asTV(in), ir_utils::asTV(uop->out())); Expr* new_op = new UnaryOp(uop->getUnaryOpType(), out, in); - if (predicated) { - active_scope = scope_utils::getParent(active_scope); - pushBack(new_op); - return pred; - } - return new_op; } @@ -281,13 +274,6 @@ Statement* GPULower::mutate(BinaryOp* bop) { if (!ir_utils::isTVOp(bop)) return OptOutMutator::mutate(bop); - IfThenElse* pred = getPredicate(ir_utils::asTV(bop->out())); - bool predicated = !pred->cond()->sameAs(new Int(1)); - if (predicated) { - pushBack(pred); - active_scope = pred; - } - TensorIndex* out = getConsumerIndex(ir_utils::asTV(bop->out())); Val* lhs = bop->lhs(); Val* rhs = bop->rhs(); @@ -300,12 +286,6 @@ Statement* GPULower::mutate(BinaryOp* bop) { Expr* new_op = new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs); - if (predicated) { - pushBack(new_op); - active_scope = scope_utils::getParent(active_scope); - return pred; - } - return new_op; } diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index 7c29b6a88b82..ff9282612227 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -2,12 +2,7 @@ #include -#include #include -#include -#include -#include -#include #include #include @@ -70,6 +65,9 @@ struct TORCH_CUDA_API GPULower : public OptOutMutator { // Open the for loop. Statement* mutate(ForLoop*) final; + // Open the for loop. + Statement* mutate(IfThenElse*) final; + // Remake operations with TensorIndex Statement* mutate(UnaryOp*) final; Statement* mutate(BinaryOp*) final; diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index 2b14cefd50d0..dcfc0384e524 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -1,65 +1,154 @@ -#include #include +#include #include +#include +#include +#include + namespace torch { namespace jit { namespace fuser { +// all the way in the loop nest, grab predicate +/* +for( i : ceil(I/4) ) { + for( j : ceil(J/128) ) { -void UnrollPass::pushBack(Expr* expr) { - if (active_scope == nullptr) - lowered_exprs.push_back(expr); - else - scope_utils::pushBack(active_scope, expr); + if( i * 4 + 3 < I && j * 128 + 127 < J ){ + for( k : 4) + for( l : 128 ) + T0[ ( i * 4 + k ) * J + j * 128 + l ] = … + } else { + for( k : 4 ) + for( l : 128 ) + if( i * 4 + k < I && j * 128 + l < J) + T0[ ( i * 4 + k ) * J + j * 128 + l ] = … + } + + } } +*/ // Custom dispatch for Expr, want to find out of it's a TV op -Statement* UnrollPass::mutate(Expr* expr) { - Statement* mutated_stmt = OptOutMutator::mutate(expr); - ir_utils::ASSERT_EXPR(mutated_stmt); - return mutated_stmt; +void UnrollPass::handle(Expr* expr) { + OptOutDispatch::handle(expr); } -// Open the for loop. -Statement* UnrollPass::mutate(ForLoop* fl) { - Expr* prev_scope = active_scope; - active_scope = fl; - std::vector mutated_exprs; - bool is_mutated = false; - for (auto expr : fl->body().exprs()) { - if (ir_utils::isUnrolledFor(expr)) { - is_mutated = true; - mutated_exprs.push_back(expr); - } else { - mutated_exprs.push_back(expr); - } - } +namespace { +Int* getPredicate(const TensorView* const pred_tv, std::vector indices) { + TensorIndex* ti = + new TensorIndex(pred_tv, IndexCompute::computeIndices(pred_tv, indices)); + std::vector all_preds = PredicateCompute::computePredicates(ti); - if (is_mutated) { - scope_utils::clearScope(active_scope); - for (auto expr : mutated_exprs) - pushBack(expr); - } + std::vector preds; - active_scope = prev_scope; + Int* one = new Int(1); - if (is_mutated) - return new ForLoop( - fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope()); + for (Int* pred : all_preds) + if (!pred->sameAs(one)) + preds.push_back(pred); - return fl; -} + Int* cond = preds[0]; + + for (decltype(preds.size()) i{1}; i < preds.size(); i++) + cond = static_cast(andOp(cond, preds[i])); -// Remake operations with TensorIndex -Statement* UnrollPass::mutate(UnaryOp* uop) { - return uop; + return cond; } -Statement* UnrollPass::mutate(BinaryOp* bop) { - return bop; +} // namespace + +// Open the for loop. +void UnrollPass::handle(ForLoop* fl) { + + // Setup for loop scoping + for_loops.push_back(fl); + bool prev_unroll = within_unroll; + within_unroll = ir_utils::isUnrolledFor(fl) || within_unroll; + + for (auto expr : fl->body().exprs()) { + OptOutDispatch::handle(expr); + + if (ir_utils::isTVOp(expr)) { + if (within_unroll) { + TORCH_INTERNAL_ASSERT( + fl->body().size() == 1, + "Expected to only find a single expr in an inner most for loop inside an unrolled scope."); + + // Indices used to detect when we can unroll a loop safely + // For loops outside the unroll, it's just he index, for loops inside + // the unroll, if it's a thread it's the thread index, otherwise it's + // the size-1 + std::vector unroll_pred_inds; + auto it = for_loops.begin(); + while (it != for_loops.end()) { + if (ir_utils::isUnrolledFor(*it)) + break; + unroll_pred_inds.push_back((*it)->index()); + it++; + } + + // This is the outer most loop that needs to be unrolled + ForLoop* first_unroll = *it; + + // Indicies inside the unroll + while (it != for_loops.end()) { + IterDomain* id = (*it)->iter_domain(); + if (id->isThread()) + unroll_pred_inds.push_back((*it)->index()); + else + unroll_pred_inds.push_back(sub(id->extent(), new Int(1))); + it++; + } + + // Tensorview of the predicate determining op + TensorView* out = + ir_utils::asTV(ir_utils::asExpr(expr)->outputs()[0]); + + // Make predicates for the unrolling, and the epilogue + Int* unroll_predicate = getPredicate(out, unroll_pred_inds); + Int* inline_predicate = + getPredicate(out, scope_utils::getLoopIndices(for_loops.back())); + + // Make the IfThenElse controlling the unrolling + IfThenElse* unroll_ite = new IfThenElse(unroll_predicate, {}, {}, first_unroll->parentScope() ); + // Get the loop nest for the unrolled path + ForLoop* unrolled_loop = scope_utils::cloneLoopNest(first_unroll, unroll_ite); + ForLoop* inlined_loop = scope_utils::cloneLoopNest(first_unroll, unroll_ite); + Expr* inner_most_inlined_loop = scope_utils::firstInnerMostScope(inlined_loop); + // Get the predicate for the non-unrolled (inline) path + IfThenElse* inline_ite = + new IfThenElse(inline_predicate, {expr}, {}, inner_most_inlined_loop); + std::unordered_map inline_replacement_map; + inline_replacement_map.emplace( std::pair(expr, inline_ite)); + scope_utils::replaceExprsInScope(inner_most_inlined_loop, inline_replacement_map); + + unroll_ite->body().push_back(unrolled_loop); + unroll_ite->elseBody().push_back(inlined_loop); + + loop_replacement_map.insert({first_unroll, unroll_ite}); + + } else { + // ! within_unroll + TensorView* out = + ir_utils::asTV(ir_utils::asExpr(expr)->outputs()[0]); + Int* pred = + getPredicate(out, scope_utils::getLoopIndices(for_loops.back())); + if(!pred->isOneInt()){ + IfThenElse* inline_ite = + new IfThenElse(pred, {expr}, {}, for_loops.back()); + for_loops.back()->body().insert_before(expr, inline_ite); + for_loops.back()->body().erase(expr); + } + } + } // if (ir_utils::isTVOp(expr)) + } // for (auto expr : fl->body().exprs()) + + for_loops.pop_back(); + bool within_unroll = prev_unroll; } // Generate the loop nest structure and place it in lowered_exprs -void UnrollPass::runPass() { +void UnrollPass::computeMap() { FusionGuard fg(fusion_); // Likely we lowered this fusion, we can simply return the lowered expressions @@ -79,11 +168,27 @@ void UnrollPass::runPass() { // Run through loop nests and further lower the expressions for (auto* expr : incoming_exprs_) { - Statement* mutated_stmt = mutate(expr); - lowered_exprs.push_back(ir_utils::asExpr(mutated_stmt)); + OptOutDispatch::handle(expr); } } + std::vector UnrollPass::runPass(Fusion* fusion, std::vector exprs) { + FusionGuard fg(fusion); + UnrollPass up(fusion, exprs); + up.computeMap(); + std::vector mutated_exprs; + for(Expr* expr : exprs){ + if(up.loop_replacement_map.find(expr) != up.loop_replacement_map.end()){ + mutated_exprs.push_back(up.loop_replacement_map[expr]); + }else{ + if(ir_utils::isScope(expr)) + scope_utils::replaceExprsInScope(expr, up.loop_replacement_map); + mutated_exprs.push_back(expr); + } + } + return mutated_exprs; + } + Allocate* LoopNestGenerator::getAlloc(TensorView* tv) { TORCH_INTERNAL_ASSERT( !(FusionGuard::getCurFusion()->hasInput(tv) || diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h index f82ff9b74955..46aeae110842 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.h +++ b/torch/csrc/jit/codegen/cuda/lower_loops.h @@ -2,49 +2,43 @@ #include #include + +#include #include namespace torch { namespace jit { namespace fuser { -struct UnrollPass : public OptOutMutator { +struct UnrollPass : public OptOutDispatch { private: + std::unordered_map loop_replacement_map; Fusion* fusion_; - std::vector lowered_exprs; const std::vector& incoming_exprs_; - Expr* active_scope = nullptr; + + // Keep all for loops conveniently to make unrolling easier + std::vector for_loops; + + // keep track if we're within an unrolled loop + bool within_unroll = false; // Track the last computeAt TensorView and axis const TensorView* active_view; unsigned int active_view_axis; - // Wrap pushBack in lower_utils if active_scope is null we want it to go - // straight to lower_exprs - void pushBack(Expr*); - // Custom dispatch for Expr, want to find out of it's a TV op - Statement* mutate(Expr*) final; + void handle(Expr*) final; // Open the for loop. - Statement* mutate(ForLoop*) final; - - // Remake operations with TensorIndex - Statement* mutate(UnaryOp*) final; - Statement* mutate(BinaryOp*) final; + void handle(ForLoop*) final; UnrollPass(Fusion* _fusion, const std::vector& _incoming_exprs) : fusion_(_fusion), incoming_exprs_(_incoming_exprs) {} - void runPass(); + void computeMap(); public: - static std::vector runPass(Fusion* fusion, std::vector exprs) { - FusionGuard fg(fusion); - UnrollPass up(fusion, exprs); - up.runPass(); - return up.lowered_exprs; - } + static std::vector runPass(Fusion* fusion, std::vector exprs); }; struct TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch { diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 30abd998f176..bb9056365a86 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -1,13 +1,12 @@ -#pragma once - #include - +#include namespace torch { namespace jit { namespace fuser { namespace scope_utils { +// START SCOPE HELPER SYSTEMS namespace { struct forLoopIndices : private OptInDispatch { @@ -29,33 +28,34 @@ struct forLoopIndices : private OptInDispatch { Expr* it = scope; while (it != nullptr) { fli.handle(it); - it = getParent(it); + it = scope_utils::getParent(it); } return fli.inds_; } }; -struct parentScope : private OptInDispatch { +struct forLoopIDs : private OptInDispatch { private: - Expr* parent_ = nullptr; - + std::vector IDs_; void handle(ForLoop* fl) final { - parent_ = fl->parentScope(); + IDs_.insert(IDs_.begin(), fl->iter_domain()); } - void handle(IfThenElse* ite) final { - parent_ = ite->parentScope(); - } + void handle(IfThenElse* ite) final {} void handle(Expr* expr) final { OptInDispatch::handle(expr); } public: - static Expr* get(Expr* scope) { - parentScope sp; - sp.handle(scope); - return sp.parent_; + static std::vector get(Expr* scope) { + forLoopIDs fli; + Expr* it = scope; + while (it != nullptr) { + fli.handle(it); + it = scope_utils::getParent(it); + } + return fli.IDs_; } }; @@ -79,46 +79,48 @@ struct forLoopCount : private OptInDispatch { Expr* it = scope; while (it != nullptr) { flc.handle(it); - it = getParent(it); + it = scope_utils::getParent(it); } return flc.count_; } }; -struct forLoopIDs : private OptInDispatch { +struct scopePushBack : private OptInDispatch { private: - std::vector IDs_; + Expr* _expr = nullptr; void handle(ForLoop* fl) final { - IDs_.insert(IDs_.begin(), fl->iter_domain()); + fl->body().push_back(_expr); } - void handle(IfThenElse* ite) final {} + void handle(IfThenElse* ite) final { + ite->body().push_back(_expr); + } void handle(Expr* expr) final { OptInDispatch::handle(expr); } public: - static std::vector get(Expr* scope) { - forLoopIDs fli; - Expr* it = scope; - while (it != nullptr) { - fli.handle(it); - it = getParent(it); - } - return fli.IDs_; + static void push(Expr* scope, Expr* expr) { + scopePushBack pb; + TORCH_INTERNAL_ASSERT( + expr != nullptr && scope != nullptr, + "Cannot push back, scope or expr is a nullptr."); + pb._expr = expr; + pb.handle(scope); } }; -struct scopePushBack : private OptInDispatch { +struct parentScope : private OptInDispatch { private: - Expr* _expr = nullptr; + Expr* parent_ = nullptr; + void handle(ForLoop* fl) final { - fl->body().push_back(_expr); + parent_ = fl->parentScope(); } void handle(IfThenElse* ite) final { - ite->body().push_back(_expr); + parent_ = ite->parentScope(); } void handle(Expr* expr) final { @@ -126,13 +128,10 @@ struct scopePushBack : private OptInDispatch { } public: - static void push(Expr* scope, Expr* expr) { - scopePushBack pb; - TORCH_INTERNAL_ASSERT( - expr != nullptr && scope != nullptr, - "Cannot push back, scope or expr is a nullptr."); - pb._expr = expr; - pb.handle(scope); + static Expr* get(Expr* scope) { + parentScope sp; + sp.handle(scope); + return sp.parent_; } }; @@ -167,6 +166,137 @@ void assertScope(Expr* expr) { "Assert Scope failed when calling a scope_util function."); } +struct CloneLoopNest : public OptOutMutator { +private: + + Expr* parent_scope_ = nullptr; + Expr* to_clone_ = nullptr; + + Statement* mutate(ForLoop* fl){ + std::vector mutated_exprs; + for(Expr* expr : fl->body().exprs()){ + mutated_exprs.push_back(ir_utils::asExpr(OptOutMutator::mutate(expr))); + } + if( fl == to_clone_ ) + return new ForLoop(fl->index(), fl->iter_domain(), mutated_exprs, parent_scope_); + return new ForLoop(fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope()); + + } + + CloneLoopNest(Expr* _to_clone, Expr* _parent_scope) : parent_scope_(_parent_scope), to_clone_(_to_clone) {} + +public: + static ForLoop* getClone(ForLoop* _to_clone, Expr* _parent_scope){ + TORCH_INTERNAL_ASSERT(_to_clone != nullptr, "Tried to clone a scope, but received a nullptr."); + CloneLoopNest cln(_to_clone, _parent_scope); + return ir_utils::asForLoop(ir_utils::asExpr(cln.mutate(_to_clone))); + } + +}; + +struct ReplaceExprsInScope : public OptOutDispatch { +private: + + std::unordered_map replacement_map_; + + + void handle(Expr* expr){ + OptOutDispatch::handle(expr); + } + + void handle(ForLoop* fl){ + for(Expr* expr : fl->body().exprs()){ + auto it = replacement_map_.find(expr); + if(it == replacement_map_.end()){ + handle(expr); + continue; + } + fl->body().insert_before(expr, replacement_map_[expr]); + fl->body().erase(expr); + } + } + + void handle(IfThenElse* ite){ + for(Expr* expr : ite->body().exprs()){ + auto it = replacement_map_.find(expr); + if(it == replacement_map_.end()){ + handle(expr); + continue; + } + ite->body().insert_before(expr, replacement_map_[expr]); + ite->body().erase(expr); + } + for(Expr* expr : ite->elseBody().exprs()){ + auto it = replacement_map_.find(expr); + if(it == replacement_map_.end()){ + handle(expr); + continue; + } + ite->elseBody().insert_before(expr, replacement_map_[expr]); + ite->elseBody().erase(expr); + } + } + + ReplaceExprsInScope(std::unordered_map _replacement_map) : replacement_map_(_replacement_map){} + +public: + static void replace(Expr* scope, std::unordered_map replacement_map){ + ReplaceExprsInScope reis(replacement_map); + reis.handle(scope); + } + +}; + +struct FirstInnerMostScope : private OptInDispatch { + private: + Expr* active_scope = nullptr; + + void handle(ForLoop* fl) final { + for(auto expr : fl->body().exprs()){ + if(ir_utils::isScope(expr)){ + active_scope = expr; + return; + } + } + active_scope = nullptr; + } + + void handle(IfThenElse* ite) final { + for(auto expr : ite->body().exprs()){ + if(ir_utils::isScope(expr)){ + active_scope = expr; + return; + } + } + for(auto expr : ite->elseBody().exprs()){ + if(ir_utils::isScope(expr)){ + active_scope = expr; + return; + } + } + active_scope = nullptr; + } + + Expr* getInner(Expr* expr) { + OptInDispatch::handle(expr); + return active_scope; + } + + public: + static Expr* get(Expr* scope) { + TORCH_INTERNAL_ASSERT( + scope != nullptr, + "Tried to get inner most scope, but was provided nullptr."); + + FirstInnerMostScope fims; + Expr* inner = fims.getInner(scope); + while (fims.getInner(inner) != nullptr) + inner = fims.getInner(inner); + return inner; + } +}; + +// END SCOPE HELPER SYSTEMS } // namespace // Grab the index variables of the active loop nest @@ -243,6 +373,21 @@ Expr* clearScope(Expr* scope) { return scope; } + +ForLoop* cloneLoopNest(ForLoop* to_clone, Expr* parent_scope){ + return CloneLoopNest::getClone(to_clone, parent_scope); +} + +void replaceExprsInScope(Expr* scope, std::unordered_map replacement_map){ + TORCH_INTERNAL_ASSERT(replacement_map.find(scope) == replacement_map.end(), + "Error trying to replace expressions in a scope, scope wants to be replaced entirely."); + ReplaceExprsInScope::replace(std::move(scope), std::move(replacement_map)); +} + +Expr* firstInnerMostScope(Expr* scope){ + return FirstInnerMostScope::get(scope); +} + } // namespace scope_utils namespace ir_utils { @@ -277,6 +422,16 @@ TensorView* asTV(Val* val) { return static_cast(val); } +bool isScope(const Expr* expr){ + return expr->getExprType() == ExprType::ForLoop || expr->getExprType() == ExprType::IfThenElse; +} + +ForLoop* asForLoop(Statement* stmt) { + Expr* expr = asExpr(stmt); + TORCH_INTERNAL_ASSERT(expr->getExprType() == ExprType::ForLoop); + return static_cast(expr); +} + const TensorView* asConstTV(const Val* const val) { TORCH_INTERNAL_ASSERT(isTV(val)); return static_cast(val); diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index f632eeb02bba..980cbac9da5f 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -36,11 +36,17 @@ Expr* closeScope(Expr* scope); // Clear all expressions from the scope Expr* clearScope(Expr* scope); -// Track how far our for loop scope is -unsigned int computeForDepth(Expr* scope); +// Provide a new for loop matching the one provided +ForLoop* cloneLoopNest(ForLoop* to_clone, Expr* parent_scope); + +// Run through a scope and replace expressions inside with replacement_map +void replaceExprsInScope(Expr* scope, std::unordered_map replacement_map); + +Expr* firstInnerMostScope(Expr* scope); } // namespace scope_utils + namespace ir_utils { bool isTV(const Val* const); @@ -49,10 +55,14 @@ bool isTVOp(const Expr*); void ASSERT_EXPR(Statement*); +bool isScope(const Expr*); + Expr* asExpr(Statement*); TensorView* asTV(Val*); +ForLoop* asForLoop(Statement*); + const TensorView* asConstTV(const Val* const); bool isUnrolledFor(const Expr*); @@ -60,4 +70,4 @@ bool isUnrolledFor(const Expr*); } // namespace ir_utils } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index d91c57023a6e..190951f17e50 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -176,12 +176,79 @@ Statement* OptOutMutator::mutate(BinaryOp* bop) { return new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs); } -Statement* OptOutMutator::mutate(ForLoop* n) { - return n; +Statement* OptOutMutator::mutate(ForLoop* fl) { + + Val* index = mutateAsVal(fl->index())->asVal(); + Val* val_id = mutateAsVal(fl->iter_domain())->asVal(); + + TORCH_INTERNAL_ASSERT(val_id->getValType() == ValType::IterDomain); + IterDomain* id = static_cast(val_id); + + bool is_mutated = !index->sameAs(fl->index()); + is_mutated = is_mutated | !id->sameAs(fl->iter_domain()); + + std::vector mutated_exprs; + for (auto expr : fl->body().exprs()) { + Statement* mutated_stmt = mutate(expr); + TORCH_INTERNAL_ASSERT( + mutated_stmt->isExpr(), + "While mutating a for loop, received a non-expression for a body entry."); + Expr* mutated_expr = static_cast(mutated_stmt); + mutated_exprs.push_back(mutated_expr); + // could use sameAs here, but we'd have to check the output value separately + is_mutated = is_mutated | (mutated_expr != expr); + } + + if (is_mutated){ + auto newFL = new ForLoop( + index, id, mutated_exprs, fl->parentScope()); + return newFL; + } + + return fl; } -Statement* OptOutMutator::mutate(IfThenElse* n) { - return n; +Statement* OptOutMutator::mutate(IfThenElse* ite) { + + Val* val_cond = mutateAsVal(ite->cond())->asVal(); + TORCH_INTERNAL_ASSERT( + val_cond->getValType().value() == ValType::Scalar && + val_cond->getDataType().value() == DataType::Int); + Int* cond = static_cast(cond); + + bool is_mutated = !cond->sameAs(ite->cond()); + + std::vector mutated_exprs; + for (auto expr : ite->body().exprs()) { + Statement* mutated_stmt = mutate(expr); + TORCH_INTERNAL_ASSERT( + mutated_stmt->isExpr(), + "While mutating a for loop, received a non-expression for a body entry."); + Expr* mutated_expr = static_cast(mutated_stmt); + mutated_exprs.push_back(mutated_expr); + // could use sameAs here, but we'd have to check the output value separately + is_mutated = is_mutated | (mutated_expr != expr); + } + + std::vector mutated_else_exprs; + for (auto expr : ite->elseBody().exprs()) { + Statement* mutated_stmt = mutate(expr); + TORCH_INTERNAL_ASSERT( + mutated_stmt->isExpr(), + "While mutating a for loop, received a non-expression for a body entry."); + Expr* mutated_expr = static_cast(mutated_stmt); + mutated_else_exprs.push_back(mutated_expr); + // could use sameAs here, but we'd have to check the output value separately + is_mutated = is_mutated | (mutated_expr != expr); + } + + if (is_mutated){ + auto newITE = new IfThenElse( + cond, ite->body().exprs(), ite->elseBody().exprs(), ite->parentScope()); + return newITE; + } + + return ite; } // START REPLACE ALL From 1ba0c74e7e2069e13522bb89ae205d79ab5e2911 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 7 Apr 2020 22:00:06 -0400 Subject: [PATCH 8/9] Clang. --- test/cpp/jit/test_gpu.cpp | 3 +- torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 1 - torch/csrc/jit/codegen/cuda/ir_iostream.h | 2 +- torch/csrc/jit/codegen/cuda/lower2device.cpp | 4 +- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 73 ++++++++-------- torch/csrc/jit/codegen/cuda/lower_loops.h | 2 +- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 91 +++++++++++--------- torch/csrc/jit/codegen/cuda/lower_utils.h | 5 +- torch/csrc/jit/codegen/cuda/mutator.cpp | 9 +- 9 files changed, 98 insertions(+), 92 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 8a6db3a68f5a..cc10242704db 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -938,7 +938,7 @@ void testGPU_FusionExecKernel() { fusion.addOutput(tv3); tv3->split(0, 4); - + // For all inputs, computeAt the output inline, temporaries should be squeezed // between them tv0->computeAt(tv3, 1); @@ -1063,7 +1063,6 @@ void testGPU_FusionLoopUnroll() { at::Tensor check = at::full({1, 128}, 4, options); ; TORCH_CHECK(output.equal(check)); - } } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 7ca98384da50..c26c47d567b2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -291,7 +291,6 @@ void IRPrinter::handle(const ForLoop* const fl) { void IRPrinter::handle(const IfThenElse* const ite) { indent(); - // IF os << "if ( "; print_inline(ite->cond()); diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index 3ca1074da1e7..65b0693ff8aa 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -47,7 +47,7 @@ struct Add; */ struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch { -public: + public: std::ostream& os; bool print_inline_ = false; diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index af02554cea86..09d8930f4964 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -225,7 +225,7 @@ Statement* GPULower::mutate(IfThenElse* ite) { active_scope = prev_scope; - if (is_mutated){ + if (is_mutated) { auto new_ite = new IfThenElse( ite->cond(), mutated_exprs, mutated_else_exprs, ite->parentScope()); return new_ite; @@ -248,7 +248,7 @@ Statement* GPULower::mutate(ForLoop* fl) { active_scope = prev_scope; - if (is_mutated){ + if (is_mutated) { auto newFL = new ForLoop( fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope()); return newFL; diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index dcfc0384e524..92f411064960 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include #include @@ -31,7 +31,7 @@ for( i : ceil(I/4) ) { // Custom dispatch for Expr, want to find out of it's a TV op void UnrollPass::handle(Expr* expr) { - OptOutDispatch::handle(expr); + OptOutDispatch::handle(expr); } namespace { @@ -59,7 +59,6 @@ Int* getPredicate(const TensorView* const pred_tv, std::vector indices) { // Open the for loop. void UnrollPass::handle(ForLoop* fl) { - // Setup for loop scoping for_loops.push_back(fl); bool prev_unroll = within_unroll; @@ -101,26 +100,31 @@ void UnrollPass::handle(ForLoop* fl) { } // Tensorview of the predicate determining op - TensorView* out = - ir_utils::asTV(ir_utils::asExpr(expr)->outputs()[0]); - + TensorView* out = ir_utils::asTV(ir_utils::asExpr(expr)->outputs()[0]); + // Make predicates for the unrolling, and the epilogue Int* unroll_predicate = getPredicate(out, unroll_pred_inds); Int* inline_predicate = getPredicate(out, scope_utils::getLoopIndices(for_loops.back())); - + // Make the IfThenElse controlling the unrolling - IfThenElse* unroll_ite = new IfThenElse(unroll_predicate, {}, {}, first_unroll->parentScope() ); + IfThenElse* unroll_ite = new IfThenElse( + unroll_predicate, {}, {}, first_unroll->parentScope()); // Get the loop nest for the unrolled path - ForLoop* unrolled_loop = scope_utils::cloneLoopNest(first_unroll, unroll_ite); - ForLoop* inlined_loop = scope_utils::cloneLoopNest(first_unroll, unroll_ite); - Expr* inner_most_inlined_loop = scope_utils::firstInnerMostScope(inlined_loop); + ForLoop* unrolled_loop = + scope_utils::cloneLoopNest(first_unroll, unroll_ite); + ForLoop* inlined_loop = + scope_utils::cloneLoopNest(first_unroll, unroll_ite); + Expr* inner_most_inlined_loop = + scope_utils::firstInnerMostScope(inlined_loop); // Get the predicate for the non-unrolled (inline) path - IfThenElse* inline_ite = - new IfThenElse(inline_predicate, {expr}, {}, inner_most_inlined_loop); + IfThenElse* inline_ite = new IfThenElse( + inline_predicate, {expr}, {}, inner_most_inlined_loop); std::unordered_map inline_replacement_map; - inline_replacement_map.emplace( std::pair(expr, inline_ite)); - scope_utils::replaceExprsInScope(inner_most_inlined_loop, inline_replacement_map); + inline_replacement_map.emplace( + std::pair(expr, inline_ite)); + scope_utils::replaceExprsInScope( + inner_most_inlined_loop, inline_replacement_map); unroll_ite->body().push_back(unrolled_loop); unroll_ite->elseBody().push_back(inlined_loop); @@ -129,20 +133,19 @@ void UnrollPass::handle(ForLoop* fl) { } else { // ! within_unroll - TensorView* out = - ir_utils::asTV(ir_utils::asExpr(expr)->outputs()[0]); + TensorView* out = ir_utils::asTV(ir_utils::asExpr(expr)->outputs()[0]); Int* pred = getPredicate(out, scope_utils::getLoopIndices(for_loops.back())); - if(!pred->isOneInt()){ + if (!pred->isOneInt()) { IfThenElse* inline_ite = - new IfThenElse(pred, {expr}, {}, for_loops.back()); + new IfThenElse(pred, {expr}, {}, for_loops.back()); for_loops.back()->body().insert_before(expr, inline_ite); for_loops.back()->body().erase(expr); } } } // if (ir_utils::isTVOp(expr)) } // for (auto expr : fl->body().exprs()) - + for_loops.pop_back(); bool within_unroll = prev_unroll; } @@ -172,22 +175,24 @@ void UnrollPass::computeMap() { } } - std::vector UnrollPass::runPass(Fusion* fusion, std::vector exprs) { - FusionGuard fg(fusion); - UnrollPass up(fusion, exprs); - up.computeMap(); - std::vector mutated_exprs; - for(Expr* expr : exprs){ - if(up.loop_replacement_map.find(expr) != up.loop_replacement_map.end()){ - mutated_exprs.push_back(up.loop_replacement_map[expr]); - }else{ - if(ir_utils::isScope(expr)) - scope_utils::replaceExprsInScope(expr, up.loop_replacement_map); - mutated_exprs.push_back(expr); - } +std::vector UnrollPass::runPass( + Fusion* fusion, + std::vector exprs) { + FusionGuard fg(fusion); + UnrollPass up(fusion, exprs); + up.computeMap(); + std::vector mutated_exprs; + for (Expr* expr : exprs) { + if (up.loop_replacement_map.find(expr) != up.loop_replacement_map.end()) { + mutated_exprs.push_back(up.loop_replacement_map[expr]); + } else { + if (ir_utils::isScope(expr)) + scope_utils::replaceExprsInScope(expr, up.loop_replacement_map); + mutated_exprs.push_back(expr); } - return mutated_exprs; } + return mutated_exprs; +} Allocate* LoopNestGenerator::getAlloc(TensorView* tv) { TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h index 46aeae110842..ac2eac5f12b6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.h +++ b/torch/csrc/jit/codegen/cuda/lower_loops.h @@ -3,8 +3,8 @@ #include -#include #include +#include namespace torch { namespace jit { diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index bb9056365a86..9b794dec7331 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -167,47 +167,47 @@ void assertScope(Expr* expr) { } struct CloneLoopNest : public OptOutMutator { -private: - + private: Expr* parent_scope_ = nullptr; Expr* to_clone_ = nullptr; - Statement* mutate(ForLoop* fl){ + Statement* mutate(ForLoop* fl) { std::vector mutated_exprs; - for(Expr* expr : fl->body().exprs()){ + for (Expr* expr : fl->body().exprs()) { mutated_exprs.push_back(ir_utils::asExpr(OptOutMutator::mutate(expr))); } - if( fl == to_clone_ ) - return new ForLoop(fl->index(), fl->iter_domain(), mutated_exprs, parent_scope_); - return new ForLoop(fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope()); - + if (fl == to_clone_) + return new ForLoop( + fl->index(), fl->iter_domain(), mutated_exprs, parent_scope_); + return new ForLoop( + fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope()); } - CloneLoopNest(Expr* _to_clone, Expr* _parent_scope) : parent_scope_(_parent_scope), to_clone_(_to_clone) {} + CloneLoopNest(Expr* _to_clone, Expr* _parent_scope) + : parent_scope_(_parent_scope), to_clone_(_to_clone) {} -public: - static ForLoop* getClone(ForLoop* _to_clone, Expr* _parent_scope){ - TORCH_INTERNAL_ASSERT(_to_clone != nullptr, "Tried to clone a scope, but received a nullptr."); + public: + static ForLoop* getClone(ForLoop* _to_clone, Expr* _parent_scope) { + TORCH_INTERNAL_ASSERT( + _to_clone != nullptr, + "Tried to clone a scope, but received a nullptr."); CloneLoopNest cln(_to_clone, _parent_scope); return ir_utils::asForLoop(ir_utils::asExpr(cln.mutate(_to_clone))); } - }; struct ReplaceExprsInScope : public OptOutDispatch { -private: - + private: std::unordered_map replacement_map_; - - void handle(Expr* expr){ + void handle(Expr* expr) { OptOutDispatch::handle(expr); } - void handle(ForLoop* fl){ - for(Expr* expr : fl->body().exprs()){ + void handle(ForLoop* fl) { + for (Expr* expr : fl->body().exprs()) { auto it = replacement_map_.find(expr); - if(it == replacement_map_.end()){ + if (it == replacement_map_.end()) { handle(expr); continue; } @@ -216,19 +216,19 @@ struct ReplaceExprsInScope : public OptOutDispatch { } } - void handle(IfThenElse* ite){ - for(Expr* expr : ite->body().exprs()){ + void handle(IfThenElse* ite) { + for (Expr* expr : ite->body().exprs()) { auto it = replacement_map_.find(expr); - if(it == replacement_map_.end()){ + if (it == replacement_map_.end()) { handle(expr); continue; } ite->body().insert_before(expr, replacement_map_[expr]); ite->body().erase(expr); } - for(Expr* expr : ite->elseBody().exprs()){ + for (Expr* expr : ite->elseBody().exprs()) { auto it = replacement_map_.find(expr); - if(it == replacement_map_.end()){ + if (it == replacement_map_.end()) { handle(expr); continue; } @@ -237,14 +237,16 @@ struct ReplaceExprsInScope : public OptOutDispatch { } } - ReplaceExprsInScope(std::unordered_map _replacement_map) : replacement_map_(_replacement_map){} + ReplaceExprsInScope(std::unordered_map _replacement_map) + : replacement_map_(_replacement_map) {} -public: - static void replace(Expr* scope, std::unordered_map replacement_map){ + public: + static void replace( + Expr* scope, + std::unordered_map replacement_map) { ReplaceExprsInScope reis(replacement_map); reis.handle(scope); } - }; struct FirstInnerMostScope : private OptInDispatch { @@ -252,8 +254,8 @@ struct FirstInnerMostScope : private OptInDispatch { Expr* active_scope = nullptr; void handle(ForLoop* fl) final { - for(auto expr : fl->body().exprs()){ - if(ir_utils::isScope(expr)){ + for (auto expr : fl->body().exprs()) { + if (ir_utils::isScope(expr)) { active_scope = expr; return; } @@ -262,14 +264,14 @@ struct FirstInnerMostScope : private OptInDispatch { } void handle(IfThenElse* ite) final { - for(auto expr : ite->body().exprs()){ - if(ir_utils::isScope(expr)){ + for (auto expr : ite->body().exprs()) { + if (ir_utils::isScope(expr)) { active_scope = expr; return; } } - for(auto expr : ite->elseBody().exprs()){ - if(ir_utils::isScope(expr)){ + for (auto expr : ite->elseBody().exprs()) { + if (ir_utils::isScope(expr)) { active_scope = expr; return; } @@ -373,18 +375,20 @@ Expr* clearScope(Expr* scope) { return scope; } - -ForLoop* cloneLoopNest(ForLoop* to_clone, Expr* parent_scope){ +ForLoop* cloneLoopNest(ForLoop* to_clone, Expr* parent_scope) { return CloneLoopNest::getClone(to_clone, parent_scope); } -void replaceExprsInScope(Expr* scope, std::unordered_map replacement_map){ - TORCH_INTERNAL_ASSERT(replacement_map.find(scope) == replacement_map.end(), - "Error trying to replace expressions in a scope, scope wants to be replaced entirely."); +void replaceExprsInScope( + Expr* scope, + std::unordered_map replacement_map) { + TORCH_INTERNAL_ASSERT( + replacement_map.find(scope) == replacement_map.end(), + "Error trying to replace expressions in a scope, scope wants to be replaced entirely."); ReplaceExprsInScope::replace(std::move(scope), std::move(replacement_map)); } -Expr* firstInnerMostScope(Expr* scope){ +Expr* firstInnerMostScope(Expr* scope) { return FirstInnerMostScope::get(scope); } @@ -422,8 +426,9 @@ TensorView* asTV(Val* val) { return static_cast(val); } -bool isScope(const Expr* expr){ - return expr->getExprType() == ExprType::ForLoop || expr->getExprType() == ExprType::IfThenElse; +bool isScope(const Expr* expr) { + return expr->getExprType() == ExprType::ForLoop || + expr->getExprType() == ExprType::IfThenElse; } ForLoop* asForLoop(Statement* stmt) { diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 980cbac9da5f..0f4cb4726f51 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -40,13 +40,14 @@ Expr* clearScope(Expr* scope); ForLoop* cloneLoopNest(ForLoop* to_clone, Expr* parent_scope); // Run through a scope and replace expressions inside with replacement_map -void replaceExprsInScope(Expr* scope, std::unordered_map replacement_map); +void replaceExprsInScope( + Expr* scope, + std::unordered_map replacement_map); Expr* firstInnerMostScope(Expr* scope); } // namespace scope_utils - namespace ir_utils { bool isTV(const Val* const); diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 190951f17e50..0973d0419c2d 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -177,7 +177,6 @@ Statement* OptOutMutator::mutate(BinaryOp* bop) { } Statement* OptOutMutator::mutate(ForLoop* fl) { - Val* index = mutateAsVal(fl->index())->asVal(); Val* val_id = mutateAsVal(fl->iter_domain())->asVal(); @@ -199,9 +198,8 @@ Statement* OptOutMutator::mutate(ForLoop* fl) { is_mutated = is_mutated | (mutated_expr != expr); } - if (is_mutated){ - auto newFL = new ForLoop( - index, id, mutated_exprs, fl->parentScope()); + if (is_mutated) { + auto newFL = new ForLoop(index, id, mutated_exprs, fl->parentScope()); return newFL; } @@ -209,7 +207,6 @@ Statement* OptOutMutator::mutate(ForLoop* fl) { } Statement* OptOutMutator::mutate(IfThenElse* ite) { - Val* val_cond = mutateAsVal(ite->cond())->asVal(); TORCH_INTERNAL_ASSERT( val_cond->getValType().value() == ValType::Scalar && @@ -242,7 +239,7 @@ Statement* OptOutMutator::mutate(IfThenElse* ite) { is_mutated = is_mutated | (mutated_expr != expr); } - if (is_mutated){ + if (is_mutated) { auto newITE = new IfThenElse( cond, ite->body().exprs(), ite->elseBody().exprs(), ite->parentScope()); return newITE; From 436f7a266ad49c68d596d533b47334756afca376 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Thu, 9 Apr 2020 09:51:45 -0400 Subject: [PATCH 9/9] Test fix. --- test/cpp/jit/test_gpu.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index cc10242704db..981699f7bce3 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -1,4 +1,4 @@ -#if defined(USE_CUDA) +// #if defined(USE_CUDA) #include #include @@ -1045,14 +1045,17 @@ void testGPU_FusionLoopUnroll() { torch::jit::fuser::cuda::CudaKernel prog; prog.device_ = 0; - prog.grid(1); // 1 CTA - prog.block(128); // 128 Threads + prog.grid(2); + prog.block(16); + + // GPULower lower(&fusion); + // lower.printKernel(std::cout); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::ones({1, 128}, options); + at::Tensor input1 = at::ones({128}, options); at::Tensor input2 = at::ones_like(input1); - ; + at::Tensor output = at::empty_like(input1); std::vector inputs{{input1, input2}}; std::vector outputs{{output}}; @@ -1060,11 +1063,11 @@ void testGPU_FusionLoopUnroll() { torch::jit::fuser::cuda::compileKernel(fusion, prog); torch::jit::fuser::cuda::runTestKernel(prog, inputs, outputs); - at::Tensor check = at::full({1, 128}, 4, options); - ; + at::Tensor check = at::full({128}, 4, options); + TORCH_CHECK(output.equal(check)); } } // namespace jit } // namespace torch -#endif // #if defined(USE_CUDA) +// #endif // #if defined(USE_CUDA) \ No newline at end of file