Skip to content

Commit 73e887c

Browse files
authored
fix: Converter, inputs, and utils bugfixes for Transformer XL (#2404)
1 parent 4e5b0f6 commit 73e887c

File tree

7 files changed

+58
-63
lines changed

7 files changed

+58
-63
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def _pretraced_backend(
8989

9090
gm = apply_lowering_passes(gm, sample_inputs)
9191

92-
torchtrt_inputs = prepare_inputs(sample_inputs)
92+
torchtrt_inputs = prepare_inputs(
93+
sample_inputs, disable_memory_format_check=True
94+
)
9395
trt_compiled = compile_module(
9496
gm,
9597
torchtrt_inputs,

py/torch_tensorrt/dynamo/conversion/conversion.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,16 @@ def convert_module(
3535
if not isinstance(module_outputs, (list, tuple)):
3636
module_outputs = [module_outputs]
3737

38-
output_dtypes = [output.dtype for output in module_outputs]
38+
# Int64 outputs can sometimes be generated from within other operators
39+
# such as aten.sum - such outputs can be truncated
40+
output_dtypes = []
41+
for output in module_outputs:
42+
if settings.truncate_long_and_double and output.dtype == torch.float64:
43+
output_dtypes.append(torch.float32)
44+
elif settings.truncate_long_and_double and output.dtype == torch.int64:
45+
output_dtypes.append(torch.int32)
46+
else:
47+
output_dtypes.append(output.dtype)
3948

4049
interpreter = TRTInterpreter(
4150
module,

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,8 @@ def sum(
4949
dim: Optional[Union[int, Sequence[int]]],
5050
keepdim: bool,
5151
) -> TRTTensor:
52-
if (isinstance(input_val, TRTTensor)) and (
53-
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
54-
):
55-
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
52+
if (isinstance(input_val, TRTTensor)) and (input_val.dtype == trt.bool):
53+
input_val = cast_trt_tensor(ctx, input_val, trt.int32, name)
5654

5755
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
5856
dim = tuple(range(len(input_val.shape)))

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,23 @@ def slice_op( # TODO: This should be slice not whatever is in base
3131
"of the TensorRT region!"
3232
)
3333

34-
ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0)
35-
dim = get_positive_dim(dim, ranks)
36-
dynamic_shape = has_dynamic_shape(input.shape)
37-
if ctx.net.has_implicit_batch_dimension:
38-
if dim == 0:
39-
raise RuntimeError(
40-
f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
41-
)
42-
dim = dim - 1
43-
else:
44-
if dynamic_shape:
45-
# Check whether slice target dim is dynamic shape dim
46-
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
47-
start_int = start
48-
stop_int = stop
49-
if stop_int == 2**63 - 1:
50-
stop_int = input.shape[dim]
51-
step_int = step
34+
dim = get_positive_dim(dim, len(input.shape))
35+
start = get_positive_dim(start, input.shape[dim])
36+
stop = get_positive_dim(stop, input.shape[dim])
37+
38+
if has_dynamic_shape(input.shape):
39+
# Check whether slice target dim is dynamic shape dim
40+
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"
41+
42+
if stop == 2**63 - 1:
43+
stop = input.shape[dim]
44+
5245
start_slice = [0] * len(input.shape)
53-
start_slice[dim] = start_int
54-
stride_slice = [1] * len(start_slice)
55-
stride_slice[dim] = step_int
46+
start_slice[dim] = start
47+
stride_slice = [1] * len(input.shape)
48+
stride_slice[dim] = step
5649
output_shape = list(input.shape)
57-
output_shape[dim] = math.ceil((stop_int - start_int) / step_int)
50+
output_shape[dim] = math.ceil((stop - start) / step)
5851

5952
return slice(
6053
ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice

py/torch_tensorrt/dynamo/utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from typing import Any, Callable, Dict, Optional, Sequence, Union
66

77
import torch
8-
import torch_tensorrt
98
from torch_tensorrt._Device import Device
109
from torch_tensorrt._Input import Input
1110
from torch_tensorrt.dynamo import CompilationSettings
1211
from torch_tensorrt.dynamo._defaults import PRECISION
1312

13+
import torch_tensorrt
1414
from packaging import version
1515

1616
logger = logging.getLogger(__name__)
@@ -104,25 +104,32 @@ def set_log_level(parent_logger: Any, level: Any) -> None:
104104

105105
def prepare_inputs(
106106
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
107+
disable_memory_format_check: bool = False,
107108
) -> Any:
108109
if isinstance(inputs, Input):
109110
return inputs
110111

111112
elif isinstance(inputs, torch.Tensor):
112-
return Input.from_tensor(inputs)
113+
return Input.from_tensor(
114+
inputs, disable_memory_format_check=disable_memory_format_check
115+
)
113116

114117
elif isinstance(inputs, list):
115118
torchtrt_input_list = []
116119
for input_obj in inputs:
117-
torchtrt_input = prepare_inputs(input_obj)
120+
torchtrt_input = prepare_inputs(
121+
input_obj, disable_memory_format_check=disable_memory_format_check
122+
)
118123
torchtrt_input_list.append(torchtrt_input)
119124

120125
return torchtrt_input_list
121126

122127
elif isinstance(inputs, tuple):
123128
torchtrt_inputs_tup = []
124129
for input_obj in inputs:
125-
torchtrt_input = prepare_inputs(input_obj)
130+
torchtrt_input = prepare_inputs(
131+
input_obj, disable_memory_format_check=disable_memory_format_check
132+
)
126133
torchtrt_inputs_tup.append(torchtrt_input)
127134

128135
return tuple(torchtrt_inputs_tup)
@@ -131,7 +138,9 @@ def prepare_inputs(
131138
torchtrt_inputs_dict: Dict[Any, Any] = dict()
132139

133140
for key, input_obj in inputs.items():
134-
torchtrt_input = prepare_inputs(input_obj)
141+
torchtrt_input = prepare_inputs(
142+
input_obj, disable_memory_format_check=disable_memory_format_check
143+
)
135144
torchtrt_inputs_dict[key] = torchtrt_input
136145

137146
return torchtrt_inputs_dict

tests/py/dynamo/conversion/test_slice_aten.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,20 @@
11
import torch
22
from parameterized import parameterized
33
from torch.testing._internal.common_utils import run_tests
4+
45
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

89

9-
class TestSelectConverterImplicitBatch(DispatchTestCase):
10+
class TestSelectConverter(DispatchTestCase):
1011
@parameterized.expand(
1112
[
1213
("select_dim_start_stop_step", 0, 0, 7, 2),
13-
]
14-
)
15-
def test_slice(self, _, dim, start, stop, step):
16-
class TestModule(torch.nn.Module):
17-
def __init__(self):
18-
super().__init__()
19-
20-
def forward(self, input):
21-
out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step)
22-
return out
23-
24-
input = [torch.randn(10, 2, 3, 1)]
25-
self.run_test(
26-
TestModule(),
27-
input,
28-
)
29-
30-
31-
class TestSelectConverterExplicitBatch(DispatchTestCase):
32-
@parameterized.expand(
33-
[
34-
("select_dim_start_stop_step", 1, 0, 7, 2),
14+
("select_dim_start_stop_step_offset", 1, 0, 7, 2),
3515
("select_dim_start_stop_step_exact", 1, 0, 10, 2),
16+
("select_dim_start_stop_step_negatives", -3, -2, -1, 1),
17+
("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1),
3618
]
3719
)
3820
def test_slice(self, _, dim, start, stop, step):

tests/py/dynamo/conversion/test_sum_aten.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def forward(self, x):
7070

7171
@parameterized.expand(
7272
[
73-
((3, 2, 4), 1, True, torch.int, 0, 5),
74-
((2, 3, 4, 5), None, True, torch.int, -10, 10),
73+
((3, 2, 4), 1, True, torch.int32, 0, 5),
74+
((2, 3, 4, 5), None, True, torch.int32, -10, 10),
7575
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
7676
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
7777
]
@@ -85,16 +85,18 @@ def forward(self, x):
8585
self.run_test(
8686
Sum(),
8787
inputs,
88-
check_dtype=False,
88+
output_dtypes=[torch.int32],
8989
)
9090

9191
@parameterized.expand(
9292
[
93-
((1, 2, 4), [], True, torch.int, 0, 5),
94-
((3, 2, 4), [1], True, torch.int, 0, 5),
95-
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
93+
((1, 2, 4), [], True, torch.int32, 0, 5),
94+
((3, 2, 4), [1], True, torch.int32, 0, 5),
95+
((2, 1, 4, 5), [0, 3], True, torch.int32, -10, 10),
9696
((2, 3, 4, 5), None, False, torch.int32, -5, 0),
9797
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
98+
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.bool, 0, 2),
99+
((4, 7, 1, 5), None, True, torch.bool, 0, 2),
98100
]
99101
)
100102
def test_sum_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high):
@@ -106,7 +108,7 @@ def forward(self, x):
106108
self.run_test(
107109
Sum(),
108110
inputs,
109-
check_dtype=False,
111+
output_dtypes=[torch.int32],
110112
)
111113

112114

0 commit comments

Comments
 (0)