Skip to content

Commit f2efdea

Browse files
peri044laikhtewari
authored andcommitted
feat: cherry-pick of Implement symbolic shape propagation, sym_size converter (#2751)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 00c66e6 commit f2efdea

File tree

14 files changed

+301
-115
lines changed

14 files changed

+301
-115
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ def compile(
163163
)
164164
gm = exported_program.module()
165165
logger.debug("Input graph: " + str(gm.graph))
166-
167166
# Apply lowering on the graph module
168167
torch_inputs = get_torch_inputs(inputs, device)
169168
gm = apply_lowering_passes(gm, torch_inputs)
169+
170170
logger.debug("Lowered Input graph: " + str(gm.graph))
171171

172172
compilation_options = {
@@ -264,6 +264,24 @@ def compile_module(
264264
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
265265
)
266266

267+
def contains_metadata(gm: torch.fx.GraphModule) -> bool:
268+
for node in gm.graph.nodes:
269+
if node.op != "output" and (not node.meta) and "val" not in node.meta:
270+
logger.warning(
271+
f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior."
272+
)
273+
return False
274+
return True
275+
276+
# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
277+
if not contains_metadata(gm):
278+
from torch._inductor.compile_fx import fake_tensor_prop
279+
280+
torch_inputs = get_torch_inputs(sample_inputs, settings.device)
281+
with torch.no_grad():
282+
# This fails if the module has data-dependent shape operators.
283+
fake_tensor_prop(gm, torch_inputs)
284+
267285
# Partition module into components that can be TRT-accelerated
268286
fast_partitioner_failed = False
269287

@@ -322,12 +340,7 @@ def compile_module(
322340
)
323341

324342
# Get the submodule inputs for min, opt, max shapes of the graph inputs
325-
submodule_inputs = partitioning.get_submod_inputs(
326-
partitioned_module,
327-
submodule,
328-
sample_inputs,
329-
to_torch_device(settings.device),
330-
)
343+
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
331344

332345
logger.debug(
333346
"Submodule name: %s\n Input shapes: %s\n %s",

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def _pretraced_backend(
7474
fake_mode, "allow_non_fake_inputs", True
7575
), fake_mode:
7676
repair_input_aliasing(gm)
77-
7877
# Invoke AOTAutograd to translate operators to aten
7978
gm = aot_export_joint_simple(
8079
gm,

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,22 @@ def infer_module_output_dtypes(
3838
# such as aten.sum - such outputs can be truncated
3939
output_dtypes = []
4040
for output in module_outputs:
41-
if truncate_long_and_double and output.dtype == dtype.float64:
41+
output_ = output
42+
# We don't need to check if output is nested here because the input module will be flattened
43+
if not isinstance(output, torch.Tensor):
44+
if isinstance(output, str):
45+
raise ValueError(
46+
f"Receieved an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
47+
)
48+
else:
49+
output_ = torch.tensor(output)
50+
51+
if truncate_long_and_double and output_.dtype == dtype.float64:
4252
output_dtypes.append(dtype.float32)
43-
elif truncate_long_and_double and output.dtype == dtype.int64:
53+
elif truncate_long_and_double and output_.dtype == dtype.int64:
4454
output_dtypes.append(dtype.int32)
4555
else:
46-
output_dtypes.append(dtype._from(output.dtype))
56+
output_dtypes.append(dtype._from(output_.dtype))
4757

4858
return output_dtypes
4959

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,22 @@ def aten_ops_sigmoid(
394394
)
395395

396396

397+
@enforce_tensor_types(
398+
{
399+
0: (TRTTensor,),
400+
}
401+
)
402+
@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
403+
def aten_ops_symsize_int(
404+
ctx: ConversionContext,
405+
target: Target,
406+
args: Tuple[Argument, ...],
407+
kwargs: Dict[str, Argument],
408+
name: str,
409+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
410+
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])
411+
412+
397413
def index_dtype_validator(node: Node) -> bool:
398414
index = node.args[1]
399415
for ind in index:

py/torch_tensorrt/dynamo/conversion/impl/grid.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from typing import Optional
22

3+
import tensorrt as trt
34
from torch.fx.node import Target
45
from torch_tensorrt.dynamo._SourceIR import SourceIR
56
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
67
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
78
from torch_tensorrt.fx.types import TRTTensor
89

9-
import tensorrt as trt
10-
1110
# nearest, linear, cubic
1211
GridSamplerInterpolationMode = {
1312
0: trt.InterpolationMode.NEAREST,

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch_tensorrt.dynamo.conversion.impl as impl
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
6-
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
77
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
88
from torch_tensorrt.fx.types import TRTTensor
99

@@ -17,7 +17,23 @@ def reshape(
1717
shape: Sequence[int],
1818
) -> TRTTensor:
1919
layer = ctx.net.add_shuffle(input)
20-
layer.reshape_dims = tuple(shape)
20+
if all(isinstance(s, int) for s in shape):
21+
layer.reshape_dims = tuple(shape)
22+
else:
23+
# Convert all the dimensions to trt Tensors.
24+
trt_shape = []
25+
26+
for i, s in enumerate(shape):
27+
if isinstance(s, TRTTensor):
28+
trt_shape.append(s)
29+
else:
30+
a = get_trt_tensor(ctx, s, f"{name}_{i}")
31+
trt_shape.append(a)
32+
shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
33+
shape_layer.axis = 0
34+
shape_layer.name = f"{name}_output_shape"
35+
layer.set_input(1, shape_layer.get_output(0))
36+
2137
set_layer_name(layer, target, name, source_ir)
2238
return layer.get_output(0)
2339

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def expand(
6969
) -> TRTTensor:
7070
shape_rank = len(shape)
7171
initial_tensor_rank = len(input_t.shape)
72-
7372
# If the rank of the input tensor is less than the shape's rank, pad with ones
7473
if initial_tensor_rank < shape_rank:
7574
input_t = prepend_ones(
@@ -99,6 +98,7 @@ def expand(
9998
stride = tuple(
10099
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
101100
) # stride == 1 if dimensions match, 0 otherwise
101+
102102
layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
103103
set_layer_name(layer, target, name, source_ir)
104104
return layer.get_output(0)

py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import Any, List
22

33
import torch
44

@@ -29,3 +29,24 @@ def get_tensor_placeholders(
2929
]
3030

3131
return placeholders
32+
33+
34+
def get_metadata(
35+
gm: torch.fx.GraphModule, target_op: Any
36+
) -> List[torch._ops.OpOverload]:
37+
"""
38+
Return the list which has the metadata of all the target_op nodes present in the graph.
39+
"""
40+
return [node.meta for node in gm.graph.nodes if node.target == target_op]
41+
42+
43+
def set_metadata(
44+
gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload]
45+
) -> None:
46+
"""
47+
Return the list which has the metadata of all the target_op nodes present in the graph.
48+
"""
49+
target_nodes = [node for node in gm.graph.nodes if node.target == target_op]
50+
assert len(target_nodes) == len(metadata)
51+
for idx, node in enumerate(target_nodes):
52+
node.meta = metadata[idx]
Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
2-
from typing import Callable, List, Sequence, Tuple
2+
from typing import List, Sequence
33

44
import torch
55
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
66
clean_up_graph_after_modifications,
7+
get_metadata,
8+
set_metadata,
79
)
810

911
logger = logging.getLogger(__name__)
@@ -13,27 +15,25 @@ def view_to_reshape(
1315
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
1416
) -> torch.fx.GraphModule:
1517
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
16-
orig, replacement = view_replacement()
17-
18-
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
19-
gm = clean_up_graph_after_modifications(gm)
20-
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
21-
22-
return gm
23-
24-
25-
def view_replacement() -> Tuple[
26-
torch.fx.GraphModule,
27-
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
28-
]:
29-
"""Constructs the original and replacement functions for view"""
18+
orig_op = torch.ops.aten.view.default
19+
replacement_op = torch.ops.aten.reshape.default
3020

3121
# Original graph
3222
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
33-
return torch.ops.aten.view.default(input, shape)
23+
return orig_op(input, shape)
3424

3525
# Replacement graph
3626
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
37-
return torch.ops.aten.reshape.default(input, shape)
27+
return replacement_op(input, shape)
3828

39-
return orig, replacement
29+
# Store metadata of the orig_op
30+
metadata = get_metadata(gm, orig_op)
31+
32+
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
33+
gm = clean_up_graph_after_modifications(gm)
34+
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
35+
36+
# Copy the orig_op's metadata to the replacement op
37+
set_metadata(gm, replacement_op, metadata)
38+
39+
return gm
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
from ._adjacency_partitioner import partition as fast_partition
22
from ._global_partitioner import partition as global_partition
3-
from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis
3+
from .common import (
4+
construct_submodule_inputs,
5+
get_graph_converter_support,
6+
run_shape_analysis,
7+
)

0 commit comments

Comments
 (0)