-
Notifications
You must be signed in to change notification settings - Fork 364
Add support for bool in elementwise ops #321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: inocsin <[email protected]>
@@ -359,8 +359,22 @@ auto element_wise_registrations TRTORCH_UNUSED = | |||
// Should implement self * other | |||
auto self = args[0].ITensorOrFreeze(ctx); | |||
auto other = args[1].ITensorOrFreeze(ctx); | |||
auto mul = | |||
nvinfer1::ILayer* mul = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@inocsin Are these changes that would only be relevant to mul or other ops in general? Also did you write tests for this? I didnt see them in the commits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, I only find the Bool * Int operation in my models. I think other operation like Int / Bool or Int +/- Bool doesn't make any sense. I will add test case later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input type of trtorch is set to float in default, I have to modify the TRTorch/core/conversion/conversion.cpp:150
auto trt_in = ctx->net->addInput(name.c_str(), ctx->input_type, dims.input_shape);
to support input type of Bool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can try this #327
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does it mean to do bool * int? should that give you a bool out? like what does False * 8 mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[True, False] * [8, 8] = [8, 0], this can be a mask operation, check demo graph here #327
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can close this PR if we merge #327
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/core/conversion/converters/impl/element_wise.cpp b/tmp/changes.txt
index da582c9..f35c52e 100644
--- a/workspace/core/conversion/converters/impl/element_wise.cpp
+++ b/tmp/changes.txt
@@ -360,7 +360,7 @@ auto element_wise_registrations TRTORCH_UNUSED =
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
nvinfer1::ILayer* mul = nullptr;
- if (self->getType() ==nvinfer1::DataType::kBOOL || other->getType() == nvinfer1::DataType::kBOOL) {
+ if (self->getType() == nvinfer1::DataType::kBOOL || other->getType() == nvinfer1::DataType::kBOOL) {
auto self_id = ctx->net->addIdentity(*self);
auto other_id = ctx->net->addIdentity(*other);
if (self->getType() == nvinfer1::DataType::kBOOL) {
@@ -369,11 +369,15 @@ auto element_wise_registrations TRTORCH_UNUSED =
if (other->getType() == nvinfer1::DataType::kBOOL) {
other_id->getOutput(0)->setType(nvinfer1::DataType::kINT32);
}
- mul =
- add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self_id->getOutput(0), other_id->getOutput(0), util::node_info(n));
+ mul = add_elementwise(
+ ctx,
+ nvinfer1::ElementWiseOperation::kPROD,
+ self_id->getOutput(0),
+ other_id->getOutput(0),
+ util::node_info(n));
} else {
mul =
- add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
+ add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
}
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);
ERROR: Some files do not conform to style guidelines
We support aten::to operator now which should handle this. Check out https://github.com/NVIDIA/TRTorch/blob/master/tests/core/conversion/converters/test_cast.cpp for usage. |
Description
Adds support for bool tensors in element wise ops
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: