Skip to content

Commit 0580bd0

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
benchmark of fbgemm op - permute_multi_embedding (#2158)
Summary: Pull Request resolved: #2158 X-link: pytorch/FBGEMM#2771 # context * added both **op-level** and **fn-level** benchmarks for the KT.regroup implementations * analyze the op-level and fn-level performance in runtime and memory usage * findings are that: **a**. In the fn-level performance, the `permute_multi_embedding` (new op) outperforms both the native-pytorch implementation and the `permute_pooled_embs_auto_grad` (current Prod) by 50% GPU runtime and 33% memory usage **b**. In the op-level performance, the new op is slightly slower than the current prod (by ~5% GPU runtime) * conclusion: **we should use the new op** # other considerations The good: 1. the algorithm is designed in a way that it doesn't need to know in advance whether the 1-to-N mapping exists in the permutes. 2. `_all_keys_used_once` is no longer needed 3. no longer need a torch.cat before calling the old operator 4. no need to use `_pin_and_move` for the meta data (arguments), it will be handled inside the operator, it's more friendly to tracing. 5. no longer need to fallback to native-pytorch implementation when duplicates existed The same bad: 1. it requires several HtoD communications (move tensor to device): a) [resolved] 3 tensors, which are `permutes`, `input_lengths`, and `output_lengths`. Those tensors needs to be on the device so that the cuda kernels has access to it. b) [resolved] 2 lists of (scalar_t*) pointers, input and output tensor lists. c) [resolved] Didn't find a good way to let the kernel knows the address of the lists of input/output tensors, because the lists are also need to be on the device. 2. tensor.contiguous for the backward function, it looks like the grad from the backward are somehow not contiguous. # benchmark * op-level results: new op is ~5% slower in GPU runtime ``` INFO:root:size: 1024 x 136896; permute_multi_embedding: 2.25 ms; permute_pooled_embs: 2.15 ms; delta: 4.5% INFO:root:size: 1024 x 108432; permute_multi_embedding: 1.79 ms; permute_pooled_embs: 1.7 ms; delta: 5.3% INFO:root:size: 1024 x 277232; permute_multi_embedding: 4.54 ms; permute_pooled_embs: 4.37 ms; delta: 3.9% INFO:root:size: 1024 x 244352; permute_multi_embedding: 4.01 ms; permute_pooled_embs: 3.83 ms; delta: 4.9% INFO:root:size: 1024 x 524224; permute_multi_embedding: 8.62 ms; permute_pooled_embs: 8.25 ms; delta: 4.5% INFO:root:size: 1024 x 564080; permute_multi_embedding: 9.27 ms; permute_pooled_embs: 8.92 ms; delta: 3.9% ``` * fn-level results: new op is 50%+ faster in GPU runtime and uses 33% fewer GPU memory ``` _regroup_keyed_tenors | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 2.8 ms | Memory (P90): 1011.0 KeyedTensor.regroup | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 5.0 ms | Memory (P90): 1517.0 KTRegroupAsDict | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 4.9 ms | Memory (P90): 1517.0 permute_multi_embs | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 2.2 ms | Memory (P90): 1011.0 _regroup_keyed_tenors_dup | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 2.5 ms | Memory (P90): 1011.0 KeyedTensor.regroup_dup | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 2.5 ms | Memory (P90): 1011.0 KTRegroupAsDict_dup | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 2.5 ms | Memory (P90): 1011.0 permute_multi_embs_dup | B: 1024 | F: 1020 | device: cuda | Runtime (P90): 3.2 ms | Memory (P90): 1011.0 ``` # traces * [files](https://drive.google.com/drive/folders/1_9hOtQUQeFICBVxQtusvpQ_VajduFUmR?usp=sharing) ``` [[email protected] /data/sandcastle/boxes/fbsource (ae677c240)]$ ll *.json -rw-rw-r-- 1 hhy hhy 8062993 Jun 21 23:26 trace-KeyedTensor.regroup_dup.json -rw-rw-r-- 1 hhy hhy 949610 Jun 21 23:26 trace-KeyedTensor.regroup.json -rw-rw-r-- 1 hhy hhy 5140143 Jun 21 23:26 trace-KTRegroupAsDict_dup.json -rw-rw-r-- 1 hhy hhy 350370 Jun 21 23:26 trace-KTRegroupAsDict.json -rw-rw-r-- 1 hhy hhy 581033 Jun 21 23:26 trace-permute_multi_embs_dup.json -rw-rw-r-- 1 hhy hhy 582607 Jun 21 23:26 trace-permute_multi_embs.json -rw-rw-r-- 1 hhy hhy 8025337 Jun 21 23:26 trace-_regroup_keyed_tenors_dup.json -rw-rw-r-- 1 hhy hhy 8041586 Jun 21 23:26 trace-_regroup_keyed_tenors.json ``` * native-pytorch {F1713052022} * current prod {F1713052648} * new op {F1713052907} * runtime |Operator|CPU runtime|GPU runtime|GPU memory|notes| |---|---|---|---|---| |**native-pytorch**|3.9 ms|3.1 ms|1.0 K|CPU-bounded, allow duplicates| |**prod op**|2.1 ms|4.9 ms|1.5 K|GPU-boudned due to torch.cat, does **NOT** allow duplicates| |**new op**|2.0 ms|2.2 ms|1.0 K|both CPU and GPU runtime outperformed, **ALLOW** duplicates| Reviewed By: dstaay-fb Differential Revision: D58906839 fbshipit-source-id: 6cb28ca17daf16943b28af9b074d1032e7079912
1 parent 3db28b3 commit 0580bd0

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

torchrec/sparse/tests/jagged_tensor_benchmark.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_regroup_keyed_tensors,
2222
KeyedJaggedTensor,
2323
KeyedTensor,
24+
permute_multi_embedding,
2425
)
2526
from torchrec.sparse.tests.utils import build_groups, build_kts
2627

@@ -105,7 +106,7 @@ def wrapped_func(
105106
)
106107

107108
print(
108-
f" {name : <{35}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.1f} ms | Memory (P90): {result.max_mem_percentile(90):5.1f}"
109+
f" {name : <{30}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.2f} ms | Memory (P90): {result.max_mem_percentile(90):5.1f}"
109110
)
110111

111112

@@ -246,6 +247,17 @@ def main(
246247
{"keyed_tensors": kts},
247248
profile,
248249
)
250+
bench(
251+
"permute_multi_embs" + dup,
252+
labels,
253+
batch_size,
254+
n_dense + n_sparse,
255+
device_type,
256+
run_backward,
257+
permute_multi_embedding,
258+
{"keyed_tensors": kts, "groups": groups},
259+
profile,
260+
)
249261

250262

251263
if __name__ == "__main__":

0 commit comments

Comments
 (0)