diff --git a/include/glow/Backends/DeviceManager.h b/include/glow/Backends/DeviceManager.h new file mode 100644 index 0000000000..f4651482c3 --- /dev/null +++ b/include/glow/Backends/DeviceManager.h @@ -0,0 +1,92 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GLOW_BACKENDS_DEVICEMANAGER_H +#define GLOW_BACKENDS_DEVICEMANAGER_H + +#include "glow/Backends/Backend.h" +#include "glow/Backends/CompiledFunction.h" +#include "glow/Graph/Context.h" +#include "glow/Graph/Graph.h" +#include "glow/Runtime/RuntimeTypes.h" + +#include +#include +#include + +namespace glow { + +/// Callback signalling success/failure of loading a Module onto a device. +using ReadyCBTy = std::function; +/// Callback signalling the result of running a function. +using ResultCBTy = std::function)>; +/// Map of Function name -> CompiledFunction, used when loading a network onto a +/// device. +using FunctionMapTy = std::map; + +/// Interface managing a specific instance of a device. +class DeviceManager { +protected: + /// Type of Backend for this Device. + BackendKind backend_; + +public: + DeviceManager(BackendKind backend) : backend_(backend) {} + virtual ~DeviceManager() {} + + /// Initialize the device. + virtual void init() {} + + /// Load the provided module into the device, readyCB will be called when + /// ready to use. + /// \p functions contains the list of functions to load, keyed by their name + /// (as used in runFunction). + virtual void addNetwork(const Module *module, FunctionMapTy functions, + ReadyCBTy readyCB) = 0; + + /// Remove (and delete) the provided network and all it's functions, freeing + /// up space on the device. + virtual void evictNetwork(const Module *module) = 0; + + /// Execute the named Function in an already provided network on the device. + /// functionName must match the name of a function already added. + /// Context should have all Placeholders allocated. resultCB will be called + /// with the Context results filled. + virtual runtime::RunIdentifierTy runFunction(std::string functionName, + std::unique_ptr ctx, + ResultCBTy resultCB) = 0; + + /// Stops execution and shuts down the Device. + virtual void stop(bool block = true) {} + + /// \returns the type of Backend that powers this Device. + BackendKind getBackendKind() { return backend_; } + + /// \returns the maximum memory (in bytes) available on the device. + virtual uint64_t getMaximumMemory() = 0; + + /// \returns the currently available memory (in bytes) available on the + /// device, for provisioning new networks. + virtual uint64_t getAvailableMemory() = 0; + + /// \returns true if we expect a Module with the estimated constant size will + /// fit on the device. + virtual bool isMemoryAvailable(uint64_t estimate) = 0; +}; + +} // namespace glow + +#endif // GLOW_BACKENDS_DEVICEMANAGER_H diff --git a/include/glow/Backends/QueueBackedDeviceManager.h b/include/glow/Backends/QueueBackedDeviceManager.h new file mode 100644 index 0000000000..c3ccd69323 --- /dev/null +++ b/include/glow/Backends/QueueBackedDeviceManager.h @@ -0,0 +1,78 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GLOW_BACKENDS_QUEUEBACKEDDEVICEMANAGER_H +#define GLOW_BACKENDS_QUEUEBACKEDDEVICEMANAGER_H + +#include "glow/Backends/DeviceManager.h" +#include "glow/Support/ThreadPool.h" + +#include + +namespace glow { + +class QueueBackedDeviceManager : public DeviceManager { +protected: + /// Thread which interfaces with the device. + ThreadPool workThread_; + + /// Identifier for next run. + std::atomic nextIdentifier_{1}; + +public: + QueueBackedDeviceManager(BackendKind backend); + virtual ~QueueBackedDeviceManager(); + + /// Initialize the device. + void init() override; + + /// Load the provided module into the device, readyCB will be called when + /// ready to use + void addNetwork(const Module *module, FunctionMapTy functions, + ReadyCBTy readyCB) override; + + /// Remove (and delete) the provided network and all it's functions, freeing + /// up space on the device. + void evictNetwork(const Module *module) override; + + /// Execute the named Function in an already provided network on the device. + /// functionName must match the name of a function already added. + /// Context should have all Placeholders allocated. resultCB will be called + /// with the Context results filled. + runtime::RunIdentifierTy runFunction(std::string functionName, + std::unique_ptr ctx, + ResultCBTy resultCB) override; + + /// Stops execution and shuts down the Device. + void stop(bool block = true) override; + +protected: + /// Operator handling methods to be implemented in subclasses (i.e. per Device + /// type) + + /// Load and compile the Module + virtual void addNetworkImpl(const Module *, FunctionMapTy, ReadyCBTy) = 0; + + /// Remove the module and reclaim it's memory + virtual void evictNetworkImpl(const Module *) = 0; + + /// Execute provided Function. + virtual void runFunctionImpl(runtime::RunIdentifierTy, std::string, + std::unique_ptr, ResultCBTy) = 0; +}; + +} // namespace glow + +#endif // GLOW_BACKENDS_QUEUEBACKEDDEVICEMANAGER_H diff --git a/include/glow/Runtime/RuntimeTypes.h b/include/glow/Runtime/RuntimeTypes.h new file mode 100644 index 0000000000..fbf23c5d0d --- /dev/null +++ b/include/glow/Runtime/RuntimeTypes.h @@ -0,0 +1,67 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GLOW_RUNTIME_RUNTIMETYPES_H +#define GLOW_RUNTIME_RUNTIMETYPES_H + +#include "glow/Backends/BackendUtils.h" +#include "glow/Graph/Graph.h" + +#include +#include +#include + +namespace glow { +namespace runtime { + +using DeviceIDTy = size_t; +using RunIdentifierTy = size_t; + +/// Enum to communicate results when communicating with device at initialization +/// and runtime. +enum ResultCode { Ready, Executed, Failed, Cancelled }; + +/// Data structure that contains device constraint information for each device. +/// Used to communicate memory constraints and later costs to the Partitioner. +struct DeviceInfo { + /// Available memory on device in bytes. + uint64_t availableMemory; +}; + +/// Individual Node in the DAG for a given network. This contains all the +/// information needed to run the sub-network at inference time. +struct DAGNode { + /// The children of this node, these are nodes that depend on the current + /// node. + std::vector children; + /// Pointers to the parents of this node. This is used by the executor for + /// determining if a given node has all dependencies met. + std::vector parents; + /// ID of the deviceManager that this network is assigned to. + DeviceIDTy deviceID; + /// The logicalDevice is an output of the Partitioner to indicate that two + /// networks should be assigned to the same device. + DeviceIDTy logicalDevice; + /// Name assigned to the sub-network, this is the id that will be passed to + /// the DeviceManager when requesting a run of the network. + std::string name; + /// Runtime bundle containing all the symbol information for this network at + /// runtime. + RuntimeBundle runtimeBundle; +}; + +} // namespace runtime +} // namespace glow +#endif // GLOW_RUNTIME_RUNTIMETYPES_H diff --git a/lib/Backends/CMakeLists.txt b/lib/Backends/CMakeLists.txt index 860da68d8c..180aaac9c8 100644 --- a/lib/Backends/CMakeLists.txt +++ b/lib/Backends/CMakeLists.txt @@ -27,3 +27,12 @@ target_link_libraries(Backends ${linked_backends} Base Graph) + +add_library(DeviceManager QueueBackedDeviceManager.cpp) + +target_link_libraries(DeviceManager + PRIVATE + Backends + Graph + ThreadPool) + diff --git a/lib/Backends/CPU/CMakeLists.txt b/lib/Backends/CPU/CMakeLists.txt index 676fc48236..c8afaef8c3 100644 --- a/lib/Backends/CPU/CMakeLists.txt +++ b/lib/Backends/CPU/CMakeLists.txt @@ -114,3 +114,17 @@ if(LLVM_VERSION_MAJOR VERSION_GREATER 6) LLVMOrcJIT) endif() add_dependencies(CPUBackend CPURuntime) + +add_library(CPUDeviceManager + CPUDeviceManager.cpp) +target_link_libraries(CPUDeviceManager + PRIVATE + Backends + BackendUtils + Base + CodeGen + CPUBackend + DeviceManager + Graph + IR + Optimizer) diff --git a/lib/Backends/CPU/CPUDeviceManager.cpp b/lib/Backends/CPU/CPUDeviceManager.cpp new file mode 100644 index 0000000000..cc5e2f8b8c --- /dev/null +++ b/lib/Backends/CPU/CPUDeviceManager.cpp @@ -0,0 +1,103 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "CPUDeviceManager.h" + +using namespace glow; +using namespace glow::runtime; + +uint64_t CPUDeviceManager::getMaximumMemory() { return maxMemoryBytes; } + +uint64_t CPUDeviceManager::getAvailableMemory() { + return maxMemoryBytes - usedMemoryBytes; +} + +bool CPUDeviceManager::isMemoryAvailable(uint64_t estimate) { + // No fuzz factor for the CPU device. + return maxMemoryBytes >= (usedMemoryBytes + estimate); +} + +void CPUDeviceManager::addNetworkImpl(const Module *module, + FunctionMapTy functions, + ReadyCBTy readyCB) { + auto modIt = modules_.find(module); + if (modIt != modules_.end()) { + // Already have a module with this ID. + // TODO: should we replace it? + readyCB(module, Failed); + return; + } + + // TODO: we should update usedMemory but we don't currently have a nice way + // to determine the memory used by the module. I'll come back to this, but for + // now we'll guess (badly). + size_t moduleSize = 200 * 1024 * 1024; + + if (usedMemoryBytes + moduleSize > maxMemoryBytes) { + readyCB(module, Failed); + return; + } + + // Add to the function name lookup map. + for (const auto &func : functions) { + // TODO: collect constants here when available. + functions_.emplace(func.first, func.second); + } + + modules_.emplace_hint(modIt, module, std::move(functions)); + usedMemoryBytes += moduleSize; + + // Fire the ready CB. + readyCB(module, Ready); +} + +void CPUDeviceManager::evictNetworkImpl(const Module *module) { + auto modIt = modules_.find(module); + if (modIt == modules_.end()) { + // Nothing to do. + return; + } + + FunctionMapTy moduleFuncs = std::move(modIt->second); + for (const auto &func : moduleFuncs) { + functions_.erase(func.first); + } + + modules_.erase(modIt); + usedMemoryBytes -= 200 * 1024 * 1024; // TODO: static moduleSize + assert(usedMemoryBytes >= 0); +} + +void CPUDeviceManager::runFunctionImpl(RunIdentifierTy id, std::string function, + std::unique_ptr ctx, + ResultCBTy resultCB) { + auto funcIt = functions_.find(function); + if (funcIt == functions_.end()) { + resultCB(id, Failed, std::move(ctx)); + return; + } + + CompiledFunction *func = funcIt->second; + + // Run that function. + func->setupRuns(); + func->beforeRun(*ctx); + func->execute(); + func->afterRun(*ctx); + func->tearDownRuns(); + + // Fire the resultCB. + resultCB(id, Executed, std::move(ctx)); +} diff --git a/lib/Backends/CPU/CPUDeviceManager.h b/lib/Backends/CPU/CPUDeviceManager.h new file mode 100644 index 0000000000..e2e86f1091 --- /dev/null +++ b/lib/Backends/CPU/CPUDeviceManager.h @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GLOW_BACKENDS_CPUDEVICEMANAGER_H +#define GLOW_BACKENDS_CPUDEVICEMANAGER_H + +#include "glow/Backends/QueueBackedDeviceManager.h" + +namespace glow { + +class CPUDeviceManager : public QueueBackedDeviceManager { + /// Loaded module list. + std::map modules_; + + /// Compiled function list by name. + FunctionMapTy functions_; + + /// Maximum available memory on the device, for CPU devices fix to some + /// constant. + uint64_t maxMemoryBytes{0}; + + /// Amount of memory used by all models. + uint64_t usedMemoryBytes{0}; + +public: + CPUDeviceManager(size_t MBsPerCore = 16000) + : QueueBackedDeviceManager(BackendKind::CPU), + maxMemoryBytes(MBsPerCore * 1024 * 1024) {} + + uint64_t getMaximumMemory() override; + uint64_t getAvailableMemory() override; + bool isMemoryAvailable(uint64_t estimate) override; + +protected: + void addNetworkImpl(const Module *module, FunctionMapTy functions, + ReadyCBTy cb) override; + void evictNetworkImpl(const Module *module) override; + void runFunctionImpl(runtime::RunIdentifierTy id, std::string functionName, + std::unique_ptr ctx, ResultCBTy cb) override; +}; + +} // namespace glow + +#endif // GLOW_BACKENBDS_CPUDEVICEMANAGER_H diff --git a/lib/Backends/QueueBackedDeviceManager.cpp b/lib/Backends/QueueBackedDeviceManager.cpp new file mode 100644 index 0000000000..e87a3ba94f --- /dev/null +++ b/lib/Backends/QueueBackedDeviceManager.cpp @@ -0,0 +1,59 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "glow/Backends/QueueBackedDeviceManager.h" + +using namespace glow; +using namespace glow::runtime; + +QueueBackedDeviceManager::QueueBackedDeviceManager(BackendKind backend) + : DeviceManager(backend), workThread_(1) {} + +QueueBackedDeviceManager::~QueueBackedDeviceManager() { + stop(true); // will join workThread_ +} + +void QueueBackedDeviceManager::init() {} + +void QueueBackedDeviceManager::addNetwork(const Module *module, + FunctionMapTy functions, + ReadyCBTy callback) { + workThread_.submit([this, module, f = std::move(functions), + c = std::move(callback)]() mutable { + addNetworkImpl(module, std::move(f), std::move(c)); + }); +} + +void QueueBackedDeviceManager::evictNetwork(const Module *module) { + workThread_.submit([this, module] { evictNetworkImpl(module); }); +} + +RunIdentifierTy +QueueBackedDeviceManager::runFunction(std::string functionName, + std::unique_ptr ctx, + ResultCBTy callback) { + + RunIdentifierTy id = nextIdentifier_++; + workThread_.submit([this, id, functionName = std::move(functionName), + ctx = std::move(ctx), + callback = std::move(callback)]() mutable { + runFunctionImpl(id, std::move(functionName), std::move(ctx), + std::move(callback)); + }); + return id; +} + +void QueueBackedDeviceManager::stop(bool block) { workThread_.stop(block); } diff --git a/tests/unittests/CMakeLists.txt b/tests/unittests/CMakeLists.txt index dd0363b8fe..2311419a3d 100755 --- a/tests/unittests/CMakeLists.txt +++ b/tests/unittests/CMakeLists.txt @@ -284,6 +284,22 @@ target_link_libraries(LLVMIRGenTest target_include_directories(LLVMIRGenTest PUBLIC ${CMAKE_SOURCE_DIR}/lib/Backends/CPU) add_glow_test(LLVMIRGenTest ${GLOW_BINARY_DIR}/tests/LLVMIRGenTest --gtest_output=xml:LLVMIRGenTest.xml) +add_executable(cpuDeviceTest + CPUDeviceManagerTest.cpp) +target_link_libraries(cpuDeviceTest + PRIVATE + Backends + DeviceManager + CPUDeviceManager + Graph + IR + ExecutionEngine + Optimizer + gtest + TestMain) +target_include_directories(cpuDeviceTest PUBLIC ${CMAKE_SOURCE_DIR}/lib/Backends/CPU) +add_glow_test(cpuDeviceTest ${GLOW_BINARY_DIR}/tests/cpuDeviceTest --gtest_output=xml:cpuDeviceTest.xml) + endif() add_executable(MemoryAllocatorTest diff --git a/tests/unittests/CPUDeviceManagerTest.cpp b/tests/unittests/CPUDeviceManagerTest.cpp new file mode 100644 index 0000000000..7343c92d3a --- /dev/null +++ b/tests/unittests/CPUDeviceManagerTest.cpp @@ -0,0 +1,391 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "CPUDeviceManager.h" +#include "glow/ExecutionEngine/ExecutionEngine.h" + +#include "gtest/gtest.h" + +#include +#include + +using namespace glow; +using namespace glow::runtime; +using namespace std::chrono_literals; + +std::unique_ptr makeBasicModule(std::string functionName = "main") { + std::unique_ptr module = std::make_unique(); + std::unique_ptr ctx = std::make_unique(); + + Function *F = module->createFunction(functionName); + auto *input = module->createPlaceholder(ElemKind::FloatTy, {1, 32, 32, 3}, + "input", false); + + auto *FC = F->createFullyConnected(*ctx, "fc", input, 10); + auto *RU = F->createRELU("relu", FC); + F->createSave("ret", RU); + + return module; +} + +// TODO: This really should be a helper function somewhere +void optimizeFunction(Backend *backend, CompilationMode mode, Function *F) { + // Verify the function pre-optimization/lowering. + assert(F->verify() && "Function must be valid"); + + // Optimize the graph. + ::glow::optimize(F, mode); + + // Allow the backend to transform the graph prior to lowering. + if (backend->transformPreLowering(F, mode)) { + // Optimize the graph again after the backend transformation. + // In particular, DCE is very likely to be useful. + ::glow::optimize(F, mode); + } + + // Lower the graph into a sequence of low-level linear algebra operations. + ::glow::lower(F, *backend); + + // Optimize the graph again. + ::glow::optimize(F, mode); + + // Allow the backend to transform the graph after lowering. + if (backend->transformPostLowering(F, mode)) { + // Optimize the graph again after the backend transformation. + // In particular, DCE is very likely to be useful. + ::glow::optimize(F, mode); + } +} + +FunctionMapTy +compileFunctions(Module *module, + std::vector> &backing) { + FunctionMapTy results; + auto *backend = createBackend(BackendKind::CPU); + for (auto *F : module->getFunctions()) { + optimizeFunction(backend, CompilationMode::Infer, F); + auto f = backend->compile(F); + backing.push_back(std::move(f)); + results.emplace(F->getName(), backing.back().get()); + } + + delete backend; + return results; +} + +template +std::pair, std::future> getFutureHelper() { + std::promise promise; + auto future = promise.get_future(); + return std::make_pair(std::move(promise), std::move(future)); +} + +template +void callbackHelper(std::promise &promise, ResultType res, + ResultCode result, ResultCode expected) { + promise.set_value(result == expected ? std::move(res) : ResultType()); +} + +TEST(CPUDeviceManagerTest, Basic) { + auto module = makeBasicModule(); + std::vector> backing; + FunctionMapTy functions = compileFunctions(module.get(), backing); + + CPUDeviceManager cpuCoreDevice; + cpuCoreDevice.init(); + + std::promise promise; + std::future future; + std::tie(promise, future) = getFutureHelper(); + + cpuCoreDevice.addNetwork(module.get(), std::move(functions), + [&promise](const Module *module, ResultCode result) { + callbackHelper(promise, module, result, Ready); + }); + + future.wait_for(2s); + EXPECT_EQ(future.get(), module.get()); + + std::unique_ptr ctx = std::make_unique(); + ctx->allocate(module->getPlaceholders()); + + Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3}); + updateInputPlaceholders(*ctx, {module->getPlaceholderByName("input")}, + {&inputs}); + + std::promise> runPromise; + std::future> runFuture; + + std::tie(runPromise, runFuture) = getFutureHelper>(); + cpuCoreDevice.runFunction("main", std::move(ctx), + [&runPromise](RunIdentifierTy, ResultCode result, + std::unique_ptr ctx_) { + callbackHelper(runPromise, std::move(ctx_), + result, Executed); + }); + + runFuture.wait_for(2s); + + EXPECT_NE(runFuture.get(), nullptr); +} + +TEST(CPUDeviceManagerTest, MultiRun) { + auto module = makeBasicModule(); + std::vector> backing; + FunctionMapTy functions = compileFunctions(module.get(), backing); + + CPUDeviceManager cpuCoreDevice; + cpuCoreDevice.init(); + + std::promise promise; + std::future future; + std::tie(promise, future) = getFutureHelper(); + cpuCoreDevice.addNetwork(module.get(), std::move(functions), + [&promise](const Module *module, ResultCode result) { + callbackHelper(promise, module, result, Ready); + }); + future.wait_for(2s); + EXPECT_EQ(future.get(), module.get()); + + std::unique_ptr ctx1 = std::make_unique(); + std::unique_ptr ctx2 = std::make_unique(); + ctx1->allocate(module->getPlaceholders()); + ctx2->allocate(module->getPlaceholders()); + + PseudoRNG PRNG; + Tensor inputs1(ElemKind::FloatTy, {1, 32, 32, 3}); + Tensor inputs2(ElemKind::FloatTy, {1, 32, 32, 3}); + inputs1.getHandle().randomize(-12.0, 13.0, PRNG); + inputs2.getHandle().randomize(-12.0, 13.0, PRNG); + + updateInputPlaceholders(*ctx1, {module->getPlaceholderByName("input")}, + {&inputs1}); + updateInputPlaceholders(*ctx2, {module->getPlaceholderByName("input")}, + {&inputs2}); + + std::promise> runP1, runP2; + std::future> runF1, runF2; + std::tie(runP1, runF1) = getFutureHelper>(); + std::tie(runP2, runF2) = getFutureHelper>(); + + cpuCoreDevice.runFunction("main", std::move(ctx1), + [&runP1](RunIdentifierTy, ResultCode result, + std::unique_ptr ctx_) { + callbackHelper(runP1, std::move(ctx_), result, + Executed); + }); + + cpuCoreDevice.runFunction("main", std::move(ctx2), + [&runP2](RunIdentifierTy, ResultCode result, + std::unique_ptr ctx_) { + callbackHelper(runP2, std::move(ctx_), result, + Executed); + }); + + ctx1 = runF1.get(); + ctx2 = runF2.get(); + EXPECT_NE(ctx1, ctx2); +} + +TEST(CPUDeviceManagerTest, MultiFunction) { + auto module = makeBasicModule("func1"); + + std::unique_ptr ctx1 = std::make_unique(); + ctx1->allocate(module->getPlaceholders()); + + Function *F = module->createFunction("func2"); + auto *input = module->getPlaceholderByName("input"); + auto *C = F->createConv(*ctx1, "conv2a", input, 64, 4, 1, 0, 1); + ctx1->get(llvm::cast(C->getFilter()))->getHandle().clear(0.3); + ctx1->get(llvm::cast(C->getBias()))->getHandle().clear(0.4); + F->createSave("ret2", C); + + std::vector> backing; + FunctionMapTy functions = compileFunctions(module.get(), backing); + EXPECT_EQ(functions.size(), 2); + + CPUDeviceManager cpuCoreDevice; + cpuCoreDevice.init(); + + std::promise promise; + std::future future; + std::tie(promise, future) = getFutureHelper(); + cpuCoreDevice.addNetwork(module.get(), std::move(functions), + [&promise](const Module *module, ResultCode result) { + callbackHelper(promise, module, result, Ready); + }); + future.wait_for(2s); + EXPECT_EQ(future.get(), module.get()); + + Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3}); + updateInputPlaceholders(*ctx1, {module->getPlaceholderByName("input")}, + {&inputs}); + + std::unique_ptr ctx2 = std::make_unique(ctx1->clone()); + + std::promise> runP1, runP2; + std::future> runF1, runF2; + std::tie(runP1, runF1) = getFutureHelper>(); + std::tie(runP2, runF2) = getFutureHelper>(); + + cpuCoreDevice.runFunction("func1", std::move(ctx1), + [&runP1](RunIdentifierTy, ResultCode result, + std::unique_ptr ctx_) { + callbackHelper(runP1, std::move(ctx_), result, + Executed); + }); + + cpuCoreDevice.runFunction("func2", std::move(ctx2), + [&runP2](RunIdentifierTy, ResultCode result, + std::unique_ptr ctx_) { + callbackHelper(runP2, std::move(ctx_), result, + Executed); + }); + + ctx1 = runF1.get(); + ctx2 = runF2.get(); + EXPECT_NE(ctx1, ctx2); +} + +TEST(CPUDeviceManagerTest, MultiModule) { + auto module1 = makeBasicModule("func1"); + auto module2 = makeBasicModule("func2"); + + std::vector> backing; + FunctionMapTy functions1 = compileFunctions(module1.get(), backing); + FunctionMapTy functions2 = compileFunctions(module2.get(), backing); + + CPUDeviceManager cpuCoreDevice; + cpuCoreDevice.init(); + + std::promise promise; + std::future future; + std::tie(promise, future) = getFutureHelper(); + cpuCoreDevice.addNetwork(module1.get(), std::move(functions1), + [&promise](const Module *module, ResultCode result) { + callbackHelper(promise, module, result, Ready); + }); + future.wait_for(2s); + EXPECT_EQ(future.get(), module1.get()); + + std::tie(promise, future) = getFutureHelper(); + cpuCoreDevice.addNetwork(module2.get(), std::move(functions2), + [&promise](const Module *module, ResultCode result) { + callbackHelper(promise, module, result, Ready); + }); + future.wait_for(2s); + EXPECT_EQ(future.get(), module2.get()); + + std::unique_ptr ctx1 = std::make_unique(); + ctx1->allocate(module1->getPlaceholders()); + Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3}); + updateInputPlaceholders(*ctx1, {module1->getPlaceholderByName("input")}, + {&inputs}); + + std::unique_ptr ctx2 = std::make_unique(ctx1->clone()); + ctx2->allocate(module2->getPlaceholders()); + updateInputPlaceholders(*ctx2, {module2->getPlaceholderByName("input")}, + {&inputs}); + + std::promise> runP1, runP2; + std::future> runF1, runF2; + std::tie(runP1, runF1) = getFutureHelper>(); + std::tie(runP2, runF2) = getFutureHelper>(); + + cpuCoreDevice.runFunction("func1", std::move(ctx1), + [&runP1](RunIdentifierTy, ResultCode result, + std::unique_ptr ctx_) { + callbackHelper(runP1, std::move(ctx_), result, + Executed); + }); + + cpuCoreDevice.runFunction("func2", std::move(ctx2), + [&runP2](RunIdentifierTy, ResultCode result, + std::unique_ptr ctx_) { + callbackHelper(runP2, std::move(ctx_), result, + Executed); + }); + + ctx1 = runF1.get(); + ctx2 = runF2.get(); + EXPECT_NE(ctx1, ctx2); +} + +TEST(CPUDeviceManagerTest, AvailableMemory) { + std::vector> backing; + std::promise promise; + std::future future; + + CPUDeviceManager cpuCoreDevice(200); + cpuCoreDevice.init(); + + uint64_t expectedBytes = 200 * 1024 * 1024; + EXPECT_EQ(cpuCoreDevice.getMaximumMemory(), expectedBytes); + EXPECT_EQ(cpuCoreDevice.getAvailableMemory(), expectedBytes); + EXPECT_TRUE(cpuCoreDevice.isMemoryAvailable(expectedBytes)); + EXPECT_FALSE(cpuCoreDevice.isMemoryAvailable(expectedBytes + 1)); + + auto module = makeBasicModule(); + std::tie(promise, future) = getFutureHelper(); + cpuCoreDevice.addNetwork(module.get(), + compileFunctions(module.get(), backing), + [&promise](const Module *module, ResultCode result) { + callbackHelper(promise, module, result, Ready); + }); + + future.wait_for(2s); + EXPECT_EQ(future.get(), module.get()); + + EXPECT_EQ(cpuCoreDevice.getMaximumMemory(), expectedBytes); + EXPECT_EQ(cpuCoreDevice.getAvailableMemory(), 0); + EXPECT_FALSE(cpuCoreDevice.isMemoryAvailable(expectedBytes)); + EXPECT_FALSE(cpuCoreDevice.isMemoryAvailable(1)); + + // Let's try again. + auto module2 = makeBasicModule(); + std::tie(promise, future) = getFutureHelper(); + cpuCoreDevice.addNetwork(module2.get(), + compileFunctions(module2.get(), backing), + [&promise](const Module *module, ResultCode result) { + callbackHelper(promise, module, result, Ready); + }); + + future.wait_for(2s); + auto *resultModule = future.get(); + EXPECT_NE(resultModule, module2.get()); + EXPECT_NE(resultModule, module.get()); + EXPECT_EQ(resultModule, nullptr); + + EXPECT_EQ(cpuCoreDevice.getMaximumMemory(), expectedBytes); + EXPECT_EQ(cpuCoreDevice.getAvailableMemory(), 0); + + // Evict the first network. + cpuCoreDevice.evictNetwork(module.get()); + + // And try again, this time with available space. + std::tie(promise, future) = getFutureHelper(); + cpuCoreDevice.addNetwork(module2.get(), + compileFunctions(module2.get(), backing), + [&promise](const Module *module, ResultCode result) { + callbackHelper(promise, module, result, Ready); + }); + + future.wait_for(2s); + EXPECT_EQ(future.get(), module2.get()); + + EXPECT_EQ(cpuCoreDevice.getMaximumMemory(), expectedBytes); + EXPECT_EQ(cpuCoreDevice.getAvailableMemory(), 0); +}