diff --git a/include/glow/ExecutionEngine/ExecutionEngine.h b/include/glow/ExecutionEngine/ExecutionEngine.h index 55a64e5941..bae28e9dc4 100644 --- a/include/glow/ExecutionEngine/ExecutionEngine.h +++ b/include/glow/ExecutionEngine/ExecutionEngine.h @@ -89,11 +89,8 @@ class ExecutionEngine final { // Helper methods for running the execution engine. //===----------------------------------------------------------------------===// -/// This method updates the variables in \p vars with the tensor content -/// values \p inputs. -void updateVariables(llvm::ArrayRef vars, - llvm::ArrayRef inputs); - +/// This method updates the placeholders in \p ph with the tensor content +/// values \p inputs, in \p ctx. void updateVariables(Context &ctx, llvm::ArrayRef ph, llvm::ArrayRef inputs); @@ -104,14 +101,6 @@ void updateInputsByName(Context &ctx, Module *mod, llvm::ArrayRef ph, llvm::ArrayRef inputs); -/// Update the content of the tensors \p vars with some slices that from \p -/// inputs. The data starts at slice \p sampleIdx and wraps around until the -/// data in \p v is filled. All dimensions, except for the first (batch) -/// dimension must be identical. -void updateVariablesFromBatch(llvm::ArrayRef vars, - llvm::ArrayRef inputs, - size_t sampleIdx); - /// Runs \p iterations iterations of the compiled function. The method updates a /// global counter and future invocations of this method continue running /// iterations of the batch at the next available slice. @@ -124,9 +113,6 @@ void updateVariablesFromBatch(llvm::ArrayRef vars, /// variable records the number of samples that were consumed by the network in /// previous iterations. The next input to be loaded is /// (sampleCounter % batchsize). -void runBatch(ExecutionEngine &EE, size_t iterations, size_t &sampleCounter, - llvm::ArrayRef vars, llvm::ArrayRef inputs); - void runBatch(ExecutionEngine &EE, Context &ctx, size_t iterations, size_t &sampleCounter, llvm::ArrayRef ph, llvm::ArrayRef inputs); diff --git a/include/glow/Graph/Graph.h b/include/glow/Graph/Graph.h index 4ec0faaa5e..b43aced0e0 100644 --- a/include/glow/Graph/Graph.h +++ b/include/glow/Graph/Graph.h @@ -141,23 +141,15 @@ class Module final { float scale, int32_t offset, llvm::StringRef name, bool isTrainable); - Variable *createVariable(TypeRef T, llvm::StringRef name, - VisibilityKind visibility = VisibilityKind::Private, - bool isTrainable = true); + Variable *createVariable(TypeRef T, llvm::StringRef name); Variable *createVariable(ElemKind T, llvm::ArrayRef dims, - llvm::StringRef name, - VisibilityKind visibility = VisibilityKind::Private, - bool isTrainable = true); + llvm::StringRef name); Variable *createVariable(ElemKind T, llvm::ArrayRef dims, float scale, - int32_t offset, llvm::StringRef name, - VisibilityKind visibility = VisibilityKind::Private, - bool isTrainable = true); + int32_t offset, llvm::StringRef name); - Variable *createVariable(llvm::StringRef name, const Tensor &tensor, - VisibilityKind visibility = VisibilityKind::Private, - bool isTrainable = true); + Variable *createVariable(llvm::StringRef name, const Tensor &tensor); ///@} diff --git a/include/glow/Graph/Nodes.h b/include/glow/Graph/Nodes.h index d35b0a818e..db09749a40 100644 --- a/include/glow/Graph/Nodes.h +++ b/include/glow/Graph/Nodes.h @@ -31,12 +31,8 @@ namespace glow { // Storage is the base class for Variables, which are bound to tensors, and // Placeholder nodes which are unbound. class Storage : public Node { - /// Specifies if the variable or placeholder is trainable. - bool isTrainable_; - public: - Storage(Kinded::Kind k, llvm::StringRef name, bool isTrainable) - : Node(k, name), isTrainable_(isTrainable) {} + Storage(Kinded::Kind k, llvm::StringRef name) : Node(k, name) {} /// \return the single output value of the node. NodeValue getOutput() { return getNthResult(0); } @@ -54,10 +50,6 @@ class Storage : public Node { Node *clone() const; /// @} - /// \returns True if the Variable or placeholder are trainable during - /// differentiation. - bool isTraining() const { return isTrainable_; } - /// \returns result type of the variable. TypeRef getType() const { return Node::getType(0); } @@ -74,37 +66,27 @@ class Storage : public Node { }; class Variable : public Storage { - /// Specifies the visibility of the variable. - VisibilityKind visibility_; /// The tensor payload that the variable holds. Tensor payload_; public: /// Create a new variable and initialize its payload. - Variable(llvm::StringRef name, TypeRef Ty, VisibilityKind visibility, - bool isTrainable) - : Storage(Kinded::Kind::VariableKind, name, isTrainable), - visibility_(visibility) { + Variable(llvm::StringRef name, TypeRef Ty) + : Storage(Kinded::Kind::VariableKind, name) { addResult(Ty); payload_.reset(*Ty); } - Variable(llvm::StringRef name, VisibilityKind visibility, Tensor &&payload) - : Storage(Kinded::Kind::VariableKind, name, false), - visibility_(visibility), payload_(std::move(payload)) { + Variable(llvm::StringRef name, Tensor &&payload) + : Storage(Kinded::Kind::VariableKind, name), + payload_(std::move(payload)) { addResult(&payload_.getType()); } - /// \returns True if the Variable is private. - bool isPrivate() const { return visibility_ == VisibilityKind::Private; } - static bool classof(const Kinded *k) { return k->getKind() == Kinded::Kind::VariableKind; } - /// \returns the visibility of the variable. - VisibilityKind getVisibilityKind() const { return visibility_; } - Tensor &getPayload() { return payload_; } const Tensor &getPayload() const { return payload_; } @@ -124,13 +106,21 @@ class Variable : public Storage { /// this node at runtime. Placeholders are used as inputs and output nodes to /// the network. class Placeholder : public Storage { + /// Specifies if the variable or placeholder is trainable. + bool isTrainable_; + public: /// Create a new placeholder variable. Placeholder(llvm::StringRef name, TypeRef Ty, bool isTrainable) - : Storage(Kinded::Kind::PlaceholderKind, name, isTrainable) { + : Storage(Kinded::Kind::PlaceholderKind, name), + isTrainable_(isTrainable) { addResult(Ty); } + /// \returns True if the Variable or placeholder are trainable during + /// differentiation. + bool isTraining() const { return isTrainable_; } + static bool classof(const Kinded *k) { return k->getKind() == Kinded::Kind::PlaceholderKind; } diff --git a/include/glow/Importer/CommonOperatorLoader.h b/include/glow/Importer/CommonOperatorLoader.h index 3d33c07e60..f0cdebef66 100644 --- a/include/glow/Importer/CommonOperatorLoader.h +++ b/include/glow/Importer/CommonOperatorLoader.h @@ -119,8 +119,7 @@ class CommonOperatorLoader : public ProtobufLoader { // have an option for a selected input anyway. So I am creating this as a // placeholder which goes unused during inference. auto selected = G_.getParent()->createVariable( - ElemKind::Int64ITy, {in.dims()[0], 1}, "selected", - VisibilityKind::Private, false); + ElemKind::Int64ITy, {in.dims()[0], 1}, "selected"); // ONNX allows shapes like . Flatten the inputs to the // softmax function. This is similar to a bitcast operation. diff --git a/lib/Backends/CPU/AllocationsInfo.cpp b/lib/Backends/CPU/AllocationsInfo.cpp index 74e03f800e..8677d3863a 100644 --- a/lib/Backends/CPU/AllocationsInfo.cpp +++ b/lib/Backends/CPU/AllocationsInfo.cpp @@ -45,8 +45,6 @@ void AllocationsInfo::allocateWeightVars(const IRFunction *F, for (auto &v : F->getGraph()->getParent()->getVars()) { assert(isa(F->getWeightForNode(v))); auto *w = cast(F->getWeightForNode(v)); - if (v->getVisibilityKind() == VisibilityKind::Public) - continue; auto numBytes = w->getSizeInBytes(); size_t addr = constantWeightVarsAllocator.allocate(numBytes, w); if (!absoluteAddr) { @@ -58,23 +56,6 @@ void AllocationsInfo::allocateWeightVars(const IRFunction *F, } } - // Process all mutable WeightVars afterwards. - for (auto &v : F->getGraph()->getParent()->getVars()) { - assert(isa(F->getWeightForNode(v))); - auto *w = cast(F->getWeightForNode(v)); - if (v->getVisibilityKind() != VisibilityKind::Public) - continue; - auto numBytes = w->getSizeInBytes(); - size_t addr = mutableWeightVarsAllocator.allocate(numBytes, w); - if (!absoluteAddr) { - allocatedAddressed_[w] = addr; - } else { - // Reuse the address used by the payload. - allocatedAddressed_[w] = - v->getPayload().getUnsafePtr() - static_cast(nullptr); - } - } - // Allocate addresses for the Placeholders. for (auto PH : ctx.pairs()) { assert(isa(F->getWeightForNode(PH.first))); @@ -208,10 +189,7 @@ void AllocationsInfo::numberValues(const IRFunction *F) { for (auto &v : F->getGraph()->getParent()->getVars()) { assert(isa(F->getWeightForNode(v))); auto *w = cast(F->getWeightForNode(v)); - auto kind = v->getVisibilityKind() != VisibilityKind::Public - ? ValueKind::ConstantWeight - : ValueKind::MutableWeight; - valueNumbers_[w] = std::make_pair(kind, valueIdx++); + valueNumbers_[w] = std::make_pair(ValueKind::ConstantWeight, valueIdx++); } // Assign numbers to all placeholders. diff --git a/lib/Backends/CPU/BundleSaver.cpp b/lib/Backends/CPU/BundleSaver.cpp index 9e54c36386..c4adcdcc00 100644 --- a/lib/Backends/CPU/BundleSaver.cpp +++ b/lib/Backends/CPU/BundleSaver.cpp @@ -54,8 +54,6 @@ void BundleSaver::saveWeights(llvm::StringRef weightsFileName) { size_t maxPos = 0; for (auto &v : F_->getGraph()->getParent()->getVars()) { auto *w = cast(F_->getWeightForNode(v)); - if (v->getVisibilityKind() == VisibilityKind::Public) - continue; auto numBytes = w->getSizeInBytes(); auto payload = v->getPayload().getUnsafePtr(); auto addr = allocationsInfo_.allocatedAddressed_[w]; @@ -99,7 +97,6 @@ void BundleSaver::emitSymbolTable() { // size and kind. for (auto &v : F_->getGraph()->getParent()->getVars()) { auto *w = cast(F_->getWeightForNode(v)); - bool isConstWeight = v->getVisibilityKind() != VisibilityKind::Public; auto size = w->getType()->size(); auto addr = allocationsInfo_.allocatedAddressed_[w]; // Create an SymbolTableEntry. @@ -114,7 +111,7 @@ void BundleSaver::emitSymbolTable() { // size. llvm::ConstantInt::get(sizeTTy, size), // kind. - llvm::ConstantInt::get(charTy, isConstWeight ? 0 : 1)}); + llvm::ConstantInt::get(charTy, /*isConstWeight*/ 0)}); entries.push_back(entry); } diff --git a/lib/Backends/CPU/Transforms.cpp b/lib/Backends/CPU/Transforms.cpp index 3bd79a5cb4..07b46ff8c1 100644 --- a/lib/Backends/CPU/Transforms.cpp +++ b/lib/Backends/CPU/Transforms.cpp @@ -44,7 +44,7 @@ static Node *optimizeCPUConv(ConvolutionNode *CN, Function *F) { } Variable *filter = dyn_cast(CN->getFilter()); - if (!filter || filter->getNumUsers() != 1 || !filter->isPrivate()) { + if (!filter || filter->getNumUsers() != 1) { // Can't mutate the filter. return nullptr; } @@ -58,9 +58,9 @@ static Node *optimizeCPUConv(ConvolutionNode *CN, Function *F) { TypeRef filterTy = filter->getType(); auto dims = filterTy->dims(); assert(dims.size() == 4 && "Invalid filter size"); - auto *filter8 = M->createVariable( - filterTy->getElementType(), {dims[0] / 8, dims[1], dims[2], dims[3], 8}, - filter->getName(), VisibilityKind::Private, false); + auto *filter8 = M->createVariable(filterTy->getElementType(), + {dims[0] / 8, dims[1], dims[2], dims[3], 8}, + filter->getName()); auto F8H = filter8->getHandle(); auto FH = filter->getHandle(); diff --git a/lib/ExecutionEngine/ExecutionEngine.cpp b/lib/ExecutionEngine/ExecutionEngine.cpp index 953bb08f4d..f7f362a92b 100644 --- a/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/lib/ExecutionEngine/ExecutionEngine.cpp @@ -41,26 +41,6 @@ void ExecutionEngine::setBackend(Backend *backend) { ExecutionEngine::~ExecutionEngine() = default; -void glow::updateVariables(llvm::ArrayRef vars, - llvm::ArrayRef inputs) { - assert(inputs.size() == vars.size() && - "The number of inputs does not match the number of variables"); - - // Update the input variables. - for (int i = 0, e = vars.size(); i < e; i++) { - assert(vars[i] && "Invalid value"); - assert(vars[i]->getVisibilityKind() == VisibilityKind::Public && - "Trying to update a private variable"); - auto &t = vars[i]->getPayload(); - auto dim = inputs[i]->dims(); - (void)dim; - assert(t.dims() == dim && - t.getElementType() == inputs[i]->getElementType() && - "Mismatch on Variable and Tensor types."); - t.assign(inputs[i]); - } -} - void glow::updateVariables(Context &ctx, llvm::ArrayRef ph, llvm::ArrayRef inputs) { assert(inputs.size() == ph.size() && @@ -99,47 +79,6 @@ void ExecutionEngine::run() { function_->execute(); } -/// Update the content of the tensors \p vars with some slices that are from \p -/// inputs. The data starts at slice \p sampleIdx and wraps around until the -/// data in \p v is filled. All dimensions, except for the first (batch) -/// dimension must be identical. -void glow::updateVariablesFromBatch(llvm::ArrayRef vars, - llvm::ArrayRef inputs, - size_t sampleIdx) { - assert(!inputs.empty() && "No inputs"); - assert(inputs.size() == vars.size() && - "The number of inputs does not match the number of variables"); - - // Update the input variables. - for (int i = 0, e = vars.size(); i < e; i++) { - assert(vars[i] && "Invalid value"); - auto &t = vars[i]->getPayload(); - - auto dim = inputs[i]->dims(); - assert(t.dims().drop_front() == dim.drop_front() && "Invalid slice size"); - // Extract the n'th slice, that must be a tensor. - size_t slc = sampleIdx % dim[0]; - t.copyConsecutiveSlices(inputs[i], slc); - } -} - -void glow::runBatch(ExecutionEngine &EE, size_t iterations, - size_t &sampleCounter, llvm::ArrayRef vars, - llvm::ArrayRef inputs) { - // This is the size of one batch (the number of samples in the batch). - size_t batchSize = vars[0]->getType()->dims()[0]; - - for (size_t i = 0; i < iterations; i++) { - // Pick up one slice from the input tensors, and load it into corresponding - // network Variables. Then, run a single pass over the network. - glow::updateVariablesFromBatch(vars, inputs, sampleCounter); - - // Run the network. - EE.run(); - sampleCounter += batchSize; - } -} - void glow::runBatch(ExecutionEngine &EE, Context &ctx, size_t iterations, size_t &sampleCounter, llvm::ArrayRef ph, llvm::ArrayRef inputs) { diff --git a/lib/Graph/Graph.cpp b/lib/Graph/Graph.cpp index 623d206005..a92c1709a2 100644 --- a/lib/Graph/Graph.cpp +++ b/lib/Graph/Graph.cpp @@ -158,11 +158,7 @@ class AbstractDottyPrinter { auto nodeColor = colorNames[colorIdx % arrayLen]; if (auto V = llvm::dyn_cast(N)) { - if (V->getVisibilityKind() == VisibilityKind::Public) { - os << "\tfillcolor=Snow2 color=DarkOliveGreen4\n"; - } else { - os << "\tfillcolor=Snow3 color=DeepSkyBlue4\n"; - } + os << "\tfillcolor=Snow3 color=DeepSkyBlue4\n"; } else { os << "\tfillcolor=" << nodeColor << "\n"; } @@ -353,31 +349,26 @@ Placeholder *Module::createPlaceholder(ElemKind T, llvm::ArrayRef dims, return createPlaceholder(FT, name, isTrainable); } -Variable *Module::createVariable(TypeRef T, llvm::StringRef name, - VisibilityKind visibility, bool isTrainable) { +Variable *Module::createVariable(TypeRef T, llvm::StringRef name) { auto FT = uniqueType(*T); - return addVar(new Variable(name, FT, visibility, isTrainable)); + return addVar(new Variable(name, FT)); } Variable *Module::createVariable(ElemKind T, llvm::ArrayRef dims, - llvm::StringRef name, - VisibilityKind visibility, bool isTrainable) { + llvm::StringRef name) { auto FT = uniqueType(T, dims); - return createVariable(FT, name, visibility, isTrainable); + return createVariable(FT, name); } Variable *Module::createVariable(ElemKind T, llvm::ArrayRef dims, float scale, int32_t offset, - llvm::StringRef name, - VisibilityKind visibility, bool isTrainable) { + llvm::StringRef name) { auto FT = uniqueType(T, dims, scale, offset); - return createVariable(FT, name, visibility, isTrainable); + return createVariable(FT, name); } -Variable *Module::createVariable(llvm::StringRef name, const Tensor &tensor, - VisibilityKind visibility, bool trainable) { - auto *V = createVariable(tensor.getElementType(), tensor.dims(), name, - visibility, trainable); +Variable *Module::createVariable(llvm::StringRef name, const Tensor &tensor) { + auto *V = createVariable(tensor.getElementType(), tensor.dims(), name); V->assign(&tensor); return V; } @@ -575,14 +566,11 @@ Function::createRowwiseQuantizedFullyConnected(llvm::StringRef name, // provided. But for rowwise quantization, the scales and offsets are stored // in vectors separately, we add the dummy scale and offset here. auto *qWeights = getParent()->createVariable(ElemKind::Int8QTy, W->dims(), - 0.0, 0, "weights.rwqfc", - VisibilityKind::Private, false); + 0.0, 0, "weights.rwqfc"); auto *scales = - getParent()->createVariable(ElemKind::FloatTy, {numRows}, "scales.rwqfc", - VisibilityKind::Private, false); + getParent()->createVariable(ElemKind::FloatTy, {numRows}, "scales.rwqfc"); auto *offsets = getParent()->createVariable(ElemKind::Int32QTy, {numRows}, - 0.0, 0, "offsets.rwqfc", - VisibilityKind::Private, false); + 0.0, 0, "offsets.rwqfc"); quantization::tensorRowwiseQuantization( weights->getPayload(), qWeights->getPayload(), scales->getPayload(), @@ -1306,7 +1294,7 @@ Function::createIntLookupTable(llvm::StringRef name, NodeValue input, TypeRef outTy) { auto *mapping = getParent()->createVariable( ElemKind::Int8QTy, {initValues.size()}, outTy->getScale(), - outTy->getOffset(), "mapping", VisibilityKind::Private, false); + outTy->getOffset(), "mapping"); mapping->getHandle() = initValues; return addNode(new IntLookupTableNode(name, outTy, input, mapping)); diff --git a/lib/Graph/Nodes.cpp b/lib/Graph/Nodes.cpp index 7f40e66c00..d22f360816 100644 --- a/lib/Graph/Nodes.cpp +++ b/lib/Graph/Nodes.cpp @@ -27,7 +27,7 @@ bool Storage::isEqual(const Storage &other) const { } llvm::hash_code Variable::getHash() const { - return llvm::hash_combine(getName(), isTraining(), getType()); + return llvm::hash_combine(getName(), getType()); } llvm::hash_code Placeholder::getHash() const { @@ -81,18 +81,11 @@ Node *Storage::clone() const { llvm_unreachable("variables can't be cloned."); } // Debug description methods //===----------------------------------------------------------------------===// -static const char *getVariableVisibilityKindStr(VisibilityKind kind) { - const char *names[] = {"public", "private", nullptr}; - return names[static_cast(kind)]; -} - std::string Variable::getDebugDesc() const { DescriptionBuilder db(getKindName()); db.addParam("name", quote(getName())) .addParam("output", *getType()) - .addParam("visibility", getVariableVisibilityKindStr(visibility_)); - db.addParam("train", isTraining()); - db.addParam("users", getNumUsers()); + .addParam("users", getNumUsers()); return db; } diff --git a/lib/IR/IRGen.cpp b/lib/IR/IRGen.cpp index e9b11b658c..9f806b27a4 100644 --- a/lib/IR/IRGen.cpp +++ b/lib/IR/IRGen.cpp @@ -369,7 +369,7 @@ struct IRGenVisitor : NodeWalker { auto *V = cast(N); auto *W = builder_.createWeightVar(V->getType(), V->getName(), WeightVar::MutabilityKind::Mutable, - V->getVisibilityKind()); + VisibilityKind::Private); W->setName(N->getName()); registerIR(N, W); break; diff --git a/lib/Importer/Caffe2ModelLoader.cpp b/lib/Importer/Caffe2ModelLoader.cpp index 750abf72f3..d6b09b12fd 100644 --- a/lib/Importer/Caffe2ModelLoader.cpp +++ b/lib/Importer/Caffe2ModelLoader.cpp @@ -249,10 +249,8 @@ void Caffe2ModelLoader::loadOperator(const caffe2::OperatorDef &op) { auto channel = getChannel(dict); auto *scaleV = G_.getParent()->createVariable("scale", *scale); auto *biasV = G_.getParent()->createVariable("bias", *bias); - auto *meanV = G_.getParent()->createVariable( - "mean", *mean, VisibilityKind::Private, false); - auto *varV = G_.getParent()->createVariable("var", *var, - VisibilityKind::Private, false); + auto *meanV = G_.getParent()->createVariable("mean", *mean); + auto *varV = G_.getParent()->createVariable("var", *var); auto *node = G_.createBatchNormalization(opName, in, biasV, scaleV, meanV, varV, channel, epsilon); @@ -318,10 +316,8 @@ void Caffe2ModelLoader::loadOperator(const caffe2::OperatorDef &op) { } else w->transpose(&wtag, {1, 0}); - auto W = G_.getParent()->addVar( - new Variable("weights", VisibilityKind::Private, std::move(wtag))); - auto B = G_.getParent()->addVar( - new Variable("biases", VisibilityKind::Private, std::move(*b))); + auto W = G_.getParent()->addVar(new Variable("weights", std::move(wtag))); + auto B = G_.getParent()->addVar(new Variable("biases", std::move(*b))); auto *node = G_.createFullyConnected(opName, in, W, B); // Save the outputs: diff --git a/lib/Importer/ONNXModelLoader.cpp b/lib/Importer/ONNXModelLoader.cpp index 63e339ca57..fc147728fe 100644 --- a/lib/Importer/ONNXModelLoader.cpp +++ b/lib/Importer/ONNXModelLoader.cpp @@ -417,10 +417,8 @@ bool ONNXModelLoader::loadOperator(const ONNX_NAMESPACE::NodeProto &op) { auto *scaleV = G_.getParent()->createVariable("scale", *scale); auto *biasV = G_.getParent()->createVariable("bias", *bias); - auto *meanV = G_.getParent()->createVariable( - "mean", *mean, VisibilityKind::Private, false); - auto *varV = G_.getParent()->createVariable("var", *var, - VisibilityKind::Private, false); + auto *meanV = G_.getParent()->createVariable("mean", *mean); + auto *varV = G_.getParent()->createVariable("var", *var); auto *node = G_.createBatchNormalization(opName, in, biasV, scaleV, meanV, varV, 1, epsilon); diff --git a/lib/Importer/ProtobufLoader.cpp b/lib/Importer/ProtobufLoader.cpp index 5db9358731..ee3f7c3a2a 100644 --- a/lib/Importer/ProtobufLoader.cpp +++ b/lib/Importer/ProtobufLoader.cpp @@ -65,9 +65,7 @@ Variable *ProtobufLoader::createAndRegisterConstant(llvm::StringRef name, assert(!hasNodeByName(name) && "Creating an already existing node"); // Note: We do not support training from models loaded from protos, so // trainable is always set to false here. - Variable *node = - G_.getParent()->createVariable(name, tensor, VisibilityKind::Private, - /* trainable */ false); + Variable *node = G_.getParent()->createVariable(name, tensor); nodeValueByName_[name] = NodeValue(node, 0); return node; } diff --git a/lib/Optimizer/GraphOptimizer.cpp b/lib/Optimizer/GraphOptimizer.cpp index 10c760446e..e8f1031049 100644 --- a/lib/Optimizer/GraphOptimizer.cpp +++ b/lib/Optimizer/GraphOptimizer.cpp @@ -52,13 +52,6 @@ static bool shouldDeleteNode(Node *N) { return false; } - if (Variable *V = dyn_cast(N)) { - // We don't want to delete unused public variables because they are - // accessible to the outside world that may hold a reference to them. - if (V->getVisibilityKind() == VisibilityKind::Public) - return false; - } - return true; } @@ -612,7 +605,7 @@ static bool mergeTransposeIntoMatMul(Function *F) { // MatMul RHS is constant weights. auto *W = dyn_cast(MMN->getRHS()); - if (!W || !W->isPrivate()) { + if (!W) { continue; } @@ -658,8 +651,7 @@ static bool mergeTransposeIntoMatMul(Function *F) { F->getParent()->uniqueTypeWithNewShape(W->getType(), newShape); // New reordered weights. - auto *newW = F->getParent()->createVariable( - W->getType(), W->getName(), W->getVisibilityKind(), W->isTraining()); + auto *newW = F->getParent()->createVariable(W->getType(), W->getName()); Tensor reshapedSrc(W->getPayload().getUnsafePtr(), reshapedWTy); Tensor reshapedDst(newW->getPayload().getUnsafePtr(), reshapedNewWTy); reshapedSrc.transpose(&reshapedDst, shuffle); @@ -1317,14 +1309,13 @@ static void optimizeTranspose(Function *F) { continue; } auto *V = dyn_cast(TN->getInput()); - // V must have a single use and be private. - if (!V || !V->hasOneUse() || !V->isPrivate()) { + // V must have a single use. + if (!V || !V->hasOneUse()) { continue; } // Create a new variable NV to hold the transposed result. auto *NV = - F->getParent()->createVariable(TN->getResult().getType(), V->getName(), - V->getVisibilityKind(), V->isTraining()); + F->getParent()->createVariable(TN->getResult().getType(), V->getName()); // Transpose the value of V into NV. genericTranspose(&V->getPayload(), &NV->getPayload(), TN->getShuffle()); // Rewrite uses of TN to reference NV. @@ -1430,10 +1421,6 @@ struct VarsEqDedup { if (lhs->getType() != rhs->getType()) { return false; } - assert(lhs->getVisibilityKind() == rhs->getVisibilityKind() && - "Should only be comparing Variables with same VisibilityKind."); - assert(lhs->isTraining() == rhs->isTraining() && - "Should only be comparing Variables with same training mode."); // Only combine Vars if their data matches exactly, so allowed error is 0.0. return lhs->getPayload().isEqual(rhs->getPayload(), /* allowedError */ 0.0); } @@ -1463,7 +1450,6 @@ static bool hasWriters(Variable *V) { /// Deduplicates constant variables in the Module \p M. Applicable constant /// variables for deduplication must have the same data, have -/// VisibilityKind::Private, not trainable, and have no writers. static void deduplicateConstants(Module *M) { // Map from Variables to other Variables that are equivalent for purposes of // deduplication. @@ -1480,11 +1466,6 @@ static void deduplicateConstants(Module *M) { continue; } - // Only perform deduplication on private vars that have no train kind. - if (V->getVisibilityKind() != VisibilityKind::Private || V->isTraining()) { - continue; - } - // Only perform deduplication on vars that have no writers. if (hasWriters(V)) { continue; @@ -1573,11 +1554,10 @@ static void optimizeReshape(Function *F) { // Only do this if the Variable has a single use, as otherwise we would // duplicate the Variable and increase the memory footprint. auto *V = dyn_cast(inputNode); - if (V && V->isPrivate() && V->hasOneUse()) { + if (V && V->hasOneUse()) { // Create a new variable with the type of the reshape. auto *newV = F->getParent()->createVariable( - reshapeNode->getResult().getType(), V->getName(), - V->getVisibilityKind(), /* isTrainable */ false); + reshapeNode->getResult().getType(), V->getName()); // Create an unowned view of the original tensor with the correct shape, // and assign it to the new Variable. Tensor reshapedT = V->getPayload().getUnowned(reshapeNode->getDims()); @@ -1673,17 +1653,15 @@ static void optimizeQuantization(Function *F) { if (auto *V = dyn_cast(Q->getInput())) { // Quantize(Variable) -> Variable - // V must be a private variable. // Note, it does not really matter how many usages this var has. // Quantized graph will use optimized var and other functions will // refer to the floating point original var. - if (!V || !V->isPrivate()) { + if (!V) { continue; } // Create a new variable NV to hold the quantized result. - auto *NV = F->getParent()->createVariable( - Q->getResult().getType(), V->getName(), V->getVisibilityKind(), - false); + auto *NV = F->getParent()->createVariable(Q->getResult().getType(), + V->getName()); // Quantize V into NV. auto srcHandle = V->getHandle(); auto destHandle = NV->getHandle(); @@ -1983,8 +1961,7 @@ void glow::convertPlaceholdersToConstants(Function *F, const Context &ctx, if (!tensor) { continue; } - auto *constantV = M->createVariable(PH->getName(), *tensor, - VisibilityKind::Private, false); + auto *constantV = M->createVariable(PH->getName(), *tensor); PH->getOutput().replaceAllUsesOfWith(constantV, F); } } diff --git a/tests/unittests/OperatorTest.cpp b/tests/unittests/OperatorTest.cpp index 58295c8aa7..bb765e4bb1 100644 --- a/tests/unittests/OperatorTest.cpp +++ b/tests/unittests/OperatorTest.cpp @@ -4017,8 +4017,7 @@ TEST_P(InterpAndCPU, rowwiseQuantizedFCTest) { // The FC fomula is I * W + B, while the RWQFC is I * transpose(W) + B. // So get the tranpose of weights from the above FC. auto *newWeights = mod_.createVariable( - ElemKind::FloatTy, {weights->dims()[1], weights->dims()[0]}, "newW", - VisibilityKind::Private, false); + ElemKind::FloatTy, {weights->dims()[1], weights->dims()[0]}, "newW"); ctx_.get(weights)->transpose(&newWeights->getPayload(), {1, 0}); TypeRef inputTy = diff --git a/tests/unittests/graphOptzTest.cpp b/tests/unittests/graphOptzTest.cpp index eee5297a7f..f7b158fb86 100644 --- a/tests/unittests/graphOptzTest.cpp +++ b/tests/unittests/graphOptzTest.cpp @@ -1813,12 +1813,9 @@ TEST_F(GraphOptz, VarsCSE) { // writers. The first two variables have the same data, and so should be // combined via variable CSE. The third variable differs by the last value, // and so should not be combined. - auto *input1 = mod_.createVariable(ElemKind::FloatTy, {10}, "input1", - VisibilityKind::Private, false); - auto *input2 = mod_.createVariable(ElemKind::FloatTy, {10}, "input2", - VisibilityKind::Private, false); - auto *input3 = mod_.createVariable(ElemKind::FloatTy, {10}, "input3", - VisibilityKind::Private, false); + auto *input1 = mod_.createVariable(ElemKind::FloatTy, {10}, "input1"); + auto *input2 = mod_.createVariable(ElemKind::FloatTy, {10}, "input2"); + auto *input3 = mod_.createVariable(ElemKind::FloatTy, {10}, "input3"); input1->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; input2->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; input3->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1}; @@ -1981,8 +1978,8 @@ TEST_F(GraphOptz, ReshapePrivateVarOneUse) { TEST_F(GraphOptz, mergeTransposeIntoMatMul) { auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3}, "input", false); - auto *weights = F_->getParent()->createVariable( - ElemKind::FloatTy, {6, 1}, "weights", VisibilityKind::Private); + auto *weights = + F_->getParent()->createVariable(ElemKind::FloatTy, {6, 1}, "weights"); weights->getHandle() = {0, 1, 2, 3, 4, 5}; float newWeightsRef[] = {0, 2, 4, 1, 3, 5}; diff --git a/tests/unittests/graphTest.cpp b/tests/unittests/graphTest.cpp index 349ac2b62f..7a45687ea5 100644 --- a/tests/unittests/graphTest.cpp +++ b/tests/unittests/graphTest.cpp @@ -35,8 +35,7 @@ TEST(Graph, testVariableErasure) { EXPECT_EQ(vars.size(), 0); EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size()); - Variable *V = MD.createVariable(ElemKind::FloatTy, {1, 1}, "dummy", - VisibilityKind::Public); + Variable *V = MD.createVariable(ElemKind::FloatTy, {1, 1}, "dummy"); EXPECT_EQ(vars.size(), 1); EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size()); @@ -698,8 +697,7 @@ TEST(Graph, parentLink) { ExecutionEngine EE; auto &mod = EE.getModule(); - Variable *V = new Variable("V", mod.uniqueType(ElemKind::FloatTy, {3, 32}), - VisibilityKind::Private, true); + Variable *V = new Variable("V", mod.uniqueType(ElemKind::FloatTy, {3, 32})); // Variables don't belong to any function... EXPECT_EQ(V->getParent(), nullptr);