Skip to content

[ONNX Importer] Add RNN and GRU. #3847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 71 additions & 12 deletions include/glow/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1355,18 +1355,74 @@ class Function final : public Named {
unsigned hiddenSize, unsigned outputSize,
std::vector<NodeValue> &outputs);

/// Definition for the activation function of an LSTM module.
using LstmActivation = std::function<Node *(llvm::StringRef, Node *)>;

/// Type definition for the direction of an LSTM module.
enum class LstmDirection {
/// Type definition for the direction of an RNN module (RNN, GRU, LSTM).
enum class RnnDirection {
Forward,
Reverse,
Bidirectional,
};

/// Create an unrolled multi-layer LSTM according to the ONNX definition. The
/// LSTM has the following inputs:
/// Definition for a lambda used to create an activation node for RNN modules.
using RnnActivation = std::function<Node *(llvm::StringRef, Node *)>;

/// Create an unrolled multi-layer RNN according to the ONNX definition:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#RNN
/// The RNN has the following inputs:
/// - input \p X with size [S, B, ISize].
/// - weigts \p W with size [N, HSize, ISize].
/// - reccurence weights \p R with size [N, HSize, HSize].
/// - bias weights \p B with size [N, 2 * HSize].
/// - initial hidden state \p initial_h with size [N, B, HSize].
/// where S is the sequence length, N is the number of directions, B is the
/// batch size, ISize is the input size and HSize is the hidden size.
/// The RNN has the following outputs:
/// - output \p Y with size [S, N, B, HSize].
/// - final hidden state \p Y_h with size [N, B, HSize].
/// The direction of the instatiated RNN is given by \p direction. The RNN
/// will use the activation functions defined by the \p activations array:
/// - [f] in case the RNN is unidirectional (1 function).
/// - [f] for the forward cell followed by [f] for the reverse cell in
/// case the RNN is bidirectional (4 functions).
/// The input \p B is optional (assumed 0 if nullptr is provided).
/// The names of all the nodes created are prefixed with \p namePrefix.
void createOnnxRNN(llvm::StringRef namePrefix, NodeValue X, NodeValue W,
NodeValue R, NodeValue B, NodeValue initial_h,
NodeValue &Y, NodeValue &Y_h, unsigned hiddenSize,
RnnDirection direction,
std::vector<RnnActivation> &activations);

/// Create an unrolled multi-layer GRU according to the ONNX definition:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU
/// The GRU has the following inputs:
/// - input \p X with size [S, B, ISize].
/// - weigts \p W with size [N, 3 * HSize, ISize].
/// - reccurence weights \p R with size [N, 3 * HSize, HSize].
/// - bias weights \p B with size [N, 6 * HSize].
/// - initial hidden state \p initial_h with size [N, B, HSize].
/// where S is the sequence length, N is the number of directions, B is the
/// batch size, ISize is the input size and HSize is the hidden size.
/// The GRU has the following outputs:
/// - output \p Y with size [S, N, B, HSize].
/// - final hidden state \p Y_h with size [N, B, HSize].
/// The direction of the instatiated GRU is given by \p direction. The GRU
/// will use the activation functions defined by the \p activations array:
/// - [f,g] in case the GRU is unidirectional (2 functions).
/// - [f,g] for the forward cell followed by [f,g] for the reverse cell in
/// case the GRU is bidirectional (4 functions).
/// The input \p B is optional (assumed 0 if nullptr is provided).
/// The names of all the nodes created are prefixed with \p namePrefix.
/// The boolean parameter \p linearBeforeReset defines whether the reset
/// for the previous hidden state occurs before/after the linear expression.
void createOnnxGRU(llvm::StringRef namePrefix, NodeValue X, NodeValue W,
NodeValue R, NodeValue B, NodeValue initial_h,
NodeValue &Y, NodeValue &Y_h, unsigned hiddenSize,
RnnDirection direction,
std::vector<RnnActivation> &activations,
bool linearBeforeReset = false);

/// Create an unrolled multi-layer LSTM according to the ONNX definition:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
/// The LSTM has the following inputs:
/// - input \p X with size [S, B, ISize].
/// - weigts \p W with size [N, 4 * HSize, ISize].
/// - reccurence weights \p R with size [N, 4 * HSize, HSize].
Expand All @@ -1377,22 +1433,25 @@ class Function final : public Named {
/// where S is the sequence length, N is the number of directions, B is the
/// batch size, ISize is the input size and HSize is the hidden size.
/// The LSTM has the following outputs:
/// - output \p Y with size [S, N, B, HSize]
/// - output \p Y with size [S, N, B, HSize].
/// - final hidden state \p Y_h with size [N, B, HSize].
/// - final cell state \p Y_c with size [N, B, HSize].
/// The direction of the instatiated LSTM is given by \p direction. The LSTM
/// will use the activation functions defined by \p activations which defines:
/// will use the activation functions defined by \p activations array:
/// - [f,g,h] in case the LSTM is unidirectional (3 functions).
/// - [f,g,h] for the forward cell followed by [f,g,h] for the reverse cell in
/// case the LSTM is bidirectional (6 functions).
/// The inputs \p B and \p P are optional (assumed 0 if nullptr is provided).
/// The names of all the nodes created are prefixed with \p namePrefix.
void createONNXLSTM(llvm::StringRef namePrefix, NodeValue X, NodeValue W,
/// The boolean parameter \p inputForget defines whether the input and forget
/// gates should be coupled (compute the input gate from the forget gate).
void createOnnxLSTM(llvm::StringRef namePrefix, NodeValue X, NodeValue W,
NodeValue R, NodeValue B, NodeValue initial_h,
NodeValue initial_c, NodeValue P, NodeValue &Y,
NodeValue &Y_h, NodeValue &Y_c, unsigned hiddenSize,
LstmDirection direction,
std::vector<LstmActivation> &activations);
RnnDirection direction,
std::vector<RnnActivation> &activations,
bool inputForget = false);
/// @}

/// Create a TraceEvent in the runtime profile, which triggers collection of
Expand Down
8 changes: 8 additions & 0 deletions include/glow/Importer/ONNXModelLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ class ONNXModelLoader
Error loadWhere(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict);

/// Load RNN ONNX operator.
Error loadRNN(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict);

/// Load GRU ONNX operator.
Error loadGRU(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict);

/// Load LSTM ONNX operator.
Error loadLSTM(const ONNX_NAMESPACE::NodeProto &op,
const ArgumentDictionaryTy &dict);
Expand Down
Loading