Skip to content

Commit e43f13f

Browse files
authored
Remove hipblaslt on Windows (#1)
* Remove hipblaslt on Windows * Additional changes with hipblaslt skipped * changed Blas.cpp * fix format
1 parent 534f000 commit e43f13f

File tree

6 files changed

+11
-11
lines changed

6 files changed

+11
-11
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
346346
const float fbeta = beta;
347347
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
348348

349-
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
349+
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && !defined(_WIN32)
350350
auto compute_type = CUBLAS_COMPUTE_32F;
351351
#else
352352
auto compute_type = CUDA_R_32F;
@@ -529,7 +529,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
529529
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
530530
}
531531
#endif
532-
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
532+
#if defined(USE_ROCM) && !defined(_WIN32) && ROCM_VERSION >= 60000
533533
auto compute_type = CUBLAS_COMPUTE_32F;
534534
#else
535535
auto compute_type = CUDA_R_32F;
@@ -558,7 +558,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
558558
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
559559
}
560560

561-
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
561+
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && !defined(_WIN32) && ROCM_VERSION >= 50700)
562562

563563
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
564564
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
@@ -1102,7 +1102,7 @@ void int8_gemm(
11021102
TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above");
11031103
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
11041104
}
1105-
#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
1105+
#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && !defined(_WIN32) && ROCM_VERSION >= 50700)
11061106

11071107
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
11081108
#if defined(USE_ROCM) && ROCM_VERSION < 50600

aten/src/ATen/cuda/CUDABlas.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
6262
template <>
6363
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
6464

65-
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
65+
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && !defined(_WIN32) && ROCM_VERSION >= 50700)
6666
enum GEMMAndBiasActivationEpilogue {
6767
None,
6868
RELU,

aten/src/ATen/cuda/CUDAContextLight.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
1111
// added bf16 support
12-
#if (!defined(_MSC_VER) && (!defined(USE_ROCM) || ROCM_VERSION >= 50700))
12+
#if (!defined(_MSC_VER) && (!defined(USE_ROCM) && !defined(_WIN32) || ROCM_VERSION >= 50700))
1313
#include <cublasLt.h>
1414
#endif
1515

aten/src/ATen/cuda/CublasHandlePool.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace at::cuda {
2828

2929
namespace {
3030

31-
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
31+
#if defined(USE_ROCM) && !defined(_WIN32) && ROCM_VERSION >= 50700
3232
void createCublasLtHandle(cublasLtHandle_t *handle) {
3333
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
3434
}
@@ -177,7 +177,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
177177
return handle;
178178
}
179179

180-
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
180+
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && !defined(_MSC_VER) && ROCM_VERSION >= 50700)
181181
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
182182
#ifdef USE_ROCM
183183
int device;

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ enum class Activation {
153153
GELU,
154154
};
155155

156-
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
156+
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && !defined(_WIN32) && ROCM_VERSION >= 50700)
157157
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
158158
switch (a) {
159159
case Activation::None:
@@ -320,7 +320,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
320320

321321
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
322322

323-
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
323+
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && !defined(_WIN32) && ROCM_VERSION >= 50700)
324324
if (useLtInterface) {
325325
AT_DISPATCH_FLOATING_TYPES_AND2(
326326
at::ScalarType::Half,

cmake/Dependencies.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1312,7 +1312,7 @@ if(USE_ROCM)
13121312
if(UNIX)
13131313
list(APPEND Caffe2 ${ROCM_ROCTX_LIB})
13141314
endif(UNIX)
1315-
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
1315+
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0" AND NOT WIN32)
13161316
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS ${hipblaslt_LIBRARIES})
13171317
endif()
13181318

0 commit comments

Comments
 (0)