Skip to content

Commit 2caac76

Browse files
committed
Changing torch_dtype_to_trt and torch_dtype_from_trt to unified_dtype_converter and adding convolution in acc_ops_converters.py and aten_ops_converters.py
1 parent 12e8aa9 commit 2caac76

File tree

5 files changed

+11
-8
lines changed

5 files changed

+11
-8
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
trt_transposed_matmul,
2727
)
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
29-
from torch_tensorrt.fx.converters.impl import activation
29+
from torch_tensorrt.fx.converters.impl import activation, convolution
3030
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
3131
from torch_tensorrt.fx.converters.impl.unary import sign
3232
from torch_tensorrt.fx.converters.impl.elementwise.base import (

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from .converter_utils import * # noqa: F403
2222
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
23-
from torch_tensorrt.fx.converters.impl import activation
23+
from torch_tensorrt.fx.converters.impl import activation, convolution
2424
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
2525
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt
2626

py/torch_tensorrt/fx/converters/impl/elementwise/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.fx.node import Target
1111

1212
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp
13-
from torch_tensorrt.fx.utils import torch_dtype_from_trt
13+
from torch_tensorrt.fx.utils import unified_dtype_converter, Frameworks
1414
from torch_tensorrt.fx.converters.converter_utils import (
1515
SourceIR,
1616
set_layer_name,
@@ -77,10 +77,10 @@ def convert_binary_elementwise(
7777
is_rhs_trt_tensor = False
7878

7979
if isinstance(lhs_val, TRTTensor):
80-
lhs_dtype = torch_dtype_from_trt(lhs_val.dtype)
80+
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH)
8181
is_lhs_trt_tensor = True
8282
if isinstance(rhs_val, TRTTensor):
83-
rhs_dtype = torch_dtype_from_trt(rhs_val.dtype)
83+
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH)
8484
is_rhs_trt_tensor = True
8585

8686
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:

py/torch_tensorrt/fx/converters/impl/elementwise/ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.fx.node import Target
1111

1212
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp
13-
from torch_tensorrt.fx.utils import torch_dtype_from_trt
13+
from torch_tensorrt.fx.utils import unified_dtype_converter, Frameworks
1414
from torch_tensorrt.fx.converters.converter_utils import (
1515
SourceIR,
1616
get_trt_tensor,
@@ -70,7 +70,10 @@ def trunc_div(
7070
input = get_trt_tensor(network, input, f"{name}_input")
7171
if not isinstance(other, trt.tensorrt.ITensor):
7272
other = get_trt_tensor(
73-
network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype)
73+
network,
74+
other,
75+
f"{name}_other",
76+
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
7477
)
7578

7679
abs_input_output = convert_unary(

py/torch_tensorrt/fx/converters/impl/shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.fx.node import Target
1111

1212
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp
13-
from torch_tensorrt.fx.utils import torch_dtype_from_trt
13+
from torch_tensorrt.fx.utils import unified_dtype_converter, Frameworks
1414
from torch_tensorrt.fx.converters.converter_utils import (
1515
SourceIR,
1616
set_layer_name,

0 commit comments

Comments
 (0)