Skip to content

Commit 73d48cf

Browse files
authored
Merge pull request #116 from iotamudelta/master
Scope ifdef down as per review.
2 parents 35bfa51 + 6316bd3 commit 73d48cf

File tree

1 file changed

+7
-25
lines changed

1 file changed

+7
-25
lines changed

aten/src/THC/THCBlas.cu

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,6 @@ void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, i
378378
}
379379
#endif
380380

381-
#ifdef __HIP_PLATFORM_HCC__
382381
void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
383382
float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,
384383
float beta, float *c[], int64_t ldc, int64_t batchCount)
@@ -389,23 +388,15 @@ void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t
389388
"with the bound [val] <= %d", INT_MAX);
390389
}
391390

391+
#ifdef __HIP_PLATFORM_HCC__
392+
392393
const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
393394
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
394395
const int64_t stridec = ldc*n;
395396

396397
THCudaBlas_SgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount);
397398

398-
}
399399
#else
400-
void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
401-
float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,
402-
float beta, float *c[], int64_t ldc, int64_t batchCount)
403-
{
404-
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
405-
{
406-
THError("Cublas_SgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
407-
"with the bound [val] <= %d", INT_MAX);
408-
}
409400

410401
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
411402
cublasOperation_t opa = convertTransToCublasOperation(transa);
@@ -417,8 +408,8 @@ void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t
417408
opa, opb, (int)m, (int)n, (int)k,
418409
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
419410
(int)batchCount));
420-
}
421411
#endif
412+
}
422413

423414
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
424415
void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
@@ -445,7 +436,6 @@ void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, i
445436
}
446437
#endif
447438

448-
#ifdef __HIP_PLATFORM_HCC__
449439
void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
450440
double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb,
451441
double beta, double *c[], int64_t ldc, int64_t batchCount)
@@ -456,23 +446,15 @@ void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t
456446
"with the bound [val] <= %d", INT_MAX);
457447
}
458448

449+
#ifdef __HIP_PLATFORM_HCC__
450+
459451
const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
460452
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
461453
const int64_t stridec = ldc*n;
462-
454+
463455
THCudaBlas_DgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount);
464456

465-
}
466457
#else
467-
void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
468-
double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb,
469-
double beta, double *c[], int64_t ldc, int64_t batchCount)
470-
{
471-
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
472-
{
473-
THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
474-
"with the bound [val] <= %d", INT_MAX);
475-
}
476458

477459
adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
478460
cublasOperation_t opa = convertTransToCublasOperation(transa);
@@ -484,8 +466,8 @@ void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t
484466
opa, opb, (int)m, (int)n, (int)k,
485467
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
486468
(int)batchCount));
487-
}
488469
#endif
470+
}
489471

490472
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
491473
void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,

0 commit comments

Comments
 (0)