Skip to content

Commit 6e0bb2b

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Move GenMode back to OpInput
Summary: The purpose of this diff stack is to clean up the OpInput code that lives inside exir/dialects/edge, leaving only what is needed to search allowed dtypes per op. Code that is only used for test generation is removed from exir/dialects/edge Reviewed By: larryliu0820 Differential Revision: D49891111 fbshipit-source-id: ef051e470067484513d021b87234d57ffc41ba04
1 parent a4c1f7e commit 6e0bb2b

File tree

2 files changed

+2
-85
lines changed

2 files changed

+2
-85
lines changed

exir/dialects/edge/arg/model.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,6 @@
1313
from executorch.exir.dialects.edge.arg.type import ArgType
1414

1515

16-
class GenMode(Enum):
17-
"""Whether to generate all dtype combinations or, A partial combination.
18-
The definition of partial combination is the following:
19-
Each operator has a set of N arguments, we loop through the dtypes of one
20-
of the arguments, then define a subset S of the remaining argument. For
21-
arguments within S, let their dtypes be the same of the chose argument; for
22-
arguments outside of S, randomly choose a dtype for it."""
23-
24-
All = "All"
25-
Partial = "Partial"
26-
27-
def __str__(self):
28-
return self.value
29-
30-
3116
class ArgMode(Enum):
3217
DEFAULT = 0
3318
ONES = 1

exir/dialects/edge/dtype/runner.py

Lines changed: 2 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import itertools
8-
import random
9-
from typing import Any, Dict, Iterator, List, Optional, Tuple
8+
from typing import Any, Dict, List, Optional, Tuple
109

1110
import torch
1211
import torch.testing._internal.common_dtype as common_dtype
13-
from executorch.exir.dialects.edge.arg.model import ArgMode, BaseArg, BaseKwarg, GenMode
12+
from executorch.exir.dialects.edge.arg.model import ArgMode, BaseArg, BaseKwarg
1413
from executorch.exir.dialects.edge.arg.type import ArgType
1514
from executorch.exir.dialects.edge.dtype.utils import extract_return_dtype
1615
from executorch.exir.dialects.edge.op.api import get_callable
@@ -48,43 +47,6 @@ def _get_args_kwargs(
4847
args.append(val)
4948
return args, kwargs
5049

51-
@staticmethod
52-
def _produce_dtype_tuple(
53-
types: List[ArgType],
54-
code_tuple: Tuple[int],
55-
ty: ArgType,
56-
dt: Optional[torch.dtype],
57-
) -> Optional[Tuple[Optional[torch.dtype]]]:
58-
dtype_tuple = []
59-
for i, code in enumerate(code_tuple):
60-
same_group = [dt]
61-
if ty.is_scalar() and types[i].is_tensor():
62-
if dt == torch.bool or dt == torch.float:
63-
same_group = list(common_dtype.floating_types())
64-
elif dt == torch.int:
65-
same_group = list(common_dtype.integral_types())
66-
else:
67-
same_group = [None]
68-
elif ty.is_tensor() and types[i].is_scalar():
69-
if dt == torch.bool:
70-
same_group = [torch.bool]
71-
elif dt in common_dtype.integral_types():
72-
same_group = [torch.int]
73-
elif dt in common_dtype.floating_types():
74-
same_group = [torch.float]
75-
else:
76-
same_group = [None]
77-
78-
if code == 0:
79-
if dt is None and not types[i].is_optional():
80-
return
81-
dtype_tuple.append(random.choice(same_group))
82-
else:
83-
all_types = common_dtype.all_types_and(torch.bool)
84-
diff_group = list(set(all_types) - set(same_group))
85-
dtype_tuple.append(random.choice(diff_group))
86-
return tuple(dtype_tuple)
87-
8850
def _get_type_tuples(
8951
self, inputs: Dict[str, List[BaseArg]]
9052
) -> List[List[Optional[torch.dtype]]]:
@@ -103,36 +65,6 @@ def mapping(t):
10365

10466
return list(map(mapping, types))
10567

106-
def select_dtype_combinations(
107-
self, inputs: Dict[str, List[BaseArg]], genmode: GenMode
108-
) -> Iterator[Tuple[Optional[torch.dtype]]]:
109-
random.seed(0)
110-
111-
def produce_code_tuples(n: int, i: int) -> Iterator[Tuple[int]]:
112-
codes = [(0,) if j == i else (0, 1) for j in range(n)]
113-
return itertools.product(*codes)
114-
115-
type_tuples = self._get_type_tuples(inputs)
116-
if genmode == GenMode.All:
117-
for dtype_tuple in itertools.product(*type_tuples):
118-
yield dtype_tuple
119-
elif genmode == GenMode.Partial:
120-
dtype_tuples_set = set()
121-
types = DtypeRunner._get_types(inputs)
122-
n = len(types)
123-
for i in range(n):
124-
for dt in type_tuples[i]:
125-
for code_tuple in produce_code_tuples(n, i):
126-
dtype_tuple = DtypeRunner._produce_dtype_tuple(
127-
types, code_tuple, types[i], dt
128-
)
129-
if (
130-
dtype_tuple is not None
131-
and dtype_tuple not in dtype_tuples_set
132-
):
133-
yield dtype_tuple
134-
dtype_tuples_set.add(dtype_tuple)
135-
13668
def run_dtypes(
13769
self,
13870
name: str,

0 commit comments

Comments
 (0)