Skip to content

Commit 8275d2b

Browse files
narendasangs-olive
authored andcommitted
Combination: 16 commits with aten improvements
refactor: Moving elementwise and unary core to impl Signed-off-by: Naren Dasan <[email protected]> new file: ../converters/impl/unary/base.py Moving elementwise core to impl - rsqrt (FX Converter Refactor [9/N]) <Target: converter_reorg_elementwise> (#1905) Converter reorg fmod Converter reorg and rsub Rsub error fixes and linting error fixed Rsub test case to include different inputs Converter reorg batch norm batch norm error fix and linting issue error fix layer_norm converter Layer norm linting correction ops file correction fixing lint Acc_ops layer_norm correction Converter reorg and softmax operation softmax linting error fix Converter reorg and gelu Linting error Converter reorg and squeeze operator Correcting squeeze operator implementation, linting error and acc squeeze test Adding the condition to convert dim to int and removing the comment Converter reorg and select operation select operation correction and linting changes converter reorg and slice converter reorg slice op Correcting linting error and slice changes Correcting the slice operation converter reorg and matmul Matmul issue fixes and lint error check moving matmul to individual file Converter reorg and where operator adding where aten op aten::where correction and linting error changes aten::unsqueeze impl refactor Signed-off-by: Boris Fomitchev <[email protected]> Moved clamp to impl Signed-off-by: Boris Fomitchev <[email protected]> fixed method name Signed-off-by: Boris Fomitchev <[email protected]> fix: Add automatic type promotion for FX ops - Implement functionality to cast tensors to alternative types - Add functionality to elementwise ops to promote types and perform necessary casts - Address issues in FX ops where mixed-precision computations can cause errors - Add test cases to validate fix
1 parent 0d1fc3b commit 8275d2b

38 files changed

+3076
-850
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+388-489
Large diffs are not rendered by default.

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+238-23
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,28 @@
2121
from .converter_utils import * # noqa: F403
2222
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2323
from torch_tensorrt.fx.converters.impl import activation, convolution
24+
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
25+
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt
26+
from torch_tensorrt.fx.converters.impl.elementwise import fmod
27+
from torch_tensorrt.fx.converters.impl.elementwise import rsub
28+
from torch_tensorrt.fx.converters.impl.normalization import batch_norm
29+
from torch_tensorrt.fx.converters.impl.normalization import layer_norm
30+
from torch_tensorrt.fx.converters.impl.normalization import softmax
31+
from torch_tensorrt.fx.converters.impl.squeeze import squeeze
32+
from torch_tensorrt.fx.converters.impl.select import select
33+
from torch_tensorrt.fx.converters.impl.slice import slice_op
34+
from torch_tensorrt.fx.converters.impl.matmul import matrix_multiply
35+
from torch_tensorrt.fx.converters.impl.condition import where
36+
from torch_tensorrt.fx.converters.impl.unsqueeze import unsqueeze
37+
from torch_tensorrt.fx.converters.impl.elementwise import clamp
2438

2539
_LOGGER: logging.Logger = logging.getLogger(__name__)
2640

41+
42+
def or_none(args, i):
43+
return args[i] if len(args) > i else None
44+
45+
2746
## converter list in alphabetic order
2847
@tensorrt_converter(torch.ops.aten.add.Tensor)
2948
def aten_ops_add(
@@ -89,18 +108,19 @@ def aten_ops_batch_norm(
89108
kwargs: Dict[str, Argument],
90109
name: str,
91110
) -> Union[TRTTensor, Sequence[TRTTensor]]:
92-
kwargs_new = {
93-
"input": args[0],
94-
"weight": args[1],
95-
"bias": args[2],
96-
"running_mean": args[3],
97-
"running_var": args[4],
98-
"training": args[5],
99-
"momentum": args[6],
100-
"eps": args[7],
101-
}
102-
return acc_ops_converters.acc_ops_batch_norm(
103-
network, target, None, kwargs_new, name
111+
return batch_norm(
112+
network,
113+
target,
114+
SourceIR.ATEN,
115+
name,
116+
args[0],
117+
args[1],
118+
args[2],
119+
args[3],
120+
args[4],
121+
args[5],
122+
args[6],
123+
args[7],
104124
)
105125

106126

@@ -182,9 +202,7 @@ def aten_ops_div(
182202
network, target, None, kwargs_new, name
183203
)
184204
elif rounding_mode == "trunc":
185-
return acc_ops_converters.acc_ops_trunc_div(
186-
network, target, None, kwargs_new, name
187-
)
205+
return trunc_div(network, target, SourceIR.ATEN, name, args[0], args[1])
188206
else:
189207
raise RuntimeError(
190208
f"Target {target} does not support rounding mode {rounding_mode}"
@@ -242,11 +260,7 @@ def aten_ops_fmod(
242260
kwargs: Dict[str, Argument],
243261
name: str,
244262
) -> Union[TRTTensor, Sequence[TRTTensor]]:
245-
kwargs_new = {
246-
"input": args[0],
247-
"other": args[1],
248-
}
249-
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)
263+
return fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
250264

251265

252266
@tensorrt_converter(torch.ops.aten.hardtanh.default)
@@ -257,12 +271,40 @@ def aten_ops_hardtanh(
257271
kwargs: Dict[str, Argument],
258272
name: str,
259273
) -> Union[TRTTensor, Sequence[TRTTensor]]:
260-
261274
return activation.hardtanh(
262275
network, target, SourceIR.ATEN, name, args[0], args[1], args[2]
263276
)
264277

265278

279+
@tensorrt_converter(torch.ops.aten.gelu.default)
280+
def aten_ops_gelu(
281+
network: TRTNetwork,
282+
target: Target,
283+
args: Tuple[Argument, ...],
284+
kwargs: Dict[str, Argument],
285+
name: str,
286+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
287+
return activation.gelu(
288+
network,
289+
target,
290+
SourceIR.ATEN,
291+
name,
292+
args[0],
293+
)
294+
295+
296+
@tensorrt_converter(torch.ops.aten.matmul)
297+
@tensorrt_converter(torch.ops.aten.mm.default)
298+
def aten_ops_matmul(
299+
network: TRTNetwork,
300+
target: Target,
301+
args: Tuple[Argument, ...],
302+
kwargs: Dict[str, Argument],
303+
name: str,
304+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
305+
return matrix_multiply(network, target, SourceIR.ATEN, name, args[0], args[1])
306+
307+
266308
@tensorrt_converter(torch.ops.aten.fmod.Tensor)
267309
def aten_ops_fmod(
268310
network: TRTNetwork,
@@ -286,10 +328,30 @@ def aten_ops_leaky_relu(
286328
kwargs: Dict[str, Argument],
287329
name: str,
288330
) -> Union[TRTTensor, Sequence[TRTTensor]]:
289-
290331
return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1])
291332

292333

334+
@tensorrt_converter(torch.ops.aten.layer_norm.default)
335+
def aten_ops_layernorm(
336+
network: TRTNetwork,
337+
target: Target,
338+
args: Tuple[Argument, ...],
339+
kwargs: Dict[str, Argument],
340+
name: str,
341+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
342+
return layer_norm(
343+
network,
344+
target,
345+
SourceIR.ATEN,
346+
name,
347+
args[0],
348+
args[1],
349+
args[2],
350+
args[3],
351+
args[4],
352+
)
353+
354+
293355
@tensorrt_converter(torch.ops.aten.linear)
294356
def aten_ops_linear(
295357
network: TRTNetwork,
@@ -390,6 +452,42 @@ def aten_ops_relu(
390452
)
391453

392454

455+
@tensorrt_converter(torch.ops.aten.relu.default)
456+
def aten_ops_relu(
457+
network: TRTNetwork,
458+
target: Target,
459+
args: Tuple[Argument, ...],
460+
kwargs: Dict[str, Argument],
461+
name: str,
462+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
463+
464+
return activation.relu(
465+
network,
466+
target,
467+
SourceIR.ATEN,
468+
name,
469+
args[0],
470+
)
471+
472+
473+
@tensorrt_converter(torch.ops.aten.rsqrt.default)
474+
def aten_ops_rsqrt(
475+
network: TRTNetwork,
476+
target: Target,
477+
args: Tuple[Argument, ...],
478+
kwargs: Dict[str, Argument],
479+
name: str,
480+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
481+
482+
return rsqrt(
483+
network,
484+
target,
485+
SourceIR.ATEN,
486+
name,
487+
args[0],
488+
)
489+
490+
393491
@tensorrt_converter(torch.ops.aten.sub.Tensor)
394492
def aten_ops_sub(
395493
network: TRTNetwork,
@@ -405,6 +503,29 @@ def aten_ops_sub(
405503
return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name)
406504

407505

506+
@tensorrt_converter(torch.ops.aten.squeeze.dim)
507+
@tensorrt_converter(torch.ops.aten.squeeze.dims)
508+
def aten_ops_squeeze(
509+
network: TRTNetwork,
510+
target: Target,
511+
args: Tuple[Argument, ...],
512+
kwargs: Dict[str, Argument],
513+
name: str,
514+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
515+
return squeeze(network, target, SourceIR.ATEN, name, args[0], args[1])
516+
517+
518+
@tensorrt_converter(torch.ops.aten.unsqueeze.default)
519+
def aten_ops_unsqueeze(
520+
network: TRTNetwork,
521+
target: Target,
522+
args: Tuple[Argument, ...],
523+
kwargs: Dict[str, Argument],
524+
name: str,
525+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
526+
return unsqueeze(network, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1])
527+
528+
408529
@tensorrt_converter(torch.ops.aten.view.default)
409530
def aten_ops_reshape(
410531
network: TRTNetwork,
@@ -442,6 +563,31 @@ def aten_ops_reshape(
442563
return layer.get_output(0)
443564

444565

566+
@tensorrt_converter(torch.ops.aten.rsub.Tensor)
567+
def aten_ops_rsub(
568+
network: TRTNetwork,
569+
target: Target,
570+
args: Tuple[Argument, ...],
571+
kwargs: Dict[str, Argument],
572+
name: str,
573+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
574+
alpha = None
575+
if "alpha" in kwargs:
576+
alpha = kwargs["alpha"]
577+
return rsub(network, target, SourceIR.ATEN, name, args[0], args[1], alpha)
578+
579+
580+
@tensorrt_converter(torch.ops.aten._softmax.default)
581+
def aten_ops_softmax(
582+
network: TRTNetwork,
583+
target: Target,
584+
args: Tuple[Argument, ...],
585+
kwargs: Dict[str, Argument],
586+
name: str,
587+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
588+
return softmax(network, target, SourceIR.ATEN, name, args[0], args[1])
589+
590+
445591
@tensorrt_converter(torch.ops.aten.tanh.default)
446592
def aten_ops_tanh(
447593
network: TRTNetwork,
@@ -450,7 +596,6 @@ def aten_ops_tanh(
450596
kwargs: Dict[str, Argument],
451597
name: str,
452598
) -> Union[TRTTensor, Sequence[TRTTensor]]:
453-
454599
return activation.tanh(
455600
network,
456601
target,
@@ -460,6 +605,25 @@ def aten_ops_tanh(
460605
)
461606

462607

608+
@tensorrt_converter(torch.ops.aten.where.self)
609+
def aten_ops_where(
610+
network: TRTNetwork,
611+
target: Target,
612+
args: Tuple[Argument, ...],
613+
kwargs: Dict[str, Argument],
614+
name: str,
615+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
616+
return where(
617+
network,
618+
target,
619+
SourceIR.ATEN,
620+
name,
621+
args[1],
622+
args[2],
623+
args[0],
624+
)
625+
626+
463627
@tensorrt_converter(torch.ops.aten.cat.default)
464628
def aten_ops_cat(
465629
network: TRTNetwork,
@@ -475,6 +639,25 @@ def aten_ops_cat(
475639
return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name)
476640

477641

642+
@tensorrt_converter(torch.ops.aten.clamp.default)
643+
def aten_ops_clamp(
644+
network: TRTNetwork,
645+
target: Target,
646+
args: Tuple[Argument, ...],
647+
kwargs: Dict[str, Argument],
648+
name: str,
649+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
650+
return clamp.clamp(
651+
network,
652+
target,
653+
SourceIR.ACC,
654+
name,
655+
input_val=args[0],
656+
min_val=or_none(args, 1),
657+
max_val=or_none(args, 2),
658+
)
659+
660+
478661
@tensorrt_converter(torch.ops.aten.expand.default)
479662
def aten_ops_expand(
480663
network: TRTNetwork,
@@ -537,6 +720,17 @@ def aten_ops_operator_add(
537720
return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name)
538721

539722

723+
@tensorrt_converter(torch.ops.aten.select.int)
724+
def aten_ops_select(
725+
network: TRTNetwork,
726+
target: Target,
727+
args: Tuple[Argument, ...],
728+
kwargs: Dict[str, Argument],
729+
name: str,
730+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
731+
return select(network, target, SourceIR.ATEN, name, args[0], args[1], args[2])
732+
733+
540734
@tensorrt_converter(operator.sub)
541735
def aten_ops_operator_sub(
542736
network: TRTNetwork,
@@ -572,6 +766,27 @@ def aten_ops_sym_numel(
572766
return reduce_layer.get_output(0)
573767

574768

769+
@tensorrt_converter(torch.ops.aten.slice.Tensor)
770+
def aten_ops_slice(
771+
network: TRTNetwork,
772+
target: Target,
773+
args: Tuple[Argument, ...],
774+
kwargs: Dict[str, Argument],
775+
name: str,
776+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
777+
return slice_op(
778+
network,
779+
target,
780+
SourceIR.ATEN,
781+
name,
782+
args[0],
783+
args[1],
784+
args[2],
785+
args[3],
786+
args[4],
787+
)
788+
789+
575790
@tensorrt_converter(torch.ops.aten.sym_size)
576791
def aten_ops_sym_size(
577792
network: TRTNetwork,

0 commit comments

Comments
 (0)