Skip to content

Conversation

narendasan
Copy link
Collaborator

Description

Adds support for bool tensors in element wise ops

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler labels Feb 4, 2021
@@ -359,8 +359,22 @@ auto element_wise_registrations TRTORCH_UNUSED =
// Should implement self * other
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
auto mul =
nvinfer1::ILayer* mul = nullptr;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inocsin Are these changes that would only be relevant to mul or other ops in general? Also did you write tests for this? I didnt see them in the commits

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, I only find the Bool * Int operation in my models. I think other operation like Int / Bool or Int +/- Bool doesn't make any sense. I will add test case later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input type of trtorch is set to float in default, I have to modify the TRTorch/core/conversion/conversion.cpp:150

auto trt_in = ctx->net->addInput(name.c_str(), ctx->input_type, dims.input_shape);

to support input type of Bool

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can try this #327

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean to do bool * int? should that give you a bool out? like what does False * 8 mean?

Copy link
Contributor

@inocsin inocsin Feb 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[True, False] * [8, 8] = [8, 0], this can be a mask operation, check demo graph here #327

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can close this PR if we merge #327

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/workspace/core/conversion/converters/impl/element_wise.cpp b/tmp/changes.txt
index da582c9..f35c52e 100644
--- a/workspace/core/conversion/converters/impl/element_wise.cpp
+++ b/tmp/changes.txt
@@ -360,7 +360,7 @@ auto element_wise_registrations TRTORCH_UNUSED =
                    auto self = args[0].ITensorOrFreeze(ctx);
                    auto other = args[1].ITensorOrFreeze(ctx);
                    nvinfer1::ILayer* mul = nullptr;
-                    if (self->getType() ==nvinfer1::DataType::kBOOL || other->getType() == nvinfer1::DataType::kBOOL) {
+                    if (self->getType() == nvinfer1::DataType::kBOOL || other->getType() == nvinfer1::DataType::kBOOL) {
                      auto self_id = ctx->net->addIdentity(*self);
                      auto other_id = ctx->net->addIdentity(*other);
                      if (self->getType() == nvinfer1::DataType::kBOOL) {
@@ -369,11 +369,15 @@ auto element_wise_registrations TRTORCH_UNUSED =
                      if (other->getType() == nvinfer1::DataType::kBOOL) {
                        other_id->getOutput(0)->setType(nvinfer1::DataType::kINT32);
                      }
-                      mul =
-                        add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self_id->getOutput(0), other_id->getOutput(0), util::node_info(n));
+                      mul = add_elementwise(
+                          ctx,
+                          nvinfer1::ElementWiseOperation::kPROD,
+                          self_id->getOutput(0),
+                          other_id->getOutput(0),
+                          util::node_info(n));
                    } else {
                      mul =
-                        add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
+                          add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
                    }
                    TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);

ERROR: Some files do not conform to style guidelines

@peri044
Copy link
Collaborator

peri044 commented Aug 24, 2021

We support aten::to operator now which should handle this. Check out https://github.com/NVIDIA/TRTorch/blob/master/tests/core/conversion/converters/test_cast.cpp for usage.

@peri044 peri044 closed this Aug 24, 2021
@narendasan narendasan deleted the bool_elem_wise branch February 24, 2022 00:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants