From e6357773491e86cae6a80f281ee985fc52982367 Mon Sep 17 00:00:00 2001 From: Alex Settle Date: Fri, 7 Jul 2023 16:12:18 -0700 Subject: [PATCH 1/3] Add detailed layer name including the module name and pass to TRT and the profiler. --- py/torch_tensorrt/fx/fx2trt.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index d7ef976fba..6de0dfaaaf 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -1,5 +1,6 @@ import logging import os +import re import warnings from datetime import datetime from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence @@ -270,8 +271,27 @@ def run( engine, self._input_names, self._output_names, serialized_cache ) + def get_node_name(self, node): + # nn_module_stack preserves the call stack of pytorch nn.modules + # The call stack contains a detailed name of the module + # which shows exactly where the module is located in the + # network architecture. + stack_item = node.meta.get("nn_module_stack", None) + # The current node is the last item in the stack + mod_stack = stack_item.popitem() if stack_item else "" + node_name = str(node) + if mod_stack: + mod_name = str(mod_stack[0]).replace("___", "/") + # Clean up the module name + mod_name = re.sub("^.*__self", "", mod_name) + mod_name = re.sub("_(\d+)$", "/\g<1>", mod_name) + node_name = mod_name + '/' + node_name + + _LOGGER.debug(f"Node meta name {node_name}") + return node_name + def run_node(self, n): - self._cur_node_name = str(n) + self._cur_node_name = self.get_node_name(n) # add "_itensor_to_tensor_meta" kwargs = dict(n.kwargs) kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta From db565d88be9d7d3f89d55bdb107ba657dbba71ed Mon Sep 17 00:00:00 2001 From: Alex Settle Date: Mon, 10 Jul 2023 13:47:28 -0700 Subject: [PATCH 2/3] Ported changes to the dynamo version of fx2trt.py for torch.compile compatibility --- .../dynamo/fx_ts_compat/fx2trt.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index a29cee509d..25ebe7d3c1 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -1,4 +1,5 @@ import logging +import re import warnings from datetime import datetime from packaging import version @@ -291,8 +292,27 @@ def run( engine, self._input_names, self._output_names, serialized_cache ) + def get_node_name(self, node): + # nn_module_stack preserves the call stack of pytorch nn.modules + # The call stack contains a detailed name of the module + # which shows exactly where the module is located in the + # network architecture. + stack_item = node.meta.get("nn_module_stack", None) + # The current node is the last item in the stack + mod_stack = stack_item.popitem() if stack_item else "" + node_name = str(node) + if mod_stack: + mod_name = str(mod_stack[0]).replace("___", "/") + # Clean up the module name + mod_name = re.sub("^.*__self", "", mod_name) + mod_name = re.sub("_(\d+)$", "/\g<1>", mod_name) + node_name = mod_name + '/' + node_name + + _LOGGER.debug(f"Node meta name {node_name}") + return node_name + def run_node(self, n): - self._cur_node_name = str(n) + self._cur_node_name = self.get_node_name(n) # add "_itensor_to_tensor_meta" kwargs = dict(n.kwargs) kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta From 4e11ca37a04906cc29557dd30c46be2b2041c353 Mon Sep 17 00:00:00 2001 From: GECOS Date: Tue, 11 Jul 2023 00:51:00 +0000 Subject: [PATCH 3/3] Fixed lint error --- py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py | 8 ++++++-- py/torch_tensorrt/fx/fx2trt.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index 25ebe7d3c1..beb0b1f1d3 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -297,7 +297,7 @@ def get_node_name(self, node): # The call stack contains a detailed name of the module # which shows exactly where the module is located in the # network architecture. - stack_item = node.meta.get("nn_module_stack", None) + stack_item = node.meta.get("nn_module_stack", None) # The current node is the last item in the stack mod_stack = stack_item.popitem() if stack_item else "" node_name = str(node) @@ -306,7 +306,11 @@ def get_node_name(self, node): # Clean up the module name mod_name = re.sub("^.*__self", "", mod_name) mod_name = re.sub("_(\d+)$", "/\g<1>", mod_name) - node_name = mod_name + '/' + node_name + node_name = mod_name + "/" + node_name + else: + # Try an alternative way to get the module info + # like the node.meta['source_fn'] attr + pass _LOGGER.debug(f"Node meta name {node_name}") return node_name diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 6de0dfaaaf..aff46f3290 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -276,7 +276,7 @@ def get_node_name(self, node): # The call stack contains a detailed name of the module # which shows exactly where the module is located in the # network architecture. - stack_item = node.meta.get("nn_module_stack", None) + stack_item = node.meta.get("nn_module_stack", None) # The current node is the last item in the stack mod_stack = stack_item.popitem() if stack_item else "" node_name = str(node) @@ -285,7 +285,7 @@ def get_node_name(self, node): # Clean up the module name mod_name = re.sub("^.*__self", "", mod_name) mod_name = re.sub("_(\d+)$", "/\g<1>", mod_name) - node_name = mod_name + '/' + node_name + node_name = mod_name + "/" + node_name _LOGGER.debug(f"Node meta name {node_name}") return node_name