From 77f634f513f44867e84d4948c7e3f7a71f445483 Mon Sep 17 00:00:00 2001 From: Nadav Rotem Date: Thu, 12 Oct 2017 15:22:53 -0700 Subject: [PATCH 1/5] ExecutionEngine: split the interpreter into two parts: the target-specific execution of instructions and management of buffers, and to the execution engine, that handles the high-level operations, such as inference and everything else. --- examples/CMakeLists.txt | 2 + examples/cifar10.cpp | 24 +- examples/mnist.cpp | 26 +- .../glow/ExecutionEngine/ExecutionEngine.h | 88 +++++++ include/glow/Importer/Caffe2.h | 6 +- include/glow/Interpreter/Interpreter.h | 78 +----- src/glow/CMakeLists.txt | 1 + src/glow/ExecutionEngine/CMakeLists.txt | 11 + src/glow/ExecutionEngine/ExecutionEngine.cpp | 187 +++++++++++++++ src/glow/Importer/Caffe2.cpp | 18 +- src/glow/Interpreter/Interpreter.cpp | 221 +---------------- src/glow/Interpreter/InterpreterNodes.cpp | 222 +++++++++--------- tests/unittests/CMakeLists.txt | 2 + tests/unittests/IRGradCheck.cpp | 38 +-- tests/unittests/InterpreterTest.cpp | 130 +++++----- tools/loader/CMakeLists.txt | 1 + tools/loader/loader.cpp | 13 +- 17 files changed, 554 insertions(+), 514 deletions(-) create mode 100644 include/glow/ExecutionEngine/ExecutionEngine.h create mode 100644 src/glow/ExecutionEngine/CMakeLists.txt create mode 100644 src/glow/ExecutionEngine/ExecutionEngine.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 79c6012fd0..405a9db44c 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -5,6 +5,7 @@ target_link_libraries(cifar10 PRIVATE Interpreter Network + ExecutionEngine Graph IR Support) @@ -15,6 +16,7 @@ target_link_libraries(mnist PRIVATE Interpreter Network + ExecutionEngine Graph IR Support) diff --git a/examples/cifar10.cpp b/examples/cifar10.cpp index 4e9e74df3e..5086a3c730 100644 --- a/examples/cifar10.cpp +++ b/examples/cifar10.cpp @@ -1,9 +1,9 @@ +#include "glow/ExecutionEngine/ExecutionEngine.h" #include "glow/Graph/Graph.h" #include "glow/Graph/Nodes.h" #include "glow/IR/IR.h" #include "glow/IR/IRBuilder.h" #include "glow/IR/Instrs.h" -#include "glow/Interpreter/Interpreter.h" #include "glow/Support/Support.h" #include "llvm/Support/Timer.h" @@ -67,14 +67,14 @@ void testCIFAR10() { GLOW_ASSERT(idx == cifarImageSize * cifarNumImages && "Invalid input file"); // Construct the network: - Interpreter IP; - IP.getConfig().learningRate = 0.001; - IP.getConfig().momentum = 0.9; - IP.getConfig().L2Decay = 0.0001; + ExecutionEngine EE; + EE.getConfig().learningRate = 0.001; + EE.getConfig().momentum = 0.9; + EE.getConfig().L2Decay = 0.0001; unsigned minibatchSize = 8; - auto &G = IP.getGraph(); + auto &G = EE.getGraph(); // Create the input layer: auto *A = @@ -100,9 +100,9 @@ void testCIFAR10() { auto *SM = G.createSoftMax("softmax", RL3, E); auto *result = G.createReturn("ret", SM); - IP.getModule().generateIR(); - IP.optimize(OptimizationMode::Train); - IP.initVars(); + EE.getModule().generateIR(); + EE.optimize(OptimizationMode::Train); + EE.initVars(); // Report progress every this number of training iterations. int reportRate = 256; @@ -117,15 +117,15 @@ void testCIFAR10() { // Bind the images tensor to the input array A, and the labels tensor // to the softmax node SM. - IP.train(reportRate, {A, E}, {&images, &labels}); + EE.train(reportRate, {A, E}, {&images, &labels}); unsigned score = 0; for (unsigned int i = 0; i < 100 / minibatchSize; i++) { Tensor sample(ElemKind::FloatTy, {minibatchSize, 3, 32, 32}); sample.copyConsecutiveSlices(&images, minibatchSize * i); - IP.infer({A}, {&sample}); - auto *res = IP.getTensorForNode(result); + EE.infer({A}, {&sample}); + auto *res = EE.getTensor(result); for (unsigned int iter = 0; iter < minibatchSize; iter++) { auto T = res->getHandle().extractSlice(iter); diff --git a/examples/mnist.cpp b/examples/mnist.cpp index 259c42d2db..10b0416e60 100644 --- a/examples/mnist.cpp +++ b/examples/mnist.cpp @@ -1,3 +1,4 @@ +#include "glow/ExecutionEngine/ExecutionEngine.h" #include "glow/Graph/Graph.h" #include "glow/Graph/Node.h" #include "glow/Graph/Nodes.h" @@ -69,13 +70,14 @@ void testMNIST() { unsigned minibatchSize = 8; + ExecutionEngine EE; + // Construct the network: - Interpreter IP; - IP.getConfig().learningRate = 0.001; - IP.getConfig().momentum = 0.9; - IP.getConfig().L2Decay = 0.001; + EE.getConfig().learningRate = 0.001; + EE.getConfig().momentum = 0.9; + EE.getConfig().L2Decay = 0.001; - auto &G = IP.getGraph(); + auto &G = EE.getGraph(); Variable *A = G.createVariable(ElemKind::FloatTy, {minibatchSize, 28, 28, 1}, "input", Variable::InitKind::Extern); @@ -97,9 +99,9 @@ void testMNIST() { auto *result = G.createReturn("return", SM); - IP.getModule().generateIR(); - IP.optimize(OptimizationMode::Train); - IP.initVars(); + EE.getModule().generateIR(); + EE.optimize(OptimizationMode::Train); + EE.initVars(); // Report progress every this number of training iterations. constexpr int reportRate = 30; @@ -115,12 +117,12 @@ void testMNIST() { // On each training iteration take an input from imageInputs and update // the input variable A, and add take a corresponding label and update the // softmax layer. - IP.train(reportRate, {A, selected}, {&imageInputs, &labelInputs}); + EE.train(reportRate, {A, selected}, {&imageInputs, &labelInputs}); timer.stopTimer(); } std::cout << "Validating.\n"; - IP.optimize(OptimizationMode::Infer); + EE.optimize(OptimizationMode::Infer); auto LIH = labelInputs.getHandle(); @@ -129,8 +131,8 @@ void testMNIST() { Tensor sample(ElemKind::FloatTy, {minibatchSize, 1, 28, 28}); sample.copyConsecutiveSlices(&imageInputs, 0); - IP.infer({A}, {&sample}); - Tensor *res = IP.getTensorForNode(result); + EE.infer({A}, {&sample}); + Tensor *res = EE.getTensor(result); for (unsigned int iter = 0; iter < minibatchSize; iter++) { auto T = res->getHandle().extractSlice(iter); diff --git a/include/glow/ExecutionEngine/ExecutionEngine.h b/include/glow/ExecutionEngine/ExecutionEngine.h new file mode 100644 index 0000000000..646997c7eb --- /dev/null +++ b/include/glow/ExecutionEngine/ExecutionEngine.h @@ -0,0 +1,88 @@ +#ifndef GLOW_EXECUTIONENGINE_EXECUTIONENGINE_H +#define GLOW_EXECUTIONENGINE_EXECUTIONENGINE_H + +#include "glow/Base/Tensor.h" +#include "glow/Base/Train.h" +#include "glow/Graph/Graph.h" +#include "glow/IR/IR.h" +#include "glow/IR/IRBuilder.h" +#include "glow/Interpreter/Interpreter.h" +#include "glow/Optimizer/Optimizer.h" + +#include "llvm/ADT/ArrayRef.h" + +#include + +namespace glow { + +/// This is the ExecutionEngine. It owns the Graph, the IR, and the backends. +class ExecutionEngine final { + /// The Graph that represents the high-level program. + Graph G_{}; + /// The Module that holds the IR. + Module M_; + /// The network interpreter + Interpreter IP_; + /// The network trainer. + Trainer trainer_{}; + +public: + ExecutionEngine() : M_(G_), IP_(M_) {} + + /// \returns the internal module. + Module &getModule() { return M_; } + + /// \returns the internal module. + Graph &getGraph() { return G_; } + + /// Run the target-independent optimizations on the module. + void optimize(OptimizationMode mode); + + /// Provides access to the training configuration. + TrainingConfig &getConfig() { return trainer_.config; } + + /// Initialize all of the variables in the program. + void initVars(); + + /// Runs the program in a forward pass. Update the nodes in \p nodes with the + /// values \p inputs. + void infer(llvm::ArrayRef vars, llvm::ArrayRef inputs); + + /// Train the network. Perform \p iterations in the training loop. Each + /// iteration does a full forward and backward pass of a whole batch. + /// The method updates the variables in \p vars with the tensors \p inputs. + void train(size_t iterations, llvm::ArrayRef vars, + llvm::ArrayRef inputs); + + /// \returns a pointer to the tensor that is saved under \p v. The tensor + /// is owned by the Interpreter. + Tensor *getTensor(const Node *v) const; + + /// \returns a float-handle to the tensor that is stored at \p v. + Handle getWeightHandle(Variable *v) const; + + /// \returns a float-handle to the tensor that is stored at \p v. + Handle getGradHandle(Variable *v); + + /// Copies the content of the tensor \p t into the value \p v. + void initValue(const Variable *v, const Tensor *t); + +private: + /// Update the inputs for all variables \p vars with data from the inputs \p + /// inputs at offset \p sampleIdx. Then perform a forward and backwards scan. + void updateForwardBackward(llvm::ArrayRef vars, + llvm::ArrayRef inputs, size_t sampleIdx); + + /// Update all of the weight tensors (non-activation) with their gradients. + void learnGradient(size_t batchSize); + + /// Update the content of the tensor \p v with data that comes from \p input. + /// The data starts at slice \p sampleIdx and wraps around until the data in + /// \p v is filled. All dimensions, except for the first (batch) dimension + /// must be identical. + void loadValueFromTensor(const Value *v, Tensor *input, size_t sampleIdx); +}; + +} // namespace glow + +#endif // GLOW_EXECUTIONENGINE_EXECUTIONENGINE_H diff --git a/include/glow/Importer/Caffe2.h b/include/glow/Importer/Caffe2.h index 8a1f3391ee..ad1c3fc95c 100644 --- a/include/glow/Importer/Caffe2.h +++ b/include/glow/Importer/Caffe2.h @@ -18,14 +18,14 @@ namespace glow { class IRBuilder; class Instruction; -class Interpreter; +class ExecutionEngine; class Tensor; class Value; /// Loads caffe2 models. class caffe2ModelLoader { /// The interpreter that runs the program. - Interpreter &IP_; + ExecutionEngine &EE_; /// Saves network nodes by name. std::unordered_map nodeByName_; @@ -75,7 +75,7 @@ class caffe2ModelLoader { caffe2ModelLoader(const std::string &netDescFilename, const std::string &netWeightFilename, llvm::ArrayRef names, - llvm::ArrayRef tensors, Interpreter &IP); + llvm::ArrayRef tensors, ExecutionEngine &IP); /// \returns the output of the network. This is usually the result of the last /// softmax or regression layer. diff --git a/include/glow/Interpreter/Interpreter.h b/include/glow/Interpreter/Interpreter.h index 4ac67d2c4d..3dfc1beaa9 100644 --- a/include/glow/Interpreter/Interpreter.h +++ b/include/glow/Interpreter/Interpreter.h @@ -19,10 +19,8 @@ class Context; /// This is the IR-interpreter. It owns the IR, and the heap, and is able to /// execute the instructions one at a time. class Interpreter final { - /// The Graph that represents the high-level program. - Graph G_; /// The Module that holds the IR. - Module M_; + Module &M_; /// Maps values to Tensors, that are owned by this class. std::unordered_map tensors_; @@ -30,91 +28,40 @@ class Interpreter final { /// are owned by this map. std::unordered_map gradients_; - /// The network trainer. - Trainer trainer_; - public: /// \returns the internal module. Module &getModule() { return M_; } - /// \returns the internal module. - Graph &getGraph() { return G_; } - /// Run the target-independent optimizations on the module. - void optimize(OptimizationMode mode); /// Ctor. - Interpreter(); + Interpreter(Module &M) : M_(M) {} /// Dtor. ~Interpreter(); - /// Provides access to the training configuration. - TrainingConfig &getConfig() { return trainer_.config; } - - /// Registers the tensor \p t under the IR value \p v. - void registerTensor(Value *v, Tensor *t); - /// \returns a pointer to the tensor that is saved under \p v. The tensor /// is owned by the Interpreter. - Tensor *getTensorForValue(const Value *v) const; + Tensor *getTensor(const Value *v) const; - /// \returns a pointer to the tensor that is saved under \p v. The tensor - /// is owned by the Interpreter. - Tensor *getTensorForNode(const Node *v) const; + /// Allocate a tensor to back the value \p v. Do not allocate anything if a + /// tensor is already allocated for \p v. + /// \returns a tensor for \p v. + Tensor *getOrCreateTensor(const Value *v); - /// Remove the tensor (and the gradient) that's stored for \p v; - void deleteTensor(const Value *v); + /// \returns True if a tensor was allocated for \p v. + bool hasTensor(const Value *v); /// Copies the content of the tensor \p t into the value \p v. - void initValue(const WeightVar *v, const Tensor *t); + void initValue(const Value *v, const Tensor *t); /// \returns gets or creates a new tensor to back the value \p v. If the /// tensor does not exist then this method creates it. The dimension of the /// gradient tensor must match the dimensions of the tensor that backs \p v. Tensor *getOrCreateGradTensor(const Value *v); - /// Update the content of the tensor \p v with data that comes from \p input. - /// The data starts at slice \p sampleIdx and wraps around until the data in - /// \p v is filled. All dimensions, except for the first (batch) dimension - /// must be identical. - void loadValueFromTensor(const Value *v, Tensor *input, size_t sampleIdx); - /// \returns a float-handle to the tensor that is stored at \p v. - Handle getWeightHandle(Context *, Value *v) const; + Handle getWeightHandle(Value *v) const; /// \returns a float-handle to the tensor that is stored at \p v. - Handle getWeightHandle(Context *, Variable *v) const; - - /// \returns a float-handle to the tensor that is stored at \p v. - Handle getGradHandle(Context *, Value *v); - - /// \returns a float-handle to the tensor that is stored at \p v. - Handle getGradHandle(Context *, Variable *v); - - /// Initialize all of the variables in the program. - void initVars(); - - /// Runs the program in a forward pass. Update the nodes in \p nodes with the - /// values \p inputs. - void infer(llvm::ArrayRef vars, llvm::ArrayRef inputs); - - /// Train the network. Perform \p iterations in the training loop. Each - /// iteration does a full forward and backward pass of a whole batch. - /// The method updates the variables in \p vars with the tensors \p inputs. - void train(size_t iterations, llvm::ArrayRef vars, - llvm::ArrayRef inputs); - -private: - /// Allocate a tensor to back the value \p v. Do not allocate anything if a - /// tensor is already allocated for \p v. - /// \returns v's Tensor. - Tensor *allocateBackingTensor(const Value *v); - - /// Update all of the weight tensors (non-activation) with their gradients. - void learnGradient(size_t batchSize); - - /// Update the inputs for all variables \p vars with data from the inputs \p - /// inputs at offset \p sampleIdx. Then perform a forward and backwards scan. - void updateForwardBackward(llvm::ArrayRef vars, - llvm::ArrayRef inputs, size_t sampleIdx); + Handle getGradHandle(Value *v); /// Perform a single forward scan of the network, interpreting all of the /// instructions. @@ -124,6 +71,7 @@ class Interpreter final { /// instructions. void doBackwardPass(); +private: /// @name Interpreter methods. This is a list of method declerations that are /// used by the interpreter to dispatch different instructions. ///@{ diff --git a/src/glow/CMakeLists.txt b/src/glow/CMakeLists.txt index 0f8ef6f3ee..76c7b516e3 100644 --- a/src/glow/CMakeLists.txt +++ b/src/glow/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Base) +add_subdirectory(ExecutionEngine) add_subdirectory(Graph) add_subdirectory(IR) add_subdirectory(Importer) diff --git a/src/glow/ExecutionEngine/CMakeLists.txt b/src/glow/ExecutionEngine/CMakeLists.txt new file mode 100644 index 0000000000..116abf0f92 --- /dev/null +++ b/src/glow/ExecutionEngine/CMakeLists.txt @@ -0,0 +1,11 @@ + +add_library(ExecutionEngine + ExecutionEngine.cpp) + +target_link_libraries(ExecutionEngine + PRIVATE + Interpreter + Optimizer + Base + Graph) + diff --git a/src/glow/ExecutionEngine/ExecutionEngine.cpp b/src/glow/ExecutionEngine/ExecutionEngine.cpp new file mode 100644 index 0000000000..27c02d4d48 --- /dev/null +++ b/src/glow/ExecutionEngine/ExecutionEngine.cpp @@ -0,0 +1,187 @@ +// Copyright 2017 Facebook Inc. All Rights Reserved. + +#include "glow/ExecutionEngine/ExecutionEngine.h" + +using namespace glow; + +void ExecutionEngine::infer(llvm::ArrayRef vars, + llvm::ArrayRef inputs) { + assert(!inputs.empty() && "No inputs"); + assert(inputs.size() == vars.size() && + "The number of inputs does not match the number of variables"); + + // Update the input variables. + for (int i = 0, e = vars.size(); i < e; i++) { + auto *val = M_.getWeightForNode(vars[i]); + loadValueFromTensor(val, inputs[i], 0); + } + + IP_.doForwardPass(false); +} + +void ExecutionEngine::train(size_t iterations, llvm::ArrayRef vars, + llvm::ArrayRef inputs) { + static size_t trainCounter = 0; + + assert(!inputs.empty() && "No inputs"); + assert(inputs.size() == vars.size() && + "The number of inputs does not match the number of variables"); + + std::vector weights; + for (auto *v : vars) { + weights.push_back(M_.getWeightForNode(v)); + } + + // This is the size of one batch (the number of samples in the batch). + size_t batchSize = vars[0]->dims()[0]; + + for (size_t i = 0; i < iterations; i++) { + // Launch threads that update the different chunks in the batch: + updateForwardBackward(weights, inputs, trainCounter + batchSize); + + trainCounter += batchSize; + + // The algorithm for merging the state from the different threads is + /// described in the paper: Alex Krizhevsky [2014] + // "One weird trick for parallelizing convolutional neural networks" + learnGradient(batchSize); + } +} + +void ExecutionEngine::learnGradient(size_t batchSize) { + for (auto *V : M_.getWeights()) { + // Do not try to learn the values of input/output buffers. + if (V->getInitKind() == WeightVar::InitKind::Extern) { + continue; + } + + auto W = IP_.getTensor(V); + auto G = IP_.getOrCreateGradTensor(V); + + // Handle weight update by learning the gradients into the weights. + trainer_.train(W, G, batchSize); + } +} + +void ExecutionEngine::updateForwardBackward(llvm::ArrayRef vars, + llvm::ArrayRef inputs, + size_t sampleIdx) { + // Update the input variables. + for (int i = 0, e = vars.size(); i < e; i++) { + loadValueFromTensor(vars[i], inputs[i], sampleIdx); + } + + IP_.doForwardPass(true); + + IP_.doBackwardPass(); +} + +void ExecutionEngine::loadValueFromTensor(const Value *v, Tensor *input, + size_t sampleIdx) { + assert(v && "Invalid value"); + auto *t = IP_.getTensor(v); + + auto dim = input->dims(); + assert(t->dims().drop_front() == dim.drop_front() && "Invalid slice size"); + // Extract the n'th slice, that must be a tensor. + size_t slc = sampleIdx % dim[0]; + t->copyConsecutiveSlices(input, slc); +} + +void ExecutionEngine::optimize(OptimizationMode mode) { + ::glow::optimize(M_, mode); +} + +Tensor *ExecutionEngine::getTensor(const Node *v) const { + auto val = M_.getWeightForNode(v); + assert(val && "Node does not have a registered IR value"); + return IP_.getTensor(val); +} + +/// \returns a float-handle to the tensor that is stored at \p v. +Handle ExecutionEngine::getWeightHandle(Variable *v) const { + auto val = M_.getWeightForNode(v); + return IP_.getWeightHandle(val); +} + +/// \returns a float-handle to the tensor that is stored at \p v. +Handle ExecutionEngine::getGradHandle(Variable *v) { + auto val = M_.getWeightForNode(v); + return IP_.getGradHandle(val); +} + +/// Copies the content of the tensor \p t into the value \p v. +void ExecutionEngine::initValue(const Variable *v, const Tensor *t) { + auto *N = M_.getWeightForNode(v); + return IP_.initValue(N, t); +} + +void ExecutionEngine::initVars() { + for (auto *W : M_.getWeights()) { + // Don't initialize tensors that are already initialized. + if (IP_.hasTensor(W)) { + continue; + } + + auto *T = IP_.getOrCreateTensor(W); + // The parameter to the instruction. + auto val = W->getVal(); + + switch (W->getInitKind()) { + case WeightVar::InitKind::Extern: + break; + + case WeightVar::InitKind::Broadcast: { + switch (T->getElementType()) { + case ElemKind::FloatTy: { + T->getHandle().clear(val); + break; + } + case ElemKind::DoubleTy: { + T->getHandle().clear(val); + break; + } + case ElemKind::Int8Ty: { + T->getHandle().clear(val); + break; + }; + case ElemKind::Int32Ty: { + T->getHandle().clear(val); + break; + } + case ElemKind::IndexTy: { + T->getHandle().clear(val); + break; + } + } + break; + } + + case WeightVar::InitKind::Xavier: { + switch (T->getElementType()) { + case ElemKind::FloatTy: { + T->getHandle().randomize(val); + break; + } + case ElemKind::DoubleTy: { + T->getHandle().randomize(val); + break; + } + case ElemKind::Int8Ty: { + T->getHandle().randomize(val); + break; + }; + case ElemKind::Int32Ty: { + T->getHandle().randomize(val); + break; + } + case ElemKind::IndexTy: { + T->getHandle().randomize(val); + break; + } + } + break; + } + } + } +} diff --git a/src/glow/Importer/Caffe2.cpp b/src/glow/Importer/Caffe2.cpp index 61c79b9c2b..db3fef63a8 100644 --- a/src/glow/Importer/Caffe2.cpp +++ b/src/glow/Importer/Caffe2.cpp @@ -2,12 +2,12 @@ #include "glow/Importer/Caffe2.h" #include "glow/Base/Tensor.h" +#include "glow/ExecutionEngine/ExecutionEngine.h" #include "glow/Graph/Graph.h" #include "glow/Graph/Nodes.h" #include "glow/IR/IR.h" #include "glow/IR/IRBuilder.h" #include "glow/IR/Instrs.h" -#include "glow/Interpreter/Interpreter.h" #include "glow/Support/Casting.h" #include "caffe.pb.h" @@ -110,7 +110,7 @@ Node *caffe2ModelLoader::getNodeByName(const std::string &name) { } Node *caffe2ModelLoader::getOrCreateNodeByName(const std::string &name) { - auto &G = IP_.getGraph(); + auto &G = EE_.getGraph(); auto it = nodeByName_.find(name); if (it != nodeByName_.end()) { return it->second; @@ -124,7 +124,7 @@ Node *caffe2ModelLoader::getOrCreateNodeByName(const std::string &name) { } void caffe2ModelLoader::loadOperator(const caffe2::OperatorDef &op) { - auto &G = IP_.getGraph(); + auto &G = EE_.getGraph(); ArgumentDictionaryTy dict = loadArgumenrMap(op); @@ -394,8 +394,8 @@ caffe2ModelLoader::caffe2ModelLoader(const std::string &netDescFilename, const std::string &netWeightFilename, llvm::ArrayRef names, llvm::ArrayRef tensors, - Interpreter &IP) - : IP_(IP) { + ExecutionEngine &EE) + : EE_(EE) { // Verify that the version of the library that we linked against is // compatible with the version of the headers we compiled against. GOOGLE_PROTOBUF_VERIFY_VERSION; @@ -417,17 +417,17 @@ caffe2ModelLoader::caffe2ModelLoader(const std::string &netDescFilename, loadNetwork(networkDef); // Save the result of the last operator into a weight. - auto &G = IP_.getGraph(); - auto &M = IP_.getModule(); + auto &G = EE_.getGraph(); + auto &M = EE_.getModule(); root_ = G.createReturn("ret", root_); // Emit IR for the graph. - IP_.getModule().generateIR(); + EE_.getModule().generateIR(); // Load the value of the variables. for (auto p : variableInit_) { WeightVar *N = cast(M.getWeightForNode(p.first)); N->setInitKind(WeightVar::InitKind::Extern); - IP.initValue(N, p.second); + EE.initValue(p.first, p.second); } } diff --git a/src/glow/Interpreter/Interpreter.cpp b/src/glow/Interpreter/Interpreter.cpp index 99aacb799a..6395cc5fb6 100644 --- a/src/glow/Interpreter/Interpreter.cpp +++ b/src/glow/Interpreter/Interpreter.cpp @@ -9,8 +9,6 @@ using namespace glow; -Interpreter::Interpreter() : M_(G_) {} - Interpreter::~Interpreter() { // Delete the tensors that are owned by this module. for (auto p : tensors_) { @@ -23,50 +21,13 @@ Interpreter::~Interpreter() { } } -void Interpreter::optimize(OptimizationMode mode) { - ::glow::optimize(M_, mode); -} - -void Interpreter::registerTensor(Value *v, Tensor *t) { - assert(t->getType().isEqual(v->getType()) && - "Tensor must match variable dimensions"); - - auto it = tensors_.find(v); - if (it != tensors_.end()) { - delete it->second; - it->second = t; - return; - } - tensors_[v] = t; -} - -Tensor *Interpreter::getTensorForValue(const Value *v) const { +Tensor *Interpreter::getTensor(const Value *v) const { auto it = tensors_.find(v); assert(it != tensors_.end() && "Unknown key Value."); return it->second; } -Tensor *Interpreter::getTensorForNode(const Node *v) const { - auto val = M_.getWeightForNode(v); - assert(val && "Node does not have a registered IR value"); - return getTensorForValue(val); -} - -void Interpreter::deleteTensor(const Value *v) { - auto it = tensors_.find(v); - assert(it != tensors_.end() && "Unknown key Value."); - auto *T = it->second; - delete T; - tensors_.erase(it); - - auto git = gradients_.find(T); - if (git != gradients_.end()) { - delete git->second; - gradients_.erase(git); - } -} - -void Interpreter::initValue(const WeightVar *v, const Tensor *t) { +void Interpreter::initValue(const Value *v, const Tensor *t) { auto it = tensors_.find(v); if (it != tensors_.end()) { it->second->copyFrom(t); @@ -79,7 +40,7 @@ void Interpreter::initValue(const WeightVar *v, const Tensor *t) { } Tensor *Interpreter::getOrCreateGradTensor(const Value *v) { - auto *T = getTensorForValue(v); + auto *T = getTensor(v); auto it = gradients_.find(T); if (it != gradients_.end()) { return it->second; @@ -91,190 +52,26 @@ Tensor *Interpreter::getOrCreateGradTensor(const Value *v) { return N; } -void Interpreter::loadValueFromTensor(const Value *v, Tensor *input, - size_t sampleIdx) { - assert(v && "Invalid value"); - auto *t = getTensorForValue(v); - - auto dim = input->dims(); - assert(t->dims().drop_front() == dim.drop_front() && "Invalid slice size"); - // Extract the n'th slice, that must be a tensor. - size_t slc = sampleIdx % dim[0]; - t->copyConsecutiveSlices(input, slc); -} - -Handle Interpreter::getWeightHandle(Context *ctx, Value *v) const { - return getTensorForValue(v)->getHandle(); -} - -Handle Interpreter::getWeightHandle(Context *ctx, Variable *v) const { - auto *N = M_.getWeightForNode(v); - return getTensorForValue(N)->getHandle(); +Handle Interpreter::getWeightHandle(Value *v) const { + return getTensor(v)->getHandle(); } -Handle Interpreter::getGradHandle(Context *ctx, Value *v) { +Handle Interpreter::getGradHandle(Value *v) { return getOrCreateGradTensor(v)->getHandle(); } -Handle Interpreter::getGradHandle(Context *ctx, Variable *v) { - auto *N = M_.getWeightForNode(v); - return getOrCreateGradTensor(N)->getHandle(); -} - -Tensor *Interpreter::allocateBackingTensor(const Value *v) { - // Allocate a tensor for the variable. - Tensor *T = nullptr; +Tensor *Interpreter::getOrCreateTensor(const Value *v) { // Pick the tensor. auto it = tensors_.find(v); if (it == tensors_.end()) { - T = new Tensor(v->getType()); + Tensor *T = new Tensor(v->getType()); tensors_[v] = T; return T; } return it->second; } -void Interpreter::initVars() { - for (auto *W : M_.getWeights()) { - // Don't initialize tensors that are already initialized. - if (tensors_.count(W)) { - continue; - } - - auto *T = allocateBackingTensor(W); - // The parameter to the instruction. - auto val = W->getVal(); - - switch (W->getInitKind()) { - case WeightVar::InitKind::Extern: - break; - - case WeightVar::InitKind::Broadcast: { - switch (T->getElementType()) { - case ElemKind::FloatTy: { - T->getHandle().clear(val); - break; - } - case ElemKind::DoubleTy: { - T->getHandle().clear(val); - break; - } - case ElemKind::Int8Ty: { - T->getHandle().clear(val); - break; - }; - case ElemKind::Int32Ty: { - T->getHandle().clear(val); - break; - } - case ElemKind::IndexTy: { - T->getHandle().clear(val); - break; - } - } - break; - } - - case WeightVar::InitKind::Xavier: { - switch (T->getElementType()) { - case ElemKind::FloatTy: { - T->getHandle().randomize(val); - break; - } - case ElemKind::DoubleTy: { - T->getHandle().randomize(val); - break; - } - case ElemKind::Int8Ty: { - T->getHandle().randomize(val); - break; - }; - case ElemKind::Int32Ty: { - T->getHandle().randomize(val); - break; - } - case ElemKind::IndexTy: { - T->getHandle().randomize(val); - break; - } - } - break; - } - } - } -} - -void Interpreter::infer(llvm::ArrayRef vars, - llvm::ArrayRef inputs) { - assert(!inputs.empty() && "No inputs"); - assert(inputs.size() == vars.size() && - "The number of inputs does not match the number of variables"); - - // Update the input variables. - for (int i = 0, e = vars.size(); i < e; i++) { - auto *val = M_.getWeightForNode(vars[i]); - loadValueFromTensor(val, inputs[i], 0); - } - - doForwardPass(false); -} - -void Interpreter::train(size_t iterations, llvm::ArrayRef vars, - llvm::ArrayRef inputs) { - static size_t trainCounter = 0; - - assert(!inputs.empty() && "No inputs"); - assert(inputs.size() == vars.size() && - "The number of inputs does not match the number of variables"); - - std::vector weights; - for (auto *v : vars) { - weights.push_back(M_.getWeightForNode(v)); - } - - // This is the size of one batch (the number of samples in the batch). - size_t batchSize = vars[0]->dims()[0]; - - for (size_t i = 0; i < iterations; i++) { - // Launch threads that update the different chunks in the batch: - updateForwardBackward(weights, inputs, trainCounter + batchSize); - - trainCounter += batchSize; - - // The algorithm for merging the state from the different threads is - /// described in the paper: Alex Krizhevsky [2014] - // "One weird trick for parallelizing convolutional neural networks" - learnGradient(batchSize); - } -} - -void Interpreter::learnGradient(size_t batchSize) { - for (auto *V : M_.getWeights()) { - // Do not try to learn the values of input/output buffers. - if (V->getInitKind() == WeightVar::InitKind::Extern) { - continue; - } - - auto W = getTensorForValue(V); - auto G = getOrCreateGradTensor(V); - - // Handle weight update by learning the gradients into the weights. - trainer_.train(W, G, batchSize); - } -} - -void Interpreter::updateForwardBackward(llvm::ArrayRef vars, - llvm::ArrayRef inputs, - size_t sampleIdx) { - // Update the input variables. - for (int i = 0, e = vars.size(); i < e; i++) { - loadValueFromTensor(vars[i], inputs[i], sampleIdx); - } - - doForwardPass(true); - - doBackwardPass(); -} +bool Interpreter::hasTensor(const Value *v) { return tensors_.count(v); } void Interpreter::doForwardPass(bool isTrain) { // Do the forward pass. diff --git a/src/glow/Interpreter/InterpreterNodes.cpp b/src/glow/Interpreter/InterpreterNodes.cpp index dd4e8c292d..fa5270a13f 100644 --- a/src/glow/Interpreter/InterpreterNodes.cpp +++ b/src/glow/Interpreter/InterpreterNodes.cpp @@ -11,8 +11,8 @@ using namespace glow; //===----------------------------------------------------------------------===// void Interpreter::fwdCopyInst(Context *ctx, bool isTrain, const CopyInst *I) { - auto S = getWeightHandle(ctx, I->getSrc()); - auto D = getWeightHandle(ctx, I->getDest()); + auto S = getWeightHandle(I->getSrc()); + auto D = getWeightHandle(I->getDest()); for (size_t i = 0, e = S.size(); i < e; i++) { D.raw(i) = S.raw(i); @@ -20,8 +20,8 @@ void Interpreter::fwdCopyInst(Context *ctx, bool isTrain, const CopyInst *I) { } void Interpreter::bwdCopyInst(Context *ctx, const CopyInst *I) { - auto inG = getGradHandle(ctx, I->getSrc()); - auto outG = getGradHandle(ctx, I->getDest()); + auto inG = getGradHandle(I->getSrc()); + auto outG = getGradHandle(I->getDest()); for (size_t i = 0, e = outG.size(); i < e; i++) { inG.raw(i) += outG.raw(i); @@ -30,10 +30,10 @@ void Interpreter::bwdCopyInst(Context *ctx, const CopyInst *I) { void Interpreter::fwdConvolutionInst(Context *ctx, bool isTrain, const ConvolutionInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); - auto filterW = getWeightHandle(ctx, I->getFilter()); - auto biasW = getWeightHandle(ctx, I->getBias()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); + auto filterW = getWeightHandle(I->getFilter()); + auto biasW = getWeightHandle(I->getBias()); size_t filterSize = I->getKernel(); size_t pad = I->getPad(); @@ -83,14 +83,14 @@ void Interpreter::fwdConvolutionInst(Context *ctx, bool isTrain, } void Interpreter::bwdConvolutionInst(Context *ctx, const ConvolutionInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto inG = getGradHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); - auto outG = getGradHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto inG = getGradHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); + auto outG = getGradHandle(I->getDest()); - auto filterW = getWeightHandle(ctx, I->getFilter()); - auto filterG = getGradHandle(ctx, I->getFilter()); - auto biasG = getGradHandle(ctx, I->getBias()); + auto filterW = getWeightHandle(I->getFilter()); + auto filterG = getGradHandle(I->getFilter()); + auto biasG = getGradHandle(I->getBias()); size_t filterSize = I->getKernel(); size_t pad = I->getPad(); @@ -153,8 +153,8 @@ void Interpreter::fwdPoolInst(Context *ctx, bool isTrain, const PoolInst *I) { } void Interpreter::fwdPoolMax_impl(Context *ctx, const PoolInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); ShapeNHWC odim(outW.dims()); ShapeNHWC idim(inW.dims()); @@ -163,7 +163,7 @@ void Interpreter::fwdPoolMax_impl(Context *ctx, const PoolInst *I) { auto filterSize = I->getKernel(); auto stride = I->getStride(); - auto SXY = getTensorForValue(I->srcXY())->getHandle(); + auto SXY = getTensor(I->srcXY())->getHandle(); // For each input in the batch: for (size_t n = 0; n < odim.n; n++) { @@ -214,8 +214,8 @@ void Interpreter::fwdPoolMax_impl(Context *ctx, const PoolInst *I) { } void Interpreter::fwdPoolAvg_impl(Context *ctx, const PoolInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); ShapeNHWC odim(outW.dims()); ShapeNHWC idim(inW.dims()); @@ -271,13 +271,13 @@ void Interpreter::bwdPoolInst(Context *ctx, const PoolInst *I) { } void Interpreter::bwdPoolMax_impl(Context *ctx, const PoolInst *I) { - auto inG = getGradHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); - auto outG = getGradHandle(ctx, I->getDest()); + auto inG = getGradHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); + auto outG = getGradHandle(I->getDest()); ShapeNHWC odim(outW.dims()); - auto SXY = getTensorForValue(I->srcXY())->getHandle(); + auto SXY = getTensor(I->srcXY())->getHandle(); // For each input in the batch: for (size_t n = 0; n < odim.n; n++) { @@ -302,10 +302,10 @@ void Interpreter::bwdPoolMax_impl(Context *ctx, const PoolInst *I) { } void Interpreter::bwdPoolAvg_impl(Context *ctx, const PoolInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto inG = getGradHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); - auto outG = getGradHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto inG = getGradHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); + auto outG = getGradHandle(I->getDest()); ShapeNHWC odim(outW.dims()); ShapeNHWC idim(inW.dims()); @@ -353,14 +353,14 @@ void Interpreter::bwdPoolAvg_impl(Context *ctx, const PoolInst *I) { void Interpreter::fwdFullyConnectedInst(Context *ctx, const bool isTrain, const FullyConnectedInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); auto odim = flattenCdr(outW.dims()); auto idim = flattenCdr(inW.dims()); - auto filterW = getWeightHandle(ctx, I->getFilter()); - auto biasW = getWeightHandle(ctx, I->getBias()); + auto filterW = getWeightHandle(I->getFilter()); + auto biasW = getWeightHandle(I->getBias()); size_t inputSize = idim.second; @@ -382,17 +382,17 @@ void Interpreter::fwdFullyConnectedInst(Context *ctx, const bool isTrain, void Interpreter::bwdFullyConnectedInst(Context *ctx, const FullyConnectedInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto inG = getGradHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); - auto outG = getGradHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto inG = getGradHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); + auto outG = getGradHandle(I->getDest()); auto odim = flattenCdr(outW.dims()); auto idim = flattenCdr(inW.dims()); - auto filterW = getWeightHandle(ctx, I->getFilter()); - auto filterG = getGradHandle(ctx, I->getFilter()); - auto biasG = getGradHandle(ctx, I->getBias()); + auto filterW = getWeightHandle(I->getFilter()); + auto filterG = getGradHandle(I->getFilter()); + auto biasG = getGradHandle(I->getBias()); size_t inSize = idim.second; @@ -420,8 +420,8 @@ void Interpreter::bwdFullyConnectedInst(Context *ctx, //===----------------------------------------------------------------------===// void Interpreter::fwdReluInst(Context *ctx, bool isTrain, const ReluInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); for (size_t i = 0, e = inW.size(); i < e; i++) { FloatTy val = inW.raw(i); @@ -430,9 +430,9 @@ void Interpreter::fwdReluInst(Context *ctx, bool isTrain, const ReluInst *I) { } void Interpreter::bwdReluInst(Context *ctx, const ReluInst *I) { - auto inG = getGradHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); - auto outG = getGradHandle(ctx, I->getDest()); + auto inG = getGradHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); + auto outG = getGradHandle(I->getDest()); for (size_t i = 0, e = outW.size(); i < e; i++) { FloatTy val = outW.raw(i); @@ -442,8 +442,8 @@ void Interpreter::bwdReluInst(Context *ctx, const ReluInst *I) { void Interpreter::fwdSigmoidInst(Context *ctx, bool isTrain, const SigmoidInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); for (size_t i = 0, e = outW.size(); i < e; i++) { FloatTy val = inW.raw(i); @@ -451,9 +451,9 @@ void Interpreter::fwdSigmoidInst(Context *ctx, bool isTrain, } } void Interpreter::bwdSigmoidInst(Context *ctx, const SigmoidInst *I) { - auto inG = getGradHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); - auto outG = getGradHandle(ctx, I->getDest()); + auto inG = getGradHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); + auto outG = getGradHandle(I->getDest()); for (size_t i = 0, e = outW.size(); i < e; i++) { FloatTy val = outW.raw(i); @@ -462,8 +462,8 @@ void Interpreter::bwdSigmoidInst(Context *ctx, const SigmoidInst *I) { } void Interpreter::fwdTanhInst(Context *ctx, bool isTrain, const TanhInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); for (size_t i = 0, e = inW.size(); i < e; i++) { FloatTy val = inW.raw(i); @@ -473,9 +473,9 @@ void Interpreter::fwdTanhInst(Context *ctx, bool isTrain, const TanhInst *I) { } } void Interpreter::bwdTanhInst(Context *ctx, const TanhInst *I) { - auto inG = getGradHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); - auto outG = getGradHandle(ctx, I->getDest()); + auto inG = getGradHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); + auto outG = getGradHandle(I->getDest()); for (size_t i = 0, e = outW.size(); i < e; i++) { FloatTy val = outW.raw(i); @@ -489,11 +489,11 @@ void Interpreter::bwdTanhInst(Context *ctx, const TanhInst *I) { void Interpreter::fwdSoftMaxInst(Context *ctx, bool isTrain, const SoftMaxInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); auto idim = inW.dims(); - auto EH = getWeightHandle(ctx, I->getE()); + auto EH = getWeightHandle(I->getE()); for (size_t n = 0; n < idim[0]; n++) { FloatTy max = inW.at({n, 0}); @@ -521,11 +521,11 @@ void Interpreter::fwdSoftMaxInst(Context *ctx, bool isTrain, } void Interpreter::bwdSoftMaxInst(Context *ctx, const SoftMaxInst *I) { - auto inG = getGradHandle(ctx, I->getSrc()); + auto inG = getGradHandle(I->getSrc()); auto idim = inG.dims(); - auto EH = getTensorForValue(I->getE())->getHandle(); - auto selectedH = getTensorForValue(I->getSelected())->getHandle(); + auto EH = getTensor(I->getE())->getHandle(); + auto selectedH = getTensor(I->getSelected())->getHandle(); // http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/ // https://stats.stackexchange.com/questions/79454/softmax-layer-in-a-neural-network @@ -540,8 +540,8 @@ void Interpreter::bwdSoftMaxInst(Context *ctx, const SoftMaxInst *I) { void Interpreter::fwdRegressionInst(Context *ctx, bool isTrain, const RegressionInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); for (size_t i = 0, e = inW.size(); i < e; i++) { outW.raw(i) = inW.raw(i); @@ -549,9 +549,9 @@ void Interpreter::fwdRegressionInst(Context *ctx, bool isTrain, } void Interpreter::bwdRegressionInst(Context *ctx, const RegressionInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto inG = getGradHandle(ctx, I->getSrc()); - auto expected = getTensorForValue(I->getExpected()); + auto inW = getWeightHandle(I->getSrc()); + auto inG = getGradHandle(I->getSrc()); + auto expected = getTensor(I->getExpected()); auto idim = inW.dims(); assert(idim.size() == 2 && "Input is expected to be a vector per input"); @@ -574,8 +574,8 @@ void Interpreter::bwdRegressionInst(Context *ctx, const RegressionInst *I) { void Interpreter::fwdTransposeInst(Context *ctx, bool isTrain, const TransposeInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getTensorForValue(I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getTensor(I->getDest()); assert(outW->size() == inW.size() && "Invalid tensor dimensions"); inW.transpose(outW, I->getShuffle()); @@ -583,7 +583,7 @@ void Interpreter::fwdTransposeInst(Context *ctx, bool isTrain, void Interpreter::bwdTransposeInst(Context *ctx, const TransposeInst *I) { auto inG = getOrCreateGradTensor(I->getSrc()); - auto outG = getGradHandle(ctx, I->getDest()); + auto outG = getGradHandle(I->getDest()); assert(outG.size() == inG->size() && "Invalid tensor dimensions"); @@ -602,16 +602,16 @@ void Interpreter::bwdTransposeInst(Context *ctx, const TransposeInst *I) { void Interpreter::fwdReshapeInst(Context *ctx, bool isTrain, const ReshapeInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); for (size_t i = 0, e = inW.size(); i < e; i++) { outW.raw(i) = inW.raw(i); } } void Interpreter::bwdReshapeInst(Context *ctx, const ReshapeInst *I) { - auto inG = getGradHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); - auto outG = getGradHandle(ctx, I->getDest()); + auto inG = getGradHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); + auto outG = getGradHandle(I->getDest()); for (size_t i = 0, e = outW.size(); i < e; i++) { inG.raw(i) += outG.raw(i); } @@ -619,14 +619,14 @@ void Interpreter::bwdReshapeInst(Context *ctx, const ReshapeInst *I) { void Interpreter::fwdConcatInst(Context *ctx, bool isTrain, const ConcatInst *I) { - auto outW = getWeightHandle(ctx, I->getDest()); + auto outW = getWeightHandle(I->getDest()); // Insert the tensors at this coordinate. Start at zero. std::vector offset(outW.size(), 0); auto dim = I->getDim(); for (unsigned i = 1, e = I->getNumOperands(); i < e; i++) { - auto inW = getWeightHandle(ctx, I->getOperand(i).first); + auto inW = getWeightHandle(I->getOperand(i).first); // Insert the tensor. outW.insertTensors(inW, offset); @@ -636,7 +636,7 @@ void Interpreter::fwdConcatInst(Context *ctx, bool isTrain, } } void Interpreter::bwdConcatInst(Context *ctx, const ConcatInst *I) { - auto outG = getGradHandle(ctx, I->getDest()); + auto outG = getGradHandle(I->getDest()); // Insert the tensors at this coordinate. Start at zero. std::vector offset(outG.size(), 0); @@ -644,7 +644,7 @@ void Interpreter::bwdConcatInst(Context *ctx, const ConcatInst *I) { auto dim = I->getDim(); for (unsigned i = 1, e = I->getNumOperands(); i < e; i++) { - auto inG = getGradHandle(ctx, I->getOperand(i).first); + auto inG = getGradHandle(I->getOperand(i).first); // Insert the tensor. outG.extractTensors(inG, offset); @@ -672,13 +672,13 @@ void Interpreter::fwdBatchNormalizationInst(Context *ctx, bool isTrain, void Interpreter::fwdBatchNormalizationInst_infer( Context *ctx, const BatchNormalizationInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); - auto betaWH = getWeightHandle(ctx, I->getBias()); - auto gammaWH = getWeightHandle(ctx, I->getScale()); - auto varH = getWeightHandle(ctx, I->getVar()); - auto meanH = getWeightHandle(ctx, I->getMean()); + auto betaWH = getWeightHandle(I->getBias()); + auto gammaWH = getWeightHandle(I->getScale()); + auto varH = getWeightHandle(I->getVar()); + auto meanH = getWeightHandle(I->getMean()); auto channelIdx = I->getChannelIdx(); auto epsilon = I->getEpsilon(); @@ -710,9 +710,9 @@ void Interpreter::fwdBatchNormalizationInst_infer( void Interpreter::fwdBatchNormalizationInst_train( Context *ctx, const BatchNormalizationInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto varH = getWeightHandle(ctx, I->getVar()); - auto meanH = getWeightHandle(ctx, I->getMean()); + auto inW = getWeightHandle(I->getSrc()); + auto varH = getWeightHandle(I->getVar()); + auto meanH = getWeightHandle(I->getMean()); auto channelIdx = I->getChannelIdx(); auto momentum = I->getMomentum(); @@ -766,16 +766,16 @@ void Interpreter::fwdBatchNormalizationInst_train( void Interpreter::bwdBatchNormalizationInst(Context *ctx, const BatchNormalizationInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto inG = getGradHandle(ctx, I->getSrc()); - auto outG = getGradHandle(ctx, I->getDest()); + auto inW = getWeightHandle(I->getSrc()); + auto inG = getGradHandle(I->getSrc()); + auto outG = getGradHandle(I->getDest()); - auto gammaWH = getWeightHandle(ctx, I->getScale()); - auto betaGH = getGradHandle(ctx, I->getBias()); - auto gammaGH = getGradHandle(ctx, I->getScale()); + auto gammaWH = getWeightHandle(I->getScale()); + auto betaGH = getGradHandle(I->getBias()); + auto gammaGH = getGradHandle(I->getScale()); - auto varH = getWeightHandle(ctx, I->getVar()); - auto meanH = getWeightHandle(ctx, I->getMean()); + auto varH = getWeightHandle(I->getVar()); + auto meanH = getWeightHandle(I->getMean()); auto channelIdx = I->getChannelIdx(); auto epsilon = I->getEpsilon(); @@ -850,9 +850,9 @@ void Interpreter::bwdBatchNormalizationInst(Context *ctx, void Interpreter::fwdLocalResponseNormalizationInst( glow::Context *ctx, bool isTrain, const glow::LocalResponseNormalizationInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); - auto scaleCache = getWeightHandle(ctx, I->getScale()); + auto inW = getWeightHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); + auto scaleCache = getWeightHandle(I->getScale()); ShapeNHWC odim(outW.dims()); ShapeNHWC idim(inW.dims()); @@ -915,11 +915,11 @@ void Interpreter::fwdLocalResponseNormalizationInst( void Interpreter::bwdLocalResponseNormalizationInst( glow::Context *ctx, const glow::LocalResponseNormalizationInst *I) { - auto inW = getWeightHandle(ctx, I->getSrc()); - auto inG = getGradHandle(ctx, I->getSrc()); - auto outW = getWeightHandle(ctx, I->getDest()); - auto outG = getGradHandle(ctx, I->getDest()); - auto scaleCache = getWeightHandle(ctx, I->getScale()); + auto inW = getWeightHandle(I->getSrc()); + auto inG = getGradHandle(I->getSrc()); + auto outW = getWeightHandle(I->getDest()); + auto outG = getGradHandle(I->getDest()); + auto scaleCache = getWeightHandle(I->getScale()); ShapeNHWC odim(outW.dims()); @@ -989,9 +989,9 @@ void Interpreter::bwdLocalResponseNormalizationInst( void Interpreter::fwdArithmeticInst(Context *ctx, bool isTrain, const ArithmeticInst *I) { - auto outW = getWeightHandle(ctx, I->getDest()); - auto LHSW = getWeightHandle(ctx, I->getLHS()); - auto RHSW = getWeightHandle(ctx, I->getRHS()); + auto outW = getWeightHandle(I->getDest()); + auto LHSW = getWeightHandle(I->getLHS()); + auto RHSW = getWeightHandle(I->getRHS()); switch (I->getKind()) { case ArithmeticInst::OpKind::Add: @@ -1011,11 +1011,11 @@ void Interpreter::fwdArithmeticInst(Context *ctx, bool isTrain, } void Interpreter::bwdArithmeticInst(Context *ctx, const ArithmeticInst *I) { - auto LHSW = getWeightHandle(ctx, I->getLHS()); - auto RHSW = getWeightHandle(ctx, I->getRHS()); - auto outG = getGradHandle(ctx, I->getDest()); - auto LHSG = getGradHandle(ctx, I->getLHS()); - auto RHSG = getGradHandle(ctx, I->getRHS()); + auto LHSW = getWeightHandle(I->getLHS()); + auto RHSW = getWeightHandle(I->getRHS()); + auto outG = getGradHandle(I->getDest()); + auto LHSG = getGradHandle(I->getLHS()); + auto RHSG = getGradHandle(I->getRHS()); switch (I->getKind()) { case ArithmeticInst::OpKind::Add: @@ -1042,7 +1042,7 @@ void Interpreter::bwdArithmeticInst(Context *ctx, const ArithmeticInst *I) { void Interpreter::fwdAllocActivationInst(Context *ctx, bool isTrain, const AllocActivationInst *I) { - allocateBackingTensor(I); + getOrCreateTensor(I); // Prepare for the next backprop iteration by zeroing the gradient // tensors. Notice that this only zeros the temporary grad tensors that // match the output tensors but not the gradient tensors that are diff --git a/tests/unittests/CMakeLists.txt b/tests/unittests/CMakeLists.txt index 8819a8f966..c785287a39 100644 --- a/tests/unittests/CMakeLists.txt +++ b/tests/unittests/CMakeLists.txt @@ -15,6 +15,7 @@ target_link_libraries(IRgradCheckTest PRIVATE Network Interpreter + ExecutionEngine IR gtest gtest_main) @@ -38,6 +39,7 @@ target_link_libraries(InterpreterTest Interpreter Network IR + ExecutionEngine gtest gtest_main) add_test(InterpreterTest ${GLOW_BINARY_DIR}/tests/InterpreterTest) diff --git a/tests/unittests/IRGradCheck.cpp b/tests/unittests/IRGradCheck.cpp index 3c6f026841..acd4566c31 100644 --- a/tests/unittests/IRGradCheck.cpp +++ b/tests/unittests/IRGradCheck.cpp @@ -1,12 +1,12 @@ // Copyright 2017 Facebook Inc. All Rights Reserved. #include "glow/Base/Tensor.h" +#include "glow/ExecutionEngine/ExecutionEngine.h" #include "glow/Graph/Graph.h" #include "glow/Graph/Nodes.h" #include "glow/IR/IR.h" #include "glow/IR/IRBuilder.h" #include "glow/IR/Instrs.h" -#include "glow/Interpreter/Interpreter.h" #include "gtest/gtest.h" @@ -34,7 +34,7 @@ FloatTy gradDiff(FloatTy G1, FloatTy G2) { return std::abs(G1 - G2) / std::abs(G1 + G2 + 1); } -void performGradCheck(Interpreter &IP, Node *result, Variable *inputVar, +void performGradCheck(ExecutionEngine &IP, Node *result, Variable *inputVar, Variable *expVar, Tensor *inputs, Tensor *outputs, float delta, float allowedError) { auto inputsH = inputs->getHandle(); @@ -43,12 +43,12 @@ void performGradCheck(Interpreter &IP, Node *result, Variable *inputVar, IP.train(300, {inputVar, expVar}, {inputs, outputs}); // Clear the gradients of the first layer. - IP.getGradHandle(nullptr, inputVar).clear(); + IP.getGradHandle(inputVar).clear(); // Train the network just once to calculate the grads. IP.train(1, {inputVar, expVar}, {inputs, outputs}); - auto analyticalGradsH = IP.getGradHandle(nullptr, inputVar); + auto analyticalGradsH = IP.getGradHandle(inputVar); for (size_t i = 0; i < analyticalGradsH.size(); i++) { auto old = inputsH.raw(i); @@ -56,13 +56,13 @@ void performGradCheck(Interpreter &IP, Node *result, Variable *inputVar, // Calculate f(x+e): inputsH.raw(i) = old + delta; IP.infer({inputVar}, {inputs}); - Tensor *res = IP.getTensorForNode(result); + Tensor *res = IP.getTensor(result); auto plusLoss = computeL2Loss(outputs, res); // Calculate f(x-e): inputsH.raw(i) = old - delta; IP.infer({inputVar}, {inputs}); - res = IP.getTensorForNode(result); + res = IP.getTensor(result); auto minusLoss = computeL2Loss(outputs, res); inputsH.raw(i) = old; @@ -77,7 +77,7 @@ void performGradCheck(Interpreter &IP, Node *result, Variable *inputVar, } TEST(Network, gradientCheck_FC_Concat_RELU) { - Interpreter IP; + ExecutionEngine IP; IP.getConfig().maxNumThreads = 1; size_t numInputElem = 20; @@ -118,7 +118,7 @@ TEST(Network, gradientCheck_FC_Concat_RELU) { } TEST(Network, gradientCheck_Conv) { - Interpreter IP; + ExecutionEngine IP; IP.getConfig().maxNumThreads = 1; size_t numDim = 10; @@ -155,7 +155,7 @@ TEST(Network, gradientCheck_Conv) { } TEST(Network, gradientCheck_AvgPool) { - Interpreter IP; + ExecutionEngine IP; IP.getConfig().maxNumThreads = 1; size_t numDim = 10; @@ -190,7 +190,7 @@ TEST(Network, gradientCheck_AvgPool) { } TEST(Network, gradientCheck_batchNorm) { - Interpreter IP; + ExecutionEngine IP; IP.getConfig().maxNumThreads = 1; size_t numDim = 5; @@ -230,7 +230,7 @@ TEST(Network, gradientCheck_batchNorm) { } TEST(Network, gradientCheck_Arithmetic) { - Interpreter IP; + ExecutionEngine IP; IP.getConfig().maxNumThreads = 1; size_t numDim = 5; @@ -274,16 +274,16 @@ TEST(Network, gradientCheck_Arithmetic) { IP.train(30, {A, B, C, Exp}, {&iA, &iB, &iC, &outputs}); // Clear the gradients of the last layer. - IP.getGradHandle(nullptr, A).clear(); - IP.getGradHandle(nullptr, B).clear(); - IP.getGradHandle(nullptr, C).clear(); + IP.getGradHandle(A).clear(); + IP.getGradHandle(B).clear(); + IP.getGradHandle(C).clear(); IP.train(1, {A, B, C, Exp}, {&iA, &iB, &iC, &outputs}); auto check = [&](Variable *var, Tensor *t) { auto iH = t->getHandle(); - auto analyticalGradsH = IP.getGradHandle(nullptr, var); + auto analyticalGradsH = IP.getGradHandle(var); float delta = 0.001; for (size_t i = 0; i < numDim; i++) { @@ -292,14 +292,14 @@ TEST(Network, gradientCheck_Arithmetic) { // Calculate f(x+e): iH.at({0, i}) = old + delta; IP.infer({A, B, C, Exp}, {&iA, &iB, &iC, &outputs}); - Tensor *res = IP.getTensorForNode(result); + Tensor *res = IP.getTensor(result); auto plusLoss = computeL2Loss(&outputs, res); // Calculate f(x-e): iH.at({0, i}) = old - delta; IP.infer({A, B, C, Exp}, {&iA, &iB, &iC, &outputs}); - res = IP.getTensorForNode(result); + res = IP.getTensor(result); auto minusLoss = computeL2Loss(&outputs, res); iH.at({0, i}) = old; @@ -320,7 +320,7 @@ TEST(Network, gradientCheck_Arithmetic) { TEST(Network, gradientCheck_FC_Concat_Tanh) { // Using the same gradient check test setup as gradientCheck_FC_Concat_RELU - Interpreter IP; + ExecutionEngine IP; IP.getConfig().maxNumThreads = 1; size_t numInputElem = 20; @@ -355,7 +355,7 @@ TEST(Network, gradientCheck_FC_Concat_Tanh) { TEST(Network, gradientCheck_Transpose) { // Using the same gradient check test setup as gradientCheck_FC_Concat_RELU - Interpreter IP; + ExecutionEngine IP; IP.getConfig().maxNumThreads = 1; size_t numOutputElem = 10; diff --git a/tests/unittests/InterpreterTest.cpp b/tests/unittests/InterpreterTest.cpp index 75485183b9..b559569ffe 100644 --- a/tests/unittests/InterpreterTest.cpp +++ b/tests/unittests/InterpreterTest.cpp @@ -1,6 +1,6 @@ // Copyright 2017 Facebook Inc. All Rights Reserved. -#include "glow/Interpreter/Interpreter.h" +#include "glow/ExecutionEngine/ExecutionEngine.h" #include "glow/Graph/Node.h" #include "glow/Graph/Nodes.h" #include "glow/IR/IR.h" @@ -15,11 +15,11 @@ using namespace glow; TEST(Interpreter, interpret) { - Interpreter IP; + ExecutionEngine EE; Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3}); - auto &G = IP.getGraph(); + auto &G = EE.getGraph(); auto *input = G.createVariable(ElemKind::FloatTy, {1, 32, 32, 3}, "input"); auto *ex = G.createVariable(ElemKind::IndexTy, {1, 1}, "exp"); @@ -41,19 +41,19 @@ TEST(Interpreter, interpret) { auto *SM = G.createSoftMax("sm", RL3, ex); G.createReturn("ret", SM); - IP.getModule().generateIR(); - IP.optimize(OptimizationMode::Infer); - IP.initVars(); - IP.infer({input}, {&inputs}); + EE.getModule().generateIR(); + EE.optimize(OptimizationMode::Infer); + EE.initVars(); + EE.infer({input}, {&inputs}); } TEST(Interpreter, trainASimpleNetwork) { - Interpreter IP; + ExecutionEngine EE; // Learning a single input vector. - IP.getConfig().maxNumThreads = 1; - IP.getConfig().learningRate = 0.05; + EE.getConfig().maxNumThreads = 1; + EE.getConfig().learningRate = 0.05; - auto &G = IP.getGraph(); + auto &G = EE.getGraph(); // Create a variable with 1 input, which is a vector of 4 elements. auto *A = G.createVariable(ElemKind::FloatTy, {1, 4}, "A"); @@ -72,18 +72,18 @@ TEST(Interpreter, trainASimpleNetwork) { inputs.getHandle() = {0.15, 0.15, 0.15, 0.15}; expected.getHandle() = {0.9, 0.9, 0.9, 0.9}; - IP.getModule().generateIR(); - IP.optimize(OptimizationMode::Train); - IP.initVars(); + EE.getModule().generateIR(); + EE.optimize(OptimizationMode::Train); + EE.initVars(); // Train the network. Learn 1000 batches. - IP.train(1000, {A, E}, {&inputs, &expected}); + EE.train(1000, {A, E}, {&inputs, &expected}); // Testing the output vector. - IP.optimize(OptimizationMode::Infer); - IP.infer({A}, {&inputs}); - auto RNWH = IP.getTensorForNode(result)->getHandle(); + EE.optimize(OptimizationMode::Infer); + EE.infer({A}, {&inputs}); + auto RNWH = EE.getTensor(result)->getHandle(); (void)RNWH; // Test the output: @@ -97,16 +97,16 @@ TEST(Interpreter, simpleRegression) { const int numInputs = 4; // Learning the Xor function. - Interpreter IP; + ExecutionEngine EE; // Learning a single input vector. - IP.getConfig().maxNumThreads = 1; - IP.getConfig().learningRate = 0.05; + EE.getConfig().maxNumThreads = 1; + EE.getConfig().learningRate = 0.05; Tensor inputs(ElemKind::FloatTy, {1, numInputs}); Tensor expected(ElemKind::FloatTy, {1, numInputs}); - auto &G = IP.getGraph(); + auto &G = EE.getGraph(); auto *A = G.createVariable(ElemKind::FloatTy, {1, numInputs}, "A"); auto *Ex = G.createVariable(ElemKind::FloatTy, {1, numInputs}, "E"); Node *O = G.createFullyConnected("fc", A, 4); @@ -117,16 +117,16 @@ TEST(Interpreter, simpleRegression) { auto I = inputs.getHandle(); auto E = expected.getHandle(); - IP.getModule().generateIR(); - IP.optimize(OptimizationMode::Train); - IP.initVars(); + EE.getModule().generateIR(); + EE.optimize(OptimizationMode::Train); + EE.initVars(); // Train the network: for (int iter = 0; iter < 1000; iter++) { float target = float(iter % 9); I = {target, 0., 0., 0.}; E = {0., target + 1, 0., 0.}; - IP.train(1, {A, Ex}, {&inputs, &expected}); + EE.train(1, {A, Ex}, {&inputs, &expected}); } // Verify the result of the regression layer. @@ -135,9 +135,9 @@ TEST(Interpreter, simpleRegression) { for (int iter = 0; iter < 5; iter++) { float target = iter % 9 + 1; I = {target, 0., 0., 0.}; - IP.infer({A}, {&inputs}); + EE.infer({A}, {&inputs}); - auto resH = IP.getTensorForNode(result)->getHandle(); + auto resH = EE.getTensor(result)->getHandle(); (void)resH; EXPECT_NEAR(I.at({0, 0}) + 1, resH.at({0, 1}), 0.1); @@ -149,13 +149,13 @@ TEST(Interpreter, learnXor) { unsigned numTests = 10; // Learning the Xor function. - Interpreter IP; + ExecutionEngine EE; // Learning a single input vector. - IP.getConfig().maxNumThreads = 1; - IP.getConfig().learningRate = 0.05; + EE.getConfig().maxNumThreads = 1; + EE.getConfig().learningRate = 0.05; - auto &G = IP.getGraph(); + auto &G = EE.getGraph(); auto *A = G.createVariable(ElemKind::FloatTy, {numInputs, 2}, "A"); auto *Ex = G.createVariable(ElemKind::FloatTy, {numInputs, 1}, "Ex"); @@ -185,12 +185,12 @@ TEST(Interpreter, learnXor) { TL.at({i, 0}) = a ^ b; } - IP.getModule().generateIR(); - IP.optimize(OptimizationMode::Train); - IP.initVars(); + EE.getModule().generateIR(); + EE.optimize(OptimizationMode::Train); + EE.initVars(); // Train the network: - IP.train(2500, {A, Ex}, {&trainingSet, &trainingLabels}); + EE.train(2500, {A, Ex}, {&trainingSet, &trainingLabels}); // Prepare the testing tensor: for (unsigned i = 0; i < numTests; i++) { @@ -198,8 +198,8 @@ TEST(Interpreter, learnXor) { TT.at({i, 1}) = (i >> 1) % 2; } - IP.infer({A}, {&trainingSet}); - auto resH = IP.getTensorForNode(result)->getHandle(); + EE.infer({A}, {&trainingSet}); + auto resH = EE.getTensor(result)->getHandle(); // Test the output: for (size_t i = 0; i < numTests; i++) { @@ -245,14 +245,14 @@ void generateCircleData(Tensor &coordinates, Tensor &labels) { /// http://cs.stanford.edu/people/karpathy/convnetjs/demo/classify2d.html TEST(Network, circle) { // Testing the softmax layer. - Interpreter IP; + ExecutionEngine EE; // Learning a single input vector. - IP.getConfig().maxNumThreads = 1; - IP.getConfig().momentum = 0.9; - IP.getConfig().learningRate = 0.01; + EE.getConfig().maxNumThreads = 1; + EE.getConfig().momentum = 0.9; + EE.getConfig().learningRate = 0.01; - auto &G = IP.getGraph(); + auto &G = EE.getGraph(); auto *A = G.createVariable(ElemKind::FloatTy, {1, 2}, "A"); auto *S = G.createVariable(ElemKind::IndexTy, {1, 1}, "S", Variable::InitKind::Extern); @@ -264,16 +264,16 @@ TEST(Network, circle) { auto *SM = G.createSoftMax("soft", RL1, S); auto *result = G.createReturn("ret", SM); - IP.getModule().generateIR(); - IP.optimize(OptimizationMode::Train); - IP.initVars(); + EE.getModule().generateIR(); + EE.optimize(OptimizationMode::Train); + EE.initVars(); Tensor coordinates(ElemKind::FloatTy, {numSamples, 2}); Tensor labels(ElemKind::IndexTy, {numSamples, 1}); generateCircleData(coordinates, labels); // Training: - IP.train(4000, {A, S}, {&coordinates, &labels}); + EE.train(4000, {A, S}, {&coordinates, &labels}); // Print a diagram that depicts the network decision on a grid. for (int x = -10; x < 10; x++) { @@ -282,9 +282,9 @@ TEST(Network, circle) { Tensor sample(ElemKind::FloatTy, {1, 2}); sample.getHandle() = {float(x) / 10, float(y) / 10}; - IP.infer({A}, {&sample}); + EE.infer({A}, {&sample}); - auto SMH = IP.getTensorForNode(result)->getHandle(); + auto SMH = EE.getTensor(result)->getHandle(); auto A = SMH.at({0, 0}); auto B = SMH.at({0, 1}); @@ -305,8 +305,8 @@ TEST(Network, circle) { // The dot in the middle must be zero. Tensor sample(ElemKind::FloatTy, {1, 2}); sample.getHandle() = {0., 0.}; - IP.infer({A}, {&sample}); - auto SMH = IP.getTensorForNode(result)->getHandle(); + EE.infer({A}, {&sample}); + auto SMH = EE.getTensor(result)->getHandle(); auto A = SMH.at({0, 0}); auto B = SMH.at({0, 1}); EXPECT_LE(A, 0.1); @@ -317,8 +317,8 @@ TEST(Network, circle) { // Far away dot must be one. Tensor sample(ElemKind::FloatTy, {1, 2}); sample.getHandle() = {1., 1.}; - IP.infer({A}, {&sample}); - auto SMH = IP.getTensorForNode(result)->getHandle(); + EE.infer({A}, {&sample}); + auto SMH = EE.getTensor(result)->getHandle(); auto A = SMH.at({0, 0}); auto B = SMH.at({0, 1}); EXPECT_GE(A, 0.9); @@ -327,15 +327,15 @@ TEST(Network, circle) { } TEST(Network, learnSingleValueConcat) { - Interpreter IP; + ExecutionEngine EE; unsigned width = 6; // Learning a single input vector. - IP.getConfig().maxNumThreads = 1; - IP.getConfig().momentum = 0.9; - IP.getConfig().learningRate = 0.01; + EE.getConfig().maxNumThreads = 1; + EE.getConfig().momentum = 0.9; + EE.getConfig().learningRate = 0.01; - auto &G = IP.getGraph(); + auto &G = EE.getGraph(); auto *Ex = G.createVariable(ElemKind::FloatTy, {1, width * 2}, "Ex"); @@ -359,18 +359,18 @@ TEST(Network, learnSingleValueConcat) { inputs.getHandle().clear(0.15); expected.getHandle().clear(0.9); - IP.getModule().generateIR(); - IP.optimize(OptimizationMode::Train); - IP.initVars(); + EE.getModule().generateIR(); + EE.optimize(OptimizationMode::Train); + EE.initVars(); // Train the network: - IP.train(1000, {A, B, Ex}, {&inputs, &inputs, &expected}); + EE.train(1000, {A, B, Ex}, {&inputs, &inputs, &expected}); - IP.optimize(OptimizationMode::Infer); + EE.optimize(OptimizationMode::Infer); // Testing the output vector. - IP.infer({A}, {&inputs}); - auto RNWH = IP.getTensorForNode(result)->getHandle(); + EE.infer({A}, {&inputs}); + auto RNWH = EE.getTensor(result)->getHandle(); (void)RNWH; // Test the output: diff --git a/tools/loader/CMakeLists.txt b/tools/loader/CMakeLists.txt index 8c09c3be1d..5dde967933 100644 --- a/tools/loader/CMakeLists.txt +++ b/tools/loader/CMakeLists.txt @@ -5,6 +5,7 @@ target_link_libraries(loader PRIVATE Interpreter Importer + ExecutionEngine Network IR Support) diff --git a/tools/loader/loader.cpp b/tools/loader/loader.cpp index 0c450f1ea4..d7c6947289 100644 --- a/tools/loader/loader.cpp +++ b/tools/loader/loader.cpp @@ -2,6 +2,7 @@ #include "glow/Base/Image.h" #include "glow/Base/Tensor.h" +#include "glow/ExecutionEngine/ExecutionEngine.h" #include "glow/Graph/Graph.h" #include "glow/Graph/Nodes.h" #include "glow/Importer/Caffe2.h" @@ -77,20 +78,20 @@ int main(int argc, char **argv) { auto imageMode = strToImageNormalizationMode(argv[2]); loadImageAndPreprocess(argv[1], &data, imageMode); - Interpreter IP; + ExecutionEngine EE; caffe2ModelLoader LD(argv[3], argv[4], {"data", "gpu_0/data", "softmax_expected"}, - {&data, &data, &expected_softmax}, IP); + {&data, &data, &expected_softmax}, EE); - IP.optimize(OptimizationMode::Infer); - IP.initVars(); + EE.optimize(OptimizationMode::Infer); + EE.initVars(); auto *SM = LD.getRoot(); auto *i0 = cast(LD.getOrCreateNodeByName("gpu_0/data")); auto *i1 = cast(LD.getOrCreateNodeByName("data")); - IP.infer({i0, i1}, {&data, &data}); - auto *res = IP.getTensorForNode(SM); + EE.infer({i0, i1}, {&data, &data}); + auto *res = EE.getTensor(SM); auto H = res->getHandle(); H.dump("res = ", "\n"); Tensor slice = H.extractSlice(0); From c14e2336348eefed9bc85b79a6b5d27f223d402f Mon Sep 17 00:00:00 2001 From: Nadav Rotem Date: Fri, 13 Oct 2017 09:33:57 -0700 Subject: [PATCH 2/5] Remove redundant includes from headers. This commit removes the redundant includes from header files. It converts a bunch of includes into forward declarations and moves files into the right base library. It also moves the execution engine into a "pimpl" implementation where the details of the classes are hidden behind a pointer. This can't be done with a smart pointer that requires the complete type for initialization. --- examples/mnist.cpp | 8 +-- include/glow/{IR => Base}/Traits.h | 0 .../glow/ExecutionEngine/ExecutionEngine.h | 33 ++++++---- include/glow/Graph/Graph.h | 18 ------ include/glow/Graph/Node.h | 2 +- include/glow/IR/IR.h | 10 ++-- include/glow/IR/IRBuilder.h | 4 +- include/glow/Interpreter/Interpreter.h | 24 ++++---- src/glow/ExecutionEngine/ExecutionEngine.cpp | 60 ++++++++++++------- src/glow/IR/IRBuilder.cpp | 40 ++++++------- src/glow/IR/IRGen.cpp | 12 ++-- src/glow/Interpreter/Interpreter.cpp | 4 +- tests/unittests/GraphTest.cpp | 4 +- tests/unittests/IRTest.cpp | 18 +++--- tests/unittests/InterpreterTest.cpp | 3 +- tests/unittests/Tensors.cpp | 7 --- 16 files changed, 123 insertions(+), 124 deletions(-) rename include/glow/{IR => Base}/Traits.h (100%) diff --git a/examples/mnist.cpp b/examples/mnist.cpp index 10b0416e60..34888c25ef 100644 --- a/examples/mnist.cpp +++ b/examples/mnist.cpp @@ -1,11 +1,5 @@ #include "glow/ExecutionEngine/ExecutionEngine.h" #include "glow/Graph/Graph.h" -#include "glow/Graph/Node.h" -#include "glow/Graph/Nodes.h" -#include "glow/IR/IR.h" -#include "glow/IR/IRBuilder.h" -#include "glow/IR/Instrs.h" -#include "glow/Interpreter/Interpreter.h" #include "glow/Support/Support.h" #include "llvm/Support/Timer.h" @@ -99,7 +93,7 @@ void testMNIST() { auto *result = G.createReturn("return", SM); - EE.getModule().generateIR(); + EE.generateIR(); EE.optimize(OptimizationMode::Train); EE.initVars(); diff --git a/include/glow/IR/Traits.h b/include/glow/Base/Traits.h similarity index 100% rename from include/glow/IR/Traits.h rename to include/glow/Base/Traits.h diff --git a/include/glow/ExecutionEngine/ExecutionEngine.h b/include/glow/ExecutionEngine/ExecutionEngine.h index 646997c7eb..e8b4c8bca6 100644 --- a/include/glow/ExecutionEngine/ExecutionEngine.h +++ b/include/glow/ExecutionEngine/ExecutionEngine.h @@ -1,43 +1,54 @@ #ifndef GLOW_EXECUTIONENGINE_EXECUTIONENGINE_H #define GLOW_EXECUTIONENGINE_EXECUTIONENGINE_H -#include "glow/Base/Tensor.h" #include "glow/Base/Train.h" -#include "glow/Graph/Graph.h" -#include "glow/IR/IR.h" -#include "glow/IR/IRBuilder.h" -#include "glow/Interpreter/Interpreter.h" #include "glow/Optimizer/Optimizer.h" #include "llvm/ADT/ArrayRef.h" +#include #include namespace glow { +class Graph; +class Node; +class Interpreter; +class Variable; +class Tensor; +class Value; + /// This is the ExecutionEngine. It owns the Graph, the IR, and the backends. +/// The Graph, Module, etc in this class are defined as pointers, in order to +/// erase the type and prevent the internal types from leaking out to the +/// users of this class. class ExecutionEngine final { /// The Graph that represents the high-level program. - Graph G_{}; + Graph *G_; /// The Module that holds the IR. - Module M_; + Module *M_; /// The network interpreter - Interpreter IP_; + Interpreter *IP_; /// The network trainer. Trainer trainer_{}; public: - ExecutionEngine() : M_(G_), IP_(M_) {} + ExecutionEngine(); + + ~ExecutionEngine(); /// \returns the internal module. - Module &getModule() { return M_; } + Module &getModule() { return *M_; } /// \returns the internal module. - Graph &getGraph() { return G_; } + Graph &getGraph() { return *G_; } /// Run the target-independent optimizations on the module. void optimize(OptimizationMode mode); + /// Generate IR from the graph nodes. + void generateIR(); + /// Provides access to the training configuration. TrainingConfig &getConfig() { return trainer_.config; } diff --git a/include/glow/Graph/Graph.h b/include/glow/Graph/Graph.h index c03a99c334..eebb3b15b4 100644 --- a/include/glow/Graph/Graph.h +++ b/include/glow/Graph/Graph.h @@ -13,24 +13,6 @@ namespace glow { -class Node; -class Variable; -class ConvolutionNode; -class PoolNode; -class FullyConnectedNode; -class ReluNode; -class SigmoidNode; -class TanhNode; -class SoftMaxNode; -class RegressionNode; -class ReshapeNode; -class TransposeNode; -class ConcatNode; -class BatchNormalizationNode; -class LocalResponseNormalizationNode; -class ArithmeticNode; -class ReturnNode; - /// Represents the compute graph. class Graph final { /// A uniqued list of types in the module. Types in this list can be compared diff --git a/include/glow/Graph/Node.h b/include/glow/Graph/Node.h index e9dae953a0..735bee8309 100644 --- a/include/glow/Graph/Node.h +++ b/include/glow/Graph/Node.h @@ -4,7 +4,7 @@ #include "llvm/ADT/StringRef.h" #include "glow/Base/Type.h" -#include "glow/IR/Traits.h" +#include "glow/Base/Traits.h" namespace glow { diff --git a/include/glow/IR/IR.h b/include/glow/IR/IR.h index 658b4c370c..4c8c14f3ab 100644 --- a/include/glow/IR/IR.h +++ b/include/glow/IR/IR.h @@ -2,7 +2,7 @@ #define GLOW_IR_IR_H #include "glow/Base/Type.h" -#include "glow/IR/Traits.h" +#include "glow/Base/Traits.h" #include "glow/IR/UseDef.h" #include "llvm/ADT/ArrayRef.h" @@ -102,8 +102,8 @@ class Module final { using WeightVarListTy = std::list; private: - /// A reference to the graph structure. - Graph &G_; + /// A pointer to the graph structure. The Module does not own the graph. + Graph *G_; /// A list of weights. Weights are shared between all execution context. std::list weights_{}; @@ -122,7 +122,7 @@ class Module final { /// Add an instruction to the instr stream. void pushInstr(Instruction *I) { instrs_.push_back(I); } - explicit Module(Graph &G) : G_(G) {} + explicit Module(Graph *G) : G_(G) {} ~Module(); @@ -130,7 +130,7 @@ class Module final { void generateIR(); /// \returns a reference to the original graph. - Graph &getGraph() { return G_; } + Graph *getGraph() { return G_; } /// Verify the correctness of the module. void verify() const; diff --git a/include/glow/IR/IRBuilder.h b/include/glow/IR/IRBuilder.h index 694289c8c0..c1622497a7 100644 --- a/include/glow/IR/IRBuilder.h +++ b/include/glow/IR/IRBuilder.h @@ -15,14 +15,14 @@ class IRBuilder { using InitKind = WeightVar::InitKind; /// The module that we are building. - Module &M_; + Module *M_; /// A list of allocated buffers that need to be deallocated at the end of the /// program that we are constructing. std::vector activeAllocs_; public: - explicit IRBuilder(Module &M) : M_(M) {} + explicit IRBuilder(Module *M) : M_(M) {} ~IRBuilder(); diff --git a/include/glow/Interpreter/Interpreter.h b/include/glow/Interpreter/Interpreter.h index 3dfc1beaa9..6d645c2564 100644 --- a/include/glow/Interpreter/Interpreter.h +++ b/include/glow/Interpreter/Interpreter.h @@ -2,11 +2,6 @@ #define GLOW_INTERPRETER_INTERPRETER_H #include "glow/Base/Tensor.h" -#include "glow/Base/Train.h" -#include "glow/Graph/Graph.h" -#include "glow/IR/IR.h" -#include "glow/IR/IRBuilder.h" -#include "glow/Optimizer/Optimizer.h" #include "llvm/ADT/ArrayRef.h" @@ -15,12 +10,21 @@ namespace glow { class Context; +class Module; +class Value; +class Tensor; + +// Forward declare all of the classes. +#define DEF_VALUE(CLASS, NAME) class CLASS; +#define DEF_NODE(CLASS, NAME) class CLASS; +#define DEF_INSTR(CLASS, NAME) class CLASS; +#include "glow/IR/Instrs.def" /// This is the IR-interpreter. It owns the IR, and the heap, and is able to /// execute the instructions one at a time. class Interpreter final { - /// The Module that holds the IR. - Module &M_; + /// The Module that holds the IR. This does not own the module. + Module *M_; /// Maps values to Tensors, that are owned by this class. std::unordered_map tensors_; @@ -29,11 +33,8 @@ class Interpreter final { std::unordered_map gradients_; public: - /// \returns the internal module. - Module &getModule() { return M_; } - /// Ctor. - Interpreter(Module &M) : M_(M) {} + Interpreter(Module *M) : M_(M) {} /// Dtor. ~Interpreter(); @@ -75,6 +76,7 @@ class Interpreter final { /// @name Interpreter methods. This is a list of method declerations that are /// used by the interpreter to dispatch different instructions. ///@{ + #define DEF_VALUE(CLASS, NAME) #define DEF_NODE(CLASS, NAME) #define DEF_INSTR(CLASS, NAME) \ diff --git a/src/glow/ExecutionEngine/ExecutionEngine.cpp b/src/glow/ExecutionEngine/ExecutionEngine.cpp index 27c02d4d48..e5132701fc 100644 --- a/src/glow/ExecutionEngine/ExecutionEngine.cpp +++ b/src/glow/ExecutionEngine/ExecutionEngine.cpp @@ -1,9 +1,25 @@ // Copyright 2017 Facebook Inc. All Rights Reserved. #include "glow/ExecutionEngine/ExecutionEngine.h" +#include "glow/Interpreter/Interpreter.h" + +#include "glow/Graph/Graph.h" +#include "glow/IR/IR.h" +#include "glow/IR/IRBuilder.h" +#include "glow/IR/Instrs.h" +#include "glow/Optimizer/Optimizer.h" using namespace glow; +ExecutionEngine::ExecutionEngine() + : G_(new Graph()), M_(new Module(G_)), IP_(new Interpreter(M_)) {} + +ExecutionEngine::~ExecutionEngine() { + delete IP_; + delete M_; + delete G_; +} + void ExecutionEngine::infer(llvm::ArrayRef vars, llvm::ArrayRef inputs) { assert(!inputs.empty() && "No inputs"); @@ -12,11 +28,11 @@ void ExecutionEngine::infer(llvm::ArrayRef vars, // Update the input variables. for (int i = 0, e = vars.size(); i < e; i++) { - auto *val = M_.getWeightForNode(vars[i]); + auto *val = M_->getWeightForNode(vars[i]); loadValueFromTensor(val, inputs[i], 0); } - IP_.doForwardPass(false); + IP_->doForwardPass(false); } void ExecutionEngine::train(size_t iterations, llvm::ArrayRef vars, @@ -29,7 +45,7 @@ void ExecutionEngine::train(size_t iterations, llvm::ArrayRef vars, std::vector weights; for (auto *v : vars) { - weights.push_back(M_.getWeightForNode(v)); + weights.push_back(M_->getWeightForNode(v)); } // This is the size of one batch (the number of samples in the batch). @@ -49,14 +65,14 @@ void ExecutionEngine::train(size_t iterations, llvm::ArrayRef vars, } void ExecutionEngine::learnGradient(size_t batchSize) { - for (auto *V : M_.getWeights()) { + for (auto *V : M_->getWeights()) { // Do not try to learn the values of input/output buffers. if (V->getInitKind() == WeightVar::InitKind::Extern) { continue; } - auto W = IP_.getTensor(V); - auto G = IP_.getOrCreateGradTensor(V); + auto W = IP_->getTensor(V); + auto G = IP_->getOrCreateGradTensor(V); // Handle weight update by learning the gradients into the weights. trainer_.train(W, G, batchSize); @@ -71,15 +87,15 @@ void ExecutionEngine::updateForwardBackward(llvm::ArrayRef vars, loadValueFromTensor(vars[i], inputs[i], sampleIdx); } - IP_.doForwardPass(true); + IP_->doForwardPass(true); - IP_.doBackwardPass(); + IP_->doBackwardPass(); } void ExecutionEngine::loadValueFromTensor(const Value *v, Tensor *input, size_t sampleIdx) { assert(v && "Invalid value"); - auto *t = IP_.getTensor(v); + auto *t = IP_->getTensor(v); auto dim = input->dims(); assert(t->dims().drop_front() == dim.drop_front() && "Invalid slice size"); @@ -89,41 +105,43 @@ void ExecutionEngine::loadValueFromTensor(const Value *v, Tensor *input, } void ExecutionEngine::optimize(OptimizationMode mode) { - ::glow::optimize(M_, mode); + ::glow::optimize(*M_, mode); } +void ExecutionEngine::generateIR() { M_->generateIR(); } + Tensor *ExecutionEngine::getTensor(const Node *v) const { - auto val = M_.getWeightForNode(v); + auto val = M_->getWeightForNode(v); assert(val && "Node does not have a registered IR value"); - return IP_.getTensor(val); + return IP_->getTensor(val); } /// \returns a float-handle to the tensor that is stored at \p v. Handle ExecutionEngine::getWeightHandle(Variable *v) const { - auto val = M_.getWeightForNode(v); - return IP_.getWeightHandle(val); + auto val = M_->getWeightForNode(v); + return IP_->getWeightHandle(val); } /// \returns a float-handle to the tensor that is stored at \p v. Handle ExecutionEngine::getGradHandle(Variable *v) { - auto val = M_.getWeightForNode(v); - return IP_.getGradHandle(val); + auto val = M_->getWeightForNode(v); + return IP_->getGradHandle(val); } /// Copies the content of the tensor \p t into the value \p v. void ExecutionEngine::initValue(const Variable *v, const Tensor *t) { - auto *N = M_.getWeightForNode(v); - return IP_.initValue(N, t); + auto *N = M_->getWeightForNode(v); + return IP_->initValue(N, t); } void ExecutionEngine::initVars() { - for (auto *W : M_.getWeights()) { + for (auto *W : M_->getWeights()) { // Don't initialize tensors that are already initialized. - if (IP_.hasTensor(W)) { + if (IP_->hasTensor(W)) { continue; } - auto *T = IP_.getOrCreateTensor(W); + auto *T = IP_->getOrCreateTensor(W); // The parameter to the instruction. auto val = W->getVal(); diff --git a/src/glow/IR/IRBuilder.cpp b/src/glow/IR/IRBuilder.cpp index 487a9020ae..56de69467a 100644 --- a/src/glow/IR/IRBuilder.cpp +++ b/src/glow/IR/IRBuilder.cpp @@ -239,7 +239,7 @@ Value *IRBuilder::createReturnOp(Value *input) { CopyInst *IRBuilder::createCopyInst(Value *dest, Value *src) { auto *A = new CopyInst(dest, src); - M_.pushInstr(A); + M_->pushInstr(A); return A; } @@ -249,7 +249,7 @@ ConvolutionInst *IRBuilder::createConvolutionInst(Value *dest, Value *src, size_t pad, size_t depth) { auto *A = new ConvolutionInst(dest, src, filter, bias, kernel, stride, pad, depth); - M_.pushInstr(A); + M_->pushInstr(A); return A; } @@ -257,7 +257,7 @@ PoolInst *IRBuilder::createPoolInst(Value *dest, Value *src, Value *srcXY, PoolInst::OpKind kind, size_t kernel, size_t stride, size_t pad) { auto *A = new PoolInst(dest, src, srcXY, kind, kernel, stride, pad); - M_.pushInstr(A); + M_->pushInstr(A); return A; } @@ -266,46 +266,46 @@ FullyConnectedInst *IRBuilder::createFullyConnectedInst(Value *dest, Value *src, Value *bias, size_t depth) { auto *A = new FullyConnectedInst(dest, src, filter, bias, depth); - M_.pushInstr(A); + M_->pushInstr(A); return A; } ReluInst *IRBuilder::createReluInst(Value *dest, Value *src) { auto *A = new ReluInst(dest, src); - M_.pushInstr(A); + M_->pushInstr(A); return A; } SigmoidInst *IRBuilder::createSigmoidInst(Value *dest, Value *src) { auto *A = new SigmoidInst(dest, src); - M_.pushInstr(A); + M_->pushInstr(A); return A; } TanhInst *IRBuilder::createTanhInst(Value *dest, Value *src) { auto *A = new TanhInst(dest, src); - M_.pushInstr(A); + M_->pushInstr(A); return A; } SoftMaxInst *IRBuilder::createSoftMaxInst(Value *dest, Value *src, Value *E, Value *selected) { auto *A = new SoftMaxInst(dest, src, E, selected); - M_.pushInstr(A); + M_->pushInstr(A); return A; } RegressionInst *IRBuilder::createRegressionInst(Value *dest, Value *src, Value *expected) { auto *A = new RegressionInst(dest, src, expected); - M_.pushInstr(A); + M_->pushInstr(A); return A; } ReshapeInst *IRBuilder::createReshapeInst(Value *dest, Value *src, llvm::ArrayRef shape) { auto *A = new ReshapeInst(dest, src, shape); - M_.pushInstr(A); + M_->pushInstr(A); return A; } @@ -313,7 +313,7 @@ TransposeInst * IRBuilder::createTransposeInst(Value *dest, Value *src, llvm::ArrayRef shuffle) { auto *A = new TransposeInst(dest, src, shuffle); - M_.pushInstr(A); + M_->pushInstr(A); return A; } @@ -321,7 +321,7 @@ ConcatInst *IRBuilder::createConcatInst(Value *dest, llvm::ArrayRef src, size_t dim) { auto *A = new ConcatInst(dest, src, dim); - M_.pushInstr(A); + M_->pushInstr(A); return A; } @@ -330,7 +330,7 @@ BatchNormalizationInst *IRBuilder::createBatchNormalizationInst( size_t channelIdx, float epsilon, float momentum) { auto *A = new BatchNormalizationInst(dest, src, scale, bias, mean, var, channelIdx, epsilon, momentum); - M_.pushInstr(A); + M_->pushInstr(A); return A; } @@ -339,7 +339,7 @@ LocalResponseNormalizationInst *IRBuilder::createLocalResponseNormalizationInst( float beta, float k) { auto *A = new LocalResponseNormalizationInst(dest, src, scale, halfWindowSize, alpha, beta, k); - M_.pushInstr(A); + M_->pushInstr(A); return A; } @@ -347,7 +347,7 @@ ArithmeticInst *IRBuilder::createArithmeticInst(Value *dest, Value *LHS, Value *RHS, ArithmeticInst::OpKind kind) { auto *A = new ArithmeticInst(dest, LHS, RHS, kind); - M_.pushInstr(A); + M_->pushInstr(A); return A; } @@ -355,14 +355,14 @@ WeightVar *IRBuilder::createWeightVar(ElemKind elemTy, llvm::ArrayRef dims, llvm::StringRef name, InitKind initKind, float val) { - auto T = M_.getGraph().uniqueType(elemTy, dims); + auto T = M_->getGraph()->uniqueType(elemTy, dims); return createWeightVar(T, name, initKind, val); } WeightVar *IRBuilder::createWeightVar(TypeRef T, llvm::StringRef name, InitKind initKind, float val) { auto *A = new WeightVar(T, initKind, val); - M_.getWeights().push_back(A); + M_->getWeights().push_back(A); A->setName(name); return A; } @@ -370,19 +370,19 @@ WeightVar *IRBuilder::createWeightVar(TypeRef T, llvm::StringRef name, AllocActivationInst * IRBuilder::createAllocActivationInst(TypeRef T, llvm::StringRef name) { auto *A = new AllocActivationInst(T); - M_.pushInstr(A); + M_->pushInstr(A); // Add this instruction to the list of open allocations. activeAllocs_.push_back(A); return A; } AllocActivationInst *IRBuilder::createAllocActivationInst( ElemKind elemTy, llvm::ArrayRef dims, llvm::StringRef name) { - auto T = M_.getGraph().uniqueType(elemTy, dims); + auto T = M_->getGraph()->uniqueType(elemTy, dims); return createAllocActivationInst(T, name); } DeallocActivationInst *IRBuilder::createDeallocActivationInst(Value *src) { auto *A = new DeallocActivationInst(src); - M_.pushInstr(A); + M_->pushInstr(A); return A; } diff --git a/src/glow/IR/IRGen.cpp b/src/glow/IR/IRGen.cpp index 4b0be7e374..12604c0fcb 100644 --- a/src/glow/IR/IRGen.cpp +++ b/src/glow/IR/IRGen.cpp @@ -23,7 +23,7 @@ struct IRGenVisitor : NodeVisitor { /// Holds the mapping between graph nodes to IR variables. NodeToInstrTy generatedNodes; /// The module that we are building. - Module &M_; + Module *M_; /// The builder that adds instructions into the module. IRBuilder builder_; @@ -33,7 +33,7 @@ struct IRGenVisitor : NodeVisitor { return !generatedNodes.count(N); } - explicit IRGenVisitor(Module &M) : M_(M), builder_(M_) {} + explicit IRGenVisitor(Module *M) : M_(M), builder_(M_) {} /// \returns the generated instruction for the node \p N. Value *valueForNode(Node *N) { @@ -48,7 +48,7 @@ struct IRGenVisitor : NodeVisitor { "Value operand must be a memory location"); generatedNodes[N] = v; // Register the fact that we've lowered this variable to the new weight. - auto &map = M_.getVariableMap(); + auto &map = M_->getVariableMap(); map[N] = v; } @@ -215,13 +215,13 @@ struct IRGenVisitor : NodeVisitor { } // namespace void Module::generateIR() { - IRGenVisitor irgen(*this); + IRGenVisitor irgen(this); - for (auto &N : G_.getVars()) { + for (auto &N : G_->getVars()) { N->visit(nullptr, &irgen); } - for (auto &N : G_.getNodes()) { + for (auto &N : G_->getNodes()) { N->visit(nullptr, &irgen); } } diff --git a/src/glow/Interpreter/Interpreter.cpp b/src/glow/Interpreter/Interpreter.cpp index 6395cc5fb6..2ddb14b493 100644 --- a/src/glow/Interpreter/Interpreter.cpp +++ b/src/glow/Interpreter/Interpreter.cpp @@ -83,7 +83,7 @@ void Interpreter::doForwardPass(bool isTrain) { break; \ } // Dispatch the interpreter on each instruction in the program: - for (auto *I : M_.getInstrs()) { + for (auto *I : M_->getInstrs()) { switch (I->getKind()) { #include "glow/IR/Instrs.def" default: @@ -103,7 +103,7 @@ void Interpreter::doBackwardPass() { } // Dispatch the interpreter on each instruction in the program, in reverse // order. - auto &L = M_.getInstrs(); + auto &L = M_->getInstrs(); for (auto it = L.rbegin(), e = L.rend(); it != e; it++) { switch ((*it)->getKind()) { #include "glow/IR/Instrs.def" diff --git a/tests/unittests/GraphTest.cpp b/tests/unittests/GraphTest.cpp index b825af9ceb..16b2ecaceb 100644 --- a/tests/unittests/GraphTest.cpp +++ b/tests/unittests/GraphTest.cpp @@ -20,7 +20,7 @@ TEST(Graph, simpleTest) { { Graph G; - Module M(G); + Module M(&G); Node *K = G.createVariable(ElemKind::FloatTy, {4, 320, 200, 3}, "input"); Node *S = G.createVariable(ElemKind::IndexTy, {4, 1}, "select"); @@ -36,7 +36,7 @@ TEST(Graph, simpleTest) { { unsigned numInputs = 10; Graph G; - Module M(G); + Module M(&G); auto *A = G.createVariable(ElemKind::FloatTy, {numInputs, 2}, "A"); auto *Ex = G.createVariable(ElemKind::FloatTy, {numInputs, 1}, "Ex"); diff --git a/tests/unittests/IRTest.cpp b/tests/unittests/IRTest.cpp index df60dc9553..3c4b96159d 100644 --- a/tests/unittests/IRTest.cpp +++ b/tests/unittests/IRTest.cpp @@ -21,7 +21,7 @@ using namespace glow; TEST(IR, uniqueTypes) { Graph G; - Module M(G); + Module M(&G); Type T1(ElemKind::FloatTy, {320, 200}); Type T2(ElemKind::FloatTy, {320, 200}); Type T3(ElemKind::FloatTy, {1, 2}); @@ -40,9 +40,9 @@ TEST(IR, uniqueTypes) { TEST(IR, basicUseList) { Graph G; - Module M(G); + Module M(&G); { - IRBuilder builder(M); + IRBuilder builder(&M); auto *V1 = builder.createWeightVar(ElemKind::FloatTy, {320, 200}); auto *V2 = builder.createWeightVar(ElemKind::FloatTy, {320, 200}); @@ -66,14 +66,14 @@ TEST(IR, allInstrs) { using InitKind = WeightVar::InitKind; Graph G; - Module M(G); + Module M(&G); auto T1 = G.uniqueType(ElemKind::FloatTy, {1, 24, 24, 3}); auto T2 = G.uniqueType(ElemKind::FloatTy, {64}); auto T4 = G.uniqueType(ElemKind::IndexTy, {1, 1}); auto T5 = G.uniqueType(ElemKind::FloatTy, {3}); { - IRBuilder builder(M); + IRBuilder builder(&M); auto *I0 = builder.createWeightVar(T1, "I0", InitKind::Extern, 0); auto *I1 = builder.createWeightVar(T1, "I1", InitKind::Extern, 0); @@ -120,9 +120,9 @@ TEST(IR, allInstrs) { TEST(IR, highLevelBuilder) { Graph G; - Module M(G); + Module M(&G); { - IRBuilder bb(M); + IRBuilder bb(&M); auto *input = bb.createWeightVar(ElemKind::FloatTy, {1, 224, 224, 3}); auto *conv = bb.createConvOp(input, 16, 7, 2, 3); @@ -154,9 +154,9 @@ TEST(IR, highLevelBuilder) { TEST(IR, casting) { Graph G; - Module M(G); + Module M(&G); { - IRBuilder bb(M); + IRBuilder bb(&M); auto *input = bb.createWeightVar(ElemKind::FloatTy, {1, 224, 224, 3}); auto *conv = bb.createConvOp(input, 16, 7, 2, 3); diff --git a/tests/unittests/InterpreterTest.cpp b/tests/unittests/InterpreterTest.cpp index b559569ffe..c2db3dee4c 100644 --- a/tests/unittests/InterpreterTest.cpp +++ b/tests/unittests/InterpreterTest.cpp @@ -1,8 +1,7 @@ // Copyright 2017 Facebook Inc. All Rights Reserved. #include "glow/ExecutionEngine/ExecutionEngine.h" -#include "glow/Graph/Node.h" -#include "glow/Graph/Nodes.h" +#include "glow/Graph/Graph.h" #include "glow/IR/IR.h" #include "glow/IR/IRBuilder.h" #include "glow/IR/Instrs.h" diff --git a/tests/unittests/Tensors.cpp b/tests/unittests/Tensors.cpp index be023009b7..5dc2c84302 100644 --- a/tests/unittests/Tensors.cpp +++ b/tests/unittests/Tensors.cpp @@ -4,13 +4,6 @@ #include "gtest/gtest.h" -#include -#include -#include -#include -#include -#include - using namespace glow; TEST(Tensor, init) { From d706046d3f087f37c35a1061476ec8b5abd791ba Mon Sep 17 00:00:00 2001 From: Nadav Rotem Date: Fri, 13 Oct 2017 10:03:46 -0700 Subject: [PATCH 3/5] Tidy and clean code base. Add explicit, remove calls to constructos without parameters, etc. --- include/glow/Base/Tensor.h | 10 +++++----- include/glow/Base/Traits.h | 8 ++++---- include/glow/Graph/Node.h | 2 +- include/glow/IR/IR.h | 4 ++-- include/glow/IR/Instrs.h | 4 ++-- include/glow/Interpreter/Interpreter.h | 2 +- src/glow/Interpreter/Interpreter.cpp | 2 +- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/include/glow/Base/Tensor.h b/include/glow/Base/Tensor.h index 5619e9a99a..6f33663e2c 100644 --- a/include/glow/Base/Tensor.h +++ b/include/glow/Base/Tensor.h @@ -112,9 +112,9 @@ class Tensor final { Tensor() = default; /// Initialize from a list of float literals. - Tensor(const std::initializer_list &vec) : type_{} { + Tensor(const std::initializer_list &vec) { reset(ElemKind::FloatTy, {vec.size()}); - FloatTy *data = getRawDataPointer(); + auto *data = getRawDataPointer(); int i = 0; for (auto &f : vec) { data[i++] = f; @@ -122,10 +122,10 @@ class Tensor final { } /// Allocate and initialize a new tensor. - Tensor(TypeRef ty) : data_(nullptr), type_(*ty) { reset(*ty); } + explicit Tensor(TypeRef ty) : data_(nullptr), type_(*ty) { reset(*ty); } /// Allocate and initialize a new tensor. - Tensor(const Type &ty) : data_(nullptr), type_(ty) { reset(ty); } + explicit Tensor(const Type &ty) : data_(nullptr), type_(ty) { reset(ty); } /// Allocate and initialize a new tensor. Tensor(ElemKind elemTy, llvm::ArrayRef dims) @@ -320,7 +320,7 @@ template class Handle final { } void clear(ElemTy value = 0) { - ElemTy *data = tensor_->getRawDataPointer(); + auto *data = tensor_->getRawDataPointer(); std::fill(&data[0], &data[0] + size(), value); } diff --git a/include/glow/Base/Traits.h b/include/glow/Base/Traits.h index f28e718812..2edf4f4d88 100644 --- a/include/glow/Base/Traits.h +++ b/include/glow/Base/Traits.h @@ -1,5 +1,5 @@ -#ifndef GLOW_IR_TRAITS_H -#define GLOW_IR_TRAITS_H +#ifndef GLOW_BASE_TRAITS_H +#define GLOW_BASE_TRAITS_H #include "glow/Base/Type.h" @@ -31,7 +31,7 @@ class Typed { TypeRef Ty_{}; public: - Typed(TypeRef Ty) : Ty_(Ty){}; + explicit Typed(TypeRef Ty) : Ty_(Ty){}; TypeRef getType() const { return Ty_; } @@ -78,4 +78,4 @@ class Kinded { } // namespace glow -#endif // GLOW_IR_TRAITS_H +#endif // GLOW_BASE_TRAITS_H diff --git a/include/glow/Graph/Node.h b/include/glow/Graph/Node.h index 735bee8309..6ff6af07dc 100644 --- a/include/glow/Graph/Node.h +++ b/include/glow/Graph/Node.h @@ -3,8 +3,8 @@ #include "llvm/ADT/StringRef.h" -#include "glow/Base/Type.h" #include "glow/Base/Traits.h" +#include "glow/Base/Type.h" namespace glow { diff --git a/include/glow/IR/IR.h b/include/glow/IR/IR.h index 4c8c14f3ab..347c7e1125 100644 --- a/include/glow/IR/IR.h +++ b/include/glow/IR/IR.h @@ -1,8 +1,8 @@ #ifndef GLOW_IR_IR_H #define GLOW_IR_IR_H -#include "glow/Base/Type.h" #include "glow/Base/Traits.h" +#include "glow/Base/Type.h" #include "glow/IR/UseDef.h" #include "llvm/ADT/ArrayRef.h" @@ -33,7 +33,7 @@ class Value : public Named, public Typed, public Kinded { public: - Value(TypeRef T, Kinded::Kind k) : Named(), UseDef(), Typed(T), Kinded(k) {} + Value(TypeRef T, Kinded::Kind k) : Typed(T), Kinded(k) {} }; /// This represents an instruction in our IR. diff --git a/include/glow/IR/Instrs.h b/include/glow/IR/Instrs.h index f85357f126..f0425de953 100644 --- a/include/glow/IR/Instrs.h +++ b/include/glow/IR/Instrs.h @@ -12,7 +12,7 @@ namespace glow { class AllocActivationInst : public Instruction { public: - AllocActivationInst(TypeRef Ty) + explicit AllocActivationInst(TypeRef Ty) : Instruction(Kinded::Kind::AllocActivationInstKind, Ty) {} static bool classof(const Kinded *k) { @@ -25,7 +25,7 @@ class AllocActivationInst : public Instruction { class DeallocActivationInst : public Instruction { public: - DeallocActivationInst(Value *src) + explicit DeallocActivationInst(Value *src) : Instruction(Kinded::Kind::DeallocActivationInstKind, src->getType(), {{src, OperandKind::Out}}) {} diff --git a/include/glow/Interpreter/Interpreter.h b/include/glow/Interpreter/Interpreter.h index 6d645c2564..d9bc615673 100644 --- a/include/glow/Interpreter/Interpreter.h +++ b/include/glow/Interpreter/Interpreter.h @@ -34,7 +34,7 @@ class Interpreter final { public: /// Ctor. - Interpreter(Module *M) : M_(M) {} + explicit Interpreter(Module *M) : M_(M) {} /// Dtor. ~Interpreter(); diff --git a/src/glow/Interpreter/Interpreter.cpp b/src/glow/Interpreter/Interpreter.cpp index 2ddb14b493..3ab489dffe 100644 --- a/src/glow/Interpreter/Interpreter.cpp +++ b/src/glow/Interpreter/Interpreter.cpp @@ -64,7 +64,7 @@ Tensor *Interpreter::getOrCreateTensor(const Value *v) { // Pick the tensor. auto it = tensors_.find(v); if (it == tensors_.end()) { - Tensor *T = new Tensor(v->getType()); + auto *T = new Tensor(v->getType()); tensors_[v] = T; return T; } From 3b928c896292a03a62ac9287c9fb57fd2955b9c5 Mon Sep 17 00:00:00 2001 From: Nadav Rotem Date: Fri, 13 Oct 2017 11:08:31 -0700 Subject: [PATCH 4/5] Remove the Context parameter from the interpreter. --- include/glow/Interpreter/Interpreter.h | 18 ++-- src/glow/Interpreter/Interpreter.cpp | 4 +- src/glow/Interpreter/InterpreterNodes.cpp | 107 ++++++++++------------ 3 files changed, 57 insertions(+), 72 deletions(-) diff --git a/include/glow/Interpreter/Interpreter.h b/include/glow/Interpreter/Interpreter.h index d9bc615673..9cdfa3f95e 100644 --- a/include/glow/Interpreter/Interpreter.h +++ b/include/glow/Interpreter/Interpreter.h @@ -80,19 +80,17 @@ class Interpreter final { #define DEF_VALUE(CLASS, NAME) #define DEF_NODE(CLASS, NAME) #define DEF_INSTR(CLASS, NAME) \ - void fwd##CLASS(Context *ctx, bool isTrain, const CLASS *I); \ - void bwd##CLASS(Context *ctx, const CLASS *I); + void fwd##CLASS(bool isTrain, const CLASS *I); \ + void bwd##CLASS(const CLASS *I); #include "glow/IR/Instrs.def" - void fwdPoolMax_impl(Context *ctx, const PoolInst *I); - void fwdPoolAvg_impl(Context *ctx, const PoolInst *I); - void bwdPoolMax_impl(Context *ctx, const PoolInst *I); - void bwdPoolAvg_impl(Context *ctx, const PoolInst *I); + void fwdPoolMax_impl(const PoolInst *I); + void fwdPoolAvg_impl(const PoolInst *I); + void bwdPoolMax_impl(const PoolInst *I); + void bwdPoolAvg_impl(const PoolInst *I); - void fwdBatchNormalizationInst_infer(Context *ctx, - const BatchNormalizationInst *I); - void fwdBatchNormalizationInst_train(Context *ctx, - const BatchNormalizationInst *I); + void fwdBatchNormalizationInst_infer(const BatchNormalizationInst *I); + void fwdBatchNormalizationInst_train(const BatchNormalizationInst *I); ///@} }; diff --git a/src/glow/Interpreter/Interpreter.cpp b/src/glow/Interpreter/Interpreter.cpp index 3ab489dffe..fd7539c80d 100644 --- a/src/glow/Interpreter/Interpreter.cpp +++ b/src/glow/Interpreter/Interpreter.cpp @@ -79,7 +79,7 @@ void Interpreter::doForwardPass(bool isTrain) { #define DEF_NODE(CLASS, NAME) #define DEF_INSTR(CLASS, NAME) \ case Kinded::Kind::CLASS##Kind: { \ - fwd##CLASS(nullptr, isTrain, cast(I)); \ + fwd##CLASS(isTrain, cast(I)); \ break; \ } // Dispatch the interpreter on each instruction in the program: @@ -98,7 +98,7 @@ void Interpreter::doBackwardPass() { #define DEF_NODE(CLASS, NAME) #define DEF_INSTR(CLASS, NAME) \ case Kinded::Kind::CLASS##Kind: { \ - bwd##CLASS(nullptr, cast(*it)); \ + bwd##CLASS(cast(*it)); \ break; \ } // Dispatch the interpreter on each instruction in the program, in reverse diff --git a/src/glow/Interpreter/InterpreterNodes.cpp b/src/glow/Interpreter/InterpreterNodes.cpp index fa5270a13f..7dd12da8c8 100644 --- a/src/glow/Interpreter/InterpreterNodes.cpp +++ b/src/glow/Interpreter/InterpreterNodes.cpp @@ -10,7 +10,7 @@ using namespace glow; // Convolution //===----------------------------------------------------------------------===// -void Interpreter::fwdCopyInst(Context *ctx, bool isTrain, const CopyInst *I) { +void Interpreter::fwdCopyInst(bool isTrain, const CopyInst *I) { auto S = getWeightHandle(I->getSrc()); auto D = getWeightHandle(I->getDest()); @@ -19,7 +19,7 @@ void Interpreter::fwdCopyInst(Context *ctx, bool isTrain, const CopyInst *I) { } } -void Interpreter::bwdCopyInst(Context *ctx, const CopyInst *I) { +void Interpreter::bwdCopyInst(const CopyInst *I) { auto inG = getGradHandle(I->getSrc()); auto outG = getGradHandle(I->getDest()); @@ -28,8 +28,7 @@ void Interpreter::bwdCopyInst(Context *ctx, const CopyInst *I) { } } -void Interpreter::fwdConvolutionInst(Context *ctx, bool isTrain, - const ConvolutionInst *I) { +void Interpreter::fwdConvolutionInst(bool isTrain, const ConvolutionInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); auto filterW = getWeightHandle(I->getFilter()); @@ -82,7 +81,7 @@ void Interpreter::fwdConvolutionInst(Context *ctx, bool isTrain, } // N } -void Interpreter::bwdConvolutionInst(Context *ctx, const ConvolutionInst *I) { +void Interpreter::bwdConvolutionInst(const ConvolutionInst *I) { auto inW = getWeightHandle(I->getSrc()); auto inG = getGradHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -144,15 +143,15 @@ void Interpreter::bwdConvolutionInst(Context *ctx, const ConvolutionInst *I) { // Pooling //===----------------------------------------------------------------------===// -void Interpreter::fwdPoolInst(Context *ctx, bool isTrain, const PoolInst *I) { +void Interpreter::fwdPoolInst(bool isTrain, const PoolInst *I) { if (I->getKind() == PoolInst::OpKind::Max) { - return fwdPoolMax_impl(ctx, I); + return fwdPoolMax_impl(I); } - return fwdPoolAvg_impl(ctx, I); + return fwdPoolAvg_impl(I); } -void Interpreter::fwdPoolMax_impl(Context *ctx, const PoolInst *I) { +void Interpreter::fwdPoolMax_impl(const PoolInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -213,7 +212,7 @@ void Interpreter::fwdPoolMax_impl(Context *ctx, const PoolInst *I) { } // N } -void Interpreter::fwdPoolAvg_impl(Context *ctx, const PoolInst *I) { +void Interpreter::fwdPoolAvg_impl(const PoolInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -262,15 +261,15 @@ void Interpreter::fwdPoolAvg_impl(Context *ctx, const PoolInst *I) { } // N } -void Interpreter::bwdPoolInst(Context *ctx, const PoolInst *I) { +void Interpreter::bwdPoolInst(const PoolInst *I) { if (I->getKind() == PoolInst::OpKind::Max) { - return bwdPoolMax_impl(ctx, I); + return bwdPoolMax_impl(I); } - return bwdPoolAvg_impl(ctx, I); + return bwdPoolAvg_impl(I); } -void Interpreter::bwdPoolMax_impl(Context *ctx, const PoolInst *I) { +void Interpreter::bwdPoolMax_impl(const PoolInst *I) { auto inG = getGradHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); auto outG = getGradHandle(I->getDest()); @@ -301,7 +300,7 @@ void Interpreter::bwdPoolMax_impl(Context *ctx, const PoolInst *I) { } // N } -void Interpreter::bwdPoolAvg_impl(Context *ctx, const PoolInst *I) { +void Interpreter::bwdPoolAvg_impl(const PoolInst *I) { auto inW = getWeightHandle(I->getSrc()); auto inG = getGradHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -351,7 +350,7 @@ void Interpreter::bwdPoolAvg_impl(Context *ctx, const PoolInst *I) { // Fully Connected //===----------------------------------------------------------------------===// -void Interpreter::fwdFullyConnectedInst(Context *ctx, const bool isTrain, +void Interpreter::fwdFullyConnectedInst(const bool isTrain, const FullyConnectedInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -380,8 +379,7 @@ void Interpreter::fwdFullyConnectedInst(Context *ctx, const bool isTrain, } // N } -void Interpreter::bwdFullyConnectedInst(Context *ctx, - const FullyConnectedInst *I) { +void Interpreter::bwdFullyConnectedInst(const FullyConnectedInst *I) { auto inW = getWeightHandle(I->getSrc()); auto inG = getGradHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -419,7 +417,7 @@ void Interpreter::bwdFullyConnectedInst(Context *ctx, // Activation functions //===----------------------------------------------------------------------===// -void Interpreter::fwdReluInst(Context *ctx, bool isTrain, const ReluInst *I) { +void Interpreter::fwdReluInst(bool isTrain, const ReluInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -429,7 +427,7 @@ void Interpreter::fwdReluInst(Context *ctx, bool isTrain, const ReluInst *I) { } } -void Interpreter::bwdReluInst(Context *ctx, const ReluInst *I) { +void Interpreter::bwdReluInst(const ReluInst *I) { auto inG = getGradHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); auto outG = getGradHandle(I->getDest()); @@ -440,8 +438,7 @@ void Interpreter::bwdReluInst(Context *ctx, const ReluInst *I) { } } -void Interpreter::fwdSigmoidInst(Context *ctx, bool isTrain, - const SigmoidInst *I) { +void Interpreter::fwdSigmoidInst(bool isTrain, const SigmoidInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -450,7 +447,7 @@ void Interpreter::fwdSigmoidInst(Context *ctx, bool isTrain, outW.raw(i) = 1 / (1 + std::exp(-val)); } } -void Interpreter::bwdSigmoidInst(Context *ctx, const SigmoidInst *I) { +void Interpreter::bwdSigmoidInst(const SigmoidInst *I) { auto inG = getGradHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); auto outG = getGradHandle(I->getDest()); @@ -461,7 +458,7 @@ void Interpreter::bwdSigmoidInst(Context *ctx, const SigmoidInst *I) { } } -void Interpreter::fwdTanhInst(Context *ctx, bool isTrain, const TanhInst *I) { +void Interpreter::fwdTanhInst(bool isTrain, const TanhInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -472,7 +469,7 @@ void Interpreter::fwdTanhInst(Context *ctx, bool isTrain, const TanhInst *I) { outW.raw(i) = (exp_val - exp_neg_val) / (exp_val + exp_neg_val); } } -void Interpreter::bwdTanhInst(Context *ctx, const TanhInst *I) { +void Interpreter::bwdTanhInst(const TanhInst *I) { auto inG = getGradHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); auto outG = getGradHandle(I->getDest()); @@ -487,8 +484,7 @@ void Interpreter::bwdTanhInst(Context *ctx, const TanhInst *I) { // Loss Functions (Softmax/regression/...) //===----------------------------------------------------------------------===// -void Interpreter::fwdSoftMaxInst(Context *ctx, bool isTrain, - const SoftMaxInst *I) { +void Interpreter::fwdSoftMaxInst(bool isTrain, const SoftMaxInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); auto idim = inW.dims(); @@ -520,7 +516,7 @@ void Interpreter::fwdSoftMaxInst(Context *ctx, bool isTrain, } // N } -void Interpreter::bwdSoftMaxInst(Context *ctx, const SoftMaxInst *I) { +void Interpreter::bwdSoftMaxInst(const SoftMaxInst *I) { auto inG = getGradHandle(I->getSrc()); auto idim = inG.dims(); @@ -538,8 +534,7 @@ void Interpreter::bwdSoftMaxInst(Context *ctx, const SoftMaxInst *I) { } } -void Interpreter::fwdRegressionInst(Context *ctx, bool isTrain, - const RegressionInst *I) { +void Interpreter::fwdRegressionInst(bool isTrain, const RegressionInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -548,7 +543,7 @@ void Interpreter::fwdRegressionInst(Context *ctx, bool isTrain, } } -void Interpreter::bwdRegressionInst(Context *ctx, const RegressionInst *I) { +void Interpreter::bwdRegressionInst(const RegressionInst *I) { auto inW = getWeightHandle(I->getSrc()); auto inG = getGradHandle(I->getSrc()); auto expected = getTensor(I->getExpected()); @@ -572,8 +567,7 @@ void Interpreter::bwdRegressionInst(Context *ctx, const RegressionInst *I) { // Tensor shape (transpose/reshape/concat/...) //===----------------------------------------------------------------------===// -void Interpreter::fwdTransposeInst(Context *ctx, bool isTrain, - const TransposeInst *I) { +void Interpreter::fwdTransposeInst(bool isTrain, const TransposeInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getTensor(I->getDest()); @@ -581,7 +575,7 @@ void Interpreter::fwdTransposeInst(Context *ctx, bool isTrain, inW.transpose(outW, I->getShuffle()); } -void Interpreter::bwdTransposeInst(Context *ctx, const TransposeInst *I) { +void Interpreter::bwdTransposeInst(const TransposeInst *I) { auto inG = getOrCreateGradTensor(I->getSrc()); auto outG = getGradHandle(I->getDest()); @@ -600,15 +594,14 @@ void Interpreter::bwdTransposeInst(Context *ctx, const TransposeInst *I) { outG.transpose(inG, reverseShuffle); } -void Interpreter::fwdReshapeInst(Context *ctx, bool isTrain, - const ReshapeInst *I) { +void Interpreter::fwdReshapeInst(bool isTrain, const ReshapeInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); for (size_t i = 0, e = inW.size(); i < e; i++) { outW.raw(i) = inW.raw(i); } } -void Interpreter::bwdReshapeInst(Context *ctx, const ReshapeInst *I) { +void Interpreter::bwdReshapeInst(const ReshapeInst *I) { auto inG = getGradHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); auto outG = getGradHandle(I->getDest()); @@ -617,8 +610,7 @@ void Interpreter::bwdReshapeInst(Context *ctx, const ReshapeInst *I) { } } -void Interpreter::fwdConcatInst(Context *ctx, bool isTrain, - const ConcatInst *I) { +void Interpreter::fwdConcatInst(bool isTrain, const ConcatInst *I) { auto outW = getWeightHandle(I->getDest()); // Insert the tensors at this coordinate. Start at zero. @@ -635,7 +627,7 @@ void Interpreter::fwdConcatInst(Context *ctx, bool isTrain, offset[dim] += inW.dims()[dim]; } } -void Interpreter::bwdConcatInst(Context *ctx, const ConcatInst *I) { +void Interpreter::bwdConcatInst(const ConcatInst *I) { auto outG = getGradHandle(I->getDest()); // Insert the tensors at this coordinate. Start at zero. @@ -661,17 +653,17 @@ void Interpreter::bwdConcatInst(Context *ctx, const ConcatInst *I) { // Batch Normalization //===----------------------------------------------------------------------===// -void Interpreter::fwdBatchNormalizationInst(Context *ctx, bool isTrain, +void Interpreter::fwdBatchNormalizationInst(bool isTrain, const BatchNormalizationInst *I) { if (isTrain) { - return fwdBatchNormalizationInst_train(ctx, I); + return fwdBatchNormalizationInst_train(I); } - return fwdBatchNormalizationInst_infer(ctx, I); + return fwdBatchNormalizationInst_infer(I); } void Interpreter::fwdBatchNormalizationInst_infer( - Context *ctx, const BatchNormalizationInst *I) { + const BatchNormalizationInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -709,7 +701,7 @@ void Interpreter::fwdBatchNormalizationInst_infer( } void Interpreter::fwdBatchNormalizationInst_train( - Context *ctx, const BatchNormalizationInst *I) { + const BatchNormalizationInst *I) { auto inW = getWeightHandle(I->getSrc()); auto varH = getWeightHandle(I->getVar()); auto meanH = getWeightHandle(I->getMean()); @@ -761,11 +753,10 @@ void Interpreter::fwdBatchNormalizationInst_train( } // TODO: should we be using the running mean or the local mean? - fwdBatchNormalizationInst_infer(ctx, I); + fwdBatchNormalizationInst_infer(I); } -void Interpreter::bwdBatchNormalizationInst(Context *ctx, - const BatchNormalizationInst *I) { +void Interpreter::bwdBatchNormalizationInst(const BatchNormalizationInst *I) { auto inW = getWeightHandle(I->getSrc()); auto inG = getGradHandle(I->getSrc()); auto outG = getGradHandle(I->getDest()); @@ -848,8 +839,7 @@ void Interpreter::bwdBatchNormalizationInst(Context *ctx, } void Interpreter::fwdLocalResponseNormalizationInst( - glow::Context *ctx, bool isTrain, - const glow::LocalResponseNormalizationInst *I) { + bool isTrain, const glow::LocalResponseNormalizationInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); auto scaleCache = getWeightHandle(I->getScale()); @@ -914,7 +904,7 @@ void Interpreter::fwdLocalResponseNormalizationInst( } void Interpreter::bwdLocalResponseNormalizationInst( - glow::Context *ctx, const glow::LocalResponseNormalizationInst *I) { + const glow::LocalResponseNormalizationInst *I) { auto inW = getWeightHandle(I->getSrc()); auto inG = getGradHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -987,8 +977,7 @@ void Interpreter::bwdLocalResponseNormalizationInst( // Arithmetic operations //===----------------------------------------------------------------------===// -void Interpreter::fwdArithmeticInst(Context *ctx, bool isTrain, - const ArithmeticInst *I) { +void Interpreter::fwdArithmeticInst(bool isTrain, const ArithmeticInst *I) { auto outW = getWeightHandle(I->getDest()); auto LHSW = getWeightHandle(I->getLHS()); auto RHSW = getWeightHandle(I->getRHS()); @@ -1010,7 +999,7 @@ void Interpreter::fwdArithmeticInst(Context *ctx, bool isTrain, } } -void Interpreter::bwdArithmeticInst(Context *ctx, const ArithmeticInst *I) { +void Interpreter::bwdArithmeticInst(const ArithmeticInst *I) { auto LHSW = getWeightHandle(I->getLHS()); auto RHSW = getWeightHandle(I->getRHS()); auto outG = getGradHandle(I->getDest()); @@ -1040,7 +1029,7 @@ void Interpreter::bwdArithmeticInst(Context *ctx, const ArithmeticInst *I) { // Tensor allocation operations //===----------------------------------------------------------------------===// -void Interpreter::fwdAllocActivationInst(Context *ctx, bool isTrain, +void Interpreter::fwdAllocActivationInst(bool isTrain, const AllocActivationInst *I) { getOrCreateTensor(I); // Prepare for the next backprop iteration by zeroing the gradient @@ -1051,11 +1040,9 @@ void Interpreter::fwdAllocActivationInst(Context *ctx, bool isTrain, getOrCreateGradTensor(I)->zero(); } -void Interpreter::bwdAllocActivationInst(Context *ctx, - const AllocActivationInst *I) {} +void Interpreter::bwdAllocActivationInst(const AllocActivationInst *I) {} -void Interpreter::fwdDeallocActivationInst(Context *ctx, bool isTrain, +void Interpreter::fwdDeallocActivationInst(bool isTrain, const DeallocActivationInst *I) {} -void Interpreter::bwdDeallocActivationInst(Context *ctx, - const DeallocActivationInst *I) {} +void Interpreter::bwdDeallocActivationInst(const DeallocActivationInst *I) {} From 8d0a6ea7ad760bcbdca028c2c901594b756bbb4a Mon Sep 17 00:00:00 2001 From: Nadav Rotem Date: Fri, 13 Oct 2017 11:12:32 -0700 Subject: [PATCH 5/5] Cleanup the code for some of the test binaries. Remove unused headers. NFC. --- examples/cifar10.cpp | 12 +----------- examples/mnist.cpp | 4 ---- tools/loader/loader.cpp | 5 ++--- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/examples/cifar10.cpp b/examples/cifar10.cpp index 5086a3c730..2235802045 100644 --- a/examples/cifar10.cpp +++ b/examples/cifar10.cpp @@ -1,22 +1,12 @@ #include "glow/ExecutionEngine/ExecutionEngine.h" #include "glow/Graph/Graph.h" -#include "glow/Graph/Nodes.h" -#include "glow/IR/IR.h" -#include "glow/IR/IRBuilder.h" -#include "glow/IR/Instrs.h" #include "glow/Support/Support.h" #include "llvm/Support/Timer.h" #include -#include -#include #include #include -#include -#include -#include -#include using namespace glow; @@ -100,7 +90,7 @@ void testCIFAR10() { auto *SM = G.createSoftMax("softmax", RL3, E); auto *result = G.createReturn("ret", SM); - EE.getModule().generateIR(); + EE.generateIR(); EE.optimize(OptimizationMode::Train); EE.initVars(); diff --git a/examples/mnist.cpp b/examples/mnist.cpp index 34888c25ef..8eb37e5763 100644 --- a/examples/mnist.cpp +++ b/examples/mnist.cpp @@ -5,12 +5,8 @@ #include "llvm/Support/Timer.h" #include -#include -#include #include #include -#include -#include using namespace glow; diff --git a/tools/loader/loader.cpp b/tools/loader/loader.cpp index d7c6947289..f4335226d5 100644 --- a/tools/loader/loader.cpp +++ b/tools/loader/loader.cpp @@ -3,10 +3,9 @@ #include "glow/Base/Image.h" #include "glow/Base/Tensor.h" #include "glow/ExecutionEngine/ExecutionEngine.h" -#include "glow/Graph/Graph.h" -#include "glow/Graph/Nodes.h" #include "glow/Importer/Caffe2.h" -#include "glow/Interpreter/Interpreter.h" + +#include using namespace glow;