Skip to content

Commit d7e47be

Browse files
peri044zewenli98
authored andcommitted
feat: Implement symbolic shape propagation, sym_size converter (#2473)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent a887971 commit d7e47be

File tree

14 files changed

+289
-119
lines changed

14 files changed

+289
-119
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ def compile(
189189
)
190190
gm = exported_program.module()
191191
logger.debug("Input graph: " + str(gm.graph))
192-
193192
# Apply lowering on the graph module
194193
torch_inputs = get_torch_inputs(inputs, device)
195194
gm = apply_lowering_passes(gm, torch_inputs)
195+
196196
logger.debug("Lowered Input graph: " + str(gm.graph))
197197

198198
compilation_options = {
@@ -290,6 +290,24 @@ def compile_module(
290290
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
291291
)
292292

293+
def contains_metadata(gm: torch.fx.GraphModule) -> bool:
294+
for node in gm.graph.nodes:
295+
if node.op != "output" and (not node.meta) and "val" not in node.meta:
296+
logger.warning(
297+
f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior."
298+
)
299+
return False
300+
return True
301+
302+
# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
303+
if not contains_metadata(gm):
304+
from torch._inductor.compile_fx import fake_tensor_prop
305+
306+
torch_inputs = get_torch_inputs(sample_inputs, settings.device)
307+
with torch.no_grad():
308+
# This fails if the module has data-dependent shape operators.
309+
fake_tensor_prop(gm, torch_inputs)
310+
293311
# Partition module into components that can be TRT-accelerated
294312
fast_partitioner_failed = False
295313

@@ -348,12 +366,7 @@ def compile_module(
348366
)
349367

350368
# Get the submodule inputs for min, opt, max shapes of the graph inputs
351-
submodule_inputs = partitioning.get_submod_inputs(
352-
partitioned_module,
353-
submodule,
354-
sample_inputs,
355-
to_torch_device(settings.device),
356-
)
369+
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
357370

358371
logger.debug(
359372
"Submodule name: %s\n Input shapes: %s\n %s",

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def _pretraced_backend(
7474
fake_mode, "allow_non_fake_inputs", True
7575
), fake_mode:
7676
repair_input_aliasing(gm)
77-
7877
# Invoke AOTAutograd to translate operators to aten
7978
gm = aot_export_joint_simple(
8079
gm,

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,22 @@ def aten_ops_sigmoid(
394394
)
395395

396396

397+
@enforce_tensor_types(
398+
{
399+
0: (TRTTensor,),
400+
}
401+
)
402+
@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
403+
def aten_ops_symsize_int(
404+
ctx: ConversionContext,
405+
target: Target,
406+
args: Tuple[Argument, ...],
407+
kwargs: Dict[str, Argument],
408+
name: str,
409+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
410+
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])
411+
412+
397413
def index_dtype_validator(node: Node) -> bool:
398414
index = node.args[1]
399415
for ind in index:

py/torch_tensorrt/dynamo/conversion/impl/grid.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from typing import Optional
22

3+
import tensorrt as trt
34
from torch.fx.node import Target
45
from torch_tensorrt.dynamo._SourceIR import SourceIR
56
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
67
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
78
from torch_tensorrt.fx.types import TRTTensor
89

9-
import tensorrt as trt
10-
1110
# nearest, linear, cubic
1211
GridSamplerInterpolationMode = {
1312
0: trt.InterpolationMode.NEAREST,

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from torch_tensorrt.dynamo._SourceIR import SourceIR
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
12-
cast_trt_tensor,
1312
get_positive_dim,
1413
get_trt_tensor,
14+
to_numpy,
1515
)
1616
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1717
convert_binary_elementwise,
@@ -39,12 +39,6 @@ def shape(
3939
"""
4040
shape_layer = ctx.net.add_shape(input_val)
4141
input_shape = shape_layer.get_output(0)
42-
input_shape = cast_trt_tensor(
43-
ctx,
44-
input_shape,
45-
trt.int32,
46-
name + "_shape_casted",
47-
)
4842
set_layer_name(shape_layer, target, name + "_shape", source_ir)
4943

5044
n_dims = len(input_val.shape)

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch_tensorrt.dynamo.conversion.impl as impl
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
6-
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
77
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
88
from torch_tensorrt.fx.types import TRTTensor
99

@@ -17,7 +17,23 @@ def reshape(
1717
shape: Sequence[int],
1818
) -> TRTTensor:
1919
layer = ctx.net.add_shuffle(input)
20-
layer.reshape_dims = tuple(shape)
20+
if all(isinstance(s, int) for s in shape):
21+
layer.reshape_dims = tuple(shape)
22+
else:
23+
# Convert all the dimensions to trt Tensors.
24+
trt_shape = []
25+
26+
for i, s in enumerate(shape):
27+
if isinstance(s, TRTTensor):
28+
trt_shape.append(s)
29+
else:
30+
a = get_trt_tensor(ctx, s, f"{name}_{i}")
31+
trt_shape.append(a)
32+
shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
33+
shape_layer.axis = 0
34+
shape_layer.name = f"{name}_output_shape"
35+
layer.set_input(1, shape_layer.get_output(0))
36+
2137
set_layer_name(layer, target, name, source_ir)
2238
return layer.get_output(0)
2339

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def expand(
6969
) -> TRTTensor:
7070
shape_rank = len(shape)
7171
initial_tensor_rank = len(input_t.shape)
72-
7372
# If the rank of the input tensor is less than the shape's rank, pad with ones
7473
if initial_tensor_rank < shape_rank:
7574
input_t = prepend_ones(
@@ -99,6 +98,7 @@ def expand(
9998
stride = tuple(
10099
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
101100
) # stride == 1 if dimensions match, 0 otherwise
101+
102102
layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
103103
set_layer_name(layer, target, name, source_ir)
104104
return layer.get_output(0)

py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import Any, List
22

33
import torch
44

@@ -29,3 +29,24 @@ def get_tensor_placeholders(
2929
]
3030

3131
return placeholders
32+
33+
34+
def get_metadata(
35+
gm: torch.fx.GraphModule, target_op: Any
36+
) -> List[torch._ops.OpOverload]:
37+
"""
38+
Return the list which has the metadata of all the target_op nodes present in the graph.
39+
"""
40+
return [node.meta for node in gm.graph.nodes if node.target == target_op]
41+
42+
43+
def set_metadata(
44+
gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload]
45+
) -> None:
46+
"""
47+
Return the list which has the metadata of all the target_op nodes present in the graph.
48+
"""
49+
target_nodes = [node for node in gm.graph.nodes if node.target == target_op]
50+
assert len(target_nodes) == len(metadata)
51+
for idx, node in enumerate(target_nodes):
52+
node.meta = metadata[idx]
Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
2-
from typing import Callable, List, Sequence, Tuple
2+
from typing import List, Sequence
33

44
import torch
55
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
66
clean_up_graph_after_modifications,
7+
get_metadata,
8+
set_metadata,
79
)
810

911
logger = logging.getLogger(__name__)
@@ -13,27 +15,25 @@ def view_to_reshape(
1315
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
1416
) -> torch.fx.GraphModule:
1517
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
16-
orig, replacement = view_replacement()
17-
18-
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
19-
gm = clean_up_graph_after_modifications(gm)
20-
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
21-
22-
return gm
23-
24-
25-
def view_replacement() -> Tuple[
26-
torch.fx.GraphModule,
27-
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
28-
]:
29-
"""Constructs the original and replacement functions for view"""
18+
orig_op = torch.ops.aten.view.default
19+
replacement_op = torch.ops.aten.reshape.default
3020

3121
# Original graph
3222
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
33-
return torch.ops.aten.view.default(input, shape)
23+
return orig_op(input, shape)
3424

3525
# Replacement graph
3626
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
37-
return torch.ops.aten.reshape.default(input, shape)
27+
return replacement_op(input, shape)
3828

39-
return orig, replacement
29+
# Store metadata of the orig_op
30+
metadata = get_metadata(gm, orig_op)
31+
32+
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
33+
gm = clean_up_graph_after_modifications(gm)
34+
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
35+
36+
# Copy the orig_op's metadata to the replacement op
37+
set_metadata(gm, replacement_op, metadata)
38+
39+
return gm
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
from ._adjacency_partitioner import partition as fast_partition
22
from ._global_partitioner import partition as global_partition
3-
from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis
3+
from .common import (
4+
construct_submodule_inputs,
5+
get_graph_converter_support,
6+
run_shape_analysis,
7+
)

0 commit comments

Comments
 (0)