From 099135cf73884daa6a2f8b408124092c4dffc1bb Mon Sep 17 00:00:00 2001
From: gs-olive <113141689+gs-olive@users.noreply.github.com>
Date: Thu, 6 Jul 2023 12:11:25 -0700
Subject: [PATCH] fix: Repair null bindings issue in TRT Engines

- Caused by passing Fake Tensor objects into TRT engines mid-compilation
- Resolved by replacing all FX modules with TRT equivalents after all
TRT compilation is complete. This way, modules are not run on
FakeTensors generated during compilation
---
 py/torch_tensorrt/dynamo/backend/backends.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py
index ef66019ed1..1d770b86a3 100644
--- a/py/torch_tensorrt/dynamo/backend/backends.py
+++ b/py/torch_tensorrt/dynamo/backend/backends.py
@@ -121,6 +121,9 @@ def _compile_module(
         torch_executed_ops=settings.torch_executed_ops,
     )
 
+    # Store TRT replicas of Torch subgraphs
+    trt_modules = {}
+
     # Iterate over all components that can be accelerated
     # Generate the corresponding TRT Module for those
     for name, _ in partitioned_module.named_children():
@@ -138,7 +141,10 @@ def _compile_module(
             settings=settings,
         )
 
-        # Replace FX Module with TRT Module
+        trt_modules[name] = trt_mod
+
+    # Replace all FX Modules with TRT Modules
+    for name, trt_mod in trt_modules.items():
         setattr(partitioned_module, name, trt_mod)
 
     return partitioned_module