Skip to content

Commit f922d21

Browse files
authored
Merge pull request #195 from iotamudelta/fp16_fixes
hgemm
2 parents fe465db + 8b2da17 commit f922d21

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

aten/src/THC/THCBlas.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,13 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t m, int6
285285
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
286286
cublasSetStream(handle, THCState_getCurrentStream(state));
287287

288+
#ifdef __HIP_PLATFORM_HCC__
289+
THCublasCheck(rocblas_hgemm(handle, opa, opb, i_m, i_n, i_k,
290+
reinterpret_cast<rocblas_half*>(&alpha), reinterpret_cast<rocblas_half*>(a), i_lda,
291+
reinterpret_cast<rocblas_half*>(b), i_ldb, reinterpret_cast<rocblas_half*>(&beta),
292+
reinterpret_cast<rocblas_half*>(c), i_ldc));
293+
#else
294+
288295
// Simulated Hgemm
289296
float fAlpha = THC_half2float(alpha);
290297
float fBeta = THC_half2float(beta);
@@ -314,6 +321,7 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t m, int6
314321
a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
315322
i_ldb, &fBeta, c, CUDA_R_16F, i_ldc));
316323
}
324+
#endif
317325
#endif
318326
return;
319327
}

0 commit comments

Comments
 (0)