Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions exir/dialects/edge/arg/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,6 @@
from executorch.exir.dialects.edge.arg.type import ArgType


class GenMode(Enum):
"""Whether to generate all dtype combinations or, A partial combination.
The definition of partial combination is the following:
Each operator has a set of N arguments, we loop through the dtypes of one
of the arguments, then define a subset S of the remaining argument. For
arguments within S, let their dtypes be the same of the chose argument; for
arguments outside of S, randomly choose a dtype for it."""

All = "All"
Partial = "Partial"

def __str__(self):
return self.value


class ArgMode(Enum):
DEFAULT = 0
ONES = 1
Expand Down
72 changes: 2 additions & 70 deletions exir/dialects/edge/dtype/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
# LICENSE file in the root directory of this source tree.

import itertools
import random
from typing import Any, Dict, Iterator, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.testing._internal.common_dtype as common_dtype
from executorch.exir.dialects.edge.arg.model import ArgMode, BaseArg, BaseKwarg, GenMode
from executorch.exir.dialects.edge.arg.model import ArgMode, BaseArg, BaseKwarg
from executorch.exir.dialects.edge.arg.type import ArgType
from executorch.exir.dialects.edge.dtype.utils import extract_return_dtype
from executorch.exir.dialects.edge.op.api import get_callable
Expand Down Expand Up @@ -48,43 +47,6 @@ def _get_args_kwargs(
args.append(val)
return args, kwargs

@staticmethod
def _produce_dtype_tuple(
types: List[ArgType],
code_tuple: Tuple[int],
ty: ArgType,
dt: Optional[torch.dtype],
) -> Optional[Tuple[Optional[torch.dtype]]]:
dtype_tuple = []
for i, code in enumerate(code_tuple):
same_group = [dt]
if ty.is_scalar() and types[i].is_tensor():
if dt == torch.bool or dt == torch.float:
same_group = list(common_dtype.floating_types())
elif dt == torch.int:
same_group = list(common_dtype.integral_types())
else:
same_group = [None]
elif ty.is_tensor() and types[i].is_scalar():
if dt == torch.bool:
same_group = [torch.bool]
elif dt in common_dtype.integral_types():
same_group = [torch.int]
elif dt in common_dtype.floating_types():
same_group = [torch.float]
else:
same_group = [None]

if code == 0:
if dt is None and not types[i].is_optional():
return
dtype_tuple.append(random.choice(same_group))
else:
all_types = common_dtype.all_types_and(torch.bool)
diff_group = list(set(all_types) - set(same_group))
dtype_tuple.append(random.choice(diff_group))
return tuple(dtype_tuple)

def _get_type_tuples(
self, inputs: Dict[str, List[BaseArg]]
) -> List[List[Optional[torch.dtype]]]:
Expand All @@ -103,36 +65,6 @@ def mapping(t):

return list(map(mapping, types))

def select_dtype_combinations(
self, inputs: Dict[str, List[BaseArg]], genmode: GenMode
) -> Iterator[Tuple[Optional[torch.dtype]]]:
random.seed(0)

def produce_code_tuples(n: int, i: int) -> Iterator[Tuple[int]]:
codes = [(0,) if j == i else (0, 1) for j in range(n)]
return itertools.product(*codes)

type_tuples = self._get_type_tuples(inputs)
if genmode == GenMode.All:
for dtype_tuple in itertools.product(*type_tuples):
yield dtype_tuple
elif genmode == GenMode.Partial:
dtype_tuples_set = set()
types = DtypeRunner._get_types(inputs)
n = len(types)
for i in range(n):
for dt in type_tuples[i]:
for code_tuple in produce_code_tuples(n, i):
dtype_tuple = DtypeRunner._produce_dtype_tuple(
types, code_tuple, types[i], dt
)
if (
dtype_tuple is not None
and dtype_tuple not in dtype_tuples_set
):
yield dtype_tuple
dtype_tuples_set.add(dtype_tuple)

def run_dtypes(
self,
name: str,
Expand Down