Skip to content

Commit ccea031

Browse files
jeffdailypragupta
authored andcommitted
cublaslt/hipblaslt persistent workspace (pytorch#156495)
Similar to cublas/hipblas, LT now allocates one workspace per handle+stream combo. - fixes hipblaslt issue where memory use increased during graph capture - preserves CUDA env var TORCH_CUBLASLT_UNIFIED_WORKSPACE - moves LT workspace and size from CUDABlas.cpp into CublasHandlePool.cpp, new APIs - size_t getCUDABlasLtWorkspaceSize() - void* getCUDABlasLtWorkspace() Fixes ROCm#2286. Pull Request resolved: pytorch#156495 Approved by: https://github.com/eqy (cherry picked from commit 996206e)
1 parent 38abb1a commit ccea031

File tree

4 files changed

+102
-109
lines changed

4 files changed

+102
-109
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 6 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -188,82 +188,11 @@ uint32_t _getAlignment(uintptr_t address) {
188188
}
189189
#endif
190190

191-
static size_t _parseChosenWorkspaceSize() {
192-
auto val = c10::utils::get_env("CUBLASLT_WORKSPACE_SIZE");
193-
#ifdef USE_ROCM
194-
if (!val.has_value()) {
195-
// accept either env var
196-
val = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE");
197-
}
198-
size_t workspace_size = 76*1024; /* Use 76 MB for hipBLASLt */
199-
#else
200-
size_t workspace_size = 1024; /* default size in KiB according to #73328 */
201-
#endif
202-
203-
if (val.has_value()) {
204-
try {
205-
workspace_size = std::stoi(val.value());
206-
} catch (std::invalid_argument const&) {
207-
TORCH_WARN(
208-
"invalid CUBLASLT_WORKSPACE_SIZE,",
209-
" using default workspace size of ",
210-
workspace_size,
211-
" KiB.");
212-
} catch (std::out_of_range const&) {
213-
TORCH_WARN(
214-
"CUBLASLT_WORKSPACE_SIZE out of range,",
215-
" using default workspace size of ",
216-
workspace_size,
217-
" KiB.");
218-
}
219-
}
220-
return workspace_size * 1024;
221-
}
222-
223-
static size_t _getWorkspaceSize() {
224-
static size_t workspace_size = _parseChosenWorkspaceSize();
225-
return workspace_size;
226-
}
227-
228-
void* _getUnifiedWorkspaceWithoutHandle() {
229-
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
230-
auto stream = c10::cuda::getCurrentCUDAStream();
231-
cudaStream_t _stream = stream;
232-
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
233-
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
234-
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
235-
return workspace_it->second.mutable_get();
236-
}
237-
238191
struct CublasLtWorkspace {
239192
CublasLtWorkspace() {
240-
size = _getWorkspaceSize();
241-
#ifndef USE_ROCM
242-
static bool unified = c10::utils::check_env("TORCH_CUBLASLT_UNIFIED_WORKSPACE") == true;
243-
if (unified) {
244-
auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize();
245-
if (cublasWorkspaceSize < size) {
246-
TORCH_WARN_ONCE("Requested unified CUBLASLT workspace size of ", size,
247-
" bytes exceeds CUBLAS workspace size of ", cublasWorkspaceSize,
248-
" bytes. Please increase CUBLAS workspace size",
249-
" via CUBLAS_WORKSPACE_CONFIG or decrease requested"
250-
" CUBLASLT_WORKSPACE_SIZE. Otherwise CUBLASLT workspace"
251-
" size will be limited to the CUBLAS workspace size.");
252-
size = cublasWorkspaceSize;
253-
}
254-
ptr = _getUnifiedWorkspaceWithoutHandle();
255-
} else {
256-
auto allocator = c10::cuda::CUDACachingAllocator::get();
257-
stashed_ptr_ = allocator->allocate(size);
258-
ptr = stashed_ptr_.mutable_get();
259-
}
260-
#else
261-
auto allocator = c10::cuda::CUDACachingAllocator::get();
262-
stashed_ptr_ = allocator->allocate(size);
263-
ptr = stashed_ptr_.mutable_get();
264-
#endif
193+
size = at::cuda::getCUDABlasLtWorkspaceSize();
194+
ptr = at::cuda::getCUDABlasLtWorkspace();
265195
}
266-
at::DataPtr stashed_ptr_;
267196
void * ptr;
268197
size_t size;
269198
};
@@ -2111,10 +2040,8 @@ void int8_gemm(
21112040

21122041
#ifdef USE_ROCM
21132042
CuBlasLtMatmulPreference preference;
2114-
size_t workspaceSize = _getWorkspaceSize();
2115-
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
2116-
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
2117-
auto workspace = allocator.allocate(workspaceSize);
2043+
auto ltworkspace = CublasLtWorkspace();
2044+
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
21182045
cublasLtMatmulHeuristicResult_t heuristicResult = {};
21192046
int returnedResult = 0;
21202047
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
@@ -2152,12 +2079,12 @@ void int8_gemm(
21522079
nullptr, // Heuristics don't seem to work for int8
21532080
#endif
21542081
#ifdef USE_ROCM
2155-
workspace.mutable_get(),
2082+
ltworkspace.ptr,
21562083
#else
21572084
nullptr, // Non-zero workspace doesn't seem to work.
21582085
#endif
21592086
#ifdef USE_ROCM
2160-
workspaceSize,
2087+
ltworkspace.size,
21612088
#else
21622089
0,
21632090
#endif

aten/src/ATen/cuda/CUDAContextLight.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
8989

9090
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
9191
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
92+
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
9293
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
94+
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
95+
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();
9396

9497
#if defined(CUDART_VERSION) || defined(USE_ROCM)
9598
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();

aten/src/ATen/cuda/CublasHandlePool.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
* To work around this difference in behavior, a separate handle pool is available for ROCm builds.
2424
* For CUDA builds, getCurrentCUDABlasLtHandle will alias for getCurrentCUDABlasHandle,
2525
* whereas for ROCm builds, it is a distinct function.
26+
*
27+
* The workspace pools are separate for ROCm. On CUDA, the env var
28+
* TORCH_CUBLASLT_UNIFIED_WORKSPACE can be used to opt-in to unifying the workspace pools.
2629
*/
2730

2831
namespace at::cuda {
@@ -109,8 +112,14 @@ std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_works
109112
return instance;
110113
}
111114

115+
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
116+
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
117+
return instance;
118+
}
119+
112120
void clearCublasWorkspaces() {
113121
cublas_handle_stream_to_workspace().clear();
122+
cublaslt_handle_stream_to_workspace().clear();
114123
}
115124

116125
size_t parseChosenWorkspaceSize() {
@@ -157,15 +166,97 @@ size_t parseChosenWorkspaceSize() {
157166
}
158167
}
159168

169+
size_t parseCUDABlasLtWorkspaceSize() {
170+
auto val = c10::utils::get_env("CUBLASLT_WORKSPACE_SIZE");
171+
#ifdef USE_ROCM
172+
if (!val.has_value()) {
173+
// accept either env var
174+
val = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE");
175+
}
176+
size_t workspace_size = 76*1024; /* Use 76 MB for hipBLASLt */
177+
#else
178+
size_t workspace_size = 1024; /* default size in KiB according to #73328 */
179+
#endif
180+
181+
if (val.has_value()) {
182+
try {
183+
workspace_size = std::stoi(val.value());
184+
} catch (std::invalid_argument const&) {
185+
TORCH_WARN(
186+
"invalid CUBLASLT_WORKSPACE_SIZE,",
187+
" using default workspace size of ",
188+
workspace_size,
189+
" KiB.");
190+
} catch (std::out_of_range const&) {
191+
TORCH_WARN(
192+
"CUBLASLT_WORKSPACE_SIZE out of range,",
193+
" using default workspace size of ",
194+
workspace_size,
195+
" KiB.");
196+
}
197+
}
198+
return workspace_size * 1024;
199+
}
200+
160201
size_t getChosenWorkspaceSize() {
161202
size_t pool_size = parseChosenWorkspaceSize();
162203
return pool_size;
163204
}
164205

206+
#define TORCH_CUBLASLT_UNIFIED_WORKSPACE "TORCH_CUBLASLT_UNIFIED_WORKSPACE"
207+
208+
size_t getCUDABlasLtWorkspaceSize() {
209+
size_t pool_size = parseCUDABlasLtWorkspaceSize();
210+
#ifndef USE_ROCM
211+
static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true;
212+
if (unified) {
213+
auto cublasWorkspaceSize = getChosenWorkspaceSize();
214+
if (cublasWorkspaceSize < pool_size) {
215+
TORCH_WARN_ONCE("Requested unified CUBLASLT workspace size of ", pool_size,
216+
" bytes exceeds CUBLAS workspace size of ", cublasWorkspaceSize,
217+
" bytes. Please increase CUBLAS workspace size",
218+
" via CUBLAS_WORKSPACE_CONFIG or decrease requested"
219+
" CUBLASLT_WORKSPACE_SIZE. Otherwise CUBLASLT workspace"
220+
" size will be limited to the CUBLAS workspace size.");
221+
pool_size = cublasWorkspaceSize;
222+
}
223+
}
224+
#endif
225+
return pool_size;
226+
}
227+
165228
at::DataPtr getNewWorkspace() {
166229
return c10::cuda::CUDACachingAllocator::get()->allocate(getChosenWorkspaceSize());
167230
}
168231

232+
at::DataPtr getNewCUDABlasLtWorkspace() {
233+
return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize());
234+
}
235+
236+
void* getCUDABlasLtWorkspace() {
237+
#ifndef USE_ROCM
238+
static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true;
239+
if (unified) {
240+
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
241+
auto stream = c10::cuda::getCurrentCUDAStream();
242+
cudaStream_t _stream = stream;
243+
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
244+
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
245+
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
246+
return workspace_it->second.mutable_get();
247+
}
248+
#endif
249+
cublasLtHandle_t handle = getCurrentCUDABlasLtHandle();
250+
auto stream = c10::cuda::getCurrentCUDAStream();
251+
cudaStream_t _stream = stream;
252+
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
253+
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
254+
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
255+
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
256+
}
257+
return workspace_it->second.mutable_get();
258+
}
259+
169260
cublasHandle_t getCurrentCUDABlasHandle() {
170261
c10::DeviceIndex device = 0;
171262
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));

aten/src/ATen/cuda/tunable/GemmHipblaslt.h

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -381,28 +381,6 @@ static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
381381
return HIPBLAS_OP_T;
382382
}
383383

384-
static size_t GetHipblasltWorkspaceSize() {
385-
static const auto env = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE");
386-
// 256MB is max workspace size allowed for hipblaslt
387-
// hipblaslt-bench uses 32MB
388-
// recommendation from hipblaslt author was 76MB
389-
// TunableOp hipBLASLt workspace size is aligned with
390-
// PyTorch's default in CUDABlas.cpp (_parseChosenWorkspaceSize)
391-
size_t workspace_size = 76*1024;
392-
if (env) {
393-
try {
394-
workspace_size = std::stoi(env.value());
395-
} catch(std::invalid_argument const& e) {
396-
TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,",
397-
" using default workspace size of ", workspace_size, " KiB.");
398-
} catch(std::out_of_range const& e) {
399-
TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,",
400-
" using default workspace size of ", workspace_size, " KiB.");
401-
}
402-
}
403-
return workspace_size * 1024;
404-
}
405-
406384
template <typename T, cublasStatus_t (*destructor)(T*)>
407385
struct HipBlasLtDeleter {
408386
void operator()(T* x) {
@@ -550,7 +528,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
550528
}
551529
}
552530

553-
size_t workspace_size = GetHipblasltWorkspaceSize();
531+
size_t workspace_size = at::cuda::getCUDABlasLtWorkspaceSize();
554532

555533
auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
556534

@@ -575,10 +553,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
575553
return FAIL;
576554
}
577555

578-
void* workspace_buffer = nullptr;
579-
if (workspace_size > 0) {
580-
workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size);
581-
}
556+
void* workspace_buffer = at::cuda::getCUDABlasLtWorkspace();
582557

583558
TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
584559
matmul.descriptor(),
@@ -601,9 +576,6 @@ class HipblasltGemmOp : public Callable<ParamsT> {
601576
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
602577
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
603578
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
604-
if (workspace_size > 0) {
605-
c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer);
606-
}
607579
return OK;
608580
}
609581

0 commit comments

Comments
 (0)