Skip to content

Scope ifdef down as per review. #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 11, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 7 additions & 25 deletions aten/src/THC/THCBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@ void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, i
}
#endif

#ifdef __HIP_PLATFORM_HCC__
void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,
float beta, float *c[], int64_t ldc, int64_t batchCount)
Expand All @@ -389,23 +388,15 @@ void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t
"with the bound [val] <= %d", INT_MAX);
}

#ifdef __HIP_PLATFORM_HCC__

const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
const int64_t stridec = ldc*n;

THCudaBlas_SgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount);

}
#else
void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb,
float beta, float *c[], int64_t ldc, int64_t batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_SgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}

adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
Expand All @@ -417,8 +408,8 @@ void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
(int)batchCount));
}
#endif
}

#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
Expand All @@ -445,7 +436,6 @@ void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, i
}
#endif

#ifdef __HIP_PLATFORM_HCC__
void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb,
double beta, double *c[], int64_t ldc, int64_t batchCount)
Expand All @@ -456,23 +446,15 @@ void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t
"with the bound [val] <= %d", INT_MAX);
}

#ifdef __HIP_PLATFORM_HCC__

const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
const int64_t stridec = ldc*n;

THCudaBlas_DgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount);

}
#else
void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb,
double beta, double *c[], int64_t ldc, int64_t batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{
THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
}

adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
cublasOperation_t opa = convertTransToCublasOperation(transa);
Expand All @@ -484,8 +466,8 @@ void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
(int)batchCount));
}
#endif
}

#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
Expand Down