Skip to content

feat: cherry-pick of Implement symbolic shape propagation, sym_size converter #2751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
@@ -163,10 +163,10 @@ def compile(
)
gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
torch_inputs = get_torch_inputs(inputs, device)
gm = apply_lowering_passes(gm, torch_inputs)

logger.debug("Lowered Input graph: " + str(gm.graph))

compilation_options = {
@@ -264,6 +264,24 @@ def compile_module(
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
)

def contains_metadata(gm: torch.fx.GraphModule) -> bool:
for node in gm.graph.nodes:
if node.op != "output" and (not node.meta) and "val" not in node.meta:
logger.warning(
f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior."
)
return False
return True

# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
if not contains_metadata(gm):
from torch._inductor.compile_fx import fake_tensor_prop

torch_inputs = get_torch_inputs(sample_inputs, settings.device)
with torch.no_grad():
# This fails if the module has data-dependent shape operators.
fake_tensor_prop(gm, torch_inputs)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

@@ -322,12 +340,7 @@ def compile_module(
)

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module,
submodule,
sample_inputs,
to_torch_device(settings.device),
)
submodule_inputs = partitioning.construct_submodule_inputs(submodule)

logger.debug(
"Submodule name: %s\n Input shapes: %s\n %s",
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
@@ -74,7 +74,6 @@ def _pretraced_backend(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
16 changes: 13 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
@@ -38,12 +38,22 @@ def infer_module_output_dtypes(
# such as aten.sum - such outputs can be truncated
output_dtypes = []
for output in module_outputs:
if truncate_long_and_double and output.dtype == dtype.float64:
output_ = output
# We don't need to check if output is nested here because the input module will be flattened
if not isinstance(output, torch.Tensor):
if isinstance(output, str):
raise ValueError(
f"Receieved an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
)
else:
output_ = torch.tensor(output)

if truncate_long_and_double and output_.dtype == dtype.float64:
output_dtypes.append(dtype.float32)
elif truncate_long_and_double and output.dtype == dtype.int64:
elif truncate_long_and_double and output_.dtype == dtype.int64:
output_dtypes.append(dtype.int32)
else:
output_dtypes.append(dtype._from(output.dtype))
output_dtypes.append(dtype._from(output_.dtype))

return output_dtypes

16 changes: 16 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -394,6 +394,22 @@ def aten_ops_sigmoid(
)


@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
def aten_ops_symsize_int(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])


def index_dtype_validator(node: Node) -> bool:
index = node.args[1]
for ind in index:
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/grid.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Optional

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

import tensorrt as trt

# nearest, linear, cubic
GridSamplerInterpolationMode = {
0: trt.InterpolationMode.NEAREST,
20 changes: 18 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

@@ -17,7 +17,23 @@ def reshape(
shape: Sequence[int],
) -> TRTTensor:
layer = ctx.net.add_shuffle(input)
layer.reshape_dims = tuple(shape)
if all(isinstance(s, int) for s in shape):
layer.reshape_dims = tuple(shape)
else:
# Convert all the dimensions to trt Tensors.
trt_shape = []

for i, s in enumerate(shape):
if isinstance(s, TRTTensor):
trt_shape.append(s)
else:
a = get_trt_tensor(ctx, s, f"{name}_{i}")
trt_shape.append(a)
shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
shape_layer.axis = 0
shape_layer.name = f"{name}_output_shape"
layer.set_input(1, shape_layer.get_output(0))

set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)

2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
@@ -69,7 +69,6 @@ def expand(
) -> TRTTensor:
shape_rank = len(shape)
initial_tensor_rank = len(input_t.shape)

# If the rank of the input tensor is less than the shape's rank, pad with ones
if initial_tensor_rank < shape_rank:
input_t = prepend_ones(
@@ -99,6 +98,7 @@ def expand(
stride = tuple(
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
) # stride == 1 if dimensions match, 0 otherwise

layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
23 changes: 22 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, List

import torch

@@ -29,3 +29,24 @@ def get_tensor_placeholders(
]

return placeholders


def get_metadata(
gm: torch.fx.GraphModule, target_op: Any
) -> List[torch._ops.OpOverload]:
"""
Return the list which has the metadata of all the target_op nodes present in the graph.
"""
return [node.meta for node in gm.graph.nodes if node.target == target_op]


def set_metadata(
gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload]
) -> None:
"""
Return the list which has the metadata of all the target_op nodes present in the graph.
"""
target_nodes = [node for node in gm.graph.nodes if node.target == target_op]
assert len(target_nodes) == len(metadata)
for idx, node in enumerate(target_nodes):
node.meta = metadata[idx]
36 changes: 18 additions & 18 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from typing import Callable, List, Sequence, Tuple
from typing import List, Sequence

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
get_metadata,
set_metadata,
)

logger = logging.getLogger(__name__)
@@ -13,27 +15,25 @@ def view_to_reshape(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
orig, replacement = view_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

return gm


def view_replacement() -> Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]:
"""Constructs the original and replacement functions for view"""
orig_op = torch.ops.aten.view.default
replacement_op = torch.ops.aten.reshape.default

# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
return orig_op(input, shape)

# Replacement graph
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.reshape.default(input, shape)
return replacement_op(input, shape)

return orig, replacement
# Store metadata of the orig_op
metadata = get_metadata(gm, orig_op)

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

# Copy the orig_op's metadata to the replacement op
set_metadata(gm, replacement_op, metadata)

return gm
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/partitioning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from ._adjacency_partitioner import partition as fast_partition
from ._global_partitioner import partition as global_partition
from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis
from .common import (
construct_submodule_inputs,
get_graph_converter_support,
run_shape_analysis,
)
164 changes: 89 additions & 75 deletions py/torch_tensorrt/dynamo/partitioning/common.py
Original file line number Diff line number Diff line change
@@ -4,11 +4,99 @@
import torch
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._defaults import DEBUG
from torch_tensorrt.dynamo.utils import get_torch_inputs, input_is_dynamic

logger = logging.getLogger(__name__)


def contains_sym_int(tensor: torch.Tensor) -> bool:
"""
Returns true if the given tensor has symbolic shape.
"""
for dim in tensor:
if isinstance(dim, torch.SymInt):
return True
return False


def construct_dynamic_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input:
"""
Constructs a torch_tensorrt.Input based on a symbolic input
Args:
input_shape: A symbolic shape / regular shape of a tensor (which can have a mix of SymInt nodes and static values)
Returns:
A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input.
"""
min_shape = []
opt_shape = []
max_shape = []
for dim in input_shape:
if isinstance(dim, torch.SymInt):
node = dim.node
expr = node.expr
shape_env = node.shape_env
var_range = shape_env.var_to_range.get(expr, None)
var_val = shape_env.var_to_val.get(expr, None)
assert var_range, var_val
# Torchdynamo 0/1 specialization outlier
if var_range.lower == 2:
min_shape.append(1)
else:
min_shape.append(int(var_range.lower))
opt_shape.append(int(var_val))
max_shape.append(int(var_range.upper))
else:
min_shape.append(dim)
opt_shape.append(dim)
max_shape.append(dim)

return Input(
min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape, dtype=input_dtype
)


def get_input(input_shape: torch.Size, input_dtype: torch.dtype) -> Input:
"""
Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs
"""
if contains_sym_int(input_shape):
return construct_dynamic_input(input_shape, input_dtype)
else:
return Input(shape=input_shape, dtype=input_dtype)


def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
"""
Construct torch_tensorrt Inputs based on the module inputs.
The module inputs will have meta data which has the shape and dtype info
Args:
module: Input FX GraphModule
Returns:
Sequence of torch_tensorrt.Input's representing inputs to given module
"""
torchtrt_inputs = []
module_inputs = [node for node in module.graph.nodes if node.op == "placeholder"]
for input in module_inputs:
if input.meta:
if "val" in input.meta:
input_meta = input.meta["val"]
input_shape = input_meta.size()
torchtrt_inputs.append(get_input(input_shape, input_meta.dtype))
elif "tensor_meta" in input.meta:
input_meta = input.meta["tensor_meta"]
input_shape = input_meta.shape
torchtrt_inputs.append(get_input(input_shape, input_meta.dtype))
else:
raise AssertionError(
f"Input {input.name} does not contain val and tensor_meta fields in the metadata. Please ensure you have exported the graph correctly"
)
else:
raise AssertionError(
f"Input {input.name} does not contain metadata. Please ensure you have exported the graph correctly"
)

return torchtrt_inputs


def run_shape_analysis(
parent_module: torch.fx.GraphModule, inputs: Sequence[Input]
) -> Tuple[Dict[Any, Sequence[Any]], Dict[Any, Sequence[Any]]]:
@@ -46,80 +134,6 @@ def get_submodule_io(
return submod_inputs_shape_map, submod_outputs_shape_map


def get_submod_inputs(
mod: torch.fx.GraphModule,
submod: torch.fx.GraphModule,
inputs: Sequence[Input],
device: torch.device,
) -> Optional[Sequence[torch.Tensor]]:
"""Helper function to get inputs to a Torch submodule
Args:
mod: Parent FX GraphModule
submod: Child FX GraphModule
inputs: Sample inputs to parent module
Returns:
Sequence of Tensors representing inputs to child module
"""
acc_inputs: Any = None

def get_input(self: Any, inputs: Sequence[torch.Tensor]) -> None:
nonlocal acc_inputs
acc_inputs = inputs
return

# Register a hook to capture submodule input
handle = submod.register_forward_pre_hook(get_input)
# Iterate over min, opt, max shapes for dynamic inputs
inputs_map = {}

if input_is_dynamic(inputs):
for mode in ["min_shape", "opt_shape", "max_shape"]:
torch_inputs = get_torch_inputs(inputs, device, mode)
mod(*torch_inputs)
inputs_map[mode] = acc_inputs
handle.remove()
else:
torch_inputs = get_torch_inputs(inputs, device)
mod(*torch_inputs)
handle.remove()
assert isinstance(acc_inputs, tuple)
return [
Input(shape=acc_input.shape, dtype=acc_input.dtype)
for acc_input in acc_inputs
]

num_submodule_inputs = (
len(inputs_map["min_shape"]) if inputs_map["min_shape"] else 0
)
submodule_inputs = []
for idx in range(num_submodule_inputs):
if not isinstance(inputs_map["min_shape"][idx], torch.Tensor):
input_val = torch.tensor(inputs_map["opt_shape"][idx], dtype=torch.int32)
logger.warning(
"Detected a zero-dimensional input. This might be a shape tensor input which is not currently supported. This might result in undefined behavior"
)
submodule_inputs.append(
Input(
shape=[1],
torch_tensor=input_val,
dtype=input_val.dtype,
)
)
else:
submodule_inputs.append(
Input(
min_shape=inputs_map["min_shape"][idx].shape,
opt_shape=inputs_map["opt_shape"][idx].shape,
max_shape=inputs_map["max_shape"][idx].shape,
torch_tensor=inputs_map["opt_shape"][idx],
dtype=inputs_map["opt_shape"][idx].dtype,
)
)

return submodule_inputs


def get_graph_converter_support(
graph_module: torch.fx.GraphModule,
verbose: bool = DEBUG,
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -88,7 +88,8 @@ def get_torch_inputs(
if isinstance(input, Input)
]
return [
input.torch_tensor.to(device) for input in inputs if isinstance(input, Input)
input.torch_tensor.to(device) if isinstance(input, Input) else input
for input in inputs
]


44 changes: 44 additions & 0 deletions tests/py/dynamo/conversion/test_sym_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestSymSizeConverter(DispatchTestCase):
@parameterized.expand(
[
((3, 2, 4),),
]
)
def test_sym_size_batch(self, input_shape):
class BatchDim(nn.Module):
def forward(self, x):
return torch.ops.aten.sym_size.int(x, 0)

inputs = [torch.randn(*input_shape)]
self.run_test(
BatchDim(),
inputs,
)

@parameterized.expand(
[
((3, 2, 4),),
]
)
def test_sym_size_non_batch(self, input_shape):
class NonBatchDim(nn.Module):
def forward(self, x):
return torch.ops.aten.sym_size.int(x, 1)

inputs = [torch.randn(*input_shape)]
self.run_test(
NonBatchDim(),
inputs,
)


if __name__ == "__main__":
run_tests()
55 changes: 52 additions & 3 deletions tests/py/dynamo/models/test_dyn_models.py
Original file line number Diff line number Diff line change
@@ -3,9 +3,8 @@
import pytest
import timm
import torch
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

import torch_tensorrt as torchtrt
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

assertions = unittest.TestCase()

@@ -65,7 +64,7 @@ def forward(self, x):
@pytest.mark.unit
def test_base_dynamic_fallback(ir):
"""
Tests the model (which is fully convertible) with dynamic shapes
Tests the model with dynamic shapes where torch.abs op is forced to run in PyTorch
"""

class MyModule(torch.nn.Module):
@@ -114,3 +113,53 @@ def forward(self, x):

with torch.no_grad():
torch.cuda.empty_cache()


@pytest.mark.unit
def test_view(ir):
"""
Tests the model (which is fully convertible) with dynamic shapes
"""

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
input_shape = x.size()
y = x.view(input_shape[0], -1)
return y

model = MyModule().eval().cuda()
input = torch.randn((6, 3, 4)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
min_shape=(1, 3, 4),
opt_shape=(4, 3, 4),
max_shape=(8, 3, 4),
dtype=torch.float32,
name="x",
)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 1,
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
torch._dynamo.reset()

with torch.no_grad():
torch.cuda.empty_cache()