Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions fbgemm_gpu/experimental/gen_ai/src/comm/car.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,28 @@ __global__ void one_shot_all_reduce(
int32_t flag,
std::array<int32_t*, 8> barriers,
std::array<at::BFloat16*, 8> inputs,
at::BFloat16* ar_input,
at::BFloat16* acc,
at::BFloat16* output,
int32_t N) {
// It is expensive to launch hipMemcpyAsync on ROCm
// Move data copy here. Each block copies part of input data
at::BFloat16* input = inputs[rank];
for (size_t i = blockDim.x * blockIdx.x * 8 + threadIdx.x * 8; i < N;
i += (size_t)blockDim.x * gridDim.x * 8) {
#if defined(USE_ROCM)
__builtin_nontemporal_store(
reinterpret_cast<uint64_t*>(&ar_input[i])[0], (uint64_t*)(&input[i]));
__builtin_nontemporal_store(
reinterpret_cast<uint64_t*>(&ar_input[i])[1],
(uint64_t*)(&input[i]) + 1);
#else
*reinterpret_cast<uint64_t*>(&input[i]) =
reinterpret_cast<uint64_t*>(&ar_input[i])[0];
*(reinterpret_cast<uint64_t*>(&input[i]) + 1) =
reinterpret_cast<uint64_t*>(&ar_input[i])[1];
#endif
}
// Synchronize the ranks.
volatile int32_t* barrier_d = barriers[rank];
if (threadIdx.x < kWorldSize) {
Expand Down Expand Up @@ -516,13 +535,6 @@ void one_shot_car_allreduce(
barriers[ii] = state->barriers_[ii].data_ptr<int32_t>();
}

AT_CUDA_CHECK(cudaMemcpyAsync(
inputs[state->rank_],
y.data_ptr<at::BFloat16>(),
y.numel() * y.element_size(),
cudaMemcpyDeviceToDevice,
at::cuda::getCurrentCUDAStream()));

constexpr int32_t N_per_thread = 8;
constexpr int32_t N_per_warp = N_per_thread * kThreadsPerWarp;
TORCH_CHECK(N % N_per_warp == 0);
Expand Down Expand Up @@ -555,6 +567,7 @@ void one_shot_car_allreduce(
state->flag_ * state->world_size_, \
barriers, \
inputs, \
y.data_ptr<at::BFloat16>(), \
z ? z->data_ptr<at::BFloat16>() : nullptr, \
y_allreduce.data_ptr<at::BFloat16>(), \
N); \
Expand Down