Skip to content

Commit f960816

Browse files
committed
Add custom model for dynamic model checking
1 parent 8006682 commit f960816

File tree

5 files changed

+57
-135
lines changed

5 files changed

+57
-135
lines changed

.github/workflows/build-test-linux.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ jobs:
144144
cd tests/py/dynamo
145145
python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
146146
python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
147-
python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_generation_compile.xml models/test_hf_generate_dynamic.py
148147
popd
149148
150149
tests-py-dynamo-serde:

.github/workflows/build-test-windows.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ jobs:
143143
cd tests/py/dynamo
144144
python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
145145
python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
146-
python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_generation_compile.xml models/test_hf_generate_dynamic.py
147146
popd
148147
149148
tests-py-dynamo-serde:

tests/py/dynamo/models/test_dyn_models.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,59 @@ def forward(self, x):
310310
cos_sim > COSINE_THRESHOLD,
311311
msg=f"test_linear model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
312312
)
313+
314+
315+
@pytest.mark.unit
316+
def test_dynamic_with_fallback_shape_tensor_pass_through(ir):
317+
class MyModule(torch.nn.Module):
318+
def __init__(self):
319+
super().__init__()
320+
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
321+
self.relu = torch.nn.ReLU()
322+
323+
def forward(self, x):
324+
out = self.conv(x)
325+
x = x + 2
326+
x = x * 2
327+
out = torch.reshape(x, (-1, 2))
328+
out = self.relu(out)
329+
return out
330+
331+
model = MyModule().eval().cuda()
332+
input_bs4 = torch.randn((4, 3, 224, 224)).to("cuda")
333+
334+
compile_spec = {
335+
"device": torchtrt.Device("cuda:0"),
336+
"enabled_precisions": {torch.float},
337+
"ir": ir,
338+
"pass_through_build_failures": True,
339+
"min_block_size": 1,
340+
"torch_executed_ops": "torch.ops.aten.add.Tensor",
341+
}
342+
343+
# Compile the model
344+
if ir == "torch_compile":
345+
torch._dynamo.mark_dynamic(input_bs4, 0, min=1, max=8)
346+
# Compile the model
347+
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
348+
trt_model(input_bs4)
349+
elif ir == "dynamo":
350+
compile_spec["inputs"] = [
351+
torchtrt.Input(
352+
min_shape=(1, 3, 224, 224),
353+
opt_shape=(4, 3, 224, 224),
354+
max_shape=(8, 3, 224, 224),
355+
dtype=torch.float32,
356+
name="x",
357+
)
358+
]
359+
trt_model = torchtrt.compile(model, **compile_spec)
360+
361+
trt_model(input_bs4)
362+
363+
input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda")
364+
cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6))
365+
assertions.assertTrue(
366+
cos_sim > COSINE_THRESHOLD,
367+
msg=f"test_dynamic_with_fallback_shape_tensor_pass_through model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
368+
)

tests/py/dynamo/models/test_hf_generate_dynamic.py

Lines changed: 0 additions & 132 deletions
This file was deleted.

tests/py/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ pytest-xdist>=3.6.1
99
pyyaml
1010
tensorrt==10.0.1
1111
timm>=1.0.3
12-
transformers==4.41.0
12+
transformers==4.40.2
1313
--extra-index-url https://pypi.nvidia.com

0 commit comments

Comments
 (0)