21
21
from .converter_utils import * # noqa: F403
22
22
import torch_tensorrt .fx .tracer .acc_tracer .acc_utils as acc_utils
23
23
from torch_tensorrt .fx .converters .impl import activation
24
- from torch_tensorrt .fx .converters .impl .elementwise import trunc_div
25
- from torch_tensorrt .fx .converters .impl .elementwise import rsqrt
26
- from torch_tensorrt .fx .converters .impl .elementwise import fmod
27
- from torch_tensorrt .fx .converters .impl .elementwise import rsub
28
- from torch_tensorrt .fx .converters .impl .normalization import batch_norm
29
- from torch_tensorrt .fx .converters .impl .normalization import layer_norm
30
- from torch_tensorrt .fx .converters .impl .normalization import softmax
31
- from torch_tensorrt .fx .converters .impl .squeeze import squeeze
32
- from torch_tensorrt .fx .converters .impl .select import select
33
- from torch_tensorrt .fx .converters .impl .slice import slice_op
34
- from torch_tensorrt .fx .converters .impl .matmul import matrix_multiply
35
- from torch_tensorrt .fx .converters .impl .condition import where
36
- from torch_tensorrt .fx .converters .impl .unsqueeze import unsqueeze
37
- from torch_tensorrt .fx .converters .impl .elementwise import clamp
38
24
39
25
_LOGGER : logging .Logger = logging .getLogger (__name__ )
40
26
41
-
42
- def or_none (args , i ):
43
- return args [i ] if len (args ) > i else None
44
-
45
-
46
27
## converter list in alphabetic order
47
28
@tensorrt_converter (torch .ops .aten .add .Tensor )
48
29
def aten_ops_add (
@@ -108,19 +89,18 @@ def aten_ops_batch_norm(
108
89
kwargs : Dict [str , Argument ],
109
90
name : str ,
110
91
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
111
- return batch_norm (
112
- network ,
113
- target ,
114
- SourceIR .ATEN ,
115
- name ,
116
- args [0 ],
117
- args [1 ],
118
- args [2 ],
119
- args [3 ],
120
- args [4 ],
121
- args [5 ],
122
- args [6 ],
123
- args [7 ],
92
+ kwargs_new = {
93
+ "input" : args [0 ],
94
+ "weight" : args [1 ],
95
+ "bias" : args [2 ],
96
+ "running_mean" : args [3 ],
97
+ "running_var" : args [4 ],
98
+ "training" : args [5 ],
99
+ "momentum" : args [6 ],
100
+ "eps" : args [7 ],
101
+ }
102
+ return acc_ops_converters .acc_ops_batch_norm (
103
+ network , target , None , kwargs_new , name
124
104
)
125
105
126
106
@@ -179,7 +159,9 @@ def aten_ops_div(
179
159
network , target , None , kwargs_new , name
180
160
)
181
161
elif rounding_mode == "trunc" :
182
- return trunc_div (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
162
+ return acc_ops_converters .acc_ops_trunc_div (
163
+ network , target , None , kwargs_new , name
164
+ )
183
165
else :
184
166
raise RuntimeError (
185
167
f"Target { target } does not support rounding mode { rounding_mode } "
@@ -237,7 +219,11 @@ def aten_ops_fmod(
237
219
kwargs : Dict [str , Argument ],
238
220
name : str ,
239
221
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
240
- return fmod (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
222
+ kwargs_new = {
223
+ "input" : args [0 ],
224
+ "other" : args [1 ],
225
+ }
226
+ return acc_ops_converters .acc_ops_fmod (network , target , None , kwargs_new , name )
241
227
242
228
243
229
@tensorrt_converter (torch .ops .aten .hardtanh .default )
@@ -248,40 +234,12 @@ def aten_ops_hardtanh(
248
234
kwargs : Dict [str , Argument ],
249
235
name : str ,
250
236
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
237
+
251
238
return activation .hardtanh (
252
239
network , target , SourceIR .ATEN , name , args [0 ], args [1 ], args [2 ]
253
240
)
254
241
255
242
256
- @tensorrt_converter (torch .ops .aten .gelu .default )
257
- def aten_ops_gelu (
258
- network : TRTNetwork ,
259
- target : Target ,
260
- args : Tuple [Argument , ...],
261
- kwargs : Dict [str , Argument ],
262
- name : str ,
263
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
264
- return activation .gelu (
265
- network ,
266
- target ,
267
- SourceIR .ATEN ,
268
- name ,
269
- args [0 ],
270
- )
271
-
272
-
273
- @tensorrt_converter (torch .ops .aten .matmul )
274
- @tensorrt_converter (torch .ops .aten .mm .default )
275
- def aten_ops_matmul (
276
- network : TRTNetwork ,
277
- target : Target ,
278
- args : Tuple [Argument , ...],
279
- kwargs : Dict [str , Argument ],
280
- name : str ,
281
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
282
- return matrix_multiply (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
283
-
284
-
285
243
@tensorrt_converter (torch .ops .aten .fmod .Tensor )
286
244
def aten_ops_fmod (
287
245
network : TRTNetwork ,
@@ -305,28 +263,8 @@ def aten_ops_leaky_relu(
305
263
kwargs : Dict [str , Argument ],
306
264
name : str ,
307
265
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
308
- return activation .leaky_relu (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
309
-
310
266
311
- @tensorrt_converter (torch .ops .aten .layer_norm .default )
312
- def aten_ops_layernorm (
313
- network : TRTNetwork ,
314
- target : Target ,
315
- args : Tuple [Argument , ...],
316
- kwargs : Dict [str , Argument ],
317
- name : str ,
318
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
319
- return layer_norm (
320
- network ,
321
- target ,
322
- SourceIR .ATEN ,
323
- name ,
324
- args [0 ],
325
- args [1 ],
326
- args [2 ],
327
- args [3 ],
328
- args [4 ],
329
- )
267
+ return activation .leaky_relu (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
330
268
331
269
332
270
@tensorrt_converter (torch .ops .aten .linear )
@@ -429,42 +367,6 @@ def aten_ops_relu(
429
367
)
430
368
431
369
432
- @tensorrt_converter (torch .ops .aten .relu .default )
433
- def aten_ops_relu (
434
- network : TRTNetwork ,
435
- target : Target ,
436
- args : Tuple [Argument , ...],
437
- kwargs : Dict [str , Argument ],
438
- name : str ,
439
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
440
-
441
- return activation .relu (
442
- network ,
443
- target ,
444
- SourceIR .ATEN ,
445
- name ,
446
- args [0 ],
447
- )
448
-
449
-
450
- @tensorrt_converter (torch .ops .aten .rsqrt .default )
451
- def aten_ops_rsqrt (
452
- network : TRTNetwork ,
453
- target : Target ,
454
- args : Tuple [Argument , ...],
455
- kwargs : Dict [str , Argument ],
456
- name : str ,
457
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
458
-
459
- return rsqrt (
460
- network ,
461
- target ,
462
- SourceIR .ATEN ,
463
- name ,
464
- args [0 ],
465
- )
466
-
467
-
468
370
@tensorrt_converter (torch .ops .aten .sub .Tensor )
469
371
def aten_ops_sub (
470
372
network : TRTNetwork ,
@@ -480,29 +382,6 @@ def aten_ops_sub(
480
382
return acc_ops_converters .acc_ops_sub (network , target , None , kwargs_new , name )
481
383
482
384
483
- @tensorrt_converter (torch .ops .aten .squeeze .dim )
484
- @tensorrt_converter (torch .ops .aten .squeeze .dims )
485
- def aten_ops_squeeze (
486
- network : TRTNetwork ,
487
- target : Target ,
488
- args : Tuple [Argument , ...],
489
- kwargs : Dict [str , Argument ],
490
- name : str ,
491
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
492
- return squeeze (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
493
-
494
-
495
- @tensorrt_converter (torch .ops .aten .unsqueeze .default )
496
- def aten_ops_unsqueeze (
497
- network : TRTNetwork ,
498
- target : Target ,
499
- args : Tuple [Argument , ...],
500
- kwargs : Dict [str , Argument ],
501
- name : str ,
502
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
503
- return unsqueeze (network , target , SourceIR .ATEN , name , input_t = args [0 ], dim = args [1 ])
504
-
505
-
506
385
@tensorrt_converter (torch .ops .aten .view .default )
507
386
def aten_ops_reshape (
508
387
network : TRTNetwork ,
@@ -540,31 +419,6 @@ def aten_ops_reshape(
540
419
return layer .get_output (0 )
541
420
542
421
543
- @tensorrt_converter (torch .ops .aten .rsub .Tensor )
544
- def aten_ops_rsub (
545
- network : TRTNetwork ,
546
- target : Target ,
547
- args : Tuple [Argument , ...],
548
- kwargs : Dict [str , Argument ],
549
- name : str ,
550
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
551
- alpha = None
552
- if "alpha" in kwargs :
553
- alpha = kwargs ["alpha" ]
554
- return rsub (network , target , SourceIR .ATEN , name , args [0 ], args [1 ], alpha )
555
-
556
-
557
- @tensorrt_converter (torch .ops .aten ._softmax .default )
558
- def aten_ops_softmax (
559
- network : TRTNetwork ,
560
- target : Target ,
561
- args : Tuple [Argument , ...],
562
- kwargs : Dict [str , Argument ],
563
- name : str ,
564
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
565
- return softmax (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
566
-
567
-
568
422
@tensorrt_converter (torch .ops .aten .tanh .default )
569
423
def aten_ops_tanh (
570
424
network : TRTNetwork ,
@@ -573,30 +427,12 @@ def aten_ops_tanh(
573
427
kwargs : Dict [str , Argument ],
574
428
name : str ,
575
429
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
576
- return activation .tanh (
577
- network ,
578
- target ,
579
- SourceIR .ATEN ,
580
- name ,
581
- args [0 ],
582
- )
583
430
584
-
585
- @tensorrt_converter (torch .ops .aten .where .self )
586
- def aten_ops_where (
587
- network : TRTNetwork ,
588
- target : Target ,
589
- args : Tuple [Argument , ...],
590
- kwargs : Dict [str , Argument ],
591
- name : str ,
592
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
593
- return where (
431
+ return activation .tanh (
594
432
network ,
595
433
target ,
596
434
SourceIR .ATEN ,
597
435
name ,
598
- args [1 ],
599
- args [2 ],
600
436
args [0 ],
601
437
)
602
438
@@ -616,25 +452,6 @@ def aten_ops_cat(
616
452
return acc_ops_converters .acc_ops_cat (network , target , None , kwargs_new , name )
617
453
618
454
619
- @tensorrt_converter (torch .ops .aten .clamp .default )
620
- def aten_ops_clamp (
621
- network : TRTNetwork ,
622
- target : Target ,
623
- args : Tuple [Argument , ...],
624
- kwargs : Dict [str , Argument ],
625
- name : str ,
626
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
627
- return clamp .clamp (
628
- network ,
629
- target ,
630
- SourceIR .ACC ,
631
- name ,
632
- input_val = args [0 ],
633
- min_val = or_none (args , 1 ),
634
- max_val = or_none (args , 2 ),
635
- )
636
-
637
-
638
455
@tensorrt_converter (torch .ops .aten .expand .default )
639
456
def aten_ops_expand (
640
457
network : TRTNetwork ,
@@ -697,17 +514,6 @@ def aten_ops_operator_add(
697
514
return acc_ops_converters .acc_ops_add (network , target , None , kwargs_new , name )
698
515
699
516
700
- @tensorrt_converter (torch .ops .aten .select .int )
701
- def aten_ops_select (
702
- network : TRTNetwork ,
703
- target : Target ,
704
- args : Tuple [Argument , ...],
705
- kwargs : Dict [str , Argument ],
706
- name : str ,
707
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
708
- return select (network , target , SourceIR .ATEN , name , args [0 ], args [1 ], args [2 ])
709
-
710
-
711
517
@tensorrt_converter (operator .sub )
712
518
def aten_ops_operator_sub (
713
519
network : TRTNetwork ,
@@ -743,27 +549,6 @@ def aten_ops_sym_numel(
743
549
return reduce_layer .get_output (0 )
744
550
745
551
746
- @tensorrt_converter (torch .ops .aten .slice .Tensor )
747
- def aten_ops_slice (
748
- network : TRTNetwork ,
749
- target : Target ,
750
- args : Tuple [Argument , ...],
751
- kwargs : Dict [str , Argument ],
752
- name : str ,
753
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
754
- return slice_op (
755
- network ,
756
- target ,
757
- SourceIR .ATEN ,
758
- name ,
759
- args [0 ],
760
- args [1 ],
761
- args [2 ],
762
- args [3 ],
763
- args [4 ],
764
- )
765
-
766
-
767
552
@tensorrt_converter (torch .ops .aten .sym_size )
768
553
def aten_ops_sym_size (
769
554
network : TRTNetwork ,
0 commit comments