diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu index 9ef0c4efe4..be111977b2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu @@ -139,9 +139,28 @@ __global__ void one_shot_all_reduce( int32_t flag, std::array barriers, std::array inputs, + at::BFloat16* ar_input, at::BFloat16* acc, at::BFloat16* output, int32_t N) { + // It is expensive to launch hipMemcpyAsync on ROCm + // Move data copy here. Each block copies part of input data + at::BFloat16* input = inputs[rank]; + for (size_t i = blockDim.x * blockIdx.x * 8 + threadIdx.x * 8; i < N; + i += (size_t)blockDim.x * gridDim.x * 8) { +#if defined(USE_ROCM) + __builtin_nontemporal_store( + reinterpret_cast(&ar_input[i])[0], (uint64_t*)(&input[i])); + __builtin_nontemporal_store( + reinterpret_cast(&ar_input[i])[1], + (uint64_t*)(&input[i]) + 1); +#else + *reinterpret_cast(&input[i]) = + reinterpret_cast(&ar_input[i])[0]; + *(reinterpret_cast(&input[i]) + 1) = + reinterpret_cast(&ar_input[i])[1]; +#endif + } // Synchronize the ranks. volatile int32_t* barrier_d = barriers[rank]; if (threadIdx.x < kWorldSize) { @@ -516,13 +535,6 @@ void one_shot_car_allreduce( barriers[ii] = state->barriers_[ii].data_ptr(); } - AT_CUDA_CHECK(cudaMemcpyAsync( - inputs[state->rank_], - y.data_ptr(), - y.numel() * y.element_size(), - cudaMemcpyDeviceToDevice, - at::cuda::getCurrentCUDAStream())); - constexpr int32_t N_per_thread = 8; constexpr int32_t N_per_warp = N_per_thread * kThreadsPerWarp; TORCH_CHECK(N % N_per_warp == 0); @@ -555,6 +567,7 @@ void one_shot_car_allreduce( state->flag_ * state->world_size_, \ barriers, \ inputs, \ + y.data_ptr(), \ z ? z->data_ptr() : nullptr, \ y_allreduce.data_ptr(), \ N); \