@@ -191,23 +191,16 @@ auto element_wise_registrations TRTORCH_UNUSED =
191
191
auto self = args[0 ].ITensorOrFreeze (ctx);
192
192
auto other = args[1 ].unwrapToScalar ().to <float >();
193
193
auto alpha = args[2 ].unwrapToScalar ().to <float >();
194
+ auto scaled_val = other * alpha;
194
195
195
- auto rhs = other * alpha;
196
- if (1 != rhs) {
197
- auto rhs_tensor = tensor_to_const (ctx, torch::tensor ({rhs}));
198
- auto sub = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , self, rhs_tensor, util::node_info (n));
199
- TRTORCH_CHECK (sub, " Unable to create sub layer from node: " << *n);
200
- sub->setName (util::node_info (n).c_str ());
201
- LOG_DEBUG (" Output tensor shape: " << sub->getOutput (0 )->getDimensions ());
202
- ctx->AssociateValueAndTensor (n->outputs ()[0 ], sub->getOutput (0 ));
203
- return true ;
204
- } else {
205
- LOG_DEBUG (" Nothing to be done this layer, passing through input" );
206
- LOG_DEBUG (" Output tensor shape: " << self->getDimensions ());
207
-
208
- ctx->AssociateValueAndTensor (n->outputs ()[0 ], self);
209
- return true ;
210
- }
196
+ auto scaled_other_tensor = tensor_to_const (ctx, torch::tensor ({scaled_val}));
197
+ auto sub = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , self, scaled_other_tensor, util::node_info (n));
198
+ TRTORCH_CHECK (sub, " Unable to create sub layer from node: " << *n);
199
+ sub->setName (util::node_info (n).c_str ());
200
+ LOG_DEBUG (" Output tensor shape: " << sub->getOutput (0 )->getDimensions ());
201
+ ctx->AssociateValueAndTensor (n->outputs ()[0 ], sub->getOutput (0 ));
202
+
203
+ return true ;
211
204
}})
212
205
.pattern({" aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar "
213
206
" alpha=1) -> (Tensor(a!))" ,
0 commit comments