From 611787872bfd09af1caa5a64fdc303aa24206c67 Mon Sep 17 00:00:00 2001 From: Garret Catron Date: Mon, 8 Apr 2019 10:58:47 +0100 Subject: [PATCH] Update createBundle to set symbol input and output --- lib/Backends/BackendUtils.cpp | 65 ++++++++++++++--- lib/Runtime/Provisioner/Provisioner.cpp | 1 - tests/unittests/BackendTest.cpp | 94 +++++++++++++++---------- 3 files changed, 112 insertions(+), 48 deletions(-) diff --git a/lib/Backends/BackendUtils.cpp b/lib/Backends/BackendUtils.cpp index 20179b12f8..510bd85514 100644 --- a/lib/Backends/BackendUtils.cpp +++ b/lib/Backends/BackendUtils.cpp @@ -19,19 +19,13 @@ using namespace glow; using llvm::cast; +using llvm::dyn_cast; using llvm::isa; void glow::runtime::RuntimeBundle::collectConstants(const IRFunction *F) { collectConstants(F->getGraph()->getParent()); } -void glow::runtime::RuntimeBundle::setInputsandOutputs() { - for (auto &symbol : symbolTable_) { - symbol.second.input = true; - symbol.second.output = true; - } -} - void glow::runtime::RuntimeBundle::freeConstants() { if (constants_) { glow::alignedFree(constants_); @@ -81,6 +75,45 @@ runtime::RuntimeBundle::getSymbolInfo(const Named *v) const { return it->second; } +/// If \p PH is an output placeholder, \returns true. +/// This is determined by checking if the PH has a user which uses the PH as an +/// overwritten input. +bool isOutput(const Placeholder *PH) { + for (const auto &use : PH->getUsers()) { + // Look through the inputs of the PH's users. If an input is overwritten + // check if it's the PH, if it is return true. + auto *user = use.getUser(); + for (unsigned i = 0, numInputs = user->getNumInputs(); i < numInputs; i++) { + // If the input is not overwritten we can continue. + if (!user->isOverwrittenNthInput(i)) { + continue; + } + auto input = use.getUser()->getNthInput(i); + if (input == PH) { + return true; + } + } + } + return false; +} + +/// If \p PH is an input placeholder, \returns true. +bool isInput(const Placeholder *PH) { + // Check that the PH is the input to a saveNode or is used by a non saveNode. + for (const auto &use : PH->getUsers()) { + // Check if PH is an input to a saveNode. + if (auto *save = dyn_cast(use.getUser())) { + auto input = save->getInput(); + // If the PH is not an input to the saveNode we keep looking. + if (input != PH) { + continue; + } + } + return true; + } + return false; +} + runtime::RuntimeBundle runtime::RuntimeBundle::create(const Function &F) { std::unordered_map symbolTable; @@ -95,6 +128,8 @@ runtime::RuntimeBundle runtime::RuntimeBundle::create(const Function &F) { symbol.offset = offset; symbol.size = size; symbol.type = *V->getType(); + symbol.input = false; + symbol.output = false; symbol.symbolCategory = SymbolCategory::Constant; symbolTable.emplace(V->getName(), symbol); } @@ -107,6 +142,8 @@ runtime::RuntimeBundle runtime::RuntimeBundle::create(const Function &F) { symbol.offset = offset; symbol.size = size; symbol.type = *V->getType(); + symbol.output = isOutput(V); + symbol.input = isInput(V); symbol.symbolCategory = SymbolCategory::Placeholder; symbolTable.emplace(V->getName(), symbol); } @@ -134,6 +171,8 @@ runtime::RuntimeBundle::create(const IRFunction &F, symbol.size = numBytes; symbol.offset = addr; symbol.type = *w->getType(); + symbol.input = false; + symbol.output = false; symbol.symbolCategory = SymbolCategory::Constant; symbolTable.emplace(std::string(v->getName()), symbol); } @@ -150,6 +189,8 @@ runtime::RuntimeBundle::create(const IRFunction &F, symbol.offset = addr; symbol.size = numBytes; symbol.type = *w->getType(); + symbol.output = isOutput(v); + symbol.input = isInput(v); symbol.symbolCategory = SymbolCategory::Placeholder; symbolTable.emplace(std::string(v->getName()), symbol); } @@ -158,7 +199,7 @@ runtime::RuntimeBundle::create(const IRFunction &F, // Compute the offsets for Activations. for (const auto &I : F.getInstrs()) { - if (auto *A = llvm::dyn_cast(&I)) { + if (auto *A = dyn_cast(&I)) { auto numBytes = I.getSizeInBytes(); size_t addr = activationsAllocator.allocate(numBytes, A); assert(!symbolTable.count(std::string(A->getName())) && @@ -167,12 +208,14 @@ runtime::RuntimeBundle::create(const IRFunction &F, symbol.offset = addr; symbol.size = numBytes; symbol.type = *A->getType(); + symbol.input = false; + symbol.output = false; symbol.symbolCategory = SymbolCategory::Activation; symbolTable.emplace(std::string(A->getName()), symbol); continue; } - if (auto *TV = llvm::dyn_cast(&I)) { + if (auto *TV = dyn_cast(&I)) { // Calculate and store the length of the offset into the base, using the // source of the tensorview. assert(!symbolTable.count(std::string(TV->getName())) && @@ -191,6 +234,8 @@ runtime::RuntimeBundle::create(const IRFunction &F, (offsetLength * TV->getType()->getElementSize()); symbol.size = TV->getSizeInBytes(); symbol.type = *TV->getType(); + symbol.input = false; + symbol.output = false; auto parentCategory = symbolTable.find(tvSource->getName())->second.symbolCategory; if (parentCategory == SymbolCategory::Placeholder) { @@ -202,7 +247,7 @@ runtime::RuntimeBundle::create(const IRFunction &F, continue; } - if (auto *D = llvm::dyn_cast(&I)) { + if (auto *D = dyn_cast(&I)) { auto *A = D->getAlloc(); assert(symbolTable.count(std::string(A->getName())) && "Invalid deallocation!"); diff --git a/lib/Runtime/Provisioner/Provisioner.cpp b/lib/Runtime/Provisioner/Provisioner.cpp index 4756481e68..66d0724a19 100644 --- a/lib/Runtime/Provisioner/Provisioner.cpp +++ b/lib/Runtime/Provisioner/Provisioner.cpp @@ -68,7 +68,6 @@ llvm::Error Provisioner::provision(DAGListTy &networks, Module &module) { auto compiled = backend_->compile(function, compileOptions); node->runtimeBundle = llvm::make_unique(compiled->getRuntimeBundle()); - node->runtimeBundle->setInputsandOutputs(); functionMap.emplace(node->name, compiled.get()); functions_.emplace(node->name, std::move(compiled)); totalMemory += node->runtimeBundle->getConstantWeightSize(); diff --git a/tests/unittests/BackendTest.cpp b/tests/unittests/BackendTest.cpp index 96c895f21a..949f672362 100644 --- a/tests/unittests/BackendTest.cpp +++ b/tests/unittests/BackendTest.cpp @@ -106,6 +106,63 @@ TEST(Interpreter, profileQuantizationForANetwork) { EXPECT_NEAR(1.6, max, 0.00001); } +/// Test that the symbol category for a symbol is properly set. +TEST(RuntimeBundle, BundleSymbolInfo) { + Module mod; + ExecutionEngine EE; + PlaceholderBindings bindings; + + Tensor inputs(ElemKind::FloatTy, {1, 10, 10, 3}); + inputs.getHandle().randomize(-2, 2, mod.getPRNG()); + + // Create a simple graph that has placeholders, constants, activations, and a + // tensor_view. + Function *F = mod.createFunction("main"); + auto *input = + mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in", false); + + auto *ex = mod.createConstant(ElemKind::Int64ITy, {1, 1}, "exp"); + + auto *FC = F->createFullyConnected(bindings, "FC", input, 30); + auto *RL = F->createRELU("RL2", FC); + auto *SM = F->createSoftMax("sm", RL, ex); + auto *S = F->createSave("ret", SM); + auto *qp = F->createQuantizationProfile(bindings, "qp", input); + + EE.compile(CompilationMode::Infer, F); + auto table = EE.getCompiledFunction().getRuntimeBundle().getSymbolTable(); + // Check that placeholders and constants are correctly labelled. + EXPECT_EQ(table.find(S->getName())->second.symbolCategory, + glow::runtime::SymbolCategory::Placeholder); + EXPECT_EQ(table.find(ex->getName())->second.symbolCategory, + glow::runtime::SymbolCategory::Constant); + // Check that activations are labelled correctly. + EXPECT_EQ(table.find("fc_add_bias_res")->second.symbolCategory, + glow::runtime::SymbolCategory::Activation); + // Check that tensor views have the same label as their parent symbol. In this + // case same as "input". + EXPECT_EQ(table.find("tensorview_reshape")->second.symbolCategory, + glow::runtime::SymbolCategory::PlaceholderTensorView); + + // Check that placeholders and constants input/output flags are correctly set. + EXPECT_EQ(table.find(S->getName())->second.input, false); + EXPECT_EQ(table.find(S->getName())->second.output, true); + EXPECT_EQ(table.find(ex->getName())->second.input, false); + EXPECT_EQ(table.find(ex->getName())->second.output, false); + EXPECT_EQ(table.find(input->getName())->second.input, true); + EXPECT_EQ(table.find(input->getName())->second.output, false); + EXPECT_EQ(table.find(qp->getHistogramPlaceholder()->getName())->second.input, + true); + EXPECT_EQ(table.find(qp->getHistogramPlaceholder()->getName())->second.output, + true); + // Check that activations are labelled correctly. + EXPECT_EQ(table.find("fc_add_bias_res")->second.input, false); + EXPECT_EQ(table.find("fc_add_bias_res")->second.output, false); + // Check that tensor views are labelled correctly. + EXPECT_EQ(table.find("tensorview_reshape")->second.input, false); + EXPECT_EQ(table.find("tensorview_reshape")->second.output, false); +} + TEST_P(BackendTest, simpleInference) { Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3}); PlaceholderBindings ctx; @@ -247,43 +304,6 @@ TEST_P(BackendTest, BundleSharedConstant) { EXPECT_TRUE(it2 != table2.end()); } -/// Test that the symbol category for a symbol is properly set. -TEST_P(BackendTest, BundleSymbolCategory) { - Module mod; - PlaceholderBindings bindings; - - Tensor inputs(ElemKind::FloatTy, {1, 10, 10, 3}); - inputs.getHandle().randomize(-2, 2, mod.getPRNG()); - - // Create a simple graph that has placeholders, constants, activations, and a - // tensor_view. - Function *F = mod.createFunction("main"); - auto *input = - mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in", false); - - auto *ex = mod.createConstant(ElemKind::Int64ITy, {1, 1}, "exp"); - - auto *FC = F->createFullyConnected(bindings, "FC", input, 30); - auto *RL = F->createRELU("RL2", FC); - auto *SM = F->createSoftMax("sm", RL, ex); - auto *S = F->createSave("ret", SM); - - EE_.compile(CompilationMode::Infer, F); - auto table = EE_.getCompiledFunction().getRuntimeBundle().getSymbolTable(); - // Check that placeholders and constants are correctly labelled. - EXPECT_EQ(table.find(S->getName())->second.symbolCategory, - glow::runtime::SymbolCategory::Placeholder); - EXPECT_EQ(table.find(ex->getName())->second.symbolCategory, - glow::runtime::SymbolCategory::Constant); - // Check that activations are labelled correctly. - EXPECT_EQ(table.find("fc_add_bias_res")->second.symbolCategory, - glow::runtime::SymbolCategory::Activation); - // Check that tensor views have the same label as their parent symbol. In this - // case same as "input". - EXPECT_EQ(table.find("tensorview_reshape")->second.symbolCategory, - glow::runtime::SymbolCategory::PlaceholderTensorView); -} - /// Test compiling a vector of functions completes without error. TEST_P(BackendTest, compileVectorOfFunctions) { Module mod;