Skip to content

Add provisioner for new Runtime #2276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b977e28
Adding Provisioner componenet of new Runtime design.
gcatron Jan 14, 2019
ffc457f
This adds a simple provisioner and unittest for the new Runtime.
gcatron Jan 16, 2019
829d282
Removed uneeded includes
gcatron Jan 17, 2019
494cf90
[docs] Update docs.
Jan 16, 2019
aea4510
Add a small support function strFormat to create std::strings using p…
opti-mix Jan 17, 2019
e50bb37
Change ResultCode to an enum class (#2278)
nickgg Jan 17, 2019
4a26275
Comment cleanup
gcatron Jan 17, 2019
1434d89
[easy] Fix some dead stores (#2280)
bertmaher Jan 18, 2019
99655c4
[quantization] Profile with required precision.
Jan 17, 2019
ffeadad
Add the Marvell logo
nadavrot Jan 19, 2019
47eb632
[classgen] Add glow/ prefix to OpenCL include paths (#2286)
bertmaher Jan 19, 2019
232e4d4
Create GlowErr type with informational error code enum (#2283)
jackm321 Jan 21, 2019
be5a4e4
[tests] Include private headers via .. relative paths (#2287)
bertmaher Jan 22, 2019
da381d5
Implement RowwiseQuantizedSparseLengthsWeightedSum (#2282)
jfix71 Jan 22, 2019
e0335a0
Addressed comments, cleaned up code and comments
gcatron Jan 22, 2019
e0dba53
Changed Padding factor to constant
gcatron Jan 22, 2019
836c362
Added resultCode to enum use sites
gcatron Jan 22, 2019
679842d
More enum cleanup
gcatron Jan 22, 2019
843491d
Hopefully last enum...
gcatron Jan 22, 2019
aa3effc
Move optimizeFunction from ExecutionEngine into the base Backend so i…
nickgg Jan 23, 2019
cfc989f
[Partitioner] First Graph Partitioning
Dec 10, 2018
73479f5
Refactor CompiledFunction to remove per-run state (V2) (#2274)
nickgg Jan 23, 2019
c96fb78
Adding Provisioner componenet of new Runtime design.
gcatron Jan 14, 2019
8e85226
This adds a simple provisioner and unittest for the new Runtime.
gcatron Jan 16, 2019
ac05309
Removed uneeded includes
gcatron Jan 17, 2019
3c9d521
Comment cleanup
gcatron Jan 17, 2019
d92c58d
Addressed comments, cleaned up code and comments
gcatron Jan 22, 2019
b18f6be
Changed Padding factor to constant
gcatron Jan 22, 2019
575f2ac
Added resultCode to enum use sites
gcatron Jan 22, 2019
7f655ff
More enum cleanup
gcatron Jan 22, 2019
bf7fe1b
Hopefully last enum...
gcatron Jan 22, 2019
be250a4
Updated Provisioner to match Partitioner's output, updated unit test.
gcatron Jan 24, 2019
38da522
rebase
gcatron Jan 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ Note:
List of partner logos sorted alphabetically column order.
-->

|![Bitmain Logo](./docs/partners/bitmain.png) |![Esperanto Logo](./docs/partners/esperanto.png) | ![NXP Logo](./docs/partners/nxp.png) |
|![Bitmain Logo](./docs/partners/bitmain.png) |![Esperanto Logo](./docs/partners/esperanto.png) | ![Marvell Logo](./docs/partners/marvell.png) |
:-------------------------:|:-------------------------:|:-------------------------:
|![CEVA Logo](./docs/partners/ceva.png) |![Habana Logo](./docs/partners/habana.png) | ![ST Logo](./docs/partners/st.png) |
|![Cadence Logo](./docs/partners/cadence.png) | ![Intel Logo](./docs/partners/intel.png)| |
|![CEVA Logo](./docs/partners/ceva.png) |![Habana Logo](./docs/partners/habana.png) | ![NXP Logo](./docs/partners/nxp.png) |
|![Cadence Logo](./docs/partners/cadence.png) | ![Intel Logo](./docs/partners/intel.png)| ![ST Logo](./docs/partners/st.png) |


## How does it work?
Expand Down
67 changes: 38 additions & 29 deletions docs/Quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,16 @@ inference. Then, we recompile the network using this profile information to
convert the network into a quantized form, allowing for static optimization of
the quantized graph. We convert portions of the network into islands of integer
computation and aim to generate outputs in the range that the original
floating-point network produces. During the conversion, for the following types
of quantized nodes, we ignore the output's quantization params (if they are
provided) and force the output have the same quantization params as the input
floating-point network produces. During the conversion, for the following types
of quantized nodes, we ignore the output's quantization params (if they are
provided) and force the output have the same quantization params as the input
for performance purpose:
```
LocalResponseNormalizationNode
SigmoidNode
SliceNode
ReshapeNode
TanhNode
TopKNode
GatherNode
LocalResponseNormalizationNode
SliceNode
ReshapeNode
TopKNode
GatherNode
MaxPoolNode
```

Expand Down Expand Up @@ -129,9 +127,13 @@ the quantized text translator:

```./bin/text-translator -m en2gr -load-profile=en2gr.yaml -keep-original-precision-for-nodes=Add,Div```

## Caffe2 Quantized Model Support
By default, target quantization precision is int8. However, precision can be
controlled via command line parameter: `quantization-precision`. There are
two supported values: `Int8` and `Int16`.

Glow is able to support Caffe2 Resnet50 quantized model:
## Caffe2 Quantized Model Support

Glow is able to support Caffe2 Resnet50 quantized model:
https://github.com/caffe2/models/tree/master/resnet50_quantized

To support Caffe2 quantized models, Glow has:
Expand All @@ -150,16 +152,16 @@ Int8GivenTensorFill
```
- Supported int32 quantized bias.

In most of the cases, bias is quantized in int32 to improve precision
(the partial sum of the matrix-matrix multiplication is accumulated into int32,
so int32 bias can be added to the int32 partial sum for better accuracy).
Glow now supports int32 quantized bias in ```Convolution```, ```FullyConnected```
In most of the cases, bias is quantized in int32 to improve precision
(the partial sum of the matrix-matrix multiplication is accumulated into int32,
so int32 bias can be added to the int32 partial sum for better accuracy).
Glow now supports int32 quantized bias in ```Convolution```, ```FullyConnected```
and ```RowwiseQuantizedFullyConnected``` nodes.

- Supported the conversion from uint8 quantized activations to int8 quantized activations.

For the quantized Caffe2 ops, the activations are quantized to uint8. In Glow, the
activations are quantized to int_8. Therefore, for the offset read from quantized Caffe2
For the quantized Caffe2 ops, the activations are quantized to uint8. In Glow, the
activations are quantized to int_8. Therefore, for the offset read from quantized Caffe2
model, we need to subtract 128(i.e. INT8_MIN) to make the activations become int8.

## Compiler Optimizations
Expand Down Expand Up @@ -189,17 +191,24 @@ For more specific graph optimizations check [here](Optimizations.md#quantization

## Row-wise Quantization

Row-wise (or channel-wise) quantization is an important way to minimize accuracy drop.
Glow supports row-wise quantized FullyConnected node ```RowwiseQuantizedFullyConnected```
which is enabled by an image-classifier/loader option "-enable-rowwise".
Row-wise (or channel-wise) quantization is an important way to minimize accuracy
drop. Glow supports row-wise quantized FullyConnected node
```RowwiseQuantizedFullyConnected``` which is enabled by an
image-classifier/loader option "-enable-rowwise".

For the regular quantized FC, we quantize the whole weights tensor with the same
scale and offset, which are computed based on the max and min of the entire
tensor). But for row-wise, after getting ```min_i``` and ```max_i``` for each
row ```i```, we compute the pair of ```(scale_i, offset_i)``` to quantize each
element in row ```i```. The figure below shows the quantized FC node and
RowwiseQuantizedFullyConnected node. Instead of using only one tensor to
represent the quantized weights, we need 2 extra vectors ```Scales``` and
```Offsets``` to store the ```(scale, offset)``` for each row.

For the regular quantized FC, we quantize the whole weights tensor with the same
scale and offset, which are computed based on the max and min of the entire tensor).
But for row-wise, after getting ```min_i``` and ```max_i``` for each row ```i```, we compute the pair
of ```(scale_i, offset_i)``` to quantize each element in row ```i```. The figure below shows
the quantized FC node and RowwiseQuantizedFullyConnected node. Instead of using only
one tensor to represent the quantized weights, we need 2 extra vectors ```Scales```
and ```Offsets``` to store the ```(scale, offset)``` for each row.

![](rowwise_quantized_fc.png)

![](rowwise_quantized_fc.png)
Row-wise quantized SparseLengthsWeightedSum is also supported. Similar to the
above, we compute scales and offsets per row, to be used with the `Data` input
for the `RowwiseQuantizedSparseLengthsSumNode`. Scales and Offsets are inputs to
the node. Output of this node is float, matching the Caffe2 implementation.
Binary file added docs/partners/marvell.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions include/glow/Backends/Backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class Backend {
/// \returns true if the Backend wants the buffer sharing optimization
/// performed.
virtual bool shouldShareBuffers() const { return true; }

/// Optimize the Function \p F given compilation mode \p mode.
void optimizeFunction(CompilationMode mode, Function *F);
};

/// Create a backend of kind \p kind.
Expand Down
18 changes: 9 additions & 9 deletions include/glow/Backends/CompiledFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,21 @@ class CompiledFunction {
virtual ~CompiledFunction() = default;
/// Execute the network and allocate Placeholder memory with given
/// \p ctx providing mapping between Placeholder and populated tensor.
virtual void execute() = 0;
virtual void execute(Context *ctx) = 0;

/// Does any needed initialization work for the Backend.
/// This includes device init constant memory allocation and copying to
/// device.
virtual void setupRuns() = 0;
/// device. \deprecated
virtual void setupRuns() { runsSetup_ = true; }

/// Per run setup. Copy inputs to device.
virtual void beforeRun(const Context &ctx) = 0;
/// Per run setup. Copy inputs to device. \deprecated
virtual void beforeRun(const Context &ctx) {}

/// Per run cleanup. Copy outputs from device.
virtual void afterRun(const Context &ctx) = 0;
/// Per run cleanup. Copy outputs from device. \deprecated
virtual void afterRun(const Context &ctx) {}

/// Final cleanup. Release memory, reset device.
virtual void tearDownRuns() = 0;
/// Final cleanup. Release memory, reset device. \deprecated
virtual void tearDownRuns() { runsSetup_ = false; }

/// Getter for the runtimeBundle.
const runtime::RuntimeBundle &getRuntimeBundle() const {
Expand Down
3 changes: 0 additions & 3 deletions include/glow/ExecutionEngine/ExecutionEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ class ExecutionEngine final {
/// A glow function compiled for this ExecutionEngine's backend.
std::unique_ptr<CompiledFunction> function_;

/// Optimize the Function \p F given compilation mode \p mode.
void optimizeFunction(CompilationMode mode, Function *F);

public:
ExecutionEngine(BackendKind backendKind = BackendKind::Interpreter);

Expand Down
19 changes: 19 additions & 0 deletions include/glow/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,25 @@ class Function final : public Named {
NodeValue data, NodeValue weights,
NodeValue indices, NodeValue lengths);

/// Create a node, performing SparseLengthsSum operation, using rowwise
/// quantization for the input data. Gathers slices of the outer-most
/// dimension of Data indexed by Indices vector, and then accumulates them
/// into len(Lengths) entries: first Lengths[0] slices are aggregated to
/// Result[0], next Lengths[1] slices are aggregated to Result[1],
/// etc. I.e. sum(Lengths) must be equal to len(Indices).
RowwiseQuantizedSparseLengthsWeightedSumNode *
createRowwiseQuantizedSparseLengthsSum(llvm::StringRef name, Tensor &data,
NodeValue indices, NodeValue lengths);

/// Same as \ref createRowwiseQuantizedSparseLengthsSum(), but i-th slice is
/// multiplied by weights[i]. len(weights) must be equal to len(indices).
RowwiseQuantizedSparseLengthsWeightedSumNode *
createRowwiseQuantizedSparseLengthsWeightedSum(llvm::StringRef name,
Tensor &data,
NodeValue weights,
NodeValue indices,
NodeValue lengths);

/// Given a vector of segment lengths, calculates offsets of each segment and
/// packs them next to the lengths. For the input vector of length N the
/// output is a Nx2 matrix with (offset, lengths) packaged for each segment.
Expand Down
130 changes: 130 additions & 0 deletions include/glow/Partitioner/Partitioner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/**
* Copyright (c) 2017-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GLOW_PARTITIONER_PARTITIONER_H
#define GLOW_PARTITIONER_PARTITIONER_H

#include "glow/Graph/Graph.h"
#include "glow/Runtime/RuntimeTypes.h"

#include "llvm/ADT/DenseMap.h"

#include <map>
#include <set>
#include <string>

namespace glow {

using namespace runtime;

using MemUsageMap = std::unordered_map<Node *, unsigned>;

/// Helper structure for building a partition. Records a mapping of nodes in the
/// original function to destination partitions, along with a list of the
/// newly-created functions.
class NodeToFunctionMap {
using Map = llvm::DenseMap<Node *, Function *>;

/// Newly-created partitions.
FunctionList functions_;

/// Map of nodes in the original function to their target partition.
Map nodeToFunction_;

public:
/// Create a new partition \p F.
void createPartition(Function *F) { functions_.emplace_back(F); }

/// Add a new Node->Function mapping.
void add(Node *N, Function *F) { nodeToFunction_[N] = F; }

/// Get list of functions contained in this map.
const FunctionList &getPartitions() const { return functions_; }

/// Map API.
Map::iterator find(Node *N) { return nodeToFunction_.find(N); }
Map::iterator begin() { return nodeToFunction_.begin(); }
Map::iterator end() { return nodeToFunction_.end(); }
Function *operator[](Node *n) { return nodeToFunction_[n]; }
};

/// The struct contains all the created DAGNodes. This DAGNodeList owns all the
/// DAGNodes, which cannot outlive the DAGNodeList. In addition, the DAGNodes
/// can only refer to the DAGNodes from the same DAGNodeList, and they can use
/// the raw pointers to refer to each other since they are in the same
/// DAGNodeList.
struct DAGNodeList {
/// The root DAGNode pointer of each graph/function.
std::vector<std::unique_ptr<DAGNode>> roots;
/// The non-root DAGNode pointers.
std::vector<std::unique_ptr<DAGNode>> nodes;
};

/// Given a module, partitions each of the its functions into multiple ones
/// based on memory constraints and minimizes the communication cost.
class Partitioner {
/// The module that needs to be decomposed.
Module *module_;

/// The representative function used for partition. We choose the function who
/// has the largest memory size.
Function *F_;

/// The cost model related to device.
const std::vector<DeviceInfo> &deviceInfo_;

/// The result of module partitioning.
DAGNodeList partitions_;

/// Total memory (bytes) requested by one module.
size_t memSize_;

/// The map of each operator and the corresponding memory size.
MemUsageMap memUsage_;

/// Get the representative function (the one with the largest input) and
/// update the memSize.
static Function *selectRepFunc(Module *parent, size_t &memSize);

/// Get the minimal memory requirement for each op in the representive
/// function.
void initOpMemUsage();

/// Assign nodes to partitions and return the mapping.
NodeToFunctionMap selectPartitions(Function *F, unsigned availableMemory);

/// Adjust a logicalDevice ID to each DAGNode. It is possible that two
/// sub-functions need to be assigned into 1 device due to the memory
/// constraits.
void adjustLogicalDeviceID(DAGNode *DAG, int num);

/// Given the node-function mapping, do the actual partitioning.
void doPartitioning(Function *F, NodeToFunctionMap &mapping);

public:
/// \p parent is the module which contains the functions need to be divided.
/// Here we assume that all the functions in one module belong to a same
/// "Function Family", that is, without considerting the "dynamic stuff" (i.e.
/// batch size, input/output shape of each op), all the functions are
/// identical. The required memory and computation cost for each op can be
/// found in Module. The \p devices provides the cost model related to
/// devices.
Partitioner(Module *parent, const std::vector<DeviceInfo> &devices);

/// Decompose each function in a module and return a list of DAGNodes.
DAGNodeList &Partition();
};
} // namespace glow
#endif // GLOW_PARTITIONER_PARTITIONER_H
9 changes: 5 additions & 4 deletions include/glow/Quantization/Base/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ chooseQuantizationParams(float min, float max, Schema schema = Asymmetric,
std::vector<int8_t> createMapping(TypeRef inTy, TypeRef outTy,
std::function<float(float)> f);

/// Row-wise quantize the tensor \p input. The param \p input is a 2D
/// tensor (i.e. M * N), \p scales and \p offsets are generated by each row of
/// \p input, \p output is 2D tensor quantized from \p input using \p scales
/// and \p offsets for each row.
/// Row-wise quantize the tensor \p input. \p scales and \p offsets are
/// generated by each row of \p input, \p output is tensor of the same shape as
/// input, quantized from \p input using \p scales and \p offsets for each
/// row. Note that the shape of input/output can be any non-zero number of
/// dimensions; row refers to all data in the first dimension of the shape.
void tensorRowwiseQuantization(const Tensor &input, Tensor &output,
Tensor &scales, Tensor &offsets);

Expand Down
11 changes: 6 additions & 5 deletions include/glow/Quantization/Quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,12 @@ struct NodeQuantizationInfo {
namespace quantization {

/// Generate NodeQuantizationInfo for all required nodes from function \p F
/// using the method specified by \p schema. Profiling values will be written
/// into context \p ctx.
std::vector<NodeQuantizationInfo>
generateNodeQuantizationInfos(Context &ctx, const Function *F,
Schema schema = Schema::Asymmetric);
/// using the method specified by \p schema and target quantization
/// precision \p quantizationPrecision. Profiling values will be written into
/// context \p ctx.
std::vector<NodeQuantizationInfo> generateNodeQuantizationInfos(
Context &ctx, const Function *F, Schema schema = Schema::Asymmetric,
ElemKind quantizationPrecision = ElemKind::Int8QTy);

/// Quantizes the function \p F into a new unoptimized partially quantized
/// function based on \p quantizationInfos and target quantization precision
Expand Down
Loading