Skip to content

feat(torch_tensorrt.dynamo.tools): Tool to calculate coverage of PyTorch #2166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/converter_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional, Sequence, Union
from typing import Any, Callable, Dict, Optional, Sequence, Union, List
from enum import Enum, auto

from torch.fx.node import Target, Node, _get_qualified_name
Expand Down Expand Up @@ -305,23 +305,32 @@ def unique_targets(self):
"""Returns the set of unique converter targets stored across all registries"""
return set.union(*[set(registry.keys()) for registry in self.registries])

# TODO: Make this a static method since it does not need state
def qualified_name_or_str(self, target: Target) -> str:
"""Returns string representation of an FX Node target"""
if isinstance(target, str):
return target
else:
return _get_qualified_name(target)

def display_all_available_converters(self) -> str:
"""Returns a string with all converters and their source, separated by newlines"""
available_converters = "Available converters in ATen registries with counts:\n"

def get_converter_support_info(self) -> Dict[str, Dict[str, int]]:
"""Returns a dictionary of targets backed by at least one converter"""
available_converters = {}
for target in sorted(
self.unique_targets(), key=lambda target: self.qualified_name_or_str(target)
):
_, registry_data = self.get_all_converters_with_target(
target, return_registry_info=True
)
available_converters[self.qualified_name_or_str(target)] = registry_data
return available_converters

def display_all_available_converters(self) -> str:
"""Returns a string with all converters and their source, separated by newlines"""
available_converters = "Available converters in ATen registries with counts:\n"

support_info = self.get_converter_support_info()
for target, registry_data in support_info.keys():
available_converters += f"Node: {self.qualified_name_or_str(target)} - Registry Presence Counts: {registry_data}\n"

return available_converters
Expand Down
Empty file.
204 changes: 204 additions & 0 deletions py/torch_tensorrt/dynamo/tools/opset_coverage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import dataclasses
import json
import os
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch._prims as prims
import torchgen
from torch._ops import OpOverload
from torch._dynamo.variables import BuiltinVariable
from torch_tensorrt.dynamo.conversion.converter_registry import (
DYNAMO_CONVERTERS,
ConverterRegistry,
)
from torch_tensorrt.dynamo.lowering import get_decompositions
from torchgen.gen import parse_native_yaml


class SupportStatus(Enum):
CONVERTED = auto()
LEGACY_CONVERTED = auto()
LOWERED = auto()
FALLBACK = auto()

def __str__(self) -> str:
return self.name


@dataclass
class OpsetCoverage:
support_status: Dict[str, Dict[str, str]]
dynamo_coverage: float
legacy_coverage: float
decomposition_coverage: float
fallback_coverage: float


NATIVE_FUNCTION_YAML_PATH = (
Path(os.path.dirname(torchgen.__file__))
/ "packaged/ATen/native/native_functions.yaml"
)
TAGS_YAML_PATH = (
Path(os.path.dirname(torchgen.__file__)) / "packaged/ATen/native/tags.yaml"
)


def get_aten_ops() -> List[Tuple[str, str]]:
parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH)
native_functions = parsed_yaml.native_functions

aten_ops = OrderedDict()
for function in native_functions:
if "core" in function.tags:
op_name = str(function.func.name)
aten_ops[op_name] = function

op_schema_pairs = []
for key, op in sorted(aten_ops.items()):
op_name = f"aten.{key}"
schema = str(op.func).replace("*", r"\*")

op_schema_pairs.append((op_name, schema))

return op_schema_pairs


ATEN_OPS = get_aten_ops()


def get_prims_ops() -> List[Tuple[str, str]]:
op_schema_pairs = []
for op_name in prims.__all__:
op_overload = getattr(prims, op_name, None)

if not isinstance(op_overload, torch._ops.OpOverload):
continue

op_overloadpacket = op_overload.overloadpacket

op_name = str(op_overload).replace(".default", "")
schema = op_overloadpacket.schema.replace("*", r"\*")

op_schema_pairs.append((op_name, schema))

return op_schema_pairs


PRIM_OPS = get_prims_ops()


def get_overloaded_py_ops() -> List[Tuple[str, str]]:
python_ops = BuiltinVariable._fx_graph_functions()
op_schema_pairs = []
for op in python_ops:
name = op.__name__
op_schema_pairs.append((f"_operator.{name}", ""))

return op_schema_pairs


OVERLOADED_PY_OPS = get_overloaded_py_ops()


def opset_coverage(
opset: List[Tuple[str, str]],
converter_registry: Optional[ConverterRegistry] = None,
decomposition_registry: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
) -> OpsetCoverage:

opset_schemas = dict(opset)
opset_targets = set(opset_schemas.keys())

support_status = {}

# TODO: Could be way less complicated if there is a way to convert from
# strings to OpOverload
c_registry = (
converter_registry if converter_registry is not None else DYNAMO_CONVERTERS
)
converter_registry_targets = {
c_registry.qualified_name_or_str(target).removeprefix("torch.ops.")
for target in c_registry.keys()
}
supported_converted_targets = opset_targets.intersection(converter_registry_targets)
support_count = 0
legacy_count = 0
for target in c_registry.keys():
target_str = c_registry.qualified_name_or_str(target).removeprefix("torch.ops.")
if target_str in opset_targets:
_, registry_data = c_registry.get_all_converters_with_target(
target, return_registry_info=True
)
if registry_data["Dynamo ATen Converters Registry"] >= 1:
status = SupportStatus.CONVERTED
support_count += 1
elif registry_data["FX ATen Converters Registry"] >= 1:
status = SupportStatus.LEGACY_CONVERTED
legacy_count += 1

support_status[target_str] = {
"schema": f"{target_str.split('.')[0]}.{opset_schemas[target_str]}",
"status": str(status),
}

l_registry = (
decomposition_registry
if decomposition_registry is not None
else get_decompositions()
)
decomp_registry_targets = {
c_registry.qualified_name_or_str(target).removeprefix("torch.ops.")
for target in l_registry.keys()
}
supported_decomp_targets = opset_targets.intersection(decomp_registry_targets)
decomposition_count = len(supported_decomp_targets)
for target in supported_decomp_targets:
support_status[target] = {
"schema": f"{target.split('.')[0]}.{opset_schemas[target]}",
"status": str(SupportStatus.LOWERED),
}

unsupported_targets = opset_targets.difference(
supported_converted_targets.union(supported_decomp_targets)
)
unsupported_count = len(unsupported_targets)
for target in unsupported_targets:
support_status[target] = {
"schema": f"{target.split('.')[0]}.{opset_schemas[target]}",
"status": str(SupportStatus.FALLBACK),
}

return OpsetCoverage(
support_status,
dynamo_coverage=support_count / len(opset),
legacy_coverage=legacy_count / len(opset),
decomposition_coverage=decomposition_count / len(opset),
fallback_coverage=unsupported_count / len(opset),
)


if __name__ == "__main__":

def find_coverage_status(opset: List[Tuple[str, str]], name: str) -> None:
coverage = opset_coverage(opset)
print(f"{name}:")
print(f" - Dynamo converters: {coverage.dynamo_coverage:.2%}")
print(f" - Decomposed: {coverage.decomposition_coverage:.2%}")
print(f" - Legacy FX converters: {coverage.legacy_coverage:.2%}")
print(f" - Ops to fallback to Torch: {coverage.fallback_coverage:.2%}")
print(
f"Per op coverage status saved to /tmp/{name.lower()}_coverage_status.json"
)

with open(f"/tmp/{name.lower()}_coverage_status.json", "w") as f:
json.dump(dataclasses.asdict(coverage), f)

print("-------- OPERATOR SET COVERAGE --------")
find_coverage_status(ATEN_OPS, "ATen")
find_coverage_status(PRIM_OPS, "prim")
find_coverage_status(OVERLOADED_PY_OPS, "py_overload")
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def run(self):
"torch_tensorrt.dynamo.lowering",
"torch_tensorrt.dynamo.lowering.substitutions",
"torch_tensorrt.dynamo.runtime",
"torch_tensorrt.dynamo.tools",
"torch_tensorrt.fx",
"torch_tensorrt.fx.converters",
"torch_tensorrt.fx.converters.impl",
Expand All @@ -374,6 +375,7 @@ def run(self):
"torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering",
"torch_tensorrt.dynamo.lowering.substitutions": "py/torch_tensorrt/dynamo/lowering/substitutions",
"torch_tensorrt.dynamo.runtime": "py/torch_tensorrt/dynamo/runtime",
"torch_tensorrt.dynamo.tools": "py/torch_tensorrt/dynamo/tools",
"torch_tensorrt.fx": "py/torch_tensorrt/fx",
"torch_tensorrt.fx.converters": "py/torch_tensorrt/fx/converters",
"torch_tensorrt.fx.converters.impl": "py/torch_tensorrt/fx/converters/impl",
Expand Down