Skip to content

Commit ed9b160

Browse files
KyleCZHpruthvistony
authored andcommitted
[ROCm] use ncclAllToAll for rocm
use ncclAllToAll for rocm version > 5.0; ROCm/rccl#503 detail on ncclAllToAll: ROCm/rccl#503 @jithunnair-amd @amathews-amd Pull Request resolved: pytorch#75128 Approved by: https://github.com/wenkaidu, https://github.com/yzygitzh, https://github.com/seemethere
1 parent 1f19f03 commit ed9b160

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torch/csrc/cuda/nccl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,9 @@ void all2all_single_equal_split(at::Tensor& input,
652652
const auto* sendbuff = reinterpret_cast<char*>(input.data_ptr());
653653
auto* recvbuff = reinterpret_cast<char *>(output.data_ptr());
654654
auto comm = to_nccl_comm(_comm);
655+
#if defined(USE_ROCM) && ROCM_VERSION >= 50000
656+
NCCL_CHECK(ncclAllToAll(sendbuff , recvbuff , count, type, comm, stream));
657+
#else
655658
NCCL_CHECK(ncclCommCount(comm, &numranks));
656659
NCCL_CHECK(ncclGroupStart());
657660
for(const auto r : c10::irange(numranks)) {
@@ -663,6 +666,7 @@ void all2all_single_equal_split(at::Tensor& input,
663666
}
664667
}
665668
NCCL_CHECK(ncclGroupEnd());
669+
#endif
666670
#else
667671
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
668672
#endif

0 commit comments

Comments
 (0)