diff --git a/devtools/backend_debug/__init__.py b/devtools/backend_debug/__init__.py index b457b7d11d5..56975e80713 100644 --- a/devtools/backend_debug/__init__.py +++ b/devtools/backend_debug/__init__.py @@ -7,6 +7,7 @@ from executorch.devtools.backend_debug.delegation_info import ( DelegationBreakdown, get_delegation_info, + print_delegation_info, ) -__all__ = ["DelegationBreakdown", "get_delegation_info"] +__all__ = ["DelegationBreakdown", "get_delegation_info", "print_delegation_info"] diff --git a/devtools/backend_debug/delegation_info.py b/devtools/backend_debug/delegation_info.py index b237d162f7a..41c7f8d8e7d 100644 --- a/devtools/backend_debug/delegation_info.py +++ b/devtools/backend_debug/delegation_info.py @@ -11,7 +11,7 @@ import pandas as pd import torch - +from tabulate import tabulate # Column names of the DataFrame returned by DelegationInfo.get_operator_delegation_dataframe() # which describes the summarized delegation information grouped by each operator type @@ -174,3 +174,10 @@ def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None: num_delegated_subgraphs=delegated_subgraph_counter, delegation_by_operator=op_occurrences_dict, ) + + +def print_delegation_info(graph_module: torch.fx.GraphModule): + delegation_info = get_delegation_info(graph_module) + print(delegation_info.get_summary()) + df = delegation_info.get_operator_delegation_dataframe() + print(tabulate(df, headers="keys", tablefmt="fancy_grid")) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 618c74e8706..4ad92903534 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -23,7 +23,7 @@ import torch from executorch.backends.vulkan._passes.remove_asserts import remove_asserts -from executorch.devtools.backend_debug import get_delegation_info +from executorch.devtools.backend_debug import print_delegation_info from executorch.devtools.etrecord import generate_etrecord from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass @@ -46,7 +46,6 @@ get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace -from tabulate import tabulate from ..model_factory import EagerModelFactory from .source_transformation.apply_spin_quant_r1_r2 import ( @@ -801,12 +800,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") - def print_delegation_info(graph_module: torch.fx.GraphModule): - delegation_info = get_delegation_info(graph_module) - print(delegation_info.get_summary()) - df = delegation_info.get_operator_delegation_dataframe() - print(tabulate(df, headers="keys", tablefmt="fancy_grid")) - additional_passes = [] if args.model in TORCHTUNE_DEFINED_MODELS: additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]