Skip to content

Commit 1260d74

Browse files
committed
Update createBundle to set symbol input and output
1 parent 4e0d094 commit 1260d74

File tree

3 files changed

+112
-48
lines changed

3 files changed

+112
-48
lines changed

lib/Backends/BackendUtils.cpp

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,13 @@
1919
using namespace glow;
2020

2121
using llvm::cast;
22+
using llvm::dyn_cast;
2223
using llvm::isa;
2324

2425
void glow::runtime::RuntimeBundle::collectConstants(const IRFunction *F) {
2526
collectConstants(F->getGraph()->getParent());
2627
}
2728

28-
void glow::runtime::RuntimeBundle::setInputsandOutputs() {
29-
for (auto &symbol : symbolTable_) {
30-
symbol.second.input = true;
31-
symbol.second.output = true;
32-
}
33-
}
34-
3529
void glow::runtime::RuntimeBundle::freeConstants() {
3630
if (constants_) {
3731
glow::alignedFree(constants_);
@@ -81,6 +75,45 @@ runtime::RuntimeBundle::getSymbolInfo(const Named *v) const {
8175
return it->second;
8276
}
8377

78+
/// If \p PH is an output placeholder, \returns true.
79+
/// This is determined by checking if the PH has a user which uses the PH as an
80+
/// overwritten input.
81+
bool isOutput(const Placeholder *PH) {
82+
for (const auto &use : PH->getUsers()) {
83+
// Look through the inputs of the PH's users. If an input is overwritten
84+
// check if it's the PH, if it is return true.
85+
auto *user = use.getUser();
86+
for (unsigned i = 0, numInputs = user->getNumInputs(); i < numInputs; i++) {
87+
// If the input is not overwritten we can continue.
88+
if (!user->isOverwrittenNthInput(i)) {
89+
continue;
90+
}
91+
auto input = use.getUser()->getNthInput(i);
92+
if (input == PH) {
93+
return true;
94+
}
95+
}
96+
}
97+
return false;
98+
}
99+
100+
/// If \p PH is an input placeholder, \returns true.
101+
bool isInput(const Placeholder *PH) {
102+
// Check that the PH is the input to a saveNode or is used by a non saveNode.
103+
for (const auto &use : PH->getUsers()) {
104+
// Check if PH is an input to a saveNode.
105+
if (auto *save = dyn_cast<SaveNode>(use.getUser())) {
106+
auto input = save->getInput();
107+
// If the PH is not an input to the saveNode we keep looking.
108+
if (input != PH) {
109+
continue;
110+
}
111+
}
112+
return true;
113+
}
114+
return false;
115+
}
116+
84117
runtime::RuntimeBundle runtime::RuntimeBundle::create(const Function &F) {
85118
std::unordered_map<std::string, runtime::RuntimeSymbolInfo> symbolTable;
86119

@@ -95,6 +128,8 @@ runtime::RuntimeBundle runtime::RuntimeBundle::create(const Function &F) {
95128
symbol.offset = offset;
96129
symbol.size = size;
97130
symbol.type = *V->getType();
131+
symbol.input = false;
132+
symbol.output = false;
98133
symbol.symbolCategory = SymbolCategory::Constant;
99134
symbolTable.emplace(V->getName(), symbol);
100135
}
@@ -107,6 +142,8 @@ runtime::RuntimeBundle runtime::RuntimeBundle::create(const Function &F) {
107142
symbol.offset = offset;
108143
symbol.size = size;
109144
symbol.type = *V->getType();
145+
symbol.output = isOutput(V);
146+
symbol.input = isInput(V);
110147
symbol.symbolCategory = SymbolCategory::Placeholder;
111148
symbolTable.emplace(V->getName(), symbol);
112149
}
@@ -134,6 +171,8 @@ runtime::RuntimeBundle::create(const IRFunction &F,
134171
symbol.size = numBytes;
135172
symbol.offset = addr;
136173
symbol.type = *w->getType();
174+
symbol.input = false;
175+
symbol.output = false;
137176
symbol.symbolCategory = SymbolCategory::Constant;
138177
symbolTable.emplace(std::string(v->getName()), symbol);
139178
}
@@ -150,6 +189,8 @@ runtime::RuntimeBundle::create(const IRFunction &F,
150189
symbol.offset = addr;
151190
symbol.size = numBytes;
152191
symbol.type = *w->getType();
192+
symbol.output = isOutput(v);
193+
symbol.input = isInput(v);
153194
symbol.symbolCategory = SymbolCategory::Placeholder;
154195
symbolTable.emplace(std::string(v->getName()), symbol);
155196
}
@@ -158,7 +199,7 @@ runtime::RuntimeBundle::create(const IRFunction &F,
158199
// Compute the offsets for Activations.
159200

160201
for (const auto &I : F.getInstrs()) {
161-
if (auto *A = llvm::dyn_cast<AllocActivationInst>(&I)) {
202+
if (auto *A = dyn_cast<AllocActivationInst>(&I)) {
162203
auto numBytes = I.getSizeInBytes();
163204
size_t addr = activationsAllocator.allocate(numBytes, A);
164205
assert(!symbolTable.count(std::string(A->getName())) &&
@@ -167,12 +208,14 @@ runtime::RuntimeBundle::create(const IRFunction &F,
167208
symbol.offset = addr;
168209
symbol.size = numBytes;
169210
symbol.type = *A->getType();
211+
symbol.input = false;
212+
symbol.output = false;
170213
symbol.symbolCategory = SymbolCategory::Activation;
171214
symbolTable.emplace(std::string(A->getName()), symbol);
172215
continue;
173216
}
174217

175-
if (auto *TV = llvm::dyn_cast<TensorViewInst>(&I)) {
218+
if (auto *TV = dyn_cast<TensorViewInst>(&I)) {
176219
// Calculate and store the length of the offset into the base, using the
177220
// source of the tensorview.
178221
assert(!symbolTable.count(std::string(TV->getName())) &&
@@ -191,6 +234,8 @@ runtime::RuntimeBundle::create(const IRFunction &F,
191234
(offsetLength * TV->getType()->getElementSize());
192235
symbol.size = TV->getSizeInBytes();
193236
symbol.type = *TV->getType();
237+
symbol.input = false;
238+
symbol.output = false;
194239
auto parentCategory =
195240
symbolTable.find(tvSource->getName())->second.symbolCategory;
196241
if (parentCategory == SymbolCategory::Placeholder) {
@@ -202,7 +247,7 @@ runtime::RuntimeBundle::create(const IRFunction &F,
202247
continue;
203248
}
204249

205-
if (auto *D = llvm::dyn_cast<DeallocActivationInst>(&I)) {
250+
if (auto *D = dyn_cast<DeallocActivationInst>(&I)) {
206251
auto *A = D->getAlloc();
207252
assert(symbolTable.count(std::string(A->getName())) &&
208253
"Invalid deallocation!");

lib/Runtime/Provisioner/Provisioner.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ llvm::Error Provisioner::provision(DAGListTy &networks, Module &module) {
8181
auto compiled = backend_->compile(function, compileOptions);
8282
node->runtimeBundle =
8383
llvm::make_unique<RuntimeBundle>(compiled->getRuntimeBundle());
84-
node->runtimeBundle->setInputsandOutputs();
8584
functionMap.emplace(node->name, compiled.get());
8685
functions_.emplace(node->name, std::move(compiled));
8786
totalMemory += node->runtimeBundle->getConstantWeightSize();

tests/unittests/BackendTest.cpp

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,63 @@ TEST(Interpreter, profileQuantizationForANetwork) {
106106
EXPECT_NEAR(1.6, max, 0.00001);
107107
}
108108

109+
/// Test that the symbol category for a symbol is properly set.
110+
TEST(RuntimeBundle, BundleSymbolInfo) {
111+
Module mod;
112+
ExecutionEngine EE;
113+
PlaceholderBindings bindings;
114+
115+
Tensor inputs(ElemKind::FloatTy, {1, 10, 10, 3});
116+
inputs.getHandle().randomize(-2, 2, mod.getPRNG());
117+
118+
// Create a simple graph that has placeholders, constants, activations, and a
119+
// tensor_view.
120+
Function *F = mod.createFunction("main");
121+
auto *input =
122+
mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in", false);
123+
124+
auto *ex = mod.createConstant(ElemKind::Int64ITy, {1, 1}, "exp");
125+
126+
auto *FC = F->createFullyConnected(bindings, "FC", input, 30);
127+
auto *RL = F->createRELU("RL2", FC);
128+
auto *SM = F->createSoftMax("sm", RL, ex);
129+
auto *S = F->createSave("ret", SM);
130+
auto *qp = F->createQuantizationProfile(bindings, "qp", input);
131+
132+
EE.compile(CompilationMode::Infer, F);
133+
auto table = EE.getCompiledFunction().getRuntimeBundle().getSymbolTable();
134+
// Check that placeholders and constants are correctly labelled.
135+
EXPECT_EQ(table.find(S->getName())->second.symbolCategory,
136+
glow::runtime::SymbolCategory::Placeholder);
137+
EXPECT_EQ(table.find(ex->getName())->second.symbolCategory,
138+
glow::runtime::SymbolCategory::Constant);
139+
// Check that activations are labelled correctly.
140+
EXPECT_EQ(table.find("fc_add_bias_res")->second.symbolCategory,
141+
glow::runtime::SymbolCategory::Activation);
142+
// Check that tensor views have the same label as their parent symbol. In this
143+
// case same as "input".
144+
EXPECT_EQ(table.find("tensorview_reshape")->second.symbolCategory,
145+
glow::runtime::SymbolCategory::PlaceholderTensorView);
146+
147+
// Check that placeholders and constants input/output flags are correctly set.
148+
EXPECT_EQ(table.find(S->getName())->second.input, false);
149+
EXPECT_EQ(table.find(S->getName())->second.output, true);
150+
EXPECT_EQ(table.find(ex->getName())->second.input, false);
151+
EXPECT_EQ(table.find(ex->getName())->second.output, false);
152+
EXPECT_EQ(table.find(input->getName())->second.input, true);
153+
EXPECT_EQ(table.find(input->getName())->second.output, false);
154+
EXPECT_EQ(table.find(qp->getHistogramPlaceholder()->getName())->second.input,
155+
true);
156+
EXPECT_EQ(table.find(qp->getHistogramPlaceholder()->getName())->second.output,
157+
true);
158+
// Check that activations are labelled correctly.
159+
EXPECT_EQ(table.find("fc_add_bias_res")->second.input, false);
160+
EXPECT_EQ(table.find("fc_add_bias_res")->second.output, false);
161+
// Check that tensor views are labelled correctly.
162+
EXPECT_EQ(table.find("tensorview_reshape")->second.input, false);
163+
EXPECT_EQ(table.find("tensorview_reshape")->second.output, false);
164+
}
165+
109166
TEST_P(BackendTest, simpleInference) {
110167
Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3});
111168
PlaceholderBindings ctx;
@@ -247,43 +304,6 @@ TEST_P(BackendTest, BundleSharedConstant) {
247304
EXPECT_TRUE(it2 != table2.end());
248305
}
249306

250-
/// Test that the symbol category for a symbol is properly set.
251-
TEST_P(BackendTest, BundleSymbolCategory) {
252-
Module mod;
253-
PlaceholderBindings bindings;
254-
255-
Tensor inputs(ElemKind::FloatTy, {1, 10, 10, 3});
256-
inputs.getHandle().randomize(-2, 2, mod.getPRNG());
257-
258-
// Create a simple graph that has placeholders, constants, activations, and a
259-
// tensor_view.
260-
Function *F = mod.createFunction("main");
261-
auto *input =
262-
mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in", false);
263-
264-
auto *ex = mod.createConstant(ElemKind::Int64ITy, {1, 1}, "exp");
265-
266-
auto *FC = F->createFullyConnected(bindings, "FC", input, 30);
267-
auto *RL = F->createRELU("RL2", FC);
268-
auto *SM = F->createSoftMax("sm", RL, ex);
269-
auto *S = F->createSave("ret", SM);
270-
271-
EE_.compile(CompilationMode::Infer, F);
272-
auto table = EE_.getCompiledFunction().getRuntimeBundle().getSymbolTable();
273-
// Check that placeholders and constants are correctly labelled.
274-
EXPECT_EQ(table.find(S->getName())->second.symbolCategory,
275-
glow::runtime::SymbolCategory::Placeholder);
276-
EXPECT_EQ(table.find(ex->getName())->second.symbolCategory,
277-
glow::runtime::SymbolCategory::Constant);
278-
// Check that activations are labelled correctly.
279-
EXPECT_EQ(table.find("fc_add_bias_res")->second.symbolCategory,
280-
glow::runtime::SymbolCategory::Activation);
281-
// Check that tensor views have the same label as their parent symbol. In this
282-
// case same as "input".
283-
EXPECT_EQ(table.find("tensorview_reshape")->second.symbolCategory,
284-
glow::runtime::SymbolCategory::PlaceholderTensorView);
285-
}
286-
287307
/// Test compiling a vector of functions completes without error.
288308
TEST_P(BackendTest, compileVectorOfFunctions) {
289309
Module mod;

0 commit comments

Comments
 (0)