21
21
from .converter_utils import * # noqa: F403
22
22
import torch_tensorrt .fx .tracer .acc_tracer .acc_utils as acc_utils
23
23
from torch_tensorrt .fx .converters .impl import activation , convolution
24
+ from torch_tensorrt .fx .converters .impl .elementwise import trunc_div
25
+ from torch_tensorrt .fx .converters .impl .elementwise import rsqrt
26
+ from torch_tensorrt .fx .converters .impl .elementwise import fmod
27
+ from torch_tensorrt .fx .converters .impl .elementwise import rsub
28
+ from torch_tensorrt .fx .converters .impl .normalization import batch_norm
29
+ from torch_tensorrt .fx .converters .impl .normalization import layer_norm
30
+ from torch_tensorrt .fx .converters .impl .normalization import softmax
31
+ from torch_tensorrt .fx .converters .impl .squeeze import squeeze
32
+ from torch_tensorrt .fx .converters .impl .select import select
33
+ from torch_tensorrt .fx .converters .impl .slice import slice_op
34
+ from torch_tensorrt .fx .converters .impl .matmul import matrix_multiply
35
+ from torch_tensorrt .fx .converters .impl .condition import where
36
+ from torch_tensorrt .fx .converters .impl .unsqueeze import unsqueeze
37
+ from torch_tensorrt .fx .converters .impl .elementwise import clamp
24
38
25
39
_LOGGER : logging .Logger = logging .getLogger (__name__ )
26
40
41
+
42
+ def or_none (args , i ):
43
+ return args [i ] if len (args ) > i else None
44
+
45
+
27
46
## converter list in alphabetic order
28
47
@tensorrt_converter (torch .ops .aten .add .Tensor )
29
48
def aten_ops_add (
@@ -89,18 +108,19 @@ def aten_ops_batch_norm(
89
108
kwargs : Dict [str , Argument ],
90
109
name : str ,
91
110
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
92
- kwargs_new = {
93
- "input" : args [0 ],
94
- "weight" : args [1 ],
95
- "bias" : args [2 ],
96
- "running_mean" : args [3 ],
97
- "running_var" : args [4 ],
98
- "training" : args [5 ],
99
- "momentum" : args [6 ],
100
- "eps" : args [7 ],
101
- }
102
- return acc_ops_converters .acc_ops_batch_norm (
103
- network , target , None , kwargs_new , name
111
+ return batch_norm (
112
+ network ,
113
+ target ,
114
+ SourceIR .ATEN ,
115
+ name ,
116
+ args [0 ],
117
+ args [1 ],
118
+ args [2 ],
119
+ args [3 ],
120
+ args [4 ],
121
+ args [5 ],
122
+ args [6 ],
123
+ args [7 ],
104
124
)
105
125
106
126
@@ -182,9 +202,7 @@ def aten_ops_div(
182
202
network , target , None , kwargs_new , name
183
203
)
184
204
elif rounding_mode == "trunc" :
185
- return acc_ops_converters .acc_ops_trunc_div (
186
- network , target , None , kwargs_new , name
187
- )
205
+ return trunc_div (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
188
206
else :
189
207
raise RuntimeError (
190
208
f"Target { target } does not support rounding mode { rounding_mode } "
@@ -242,11 +260,7 @@ def aten_ops_fmod(
242
260
kwargs : Dict [str , Argument ],
243
261
name : str ,
244
262
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
245
- kwargs_new = {
246
- "input" : args [0 ],
247
- "other" : args [1 ],
248
- }
249
- return acc_ops_converters .acc_ops_fmod (network , target , None , kwargs_new , name )
263
+ return fmod (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
250
264
251
265
252
266
@tensorrt_converter (torch .ops .aten .hardtanh .default )
@@ -257,12 +271,40 @@ def aten_ops_hardtanh(
257
271
kwargs : Dict [str , Argument ],
258
272
name : str ,
259
273
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
260
-
261
274
return activation .hardtanh (
262
275
network , target , SourceIR .ATEN , name , args [0 ], args [1 ], args [2 ]
263
276
)
264
277
265
278
279
+ @tensorrt_converter (torch .ops .aten .gelu .default )
280
+ def aten_ops_gelu (
281
+ network : TRTNetwork ,
282
+ target : Target ,
283
+ args : Tuple [Argument , ...],
284
+ kwargs : Dict [str , Argument ],
285
+ name : str ,
286
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
287
+ return activation .gelu (
288
+ network ,
289
+ target ,
290
+ SourceIR .ATEN ,
291
+ name ,
292
+ args [0 ],
293
+ )
294
+
295
+
296
+ @tensorrt_converter (torch .ops .aten .matmul )
297
+ @tensorrt_converter (torch .ops .aten .mm .default )
298
+ def aten_ops_matmul (
299
+ network : TRTNetwork ,
300
+ target : Target ,
301
+ args : Tuple [Argument , ...],
302
+ kwargs : Dict [str , Argument ],
303
+ name : str ,
304
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
305
+ return matrix_multiply (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
306
+
307
+
266
308
@tensorrt_converter (torch .ops .aten .fmod .Tensor )
267
309
def aten_ops_fmod (
268
310
network : TRTNetwork ,
@@ -286,10 +328,30 @@ def aten_ops_leaky_relu(
286
328
kwargs : Dict [str , Argument ],
287
329
name : str ,
288
330
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
289
-
290
331
return activation .leaky_relu (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
291
332
292
333
334
+ @tensorrt_converter (torch .ops .aten .layer_norm .default )
335
+ def aten_ops_layernorm (
336
+ network : TRTNetwork ,
337
+ target : Target ,
338
+ args : Tuple [Argument , ...],
339
+ kwargs : Dict [str , Argument ],
340
+ name : str ,
341
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
342
+ return layer_norm (
343
+ network ,
344
+ target ,
345
+ SourceIR .ATEN ,
346
+ name ,
347
+ args [0 ],
348
+ args [1 ],
349
+ args [2 ],
350
+ args [3 ],
351
+ args [4 ],
352
+ )
353
+
354
+
293
355
@tensorrt_converter (torch .ops .aten .linear )
294
356
def aten_ops_linear (
295
357
network : TRTNetwork ,
@@ -390,6 +452,42 @@ def aten_ops_relu(
390
452
)
391
453
392
454
455
+ @tensorrt_converter (torch .ops .aten .relu .default )
456
+ def aten_ops_relu (
457
+ network : TRTNetwork ,
458
+ target : Target ,
459
+ args : Tuple [Argument , ...],
460
+ kwargs : Dict [str , Argument ],
461
+ name : str ,
462
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
463
+
464
+ return activation .relu (
465
+ network ,
466
+ target ,
467
+ SourceIR .ATEN ,
468
+ name ,
469
+ args [0 ],
470
+ )
471
+
472
+
473
+ @tensorrt_converter (torch .ops .aten .rsqrt .default )
474
+ def aten_ops_rsqrt (
475
+ network : TRTNetwork ,
476
+ target : Target ,
477
+ args : Tuple [Argument , ...],
478
+ kwargs : Dict [str , Argument ],
479
+ name : str ,
480
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
481
+
482
+ return rsqrt (
483
+ network ,
484
+ target ,
485
+ SourceIR .ATEN ,
486
+ name ,
487
+ args [0 ],
488
+ )
489
+
490
+
393
491
@tensorrt_converter (torch .ops .aten .sub .Tensor )
394
492
def aten_ops_sub (
395
493
network : TRTNetwork ,
@@ -405,6 +503,29 @@ def aten_ops_sub(
405
503
return acc_ops_converters .acc_ops_sub (network , target , None , kwargs_new , name )
406
504
407
505
506
+ @tensorrt_converter (torch .ops .aten .squeeze .dim )
507
+ @tensorrt_converter (torch .ops .aten .squeeze .dims )
508
+ def aten_ops_squeeze (
509
+ network : TRTNetwork ,
510
+ target : Target ,
511
+ args : Tuple [Argument , ...],
512
+ kwargs : Dict [str , Argument ],
513
+ name : str ,
514
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
515
+ return squeeze (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
516
+
517
+
518
+ @tensorrt_converter (torch .ops .aten .unsqueeze .default )
519
+ def aten_ops_unsqueeze (
520
+ network : TRTNetwork ,
521
+ target : Target ,
522
+ args : Tuple [Argument , ...],
523
+ kwargs : Dict [str , Argument ],
524
+ name : str ,
525
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
526
+ return unsqueeze (network , target , SourceIR .ATEN , name , input_t = args [0 ], dim = args [1 ])
527
+
528
+
408
529
@tensorrt_converter (torch .ops .aten .view .default )
409
530
def aten_ops_reshape (
410
531
network : TRTNetwork ,
@@ -442,6 +563,31 @@ def aten_ops_reshape(
442
563
return layer .get_output (0 )
443
564
444
565
566
+ @tensorrt_converter (torch .ops .aten .rsub .Tensor )
567
+ def aten_ops_rsub (
568
+ network : TRTNetwork ,
569
+ target : Target ,
570
+ args : Tuple [Argument , ...],
571
+ kwargs : Dict [str , Argument ],
572
+ name : str ,
573
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
574
+ alpha = None
575
+ if "alpha" in kwargs :
576
+ alpha = kwargs ["alpha" ]
577
+ return rsub (network , target , SourceIR .ATEN , name , args [0 ], args [1 ], alpha )
578
+
579
+
580
+ @tensorrt_converter (torch .ops .aten ._softmax .default )
581
+ def aten_ops_softmax (
582
+ network : TRTNetwork ,
583
+ target : Target ,
584
+ args : Tuple [Argument , ...],
585
+ kwargs : Dict [str , Argument ],
586
+ name : str ,
587
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
588
+ return softmax (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
589
+
590
+
445
591
@tensorrt_converter (torch .ops .aten .tanh .default )
446
592
def aten_ops_tanh (
447
593
network : TRTNetwork ,
@@ -450,7 +596,6 @@ def aten_ops_tanh(
450
596
kwargs : Dict [str , Argument ],
451
597
name : str ,
452
598
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
453
-
454
599
return activation .tanh (
455
600
network ,
456
601
target ,
@@ -460,6 +605,25 @@ def aten_ops_tanh(
460
605
)
461
606
462
607
608
+ @tensorrt_converter (torch .ops .aten .where .self )
609
+ def aten_ops_where (
610
+ network : TRTNetwork ,
611
+ target : Target ,
612
+ args : Tuple [Argument , ...],
613
+ kwargs : Dict [str , Argument ],
614
+ name : str ,
615
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
616
+ return where (
617
+ network ,
618
+ target ,
619
+ SourceIR .ATEN ,
620
+ name ,
621
+ args [1 ],
622
+ args [2 ],
623
+ args [0 ],
624
+ )
625
+
626
+
463
627
@tensorrt_converter (torch .ops .aten .cat .default )
464
628
def aten_ops_cat (
465
629
network : TRTNetwork ,
@@ -475,6 +639,25 @@ def aten_ops_cat(
475
639
return acc_ops_converters .acc_ops_cat (network , target , None , kwargs_new , name )
476
640
477
641
642
+ @tensorrt_converter (torch .ops .aten .clamp .default )
643
+ def aten_ops_clamp (
644
+ network : TRTNetwork ,
645
+ target : Target ,
646
+ args : Tuple [Argument , ...],
647
+ kwargs : Dict [str , Argument ],
648
+ name : str ,
649
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
650
+ return clamp .clamp (
651
+ network ,
652
+ target ,
653
+ SourceIR .ACC ,
654
+ name ,
655
+ input_val = args [0 ],
656
+ min_val = or_none (args , 1 ),
657
+ max_val = or_none (args , 2 ),
658
+ )
659
+
660
+
478
661
@tensorrt_converter (torch .ops .aten .expand .default )
479
662
def aten_ops_expand (
480
663
network : TRTNetwork ,
@@ -537,6 +720,17 @@ def aten_ops_operator_add(
537
720
return acc_ops_converters .acc_ops_add (network , target , None , kwargs_new , name )
538
721
539
722
723
+ @tensorrt_converter (torch .ops .aten .select .int )
724
+ def aten_ops_select (
725
+ network : TRTNetwork ,
726
+ target : Target ,
727
+ args : Tuple [Argument , ...],
728
+ kwargs : Dict [str , Argument ],
729
+ name : str ,
730
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
731
+ return select (network , target , SourceIR .ATEN , name , args [0 ], args [1 ], args [2 ])
732
+
733
+
540
734
@tensorrt_converter (operator .sub )
541
735
def aten_ops_operator_sub (
542
736
network : TRTNetwork ,
@@ -572,6 +766,27 @@ def aten_ops_sym_numel(
572
766
return reduce_layer .get_output (0 )
573
767
574
768
769
+ @tensorrt_converter (torch .ops .aten .slice .Tensor )
770
+ def aten_ops_slice (
771
+ network : TRTNetwork ,
772
+ target : Target ,
773
+ args : Tuple [Argument , ...],
774
+ kwargs : Dict [str , Argument ],
775
+ name : str ,
776
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
777
+ return slice_op (
778
+ network ,
779
+ target ,
780
+ SourceIR .ATEN ,
781
+ name ,
782
+ args [0 ],
783
+ args [1 ],
784
+ args [2 ],
785
+ args [3 ],
786
+ args [4 ],
787
+ )
788
+
789
+
575
790
@tensorrt_converter (torch .ops .aten .sym_size )
576
791
def aten_ops_sym_size (
577
792
network : TRTNetwork ,
0 commit comments