Skip to content

Pass detailed unique layer name to the TRT engine #2087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

alexmsettle
Copy link

PyTorch provides a way to retrieve a detailed unique layer name from framework. The name is derived from the module class hierarchy so details of the network architecture is encoded in the name. This information is really useful for performance analysis and for a basic understanding of the network architecture. The goal is to pass this name from PyTorch to TRT, torch-TRT is effectively just a pass through stage.

Description

Each layer in pytorch is defined by classes which inherit from nn.module. The fx graph represents the network as a sequence of nn.module instances along with the pytorch operators which implement the forward method for each class. Each of these operators can be made unique by prefixing them with the module name they appear in. The fx graph maintains the module name in the node.meta['nn_module_stack'] attribute. This change just fetches the 'nn_module_stack' cleans up the name and prefixes it to the operator name, then it gets passed to the TRT engine IR.

Fixes #2069

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

name and pass to TRT and the profiler.
@narendasan
Copy link
Collaborator

@alexmsettle can you actually redirect this to the dynamo frontend (torch_tensorrt.dynamo.fx_ts_compat). This is going to be the new default frontend for torch_tensorrt, so most of the development work is going there. Same patch should apply

cc: @gs-olive @peri044

@narendasan narendasan requested review from peri044 and gs-olive July 10, 2023 18:41
for torch.compile compatibility
@alexmsettle
Copy link
Author

@narendasan - how do you run the linter locally?

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, augmenting the node name to include the critical module-path information is very useful in debugging models. Pending the Python style linting via pre-commit:

pip install -r requirements-dev.txt 
pre-commit install

Added some comments on node names for FX nodes without an nn_module_stack attribute.

Comment on lines 295 to 312
def get_node_name(self, node):
# nn_module_stack preserves the call stack of pytorch nn.modules
# The call stack contains a detailed name of the module
# which shows exactly where the module is located in the
# network architecture.
stack_item = node.meta.get("nn_module_stack", None)
# The current node is the last item in the stack
mod_stack = stack_item.popitem() if stack_item else ""
node_name = str(node)
if mod_stack:
mod_name = str(mod_stack[0]).replace("___", "/")
# Clean up the module name
mod_name = re.sub("^.*__self", "", mod_name)
mod_name = re.sub("_(\d+)$", "/\g<1>", mod_name)
node_name = mod_name + '/' + node_name

_LOGGER.debug(f"Node meta name {node_name}")
return node_name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Certain nodes in FX do not have nn_module_stack information. For instance, an aten.mul op which appears in BERT has the following n.meta dictionary:

{'stack_trace': ' File "<eval_with_key>.0", line 491, in forward\n encoder_layer_11_attention_self_value = getattr(self.encoder.layer, "11").attention.self.value(encoder_layer_10_output_layer_norm)\n', 'nn_module_stack': {}, 'source_fn': ('getattr_l__self___encoder_layer___11___attention_self_value', <class 'torch.nn.modules.linear.Linear'>), 'original_aten': <OpOverload(op='aten.addmm', overload='default')>, 'val': FakeTensor(..., device='cuda:0', size=(14, 768)), 'tensor_meta': TensorMetadata(shape=torch.Size([14, 768]), dtype=torch.float32, requires_grad=False, stride=(768, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}

@alexmsettle - would the stack_trace field potentially be a helpful debugging/naming tool here? The stack trace for the above aten.mul node would be:

   File "<eval_with_key>.0", line 491, in forward
encoder_layer_11_attention_self_value = getattr(self.encoder.layer, "11").attention.self.value(encoder_layer_10_output_layer_norm)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other BERT aten.mul nodes, however, do have a non-empty nn_module_stack, such as:

('getattr_L__self___encoder_layer___11___attention_self_value', ("getattr(L['self'].encoder.layer, '11').attention.self.value", <class 'torch.nn.modules.linear.Linear'>))

For these, the output formatted node name string looks great: '/encoder_layer/11/attention_self_value/mul_214'

@gs-olive
Copy link
Collaborator

@alexmsettle - could you remove the changes to py/torch_tensorrt/fx/fx2trt.py and leave only those to py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py? Afterwards, I think it should be good to go!

@peri044 peri044 mentioned this pull request Aug 2, 2023
7 tasks
@peri044
Copy link
Collaborator

peri044 commented Aug 2, 2023

Closing this in favor of #2162. Please post there for further questions.

@peri044 peri044 closed this Aug 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Add detailed layer name to the NVTX markers in torch-TRT
5 participants