Skip to content

fix/feat: Add and repair multiple converters for SD + other models #2353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from ._TRTInterpreter import * # noqa: F403
from .aten_ops_converters import * # noqa: F403
from .conversion import * # noqa: F403
from .op_evaluators import * # noqa: F403
from .ops_evaluators import * # noqa: F403
from .prims_ops_converters import * # noqa: F403
from .truncate_long_and_double import repair_long_or_double_inputs
45 changes: 39 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ def args_bounds_check(
return args[i] if len(args) > i else replacement


def get_ir(target: Target) -> SourceIR:
target_module = getattr(target, "__module__", "None")
if any(
target_module.startswith(prefix)
for prefix in ("torch.ops.prims", "torch._ops.prims")
):
return SourceIR.ATEN
elif any(
target_module.startswith(prefix)
for prefix in ("torch.ops.prims", "torch._ops.prims")
):
return SourceIR.PRIM
elif target_module.startswith("torch.nn"):
return SourceIR.NN

return SourceIR.UNKNOWN


@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc]
def aten_ops_batch_norm(
ctx: ConversionContext,
Expand Down Expand Up @@ -674,23 +692,37 @@ def aten_ops_amax(

@dynamo_tensorrt_converter(torch.ops.aten.sum.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.prims.sum.default) # type: ignore[misc]
def aten_ops_sum(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.reduce.sum(
sum_ = impl.reduce.sum(
ctx,
target,
SourceIR.ATEN,
get_ir(target),
name,
args[0],
args_bounds_check(args, 1, replacement=None),
args_bounds_check(args, 2, replacement=False),
)

if kwargs.get("output_dtype", None) is not None:
return impl.cast.to_copy(
ctx,
target,
SourceIR.ATEN,
name,
sum_,
kwargs["output_dtype"],
force_layer=False,
)
else:
return sum_


@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc]
def aten_ops_exp(
Expand Down Expand Up @@ -1189,6 +1221,7 @@ def aten_ops_sub(
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.prims.div.default) # type: ignore[misc]
def aten_ops_div(
ctx: ConversionContext,
target: Target,
Expand All @@ -1202,7 +1235,7 @@ def aten_ops_div(
return impl.elementwise.div(
ctx,
target,
SourceIR.ATEN,
get_ir(target),
name,
args[0],
args[1],
Expand All @@ -1211,7 +1244,7 @@ def aten_ops_div(
return impl.elementwise.floor_divide(
ctx,
target,
SourceIR.ATEN,
get_ir(target),
name,
args[0],
args[1],
Expand All @@ -1220,7 +1253,7 @@ def aten_ops_div(
return impl.elementwise.trunc_div(
ctx,
target,
SourceIR.ATEN,
get_ir(target),
name,
args[0],
args[1],
Expand Down Expand Up @@ -1553,5 +1586,5 @@ def tensorrt_scaled_dot_product_attention(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.attention.scaled_dot_product_attention(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
ctx, target, SourceIR.TORCHTRT_LOWERED, name, args[0], args[1], args[2]
)
47 changes: 23 additions & 24 deletions py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
Expand All @@ -23,16 +24,6 @@ def where(
other: TRTTensor,
condition: TRTTensor,
) -> TRTTensor:
input_dim = len(tuple(input.shape))
other_dim = len(tuple(other.shape))
condition_dim = len(tuple(condition.shape))

if type(input) != TRTTensor:
assert type(input) is torch.Tensor, f"value {input} is not torch.Tensor!"

if type(other) != TRTTensor:
assert type(other) is torch.Tensor, f"value {other} is not torch.Tensor!"

if not (broadcastable(input, other)):
assert "The two torch tensors should be broadcastable"

Expand All @@ -49,33 +40,37 @@ def where(
x_shape = list(input.shape)
y_shape = list(other.shape)
condition_shape = list(condition.shape)

output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))

# expand shape
if type(condition) != TRTTensor:
assert condition.dtype == torch.bool, "condition dtype is not bool"
if not isinstance(condition, TRTTensor):
assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool"
if condition_shape != output_shape:
condition.expand(output_shape)
condition = condition.to(torch.int32)
condition_const = get_trt_tensor(ctx, condition, f"{name}_condition")
condition_layer = ctx.net.add_identity(condition_const)
condition_layer.set_output_type(0, trt.bool)
set_layer_name(condition_layer, target, f"{name}_condition")
condition_val = condition_layer.get_output(0)
condition = (
condition.expand(output_shape)
if isinstance(condition, torch.Tensor)
else np.broadcast_to(condition, output_shape)
)
condition_val = get_trt_tensor(ctx, condition, f"{name}_condition")
else:
assert condition.dtype == trt.bool, "mask dtype is not bool!"
if len(condition_shape) != condition_dim:
if condition_shape != output_shape:
condition_val = expand(
ctx, target, source_ir, f"{name}_expand", condition, output_shape
)
else:
condition_val = condition

if type(input) != TRTTensor:
if not isinstance(input, TRTTensor):
if x_shape != output_shape:
# special case where 1 element in input
if len(input.shape) == 0:
input = input.unsqueeze(0)
input = (
input.unsqueeze(0)
if isinstance(input, torch.Tensor)
else np.expand_dims(input, axis=0)
)
input = input.expand(output_shape)
x_val = get_trt_tensor(ctx, input, f"{name}_x")
else:
Expand All @@ -85,11 +80,15 @@ def where(
ctx, target, source_ir, f"{name}_x_expand", input, output_shape
)

if type(other) != TRTTensor:
if not isinstance(other, TRTTensor):
if y_shape != output_shape:
# special case where 1 element in other
if len(other.shape) == 0:
other = other.unsqueeze(0)
other = (
other.unsqueeze(0)
if isinstance(other, torch.Tensor)
else np.expand_dims(other, axis=0)
)
other = other.expand(output_shape)
y_val = get_trt_tensor(ctx, other, f"{name}_y")
else:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def sum(
):
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)

if dim is None:
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider dim is [] or ()? In this case, the dim passed into axes will be 0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that dim being empty implies the same thing as dim being None, which is to reduce over all axes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the behavior in torch, at least with torch.sum

dim = tuple(range(len(input_val.shape)))
layer = ctx.net.add_reduce(
input_val,
Expand Down
41 changes: 40 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, cast
from typing import List, Optional, Sequence, cast

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand Down Expand Up @@ -49,3 +49,42 @@ def unsqueeze(
)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def broadcast_in_dim(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_t: TRTTensor,
shape: Sequence[int],
broadcast_dimensions: Sequence[int],
) -> TRTTensor:
augmented_shape_list: List[Optional[int]] = list(shape)

# For each dimension being broadcasted, set the augmented shape to None
for broadcast_dim in broadcast_dimensions:
augmented_shape_list[broadcast_dim] = None

# TODO: Expand support to arbitrary broadcasts
assert all(
dim in (1, None) for dim in augmented_shape_list
), "broadcast_in_dim currently only supports unsqueeze broadcasting"

# Unsqueeze the shape repeatedly to broadcast
output = input_t
for idx, x in enumerate(augmented_shape_list):
# If the value is not None, that dimension is to be broadcasted
if x is not None:
output = unsqueeze(
ctx,
target,
source_ir,
name + f"_unsqueeze_for_broadcast_{idx}",
output,
idx,
)

assert tuple(output.shape) == tuple(shape), "broadcast_in_dim shapes don't match"

return output
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def getitem_validator(getitem_node: Node) -> bool:


# TODO: Subsequent evaluators should be registered here with their own validators
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) # type: ignore[misc]
def generic_evaluator(
ctx: ConversionContext,
target: Target,
Expand Down
44 changes: 44 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
from typing import Dict, Sequence, Tuple, Union

import torch
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor

from .converter_registry import dynamo_tensorrt_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)


# TODO: expand the scope of this converter with aten.expand implementation
def broadcast_checker(broadcast_node: torch.fx.Node) -> bool:
# The current implementation of broadcast_in_dim can only handle unsqueeze
return all(
broadcast_node.args[1][i] == 1
for i in range(len(broadcast_node.args[1]))
if i not in broadcast_node.args[2]
)


@dynamo_tensorrt_converter(
torch.ops.prims.broadcast_in_dim.default, capability_validator=broadcast_checker
) # type: ignore[misc]
def aten_ops_broadcast_in_dim(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unsqueeze.broadcast_in_dim(
ctx,
target,
SourceIR.PRIM,
name,
args[0],
args[1],
args[2],
)
18 changes: 18 additions & 0 deletions tests/py/dynamo/conversion/test_div_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from torch_tensorrt import Input

from .harness import DispatchTestCase
Expand Down Expand Up @@ -82,6 +83,23 @@ def forward(self, lhs_val):
inputs,
)

@parameterized.expand(
[
("2d", (2, 1)),
("3d", (2, 1, 2)),
]
)
def test_prims_div_tensor(self, _, shape):
class div(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.prims.div.default(lhs_val, rhs_val)

inputs = [torch.randn(shape), torch.randn(shape)]
self.run_test(
div(),
inputs,
)


if __name__ == "__main__":
run_tests()
21 changes: 21 additions & 0 deletions tests/py/dynamo/conversion/test_sum_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,26 @@ def forward(self, x):
)


class TestPrimsSumConverter(DispatchTestCase):
@parameterized.expand(
[
((3, 2, 4), [1]),
((2, 1, 4, 5), [1, 2]),
((2, 3, 4, 5), [0, 1, 2, 3]),
((6, 7, 5, 4, 5), [1, 3, 4]),
]
)
def test_sum_dim_sequence(self, input_shape, dim):
class Sum(nn.Module):
def forward(self, x):
return torch.ops.prims.sum.default(x, dim)

inputs = [torch.randn(*input_shape)]
self.run_test(
Sum(),
inputs,
)


if __name__ == "__main__":
run_tests()
Loading