diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index ef66019ed1..1d770b86a3 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -121,6 +121,9 @@ def _compile_module( torch_executed_ops=settings.torch_executed_ops, ) + # Store TRT replicas of Torch subgraphs + trt_modules = {} + # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those for name, _ in partitioned_module.named_children(): @@ -138,7 +141,10 @@ def _compile_module( settings=settings, ) - # Replace FX Module with TRT Module + trt_modules[name] = trt_mod + + # Replace all FX Modules with TRT Modules + for name, trt_mod in trt_modules.items(): setattr(partitioned_module, name, trt_mod) return partitioned_module