Skip to content

[Quantization] Support int32 quantized bias for quantized Conv #1876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions include/glow/Quantization/Base/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,28 +72,32 @@ enum Schema {
SymmetricWithUInt8,
};

/// Converts floating point value to int8 based on the quantization
/// parameters \p TQP.
int8_t quantize(float input, const TensorQuantizationParams &TQP);

/// Converts a floating point \p tensor to int8 based on the quantization
/// parameters \p TQP.
Tensor quantizeTensor(const Tensor &tensor,
const TensorQuantizationParams &TQP);

/// Converts int8 quantized value back to floating point number based on
/// the quantization parameters \p TQP.
float dequantize(int8_t input, const TensorQuantizationParams &TQP);

/// \returns the value \p in as clipped to the range of \p DestTy.
template <class SrcTy, class DestTy> DestTy clip(SrcTy in) {
assert(sizeof(SrcTy) >= sizeof(DestTy) && "Invalid types");
static_assert(sizeof(SrcTy) >= sizeof(DestTy), "Invalid types");

auto mx = std::numeric_limits<DestTy>::max();
auto mn = std::numeric_limits<DestTy>::min();
return std::max<SrcTy>(mn, std::min<SrcTy>(mx, in));
}

/// Converts floating point value to DestTy (int8 or int32) based on the
/// quantization parameters \p TQP.
template <class DestTy = int8_t>
inline DestTy quantize(float input, const TensorQuantizationParams &TQP) {
float result = input / TQP.scale + TQP.offset;
return quantization::clip<int32_t, DestTy>((int32_t)nearbyintf(result));
}

/// Converts a floating point \p tensor to int8 or int32 based on the
/// quantization parameters \p TQP and \p Ty.
Tensor quantizeTensor(const Tensor &tensor, const TensorQuantizationParams &TQP,
ElemKind Ty = ElemKind::Int8QTy);

/// Converts int8 quantized value back to floating point number based on
/// the quantization parameters \p TQP.
float dequantize(int8_t input, const TensorQuantizationParams &TQP);

/// Convert the floating point quantization parameters \p scale and \p offset
/// into the integer sequence of:
/// result = ((input >> pre) * scale) >> post + offset.
Expand Down
13 changes: 11 additions & 2 deletions lib/Backends/CPU/LLVMIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,17 @@ void LLVMIRGen::generateLLVMIRForDataParallelInstr(

auto *stackedOpCall =
createCall(builder, F, {loopCount, srcPtr, destScale, destOffset});
auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, loopCount,
"buffer.element.addr");
llvm::Value *destAddr = nullptr;
if (dest->getElementType() == ElemKind::Int8QTy) {
destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, loopCount,
"buffer.element.addr");
} else if (dest->getElementType() == ElemKind::Int32QTy) {
destAddr = builder.CreateGEP(builder.getInt32Ty(), destPtr, loopCount,
"buffer.element.addr");
} else {
GLOW_ASSERT("Type is not supported.");
}

builder.CreateStore(stackedOpCall, destAddr);
break;
}
Expand Down
6 changes: 6 additions & 0 deletions lib/Backends/CPU/libjit/libjit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,12 @@ int8_t libjit_element_quantize_kernel_i8(size_t idx, const float *inW,
return (int8_t)MAX(INT8_MIN, MIN(INT8_MAX, result));
}

int32_t libjit_element_quantize_kernel_i32(size_t idx, const float *inW,
float scale, int32_t offset) {
int32_t result = (int32_t)nearbyintf(inW[idx] / scale + offset);
return result;
}

float libjit_element_dequantize_kernel_f(size_t idx, const int8_t *inW,
float scale, int32_t offset) {
return scale * (inW[idx] - offset);
Expand Down
14 changes: 7 additions & 7 deletions lib/Backends/CPU/libjit/libjit_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,13 @@ void libjit_convolution_f(float *outW, const float *inW, const float *filterW,
}

void libjit_convolution_i8(
int8_t *outW, const int8_t *inW, const int8_t *filterW, const int8_t *biasW,
const size_t *outWdims, const size_t *inWdims, const size_t *filterWdims,
const size_t *biasWdims, const size_t *kernelSizes, const size_t *strides,
const size_t *pads, size_t group, int32_t outOffset, int32_t inOffset,
int32_t filterOffset, int32_t biasOffset, int32_t biasPre, int32_t biasPost,
int32_t biasScale, int32_t outPre, int32_t outPost, int32_t outScale,
unsigned depthUnroll) {
int8_t *outW, const int8_t *inW, const int8_t *filterW,
const int32_t *biasW, const size_t *outWdims, const size_t *inWdims,
const size_t *filterWdims, const size_t *biasWdims,
const size_t *kernelSizes, const size_t *strides, const size_t *pads,
size_t group, int32_t outOffset, int32_t inOffset, int32_t filterOffset,
int32_t biasOffset, int32_t biasPre, int32_t biasPost, int32_t biasScale,
int32_t outPre, int32_t outPost, int32_t outScale, unsigned depthUnroll) {
size_t inChannels = inWdims[3];
size_t outChannels = outWdims[3];
size_t inCperG = inChannels / group;
Expand Down
21 changes: 17 additions & 4 deletions lib/Backends/Interpreter/InterpreterNodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,15 @@ void InterpreterFunction::fwdConvolutionInst_FloatImpl(
}

// This is the quantized i8 implementation of Convolution.
// For bias, we support int32 quantization.
void InterpreterFunction::fwdConvolutionInst_I8Impl(
Value *inV, Value *outV, Value *filterV, Value *biasV,
llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
llvm::ArrayRef<unsigned_t> pads, size_t group) {
auto inW = getWeightHandle<int8_t>(inV);
auto outW = getWeightHandle<int8_t>(outV);
auto filterW = getWeightHandle<int8_t>(filterV);
auto biasW = getWeightHandle<int8_t>(biasV);
auto biasW = getWeightHandle<int32_t>(biasV);

ShapeNHWC odim(outW.dims());
ShapeNHWC idim(inW.dims());
Expand Down Expand Up @@ -1932,7 +1933,6 @@ void InterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) {
//===----------------------------------------------------------------------===//
// Instructions used by Quantization
//===----------------------------------------------------------------------===//

void InterpreterFunction::fwdQuantizationProfileInst(
const glow::QuantizationProfileInst *I) {
auto inputTensor = getWeightHandle(I->getInputTensor());
Expand All @@ -1946,7 +1946,7 @@ void InterpreterFunction::fwdQuantizationProfileInst(
quantization::generateTensorHistogram(inputTensor, currentHistogram, min,
max);
}

/*
template <typename ElemTy>
static void fwdQuantize(Handle<ElemTy> srcHandle, Tensor *destTensor) {
TensorQuantizationParams params{destTensor->getType().getScale(),
Expand All @@ -1956,8 +1956,21 @@ static void fwdQuantize(Handle<ElemTy> srcHandle, Tensor *destTensor) {
for (size_t i = 0, e = destHandle.size(); i < e; ++i) {
destHandle.raw(i) = quantization::quantize(srcHandle.raw(i), params);
}
}*/

/// Quantize floating point tensor. Scale and Offset are based on return type
/// of the instruction \p I.
void InterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) {
auto *srcTensor = getTensor(I->getSrc());
auto *destTensor = getTensor(I->getDest());
auto destTy = destTensor->getType();
Tensor qTensor = quantization::quantizeTensor(
*srcTensor, {destTy.getScale(), destTy.getOffset()},
destTy.getElementType());
destTensor->assign(&qTensor);
}

/*
/// Quantize floating point tensor. Scale and Offset are based on return type
/// of the instruction \p I.
void InterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) {
Expand All @@ -1976,7 +1989,7 @@ void InterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) {
default:
llvm_unreachable("Type not supported");
}
}
}*/

template <typename ElemTy>
static void fwdDequantize(Tensor *srcTensor, Handle<ElemTy> destHandle) {
Expand Down
19 changes: 18 additions & 1 deletion lib/Backends/OpenCL/kernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ cl_int8_t quantize(float input, float scale, cl_int32_t offset) {
return clip((cl_int32_t)round(result));
}

/// Quantizes \p input from float to int32.
cl_int32_t quantize_i32(float input, float scale, cl_int32_t offset) {
float result = input / scale + offset;
return (cl_int32_t)round(result);
}

/// Dequantizes \p input from int8 to float.
float dequantize(cl_int8_t input, float scale, cl_int32_t offset) {
return scale * (input - offset);
Expand All @@ -140,11 +146,22 @@ __kernel void quantize_i8K(__global cl_int8_t *dest, __global float *src,
dest[i] = quantize(src[i], scale, offset);
}

__kernel void quantize_i32K(__global cl_int32_t *dest, __global float *src,
float scale, cl_int32_t offset) {
size_t i = get_global_id(0);
dest[i] = quantize_i32(src[i], scale, offset);
}

__kernel void quantize_i8W(__global void *mem, cl_uint32_t dest,
cl_uint32_t src, float scale, cl_int32_t offset) {
quantize_i8K(&mem[dest], &mem[src], scale, offset);
}

__kernel void quantize_i32W(__global void *mem, cl_uint32_t dest,
cl_uint32_t src, float scale, cl_int32_t offset) {
quantize_i32K(&mem[dest], &mem[src], scale, offset);
}

__kernel void rescalequantized_i8K(__global cl_int8_t *dest,
__global cl_int8_t *src,
cl_int32_t destOffset, cl_int32_t srcOffset,
Expand Down Expand Up @@ -834,7 +851,7 @@ __kernel void convolutionW(__global void *mem, cl_uint32_t dest,

__kernel void convolution_i8K(
__global cl_int8_t *dest, __global cl_int8_t *src,
__global cl_int8_t *filter, __global cl_int8_t *bias, ShapeHW kernelSizes,
__global cl_int8_t *filter, __global cl_int32_t *bias, ShapeHW kernelSizes,
ShapeHW strides, cl_int32_t destOffset, float destScale,
cl_int32_t srcOffset, float srcScale, cl_int32_t filterOffset,
float filterScale, cl_int32_t biasOffset, float biasScale, PaddingTLBR pads,
Expand Down
4 changes: 2 additions & 2 deletions lib/Backends/OpenCL/kernels_fwd_quantized_conv.cl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ __kernel
float c_scale, int d_offset, float d_scale) {
__global const Dtype *im_in = &mem[im_in_offset];
__global const Dtype *wg = &mem[wg_offset];
__global const Dtype *bias = &mem[bias_offset];
__global const int *bias = &mem[bias_offset];
__global Dtype *im_out = &mem[im_out_offset];
// Thread identifiers.
// Local row ID (max: RTSM=TSM/WPTM).
Expand All @@ -156,7 +156,7 @@ __kernel
__global const Dtype *Aptr = wg;
__global const Dtype *Bptr = im_in + v_B_off * batch;
__global Dtype *Cptr = im_out + v_C_off * batch;
__global const Dtype *Dptr = bias;
__global const int *Dptr = bias;
// Initialize the accumulation registers.
{
int4 Creg[WPTM][WPTN / VWN];
Expand Down
3 changes: 2 additions & 1 deletion lib/Graph/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,8 @@ BatchOneHotNode *Function::createBatchOneHot(llvm::StringRef name,
QuantizeNode *Function::createQuantize(llvm::StringRef name, NodeValue input,
TypeRef outTy) {
assert(input.getType()->isFPType() && "Input must be a floating type");
assert(outTy->getElementType() == ElemKind::Int8QTy &&
assert((outTy->getElementType() == ElemKind::Int8QTy ||
outTy->getElementType() == ElemKind::Int32QTy) &&

This comment was marked as off-topic.

"Output must be a quantized type");
assert(input.dims().equals(outTy->dims()) &&
"Different dimensions for input and output");
Expand Down
14 changes: 11 additions & 3 deletions lib/Graph/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,14 @@ static void verifyConvolution(NodeValue src, NodeValue dest, NodeValue filter,
unsigned_t group) {
assert(src.getElementType() == dest.getElementType() && "Invalid Type");
assert(src.getElementType() == filter.getElementType() && "Invalid Type");
assert(src.getElementType() == bias.getElementType() && "Invalid Type");

// Non quantization type check.
if (src.getElementType() == ElemKind::FloatTy) {
assert(bias.getElementType() == ElemKind::FloatTy && "Invalid Type");
}
// Quantization type check.
if (src.getElementType() == ElemKind::Int8QTy) {
assert(bias.getElementType() == ElemKind::Int32QTy && "Invalid Type");
}
ShapeNHWC idim(src.getType()->dims());
ShapeNHWC odim(dest.getType()->dims());
PaddingTLBR pdim(pads);
Expand Down Expand Up @@ -643,7 +649,9 @@ void IntLookupTableNode::verify() const {

void QuantizeNode::verify() const {
// Dest must be quantized.
checkType(getResult(), ElemKind::Int8QTy);
assert((getResult().getElementType() == ElemKind::Int8QTy ||
getResult().getElementType() == ElemKind::Int32QTy) &&
"Invalid type");
// Src must be an FP type.
assert(getInput().getType()->isFPType() && "Invalid type");
checkSameShape(getResult(), getInput());
Expand Down
9 changes: 9 additions & 0 deletions lib/IR/Instrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,12 @@ void InsertTensorInst::verify() const {
assert(getAxis() >= 0 && getAxis() < getDest()->dims().size() &&
"Axis must fit inside Dest dims.");
}

void QuantizeInst::verify() const {
assert((getDest()->getElementType() == ElemKind::Int8QTy ||
getDest()->getElementType() == ElemKind::Int32QTy) &&
"Invalid type");
assert(getSrc()->getElementType() == ElemKind::FloatTy ||
getSrc()->getElementType() == ElemKind::Float16Ty && "Invalid type");
assert(getSrc()->dims() == getDest()->dims() && "Invalid shape");
}
7 changes: 4 additions & 3 deletions lib/Optimizer/GraphOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1639,8 +1639,9 @@ static NodeValue convertConstant(Module &mod, Constant &constant,
constantToBeModified.getPayload().convertToType(dstTy->getElementType());
return NodeValue(&constantToBeModified, 0);
}
case ElemKind::Int32QTy:
case ElemKind::Int8QTy: {
// Quantization: {FloatTy, Float16Ty} -> Int8QTy.
// Quantization: {FloatTy, Float16Ty} -> Int8QTy or Int32QTy.
Constant &constantToBeModified = modifyConstantTyAndGet();
TensorQuantizationParams params{dstTy->getScale(), dstTy->getOffset()};
Tensor &tensorToBeModified = constantToBeModified.getPayload();
Expand All @@ -1650,8 +1651,8 @@ static NodeValue convertConstant(Module &mod, Constant &constant,
// teach quantizeTensor how to deal with Float16Ty.
assert(tensor.getType().isFPType() &&
"Type quantization not implemented");
tensorToBeModified =
quantization::quantizeTensor(tensorToBeModified, params);
tensorToBeModified = quantization::quantizeTensor(
tensorToBeModified, params, dstTy->getElementType());
return NodeValue(&constantToBeModified, 0);
}
default:
Expand Down
44 changes: 25 additions & 19 deletions lib/Quantization/Base/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,36 +22,42 @@
namespace glow {
namespace quantization {

int8_t quantize(float input, const TensorQuantizationParams &TQP) {
float result = input / TQP.scale + TQP.offset;
return quantization::clip<int32_t, int8_t>((int32_t)nearbyintf(result));
}

Tensor quantizeTensor(const Tensor &tensor,
const TensorQuantizationParams &TQP) {
Tensor tmp(ElemKind::Int8QTy, tensor.dims(), TQP.scale, TQP.offset);
auto destHandle = tmp.getHandle<int8_t>();
assert(tensor.getType().isFPType() && "Type not supported yet");
switch (tensor.getElementType()) {
template <class eTy = int8_t>
static void quantizeTensorUtil(Tensor *dest, const Tensor &src) {
auto destH = dest->getHandle<eTy>();
TensorQuantizationParams TQP{dest->getType().getScale(),
dest->getType().getOffset()};
switch (src.getElementType()) {
case ElemKind::FloatTy: {
auto srcHandle = tensor.getHandle<float>();
for (size_t i = 0, e = destHandle.size(); i < e; ++i) {
destHandle.raw(i) =
quantization::quantize(static_cast<float>(srcHandle.raw(i)), TQP);
auto srcHandle = src.getHandle<float>();
for (size_t i = 0, e = destH.size(); i < e; ++i) {
destH.raw(i) = quantization::quantize<eTy>(
static_cast<float>(srcHandle.raw(i)), TQP);
}
break;
}
case ElemKind::Float16Ty: {
auto srcHandle = tensor.getHandle<float16>();
for (size_t i = 0, e = destHandle.size(); i < e; ++i) {
destHandle.raw(i) =
quantization::quantize(static_cast<float>(srcHandle.raw(i)), TQP);
auto srcHandle = src.getHandle<float16>();
for (size_t i = 0, e = destH.size(); i < e; ++i) {
destH.raw(i) = quantization::quantize<eTy>(
static_cast<float>(srcHandle.raw(i)), TQP);
}
break;
}
default:
llvm_unreachable("Cannot quantize a type");
}
}

Tensor quantizeTensor(const Tensor &tensor, const TensorQuantizationParams &TQP,
ElemKind Ty) {
Tensor tmp(Ty, tensor.dims(), TQP.scale, TQP.offset);
assert(tensor.getType().isFPType() && "Type not supported yet");
if (Ty == ElemKind::Int8QTy) {
quantizeTensorUtil<int8_t>(&tmp, tensor);
} else {
quantizeTensorUtil<int32_t>(&tmp, tensor);
}
return tmp;
}

Expand Down
22 changes: 19 additions & 3 deletions lib/Quantization/Quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,22 @@ class FunctionQuantizer : public FunctionConverter {
"Missing quantization params for a node");

const TensorQuantizationParams &TQP = valTQPIt->second;
return mod_.uniqueType(ElemKind::Int8QTy, val.dims(), TQP.scale,
TQP.offset);
// For bias of a conv op, it is quantized to int32.
if (use.getKind() == glow::Kinded::Kind::ConvolutionNodeKind && idx == 2) {
// For bias of a conv op, it is quantized to int32. Also, we should make
// sure its scale should be (scale of input) * (scale of weights).
auto convN = llvm::dyn_cast<ConvolutionNode>(&use);
NodeValue input = convN->getInput();
NodeValue weights = convN->getFilter();
float scaleInput = input.getNode()->getNthResult(0).getType()->getScale();
float scaleWeights =
weights.getNode()->getNthResult(0).getType()->getScale();
return mod_.uniqueType(ElemKind::Int32QTy, val.dims(),
scaleInput * scaleWeights, TQP.offset);
} else {
return mod_.uniqueType(ElemKind::Int8QTy, val.dims(), TQP.scale,
TQP.offset);
}
}

/// \see FunctionConverter::canConvert.
Expand Down Expand Up @@ -122,7 +136,9 @@ class FunctionQuantizer : public FunctionConverter {
Node *createConversion(Function &function, NodeValue &val,
TypeRef destTy) override {
if (destTy->isQuantizedType()) {
assert(destTy->getElementType() == ElemKind::Int8QTy && "");
assert((destTy->getElementType() == ElemKind::Int8QTy ||
destTy->getElementType() == ElemKind::Int32QTy) &&
"We only support int8_t and int32_t quantization now");
return function_.createQuantize("quantize", val, destTy);
}
assert(destTy->getElementType() == ElemKind::FloatTy && "");
Expand Down
Loading