Skip to content

Commit 421d0a5

Browse files
authored
[habana] Misc fixes for vision nets (#2838)
* Introduced LRN, fixed #144 * Added Concat node to the list of supported quantized nodes * Added RescaleQuantized node, replaced quantized MatMul with GEMM
1 parent d67cfad commit 421d0a5

File tree

1 file changed

+51
-10
lines changed

1 file changed

+51
-10
lines changed

lib/Backends/Habana/Habana.cpp

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -548,10 +548,10 @@ makeSynPoolParams(llvm::ArrayRef<unsigned_t> kernel,
548548
params->sW = stride[0];
549549
params->sH = stride[1];
550550
// Padding
551-
params->pWbegin = pad[0];
552-
params->pWend = pad[0];
553-
params->pHbegin = pad[1];
554-
params->pHend = pad[1];
551+
params->pHbegin = pad[0];
552+
params->pWbegin = pad[1];
553+
params->pHend = pad[2];
554+
params->pWend = pad[3];
555555
// Dilation
556556
params->dilW = 1;
557557
params->dilH = 1;
@@ -591,6 +591,16 @@ makeSynSliceAxisParams(unsigned axis, unsigned axes, unsigned outputAxisSize,
591591
return params;
592592
}
593593

594+
static std::unique_ptr<ns_LrnKernel::Params>
595+
makeLrnParams(float alpha, float beta, float knorm, int halfWindowSize) {
596+
auto params = llvm::make_unique<ns_LrnKernel::Params>();
597+
params->alpha = alpha;
598+
params->beta = beta;
599+
params->knorm = knorm;
600+
params->nsize = 2 * halfWindowSize + 1;
601+
return params;
602+
}
603+
594604
static std::unique_ptr<ns_ConstantKernel::Params>
595605
makeConstantParams(float value) {
596606
auto params = llvm::make_unique<ns_ConstantKernel::Params>();
@@ -733,6 +743,8 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
733743
std::vector<std::unique_ptr<ns_TileKernel::Params>> tileParams;
734744
std::vector<std::unique_ptr<unsigned>> concatParams;
735745
std::vector<std::unique_ptr<ns_TakeKernel::Params>> takeParams;
746+
std::vector<std::unique_ptr<ns_LrnKernel::Params>> lrnParams;
747+
std::vector<std::unique_ptr<synGEMMParams>> gemmParams;
736748

737749
// Keep references to tensor pointer arrays passed into multi-input nodes
738750
// until the compilation is done.
@@ -965,12 +977,16 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
965977
if (MI->getLHS().getType()->isQuantizedType()) {
966978
// Let GEMM run on MME via FullyConnected node.
967979
// MME only runs on quantized types, e.g., int8 or int16.
968-
auto params = llvm::make_unique<synFCParams>();
969-
params->activation.reluEnable = false;
970-
chk(synFullyConnected(tensors[MI->getLHS()].get(),
971-
tensors[MI->getRHS()].get(), nullptr,
972-
tensors[MI].get(), *params, ""));
973-
fcParams.emplace_back(std::move(params));
980+
// The default params are OK - don't transpose A and B
981+
auto params = llvm::make_unique<synGEMMParams>();
982+
std::vector<synTensor> inputs;
983+
inputs.push_back(tensors[MI->getLHS()].get());
984+
inputs.push_back(tensors[MI->getRHS()].get());
985+
chk(synCreateGenericNode(inputs.data(), &tensors[MI].get(),
986+
inputs.size(), 1, nullptr, "gemm",
987+
MI->getName().data(), nullptr, nullptr));
988+
gemmParams.emplace_back(std::move(params));
989+
974990
} else {
975991
std::vector<synTensor> inputs;
976992
inputs.push_back(tensors[MI->getLHS()].get());
@@ -1015,6 +1031,18 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
10151031
convParams.emplace_back(std::move(params));
10161032
break;
10171033
}
1034+
case Kinded::Kind::LocalResponseNormalizationNodeKind: {
1035+
auto *NI = llvm::cast<LocalResponseNormalizationNode>(&I);
1036+
std::unique_ptr<ns_LrnKernel::Params> params = makeLrnParams(
1037+
NI->getAlpha(), NI->getBeta(), NI->getK(), NI->getHalfWindowSize());
1038+
1039+
chk(synCreateGenericNode(&tensors[NI->getInput()].get(),
1040+
&tensors[NI].get(), 1, 1, (void *)params.get(),
1041+
"lrn_f32", NI->getName().str().c_str(), nullptr,
1042+
nullptr));
1043+
lrnParams.emplace_back(std::move(params));
1044+
break;
1045+
}
10181046
case Kinded::Kind::TransposeNodeKind: {
10191047
auto *TI = llvm::cast<TransposeNode>(&I);
10201048
std::unique_ptr<synTransposeParams> params =
@@ -1126,6 +1154,14 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
11261154
concatParams.emplace_back(std::move(params));
11271155
break;
11281156
}
1157+
case Kinded::Kind::RescaleQuantizedNodeKind: {
1158+
auto *RI = llvm::cast<RescaleQuantizedNode>(&I);
1159+
chk(synCreateGenericNode(
1160+
&tensors[RI->getInput()].get(), &tensors[RI].get(), 1, 1, nullptr,
1161+
getKernelName("requant", RI->getResult().getElementType()).c_str(),
1162+
RI->getName().data(), nullptr, nullptr));
1163+
break;
1164+
}
11291165
case Kinded::Kind::SaveNodeKind: {
11301166
auto *CI = llvm::cast<SaveNode>(&I);
11311167
if (tensors.count(CI)) {
@@ -1237,7 +1273,11 @@ bool HabanaBackend::isOpSupported(const NodeInfo &NI) const {
12371273
case Kinded::Kind::SplatNodeKind:
12381274
case Kinded::Kind::SubNodeKind:
12391275
case Kinded::Kind::TileNodeKind:
1276+
case Kinded::Kind::ConcatNodeKind:
12401277
return true;
1278+
case Kinded::Kind::RescaleQuantizedNodeKind:
1279+
return NI.allInputsAndOutputsHaveSameElemKind(
1280+
{ElemKind::Int8QTy, ElemKind::Int16QTy});
12411281
default:
12421282
return false;
12431283
}
@@ -1273,6 +1313,7 @@ bool HabanaBackend::isOpSupported(const NodeInfo &NI) const {
12731313
case Kinded::Kind::TransposeNodeKind:
12741314
case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
12751315
case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind:
1316+
case Kinded::Kind::LocalResponseNormalizationNodeKind:
12761317
return true;
12771318
default:
12781319
return false;

0 commit comments

Comments
 (0)