Skip to content

Fx2trt converters #1658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed

Fx2trt converters #1658

wants to merge 7 commits into from

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Feb 10, 2023

Description

This PR addresses the feature #1657

Type of change

Addition of the following files-

  1. fx/converters/fx2trt_converters.py
    Converters added-
  • torch.ops.aten.add.Tensor
  • torch.ops.aten.leaky_relu
  • torch.ops.aten.adaptive_avg_pool2d.default
  • torch.ops.aten._adaptive_avg_pool3d.default
  • torch.ops.aten.mean.dim
  • torch.ops.aten.batch_norm
  • torch.ops.aten.cat.default
  • torch.ops.aten.convolution.default
  1. fx/converters/fx2trt_converters_util.py

  2. Tests-

  • torch.ops.aten.add.Tensor
    Added a seperate test_add_torch2trt.py for testing. Will remove this later, since this is already covered in test_binary_ops_aten.py
  • torch.ops.aten.leaky_relu
    Added test_leaky_relu_torch2trt.py
  • torch.ops.aten.adaptive_avg_pool2d.default
    Already present
  • torch.ops.aten._adaptive_avg_pool3d.default
    Already present
  • torch.ops.aten.mean.dim
    Already present
  • torch.ops.aten.batch_norm
    Already present
  • torch.ops.aten.cat.default
    Already present
  • torch.ops.aten.convolution.default
    Already present

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- py/torch_tensorrt/fx/converters/fx2trt_ops_converter.py	2023-02-10 19:54:46.440236 +0000
+++ py/torch_tensorrt/fx/converters/fx2trt_ops_converter.py	2023-02-10 19:55:00.607626 +0000
@@ -24,106 +24,121 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from .fx2trt_ops_converter_utils import *

_LOGGER: logging.Logger = logging.getLogger(__name__)

+
@tensorrt_converter(torch.ops.aten.add.Tensor)
def convert_add(network, target, args, kwargs, name):
    input_a = args[0]
    input_b = args[1]
    input_a_trt, input_b_trt = add_missing_trt_tensors(network, [input_a, input_b])
-    input_a_trt, input_b_trt = broadcast_trt_tensors(network, [input_a_trt, input_b_trt], len(output.shape))
-    layer = network.add_elementwise(input_a_trt, input_b_trt, trt.ElementWiseOperation.SUM)
-    output = layer.get_output(0) 
+    input_a_trt, input_b_trt = broadcast_trt_tensors(
+        network, [input_a_trt, input_b_trt], len(output.shape)
+    )
+    layer = network.add_elementwise(
+        input_a_trt, input_b_trt, trt.ElementWiseOperation.SUM
+    )
+    output = layer.get_output(0)
    return output
+

@tensorrt_converter(torch.ops.aten.leaky_relu)
def convert_leaky_relu(network, target, args, kwargs, name):
-    input = get_arg(args, kwargs, 'input', pos=0, default=None)
-    negative_slope = get_arg(args, kwargs, 'negative_slope', pos=1, default=0.01)
+    input = get_arg(args, kwargs, "input", pos=0, default=None)
+    negative_slope = get_arg(args, kwargs, "negative_slope", pos=1, default=0.01)
    input_trt = add_missing_trt_tensors(network, [input])[0]
    layer = network.add_activation(input_trt, trt.ActivationType.LEAKY_RELU)
    layer.alpha = negative_slope
    output = layer.get_output(0)
    return output
+

@tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
@tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d)
def convert_adaptive_avg_pool2d(network, target, args, kwargs, name):
    method_args = (network, torch.nn.AdaptiveAvgPool2d(args[1]), args[0])
    output = convert_AdaptiveAvgPool2d(method_args)

+
@tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
def convert_adaptive_avg_pool3d(network, target, args, kwargs, name):
    method_args = (network, torch.nn.AdaptiveAvgPool3d(args[1]), args[0])
    output = convert_AdaptiveAvgPool3d(method_args)

-#FIXME: check if this is required
+
+# FIXME: check if this is required
@tensorrt_converter(torch.ops.aten.mean.dim)
def convert_avg_pool(network, target, args, kwargs, name):
    # parse args
-    input = get_arg(args, kwargs, 'input', pos=0, default=None)
-    kernel_size = get_arg(args, kwargs, 'kernel_size', pos=1, default=None)
-    stride = get_arg(args, kwargs, 'stride', pos=2, default=None)
-    padding = get_arg(args, kwargs, 'padding', pos=3, default=0)
-    ceil_mode = get_arg(args, kwargs,'ceil_mode', pos=4, default=False)
-    count_include_pad = get_arg(args, kwargs, 'count_include_pad', pos=5, default=True)
-    
+    input = get_arg(args, kwargs, "input", pos=0, default=None)
+    kernel_size = get_arg(args, kwargs, "kernel_size", pos=1, default=None)
+    stride = get_arg(args, kwargs, "stride", pos=2, default=None)
+    padding = get_arg(args, kwargs, "padding", pos=3, default=0)
+    ceil_mode = get_arg(args, kwargs, "ceil_mode", pos=4, default=False)
+    count_include_pad = get_arg(args, kwargs, "count_include_pad", pos=5, default=True)
+
    # get input trt tensor (or create constant if it doesn't exist)
    input_trt = add_missing_trt_tensors(network, [input])[0]
    input_dim = input.dim() - 2

    # get kernel size
    if not isinstance(kernel_size, tuple):
-        kernel_size = (kernel_size, ) * input_dim
+        kernel_size = (kernel_size,) * input_dim

    # get stride
    if not isinstance(stride, tuple):
-        stride = (stride, ) * input_dim
+        stride = (stride,) * input_dim

    # get padding
    if not isinstance(padding, tuple):
-        padding = (padding, ) * input_dim
+        padding = (padding,) * input_dim

    layer = network.add_pooling_nd(
-        input=input_trt, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
-    
+        input=input_trt, type=trt.PoolingType.AVERAGE, window_size=kernel_size
+    )
+
    layer.stride_nd = stride
    layer.padding_nd = padding
    layer.average_count_excludes_padding = not count_include_pad
-    
+
    if ceil_mode:
        layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP

    output = layer.get_output(0)
    return output

+
@tensorrt_converter(torch.ops.aten.batch_norm)
def convert_batch_norm(network, target, args, kwargs, name):
-    input = get_arg(args, kwargs, 'input', pos=0, default=None) 
-    running_mean = get_arg(args, kwargs, 'running_mean', pos=1, default=None) 
-    running_var = get_arg(args, kwargs, 'running_var', pos=2, default=None) 
+    input = get_arg(args, kwargs, "input", pos=0, default=None)
+    running_mean = get_arg(args, kwargs, "running_mean", pos=1, default=None)
+    running_var = get_arg(args, kwargs, "running_var", pos=2, default=None)

-    weight = get_arg(args, kwargs, 'weight', pos=3, default=None) 
-    bias = get_arg(args, kwargs, 'bias', pos=4, default=None) 
-    eps = get_arg(args, kwargs, 'eps', pos=7, default=10e-6) 
+    weight = get_arg(args, kwargs, "weight", pos=3, default=None)
+    bias = get_arg(args, kwargs, "bias", pos=4, default=None)
+    eps = get_arg(args, kwargs, "eps", pos=7, default=10e-6)

    input_trt = add_missing_trt_tensors(network, [input])[0]
-    
-    
-    scale = weight.detach().cpu().numpy() / np.sqrt(running_var.detach().cpu().numpy() + eps)
+
+    scale = weight.detach().cpu().numpy() / np.sqrt(
+        running_var.detach().cpu().numpy() + eps
+    )
    bias = bias.detach().cpu().numpy() - running_mean.detach().cpu().numpy() * scale
    power = np.ones_like(scale)

-    layer = network.add_scale_nd(input_trt, trt.ScaleMode.CHANNEL, bias, scale, power, 1)
+    layer = network.add_scale_nd(
+        input_trt, trt.ScaleMode.CHANNEL, bias, scale, power, 1
+    )
    output = layer.get_output(0)
    return output

+
@tensorrt_converter(torch.ops.aten.cat.default)
def convert_cat(network, target, args, kwargs, name):
-    inputs = get_arg(args, kwargs, 'input', pos=0, default=None)
-    dim = get_arg(args, kwargs, 'dim', pos=1, default=0)
+    inputs = get_arg(args, kwargs, "input", pos=0, default=None)
+    dim = get_arg(args, kwargs, "dim", pos=1, default=0)

    # Reverse negative dims.
    if dim < 0:
        dim = len(inputs[0].shape) - abs(dim)

@@ -133,46 +148,48 @@
    layer = network.add_concatenation(inputs=trt_inputs)
    layer.axis = dim
    output = layer.get_output(0)
    return output

+
@tensorrt_converter(torch.ops.aten.convolution.default)
def convert_cat(network, target, args, kwargs, name):
    module = args[0]
    input = args[1]
    input_trt = add_missing_trt_tensors(network, [input])[0]

    input_dim = input.dim() - 2

    kernel_size = module.kernel_size
    if not isinstance(kernel_size, tuple):
-        kernel_size = (kernel_size, ) * input_dim
+        kernel_size = (kernel_size,) * input_dim

    stride = module.stride
    if not isinstance(stride, tuple):
-        stride = (stride, ) * input_dim
+        stride = (stride,) * input_dim

    padding = module.padding
    if not isinstance(padding, tuple):
-        padding = (padding, ) * input_dim
+        padding = (padding,) * input_dim

    dilation = module.dilation
    if not isinstance(dilation, tuple):
-        dilation = (dilation, ) * input_dim
+        dilation = (dilation,) * input_dim

    kernel = module.weight.detach().cpu().numpy()
-    
-    bias = None #trt.Weights(torch_dtype_to_trt(module.weight.dtype))
+
+    bias = None  # trt.Weights(torch_dtype_to_trt(module.weight.dtype))
    if module.bias is not None:
        bias = module.bias.detach().cpu().numpy()

    layer = network.add_convolution_nd(
        input=input_trt,
        num_output_maps=module.out_channels,
        kernel_shape=kernel_size,
        kernel=kernel,
-        bias=bias)
+        bias=bias,
+    )
    layer.stride_nd = stride
    layer.padding_nd = padding
    layer.dilation_nd = dilation

    if module.groups is not None:
--- py/torch_tensorrt/fx/converters/fx2trt_ops_converter_utils.py	2023-02-10 19:54:46.440236 +0000
+++ py/torch_tensorrt/fx/converters/fx2trt_ops_converter_utils.py	2023-02-10 19:55:00.694990 +0000
@@ -1,16 +1,18 @@
import tensorrt as trt
import torch
+

def get_arg(args, kwargs, name, pos, default):
    if name in kwargs:
        return kwargs[name]
    elif len(args) > pos:
        return args[pos]
    else:
        return default
-    
+
+
def add_missing_trt_tensors(network, tensors):
    """Creates missing TensorRT tensors as constants and attaches them to the Torch Tensors"""
    with use_shape_wrapping(False):
        trt_tensors = [None] * len(tensors)

@@ -44,11 +46,10 @@

                weight = t.detach().cpu().numpy()
                t._trt = network.add_constant(shape, weight).get_output(0)
                trt_tensor = t._trt

-
            assert trt_tensor is not None

            trt_tensors[i] = trt_tensor

        return trt_tensors
@@ -73,10 +74,11 @@

            broadcasted_trt_tensors[i] = trt_tensor

        return broadcasted_trt_tensors

+
def check_torch_dtype(*tensors):
    dtype = None
    for t in tensors:
        if isinstance(t, torch.Tensor):
            if dtype is None:
@@ -86,41 +88,48 @@
    assert (
        dtype is not None
    )  # , 'Data type could not be inferred from any item in list')
    return dtype

+
class use_shape_wrapping:

-    stack = [True] # default true
+    stack = [True]  # default true

    def __init__(self, value: bool):
        self._value = value
-    
+
    def __enter__(self, *args, **kwargs):
        self.stack.insert(0, self._value)

    def __exit__(self, *args, **kwargs):
        self.stack.pop(0)
+

def convert_AdaptiveAvgPool2d(method_args):
    network = method_args[0]
    module = method_args[1]
    input = method_args[2]
    input_trt = add_missing_trt_tensors(method_args[0], [input])[0]
    output_size = module.output_size
    if not isinstance(output_size, tuple):
-        output_size = (output_size, ) * 2
+        output_size = (output_size,) * 2

-    stride = (input_trt.shape[-2] // output_size[-2], input_trt.shape[-1] // output_size[-1])
+    stride = (
+        input_trt.shape[-2] // output_size[-2],
+        input_trt.shape[-1] // output_size[-1],
+    )

    kernel_size = stride
    layer = network.add_pooling(
-        input=input_trt, type=trt.PoolingType.AVERAGE, window_size=kernel_size)
+        input=input_trt, type=trt.PoolingType.AVERAGE, window_size=kernel_size
+    )
    layer.stride = stride

    output = layer.get_output(0)
    return output
+

def convert_AdaptiveAvgPool3d(method_args):
    network = method_args[0]
    module = method_args[1]
    input = method_args[2]
@@ -144,6 +153,6 @@
        window_size=kernel_size,
    )
    layer.stride_nd = stride

    output = layer.get_output(0)
-    return output
\ No newline at end of file
+    return output
--- py/torch_tensorrt/fx/test/converters/aten_op/test_add_torch2trt.py	2023-02-10 19:54:46.440236 +0000
+++ py/torch_tensorrt/fx/test/converters/aten_op/test_add_torch2trt.py	2023-02-10 19:55:01.881660 +0000
@@ -1,24 +1,26 @@
import torch
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
+

class TestAddBasic(DispatchTestCase):
    def test_add(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super(TestModule, self).__init__()

            def forward(self, x, y):
                return x + y
-        
-        inputs = [torch.randn(1,3,224,224), torch.randn(1,3,224,224)]
+
+        inputs = [torch.randn(1, 3, 224, 224), torch.randn(1, 3, 224, 224)]
        self.run_test(
            TestModule(),
            inputs,
            expected_ops=({torch.ops.aten.add.Tensor}),
        )
+
+
if __name__ == "__main__":
    run_tests()
-
--- py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_torch2trt.py	2023-02-10 19:54:46.444236 +0000
+++ py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_torch2trt.py	2023-02-10 19:55:01.961926 +0000
@@ -9,11 +9,13 @@
        class TestModule(nn.Module):
            def forward(self, x):
                return nn.functional.leaky_relu(x)

        inputs = [torch.randn(1, 10)]
-        self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.leaky_relu.default})
+        self.run_test(
+            TestModule(), inputs, expected_ops={torch.ops.aten.leaky_relu.default}
+        )

    def test_relu_with_dynamic_shape(self):
        class TestModule(nn.Module):
            def forward(self, x):
                return nn.functional.leaky_relu(x)

@narendasan narendasan added the WIP Work is in progress, pull request should not be merged yet label Feb 13, 2023
@github-actions github-actions bot requested a review from frank-wei February 13, 2023 17:41
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@apbose
Copy link
Collaborator Author

apbose commented Mar 17, 2023

Raised #1745

@apbose apbose closed this Mar 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: fx fx WIP Work is in progress, pull request should not be merged yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants