Skip to content

Commit 083e983

Browse files
committed
[Placeholder] Add support for the runBatch interface.
Add support for the runBatch interface. This commit implements the parts of runBatch that copy parts of the tensors into the placeholder. The implementation is zero-copy because we use tensor views. I hope that this is a temporary solution and that we'll get rid of runBatch soon (see #1587).
1 parent bdae151 commit 083e983

File tree

2 files changed

+63
-4
lines changed

2 files changed

+63
-4
lines changed

lib/ExecutionEngine/ExecutionEngine.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,15 @@ void ExecutionEngine::runBatch(size_t iterations,
113113
void ExecutionEngine::updateInputsAndRunNetwork(llvm::ArrayRef<Storage *> vars,
114114
llvm::ArrayRef<Tensor *> inputs,
115115
size_t sampleIdx) {
116-
117116
llvm::SmallVector<Placeholder *, 8> placeholders;
117+
118+
// This container saves the tensor slices that were extracted from the inputs.
119+
// The tensors are alive during the lifetime of this function. These tensors
120+
// must not move, and we must pass them as ArrayRef, so we allocate them here
121+
// manually.
122+
llvm::SmallVector<Tensor, 8> tempTensorViews;
123+
// We need to pass the tensors as array ref, so we reference the static
124+
// storage above, that's frozen during the execution of the program.
118125
llvm::SmallVector<Tensor *, 8> tensors;
119126

120127
// Update the input variables.
@@ -129,7 +136,26 @@ void ExecutionEngine::updateInputsAndRunNetwork(llvm::ArrayRef<Storage *> vars,
129136
// This is a placeholder that we need to make concrete during the execution
130137
// of the program.
131138
placeholders.push_back(cast<Placeholder>(vars[i]));
132-
tensors.push_back(inputs[i]);
139+
140+
// Extract a tensor view from the input values. This has the same
141+
// functionality as the logic in loadValueFromTensorSlice that copies the
142+
// content of the tensor.
143+
auto batchSize = inputs[i]->dims()[0];
144+
auto startIdx = sampleIdx % batchSize;
145+
auto phDims = placeholders[i]->dims();
146+
147+
// The start offset is all zeros except for the batch dimension.
148+
std::vector<size_t> sliceOffset(inputs[i]->dims().size(), 0);
149+
sliceOffset[0] = startIdx;
150+
151+
// Create the tensor view.
152+
tempTensorViews.push_back(inputs[i]->getUnowned(phDims, sliceOffset));
153+
}
154+
155+
// Save the pointer to the allocated tensor view, because the interface
156+
// requires ArrayRef of tensors.
157+
for (int i = 0, e = tempTensorViews.size(); i < e; i++) {
158+
tensors.push_back(&tempTensorViews[i]);
133159
}
134160

135161
// Run the network.

tests/unittests/BackendTest.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ TEST_P(BackendTest, decoupleCodegenFromGraph) {
173173
EXPECT_NEAR(HX.at({2}), 9, 1E-5);
174174
}
175175

176+
/// Check that we can pass information to the execution engine using Placeholder
177+
/// variables and read it back using Save nodes (in variables).
176178
TEST(Placeholder, simplePlaceholderValue) {
177179
Tensor data{99.0, 35.0, 2.0, 3.0};
178180

@@ -188,11 +190,42 @@ TEST(Placeholder, simplePlaceholderValue) {
188190
EE.run({input}, {&data});
189191

190192
auto &res = S->getVariable()->getPayload();
191-
192-
res.getHandle().dump();
193193
EXPECT_TRUE(res.isEqual(data));
194194
}
195195

196+
/// Check that we can pass information to the execution engine using Placeholder
197+
/// variables and the runBatch API.
198+
TEST(Placeholder, runBatchTest) {
199+
// The input contains two slices of 4 floats each.
200+
Tensor data(ElemKind::FloatTy, {4, 4});
201+
// Fill the array with the pattern: [0 1 2 3; 10, 11, 12, 13; 20 21 22 23 ...]
202+
for (size_t i = 0; i < 4; i++) {
203+
for (size_t j = 0; j < 4; j++) {
204+
data.getHandle().at({i, j}) = i * 10 + j;
205+
}
206+
}
207+
208+
ExecutionEngine EE{BackendKind::Interpreter};
209+
auto &mod = EE.getModule();
210+
211+
Function *F = mod.createFunction("main");
212+
auto *input = mod.createPlaceholder(ElemKind::FloatTy, {1, 4}, "input");
213+
SaveNode *S = F->createSave("ret", input);
214+
215+
EE.compile(CompilationMode::Infer, F);
216+
217+
// Run the batch for 2 iterations:
218+
EE.runBatch(2, {input}, {&data});
219+
220+
Tensor expected{10, 11, 12, 13};
221+
auto EH = expected.getHandle();
222+
auto RH = S->getVariable()->getPayload().getHandle();
223+
224+
for (size_t i = 0; i < 4; i++) {
225+
EXPECT_EQ(RH.raw(i), EH.raw(i));
226+
}
227+
}
228+
196229
INSTANTIATE_TEST_CASE_P(Interpreter, BackendTest,
197230
::testing::Values(BackendKind::Interpreter));
198231

0 commit comments

Comments
 (0)