From 5b16c6541a59c83f93d720f31a7673c9bc0855f8 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 4 Oct 2023 15:55:51 -0700 Subject: [PATCH] Move GenMode back to OpInput (#608) Summary: The purpose of this diff stack is to clean up the OpInput code that lives inside exir/dialects/edge, leaving only what is needed to search allowed dtypes per op. Code that is only used for test generation is removed from exir/dialects/edge Reviewed By: larryliu0820 Differential Revision: D49891111 --- exir/dialects/edge/arg/model.py | 15 ------- exir/dialects/edge/dtype/runner.py | 72 +----------------------------- 2 files changed, 2 insertions(+), 85 deletions(-) diff --git a/exir/dialects/edge/arg/model.py b/exir/dialects/edge/arg/model.py index d2acb8c0857..16f54963283 100644 --- a/exir/dialects/edge/arg/model.py +++ b/exir/dialects/edge/arg/model.py @@ -13,21 +13,6 @@ from executorch.exir.dialects.edge.arg.type import ArgType -class GenMode(Enum): - """Whether to generate all dtype combinations or, A partial combination. - The definition of partial combination is the following: - Each operator has a set of N arguments, we loop through the dtypes of one - of the arguments, then define a subset S of the remaining argument. For - arguments within S, let their dtypes be the same of the chose argument; for - arguments outside of S, randomly choose a dtype for it.""" - - All = "All" - Partial = "Partial" - - def __str__(self): - return self.value - - class ArgMode(Enum): DEFAULT = 0 ONES = 1 diff --git a/exir/dialects/edge/dtype/runner.py b/exir/dialects/edge/dtype/runner.py index d64f9620b4c..41cad201ac5 100644 --- a/exir/dialects/edge/dtype/runner.py +++ b/exir/dialects/edge/dtype/runner.py @@ -5,12 +5,11 @@ # LICENSE file in the root directory of this source tree. import itertools -import random -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import torch.testing._internal.common_dtype as common_dtype -from executorch.exir.dialects.edge.arg.model import ArgMode, BaseArg, BaseKwarg, GenMode +from executorch.exir.dialects.edge.arg.model import ArgMode, BaseArg, BaseKwarg from executorch.exir.dialects.edge.arg.type import ArgType from executorch.exir.dialects.edge.dtype.utils import extract_return_dtype from executorch.exir.dialects.edge.op.api import get_callable @@ -48,43 +47,6 @@ def _get_args_kwargs( args.append(val) return args, kwargs - @staticmethod - def _produce_dtype_tuple( - types: List[ArgType], - code_tuple: Tuple[int], - ty: ArgType, - dt: Optional[torch.dtype], - ) -> Optional[Tuple[Optional[torch.dtype]]]: - dtype_tuple = [] - for i, code in enumerate(code_tuple): - same_group = [dt] - if ty.is_scalar() and types[i].is_tensor(): - if dt == torch.bool or dt == torch.float: - same_group = list(common_dtype.floating_types()) - elif dt == torch.int: - same_group = list(common_dtype.integral_types()) - else: - same_group = [None] - elif ty.is_tensor() and types[i].is_scalar(): - if dt == torch.bool: - same_group = [torch.bool] - elif dt in common_dtype.integral_types(): - same_group = [torch.int] - elif dt in common_dtype.floating_types(): - same_group = [torch.float] - else: - same_group = [None] - - if code == 0: - if dt is None and not types[i].is_optional(): - return - dtype_tuple.append(random.choice(same_group)) - else: - all_types = common_dtype.all_types_and(torch.bool) - diff_group = list(set(all_types) - set(same_group)) - dtype_tuple.append(random.choice(diff_group)) - return tuple(dtype_tuple) - def _get_type_tuples( self, inputs: Dict[str, List[BaseArg]] ) -> List[List[Optional[torch.dtype]]]: @@ -103,36 +65,6 @@ def mapping(t): return list(map(mapping, types)) - def select_dtype_combinations( - self, inputs: Dict[str, List[BaseArg]], genmode: GenMode - ) -> Iterator[Tuple[Optional[torch.dtype]]]: - random.seed(0) - - def produce_code_tuples(n: int, i: int) -> Iterator[Tuple[int]]: - codes = [(0,) if j == i else (0, 1) for j in range(n)] - return itertools.product(*codes) - - type_tuples = self._get_type_tuples(inputs) - if genmode == GenMode.All: - for dtype_tuple in itertools.product(*type_tuples): - yield dtype_tuple - elif genmode == GenMode.Partial: - dtype_tuples_set = set() - types = DtypeRunner._get_types(inputs) - n = len(types) - for i in range(n): - for dt in type_tuples[i]: - for code_tuple in produce_code_tuples(n, i): - dtype_tuple = DtypeRunner._produce_dtype_tuple( - types, code_tuple, types[i], dt - ) - if ( - dtype_tuple is not None - and dtype_tuple not in dtype_tuples_set - ): - yield dtype_tuple - dtype_tuples_set.add(dtype_tuple) - def run_dtypes( self, name: str,