@@ -129,6 +129,42 @@ def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkR
129
129
)
130
130
131
131
132
+ def benchmark_qec_unsharded (
133
+ args : argparse .Namespace , output_dir : str
134
+ ) -> List [BenchmarkResult ]:
135
+ tables = get_tables (TABLE_SIZES , is_pooled = False )
136
+ sharder = TestQuantECSharder (
137
+ sharding_type = "" ,
138
+ kernel_type = EmbeddingComputeKernel .QUANT .value ,
139
+ shardable_params = [table .name for table in tables ],
140
+ )
141
+
142
+ module = QuantEmbeddingCollection (
143
+ # pyre-ignore [6]
144
+ tables = tables ,
145
+ device = torch .device ("cpu" ),
146
+ quant_state_dict_split_scale_bias = True ,
147
+ )
148
+
149
+ args_kwargs = {
150
+ argname : getattr (args , argname )
151
+ for argname in dir (args )
152
+ # Don't include output_dir since output_dir was modified
153
+ if not argname .startswith ("_" ) and argname not in IGNORE_ARGNAME
154
+ }
155
+
156
+ return benchmark_module (
157
+ module = module ,
158
+ sharder = sharder ,
159
+ sharding_types = [],
160
+ compile_modes = BENCH_COMPILE_MODES ,
161
+ tables = tables ,
162
+ output_dir = output_dir ,
163
+ benchmark_unsharded = True , # benchmark unsharded module
164
+ ** args_kwargs ,
165
+ )
166
+
167
+
132
168
def benchmark_qebc_unsharded (
133
169
args : argparse .Namespace , output_dir : str
134
170
) -> List [BenchmarkResult ]:
@@ -185,9 +221,10 @@ def main() -> None:
185
221
"QuantEmbeddingCollection" ,
186
222
]
187
223
188
- # Only do unsharded QEBC benchmark when using CPU device
224
+ # Only do unsharded QEBC/QEC benchmark when using CPU device
189
225
if args .device_type == "cpu" :
190
226
module_names .append ("unshardedQuantEmbeddingBagCollection" )
227
+ module_names .append ("unshardedQuantEmbeddingCollection" )
191
228
192
229
for module_name in module_names :
193
230
output_dir = args .output_dir + f"/run_{ datetime_sfx } "
@@ -197,9 +234,12 @@ def main() -> None:
197
234
elif module_name == "QuantEmbeddingCollection" :
198
235
output_dir += "_qec"
199
236
benchmark_func = benchmark_qec
200
- else :
237
+ elif module_name == "unshardedQuantEmbeddingBagCollection" :
201
238
output_dir += "_uqebc"
202
239
benchmark_func = benchmark_qebc_unsharded
240
+ else :
241
+ output_dir += "_uqec"
242
+ benchmark_func = benchmark_qec_unsharded
203
243
204
244
if not os .path .exists (output_dir ):
205
245
# Place all outputs under the datetime folder
0 commit comments