Skip to content

Commit caf3a92

Browse files
authored
fix: Repair integer inputs in dynamic shape cases (#2876)
1 parent fbc72d5 commit caf3a92

File tree

5 files changed

+82
-22
lines changed

5 files changed

+82
-22
lines changed

.github/workflows/build-test-linux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,4 +264,4 @@ jobs:
264264
265265
concurrency:
266266
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }}
267-
cancel-in-progress: true
267+
cancel-in-progress: true

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
128128
self.context = self.engine.create_execution_context()
129129

130130
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
131+
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
132+
contiguous_inputs: List[torch.Tensor] = [
133+
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
134+
for i in inputs
135+
]
131136
with (
132137
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
133138
if self.profiling_enabled
@@ -174,7 +179,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
174179
self.input_names
175180
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
176181

177-
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
178182
for i, input_name in enumerate(self.input_names):
179183
if not contiguous_inputs[i].is_cuda:
180184
logger.warning(
@@ -193,12 +197,17 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
193197
contiguous_inputs[i].dtype == self.input_dtypes[i]
194198
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
195199

200+
# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
201+
# as per TensorRT requirements
196202
if self.engine.is_shape_inference_io(input_name):
197-
# Shape tensor inputs are casted to int32 explicitly.
198-
# Refer to https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
199-
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int32)
203+
# Shape tensor inputs are casted to int64 explicitly
204+
# Currently Torch CPU pointers are not working; numpy pointers are used instead
205+
# to refer to underlying memory
206+
inputs_cpu = (
207+
contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
208+
)
200209
self.context.set_tensor_address(
201-
input_name, inputs_cpu.data_ptr()
210+
input_name, inputs_cpu.ctypes.data
202211
)
203212
else:
204213
self.context.set_input_shape(

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
146146
"""Implementation of the forward pass for a TensorRT engine
147147
148148
Args:
149-
*inputs (torch.Tensor): Inputs to the forward function, must all be ``torch.Tensor``
149+
*inputs (Union[torch.Tensor, int]): Inputs to the forward function
150150
151151
Returns:
152152
torch.Tensor or Tuple(torch.Tensor): Result of the engine computation
@@ -158,22 +158,18 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
158158
self.input_binding_names
159159
), f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}."
160160

161-
types: List[bool] = [issubclass(type(i), torch.Tensor) for i in inputs]
162-
163-
try:
164-
assert all(types)
165-
except AssertionError:
166-
167-
def is_non_tensor(i: Tuple[Any, bool]) -> bool:
168-
return not i[1]
169-
170-
non_tensors = [i[0] for i in filter(is_non_tensor, zip(inputs, types))]
171-
raise RuntimeError(
172-
f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}"
173-
)
161+
# If the inputs are not Torch Tensors, which can occur in scenarios such as shape tensors
162+
# which are outputs of a preceding Torch subgraph (where the Dynamic input may be an integer)
163+
# directly cast the input to a Torch Tensor.
164+
#
165+
# This also avoids the need for type-checking inputs, since they are now explicitly casted to Torch tensors
166+
input_tensors: List[torch.Tensor] = [
167+
(i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
168+
for i in inputs
169+
]
174170

175171
outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
176-
list(inputs), self.engine
172+
list(input_tensors), self.engine
177173
)
178174

179175
if len(outputs) == 1:

tests/py/dynamo/models/test_dyn_models.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,58 @@ def forward(self, x):
310310
cos_sim > COSINE_THRESHOLD,
311311
msg=f"test_linear model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
312312
)
313+
314+
315+
@pytest.mark.unit
316+
def test_dynamic_with_fallback_shape_tensor_pass_through(ir):
317+
class MyModule(torch.nn.Module):
318+
def __init__(self):
319+
super().__init__()
320+
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
321+
self.relu = torch.nn.ReLU()
322+
323+
def forward(self, x):
324+
out = self.conv(x)
325+
x = x + 2
326+
x = x * 2
327+
out = torch.reshape(x, (-1, 224 * 224))
328+
out = self.relu(out)
329+
return out
330+
331+
model = MyModule().eval().cuda()
332+
input_bs4 = torch.randn((4, 3, 224, 224)).to("cuda")
333+
334+
compile_spec = {
335+
"device": torchtrt.Device("cuda:0"),
336+
"enabled_precisions": {torch.float},
337+
"ir": ir,
338+
"pass_through_build_failures": True,
339+
"min_block_size": 1,
340+
"torch_executed_ops": {"torch.ops.aten.add.Tensor"},
341+
}
342+
343+
# Compile the model
344+
if ir == "torch_compile":
345+
torch._dynamo.mark_dynamic(input_bs4, 0, min=4, max=1024)
346+
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
347+
trt_model(input_bs4)
348+
elif ir == "dynamo":
349+
compile_spec["inputs"] = [
350+
torchtrt.Input(
351+
min_shape=(1, 3, 224, 224),
352+
opt_shape=(4, 3, 224, 224),
353+
max_shape=(1024, 3, 224, 224),
354+
dtype=torch.float32,
355+
name="x",
356+
)
357+
]
358+
trt_model = torchtrt.compile(model, **compile_spec)
359+
360+
trt_model(input_bs4)
361+
362+
input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda")
363+
cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6))
364+
assertions.assertTrue(
365+
cos_sim > COSINE_THRESHOLD,
366+
msg=f"test_dynamic_with_fallback_shape_tensor_pass_through model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
367+
)

tests/py/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ pytest-xdist>=3.6.1
99
pyyaml
1010
tensorrt==10.0.1
1111
timm>=1.0.3
12-
transformers==4.39.3
12+
transformers==4.40.2
1313
--extra-index-url https://pypi.nvidia.com

0 commit comments

Comments
 (0)