Skip to content

Commit 6f345cf

Browse files
committed
Revert all changes to py/torch_tensorrt/fx
Revert "fix: Add automatic type promotion for FX ops" This reverts commit f1f3716. Revert "Moved clamp to impl" This reverts commit df401dd. Revert "aten::unsqueeze impl refactor" This reverts commit b424735. Revert "Converter reorg and where operator" This reverts commit b4da15e. Revert "converter reorg and matmul" This reverts commit 7551eee. Revert "converter reorg and slice" This reverts commit 9bbdc9e. Revert "Converter reorg and select operation" This reverts commit fb70253. Revert "Converter reorg and squeeze operator" This reverts commit 294545c. Revert "Converter reorg and gelu" This reverts commit 37d1168. Revert "Converter reorg and softmax operation" This reverts commit 1ba6d13. Revert "layer_norm converter" This reverts commit e0b34b1. Revert "Converter reorg batch norm" This reverts commit 59354e5. Revert "Converter reorg and rsub" This reverts commit db15d27. Revert "Converter reorg fmod" This reverts commit ce3fa67. Revert "Moving elementwise core to impl - rsqrt (FX Converter Refactor [9/N]) <Target: converter_reorg_elementwise> (#1905)" This reverts commit 7158ca5. Revert "refactor: Moving elementwise and unary core to impl" This reverts commit 45e43ca.
1 parent f1f3716 commit 6f345cf

38 files changed

+850
-3076
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 489 additions & 388 deletions
Large diffs are not rendered by default.

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 23 additions & 238 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,9 @@
2121
from .converter_utils import * # noqa: F403
2222
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2323
from torch_tensorrt.fx.converters.impl import activation
24-
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
25-
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt
26-
from torch_tensorrt.fx.converters.impl.elementwise import fmod
27-
from torch_tensorrt.fx.converters.impl.elementwise import rsub
28-
from torch_tensorrt.fx.converters.impl.normalization import batch_norm
29-
from torch_tensorrt.fx.converters.impl.normalization import layer_norm
30-
from torch_tensorrt.fx.converters.impl.normalization import softmax
31-
from torch_tensorrt.fx.converters.impl.squeeze import squeeze
32-
from torch_tensorrt.fx.converters.impl.select import select
33-
from torch_tensorrt.fx.converters.impl.slice import slice_op
34-
from torch_tensorrt.fx.converters.impl.matmul import matrix_multiply
35-
from torch_tensorrt.fx.converters.impl.condition import where
36-
from torch_tensorrt.fx.converters.impl.unsqueeze import unsqueeze
37-
from torch_tensorrt.fx.converters.impl.elementwise import clamp
3824

3925
_LOGGER: logging.Logger = logging.getLogger(__name__)
4026

41-
42-
def or_none(args, i):
43-
return args[i] if len(args) > i else None
44-
45-
4627
## converter list in alphabetic order
4728
@tensorrt_converter(torch.ops.aten.add.Tensor)
4829
def aten_ops_add(
@@ -108,19 +89,18 @@ def aten_ops_batch_norm(
10889
kwargs: Dict[str, Argument],
10990
name: str,
11091
) -> Union[TRTTensor, Sequence[TRTTensor]]:
111-
return batch_norm(
112-
network,
113-
target,
114-
SourceIR.ATEN,
115-
name,
116-
args[0],
117-
args[1],
118-
args[2],
119-
args[3],
120-
args[4],
121-
args[5],
122-
args[6],
123-
args[7],
92+
kwargs_new = {
93+
"input": args[0],
94+
"weight": args[1],
95+
"bias": args[2],
96+
"running_mean": args[3],
97+
"running_var": args[4],
98+
"training": args[5],
99+
"momentum": args[6],
100+
"eps": args[7],
101+
}
102+
return acc_ops_converters.acc_ops_batch_norm(
103+
network, target, None, kwargs_new, name
124104
)
125105

126106

@@ -179,7 +159,9 @@ def aten_ops_div(
179159
network, target, None, kwargs_new, name
180160
)
181161
elif rounding_mode == "trunc":
182-
return trunc_div(network, target, SourceIR.ATEN, name, args[0], args[1])
162+
return acc_ops_converters.acc_ops_trunc_div(
163+
network, target, None, kwargs_new, name
164+
)
183165
else:
184166
raise RuntimeError(
185167
f"Target {target} does not support rounding mode {rounding_mode}"
@@ -237,7 +219,11 @@ def aten_ops_fmod(
237219
kwargs: Dict[str, Argument],
238220
name: str,
239221
) -> Union[TRTTensor, Sequence[TRTTensor]]:
240-
return fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
222+
kwargs_new = {
223+
"input": args[0],
224+
"other": args[1],
225+
}
226+
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)
241227

242228

243229
@tensorrt_converter(torch.ops.aten.hardtanh.default)
@@ -248,40 +234,12 @@ def aten_ops_hardtanh(
248234
kwargs: Dict[str, Argument],
249235
name: str,
250236
) -> Union[TRTTensor, Sequence[TRTTensor]]:
237+
251238
return activation.hardtanh(
252239
network, target, SourceIR.ATEN, name, args[0], args[1], args[2]
253240
)
254241

255242

256-
@tensorrt_converter(torch.ops.aten.gelu.default)
257-
def aten_ops_gelu(
258-
network: TRTNetwork,
259-
target: Target,
260-
args: Tuple[Argument, ...],
261-
kwargs: Dict[str, Argument],
262-
name: str,
263-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
264-
return activation.gelu(
265-
network,
266-
target,
267-
SourceIR.ATEN,
268-
name,
269-
args[0],
270-
)
271-
272-
273-
@tensorrt_converter(torch.ops.aten.matmul)
274-
@tensorrt_converter(torch.ops.aten.mm.default)
275-
def aten_ops_matmul(
276-
network: TRTNetwork,
277-
target: Target,
278-
args: Tuple[Argument, ...],
279-
kwargs: Dict[str, Argument],
280-
name: str,
281-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
282-
return matrix_multiply(network, target, SourceIR.ATEN, name, args[0], args[1])
283-
284-
285243
@tensorrt_converter(torch.ops.aten.fmod.Tensor)
286244
def aten_ops_fmod(
287245
network: TRTNetwork,
@@ -305,28 +263,8 @@ def aten_ops_leaky_relu(
305263
kwargs: Dict[str, Argument],
306264
name: str,
307265
) -> Union[TRTTensor, Sequence[TRTTensor]]:
308-
return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1])
309-
310266

311-
@tensorrt_converter(torch.ops.aten.layer_norm.default)
312-
def aten_ops_layernorm(
313-
network: TRTNetwork,
314-
target: Target,
315-
args: Tuple[Argument, ...],
316-
kwargs: Dict[str, Argument],
317-
name: str,
318-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
319-
return layer_norm(
320-
network,
321-
target,
322-
SourceIR.ATEN,
323-
name,
324-
args[0],
325-
args[1],
326-
args[2],
327-
args[3],
328-
args[4],
329-
)
267+
return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1])
330268

331269

332270
@tensorrt_converter(torch.ops.aten.linear)
@@ -429,42 +367,6 @@ def aten_ops_relu(
429367
)
430368

431369

432-
@tensorrt_converter(torch.ops.aten.relu.default)
433-
def aten_ops_relu(
434-
network: TRTNetwork,
435-
target: Target,
436-
args: Tuple[Argument, ...],
437-
kwargs: Dict[str, Argument],
438-
name: str,
439-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
440-
441-
return activation.relu(
442-
network,
443-
target,
444-
SourceIR.ATEN,
445-
name,
446-
args[0],
447-
)
448-
449-
450-
@tensorrt_converter(torch.ops.aten.rsqrt.default)
451-
def aten_ops_rsqrt(
452-
network: TRTNetwork,
453-
target: Target,
454-
args: Tuple[Argument, ...],
455-
kwargs: Dict[str, Argument],
456-
name: str,
457-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
458-
459-
return rsqrt(
460-
network,
461-
target,
462-
SourceIR.ATEN,
463-
name,
464-
args[0],
465-
)
466-
467-
468370
@tensorrt_converter(torch.ops.aten.sub.Tensor)
469371
def aten_ops_sub(
470372
network: TRTNetwork,
@@ -480,29 +382,6 @@ def aten_ops_sub(
480382
return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name)
481383

482384

483-
@tensorrt_converter(torch.ops.aten.squeeze.dim)
484-
@tensorrt_converter(torch.ops.aten.squeeze.dims)
485-
def aten_ops_squeeze(
486-
network: TRTNetwork,
487-
target: Target,
488-
args: Tuple[Argument, ...],
489-
kwargs: Dict[str, Argument],
490-
name: str,
491-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
492-
return squeeze(network, target, SourceIR.ATEN, name, args[0], args[1])
493-
494-
495-
@tensorrt_converter(torch.ops.aten.unsqueeze.default)
496-
def aten_ops_unsqueeze(
497-
network: TRTNetwork,
498-
target: Target,
499-
args: Tuple[Argument, ...],
500-
kwargs: Dict[str, Argument],
501-
name: str,
502-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
503-
return unsqueeze(network, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1])
504-
505-
506385
@tensorrt_converter(torch.ops.aten.view.default)
507386
def aten_ops_reshape(
508387
network: TRTNetwork,
@@ -540,31 +419,6 @@ def aten_ops_reshape(
540419
return layer.get_output(0)
541420

542421

543-
@tensorrt_converter(torch.ops.aten.rsub.Tensor)
544-
def aten_ops_rsub(
545-
network: TRTNetwork,
546-
target: Target,
547-
args: Tuple[Argument, ...],
548-
kwargs: Dict[str, Argument],
549-
name: str,
550-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
551-
alpha = None
552-
if "alpha" in kwargs:
553-
alpha = kwargs["alpha"]
554-
return rsub(network, target, SourceIR.ATEN, name, args[0], args[1], alpha)
555-
556-
557-
@tensorrt_converter(torch.ops.aten._softmax.default)
558-
def aten_ops_softmax(
559-
network: TRTNetwork,
560-
target: Target,
561-
args: Tuple[Argument, ...],
562-
kwargs: Dict[str, Argument],
563-
name: str,
564-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
565-
return softmax(network, target, SourceIR.ATEN, name, args[0], args[1])
566-
567-
568422
@tensorrt_converter(torch.ops.aten.tanh.default)
569423
def aten_ops_tanh(
570424
network: TRTNetwork,
@@ -573,30 +427,12 @@ def aten_ops_tanh(
573427
kwargs: Dict[str, Argument],
574428
name: str,
575429
) -> Union[TRTTensor, Sequence[TRTTensor]]:
576-
return activation.tanh(
577-
network,
578-
target,
579-
SourceIR.ATEN,
580-
name,
581-
args[0],
582-
)
583430

584-
585-
@tensorrt_converter(torch.ops.aten.where.self)
586-
def aten_ops_where(
587-
network: TRTNetwork,
588-
target: Target,
589-
args: Tuple[Argument, ...],
590-
kwargs: Dict[str, Argument],
591-
name: str,
592-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
593-
return where(
431+
return activation.tanh(
594432
network,
595433
target,
596434
SourceIR.ATEN,
597435
name,
598-
args[1],
599-
args[2],
600436
args[0],
601437
)
602438

@@ -616,25 +452,6 @@ def aten_ops_cat(
616452
return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name)
617453

618454

619-
@tensorrt_converter(torch.ops.aten.clamp.default)
620-
def aten_ops_clamp(
621-
network: TRTNetwork,
622-
target: Target,
623-
args: Tuple[Argument, ...],
624-
kwargs: Dict[str, Argument],
625-
name: str,
626-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
627-
return clamp.clamp(
628-
network,
629-
target,
630-
SourceIR.ACC,
631-
name,
632-
input_val=args[0],
633-
min_val=or_none(args, 1),
634-
max_val=or_none(args, 2),
635-
)
636-
637-
638455
@tensorrt_converter(torch.ops.aten.expand.default)
639456
def aten_ops_expand(
640457
network: TRTNetwork,
@@ -697,17 +514,6 @@ def aten_ops_operator_add(
697514
return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name)
698515

699516

700-
@tensorrt_converter(torch.ops.aten.select.int)
701-
def aten_ops_select(
702-
network: TRTNetwork,
703-
target: Target,
704-
args: Tuple[Argument, ...],
705-
kwargs: Dict[str, Argument],
706-
name: str,
707-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
708-
return select(network, target, SourceIR.ATEN, name, args[0], args[1], args[2])
709-
710-
711517
@tensorrt_converter(operator.sub)
712518
def aten_ops_operator_sub(
713519
network: TRTNetwork,
@@ -743,27 +549,6 @@ def aten_ops_sym_numel(
743549
return reduce_layer.get_output(0)
744550

745551

746-
@tensorrt_converter(torch.ops.aten.slice.Tensor)
747-
def aten_ops_slice(
748-
network: TRTNetwork,
749-
target: Target,
750-
args: Tuple[Argument, ...],
751-
kwargs: Dict[str, Argument],
752-
name: str,
753-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
754-
return slice_op(
755-
network,
756-
target,
757-
SourceIR.ATEN,
758-
name,
759-
args[0],
760-
args[1],
761-
args[2],
762-
args[3],
763-
args[4],
764-
)
765-
766-
767552
@tensorrt_converter(torch.ops.aten.sym_size)
768553
def aten_ops_sym_size(
769554
network: TRTNetwork,

0 commit comments

Comments
 (0)