Skip to content

Commit eaea423

Browse files
gnahzgfacebook-github-bot
authored andcommitted
Enable unsharded QEC benchmark
Summary: As titled Differential Revision: D58822423
1 parent dcd5c72 commit eaea423

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

torchrec/distributed/benchmark/benchmark_inference.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,42 @@ def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkR
129129
)
130130

131131

132+
def benchmark_qec_unsharded(
133+
args: argparse.Namespace, output_dir: str
134+
) -> List[BenchmarkResult]:
135+
tables = get_tables(TABLE_SIZES, is_pooled=False)
136+
sharder = TestQuantECSharder(
137+
sharding_type="",
138+
kernel_type=EmbeddingComputeKernel.QUANT.value,
139+
shardable_params=[table.name for table in tables],
140+
)
141+
142+
module = QuantEmbeddingCollection(
143+
# pyre-ignore [6]
144+
tables=tables,
145+
device=torch.device("cpu"),
146+
quant_state_dict_split_scale_bias=True,
147+
)
148+
149+
args_kwargs = {
150+
argname: getattr(args, argname)
151+
for argname in dir(args)
152+
# Don't include output_dir since output_dir was modified
153+
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
154+
}
155+
156+
return benchmark_module(
157+
module=module,
158+
sharder=sharder,
159+
sharding_types=[],
160+
compile_modes=BENCH_COMPILE_MODES,
161+
tables=tables,
162+
output_dir=output_dir,
163+
benchmark_unsharded=True, # benchmark unsharded module
164+
**args_kwargs,
165+
)
166+
167+
132168
def benchmark_qebc_unsharded(
133169
args: argparse.Namespace, output_dir: str
134170
) -> List[BenchmarkResult]:
@@ -185,9 +221,10 @@ def main() -> None:
185221
"QuantEmbeddingCollection",
186222
]
187223

188-
# Only do unsharded QEBC benchmark when using CPU device
224+
# Only do unsharded QEBC/QEC benchmark when using CPU device
189225
if args.device_type == "cpu":
190226
module_names.append("unshardedQuantEmbeddingBagCollection")
227+
module_names.append("unshardedQuantEmbeddingCollection")
191228

192229
for module_name in module_names:
193230
output_dir = args.output_dir + f"/run_{datetime_sfx}"
@@ -197,9 +234,12 @@ def main() -> None:
197234
elif module_name == "QuantEmbeddingCollection":
198235
output_dir += "_qec"
199236
benchmark_func = benchmark_qec
200-
else:
237+
elif module_name == "unshardedQuantEmbeddingBagCollection":
201238
output_dir += "_uqebc"
202239
benchmark_func = benchmark_qebc_unsharded
240+
else:
241+
output_dir += "_uqec"
242+
benchmark_func = benchmark_qec_unsharded
203243

204244
if not os.path.exists(output_dir):
205245
# Place all outputs under the datetime folder

0 commit comments

Comments
 (0)