Skip to content

[ONNX] Import LSTM #3713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions include/glow/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,14 @@ class Function final : public Named {
NodeValue input, Storage *W,
Storage *B, unsigned_t axis = 1);

/// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
/// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened
/// along \p axis. Note, output type and outputDepth are inferred based on
/// the input types.
FullyConnectedNode *createFullyConnected(llvm::StringRef name,
NodeValue input, NodeValue W,
NodeValue B, unsigned_t axis = 1);

/// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
/// \p W, bias \p B, and \p outTy. If \p input is not 2 dimensional then it is
/// flattened along \p axis. Note, outputDepth is inferred based on \p outTy.
Expand Down Expand Up @@ -1337,6 +1345,45 @@ class Function final : public Named {
const llvm::ArrayRef<NodeValue> inputs, unsigned batchSize,
unsigned hiddenSize, unsigned outputSize,
std::vector<NodeValue> &outputs);

/// Definition for the activation function of an LSTM module.
using LstmActivation = std::function<Node *(llvm::StringRef, Node *)>;

/// Type definition for the direction of an LSTM module.
enum class LstmDirection {
Forward,
Reverse,
Bidirectional,
};

/// Create an unrolled multi-layer LSTM according to the ONNX definition. The
/// LSTM has the following inputs:
/// - input \p X with size [S, B, ISize].
/// - weigts \p W with size [N, 4 * HSize, ISize].
/// - reccurence weights \p R with size [N, 4 * HSize, HSize].
/// - bias weights \p B with size [N, 8 * HSize].
/// - initial hidden state \p initial_h with size [N, B, HSize].
/// - initial cell state \p initial_c with size [N, B, HSize].
/// - peephole weights \p P with size [N, 3 * HSize].
/// where S is the sequence length, N is the number of directions, B is the
/// batch size, ISize is the input size and HSize is the hidden size.
/// The LSTM has the following outputs:
/// - output \p Y with size [S, N, B, HSize]
/// - final hidden state \p Y_h with size [N, B, HSize].
/// - final cell state \p Y_c with size [N, B, HSize].
/// The direction of the instatiated LSTM is given by \p direction. The LSTM
/// will use the activation functions defined by \p activations which defines:
/// - [f,g,h] in case the LSTM is unidirectional (3 functions).
/// - [f,g,h] for the forward cell followed by [f,g,h] for the reverse cell in
/// case the LSTM is bidirectional (6 functions).
/// The inputs \p B and \p P are optional (assumed 0 if nullptr is provided).
/// The names of all the nodes created are prefixed with \p namePrefix.
void createONNXLSTM(llvm::StringRef namePrefix, NodeValue X, NodeValue W,
NodeValue R, NodeValue B, NodeValue initial_h,
NodeValue initial_c, NodeValue P, NodeValue &Y,
NodeValue &Y_h, NodeValue &Y_c, unsigned hiddenSize,
LstmDirection direction,
std::vector<LstmActivation> &activations);
/// @}

/// Create a TraceEvent in the runtime profile, which triggers collection of
Expand Down
4 changes: 4 additions & 0 deletions include/glow/Importer/ONNXModelLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ class ONNXModelLoader
Error loadWhere(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict);

/// Load LSTM ONNX operator.
Error loadLSTM(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict);

/// Load Glow specific operators, not defined in ONNX format
/// Load Glow CmpEQ operator.
Error loadCmpEQ(const ONNX_NAMESPACE::NodeProto &op,
Expand Down
Loading