diff --git a/include/glow/Runtime/Executor/Executor.h b/include/glow/Runtime/Executor/Executor.h new file mode 100644 index 0000000000..fe3d8b842a --- /dev/null +++ b/include/glow/Runtime/Executor/Executor.h @@ -0,0 +1,84 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GLOW_RUNTIME_EXECUTOR_H +#define GLOW_RUNTIME_EXECUTOR_H + +#include "glow/Graph/Context.h" +#include "glow/Graph/Graph.h" + +#include +#include +#include + +namespace glow { +namespace runtime { + +/// Copied @nickg's ResultCode. I think we'll want a common one anyway. +enum class ResultCode { + EXECUTED, + FAILED, + CANCELLED, +}; + +/// This enum lists the available executors. +enum class ExecutorKind { + ThreadPool, // Executor backed by a thread pool. +}; + +/// This class contains the graph to be executed partitioned into subgraphs +/// that can be run on individual devices as well as extra information to help +/// manage execution. +struct ExecutorFunctionDAG { + // The list of functions to run for this DAG, topologically sorted. + std::list functions; + // All functions that output final results. + std::set endpoints; + // Maps from a function to its prerequisites and postrequisites. + std::map> incoming; + std::map> outgoing; + // Output placeholder names for each function. + std::map> outputs; +}; + +/// The class encapsulates the context required to run the given DAG. +struct ExecutorFunctionDAGContext { + // Partioned contexts for each function. + std::map contexts; +}; + +/// This is an interface to an executor that can run and results the results of +/// a partitioned graph. +class Executor { +public: + /// Virtual destructor. + virtual ~Executor(); + using DoneCb = std::function; + + /// Run the DAG specified by \p functionDag using Placeholder values contained + /// in \p ctx and call \cb with the results. cb will be called with a result + /// code and a Context containing placeholder-tensor mappings for the + /// Functions in \p functionDag that have no postrequisites (i.e. the final + /// results). + virtual void run(ExecutorFunctionDAG *functionDag, + ExecutorFunctionDAGContext *ctx, DoneCb cb) = 0; +}; + +/// Create a executor of kind \p kind. +Executor *createExecutor(ExecutorKind executorKind); + +} // namespace runtime +} // namespace glow +#endif diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 14da9bcade..a4f9e0aac2 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -8,5 +8,6 @@ add_subdirectory(IR) add_subdirectory(Importer) add_subdirectory(Optimizer) add_subdirectory(Quantization) +add_subdirectory(Runtime) add_subdirectory(Support) add_subdirectory(Onnxifi) diff --git a/lib/Runtime/CMakeLists.txt b/lib/Runtime/CMakeLists.txt new file mode 100644 index 0000000000..b5d0c0258f --- /dev/null +++ b/lib/Runtime/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Executor) diff --git a/lib/Runtime/Executor/CMakeLists.txt b/lib/Runtime/Executor/CMakeLists.txt new file mode 100644 index 0000000000..757da8ac32 --- /dev/null +++ b/lib/Runtime/Executor/CMakeLists.txt @@ -0,0 +1,3 @@ +add_library(Executor + Executor.cpp + ThreadPoolExecutor.cpp) diff --git a/lib/Runtime/Executor/Executor.cpp b/lib/Runtime/Executor/Executor.cpp new file mode 100644 index 0000000000..1b165be2b3 --- /dev/null +++ b/lib/Runtime/Executor/Executor.cpp @@ -0,0 +1,36 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "glow/Runtime/Executor/Executor.h" +#include "ThreadPoolExecutor.h" +#include "llvm/Support/Casting.h" + +namespace glow { +namespace runtime { + +Executor *createExecutor(ExecutorKind executorKind) { + switch (executorKind) { + case ExecutorKind::ThreadPool: + return new ThreadPoolExecutor(); + } + + // This is to make compiler happy. It can never reach this point as switch + // always covers all possible values. + llvm_unreachable("unreachable"); +} + +} // namespace runtime +} // namespace glow diff --git a/lib/Runtime/Executor/ThreadPoolExecutor.cpp b/lib/Runtime/Executor/ThreadPoolExecutor.cpp new file mode 100644 index 0000000000..529570f9e0 --- /dev/null +++ b/lib/Runtime/Executor/ThreadPoolExecutor.cpp @@ -0,0 +1,264 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ThreadPoolExecutor.h" + +namespace glow { +namespace runtime { + +ThreadPoolExecutorWorkItem::ThreadPoolExecutorWorkItem( + ExecutorFunctionDAG *dag, ExecutorFunctionDAGContext *ctx, DoneCb cb) + : cb_(cb), dag_(dag), ctx_(ctx), status_(Status::NONE), + result_(new Context()) { + it_ = (dag_->functions).begin(); +} + +bool ThreadPoolExecutorWorkItem::isMoreWork() { + std::lock_guard lock(mtx_); + // As long as the work item is in statued QUEUED, IN_PROGRESS, or FAILED, + // there is more work to be done (i.e. the caller should bother calling + // getNext()). + return (status_ != Status::DONE && status_ != Status::NONE); +} + +std::tuple ThreadPoolExecutorWorkItem::getNext() { + std::unique_lock lock(mtx_); + + bool validState = status_ == Status::QUEUED || status_ == Status::IN_PROGRESS; + + // If the item is not in state QUEUED or IN_PROGRESS, there is no function/ + // context pair to return. + if (!validState) { + // If the item has been marked as FAILED, move it to DONE state and + // invoke the callback. + if (status_ == Status::FAILED) { + status_ = Status::DONE; + lock.unlock(); + cb_(ResultCode::FAILED, nullptr); + } + return std::make_tuple(nullptr, nullptr); + } + + // Process any updates that were made since the last call to getNext() in + // order to get updated information on work item state. + processUpdates(); + + // If all items are done, move the work item to DONE state and call the + // callback with the result context. + bool allDone = completedFunctions_.size() == (dag_->functions).size(); + if (allDone) { + status_ = Status::DONE; + lock.unlock(); + cb_(ResultCode::EXECUTED, result_); + return std::make_tuple(nullptr, nullptr); + } + + // If execution reaches this point, that means there are still unfinished + // functions in this work item. However, they could be executing right now. + // In any case, update the status of the work item to IN_PROGRESS. + status_ = Status::IN_PROGRESS; + auto currentIt = it_; + + // Scan through the list of functions and find one that is not executing + // whose prerequisites are done. + do { + // If the iterator has reached the end of the list, reset it. + if (it_ == (dag_->functions).end()) { + it_ = (dag_->functions).begin(); + } + + Function *f = *it_; + + // Check if all prerequisites of the current candidate function are done. + std::list &prerequisites = (dag_->incoming).at(f); + bool allPrerequisitesFinished = true; + for (auto &prerequisite : prerequisites) { + if (!completedFunctions_.count(prerequisite)) { + allPrerequisitesFinished = false; + break; + } + } + + // If all prerequisites are done and the function is not currently being + // executed, record that it is now executing and return it. + if (allPrerequisitesFinished && !inflightFunctions_.count(f)) { + inflightFunctions_.insert(f); + Context *ctx = (ctx_->contexts).at(f); + return std::make_tuple(f, ctx); + } else { + ++it_; + } + + } while (it_ != currentIt); + + // If we make one pass through the list of functions and find there is + // nothing to run, return nothing. + return std::make_tuple(nullptr, nullptr); +} + +void ThreadPoolExecutorWorkItem::markSuccess(Function *function, + Context *context) { + std::lock_guard lock(mtx_); + updateFunctions_.insert(function); + updateContexts_.insert(context); +} + +void ThreadPoolExecutorWorkItem::markQueued() { + std::lock_guard lock(mtx_); + status_ = Status::QUEUED; +} + +void ThreadPoolExecutorWorkItem::markFailure() { + std::lock_guard lock(mtx_); + status_ = Status::FAILED; +} + +void ThreadPoolExecutorWorkItem::processUpdates() { + auto fnIt = updateFunctions_.begin(); + auto fnEnd = updateFunctions_.end(); + auto ctxIt = updateContexts_.begin(); + auto ctxEnd = updateContexts_.end(); + + while ((fnIt != fnEnd) && (ctxIt != ctxEnd)) { + Function *f = *fnIt; + Context *ctx = *ctxIt; + + // For every completed function, copy its outputs to the Context of any + // of the functions that depend on it that need that output. + std::list &outputs = (dag_->outputs).at(f); + std::list &postrequisites = (dag_->outgoing).at(f); + for (auto &output : outputs) { + for (auto &postrequisite : postrequisites) { + Module *postModule = postrequisite->getParent(); + Context *postCtx = (ctx_->contexts).at(postrequisite); + Placeholder *p; + if ((p = postModule->getPlaceholderByName(output))) { + postCtx->insert(p, ctx->get(p)->clone()); + } + } + } + + // Mark the function as completed instead of inflight/executing. + completedFunctions_.insert(f); + inflightFunctions_.erase(f); + ++ctxIt; + ++fnIt; + } +} + +ThreadPoolExecutor::ThreadPoolExecutor(unsigned numWorkers) { + // Intialize all workers and make each one run workerMain. + for (unsigned i = 0; i < numWorkers; i++) { + std::thread th(std::bind(&ThreadPoolExecutor::workerMain, this)); + workers_.emplace_back(std::move(th)); + } +} + +void ThreadPoolExecutor::run(ExecutorFunctionDAG *functionDag, + ExecutorFunctionDAGContext *ctx, DoneCb cb) { + // Create a new work item from the provided information. + auto workItem = new ThreadPoolExecutorWorkItem(functionDag, ctx, cb); + // Put the work item onto the queue and mark it as queued. Signal to any + // worker waiting for work items. + std::unique_lock lock(workQueueMtx_); + workQueue_.push(workItem); + workItem->markQueued(); + lock.unlock(); + queueNotEmpty_.notify_one(); +} + +ThreadPoolExecutor::~ThreadPoolExecutor() { + // Lock mutex before signalling for threads to stop to make sure + // a thread can't wait on the condition variable after checking the + // *old* value of shouldStop_. + std::unique_lock lock(workQueueMtx_); + + // Signal to workers to stop. + shouldStop_ = true; + + // Notify all worker threads in case any are waiting on the condition + // variable. + lock.unlock(); + queueNotEmpty_.notify_all(); + + // Join all worker threads. + for (auto &w : workers_) { + w.join(); + } + workers_.clear(); +} + +void ThreadPoolExecutor::workerMain() { + std::unique_lock lock(workQueueMtx_, std::defer_lock); + + while (!shouldStop_) { + // Lock the lock after processing a work item. + lock.lock(); + + // If work queue is empty, wait to be signalled when + // a work item is submitted. + while (workQueue_.empty() && !shouldStop_) { + queueNotEmpty_.wait(lock); + } + + // If shouldStop_ was set to false while the thread + // was asleep, break out of the main loop. + if (shouldStop_) { + break; + } + + // Pop a work item from the queue, and make sure to unlock + // the lock before processing it. + auto workItem = workQueue_.front(); + workQueue_.pop(); + lock.unlock(); + + // Process work item. + processWorkItem(workItem); + } +} + +void ThreadPoolExecutor::processWorkItem(ThreadPoolExecutorWorkItem *workItem) { + // Check if there is more work left in this work item. If not, that means + // it either succeeded or failed, and the callback + if (workItem->isMoreWork()) { + Function *f; + Context *ctx; + std::tie(f, ctx) = workItem->getNext(); + + // If there is a function and context available to work on, run it. + if (f && ctx) { + someDeviceManagerFunction( + f, ctx, [workItem, f](ResultCode resultCode, Context *ctx) { + if (resultCode == ResultCode::EXECUTED) { + workItem->markSuccess(f, ctx); + } else if (resultCode == ResultCode::FAILED) { + workItem->markFailure(); + } + }); + } + + // If isMoreWork() returned true but getNext() returned nothing, this work + // item has more work that needs doing but not until some dependencies are + // fulfilled. Requeue it so that another worker will look at it again. + std::unique_lock lock(workQueueMtx_); + workQueue_.push(workItem); + lock.unlock(); + queueNotEmpty_.notify_one(); + } +} +} // namespace runtime +} // namespace glow diff --git a/lib/Runtime/Executor/ThreadPoolExecutor.h b/lib/Runtime/Executor/ThreadPoolExecutor.h new file mode 100644 index 0000000000..75eacaadd3 --- /dev/null +++ b/lib/Runtime/Executor/ThreadPoolExecutor.h @@ -0,0 +1,160 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GLOW_RUNTIME_THREAD_POOL_EXECUTOR_H +#define GLOW_RUNTIME_THREAD_POOL_EXECUTOR_H + +#include +#include +#include +#include +#include +#include + +#include "glow/Runtime/Executor/Executor.h" + +namespace glow { +namespace runtime { + +/// This class represents a single work item of the TheadPoolExecutor. It +/// contains a DAG, an associated context, and some additional information to +/// track progress on executing the DAG. +class ThreadPoolExecutorWorkItem { +public: + using DoneCb = Executor::DoneCb; + + /// This enum represents the overall status of the work item. The transitions + /// are as follows: + /// + /// NONE ---> QUEUED ---> IN_PROGRESS ---> DONE + /// | ^ + /// | | + /// ---> FAILED --- + enum class Status { + // This work item has been created and not queued. + NONE, + // This work item has been inserted into the queue and can be worked on. + QUEUED, + // This work item is currently being worked on. Some pieces might be done. + IN_PROGRESS, + // This work item has failed, but the callback has not been called yet. + FAILED, + // This work item is done. It has either failed or succeeded, and its + // callback has been called. + DONE, + }; + + /// Constructor. \p dag is the DAG that this work item should run, and \p ctx + /// is the context that the components should run with. \p cb is the callback + /// to be called when the work item is done. + explicit ThreadPoolExecutorWorkItem(ExecutorFunctionDAG *dag, + ExecutorFunctionDAGContext *ctx, + DoneCb cb); + + /// Destructor. + ~ThreadPoolExecutorWorkItem() = default; + + /// \returns whether or not there is more work to be done on this work item. + bool isMoreWork(); + + /// \returns the next pair {function, context} that is ready for execution + /// (i.e. all prerequisites have been fulfilled and it is not already + /// being executed). + std::tuple getNext(); + + /// Mark the work item as queued. + void markQueued(); + + /// Mark that the pair {\p function, \p context} have succeeded. + void markSuccess(Function *function, Context *context); + + /// Mark the work item as failed. + void markFailure(); + +private: + /// Process the queue of updates buffered by markSuccess and update internal + /// records accordingly. + void processUpdates(); + + /// The callback to call when the work item is done. + DoneCb cb_; + /// The DAG being executed. + ExecutorFunctionDAG *dag_; + /// The context with which the DAG should be executed. + ExecutorFunctionDAGContext *ctx_; + /// An iterator into the list of functions in the DAG. + std::list::const_iterator it_; + /// The current status of the work item. + Status status_; + /// The Context object that holds the final results of execution. + Context *result_; + /// All functions that have finished executing. + std::set completedFunctions_; + /// All functions that are currently executing. + std::set inflightFunctions_; + /// All functions that have finished executing but have not been moved + /// to completedFunctions_ yet. + std::set updateFunctions_; + /// All contexts that resulted from finished executions whose contents have + /// not yet been copied into the contexts of dependent graph components or + /// into the result_ Context if applicable. + std::set updateContexts_; + /// A mutex to guard accesses to class members. + /// TODO: This can probably be eliminated by making certain members + /// std::atomic. + std::mutex mtx_; +}; + +/// This class implements the Executor interface by doing all of the work on a +/// thread pool. Each call to ThreadPoolExecutor::run() creates a stateful +/// work item that is ushered to completion though a series of state transitions +/// (see ThreadPoolExecutorWorkItem::Status) by the threads in the pool. +class ThreadPoolExecutor final : public Executor { +public: + using DoneCb = Executor::DoneCb; + /// Constructor. + explicit ThreadPoolExecutor(unsigned numWorkers = kNumWorkers); + + /// Virtual destructor. + virtual ~ThreadPoolExecutor(); + + /// See Executor::run. + void run(ExecutorFunctionDAG *functionDag, ExecutorFunctionDAGContext *ctx, + DoneCb cb) override; + +private: + /// Main loop run by workers in the thread pool. + void workerMain(); + /// Helper function for processing a work item. + void processWorkItem(ThreadPoolExecutorWorkItem *workItem); + /// The default number of workers in the thread pool. + constexpr static unsigned kNumWorkers = 3; + /// Thread pool workers. + std::list workers_; + /// Flag checked by the workers in between work items to determine + /// whether they should stop and exit. + std::atomic shouldStop_; + /// Queue of work items. + std::queue workQueue_; + /// Condition variable to signal to threads when work is added to + /// the work queue. + std::condition_variable queueNotEmpty_; + /// Mutex to coordinate access to the work queue. + std::mutex workQueueMtx_; +}; + +} // namespace runtime +} // namespace glow +#endif