From e8966d7b09a54515bdf8c436c8e575b3c2fa2185 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 2 Aug 2023 21:34:05 -0700 Subject: [PATCH] feat(torch_tensorrt.dynamo.tools): Tool to calculate coverage of PyTorch opsets Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .../dynamo/conversion/converter_registry.py | 19 +- py/torch_tensorrt/dynamo/tools/__init__.py | 0 .../dynamo/tools/opset_coverage.py | 204 ++++++++++++++++++ setup.py | 2 + 4 files changed, 220 insertions(+), 5 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/tools/__init__.py create mode 100644 py/torch_tensorrt/dynamo/tools/opset_coverage.py diff --git a/py/torch_tensorrt/dynamo/conversion/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/converter_registry.py index e29e5b8437..9bdfc9bf05 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_registry.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_registry.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, Optional, Sequence, Union, List from enum import Enum, auto from torch.fx.node import Target, Node, _get_qualified_name @@ -305,6 +305,7 @@ def unique_targets(self): """Returns the set of unique converter targets stored across all registries""" return set.union(*[set(registry.keys()) for registry in self.registries]) + # TODO: Make this a static method since it does not need state def qualified_name_or_str(self, target: Target) -> str: """Returns string representation of an FX Node target""" if isinstance(target, str): @@ -312,16 +313,24 @@ def qualified_name_or_str(self, target: Target) -> str: else: return _get_qualified_name(target) - def display_all_available_converters(self) -> str: - """Returns a string with all converters and their source, separated by newlines""" - available_converters = "Available converters in ATen registries with counts:\n" - + def get_converter_support_info(self) -> Dict[str, Dict[str, int]]: + """Returns a dictionary of targets backed by at least one converter""" + available_converters = {} for target in sorted( self.unique_targets(), key=lambda target: self.qualified_name_or_str(target) ): _, registry_data = self.get_all_converters_with_target( target, return_registry_info=True ) + available_converters[self.qualified_name_or_str(target)] = registry_data + return available_converters + + def display_all_available_converters(self) -> str: + """Returns a string with all converters and their source, separated by newlines""" + available_converters = "Available converters in ATen registries with counts:\n" + + support_info = self.get_converter_support_info() + for target, registry_data in support_info.keys(): available_converters += f"Node: {self.qualified_name_or_str(target)} - Registry Presence Counts: {registry_data}\n" return available_converters diff --git a/py/torch_tensorrt/dynamo/tools/__init__.py b/py/torch_tensorrt/dynamo/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/dynamo/tools/opset_coverage.py b/py/torch_tensorrt/dynamo/tools/opset_coverage.py new file mode 100644 index 0000000000..a46236c408 --- /dev/null +++ b/py/torch_tensorrt/dynamo/tools/opset_coverage.py @@ -0,0 +1,204 @@ +import dataclasses +import json +import os +from collections import OrderedDict +from dataclasses import dataclass +from enum import Enum, auto +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch._prims as prims +import torchgen +from torch._ops import OpOverload +from torch._dynamo.variables import BuiltinVariable +from torch_tensorrt.dynamo.conversion.converter_registry import ( + DYNAMO_CONVERTERS, + ConverterRegistry, +) +from torch_tensorrt.dynamo.lowering import get_decompositions +from torchgen.gen import parse_native_yaml + + +class SupportStatus(Enum): + CONVERTED = auto() + LEGACY_CONVERTED = auto() + LOWERED = auto() + FALLBACK = auto() + + def __str__(self) -> str: + return self.name + + +@dataclass +class OpsetCoverage: + support_status: Dict[str, Dict[str, str]] + dynamo_coverage: float + legacy_coverage: float + decomposition_coverage: float + fallback_coverage: float + + +NATIVE_FUNCTION_YAML_PATH = ( + Path(os.path.dirname(torchgen.__file__)) + / "packaged/ATen/native/native_functions.yaml" +) +TAGS_YAML_PATH = ( + Path(os.path.dirname(torchgen.__file__)) / "packaged/ATen/native/tags.yaml" +) + + +def get_aten_ops() -> List[Tuple[str, str]]: + parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH) + native_functions = parsed_yaml.native_functions + + aten_ops = OrderedDict() + for function in native_functions: + if "core" in function.tags: + op_name = str(function.func.name) + aten_ops[op_name] = function + + op_schema_pairs = [] + for key, op in sorted(aten_ops.items()): + op_name = f"aten.{key}" + schema = str(op.func).replace("*", r"\*") + + op_schema_pairs.append((op_name, schema)) + + return op_schema_pairs + + +ATEN_OPS = get_aten_ops() + + +def get_prims_ops() -> List[Tuple[str, str]]: + op_schema_pairs = [] + for op_name in prims.__all__: + op_overload = getattr(prims, op_name, None) + + if not isinstance(op_overload, torch._ops.OpOverload): + continue + + op_overloadpacket = op_overload.overloadpacket + + op_name = str(op_overload).replace(".default", "") + schema = op_overloadpacket.schema.replace("*", r"\*") + + op_schema_pairs.append((op_name, schema)) + + return op_schema_pairs + + +PRIM_OPS = get_prims_ops() + + +def get_overloaded_py_ops() -> List[Tuple[str, str]]: + python_ops = BuiltinVariable._fx_graph_functions() + op_schema_pairs = [] + for op in python_ops: + name = op.__name__ + op_schema_pairs.append((f"_operator.{name}", "")) + + return op_schema_pairs + + +OVERLOADED_PY_OPS = get_overloaded_py_ops() + + +def opset_coverage( + opset: List[Tuple[str, str]], + converter_registry: Optional[ConverterRegistry] = None, + decomposition_registry: Optional[Dict[OpOverload, Callable[..., Any]]] = None, +) -> OpsetCoverage: + + opset_schemas = dict(opset) + opset_targets = set(opset_schemas.keys()) + + support_status = {} + + # TODO: Could be way less complicated if there is a way to convert from + # strings to OpOverload + c_registry = ( + converter_registry if converter_registry is not None else DYNAMO_CONVERTERS + ) + converter_registry_targets = { + c_registry.qualified_name_or_str(target).removeprefix("torch.ops.") + for target in c_registry.keys() + } + supported_converted_targets = opset_targets.intersection(converter_registry_targets) + support_count = 0 + legacy_count = 0 + for target in c_registry.keys(): + target_str = c_registry.qualified_name_or_str(target).removeprefix("torch.ops.") + if target_str in opset_targets: + _, registry_data = c_registry.get_all_converters_with_target( + target, return_registry_info=True + ) + if registry_data["Dynamo ATen Converters Registry"] >= 1: + status = SupportStatus.CONVERTED + support_count += 1 + elif registry_data["FX ATen Converters Registry"] >= 1: + status = SupportStatus.LEGACY_CONVERTED + legacy_count += 1 + + support_status[target_str] = { + "schema": f"{target_str.split('.')[0]}.{opset_schemas[target_str]}", + "status": str(status), + } + + l_registry = ( + decomposition_registry + if decomposition_registry is not None + else get_decompositions() + ) + decomp_registry_targets = { + c_registry.qualified_name_or_str(target).removeprefix("torch.ops.") + for target in l_registry.keys() + } + supported_decomp_targets = opset_targets.intersection(decomp_registry_targets) + decomposition_count = len(supported_decomp_targets) + for target in supported_decomp_targets: + support_status[target] = { + "schema": f"{target.split('.')[0]}.{opset_schemas[target]}", + "status": str(SupportStatus.LOWERED), + } + + unsupported_targets = opset_targets.difference( + supported_converted_targets.union(supported_decomp_targets) + ) + unsupported_count = len(unsupported_targets) + for target in unsupported_targets: + support_status[target] = { + "schema": f"{target.split('.')[0]}.{opset_schemas[target]}", + "status": str(SupportStatus.FALLBACK), + } + + return OpsetCoverage( + support_status, + dynamo_coverage=support_count / len(opset), + legacy_coverage=legacy_count / len(opset), + decomposition_coverage=decomposition_count / len(opset), + fallback_coverage=unsupported_count / len(opset), + ) + + +if __name__ == "__main__": + + def find_coverage_status(opset: List[Tuple[str, str]], name: str) -> None: + coverage = opset_coverage(opset) + print(f"{name}:") + print(f" - Dynamo converters: {coverage.dynamo_coverage:.2%}") + print(f" - Decomposed: {coverage.decomposition_coverage:.2%}") + print(f" - Legacy FX converters: {coverage.legacy_coverage:.2%}") + print(f" - Ops to fallback to Torch: {coverage.fallback_coverage:.2%}") + print( + f"Per op coverage status saved to /tmp/{name.lower()}_coverage_status.json" + ) + + with open(f"/tmp/{name.lower()}_coverage_status.json", "w") as f: + json.dump(dataclasses.asdict(coverage), f) + + print("-------- OPERATOR SET COVERAGE --------") + find_coverage_status(ATEN_OPS, "ATen") + find_coverage_status(PRIM_OPS, "prim") + find_coverage_status(OVERLOADED_PY_OPS, "py_overload") diff --git a/setup.py b/setup.py index 8e5c5330b7..7544ea8eb5 100644 --- a/setup.py +++ b/setup.py @@ -350,6 +350,7 @@ def run(self): "torch_tensorrt.dynamo.lowering", "torch_tensorrt.dynamo.lowering.substitutions", "torch_tensorrt.dynamo.runtime", + "torch_tensorrt.dynamo.tools", "torch_tensorrt.fx", "torch_tensorrt.fx.converters", "torch_tensorrt.fx.converters.impl", @@ -374,6 +375,7 @@ def run(self): "torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering", "torch_tensorrt.dynamo.lowering.substitutions": "py/torch_tensorrt/dynamo/lowering/substitutions", "torch_tensorrt.dynamo.runtime": "py/torch_tensorrt/dynamo/runtime", + "torch_tensorrt.dynamo.tools": "py/torch_tensorrt/dynamo/tools", "torch_tensorrt.fx": "py/torch_tensorrt/fx", "torch_tensorrt.fx.converters": "py/torch_tensorrt/fx/converters", "torch_tensorrt.fx.converters.impl": "py/torch_tensorrt/fx/converters/impl",