Skip to content

Commit 9a09514

Browse files
committed
fix: fix aten::sub.scalar operator
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 2760b8d commit 9a09514

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -191,23 +191,16 @@ auto element_wise_registrations TRTORCH_UNUSED =
191191
auto self = args[0].ITensorOrFreeze(ctx);
192192
auto other = args[1].unwrapToScalar().to<float>();
193193
auto alpha = args[2].unwrapToScalar().to<float>();
194+
auto scaled_val = other * alpha;
194195

195-
auto rhs = other * alpha;
196-
if (1 != rhs) {
197-
auto rhs_tensor = tensor_to_const(ctx, torch::tensor({rhs}));
198-
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, rhs_tensor, util::node_info(n));
199-
TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n);
200-
sub->setName(util::node_info(n).c_str());
201-
LOG_DEBUG("Output tensor shape: " << sub->getOutput(0)->getDimensions());
202-
ctx->AssociateValueAndTensor(n->outputs()[0], sub->getOutput(0));
203-
return true;
204-
} else {
205-
LOG_DEBUG("Nothing to be done this layer, passing through input");
206-
LOG_DEBUG("Output tensor shape: " << self->getDimensions());
207-
208-
ctx->AssociateValueAndTensor(n->outputs()[0], self);
209-
return true;
210-
}
196+
auto scaled_other_tensor = tensor_to_const(ctx, torch::tensor({scaled_val}));
197+
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, scaled_other_tensor, util::node_info(n));
198+
TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n);
199+
sub->setName(util::node_info(n).c_str());
200+
LOG_DEBUG("Output tensor shape: " << sub->getOutput(0)->getDimensions());
201+
ctx->AssociateValueAndTensor(n->outputs()[0], sub->getOutput(0));
202+
203+
return true;
211204
}})
212205
.pattern({"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar "
213206
"alpha=1) -> (Tensor(a!))",

0 commit comments

Comments
 (0)