Skip to content

Commit 85ace43

Browse files
authored
Merge pull request #1011 from ROCmSoftwarePlatform/release/1.10_revert_ncclAllToAll
Deactive ncclAllToAll
2 parents ed9b160 + 4a5e2b0 commit 85ace43

File tree

1 file changed

+0
-4
lines changed

1 file changed

+0
-4
lines changed

torch/csrc/cuda/nccl.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,6 @@ void all2all_single_equal_split(at::Tensor& input,
652652
const auto* sendbuff = reinterpret_cast<char*>(input.data_ptr());
653653
auto* recvbuff = reinterpret_cast<char *>(output.data_ptr());
654654
auto comm = to_nccl_comm(_comm);
655-
#if defined(USE_ROCM) && ROCM_VERSION >= 50000
656-
NCCL_CHECK(ncclAllToAll(sendbuff , recvbuff , count, type, comm, stream));
657-
#else
658655
NCCL_CHECK(ncclCommCount(comm, &numranks));
659656
NCCL_CHECK(ncclGroupStart());
660657
for(const auto r : c10::irange(numranks)) {
@@ -666,7 +663,6 @@ void all2all_single_equal_split(at::Tensor& input,
666663
}
667664
}
668665
NCCL_CHECK(ncclGroupEnd());
669-
#endif
670666
#else
671667
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
672668
#endif

0 commit comments

Comments
 (0)