Skip to content

Moving permute core to impl - permute(FX Converter Refactor [22/N]) <Target: converter_reorg_elementwise> #1999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
535 changes: 384 additions & 151 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py

Large diffs are not rendered by default.

93 changes: 90 additions & 3 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,9 @@
from .converter_utils import * # noqa: F403
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch_tensorrt.fx.converters.impl import activation
from torch_tensorrt.fx.converters.impl import permute
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt

_LOGGER: logging.Logger = logging.getLogger(__name__)

@@ -161,9 +164,7 @@ def aten_ops_div(
network, target, None, kwargs_new, name
)
elif rounding_mode == "trunc":
return acc_ops_converters.acc_ops_trunc_div(
network, target, None, kwargs_new, name
)
return trunc_div(network, target, SourceIR.ATEN, name, args[0], args[1])
else:
raise RuntimeError(
f"Target {target} does not support rounding mode {rounding_mode}"
@@ -335,6 +336,24 @@ def aten_ops_mul(
return acc_ops_converters.acc_ops_mul(network, target, None, kwargs_new, name)


@tensorrt_converter(torch.ops.aten.permute.default)
def aten_ops_permute_default(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return permute.permute(
network,
target,
SourceIR.ATEN,
name=name,
input_val=args[0],
index=args[1:],
)


@tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
@tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
def aten_ops_pow(
@@ -369,6 +388,42 @@ def aten_ops_relu(
)


@tensorrt_converter(torch.ops.aten.relu.default)
def aten_ops_relu(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:

return activation.relu(
network,
target,
SourceIR.ATEN,
name,
args[0],
)


@tensorrt_converter(torch.ops.aten.rsqrt.default)
def aten_ops_rsqrt(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:

return rsqrt(
network,
target,
SourceIR.ATEN,
name,
args[0],
)


@tensorrt_converter(torch.ops.aten.sub.Tensor)
def aten_ops_sub(
network: TRTNetwork,
@@ -384,6 +439,38 @@ def aten_ops_sub(
return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name)


@tensorrt_converter(torch.ops.aten.transpose.int)
def aten_ops_transpose_int(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = args[0]
ndim = len(input_val.shape)
if len(args) == 1:
# default is to reverse dimensions
new_order = torch.arange(0, start=ndim - 1, step=-1)
else:
assert (
len(args) == 3
), f"Wrong number of arguments to transpose(): {len(args)-1}"
new_order = torch.arange(ndim)
dim0 = args[1]
if args[1] < 0:
dim0 = dim0 + ndim
dim1 = args[2]
if args[2] < 0:
dim1 = dim1 + ndim
new_order[dim0] = dim1
new_order[dim1] = dim0
print("New order: ", new_order)
return permute.permute(
network, target, SourceIR.ATEN, name=name, input_val=input_val, index=new_order
)


@tensorrt_converter(torch.ops.aten.view.default)
def aten_ops_reshape(
network: TRTNetwork,
333 changes: 3 additions & 330 deletions py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ class SourceIR(Enum):
ACC = auto()
ATEN = auto()
PRIM = auto()
TORCHTRT_LOWERED = auto()
UNKNOWN = auto()

def __str__(self):
@@ -39,6 +40,8 @@ def __str__(self):
return "aten"
elif self == SourceIR.PRIM:
return "prim"
elif self == SourceIR.TORCHTRT_LOWERED:
return "torchtrt_lowered"
else:
return "unknown_ir"

@@ -383,171 +386,6 @@ def broadcast(
return a, b


def get_shape_with_dynamic_shape(
network: TRTNetwork,
shape: Union[list, tuple, torch.Tensor],
input_val: TRTTensor,
target: Target,
name: str,
) -> TRTTensor:
"""
Prepare the real output tensor shape for dynamic shape mode tensor input.
How this functions works:
Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation
output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual
reduce operation output shape. Steps of calculations are:
1. get the actual tensor shape of input_val via add_shape layer;
2. create a all 0 tensor [0, 0, 0];
3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False];
4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace
all -1 dynamic shape dimensions with actual batch_size value;
5. output shape with actual batch_size as [2048, 128, 256]
Args:
network (TRTNetwork): TensorRT network object.
shape: calculated shape of the expected output tensor
input_val (TRTTensor): A TensorRT ITensor.
target (Target): Target of fx node.
name (str): The name we want to assign to the created TensorRT layer.
Returns:
TensorRT ITensors that represents the actual shape of the input_val
"""
# Ger real shape info for input_val
input_shape = network.add_shape(input_val).get_output(0)

scale_layer = network.add_constant(
input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32)
)
set_layer_name(scale_layer, target, f"{name}_scale")
scale_res = scale_layer.get_output(0)

length = input_shape.shape[0]
zero_layer = network.add_constant(
input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32))
)
set_layer_name(zero_layer, target, f"{name}_zeros")

condition_val = add_binary_elementwise_layer(
network,
scale_res,
zero_layer.get_output(0),
trt.ElementWiseOperation.LESS,
target,
f"{name}_shape",
)
select_layer = network.add_select(condition_val, input_shape, scale_res)
set_layer_name(select_layer, target, f"{name}_select")
return select_layer.get_output(0)


def add_binary_elementwise_layer(
network: TRTNetwork,
lhs_val: Union[int, float, TRTTensor, torch.Tensor],
rhs_val: Union[int, float, TRTTensor, torch.Tensor],
op_type: trt.ElementWiseOperation,
target: Target,
name: str,
) -> TRTTensor:
"""
This function adds a TensorRT elementwise layer. We allow both operands to be
constant (not a trt tensor) because in implicit batch dimension mode, we could
introduce constant via .size() op. Other scenario should be const folded first.
If any operand is not a trt tensor, we make it a trt constant layer while preserve
its dtype. Then we broadcast these two inputs to have the same number of dimensions.
Limitation:
If we are using implicit batch dim mode, the operand that is not a trt
tensor are not allowed to have larger ranks than the trt tensor operand.
Args:
network (TRTNetwork): TensorRT network object.
lhs_val (TRTTensor): Left operand of the binary operation. Could
be a TensorRT tensor, a PyTorch tensor or a simple value.
rhs_val (TRTTensor): Right operand of the binary operation. Similar
to lhs_val.
op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation.
target (Target): Target of fx node.
name (str): The name we want to assign to the created TensorRT layer.
Returns:
The output of TensorRT Elementwise layer.
"""
lhs_dtype = None
rhs_dtype = None
is_lhs_trt_tensor = False
is_rhs_trt_tensor = False

if isinstance(lhs_val, TRTTensor):
lhs_dtype = torch_dtype_from_trt(lhs_val.dtype)
is_lhs_trt_tensor = True
if isinstance(rhs_val, TRTTensor):
rhs_dtype = torch_dtype_from_trt(rhs_val.dtype)
is_rhs_trt_tensor = True

if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
warnings.warn(
f"Both operands of the binary elementwise op {name} "
"are constant. In this case, please consider constant fold the model first."
)
return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val)

# If the following conditions are true:
# 1. the network has implicit batch dimension,
# 2. one operand has shape [] (real shape is [batch_size]),
# 3. another operand is a scalar,
# then the result should also have shape [] (real shape is [batch_size]).
#
# In such case, we need to convert the scalar operand to tensor, because
# this way the shape will become [1], and then will be properly squeezed
# into [], meaning that the result will have shape [], which is what we
# expect.
#
# Note that the dtype here is supposed to be the same as the scalar
# dtype but we don't have a way to detect whether it makes sense for the
# scalar to be float or half. Hence we go with the lhs dtype.
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype)

# When lhs is scalar, and rhs has shape [1,], then currently the assert
# will fail because lhs shape has fewer dimensions than rhs shape. This
# happens when using implicit batch dimension, when we removed the 1st
# dimension from input tensor, causing it to have shape [] - a scalar. We
# fix it by reducing the rhs constant with a squeeze_left, so it becomes a
# scalar too. More generally, we squeeze_left on input if it's a constant
# tensor. This is safe because broadcast will pad dimensions on the left
# (prepend) to make lhs and rhs shape compatible.
if network.has_implicit_batch_dimension:
if isinstance(lhs_val, torch.Tensor):
lhs_val = squeeze_left(lhs_val)
if isinstance(rhs_val, torch.Tensor):
rhs_val = squeeze_left(rhs_val)

lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)

# Check the limitation in the doc string.
if network.has_implicit_batch_dimension:
if is_lhs_trt_tensor and not is_rhs_trt_tensor:
assert len(lhs_val.shape) >= len(
rhs_val.shape
), f"{lhs_val.shape} >= {rhs_val.shape}"
elif not is_lhs_trt_tensor and is_rhs_trt_tensor:
assert len(rhs_val.shape) >= len(
lhs_val.shape
), f"{rhs_val.shape} >= {lhs_val.shape}"

lhs_val, rhs_val = broadcast(
network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
)
layer = network.add_elementwise(lhs_val, rhs_val, op_type)
set_layer_name(layer, target, name)
output = layer.get_output(0)
output.name = output.name + "_" + target.__name__
return output


def squeeze_left(const: torch.Tensor):
"""
Squeeze the size-1 dimensions on the left side of the shape tuple.
@@ -559,38 +397,6 @@ def squeeze_left(const: torch.Tensor):
return const


def add_unary_layer(
network: TRTNetwork,
input_val: TRTTensor,
operation_type: trt.UnaryOperation,
target: Target,
name: str,
) -> TRTTensor:
"""
Add a TensorRT Unary layer to `network`.
Args:
network (TRTNetwork): TensorRT network object.
input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor.
op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation.
target (Target): Target of fx node.
name (str): The name we want to assign to the created TensorRT layer.
Returns:
The output of TensorRT Unary layer.
"""
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"{operation_type} received input {input_val} that is not part "
"of the TensorRT region!"
)
layer = network.add_unary(input_val, operation_type)
set_layer_name(layer, target, name)
output = layer.get_output(0)
output.name = output.name + "_" + target.__name__
return layer.get_output(0)


def add_reduce_layer(
network: TRTNetwork,
target: Target,
@@ -695,139 +501,6 @@ def get_inputs_from_args_and_kwargs(args, kwargs, input_names):
return inputs


def sign(
network: TRTNetwork, input_val: TRTTensor, target: Target, name: str
) -> TRTTensor:
"""
Sign is calculated as below:
x = input
sign = (exp(x) // exp(abs(x))) * 2 - 1
For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0.
With multiply 2, the value become 2(for pos and 0) and 0(for neg).
Finally minus 1, the value become 1(for pos and 0) and -1(for neg).
Args:
network (TRTNetwork): TensorRT network object.
input_val (TRTTensor): The input tensor.
target (Target): fx node target.
name (str): Name of the fx node with optional suffix.
Returns:
A TensorRT tensor represent the result of sign operator.
"""
input_exp_output = add_unary_layer(
network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp"
)
input_abs_output = add_unary_layer(
network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs"
)
input_abs_exp_output = add_unary_layer(
network,
input_abs_output,
trt.UnaryOperation.EXP,
target,
f"{name}_prod_abs_exp",
)
floor_div_output = add_binary_elementwise_layer(
network,
input_exp_output,
input_abs_exp_output,
trt.ElementWiseOperation.FLOOR_DIV,
target,
f"{name}_exp_floor_div",
)
double_floor_div_output = add_binary_elementwise_layer(
network,
floor_div_output,
2,
trt.ElementWiseOperation.PROD,
target,
f"{name}_floor_div*2",
)
return add_binary_elementwise_layer(
network,
double_floor_div_output,
1,
trt.ElementWiseOperation.SUB,
target,
f"{name}_sign",
)


def trunc_div(
input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str
) -> TRTTensor:
"""
Perform trunc divide on Tensor, result of divide will be round toward zero.
This means for positive number, it will be floor round; for negative number,
it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3].
Args:
input: divisor.
other: dividend.
network: INetworkDefinition.
target: node target.
name: namespace for the op
Returns:
A TensorRT tensor represent the result of trunc divide.
"""
prod_output = add_binary_elementwise_layer(
network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod"
)
sign_output = sign(network, prod_output, target, name)

# Convert constant input into ITensor for UnaryOperation
if not isinstance(input, trt.tensorrt.ITensor):
input = get_trt_tensor(network, input, f"{name}_input")
if not isinstance(other, trt.tensorrt.ITensor):
other = get_trt_tensor(
network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype)
)

abs_input_output = add_unary_layer(
network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_input"
)
abs_other_output = add_unary_layer(
network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_other"
)
abs_floor_output = add_binary_elementwise_layer(
network,
abs_input_output,
abs_other_output,
trt.ElementWiseOperation.FLOOR_DIV,
target,
f"{name}_floor_div",
)
output = add_binary_elementwise_layer(
network,
abs_floor_output,
sign_output,
trt.ElementWiseOperation.PROD,
target,
f"{name}_output",
)

return output


def get_python_op_from_trt_elementwise_op(
trt_op: TRTElementWiseOp,
) -> Callable[[Any, Any], Any]:
if trt_op == trt.ElementWiseOperation.SUM:
return operator.add
elif trt_op == trt.ElementWiseOperation.PROD:
return operator.mul
elif trt_op == trt.ElementWiseOperation.SUB:
return operator.sub
elif trt_op == trt.ElementWiseOperation.DIV:
return operator.truediv
elif trt_op == trt.ElementWiseOperation.FLOOR_DIV:
return operator.floordiv
else:
raise RuntimeError(f"{trt_op} is not supported yet!")


def dtype_uniform(
network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor
):
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ops import *
147 changes: 147 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/elementwise/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import operator
import warnings
from typing import Union, Callable, Any, Optional

import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from torch.fx.node import Target

from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp
from torch_tensorrt.fx.utils import torch_dtype_from_trt
from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
set_layer_name,
broadcast,
squeeze_left,
get_trt_tensor,
)


def get_python_op_from_trt_elementwise_op(
trt_op: TRTElementWiseOp,
) -> Callable[[Any, Any], Any]:
if trt_op == trt.ElementWiseOperation.SUM:
return operator.add
elif trt_op == trt.ElementWiseOperation.PROD:
return operator.mul
elif trt_op == trt.ElementWiseOperation.SUB:
return operator.sub
elif trt_op == trt.ElementWiseOperation.DIV:
return operator.truediv
elif trt_op == trt.ElementWiseOperation.FLOOR_DIV:
return operator.floordiv
else:
raise RuntimeError(f"{trt_op} is not supported yet!")


def convert_binary_elementwise(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
op_type: trt.ElementWiseOperation,
lhs_val: Union[int, float, TRTTensor, torch.Tensor],
rhs_val: Union[int, float, TRTTensor, torch.Tensor],
) -> TRTTensor:
"""
This function adds a TensorRT elementwise layer. We allow both operands to be
constant (not a trt tensor) because in implicit batch dimension mode, we could
introduce constant via .size() op. Other scenario should be const folded first.
If any operand is not a trt tensor, we make it a trt constant layer while preserve
its dtype. Then we broadcast these two inputs to have the same number of dimensions.
Limitation:
If we are using implicit batch dim mode, the operand that is not a trt
tensor are not allowed to have larger ranks than the trt tensor operand.
Args:
network (TRTNetwork): TensorRT network object.
target (Target): Target of fx node.
source_ir (SourceIR): The IR that is calling the function.
name (str): The name we want to assign to the created TensorRT layer.
lhs_val (TRTTensor): Left operand of the binary operation. Could
be a TensorRT tensor, a PyTorch tensor or a simple value.
rhs_val (TRTTensor): Right operand of the binary operation. Similar
to lhs_val.
op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation.
Returns:
The output of TensorRT Elementwise layer.
"""
lhs_dtype = None
rhs_dtype = None
is_lhs_trt_tensor = False
is_rhs_trt_tensor = False

if isinstance(lhs_val, TRTTensor):
lhs_dtype = torch_dtype_from_trt(lhs_val.dtype)
is_lhs_trt_tensor = True
if isinstance(rhs_val, TRTTensor):
rhs_dtype = torch_dtype_from_trt(rhs_val.dtype)
is_rhs_trt_tensor = True

if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
warnings.warn(
f"Both operands of the binary elementwise op {name} "
"are constant. In this case, please consider constant fold the model first."
)
return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val)

# If the following conditions are true:
# 1. the network has implicit batch dimension,
# 2. one operand has shape [] (real shape is [batch_size]),
# 3. another operand is a scalar,
# then the result should also have shape [] (real shape is [batch_size]).
#
# In such case, we need to convert the scalar operand to tensor, because
# this way the shape will become [1], and then will be properly squeezed
# into [], meaning that the result will have shape [], which is what we
# expect.
#
# Note that the dtype here is supposed to be the same as the scalar
# dtype but we don't have a way to detect whether it makes sense for the
# scalar to be float or half. Hence we go with the lhs dtype.
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype)

# When lhs is scalar, and rhs has shape [1,], then currently the assert
# will fail because lhs shape has fewer dimensions than rhs shape. This
# happens when using implicit batch dimension, when we removed the 1st
# dimension from input tensor, causing it to have shape [] - a scalar. We
# fix it by reducing the rhs constant with a squeeze_left, so it becomes a
# scalar too. More generally, we squeeze_left on input if it's a constant
# tensor. This is safe because broadcast will pad dimensions on the left
# (prepend) to make lhs and rhs shape compatible.
if network.has_implicit_batch_dimension:
if isinstance(lhs_val, torch.Tensor):
lhs_val = squeeze_left(lhs_val)
if isinstance(rhs_val, torch.Tensor):
rhs_val = squeeze_left(rhs_val)

lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)

# Check the limitation in the doc string.
if network.has_implicit_batch_dimension:
if is_lhs_trt_tensor and not is_rhs_trt_tensor:
assert len(lhs_val.shape) >= len(
rhs_val.shape
), f"{lhs_val.shape} >= {rhs_val.shape}"
elif not is_lhs_trt_tensor and is_rhs_trt_tensor:
assert len(rhs_val.shape) >= len(
lhs_val.shape
), f"{rhs_val.shape} >= {lhs_val.shape}"

lhs_val, rhs_val = broadcast(
network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
)
layer = network.add_elementwise(lhs_val, rhs_val, op_type)
set_layer_name(layer, target, name, source_ir)
output = layer.get_output(0)
output.name = output.name + "_" + target.__name__
return output
141 changes: 141 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/elementwise/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import operator
import warnings
from typing import Union, Callable, Any, Optional

import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from torch.fx.node import Target

from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp
from torch_tensorrt.fx.utils import torch_dtype_from_trt
from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
get_trt_tensor,
)

from torch_tensorrt.fx.converters.impl.elementwise.base import (
convert_binary_elementwise,
)
from torch_tensorrt.fx.converters.impl.unary.base import convert_unary
from torch_tensorrt.fx.converters.impl.unary import sign


def trunc_div(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
other: TRTTensor,
) -> TRTTensor:
"""
Perform trunc divide on Tensor, result of divide will be round toward zero.
This means for positive number, it will be floor round; for negative number,
it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3].
Args:
network: INetworkDefinition.
target: node target
source_ir (SourceIR): Source IR calling the function.
name: namespace for the op
input: divisor.
other: dividend.
Returns:
A TensorRT tensor represent the result of trunc divide.
"""
prod_output = convert_binary_elementwise(
network,
target,
source_ir,
f"{name}_prod",
trt.ElementWiseOperation.PROD,
input,
other,
)

sign_output = sign(
network,
target,
source_ir,
name,
prod_output,
)

# Convert constant input into ITensor for UnaryOperation
if not isinstance(input, trt.tensorrt.ITensor):
input = get_trt_tensor(network, input, f"{name}_input")
if not isinstance(other, trt.tensorrt.ITensor):
other = get_trt_tensor(
network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype)
)

abs_input_output = convert_unary(
network,
target,
source_ir,
f"{name}_abs_input",
trt.UnaryOperation.ABS,
input,
)
abs_other_output = convert_unary(
network,
target,
source_ir,
f"{name}_abs_other",
trt.UnaryOperation.ABS,
other,
)
abs_floor_output = convert_binary_elementwise(
network,
target,
source_ir,
f"{name}_floor_div",
trt.ElementWiseOperation.FLOOR_DIV,
abs_input_output,
abs_other_output,
)
output = convert_binary_elementwise(
network,
target,
source_ir,
f"{name}_output",
trt.ElementWiseOperation.PROD,
abs_floor_output,
sign_output,
)

return output


def rsqrt(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:

sqrt_trt_output = convert_unary(
network,
target,
source_ir,
f"{name}_sqrt",
trt.UnaryOperation.SQRT,
input,
)

output = convert_binary_elementwise(
network,
target,
source_ir,
f"{name}_output",
trt.ElementWiseOperation.DIV,
1,
sqrt_trt_output,
)

return output
46 changes: 46 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/permute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
import operator
import warnings
from typing import cast, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from torch.fx.node import Argument, Target

from ..converter_utils import * # noqa: F403
from ...utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt

from torch_tensorrt.fx.types import (
TRTNetwork,
TRTTensor,
)


def permute(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
index: Sequence[TRTTensor],
) -> Union[TRTTensor, Sequence[TRTTensor]]:
ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
if len(index) == 1:
index = index[0]
permutation = [get_positive_dim(i, ranks) for i in cast(Sequence[int], index)]

if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"permute received input {input_val} that is not part "
"of the TensorRT region!"
)

if network.has_implicit_batch_dimension:
assert permutation[0] == 0, "Can't permute batch dimension when it's implicit."
permutation = [i - 1 for i in permutation[1:]]

layer = network.add_shuffle(input_val)
layer.second_transpose = tuple(permutation)
set_layer_name(layer, target, name)
return layer.get_output(0)
81 changes: 81 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import operator
import warnings
from typing import Union, Callable, Any, Optional

import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from torch.fx.node import Target

from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp
from torch_tensorrt.fx.utils import torch_dtype_from_trt
from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
set_layer_name,
to_numpy,
)

from torch_tensorrt.fx.converters.impl.elementwise.base import (
convert_binary_elementwise,
)


def get_shape_with_dynamic_shape(
network: TRTNetwork,
target: Target,
source_ir: SourceIR,
name: str,
shape: Union[list, tuple, torch.Tensor],
input_val: TRTTensor,
) -> TRTTensor:
"""
Prepare the real output tensor shape for dynamic shape mode tensor input.
How this functions works:
Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation
output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual
reduce operation output shape. Steps of calculations are:
1. get the actual tensor shape of input_val via add_shape layer;
2. create a all 0 tensor [0, 0, 0];
3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False];
4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace
all -1 dynamic shape dimensions with actual batch_size value;
5. output shape with actual batch_size as [2048, 128, 256]
Args:
network (TRTNetwork): TensorRT network object.
shape: calculated shape of the expected output tensor
input_val (TRTTensor): A TensorRT ITensor.
target (Target): Target of fx node.
name (str): The name we want to assign to the created TensorRT layer.
Returns:
TensorRT ITensors that represents the actual shape of the input_val
"""
# Ger real shape info for input_val
input_shape = network.add_shape(input_val).get_output(0)

scale_layer = network.add_constant(
input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32)
)
set_layer_name(scale_layer, target, f"{name}_scale")
scale_res = scale_layer.get_output(0)

length = input_shape.shape[0]
zero_layer = network.add_constant(
input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32))
)
set_layer_name(zero_layer, target, f"{name}_zeros")

condition_val = convert_binary_elementwise(
network,
target,
source_ir,
f"{name}_shape",
trt.ElementWiseOperation.LESS,
scale_res,
zero_layer.get_output(0),
)
select_layer = network.add_select(condition_val, input_shape, scale_res)
set_layer_name(select_layer, target, f"{name}_select")
return select_layer.get_output(0)
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/impl/unary/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ops import *
53 changes: 53 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/unary/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import operator
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from enum import Enum, auto

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
from torch.fx.node import Target

from torch_tensorrt.fx.types import (
TRTNetwork,
TRTTensor,
)

from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name

from torch_tensorrt.fx.converters.impl.elementwise.base import (
convert_binary_elementwise,
)


def convert_unary(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
operation_type: trt.UnaryOperation,
input_val: TRTTensor,
) -> TRTTensor:
"""
Add a TensorRT Unary layer to `network`.
Args:
network (TRTNetwork): TensorRT network object.
input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor.
op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation.
target (Target): Target of fx node.
name (str): The name we want to assign to the created TensorRT layer.
Returns:
The output of TensorRT Unary layer.
"""
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"{operation_type} received input {input_val} that is not part "
"of the TensorRT region!"
)
layer = network.add_unary(input_val, operation_type)
set_layer_name(layer, target, name, source_ir)
output = layer.get_output(0)
output.name = output.name + "_" + target.__name__
return layer.get_output(0)
104 changes: 104 additions & 0 deletions py/torch_tensorrt/fx/converters/impl/unary/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import operator
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from enum import Enum, auto

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
from torch.fx.node import Target

from torch_tensorrt.fx.types import (
TRTNetwork,
TRTTensor,
)

from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
)

from torch_tensorrt.fx.converters.impl.elementwise.base import (
convert_binary_elementwise,
)
from torch_tensorrt.fx.converters.impl.unary.base import convert_unary


def sign(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
) -> TRTTensor:
"""
Sign is calculated as below:
x = input
sign = (exp(x) // exp(abs(x))) * 2 - 1
For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0.
With multiply 2, the value become 2(for pos and 0) and 0(for neg).
Finally minus 1, the value become 1(for pos and 0) and -1(for neg).
Args:
network (TRTNetwork): TensorRT network object.
target (Target): fx node target.
source_ir (SourceIR): Source IR calling the function
name (str): Name of the fx node with optional suffix.
input_val (TRTTensor): The input tensor.
Returns:
A TensorRT tensor represent the result of sign operator.
"""
input_exp_output = convert_unary(
network,
target,
source_ir,
f"{name}_prod_exp",
trt.UnaryOperation.EXP,
input_val,
)
input_abs_output = convert_unary(
network,
target,
source_ir,
f"{name}_prod_abs",
trt.UnaryOperation.ABS,
input_val,
)
input_abs_exp_output = convert_unary(
network,
target,
source_ir,
f"{name}_prod_abs_exp",
trt.UnaryOperation.EXP,
input_abs_output,
)

floor_div_output = convert_binary_elementwise(
network,
target,
source_ir,
f"{name}_exp_floor_div",
trt.ElementWiseOperation.FLOOR_DIV,
input_exp_output,
input_abs_exp_output,
)

double_floor_div_output = convert_binary_elementwise(
network,
target,
source_ir,
f"{name}_floor_div*2",
trt.ElementWiseOperation.PROD,
floor_div_output,
2,
)

return convert_binary_elementwise(
network,
target,
source_ir,
f"{name}_sign",
trt.ElementWiseOperation.SUB,
double_floor_div_output,
1,
)
86 changes: 86 additions & 0 deletions py/torch_tensorrt/fx/test/converters/aten_op/test_permute_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec


class TestPermuteConverter(DispatchTestCase):
@parameterized.expand(
[
("positive", [0, 2, 1]),
("negative", [0, -1, -2]),
]
)
def test_permute_list(self, _, permutation):
class Permute(nn.Module):
def forward(self, x):
return x.permute(permutation)

inputs = [torch.randn(1, 3, 2)]
self.run_test(Permute(), inputs, expected_ops={torch.ops.aten.permute.default})

@parameterized.expand(
[
("positive", [0, 2, 1]),
("negative", [0, -1, -2]),
]
)
def test_permute(self, _, permutation):
class Permute(nn.Module):
def forward(self, x):
return x.permute(*permutation)

inputs = [torch.randn(1, 3, 2)]
self.run_test(Permute(), inputs, expected_ops={torch.ops.aten.permute.default})

@parameterized.expand(
[
("positive", (1, 2)),
("negative", (-1, -2)),
]
)
def test_transpose(self, _, dims):
class Transpose(nn.Module):
def forward(self, x):
return x.transpose(*dims)

inputs = [torch.randn(1, 2, 3)]
self.run_test(Transpose(), inputs, expected_ops={torch.ops.aten.transpose.int})

def test_permute_with_dynamic_shape(self):
class Permute(nn.Module):
def forward(self, x):
return x.permute(1, 2, 0)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
),
]
self.run_test_with_dynamic_shape(
Permute(), input_specs, expected_ops={torch.ops.aten.permute.default}
)

def test_permute_with_dynamic_shape_four_dimensions(self):
class Permute(nn.Module):
def forward(self, x):
return x.permute(1, 2, 3, 0)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
),
]

self.run_test_with_dynamic_shape(
Permute(), input_specs, expected_ops={torch.ops.aten.permute.default}
)


if __name__ == "__main__":
run_tests()
29 changes: 29 additions & 0 deletions py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec


class TestRSqrtConverter(DispatchTestCase):
@parameterized.expand(
[
("2d_dim_alpha", (2, 1), 2),
("3d_dim_alpha", (2, 1, 2), 2),
]
)
def test_rsqrt(self, _, x, alpha):
class rsqrt(nn.Module):
def forward(self, input):
return torch.rsqrt(input)

inputs = [torch.randn(x) + 1]
self.run_test(
rsqrt(),
inputs,
expected_ops={torch.ops.aten.rsqrt.default},
)


if __name__ == "__main__":
run_tests()