diff --git a/py/torch_tensorrt/dynamo/tools/opset_coverage.py b/py/torch_tensorrt/dynamo/tools/opset_coverage.py index 9244f06f19..977d8db875 100644 --- a/py/torch_tensorrt/dynamo/tools/opset_coverage.py +++ b/py/torch_tensorrt/dynamo/tools/opset_coverage.py @@ -48,6 +48,10 @@ class OpsetCoverage: Path(os.path.dirname(torchgen.__file__)) / "packaged/ATen/native/tags.yaml" ) +DYNAMO_REGISTRY_NAME = "Dynamo ATen Converters Registry" +FX_REGISTRY_NAME = "FX ATen Converters Registry" +FX_LEGACY_REGISTRY_NAME = "FX Legacy ATen Converters Registry" + def get_aten_ops() -> List[Tuple[str, str]]: parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH) @@ -140,13 +144,25 @@ def opset_coverage( _, registry_data = c_registry.get_all_converters_with_target( target, return_registry_info=True ) + if registry_data is not None: - if registry_data["Dynamo ATen Converters Registry"] >= 1: + if ( + DYNAMO_REGISTRY_NAME in registry_data + and registry_data[DYNAMO_REGISTRY_NAME] >= 1 + ): status = SupportStatus.CONVERTED support_count += 1 - elif registry_data["FX ATen Converters Registry"] >= 1: + elif ( + FX_REGISTRY_NAME in registry_data + and registry_data[FX_REGISTRY_NAME] >= 1 + ) or ( + FX_LEGACY_REGISTRY_NAME in registry_data + and registry_data[FX_LEGACY_REGISTRY_NAME] >= 1 + ): status = SupportStatus.LEGACY_CONVERTED legacy_count += 1 + else: + raise Exception(f"Op belongs to unknown registry: {registry_data}") support_status[target_str] = { "schema": f"{target_str.split('.')[0]}.{opset_schemas[target_str]}",