From ec8969ba32824a24a4e452b49b094f44e338e2a3 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 26 Jan 2022 14:06:48 -0800 Subject: [PATCH] rocblas alt impl during backward pass only --- aten/src/ATen/Context.cpp | 14 ++++++++++++++ aten/src/ATen/Context.h | 8 ++++++++ aten/src/ATen/cuda/CUDABlas.cpp | 17 +++++++++++++++-- torch/csrc/autograd/function.h | 3 +++ 4 files changed, 40 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 98590b266be402..4a6da80ad32499 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -293,6 +293,20 @@ bool NoTF32Guard::should_disable_tf32() { return override_allow_tf32_flag; } +thread_local bool BackwardPassGuard::is_backward_pass_; + +BackwardPassGuard::BackwardPassGuard() { + is_backward_pass_ = true; +} + +BackwardPassGuard::~BackwardPassGuard() { + is_backward_pass_ = false; +} + +bool BackwardPassGuard::is_backward_pass() { + return is_backward_pass_; +} + bool Context::areVmapFallbackWarningsEnabled() const { return display_vmap_fallback_warnings_; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 88cbc3ec0bb3a1..161aeb7b2e1683 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -387,4 +387,12 @@ struct TORCH_API NoTF32Guard { bool changed = false; }; +struct TORCH_API BackwardPassGuard { + BackwardPassGuard(); + ~BackwardPassGuard(); + static bool is_backward_pass(); +private: + static thread_local bool is_backward_pass_; +}; + } // namespace at diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 5e795396d7dbe5..4f74130185918d 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -15,6 +15,11 @@ #include #endif +#ifdef USE_ROCM +#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) +#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) +#endif + #define CUDABLAS_POSINT_CHECK(FD, X) \ TORCH_CHECK( \ (X > 0 && X <= INT_MAX), \ @@ -246,13 +251,17 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { float falpha = alpha; float fbeta = beta; #ifdef USE_ROCM + int flag = 0; +#if USE_GEMM_FLAGS_FP16_ALT_IMPL + flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; +#endif TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, (void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea, b, rocblas_datatype_f16_r, (int)ldb, strideb, (void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec, c, rocblas_datatype_f16_r, (int)ldc, stridec, (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, - 0, 0)); + 0, flag)); #else #if defined(CUDA_VERSION) && CUDA_VERSION < 11000 // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH @@ -392,6 +401,10 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) { _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); GEMM_CHECK_ARGVALUES(at::Half); #ifdef USE_ROCM + int flag = 0; +#if USE_GEMM_FLAGS_FP16_ALT_IMPL + flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; +#endif TORCH_CUDABLAS_CHECK(rocblas_gemm_ex( handle, opa, @@ -416,7 +429,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) { rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, - 0)); + flag)); #else cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); if (prop->major >= 5) { diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index cc5fa59e9ed6a2..e258cbf4b6588d 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -151,6 +151,9 @@ struct TORCH_API Node : std::enable_shared_from_this { // probably operate with names. at::NoNamesGuard no_names_guard; + // Keep track of backward pass for rocblas. + at::BackwardPassGuard in_backward; + bool pre_sampled = false; if (at::shouldRunRecordFunction(&pre_sampled)) { // Using RecordFunction to trigger observers in the backward pass