5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import itertools
8
- import random
9
- from typing import Any , Dict , Iterator , List , Optional , Tuple
8
+ from typing import Any , Dict , List , Optional , Tuple
10
9
11
10
import torch
12
11
import torch .testing ._internal .common_dtype as common_dtype
13
- from executorch .exir .dialects .edge .arg .model import ArgMode , BaseArg , BaseKwarg , GenMode
12
+ from executorch .exir .dialects .edge .arg .model import ArgMode , BaseArg , BaseKwarg
14
13
from executorch .exir .dialects .edge .arg .type import ArgType
15
14
from executorch .exir .dialects .edge .dtype .utils import extract_return_dtype
16
15
from executorch .exir .dialects .edge .op .api import get_callable
@@ -48,43 +47,6 @@ def _get_args_kwargs(
48
47
args .append (val )
49
48
return args , kwargs
50
49
51
- @staticmethod
52
- def _produce_dtype_tuple (
53
- types : List [ArgType ],
54
- code_tuple : Tuple [int ],
55
- ty : ArgType ,
56
- dt : Optional [torch .dtype ],
57
- ) -> Optional [Tuple [Optional [torch .dtype ]]]:
58
- dtype_tuple = []
59
- for i , code in enumerate (code_tuple ):
60
- same_group = [dt ]
61
- if ty .is_scalar () and types [i ].is_tensor ():
62
- if dt == torch .bool or dt == torch .float :
63
- same_group = list (common_dtype .floating_types ())
64
- elif dt == torch .int :
65
- same_group = list (common_dtype .integral_types ())
66
- else :
67
- same_group = [None ]
68
- elif ty .is_tensor () and types [i ].is_scalar ():
69
- if dt == torch .bool :
70
- same_group = [torch .bool ]
71
- elif dt in common_dtype .integral_types ():
72
- same_group = [torch .int ]
73
- elif dt in common_dtype .floating_types ():
74
- same_group = [torch .float ]
75
- else :
76
- same_group = [None ]
77
-
78
- if code == 0 :
79
- if dt is None and not types [i ].is_optional ():
80
- return
81
- dtype_tuple .append (random .choice (same_group ))
82
- else :
83
- all_types = common_dtype .all_types_and (torch .bool )
84
- diff_group = list (set (all_types ) - set (same_group ))
85
- dtype_tuple .append (random .choice (diff_group ))
86
- return tuple (dtype_tuple )
87
-
88
50
def _get_type_tuples (
89
51
self , inputs : Dict [str , List [BaseArg ]]
90
52
) -> List [List [Optional [torch .dtype ]]]:
@@ -103,36 +65,6 @@ def mapping(t):
103
65
104
66
return list (map (mapping , types ))
105
67
106
- def select_dtype_combinations (
107
- self , inputs : Dict [str , List [BaseArg ]], genmode : GenMode
108
- ) -> Iterator [Tuple [Optional [torch .dtype ]]]:
109
- random .seed (0 )
110
-
111
- def produce_code_tuples (n : int , i : int ) -> Iterator [Tuple [int ]]:
112
- codes = [(0 ,) if j == i else (0 , 1 ) for j in range (n )]
113
- return itertools .product (* codes )
114
-
115
- type_tuples = self ._get_type_tuples (inputs )
116
- if genmode == GenMode .All :
117
- for dtype_tuple in itertools .product (* type_tuples ):
118
- yield dtype_tuple
119
- elif genmode == GenMode .Partial :
120
- dtype_tuples_set = set ()
121
- types = DtypeRunner ._get_types (inputs )
122
- n = len (types )
123
- for i in range (n ):
124
- for dt in type_tuples [i ]:
125
- for code_tuple in produce_code_tuples (n , i ):
126
- dtype_tuple = DtypeRunner ._produce_dtype_tuple (
127
- types , code_tuple , types [i ], dt
128
- )
129
- if (
130
- dtype_tuple is not None
131
- and dtype_tuple not in dtype_tuples_set
132
- ):
133
- yield dtype_tuple
134
- dtype_tuples_set .add (dtype_tuple )
135
-
136
68
def run_dtypes (
137
69
self ,
138
70
name : str ,
0 commit comments