@@ -378,7 +378,6 @@ void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, i
378
378
}
379
379
#endif
380
380
381
- #ifdef __HIP_PLATFORM_HCC__
382
381
void THCudaBlas_SgemmBatched (THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
383
382
float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,
384
383
float beta, float *c[], int64_t ldc, int64_t batchCount)
@@ -389,23 +388,15 @@ void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t
389
388
" with the bound [val] <= %d" , INT_MAX);
390
389
}
391
390
391
+ #ifdef __HIP_PLATFORM_HCC__
392
+
392
393
const int64_t stridea = (transa == ' N' || transa == ' n' ) ? lda*k : lda*n;
393
394
const int64_t strideb = (transb == ' N' || transb == ' n' ) ? ldb*n : ldb*k;
394
395
const int64_t stridec = ldc*n;
395
396
396
397
THCudaBlas_SgemmStridedBatched (state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount);
397
398
398
- }
399
399
#else
400
- void THCudaBlas_SgemmBatched (THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
401
- float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,
402
- float beta, float *c[], int64_t ldc, int64_t batchCount)
403
- {
404
- if ( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
405
- {
406
- THError (" Cublas_SgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
407
- " with the bound [val] <= %d" , INT_MAX);
408
- }
409
400
410
401
adjustLdLevel3 (transa, transb, m, n, k, &lda, &ldb, &ldc);
411
402
cublasOperation_t opa = convertTransToCublasOperation (transa);
@@ -417,8 +408,8 @@ void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t
417
408
opa, opb, (int )m, (int )n, (int )k,
418
409
&alpha, a, (int )lda, b, (int )ldb, &beta, c, (int )ldc,
419
410
(int )batchCount));
420
- }
421
411
#endif
412
+ }
422
413
423
414
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
424
415
void THCudaBlas_SgemmStridedBatched (THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
@@ -445,7 +436,6 @@ void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, i
445
436
}
446
437
#endif
447
438
448
- #ifdef __HIP_PLATFORM_HCC__
449
439
void THCudaBlas_DgemmBatched (THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
450
440
double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb,
451
441
double beta, double *c[], int64_t ldc, int64_t batchCount)
@@ -456,23 +446,15 @@ void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t
456
446
" with the bound [val] <= %d" , INT_MAX);
457
447
}
458
448
449
+ #ifdef __HIP_PLATFORM_HCC__
450
+
459
451
const int64_t stridea = (transa == ' N' || transa == ' n' ) ? lda*k : lda*n;
460
452
const int64_t strideb = (transb == ' N' || transb == ' n' ) ? ldb*n : ldb*k;
461
453
const int64_t stridec = ldc*n;
462
-
454
+
463
455
THCudaBlas_DgemmStridedBatched (state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount);
464
456
465
- }
466
457
#else
467
- void THCudaBlas_DgemmBatched (THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
468
- double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb,
469
- double beta, double *c[], int64_t ldc, int64_t batchCount)
470
- {
471
- if ( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
472
- {
473
- THError (" Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
474
- " with the bound [val] <= %d" , INT_MAX);
475
- }
476
458
477
459
adjustLdLevel3 (transa, transb, m, n, k, &lda, &ldb, &ldc);
478
460
cublasOperation_t opa = convertTransToCublasOperation (transa);
@@ -484,8 +466,8 @@ void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t
484
466
opa, opb, (int )m, (int )n, (int )k,
485
467
&alpha, a, (int )lda, b, (int )ldb, &beta, c, (int )ldc,
486
468
(int )batchCount));
487
- }
488
469
#endif
470
+ }
489
471
490
472
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
491
473
void THCudaBlas_DgemmStridedBatched (THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
0 commit comments