From 099135cf73884daa6a2f8b408124092c4dffc1bb Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 6 Jul 2023 12:11:25 -0700 Subject: [PATCH] fix: Repair null bindings issue in TRT Engines - Caused by passing Fake Tensor objects into TRT engines mid-compilation - Resolved by replacing all FX modules with TRT equivalents after all TRT compilation is complete. This way, modules are not run on FakeTensors generated during compilation --- py/torch_tensorrt/dynamo/backend/backends.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index ef66019ed1..1d770b86a3 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -121,6 +121,9 @@ def _compile_module( torch_executed_ops=settings.torch_executed_ops, ) + # Store TRT replicas of Torch subgraphs + trt_modules = {} + # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those for name, _ in partitioned_module.named_children(): @@ -138,7 +141,10 @@ def _compile_module( settings=settings, ) - # Replace FX Module with TRT Module + trt_modules[name] = trt_mod + + # Replace all FX Modules with TRT Modules + for name, trt_mod in trt_modules.items(): setattr(partitioned_module, name, trt_mod) return partitioned_module