Skip to content

Commit eb1ce9c

Browse files
KyleCZHpruthvistony
authored andcommitted
[ROCm] use ncclAllToAll for rocm
use ncclAllToAll for rocm version > 5.0; ROCm/rccl#503 detail on ncclAllToAll: ROCm/rccl#503 @jithunnair-amd @amathews-amd Pull Request resolved: pytorch#75128 Approved by: https://github.com/wenkaidu, https://github.com/yzygitzh, https://github.com/seemethere
1 parent 44297c1 commit eb1ce9c

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torch/csrc/cuda/nccl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,9 @@ void all2all_single_equal_split(at::Tensor& input,
650650
const auto* sendbuff = reinterpret_cast<char*>(input.data_ptr());
651651
auto* recvbuff = reinterpret_cast<char *>(output.data_ptr());
652652
auto comm = to_nccl_comm(_comm);
653+
#if defined(USE_ROCM) && ROCM_VERSION >= 50000
654+
NCCL_CHECK(ncclAllToAll(sendbuff , recvbuff , count, type, comm, stream));
655+
#else
653656
NCCL_CHECK(ncclCommCount(comm, &numranks));
654657
NCCL_CHECK(ncclGroupStart());
655658
for(const auto r : c10::irange(numranks)) {
@@ -661,6 +664,7 @@ void all2all_single_equal_split(at::Tensor& input,
661664
}
662665
}
663666
NCCL_CHECK(ncclGroupEnd());
667+
#endif
664668
#else
665669
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
666670
#endif

0 commit comments

Comments
 (0)