diff --git a/include/glow/Quantization/Base/Base.h b/include/glow/Quantization/Base/Base.h index 7e9e3d18b0..bf8fcaaa9a 100644 --- a/include/glow/Quantization/Base/Base.h +++ b/include/glow/Quantization/Base/Base.h @@ -72,28 +72,32 @@ enum Schema { SymmetricWithUInt8, }; -/// Converts floating point value to int8 based on the quantization -/// parameters \p TQP. -int8_t quantize(float input, const TensorQuantizationParams &TQP); - -/// Converts a floating point \p tensor to int8 based on the quantization -/// parameters \p TQP. -Tensor quantizeTensor(const Tensor &tensor, - const TensorQuantizationParams &TQP); - -/// Converts int8 quantized value back to floating point number based on -/// the quantization parameters \p TQP. -float dequantize(int8_t input, const TensorQuantizationParams &TQP); - /// \returns the value \p in as clipped to the range of \p DestTy. template DestTy clip(SrcTy in) { - assert(sizeof(SrcTy) >= sizeof(DestTy) && "Invalid types"); + static_assert(sizeof(SrcTy) >= sizeof(DestTy), "Invalid types"); auto mx = std::numeric_limits::max(); auto mn = std::numeric_limits::min(); return std::max(mn, std::min(mx, in)); } +/// Converts floating point value to DestTy (int8 or int32) based on the +/// quantization parameters \p TQP. +template +inline DestTy quantize(float input, const TensorQuantizationParams &TQP) { + float result = input / TQP.scale + TQP.offset; + return quantization::clip((int32_t)nearbyintf(result)); +} + +/// Converts a floating point \p tensor to int8 or int32 based on the +/// quantization parameters \p TQP and \p Ty. +Tensor quantizeTensor(const Tensor &tensor, const TensorQuantizationParams &TQP, + ElemKind Ty = ElemKind::Int8QTy); + +/// Converts int8 quantized value back to floating point number based on +/// the quantization parameters \p TQP. +float dequantize(int8_t input, const TensorQuantizationParams &TQP); + /// Convert the floating point quantization parameters \p scale and \p offset /// into the integer sequence of: /// result = ((input >> pre) * scale) >> post + offset. diff --git a/lib/Backends/CPU/LLVMIRGen.cpp b/lib/Backends/CPU/LLVMIRGen.cpp index 2225de1adb..2ed5c1ace6 100644 --- a/lib/Backends/CPU/LLVMIRGen.cpp +++ b/lib/Backends/CPU/LLVMIRGen.cpp @@ -928,8 +928,17 @@ void LLVMIRGen::generateLLVMIRForDataParallelInstr( auto *stackedOpCall = createCall(builder, F, {loopCount, srcPtr, destScale, destOffset}); - auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, loopCount, - "buffer.element.addr"); + llvm::Value *destAddr = nullptr; + if (dest->getElementType() == ElemKind::Int8QTy) { + destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, loopCount, + "buffer.element.addr"); + } else if (dest->getElementType() == ElemKind::Int32QTy) { + destAddr = builder.CreateGEP(builder.getInt32Ty(), destPtr, loopCount, + "buffer.element.addr"); + } else { + GLOW_ASSERT("Type is not supported."); + } + builder.CreateStore(stackedOpCall, destAddr); break; } diff --git a/lib/Backends/CPU/libjit/libjit.cpp b/lib/Backends/CPU/libjit/libjit.cpp index 9e1a8e9bb6..2fd54c6278 100644 --- a/lib/Backends/CPU/libjit/libjit.cpp +++ b/lib/Backends/CPU/libjit/libjit.cpp @@ -1184,6 +1184,12 @@ int8_t libjit_element_quantize_kernel_i8(size_t idx, const float *inW, return (int8_t)MAX(INT8_MIN, MIN(INT8_MAX, result)); } +int32_t libjit_element_quantize_kernel_i32(size_t idx, const float *inW, + float scale, int32_t offset) { + int32_t result = (int32_t)nearbyintf(inW[idx] / scale + offset); + return result; +} + float libjit_element_dequantize_kernel_f(size_t idx, const int8_t *inW, float scale, int32_t offset) { return scale * (inW[idx] - offset); diff --git a/lib/Backends/CPU/libjit/libjit_conv.cpp b/lib/Backends/CPU/libjit/libjit_conv.cpp index e9e9e5b97d..2fc76c0de9 100644 --- a/lib/Backends/CPU/libjit/libjit_conv.cpp +++ b/lib/Backends/CPU/libjit/libjit_conv.cpp @@ -418,13 +418,13 @@ void libjit_convolution_f(float *outW, const float *inW, const float *filterW, } void libjit_convolution_i8( - int8_t *outW, const int8_t *inW, const int8_t *filterW, const int8_t *biasW, - const size_t *outWdims, const size_t *inWdims, const size_t *filterWdims, - const size_t *biasWdims, const size_t *kernelSizes, const size_t *strides, - const size_t *pads, size_t group, int32_t outOffset, int32_t inOffset, - int32_t filterOffset, int32_t biasOffset, int32_t biasPre, int32_t biasPost, - int32_t biasScale, int32_t outPre, int32_t outPost, int32_t outScale, - unsigned depthUnroll) { + int8_t *outW, const int8_t *inW, const int8_t *filterW, + const int32_t *biasW, const size_t *outWdims, const size_t *inWdims, + const size_t *filterWdims, const size_t *biasWdims, + const size_t *kernelSizes, const size_t *strides, const size_t *pads, + size_t group, int32_t outOffset, int32_t inOffset, int32_t filterOffset, + int32_t biasOffset, int32_t biasPre, int32_t biasPost, int32_t biasScale, + int32_t outPre, int32_t outPost, int32_t outScale, unsigned depthUnroll) { size_t inChannels = inWdims[3]; size_t outChannels = outWdims[3]; size_t inCperG = inChannels / group; diff --git a/lib/Backends/Interpreter/InterpreterNodes.cpp b/lib/Backends/Interpreter/InterpreterNodes.cpp index cb08753492..a306c25930 100644 --- a/lib/Backends/Interpreter/InterpreterNodes.cpp +++ b/lib/Backends/Interpreter/InterpreterNodes.cpp @@ -103,6 +103,7 @@ void InterpreterFunction::fwdConvolutionInst_FloatImpl( } // This is the quantized i8 implementation of Convolution. +// For bias, we support int32 quantization. void InterpreterFunction::fwdConvolutionInst_I8Impl( Value *inV, Value *outV, Value *filterV, Value *biasV, llvm::ArrayRef kernelSizes, llvm::ArrayRef strides, @@ -110,7 +111,7 @@ void InterpreterFunction::fwdConvolutionInst_I8Impl( auto inW = getWeightHandle(inV); auto outW = getWeightHandle(outV); auto filterW = getWeightHandle(filterV); - auto biasW = getWeightHandle(biasV); + auto biasW = getWeightHandle(biasV); ShapeNHWC odim(outW.dims()); ShapeNHWC idim(inW.dims()); @@ -1932,7 +1933,6 @@ void InterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) { //===----------------------------------------------------------------------===// // Instructions used by Quantization //===----------------------------------------------------------------------===// - void InterpreterFunction::fwdQuantizationProfileInst( const glow::QuantizationProfileInst *I) { auto inputTensor = getWeightHandle(I->getInputTensor()); @@ -1946,7 +1946,7 @@ void InterpreterFunction::fwdQuantizationProfileInst( quantization::generateTensorHistogram(inputTensor, currentHistogram, min, max); } - +/* template static void fwdQuantize(Handle srcHandle, Tensor *destTensor) { TensorQuantizationParams params{destTensor->getType().getScale(), @@ -1956,8 +1956,21 @@ static void fwdQuantize(Handle srcHandle, Tensor *destTensor) { for (size_t i = 0, e = destHandle.size(); i < e; ++i) { destHandle.raw(i) = quantization::quantize(srcHandle.raw(i), params); } + }*/ + +/// Quantize floating point tensor. Scale and Offset are based on return type +/// of the instruction \p I. +void InterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) { + auto *srcTensor = getTensor(I->getSrc()); + auto *destTensor = getTensor(I->getDest()); + auto destTy = destTensor->getType(); + Tensor qTensor = quantization::quantizeTensor( + *srcTensor, {destTy.getScale(), destTy.getOffset()}, + destTy.getElementType()); + destTensor->assign(&qTensor); } +/* /// Quantize floating point tensor. Scale and Offset are based on return type /// of the instruction \p I. void InterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) { @@ -1976,7 +1989,7 @@ void InterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) { default: llvm_unreachable("Type not supported"); } -} + }*/ template static void fwdDequantize(Tensor *srcTensor, Handle destHandle) { diff --git a/lib/Backends/OpenCL/kernels.cl b/lib/Backends/OpenCL/kernels.cl index 6ae4fe4a78..4ca32d81a7 100644 --- a/lib/Backends/OpenCL/kernels.cl +++ b/lib/Backends/OpenCL/kernels.cl @@ -129,6 +129,12 @@ cl_int8_t quantize(float input, float scale, cl_int32_t offset) { return clip((cl_int32_t)round(result)); } +/// Quantizes \p input from float to int32. +cl_int32_t quantize_i32(float input, float scale, cl_int32_t offset) { + float result = input / scale + offset; + return (cl_int32_t)round(result); +} + /// Dequantizes \p input from int8 to float. float dequantize(cl_int8_t input, float scale, cl_int32_t offset) { return scale * (input - offset); @@ -140,11 +146,22 @@ __kernel void quantize_i8K(__global cl_int8_t *dest, __global float *src, dest[i] = quantize(src[i], scale, offset); } +__kernel void quantize_i32K(__global cl_int32_t *dest, __global float *src, + float scale, cl_int32_t offset) { + size_t i = get_global_id(0); + dest[i] = quantize_i32(src[i], scale, offset); +} + __kernel void quantize_i8W(__global void *mem, cl_uint32_t dest, cl_uint32_t src, float scale, cl_int32_t offset) { quantize_i8K(&mem[dest], &mem[src], scale, offset); } +__kernel void quantize_i32W(__global void *mem, cl_uint32_t dest, + cl_uint32_t src, float scale, cl_int32_t offset) { + quantize_i32K(&mem[dest], &mem[src], scale, offset); +} + __kernel void rescalequantized_i8K(__global cl_int8_t *dest, __global cl_int8_t *src, cl_int32_t destOffset, cl_int32_t srcOffset, @@ -834,7 +851,7 @@ __kernel void convolutionW(__global void *mem, cl_uint32_t dest, __kernel void convolution_i8K( __global cl_int8_t *dest, __global cl_int8_t *src, - __global cl_int8_t *filter, __global cl_int8_t *bias, ShapeHW kernelSizes, + __global cl_int8_t *filter, __global cl_int32_t *bias, ShapeHW kernelSizes, ShapeHW strides, cl_int32_t destOffset, float destScale, cl_int32_t srcOffset, float srcScale, cl_int32_t filterOffset, float filterScale, cl_int32_t biasOffset, float biasScale, PaddingTLBR pads, diff --git a/lib/Backends/OpenCL/kernels_fwd_quantized_conv.cl b/lib/Backends/OpenCL/kernels_fwd_quantized_conv.cl index 3d6645dbfa..ea780ae06a 100644 --- a/lib/Backends/OpenCL/kernels_fwd_quantized_conv.cl +++ b/lib/Backends/OpenCL/kernels_fwd_quantized_conv.cl @@ -136,7 +136,7 @@ __kernel float c_scale, int d_offset, float d_scale) { __global const Dtype *im_in = &mem[im_in_offset]; __global const Dtype *wg = &mem[wg_offset]; - __global const Dtype *bias = &mem[bias_offset]; + __global const int *bias = &mem[bias_offset]; __global Dtype *im_out = &mem[im_out_offset]; // Thread identifiers. // Local row ID (max: RTSM=TSM/WPTM). @@ -156,7 +156,7 @@ __kernel __global const Dtype *Aptr = wg; __global const Dtype *Bptr = im_in + v_B_off * batch; __global Dtype *Cptr = im_out + v_C_off * batch; - __global const Dtype *Dptr = bias; + __global const int *Dptr = bias; // Initialize the accumulation registers. { int4 Creg[WPTM][WPTN / VWN]; diff --git a/lib/Graph/Graph.cpp b/lib/Graph/Graph.cpp index 0785198616..f0ddc447ef 100644 --- a/lib/Graph/Graph.cpp +++ b/lib/Graph/Graph.cpp @@ -1423,7 +1423,8 @@ BatchOneHotNode *Function::createBatchOneHot(llvm::StringRef name, QuantizeNode *Function::createQuantize(llvm::StringRef name, NodeValue input, TypeRef outTy) { assert(input.getType()->isFPType() && "Input must be a floating type"); - assert(outTy->getElementType() == ElemKind::Int8QTy && + assert((outTy->getElementType() == ElemKind::Int8QTy || + outTy->getElementType() == ElemKind::Int32QTy) && "Output must be a quantized type"); assert(input.dims().equals(outTy->dims()) && "Different dimensions for input and output"); diff --git a/lib/Graph/Nodes.cpp b/lib/Graph/Nodes.cpp index 05b9fc2141..7f0a37473e 100644 --- a/lib/Graph/Nodes.cpp +++ b/lib/Graph/Nodes.cpp @@ -144,8 +144,14 @@ static void verifyConvolution(NodeValue src, NodeValue dest, NodeValue filter, unsigned_t group) { assert(src.getElementType() == dest.getElementType() && "Invalid Type"); assert(src.getElementType() == filter.getElementType() && "Invalid Type"); - assert(src.getElementType() == bias.getElementType() && "Invalid Type"); - + // Non quantization type check. + if (src.getElementType() == ElemKind::FloatTy) { + assert(bias.getElementType() == ElemKind::FloatTy && "Invalid Type"); + } + // Quantization type check. + if (src.getElementType() == ElemKind::Int8QTy) { + assert(bias.getElementType() == ElemKind::Int32QTy && "Invalid Type"); + } ShapeNHWC idim(src.getType()->dims()); ShapeNHWC odim(dest.getType()->dims()); PaddingTLBR pdim(pads); @@ -643,7 +649,9 @@ void IntLookupTableNode::verify() const { void QuantizeNode::verify() const { // Dest must be quantized. - checkType(getResult(), ElemKind::Int8QTy); + assert((getResult().getElementType() == ElemKind::Int8QTy || + getResult().getElementType() == ElemKind::Int32QTy) && + "Invalid type"); // Src must be an FP type. assert(getInput().getType()->isFPType() && "Invalid type"); checkSameShape(getResult(), getInput()); diff --git a/lib/IR/Instrs.cpp b/lib/IR/Instrs.cpp index ca46ac10a1..a41f769583 100644 --- a/lib/IR/Instrs.cpp +++ b/lib/IR/Instrs.cpp @@ -90,3 +90,12 @@ void InsertTensorInst::verify() const { assert(getAxis() >= 0 && getAxis() < getDest()->dims().size() && "Axis must fit inside Dest dims."); } + +void QuantizeInst::verify() const { + assert((getDest()->getElementType() == ElemKind::Int8QTy || + getDest()->getElementType() == ElemKind::Int32QTy) && + "Invalid type"); + assert(getSrc()->getElementType() == ElemKind::FloatTy || + getSrc()->getElementType() == ElemKind::Float16Ty && "Invalid type"); + assert(getSrc()->dims() == getDest()->dims() && "Invalid shape"); +} diff --git a/lib/Optimizer/GraphOptimizer.cpp b/lib/Optimizer/GraphOptimizer.cpp index 792f985fa7..f667c10988 100644 --- a/lib/Optimizer/GraphOptimizer.cpp +++ b/lib/Optimizer/GraphOptimizer.cpp @@ -1639,8 +1639,9 @@ static NodeValue convertConstant(Module &mod, Constant &constant, constantToBeModified.getPayload().convertToType(dstTy->getElementType()); return NodeValue(&constantToBeModified, 0); } + case ElemKind::Int32QTy: case ElemKind::Int8QTy: { - // Quantization: {FloatTy, Float16Ty} -> Int8QTy. + // Quantization: {FloatTy, Float16Ty} -> Int8QTy or Int32QTy. Constant &constantToBeModified = modifyConstantTyAndGet(); TensorQuantizationParams params{dstTy->getScale(), dstTy->getOffset()}; Tensor &tensorToBeModified = constantToBeModified.getPayload(); @@ -1650,8 +1651,8 @@ static NodeValue convertConstant(Module &mod, Constant &constant, // teach quantizeTensor how to deal with Float16Ty. assert(tensor.getType().isFPType() && "Type quantization not implemented"); - tensorToBeModified = - quantization::quantizeTensor(tensorToBeModified, params); + tensorToBeModified = quantization::quantizeTensor( + tensorToBeModified, params, dstTy->getElementType()); return NodeValue(&constantToBeModified, 0); } default: diff --git a/lib/Quantization/Base/Base.cpp b/lib/Quantization/Base/Base.cpp index 17d2ab53df..81bd5da482 100644 --- a/lib/Quantization/Base/Base.cpp +++ b/lib/Quantization/Base/Base.cpp @@ -22,36 +22,42 @@ namespace glow { namespace quantization { -int8_t quantize(float input, const TensorQuantizationParams &TQP) { - float result = input / TQP.scale + TQP.offset; - return quantization::clip((int32_t)nearbyintf(result)); -} - -Tensor quantizeTensor(const Tensor &tensor, - const TensorQuantizationParams &TQP) { - Tensor tmp(ElemKind::Int8QTy, tensor.dims(), TQP.scale, TQP.offset); - auto destHandle = tmp.getHandle(); - assert(tensor.getType().isFPType() && "Type not supported yet"); - switch (tensor.getElementType()) { +template +static void quantizeTensorUtil(Tensor *dest, const Tensor &src) { + auto destH = dest->getHandle(); + TensorQuantizationParams TQP{dest->getType().getScale(), + dest->getType().getOffset()}; + switch (src.getElementType()) { case ElemKind::FloatTy: { - auto srcHandle = tensor.getHandle(); - for (size_t i = 0, e = destHandle.size(); i < e; ++i) { - destHandle.raw(i) = - quantization::quantize(static_cast(srcHandle.raw(i)), TQP); + auto srcHandle = src.getHandle(); + for (size_t i = 0, e = destH.size(); i < e; ++i) { + destH.raw(i) = quantization::quantize( + static_cast(srcHandle.raw(i)), TQP); } break; } case ElemKind::Float16Ty: { - auto srcHandle = tensor.getHandle(); - for (size_t i = 0, e = destHandle.size(); i < e; ++i) { - destHandle.raw(i) = - quantization::quantize(static_cast(srcHandle.raw(i)), TQP); + auto srcHandle = src.getHandle(); + for (size_t i = 0, e = destH.size(); i < e; ++i) { + destH.raw(i) = quantization::quantize( + static_cast(srcHandle.raw(i)), TQP); } break; } default: llvm_unreachable("Cannot quantize a type"); } +} + +Tensor quantizeTensor(const Tensor &tensor, const TensorQuantizationParams &TQP, + ElemKind Ty) { + Tensor tmp(Ty, tensor.dims(), TQP.scale, TQP.offset); + assert(tensor.getType().isFPType() && "Type not supported yet"); + if (Ty == ElemKind::Int8QTy) { + quantizeTensorUtil(&tmp, tensor); + } else { + quantizeTensorUtil(&tmp, tensor); + } return tmp; } diff --git a/lib/Quantization/Quantization.cpp b/lib/Quantization/Quantization.cpp index ccc003a12c..82e55310c4 100644 --- a/lib/Quantization/Quantization.cpp +++ b/lib/Quantization/Quantization.cpp @@ -67,8 +67,22 @@ class FunctionQuantizer : public FunctionConverter { "Missing quantization params for a node"); const TensorQuantizationParams &TQP = valTQPIt->second; - return mod_.uniqueType(ElemKind::Int8QTy, val.dims(), TQP.scale, - TQP.offset); + // For bias of a conv op, it is quantized to int32. + if (use.getKind() == glow::Kinded::Kind::ConvolutionNodeKind && idx == 2) { + // For bias of a conv op, it is quantized to int32. Also, we should make + // sure its scale should be (scale of input) * (scale of weights). + auto convN = llvm::dyn_cast(&use); + NodeValue input = convN->getInput(); + NodeValue weights = convN->getFilter(); + float scaleInput = input.getNode()->getNthResult(0).getType()->getScale(); + float scaleWeights = + weights.getNode()->getNthResult(0).getType()->getScale(); + return mod_.uniqueType(ElemKind::Int32QTy, val.dims(), + scaleInput * scaleWeights, TQP.offset); + } else { + return mod_.uniqueType(ElemKind::Int8QTy, val.dims(), TQP.scale, + TQP.offset); + } } /// \see FunctionConverter::canConvert. @@ -122,7 +136,9 @@ class FunctionQuantizer : public FunctionConverter { Node *createConversion(Function &function, NodeValue &val, TypeRef destTy) override { if (destTy->isQuantizedType()) { - assert(destTy->getElementType() == ElemKind::Int8QTy && ""); + assert((destTy->getElementType() == ElemKind::Int8QTy || + destTy->getElementType() == ElemKind::Int32QTy) && + "We only support int8_t and int32_t quantization now"); return function_.createQuantize("quantize", val, destTy); } assert(destTy->getElementType() == ElemKind::FloatTy && ""); diff --git a/tests/unittests/BackendCorrectnessTest.cpp b/tests/unittests/BackendCorrectnessTest.cpp index c444ef94a0..d75501246f 100644 --- a/tests/unittests/BackendCorrectnessTest.cpp +++ b/tests/unittests/BackendCorrectnessTest.cpp @@ -136,10 +136,10 @@ TEST_P(CPUOnly, quantizedConvTest) { PseudoRNG PRNG; Tensor inputs(ElemKind::Int8QTy, {20, 41, 32, 6}, 0.025, -7); Tensor kernel(ElemKind::Int8QTy, {10, 5, 5, 6}, 0.003, 3); - Tensor bias(ElemKind::Int8QTy, {10}, 0.5, -4); + Tensor bias(ElemKind::Int32QTy, {10}, 0.5, -4); inputs.getHandle().randomize(-128, 127, PRNG); kernel.getHandle().randomize(-128, 127, PRNG); - bias.getHandle().randomize(-11, 8, PRNG); + bias.getHandle().randomize(-11, 8, PRNG); std::array S{{20, 15, 12, 10}}; llvm::ArrayRef shape(S); Tensor out1(ElemKind::Int8QTy, shape, 0.05, -17); diff --git a/tests/unittests/OperatorTest.cpp b/tests/unittests/OperatorTest.cpp index 44945c7ac2..0936f8bec4 100644 --- a/tests/unittests/OperatorTest.cpp +++ b/tests/unittests/OperatorTest.cpp @@ -1588,7 +1588,7 @@ void checkIntConvolution(ExecutionEngine &EE, Function *F, unsigned convDepth, TypeRef inputTy = mod.uniqueType(ElemKind::Int8QTy, input->dims(), 0.01, 0.0); TypeRef filterTy = mod.uniqueType(ElemKind::Int8QTy, filter->dims(), 0.01, 0.0); - TypeRef biasTy = mod.uniqueType(ElemKind::Int8QTy, bias->dims(), 0.04, 0.0); + TypeRef biasTy = mod.uniqueType(ElemKind::Int32QTy, bias->dims(), 0.04, 0.0); auto *inputq = F->createQuantize("input.q", input, inputTy); auto *filterq = F->createQuantize("filter.q", filter, filterTy); diff --git a/tests/unittests/graphTest.cpp b/tests/unittests/graphTest.cpp index 3c90d4d9fe..3a1cb9734f 100644 --- a/tests/unittests/graphTest.cpp +++ b/tests/unittests/graphTest.cpp @@ -344,7 +344,7 @@ TEST(Graph, simpleQuant) { auto *filter = MD.createPlaceholder(ElemKind::Int8QTy, filterDim, 3.3, 4, "F", true); auto *bias = - MD.createPlaceholder(ElemKind::Int8QTy, {depth}, 1.3, 5, "B", true); + MD.createPlaceholder(ElemKind::Int32QTy, {depth}, 1.3, 5, "B", true); // Calculate the size and allocate the output buffer. auto outSz = calculateConvPoolOutputDims(width, width, kernels, steps, pads); diff --git a/tools/ClassGen/Backends/OpenCL/OpenCLSpecificInstrs.h b/tools/ClassGen/Backends/OpenCL/OpenCLSpecificInstrs.h index 8b63feb90c..0ae56d71ba 100644 --- a/tools/ClassGen/Backends/OpenCL/OpenCLSpecificInstrs.h +++ b/tools/ClassGen/Backends/OpenCL/OpenCLSpecificInstrs.h @@ -26,7 +26,7 @@ BB.newBackendSpecificInstr("OCLConvolution") .addMember(MemberType::VectorUnsigned, "Pads") .addMember(MemberType::Unsigned, "Group") .autoIRGen() - .autoVerify(VerifyKind::SameElementType, {"Dest", "Src", "Filter", "Bias"}); + .autoVerify(VerifyKind::SameElementType, {"Dest", "Src", "Filter"}); BB.newBackendSpecificInstr("OCLAvgPool") .addOperand("Dest", OperandKind::Out) diff --git a/tools/ClassGen/InstrGen.cpp b/tools/ClassGen/InstrGen.cpp index 172bb46614..fb3a1c952b 100644 --- a/tools/ClassGen/InstrGen.cpp +++ b/tools/ClassGen/InstrGen.cpp @@ -86,8 +86,7 @@ int main(int argc, char **argv) { .addMember(MemberType::VectorUnsigned, "Pads") .addMember(MemberType::Unsigned, "Group") .autoIRGen() - .autoVerify(VerifyKind::SameElementType, - {"Dest", "Src", "Filter", "Bias"}) + .autoVerify(VerifyKind::SameElementType, {"Dest", "Src", "Filter"}) .addGradientInstr({"Src", "Filter"}, {"Dest", "Src", "Filter", "Bias"}); // MaxPool version caching XY coordinates to speedup gradient-based @@ -472,9 +471,6 @@ int main(int argc, char **argv) { BB.newInstr("Quantize") .addOperand("Dest", OperandKind::Out) .addOperand("Src", OperandKind::In) - .autoVerify(VerifyKind::SameElementType, {"Dest", "ElemKind::Int8QTy"}) - .autoVerify(VerifyKind::TypeCheck, {"Src", "isFPType()"}) - .autoVerify(VerifyKind::SameShape, {"Dest", "Src"}) .dataParallel() .autoIRGen();