Skip to content

[c10d] Working async version of AllGather, test fix and compiler warnings, and CI #10932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

teng-li
Copy link
Contributor

@teng-li teng-li commented Aug 28, 2018

The previous NCCL all gather doesn't work as expected. This is a fully working async version. Tested on both C++ and Python Frontend.

Multi-node:

tengli@learnfair042:~/new_pytorch/pytorch/torch/lib/build/c10d/test$ TMPFILE="/private/home/tengli/temp/tengli-test" RANK=0 WORLD_SIZE=2 ./ProcessGroupNCCLTest
Multi-node world size: 2 rank: 0
Allreduce test successful
Broadcast test successful
Reduce test successful
Allgather test successful

tengli@learnfair117:~/new_pytorch/pytorch/torch/lib/build/c10d/test$ TMPFILE="/private/home/tengli/temp/tengli-test" RANK=1 WORLD_SIZE=2 ./ProcessGroupNCCLTest
Multi-node world size: 2 rank: 1
Allreduce test successful
Broadcast test successful
Reduce test successful
Allgather test successful

CI test:

test_set_get (__main__.FileStoreTest) ... ok
test_set_get (__main__.PrefixFileStoreTest) ... ok
test_set_get (__main__.PrefixTCPStoreTest) ... ok
test_allreduce_ops (__main__.ProcessGroupGlooTest) ... ok
test_broadcast_ops (__main__.ProcessGroupGlooTest) ... ok
test_allgather_ops (__main__.ProcessGroupNCCLTest) ... ok
test_allreduce_ops (__main__.ProcessGroupNCCLTest) ... ok
test_broadcast_ops (__main__.ProcessGroupNCCLTest) ... ok
test_reduce_ops (__main__.ProcessGroupNCCLTest) ... ok
test_common_errors (__main__.RendezvousFileTest) ... ok
test_nominal (__main__.RendezvousFileTest) ... ok
test_common_errors (__main__.RendezvousTCPTest) ... ok
test_nominal (__main__.RendezvousTCPTest) ... ok
test_unknown_handler (__main__.RendezvousTest) ... ok
test_set_get (__main__.TCPStoreTest) ... ok

@teng-li teng-li requested a review from pietern August 28, 2018 03:14
@teng-li teng-li requested a review from apaszke as a code owner August 28, 2018 03:14
@teng-li teng-li added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 28, 2018
@teng-li teng-li force-pushed the test_fix branch 2 times, most recently from 109e585 to a96580a Compare August 28, 2018 05:46
@teng-li teng-li changed the title [c10d] fix test, and compiler warnings, enabling nccl for ci [c10d] skip test, and compiler warnings, enabling nccl for ci Aug 28, 2018
@teng-li teng-li force-pushed the test_fix branch 2 times, most recently from 08a738f to 7a43626 Compare August 28, 2018 07:42
@teng-li teng-li changed the title [c10d] skip test, and compiler warnings, enabling nccl for ci [c10d] Working async version of AllGather, test fix and compiler warnings, and CI Aug 28, 2018
@pietern
Copy link
Contributor

pietern commented Aug 28, 2018

Looking at the code, was it the test case not allocating enough memory for the output tensors?

@teng-li
Copy link
Contributor Author

teng-li commented Aug 28, 2018

@pietern The biggest issue was that each big tensor (flattened) should be on a different GPU, that was causing the error.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

teng-li has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

teng-li has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

PenghuiCheng pushed a commit to PenghuiCheng/pytorch that referenced this pull request Sep 11, 2018
…nd CI (pytorch#10932)

Summary:
The previous NCCL all gather doesn't work as expected. This is a fully working async version.  Tested on both C++ and Python Frontend.

Multi-node:
```
tengli@learnfair042:~/new_pytorch/pytorch/torch/lib/build/c10d/test$ TMPFILE="/private/home/tengli/temp/tengli-test" RANK=0 WORLD_SIZE=2 ./ProcessGroupNCCLTest
Multi-node world size: 2 rank: 0
Allreduce test successful
Broadcast test successful
Reduce test successful
Allgather test successful

tengli@learnfair117:~/new_pytorch/pytorch/torch/lib/build/c10d/test$ TMPFILE="/private/home/tengli/temp/tengli-test" RANK=1 WORLD_SIZE=2 ./ProcessGroupNCCLTest
Multi-node world size: 2 rank: 1
Allreduce test successful
Broadcast test successful
Reduce test successful
Allgather test successful
```

CI test:
```
test_set_get (__main__.FileStoreTest) ... ok
test_set_get (__main__.PrefixFileStoreTest) ... ok
test_set_get (__main__.PrefixTCPStoreTest) ... ok
test_allreduce_ops (__main__.ProcessGroupGlooTest) ... ok
test_broadcast_ops (__main__.ProcessGroupGlooTest) ... ok
test_allgather_ops (__main__.ProcessGroupNCCLTest) ... ok
test_allreduce_ops (__main__.ProcessGroupNCCLTest) ... ok
test_broadcast_ops (__main__.ProcessGroupNCCLTest) ... ok
test_reduce_ops (__main__.ProcessGroupNCCLTest) ... ok
test_common_errors (__main__.RendezvousFileTest) ... ok
test_nominal (__main__.RendezvousFileTest) ... ok
test_common_errors (__main__.RendezvousTCPTest) ... ok
test_nominal (__main__.RendezvousTCPTest) ... ok
test_unknown_handler (__main__.RendezvousTest) ... ok
test_set_get (__main__.TCPStoreTest) ... ok
```
Pull Request resolved: pytorch#10932

Differential Revision: D9542067

Pulled By: teng-li

fbshipit-source-id: 25513eddcc3119fd736875d69dfb631b10f4ac86
@ezyang ezyang added the merged label Jun 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants