Skip to content

Commit b6a1f48

Browse files
athittenhubertlu-tw
and
hubertlu-tw
authored
Add rocblas_alt_impl falg for bwd rocblas calls in MHA (#70)
* Add missing flags arg in gemm_switch_fp32accum call * Add rocblas_alt_impl flag in MHA <rev> Add rocblas_alt_impl flag for all bwd gemms in MHA module * Use ifdef for rocblas_gemm_flags_fp16_alt_impl to target at various AMD hardware Co-authored-by: hubertlu-tw <[email protected]>
1 parent 7bef81f commit b6a1f48

7 files changed

+125
-48
lines changed

apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
8787
char b_layout_n{'n'};
8888

8989
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
90+
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
91+
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
92+
#endif
93+
9094
// Input Linear Q Fwd
9195
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
9296
CUBLAS_OP_T,
@@ -159,7 +163,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
159163
static_cast<half*>(softmax_results_ptr),
160164
k_seq_len,
161165
k_seq_len*q_seq_len,
162-
attn_batches);
166+
attn_batches,
167+
flags);
163168

164169
// Padded Softmax
165170
bool softmax_success = false;
@@ -212,7 +217,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
212217
static_cast<half*>(matmul2_results.data_ptr()),
213218
head_dim*attn_batches,
214219
head_dim,
215-
attn_batches);
220+
attn_batches,
221+
flags);
216222

217223
// Output Linear
218224
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
@@ -315,7 +321,9 @@ std::vector<torch::Tensor> bwd_cuda(
315321
char b_layout_t{'t'};
316322

317323
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
318-
324+
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
325+
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
326+
#endif
319327
// Output Linear Dgrad
320328
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
321329
CUBLAS_OP_N,
@@ -388,7 +396,8 @@ std::vector<torch::Tensor> bwd_cuda(
388396
static_cast<half*>(matmul2_grads.data_ptr()),
389397
k_seq_len,
390398
k_seq_len*q_seq_len,
391-
attn_batches);
399+
attn_batches,
400+
flags);
392401

393402
// Matmul2 Dgrad2
394403
gemm_switch_fp32accum( a_layout_n,
@@ -410,7 +419,8 @@ std::vector<torch::Tensor> bwd_cuda(
410419
v_lin_grads_ptr,
411420
lead_dim_kv,
412421
batch_stride_kv,
413-
attn_batches);
422+
attn_batches,
423+
flags);
414424

415425
// Apply Dropout Mask and Scale by Dropout Probability
416426
apex_masked_scale_cuda<at::Half,float,uint32_t>(
@@ -449,7 +459,8 @@ std::vector<torch::Tensor> bwd_cuda(
449459
q_lin_grads_ptr,
450460
lead_dim_q,
451461
batch_stride_q,
452-
attn_batches);
462+
attn_batches,
463+
flags);
453464

454465
// Matmul1 Dgrad2
455466
gemm_switch_fp32accum( a_layout_n,
@@ -471,7 +482,8 @@ std::vector<torch::Tensor> bwd_cuda(
471482
k_lin_grads_ptr,
472483
lead_dim_kv,
473484
batch_stride_kv,
474-
attn_batches);
485+
attn_batches,
486+
flags);
475487

476488
// Input Linear Q Dgrad
477489
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,

apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ std::vector<torch::Tensor> fwd_cuda(
113113
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
114114
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
115115

116+
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
117+
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
118+
#endif
119+
116120
// Input Linear Q Fwd
117121
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
118122
CUBLAS_OP_T,
@@ -185,7 +189,8 @@ std::vector<torch::Tensor> fwd_cuda(
185189
static_cast<half*>(softmax_results_ptr),
186190
k_seq_len,
187191
k_seq_len*q_seq_len,
188-
attn_batches);
192+
attn_batches,
193+
flags);
189194

190195
// Padded Softmax
191196
bool softmax_success = false;
@@ -239,7 +244,8 @@ std::vector<torch::Tensor> fwd_cuda(
239244
static_cast<half*>(matmul2_results.data_ptr()),
240245
head_dim*attn_batches,
241246
head_dim,
242-
attn_batches);
247+
attn_batches,
248+
flags);
243249

244250
// Output Linear
245251
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
@@ -371,6 +377,10 @@ std::vector<torch::Tensor> bwd_cuda(
371377
char b_layout_t{'t'};
372378

373379
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
380+
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
381+
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
382+
#endif
383+
374384

375385
// Dropout Add Backward
376386
apex_masked_scale_cuda<at::Half,float,uint32_t>(
@@ -452,7 +462,8 @@ std::vector<torch::Tensor> bwd_cuda(
452462
static_cast<half*>(matmul2_grads.data_ptr()),
453463
k_seq_len,
454464
k_seq_len*q_seq_len,
455-
attn_batches);
465+
attn_batches,
466+
flags);
456467

457468
// Matmul2 Dgrad2
458469
gemm_switch_fp32accum( a_layout_n,
@@ -474,7 +485,8 @@ std::vector<torch::Tensor> bwd_cuda(
474485
v_lin_grads_ptr,
475486
lead_dim_kv,
476487
batch_stride_kv,
477-
attn_batches);
488+
attn_batches,
489+
flags);
478490

479491
// Apply Dropout Mask and Scale by Dropout Probability
480492
apex_masked_scale_cuda<at::Half,float,uint32_t>(
@@ -513,7 +525,8 @@ std::vector<torch::Tensor> bwd_cuda(
513525
q_lin_grads_ptr,
514526
lead_dim_q,
515527
batch_stride_q,
516-
attn_batches);
528+
attn_batches,
529+
flags);
517530

518531
// Matmul1 Dgrad2
519532
gemm_switch_fp32accum( a_layout_n,
@@ -535,7 +548,8 @@ std::vector<torch::Tensor> bwd_cuda(
535548
k_lin_grads_ptr,
536549
lead_dim_kv,
537550
batch_stride_kv,
538-
attn_batches);
551+
attn_batches,
552+
flags);
539553

540554
// Input Linear Q Dgrad
541555
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,

apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ std::vector<torch::Tensor> fwd_cuda(
8888
char b_layout_n{'n'};
8989

9090
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
91+
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
92+
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
93+
#endif
9194
// Input Linear Fwd
9295
input_lin_results.copy_(input_biases);
9396
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
@@ -135,7 +138,8 @@ std::vector<torch::Tensor> fwd_cuda(
135138
static_cast<half*>(bmm1_results_ptr),
136139
k_seq_len,
137140
k_seq_len*q_seq_len,
138-
attn_batches);
141+
attn_batches,
142+
flags);
139143

140144
// Padded Softmax
141145
bool softmax_success = false;
@@ -180,7 +184,8 @@ std::vector<torch::Tensor> fwd_cuda(
180184
static_cast<half*>(matmul2_results.data_ptr()),
181185
head_dim*attn_batches,
182186
head_dim,
183-
attn_batches);
187+
attn_batches,
188+
flags);
184189

185190
outputs.copy_(output_biases);
186191

@@ -270,6 +275,9 @@ std::vector<torch::Tensor> bwd_cuda(
270275
char b_layout_t{'t'};
271276

272277
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
278+
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
279+
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
280+
#endif
273281

274282
// Output Linear Dgrad
275283
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
@@ -321,7 +329,7 @@ std::vector<torch::Tensor> bwd_cuda(
321329
rocblas_datatype_f32_r,
322330
algo,
323331
solution_index,
324-
flags));
332+
flags));
325333

326334
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
327335
// MatMul2 Dgrad1
@@ -344,7 +352,8 @@ std::vector<torch::Tensor> bwd_cuda(
344352
static_cast<half*>(matmul2_grads.data_ptr()),
345353
k_seq_len,
346354
k_seq_len*q_seq_len,
347-
attn_batches);
355+
attn_batches,
356+
flags);
348357

349358
// Matmul2 Dgrad2
350359
gemm_switch_fp32accum( a_layout_n,
@@ -366,7 +375,8 @@ std::vector<torch::Tensor> bwd_cuda(
366375
v_lin_grads_ptr,
367376
lead_dim,
368377
batch_stride,
369-
attn_batches);
378+
attn_batches,
379+
flags);
370380

371381
// Apply Dropout Mask and Scale by Dropout Probability
372382
// Softmax Grad
@@ -403,7 +413,8 @@ std::vector<torch::Tensor> bwd_cuda(
403413
q_lin_grads_ptr,
404414
lead_dim,
405415
batch_stride,
406-
attn_batches);
416+
attn_batches,
417+
flags);
407418

408419
// Matmul1 Dgrad2
409420
gemm_switch_fp32accum( a_layout_n,
@@ -425,7 +436,8 @@ std::vector<torch::Tensor> bwd_cuda(
425436
k_lin_grads_ptr,
426437
lead_dim,
427438
batch_stride,
428-
attn_batches);
439+
attn_batches,
440+
flags);
429441

430442
// Input Linear Dgrad
431443
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,

apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
8080
char b_layout_n{'n'};
8181

8282
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
83+
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
84+
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
85+
#endif
86+
8387
// Input Linear Fwd
8488
input_lin_results.copy_(input_biases);
8589
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
@@ -127,7 +131,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
127131
static_cast<half*>(softmax_results_ptr),
128132
k_seq_len,
129133
k_seq_len*q_seq_len,
130-
attn_batches);
134+
attn_batches,
135+
flags);
131136

132137
// Padded Softmax
133138
bool softmax_success = false;
@@ -180,7 +185,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
180185
static_cast<half*>(matmul2_results.data_ptr()),
181186
head_dim*attn_batches,
182187
head_dim,
183-
attn_batches);
188+
attn_batches,
189+
flags);
184190

185191
outputs.copy_(output_biases);
186192

@@ -270,6 +276,9 @@ std::vector<torch::Tensor> bwd_cuda(
270276
char b_layout_t{'t'};
271277

272278
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
279+
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
280+
flags = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
281+
#endif
273282

274283
// Output Linear Dgrad
275284
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
@@ -344,7 +353,8 @@ std::vector<torch::Tensor> bwd_cuda(
344353
static_cast<half*>(matmul2_grads.data_ptr()),
345354
k_seq_len,
346355
k_seq_len*q_seq_len,
347-
attn_batches);
356+
attn_batches,
357+
flags);
348358

349359
// Matmul2 Dgrad2
350360
gemm_switch_fp32accum( a_layout_n,
@@ -366,7 +376,8 @@ std::vector<torch::Tensor> bwd_cuda(
366376
v_lin_grads_ptr,
367377
lead_dim,
368378
batch_stride,
369-
attn_batches);
379+
attn_batches,
380+
flags);
370381

371382
// Apply Dropout Mask and Scale by Dropout Probability
372383
// Softmax Grad
@@ -398,7 +409,8 @@ std::vector<torch::Tensor> bwd_cuda(
398409
q_lin_grads_ptr,
399410
lead_dim,
400411
batch_stride,
401-
attn_batches);
412+
attn_batches,
413+
flags);
402414

403415
// Matmul1 Dgrad2
404416
gemm_switch_fp32accum( a_layout_n,
@@ -420,7 +432,8 @@ std::vector<torch::Tensor> bwd_cuda(
420432
k_lin_grads_ptr,
421433
lead_dim,
422434
batch_stride,
423-
attn_batches);
435+
attn_batches,
436+
flags);
424437
// Input Linear Dgrad
425438
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
426439
CUBLAS_OP_N,

0 commit comments

Comments
 (0)