diff --git a/README.md b/README.md index 64d93edf0d..8ef9bc776a 100644 --- a/README.md +++ b/README.md @@ -22,10 +22,10 @@ Note: List of partner logos sorted alphabetically column order. --> -|![Bitmain Logo](./docs/partners/bitmain.png) |![Esperanto Logo](./docs/partners/esperanto.png) | ![NXP Logo](./docs/partners/nxp.png) | +|![Bitmain Logo](./docs/partners/bitmain.png) |![Esperanto Logo](./docs/partners/esperanto.png) | ![Marvell Logo](./docs/partners/marvell.png) | :-------------------------:|:-------------------------:|:-------------------------: -|![CEVA Logo](./docs/partners/ceva.png) |![Habana Logo](./docs/partners/habana.png) | ![ST Logo](./docs/partners/st.png) | -|![Cadence Logo](./docs/partners/cadence.png) | ![Intel Logo](./docs/partners/intel.png)| | +|![CEVA Logo](./docs/partners/ceva.png) |![Habana Logo](./docs/partners/habana.png) | ![NXP Logo](./docs/partners/nxp.png) | +|![Cadence Logo](./docs/partners/cadence.png) | ![Intel Logo](./docs/partners/intel.png)| ![ST Logo](./docs/partners/st.png) | ## How does it work? diff --git a/docs/Quantization.md b/docs/Quantization.md index 820da151e1..e862b729eb 100644 --- a/docs/Quantization.md +++ b/docs/Quantization.md @@ -57,18 +57,16 @@ inference. Then, we recompile the network using this profile information to convert the network into a quantized form, allowing for static optimization of the quantized graph. We convert portions of the network into islands of integer computation and aim to generate outputs in the range that the original -floating-point network produces. During the conversion, for the following types -of quantized nodes, we ignore the output's quantization params (if they are -provided) and force the output have the same quantization params as the input +floating-point network produces. During the conversion, for the following types +of quantized nodes, we ignore the output's quantization params (if they are +provided) and force the output have the same quantization params as the input for performance purpose: ``` -LocalResponseNormalizationNode -SigmoidNode -SliceNode -ReshapeNode -TanhNode -TopKNode -GatherNode +LocalResponseNormalizationNode +SliceNode +ReshapeNode +TopKNode +GatherNode MaxPoolNode ``` @@ -129,9 +127,13 @@ the quantized text translator: ```./bin/text-translator -m en2gr -load-profile=en2gr.yaml -keep-original-precision-for-nodes=Add,Div``` -## Caffe2 Quantized Model Support +By default, target quantization precision is int8. However, precision can be +controlled via command line parameter: `quantization-precision`. There are +two supported values: `Int8` and `Int16`. -Glow is able to support Caffe2 Resnet50 quantized model: +## Caffe2 Quantized Model Support + +Glow is able to support Caffe2 Resnet50 quantized model: https://github.com/caffe2/models/tree/master/resnet50_quantized To support Caffe2 quantized models, Glow has: @@ -150,16 +152,16 @@ Int8GivenTensorFill ``` - Supported int32 quantized bias. -In most of the cases, bias is quantized in int32 to improve precision -(the partial sum of the matrix-matrix multiplication is accumulated into int32, -so int32 bias can be added to the int32 partial sum for better accuracy). -Glow now supports int32 quantized bias in ```Convolution```, ```FullyConnected``` +In most of the cases, bias is quantized in int32 to improve precision +(the partial sum of the matrix-matrix multiplication is accumulated into int32, +so int32 bias can be added to the int32 partial sum for better accuracy). +Glow now supports int32 quantized bias in ```Convolution```, ```FullyConnected``` and ```RowwiseQuantizedFullyConnected``` nodes. - Supported the conversion from uint8 quantized activations to int8 quantized activations. -For the quantized Caffe2 ops, the activations are quantized to uint8. In Glow, the -activations are quantized to int_8. Therefore, for the offset read from quantized Caffe2 +For the quantized Caffe2 ops, the activations are quantized to uint8. In Glow, the +activations are quantized to int_8. Therefore, for the offset read from quantized Caffe2 model, we need to subtract 128(i.e. INT8_MIN) to make the activations become int8. ## Compiler Optimizations @@ -189,17 +191,24 @@ For more specific graph optimizations check [here](Optimizations.md#quantization ## Row-wise Quantization -Row-wise (or channel-wise) quantization is an important way to minimize accuracy drop. -Glow supports row-wise quantized FullyConnected node ```RowwiseQuantizedFullyConnected``` -which is enabled by an image-classifier/loader option "-enable-rowwise". +Row-wise (or channel-wise) quantization is an important way to minimize accuracy +drop. Glow supports row-wise quantized FullyConnected node +```RowwiseQuantizedFullyConnected``` which is enabled by an +image-classifier/loader option "-enable-rowwise". + +For the regular quantized FC, we quantize the whole weights tensor with the same +scale and offset, which are computed based on the max and min of the entire +tensor). But for row-wise, after getting ```min_i``` and ```max_i``` for each +row ```i```, we compute the pair of ```(scale_i, offset_i)``` to quantize each +element in row ```i```. The figure below shows the quantized FC node and +RowwiseQuantizedFullyConnected node. Instead of using only one tensor to +represent the quantized weights, we need 2 extra vectors ```Scales``` and +```Offsets``` to store the ```(scale, offset)``` for each row. -For the regular quantized FC, we quantize the whole weights tensor with the same -scale and offset, which are computed based on the max and min of the entire tensor). -But for row-wise, after getting ```min_i``` and ```max_i``` for each row ```i```, we compute the pair -of ```(scale_i, offset_i)``` to quantize each element in row ```i```. The figure below shows -the quantized FC node and RowwiseQuantizedFullyConnected node. Instead of using only -one tensor to represent the quantized weights, we need 2 extra vectors ```Scales``` -and ```Offsets``` to store the ```(scale, offset)``` for each row. +![](rowwise_quantized_fc.png) -![](rowwise_quantized_fc.png) \ No newline at end of file +Row-wise quantized SparseLengthsWeightedSum is also supported. Similar to the +above, we compute scales and offsets per row, to be used with the `Data` input +for the `RowwiseQuantizedSparseLengthsSumNode`. Scales and Offsets are inputs to +the node. Output of this node is float, matching the Caffe2 implementation. diff --git a/docs/partners/marvell.png b/docs/partners/marvell.png new file mode 100644 index 0000000000..6a6edfa288 Binary files /dev/null and b/docs/partners/marvell.png differ diff --git a/include/glow/Backends/Backend.h b/include/glow/Backends/Backend.h index d7d68987a3..2e435aa87d 100644 --- a/include/glow/Backends/Backend.h +++ b/include/glow/Backends/Backend.h @@ -82,6 +82,9 @@ class Backend { /// \returns true if the Backend wants the buffer sharing optimization /// performed. virtual bool shouldShareBuffers() const { return true; } + + /// Optimize the Function \p F given compilation mode \p mode. + void optimizeFunction(CompilationMode mode, Function *F); }; /// Create a backend of kind \p kind. diff --git a/include/glow/Backends/CompiledFunction.h b/include/glow/Backends/CompiledFunction.h index a5197c8906..27483a03ba 100644 --- a/include/glow/Backends/CompiledFunction.h +++ b/include/glow/Backends/CompiledFunction.h @@ -38,21 +38,21 @@ class CompiledFunction { virtual ~CompiledFunction() = default; /// Execute the network and allocate Placeholder memory with given /// \p ctx providing mapping between Placeholder and populated tensor. - virtual void execute() = 0; + virtual void execute(Context *ctx) = 0; /// Does any needed initialization work for the Backend. /// This includes device init constant memory allocation and copying to - /// device. - virtual void setupRuns() = 0; + /// device. \deprecated + virtual void setupRuns() { runsSetup_ = true; } - /// Per run setup. Copy inputs to device. - virtual void beforeRun(const Context &ctx) = 0; + /// Per run setup. Copy inputs to device. \deprecated + virtual void beforeRun(const Context &ctx) {} - /// Per run cleanup. Copy outputs from device. - virtual void afterRun(const Context &ctx) = 0; + /// Per run cleanup. Copy outputs from device. \deprecated + virtual void afterRun(const Context &ctx) {} - /// Final cleanup. Release memory, reset device. - virtual void tearDownRuns() = 0; + /// Final cleanup. Release memory, reset device. \deprecated + virtual void tearDownRuns() { runsSetup_ = false; } /// Getter for the runtimeBundle. const runtime::RuntimeBundle &getRuntimeBundle() const { diff --git a/include/glow/ExecutionEngine/ExecutionEngine.h b/include/glow/ExecutionEngine/ExecutionEngine.h index 96e30cdc0d..fcfb861dd7 100644 --- a/include/glow/ExecutionEngine/ExecutionEngine.h +++ b/include/glow/ExecutionEngine/ExecutionEngine.h @@ -43,9 +43,6 @@ class ExecutionEngine final { /// A glow function compiled for this ExecutionEngine's backend. std::unique_ptr function_; - /// Optimize the Function \p F given compilation mode \p mode. - void optimizeFunction(CompilationMode mode, Function *F); - public: ExecutionEngine(BackendKind backendKind = BackendKind::Interpreter); diff --git a/include/glow/Graph/Graph.h b/include/glow/Graph/Graph.h index fe54bf4e37..ad7c9a1718 100644 --- a/include/glow/Graph/Graph.h +++ b/include/glow/Graph/Graph.h @@ -565,6 +565,25 @@ class Function final : public Named { NodeValue data, NodeValue weights, NodeValue indices, NodeValue lengths); + /// Create a node, performing SparseLengthsSum operation, using rowwise + /// quantization for the input data. Gathers slices of the outer-most + /// dimension of Data indexed by Indices vector, and then accumulates them + /// into len(Lengths) entries: first Lengths[0] slices are aggregated to + /// Result[0], next Lengths[1] slices are aggregated to Result[1], + /// etc. I.e. sum(Lengths) must be equal to len(Indices). + RowwiseQuantizedSparseLengthsWeightedSumNode * + createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name, Tensor &data, + NodeValue indices, NodeValue lengths); + + /// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but i-th slice is + /// multiplied by weights[i]. len(weights) must be equal to len(indices). + RowwiseQuantizedSparseLengthsWeightedSumNode * + createRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name, + Tensor &data, + NodeValue weights, + NodeValue indices, + NodeValue lengths); + /// Given a vector of segment lengths, calculates offsets of each segment and /// packs them next to the lengths. For the input vector of length N the /// output is a Nx2 matrix with (offset, lengths) packaged for each segment. diff --git a/include/glow/Partitioner/Partitioner.h b/include/glow/Partitioner/Partitioner.h new file mode 100644 index 0000000000..0e83c1bf42 --- /dev/null +++ b/include/glow/Partitioner/Partitioner.h @@ -0,0 +1,130 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GLOW_PARTITIONER_PARTITIONER_H +#define GLOW_PARTITIONER_PARTITIONER_H + +#include "glow/Graph/Graph.h" +#include "glow/Runtime/RuntimeTypes.h" + +#include "llvm/ADT/DenseMap.h" + +#include +#include +#include + +namespace glow { + +using namespace runtime; + +using MemUsageMap = std::unordered_map; + +/// Helper structure for building a partition. Records a mapping of nodes in the +/// original function to destination partitions, along with a list of the +/// newly-created functions. +class NodeToFunctionMap { + using Map = llvm::DenseMap; + + /// Newly-created partitions. + FunctionList functions_; + + /// Map of nodes in the original function to their target partition. + Map nodeToFunction_; + +public: + /// Create a new partition \p F. + void createPartition(Function *F) { functions_.emplace_back(F); } + + /// Add a new Node->Function mapping. + void add(Node *N, Function *F) { nodeToFunction_[N] = F; } + + /// Get list of functions contained in this map. + const FunctionList &getPartitions() const { return functions_; } + + /// Map API. + Map::iterator find(Node *N) { return nodeToFunction_.find(N); } + Map::iterator begin() { return nodeToFunction_.begin(); } + Map::iterator end() { return nodeToFunction_.end(); } + Function *operator[](Node *n) { return nodeToFunction_[n]; } +}; + +/// The struct contains all the created DAGNodes. This DAGNodeList owns all the +/// DAGNodes, which cannot outlive the DAGNodeList. In addition, the DAGNodes +/// can only refer to the DAGNodes from the same DAGNodeList, and they can use +/// the raw pointers to refer to each other since they are in the same +/// DAGNodeList. +struct DAGNodeList { + /// The root DAGNode pointer of each graph/function. + std::vector> roots; + /// The non-root DAGNode pointers. + std::vector> nodes; +}; + +/// Given a module, partitions each of the its functions into multiple ones +/// based on memory constraints and minimizes the communication cost. +class Partitioner { + /// The module that needs to be decomposed. + Module *module_; + + /// The representative function used for partition. We choose the function who + /// has the largest memory size. + Function *F_; + + /// The cost model related to device. + const std::vector &deviceInfo_; + + /// The result of module partitioning. + DAGNodeList partitions_; + + /// Total memory (bytes) requested by one module. + size_t memSize_; + + /// The map of each operator and the corresponding memory size. + MemUsageMap memUsage_; + + /// Get the representative function (the one with the largest input) and + /// update the memSize. + static Function *selectRepFunc(Module *parent, size_t &memSize); + + /// Get the minimal memory requirement for each op in the representive + /// function. + void initOpMemUsage(); + + /// Assign nodes to partitions and return the mapping. + NodeToFunctionMap selectPartitions(Function *F, unsigned availableMemory); + + /// Adjust a logicalDevice ID to each DAGNode. It is possible that two + /// sub-functions need to be assigned into 1 device due to the memory + /// constraits. + void adjustLogicalDeviceID(DAGNode *DAG, int num); + + /// Given the node-function mapping, do the actual partitioning. + void doPartitioning(Function *F, NodeToFunctionMap &mapping); + +public: + /// \p parent is the module which contains the functions need to be divided. + /// Here we assume that all the functions in one module belong to a same + /// "Function Family", that is, without considerting the "dynamic stuff" (i.e. + /// batch size, input/output shape of each op), all the functions are + /// identical. The required memory and computation cost for each op can be + /// found in Module. The \p devices provides the cost model related to + /// devices. + Partitioner(Module *parent, const std::vector &devices); + + /// Decompose each function in a module and return a list of DAGNodes. + DAGNodeList &Partition(); +}; +} // namespace glow +#endif // GLOW_PARTITIONER_PARTITIONER_H diff --git a/include/glow/Quantization/Base/Base.h b/include/glow/Quantization/Base/Base.h index 417f20acc7..abbe2ef85d 100644 --- a/include/glow/Quantization/Base/Base.h +++ b/include/glow/Quantization/Base/Base.h @@ -128,10 +128,11 @@ chooseQuantizationParams(float min, float max, Schema schema = Asymmetric, std::vector createMapping(TypeRef inTy, TypeRef outTy, std::function f); -/// Row-wise quantize the tensor \p input. The param \p input is a 2D -/// tensor (i.e. M * N), \p scales and \p offsets are generated by each row of -/// \p input, \p output is 2D tensor quantized from \p input using \p scales -/// and \p offsets for each row. +/// Row-wise quantize the tensor \p input. \p scales and \p offsets are +/// generated by each row of \p input, \p output is tensor of the same shape as +/// input, quantized from \p input using \p scales and \p offsets for each +/// row. Note that the shape of input/output can be any non-zero number of +/// dimensions; row refers to all data in the first dimension of the shape. void tensorRowwiseQuantization(const Tensor &input, Tensor &output, Tensor &scales, Tensor &offsets); diff --git a/include/glow/Quantization/Quantization.h b/include/glow/Quantization/Quantization.h index 8678cad4f7..3decac3fc8 100644 --- a/include/glow/Quantization/Quantization.h +++ b/include/glow/Quantization/Quantization.h @@ -53,11 +53,12 @@ struct NodeQuantizationInfo { namespace quantization { /// Generate NodeQuantizationInfo for all required nodes from function \p F -/// using the method specified by \p schema. Profiling values will be written -/// into context \p ctx. -std::vector -generateNodeQuantizationInfos(Context &ctx, const Function *F, - Schema schema = Schema::Asymmetric); +/// using the method specified by \p schema and target quantization +/// precision \p quantizationPrecision. Profiling values will be written into +/// context \p ctx. +std::vector generateNodeQuantizationInfos( + Context &ctx, const Function *F, Schema schema = Schema::Asymmetric, + ElemKind quantizationPrecision = ElemKind::Int8QTy); /// Quantizes the function \p F into a new unoptimized partially quantized /// function based on \p quantizationInfos and target quantization precision diff --git a/include/glow/Runtime/Provisioner/Provisioner.h b/include/glow/Runtime/Provisioner/Provisioner.h new file mode 100644 index 0000000000..76a380a63b --- /dev/null +++ b/include/glow/Runtime/Provisioner/Provisioner.h @@ -0,0 +1,54 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GLOW_RUNTIME_PROVISIONER_H +#define GLOW_RUNTIME_PROVISIONER_H + +#include "glow/Backends/Backend.h" +#include "glow/Backends/DeviceManager.h" +#include "glow/Runtime/RuntimeTypes.h" + +#include + +namespace glow { +namespace runtime { + +/// The Provisioner is responsible for assigning networks to an actual device. +/// It also compiles the networks before passing the compiled functions to the +/// device. +class Provisioner final { +public: + /// Walks \p networks and assigns each function to a DeviceManager in \p + /// devices. The Provisioner calls the addNetwork method for each + /// DeviceManager. Returns a ResultCode indicating if the operation was a + /// success. + ResultCode + provision(std::vector> &networks, + std::map> &devices, + Module &module); + +private: + /// Pointer to backend used for compilation. This currently gets reset per + /// device to ensure the correct backed per device. + std::unique_ptr backend_; + + /// Padding factor to account for generated code size. Should be greater + /// than 1.0. + const float NETWORK_PADDING_FACTOR = 1.1; +}; +} // namespace runtime +} // namespace glow + +#endif // GLOW_RUNTIME_PROVISIONER_H diff --git a/include/glow/Runtime/RuntimeTypes.h b/include/glow/Runtime/RuntimeTypes.h index fbf23c5d0d..d65976a1e1 100644 --- a/include/glow/Runtime/RuntimeTypes.h +++ b/include/glow/Runtime/RuntimeTypes.h @@ -31,7 +31,7 @@ using RunIdentifierTy = size_t; /// Enum to communicate results when communicating with device at initialization /// and runtime. -enum ResultCode { Ready, Executed, Failed, Cancelled }; +enum class ResultCode { Ready, Executed, Failed, Cancelled }; /// Data structure that contains device constraint information for each device. /// Used to communicate memory constraints and later costs to the Partitioner. diff --git a/include/glow/Support/Error.h b/include/glow/Support/Error.h index 1a58a179d6..b6b4d1debc 100644 --- a/include/glow/Support/Error.h +++ b/include/glow/Support/Error.h @@ -30,12 +30,6 @@ namespace glow { /// line numbers also. extern llvm::ExitOnError exitOnErr; -/// Take a message \p str and prepend it with the given \p file and \p line -/// number. This is useful for augmenting StringErrors with information about -/// where they were generated. -std::string addFileAndLineToError(llvm::StringRef str, llvm::StringRef file, - uint32_t line); - /// Is true_type only if applied to llvm::Error or a descendant. template struct IsLLVMError : public std::is_base_of {}; @@ -45,6 +39,101 @@ template struct IsLLVMExpected : public std::false_type {}; template struct IsLLVMExpected> : public std::true_type {}; +/// Represents errors in Glow. GlowErr track the file name and line number of +/// where they were created as well as a textual message and/or a error code to +/// help identify the type of error the occurred programtically. +class GlowErr final : public llvm::ErrorInfo { +public: + /// Used by ErrorInfo::classID. + static const uint8_t ID; + /// An enumeration of error codes representing various possible errors that + /// could occur. + /// NOTE: when updating this enum, also update ErrorCodeToString function + /// below. + enum class ErrorCode { + // An unknown error ocurred. This is the default value. + UNKNOWN, + // Model loader encountered an unsupported shape. + MODEL_LOADER_UNSUPPORTED_SHAPE, + // Model loader encountered an unsupported operator. + MODEL_LOADER_UNSUPPORTED_OPERATOR, + // Model loader encountered an unsupported attribute. + MODEL_LOADER_UNSUPPORTED_ATTRIBUTE, + // Model loader encountered an unsupported datatype. + MODEL_LOADER_UNSUPPORTED_DATATYPE, + // Model loader encountered an unsupported ONNX version. + MODEL_LOADER_UNSUPPORTED_ONNX_VERSION, + // Model loader encountered an invalid protobuf. + MODEL_LOADER_INVALID_PROTOBUF, + }; + + /// GlowErr is not convertable to std::error_code. This is included for + /// compatiblity with ErrorInfo. + virtual std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } + + /// Log to \p OS relevant error information including the file name and + /// line number the GlowErr was created on as well as the message and/or error + /// code the GlowErr was created with. + void log(llvm::raw_ostream &OS) const override { + OS << "file: " << fileName_ << " line: " << lineNumber_; + if (ec_ != ErrorCode::UNKNOWN) { + OS << " error code: " << errorCodeToString(ec_); + } + if (!message_.empty()) { + OS << " message: " << message_; + } + } + + GlowErr(llvm::StringRef fileName, size_t lineNumber, llvm::StringRef message, + ErrorCode ec) + : lineNumber_(lineNumber), fileName_(fileName), message_(message), + ec_(ec) {} + + GlowErr(llvm::StringRef fileName, size_t lineNumber, ErrorCode ec, + llvm::StringRef message) + : lineNumber_(lineNumber), fileName_(fileName), message_(message), + ec_(ec) {} + + GlowErr(llvm::StringRef fileName, size_t lineNumber, ErrorCode ec) + : lineNumber_(lineNumber), fileName_(fileName), ec_(ec) {} + + GlowErr(llvm::StringRef fileName, size_t lineNumber, llvm::StringRef message) + : lineNumber_(lineNumber), fileName_(fileName), message_(message) {} + +private: + /// Convert ErrorCode values to string. + static std::string errorCodeToString(const ErrorCode &ec) { + switch (ec) { + case ErrorCode::UNKNOWN: + return "UNKNOWN"; + case ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE: + return "MODEL_LOADER_UNSUPPORTED_SHAPE"; + case ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR: + return "MODEL_LOADER_UNSUPPORTED_OPERATOR"; + case ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE: + return "MODEL_LOADER_UNSUPPORTED_ATTRIBUTE"; + case ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE: + return "MODEL_LOADER_UNSUPPORTED_DATATYPE"; + case ErrorCode::MODEL_LOADER_UNSUPPORTED_ONNX_VERSION: + return "MODEL_LOADER_UNSUPPORTED_ONNX_VERSION"; + case ErrorCode::MODEL_LOADER_INVALID_PROTOBUF: + return "MODEL_LOADER_INVALID_PROTOBUF"; + }; + llvm_unreachable("unsupported ErrorCode"); + } + + /// The line number the error was generated on. + size_t lineNumber_; + /// The name of the file the error was generated in. + std::string fileName_; + /// Optional message associated with the error. + std::string message_; + /// Optional error code associated with the error. + ErrorCode ec_ = ErrorCode::UNKNOWN; +}; + /// Unwraps the T from within an llvm::Expected. If the Expected contains /// an error, the program will exit. #define EXIT_ON_ERR(...) (exitOnErr(__VA_ARGS__)) @@ -56,15 +145,12 @@ struct IsLLVMExpected> : public std::true_type {}; #define TEMP_EXIT_ON_ERR(...) (EXIT_ON_ERR(__VA_ARGS__)) /// Make a new llvm::StringError. -#define MAKE_ERR(str) \ - llvm::make_error( \ - (addFileAndLineToError(str, __FILE__, __LINE__)), \ - llvm::inconvertibleErrorCode()) +#define MAKE_ERR(...) llvm::make_error(__FILE__, __LINE__, __VA_ARGS__) /// Makes a new llvm::StringError and returns it. -#define RETURN_ERR(str) \ +#define RETURN_ERR(...) \ do { \ - return MAKE_ERR(str); \ + return MAKE_ERR(__VA_ARGS__); \ } while (0) /// Takes an llvm::Expected \p lhsOrErr and if it is an Error then returns @@ -94,10 +180,10 @@ struct IsLLVMExpected> : public std::true_type {}; /// Takes a predicate \p and if it is false then creates a new llvm::StringError /// and returns it. -#define RETURN_ERR_IF_NOT(p, str) \ +#define RETURN_ERR_IF_NOT(p, ...) \ do { \ if (!(p)) { \ - RETURN_ERR(str); \ + RETURN_ERR(__VA_ARGS__); \ } \ } while (0) } // end namespace glow diff --git a/include/glow/Support/Support.h b/include/glow/Support/Support.h index b67522f653..92000b330c 100644 --- a/include/glow/Support/Support.h +++ b/include/glow/Support/Support.h @@ -110,6 +110,10 @@ class DescriptionBuilder { void report(const char *msg); inline void report(const std::string &str) { report(str.c_str()); } inline void report(llvm::StringRef str) { report(str.data()); } + +/// Printf-like formatting for std::string. +const std::string strFormat(const char *format, ...) + __attribute__((__format__(__printf__, 1, 2))); } // namespace glow #endif // GLOW_SUPPORT_SUPPORT_H diff --git a/lib/Backends/Backend.cpp b/lib/Backends/Backend.cpp new file mode 100644 index 0000000000..5238b43cbe --- /dev/null +++ b/lib/Backends/Backend.cpp @@ -0,0 +1,49 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "glow/Backends/Backend.h" +#include "glow/Graph/Graph.h" +#include "glow/Optimizer/Optimizer.h" + +using namespace glow; + +void Backend::optimizeFunction(CompilationMode mode, Function *F) { + // Verify the function pre-optimization/lowering. + assert(F->verify() && "Function must be valid"); + + // Optimize the graph. + ::glow::optimize(F, mode); + + // Allow the backend to transform the graph prior to lowering. + if (transformPreLowering(F, mode)) { + // Optimize the graph again after the backend transformation. + // In particular, DCE is very likely to be useful. + ::glow::optimize(F, mode); + } + + // Lower the graph into a sequence of low-level linear algebra operations. + ::glow::lower(F, *this); + + // Optimize the graph again. + ::glow::optimize(F, mode); + + // Allow the backend to transform the graph after lowering. + if (transformPostLowering(F, mode)) { + // Optimize the graph again after the backend transformation. + // In particular, DCE is very likely to be useful. + ::glow::optimize(F, mode); + } +} diff --git a/lib/Backends/CMakeLists.txt b/lib/Backends/CMakeLists.txt index 180aaac9c8..bb4ed77d3e 100644 --- a/lib/Backends/CMakeLists.txt +++ b/lib/Backends/CMakeLists.txt @@ -1,4 +1,6 @@ -add_library(Backends Backends.cpp) +add_library(Backends + Backend.cpp + Backends.cpp) add_library(BackendUtils BackendUtils.cpp) diff --git a/lib/Backends/CPU/CPUDeviceManager.cpp b/lib/Backends/CPU/CPUDeviceManager.cpp index cc5e2f8b8c..37e33514c1 100644 --- a/lib/Backends/CPU/CPUDeviceManager.cpp +++ b/lib/Backends/CPU/CPUDeviceManager.cpp @@ -36,7 +36,7 @@ void CPUDeviceManager::addNetworkImpl(const Module *module, if (modIt != modules_.end()) { // Already have a module with this ID. // TODO: should we replace it? - readyCB(module, Failed); + readyCB(module, ResultCode::Failed); return; } @@ -46,7 +46,7 @@ void CPUDeviceManager::addNetworkImpl(const Module *module, size_t moduleSize = 200 * 1024 * 1024; if (usedMemoryBytes + moduleSize > maxMemoryBytes) { - readyCB(module, Failed); + readyCB(module, ResultCode::Failed); return; } @@ -60,7 +60,7 @@ void CPUDeviceManager::addNetworkImpl(const Module *module, usedMemoryBytes += moduleSize; // Fire the ready CB. - readyCB(module, Ready); + readyCB(module, ResultCode::Ready); } void CPUDeviceManager::evictNetworkImpl(const Module *module) { @@ -85,19 +85,15 @@ void CPUDeviceManager::runFunctionImpl(RunIdentifierTy id, std::string function, ResultCBTy resultCB) { auto funcIt = functions_.find(function); if (funcIt == functions_.end()) { - resultCB(id, Failed, std::move(ctx)); + resultCB(id, ResultCode::Failed, std::move(ctx)); return; } CompiledFunction *func = funcIt->second; // Run that function. - func->setupRuns(); - func->beforeRun(*ctx); - func->execute(); - func->afterRun(*ctx); - func->tearDownRuns(); + func->execute(ctx.get()); // Fire the resultCB. - resultCB(id, Executed, std::move(ctx)); + resultCB(id, ResultCode::Executed, std::move(ctx)); } diff --git a/lib/Backends/CPU/CPUFunction.cpp b/lib/Backends/CPU/CPUFunction.cpp index 16d00439a1..ffc36f10f8 100644 --- a/lib/Backends/CPU/CPUFunction.cpp +++ b/lib/Backends/CPU/CPUFunction.cpp @@ -30,42 +30,29 @@ CPUFunction::~CPUFunction() { tearDownRuns(); } -void CPUFunction::setupRuns() { - if (!runsSetup_) { - if (runtimeBundle_.getActivationsSize() != 0) { - baseActivationsAddress_ = (uint8_t *)alignedAlloc( - runtimeBundle_.getActivationsSize(), TensorAlignment); - } - - if (runtimeBundle_.getMutableWeightSize() != 0) { - baseMutableWeightVarsAddress_ = (uint8_t *)alignedAlloc( - runtimeBundle_.getMutableWeightSize(), TensorAlignment); - } - runsSetup_ = true; - } -} - void CPUFunction::collectConstants(IRFunction *F) { runtimeBundle_.collectConstants(F); } -void CPUFunction::beforeRun(const Context &ctx) { +void CPUFunction::loadPlaceholders(Context *ctx, + uint8_t *baseMutableWeightVarsAddress) { // Copy Placeholders into allocated memory. - for (auto PH : ctx.pairs()) { + for (auto PH : ctx->pairs()) { auto payload = PH.second->getUnsafePtr(); auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first); auto addr = symbolInfo.offset; auto numBytes = symbolInfo.size; // copy PH to allocated memory. - memcpy(baseMutableWeightVarsAddress_ + addr, payload, numBytes); + memcpy(baseMutableWeightVarsAddress + addr, payload, numBytes); } } -void CPUFunction::afterRun(const Context &ctx) { +void CPUFunction::updatePlaceholders(Context *ctx, + uint8_t *baseMutableWeightVarsAddress) { // Copy placeholders from device back into context. - for (auto PH : ctx.pairs()) { + for (auto PH : ctx->pairs()) { auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first); - auto payload = baseMutableWeightVarsAddress_ + symbolInfo.offset; + auto payload = baseMutableWeightVarsAddress + symbolInfo.offset; auto numBytes = symbolInfo.size; auto addr = PH.second->getUnsafePtr(); // copy PH from allocated memory. @@ -73,20 +60,25 @@ void CPUFunction::afterRun(const Context &ctx) { } } -void CPUFunction::tearDownRuns() { - if (baseMutableWeightVarsAddress_) { - alignedFree(baseMutableWeightVarsAddress_); - baseMutableWeightVarsAddress_ = nullptr; +void CPUFunction::execute(Context *ctx) { + /// Base address for Activations memory block. + uint8_t *baseActivationsAddress{nullptr}; + + /// Base address for Mutable weights memory block, Inputs and Outputs. + uint8_t *baseMutableWeightVarsAddress{nullptr}; + + if (runtimeBundle_.getActivationsSize() != 0) { + baseActivationsAddress = (uint8_t *)alignedAlloc( + runtimeBundle_.getActivationsSize(), TensorAlignment); } - if (baseActivationsAddress_) { - alignedFree(baseActivationsAddress_); - baseActivationsAddress_ = nullptr; + if (runtimeBundle_.getMutableWeightSize() != 0) { + baseMutableWeightVarsAddress = (uint8_t *)alignedAlloc( + runtimeBundle_.getMutableWeightSize(), TensorAlignment); } - runsSetup_ = false; -} -void CPUFunction::execute() { + loadPlaceholders(ctx, baseMutableWeightVarsAddress); + auto sym = JIT_->findSymbol("jitmain"); assert(sym && "Unable to JIT the code!"); using JitFuncType = @@ -95,9 +87,14 @@ void CPUFunction::execute() { auto address = sym.getAddress(); if (address) { JitFuncType funcPtr = reinterpret_cast(address.get()); - funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress_, - baseActivationsAddress_); + funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress, + baseActivationsAddress); } else { GLOW_ASSERT(false && "Error getting address."); } + + updatePlaceholders(ctx, baseMutableWeightVarsAddress); + + alignedFree(baseMutableWeightVarsAddress); + alignedFree(baseActivationsAddress); } diff --git a/lib/Backends/CPU/CPUFunction.h b/lib/Backends/CPU/CPUFunction.h index c1ca4fdf4b..d35277f39b 100644 --- a/lib/Backends/CPU/CPUFunction.h +++ b/lib/Backends/CPU/CPUFunction.h @@ -28,12 +28,6 @@ class CPUFunction final : public CompiledFunction { /// initializes the LLVM backends. std::unique_ptr JIT_; - /// Base address for Activations memory block. - uint8_t *baseActivationsAddress_{}; - - /// Base address for Mutable weights memory block, Inputs and Outputs. - uint8_t *baseMutableWeightVarsAddress_{}; - public: /// Ctor. CPUFunction(std::unique_ptr JIT, @@ -42,24 +36,19 @@ class CPUFunction final : public CompiledFunction { /// Collects constants for runtime. void collectConstants(IRFunction *F); - /// Allocate Mutable buffers on device this includes Activations and - /// Placeholders. - void setupRuns() override; - - /// Copy Input Placeholder data to position. - void beforeRun(const Context &ctx) override; - - /// Copy Outputs to Placeholders in \p ctx. - void afterRun(const Context &ctx) override; - - /// Final cleanup, free all allocations. - void tearDownRuns() override; - /// \name CompiledFunction interface ///@{ ~CPUFunction() override; - void execute() override; + void execute(Context *ctx) override; ///@} +private: + /// Load constant tensors from \p ctx into \p weightsAddress, as defined by + /// the RuntimeBundle (pre-run). + void loadPlaceholders(Context *ctx, uint8_t *weightsAddress); + + /// Load weights from \p weightsAddress into applicable backing tensors in + /// \p ctx, as defined by the RuntimeBundle (post-run). + void updatePlaceholders(Context *ctx, uint8_t *weightsAddress); }; } // end namespace glow diff --git a/lib/Backends/CPU/LLVMIRGen.cpp b/lib/Backends/CPU/LLVMIRGen.cpp index 6df7f8aa59..dd1857a258 100644 --- a/lib/Backends/CPU/LLVMIRGen.cpp +++ b/lib/Backends/CPU/LLVMIRGen.cpp @@ -2292,6 +2292,32 @@ void LLVMIRGen::generateLLVMIRForInstr(llvm::IRBuilder<> &builder, break; } + case Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumInstKind: { + auto *N = cast(I); + auto *dest = N->getDest(); + auto *data = N->getData(); + auto *scales = N->getScales(); + auto *offsets = N->getOffsets(); + auto *weights = N->getWeights(); + auto *indices = N->getIndices(); + auto *lengths = N->getLengths(); + auto *destPtr = emitValueAddress(builder, dest); + auto *dataPtr = emitValueAddress(builder, data); + auto *scalesPtr = emitValueAddress(builder, scales); + auto *offsetsPtr = emitValueAddress(builder, offsets); + auto *weightsPtr = emitValueAddress(builder, weights); + auto *indicesPtr = emitValueAddress(builder, indices); + auto *lengthsPtr = emitValueAddress(builder, lengths); + auto *segments = emitConstSizeT(builder, lengths->dims()[0]); + auto *lineSize = emitConstSizeT(builder, data->size() / data->dims()[0]); + auto *F = getFunction("rowwise_quantized_sparse_lengths_weighted_sum", + dest->getElementType()); + createCall(builder, F, + {destPtr, dataPtr, scalesPtr, offsetsPtr, weightsPtr, indicesPtr, + lengthsPtr, segments, lineSize}); + break; + } + case Kinded::Kind::SparseToDenseInstKind: { auto *STDI = llvm::cast(I); auto *indices = STDI->getIndices(); diff --git a/lib/Backends/CPU/libjit/libjit.cpp b/lib/Backends/CPU/libjit/libjit.cpp index aaa992ab89..4b8fbece6f 100644 --- a/lib/Backends/CPU/libjit/libjit.cpp +++ b/lib/Backends/CPU/libjit/libjit.cpp @@ -1037,6 +1037,26 @@ void libjit_sparse_lengths_weighted_sum_f(float *dest, float *data, } } +void libjit_rowwise_quantized_sparse_lengths_weighted_sum_f( + float *dest, int8_t *data, float *scales, int32_t *offsets, float *weights, + size_t *indices, int32_t *lengths, size_t segments, size_t lineSize) { + memset(dest, 0, segments * lineSize * sizeof(float)); + size_t curIndex = 0; + for (size_t i = 0; i < segments; i++) { + for (int32_t j = 0; j < lengths[i]; j++) { + float weight = weights[curIndex]; + size_t line = indices[curIndex]; + const float scale = scales[line]; + const int32_t offset = offsets[line]; + for (size_t k = 0; k < lineSize; k++) { + const float fData = scale * (data[line * lineSize + k] - offset); + dest[i * lineSize + k] += weight * fData; + } + curIndex++; + } + } +} + void libjit_sparse_to_dense_f(float *dest, const size_t *indices, const float *values, size_t numIndices, size_t destSize, size_t valueSize) { diff --git a/lib/Backends/Interpreter/InterpreterFunction.cpp b/lib/Backends/Interpreter/InterpreterFunction.cpp index c95b7386ad..6ab0f503bc 100644 --- a/lib/Backends/Interpreter/InterpreterFunction.cpp +++ b/lib/Backends/Interpreter/InterpreterFunction.cpp @@ -22,6 +22,7 @@ #include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" using namespace glow; InterpreterFunction::InterpreterFunction(std::unique_ptr F, @@ -29,22 +30,18 @@ InterpreterFunction::InterpreterFunction(std::unique_ptr F, : CompiledFunction(bundle), F_(std::move(F)) {} InterpreterFunction::~InterpreterFunction() { - // Delete the tensors that are owned by this backend. - for (const auto &p : tensors_) { + for (const auto &p : constants_) { delete p.second; } - tensors_.clear(); - externalTensors_.clear(); + constants_.clear(); + alignedFree(runtimeBundle_.getConstants()); tearDownRuns(); } void InterpreterFunction::collectConstants(IRFunction *F) { runtimeBundle_.collectConstants(F); -} - -void InterpreterFunction::setupRuns() { - if (!runsSetup_) { + if (constants_.empty()) { if (runtimeBundle_.getConstantWeightSize()) { for (const auto &v : F_->getGraph()->getParent()->getConstants()) { auto symbolInfo = runtimeBundle_.getSymbolInfo(v); @@ -53,36 +50,27 @@ void InterpreterFunction::setupRuns() { constants_.emplace(std::string(v->getName()), tensor); } } - runsSetup_ = true; - } -} - -void InterpreterFunction::beforeRun(const Context &ctx) { - // Register the concrete tensors that back the placeholder tensors. - for (auto &ph : ctx.pairs()) { - auto *w = F_->getWeightForNode(ph.first); - assert(!externalTensors_.count(w) && "The tensor is already registered"); - externalTensors_[w] = ph.second; } } -void InterpreterFunction::afterRun(const Context &ctx) { - // Remove the concrete tensors that back the placeholder tensors. - for (auto &ph : ctx.pairs()) { - auto *w = F_->getWeightForNode(ph.first); - externalTensors_.erase(w); +void InterpreterFunction::execute(Context *ctx) { + if (constants_.empty()) { + collectConstants(F_.get()); } + BoundInterpreterFunction boundFunc(constants_); + boundFunc.execute(F_.get(), ctx); } -void InterpreterFunction::tearDownRuns() { - for (const auto &p : constants_) { +BoundInterpreterFunction::~BoundInterpreterFunction() { + // Delete the tensors that are owned by this backend. + for (const auto &p : tensors_) { delete p.second; } - constants_.clear(); - runsSetup_ = false; + tensors_.clear(); + externalTensors_.clear(); } -Tensor *InterpreterFunction::getTensor(const Value *v) const { +Tensor *BoundInterpreterFunction::getTensor(const Value *v) const { auto it = tensors_.find(v); if (it != tensors_.end()) { return it->second; @@ -97,7 +85,7 @@ Tensor *InterpreterFunction::getTensor(const Value *v) const { return ie->second; } -Tensor *InterpreterFunction::getOrCreateTensor(const Value *v) { +Tensor *BoundInterpreterFunction::getOrCreateTensor(const Value *v) { auto ie = externalTensors_.find(v); if (ie != externalTensors_.end()) { return ie->second; @@ -117,9 +105,8 @@ Tensor *InterpreterFunction::getOrCreateTensor(const Value *v) { return it->second; } -Tensor * -InterpreterFunction::getOrCreateUnownedTensor(const Value *v, const Value *src, - llvm::ArrayRef offsets) { +Tensor *BoundInterpreterFunction::getOrCreateUnownedTensor( + const Value *v, const Value *src, llvm::ArrayRef offsets) { assert(llvm::isa(v) && "Expected a tensor view"); // Pick the tensor. @@ -136,7 +123,7 @@ InterpreterFunction::getOrCreateUnownedTensor(const Value *v, const Value *src, return T; } -void InterpreterFunction::deleteTensor(const Value *v) { +void BoundInterpreterFunction::deleteTensor(const Value *v) { auto it = tensors_.find(v); if (it == tensors_.end()) { return; @@ -146,7 +133,14 @@ void InterpreterFunction::deleteTensor(const Value *v) { tensors_.erase(it); } -void InterpreterFunction::execute() { +void BoundInterpreterFunction::execute(IRFunction *F, Context *ctx) { + // Register the concrete tensors that back the placeholder tensors. + for (auto &ph : ctx->pairs()) { + auto *w = F->getWeightForNode(ph.first); + assert(!externalTensors_.count(w) && "The tensor is already registered"); + externalTensors_[w] = ph.second; + } + // Do the forward pass. #define DEF_VALUE(CLASS, NAME) #define DEF_INSTR(CLASS, NAME) \ @@ -156,7 +150,7 @@ void InterpreterFunction::execute() { } #define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) // Dispatch the interpreter on each instruction in the program: - for (const auto &I : F_->getInstrs()) { + for (const auto &I : F->getInstrs()) { switch (I.getKind()) { #include "glow/AutoGenInstr.def" @@ -164,4 +158,10 @@ void InterpreterFunction::execute() { llvm_unreachable("Invalid instruction."); } } + + // Remove the concrete tensors that back the placeholder tensors. + for (auto &ph : ctx->pairs()) { + auto *w = F->getWeightForNode(ph.first); + externalTensors_.erase(w); + } } diff --git a/lib/Backends/Interpreter/InterpreterFunction.h b/lib/Backends/Interpreter/InterpreterFunction.h index 9fb49aa187..8ac6832e1a 100644 --- a/lib/Backends/Interpreter/InterpreterFunction.h +++ b/lib/Backends/Interpreter/InterpreterFunction.h @@ -45,10 +45,7 @@ class Constant; class InterpreterFunction final : public CompiledFunction { /// The IR to be executed. std::unique_ptr F_; - /// Maps values to Tensors, that are owned by this class. - std::unordered_map tensors_; - /// Maps values to Tensors, that are *not* owned by this class. - std::unordered_map externalTensors_; + /// Maps Value.name to tensors for constants. std::unordered_map constants_; @@ -60,29 +57,34 @@ class InterpreterFunction final : public CompiledFunction { ///@{ ~InterpreterFunction() override; - /// Does any needed initialization work for the Backend, creates tensors from - /// constants. - /// Collects constants for runtime. void collectConstants(IRFunction *F); - void setupRuns() override; - - /// Per run setup, adds references for tensors from \p ctx to - /// externalTensors_. - void beforeRun(const Context &ctx) override; - - /// Per run cleanup, removes references for tensors from \p ctx from - /// externalTensors_. - void afterRun(const Context &ctx) override; + void execute(Context *ctx) override; - /// Final cleanup, remove created constant Tensors. - void tearDownRuns() override; - - void execute() override; /// Get reference to IR function. IRFunction *getIR() { return F_.get(); } ///@} +}; + +/// An InterpreterFunction bound to a specific invocation. +class BoundInterpreterFunction { + /// Maps values to Tensors, that are owned by this class. + std::unordered_map tensors_; + + /// Maps values to Tensors, that are *not* owned by this class. + std::unordered_map externalTensors_; + + /// A reference to the constant map from the owning InterpreterFunction. + const std::unordered_map &constants_; + +public: + BoundInterpreterFunction( + const std::unordered_map &constants) + : constants_(constants) {} + ~BoundInterpreterFunction(); + + void execute(IRFunction *F, Context *ctx); private: /// \returns a pointer to the tensor that is saved under \p v. @@ -108,8 +110,9 @@ class InterpreterFunction final : public CompiledFunction { return getTensor(v)->getHandle(); } - /// @name Interpreter methods. This is a list of method declerations that are - /// used by the interpreter to dispatch different instructions. + /// @name BoundInterpreterFunction methods. This is a list of method + /// declerations that are used by the interpreter to dispatch different + /// instructions. ///@{ #define DEF_VALUE(CLASS, NAME) diff --git a/lib/Backends/Interpreter/InterpreterNodes.cpp b/lib/Backends/Interpreter/InterpreterNodes.cpp index fd125e7bfd..7f3648b0d5 100644 --- a/lib/Backends/Interpreter/InterpreterNodes.cpp +++ b/lib/Backends/Interpreter/InterpreterNodes.cpp @@ -78,7 +78,7 @@ using namespace glow; /// This is the floating point implementation of Convolution. template -void InterpreterFunction::fwdConvolutionInstFloatImpl( +void BoundInterpreterFunction::fwdConvolutionInstFloatImpl( Value *inV, Value *outV, Value *filterV, Value *biasV, llvm::ArrayRef kernelSizes, llvm::ArrayRef strides, llvm::ArrayRef pads, size_t group) { @@ -148,7 +148,7 @@ void InterpreterFunction::fwdConvolutionInstFloatImpl( /// This is the quantized implementation of Convolution. /// For bias, we support int32 quantization. template -void InterpreterFunction::fwdConvolutionInstQuantizedImpl( +void BoundInterpreterFunction::fwdConvolutionInstQuantizedImpl( Value *inV, Value *outV, Value *filterV, Value *biasV, llvm::ArrayRef kernelSizes, llvm::ArrayRef strides, llvm::ArrayRef pads, size_t group) { @@ -242,7 +242,7 @@ void InterpreterFunction::fwdConvolutionInstQuantizedImpl( } // N } -void InterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) { +void BoundInterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) { auto kernelSizes = I->getKernels(); auto pads = I->getPads(); auto strides = I->getStrides(); @@ -262,7 +262,8 @@ void InterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) { kernelSizes, strides, pads, group); } -void InterpreterFunction::fwdConvolutionGradInst(const ConvolutionGradInst *I) { +void BoundInterpreterFunction::fwdConvolutionGradInst( + const ConvolutionGradInst *I) { auto inW = getWeightHandle(I->getSrc()); auto inG = getWeightHandle(I->getSrcGrad()); auto outG = getWeightHandle(I->getDestGrad()); @@ -400,7 +401,7 @@ static void fwdMaxPool(Tensor *inW, Tensor *outW, Handle *SXY, } // N } -void InterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) { +void BoundInterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) { auto inW = getTensor(I->getSrc()); auto outW = getTensor(I->getDest()); @@ -416,7 +417,8 @@ void InterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) { I->getPads()); } -void InterpreterFunction::fwdMaxPoolWithXYInst(const MaxPoolWithXYInst *I) { +void BoundInterpreterFunction::fwdMaxPoolWithXYInst( + const MaxPoolWithXYInst *I) { auto inW = getTensor(I->getSrc()); auto outW = getTensor(I->getDest()); auto SXY = getWeightHandle(I->getSrcXY()); @@ -433,7 +435,7 @@ void InterpreterFunction::fwdMaxPoolWithXYInst(const MaxPoolWithXYInst *I) { } template -void InterpreterFunction::fwdAvgPoolInstFloatImpl(const AvgPoolInst *I) { +void BoundInterpreterFunction::fwdAvgPoolInstFloatImpl(const AvgPoolInst *I) { staticAssertFloatingPointType(ElemTy); ShapeNHWC odim(I->getDest()->dims()); @@ -481,7 +483,7 @@ void InterpreterFunction::fwdAvgPoolInstFloatImpl(const AvgPoolInst *I) { } // N } -void InterpreterFunction::fwdAvgPoolInstI8Impl(const AvgPoolInst *I) { +void BoundInterpreterFunction::fwdAvgPoolInstI8Impl(const AvgPoolInst *I) { ShapeNHWC odim(I->getDest()->dims()); ShapeNHWC idim(I->getSrc()->dims()); @@ -534,7 +536,7 @@ void InterpreterFunction::fwdAvgPoolInstI8Impl(const AvgPoolInst *I) { } // N } -void InterpreterFunction::fwdAvgPoolInst(const AvgPoolInst *I) { +void BoundInterpreterFunction::fwdAvgPoolInst(const AvgPoolInst *I) { if (I->getSrc()->getType()->isQuantizedType()) { fwdAvgPoolInstI8Impl(I); return; @@ -544,7 +546,7 @@ void InterpreterFunction::fwdAvgPoolInst(const AvgPoolInst *I) { I->getSrc()->getElementType(), I); } -void InterpreterFunction::fwdMaxPoolWithXYGradInst( +void BoundInterpreterFunction::fwdMaxPoolWithXYGradInst( const MaxPoolWithXYGradInst *I) { auto inG = getWeightHandle(I->getSrcGrad()); auto outW = getWeightHandle(I->getDest()); @@ -578,7 +580,7 @@ void InterpreterFunction::fwdMaxPoolWithXYGradInst( } // N } -void InterpreterFunction::fwdAvgPoolGradInst(const AvgPoolGradInst *I) { +void BoundInterpreterFunction::fwdAvgPoolGradInst(const AvgPoolGradInst *I) { auto inG = getWeightHandle(I->getSrcGrad()); auto outW = getWeightHandle(I->getDest()); auto outG = getWeightHandle(I->getDestGrad()); @@ -630,7 +632,7 @@ void InterpreterFunction::fwdAvgPoolGradInst(const AvgPoolGradInst *I) { // Activation functions //===----------------------------------------------------------------------===// template -void InterpreterFunction::fwdSigmoidInstFloatImpl(const SigmoidInst *I) { +void BoundInterpreterFunction::fwdSigmoidInstFloatImpl(const SigmoidInst *I) { staticAssertFloatingPointType(ElemTy); auto inW = getWeightHandle(I->getSrc()); @@ -642,13 +644,13 @@ void InterpreterFunction::fwdSigmoidInstFloatImpl(const SigmoidInst *I) { } } -void InterpreterFunction::fwdSigmoidInst(const SigmoidInst *I) { +void BoundInterpreterFunction::fwdSigmoidInst(const SigmoidInst *I) { dispatchFloatingPointImpl(fwdSigmoidInstFloatImpl, I->getSrc()->getElementType(), I); } template -void InterpreterFunction::fwdTanhInstFloatImpl(const TanhInst *I) { +void BoundInterpreterFunction::fwdTanhInstFloatImpl(const TanhInst *I) { staticAssertFloatingPointType(ElemTy); auto inW = getWeightHandle(I->getSrc()); @@ -660,7 +662,7 @@ void InterpreterFunction::fwdTanhInstFloatImpl(const TanhInst *I) { } } -void InterpreterFunction::fwdTanhInst(const TanhInst *I) { +void BoundInterpreterFunction::fwdTanhInst(const TanhInst *I) { dispatchFloatingPointImpl(fwdTanhInstFloatImpl, I->getSrc()->getElementType(), I); } @@ -670,7 +672,7 @@ void InterpreterFunction::fwdTanhInst(const TanhInst *I) { //===----------------------------------------------------------------------===// template -void InterpreterFunction::fwdSoftMaxInstImpl(const SoftMaxInst *I) { +void BoundInterpreterFunction::fwdSoftMaxInstImpl(const SoftMaxInst *I) { staticAssertFloatingPointType(ElemTy); auto inW = getWeightHandle(I->getSrc()); @@ -699,12 +701,12 @@ void InterpreterFunction::fwdSoftMaxInstImpl(const SoftMaxInst *I) { } // N } -void InterpreterFunction::fwdSoftMaxInst(const SoftMaxInst *I) { +void BoundInterpreterFunction::fwdSoftMaxInst(const SoftMaxInst *I) { dispatchFloatingPointImpl(fwdSoftMaxInstImpl, I->getSrc()->getElementType(), I); } -void InterpreterFunction::fwdSoftMaxGradInst(const SoftMaxGradInst *I) { +void BoundInterpreterFunction::fwdSoftMaxGradInst(const SoftMaxGradInst *I) { auto inG = getWeightHandle(I->getSrcGrad()); auto idim = inG.dims(); auto outW = getWeightHandle(I->getOrigDest()); @@ -723,7 +725,7 @@ void InterpreterFunction::fwdSoftMaxGradInst(const SoftMaxGradInst *I) { } template -void InterpreterFunction::fwdCrossEntropyLossInstFloatImpl( +void BoundInterpreterFunction::fwdCrossEntropyLossInstFloatImpl( const CrossEntropyLossInst *I) { staticAssertFloatingPointType(ElemTy); @@ -739,13 +741,13 @@ void InterpreterFunction::fwdCrossEntropyLossInstFloatImpl( } } -void InterpreterFunction::fwdCrossEntropyLossInst( +void BoundInterpreterFunction::fwdCrossEntropyLossInst( const CrossEntropyLossInst *I) { dispatchFloatingPointImpl(fwdCrossEntropyLossInstFloatImpl, I->getP()->getElementType(), I); } -void InterpreterFunction::fwdCrossEntropyLossGradInst( +void BoundInterpreterFunction::fwdCrossEntropyLossGradInst( const CrossEntropyLossGradInst *I) { auto P = getWeightHandle(I->getP()); auto Labels = getWeightHandle(I->getLabels()); @@ -763,13 +765,13 @@ void InterpreterFunction::fwdCrossEntropyLossGradInst( // Tensor shape (copy/transpose/concat/...) //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdCopyInst(const CopyInst *I) { +void BoundInterpreterFunction::fwdCopyInst(const CopyInst *I) { auto inT = getTensor(I->getSrc()); auto outT = getTensor(I->getDest()); outT->copyRawFrom(inT); } -void InterpreterFunction::fwdTransposeInst(const TransposeInst *I) { +void BoundInterpreterFunction::fwdTransposeInst(const TransposeInst *I) { auto inT = getTensor(I->getSrc()); (void)inT; auto outT = getTensor(I->getDest()); @@ -783,11 +785,11 @@ void InterpreterFunction::fwdTransposeInst(const TransposeInst *I) { } } -void InterpreterFunction::fwdTensorViewInst(const TensorViewInst *I) { +void BoundInterpreterFunction::fwdTensorViewInst(const TensorViewInst *I) { getOrCreateUnownedTensor(I, I->getSrc(), I->getOffsets()); } -void InterpreterFunction::fwdSplatInst(const glow::SplatInst *I) { +void BoundInterpreterFunction::fwdSplatInst(const glow::SplatInst *I) { auto *T = getTensor(I->getDest()); ElemKind k = T->getElementType(); @@ -815,7 +817,8 @@ void InterpreterFunction::fwdSplatInst(const glow::SplatInst *I) { llvm_unreachable("Unsupported tensor type"); } -void InterpreterFunction::fwdInsertTensorInst(const glow::InsertTensorInst *I) { +void BoundInterpreterFunction::fwdInsertTensorInst( + const glow::InsertTensorInst *I) { Tensor *outT = getTensor(I->getDest()); Tensor *inT = getTensor(I->getSrc()); ElemKind k = outT->getElementType(); @@ -835,7 +838,7 @@ void InterpreterFunction::fwdInsertTensorInst(const glow::InsertTensorInst *I) { llvm_unreachable("Unsupported tensor type"); } -void InterpreterFunction::fwdExtractTensorInst( +void BoundInterpreterFunction::fwdExtractTensorInst( const glow::ExtractTensorInst *I) { Tensor *outT = getTensor(I->getDest()); Tensor *inT = getTensor(I->getSrc()); @@ -857,7 +860,7 @@ void InterpreterFunction::fwdExtractTensorInst( } template -void InterpreterFunction::fwdGatherInstImpl(const glow::GatherInst *I) { +void BoundInterpreterFunction::fwdGatherInstImpl(const glow::GatherInst *I) { Tensor *dataT = getTensor(I->getData()); auto &dataTy = dataT->getType(); Tensor *indicesT = getTensor(I->getIndices()); @@ -895,7 +898,7 @@ void InterpreterFunction::fwdGatherInstImpl(const glow::GatherInst *I) { } } -void InterpreterFunction::fwdGatherInst(const glow::GatherInst *I) { +void BoundInterpreterFunction::fwdGatherInst(const glow::GatherInst *I) { switch (I->getIndices()->getElementType()) { case ElemKind::Int64ITy: fwdGatherInstImpl(I); @@ -909,7 +912,7 @@ void InterpreterFunction::fwdGatherInst(const glow::GatherInst *I) { } template -void InterpreterFunction::fwdGatherRangesInstImpl( +void BoundInterpreterFunction::fwdGatherRangesInstImpl( const glow::GatherRangesInst *I) { Tensor *dataT = getTensor(I->getData()); auto &dataTy = dataT->getType(); @@ -977,7 +980,8 @@ void InterpreterFunction::fwdGatherRangesInstImpl( assert(grandTotalLen == (outP / dataElementSize)); } -void InterpreterFunction::fwdGatherRangesInst(const glow::GatherRangesInst *I) { +void BoundInterpreterFunction::fwdGatherRangesInst( + const glow::GatherRangesInst *I) { switch (I->getRanges()->getElementType()) { case ElemKind::Int64ITy: fwdGatherRangesInstImpl(I); @@ -990,7 +994,7 @@ void InterpreterFunction::fwdGatherRangesInst(const glow::GatherRangesInst *I) { } } -void InterpreterFunction::fwdScatterAssignInst( +void BoundInterpreterFunction::fwdScatterAssignInst( const glow::ScatterAssignInst *I) { Tensor *dataT = getTensor(I->getData()); Tensor *indicesT = getTensor(I->getIndices()); @@ -1010,7 +1014,8 @@ void InterpreterFunction::fwdScatterAssignInst( } template -void InterpreterFunction::fwdBatchOneHotImpl(const glow::BatchOneHotInst *I) { +void BoundInterpreterFunction::fwdBatchOneHotImpl( + const glow::BatchOneHotInst *I) { auto dataH = getWeightHandle(I->getData()); auto lengthsH = getWeightHandle(I->getLengths()); auto valuesH = getWeightHandle(I->getValues()); @@ -1034,7 +1039,8 @@ void InterpreterFunction::fwdBatchOneHotImpl(const glow::BatchOneHotInst *I) { } } -void InterpreterFunction::fwdBatchOneHotInst(const glow::BatchOneHotInst *I) { +void BoundInterpreterFunction::fwdBatchOneHotInst( + const glow::BatchOneHotInst *I) { switch (I->getData()->getElementType()) { case ElemKind::Int64ITy: fwdBatchOneHotImpl(I); @@ -1053,7 +1059,7 @@ void InterpreterFunction::fwdBatchOneHotInst(const glow::BatchOneHotInst *I) { //===----------------------------------------------------------------------===// template -void InterpreterFunction::fwdLocalResponseNormalizationInstFloatImpl( +void BoundInterpreterFunction::fwdLocalResponseNormalizationInstFloatImpl( const glow::LocalResponseNormalizationInst *I) { staticAssertFloatingPointType(ElemTy); @@ -1111,13 +1117,13 @@ void InterpreterFunction::fwdLocalResponseNormalizationInstFloatImpl( } } -void InterpreterFunction::fwdLocalResponseNormalizationInst( +void BoundInterpreterFunction::fwdLocalResponseNormalizationInst( const LocalResponseNormalizationInst *I) { dispatchFloatingPointImpl(fwdLocalResponseNormalizationInstFloatImpl, I->getSrc()->getElementType(), I); } -void InterpreterFunction::fwdLocalResponseNormalizationGradInst( +void BoundInterpreterFunction::fwdLocalResponseNormalizationGradInst( const glow::LocalResponseNormalizationGradInst *I) { auto inW = getWeightHandle(I->getSrc()); auto inG = getWeightHandle(I->getSrcGrad()); @@ -1190,7 +1196,8 @@ void InterpreterFunction::fwdLocalResponseNormalizationGradInst( //===----------------------------------------------------------------------===// // Arithmetic operations //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdElementAddInstI8Impl(const ElementAddInst *I) { +void BoundInterpreterFunction::fwdElementAddInstI8Impl( + const ElementAddInst *I) { assert(getTensor(I->getLHS())->getType().isQuantizedType() && "Wrong function"); auto lhsTy = I->getLHS()->getType(); @@ -1224,7 +1231,8 @@ void InterpreterFunction::fwdElementAddInstI8Impl(const ElementAddInst *I) { } template -void InterpreterFunction::fwdElementAddInstFloatImpl(const ElementAddInst *I) { +void BoundInterpreterFunction::fwdElementAddInstFloatImpl( + const ElementAddInst *I) { staticAssertFloatingPointType(ElemTy); auto outW = getWeightHandle(I->getDest()); @@ -1235,7 +1243,7 @@ void InterpreterFunction::fwdElementAddInstFloatImpl(const ElementAddInst *I) { } } -void InterpreterFunction::fwdElementAddInst(const ElementAddInst *I) { +void BoundInterpreterFunction::fwdElementAddInst(const ElementAddInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { fwdElementAddInstI8Impl(I); return; @@ -1246,7 +1254,8 @@ void InterpreterFunction::fwdElementAddInst(const ElementAddInst *I) { } template -void InterpreterFunction::fwdElementSubInstFloatImpl(const ElementSubInst *I) { +void BoundInterpreterFunction::fwdElementSubInstFloatImpl( + const ElementSubInst *I) { staticAssertFloatingPointType(ElemTy); auto outW = getWeightHandle(I->getDest()); @@ -1257,7 +1266,7 @@ void InterpreterFunction::fwdElementSubInstFloatImpl(const ElementSubInst *I) { } } -void InterpreterFunction::fwdElementSubInst(const ElementSubInst *I) { +void BoundInterpreterFunction::fwdElementSubInst(const ElementSubInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto destTy = I->getDest()->getType(); auto lhsTy = I->getLHS()->getType(); @@ -1290,7 +1299,8 @@ void InterpreterFunction::fwdElementSubInst(const ElementSubInst *I) { } template -void InterpreterFunction::fwdElementMulInstFloatImpl(const ElementMulInst *I) { +void BoundInterpreterFunction::fwdElementMulInstFloatImpl( + const ElementMulInst *I) { staticAssertFloatingPointType(ElemTy); auto outW = getWeightHandle(I->getDest()); @@ -1301,7 +1311,7 @@ void InterpreterFunction::fwdElementMulInstFloatImpl(const ElementMulInst *I) { } } -void InterpreterFunction::fwdElementMulInst(const ElementMulInst *I) { +void BoundInterpreterFunction::fwdElementMulInst(const ElementMulInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto lhsTy = I->getLHS()->getType(); auto rhsTy = I->getRHS()->getType(); @@ -1327,7 +1337,7 @@ void InterpreterFunction::fwdElementMulInst(const ElementMulInst *I) { I->getDest()->getElementType(), I); } -void InterpreterFunction::fwdElementDivInst(const ElementDivInst *I) { +void BoundInterpreterFunction::fwdElementDivInst(const ElementDivInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto destTy = I->getDest()->getType(); auto lhsTy = I->getLHS()->getType(); @@ -1382,7 +1392,8 @@ void InterpreterFunction::fwdElementDivInst(const ElementDivInst *I) { } } -void InterpreterFunction::fwdElementMaxInstI8Impl(const ElementMaxInst *I) { +void BoundInterpreterFunction::fwdElementMaxInstI8Impl( + const ElementMaxInst *I) { assert(getTensor(I->getLHS())->getType().isQuantizedType() && "Wrong function"); auto lhsTy = I->getLHS()->getType(); @@ -1408,7 +1419,8 @@ void InterpreterFunction::fwdElementMaxInstI8Impl(const ElementMaxInst *I) { } template -void InterpreterFunction::fwdElementMaxInstFloatImpl(const ElementMaxInst *I) { +void BoundInterpreterFunction::fwdElementMaxInstFloatImpl( + const ElementMaxInst *I) { staticAssertFloatingPointType(ElemTy); auto outW = getWeightHandle(I->getDest()); @@ -1419,7 +1431,7 @@ void InterpreterFunction::fwdElementMaxInstFloatImpl(const ElementMaxInst *I) { } } -void InterpreterFunction::fwdElementMaxInst(const ElementMaxInst *I) { +void BoundInterpreterFunction::fwdElementMaxInst(const ElementMaxInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { fwdElementMaxInstI8Impl(I); return; @@ -1430,7 +1442,8 @@ void InterpreterFunction::fwdElementMaxInst(const ElementMaxInst *I) { } template -void InterpreterFunction::fwdElementMinInstFloatImpl(const ElementMinInst *I) { +void BoundInterpreterFunction::fwdElementMinInstFloatImpl( + const ElementMinInst *I) { staticAssertFloatingPointType(ElemTy); auto outW = getWeightHandle(I->getDest()); @@ -1441,7 +1454,7 @@ void InterpreterFunction::fwdElementMinInstFloatImpl(const ElementMinInst *I) { } } -void InterpreterFunction::fwdElementMinInst(const ElementMinInst *I) { +void BoundInterpreterFunction::fwdElementMinInst(const ElementMinInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto lhsTy = I->getLHS()->getType(); auto rhsTy = I->getRHS()->getType(); @@ -1471,7 +1484,7 @@ void InterpreterFunction::fwdElementMinInst(const ElementMinInst *I) { } template -void InterpreterFunction::fwdElementCmpLTEInstFloatImpl( +void BoundInterpreterFunction::fwdElementCmpLTEInstFloatImpl( const ElementCmpLTEInst *I) { staticAssertFloatingPointType(ElemTy); @@ -1485,7 +1498,8 @@ void InterpreterFunction::fwdElementCmpLTEInstFloatImpl( // For both quantized and non-quantized CmpLTE, we set the result to 1.0/0.0. // In the quantized case, we assume that the scale params are (1.0, 0). -void InterpreterFunction::fwdElementCmpLTEInst(const ElementCmpLTEInst *I) { +void BoundInterpreterFunction::fwdElementCmpLTEInst( + const ElementCmpLTEInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto lhsTy = I->getLHS()->getType(); auto rhsTy = I->getRHS()->getType(); @@ -1513,7 +1527,8 @@ void InterpreterFunction::fwdElementCmpLTEInst(const ElementCmpLTEInst *I) { } template -void InterpreterFunction::fwdElementCmpEQInstImpl(const ElementCmpEQInst *I) { +void BoundInterpreterFunction::fwdElementCmpEQInstImpl( + const ElementCmpEQInst *I) { auto outW = getWeightHandle(I->getDest()); auto lhsW = getWeightHandle(I->getLHS()); auto rhsW = getWeightHandle(I->getRHS()); @@ -1522,7 +1537,7 @@ void InterpreterFunction::fwdElementCmpEQInstImpl(const ElementCmpEQInst *I) { } } -void InterpreterFunction::fwdElementCmpEQInst(const ElementCmpEQInst *I) { +void BoundInterpreterFunction::fwdElementCmpEQInst(const ElementCmpEQInst *I) { auto *T = getTensor(I->getDest()); switch (T->getElementType()) { @@ -1536,7 +1551,8 @@ void InterpreterFunction::fwdElementCmpEQInst(const ElementCmpEQInst *I) { } template -void InterpreterFunction::fwdElementPowInstFloatImpl(const ElementPowInst *I) { +void BoundInterpreterFunction::fwdElementPowInstFloatImpl( + const ElementPowInst *I) { staticAssertFloatingPointType(ElemTy); auto baseW = getWeightHandle(I->getLHS()); @@ -1547,13 +1563,14 @@ void InterpreterFunction::fwdElementPowInstFloatImpl(const ElementPowInst *I) { } } -void InterpreterFunction::fwdElementPowInst(const glow::ElementPowInst *I) { +void BoundInterpreterFunction::fwdElementPowInst( + const glow::ElementPowInst *I) { dispatchFloatingPointImpl(fwdElementPowInstFloatImpl, I->getLHS()->getElementType(), I); } template -void InterpreterFunction::fwdElementIsNaNInstFloatImpl( +void BoundInterpreterFunction::fwdElementIsNaNInstFloatImpl( const ElementIsNaNInst *I) { staticAssertFloatingPointType(ElemTy); @@ -1565,13 +1582,15 @@ void InterpreterFunction::fwdElementIsNaNInstFloatImpl( } } -void InterpreterFunction::fwdElementIsNaNInst(const glow::ElementIsNaNInst *I) { +void BoundInterpreterFunction::fwdElementIsNaNInst( + const glow::ElementIsNaNInst *I) { dispatchFloatingPointImpl(fwdElementIsNaNInstFloatImpl, I->getSrc()->getElementType(), I); } template -void InterpreterFunction::fwdElementLogInstFloatImpl(const ElementLogInst *I) { +void BoundInterpreterFunction::fwdElementLogInstFloatImpl( + const ElementLogInst *I) { staticAssertFloatingPointType(ElemTy); auto inW = getWeightHandle(I->getSrc()); @@ -1582,13 +1601,13 @@ void InterpreterFunction::fwdElementLogInstFloatImpl(const ElementLogInst *I) { } } -void InterpreterFunction::fwdElementLogInst(const ElementLogInst *I) { +void BoundInterpreterFunction::fwdElementLogInst(const ElementLogInst *I) { dispatchFloatingPointImpl(fwdElementLogInstFloatImpl, I->getSrc()->getElementType(), I); } template -void InterpreterFunction::fwdElementSelectInstFloatImpl( +void BoundInterpreterFunction::fwdElementSelectInstFloatImpl( const glow::ElementSelectInst *I) { staticAssertFloatingPointType(ElemTy); @@ -1601,7 +1620,7 @@ void InterpreterFunction::fwdElementSelectInstFloatImpl( } } -void InterpreterFunction::fwdElementSelectInst( +void BoundInterpreterFunction::fwdElementSelectInst( const glow::ElementSelectInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto destTy = I->getDest()->getType(); @@ -1637,7 +1656,7 @@ void InterpreterFunction::fwdElementSelectInst( // Mat Mul //===----------------------------------------------------------------------===// template -void InterpreterFunction::fwdMatMulInstQuantizedImpl( +void BoundInterpreterFunction::fwdMatMulInstQuantizedImpl( const glow::MatMulInst *I) { assert(getTensor(I->getLHS())->getType().isQuantizedType()); auto lhs = getWeightHandle(I->getLHS()); @@ -1683,7 +1702,7 @@ void InterpreterFunction::fwdMatMulInstQuantizedImpl( } template -void InterpreterFunction::fwdMatMulInstFloatImpl(const MatMulInst *I) { +void BoundInterpreterFunction::fwdMatMulInstFloatImpl(const MatMulInst *I) { staticAssertFloatingPointType(ElemTy); auto lhs = getWeightHandle(I->getLHS()); @@ -1709,7 +1728,7 @@ void InterpreterFunction::fwdMatMulInstFloatImpl(const MatMulInst *I) { } } -void InterpreterFunction::fwdMatMulInst(const glow::MatMulInst *I) { +void BoundInterpreterFunction::fwdMatMulInst(const glow::MatMulInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { dispatchQuantizedWithAccumulationImpl(fwdMatMulInstQuantizedImpl, I->getLHS()->getElementType(), I); @@ -1723,7 +1742,7 @@ void InterpreterFunction::fwdMatMulInst(const glow::MatMulInst *I) { //===----------------------------------------------------------------------===// // Row-wise quantized FC //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdRowwiseQuantizedFullyConnectedInst( +void BoundInterpreterFunction::fwdRowwiseQuantizedFullyConnectedInst( const RowwiseQuantizedFullyConnectedInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -1811,7 +1830,7 @@ static void fwdBatchedAdd(Tensor *batch, Tensor *slice, Tensor *dest) { } template -void InterpreterFunction::fwdBatchedAddInstFloatImpl( +void BoundInterpreterFunction::fwdBatchedAddInstFloatImpl( const glow::BatchedAddInst *I) { staticAssertFloatingPointType(ElemTy); @@ -1834,7 +1853,8 @@ void InterpreterFunction::fwdBatchedAddInstFloatImpl( } } -void InterpreterFunction::fwdBatchedAddInst(const glow::BatchedAddInst *I) { +void BoundInterpreterFunction::fwdBatchedAddInst( + const glow::BatchedAddInst *I) { if (getTensor(I->getBatch())->getType().isQuantizedType()) { dispatchQuantizedImpl(fwdBatchedAdd, I->getSlice()->getElementType(), getTensor(I->getBatch()), getTensor(I->getSlice()), @@ -1846,7 +1866,7 @@ void InterpreterFunction::fwdBatchedAddInst(const glow::BatchedAddInst *I) { } template -void InterpreterFunction::fwdBatchedReduceAddInstFloatImpl( +void BoundInterpreterFunction::fwdBatchedReduceAddInstFloatImpl( Value *batch, Value *dest, unsigned_t axis, const ShapeVector &eBatchDims, const ShapeVector &eDestDims) { staticAssertFloatingPointType(ElemTy); @@ -1878,7 +1898,7 @@ void InterpreterFunction::fwdBatchedReduceAddInstFloatImpl( } } -void InterpreterFunction::fwdBatchedReduceAddInst( +void BoundInterpreterFunction::fwdBatchedReduceAddInst( const glow::BatchedReduceAddInst *I) { static_assert(max_tensor_dimensions == 6, "Loops below assume max_tensor_dimensions = 6."); @@ -1956,7 +1976,8 @@ void InterpreterFunction::fwdBatchedReduceAddInst( } template -void InterpreterFunction::fwdLengthsSumInstFloatImpl(const LengthsSumInst *I) { +void BoundInterpreterFunction::fwdLengthsSumInstFloatImpl( + const LengthsSumInst *I) { staticAssertFloatingPointType(ElemTy); auto out = getTensor(I->getDest()); @@ -1989,13 +2010,13 @@ void InterpreterFunction::fwdLengthsSumInstFloatImpl(const LengthsSumInst *I) { assert(offsetOut == out->size() && "All values in Dest should be written to"); } -void InterpreterFunction::fwdLengthsSumInst(const LengthsSumInst *I) { +void BoundInterpreterFunction::fwdLengthsSumInst(const LengthsSumInst *I) { dispatchFloatingPointImpl(fwdLengthsSumInstFloatImpl, I->getData()->getElementType(), I) } template -void InterpreterFunction::fwdSparseLengthsWeightedSumInstFloatImpl( +void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstFloatImpl( const SparseLengthsWeightedSumInst *I) { staticAssertFloatingPointType(ElemTy); @@ -2036,7 +2057,7 @@ void InterpreterFunction::fwdSparseLengthsWeightedSumInstFloatImpl( } } -void InterpreterFunction::fwdSparseLengthsWeightedSumInstI8Impl( +void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstI8Impl( const SparseLengthsWeightedSumInst *I) { auto out = getTensor(I->getDest()); @@ -2088,7 +2109,7 @@ void InterpreterFunction::fwdSparseLengthsWeightedSumInstI8Impl( } } -void InterpreterFunction::fwdSparseLengthsWeightedSumInst( +void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInst( const SparseLengthsWeightedSumInst *I) { if (I->getDest()->getType()->isQuantizedType()) { return fwdSparseLengthsWeightedSumInstI8Impl(I); @@ -2097,7 +2118,57 @@ void InterpreterFunction::fwdSparseLengthsWeightedSumInst( I->getData()->getElementType(), I); } -void InterpreterFunction::fwdLengthsToRangesInst(const LengthsToRangesInst *I) { +void BoundInterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumInst( + const RowwiseQuantizedSparseLengthsWeightedSumInst *I) { + auto *out = getTensor(I->getDest()); + auto *data = getTensor(I->getData()); + auto *dataScales = getTensor(I->getScales()); + auto *dataOffsets = getTensor(I->getOffsets()); + auto *weights = getTensor(I->getWeights()); + auto *indices = getTensor(I->getIndices()); + auto *lengths = getTensor(I->getLengths()); + + out->zero(); + + auto IH = indices->getHandle(); + auto LH = lengths->getHandle(); + + size_t segments = lengths->dims()[0]; + size_t totalLength = 0; + for (size_t i = 0; i < segments; i++) { + totalLength += LH.raw(i); + } + assert(totalLength == indices->dims()[0] && + "sum(Lengths) must be equal to len(Indices)"); + + size_t lineSize = data->size() / data->dims()[0]; + + auto DH = data->getHandle(); + auto DSH = dataScales->getHandle(); + auto DOH = dataOffsets->getHandle(); + auto WH = weights->getHandle(); + auto OH = out->getHandle(); + + size_t curIdx = 0; + for (size_t i = 0; i < segments; i++) { + for (size_t j = 0, e = LH.raw(i); j < e; j++) { + const float weight = WH.raw(curIdx); + const size_t rowIdx = IH.raw(curIdx++); + const float scale = DSH.at({rowIdx}); + const int32_t offset = DOH.at({rowIdx}); + size_t offsetIn = rowIdx * lineSize; + size_t offsetOut = i * lineSize; + for (size_t k = 0; k < lineSize; k++) { + float d = quantization::dequantize( + DH.raw(offsetIn++), TensorQuantizationParams{scale, offset}); + OH.raw(offsetOut++) += d * weight; + } + } + } +} + +void BoundInterpreterFunction::fwdLengthsToRangesInst( + const LengthsToRangesInst *I) { auto ranges = getTensor(I->getDest())->getHandle(); auto lengths = getTensor(I->getLengths())->getHandle(); int32_t offset = 0; @@ -2110,7 +2181,7 @@ void InterpreterFunction::fwdLengthsToRangesInst(const LengthsToRangesInst *I) { } template -void InterpreterFunction::fwdSparseToDenseInstFloatImpl( +void BoundInterpreterFunction::fwdSparseToDenseInstFloatImpl( const SparseToDenseInst *I) { staticAssertFloatingPointType(ElemTy); @@ -2158,7 +2229,8 @@ void InterpreterFunction::fwdSparseToDenseInstFloatImpl( } } -void InterpreterFunction::fwdSparseToDenseInst(const SparseToDenseInst *I) { +void BoundInterpreterFunction::fwdSparseToDenseInst( + const SparseToDenseInst *I) { dispatchFloatingPointImpl(fwdSparseToDenseInstFloatImpl, I->getDest()->getElementType(), I); } @@ -2201,7 +2273,7 @@ static void fwdTopK(Tensor *outW, Tensor *indW, Tensor *inW, size_t k) { // Sorting operators //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdTopKInst(const TopKInst *I) { +void BoundInterpreterFunction::fwdTopKInst(const TopKInst *I) { auto outW = getTensor(I->getValues()); auto indW = getTensor(I->getIndices()); auto inW = getTensor(I->getInput()); @@ -2219,11 +2291,12 @@ void InterpreterFunction::fwdTopKInst(const TopKInst *I) { // Tensor allocation operations //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdAllocActivationInst(const AllocActivationInst *I) { +void BoundInterpreterFunction::fwdAllocActivationInst( + const AllocActivationInst *I) { getOrCreateTensor(I); } -void InterpreterFunction::fwdDeallocActivationInst( +void BoundInterpreterFunction::fwdDeallocActivationInst( const DeallocActivationInst *I) { deleteTensor(I->getSrc()); } @@ -2235,7 +2308,7 @@ void InterpreterFunction::fwdDeallocActivationInst( /// Prints a value of the instruction's operand. /// In most cases it will be the name of the variable and the value of the /// tensor. -void InterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) { +void BoundInterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) { auto *V = I->getSrc(); llvm::outs() << I->getName() << ": "; // Dump the content of a value. @@ -2248,7 +2321,7 @@ void InterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) { //===----------------------------------------------------------------------===// // Instructions used by Quantization //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdQuantizationProfileInst( +void BoundInterpreterFunction::fwdQuantizationProfileInst( const glow::QuantizationProfileInst *I) { auto inputTensor = getWeightHandle(I->getInputTensor()); auto currentHistogram = getWeightHandle(I->getHistogram()); @@ -2264,7 +2337,7 @@ void InterpreterFunction::fwdQuantizationProfileInst( /// Quantize floating point tensor. Scale and Offset are based on return type /// of the instruction \p I. -void InterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) { +void BoundInterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) { auto *srcTensor = getTensor(I->getSrc()); auto *destTensor = getTensor(I->getDest()); auto destTy = destTensor->getType(); @@ -2276,7 +2349,8 @@ void InterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) { /// Dequantize integer tensor. Scale and Offset are based /// on the source tensor type. -void InterpreterFunction::fwdDequantizeInst(const glow::DequantizeInst *I) { +void BoundInterpreterFunction::fwdDequantizeInst( + const glow::DequantizeInst *I) { auto *srcTensor = getTensor(I->getSrc()); auto *destTensor = getTensor(I->getDest()); auto destTy = destTensor->getType(); @@ -2286,7 +2360,7 @@ void InterpreterFunction::fwdDequantizeInst(const glow::DequantizeInst *I) { } template -void InterpreterFunction::fwdRescaleQuantizedInstImpl( +void BoundInterpreterFunction::fwdRescaleQuantizedInstImpl( Value *src, Value *dest, TensorQuantizationParams &srcQ, TensorQuantizationParams &destQ) { @@ -2299,7 +2373,7 @@ void InterpreterFunction::fwdRescaleQuantizedInstImpl( } } -void InterpreterFunction::fwdRescaleQuantizedInst( +void BoundInterpreterFunction::fwdRescaleQuantizedInst( const glow::RescaleQuantizedInst *I) { auto src = I->getSrc(); auto dest = I->getDest(); @@ -2313,7 +2387,8 @@ void InterpreterFunction::fwdRescaleQuantizedInst( src, dest, srcQ, destQ); } -void InterpreterFunction::fwdIntLookupTableInst(const IntLookupTableInst *I) { +void BoundInterpreterFunction::fwdIntLookupTableInst( + const IntLookupTableInst *I) { auto srcH = getWeightHandle(I->getSrc()); auto destH = getWeightHandle(I->getDest()); auto mappingH = getWeightHandle(I->getMapping()); @@ -2323,7 +2398,7 @@ void InterpreterFunction::fwdIntLookupTableInst(const IntLookupTableInst *I) { } } -void InterpreterFunction::fwdConvertToInst(const glow::ConvertToInst *I) { +void BoundInterpreterFunction::fwdConvertToInst(const glow::ConvertToInst *I) { Tensor *source = getTensor(I->getInput()); Tensor *dest = getTensor(I->getResult()); if (source->getType() == dest->getType()) { diff --git a/lib/Backends/OpenCL/OpenCL.cpp b/lib/Backends/OpenCL/OpenCL.cpp index 7038824159..6bfef00d11 100644 --- a/lib/Backends/OpenCL/OpenCL.cpp +++ b/lib/Backends/OpenCL/OpenCL.cpp @@ -621,7 +621,9 @@ static void topK(Tensor &outW, Tensor &indW, Tensor &inW, size_t k) { } } -void OpenCLFunction::execute() { +void OpenCLFunction::execute(Context *ctx) { + (void)ctx; + for (const auto &I : F_->getInstrs()) { // The kernels are named after the name of the instruction, plus the "W" // suffix to prevent name colissions for functions like 'tanh' that are also diff --git a/lib/Backends/OpenCL/OpenCL.h b/lib/Backends/OpenCL/OpenCL.h index 5662a1e138..83d5e31fd7 100644 --- a/lib/Backends/OpenCL/OpenCL.h +++ b/lib/Backends/OpenCL/OpenCL.h @@ -96,7 +96,7 @@ class OpenCLFunction final : public CompiledFunction { ///@{ ~OpenCLFunction() override; - void execute() override; + void execute(Context *ctx) override; ///@} /// Allocates on device buffer and copies Constant weights to device. void setupRuns() override; @@ -205,6 +205,11 @@ class OCLBackend final : public BackendUsingGlowIR { return false; } } + + if (elementTy == ElemKind::Int16QTy) { + return false; + } + return true; }; diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 9ce747ce8e..06b6f5cdf8 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -10,5 +10,8 @@ add_subdirectory(IR) add_subdirectory(Importer) add_subdirectory(Optimizer) add_subdirectory(Quantization) +add_subdirectory(Runtime) add_subdirectory(Support) add_subdirectory(Onnxifi) +add_subdirectory(Partitioner) + diff --git a/lib/ExecutionEngine/ExecutionEngine.cpp b/lib/ExecutionEngine/ExecutionEngine.cpp index 96d558ef80..8d37be3cef 100644 --- a/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/lib/ExecutionEngine/ExecutionEngine.cpp @@ -82,7 +82,7 @@ void ExecutionEngine::run(Context &ctx) { ctx.allocate(M_.getPlaceholders()); function_->setupRuns(); function_->beforeRun(ctx); - function_->execute(); + function_->execute(&ctx); function_->afterRun(ctx); } @@ -121,42 +121,14 @@ void glow::runBatch(ExecutionEngine &EE, Context &ctx, size_t iterations, } } -void ExecutionEngine::optimizeFunction(CompilationMode mode, Function *F) { - // Verify the function pre-optimization/lowering. - assert(F->verify() && "Function must be valid"); - - // Optimize the graph. - ::glow::optimize(F, mode); - - // Allow the backend to transform the graph prior to lowering. - if (backend_->transformPreLowering(F, mode)) { - // Optimize the graph again after the backend transformation. - // In particular, DCE is very likely to be useful. - ::glow::optimize(F, mode); - } - - // Lower the graph into a sequence of low-level linear algebra operations. - ::glow::lower(F, *backend_); - - // Optimize the graph again. - ::glow::optimize(F, mode); - - // Allow the backend to transform the graph after lowering. - if (backend_->transformPostLowering(F, mode)) { - // Optimize the graph again after the backend transformation. - // In particular, DCE is very likely to be useful. - ::glow::optimize(F, mode); - } -} - void ExecutionEngine::compile(CompilationMode mode, Function *F) { - optimizeFunction(mode, F); + backend_->optimizeFunction(mode, F); function_ = backend_->compile(F); } void ExecutionEngine::save(CompilationMode mode, Function *F, llvm::StringRef outputDir, llvm::StringRef networkName) { - optimizeFunction(mode, F); + backend_->optimizeFunction(mode, F); backend_->save(F, outputDir, networkName); } diff --git a/lib/Graph/Graph.cpp b/lib/Graph/Graph.cpp index 3f99a341af..d9f0c05c35 100644 --- a/lib/Graph/Graph.cpp +++ b/lib/Graph/Graph.cpp @@ -1369,6 +1369,60 @@ Function::createSparseLengthsWeightedSum(llvm::StringRef name, TypeRef outTy, indices, lengths)); } +/// Helper to create a RowwiseQuantizedSparseLengthsWeightedSumNode in the +/// Function \p F with \p name, using \ data, \p weights, \p indices, and \p +/// lengths as inputs. The provided float data in \p Tensor is rowwise +/// quantized, creating Constants for the rowwise quantized data as well as +/// Scales and Offsets, in the Module containing \p F. +static RowwiseQuantizedSparseLengthsWeightedSumNode * +quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum( + Function *F, llvm::StringRef name, Tensor &data, NodeValue weights, + NodeValue indices, NodeValue lengths) { + auto inDims = data.dims(); + ShapeVector outDims(inDims.begin(), inDims.end()); + outDims[0] = lengths.dims()[0]; + auto outTy = F->getParent()->uniqueType(ElemKind::FloatTy, outDims); + + // Note: In rwqData, we are using a quantized type, however the scale/offset + // are set to dummy values 0.0/0. This is because the actually used + // scale/offset come from dataScales and dataOffsets. + Constant *rwqData = + F->getParent()->createConstant(ElemKind::Int8QTy, inDims, 0.0, 0, "data"); + Constant *dataScales = F->getParent()->createConstant( + ElemKind::FloatTy, {inDims[0]}, "dataScales"); + Constant *dataOffsets = F->getParent()->createConstant( + ElemKind::Int32ITy, {inDims[0]}, "dataOffsets"); + + quantization::tensorRowwiseQuantization(data, rwqData->getPayload(), + dataScales->getPayload(), + dataOffsets->getPayload()); + + return F->addNode(new RowwiseQuantizedSparseLengthsWeightedSumNode( + name, outTy, rwqData, dataScales, dataOffsets, weights, indices, + lengths)); +} + +RowwiseQuantizedSparseLengthsWeightedSumNode * +Function::createRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name, + Tensor &data, + NodeValue weights, + NodeValue indices, + NodeValue lengths) { + return quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum( + this, name, data, weights, indices, lengths); +} + +RowwiseQuantizedSparseLengthsWeightedSumNode * +Function::createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name, + Tensor &data, + NodeValue indices, + NodeValue lengths) { + auto ty = getParent()->uniqueType(ElemKind::FloatTy, {indices.dims()[0]}); + auto ones = createSplat(name.str() + ".ones", ty, 1.0); + return quantizeDataAndCreateRowwiseQuantizedSparseLengthsWeightedSum( + this, name, data, ones, indices, lengths); +} + LengthsToRangesNode *Function::createLengthsToRanges(llvm::StringRef name, NodeValue lengths) { ShapeVector outDims({lengths.dims()[0], 2}); diff --git a/lib/Graph/Nodes.cpp b/lib/Graph/Nodes.cpp index ed0c6e4089..ae35c0491a 100644 --- a/lib/Graph/Nodes.cpp +++ b/lib/Graph/Nodes.cpp @@ -384,7 +384,7 @@ bool AvgPoolGradNode::verify() const { isValid &= verifyPool(getGradOfInputNamedInput(), getGradOfOriginalOutputNamedResult(), Kernels_, Strides_, Pads_); - return true; + return isValid; } bool MatMulNode::verify() const { @@ -717,6 +717,36 @@ bool SparseLengthsWeightedSumNode::verify() const { return isValid; } +bool RowwiseQuantizedSparseLengthsWeightedSumNode::verify() const { + bool isValid = checkType(getResult(), ElemKind::FloatTy, this); + isValid &= checkType(getData(), ElemKind::Int8QTy, this); + isValid &= checkType(getScales(), ElemKind::FloatTy, this); + isValid &= checkType(getOffsets(), ElemKind::Int32ITy, this); + isValid &= checkType(getWeights(), ElemKind::FloatTy, this); + isValid &= checkType(getIndices(), ElemKind::Int64ITy, this); + isValid &= checkType(getLengths(), ElemKind::Int32ITy, this); + isValid &= expectCompareTrue("Indices must be a 1D vector", + getIndices().dims().size(), size_t(1), this); + isValid &= expectCompareTrue("Lengths must be a 1D vector", + getLengths().dims().size(), size_t(1), this); + isValid &= expectCompareTrue("Weights must be a 1D vector", + getWeights().dims().size(), size_t(1), this); + isValid &= expectCompareTrue("Scales must be a 1D vector", + getScales().dims().size(), size_t(1), this); + isValid &= expectCompareTrue("Offsets must be a 1D vector", + getOffsets().dims().size(), size_t(1), this); + isValid &= + expectCompareTrue("Weights and Indices must have the same size", + getWeights().dims()[0], getIndices().dims()[0], this); + isValid &= expectCompareTrue( + "Scales and Data must have the same first dimension size", + getData().dims()[0], getScales().dims()[0], this); + isValid &= expectCompareTrue( + "Offsets and Data must have the same first dimension size", + getData().dims()[0], getOffsets().dims()[0], this); + return isValid; +} + bool LengthsToRangesNode::verify() const { bool isValid = checkType(getResult(), getLengths().getElementType(), this); isValid &= checkType(getLengths(), ElemKind::Int32ITy, this); diff --git a/lib/Importer/ONNXModelLoader.cpp b/lib/Importer/ONNXModelLoader.cpp index 873d09af92..86d00f6b88 100644 --- a/lib/Importer/ONNXModelLoader.cpp +++ b/lib/Importer/ONNXModelLoader.cpp @@ -135,7 +135,8 @@ llvm::Error ONNXModelLoader::setVersion(ONNX_NAMESPACE::ModelProto MP) { opsetVersion_ = 0; RETURN_ERR_IF_NOT( irVersion_ >= 3, - "This ONNX model with ir_version < 3 is too old to be supported."); + "This ONNX model with ir_version < 3 is too old to be supported.", + GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_ONNX_VERSION); for (const auto &imp : MP.opset_import()) { if (!imp.has_domain() || imp.domain() == "") { opsetVersion_ = imp.version(); @@ -156,7 +157,8 @@ ONNXModelLoader::loadProto(google::protobuf::io::ZeroCopyInputStream &iStream) { codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE, MAX_PROTO_SIZE); ONNX_NAMESPACE::ModelProto MP; bool parseNet = MP.ParseFromCodedStream(&codedStream); - RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto"); + RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto", + GlowErr::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF); return MP; } @@ -169,7 +171,8 @@ ONNXModelLoader::loadProto(const void *onnxModel, size_t onnxModelSize) { llvm::Expected ONNXModelLoader::loadProto(const std::string &filename) { std::ifstream ff(filename, std::ios::in | std::ios::binary); - RETURN_ERR_IF_NOT(ff, "Can't find the model or network files."); + RETURN_ERR_IF_NOT(ff, "Can't find the model or network files.", + GlowErr::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF); // TODO: intend to find a way to reuse the following function later // for the text format onnx model: @@ -181,7 +184,8 @@ ONNXModelLoader::loadProto(const std::string &filename) { ONNX_NAMESPACE::ModelProto MP; bool parseNet = google::protobuf::TextFormat::ParseFromString(str, &MP); - RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto"); + RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto", + GlowErr::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF); return MP; } @@ -232,7 +236,8 @@ static llvm::Error loadTensor(const ONNX_NAMESPACE::TensorProto &in, std::istringstream inStream(in.raw_data(), std::stringstream::binary); inStream.read(T->getUnsafePtr(), T->size() * sizeof(float)); } else { - RETURN_ERR("Unsupported Tensor format."); + RETURN_ERR("Unsupported Tensor format.", + GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); } } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT64) { T->reset(ElemKind::Int64ITy, dim); @@ -247,7 +252,8 @@ static llvm::Error loadTensor(const ONNX_NAMESPACE::TensorProto &in, std::istringstream inStream(in.raw_data(), std::stringstream::binary); inStream.read(T->getUnsafePtr(), T->size() * sizeof(int64_t)); } else { - RETURN_ERR("Unsupported Tensor format."); + RETURN_ERR("Unsupported Tensor format.", + GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); } } else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT32) { // There are few cases when we will have int32 tensors. For example, the @@ -264,10 +270,12 @@ static llvm::Error loadTensor(const ONNX_NAMESPACE::TensorProto &in, std::istringstream inStream(in.raw_data(), std::stringstream::binary); inStream.read(T->getUnsafePtr(), T->size() * sizeof(int32_t)); } else { - RETURN_ERR("Unsupported Tensor format."); + RETURN_ERR("Unsupported Tensor format.", + GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); } } else { - RETURN_ERR("Only float and index tensors are supported"); + RETURN_ERR("Only float and index tensors are supported", + GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); } return llvm::Error::success(); } @@ -307,7 +315,8 @@ llvm::Error ONNXModelLoader::loadConstant(const ONNX_NAMESPACE::NodeProto &op, RETURN_ERR_IF_NOT(dict.at("value")->type() == ONNX_NAMESPACE::AttributeProto::TENSOR, - "Only Tensor type constants are supported."); + "Only Tensor type constants are supported.", + GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE); std::unique_ptr T(new Tensor()); RETURN_IF_ERR(loadTensor(dict.at("value")->t(), T.get())); @@ -511,7 +520,8 @@ llvm::Error ONNXModelLoader::loadPool(const ONNX_NAMESPACE::NodeProto &op, // Glow doesn't support argmax output yet. if (op.output_size() > 1) { - RETURN_ERR("Glow doesn't support argmax output yet."); + RETURN_ERR("Glow doesn't support argmax output yet.", + GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR); } // Load the inputs: NodeValue in; @@ -529,7 +539,8 @@ llvm::Error ONNXModelLoader::loadPool(const ONNX_NAMESPACE::NodeProto &op, if (in.dims().size() != 4 || kernels.size() != 2) { // Glow only handles 2D pooling currently. - RETURN_ERR("Glow only handles 2D pooling currently."); + RETURN_ERR("Glow only handles 2D pooling currently.", + GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE); } auto *tr = G_.createTranspose(opName, in, NCHW2NHWC); @@ -785,7 +796,8 @@ llvm::Error ONNXModelLoader::loadPad(const ONNX_NAMESPACE::NodeProto &op, } else if (modeStr == "edge") { mode = PaddingMode::EDGE; } else { - RETURN_ERR("Pad: Invalid mode"); + RETURN_ERR("Pad: Invalid mode", + GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE); } } float value = 0.f; // Default @@ -873,7 +885,8 @@ llvm::Error ONNXModelLoader::loadOperator(const ONNX_NAMESPACE::NodeProto &op) { return loadPad(op, dict); } - RETURN_ERR("Failed to load operator."); + RETURN_ERR("Failed to load operator.", + GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR); } llvm::Error ONNXModelLoader::loadInitializers(ONNX_NAMESPACE::GraphProto &net) { diff --git a/lib/Optimizer/IROptimizer.cpp b/lib/Optimizer/IROptimizer.cpp index 5cd7990cf7..753ad55cc6 100644 --- a/lib/Optimizer/IROptimizer.cpp +++ b/lib/Optimizer/IROptimizer.cpp @@ -574,7 +574,7 @@ static void calculateLiveIntervals(const IRFunction &M, (opKind == OperandKind::Out) && (op->getType()->size() < loc->getType()->size()); - auto opIdx = instIdx; + unsigned opIdx; if (opKind == OperandKind::Out && !isPartialWrite) { opIdx = LiveIntervalsInstructionNumbering::getInstrWriteSlotNumber(instIdx); diff --git a/lib/Partitioner/CMakeLists.txt b/lib/Partitioner/CMakeLists.txt new file mode 100644 index 0000000000..ffd6ea532b --- /dev/null +++ b/lib/Partitioner/CMakeLists.txt @@ -0,0 +1,6 @@ +add_library(Partitioner + Partitioner.cpp) + +target_link_libraries(Partitioner + PRIVATE + Graph) diff --git a/lib/Partitioner/Partitioner.cpp b/lib/Partitioner/Partitioner.cpp new file mode 100644 index 0000000000..91bfc5689f --- /dev/null +++ b/lib/Partitioner/Partitioner.cpp @@ -0,0 +1,363 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "glow/Partitioner/Partitioner.h" +#include "glow/Graph/Context.h" +#include "glow/Graph/Utils.h" + +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" + +using namespace glow; +using llvm::isa; + +/// A graph with BFS oder. +struct BFSLevel { + /// A list of with BFS order. + std::vector>> levels; + /// A set of visited nodes. + std::unordered_set visited; +}; + +Partitioner::Partitioner(Module *parent, const std::vector &devices) + : module_(parent), deviceInfo_(devices) { + memSize_ = module_->getConstantsSize(); +} + +Function *Partitioner::selectRepFunc(Module *parent, size_t &memSize) { + auto funcList = parent->getFunctions(); + Function *ret = nullptr; + for (Function *F : funcList) { + size_t size = memSize; + + // The set to keep the placeholders (only for Inputs) whose size is + // already calculated. + std::set pSet; + + for (auto &node : F->getNodes()) { + int n = node.getNumInputs(); + if (node.getKind() == Kinded::Kind::SaveNodeKind) { + // Special node, the placeholder should be ignored? + continue; + } + for (int i = 0; i < n; i++) { + Placeholder *in = + llvm::dyn_cast(node.getNthInput(i).getNode()); + if (in && pSet.count(in->getName()) == 0) { + auto ty = in->getType(); + size += ty->getSizeInBytes(); + pSet.insert(in->getName()); + } + } + } + // Find the function with largest required memory as the representive + // function. + if (size > memSize) { + ret = F; + memSize = size; + } + } + return ret; +} + +/// Get the minimal memory requirement (constant) for each op in the function. +void Partitioner::initOpMemUsage() { + memUsage_.clear(); + for (auto &node : F_->getNodes()) { + int n = node.getNumInputs(); + unsigned size = 0; + if (node.getKind() == Kinded::Kind::SaveNodeKind) { + memUsage_[&node] = size; + continue; + } + for (int i = 0; i < n; i++) { + Storage *in = llvm::dyn_cast(node.getNthInput(i).getNode()); + if (in) { + auto ty = in->getType(); + size += ty->getSizeInBytes(); + } + } + memUsage_[&node] = size; + } +} + +static BFSLevel getBFSLevel(Function *F) { + // Visit graph nodes in BFS order. For each non-storage node, get its level. + // Use the preorder to get the roots. Now assume there is only one output op + // (i.e. root) now. + GraphPreOrderVisitor visitor(*F); + Node *node = nullptr; + for (auto &N : visitor.getPreOrder()) { + if (isa(N)) { + continue; + } + node = N; + break; + } + + BFSLevel bfs; + int level = 0; + int current = 0; + bfs.levels.push_back({level, {node}}); + bfs.visited.insert(node); + level++; + while (current < level) { + std::vector nodes; + for (int i = 0, e = bfs.levels[current].second.size(); i < e; i++) { + Node *N = bfs.levels[current].second[i]; + + for (int j = 0, e = N->getNumInputs(); j < e; ++j) { + Node *in = N->getNthInput(j).getNode(); + if (isa(in) || bfs.visited.count(in)) { + continue; + } + nodes.push_back(in); + bfs.visited.insert(in); + } + } + if (nodes.size() > 0) { + auto newPair = std::make_pair(level, nodes); + bfs.levels.push_back(newPair); + level++; + } + current++; + } + + return bfs; +} + +/// Assign nodes to partitions and return the mapping. +NodeToFunctionMap Partitioner::selectPartitions(Function *F, + unsigned availableMemory) { + NodeToFunctionMap mapping; + BFSLevel bfs = getBFSLevel(F); + unsigned level = bfs.levels.size(); + // A list of cut. The graph can be partitioned by levels [level - 1, + // cut[0]), [cut[0] - 1, cut[1]), ..., [cut[n], -1). + std::vector cut; + + // Step 1 : get the initial cut based on BFS levels and avaiableMemory. + // TODO .. need to remove the duplicated memory usage. + unsigned mem = 0; + for (int i = level - 1; i >= 0; i--) { + unsigned tmp = 0; + for (int j = 0, e = bfs.levels[i].second.size(); j < e; j++) { + Node *N = bfs.levels[i].second[j]; + tmp += memUsage_[N]; + } + if (mem + tmp > availableMemory) { + if (mem == 0) { + // This means the mem usage for one level exceeds the availableMem, + // accept it now and will do adjustment later. + cut.push_back(i + 1); + } else { + cut.push_back(i); + mem = tmp; + } + } else { + mem += tmp; + } + } + + // The last border. + cut.push_back(-1); + + // Step 2 : Create the initial mapping between node and functions. + for (int k = 0, e = cut.size(); k < e; k++) { + auto *newF = F->getParent()->createFunction(std::string(F->getName()) + + "_part" + std::to_string(k)); + mapping.createPartition(newF); + for (int i = k > 0 ? cut[k - 1] : level - 1; i > cut[k]; i--) { + for (int j = 0, e1 = bfs.levels[i].second.size(); j < e1; j++) { + Node *N = bfs.levels[i].second[j]; + mapping.add(N, newF); + } + } + } + // Step 3 : adjust the partition based on performance (Advanced Graph + // Paritioning algrithm will be applied here). + // --- TODO + + return mapping; +} + +/// Adjust the logicalDevice ID for each DAGNode. This happens when \p num (i.e. +/// the number of DAGNodes) is larger than the number of devices. E.g: +/// node1(6GB) -> node2(14GB) -> node3(6GB). The memory limitation is 16GB, and +/// there is only 2 devices. +void Partitioner::adjustLogicalDeviceID(DAGNode *DAG, int num) {} + +/// Current only partition the representive function. +void Partitioner::doPartitioning(Function *F, NodeToFunctionMap &mapping) { + // The dummy node. + std::unique_ptr DAG = std::make_unique(); + DAG->logicalDevice = 0; + DAG->name = F->getName(); + DAG->deviceID = 0; + DAG->logicalDevice = 0; + DAGNode *root = DAG.get(); + partitions_.roots.push_back(std::move(DAG)); + llvm::DenseMap currToNew; + + // Clone nodes into target partition. + for (auto &N : F->getNodes()) { + auto *clone = N.clone(); + currToNew[&N] = clone; + mapping[&N]->addNode(clone); + } + + // For any dependency that crosses a partition, add a placeholder and save + // node. Record the dependence in the function graph. + int logicalID = 0; + llvm::DenseMap placeholders; + llvm::DenseMap funcDAG; + for (auto *subF : mapping.getPartitions()) { + if (funcDAG.find(subF) == funcDAG.end()) { + std::unique_ptr subDAG = std::make_unique(); + subDAG->name = subF->getName(); + subDAG->logicalDevice = logicalID++; + funcDAG[subF] = subDAG.get(); + partitions_.nodes.push_back(std::move(subDAG)); + } + + // Link subF to its parents. + for (auto &N : subF->getNodes()) { + for (int inp = 0, e = N.getNumInputs(); inp < e; inp++) { + auto input = N.getNthInput(inp); + if (isa(input.getNode())) + continue; + + auto *inputF = mapping[input.getNode()]; + if (subF == inputF) + continue; + + // Check if a DAGNode for subF's parent is created or not. If not, + // create one. + if (funcDAG.find(inputF) == funcDAG.end()) { + std::unique_ptr subDAG = std::make_unique(); + subDAG->name = inputF->getName(); + subDAG->logicalDevice = logicalID++; + funcDAG[inputF] = subDAG.get(); + partitions_.nodes.push_back(std::move(subDAG)); + } + + // subF is a child of inputF, inputF is a parent of subF. + funcDAG[inputF]->children.push_back(funcDAG[subF]); + funcDAG[subF]->parents.push_back(funcDAG[inputF]); + + // If we've already created a placeholder for this dependence, use it. + auto it = placeholders.find(input.getNode()); + if (it != placeholders.end()) { + N.setNthInput(inp, it->second); + continue; + } + + // Create a new placeholder to represent this dependence. + auto *save = inputF->createSave("tmp", input); + auto *tmp = save->getPlaceholder(); + placeholders[input.getNode()] = tmp; + N.setNthInput(inp, tmp); + } + } + } + + // Update links between nodes in the cloned functions. Add placeholders (and + // save nodes) where a link crosses a partition boundary. + for (auto *subF : mapping.getPartitions()) { + for (auto &N : subF->getNodes()) { + for (int inp = 0, e = N.getNumInputs(); inp < e; inp++) { + auto input = N.getNthInput(inp); + if (isa(input.getNode())) + continue; + // Link this node to the clone of its input. + auto *clone = currToNew[input.getNode()]; + N.setNthInput(inp, NodeValue(clone, input.getResNo())); + } + } + } + + // For all DAGNode without parents, link them to the root DAG. + for (auto *subF : mapping.getPartitions()) { + if (funcDAG[subF]->parents.size() == 0) { + funcDAG[subF]->parents.push_back(DAG.get()); + root->children.push_back(funcDAG[subF]); + } + } + + // Adjust the logicalDevice for each DAGNode. + if (mapping.getPartitions().size() > deviceInfo_.size()) { + adjustLogicalDeviceID(DAG.get(), mapping.getPartitions().size()); + } +} + +DAGNodeList &Partitioner::Partition() { + + // Find the representive function for running partitioning algrithm. + F_ = selectRepFunc(module_, memSize_); + + // Possible minimal k devices for a succesful partitioning + // Note: here 2 is for testing; + unsigned k = 2; //(memSize_ + MARGIN) / devices[0].availableMemory; + + if (k == 1) { + // No partition is needed. Create DAGNode and return. This root is alway a + // dummy function. + for (auto F : module_->getFunctions()) { + std::unique_ptr DAG = std::make_unique(); + DAG->logicalDevice = 0; + DAG->name = F->getName(); + std::unique_ptr DAG1 = std::make_unique(); + DAG1->logicalDevice = 0; + DAG1->name = F->getName(); + DAG1->parents.push_back(DAG.get()); + DAG->children.push_back(DAG1.get()); + partitions_.roots.push_back(std::move(DAG)); + partitions_.nodes.push_back(std::move(DAG1)); + } + return partitions_; + } + + // Prepare 1: Get the min memory usage for each op. + initOpMemUsage(); + + // Prepare 2: TODO: get the minimal comunication cost for any 2 ops (i.e. the + // output data size) Will calculate it on the fly. -- Will double check which + // way is better. + + // Partition + // Use BFS to do the initial partitioning. Starting from the final node, BFS + // until the memory limitation reached one by one. + unsigned unitMem = memSize_ / k; // used for testing + + NodeToFunctionMap partitionMap = selectPartitions(F_, unitMem); + + doPartitioning(F_, partitionMap); + + // Remove the original function after partitioning. + module_->eraseFunction(F_); + + auto funcList = module_->getFunctions(); + for (Function *F : funcList) { + (void)F; + assert(F->verify() && "Conversion led to invalid function"); + } + + // TODO: Optional: if (k < number of devices) + // Check the computation time of each sub-module, and find out the "key" + // sub-module to decide if duplicating the sub-module is necessary. + + return partitions_; +} diff --git a/lib/Quantization/Base/Base.cpp b/lib/Quantization/Base/Base.cpp index 97f0aeee05..0844931155 100644 --- a/lib/Quantization/Base/Base.cpp +++ b/lib/Quantization/Base/Base.cpp @@ -364,10 +364,13 @@ std::vector createMapping(TypeRef inTy, TypeRef outTy, void tensorRowwiseQuantization(const Tensor &input, Tensor &output, Tensor &scales, Tensor &offsets) { - ShapeHW idim(input.dims()); + const auto fDims = flattenCdr(input.dims()); + Tensor finalIn = input.getUnowned({fDims.first, fDims.second}); + Tensor finalOut = output.getUnowned({fDims.first, fDims.second}); + ShapeHW idim(finalIn.dims()); - auto srcH = input.getHandle(); - auto destH = output.getHandle(); + auto srcH = finalIn.getHandle(); + auto destH = finalOut.getHandle(); auto scalesH = scales.getHandle(); auto offsetsH = offsets.getHandle(); for (size_t i = 0; i < idim.height; i++) { diff --git a/lib/Quantization/Quantization.cpp b/lib/Quantization/Quantization.cpp index 13cfece4d5..2321e7693e 100644 --- a/lib/Quantization/Quantization.cpp +++ b/lib/Quantization/Quantization.cpp @@ -195,10 +195,8 @@ class FunctionQuantizer : public FunctionConverter { // clang-format off #define casesForNodesWithIRConstraint \ IR_CONSTRAINT_CASE(LocalResponseNormalization, Input, Result): \ - IR_CONSTRAINT_CASE(Sigmoid, Input, Result): \ IR_CONSTRAINT_CASE(Slice, Input, Result): \ IR_CONSTRAINT_CASE(Reshape, Input, Result): \ - IR_CONSTRAINT_CASE(Tanh, Input, Result): \ IR_CONSTRAINT_CASE(TopK, Input, Values): \ IR_CONSTRAINT_CASE(Gather, Data, Result): \ IR_CONSTRAINT_CASE(MaxPool, Input, Result) @@ -460,7 +458,8 @@ namespace glow { namespace quantization { std::vector -generateNodeQuantizationInfos(Context &ctx, const Function *F, Schema schema) { +generateNodeQuantizationInfos(Context &ctx, const Function *F, Schema schema, + ElemKind quantizationPrecision) { std::vector quantizationInfos; for (auto &node : F->getNodes()) { @@ -480,7 +479,8 @@ generateNodeQuantizationInfos(Context &ctx, const Function *F, Schema schema) { // TODO: Ideally tensor quantization params should be calculated // based on the histogram distribution. Use simplistic approach for now. (void)histogram; - TensorQuantizationParams TQP = chooseQuantizationParams(min, max, schema); + TensorQuantizationParams TQP = + chooseQuantizationParams(min, max, schema, quantizationPrecision); quantizationInfos.emplace_back(fullOutputName, TQP); } diff --git a/lib/Runtime/CMakeLists.txt b/lib/Runtime/CMakeLists.txt new file mode 100644 index 0000000000..f16decc8e3 --- /dev/null +++ b/lib/Runtime/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Provisioner) \ No newline at end of file diff --git a/lib/Runtime/Provisioner/CMakeLists.txt b/lib/Runtime/Provisioner/CMakeLists.txt new file mode 100644 index 0000000000..86eaf3591f --- /dev/null +++ b/lib/Runtime/Provisioner/CMakeLists.txt @@ -0,0 +1,6 @@ +add_library(Provisioner + Provisioner.cpp) + +target_link_libraries(Provisioner + PRIVATE + Backends) diff --git a/lib/Runtime/Provisioner/Provisioner.cpp b/lib/Runtime/Provisioner/Provisioner.cpp new file mode 100644 index 0000000000..e8a8579c68 --- /dev/null +++ b/lib/Runtime/Provisioner/Provisioner.cpp @@ -0,0 +1,106 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "glow/Runtime/Provisioner/Provisioner.h" +#include "glow/Backends/BackendUtils.h" +#include "glow/Backends/CompiledFunction.h" +#include "glow/Graph/Graph.h" + +#include +#include + +using namespace glow; +using namespace runtime; +using DeviceID = unsigned int; +ResultCode Provisioner::provision( + std::vector> &networks, + std::map> &devices, + Module &module) { + // For the first pass we will just assign and load devices in order and update + // the deviceID field of the node. + std::queue> nextNode; + // Process head node, this does not contain a function but serves as an entry + // point for the network. We build a vector of nodes, containing all + // sub-functions that use the same constants. Later we will group by + // logicalDevice. + for (int i = 0; i < networks[0]->children.size(); i++) { + std::vector newSet; + for (auto &node : networks) { + newSet.push_back(node->children[i]); + } + nextNode.push(newSet); + } + while (!nextNode.empty()) { + std::map compiledFunctions; + std::vector> functions; + auto nodes = nextNode.front(); + nextNode.pop(); + // Add child nodes to the queue. + for (int i = 0; i < nodes[0]->children.size(); i++) { + std::vector newSet; + for (auto node : nodes) { + newSet.push_back(node->children[i]); + } + nextNode.push(newSet); + } + // Assign collection of nodes to a device, compile and load the device. + // We will do a round robin assignment of nodes. If there is not space we + // will return an error. + // TODO Add ability to try against another device when currDevice has + // insufficient space. + auto currDevice = devices.begin(); + // Set backend to match the device. + backend_.reset(createBackend(currDevice->second->getBackendKind())); + // Iterate over the nodes, compile them and add them to compiledFunctions. + for (auto node : nodes) { + node->deviceID = currDevice->first; + Function *function = module.getFunction(node->name); + auto compiled = backend_->compile(function); + node->runtimeBundle = compiled->getRuntimeBundle(); + compiledFunctions.emplace(node->name, compiled.get()); + functions.push_back(std::move(compiled)); + } + // Check if sufficient space on device. Currently requiring a buffer + // over the size of constants determined by NETWORK_PADDING_FACTOR. + auto availableMemory = currDevice->second->getAvailableMemory(); + if (availableMemory < NETWORK_PADDING_FACTOR * + nodes[0]->runtimeBundle.getConstantWeightSize()) { + return ResultCode::Failed; + } + // Load functions on device. + std::promise addNetwork; + auto ready = addNetwork.get_future(); + currDevice->second->addNetwork( + &module, compiledFunctions, + [&addNetwork](const Module *, ResultCode result) { + if (result == ResultCode::Ready) { + addNetwork.set_value(true); + } else { + addNetwork.set_value(false); + } + }); + auto result = ready.get(); + if (!result) { + return ResultCode::Failed; + } + currDevice++; + // Handle wrapping around to start of devices again. + if (currDevice == devices.end()) { + currDevice = devices.begin(); + } + } + return ResultCode::Ready; +}; diff --git a/lib/Support/Error.cpp b/lib/Support/Error.cpp index f438350633..73e9153afa 100644 --- a/lib/Support/Error.cpp +++ b/lib/Support/Error.cpp @@ -20,9 +20,6 @@ namespace glow { llvm::ExitOnError exitOnErr("Encountered an error, exiting.\n"); -std::string addFileAndLineToError(llvm::StringRef str, llvm::StringRef file, - uint32_t line) { - return llvm::formatv("Error at file {0} line {1} \"{2}\"", file, line, str); -} - +/// ID used by llvm::ErrorInfo::isA's dynamic typing. +uint8_t const GlowErr::ID = 0; } // namespace glow diff --git a/lib/Support/Support.cpp b/lib/Support/Support.cpp index c4952ff7a4..42e2338641 100644 --- a/lib/Support/Support.cpp +++ b/lib/Support/Support.cpp @@ -18,6 +18,7 @@ #include "llvm/Support/Debug.h" #include +#include #include #include @@ -77,4 +78,25 @@ std::string escapeDottyString(const std::string &str) { } void report(const char *msg) { errs() << msg; } + +const std::string strFormat(const char *format, ...) { + // Initialize use of varargs. + va_list vaArgs; + va_start(vaArgs, format); + + // Create a copy of the varargs. + va_list vaArgsCopy; + va_copy(vaArgsCopy, vaArgs); + // Compute the length of the output to be produced. + // The vsnprintf call does not actually write anything, but properly computes + // the amount of characters that would be written. + const int len = vsnprintf(NULL, 0, format, vaArgsCopy); + va_end(vaArgsCopy); + + // Create a formatted string without any risk of memory issues. + std::vector str(len + 1); + std::vsnprintf(str.data(), str.size(), format, vaArgs); + va_end(vaArgs); + return std::string(str.data(), len); +} } // namespace glow diff --git a/tests/unittests/BackendCorrectnessTest.cpp b/tests/unittests/BackendCorrectnessTest.cpp index f96d8f68ff..b06d1dafa4 100644 --- a/tests/unittests/BackendCorrectnessTest.cpp +++ b/tests/unittests/BackendCorrectnessTest.cpp @@ -316,7 +316,7 @@ TEST_P(CPUOnly, dataParallelStackingTest) { auto function = backend.compileIR(std::move(M)); function->setupRuns(); function->beforeRun(ctx); - function->execute(); + function->execute(&ctx); function->afterRun(ctx); function->tearDownRuns(); auto H = outputTensor->getHandle(); diff --git a/tests/unittests/BackendTest.cpp b/tests/unittests/BackendTest.cpp index b96dfc87ec..cc4d0682e8 100644 --- a/tests/unittests/BackendTest.cpp +++ b/tests/unittests/BackendTest.cpp @@ -163,7 +163,7 @@ TEST_P(BackendTest, debugPrint) { auto function = backend->compileIR(std::move(IR)); function->setupRuns(); function->beforeRun(ctx); - function->execute(); + function->execute(&ctx); function->afterRun(ctx); function->tearDownRuns(); } diff --git a/tests/unittests/BackendTestUtils.h b/tests/unittests/BackendTestUtils.h index 9685374bbc..93ad2dcde7 100644 --- a/tests/unittests/BackendTestUtils.h +++ b/tests/unittests/BackendTestUtils.h @@ -23,11 +23,7 @@ namespace glow { /// MockBackend used only for unit testing. class MockBackend : public Backend { class MockFunction : public CompiledFunction { - void execute() override{}; - void setupRuns() override{}; - void beforeRun(const Context &ctx) override{}; - void afterRun(const Context &ctx) override{}; - void tearDownRuns() override{}; + void execute(Context *) override{}; }; std::unique_ptr compile(Function *F) const override { return llvm::make_unique(); diff --git a/tests/unittests/CMakeLists.txt b/tests/unittests/CMakeLists.txt index d40824ab1a..d58cd5f6a8 100755 --- a/tests/unittests/CMakeLists.txt +++ b/tests/unittests/CMakeLists.txt @@ -141,7 +141,6 @@ target_link_libraries(GraphSchedulerTest IR gtest TestMain) -target_include_directories(GraphSchedulerTest PRIVATE ${CMAKE_SOURCE_DIR}/lib/IR) add_glow_test(GraphSchedulerTest ${GLOW_BINARY_DIR}/tests/GraphSchedulerTest --gtest_output=xml:GraphSchedulerTest.xml) add_executable(QuantizationTest @@ -190,6 +189,7 @@ add_glow_test(ThreadPoolTest ${GLOW_BINARY_DIR}/tests/ThreadPoolTest --gtest_out add_executable(UtilsTest StrCheck.cpp + SupportTest.cpp TaggedListTest.cpp UtilsTest.cpp) target_link_libraries(UtilsTest @@ -282,12 +282,11 @@ target_link_libraries(LLVMIRGenTest Support gtest TestMain) -target_include_directories(LLVMIRGenTest PUBLIC ${CMAKE_SOURCE_DIR}/lib/Backends/CPU) add_glow_test(LLVMIRGenTest ${GLOW_BINARY_DIR}/tests/LLVMIRGenTest --gtest_output=xml:LLVMIRGenTest.xml) -add_executable(cpuDeviceTest +add_executable(CPUDeviceTest CPUDeviceManagerTest.cpp) -target_link_libraries(cpuDeviceTest +target_link_libraries(CPUDeviceTest PRIVATE Backends DeviceManager @@ -298,8 +297,7 @@ target_link_libraries(cpuDeviceTest Optimizer gtest TestMain) -target_include_directories(cpuDeviceTest PUBLIC ${CMAKE_SOURCE_DIR}/lib/Backends/CPU) -add_glow_test(cpuDeviceTest ${GLOW_BINARY_DIR}/tests/cpuDeviceTest --gtest_output=xml:cpuDeviceTest.xml) +add_glow_test(CPUDeviceTest ${GLOW_BINARY_DIR}/tests/CPUDeviceTest --gtest_output=xml:CPUDeviceTest.xml) endif() @@ -313,6 +311,19 @@ target_link_libraries(MemoryAllocatorTest TestMain) add_glow_test(MemoryAllocatorTest ${GLOW_BINARY_DIR}/tests/MemoryAllocatorTest --gtest_output=xml:MemoryAllocatorTest.xml) +add_executable(PartitionerTest + PartitionerTest.cpp) +target_link_libraries(PartitionerTest + PRIVATE + Backends + ExecutionEngine + Graph + IR + Partitioner + gtest + TestMain) +add_glow_test(PartitionerTest ${GLOW_BINARY_DIR}/tests/PartitionerTest --gtest_output=xml:PartitionerTest.xml) + add_executable(Caffe2ImporterTest Caffe2ImporterTest.cpp) target_link_libraries(Caffe2ImporterTest @@ -346,13 +357,25 @@ target_link_libraries(GlowOnnxifiManagerTest onnxifi-glow-lib gtest TestMain) -target_include_directories(GlowOnnxifiManagerTest - PRIVATE - ${CMAKE_SOURCE_DIR}/lib/Onnxifi ${GLOW_THIRDPARTY_DIR}/onnx) add_glow_test(NAME GlowOnnxifiManagerTest COMMAND ${GLOW_BINARY_DIR}/tests/GlowOnnxifiManagerTest --gtest_output=xml:GlowOnnxifiManagerTest.xml WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) +add_executable(ProvisionerTest + ProvisionerTest.cpp) +target_link_libraries(ProvisionerTest + PRIVATE + Backends + Graph + IR + Provisioner + DeviceManager + CPUDeviceManager + gtest + TestMain) +target_include_directories(ProvisionerTest PUBLIC ${CMAKE_SOURCE_DIR}/lib/Backends/CPU) +add_glow_test(ProvisionerTest ${GLOW_BINARY_DIR}/tests/ProvisionerTest --gtest_output=xml:ProvisionerTest.xml) + LIST(APPEND UNOPT_TESTS ./tests/BackendTest -optimize-ir=false && ./tests/MLTest -optimize-ir=false && diff --git a/tests/unittests/CPUDeviceManagerTest.cpp b/tests/unittests/CPUDeviceManagerTest.cpp index 7343c92d3a..a4aaa6c6c2 100644 --- a/tests/unittests/CPUDeviceManagerTest.cpp +++ b/tests/unittests/CPUDeviceManagerTest.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "CPUDeviceManager.h" +#include "../../lib/Backends/CPU/CPUDeviceManager.h" #include "glow/ExecutionEngine/ExecutionEngine.h" #include "gtest/gtest.h" @@ -41,42 +41,13 @@ std::unique_ptr makeBasicModule(std::string functionName = "main") { return module; } -// TODO: This really should be a helper function somewhere -void optimizeFunction(Backend *backend, CompilationMode mode, Function *F) { - // Verify the function pre-optimization/lowering. - assert(F->verify() && "Function must be valid"); - - // Optimize the graph. - ::glow::optimize(F, mode); - - // Allow the backend to transform the graph prior to lowering. - if (backend->transformPreLowering(F, mode)) { - // Optimize the graph again after the backend transformation. - // In particular, DCE is very likely to be useful. - ::glow::optimize(F, mode); - } - - // Lower the graph into a sequence of low-level linear algebra operations. - ::glow::lower(F, *backend); - - // Optimize the graph again. - ::glow::optimize(F, mode); - - // Allow the backend to transform the graph after lowering. - if (backend->transformPostLowering(F, mode)) { - // Optimize the graph again after the backend transformation. - // In particular, DCE is very likely to be useful. - ::glow::optimize(F, mode); - } -} - FunctionMapTy compileFunctions(Module *module, std::vector> &backing) { FunctionMapTy results; auto *backend = createBackend(BackendKind::CPU); for (auto *F : module->getFunctions()) { - optimizeFunction(backend, CompilationMode::Infer, F); + backend->optimizeFunction(CompilationMode::Infer, F); auto f = backend->compile(F); backing.push_back(std::move(f)); results.emplace(F->getName(), backing.back().get()); @@ -113,7 +84,8 @@ TEST(CPUDeviceManagerTest, Basic) { cpuCoreDevice.addNetwork(module.get(), std::move(functions), [&promise](const Module *module, ResultCode result) { - callbackHelper(promise, module, result, Ready); + callbackHelper(promise, module, result, + ResultCode::Ready); }); future.wait_for(2s); @@ -134,7 +106,7 @@ TEST(CPUDeviceManagerTest, Basic) { [&runPromise](RunIdentifierTy, ResultCode result, std::unique_ptr ctx_) { callbackHelper(runPromise, std::move(ctx_), - result, Executed); + result, ResultCode::Executed); }); runFuture.wait_for(2s); @@ -155,7 +127,8 @@ TEST(CPUDeviceManagerTest, MultiRun) { std::tie(promise, future) = getFutureHelper(); cpuCoreDevice.addNetwork(module.get(), std::move(functions), [&promise](const Module *module, ResultCode result) { - callbackHelper(promise, module, result, Ready); + callbackHelper(promise, module, result, + ResultCode::Ready); }); future.wait_for(2s); EXPECT_EQ(future.get(), module.get()); @@ -185,14 +158,14 @@ TEST(CPUDeviceManagerTest, MultiRun) { [&runP1](RunIdentifierTy, ResultCode result, std::unique_ptr ctx_) { callbackHelper(runP1, std::move(ctx_), result, - Executed); + ResultCode::Executed); }); cpuCoreDevice.runFunction("main", std::move(ctx2), [&runP2](RunIdentifierTy, ResultCode result, std::unique_ptr ctx_) { callbackHelper(runP2, std::move(ctx_), result, - Executed); + ResultCode::Executed); }); ctx1 = runF1.get(); @@ -225,7 +198,8 @@ TEST(CPUDeviceManagerTest, MultiFunction) { std::tie(promise, future) = getFutureHelper(); cpuCoreDevice.addNetwork(module.get(), std::move(functions), [&promise](const Module *module, ResultCode result) { - callbackHelper(promise, module, result, Ready); + callbackHelper(promise, module, result, + ResultCode::Ready); }); future.wait_for(2s); EXPECT_EQ(future.get(), module.get()); @@ -245,14 +219,14 @@ TEST(CPUDeviceManagerTest, MultiFunction) { [&runP1](RunIdentifierTy, ResultCode result, std::unique_ptr ctx_) { callbackHelper(runP1, std::move(ctx_), result, - Executed); + ResultCode::Executed); }); cpuCoreDevice.runFunction("func2", std::move(ctx2), [&runP2](RunIdentifierTy, ResultCode result, std::unique_ptr ctx_) { callbackHelper(runP2, std::move(ctx_), result, - Executed); + ResultCode::Executed); }); ctx1 = runF1.get(); @@ -276,7 +250,8 @@ TEST(CPUDeviceManagerTest, MultiModule) { std::tie(promise, future) = getFutureHelper(); cpuCoreDevice.addNetwork(module1.get(), std::move(functions1), [&promise](const Module *module, ResultCode result) { - callbackHelper(promise, module, result, Ready); + callbackHelper(promise, module, result, + ResultCode::Ready); }); future.wait_for(2s); EXPECT_EQ(future.get(), module1.get()); @@ -284,7 +259,8 @@ TEST(CPUDeviceManagerTest, MultiModule) { std::tie(promise, future) = getFutureHelper(); cpuCoreDevice.addNetwork(module2.get(), std::move(functions2), [&promise](const Module *module, ResultCode result) { - callbackHelper(promise, module, result, Ready); + callbackHelper(promise, module, result, + ResultCode::Ready); }); future.wait_for(2s); EXPECT_EQ(future.get(), module2.get()); @@ -309,14 +285,14 @@ TEST(CPUDeviceManagerTest, MultiModule) { [&runP1](RunIdentifierTy, ResultCode result, std::unique_ptr ctx_) { callbackHelper(runP1, std::move(ctx_), result, - Executed); + ResultCode::Executed); }); cpuCoreDevice.runFunction("func2", std::move(ctx2), [&runP2](RunIdentifierTy, ResultCode result, std::unique_ptr ctx_) { callbackHelper(runP2, std::move(ctx_), result, - Executed); + ResultCode::Executed); }); ctx1 = runF1.get(); @@ -340,11 +316,11 @@ TEST(CPUDeviceManagerTest, AvailableMemory) { auto module = makeBasicModule(); std::tie(promise, future) = getFutureHelper(); - cpuCoreDevice.addNetwork(module.get(), - compileFunctions(module.get(), backing), - [&promise](const Module *module, ResultCode result) { - callbackHelper(promise, module, result, Ready); - }); + cpuCoreDevice.addNetwork( + module.get(), compileFunctions(module.get(), backing), + [&promise](const Module *module, ResultCode result) { + callbackHelper(promise, module, result, ResultCode::Ready); + }); future.wait_for(2s); EXPECT_EQ(future.get(), module.get()); @@ -357,11 +333,11 @@ TEST(CPUDeviceManagerTest, AvailableMemory) { // Let's try again. auto module2 = makeBasicModule(); std::tie(promise, future) = getFutureHelper(); - cpuCoreDevice.addNetwork(module2.get(), - compileFunctions(module2.get(), backing), - [&promise](const Module *module, ResultCode result) { - callbackHelper(promise, module, result, Ready); - }); + cpuCoreDevice.addNetwork( + module2.get(), compileFunctions(module2.get(), backing), + [&promise](const Module *module, ResultCode result) { + callbackHelper(promise, module, result, ResultCode::Ready); + }); future.wait_for(2s); auto *resultModule = future.get(); @@ -377,11 +353,11 @@ TEST(CPUDeviceManagerTest, AvailableMemory) { // And try again, this time with available space. std::tie(promise, future) = getFutureHelper(); - cpuCoreDevice.addNetwork(module2.get(), - compileFunctions(module2.get(), backing), - [&promise](const Module *module, ResultCode result) { - callbackHelper(promise, module, result, Ready); - }); + cpuCoreDevice.addNetwork( + module2.get(), compileFunctions(module2.get(), backing), + [&promise](const Module *module, ResultCode result) { + callbackHelper(promise, module, result, ResultCode::Ready); + }); future.wait_for(2s); EXPECT_EQ(future.get(), module2.get()); diff --git a/tests/unittests/GlowOnnxifiManagerTest.cpp b/tests/unittests/GlowOnnxifiManagerTest.cpp index df53625563..7bba2714d8 100644 --- a/tests/unittests/GlowOnnxifiManagerTest.cpp +++ b/tests/unittests/GlowOnnxifiManagerTest.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "GlowOnnxifiManager.h" +#include "../../lib/Onnxifi/GlowOnnxifiManager.h" #include "gtest/gtest.h" diff --git a/tests/unittests/GraphSchedulerTest.cpp b/tests/unittests/GraphSchedulerTest.cpp index e65c94bf83..f2fb2f1bfc 100644 --- a/tests/unittests/GraphSchedulerTest.cpp +++ b/tests/unittests/GraphSchedulerTest.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "GraphScheduler.h" +#include "../../lib/IR/GraphScheduler.h" #include "glow/Graph/Context.h" #include "glow/Graph/Graph.h" diff --git a/tests/unittests/LLVMIRGenTest.cpp b/tests/unittests/LLVMIRGenTest.cpp index f50b218dc5..51c21361d7 100644 --- a/tests/unittests/LLVMIRGenTest.cpp +++ b/tests/unittests/LLVMIRGenTest.cpp @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "LLVMIRGen.h" -#include "AllocationsInfo.h" +#include "../../lib/Backends/CPU/LLVMIRGen.h" +#include "../../lib/Backends/CPU/AllocationsInfo.h" #include "glow/IR/IR.h" diff --git a/tests/unittests/OperatorTest.cpp b/tests/unittests/OperatorTest.cpp index 6fb966aed7..0933ba05ee 100644 --- a/tests/unittests/OperatorTest.cpp +++ b/tests/unittests/OperatorTest.cpp @@ -4037,6 +4037,114 @@ TEST_P(InterpOnly, SparseLengthsWeightedSumI8) { EXPECT_TRUE(expected.isEqual(result)); } +TEST_P(InterpAndCPU, RowwiseQuantizedSparseLengthsWeightedSum) { + /* + DATA = [2.0, -0.5, 13] + WEIGHTS = [3, 1, 0, 0, 0, 0, 2, -0.5] + INDICES = [1, 0, 2, 0, 1, 2, 2, 0] + LENGTHS = [3, 0, 3, 2] + OUTPUT = [0.5, 0, 0, 25] + */ + Tensor data(ElemKind::FloatTy, {3}); + data.getHandle() = { + 2.0, + -0.5, + 13, + }; + + Constant *weights = mod_.createConstant(ElemKind::FloatTy, {8}, "weights"); + weights->getPayload().getHandle() = { + 3., 1., 0., 0., 0., 0., 2., -0.5, + }; + + Placeholder *indices = + mod_.createPlaceholder(ElemKind::Int64ITy, {8}, "indices", + /* isTrainable */ false); + Placeholder *lengths = + mod_.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths", + /* isTrainable */ false); + + ctx_.allocate(indices)->getHandle() = { + 1, 0, 2, 0, 1, 2, 2, 0, + }; + ctx_.allocate(lengths)->getHandle() = { + 3, + 0, + 3, + 2, + }; + + auto *R = F_->createRowwiseQuantizedSparseLengthsWeightedSum( + "RQSLWS", data, weights, indices, lengths); + SaveNode *S = F_->createSave("save", R); + ctx_.allocate(S->getPlaceholder()); + + EE_.compile(CompilationMode::Infer, F_); + EE_.run(ctx_); + + Tensor &result = *ctx_.get(S->getPlaceholder()); + Tensor expected(ElemKind::FloatTy, {4}); + expected.getHandle() = { + 0.5, + 0, + 0, + 25, + }; + + EXPECT_TRUE(expected.isEqual(result)); +} + +TEST_P(InterpAndCPU, RowwiseQuantizedSparseLengthsSum) { + /* + DATA = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + INDICES = [2, 0, 1, 2, 0, 0, 0, 0] + LENGTHS = [2, 0, 2, 1, 3] + OUTPUT = [ + [5.5, 6.9], + [0.0, 0.0], + [6.8, 9.1], + [1.0, 1.2], + [3.0, 3.6], + ] + */ + Tensor data(ElemKind::FloatTy, {3, 2}); + data.getHandle() = { + 1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f, + }; + + Placeholder *indices = mod_.createPlaceholder( + ElemKind::Int64ITy, {8}, "indices", /* isTrainable */ false); + Placeholder *lengths = mod_.createPlaceholder( + ElemKind::Int32ITy, {5}, "lengths", /* isTrainable */ false); + + ctx_.allocate(indices)->getHandle() = { + 2, 0, 1, 2, 0, 0, 0, 0, + }; + ctx_.allocate(lengths)->getHandle() = { + 2, 0, 2, 1, 3, + }; + + auto *R = F_->createRowwiseQuantizedSparseLengthsSum("RQSLWS", data, indices, + lengths); + SaveNode *S = F_->createSave("save", R); + ctx_.allocate(S->getPlaceholder()); + + EE_.compile(CompilationMode::Infer, F_); + EE_.run(ctx_); + + Tensor &result = *ctx_.get(S->getPlaceholder()); + Tensor expected(ElemKind::FloatTy, {5, 2}); + expected.getHandle() = { + 5.5f, 6.9f, 0.0f, 0.0f, 6.8f, 9.1f, 1.0f, 1.2f, 3.0f, 3.6f, + }; + + EXPECT_TRUE(expected.isEqual(result, 0.02)); +} + TEST_P(InterpAndCPU, SparseToDense) { // Create and initialize inputs. Make input 3D to make sure // multidimensional values are handled properly. diff --git a/tests/unittests/PartitionerTest.cpp b/tests/unittests/PartitionerTest.cpp new file mode 100644 index 0000000000..e8d57d9102 --- /dev/null +++ b/tests/unittests/PartitionerTest.cpp @@ -0,0 +1,118 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "glow/Partitioner/Partitioner.h" +#include "glow/ExecutionEngine/ExecutionEngine.h" +#include "glow/Graph/Graph.h" + +#include "gtest/gtest.h" + +using namespace glow; + +class PartitionerTest : public ::testing::Test { +public: + PartitionerTest() : F_(mod_.createFunction("main")) {} + +protected: + Module mod_; + Function *F_; + Context ctx_; +}; + +/// Execute a graph of functions based on the given DAG. +static void executeDAG(DAGNode *G, Module &mod, Context &ctx, + llvm::ArrayRef vars, + llvm::ArrayRef inputs) { + std::unordered_map name2func; + + for (auto *F : mod.getFunctions()) { + name2func[F->getName()] = F; + } + + std::vector exeList; + int endPt = 0; + int curPt = 0; + // The first node is always the dummy node. + exeList.push_back(G); + endPt++; + while (curPt < endPt) { + DAGNode *dag = exeList.at(curPt); + // The root in a G is always a dummy function. + if (curPt > 0) { + ExecutionEngine EE; + Function *func = name2func[dag->name]; + EE.compile(CompilationMode::Infer, func); + updateInputPlaceholders(ctx, vars, inputs); + EE.run(ctx); + } + for (int i = 0, e = dag->children.size(); i < e; i++) { + exeList.push_back(dag->children.at(i)); + endPt++; + } + curPt++; + } +} + +TEST_F(PartitionerTest, test1) { + auto *input = + mod_.createPlaceholder(ElemKind::FloatTy, {1, 32}, "input", false); + ctx_.allocate(input); + + // Initial FC. + Node *I = F_->createFullyConnected(ctx_, "initial_fc", input, 16); + I = F_->createSigmoid("initial_sigmoid", I); + + // Left branch. + Node *L = F_->createFullyConnected(ctx_, "left_fc1", I, 16); + L = F_->createSigmoid("left_sigmoid1", L); + L = F_->createFullyConnected(ctx_, "left_fc2", L, 8); + L = F_->createSigmoid("left_sigmoid2", L); + + // Right branch. + Node *R = F_->createFullyConnected(ctx_, "right_fc1", I, 16); + R = F_->createSigmoid("right_sigmoid1", R); + R = F_->createFullyConnected(ctx_, "right_fc2", R, 8); + R = F_->createSigmoid("right_sigmoid2", R); + + // Join branches. + auto *mul = F_->createMul("mul", L, R); + auto *save = F_->createSave("ret", mul); + auto &res = *ctx_.allocate(save->getPlaceholder()); + + // Infer using the un-partitioned graph. + Tensor in(ElemKind::FloatTy, {1, 32}); + ExecutionEngine EE; + + EE.compile(CompilationMode::Infer, F_); + updateInputPlaceholders(ctx_, {input}, {&in}); + EE.run(ctx_); + Tensor ref = res.clone(); + + std::vector devices; + Partitioner myPartitioner(&mod_, devices); + + DAGNodeList myList = std::move(myPartitioner.Partition()); + ASSERT_EQ(mod_.getFunctions().size(), 3); + ASSERT_EQ(myList.roots.size(), 1); + + // Run the paritioned graph and compare the results. + ctx_.allocate(mod_.getPlaceholders()); + for (auto it = myList.roots.begin(); it != myList.roots.end(); ++it) { + ctx_.allocate(mod_.getPlaceholders()); + executeDAG((*it).get(), mod_, ctx_, {input}, {&in}); + Tensor test = res.clone(); + EXPECT_TRUE(ref.isEqual(test)); + } +} diff --git a/tests/unittests/ProvisionerTest.cpp b/tests/unittests/ProvisionerTest.cpp new file mode 100644 index 0000000000..b19cc717dd --- /dev/null +++ b/tests/unittests/ProvisionerTest.cpp @@ -0,0 +1,74 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "glow/Runtime/Provisioner/Provisioner.h" +#include "CPUDeviceManager.h" + +#include "gtest/gtest.h" + +using namespace glow; +using namespace glow::runtime; +using DAGNodePairTy = std::pair>, + std::vector>>; + +class ProvisionerTest : public ::testing::Test {}; +std::unique_ptr setupModule(uint functionCount) { + std::unique_ptr module = std::make_unique(); + for (int i = 0; i < functionCount; i++) { + Function *F = module->createFunction("function" + std::to_string(i)); + auto *X = module->createPlaceholder(ElemKind::FloatTy, {3}, + "X" + std::to_string(i), false); + auto *pow = F->createPow("Pow" + std::to_string(i), X, 2.0); + F->createSave("save" + std::to_string(i), pow); + } + return module; +} + +DAGNodePairTy setupDAG(uint rootCount, uint childCount) { + std::vector> networks; + std::vector> children; + uint currentFunction = 0; + for (int root = 0; root < rootCount; root++) { + auto rootNode = std::make_unique(); + auto firstNode = std::make_unique(); + rootNode->name = "root" + std::to_string(root); + rootNode->children.push_back(firstNode.get()); + firstNode->name = "function" + std::to_string(currentFunction); + currentFunction++; + for (int child = 0; child < childCount; child++) { + auto newChild = std::make_unique(); + newChild->name = "function" + std::to_string(currentFunction); + currentFunction++; + firstNode->children.push_back(newChild.get()); + children.push_back(std::move(newChild)); + } + networks.push_back(std::move(rootNode)); + children.push_back(std::move(firstNode)); + } + return std::make_pair(std::move(networks), std::move(children)); +} + +TEST_F(ProvisionerTest, provisionDag) { + auto mod = setupModule(6); + auto networks = setupDAG(2, 2); + auto provisioner = Provisioner(); + std::map> devices; + for (int i = 0; i < 6; i++) { + std::unique_ptr device(new CPUDeviceManager); + devices.emplace(i, std::move(device)); + } + auto result = provisioner.provision(networks.first, devices, *mod.get()); + EXPECT_EQ(result, ResultCode::Ready); +} \ No newline at end of file diff --git a/tests/unittests/QuantizationTest.cpp b/tests/unittests/QuantizationTest.cpp index fd5388b54a..8a2a5356f9 100644 --- a/tests/unittests/QuantizationTest.cpp +++ b/tests/unittests/QuantizationTest.cpp @@ -388,7 +388,9 @@ static Function *createSimpleGraphForQuantization(Module *M, Context &ctx, return F; } -TEST_P(Operator, end2end) { +void testQuantizationEnd2End(ExecutionEngine &profileEE, + ExecutionEngine &backendSpecificEE, + ElemKind quantizationPrecision) { auto *mod = &profileEE.getModule(); Context ctx; @@ -408,13 +410,14 @@ TEST_P(Operator, end2end) { // Get quantization infos and build new quantized graph. std::vector QI = - quantization::generateNodeQuantizationInfos(ctx, F1); + quantization::generateNodeQuantizationInfos( + ctx, F1, quantization::Schema::Asymmetric, quantizationPrecision); // STEP2 - Use the profile to quantize a network. SaveNode *result2 = cast(F2->getNodeByName("save")); - F2 = quantization::quantizeFunction(backendSpecificEE, QI, ElemKind::Int8QTy, - F2); + F2 = quantization::quantizeFunction(backendSpecificEE, QI, + quantizationPrecision, F2); backendSpecificEE.compile(CompilationMode::Infer, F2); backendSpecificEE.run(ctx); @@ -433,6 +436,14 @@ TEST_P(Operator, end2end) { } } +TEST_P(Operator, end2endInt8) { + testQuantizationEnd2End(profileEE, backendSpecificEE, ElemKind::Int8QTy); +} + +TEST_P(Operator, end2endInt16) { + testQuantizationEnd2End(profileEE, backendSpecificEE, ElemKind::Int16QTy); +} + /// Fills the tensor \p H with some stable random integers with the seed \p seed /// and the range [0, scale). static void fillStableRandomIndex(Handle H, size_t seed, diff --git a/tests/unittests/SupportTest.cpp b/tests/unittests/SupportTest.cpp new file mode 100644 index 0000000000..a081b83d93 --- /dev/null +++ b/tests/unittests/SupportTest.cpp @@ -0,0 +1,34 @@ +/** + * Copyright (c) 2018-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "glow/Support/Support.h" +#include "glow/Testing/StrCheck.h" +#include "gtest/gtest.h" + +using namespace glow; +using glow::StrCheck; + +TEST(Support, strFormat) { + // Check single-line formatted output. + std::string str1 = strFormat("%s %d %c", "string1", 123, 'x'); + EXPECT_TRUE(StrCheck(str1).sameln("string1").sameln("123").sameln("x")); + + // Check multi-line formatted output. + std::string str2 = strFormat("%s\n%d\n%c\n", "string2", 456, 'y'); + EXPECT_TRUE(StrCheck(str2).check("string2").check("456").check("y")); + // Output is not a single line. + EXPECT_FALSE(StrCheck(str2).sameln("string2").sameln("456").sameln("y")); +} diff --git a/tools/ClassGen/Backends/OpenCL/OpenCLSpecificInstrs.h b/tools/ClassGen/Backends/OpenCL/OpenCLSpecificInstrs.h index 0ae56d71ba..12fec9d1b5 100644 --- a/tools/ClassGen/Backends/OpenCL/OpenCLSpecificInstrs.h +++ b/tools/ClassGen/Backends/OpenCL/OpenCLSpecificInstrs.h @@ -47,6 +47,7 @@ BB.newBackendSpecificInstr("OCLMaxPool") .autoIRGen() .autoVerify(VerifyKind::SameElementType, {"Dest", "Src"}); -BB.includeBackendSpecificVerification("OpenCLSpecificInstrsVerification.h"); +BB.includeBackendSpecificVerification( + "glow/OpenCLSpecificInstrsVerification.h"); #endif // GLOW_WITH_CPU diff --git a/tools/ClassGen/Backends/OpenCL/OpenCLSpecificNodes.h b/tools/ClassGen/Backends/OpenCL/OpenCLSpecificNodes.h index b02d1f3e0d..de3fe0c0d9 100644 --- a/tools/ClassGen/Backends/OpenCL/OpenCLSpecificNodes.h +++ b/tools/ClassGen/Backends/OpenCL/OpenCLSpecificNodes.h @@ -51,6 +51,6 @@ BB.newNode("OCLMaxPool") "provided " "Kernel, Stride, and Pads. The input and output are in NCHW format"); -BB.includeBackendSpecificVerification("OpenCLSpecificNodesVerification.h"); +BB.includeBackendSpecificVerification("glow/OpenCLSpecificNodesVerification.h"); #endif // GLOW_WITH_CPU diff --git a/tools/ClassGen/InstrGen.cpp b/tools/ClassGen/InstrGen.cpp index efd45c1ce2..bd8dedafca 100644 --- a/tools/ClassGen/InstrGen.cpp +++ b/tools/ClassGen/InstrGen.cpp @@ -229,6 +229,27 @@ int main(int argc, char **argv) { {"Lengths", "ElemKind::Int32ITy"}) .autoVerify(VerifyKind::SameShape, {"Weights", "Indices"}); + BB.newInstr("RowwiseQuantizedSparseLengthsWeightedSum") + .addOperand("Dest", OperandKind::Out) + .addOperand("Data", OperandKind::In) + .addOperand("Scales", OperandKind::In) + .addOperand("Offsets", OperandKind::In) + .addOperand("Weights", OperandKind::In) + .addOperand("Indices", OperandKind::In) + .addOperand("Lengths", OperandKind::In) + .autoIRGen() + .autoVerify(VerifyKind::SameElementType, {"Dest", "ElemKind::FloatTy"}) + .autoVerify(VerifyKind::SameElementType, {"Data", "ElemKind::Int8QTy"}) + .autoVerify(VerifyKind::SameElementType, {"Scales", "ElemKind::FloatTy"}) + .autoVerify(VerifyKind::SameElementType, + {"Offsets", "ElemKind::Int32ITy"}) + .autoVerify(VerifyKind::SameElementType, {"Weights", "ElemKind::FloatTy"}) + .autoVerify(VerifyKind::SameElementType, + {"Indices", "ElemKind::Int64ITy"}) + .autoVerify(VerifyKind::SameElementType, + {"Lengths", "ElemKind::Int32ITy"}) + .autoVerify(VerifyKind::SameShape, {"Weights", "Indices"}); + BB.newInstr("LengthsToRanges") .addOperand("Dest", OperandKind::Out) .addOperand("Lengths", OperandKind::In) diff --git a/tools/ClassGen/NodeGen.cpp b/tools/ClassGen/NodeGen.cpp index dd38ae4b71..6c3f31bc8e 100644 --- a/tools/ClassGen/NodeGen.cpp +++ b/tools/ClassGen/NodeGen.cpp @@ -335,6 +335,26 @@ int main(int argc, char **argv) { "Weights[0] * Slice(0) + Weights[1] * Slice(1) + ... " "It implies that len(Weights) == len(Indices)."); + BB.newNode("RowwiseQuantizedSparseLengthsWeightedSum") + .addInput("Data") + .addInput("Scales") + .addInput("Offsets") + .addInput("Weights") + .addInput("Indices") + .addInput("Lengths") + .addResultFromCtorArg() + .setDocstring("Gathers slices of the outer-most dimension of Data " + "indexed by Indices vector, and then accumulates them into " + "len(Lengths) entries: first Lengths[0] slices are " + "aggregated to Result[0], next Lengths[1] slices are " + "aggregated to Result[1], etc. I.e. sum(Lengths) must be " + "equal to len(Indices). Before doing aggregation, each " + "individual slice is scaled by its weight: Result[0] = " + "Weights[0] * Slice(0) + Weights[1] * Slice(1) + ... " + "It implies that len(Weights) == len(Indices). The input " + "data is rowwise-quantized, where the Scales and Offsets " + "are 1D tensors of length equal to the first dim of Data."); + BB.newNode("LengthsToRanges") .addInput("Lengths") .addResultFromCtorArg() diff --git a/tools/loader/Loader.cpp b/tools/loader/Loader.cpp index 2c48f0fb70..6a423b9fef 100644 --- a/tools/loader/Loader.cpp +++ b/tools/loader/Loader.cpp @@ -328,7 +328,8 @@ void Loader::generateAndSerializeQuantizationInfos(Context &ctx) { assert(!dumpProfileFileOpt.empty() && "Filename to dump serialized profile to must not be empty."); std::vector QI = - quantization::generateNodeQuantizationInfos(ctx, F_, quantizationSchema); + quantization::generateNodeQuantizationInfos(ctx, F_, quantizationSchema, + quantizationPrecision); serializeToYaml(dumpProfileFileOpt, QI); }