From 3b52730e2f18e1478d84379fbd35b59101673b4b Mon Sep 17 00:00:00 2001 From: Xiaodong Wang Date: Fri, 21 Jun 2024 19:34:45 -0700 Subject: [PATCH] move memory copy into one_shot_all_reduce (#2770) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2770 Avoid latency of launching hipMemcpyAsync. Could see 3-4us reduction in benchmarking. Also see improvements in end to end testing. Reviewed By: sryap, jianyuh Differential Revision: D58223358 --- .../experimental/gen_ai/src/comm/car.cu | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu index 9ef0c4efe4..be111977b2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu @@ -139,9 +139,28 @@ __global__ void one_shot_all_reduce( int32_t flag, std::array barriers, std::array inputs, + at::BFloat16* ar_input, at::BFloat16* acc, at::BFloat16* output, int32_t N) { + // It is expensive to launch hipMemcpyAsync on ROCm + // Move data copy here. Each block copies part of input data + at::BFloat16* input = inputs[rank]; + for (size_t i = blockDim.x * blockIdx.x * 8 + threadIdx.x * 8; i < N; + i += (size_t)blockDim.x * gridDim.x * 8) { +#if defined(USE_ROCM) + __builtin_nontemporal_store( + reinterpret_cast(&ar_input[i])[0], (uint64_t*)(&input[i])); + __builtin_nontemporal_store( + reinterpret_cast(&ar_input[i])[1], + (uint64_t*)(&input[i]) + 1); +#else + *reinterpret_cast(&input[i]) = + reinterpret_cast(&ar_input[i])[0]; + *(reinterpret_cast(&input[i]) + 1) = + reinterpret_cast(&ar_input[i])[1]; +#endif + } // Synchronize the ranks. volatile int32_t* barrier_d = barriers[rank]; if (threadIdx.x < kWorldSize) { @@ -516,13 +535,6 @@ void one_shot_car_allreduce( barriers[ii] = state->barriers_[ii].data_ptr(); } - AT_CUDA_CHECK(cudaMemcpyAsync( - inputs[state->rank_], - y.data_ptr(), - y.numel() * y.element_size(), - cudaMemcpyDeviceToDevice, - at::cuda::getCurrentCUDAStream())); - constexpr int32_t N_per_thread = 8; constexpr int32_t N_per_warp = N_per_thread * kThreadsPerWarp; TORCH_CHECK(N % N_per_warp == 0); @@ -555,6 +567,7 @@ void one_shot_car_allreduce( state->flag_ * state->world_size_, \ barriers, \ inputs, \ + y.data_ptr(), \ z ? z->data_ptr() : nullptr, \ y_allreduce.data_ptr(), \ N); \