diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 57c720ffba..54c546b2c3 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -27,6 +27,14 @@ ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous from torch_tensorrt.fx.converters.impl import activation +from torch_tensorrt.fx.converters.impl import permute +from torch_tensorrt.fx.converters.impl.elementwise import trunc_div +from torch_tensorrt.fx.converters.impl.unary import sign +from torch_tensorrt.fx.converters.impl.elementwise.base import ( + convert_binary_elementwise, +) +from torch_tensorrt.fx.converters.impl.unary.base import convert_unary +from torch_tensorrt.fx.converters.impl.shape import get_shape_with_dynamic_shape _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -78,13 +86,14 @@ def trt_transposed_linear_converter(network, target, args, kwargs, name): trt.MatrixOperation.NONE, ) set_layer_name(layer, target, f"{name}_mm") - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - layer.get_output(0), - bias, - trt.ElementWiseOperation.SUM, target, + SourceIR.TORCHTRT_LOWERED, f"{name}_add", + trt.ElementWiseOperation.SUM, + layer.get_output(0), + bias, ) @@ -755,13 +764,14 @@ def layer_norm( set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") # X-E[x] - sub_trt = add_binary_elementwise_layer( + sub_trt = convert_binary_elementwise( network, - input_val, - mean_expected_layer.get_output(0), - trt.ElementWiseOperation.SUB, target, + SourceIR.ACC, f"{name}_sub", + trt.ElementWiseOperation.SUB, + input_val, + mean_expected_layer.get_output(0), ) # Variance = mean(pow(x_sub_mean,2)) pow_tensor = network.add_constant( @@ -769,13 +779,14 @@ def layer_norm( trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), ) pow_tensor.name = f"{name}_power" - pow_var = add_binary_elementwise_layer( + pow_var = convert_binary_elementwise( network, - sub_trt, - pow_tensor.get_output(0), - trt.ElementWiseOperation.POW, target, + SourceIR.ACC, f"{name}_pow_var", + trt.ElementWiseOperation.POW, + sub_trt, + pow_tensor.get_output(0), ) mean_trt_layer = network.add_reduce( pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True @@ -787,26 +798,33 @@ def layer_norm( trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), ) eps_tensor.name = f"{name}_eps" - add_trt = add_binary_elementwise_layer( + add_trt = convert_binary_elementwise( network, - mean_trt_layer.get_output(0), - eps_tensor.get_output(0), - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, f"{name}_add", + trt.ElementWiseOperation.SUM, + mean_trt_layer.get_output(0), + eps_tensor.get_output(0), ) # SQRT((Var + eps)) - sqrt_trt = add_unary_layer( - network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt" + sqrt_trt = convert_unary( + network, + target, + SourceIR.ACC, + f"{name}_sqrt", + trt.UnaryOperation.SQRT, + add_trt, ) # (x - E[x]) / sqrt((var + eps)) - div_trt = add_binary_elementwise_layer( + div_trt = convert_binary_elementwise( network, - sub_trt, - sqrt_trt, - trt.ElementWiseOperation.DIV, target, + SourceIR.ACC, f"{name}_div_trt", + trt.ElementWiseOperation.DIV, + sub_trt, + sqrt_trt, ) assert gamma is not None @@ -816,21 +834,23 @@ def layer_norm( beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) # type: ignore[attr-defined] beta_tensor.name = f"{name}_beta" # y * gamma + beta - scale_layer = add_binary_elementwise_layer( + scale_layer = convert_binary_elementwise( network, - div_trt, - gamma_tensor.get_output(0), - trt.ElementWiseOperation.PROD, target, + SourceIR.ACC, f"{name}_scale", + trt.ElementWiseOperation.PROD, + div_trt, + gamma_tensor.get_output(0), ) - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - scale_layer, - beta_tensor.get_output(0), - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.SUM, + scale_layer, + beta_tensor.get_output(0), ) @@ -933,13 +953,14 @@ def acc_ops_tile( else: d = get_trt_tensor(network, d, f"{name}_{i}") shape.append(d) - mul = add_binary_elementwise_layer( + mul = convert_binary_elementwise( network, - s, - d, - trt.ElementWiseOperation.PROD, target, + SourceIR.ACC, f"{name}_mul_{i}", + trt.ElementWiseOperation.PROD, + s, + d, ) shapes.append(mul) dims = shape @@ -968,13 +989,14 @@ def acc_ops_tile( dims_tensor = concat_dims_layer.get_output(0) input_shape_layer = network.add_shape(input_val) input_shape_layer.name = f"{name}_slice_input_shape" - slice_shapes_tensor = add_binary_elementwise_layer( + slice_shapes_tensor = convert_binary_elementwise( network, - input_shape_layer.get_output(0), - dims_tensor, - trt.ElementWiseOperation.PROD, target, + SourceIR.ACC, f"{name}_slice_shapes", + trt.ElementWiseOperation.PROD, + input_shape_layer.get_output(0), + dims_tensor, ) layer.set_input(1, starts_tensor) layer.set_input(2, slice_shapes_tensor) @@ -995,9 +1017,22 @@ def acc_ops_sign( if trt.__version__ >= "8.2" and not network.has_implicit_batch_dimension: input_val = kwargs["input"] operation_type = trt.UnaryOperation.SIGN - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) - return sign(network, input_val, target, name) + return sign( + network, + target, + SourceIR.ACC, + name, + input_val, + ) @tensorrt_converter(acc_ops.relu) @@ -1099,7 +1134,14 @@ def acc_ops_sin( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.SIN - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.cos) @@ -1112,7 +1154,14 @@ def acc_ops_cos( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.COS - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.tan) @@ -1125,7 +1174,14 @@ def acc_ops_tan( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.TAN - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.sinh) @@ -1138,7 +1194,14 @@ def acc_ops_sinh( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.SINH - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.cosh) @@ -1151,7 +1214,14 @@ def acc_ops_cosh( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.COSH - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.tanh) @@ -1181,7 +1251,14 @@ def acc_ops_asin( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.ASIN - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.acos) @@ -1194,7 +1271,14 @@ def acc_ops_acos( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.ACOS - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.atan) @@ -1207,7 +1291,14 @@ def acc_ops_atan( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.ATAN - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.exp) @@ -1220,7 +1311,14 @@ def acc_ops_exp( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.EXP - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.log) @@ -1233,7 +1331,14 @@ def acc_ops_log( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.LOG - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.sqrt) @@ -1246,7 +1351,14 @@ def acc_ops_sqrt( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.SQRT - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.reciprocal) @@ -1259,7 +1371,14 @@ def acc_ops_reciprocal( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.RECIP - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.abs) @@ -1272,7 +1391,14 @@ def acc_ops_abs( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.ABS - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.neg) @@ -1285,7 +1411,14 @@ def acc_ops_neg( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.NEG - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.floor) @@ -1298,7 +1431,14 @@ def acc_ops_floor( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.FLOOR - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.ceil) @@ -1311,7 +1451,14 @@ def acc_ops_ceil( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.CEIL - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.sum) @@ -1486,13 +1633,14 @@ def acc_ops_maximum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.MAX, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.MAX, + kwargs["input"], + kwargs["other"], ) @@ -1504,13 +1652,14 @@ def acc_ops_minimum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.MIN, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.MIN, + kwargs["input"], + kwargs["other"], ) @@ -1569,7 +1718,14 @@ def acc_ops_logical_not( # cast to bool type if input_val.dtype in (trt.float32, trt.float16, trt.int32): input_val = type_cast(network, target, f"{name}_input", input_val, trt.bool) - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.logical_and, no_implicit_batch_dim=True) @@ -1615,8 +1771,14 @@ def check_is_bool(input_t): input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool) if other_t.dtype != trt.bool: other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.AND, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.AND, + input_t, + other_t, ) @@ -1640,11 +1802,24 @@ def acc_ops_ne( other_t = get_trt_tensor(network, other_t, f"{name}_other_t") input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - eq_t = add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name + eq_t = convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.EQUAL, + input_t, + other_t, ) - return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + trt.UnaryOperation.NOT, + eq_t, + ) @tensorrt_converter(acc_ops.eq, no_implicit_batch_dim=True) @@ -1667,8 +1842,14 @@ def acc_ops_eq( other_t = get_trt_tensor(network, other_t, f"{name}_other_t") input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.EQUAL, + input_t, + other_t, ) @@ -1692,8 +1873,14 @@ def acc_ops_gt( other_t = get_trt_tensor(network, other_t, f"{name}_other_t") input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.GREATER, + input_t, + other_t, ) @@ -1717,8 +1904,14 @@ def acc_ops_lt( other_t = get_trt_tensor(network, other_t, f"{name}_other_t") input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.LESS, + input_t, + other_t, ) @@ -1754,8 +1947,14 @@ def acc_ops_logical_or( set_layer_name(layer_o, target, f"{name}_other_dtype_change") other_t = layer_o.get_output(0) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.OR, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.OR, + input_t, + other_t, ) @@ -1791,8 +1990,14 @@ def acc_ops_logical_xor( set_layer_name(layer_o, target, f"{name}_other_dtype_change") other_t = layer_o.get_output(0) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.XOR, + input_t, + other_t, ) @@ -1889,23 +2094,30 @@ def acc_ops_fmod( ) -> Union[TRTTensor, Sequence[TRTTensor]]: # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it trunc_div_value = trunc_div( - kwargs["input"], kwargs["other"], network, target, name + "_trunc_div" - ) - prod_value = add_binary_elementwise_layer( network, - trunc_div_value, + target, + SourceIR.ACC, + name + "_trunc_div", + kwargs["input"], kwargs["other"], - trt.ElementWiseOperation.PROD, + ) + prod_value = convert_binary_elementwise( + network, target, + SourceIR.ACC, name + "_prod", + trt.ElementWiseOperation.PROD, + trunc_div_value, + kwargs["other"], ) - sub_value = add_binary_elementwise_layer( + sub_value = convert_binary_elementwise( network, - kwargs["input"], - prod_value, - trt.ElementWiseOperation.SUB, target, + SourceIR.ACC, name + "_sub", + trt.ElementWiseOperation.SUB, + kwargs["input"], + prod_value, ) return sub_value @@ -2136,13 +2348,14 @@ def acc_ops_add( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.SUM, + kwargs["input"], + kwargs["other"], ) @@ -2154,13 +2367,14 @@ def acc_ops_sub( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.SUB, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.SUB, + kwargs["input"], + kwargs["other"], ) @@ -2172,13 +2386,14 @@ def acc_ops_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.DIV, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.DIV, + kwargs["input"], + kwargs["other"], ) @@ -2190,13 +2405,14 @@ def acc_ops_floor_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.FLOOR_DIV, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.FLOOR_DIV, + kwargs["input"], + kwargs["other"], ) @@ -2208,7 +2424,14 @@ def acc_ops_trunc_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return trunc_div(kwargs["input"], kwargs["other"], network, target, name) + return trunc_div( + network, + target, + SourceIR.ACC, + name, + kwargs["input"], + kwargs["other"], + ) @tensorrt_converter(acc_ops.mul) @@ -2219,13 +2442,14 @@ def acc_ops_mul( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.PROD, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.PROD, + kwargs["input"], + kwargs["other"], ) @@ -2237,13 +2461,14 @@ def acc_ops_pow( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["exponent"], - trt.ElementWiseOperation.POW, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.POW, + kwargs["input"], + kwargs["exponent"], ) @@ -2549,7 +2774,12 @@ def acc_ops_slice_tensor( if dynamic_shape > 0: output_shape = get_shape_with_dynamic_shape( - network, output_shape, input_val, target, name + network, + target, + SourceIR.ACC, + name, + output_shape, + input_val, ) layer = network.add_slice( input_val, @@ -2793,7 +3023,12 @@ def acc_ops_split( start[dim] = offset if dynamic_shape: shape = get_shape_with_dynamic_shape( - network, shape, input_val, target, f"{name}_shape_{i}" + network, + target, + SourceIR.ACC, + f"{name}_shape_{i}", + shape, + input_val, ) layer = network.add_slice( input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride @@ -2857,13 +3092,14 @@ def acc_ops_linear( if kwargs["bias"] is not None: bias = get_trt_tensor(network, kwargs["bias"], f"{name}_bias") # type: ignore[arg-type] - res = add_binary_elementwise_layer( + res = convert_binary_elementwise( network, - matmul_layer.get_output(0), - bias, - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, f"{name}_add", + trt.ElementWiseOperation.SUM, + matmul_layer.get_output(0), + bias, ) return res @@ -3049,7 +3285,14 @@ def slice_to_trt_params(py_slice, dim_size): i += 1 if dynamic_shape: - size = get_shape_with_dynamic_shape(network, size, input_val, target, name) + size = get_shape_with_dynamic_shape( + network, + target, + SourceIR.ACC, + name, + size, + input_val, + ) layer = network.add_slice( input=input_val, @@ -3204,27 +3447,10 @@ def acc_ops_permute( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] - ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] - if len(kwargs["permutation"]) == 1: - index = kwargs["permutation"][0] - else: - index = kwargs["permutation"] - permutation = [get_positive_dim(i, ranks) for i in cast(Sequence[int], index)] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"permute received input {input_val} that is not part " - "of the TensorRT region!" - ) - - if network.has_implicit_batch_dimension: - assert permutation[0] == 0, "Can't permute batch dimension when it's implicit." - permutation = [i - 1 for i in permutation[1:]] - - layer = network.add_shuffle(input_val) - layer.second_transpose = tuple(permutation) - set_layer_name(layer, target, name) - return layer.get_output(0) + index = kwargs["permutation"] + return permute.permute( + network, target, SourceIR.ACC, name, input_val=input_val, index=index + ) @tensorrt_converter(acc_ops.quantize_per_tensor) @@ -3474,7 +3700,12 @@ def acc_ops_chunk( shape[dim] = min(split_size, max_offset - offset) if dynamic_shape: shape = get_shape_with_dynamic_shape( - network, shape, input_val, target, f"{name}_{i}" + network, + target, + SourceIR.ACC, + f"{name}_{i}", + shape, + input_val, ) start[dim] = offset layer = network.add_slice( @@ -3537,13 +3768,14 @@ def acc_ops_cumsum( set_layer_name(running_sum, target, f"{name}_running_sum_1") running_sum_tensor = running_sum.get_output(0) - current_sum = add_binary_elementwise_layer( + current_sum = convert_binary_elementwise( network, - data, - running_sum_tensor, - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, f"{name}_sum_1", + trt.ElementWiseOperation.SUM, + data, + running_sum_tensor, ) running_sum.set_input(1, current_sum) @@ -3551,13 +3783,14 @@ def acc_ops_cumsum( set_layer_name(running_sum, target, f"{name}_running_sum_2") running_sum_tensor = running_sum.get_output(0) - current_sum = add_binary_elementwise_layer( + current_sum = convert_binary_elementwise( network, - data, - running_sum_tensor, - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, f"{name}_sum_2", + trt.ElementWiseOperation.SUM, + data, + running_sum_tensor, ) running_sum.set_input(1, current_sum) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 82847cc760..9e6dff5c69 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -23,6 +23,9 @@ from .converter_utils import * # noqa: F403 import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils from torch_tensorrt.fx.converters.impl import activation +from torch_tensorrt.fx.converters.impl import permute +from torch_tensorrt.fx.converters.impl.elementwise import trunc_div +from torch_tensorrt.fx.converters.impl.elementwise import rsqrt _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -161,9 +164,7 @@ def aten_ops_div( network, target, None, kwargs_new, name ) elif rounding_mode == "trunc": - return acc_ops_converters.acc_ops_trunc_div( - network, target, None, kwargs_new, name - ) + return trunc_div(network, target, SourceIR.ATEN, name, args[0], args[1]) else: raise RuntimeError( f"Target {target} does not support rounding mode {rounding_mode}" @@ -335,6 +336,24 @@ def aten_ops_mul( return acc_ops_converters.acc_ops_mul(network, target, None, kwargs_new, name) +@tensorrt_converter(torch.ops.aten.permute.default) +def aten_ops_permute_default( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return permute.permute( + network, + target, + SourceIR.ATEN, + name=name, + input_val=args[0], + index=args[1:], + ) + + @tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) @tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) def aten_ops_pow( @@ -369,6 +388,42 @@ def aten_ops_relu( ) +@tensorrt_converter(torch.ops.aten.relu.default) +def aten_ops_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return activation.relu( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@tensorrt_converter(torch.ops.aten.rsqrt.default) +def aten_ops_rsqrt( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return rsqrt( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @tensorrt_converter(torch.ops.aten.sub.Tensor) def aten_ops_sub( network: TRTNetwork, @@ -384,6 +439,38 @@ def aten_ops_sub( return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name) +@tensorrt_converter(torch.ops.aten.transpose.int) +def aten_ops_transpose_int( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + input_val = args[0] + ndim = len(input_val.shape) + if len(args) == 1: + # default is to reverse dimensions + new_order = torch.arange(0, start=ndim - 1, step=-1) + else: + assert ( + len(args) == 3 + ), f"Wrong number of arguments to transpose(): {len(args)-1}" + new_order = torch.arange(ndim) + dim0 = args[1] + if args[1] < 0: + dim0 = dim0 + ndim + dim1 = args[2] + if args[2] < 0: + dim1 = dim1 + ndim + new_order[dim0] = dim1 + new_order[dim1] = dim0 + print("New order: ", new_order) + return permute.permute( + network, target, SourceIR.ATEN, name=name, input_val=input_val, index=new_order + ) + + @tensorrt_converter(torch.ops.aten.view.default) def aten_ops_reshape( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index d13be41d05..432ec9eecd 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -28,6 +28,7 @@ class SourceIR(Enum): ACC = auto() ATEN = auto() PRIM = auto() + TORCHTRT_LOWERED = auto() UNKNOWN = auto() def __str__(self): @@ -39,6 +40,8 @@ def __str__(self): return "aten" elif self == SourceIR.PRIM: return "prim" + elif self == SourceIR.TORCHTRT_LOWERED: + return "torchtrt_lowered" else: return "unknown_ir" @@ -383,171 +386,6 @@ def broadcast( return a, b -def get_shape_with_dynamic_shape( - network: TRTNetwork, - shape: Union[list, tuple, torch.Tensor], - input_val: TRTTensor, - target: Target, - name: str, -) -> TRTTensor: - """ - Prepare the real output tensor shape for dynamic shape mode tensor input. - How this functions works: - Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation - output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual - reduce operation output shape. Steps of calculations are: - 1. get the actual tensor shape of input_val via add_shape layer; - 2. create a all 0 tensor [0, 0, 0]; - 3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False]; - 4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace - all -1 dynamic shape dimensions with actual batch_size value; - 5. output shape with actual batch_size as [2048, 128, 256] - - Args: - network (TRTNetwork): TensorRT network object. - shape: calculated shape of the expected output tensor - input_val (TRTTensor): A TensorRT ITensor. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - Returns: - TensorRT ITensors that represents the actual shape of the input_val - """ - # Ger real shape info for input_val - input_shape = network.add_shape(input_val).get_output(0) - - scale_layer = network.add_constant( - input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) - ) - set_layer_name(scale_layer, target, f"{name}_scale") - scale_res = scale_layer.get_output(0) - - length = input_shape.shape[0] - zero_layer = network.add_constant( - input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) - ) - set_layer_name(zero_layer, target, f"{name}_zeros") - - condition_val = add_binary_elementwise_layer( - network, - scale_res, - zero_layer.get_output(0), - trt.ElementWiseOperation.LESS, - target, - f"{name}_shape", - ) - select_layer = network.add_select(condition_val, input_shape, scale_res) - set_layer_name(select_layer, target, f"{name}_select") - return select_layer.get_output(0) - - -def add_binary_elementwise_layer( - network: TRTNetwork, - lhs_val: Union[int, float, TRTTensor, torch.Tensor], - rhs_val: Union[int, float, TRTTensor, torch.Tensor], - op_type: trt.ElementWiseOperation, - target: Target, - name: str, -) -> TRTTensor: - """ - This function adds a TensorRT elementwise layer. We allow both operands to be - constant (not a trt tensor) because in implicit batch dimension mode, we could - introduce constant via .size() op. Other scenario should be const folded first. - If any operand is not a trt tensor, we make it a trt constant layer while preserve - its dtype. Then we broadcast these two inputs to have the same number of dimensions. - - Limitation: - If we are using implicit batch dim mode, the operand that is not a trt - tensor are not allowed to have larger ranks than the trt tensor operand. - - Args: - network (TRTNetwork): TensorRT network object. - lhs_val (TRTTensor): Left operand of the binary operation. Could - be a TensorRT tensor, a PyTorch tensor or a simple value. - rhs_val (TRTTensor): Right operand of the binary operation. Similar - to lhs_val. - op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - - Returns: - The output of TensorRT Elementwise layer. - """ - lhs_dtype = None - rhs_dtype = None - is_lhs_trt_tensor = False - is_rhs_trt_tensor = False - - if isinstance(lhs_val, TRTTensor): - lhs_dtype = torch_dtype_from_trt(lhs_val.dtype) - is_lhs_trt_tensor = True - if isinstance(rhs_val, TRTTensor): - rhs_dtype = torch_dtype_from_trt(rhs_val.dtype) - is_rhs_trt_tensor = True - - if not is_lhs_trt_tensor and not is_rhs_trt_tensor: - warnings.warn( - f"Both operands of the binary elementwise op {name} " - "are constant. In this case, please consider constant fold the model first." - ) - return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val) - - # If the following conditions are true: - # 1. the network has implicit batch dimension, - # 2. one operand has shape [] (real shape is [batch_size]), - # 3. another operand is a scalar, - # then the result should also have shape [] (real shape is [batch_size]). - # - # In such case, we need to convert the scalar operand to tensor, because - # this way the shape will become [1], and then will be properly squeezed - # into [], meaning that the result will have shape [], which is what we - # expect. - # - # Note that the dtype here is supposed to be the same as the scalar - # dtype but we don't have a way to detect whether it makes sense for the - # scalar to be float or half. Hence we go with the lhs dtype. - if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): - rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) - if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): - lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) - - # When lhs is scalar, and rhs has shape [1,], then currently the assert - # will fail because lhs shape has fewer dimensions than rhs shape. This - # happens when using implicit batch dimension, when we removed the 1st - # dimension from input tensor, causing it to have shape [] - a scalar. We - # fix it by reducing the rhs constant with a squeeze_left, so it becomes a - # scalar too. More generally, we squeeze_left on input if it's a constant - # tensor. This is safe because broadcast will pad dimensions on the left - # (prepend) to make lhs and rhs shape compatible. - if network.has_implicit_batch_dimension: - if isinstance(lhs_val, torch.Tensor): - lhs_val = squeeze_left(lhs_val) - if isinstance(rhs_val, torch.Tensor): - rhs_val = squeeze_left(rhs_val) - - lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) - rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) - - # Check the limitation in the doc string. - if network.has_implicit_batch_dimension: - if is_lhs_trt_tensor and not is_rhs_trt_tensor: - assert len(lhs_val.shape) >= len( - rhs_val.shape - ), f"{lhs_val.shape} >= {rhs_val.shape}" - elif not is_lhs_trt_tensor and is_rhs_trt_tensor: - assert len(rhs_val.shape) >= len( - lhs_val.shape - ), f"{rhs_val.shape} >= {lhs_val.shape}" - - lhs_val, rhs_val = broadcast( - network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" - ) - layer = network.add_elementwise(lhs_val, rhs_val, op_type) - set_layer_name(layer, target, name) - output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ - return output - - def squeeze_left(const: torch.Tensor): """ Squeeze the size-1 dimensions on the left side of the shape tuple. @@ -559,38 +397,6 @@ def squeeze_left(const: torch.Tensor): return const -def add_unary_layer( - network: TRTNetwork, - input_val: TRTTensor, - operation_type: trt.UnaryOperation, - target: Target, - name: str, -) -> TRTTensor: - """ - Add a TensorRT Unary layer to `network`. - - Args: - network (TRTNetwork): TensorRT network object. - input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor. - op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - - Returns: - The output of TensorRT Unary layer. - """ - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"{operation_type} received input {input_val} that is not part " - "of the TensorRT region!" - ) - layer = network.add_unary(input_val, operation_type) - set_layer_name(layer, target, name) - output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ - return layer.get_output(0) - - def add_reduce_layer( network: TRTNetwork, target: Target, @@ -695,139 +501,6 @@ def get_inputs_from_args_and_kwargs(args, kwargs, input_names): return inputs -def sign( - network: TRTNetwork, input_val: TRTTensor, target: Target, name: str -) -> TRTTensor: - """ - Sign is calculated as below: - x = input - sign = (exp(x) // exp(abs(x))) * 2 - 1 - For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0. - With multiply 2, the value become 2(for pos and 0) and 0(for neg). - Finally minus 1, the value become 1(for pos and 0) and -1(for neg). - - Args: - network (TRTNetwork): TensorRT network object. - input_val (TRTTensor): The input tensor. - target (Target): fx node target. - name (str): Name of the fx node with optional suffix. - - Returns: - A TensorRT tensor represent the result of sign operator. - """ - input_exp_output = add_unary_layer( - network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp" - ) - input_abs_output = add_unary_layer( - network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs" - ) - input_abs_exp_output = add_unary_layer( - network, - input_abs_output, - trt.UnaryOperation.EXP, - target, - f"{name}_prod_abs_exp", - ) - floor_div_output = add_binary_elementwise_layer( - network, - input_exp_output, - input_abs_exp_output, - trt.ElementWiseOperation.FLOOR_DIV, - target, - f"{name}_exp_floor_div", - ) - double_floor_div_output = add_binary_elementwise_layer( - network, - floor_div_output, - 2, - trt.ElementWiseOperation.PROD, - target, - f"{name}_floor_div*2", - ) - return add_binary_elementwise_layer( - network, - double_floor_div_output, - 1, - trt.ElementWiseOperation.SUB, - target, - f"{name}_sign", - ) - - -def trunc_div( - input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str -) -> TRTTensor: - """ - Perform trunc divide on Tensor, result of divide will be round toward zero. - This means for positive number, it will be floor round; for negative number, - it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3]. - - Args: - input: divisor. - other: dividend. - network: INetworkDefinition. - target: node target. - name: namespace for the op - - Returns: - A TensorRT tensor represent the result of trunc divide. - """ - prod_output = add_binary_elementwise_layer( - network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod" - ) - sign_output = sign(network, prod_output, target, name) - - # Convert constant input into ITensor for UnaryOperation - if not isinstance(input, trt.tensorrt.ITensor): - input = get_trt_tensor(network, input, f"{name}_input") - if not isinstance(other, trt.tensorrt.ITensor): - other = get_trt_tensor( - network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype) - ) - - abs_input_output = add_unary_layer( - network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_input" - ) - abs_other_output = add_unary_layer( - network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_other" - ) - abs_floor_output = add_binary_elementwise_layer( - network, - abs_input_output, - abs_other_output, - trt.ElementWiseOperation.FLOOR_DIV, - target, - f"{name}_floor_div", - ) - output = add_binary_elementwise_layer( - network, - abs_floor_output, - sign_output, - trt.ElementWiseOperation.PROD, - target, - f"{name}_output", - ) - - return output - - -def get_python_op_from_trt_elementwise_op( - trt_op: TRTElementWiseOp, -) -> Callable[[Any, Any], Any]: - if trt_op == trt.ElementWiseOperation.SUM: - return operator.add - elif trt_op == trt.ElementWiseOperation.PROD: - return operator.mul - elif trt_op == trt.ElementWiseOperation.SUB: - return operator.sub - elif trt_op == trt.ElementWiseOperation.DIV: - return operator.truediv - elif trt_op == trt.ElementWiseOperation.FLOOR_DIV: - return operator.floordiv - else: - raise RuntimeError(f"{trt_op} is not supported yet!") - - def dtype_uniform( network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor ): diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/__init__.py b/py/torch_tensorrt/fx/converters/impl/elementwise/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/base.py b/py/torch_tensorrt/fx/converters/impl/elementwise/base.py new file mode 100644 index 0000000000..261e45728f --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/base.py @@ -0,0 +1,147 @@ +import operator +import warnings +from typing import Union, Callable, Any, Optional + +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp +from torch_tensorrt.fx.utils import torch_dtype_from_trt +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, + set_layer_name, + broadcast, + squeeze_left, + get_trt_tensor, +) + + +def get_python_op_from_trt_elementwise_op( + trt_op: TRTElementWiseOp, +) -> Callable[[Any, Any], Any]: + if trt_op == trt.ElementWiseOperation.SUM: + return operator.add + elif trt_op == trt.ElementWiseOperation.PROD: + return operator.mul + elif trt_op == trt.ElementWiseOperation.SUB: + return operator.sub + elif trt_op == trt.ElementWiseOperation.DIV: + return operator.truediv + elif trt_op == trt.ElementWiseOperation.FLOOR_DIV: + return operator.floordiv + else: + raise RuntimeError(f"{trt_op} is not supported yet!") + + +def convert_binary_elementwise( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + op_type: trt.ElementWiseOperation, + lhs_val: Union[int, float, TRTTensor, torch.Tensor], + rhs_val: Union[int, float, TRTTensor, torch.Tensor], +) -> TRTTensor: + """ + This function adds a TensorRT elementwise layer. We allow both operands to be + constant (not a trt tensor) because in implicit batch dimension mode, we could + introduce constant via .size() op. Other scenario should be const folded first. + If any operand is not a trt tensor, we make it a trt constant layer while preserve + its dtype. Then we broadcast these two inputs to have the same number of dimensions. + + Limitation: + If we are using implicit batch dim mode, the operand that is not a trt + tensor are not allowed to have larger ranks than the trt tensor operand. + + Args: + network (TRTNetwork): TensorRT network object. + target (Target): Target of fx node. + source_ir (SourceIR): The IR that is calling the function. + name (str): The name we want to assign to the created TensorRT layer. + lhs_val (TRTTensor): Left operand of the binary operation. Could + be a TensorRT tensor, a PyTorch tensor or a simple value. + rhs_val (TRTTensor): Right operand of the binary operation. Similar + to lhs_val. + op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation. + + Returns: + The output of TensorRT Elementwise layer. + """ + lhs_dtype = None + rhs_dtype = None + is_lhs_trt_tensor = False + is_rhs_trt_tensor = False + + if isinstance(lhs_val, TRTTensor): + lhs_dtype = torch_dtype_from_trt(lhs_val.dtype) + is_lhs_trt_tensor = True + if isinstance(rhs_val, TRTTensor): + rhs_dtype = torch_dtype_from_trt(rhs_val.dtype) + is_rhs_trt_tensor = True + + if not is_lhs_trt_tensor and not is_rhs_trt_tensor: + warnings.warn( + f"Both operands of the binary elementwise op {name} " + "are constant. In this case, please consider constant fold the model first." + ) + return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val) + + # If the following conditions are true: + # 1. the network has implicit batch dimension, + # 2. one operand has shape [] (real shape is [batch_size]), + # 3. another operand is a scalar, + # then the result should also have shape [] (real shape is [batch_size]). + # + # In such case, we need to convert the scalar operand to tensor, because + # this way the shape will become [1], and then will be properly squeezed + # into [], meaning that the result will have shape [], which is what we + # expect. + # + # Note that the dtype here is supposed to be the same as the scalar + # dtype but we don't have a way to detect whether it makes sense for the + # scalar to be float or half. Hence we go with the lhs dtype. + if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): + rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) + if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): + lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) + + # When lhs is scalar, and rhs has shape [1,], then currently the assert + # will fail because lhs shape has fewer dimensions than rhs shape. This + # happens when using implicit batch dimension, when we removed the 1st + # dimension from input tensor, causing it to have shape [] - a scalar. We + # fix it by reducing the rhs constant with a squeeze_left, so it becomes a + # scalar too. More generally, we squeeze_left on input if it's a constant + # tensor. This is safe because broadcast will pad dimensions on the left + # (prepend) to make lhs and rhs shape compatible. + if network.has_implicit_batch_dimension: + if isinstance(lhs_val, torch.Tensor): + lhs_val = squeeze_left(lhs_val) + if isinstance(rhs_val, torch.Tensor): + rhs_val = squeeze_left(rhs_val) + + lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) + rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) + + # Check the limitation in the doc string. + if network.has_implicit_batch_dimension: + if is_lhs_trt_tensor and not is_rhs_trt_tensor: + assert len(lhs_val.shape) >= len( + rhs_val.shape + ), f"{lhs_val.shape} >= {rhs_val.shape}" + elif not is_lhs_trt_tensor and is_rhs_trt_tensor: + assert len(rhs_val.shape) >= len( + lhs_val.shape + ), f"{rhs_val.shape} >= {lhs_val.shape}" + + lhs_val, rhs_val = broadcast( + network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ) + layer = network.add_elementwise(lhs_val, rhs_val, op_type) + set_layer_name(layer, target, name, source_ir) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return output diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py new file mode 100644 index 0000000000..8fddb426a6 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py @@ -0,0 +1,141 @@ +import operator +import warnings +from typing import Union, Callable, Any, Optional + +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp +from torch_tensorrt.fx.utils import torch_dtype_from_trt +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, + get_trt_tensor, +) + +from torch_tensorrt.fx.converters.impl.elementwise.base import ( + convert_binary_elementwise, +) +from torch_tensorrt.fx.converters.impl.unary.base import convert_unary +from torch_tensorrt.fx.converters.impl.unary import sign + + +def trunc_div( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + """ + Perform trunc divide on Tensor, result of divide will be round toward zero. + This means for positive number, it will be floor round; for negative number, + it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3]. + + Args: + network: INetworkDefinition. + target: node target + source_ir (SourceIR): Source IR calling the function. + name: namespace for the op + input: divisor. + other: dividend. + + Returns: + A TensorRT tensor represent the result of trunc divide. + """ + prod_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_prod", + trt.ElementWiseOperation.PROD, + input, + other, + ) + + sign_output = sign( + network, + target, + source_ir, + name, + prod_output, + ) + + # Convert constant input into ITensor for UnaryOperation + if not isinstance(input, trt.tensorrt.ITensor): + input = get_trt_tensor(network, input, f"{name}_input") + if not isinstance(other, trt.tensorrt.ITensor): + other = get_trt_tensor( + network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype) + ) + + abs_input_output = convert_unary( + network, + target, + source_ir, + f"{name}_abs_input", + trt.UnaryOperation.ABS, + input, + ) + abs_other_output = convert_unary( + network, + target, + source_ir, + f"{name}_abs_other", + trt.UnaryOperation.ABS, + other, + ) + abs_floor_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_floor_div", + trt.ElementWiseOperation.FLOOR_DIV, + abs_input_output, + abs_other_output, + ) + output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_output", + trt.ElementWiseOperation.PROD, + abs_floor_output, + sign_output, + ) + + return output + + +def rsqrt( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + + sqrt_trt_output = convert_unary( + network, + target, + source_ir, + f"{name}_sqrt", + trt.UnaryOperation.SQRT, + input, + ) + + output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_output", + trt.ElementWiseOperation.DIV, + 1, + sqrt_trt_output, + ) + + return output diff --git a/py/torch_tensorrt/fx/converters/impl/permute.py b/py/torch_tensorrt/fx/converters/impl/permute.py new file mode 100644 index 0000000000..dabe4de5bc --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/permute.py @@ -0,0 +1,46 @@ +import numpy as np +import operator +import warnings +from typing import cast, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch.fx.node import Argument, Target + +from ..converter_utils import * # noqa: F403 +from ...utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + + +def permute( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + index: Sequence[TRTTensor], +) -> Union[TRTTensor, Sequence[TRTTensor]]: + ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] + if len(index) == 1: + index = index[0] + permutation = [get_positive_dim(i, ranks) for i in cast(Sequence[int], index)] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"permute received input {input_val} that is not part " + "of the TensorRT region!" + ) + + if network.has_implicit_batch_dimension: + assert permutation[0] == 0, "Can't permute batch dimension when it's implicit." + permutation = [i - 1 for i in permutation[1:]] + + layer = network.add_shuffle(input_val) + layer.second_transpose = tuple(permutation) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/fx/converters/impl/shape.py b/py/torch_tensorrt/fx/converters/impl/shape.py new file mode 100644 index 0000000000..8667c712b8 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/shape.py @@ -0,0 +1,81 @@ +import operator +import warnings +from typing import Union, Callable, Any, Optional + +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp +from torch_tensorrt.fx.utils import torch_dtype_from_trt +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, + set_layer_name, + to_numpy, +) + +from torch_tensorrt.fx.converters.impl.elementwise.base import ( + convert_binary_elementwise, +) + + +def get_shape_with_dynamic_shape( + network: TRTNetwork, + target: Target, + source_ir: SourceIR, + name: str, + shape: Union[list, tuple, torch.Tensor], + input_val: TRTTensor, +) -> TRTTensor: + """ + Prepare the real output tensor shape for dynamic shape mode tensor input. + How this functions works: + Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation + output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual + reduce operation output shape. Steps of calculations are: + 1. get the actual tensor shape of input_val via add_shape layer; + 2. create a all 0 tensor [0, 0, 0]; + 3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False]; + 4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace + all -1 dynamic shape dimensions with actual batch_size value; + 5. output shape with actual batch_size as [2048, 128, 256] + + Args: + network (TRTNetwork): TensorRT network object. + shape: calculated shape of the expected output tensor + input_val (TRTTensor): A TensorRT ITensor. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + Returns: + TensorRT ITensors that represents the actual shape of the input_val + """ + # Ger real shape info for input_val + input_shape = network.add_shape(input_val).get_output(0) + + scale_layer = network.add_constant( + input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) + ) + set_layer_name(scale_layer, target, f"{name}_scale") + scale_res = scale_layer.get_output(0) + + length = input_shape.shape[0] + zero_layer = network.add_constant( + input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) + ) + set_layer_name(zero_layer, target, f"{name}_zeros") + + condition_val = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_shape", + trt.ElementWiseOperation.LESS, + scale_res, + zero_layer.get_output(0), + ) + select_layer = network.add_select(condition_val, input_shape, scale_res) + set_layer_name(select_layer, target, f"{name}_select") + return select_layer.get_output(0) diff --git a/py/torch_tensorrt/fx/converters/impl/unary/__init__.py b/py/torch_tensorrt/fx/converters/impl/unary/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/unary/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/fx/converters/impl/unary/base.py b/py/torch_tensorrt/fx/converters/impl/unary/base.py new file mode 100644 index 0000000000..fea6334170 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/unary/base.py @@ -0,0 +1,53 @@ +import operator +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from enum import Enum, auto + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + +from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name + +from torch_tensorrt.fx.converters.impl.elementwise.base import ( + convert_binary_elementwise, +) + + +def convert_unary( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + operation_type: trt.UnaryOperation, + input_val: TRTTensor, +) -> TRTTensor: + """ + Add a TensorRT Unary layer to `network`. + + Args: + network (TRTNetwork): TensorRT network object. + input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor. + op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + + Returns: + The output of TensorRT Unary layer. + """ + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_unary(input_val, operation_type) + set_layer_name(layer, target, name, source_ir) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return layer.get_output(0) diff --git a/py/torch_tensorrt/fx/converters/impl/unary/ops.py b/py/torch_tensorrt/fx/converters/impl/unary/ops.py new file mode 100644 index 0000000000..cba760736a --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/unary/ops.py @@ -0,0 +1,104 @@ +import operator +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from enum import Enum, auto + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, +) + +from torch_tensorrt.fx.converters.impl.elementwise.base import ( + convert_binary_elementwise, +) +from torch_tensorrt.fx.converters.impl.unary.base import convert_unary + + +def sign( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + """ + Sign is calculated as below: + x = input + sign = (exp(x) // exp(abs(x))) * 2 - 1 + For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0. + With multiply 2, the value become 2(for pos and 0) and 0(for neg). + Finally minus 1, the value become 1(for pos and 0) and -1(for neg). + + Args: + network (TRTNetwork): TensorRT network object. + target (Target): fx node target. + source_ir (SourceIR): Source IR calling the function + name (str): Name of the fx node with optional suffix. + input_val (TRTTensor): The input tensor. + + Returns: + A TensorRT tensor represent the result of sign operator. + """ + input_exp_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_exp", + trt.UnaryOperation.EXP, + input_val, + ) + input_abs_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_abs", + trt.UnaryOperation.ABS, + input_val, + ) + input_abs_exp_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_abs_exp", + trt.UnaryOperation.EXP, + input_abs_output, + ) + + floor_div_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_exp_floor_div", + trt.ElementWiseOperation.FLOOR_DIV, + input_exp_output, + input_abs_exp_output, + ) + + double_floor_div_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_floor_div*2", + trt.ElementWiseOperation.PROD, + floor_div_output, + 2, + ) + + return convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_sign", + trt.ElementWiseOperation.SUB, + double_floor_div_output, + 1, + ) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_permute_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_permute_aten.py new file mode 100644 index 0000000000..8f42a25311 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_permute_aten.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestPermuteConverter(DispatchTestCase): + @parameterized.expand( + [ + ("positive", [0, 2, 1]), + ("negative", [0, -1, -2]), + ] + ) + def test_permute_list(self, _, permutation): + class Permute(nn.Module): + def forward(self, x): + return x.permute(permutation) + + inputs = [torch.randn(1, 3, 2)] + self.run_test(Permute(), inputs, expected_ops={torch.ops.aten.permute.default}) + + @parameterized.expand( + [ + ("positive", [0, 2, 1]), + ("negative", [0, -1, -2]), + ] + ) + def test_permute(self, _, permutation): + class Permute(nn.Module): + def forward(self, x): + return x.permute(*permutation) + + inputs = [torch.randn(1, 3, 2)] + self.run_test(Permute(), inputs, expected_ops={torch.ops.aten.permute.default}) + + @parameterized.expand( + [ + ("positive", (1, 2)), + ("negative", (-1, -2)), + ] + ) + def test_transpose(self, _, dims): + class Transpose(nn.Module): + def forward(self, x): + return x.transpose(*dims) + + inputs = [torch.randn(1, 2, 3)] + self.run_test(Transpose(), inputs, expected_ops={torch.ops.aten.transpose.int}) + + def test_permute_with_dynamic_shape(self): + class Permute(nn.Module): + def forward(self, x): + return x.permute(1, 2, 0) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Permute(), input_specs, expected_ops={torch.ops.aten.permute.default} + ) + + def test_permute_with_dynamic_shape_four_dimensions(self): + class Permute(nn.Module): + def forward(self, x): + return x.permute(1, 2, 3, 0) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + Permute(), input_specs, expected_ops={torch.ops.aten.permute.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py new file mode 100644 index 0000000000..3fa27af1a0 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSqrtConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsqrt(self, _, x, alpha): + class rsqrt(nn.Module): + def forward(self, input): + return torch.rsqrt(input) + + inputs = [torch.randn(x) + 1] + self.run_test( + rsqrt(), + inputs, + expected_ops={torch.ops.aten.rsqrt.default}, + ) + + +if __name__ == "__main__": + run_tests()