We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents fe465db + 8b2da17 commit f922d21Copy full SHA for f922d21
aten/src/THC/THCBlas.cu
@@ -285,6 +285,13 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t m, int6
285
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
286
cublasSetStream(handle, THCState_getCurrentStream(state));
287
288
+#ifdef __HIP_PLATFORM_HCC__
289
+ THCublasCheck(rocblas_hgemm(handle, opa, opb, i_m, i_n, i_k,
290
+ reinterpret_cast<rocblas_half*>(&alpha), reinterpret_cast<rocblas_half*>(a), i_lda,
291
+ reinterpret_cast<rocblas_half*>(b), i_ldb, reinterpret_cast<rocblas_half*>(&beta),
292
+ reinterpret_cast<rocblas_half*>(c), i_ldc));
293
+#else
294
+
295
// Simulated Hgemm
296
float fAlpha = THC_half2float(alpha);
297
float fBeta = THC_half2float(beta);
@@ -314,6 +321,7 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t m, int6
314
321
a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
315
322
i_ldb, &fBeta, c, CUDA_R_16F, i_ldc));
316
323
}
324
+#endif
317
325
#endif
318
326
return;
319
327
0 commit comments