@@ -548,10 +548,10 @@ makeSynPoolParams(llvm::ArrayRef<unsigned_t> kernel,
548
548
params->sW = stride[0 ];
549
549
params->sH = stride[1 ];
550
550
// Padding
551
- params->pWbegin = pad[0 ];
552
- params->pWend = pad[0 ];
553
- params->pHbegin = pad[1 ];
554
- params->pHend = pad[1 ];
551
+ params->pHbegin = pad[0 ];
552
+ params->pWbegin = pad[1 ];
553
+ params->pHend = pad[2 ];
554
+ params->pWend = pad[3 ];
555
555
// Dilation
556
556
params->dilW = 1 ;
557
557
params->dilH = 1 ;
@@ -591,6 +591,16 @@ makeSynSliceAxisParams(unsigned axis, unsigned axes, unsigned outputAxisSize,
591
591
return params;
592
592
}
593
593
594
+ static std::unique_ptr<ns_LrnKernel::Params>
595
+ makeLrnParams (float alpha, float beta, float knorm, int halfWindowSize) {
596
+ auto params = llvm::make_unique<ns_LrnKernel::Params>();
597
+ params->alpha = alpha;
598
+ params->beta = beta;
599
+ params->knorm = knorm;
600
+ params->nsize = 2 * halfWindowSize + 1 ;
601
+ return params;
602
+ }
603
+
594
604
static std::unique_ptr<ns_ConstantKernel::Params>
595
605
makeConstantParams (float value) {
596
606
auto params = llvm::make_unique<ns_ConstantKernel::Params>();
@@ -733,6 +743,8 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
733
743
std::vector<std::unique_ptr<ns_TileKernel::Params>> tileParams;
734
744
std::vector<std::unique_ptr<unsigned >> concatParams;
735
745
std::vector<std::unique_ptr<ns_TakeKernel::Params>> takeParams;
746
+ std::vector<std::unique_ptr<ns_LrnKernel::Params>> lrnParams;
747
+ std::vector<std::unique_ptr<synGEMMParams>> gemmParams;
736
748
737
749
// Keep references to tensor pointer arrays passed into multi-input nodes
738
750
// until the compilation is done.
@@ -965,12 +977,16 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
965
977
if (MI->getLHS ().getType ()->isQuantizedType ()) {
966
978
// Let GEMM run on MME via FullyConnected node.
967
979
// MME only runs on quantized types, e.g., int8 or int16.
968
- auto params = llvm::make_unique<synFCParams>();
969
- params->activation .reluEnable = false ;
970
- chk (synFullyConnected (tensors[MI->getLHS ()].get (),
971
- tensors[MI->getRHS ()].get (), nullptr ,
972
- tensors[MI].get (), *params, " " ));
973
- fcParams.emplace_back (std::move (params));
980
+ // The default params are OK - don't transpose A and B
981
+ auto params = llvm::make_unique<synGEMMParams>();
982
+ std::vector<synTensor> inputs;
983
+ inputs.push_back (tensors[MI->getLHS ()].get ());
984
+ inputs.push_back (tensors[MI->getRHS ()].get ());
985
+ chk (synCreateGenericNode (inputs.data (), &tensors[MI].get (),
986
+ inputs.size (), 1 , nullptr , " gemm" ,
987
+ MI->getName ().data (), nullptr , nullptr ));
988
+ gemmParams.emplace_back (std::move (params));
989
+
974
990
} else {
975
991
std::vector<synTensor> inputs;
976
992
inputs.push_back (tensors[MI->getLHS ()].get ());
@@ -1015,6 +1031,18 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
1015
1031
convParams.emplace_back (std::move (params));
1016
1032
break ;
1017
1033
}
1034
+ case Kinded::Kind::LocalResponseNormalizationNodeKind: {
1035
+ auto *NI = llvm::cast<LocalResponseNormalizationNode>(&I);
1036
+ std::unique_ptr<ns_LrnKernel::Params> params = makeLrnParams (
1037
+ NI->getAlpha (), NI->getBeta (), NI->getK (), NI->getHalfWindowSize ());
1038
+
1039
+ chk (synCreateGenericNode (&tensors[NI->getInput ()].get (),
1040
+ &tensors[NI].get (), 1 , 1 , (void *)params.get (),
1041
+ " lrn_f32" , NI->getName ().str ().c_str (), nullptr ,
1042
+ nullptr ));
1043
+ lrnParams.emplace_back (std::move (params));
1044
+ break ;
1045
+ }
1018
1046
case Kinded::Kind::TransposeNodeKind: {
1019
1047
auto *TI = llvm::cast<TransposeNode>(&I);
1020
1048
std::unique_ptr<synTransposeParams> params =
@@ -1126,6 +1154,14 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
1126
1154
concatParams.emplace_back (std::move (params));
1127
1155
break ;
1128
1156
}
1157
+ case Kinded::Kind::RescaleQuantizedNodeKind: {
1158
+ auto *RI = llvm::cast<RescaleQuantizedNode>(&I);
1159
+ chk (synCreateGenericNode (
1160
+ &tensors[RI->getInput ()].get (), &tensors[RI].get (), 1 , 1 , nullptr ,
1161
+ getKernelName (" requant" , RI->getResult ().getElementType ()).c_str (),
1162
+ RI->getName ().data (), nullptr , nullptr ));
1163
+ break ;
1164
+ }
1129
1165
case Kinded::Kind::SaveNodeKind: {
1130
1166
auto *CI = llvm::cast<SaveNode>(&I);
1131
1167
if (tensors.count (CI)) {
@@ -1237,7 +1273,11 @@ bool HabanaBackend::isOpSupported(const NodeInfo &NI) const {
1237
1273
case Kinded::Kind::SplatNodeKind:
1238
1274
case Kinded::Kind::SubNodeKind:
1239
1275
case Kinded::Kind::TileNodeKind:
1276
+ case Kinded::Kind::ConcatNodeKind:
1240
1277
return true ;
1278
+ case Kinded::Kind::RescaleQuantizedNodeKind:
1279
+ return NI.allInputsAndOutputsHaveSameElemKind (
1280
+ {ElemKind::Int8QTy, ElemKind::Int16QTy});
1241
1281
default :
1242
1282
return false ;
1243
1283
}
@@ -1273,6 +1313,7 @@ bool HabanaBackend::isOpSupported(const NodeInfo &NI) const {
1273
1313
case Kinded::Kind::TransposeNodeKind:
1274
1314
case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
1275
1315
case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind:
1316
+ case Kinded::Kind::LocalResponseNormalizationNodeKind:
1276
1317
return true ;
1277
1318
default :
1278
1319
return false ;
0 commit comments