diff --git a/include/glow/Graph/Graph.h b/include/glow/Graph/Graph.h index 9e07b42da5..ce5a951f2e 100644 --- a/include/glow/Graph/Graph.h +++ b/include/glow/Graph/Graph.h @@ -510,6 +510,14 @@ class Function final : public Named { NodeValue input, Storage *W, Storage *B, unsigned_t axis = 1); + /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights + /// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened + /// along \p axis. Note, output type and outputDepth are inferred based on + /// the input types. + FullyConnectedNode *createFullyConnected(llvm::StringRef name, + NodeValue input, NodeValue W, + NodeValue B, unsigned_t axis = 1); + /// Creates and \returns a FullyConnectedNode with \p name, \p input, weights /// \p W, bias \p B, and \p outTy. If \p input is not 2 dimensional then it is /// flattened along \p axis. Note, outputDepth is inferred based on \p outTy. @@ -1337,6 +1345,45 @@ class Function final : public Named { const llvm::ArrayRef inputs, unsigned batchSize, unsigned hiddenSize, unsigned outputSize, std::vector &outputs); + + /// Definition for the activation function of an LSTM module. + using LstmActivation = std::function; + + /// Type definition for the direction of an LSTM module. + enum class LstmDirection { + Forward, + Reverse, + Bidirectional, + }; + + /// Create an unrolled multi-layer LSTM according to the ONNX definition. The + /// LSTM has the following inputs: + /// - input \p X with size [S, B, ISize]. + /// - weigts \p W with size [N, 4 * HSize, ISize]. + /// - reccurence weights \p R with size [N, 4 * HSize, HSize]. + /// - bias weights \p B with size [N, 8 * HSize]. + /// - initial hidden state \p initial_h with size [N, B, HSize]. + /// - initial cell state \p initial_c with size [N, B, HSize]. + /// - peephole weights \p P with size [N, 3 * HSize]. + /// where S is the sequence length, N is the number of directions, B is the + /// batch size, ISize is the input size and HSize is the hidden size. + /// The LSTM has the following outputs: + /// - output \p Y with size [S, N, B, HSize] + /// - final hidden state \p Y_h with size [N, B, HSize]. + /// - final cell state \p Y_c with size [N, B, HSize]. + /// The direction of the instatiated LSTM is given by \p direction. The LSTM + /// will use the activation functions defined by \p activations which defines: + /// - [f,g,h] in case the LSTM is unidirectional (3 functions). + /// - [f,g,h] for the forward cell followed by [f,g,h] for the reverse cell in + /// case the LSTM is bidirectional (6 functions). + /// The inputs \p B and \p P are optional (assumed 0 if nullptr is provided). + /// The names of all the nodes created are prefixed with \p namePrefix. + void createONNXLSTM(llvm::StringRef namePrefix, NodeValue X, NodeValue W, + NodeValue R, NodeValue B, NodeValue initial_h, + NodeValue initial_c, NodeValue P, NodeValue &Y, + NodeValue &Y_h, NodeValue &Y_c, unsigned hiddenSize, + LstmDirection direction, + std::vector &activations); /// @} /// Create a TraceEvent in the runtime profile, which triggers collection of diff --git a/include/glow/Importer/ONNXModelLoader.h b/include/glow/Importer/ONNXModelLoader.h index b40cb29a2e..399032bf96 100644 --- a/include/glow/Importer/ONNXModelLoader.h +++ b/include/glow/Importer/ONNXModelLoader.h @@ -150,6 +150,10 @@ class ONNXModelLoader Error loadWhere(const ONNX_NAMESPACE::NodeProto &op, const ArgumentDictionaryTy &dict); + /// Load LSTM ONNX operator. + Error loadLSTM(const ONNX_NAMESPACE::NodeProto &op, + const ArgumentDictionaryTy &dict); + /// Load Glow specific operators, not defined in ONNX format /// Load Glow CmpEQ operator. Error loadCmpEQ(const ONNX_NAMESPACE::NodeProto &op, diff --git a/lib/Graph/Graph.cpp b/lib/Graph/Graph.cpp index 308e0bcf9e..a2cb0bb798 100644 --- a/lib/Graph/Graph.cpp +++ b/lib/Graph/Graph.cpp @@ -834,6 +834,17 @@ FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name, return createFullyConnected(name, input, W, B, OT, axis); } +FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name, + NodeValue input, NodeValue W, + NodeValue B, + unsigned_t axis) { + TypeRef T = input.getType(); + TypeRef OT = + getParent()->uniqueTypeWithNewShape(T, {input.dims()[0], B.dims()[0]}); + + return createFullyConnected(name, input, W, B, OT, axis); +} + FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name, NodeValue input, NodeValue W, NodeValue B, TypeRef outTy, @@ -2904,6 +2915,369 @@ void Function::createLSTM(PlaceholderBindings &bindings, } }; +void Function::createONNXLSTM(llvm::StringRef namePrefix, NodeValue X, + NodeValue W, NodeValue R, NodeValue B, + NodeValue initial_h, NodeValue initial_c, + NodeValue P, NodeValue &Y, NodeValue &Y_h, + NodeValue &Y_c, unsigned hiddenSize, + LstmDirection direction, + std::vector &activations) { + +#define LSTM_X_SLICE_RANGE(idx) \ + {idx + 0, 0, 0}, { idx + 1, batchSize, inputSize } +#define LSTM_H_SLICE_RANGE(idx) \ + {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize } +#define LSTM_C_SLICE_RANGE(idx) \ + {idx + 0, 0, 0}, { idx + 1, batchSize, hiddenSize } +#define LSTM_W_SLICE_RANGE(idx0, idx1) \ + {idx0, idx1 * hiddenSize, 0}, { idx0 + 1, (idx1 + 1) * hiddenSize, inputSize } +#define LSTM_R_SLICE_RANGE(idx0, idx1) \ + {idx0, idx1 * hiddenSize, 0}, { \ + idx0 + 1, (idx1 + 1) * hiddenSize, hiddenSize \ + } +#define LSTM_B_SLICE_RANGE(idx0, idx1) \ + {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize } +#define LSTM_P_SLICE_RANGE(idx0, idx1) \ + {idx0, idx1 * hiddenSize}, { idx0 + 1, (idx1 + 1) * hiddenSize } +#define LSTM_CREATE_FC(name, LHS, RHS, BIAS) \ + BIAS ? (Node *)createFullyConnected(name, LHS, RHS, BIAS) \ + : (Node *)createMatMul(name, LHS, RHS) + + // Operator name. + const std::string &opName = namePrefix.str(); + + // Get all size parameters. + size_t numDirections = (direction == LstmDirection::Bidirectional) ? 2 : 1; + assert(X.dims().size() == 3 && + "ONNX LSTM input 'X' should have 3 dimensions!"); + size_t seqLength = X.dims()[0]; + size_t batchSize = X.dims()[1]; + size_t inputSize = X.dims()[2]; + + // Validate W size. + assert(W.dims().size() == 3 && + "ONNX LSTM input 'W' should have 3 dimensions!"); + assert(W.dims()[0] == numDirections && W.dims()[1] == 4 * hiddenSize && + W.dims()[2] == inputSize && "ONNX LSTM 'W' tensor size invalid!"); + + // Validate R size. + assert(R.dims().size() == 3 && + "ONNX LSTM input 'R' should have 3 dimensions!"); + assert(R.dims()[0] == numDirections && R.dims()[1] == 4 * hiddenSize && + R.dims()[2] == hiddenSize && "ONNX LSTM 'R' tensor size invalid!"); + + // Validate B size. + if (B.getNode()) { + assert(B.dims().size() == 2 && + "ONNX LSTM input 'B' should have 2 dimensions!"); + assert(B.dims()[0] == numDirections && B.dims()[1] == 8 * hiddenSize && + "ONNX LSTM 'B' tensor size invalid!"); + } + + // Validate initial_h size. + assert(initial_h.getNode() && + "ONNX LSTM input 'initial_h' is mandatory. Null provided!"); + assert(initial_h.dims().size() == 3 && + "ONNX LSTM input 'initial_h' should have 2 dimensions!"); + assert(initial_h.dims()[0] == numDirections && + initial_h.dims()[1] == batchSize && + initial_h.dims()[2] == hiddenSize && + "ONNX LSTM 'initial_h' tensor size invalid!"); + + // Validate initial_c size. + assert(initial_c.getNode() && + "ONNX LSTM input 'initial_c' is mandatory. Null provided!"); + assert(initial_c.dims().size() == 3 && + "ONNX LSTM input 'initial_c' should have 2 dimensions!"); + assert(initial_c.dims()[0] == numDirections && + initial_c.dims()[1] == batchSize && + initial_c.dims()[2] == hiddenSize && + "ONNX LSTM 'initial_c' tensor size invalid!"); + + // Validate P size. + if (P.getNode()) { + assert(P.dims().size() == 2 && + "ONNX LSTM input 'P' should have 2 dimensions!"); + assert(P.dims()[0] == numDirections && P.dims()[1] == 3 * hiddenSize && + "ONNX LSTM 'P' tensor size invalid!"); + } + + // Validate number of activations. + assert(activations.size() == numDirections * 3 && + "ONNX LSTM activations vector invalid!"); + + // Create X slices. + std::vector Xslices; + for (size_t t = 0; t < seqLength; t++) { + auto XsliceName = opName + ".X" + std::to_string(t) + ".slice"; + Node *Xt = createSlice(XsliceName, X, LSTM_X_SLICE_RANGE(t)); + auto XreshapeName = opName + ".X" + std::to_string(t) + ".reshape"; + Xt = createReshape(XreshapeName, Xt, {batchSize, inputSize}); + Xslices.push_back(Xt); + } + + // Lambda to load forward/backward LSTM cell. + auto loadLSTMCell = [&](bool forward, std::vector &Yslices, + NodeValue &Hslice, NodeValue &Cslice) { + // Name prefix. + std::string dirLabel = forward ? ".fw" : ".bw"; + std::string prefix = opName + ((numDirections > 1) ? dirLabel : ""); + + // Slice index used for creating weights slices. + size_t sliceIdx0 = 0; + if (direction == LstmDirection::Bidirectional) { + sliceIdx0 = forward ? 0 : 1; + } + + // Activations. + size_t activationOffset = 0; + if (direction == LstmDirection::Bidirectional) { + activationOffset = forward ? 0 : 3; + } + auto activationF = activations[activationOffset + 0]; + auto activationG = activations[activationOffset + 1]; + auto activationH = activations[activationOffset + 2]; + + // Create W slices (Required). + NodeValue Wi = + createSlice(prefix + ".Wi.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 0)); + NodeValue Wo = + createSlice(prefix + ".Wo.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 1)); + NodeValue Wf = + createSlice(prefix + ".Wf.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 2)); + NodeValue Wc = + createSlice(prefix + ".Wc.", W, LSTM_W_SLICE_RANGE(sliceIdx0, 3)); + + Wi = createReshape(prefix + ".Wi.reshape", Wi, {hiddenSize, inputSize}); + Wo = createReshape(prefix + ".Wo.reshape", Wo, {hiddenSize, inputSize}); + Wf = createReshape(prefix + ".Wf.reshape", Wf, {hiddenSize, inputSize}); + Wc = createReshape(prefix + ".Wc.reshape", Wc, {hiddenSize, inputSize}); + + Wi = createTranspose(prefix + ".Wi.transp", Wi, {1, 0}); + Wo = createTranspose(prefix + ".Wo.transp", Wo, {1, 0}); + Wf = createTranspose(prefix + ".Wf.transp", Wf, {1, 0}); + Wc = createTranspose(prefix + ".Wc.transp", Wc, {1, 0}); + + // Create R slices (Required). + NodeValue Ri = + createSlice(prefix + ".Ri.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 0)); + NodeValue Ro = + createSlice(prefix + ".Ro.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 1)); + NodeValue Rf = + createSlice(prefix + ".Rf.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 2)); + NodeValue Rc = + createSlice(prefix + ".Rc.", R, LSTM_R_SLICE_RANGE(sliceIdx0, 3)); + + Ri = createReshape(prefix + ".Ri.reshape", Ri, {hiddenSize, hiddenSize}); + Ro = createReshape(prefix + ".Ro.reshape", Ro, {hiddenSize, hiddenSize}); + Rf = createReshape(prefix + ".Rf.reshape", Rf, {hiddenSize, hiddenSize}); + Rc = createReshape(prefix + ".Rc.reshape", Rc, {hiddenSize, hiddenSize}); + + Ri = createTranspose(prefix + ".Ri.transp", Ri, {1, 0}); + Ro = createTranspose(prefix + ".Ro.transp", Ro, {1, 0}); + Rf = createTranspose(prefix + ".Rf.transp", Rf, {1, 0}); + Rc = createTranspose(prefix + ".Rc.transp", Rc, {1, 0}); + + // Create B slices (optional). + NodeValue bWi = nullptr; + NodeValue bWo = nullptr; + NodeValue bWf = nullptr; + NodeValue bWc = nullptr; + NodeValue bRi = nullptr; + NodeValue bRo = nullptr; + NodeValue bRf = nullptr; + NodeValue bRc = nullptr; + + if (B) { + + bWi = createSlice(prefix + ".bWi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 0)); + bWo = createSlice(prefix + ".bWo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 1)); + bWf = createSlice(prefix + ".bWf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 2)); + bWc = createSlice(prefix + ".bWc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 3)); + bRi = createSlice(prefix + ".bRi.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 4)); + bRo = createSlice(prefix + ".bRo.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 5)); + bRf = createSlice(prefix + ".bRf.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 6)); + bRc = createSlice(prefix + ".bRc.", B, LSTM_B_SLICE_RANGE(sliceIdx0, 7)); + + bWi = createReshape(prefix + ".bWi.reshape", bWi, {hiddenSize}); + bWo = createReshape(prefix + ".bWo.reshape", bWo, {hiddenSize}); + bWf = createReshape(prefix + ".bWf.reshape", bWf, {hiddenSize}); + bWc = createReshape(prefix + ".bWc.reshape", bWc, {hiddenSize}); + bRi = createReshape(prefix + ".bRi.reshape", bRi, {hiddenSize}); + bRo = createReshape(prefix + ".bRo.reshape", bRo, {hiddenSize}); + bRf = createReshape(prefix + ".bRf.reshape", bRf, {hiddenSize}); + bRc = createReshape(prefix + ".bRc.reshape", bRc, {hiddenSize}); + } + + // Create P slices (optional). + NodeValue Pi = nullptr; + NodeValue Po = nullptr; + NodeValue Pf = nullptr; + + if (P) { + + Pi = createSlice(prefix + ".Pi.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 0)); + Po = createSlice(prefix + ".Po.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 1)); + Pf = createSlice(prefix + ".Pf.", P, LSTM_P_SLICE_RANGE(sliceIdx0, 2)); + + // Repeat P slices to match [batchSize, hiddenSize]. + Pi = createTile(prefix + ".Pi.repeat", Pi, batchSize, 0); + Po = createTile(prefix + ".Po.repeat", Po, batchSize, 0); + Pf = createTile(prefix + ".Pf.repeat", Pf, batchSize, 0); + } + + // Create H slice for this direction. + Node *Hinit = initial_h.getNode(); + if (numDirections > 1) { + Hinit = createSlice(prefix + ".H.slice", Hinit, + LSTM_H_SLICE_RANGE(sliceIdx0)); + } + Hinit = + createReshape(prefix + ".H.reshape", Hinit, {batchSize, hiddenSize}); + + // Create C slice for this direction. + Node *Cinit = initial_c.getNode(); + if (numDirections > 1) { + Cinit = createSlice(prefix + ".C.slice", Cinit, + LSTM_C_SLICE_RANGE(sliceIdx0)); + } + Cinit = + createReshape(prefix + ".C.reshape", Cinit, {batchSize, hiddenSize}); + + // Initialize. + Node *Ht = Hinit; + Node *Ct = Cinit; + + // Unroll LSTM cell for all time steps. + for (size_t t = 0; t < seqLength; t++) { + + // Input for current time step. + // For the reverse LSTM cell the inputs are provided in reverse order. + Node *Xt = forward ? Xslices[t] : Xslices[seqLength - 1 - t]; + + // Forget gate: ft = f(Wf * Xt + bWf + Rf * Ht-1 + bRf + Pf . Ct-1). + Node *ft = createAdd(prefix + ".F.add1", + LSTM_CREATE_FC(prefix + ".F.fc1", Xt, Wf, bWf), + LSTM_CREATE_FC(prefix + ".F.fc2", Ht, Rf, bRf)); + if (Pf) { + ft = createAdd(prefix + ".F.add2", ft, + createMul(prefix + ".F.mult", Pf, Ct)); + } + ft = activationF(prefix + ".F.act", ft); + + // Cell state candidate: ctild = g(Wc * Xt + bWc + Rc * Ht-1 + bRc). + Node *ctild = + createAdd(prefix + ".ctild.add1", + LSTM_CREATE_FC(prefix + ".ctild.fc1", Xt, Wc, bWc), + LSTM_CREATE_FC(prefix + ".ctild.fc2", Ht, Rc, bRc)); + ctild = activationG(prefix + ".ctild.act", ctild); + + // Input gate: it = f(Wi * Xt + bWi + Ri * Ht-1 + bRi + Pi . Ct-1). + Node *it = createAdd(prefix + ".I.add1", + LSTM_CREATE_FC(prefix + ".I.fc1", Xt, Wi, bWi), + LSTM_CREATE_FC(prefix + ".I.fc2", Ht, Ri, bRi)); + if (Pi) { + it = createAdd(prefix + ".I.add2", it, + createMul(prefix + ".I.mult", Pi, Ct)); + } + it = activationF(prefix + ".I.act", it); + + // Cell state update: Ct = ft . Ct-1 + it . ctild. + Ct = createAdd(prefix + ".C.add", createMul(prefix + ".C.mult1", ft, Ct), + createMul(prefix + ".C.mult2", it, ctild)); + + // Output gate: ot = f(Wo * Xt + bWo + Ro * Ht-1 + bRo + Po . Ct). + Node *ot = createAdd(prefix + ".O.add1", + LSTM_CREATE_FC(prefix + ".O.fc1", Xt, Wo, bWo), + LSTM_CREATE_FC(prefix + ".O.fc2", Ht, Ro, bRo)); + if (Po) { + ot = createAdd(prefix + ".O.add2", ot, + createMul(prefix + ".O.mult", Po, Ct)); + } + ot = activationF(prefix + ".O.act", ot); + + // Hidden state update: Ht = ot . h(Ct). + Ht = + createMul(prefix + ".H.mult", ot, activationH(prefix + ".H.act", Ct)); + + // Output. + Yslices.push_back(Ht); + } + + // Updated states nodes. + Hslice = Ht; + Cslice = Ct; + }; // End of local lambda "loadLSTMCell". + + bool forwardEnabled = ((direction == LstmDirection::Forward) || + (direction == LstmDirection::Bidirectional)); + bool backwardEnabled = ((direction == LstmDirection::Reverse) || + (direction == LstmDirection::Bidirectional)); + + std::vector YSlices; + std::vector Hslices; + std::vector Cslices; + + // Load forward LSTM. + std::vector forwardYslices; + if (forwardEnabled) { + NodeValue forwardHslice; + NodeValue forwardCslice; + loadLSTMCell(/* forward */ true, forwardYslices, forwardHslice, + forwardCslice); + Hslices.push_back(forwardHslice); + Cslices.push_back(forwardCslice); + } + + // Load backward LSTM. + std::vector backwardYslices; + if (backwardEnabled) { + NodeValue backwardHslice; + NodeValue backwardCslice; + loadLSTMCell(/* forward */ false, backwardYslices, backwardHslice, + backwardCslice); + Hslices.push_back(backwardHslice); + Cslices.push_back(backwardCslice); + } + + // Gather Y slices. + for (size_t t = 0; t < seqLength; t++) { + if (forwardEnabled) { + YSlices.push_back(forwardYslices[t]); + } + if (backwardEnabled) { + YSlices.push_back(backwardYslices[seqLength - 1 - t]); + } + } + + // Concatenate Y slices. + // Y size is [seqLength, numDirections, batchSize, hiddenSize]. + Y = createReshape(opName + ".Y.reshape", + createConcat(opName + ".Y.concat", YSlices, 0), + {seqLength, numDirections, batchSize, hiddenSize}); + + // Concatenate Y_h slices. + // Y_h size is [numDirections, batchSize, hiddenSize]. + Y_h = createReshape(opName + ".Y_h.reshape", + createConcat(opName + ".Y_h.concat", Hslices, 0), + {numDirections, batchSize, hiddenSize}); + + // Concatenate Y_c slices. + // Y_c size is [numDirections, batchSize, hiddenSize]. + Y_c = createReshape(opName + ".Y_c.reshape", + createConcat(opName + ".Y_c.concat", Cslices, 0), + {numDirections, batchSize, hiddenSize}); + +#undef LSTM_X_SLICE_RANGE +#undef LSTM_H_SLICE_RANGE +#undef LSTM_C_SLICE_RANGE +#undef LSTM_W_SLICE_RANGE +#undef LSTM_R_SLICE_RANGE +#undef LSTM_B_SLICE_RANGE +#undef LSTM_P_SLICE_RANGE +#undef LSTM_CREATE_FC +} + TraceEventNode *Function::createTraceEvent(llvm::StringRef eventName, llvm::StringRef eventType, Node *data, unsigned index) { diff --git a/lib/Importer/ONNXModelLoader.cpp b/lib/Importer/ONNXModelLoader.cpp index 514606d09f..5f6a681eff 100644 --- a/lib/Importer/ONNXModelLoader.cpp +++ b/lib/Importer/ONNXModelLoader.cpp @@ -1230,6 +1230,211 @@ Error ONNXModelLoader::loadWhere(const ONNX_NAMESPACE::NodeProto &op, return Error::success(); } +// ONNX LSTM: https://github.com/onnx/onnx/blob/master/docs/Operators.md#lstm +// Limitations: +// - Only Sigmoid, Tahn and ReLU activations are supported. +// - Activation clipping not supported. +// - Variable sequence length not supported. +// - Coupling of input and forget gate not supported. +Error ONNXModelLoader::loadLSTM(const ONNX_NAMESPACE::NodeProto &op, + const ArgumentDictionaryTy &dict) { + + const std::string &opName = loadOperatorName(op); + + // ------------------------- Attributes ------------------------------------- + // Get direction (Optional)(Default:forward). + Function::LstmDirection direction = Function::LstmDirection::Forward; + if (dict.count("direction")) { + std::string directionStr; + ASSIGN_VALUE_OR_RETURN_ERR(directionStr, loadStr(dict.at("direction"))); + if (directionStr == "forward") { + direction = Function::LstmDirection::Forward; + } else if (directionStr == "reverse") { + direction = Function::LstmDirection::Reverse; + } else if (directionStr == "bidirectional") { + direction = Function::LstmDirection::Bidirectional; + } else { + RETURN_ERR("ONNX LSTM 'direction' attribute is invalid!", + ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE); + } + } + size_t numDirections = + (direction == Function::LstmDirection::Bidirectional) ? 2 : 1; + + // Activation alpha not supported (Optional)(Default:activation dependent). + RETURN_ERR_IF_NOT(!dict.count("activation_alpha"), + "ONNX LSTM 'activation_alpha' attribute not supported!"); + + // Activation beta not supported (Optional)(Default:activation dependent). + RETURN_ERR_IF_NOT(!dict.count("activation_beta"), + "ONNX LSTM 'activation_beta' attribute not supported!"); + + // Get activations as lambdas (Optional)(Default:f=Sigmoid, g=Tanh, h=Tanh). +#define LSTM_ACTIVATION_LAMBDA_RELU \ + [this](llvm::StringRef name, Node *input) { \ + return G_.createRELU(name, input); \ + } +#define LSTM_ACTIVATION_LAMBDA_TANH \ + [this](llvm::StringRef name, Node *input) { \ + return G_.createTanh(name, input); \ + } +#define LSTM_ACTIVATION_LAMBDA_SIGMOID \ + [this](llvm::StringRef name, Node *input) { \ + return G_.createSigmoid(name, input); \ + } + std::vector activations; + if (direction == Function::LstmDirection::Bidirectional) { + activations = { + LSTM_ACTIVATION_LAMBDA_SIGMOID, LSTM_ACTIVATION_LAMBDA_TANH, + LSTM_ACTIVATION_LAMBDA_TANH, LSTM_ACTIVATION_LAMBDA_SIGMOID, + LSTM_ACTIVATION_LAMBDA_TANH, LSTM_ACTIVATION_LAMBDA_TANH}; + } else { + activations = {LSTM_ACTIVATION_LAMBDA_SIGMOID, LSTM_ACTIVATION_LAMBDA_TANH, + LSTM_ACTIVATION_LAMBDA_TANH}; + } + if (dict.count("activations") && dict.at("activations")->strings_size()) { + size_t actNum = dict.at("activations")->strings_size(); + RETURN_ERR_IF_NOT(actNum == numDirections * 3, + "ONNX LSTM 'activations' attribute is invalid!"); + for (size_t actIdx = 0; actIdx < actNum; actIdx++) { + std::string actStr = dict.at("activations")->strings().Get(actIdx); + if (actStr == "Relu") { + activations[actIdx] = LSTM_ACTIVATION_LAMBDA_RELU; + } else if (actStr == "Tanh") { + activations[actIdx] = LSTM_ACTIVATION_LAMBDA_TANH; + } else if (actStr == "Sigmoid") { + activations[actIdx] = LSTM_ACTIVATION_LAMBDA_SIGMOID; + } else { + RETURN_ERR("ONNX LSTM activation '" + actStr + "' not supported!", + ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE); + } + } + } +#undef LSTM_ACTIVATION_LAMBDA_RELU +#undef LSTM_ACTIVATION_LAMBDA_TANH +#undef LSTM_ACTIVATION_LAMBDA_SIGMOID + + // Activation clipping not supported (Optional)(Default: 0 for no clipping). + RETURN_ERR_IF_NOT(!dict.count("clip"), + "ONNX LSTM 'clip' attribute not supported!"); + + // Get hidden size (Required). + size_t hiddenSize; + RETURN_ERR_IF_NOT(dict.count("hidden_size"), + "ONNX LSTM 'hidden_size' attribute is required!"); + ASSIGN_VALUE_OR_RETURN_ERR(hiddenSize, loadInt(dict.at("hidden_size"))); + + // Get input forget (Optional)(Default:0). + int inputForget = 0; + if (dict.count("input_forget") && dict.at("input_forget")->has_i()) { + inputForget = dict.at("input_forget")->i(); + } + RETURN_ERR_IF_NOT(inputForget == 0, + "ONNX LSTM 'input_forget' attribute not supported!"); + + // --------------------------- Inputs --------------------------------------- + const int numInputs = op.input_size(); + RETURN_ERR_IF_NOT((3 <= numInputs) && (numInputs <= 8), + "ONNX LSTM should have minimum 3 and maximum 8 inputs!"); + + // Input0: X (Required). + NodeValue X; + ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0))); + + // Input1: W (Required). + NodeValue W; + ASSIGN_VALUE_OR_RETURN_ERR(W, getNodeValueByName(op.input(1))); + + // Input2: R (Required). + NodeValue R; + ASSIGN_VALUE_OR_RETURN_ERR(R, getNodeValueByName(op.input(2))); + + // Input3: B (Optional). + NodeValue B = nullptr; + if (numInputs > 3 && !op.input(3).empty()) { + ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(op.input(3))); + } + + // Input4: sequence_lens (Optional). + if (numInputs > 4) { + RETURN_ERR_IF_NOT(op.input(4).empty(), + "ONNX LSTM 'sequence_lens' attribute not supported!"); + } + + // Input5: initial_h (Optional). + NodeValue initial_h = nullptr; + if (numInputs > 5 && !op.input(5).empty()) { + ASSIGN_VALUE_OR_RETURN_ERR(initial_h, getNodeValueByName(op.input(5))); + } + + // Input6: initial_c (Optional). + NodeValue initial_c = nullptr; + if (numInputs > 6 && !op.input(6).empty()) { + ASSIGN_VALUE_OR_RETURN_ERR(initial_c, getNodeValueByName(op.input(6))); + } + + // Input7: P (Optional). + NodeValue P = nullptr; + if (numInputs > 7 && !op.input(7).empty()) { + ASSIGN_VALUE_OR_RETURN_ERR(P, getNodeValueByName(op.input(7))); + } + + // -------------------------- Outputs --------------------------------------- + // We always create placeholders for the LSTM state variables (Y_h and Y_c) + // for the following reasons: + // - expose the LSTM state in the graph interface for accessibility (set + // desired state, reset state, watch the state being updated automatically). + // - since the LSTM cells are unrolled (no graph loop primitive available + // at this point), the optimal way to use the LSTM within a model would be + // to have it defined with only 1 time step and have the loop in the top + // of the application while the LSTM state will be automatically updated + // from one iteration (time step) to the next through the placeholders. + const int numOutputs = op.output_size(); + RETURN_ERR_IF_NOT(1 <= numOutputs, + "ONNX LSTM should have minimum 1 output defined!"); + + // Derived parameters. + RETURN_ERR_IF_NOT(X.dims().size() == 3, + "ONNX LSTM input 'X' should have 3 dimensions!"); + size_t batchSize = X.dims()[1]; + + // Create Y_h (hidden state) output placeholder. + Placeholder *Y_h_ph; + TypeRef Htype = G_.getParent()->uniqueTypeWithNewShape( + X.getType(), {numDirections, batchSize, hiddenSize}); + std::string Hname = opName + ".Y_h"; + ASSIGN_VALUE_OR_RETURN_ERR(Y_h_ph, + createAndRegisterPlaceholder(Hname, Htype)); + inputVarsByName_.try_emplace(Hname, Y_h_ph); + + // Create Y_c (cell state) output placeholder. + Placeholder *Y_c_ph; + TypeRef Ctype = G_.getParent()->uniqueTypeWithNewShape( + X.getType(), {numDirections, batchSize, hiddenSize}); + std::string Cname = opName + ".Y_c"; + ASSIGN_VALUE_OR_RETURN_ERR(Y_c_ph, + createAndRegisterPlaceholder(Cname, Ctype)); + inputVarsByName_.try_emplace(Cname, Y_c_ph); + + // If LSTM input states are explicitly provided then used them. If not, then + // use the LSTM state placeholders. + NodeValue Y_h_init = initial_h.getNode() ? initial_h : Y_h_ph; + NodeValue Y_c_init = initial_c.getNode() ? initial_c : Y_c_ph; + + // Create ONNX LSTM. + NodeValue Y, Y_h, Y_c; + G_.createONNXLSTM(opName, X, W, R, B, Y_h_init, Y_c_init, P, Y, Y_h, Y_c, + hiddenSize, direction, activations); + + // Save LSTM state in the state placeholders. + G_.createSave(opName + ".Y_h.save", Y_h, Y_h_ph); + G_.createSave(opName + ".Y_c.save", Y_c, Y_c_ph); + + // Add node. + RETURN_IF_ERR(addNodeAsOutput(op, Y, 1)); + return Error::success(); +} + Error ONNXModelLoader::loadCmpEQ(const ONNX_NAMESPACE::NodeProto &op, const ArgumentDictionaryTy &dict) { NodeValue LHS; @@ -1621,6 +1826,9 @@ Error ONNXModelLoader::loadOperator(const ONNX_NAMESPACE::NodeProto &op) { if (typeName == "Where") { return loadWhere(op, dict); } + if (typeName == "LSTM") { + return loadLSTM(op, dict); + } // Glow specific operators if (typeName == "CmpEQ") { return loadCmpEQ(op, dict); diff --git a/tests/models/onnxModels/lstmBidirectional.onnxtxt b/tests/models/onnxModels/lstmBidirectional.onnxtxt new file mode 100644 index 0000000000..132ff72dcf --- /dev/null +++ b/tests/models/onnxModels/lstmBidirectional.onnxtxt @@ -0,0 +1,720 @@ +ir_version: 5 +producer_name: "onnx-lstm" +graph { + node { + input: "X" + input: "W" + input: "R" + input: "B" + input: "" + input: "initial_h" + input: "initial_c" + input: "" + output: "Y" + name: "lstm" + op_type: "LSTM" + attribute { + name: "direction" + s: "bidirectional" + type: STRING + } + attribute { + name: "hidden_size" + i: 4 + type: INT + } + } + node { + input: "Y" + input: "Y_ref" + output: "Y_err" + name: "error" + op_type: "Sub" + } + name: "lstm_test" + initializer { + dims: 2 + dims: 5 + dims: 3 + data_type: 1 + float_data: 1.6243454217910767 + float_data: -0.6117563843727112 + float_data: -0.5281717777252197 + float_data: -1.072968602180481 + float_data: 0.8654076457023621 + float_data: -2.3015387058258057 + float_data: 1.744811773300171 + float_data: -0.7612069249153137 + float_data: 0.31903910636901855 + float_data: -0.24937038123607635 + float_data: 1.4621078968048096 + float_data: -2.060140609741211 + float_data: -0.3224171996116638 + float_data: -0.38405436277389526 + float_data: 1.1337693929672241 + float_data: -1.0998913049697876 + float_data: -0.1724282056093216 + float_data: -0.8778584003448486 + float_data: 0.042213745415210724 + float_data: 0.5828152298927307 + float_data: -1.1006191968917847 + float_data: 1.144723653793335 + float_data: 0.9015907049179077 + float_data: 0.5024943351745605 + float_data: 0.9008559584617615 + float_data: -0.6837278604507446 + float_data: -0.12289022654294968 + float_data: -0.9357694387435913 + float_data: -0.26788806915283203 + float_data: 0.5303554534912109 + name: "X" + } + initializer { + dims: 2 + dims: 16 + dims: 3 + data_type: 1 + float_data: -0.6916607618331909 + float_data: -0.3967535197734833 + float_data: -0.6871727108955383 + float_data: -0.8452056646347046 + float_data: -0.6712461113929749 + float_data: -0.01266459934413433 + float_data: -1.1173104047775269 + float_data: 0.2344156950712204 + float_data: 1.6598021984100342 + float_data: 0.7420441508293152 + float_data: -0.19183555245399475 + float_data: -0.887628972530365 + float_data: -0.7471582889556885 + float_data: 1.6924545764923096 + float_data: 0.050807755440473557 + float_data: -0.6369956731796265 + float_data: 0.19091548025608063 + float_data: 2.100255250930786 + float_data: 0.12015895545482635 + float_data: 0.6172031164169312 + float_data: 0.30017033219337463 + float_data: -0.3522498607635498 + float_data: -1.142518162727356 + float_data: -0.3493427336215973 + float_data: -0.20889423787593842 + float_data: 0.5866231918334961 + float_data: 0.838983416557312 + float_data: 0.9311020970344543 + float_data: 0.2855873107910156 + float_data: 0.8851411938667297 + float_data: -0.7543979287147522 + float_data: 1.2528681755065918 + float_data: 0.5129297971725464 + float_data: -0.2980928421020508 + float_data: 0.4885181486606598 + float_data: -0.07557171583175659 + float_data: 1.1316293478012085 + float_data: 1.5198168754577637 + float_data: 2.185575485229492 + float_data: -1.396496295928955 + float_data: -1.444113850593567 + float_data: -0.5044658780097961 + float_data: 0.1600370705127716 + float_data: 0.8761689066886902 + float_data: 0.31563493609428406 + float_data: -2.0222012996673584 + float_data: -0.30620402097702026 + float_data: 0.8279746174812317 + float_data: 0.23009473085403442 + float_data: 0.7620111703872681 + float_data: -0.22232814133167267 + float_data: -0.20075806975364685 + float_data: 0.18656139075756073 + float_data: 0.4100516438484192 + float_data: 0.19829972088336945 + float_data: 0.11900864541530609 + float_data: -0.6706622838973999 + float_data: 0.3775637745857239 + float_data: 0.12182126939296722 + float_data: 1.129483938217163 + float_data: 1.1989178657531738 + float_data: 0.1851564198732376 + float_data: -0.37528494000434875 + float_data: -0.6387304067611694 + float_data: 0.4234943687915802 + float_data: 0.07734006643295288 + float_data: -0.3438536822795868 + float_data: 0.04359685629606247 + float_data: -0.6200008392333984 + float_data: 0.698032021522522 + float_data: -0.447128564119339 + float_data: 1.2245076894760132 + float_data: 0.40349164605140686 + float_data: 0.5935785174369812 + float_data: -1.094911813735962 + float_data: 0.16938243806362152 + float_data: 0.7405564785003662 + float_data: -0.953700602054596 + float_data: -0.26621851325035095 + float_data: 0.03261454775929451 + float_data: -1.3731173276901245 + float_data: 0.3151593804359436 + float_data: 0.8461606502532959 + float_data: -0.8595159649848938 + float_data: 0.3505459725856781 + float_data: -1.3122833967208862 + float_data: -0.03869551047682762 + float_data: -1.6157723665237427 + float_data: 1.121417760848999 + float_data: 0.40890052914619446 + float_data: -0.02461695671081543 + float_data: -0.775161623954773 + float_data: 1.2737559080123901 + float_data: 1.9671016931533813 + float_data: -1.8579819202423096 + float_data: 1.2361639738082886 + name: "W" + } + initializer { + dims: 2 + dims: 16 + dims: 4 + data_type: 1 + float_data: 1.6276507377624512 + float_data: 0.3380116820335388 + float_data: -1.1992679834365845 + float_data: 0.8633453249931335 + float_data: -0.1809203028678894 + float_data: -0.6039206385612488 + float_data: -1.230058193206787 + float_data: 0.5505374670028687 + float_data: 0.79280686378479 + float_data: -0.6235307455062866 + float_data: 0.5205763578414917 + float_data: -1.1443413496017456 + float_data: 0.801861047744751 + float_data: 0.04656729847192764 + float_data: -0.18656976521015167 + float_data: -0.10174587368965149 + float_data: 0.8688861727714539 + float_data: 0.7504116296768188 + float_data: 0.5294653177261353 + float_data: 0.13770121335983276 + float_data: 0.07782112807035446 + float_data: 0.6183802485466003 + float_data: 0.2324945628643036 + float_data: 0.682551383972168 + float_data: -0.3101167678833008 + float_data: -2.434837818145752 + float_data: 1.0388245582580566 + float_data: 2.1869795322418213 + float_data: 0.44136443734169006 + float_data: -0.10015523433685303 + float_data: -0.13644474744796753 + float_data: -0.11905419081449509 + float_data: 0.01740940846502781 + float_data: -1.1220186948776245 + float_data: -0.5170944333076477 + float_data: -0.997026801109314 + float_data: 0.2487991601228714 + float_data: -0.29664114117622375 + float_data: 0.49521133303642273 + float_data: -0.17470316588878632 + float_data: 0.9863351583480835 + float_data: 0.2135339081287384 + float_data: 2.190699815750122 + float_data: -1.8963608741760254 + float_data: -0.6469166874885559 + float_data: 0.901486873626709 + float_data: 2.5283257961273193 + float_data: -0.24863477051258087 + float_data: 0.043668992817401886 + float_data: -0.2263142466545105 + float_data: 1.3314571380615234 + float_data: -0.28730785846710205 + float_data: 0.6800698637962341 + float_data: -0.31980159878730774 + float_data: -1.2725588083267212 + float_data: 0.3135477304458618 + float_data: 0.5031847953796387 + float_data: 1.293225884437561 + float_data: -0.11044702678918839 + float_data: -0.6173620820045471 + float_data: 0.5627610683441162 + float_data: 0.24073709547519684 + float_data: 0.2806650698184967 + float_data: -0.07311270385980606 + float_data: 1.1603385210037231 + float_data: 0.36949270963668823 + float_data: 1.9046586751937866 + float_data: 1.1110566854476929 + float_data: 0.6590498089790344 + float_data: -1.6274383068084717 + float_data: 0.6023193001747131 + float_data: 0.4202822148799896 + float_data: 0.8109516501426697 + float_data: 1.044442057609558 + float_data: -0.4008781909942627 + float_data: 0.8240056037902832 + float_data: -0.5623054504394531 + float_data: 1.9548780918121338 + float_data: -1.33195161819458 + float_data: -1.7606885433197021 + float_data: -1.6507213115692139 + float_data: -0.8905555605888367 + float_data: -1.1191153526306152 + float_data: 1.9560788869857788 + float_data: -0.32649949193000793 + float_data: -1.342675805091858 + float_data: 1.1143829822540283 + float_data: -0.5865239500999451 + float_data: -1.2368533611297607 + float_data: 0.8758389353752136 + float_data: 0.6233621835708618 + float_data: -0.4349566698074341 + float_data: 1.407539963722229 + float_data: 0.12910157442092896 + float_data: 1.6169495582580566 + float_data: 0.5027408599853516 + float_data: 1.5588055849075317 + float_data: 0.10940269380807877 + float_data: -1.2197444438934326 + float_data: 2.449368715286255 + float_data: -0.5457741618156433 + float_data: -0.19883786141872406 + float_data: -0.7003985047340393 + float_data: -0.20339444279670715 + float_data: 0.24266944825649261 + float_data: 0.2018301784992218 + float_data: 0.6610202789306641 + float_data: 1.7921582460403442 + float_data: -0.12046457082033157 + float_data: -1.2331206798553467 + float_data: -1.182318091392517 + float_data: -0.665754497051239 + float_data: -1.6741957664489746 + float_data: 0.8250298500061035 + float_data: -0.4982135593891144 + float_data: -0.3109849691390991 + float_data: -0.0018914828542619944 + float_data: -1.3966203927993774 + float_data: -0.8613163828849792 + float_data: 0.6747115254402161 + float_data: 0.6185391545295715 + float_data: -0.4431719183921814 + float_data: 1.810534954071045 + float_data: -1.3057268857955933 + float_data: -0.3449872136116028 + float_data: -0.23083974421024323 + float_data: -2.7930850982666016 + float_data: 1.9375288486480713 + name: "R" + } + initializer { + dims: 2 + dims: 32 + data_type: 1 + float_data: 0.3663320243358612 + float_data: -1.0445894002914429 + float_data: 2.051173448562622 + float_data: 0.5856620073318481 + float_data: 0.429526150226593 + float_data: -0.6069983839988708 + float_data: 0.1062227264046669 + float_data: -1.5256803035736084 + float_data: 0.7950261235237122 + float_data: -0.3744383156299591 + float_data: 0.134048193693161 + float_data: 1.2020548582077026 + float_data: 0.2847481071949005 + float_data: 0.2624674439430237 + float_data: 0.27649930119514465 + float_data: -0.733271598815918 + float_data: 0.8360047340393066 + float_data: 1.5433591604232788 + float_data: 0.7588056325912476 + float_data: 0.8849087953567505 + float_data: -0.8772815465927124 + float_data: -0.86778724193573 + float_data: -1.4408760070800781 + float_data: 1.232253074645996 + float_data: -0.25417986512184143 + float_data: 1.3998439311981201 + float_data: -0.7819116711616516 + float_data: -0.4375089704990387 + float_data: 0.095425084233284 + float_data: 0.9214500784873962 + float_data: 0.06075019761919975 + float_data: 0.2111247479915619 + float_data: 0.01652756705880165 + float_data: 0.1771877259016037 + float_data: -1.1164699792861938 + float_data: 0.08092710375785828 + float_data: -0.18657898902893066 + float_data: -0.05682447925209999 + float_data: 0.49233654141426086 + float_data: -0.680678129196167 + float_data: -0.0845080241560936 + float_data: -0.2973618805408478 + float_data: 0.4173020124435425 + float_data: 0.784770667552948 + float_data: -0.9554252624511719 + float_data: 0.585910439491272 + float_data: 2.0657832622528076 + float_data: -1.4711569547653198 + float_data: -0.8301718831062317 + float_data: -0.8805776238441467 + float_data: -0.2790977358818054 + float_data: 1.6228491067886353 + float_data: 0.013352676294744015 + float_data: -0.6946936249732971 + float_data: 0.6218035221099854 + float_data: -0.5998045206069946 + float_data: 1.1234121322631836 + float_data: 0.3052670359611511 + float_data: 1.3887794017791748 + float_data: -0.6613442301750183 + float_data: 3.0308570861816406 + float_data: 0.8245846033096313 + float_data: 0.6545801758766174 + float_data: -0.05118844658136368 + name: "B" + } + initializer { + dims: 2 + dims: 5 + dims: 4 + data_type: 1 + float_data: -0.7255971431732178 + float_data: -0.8677687048912048 + float_data: -0.13597732782363892 + float_data: -0.7972697615623474 + float_data: 0.28267571330070496 + float_data: -0.8260974287986755 + float_data: 0.6210827231407166 + float_data: 0.9561216831207275 + float_data: -0.705840528011322 + float_data: 1.1926860809326172 + float_data: -0.2379419356584549 + float_data: 1.1552878618240356 + float_data: 0.4381663501262665 + float_data: 1.122328281402588 + float_data: -0.9970197677612305 + float_data: -0.10679398477077484 + float_data: 1.4514292478561401 + float_data: -0.6180368661880493 + float_data: -2.037201166152954 + float_data: -1.9425891637802124 + float_data: -2.5064406394958496 + float_data: -2.114163875579834 + float_data: -0.41163915395736694 + float_data: 1.278528094291687 + float_data: -0.4422292709350586 + float_data: 0.32352736592292786 + float_data: -0.10999149084091187 + float_data: 0.008548945188522339 + float_data: -0.1681988388299942 + float_data: -0.17418034374713898 + float_data: 0.46116408705711365 + float_data: -1.1759827136993408 + float_data: 1.0101271867752075 + float_data: 0.9200179576873779 + float_data: -0.19505734741687775 + float_data: 0.805393397808075 + float_data: -0.7013444304466248 + float_data: -0.5372230410575867 + float_data: 0.15626384317874908 + float_data: -0.19022102653980255 + name: "initial_h" + } + initializer { + dims: 2 + dims: 5 + dims: 4 + data_type: 1 + float_data: -0.4487380385398865 + float_data: -0.6724480390548706 + float_data: -0.5574946999549866 + float_data: 0.9391687512397766 + float_data: -1.9433233737945557 + float_data: 0.35249435901641846 + float_data: -0.23643694818019867 + float_data: 0.7278134822845459 + float_data: 0.5150735974311829 + float_data: -2.78253436088562 + float_data: 0.5846465826034546 + float_data: 0.3242742419242859 + float_data: 0.021862836554646492 + float_data: -0.46867382526397705 + float_data: 0.8532811999320984 + float_data: -0.4130293130874634 + float_data: 1.8347176313400269 + float_data: 0.5643828511238098 + float_data: 2.1378281116485596 + float_data: -0.7855340242385864 + float_data: -1.7559256553649902 + float_data: 0.7147895693778992 + float_data: 0.8527040481567383 + float_data: 0.035360097885131836 + float_data: -1.5387932062149048 + float_data: -0.44789519906044006 + float_data: 0.6179855465888977 + float_data: -0.18417632579803467 + float_data: -0.11598518490791321 + float_data: -0.17545896768569946 + float_data: -0.9339146614074707 + float_data: -0.5330203175544739 + float_data: -1.4265553951263428 + float_data: 1.7679599523544312 + float_data: -0.47537288069725037 + float_data: 0.47761017084121704 + float_data: -1.0218859910964966 + float_data: 0.7945282459259033 + float_data: -1.8731609582901 + float_data: 0.9206151366233826 + name: "initial_c" + } + initializer { + dims: 2 + dims: 2 + dims: 5 + dims: 4 + data_type: 1 + float_data: -0.004032406024634838 + float_data: -0.004462004639208317 + float_data: -0.11961524933576584 + float_data: -0.24488280713558197 + float_data: -0.7739036679267883 + float_data: 0.0038530235178768635 + float_data: -0.5268939733505249 + float_data: 0.11705601960420609 + float_data: 0.029958104714751244 + float_data: -0.340408056974411 + float_data: 0.011654520407319069 + float_data: -0.1988580971956253 + float_data: -0.7096371650695801 + float_data: 0.002710475353524089 + float_data: 0.0027021574787795544 + float_data: -0.22992612421512604 + float_data: 0.1948436051607132 + float_data: 0.23224633932113647 + float_data: 0.0013459676411002874 + float_data: 0.4876309633255005 + float_data: -0.14344684779644012 + float_data: 0.06210453808307648 + float_data: 0.744733989238739 + float_data: 0.15252138674259186 + float_data: -0.18279105424880981 + float_data: -0.021294880658388138 + float_data: 0.5454503297805786 + float_data: -0.004351779818534851 + float_data: 0.23740188777446747 + float_data: -0.035537537187337875 + float_data: 0.03293892741203308 + float_data: 0.09300301969051361 + float_data: -0.3197038173675537 + float_data: 0.22859133780002594 + float_data: -0.22090184688568115 + float_data: 0.006025888025760651 + float_data: 0.013019595295190811 + float_data: 0.04256730526685715 + float_data: -0.24416765570640564 + float_data: 0.17333006858825684 + float_data: -0.3833770155906677 + float_data: 0.029960526153445244 + float_data: -0.026150351390242577 + float_data: 0.14755044877529144 + float_data: -0.3461287319660187 + float_data: 0.015776503831148148 + float_data: 0.002974175615236163 + float_data: -0.1785404086112976 + float_data: 0.3363659083843231 + float_data: -0.20950248837471008 + float_data: 0.3258989155292511 + float_data: -0.11810128390789032 + float_data: -0.021570634096860886 + float_data: 0.03261823207139969 + float_data: -0.027086878195405006 + float_data: -0.4444737434387207 + float_data: 0.4044571816921234 + float_data: 0.6170535683631897 + float_data: 0.19347822666168213 + float_data: 0.44024890661239624 + float_data: -0.8476369380950928 + float_data: 0.8057203888893127 + float_data: 0.6255403757095337 + float_data: 4.2606592614902183e-05 + float_data: -0.5677104592323303 + float_data: -0.07071330398321152 + float_data: 0.5493011474609375 + float_data: -0.01966593973338604 + float_data: 0.06701567023992538 + float_data: -0.09959989041090012 + float_data: -0.1481088399887085 + float_data: -0.3607737421989441 + float_data: -0.3277128040790558 + float_data: 0.006580004934221506 + float_data: 0.13670679926872253 + float_data: 0.5570814609527588 + float_data: -0.01634780503809452 + float_data: 0.49886056780815125 + float_data: -0.716121256351471 + float_data: -0.03986666351556778 + name: "Y_ref" + } + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 5 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "W" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "R" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 32 + } + } + } + } + } + input { + name: "initial_h" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "initial_c" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "Y_ref" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "Y_err" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 10 +} diff --git a/tests/models/onnxModels/lstmForward.onnxtxt b/tests/models/onnxModels/lstmForward.onnxtxt new file mode 100644 index 0000000000..80394a0398 --- /dev/null +++ b/tests/models/onnxModels/lstmForward.onnxtxt @@ -0,0 +1,496 @@ +ir_version: 5 +producer_name: "onnx-lstm" +graph { + node { + input: "X" + input: "W" + input: "R" + input: "B" + input: "" + input: "initial_h" + input: "initial_c" + input: "" + output: "Y" + name: "lstm" + op_type: "LSTM" + attribute { + name: "direction" + s: "forward" + type: STRING + } + attribute { + name: "hidden_size" + i: 4 + type: INT + } + } + node { + input: "Y" + input: "Y_ref" + output: "Y_err" + name: "error" + op_type: "Sub" + } + name: "lstm_test" + initializer { + dims: 2 + dims: 5 + dims: 3 + data_type: 1 + float_data: 1.6243454217910767 + float_data: -0.6117563843727112 + float_data: -0.5281717777252197 + float_data: -1.072968602180481 + float_data: 0.8654076457023621 + float_data: -2.3015387058258057 + float_data: 1.744811773300171 + float_data: -0.7612069249153137 + float_data: 0.31903910636901855 + float_data: -0.24937038123607635 + float_data: 1.4621078968048096 + float_data: -2.060140609741211 + float_data: -0.3224171996116638 + float_data: -0.38405436277389526 + float_data: 1.1337693929672241 + float_data: -1.0998913049697876 + float_data: -0.1724282056093216 + float_data: -0.8778584003448486 + float_data: 0.042213745415210724 + float_data: 0.5828152298927307 + float_data: -1.1006191968917847 + float_data: 1.144723653793335 + float_data: 0.9015907049179077 + float_data: 0.5024943351745605 + float_data: 0.9008559584617615 + float_data: -0.6837278604507446 + float_data: -0.12289022654294968 + float_data: -0.9357694387435913 + float_data: -0.26788806915283203 + float_data: 0.5303554534912109 + name: "X" + } + initializer { + dims: 1 + dims: 16 + dims: 3 + data_type: 1 + float_data: -0.6916607618331909 + float_data: -0.3967535197734833 + float_data: -0.6871727108955383 + float_data: -0.8452056646347046 + float_data: -0.6712461113929749 + float_data: -0.01266459934413433 + float_data: -1.1173104047775269 + float_data: 0.2344156950712204 + float_data: 1.6598021984100342 + float_data: 0.7420441508293152 + float_data: -0.19183555245399475 + float_data: -0.887628972530365 + float_data: -0.7471582889556885 + float_data: 1.6924545764923096 + float_data: 0.050807755440473557 + float_data: -0.6369956731796265 + float_data: 0.19091548025608063 + float_data: 2.100255250930786 + float_data: 0.12015895545482635 + float_data: 0.6172031164169312 + float_data: 0.30017033219337463 + float_data: -0.3522498607635498 + float_data: -1.142518162727356 + float_data: -0.3493427336215973 + float_data: -0.20889423787593842 + float_data: 0.5866231918334961 + float_data: 0.838983416557312 + float_data: 0.9311020970344543 + float_data: 0.2855873107910156 + float_data: 0.8851411938667297 + float_data: -0.7543979287147522 + float_data: 1.2528681755065918 + float_data: 0.5129297971725464 + float_data: -0.2980928421020508 + float_data: 0.4885181486606598 + float_data: -0.07557171583175659 + float_data: 1.1316293478012085 + float_data: 1.5198168754577637 + float_data: 2.185575485229492 + float_data: -1.396496295928955 + float_data: -1.444113850593567 + float_data: -0.5044658780097961 + float_data: 0.1600370705127716 + float_data: 0.8761689066886902 + float_data: 0.31563493609428406 + float_data: -2.0222012996673584 + float_data: -0.30620402097702026 + float_data: 0.8279746174812317 + name: "W" + } + initializer { + dims: 1 + dims: 16 + dims: 4 + data_type: 1 + float_data: 0.23009473085403442 + float_data: 0.7620111703872681 + float_data: -0.22232814133167267 + float_data: -0.20075806975364685 + float_data: 0.18656139075756073 + float_data: 0.4100516438484192 + float_data: 0.19829972088336945 + float_data: 0.11900864541530609 + float_data: -0.6706622838973999 + float_data: 0.3775637745857239 + float_data: 0.12182126939296722 + float_data: 1.129483938217163 + float_data: 1.1989178657531738 + float_data: 0.1851564198732376 + float_data: -0.37528494000434875 + float_data: -0.6387304067611694 + float_data: 0.4234943687915802 + float_data: 0.07734006643295288 + float_data: -0.3438536822795868 + float_data: 0.04359685629606247 + float_data: -0.6200008392333984 + float_data: 0.698032021522522 + float_data: -0.447128564119339 + float_data: 1.2245076894760132 + float_data: 0.40349164605140686 + float_data: 0.5935785174369812 + float_data: -1.094911813735962 + float_data: 0.16938243806362152 + float_data: 0.7405564785003662 + float_data: -0.953700602054596 + float_data: -0.26621851325035095 + float_data: 0.03261454775929451 + float_data: -1.3731173276901245 + float_data: 0.3151593804359436 + float_data: 0.8461606502532959 + float_data: -0.8595159649848938 + float_data: 0.3505459725856781 + float_data: -1.3122833967208862 + float_data: -0.03869551047682762 + float_data: -1.6157723665237427 + float_data: 1.121417760848999 + float_data: 0.40890052914619446 + float_data: -0.02461695671081543 + float_data: -0.775161623954773 + float_data: 1.2737559080123901 + float_data: 1.9671016931533813 + float_data: -1.8579819202423096 + float_data: 1.2361639738082886 + float_data: 1.6276507377624512 + float_data: 0.3380116820335388 + float_data: -1.1992679834365845 + float_data: 0.8633453249931335 + float_data: -0.1809203028678894 + float_data: -0.6039206385612488 + float_data: -1.230058193206787 + float_data: 0.5505374670028687 + float_data: 0.79280686378479 + float_data: -0.6235307455062866 + float_data: 0.5205763578414917 + float_data: -1.1443413496017456 + float_data: 0.801861047744751 + float_data: 0.04656729847192764 + float_data: -0.18656976521015167 + float_data: -0.10174587368965149 + name: "R" + } + initializer { + dims: 1 + dims: 32 + data_type: 1 + float_data: 0.8688861727714539 + float_data: 0.7504116296768188 + float_data: 0.5294653177261353 + float_data: 0.13770121335983276 + float_data: 0.07782112807035446 + float_data: 0.6183802485466003 + float_data: 0.2324945628643036 + float_data: 0.682551383972168 + float_data: -0.3101167678833008 + float_data: -2.434837818145752 + float_data: 1.0388245582580566 + float_data: 2.1869795322418213 + float_data: 0.44136443734169006 + float_data: -0.10015523433685303 + float_data: -0.13644474744796753 + float_data: -0.11905419081449509 + float_data: 0.01740940846502781 + float_data: -1.1220186948776245 + float_data: -0.5170944333076477 + float_data: -0.997026801109314 + float_data: 0.2487991601228714 + float_data: -0.29664114117622375 + float_data: 0.49521133303642273 + float_data: -0.17470316588878632 + float_data: 0.9863351583480835 + float_data: 0.2135339081287384 + float_data: 2.190699815750122 + float_data: -1.8963608741760254 + float_data: -0.6469166874885559 + float_data: 0.901486873626709 + float_data: 2.5283257961273193 + float_data: -0.24863477051258087 + name: "B" + } + initializer { + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: 0.043668992817401886 + float_data: -0.2263142466545105 + float_data: 1.3314571380615234 + float_data: -0.28730785846710205 + float_data: 0.6800698637962341 + float_data: -0.31980159878730774 + float_data: -1.2725588083267212 + float_data: 0.3135477304458618 + float_data: 0.5031847953796387 + float_data: 1.293225884437561 + float_data: -0.11044702678918839 + float_data: -0.6173620820045471 + float_data: 0.5627610683441162 + float_data: 0.24073709547519684 + float_data: 0.2806650698184967 + float_data: -0.07311270385980606 + float_data: 1.1603385210037231 + float_data: 0.36949270963668823 + float_data: 1.9046586751937866 + float_data: 1.1110566854476929 + name: "initial_h" + } + initializer { + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: 0.6590498089790344 + float_data: -1.6274383068084717 + float_data: 0.6023193001747131 + float_data: 0.4202822148799896 + float_data: 0.8109516501426697 + float_data: 1.044442057609558 + float_data: -0.4008781909942627 + float_data: 0.8240056037902832 + float_data: -0.5623054504394531 + float_data: 1.9548780918121338 + float_data: -1.33195161819458 + float_data: -1.7606885433197021 + float_data: -1.6507213115692139 + float_data: -0.8905555605888367 + float_data: -1.1191153526306152 + float_data: 1.9560788869857788 + float_data: -0.32649949193000793 + float_data: -1.342675805091858 + float_data: 1.1143829822540283 + float_data: -0.5865239500999451 + name: "initial_c" + } + initializer { + dims: 2 + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: -0.0025073192082345486 + float_data: -0.029248230159282684 + float_data: 0.10335244238376617 + float_data: -0.37830352783203125 + float_data: -0.6846296191215515 + float_data: 0.014212781563401222 + float_data: -0.2873882055282593 + float_data: 0.6879828572273254 + float_data: 0.04762459173798561 + float_data: 0.08635172992944717 + float_data: -0.6732637882232666 + float_data: -0.4424405097961426 + float_data: -0.8606193661689758 + float_data: -0.003097095526754856 + float_data: -0.5818638801574707 + float_data: 0.3305215835571289 + float_data: 0.16285081207752228 + float_data: -0.49072203040122986 + float_data: 0.3839721083641052 + float_data: -0.13327273726463318 + float_data: -0.4976532757282257 + float_data: 0.11136293411254883 + float_data: 0.3407604694366455 + float_data: -0.015222622081637383 + float_data: -0.6233256459236145 + float_data: 0.09541770070791245 + float_data: -0.004889961332082748 + float_data: 0.23766900599002838 + float_data: 0.48381438851356506 + float_data: 0.009071349166333675 + float_data: -0.6133842468261719 + float_data: -0.26729273796081543 + float_data: -0.13883478939533234 + float_data: 0.157814159989357 + float_data: -0.2884327471256256 + float_data: 0.14862513542175293 + float_data: -0.15628398954868317 + float_data: 0.3625759482383728 + float_data: 0.4816629886627197 + float_data: 0.0627441480755806 + name: "Y_ref" + } + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 5 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "W" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "R" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + input { + name: "initial_h" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "initial_c" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "Y_ref" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "Y_err" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 10 +} diff --git a/tests/models/onnxModels/lstmForwardNoBias.onnxtxt b/tests/models/onnxModels/lstmForwardNoBias.onnxtxt new file mode 100644 index 0000000000..a67485e563 --- /dev/null +++ b/tests/models/onnxModels/lstmForwardNoBias.onnxtxt @@ -0,0 +1,407 @@ +ir_version: 5 +producer_name: "onnx-lstm" +graph { + node { + input: "X" + input: "W" + input: "R" + input: "" + input: "" + input: "initial_h" + input: "initial_c" + input: "" + output: "Y" + name: "lstm" + op_type: "LSTM" + attribute { + name: "direction" + s: "forward" + type: STRING + } + attribute { + name: "hidden_size" + i: 4 + type: INT + } + } + node { + input: "Y" + input: "Y_ref" + output: "Y_err" + name: "error" + op_type: "Sub" + } + name: "lstm_test" + initializer { + dims: 1 + dims: 5 + dims: 3 + data_type: 1 + float_data: 1.6243454217910767 + float_data: -0.6117563843727112 + float_data: -0.5281717777252197 + float_data: -1.072968602180481 + float_data: 0.8654076457023621 + float_data: -2.3015387058258057 + float_data: 1.744811773300171 + float_data: -0.7612069249153137 + float_data: 0.31903910636901855 + float_data: -0.24937038123607635 + float_data: 1.4621078968048096 + float_data: -2.060140609741211 + float_data: -0.3224171996116638 + float_data: -0.38405436277389526 + float_data: 1.1337693929672241 + name: "X" + } + initializer { + dims: 1 + dims: 16 + dims: 3 + data_type: 1 + float_data: -1.0998913049697876 + float_data: -0.1724282056093216 + float_data: -0.8778584003448486 + float_data: 0.042213745415210724 + float_data: 0.5828152298927307 + float_data: -1.1006191968917847 + float_data: 1.144723653793335 + float_data: 0.9015907049179077 + float_data: 0.5024943351745605 + float_data: 0.9008559584617615 + float_data: -0.6837278604507446 + float_data: -0.12289022654294968 + float_data: -0.9357694387435913 + float_data: -0.26788806915283203 + float_data: 0.5303554534912109 + float_data: -0.6916607618331909 + float_data: -0.3967535197734833 + float_data: -0.6871727108955383 + float_data: -0.8452056646347046 + float_data: -0.6712461113929749 + float_data: -0.01266459934413433 + float_data: -1.1173104047775269 + float_data: 0.2344156950712204 + float_data: 1.6598021984100342 + float_data: 0.7420441508293152 + float_data: -0.19183555245399475 + float_data: -0.887628972530365 + float_data: -0.7471582889556885 + float_data: 1.6924545764923096 + float_data: 0.050807755440473557 + float_data: -0.6369956731796265 + float_data: 0.19091548025608063 + float_data: 2.100255250930786 + float_data: 0.12015895545482635 + float_data: 0.6172031164169312 + float_data: 0.30017033219337463 + float_data: -0.3522498607635498 + float_data: -1.142518162727356 + float_data: -0.3493427336215973 + float_data: -0.20889423787593842 + float_data: 0.5866231918334961 + float_data: 0.838983416557312 + float_data: 0.9311020970344543 + float_data: 0.2855873107910156 + float_data: 0.8851411938667297 + float_data: -0.7543979287147522 + float_data: 1.2528681755065918 + float_data: 0.5129297971725464 + name: "W" + } + initializer { + dims: 1 + dims: 16 + dims: 4 + data_type: 1 + float_data: -0.2980928421020508 + float_data: 0.4885181486606598 + float_data: -0.07557171583175659 + float_data: 1.1316293478012085 + float_data: 1.5198168754577637 + float_data: 2.185575485229492 + float_data: -1.396496295928955 + float_data: -1.444113850593567 + float_data: -0.5044658780097961 + float_data: 0.1600370705127716 + float_data: 0.8761689066886902 + float_data: 0.31563493609428406 + float_data: -2.0222012996673584 + float_data: -0.30620402097702026 + float_data: 0.8279746174812317 + float_data: 0.23009473085403442 + float_data: 0.7620111703872681 + float_data: -0.22232814133167267 + float_data: -0.20075806975364685 + float_data: 0.18656139075756073 + float_data: 0.4100516438484192 + float_data: 0.19829972088336945 + float_data: 0.11900864541530609 + float_data: -0.6706622838973999 + float_data: 0.3775637745857239 + float_data: 0.12182126939296722 + float_data: 1.129483938217163 + float_data: 1.1989178657531738 + float_data: 0.1851564198732376 + float_data: -0.37528494000434875 + float_data: -0.6387304067611694 + float_data: 0.4234943687915802 + float_data: 0.07734006643295288 + float_data: -0.3438536822795868 + float_data: 0.04359685629606247 + float_data: -0.6200008392333984 + float_data: 0.698032021522522 + float_data: -0.447128564119339 + float_data: 1.2245076894760132 + float_data: 0.40349164605140686 + float_data: 0.5935785174369812 + float_data: -1.094911813735962 + float_data: 0.16938243806362152 + float_data: 0.7405564785003662 + float_data: -0.953700602054596 + float_data: -0.26621851325035095 + float_data: 0.03261454775929451 + float_data: -1.3731173276901245 + float_data: 0.3151593804359436 + float_data: 0.8461606502532959 + float_data: -0.8595159649848938 + float_data: 0.3505459725856781 + float_data: -1.3122833967208862 + float_data: -0.03869551047682762 + float_data: -1.6157723665237427 + float_data: 1.121417760848999 + float_data: 0.40890052914619446 + float_data: -0.02461695671081543 + float_data: -0.775161623954773 + float_data: 1.2737559080123901 + float_data: 1.9671016931533813 + float_data: -1.8579819202423096 + float_data: 1.2361639738082886 + float_data: 1.6276507377624512 + name: "R" + } + initializer { + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: 0.3380116820335388 + float_data: -1.1992679834365845 + float_data: 0.8633453249931335 + float_data: -0.1809203028678894 + float_data: -0.6039206385612488 + float_data: -1.230058193206787 + float_data: 0.5505374670028687 + float_data: 0.79280686378479 + float_data: -0.6235307455062866 + float_data: 0.5205763578414917 + float_data: -1.1443413496017456 + float_data: 0.801861047744751 + float_data: 0.04656729847192764 + float_data: -0.18656976521015167 + float_data: -0.10174587368965149 + float_data: 0.8688861727714539 + float_data: 0.7504116296768188 + float_data: 0.5294653177261353 + float_data: 0.13770121335983276 + float_data: 0.07782112807035446 + name: "initial_h" + } + initializer { + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: 0.6183802485466003 + float_data: 0.2324945628643036 + float_data: 0.682551383972168 + float_data: -0.3101167678833008 + float_data: -2.434837818145752 + float_data: 1.0388245582580566 + float_data: 2.1869795322418213 + float_data: 0.44136443734169006 + float_data: -0.10015523433685303 + float_data: -0.13644474744796753 + float_data: -0.11905419081449509 + float_data: 0.01740940846502781 + float_data: -1.1220186948776245 + float_data: -0.5170944333076477 + float_data: -0.997026801109314 + float_data: 0.2487991601228714 + float_data: -0.29664114117622375 + float_data: 0.49521133303642273 + float_data: -0.17470316588878632 + float_data: 0.9863351583480835 + name: "initial_c" + } + initializer { + dims: 1 + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: 0.0933857262134552 + float_data: 0.011009203270077705 + float_data: 0.14254379272460938 + float_data: 0.02861296571791172 + float_data: -0.3509131968021393 + float_data: 0.5523810982704163 + float_data: -0.09563437104225159 + float_data: 0.07891085743904114 + float_data: 0.04793476313352585 + float_data: 0.054873671382665634 + float_data: 0.1155666634440422 + float_data: -0.21774370968341827 + float_data: -0.23375603556632996 + float_data: -0.15837359428405762 + float_data: -0.13897359371185303 + float_data: 0.03168010711669922 + float_data: 0.14714477956295013 + float_data: 0.027562865987420082 + float_data: 0.10515212267637253 + float_data: 0.34372878074645996 + name: "Y_ref" + } + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "W" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "R" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "initial_h" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "initial_c" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "Y_ref" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "Y_err" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 10 +} diff --git a/tests/models/onnxModels/lstmForwardNoState.onnxtxt b/tests/models/onnxModels/lstmForwardNoState.onnxtxt new file mode 100644 index 0000000000..6db43dc7b2 --- /dev/null +++ b/tests/models/onnxModels/lstmForwardNoState.onnxtxt @@ -0,0 +1,369 @@ +ir_version: 5 +producer_name: "onnx-lstm" +graph { + node { + input: "X" + input: "W" + input: "R" + input: "B" + input: "" + input: "" + input: "" + input: "" + output: "Y" + name: "lstm" + op_type: "LSTM" + attribute { + name: "direction" + s: "forward" + type: STRING + } + attribute { + name: "hidden_size" + i: 4 + type: INT + } + } + node { + input: "Y" + input: "Y_ref" + output: "Y_err" + name: "error" + op_type: "Sub" + } + name: "lstm_test" + initializer { + dims: 1 + dims: 5 + dims: 3 + data_type: 1 + float_data: 1.6243454217910767 + float_data: -0.6117563843727112 + float_data: -0.5281717777252197 + float_data: -1.072968602180481 + float_data: 0.8654076457023621 + float_data: -2.3015387058258057 + float_data: 1.744811773300171 + float_data: -0.7612069249153137 + float_data: 0.31903910636901855 + float_data: -0.24937038123607635 + float_data: 1.4621078968048096 + float_data: -2.060140609741211 + float_data: -0.3224171996116638 + float_data: -0.38405436277389526 + float_data: 1.1337693929672241 + name: "X" + } + initializer { + dims: 1 + dims: 16 + dims: 3 + data_type: 1 + float_data: -1.0998913049697876 + float_data: -0.1724282056093216 + float_data: -0.8778584003448486 + float_data: 0.042213745415210724 + float_data: 0.5828152298927307 + float_data: -1.1006191968917847 + float_data: 1.144723653793335 + float_data: 0.9015907049179077 + float_data: 0.5024943351745605 + float_data: 0.9008559584617615 + float_data: -0.6837278604507446 + float_data: -0.12289022654294968 + float_data: -0.9357694387435913 + float_data: -0.26788806915283203 + float_data: 0.5303554534912109 + float_data: -0.6916607618331909 + float_data: -0.3967535197734833 + float_data: -0.6871727108955383 + float_data: -0.8452056646347046 + float_data: -0.6712461113929749 + float_data: -0.01266459934413433 + float_data: -1.1173104047775269 + float_data: 0.2344156950712204 + float_data: 1.6598021984100342 + float_data: 0.7420441508293152 + float_data: -0.19183555245399475 + float_data: -0.887628972530365 + float_data: -0.7471582889556885 + float_data: 1.6924545764923096 + float_data: 0.050807755440473557 + float_data: -0.6369956731796265 + float_data: 0.19091548025608063 + float_data: 2.100255250930786 + float_data: 0.12015895545482635 + float_data: 0.6172031164169312 + float_data: 0.30017033219337463 + float_data: -0.3522498607635498 + float_data: -1.142518162727356 + float_data: -0.3493427336215973 + float_data: -0.20889423787593842 + float_data: 0.5866231918334961 + float_data: 0.838983416557312 + float_data: 0.9311020970344543 + float_data: 0.2855873107910156 + float_data: 0.8851411938667297 + float_data: -0.7543979287147522 + float_data: 1.2528681755065918 + float_data: 0.5129297971725464 + name: "W" + } + initializer { + dims: 1 + dims: 16 + dims: 4 + data_type: 1 + float_data: -0.2980928421020508 + float_data: 0.4885181486606598 + float_data: -0.07557171583175659 + float_data: 1.1316293478012085 + float_data: 1.5198168754577637 + float_data: 2.185575485229492 + float_data: -1.396496295928955 + float_data: -1.444113850593567 + float_data: -0.5044658780097961 + float_data: 0.1600370705127716 + float_data: 0.8761689066886902 + float_data: 0.31563493609428406 + float_data: -2.0222012996673584 + float_data: -0.30620402097702026 + float_data: 0.8279746174812317 + float_data: 0.23009473085403442 + float_data: 0.7620111703872681 + float_data: -0.22232814133167267 + float_data: -0.20075806975364685 + float_data: 0.18656139075756073 + float_data: 0.4100516438484192 + float_data: 0.19829972088336945 + float_data: 0.11900864541530609 + float_data: -0.6706622838973999 + float_data: 0.3775637745857239 + float_data: 0.12182126939296722 + float_data: 1.129483938217163 + float_data: 1.1989178657531738 + float_data: 0.1851564198732376 + float_data: -0.37528494000434875 + float_data: -0.6387304067611694 + float_data: 0.4234943687915802 + float_data: 0.07734006643295288 + float_data: -0.3438536822795868 + float_data: 0.04359685629606247 + float_data: -0.6200008392333984 + float_data: 0.698032021522522 + float_data: -0.447128564119339 + float_data: 1.2245076894760132 + float_data: 0.40349164605140686 + float_data: 0.5935785174369812 + float_data: -1.094911813735962 + float_data: 0.16938243806362152 + float_data: 0.7405564785003662 + float_data: -0.953700602054596 + float_data: -0.26621851325035095 + float_data: 0.03261454775929451 + float_data: -1.3731173276901245 + float_data: 0.3151593804359436 + float_data: 0.8461606502532959 + float_data: -0.8595159649848938 + float_data: 0.3505459725856781 + float_data: -1.3122833967208862 + float_data: -0.03869551047682762 + float_data: -1.6157723665237427 + float_data: 1.121417760848999 + float_data: 0.40890052914619446 + float_data: -0.02461695671081543 + float_data: -0.775161623954773 + float_data: 1.2737559080123901 + float_data: 1.9671016931533813 + float_data: -1.8579819202423096 + float_data: 1.2361639738082886 + float_data: 1.6276507377624512 + name: "R" + } + initializer { + dims: 1 + dims: 32 + data_type: 1 + float_data: 0.3380116820335388 + float_data: -1.1992679834365845 + float_data: 0.8633453249931335 + float_data: -0.1809203028678894 + float_data: -0.6039206385612488 + float_data: -1.230058193206787 + float_data: 0.5505374670028687 + float_data: 0.79280686378479 + float_data: -0.6235307455062866 + float_data: 0.5205763578414917 + float_data: -1.1443413496017456 + float_data: 0.801861047744751 + float_data: 0.04656729847192764 + float_data: -0.18656976521015167 + float_data: -0.10174587368965149 + float_data: 0.8688861727714539 + float_data: 0.7504116296768188 + float_data: 0.5294653177261353 + float_data: 0.13770121335983276 + float_data: 0.07782112807035446 + float_data: 0.6183802485466003 + float_data: 0.2324945628643036 + float_data: 0.682551383972168 + float_data: -0.3101167678833008 + float_data: -2.434837818145752 + float_data: 1.0388245582580566 + float_data: 2.1869795322418213 + float_data: 0.44136443734169006 + float_data: -0.10015523433685303 + float_data: -0.13644474744796753 + float_data: -0.11905419081449509 + float_data: 0.01740940846502781 + name: "B" + } + initializer { + dims: 1 + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: 0.019366364926099777 + float_data: -0.06308465451002121 + float_data: 0.2654789090156555 + float_data: -0.055694155395030975 + float_data: 0.05336597561836243 + float_data: -0.4943341612815857 + float_data: -0.2806651294231415 + float_data: 0.023125888779759407 + float_data: 0.005478795152157545 + float_data: -0.014820549637079239 + float_data: 0.38491129875183105 + float_data: -0.1545427292585373 + float_data: -0.1341647356748581 + float_data: -0.31911501288414 + float_data: -0.3730350732803345 + float_data: 0.02135608345270157 + float_data: 0.047005925327539444 + float_data: 0.008991322480142117 + float_data: 0.20875070989131927 + float_data: 0.3259517252445221 + name: "Y_ref" + } + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "W" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "R" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + input { + name: "Y_ref" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "Y_err" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 10 +} diff --git a/tests/models/onnxModels/lstmForwardWithPeephole.onnxtxt b/tests/models/onnxModels/lstmForwardWithPeephole.onnxtxt new file mode 100644 index 0000000000..f8c6132a0b --- /dev/null +++ b/tests/models/onnxModels/lstmForwardWithPeephole.onnxtxt @@ -0,0 +1,495 @@ +ir_version: 5 +producer_name: "onnx-lstm" +graph { + node { + input: "X" + input: "W" + input: "R" + input: "B" + input: "" + input: "initial_h" + input: "initial_c" + input: "P" + output: "Y" + name: "lstm" + op_type: "LSTM" + attribute { + name: "direction" + s: "forward" + type: STRING + } + attribute { + name: "hidden_size" + i: 4 + type: INT + } + } + node { + input: "Y" + input: "Y_ref" + output: "Y_err" + name: "error" + op_type: "Sub" + } + name: "lstm_test" + initializer { + dims: 1 + dims: 5 + dims: 3 + data_type: 1 + float_data: 1.6243454217910767 + float_data: -0.6117563843727112 + float_data: -0.5281717777252197 + float_data: -1.072968602180481 + float_data: 0.8654076457023621 + float_data: -2.3015387058258057 + float_data: 1.744811773300171 + float_data: -0.7612069249153137 + float_data: 0.31903910636901855 + float_data: -0.24937038123607635 + float_data: 1.4621078968048096 + float_data: -2.060140609741211 + float_data: -0.3224171996116638 + float_data: -0.38405436277389526 + float_data: 1.1337693929672241 + name: "X" + } + initializer { + dims: 1 + dims: 16 + dims: 3 + data_type: 1 + float_data: -1.0998913049697876 + float_data: -0.1724282056093216 + float_data: -0.8778584003448486 + float_data: 0.042213745415210724 + float_data: 0.5828152298927307 + float_data: -1.1006191968917847 + float_data: 1.144723653793335 + float_data: 0.9015907049179077 + float_data: 0.5024943351745605 + float_data: 0.9008559584617615 + float_data: -0.6837278604507446 + float_data: -0.12289022654294968 + float_data: -0.9357694387435913 + float_data: -0.26788806915283203 + float_data: 0.5303554534912109 + float_data: -0.6916607618331909 + float_data: -0.3967535197734833 + float_data: -0.6871727108955383 + float_data: -0.8452056646347046 + float_data: -0.6712461113929749 + float_data: -0.01266459934413433 + float_data: -1.1173104047775269 + float_data: 0.2344156950712204 + float_data: 1.6598021984100342 + float_data: 0.7420441508293152 + float_data: -0.19183555245399475 + float_data: -0.887628972530365 + float_data: -0.7471582889556885 + float_data: 1.6924545764923096 + float_data: 0.050807755440473557 + float_data: -0.6369956731796265 + float_data: 0.19091548025608063 + float_data: 2.100255250930786 + float_data: 0.12015895545482635 + float_data: 0.6172031164169312 + float_data: 0.30017033219337463 + float_data: -0.3522498607635498 + float_data: -1.142518162727356 + float_data: -0.3493427336215973 + float_data: -0.20889423787593842 + float_data: 0.5866231918334961 + float_data: 0.838983416557312 + float_data: 0.9311020970344543 + float_data: 0.2855873107910156 + float_data: 0.8851411938667297 + float_data: -0.7543979287147522 + float_data: 1.2528681755065918 + float_data: 0.5129297971725464 + name: "W" + } + initializer { + dims: 1 + dims: 16 + dims: 4 + data_type: 1 + float_data: -0.2980928421020508 + float_data: 0.4885181486606598 + float_data: -0.07557171583175659 + float_data: 1.1316293478012085 + float_data: 1.5198168754577637 + float_data: 2.185575485229492 + float_data: -1.396496295928955 + float_data: -1.444113850593567 + float_data: -0.5044658780097961 + float_data: 0.1600370705127716 + float_data: 0.8761689066886902 + float_data: 0.31563493609428406 + float_data: -2.0222012996673584 + float_data: -0.30620402097702026 + float_data: 0.8279746174812317 + float_data: 0.23009473085403442 + float_data: 0.7620111703872681 + float_data: -0.22232814133167267 + float_data: -0.20075806975364685 + float_data: 0.18656139075756073 + float_data: 0.4100516438484192 + float_data: 0.19829972088336945 + float_data: 0.11900864541530609 + float_data: -0.6706622838973999 + float_data: 0.3775637745857239 + float_data: 0.12182126939296722 + float_data: 1.129483938217163 + float_data: 1.1989178657531738 + float_data: 0.1851564198732376 + float_data: -0.37528494000434875 + float_data: -0.6387304067611694 + float_data: 0.4234943687915802 + float_data: 0.07734006643295288 + float_data: -0.3438536822795868 + float_data: 0.04359685629606247 + float_data: -0.6200008392333984 + float_data: 0.698032021522522 + float_data: -0.447128564119339 + float_data: 1.2245076894760132 + float_data: 0.40349164605140686 + float_data: 0.5935785174369812 + float_data: -1.094911813735962 + float_data: 0.16938243806362152 + float_data: 0.7405564785003662 + float_data: -0.953700602054596 + float_data: -0.26621851325035095 + float_data: 0.03261454775929451 + float_data: -1.3731173276901245 + float_data: 0.3151593804359436 + float_data: 0.8461606502532959 + float_data: -0.8595159649848938 + float_data: 0.3505459725856781 + float_data: -1.3122833967208862 + float_data: -0.03869551047682762 + float_data: -1.6157723665237427 + float_data: 1.121417760848999 + float_data: 0.40890052914619446 + float_data: -0.02461695671081543 + float_data: -0.775161623954773 + float_data: 1.2737559080123901 + float_data: 1.9671016931533813 + float_data: -1.8579819202423096 + float_data: 1.2361639738082886 + float_data: 1.6276507377624512 + name: "R" + } + initializer { + dims: 1 + dims: 32 + data_type: 1 + float_data: 0.3380116820335388 + float_data: -1.1992679834365845 + float_data: 0.8633453249931335 + float_data: -0.1809203028678894 + float_data: -0.6039206385612488 + float_data: -1.230058193206787 + float_data: 0.5505374670028687 + float_data: 0.79280686378479 + float_data: -0.6235307455062866 + float_data: 0.5205763578414917 + float_data: -1.1443413496017456 + float_data: 0.801861047744751 + float_data: 0.04656729847192764 + float_data: -0.18656976521015167 + float_data: -0.10174587368965149 + float_data: 0.8688861727714539 + float_data: 0.7504116296768188 + float_data: 0.5294653177261353 + float_data: 0.13770121335983276 + float_data: 0.07782112807035446 + float_data: 0.6183802485466003 + float_data: 0.2324945628643036 + float_data: 0.682551383972168 + float_data: -0.3101167678833008 + float_data: -2.434837818145752 + float_data: 1.0388245582580566 + float_data: 2.1869795322418213 + float_data: 0.44136443734169006 + float_data: -0.10015523433685303 + float_data: -0.13644474744796753 + float_data: -0.11905419081449509 + float_data: 0.01740940846502781 + name: "B" + } + initializer { + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: -1.1220186948776245 + float_data: -0.5170944333076477 + float_data: -0.997026801109314 + float_data: 0.2487991601228714 + float_data: -0.29664114117622375 + float_data: 0.49521133303642273 + float_data: -0.17470316588878632 + float_data: 0.9863351583480835 + float_data: 0.2135339081287384 + float_data: 2.190699815750122 + float_data: -1.8963608741760254 + float_data: -0.6469166874885559 + float_data: 0.901486873626709 + float_data: 2.5283257961273193 + float_data: -0.24863477051258087 + float_data: 0.043668992817401886 + float_data: -0.2263142466545105 + float_data: 1.3314571380615234 + float_data: -0.28730785846710205 + float_data: 0.6800698637962341 + name: "initial_h" + } + initializer { + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: -0.31980159878730774 + float_data: -1.2725588083267212 + float_data: 0.3135477304458618 + float_data: 0.5031847953796387 + float_data: 1.293225884437561 + float_data: -0.11044702678918839 + float_data: -0.6173620820045471 + float_data: 0.5627610683441162 + float_data: 0.24073709547519684 + float_data: 0.2806650698184967 + float_data: -0.07311270385980606 + float_data: 1.1603385210037231 + float_data: 0.36949270963668823 + float_data: 1.9046586751937866 + float_data: 1.1110566854476929 + float_data: 0.6590498089790344 + float_data: -1.6274383068084717 + float_data: 0.6023193001747131 + float_data: 0.4202822148799896 + float_data: 0.8109516501426697 + name: "initial_c" + } + initializer { + dims: 1 + dims: 12 + data_type: 1 + float_data: 1.044442057609558 + float_data: -0.4008781909942627 + float_data: 0.8240056037902832 + float_data: -0.5623054504394531 + float_data: 1.9548780918121338 + float_data: -1.33195161819458 + float_data: -1.7606885433197021 + float_data: -1.6507213115692139 + float_data: -0.8905555605888367 + float_data: -1.1191153526306152 + float_data: 1.9560788869857788 + float_data: -0.32649949193000793 + name: "P" + } + initializer { + dims: 1 + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: 0.014501647092401981 + float_data: -0.021486641839146614 + float_data: 0.05031450092792511 + float_data: -0.16132746636867523 + float_data: 0.4661756157875061 + float_data: 0.04947449266910553 + float_data: -0.2929844558238983 + float_data: 0.03517363220453262 + float_data: 0.1610381305217743 + float_data: 0.05231039971113205 + float_data: 0.01746058464050293 + float_data: 0.06674934178590775 + float_data: 0.5095223784446716 + float_data: 0.24927878379821777 + float_data: -0.5190149545669556 + float_data: 0.008717215619981289 + float_data: 0.39056506752967834 + float_data: 0.0400579571723938 + float_data: 0.47333407402038574 + float_data: 0.14710074663162231 + name: "Y_ref" + } + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "W" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "R" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + input { + name: "initial_h" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "initial_c" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "P" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 12 + } + } + } + } + } + input { + name: "Y_ref" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "Y_err" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 10 +} diff --git a/tests/models/onnxModels/lstmReverse.onnxtxt b/tests/models/onnxModels/lstmReverse.onnxtxt new file mode 100644 index 0000000000..e040fd6395 --- /dev/null +++ b/tests/models/onnxModels/lstmReverse.onnxtxt @@ -0,0 +1,496 @@ +ir_version: 5 +producer_name: "onnx-lstm" +graph { + node { + input: "X" + input: "W" + input: "R" + input: "B" + input: "" + input: "initial_h" + input: "initial_c" + input: "" + output: "Y" + name: "lstm" + op_type: "LSTM" + attribute { + name: "direction" + s: "reverse" + type: STRING + } + attribute { + name: "hidden_size" + i: 4 + type: INT + } + } + node { + input: "Y" + input: "Y_ref" + output: "Y_err" + name: "error" + op_type: "Sub" + } + name: "lstm_test" + initializer { + dims: 2 + dims: 5 + dims: 3 + data_type: 1 + float_data: 1.6243454217910767 + float_data: -0.6117563843727112 + float_data: -0.5281717777252197 + float_data: -1.072968602180481 + float_data: 0.8654076457023621 + float_data: -2.3015387058258057 + float_data: 1.744811773300171 + float_data: -0.7612069249153137 + float_data: 0.31903910636901855 + float_data: -0.24937038123607635 + float_data: 1.4621078968048096 + float_data: -2.060140609741211 + float_data: -0.3224171996116638 + float_data: -0.38405436277389526 + float_data: 1.1337693929672241 + float_data: -1.0998913049697876 + float_data: -0.1724282056093216 + float_data: -0.8778584003448486 + float_data: 0.042213745415210724 + float_data: 0.5828152298927307 + float_data: -1.1006191968917847 + float_data: 1.144723653793335 + float_data: 0.9015907049179077 + float_data: 0.5024943351745605 + float_data: 0.9008559584617615 + float_data: -0.6837278604507446 + float_data: -0.12289022654294968 + float_data: -0.9357694387435913 + float_data: -0.26788806915283203 + float_data: 0.5303554534912109 + name: "X" + } + initializer { + dims: 1 + dims: 16 + dims: 3 + data_type: 1 + float_data: -0.6916607618331909 + float_data: -0.3967535197734833 + float_data: -0.6871727108955383 + float_data: -0.8452056646347046 + float_data: -0.6712461113929749 + float_data: -0.01266459934413433 + float_data: -1.1173104047775269 + float_data: 0.2344156950712204 + float_data: 1.6598021984100342 + float_data: 0.7420441508293152 + float_data: -0.19183555245399475 + float_data: -0.887628972530365 + float_data: -0.7471582889556885 + float_data: 1.6924545764923096 + float_data: 0.050807755440473557 + float_data: -0.6369956731796265 + float_data: 0.19091548025608063 + float_data: 2.100255250930786 + float_data: 0.12015895545482635 + float_data: 0.6172031164169312 + float_data: 0.30017033219337463 + float_data: -0.3522498607635498 + float_data: -1.142518162727356 + float_data: -0.3493427336215973 + float_data: -0.20889423787593842 + float_data: 0.5866231918334961 + float_data: 0.838983416557312 + float_data: 0.9311020970344543 + float_data: 0.2855873107910156 + float_data: 0.8851411938667297 + float_data: -0.7543979287147522 + float_data: 1.2528681755065918 + float_data: 0.5129297971725464 + float_data: -0.2980928421020508 + float_data: 0.4885181486606598 + float_data: -0.07557171583175659 + float_data: 1.1316293478012085 + float_data: 1.5198168754577637 + float_data: 2.185575485229492 + float_data: -1.396496295928955 + float_data: -1.444113850593567 + float_data: -0.5044658780097961 + float_data: 0.1600370705127716 + float_data: 0.8761689066886902 + float_data: 0.31563493609428406 + float_data: -2.0222012996673584 + float_data: -0.30620402097702026 + float_data: 0.8279746174812317 + name: "W" + } + initializer { + dims: 1 + dims: 16 + dims: 4 + data_type: 1 + float_data: 0.23009473085403442 + float_data: 0.7620111703872681 + float_data: -0.22232814133167267 + float_data: -0.20075806975364685 + float_data: 0.18656139075756073 + float_data: 0.4100516438484192 + float_data: 0.19829972088336945 + float_data: 0.11900864541530609 + float_data: -0.6706622838973999 + float_data: 0.3775637745857239 + float_data: 0.12182126939296722 + float_data: 1.129483938217163 + float_data: 1.1989178657531738 + float_data: 0.1851564198732376 + float_data: -0.37528494000434875 + float_data: -0.6387304067611694 + float_data: 0.4234943687915802 + float_data: 0.07734006643295288 + float_data: -0.3438536822795868 + float_data: 0.04359685629606247 + float_data: -0.6200008392333984 + float_data: 0.698032021522522 + float_data: -0.447128564119339 + float_data: 1.2245076894760132 + float_data: 0.40349164605140686 + float_data: 0.5935785174369812 + float_data: -1.094911813735962 + float_data: 0.16938243806362152 + float_data: 0.7405564785003662 + float_data: -0.953700602054596 + float_data: -0.26621851325035095 + float_data: 0.03261454775929451 + float_data: -1.3731173276901245 + float_data: 0.3151593804359436 + float_data: 0.8461606502532959 + float_data: -0.8595159649848938 + float_data: 0.3505459725856781 + float_data: -1.3122833967208862 + float_data: -0.03869551047682762 + float_data: -1.6157723665237427 + float_data: 1.121417760848999 + float_data: 0.40890052914619446 + float_data: -0.02461695671081543 + float_data: -0.775161623954773 + float_data: 1.2737559080123901 + float_data: 1.9671016931533813 + float_data: -1.8579819202423096 + float_data: 1.2361639738082886 + float_data: 1.6276507377624512 + float_data: 0.3380116820335388 + float_data: -1.1992679834365845 + float_data: 0.8633453249931335 + float_data: -0.1809203028678894 + float_data: -0.6039206385612488 + float_data: -1.230058193206787 + float_data: 0.5505374670028687 + float_data: 0.79280686378479 + float_data: -0.6235307455062866 + float_data: 0.5205763578414917 + float_data: -1.1443413496017456 + float_data: 0.801861047744751 + float_data: 0.04656729847192764 + float_data: -0.18656976521015167 + float_data: -0.10174587368965149 + name: "R" + } + initializer { + dims: 1 + dims: 32 + data_type: 1 + float_data: 0.8688861727714539 + float_data: 0.7504116296768188 + float_data: 0.5294653177261353 + float_data: 0.13770121335983276 + float_data: 0.07782112807035446 + float_data: 0.6183802485466003 + float_data: 0.2324945628643036 + float_data: 0.682551383972168 + float_data: -0.3101167678833008 + float_data: -2.434837818145752 + float_data: 1.0388245582580566 + float_data: 2.1869795322418213 + float_data: 0.44136443734169006 + float_data: -0.10015523433685303 + float_data: -0.13644474744796753 + float_data: -0.11905419081449509 + float_data: 0.01740940846502781 + float_data: -1.1220186948776245 + float_data: -0.5170944333076477 + float_data: -0.997026801109314 + float_data: 0.2487991601228714 + float_data: -0.29664114117622375 + float_data: 0.49521133303642273 + float_data: -0.17470316588878632 + float_data: 0.9863351583480835 + float_data: 0.2135339081287384 + float_data: 2.190699815750122 + float_data: -1.8963608741760254 + float_data: -0.6469166874885559 + float_data: 0.901486873626709 + float_data: 2.5283257961273193 + float_data: -0.24863477051258087 + name: "B" + } + initializer { + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: 0.043668992817401886 + float_data: -0.2263142466545105 + float_data: 1.3314571380615234 + float_data: -0.28730785846710205 + float_data: 0.6800698637962341 + float_data: -0.31980159878730774 + float_data: -1.2725588083267212 + float_data: 0.3135477304458618 + float_data: 0.5031847953796387 + float_data: 1.293225884437561 + float_data: -0.11044702678918839 + float_data: -0.6173620820045471 + float_data: 0.5627610683441162 + float_data: 0.24073709547519684 + float_data: 0.2806650698184967 + float_data: -0.07311270385980606 + float_data: 1.1603385210037231 + float_data: 0.36949270963668823 + float_data: 1.9046586751937866 + float_data: 1.1110566854476929 + name: "initial_h" + } + initializer { + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: 0.6590498089790344 + float_data: -1.6274383068084717 + float_data: 0.6023193001747131 + float_data: 0.4202822148799896 + float_data: 0.8109516501426697 + float_data: 1.044442057609558 + float_data: -0.4008781909942627 + float_data: 0.8240056037902832 + float_data: -0.5623054504394531 + float_data: 1.9548780918121338 + float_data: -1.33195161819458 + float_data: -1.7606885433197021 + float_data: -1.6507213115692139 + float_data: -0.8905555605888367 + float_data: -1.1191153526306152 + float_data: 1.9560788869857788 + float_data: -0.32649949193000793 + float_data: -1.342675805091858 + float_data: 1.1143829822540283 + float_data: -0.5865239500999451 + name: "initial_c" + } + initializer { + dims: 2 + dims: 1 + dims: 5 + dims: 4 + data_type: 1 + float_data: -0.05507460609078407 + float_data: -0.00044277211418375373 + float_data: 0.3326265215873718 + float_data: -0.32690632343292236 + float_data: -0.645453691482544 + float_data: 0.013076402246952057 + float_data: -0.1509844958782196 + float_data: 0.2572307288646698 + float_data: 0.06812380254268646 + float_data: 0.05665956065058708 + float_data: -0.49239203333854675 + float_data: -0.666763186454773 + float_data: -0.7415023446083069 + float_data: 0.007944668643176556 + float_data: -0.5241938829421997 + float_data: -0.05047144368290901 + float_data: 0.14573483169078827 + float_data: 0.3051227927207947 + float_data: 0.593583881855011 + float_data: 0.010722354054450989 + float_data: -0.19817383587360382 + float_data: 0.0608539842069149 + float_data: 0.14691033959388733 + float_data: 0.16047033667564392 + float_data: 0.5478463172912598 + float_data: 0.05255172774195671 + float_data: -0.25240346789360046 + float_data: 0.23623821139335632 + float_data: 0.15029644966125488 + float_data: 0.22066667675971985 + float_data: -0.7114776968955994 + float_data: -0.1218540221452713 + float_data: -0.10800767689943314 + float_data: -0.03991087153553963 + float_data: -0.4257034659385681 + float_data: 0.34172356128692627 + float_data: -0.03591137006878853 + float_data: 0.034066714346408844 + float_data: 0.34216535091400146 + float_data: -0.16652527451515198 + name: "Y_ref" + } + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 5 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "W" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "R" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + input { + name: "initial_h" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "initial_c" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "Y_ref" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "Y_err" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 10 +} diff --git a/tests/unittests/OnnxExporterTest.cpp b/tests/unittests/OnnxExporterTest.cpp index f921ce8368..8872401b44 100644 --- a/tests/unittests/OnnxExporterTest.cpp +++ b/tests/unittests/OnnxExporterTest.cpp @@ -121,7 +121,11 @@ TEST(exporter, onnxModels) { llvm::outs() << "Ignore output file: " << name << "\n"; continue; } - + if (name.find("lstm") != std::string::npos) { + // Ignore LSTM files. + llvm::outs() << "Ignore LSTM model file: " << name << "\n"; + continue; + } testLoadAndSaveONNXModel(dirIt->path(), /* zipMode */ true); testLoadAndSaveONNXModel(dirIt->path(), /* zipMode */ false); } diff --git a/tests/unittests/OnnxImporterTest.cpp b/tests/unittests/OnnxImporterTest.cpp index 375d0b92ae..82de72fdaa 100644 --- a/tests/unittests/OnnxImporterTest.cpp +++ b/tests/unittests/OnnxImporterTest.cpp @@ -2519,3 +2519,71 @@ TEST(onnx, importLess) { EXPECT_EQ(CMPLT->getResult().dims()[1], 4); EXPECT_EQ(CMPLT->getResult().dims()[2], 1); } + +/// Test loading LSTM from a ONNX model. The ONNX model already computes +/// the error compared to a PyTorch reference implementation. +static void importLSTM(std::string fileName) { + ExecutionEngine EE; + auto &mod = EE.getModule(); + Function *F = mod.createFunction("main"); + + PlaceholderBindings bindings; + { + ONNXModelLoader onnxLD(fileName, {}, {}, *F); + bindings.allocate(mod.getPlaceholders()); + } + + // Search LSTM state placeholders and set to 0. + Placeholder *Y_h_ph = nullptr; + Placeholder *Y_c_ph = nullptr; + for (const auto &ph : mod.getPlaceholders()) { + if (llvm::StringRef(ph->getName()).endswith("Y_h")) + Y_h_ph = ph; + if (llvm::StringRef(ph->getName()).endswith("Y_c")) + Y_c_ph = ph; + } + EXPECT_TRUE(Y_h_ph); + EXPECT_TRUE(Y_c_ph); + bindings.get(Y_h_ph)->zero(); + bindings.get(Y_c_ph)->zero(); + + // Compile and run. + EE.compile(CompilationMode::Infer); + EE.run(bindings); + + // Verify LSTM error. + Placeholder *Y_err_ph = mod.getPlaceholderByName("Y_err"); + EXPECT_TRUE(Y_err_ph); + auto err = bindings.get(Y_err_ph)->getHandle(); + for (size_t idx = 0; idx < Y_err_ph->getType()->size(); idx++) { + EXPECT_TRUE(std::abs(err.raw(idx)) < 1e-6); + } +} + +TEST(onnx, importLSTMForward) { + importLSTM(GLOW_DATA_PATH "tests/models/onnxModels/lstmForward.onnxtxt"); +} + +TEST(onnx, importLSTMReverse) { + importLSTM(GLOW_DATA_PATH "tests/models/onnxModels/lstmReverse.onnxtxt"); +} + +TEST(onnx, importLSTMBidirectional) { + importLSTM(GLOW_DATA_PATH + "tests/models/onnxModels/lstmBidirectional.onnxtxt"); +} + +TEST(onnx, importLSTMForwardNoBias) { + importLSTM(GLOW_DATA_PATH + "tests/models/onnxModels/lstmForwardNoBias.onnxtxt"); +} + +TEST(onnx, importLSTMForwardNoState) { + importLSTM(GLOW_DATA_PATH + "tests/models/onnxModels/lstmForwardNoState.onnxtxt"); +} + +TEST(onnx, importLSTMForwardWithPeephole) { + importLSTM(GLOW_DATA_PATH + "tests/models/onnxModels/lstmForwardWithPeephole.onnxtxt"); +} diff --git a/utils/scripts/gen_onnx_lstm_model.py b/utils/scripts/gen_onnx_lstm_model.py new file mode 100644 index 0000000000..f2f99e9b46 --- /dev/null +++ b/utils/scripts/gen_onnx_lstm_model.py @@ -0,0 +1,462 @@ +# Copyright (c) Glow Contributors. See CONTRIBUTORS file. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import onnx +from onnx import helper +from onnx import TensorProto +import numpy as np + +import torch +import torch.nn +import torch.onnx + +# LSTM enums +LSTM_DIR_FORWARD = 'forward' +LSTM_DIR_REVERSE = 'reverse' +LSTM_DIR_BIDIRECTIONAL = 'bidirectional' +LSTM_DIRS = [LSTM_DIR_FORWARD, LSTM_DIR_REVERSE, LSTM_DIR_BIDIRECTIONAL] + + +# ONNX utility +def make_init(name, type, tensor): + return helper.make_tensor(name=name, data_type=type, dims=tensor.shape, vals=tensor.reshape(tensor.size).tolist()) + + +# Function to generate LSTM ONNX test model +def gen_lstm_onnx_test_model(model_path, seq_length, batch_size, hidden_size, input_size, direction, has_bias, + has_sequence_lens, has_initial_h, has_initial_c, has_peephole): + + # Validate parameters + assert direction in LSTM_DIRS, 'ONNX LSTM direction invalid!' + assert not has_sequence_lens, 'ONNX LSTM Variable sequence length not supported' + + # Get number of directions + num_directions = 2 if (direction == LSTM_DIR_BIDIRECTIONAL) else 1 + + # Tensor sizes + X_shape = [seq_length, batch_size, input_size] + W_shape = [num_directions, 4 * hidden_size, input_size] + R_shape = [num_directions, 4 * hidden_size, hidden_size] + B_shape = [num_directions, 8 * hidden_size] + sequence_lens_shape = [batch_size] + initial_h_shape = [num_directions, batch_size, hidden_size] + initial_c_shape = [num_directions, batch_size, hidden_size] + P_shape = [num_directions, 3 * hidden_size] + Y_shape = [seq_length, num_directions, batch_size, hidden_size] + + # Generate random inputs (weights are assumed concatenated in ONNX format: i,o,f,c) + np.random.seed(1) + X = np.random.randn(*X_shape) + W = np.random.randn(*W_shape) + R = np.random.randn(*R_shape) + B = np.random.randn(*B_shape) if has_bias else np.zeros(B_shape) + sequence_lens = np.random.randint( + 1, seq_length, batch_size) if has_sequence_lens else np.tile(seq_length, batch_size) + initial_h = np.random.randn( + *initial_h_shape) if has_initial_h else np.zeros(initial_h_shape) + initial_c = np.random.randn( + *initial_c_shape) if has_initial_c else np.zeros(initial_c_shape) + P = np.random.randn(*P_shape) if has_peephole else np.zeros(P_shape) + + # Function to get all the weight components for the given direction + def get_weights(dir_idx): + Wi = np.reshape(W[dir_idx, 0 * hidden_size: 1 * + hidden_size, :], [hidden_size, input_size]) + Wo = np.reshape(W[dir_idx, 1 * hidden_size: 2 * + hidden_size, :], [hidden_size, input_size]) + Wf = np.reshape(W[dir_idx, 2 * hidden_size: 3 * + hidden_size, :], [hidden_size, input_size]) + Wc = np.reshape(W[dir_idx, 3 * hidden_size: 4 * + hidden_size, :], [hidden_size, input_size]) + Ri = np.reshape(R[dir_idx, 0 * hidden_size: 1 * + hidden_size, :], [hidden_size, hidden_size]) + Ro = np.reshape(R[dir_idx, 1 * hidden_size: 2 * + hidden_size, :], [hidden_size, hidden_size]) + Rf = np.reshape(R[dir_idx, 2 * hidden_size: 3 * + hidden_size, :], [hidden_size, hidden_size]) + Rc = np.reshape(R[dir_idx, 3 * hidden_size: 4 * + hidden_size, :], [hidden_size, hidden_size]) + bWi = np.reshape(B[dir_idx, 0 * hidden_size: 1 * + hidden_size], [hidden_size]) + bWo = np.reshape(B[dir_idx, 1 * hidden_size: 2 * + hidden_size], [hidden_size]) + bWf = np.reshape(B[dir_idx, 2 * hidden_size: 3 * + hidden_size], [hidden_size]) + bWc = np.reshape(B[dir_idx, 3 * hidden_size: 4 * + hidden_size], [hidden_size]) + bRi = np.reshape(B[dir_idx, 4 * hidden_size: 5 * + hidden_size], [hidden_size]) + bRo = np.reshape(B[dir_idx, 5 * hidden_size: 6 * + hidden_size], [hidden_size]) + bRf = np.reshape(B[dir_idx, 6 * hidden_size: 7 * + hidden_size], [hidden_size]) + bRc = np.reshape(B[dir_idx, 7 * hidden_size: 8 * + hidden_size], [hidden_size]) + Pi = np.tile(P[dir_idx, 0 * hidden_size: 1 * + hidden_size], (batch_size, 1)) + Po = np.tile(P[dir_idx, 1 * hidden_size: 2 * + hidden_size], (batch_size, 1)) + Pf = np.tile(P[dir_idx, 2 * hidden_size: 3 * + hidden_size], (batch_size, 1)) + return Wi, Wo, Wf, Wc, Ri, Ro, Rf, Rc, bWi, bWo, bWf, bWc, bRi, bRo, bRf, bRc, Pi, Po, Pf + + # Function to get PyTorch weights (which are in the i, f, c, o order) + def get_torch_weights(dir_idx): + Wi, Wo, Wf, Wc, Ri, Ro, Rf, Rc, bWi, bWo, bWf, bWc, bRi, bRo, bRf, bRc, Pi, Po, Pf = get_weights( + dir_idx) + W_torch = np.concatenate((Wi, Wf, Wc, Wo), 0) + R_torch = np.concatenate((Ri, Rf, Rc, Ro), 0) + bW_torch = np.concatenate((bWi, bWf, bWc, bWo), 0) + bR_torch = np.concatenate((bRi, bRf, bRc, bRo), 0) + return (W_torch, R_torch, bW_torch, bR_torch) + + # ----------------------------------------- COMPUTE pyTORCH REFERENCE ---------------------------------------------- + # Compute reference using Pytorch. Pytorch LSTM has only forward/bidirectional so we will do the reverse LSTM using + # a Pytorch forward LSTM. + lstm = torch.nn.LSTM(input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0, + bidirectional=(direction == LSTM_DIR_BIDIRECTIONAL)) + + # Get LSTM state dictionary + lstm_state_dict = lstm.state_dict() + + # Assign forward weights + forwardEnabled = direction in [LSTM_DIR_FORWARD, LSTM_DIR_BIDIRECTIONAL] + if forwardEnabled: + forward_dir_idx = 0 + (W_torch, R_torch, bW_torch, bR_torch) = get_torch_weights(forward_dir_idx) + lstm_state_dict['weight_ih_l0'] = torch.tensor( + W_torch, dtype=torch.float32) + lstm_state_dict['weight_hh_l0'] = torch.tensor( + R_torch, dtype=torch.float32) + lstm_state_dict['bias_ih_l0'] = torch.tensor( + bW_torch, dtype=torch.float32) + lstm_state_dict['bias_hh_l0'] = torch.tensor( + bR_torch, dtype=torch.float32) + + # Assign reverse weights + reverseEnabled = direction in [LSTM_DIR_REVERSE, LSTM_DIR_BIDIRECTIONAL] + if reverseEnabled: + if direction == LSTM_DIR_REVERSE: + reverse_dir_idx = 0 + (W_torch, R_torch, bW_torch, bR_torch) = get_torch_weights(reverse_dir_idx) + lstm_state_dict['weight_ih_l0'] = torch.tensor( + W_torch, dtype=torch.float32) + lstm_state_dict['weight_hh_l0'] = torch.tensor( + R_torch, dtype=torch.float32) + lstm_state_dict['bias_ih_l0'] = torch.tensor( + bW_torch, dtype=torch.float32) + lstm_state_dict['bias_hh_l0'] = torch.tensor( + bR_torch, dtype=torch.float32) + else: + reverse_dir_idx = 1 + (W_torch, R_torch, bW_torch, bR_torch) = get_torch_weights(reverse_dir_idx) + lstm_state_dict['weight_ih_l0_reverse'] = torch.tensor( + W_torch, dtype=torch.float32) + lstm_state_dict['weight_hh_l0_reverse'] = torch.tensor( + R_torch, dtype=torch.float32) + lstm_state_dict['bias_ih_l0_reverse'] = torch.tensor( + bW_torch, dtype=torch.float32) + lstm_state_dict['bias_hh_l0_reverse'] = torch.tensor( + bR_torch, dtype=torch.float32) + + # Set LSTM state dictionary + lstm.load_state_dict(lstm_state_dict, strict=True) + + # Perform inference + X_torch = torch.tensor(X, dtype=torch.float32) + initial_h_torch = torch.tensor(initial_h, dtype=torch.float32) + initial_c_torch = torch.tensor(initial_c, dtype=torch.float32) + if direction == LSTM_DIR_REVERSE: + Y, (next_h, next_c) = lstm(X_torch.flip( + [0]), (initial_h_torch, initial_c_torch)) + Y = Y.flip([0]) + else: + Y, (next_h, next_c) = lstm(X_torch, (initial_h_torch, initial_c_torch)) + + # Reshape output to ONNX format [seq_length, num_directions, batch_size, hidden_size] + Y_ref = Y.detach().numpy() + Y_ref = np.reshape( + Y_ref, [seq_length, batch_size, num_directions, hidden_size]) + Y_ref = np.transpose(Y_ref, [0, 2, 1, 3]) + + # Reshape states to ONNX format + Y_h_ref = next_h.detach().numpy() + Y_c_ref = next_c.detach().numpy() + + # --------------------------------------- COMPUTE PYTHON-NUMPY REFERENCE ------------------------------------------- + # Create X slices + Xslices = list() + for t in range(seq_length): + Xslices.append(np.reshape(X[t, :, :], [batch_size, input_size])) + + # Function to compute one LSTM cell + def compute_lstm(forward): + dir_idx = 0 if forward else (0 if direction == LSTM_DIR_REVERSE else 1) + Wi, Wo, Wf, Wc, Ri, Ro, Rf, Rc, bWi, bWo, bWf, bWc, bRi, bRo, bRf, bRc, Pi, Po, Pf = get_weights( + dir_idx) + + def f(x): return (1 / (1 + np.exp(-x))) + def g(x): return np.tanh(x) + def h(x): return np.tanh(x) + def mm(x, w): return np.matmul(x, w.transpose()) + Ht = np.reshape(initial_h[dir_idx, :, :], [batch_size, hidden_size]) + Ct = np.reshape(initial_c[dir_idx, :, :], [batch_size, hidden_size]) + + Yslices = list() + for t in range(seq_length): + xt = Xslices[t] if forward else Xslices[seq_length - 1 - t] + ft = f(mm(xt, Wf) + bWf + mm(Ht, Rf) + bRf + Pf * Ct) + it = f(mm(xt, Wi) + bWi + mm(Ht, Ri) + bRi + Pi * Ct) + ctild = g(mm(xt, Wc) + bWc + mm(Ht, Rc) + bRc) + Ct = ft * Ct + it * ctild + ot = f(mm(xt, Wo) + bWo + mm(Ht, Ro) + bRo + Po * Ct) + Ht = ot * h(Ct) + Yslices.append(Ht) + return Yslices, Ht, Ct + + Yslices = list() + Hslices = list() + Cslices = list() + + # Compute forward LSTM + forwardYslices = list() + if forwardEnabled: + Yt, Ht, Ct = compute_lstm(True) + forwardYslices += Yt + Hslices.append(Ht) + Cslices.append(Ct) + + # Compute reverse LSTM + reverseYslices = list() + if reverseEnabled: + Yt, Ht, Ct = compute_lstm(False) + reverseYslices += Yt + Hslices.append(Ht) + Cslices.append(Ct) + + # Concatenate slices + for t in range(seq_length): + if forwardEnabled: + Yslices.append(forwardYslices[t]) + if reverseEnabled: + Yslices.append(reverseYslices[seq_length - 1 - t]) + Y_ref_np = np.concatenate(Yslices, 0).reshape( + [seq_length, num_directions, batch_size, hidden_size]) + Y_h_ref_np = np.concatenate(Hslices, 0).reshape( + [num_directions, batch_size, hidden_size]) + Y_c_ref_np = np.concatenate(Cslices, 0).reshape( + [num_directions, batch_size, hidden_size]) + + # Use numpy implementation when using peepholes, else assert errors + if has_peephole: + Y_ref = Y_ref_np + Y_h_ref = Y_h_ref_np + Y_c_ref = Y_c_ref_np + else: + assert np.max(np.abs(Y_ref - Y_ref_np) + ) < 1e-6, "Mismatch between Pytorch and Numpy LSTM implementation" + assert np.max(np.abs(Y_h_ref - Y_h_ref_np) + ) < 1e-6, "Mismatch between Pytorch and Numpy LSTM implementation" + assert np.max(np.abs(Y_c_ref - Y_c_ref_np) + ) < 1e-6, "Mismatch between Pytorch and Numpy LSTM implementation" + + # ---------------------------------------------- NODE DEFINITION -------------------------------------------------- + # Node inputs + node_inputs = ['X', + 'W', + 'R', + 'B' if has_bias else '', + '', + 'initial_h' if has_initial_h else '', + 'initial_c' if has_initial_c else '', + 'P' if has_peephole else ''] + + # Node outputs + node_outputs = ['Y'] + + # LSTM node definition + lstm_node_def = onnx.helper.make_node( + 'LSTM', + name='lstm', + inputs=node_inputs, + outputs=node_outputs, + hidden_size=hidden_size, + direction=direction + ) + + # Error node definition + err_node_def = onnx.helper.make_node( + 'Sub', + name='error', + inputs=['Y', 'Y_ref'], + outputs=['Y_err'] + ) + + # --------------------------------------------- GRAPH DEFINITION -------------------------------------------------- + graph_input = list() + graph_init = list() + graph_output = list() + + # LSTM inputs + graph_input.append(helper.make_tensor_value_info( + 'X', TensorProto.FLOAT, X_shape)) + graph_input.append(helper.make_tensor_value_info( + 'W', TensorProto.FLOAT, W_shape)) + graph_input.append(helper.make_tensor_value_info( + 'R', TensorProto.FLOAT, R_shape)) + if has_bias: + graph_input.append(helper.make_tensor_value_info( + 'B', TensorProto.FLOAT, B_shape)) + if has_sequence_lens: + graph_input.append(helper.make_tensor_value_info( + 'sequence_lens', TensorProto.INT32, sequence_lens_shape)) + if has_initial_h: + graph_input.append(helper.make_tensor_value_info( + 'initial_h', TensorProto.FLOAT, initial_h_shape)) + if has_initial_c: + graph_input.append(helper.make_tensor_value_info( + 'initial_c', TensorProto.FLOAT, initial_c_shape)) + if has_peephole: + graph_input.append(helper.make_tensor_value_info( + 'P', TensorProto.FLOAT, P_shape)) + + # Reference input + graph_input.append(helper.make_tensor_value_info( + 'Y_ref', TensorProto.FLOAT, Y_shape)) + + # LSTM initializers + graph_init.append(make_init('X', TensorProto.FLOAT, X)) + graph_init.append(make_init('W', TensorProto.FLOAT, W)) + graph_init.append(make_init('R', TensorProto.FLOAT, R)) + if has_bias: + graph_init.append(make_init('B', TensorProto.FLOAT, B)) + if has_sequence_lens: + graph_init.append( + make_init('sequence_lens', TensorProto.INT32, sequence_lens)) + if has_initial_h: + graph_init.append(make_init('initial_h', TensorProto.FLOAT, initial_h)) + if has_initial_c: + graph_init.append(make_init('initial_c', TensorProto.FLOAT, initial_c)) + if has_peephole: + graph_init.append(make_init('P', TensorProto.FLOAT, P)) + + # Reference initializer + graph_init.append(make_init('Y_ref', TensorProto.FLOAT, Y_ref)) + + # Graph outputs + graph_output.append(helper.make_tensor_value_info( + 'Y_err', TensorProto.FLOAT, Y_shape)) + + # Define graph (GraphProto) + graph_name = 'lstm_test' + graph_def = helper.make_graph( + [lstm_node_def, err_node_def], graph_name, inputs=graph_input, outputs=graph_output) + + # Set initializers + graph_def.initializer.extend(graph_init) + + # --------------------------------------------- MODEL DEFINITION -------------------------------------------------- + # Define model (ModelProto) + model_def = helper.make_model(graph_def, producer_name='onnx-lstm') + + # Check model + onnx.checker.check_model(model_def) + + # Print model + with open(model_path, 'w') as f: + f.write(str(model_def)) + + +# Forward LSTM +gen_lstm_onnx_test_model(model_path='lstmForward.onnxtxt', + seq_length=2, + batch_size=5, + hidden_size=4, + input_size=3, + direction='forward', + has_bias=True, + has_sequence_lens=False, + has_initial_h=True, + has_initial_c=True, + has_peephole=False) + +# Reverse LSTM +gen_lstm_onnx_test_model(model_path='lstmReverse.onnxtxt', + seq_length=2, + batch_size=5, + hidden_size=4, + input_size=3, + direction='reverse', + has_bias=True, + has_sequence_lens=False, + has_initial_h=True, + has_initial_c=True, + has_peephole=False) + +# Bidirectional LSTM +gen_lstm_onnx_test_model(model_path='lstmBidirectional.onnxtxt', + seq_length=2, + batch_size=5, + hidden_size=4, + input_size=3, + direction='bidirectional', + has_bias=True, + has_sequence_lens=False, + has_initial_h=True, + has_initial_c=True, + has_peephole=False) + +# Forward no bias LSTM +gen_lstm_onnx_test_model(model_path='lstmForwardNoBias.onnxtxt', + seq_length=1, + batch_size=5, + hidden_size=4, + input_size=3, + direction='forward', + has_bias=False, + has_sequence_lens=False, + has_initial_h=True, + has_initial_c=True, + has_peephole=False) + +# Forward no state LSTM +gen_lstm_onnx_test_model(model_path='lstmForwardNoState.onnxtxt', + seq_length=1, + batch_size=5, + hidden_size=4, + input_size=3, + direction='forward', + has_bias=True, + has_sequence_lens=False, + has_initial_h=False, + has_initial_c=False, + has_peephole=False) + +# Forward with peephole LSTM +gen_lstm_onnx_test_model(model_path='lstmForwardWithPeephole.onnxtxt', + seq_length=1, + batch_size=5, + hidden_size=4, + input_size=3, + direction='forward', + has_bias=True, + has_sequence_lens=False, + has_initial_h=True, + has_initial_c=True, + has_peephole=True)