-
Notifications
You must be signed in to change notification settings - Fork 444
Add Support for MatMulInteger #2072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for your contributions. |
I wanted to ask if this change will require unit tests? If so might as well address that also. |
Yes, please add a unit test into here. |
Thanks for your contribution. Could you please follow the details of DCO failure to make it pass? And also please add a new unit test here. |
MatMulInteger was supported in ONNX opset v10 (not checked in proposed change, the error can be addressed on save), this specific type combination is support in TensorFlow, but the node type not identified and handled properly here. Handles onnx#2071 Signed-off-by: Gregory Morse <[email protected]>
Signed-off-by: Gregory Morse <[email protected]>
Signed-off-by: Gregory Morse <[email protected]>
Signed-off-by: Gregory Morse <[email protected]>
Signed-off-by: Gregory Morse <[email protected]>
Signed-off-by: Gregory Morse <[email protected]>
Signed-off-by: Gregory Morse <[email protected]>
Signed-off-by: Gregory Morse <[email protected]>
Signed-off-by: Gregory Morse <[email protected]>
Signed-off-by: Gregory Morse <[email protected]>
All set, kindly ask for a review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your efforts, LGTM!
* Add Support for MatMulInteger MatMulInteger was supported in ONNX opset v10 (not checked in proposed change, the error can be addressed on save), this specific type combination is support in TensorFlow, but the node type not identified and handled properly here. Handles onnx#2071 Signed-off-by: Gregory Morse <[email protected]> * Update math.py Signed-off-by: Gregory Morse <[email protected]> * Update support_status.md Signed-off-by: Gregory Morse <[email protected]> * Update test_backend.py Signed-off-by: Gregory Morse <[email protected]> Signed-off-by: Gregory Morse <[email protected]> Co-authored-by: Jay Zhang <[email protected]> Signed-off-by: Salvetti, Francesco <[email protected]>
MatMulInteger was supported in ONNX opset v10, this specific type combination is support in TensorFlow, but the node type not identified and handled properly here.
Handles #2071