diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 3fabb1bb45..1261062cf4 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -2,5 +2,6 @@ from ._TRTInterpreter import * # noqa: F403 from .aten_ops_converters import * # noqa: F403 from .conversion import * # noqa: F403 -from .op_evaluators import * # noqa: F403 +from .ops_evaluators import * # noqa: F403 +from .prims_ops_converters import * # noqa: F403 from .truncate_long_and_double import repair_long_or_double_inputs diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 67ce83469f..318befe945 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -27,6 +27,24 @@ def args_bounds_check( return args[i] if len(args) > i else replacement +def get_ir(target: Target) -> SourceIR: + target_module = getattr(target, "__module__", "None") + if any( + target_module.startswith(prefix) + for prefix in ("torch.ops.prims", "torch._ops.prims") + ): + return SourceIR.ATEN + elif any( + target_module.startswith(prefix) + for prefix in ("torch.ops.prims", "torch._ops.prims") + ): + return SourceIR.PRIM + elif target_module.startswith("torch.nn"): + return SourceIR.NN + + return SourceIR.UNKNOWN + + @dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc] def aten_ops_batch_norm( ctx: ConversionContext, @@ -674,6 +692,7 @@ def aten_ops_amax( @dynamo_tensorrt_converter(torch.ops.aten.sum.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.prims.sum.default) # type: ignore[misc] def aten_ops_sum( ctx: ConversionContext, target: Target, @@ -681,16 +700,29 @@ def aten_ops_sum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.reduce.sum( + sum_ = impl.reduce.sum( ctx, target, - SourceIR.ATEN, + get_ir(target), name, args[0], args_bounds_check(args, 1, replacement=None), args_bounds_check(args, 2, replacement=False), ) + if kwargs.get("output_dtype", None) is not None: + return impl.cast.to_copy( + ctx, + target, + SourceIR.ATEN, + name, + sum_, + kwargs["output_dtype"], + force_layer=False, + ) + else: + return sum_ + @dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc] def aten_ops_exp( @@ -1189,6 +1221,7 @@ def aten_ops_sub( @dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.prims.div.default) # type: ignore[misc] def aten_ops_div( ctx: ConversionContext, target: Target, @@ -1202,7 +1235,7 @@ def aten_ops_div( return impl.elementwise.div( ctx, target, - SourceIR.ATEN, + get_ir(target), name, args[0], args[1], @@ -1211,7 +1244,7 @@ def aten_ops_div( return impl.elementwise.floor_divide( ctx, target, - SourceIR.ATEN, + get_ir(target), name, args[0], args[1], @@ -1220,7 +1253,7 @@ def aten_ops_div( return impl.elementwise.trunc_div( ctx, target, - SourceIR.ATEN, + get_ir(target), name, args[0], args[1], @@ -1553,5 +1586,5 @@ def tensorrt_scaled_dot_product_attention( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.attention.scaled_dot_product_attention( - ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2] + ctx, target, SourceIR.TORCHTRT_LOWERED, name, args[0], args[1], args[2] ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py index 981c13397f..3078bd4587 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py @@ -1,5 +1,6 @@ from typing import Optional +import numpy as np import tensorrt as trt import torch from torch.fx.node import Target @@ -23,16 +24,6 @@ def where( other: TRTTensor, condition: TRTTensor, ) -> TRTTensor: - input_dim = len(tuple(input.shape)) - other_dim = len(tuple(other.shape)) - condition_dim = len(tuple(condition.shape)) - - if type(input) != TRTTensor: - assert type(input) is torch.Tensor, f"value {input} is not torch.Tensor!" - - if type(other) != TRTTensor: - assert type(other) is torch.Tensor, f"value {other} is not torch.Tensor!" - if not (broadcastable(input, other)): assert "The two torch tensors should be broadcastable" @@ -49,33 +40,37 @@ def where( x_shape = list(input.shape) y_shape = list(other.shape) condition_shape = list(condition.shape) + output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) # expand shape - if type(condition) != TRTTensor: - assert condition.dtype == torch.bool, "condition dtype is not bool" + if not isinstance(condition, TRTTensor): + assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool" if condition_shape != output_shape: - condition.expand(output_shape) - condition = condition.to(torch.int32) - condition_const = get_trt_tensor(ctx, condition, f"{name}_condition") - condition_layer = ctx.net.add_identity(condition_const) - condition_layer.set_output_type(0, trt.bool) - set_layer_name(condition_layer, target, f"{name}_condition") - condition_val = condition_layer.get_output(0) + condition = ( + condition.expand(output_shape) + if isinstance(condition, torch.Tensor) + else np.broadcast_to(condition, output_shape) + ) + condition_val = get_trt_tensor(ctx, condition, f"{name}_condition") else: assert condition.dtype == trt.bool, "mask dtype is not bool!" - if len(condition_shape) != condition_dim: + if condition_shape != output_shape: condition_val = expand( ctx, target, source_ir, f"{name}_expand", condition, output_shape ) else: condition_val = condition - if type(input) != TRTTensor: + if not isinstance(input, TRTTensor): if x_shape != output_shape: # special case where 1 element in input if len(input.shape) == 0: - input = input.unsqueeze(0) + input = ( + input.unsqueeze(0) + if isinstance(input, torch.Tensor) + else np.expand_dims(input, axis=0) + ) input = input.expand(output_shape) x_val = get_trt_tensor(ctx, input, f"{name}_x") else: @@ -85,11 +80,15 @@ def where( ctx, target, source_ir, f"{name}_x_expand", input, output_shape ) - if type(other) != TRTTensor: + if not isinstance(other, TRTTensor): if y_shape != output_shape: # special case where 1 element in other if len(other.shape) == 0: - other = other.unsqueeze(0) + other = ( + other.unsqueeze(0) + if isinstance(other, torch.Tensor) + else np.expand_dims(other, axis=0) + ) other = other.expand(output_shape) y_val = get_trt_tensor(ctx, other, f"{name}_y") else: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py index 0357962be5..eb02657d08 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py @@ -51,7 +51,7 @@ def sum( ): input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) - if dim is None: + if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0): dim = tuple(range(len(input_val.shape))) layer = ctx.net.add_reduce( input_val, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index 185a985e10..ce893f8d5b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -1,4 +1,4 @@ -from typing import Optional, cast +from typing import List, Optional, Sequence, cast from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -49,3 +49,42 @@ def unsqueeze( ) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def broadcast_in_dim( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_t: TRTTensor, + shape: Sequence[int], + broadcast_dimensions: Sequence[int], +) -> TRTTensor: + augmented_shape_list: List[Optional[int]] = list(shape) + + # For each dimension being broadcasted, set the augmented shape to None + for broadcast_dim in broadcast_dimensions: + augmented_shape_list[broadcast_dim] = None + + # TODO: Expand support to arbitrary broadcasts + assert all( + dim in (1, None) for dim in augmented_shape_list + ), "broadcast_in_dim currently only supports unsqueeze broadcasting" + + # Unsqueeze the shape repeatedly to broadcast + output = input_t + for idx, x in enumerate(augmented_shape_list): + # If the value is not None, that dimension is to be broadcasted + if x is not None: + output = unsqueeze( + ctx, + target, + source_ir, + name + f"_unsqueeze_for_broadcast_{idx}", + output, + idx, + ) + + assert tuple(output.shape) == tuple(shape), "broadcast_in_dim shapes don't match" + + return output diff --git a/py/torch_tensorrt/dynamo/conversion/op_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py similarity index 96% rename from py/torch_tensorrt/dynamo/conversion/op_evaluators.py rename to py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index 08285762ce..5cd09a010c 100644 --- a/py/torch_tensorrt/dynamo/conversion/op_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -19,7 +19,7 @@ def getitem_validator(getitem_node: Node) -> bool: # TODO: Subsequent evaluators should be registered here with their own validators -@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) +@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) # type: ignore[misc] def generic_evaluator( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py new file mode 100644 index 0000000000..a8c0dfa6fd --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py @@ -0,0 +1,44 @@ +import logging +from typing import Dict, Sequence, Tuple, Union + +import torch +from torch.fx.node import Argument, Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.fx.types import TRTTensor + +from .converter_registry import dynamo_tensorrt_converter + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +# TODO: expand the scope of this converter with aten.expand implementation +def broadcast_checker(broadcast_node: torch.fx.Node) -> bool: + # The current implementation of broadcast_in_dim can only handle unsqueeze + return all( + broadcast_node.args[1][i] == 1 + for i in range(len(broadcast_node.args[1])) + if i not in broadcast_node.args[2] + ) + + +@dynamo_tensorrt_converter( + torch.ops.prims.broadcast_in_dim.default, capability_validator=broadcast_checker +) # type: ignore[misc] +def aten_ops_broadcast_in_dim( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unsqueeze.broadcast_in_dim( + ctx, + target, + SourceIR.PRIM, + name, + args[0], + args[1], + args[2], + ) diff --git a/tests/py/dynamo/conversion/test_div_aten.py b/tests/py/dynamo/conversion/test_div_aten.py index 2facb52289..882625de25 100644 --- a/tests/py/dynamo/conversion/test_div_aten.py +++ b/tests/py/dynamo/conversion/test_div_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests + from torch_tensorrt import Input from .harness import DispatchTestCase @@ -82,6 +83,23 @@ def forward(self, lhs_val): inputs, ) + @parameterized.expand( + [ + ("2d", (2, 1)), + ("3d", (2, 1, 2)), + ] + ) + def test_prims_div_tensor(self, _, shape): + class div(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.prims.div.default(lhs_val, rhs_val) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + div(), + inputs, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_sum_aten.py b/tests/py/dynamo/conversion/test_sum_aten.py index b279bed43e..c69e7707b7 100644 --- a/tests/py/dynamo/conversion/test_sum_aten.py +++ b/tests/py/dynamo/conversion/test_sum_aten.py @@ -108,5 +108,26 @@ def forward(self, x): ) +class TestPrimsSumConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 2, 4), [1]), + ((2, 1, 4, 5), [1, 2]), + ((2, 3, 4, 5), [0, 1, 2, 3]), + ((6, 7, 5, 4, 5), [1, 3, 4]), + ] + ) + def test_sum_dim_sequence(self, input_shape, dim): + class Sum(nn.Module): + def forward(self, x): + return torch.ops.prims.sum.default(x, dim) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Sum(), + inputs, + ) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_unsqueeze_aten.py b/tests/py/dynamo/conversion/test_unsqueeze_aten.py index e448c4f925..cc920283b3 100644 --- a/tests/py/dynamo/conversion/test_unsqueeze_aten.py +++ b/tests/py/dynamo/conversion/test_unsqueeze_aten.py @@ -3,6 +3,8 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException + from torch_tensorrt import Input from .harness import DispatchTestCase @@ -55,5 +57,52 @@ def forward(self, x): self.run_test_with_dynamic_shape(Unsqueeze(dim), input_specs) +class TestBroadcastInDim(DispatchTestCase): + def test_broadcast_in_dim_supported( + self, + ): + class Unsqueeze(nn.Module): + def forward(self, x): + return torch.ops.prims.broadcast_in_dim.default( + x, [4, 5, 6, 1, 1], [0, 1, 2] + ) + + inputs = [torch.randn(4, 5, 6)] + self.run_test( + Unsqueeze(), + inputs, + ) + + def test_broadcast_in_dim_supported_singleton( + self, + ): + class Unsqueeze(nn.Module): + def forward(self, x): + return torch.ops.prims.broadcast_in_dim.default(x, [1, 1, 1], [0, 1]) + + inputs = [torch.randn(1, 1)] + self.run_test( + Unsqueeze(), + inputs, + ) + + # TODO: Remove this test when support is updated + def test_broadcast_in_dim_unsupported( + self, + ): + class Unsqueeze(nn.Module): + def forward(self, x): + return torch.ops.prims.broadcast_in_dim.default( + x, [4, 5, 6, 7, 1], [0, 1, 2] + ) + + inputs = [torch.randn(4, 5, 6)] + with self.assertRaises(UnsupportedOperatorException): + self.run_test( + Unsqueeze(), + inputs, + ) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_where_aten.py b/tests/py/dynamo/conversion/test_where_aten.py index 2a4bf108da..3594fc6d83 100644 --- a/tests/py/dynamo/conversion/test_where_aten.py +++ b/tests/py/dynamo/conversion/test_where_aten.py @@ -42,6 +42,23 @@ def forward(self, condition, x, y): (condition, inputX, inputOther), ) + def test_const_input(self): + class Where(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.inputY = torch.randn((5, 6, 7)) + self.inputX = torch.randn((5, 6, 7)) + + def forward(self, condition): + return torch.ops.aten.where.self(condition, self.inputX, self.inputY) + + input1 = torch.randn((5, 6, 7)) + condition = input1 < 0 + self.run_test( + Where(), + (condition,), + ) + if __name__ == "__main__": run_tests()