Skip to content

Commit 7f77444

Browse files
xw285cornellfacebook-github-bot
authored andcommitted
move memory copy into one_shot_all_reduce (#2770)
Summary: Pull Request resolved: #2770 Avoid latency of launching hipMemcpyAsync. Could see 3-4us reduction in benchmarking. Also see improvements in end to end testing. Reviewed By: sryap, jianyuh Differential Revision: D58223358 fbshipit-source-id: c5bf36866ab5f89a8ce186bcd728d02638c12070
1 parent 2114817 commit 7f77444

File tree

1 file changed

+20
-7
lines changed
  • fbgemm_gpu/experimental/gen_ai/src/comm

1 file changed

+20
-7
lines changed

fbgemm_gpu/experimental/gen_ai/src/comm/car.cu

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,28 @@ __global__ void one_shot_all_reduce(
139139
int32_t flag,
140140
std::array<int32_t*, 8> barriers,
141141
std::array<at::BFloat16*, 8> inputs,
142+
at::BFloat16* ar_input,
142143
at::BFloat16* acc,
143144
at::BFloat16* output,
144145
int32_t N) {
146+
// It is expensive to launch hipMemcpyAsync on ROCm
147+
// Move data copy here. Each block copies part of input data
148+
at::BFloat16* input = inputs[rank];
149+
for (size_t i = blockDim.x * blockIdx.x * 8 + threadIdx.x * 8; i < N;
150+
i += (size_t)blockDim.x * gridDim.x * 8) {
151+
#if defined(USE_ROCM)
152+
__builtin_nontemporal_store(
153+
reinterpret_cast<uint64_t*>(&ar_input[i])[0], (uint64_t*)(&input[i]));
154+
__builtin_nontemporal_store(
155+
reinterpret_cast<uint64_t*>(&ar_input[i])[1],
156+
(uint64_t*)(&input[i]) + 1);
157+
#else
158+
*reinterpret_cast<uint64_t*>(&input[i]) =
159+
reinterpret_cast<uint64_t*>(&ar_input[i])[0];
160+
*(reinterpret_cast<uint64_t*>(&input[i]) + 1) =
161+
reinterpret_cast<uint64_t*>(&ar_input[i])[1];
162+
#endif
163+
}
145164
// Synchronize the ranks.
146165
volatile int32_t* barrier_d = barriers[rank];
147166
if (threadIdx.x < kWorldSize) {
@@ -516,13 +535,6 @@ void one_shot_car_allreduce(
516535
barriers[ii] = state->barriers_[ii].data_ptr<int32_t>();
517536
}
518537

519-
AT_CUDA_CHECK(cudaMemcpyAsync(
520-
inputs[state->rank_],
521-
y.data_ptr<at::BFloat16>(),
522-
y.numel() * y.element_size(),
523-
cudaMemcpyDeviceToDevice,
524-
at::cuda::getCurrentCUDAStream()));
525-
526538
constexpr int32_t N_per_thread = 8;
527539
constexpr int32_t N_per_warp = N_per_thread * kThreadsPerWarp;
528540
TORCH_CHECK(N % N_per_warp == 0);
@@ -555,6 +567,7 @@ void one_shot_car_allreduce(
555567
state->flag_ * state->world_size_, \
556568
barriers, \
557569
inputs, \
570+
y.data_ptr<at::BFloat16>(), \
558571
z ? z->data_ptr<at::BFloat16>() : nullptr, \
559572
y_allreduce.data_ptr<at::BFloat16>(), \
560573
N); \

0 commit comments

Comments
 (0)