-
Notifications
You must be signed in to change notification settings - Fork 699
[ONNX Importer] Add RNN and GRU. #3847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@jfix71 As promised, I come back with double the fun: RNN and GRU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mciprian13 Looks good mostly, few questions/nits. Also wondering about if it's at all possible to reduce some of the logic here between GRU/LSTM/RNN, but understand they're different ops so might be a bit difficult/messier that way.
@@ -124,6 +124,16 @@ TEST(exporter, onnxModels) { | |||
llvm::outs() << "Ignore output file: " << name << "\n"; | |||
continue; | |||
} | |||
if (name.find("rnn") != std::string::npos) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why we need to disable for the exporter test here -- and same for gru and lstm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The models don not work when reloading (they are loaded and written without errors but when reloading some errors pop up). I debugged a little bit but I can`t figure out what is the problem. The errors look like this:
Error message: could not find constant with name lstm_Ro_transp
Error message: No node under name lstm_Y_c
Can anyone help with this? Or maybe raise a separate issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, yeah we need to improve the testing for the exporter, this can be done separately, thanks.
…. Add "input_forget" capability for LSTM.
Monsieur @jfix71:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mciprian13!
@@ -124,6 +124,16 @@ TEST(exporter, onnxModels) { | |||
llvm::outs() << "Ignore output file: " << name << "\n"; | |||
continue; | |||
} | |||
if (name.find("rnn") != std::string::npos) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, yeah we need to improve the testing for the exporter, this can be done separately, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jfix71 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: **Summary** Add RNN and GRU modules in the ONNX Importer. This PR is similar to pytorch#3713. I added both RNN and GRU in the same PR because they are very similar. **Documentation** None **Test Plan** Add ONNX models (and Python generator scripts) with PyTorch numerical references. Pull Request resolved: pytorch#3847 Differential Revision: D18987803 Pulled By: jfix71 fbshipit-source-id: 09f760aa57cd416bec91f8b67c83f315cb5acfff
Summary
Add RNN and GRU modules in the ONNX Importer.
This PR is similar to #3713.
I added both RNN and GRU in the same PR because they are very similar.
Documentation
None
Test Plan
Add ONNX models (and Python generator scripts) with PyTorch numerical references.