diff --git a/lib/Backends/Habana/Habana.cpp b/lib/Backends/Habana/Habana.cpp index 3442dd4a3e..6c823916b9 100644 --- a/lib/Backends/Habana/Habana.cpp +++ b/lib/Backends/Habana/Habana.cpp @@ -548,10 +548,10 @@ makeSynPoolParams(llvm::ArrayRef kernel, params->sW = stride[0]; params->sH = stride[1]; // Padding - params->pWbegin = pad[0]; - params->pWend = pad[0]; - params->pHbegin = pad[1]; - params->pHend = pad[1]; + params->pHbegin = pad[0]; + params->pWbegin = pad[1]; + params->pHend = pad[2]; + params->pWend = pad[3]; // Dilation params->dilW = 1; params->dilH = 1; @@ -591,6 +591,16 @@ makeSynSliceAxisParams(unsigned axis, unsigned axes, unsigned outputAxisSize, return params; } +static std::unique_ptr +makeLrnParams(float alpha, float beta, float knorm, int halfWindowSize) { + auto params = llvm::make_unique(); + params->alpha = alpha; + params->beta = beta; + params->knorm = knorm; + params->nsize = 2 * halfWindowSize + 1; + return params; +} + static std::unique_ptr makeConstantParams(float value) { auto params = llvm::make_unique(); @@ -733,6 +743,8 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const { std::vector> tileParams; std::vector> concatParams; std::vector> takeParams; + std::vector> lrnParams; + std::vector> gemmParams; // Keep references to tensor pointer arrays passed into multi-input nodes // until the compilation is done. @@ -965,12 +977,16 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const { if (MI->getLHS().getType()->isQuantizedType()) { // Let GEMM run on MME via FullyConnected node. // MME only runs on quantized types, e.g., int8 or int16. - auto params = llvm::make_unique(); - params->activation.reluEnable = false; - chk(synFullyConnected(tensors[MI->getLHS()].get(), - tensors[MI->getRHS()].get(), nullptr, - tensors[MI].get(), *params, "")); - fcParams.emplace_back(std::move(params)); + // The default params are OK - don't transpose A and B + auto params = llvm::make_unique(); + std::vector inputs; + inputs.push_back(tensors[MI->getLHS()].get()); + inputs.push_back(tensors[MI->getRHS()].get()); + chk(synCreateGenericNode(inputs.data(), &tensors[MI].get(), + inputs.size(), 1, nullptr, "gemm", + MI->getName().data(), nullptr, nullptr)); + gemmParams.emplace_back(std::move(params)); + } else { std::vector inputs; inputs.push_back(tensors[MI->getLHS()].get()); @@ -1015,6 +1031,18 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const { convParams.emplace_back(std::move(params)); break; } + case Kinded::Kind::LocalResponseNormalizationNodeKind: { + auto *NI = llvm::cast(&I); + std::unique_ptr params = makeLrnParams( + NI->getAlpha(), NI->getBeta(), NI->getK(), NI->getHalfWindowSize()); + + chk(synCreateGenericNode(&tensors[NI->getInput()].get(), + &tensors[NI].get(), 1, 1, (void *)params.get(), + "lrn_f32", NI->getName().str().c_str(), nullptr, + nullptr)); + lrnParams.emplace_back(std::move(params)); + break; + } case Kinded::Kind::TransposeNodeKind: { auto *TI = llvm::cast(&I); std::unique_ptr params = @@ -1126,6 +1154,14 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const { concatParams.emplace_back(std::move(params)); break; } + case Kinded::Kind::RescaleQuantizedNodeKind: { + auto *RI = llvm::cast(&I); + chk(synCreateGenericNode( + &tensors[RI->getInput()].get(), &tensors[RI].get(), 1, 1, nullptr, + getKernelName("requant", RI->getResult().getElementType()).c_str(), + RI->getName().data(), nullptr, nullptr)); + break; + } case Kinded::Kind::SaveNodeKind: { auto *CI = llvm::cast(&I); if (tensors.count(CI)) { @@ -1237,7 +1273,11 @@ bool HabanaBackend::isOpSupported(const NodeInfo &NI) const { case Kinded::Kind::SplatNodeKind: case Kinded::Kind::SubNodeKind: case Kinded::Kind::TileNodeKind: + case Kinded::Kind::ConcatNodeKind: return true; + case Kinded::Kind::RescaleQuantizedNodeKind: + return NI.allInputsAndOutputsHaveSameElemKind( + {ElemKind::Int8QTy, ElemKind::Int16QTy}); default: return false; } @@ -1273,6 +1313,7 @@ bool HabanaBackend::isOpSupported(const NodeInfo &NI) const { case Kinded::Kind::TransposeNodeKind: case Kinded::Kind::SparseLengthsWeightedSumNodeKind: case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: + case Kinded::Kind::LocalResponseNormalizationNodeKind: return true; default: return false;