Skip to content

[ONNX Importer] Add RNN and GRU. #3847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from

Conversation

mciprian13
Copy link
Contributor

Summary
Add RNN and GRU modules in the ONNX Importer.
This PR is similar to #3713.
I added both RNN and GRU in the same PR because they are very similar.

Documentation
None

Test Plan
Add ONNX models (and Python generator scripts) with PyTorch numerical references.

@mciprian13
Copy link
Contributor Author

@jfix71 As promised, I come back with double the fun: RNN and GRU.

Copy link
Contributor

@jfix71 jfix71 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mciprian13 Looks good mostly, few questions/nits. Also wondering about if it's at all possible to reduce some of the logic here between GRU/LSTM/RNN, but understand they're different ops so might be a bit difficult/messier that way.

@@ -124,6 +124,16 @@ TEST(exporter, onnxModels) {
llvm::outs() << "Ignore output file: " << name << "\n";
continue;
}
if (name.find("rnn") != std::string::npos) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why we need to disable for the exporter test here -- and same for gru and lstm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The models don not work when reloading (they are loaded and written without errors but when reloading some errors pop up). I debugged a little bit but I can`t figure out what is the problem. The errors look like this:

Error message: could not find constant with name lstm_Ro_transp
Error message: No node under name lstm_Y_c

Can anyone help with this? Or maybe raise a separate issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, yeah we need to improve the testing for the exporter, this can be done separately, thanks.

@mciprian13
Copy link
Contributor Author

mciprian13 commented Dec 12, 2019

Monsieur @jfix71:

  • I updated the little nits you mentioned.
  • I refactored a little bit the code from OnnnModelLoader.cpp to reuse more code between the 3 modules (the part when reading the direction and the one reading the activations from the proto).
  • As a bonus I also added the "input_forget" feature for the LSTM (plus unit test).
  • The rest of the code I don`t think we can safely refactor for code reuse since the 3 modules (RNN, GRU, LSTM) might evolve slightly different in the future and the refactored code would look messy.
  • The remaining problem is the one related to the ONNX Exporter test which I cannot figure out.

Copy link
Contributor

@jfix71 jfix71 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mciprian13!

@@ -124,6 +124,16 @@ TEST(exporter, onnxModels) {
llvm::outs() << "Ignore output file: " << name << "\n";
continue;
}
if (name.find("rnn") != std::string::npos) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, yeah we need to improve the testing for the exporter, this can be done separately, thanks.

Copy link

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jfix71 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link

@jfix71 merged this pull request in 306a648.

@mciprian13 mciprian13 deleted the Add_ONNX_GRU branch December 20, 2019 13:06
vdantu pushed a commit to vdantu/glow that referenced this pull request Jul 12, 2020
Summary:
**Summary**
Add RNN and GRU modules in the ONNX Importer.
This PR is similar to pytorch#3713.
I added both RNN and GRU in the same PR because they are very similar.

**Documentation**
None

**Test Plan**
Add ONNX models (and Python generator scripts) with PyTorch numerical references.
Pull Request resolved: pytorch#3847

Differential Revision: D18987803

Pulled By: jfix71

fbshipit-source-id: 09f760aa57cd416bec91f8b67c83f315cb5acfff
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants