-
Notifications
You must be signed in to change notification settings - Fork 364
Pass detailed unique layer name to the TRT engine #2087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
name and pass to TRT and the profiler.
@alexmsettle can you actually redirect this to the dynamo frontend ( |
for torch.compile compatibility
@narendasan - how do you run the linter locally? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, augmenting the node name to include the critical module-path information is very useful in debugging models. Pending the Python style linting via pre-commit:
pip install -r requirements-dev.txt
pre-commit install
Added some comments on node names for FX nodes without an nn_module_stack
attribute.
def get_node_name(self, node): | ||
# nn_module_stack preserves the call stack of pytorch nn.modules | ||
# The call stack contains a detailed name of the module | ||
# which shows exactly where the module is located in the | ||
# network architecture. | ||
stack_item = node.meta.get("nn_module_stack", None) | ||
# The current node is the last item in the stack | ||
mod_stack = stack_item.popitem() if stack_item else "" | ||
node_name = str(node) | ||
if mod_stack: | ||
mod_name = str(mod_stack[0]).replace("___", "/") | ||
# Clean up the module name | ||
mod_name = re.sub("^.*__self", "", mod_name) | ||
mod_name = re.sub("_(\d+)$", "/\g<1>", mod_name) | ||
node_name = mod_name + '/' + node_name | ||
|
||
_LOGGER.debug(f"Node meta name {node_name}") | ||
return node_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Certain nodes in FX do not have nn_module_stack
information. For instance, an aten.mul
op which appears in BERT has the following n.meta
dictionary:
{'stack_trace': ' File "<eval_with_key>.0", line 491, in forward\n encoder_layer_11_attention_self_value = getattr(self.encoder.layer, "11").attention.self.value(encoder_layer_10_output_layer_norm)\n', 'nn_module_stack': {}, 'source_fn': ('getattr_l__self___encoder_layer___11___attention_self_value', <class 'torch.nn.modules.linear.Linear'>), 'original_aten': <OpOverload(op='aten.addmm', overload='default')>, 'val': FakeTensor(..., device='cuda:0', size=(14, 768)), 'tensor_meta': TensorMetadata(shape=torch.Size([14, 768]), dtype=torch.float32, requires_grad=False, stride=(768, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
@alexmsettle - would the stack_trace
field potentially be a helpful debugging/naming tool here? The stack trace for the above aten.mul
node would be:
File "<eval_with_key>.0", line 491, in forward
encoder_layer_11_attention_self_value = getattr(self.encoder.layer, "11").attention.self.value(encoder_layer_10_output_layer_norm)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other BERT aten.mul
nodes, however, do have a non-empty nn_module_stack
, such as:
('getattr_L__self___encoder_layer___11___attention_self_value', ("getattr(L['self'].encoder.layer, '11').attention.self.value", <class 'torch.nn.modules.linear.Linear'>))
For these, the output formatted node name string looks great: '/encoder_layer/11/attention_self_value/mul_214'
@alexmsettle - could you remove the changes to |
Closing this in favor of #2162. Please post there for further questions. |
PyTorch provides a way to retrieve a detailed unique layer name from framework. The name is derived from the module class hierarchy so details of the network architecture is encoded in the name. This information is really useful for performance analysis and for a basic understanding of the network architecture. The goal is to pass this name from PyTorch to TRT, torch-TRT is effectively just a pass through stage.
Description
Each layer in pytorch is defined by classes which inherit from nn.module. The fx graph represents the network as a sequence of nn.module instances along with the pytorch operators which implement the forward method for each class. Each of these operators can be made unique by prefixing them with the module name they appear in. The fx graph maintains the module name in the node.meta['nn_module_stack'] attribute. This change just fetches the 'nn_module_stack' cleans up the name and prefixes it to the operator name, then it gets passed to the TRT engine IR.
Fixes #2069
Type of change
Checklist: