Skip to content

Commit 6fd8b50

Browse files
authored
Feature distributed fused adam (#184)
* Updated feature of distributed fused adam from upstream. Updated its dependencies - fused adam, distributed adam. Updated the unit test case for distributed fused adam. * Raise Exception when nccl user buffer / cuda graph is used in distributed fused adam. Skipped these particular UTs * Adding support for rccl_ub in distributed_fused_adam * build nccl_allocator module when cuda_ext flag is mentioned
1 parent 8051f20 commit 6fd8b50

12 files changed

+4365
-1081
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
2+
#include <c10/cuda/CUDACachingAllocator.h>
3+
#include <c10/util/Exception.h>
4+
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
5+
#include <torch/extension.h>
6+
7+
#include <nccl.h>
8+
9+
#define NCCL_CHECK(cmd) \
10+
do { \
11+
ncclResult_t result = cmd; \
12+
if (result != ncclSuccess) { \
13+
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
14+
std::to_string(__LINE__) + ", " + \
15+
std::string(ncclGetErrorString(result)); \
16+
TORCH_CHECK(false, err); \
17+
} \
18+
} while (0)
19+
20+
void *nccl_alloc_plug(size_t size, int device, void *stream) {
21+
void *ptr;
22+
NCCL_CHECK(ncclMemAlloc(&ptr, size));
23+
return ptr;
24+
}
25+
26+
void nccl_free_plug(void *ptr, std::size_t size, int device, void *stream) {
27+
NCCL_CHECK(ncclMemFree(ptr));
28+
}
29+
30+
std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> nccl_allocator;
31+
32+
void maybe_init() {
33+
if (!nccl_allocator) {
34+
nccl_allocator = std::make_shared<
35+
torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator>(
36+
nccl_alloc_plug, nccl_free_plug);
37+
}
38+
}
39+
40+
std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
41+
get_nccl_allocator() {
42+
maybe_init();
43+
return nccl_allocator;
44+
}
45+
46+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
47+
m.def("get_nccl_allocator", []() { return get_nccl_allocator(); });
48+
};
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2+
// This file is used to check the version of NCCL detected.
3+
#include <tuple>
4+
5+
#include <torch/extension.h>
6+
7+
std::tuple<int, int> get_nccl_version();
8+
9+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10+
m.def("get_nccl_version", &get_nccl_version);
11+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2+
3+
// This file is used to check the version of NCCL detected.
4+
#include <tuple>
5+
#include <nccl.h>
6+
7+
8+
std::tuple<int, int> get_nccl_version() {
9+
return { int(NCCL_MAJOR), int(NCCL_MINOR) };
10+
}

apex/contrib/csrc/optimizers/fused_adam_cuda.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_ou
7676
}
7777

7878
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
79-
m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
80-
m.def("adam", &adam, "Adam optimized CUDA implementation.");
81-
m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.");
82-
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
83-
m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.");
84-
m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.");
85-
m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.");
79+
m.def("strided_check_finite", &strided_check_finite, "Strided finite check.", py::call_guard<py::gil_scoped_release>());
80+
m.def("adam", &adam, "Adam optimized CUDA implementation.", py::call_guard<py::gil_scoped_release>());
81+
m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.", py::call_guard<py::gil_scoped_release>());
82+
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.", py::call_guard<py::gil_scoped_release>());
83+
m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.", py::call_guard<py::gil_scoped_release>());
84+
m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.", py::call_guard<py::gil_scoped_release>());
85+
m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.", py::call_guard<py::gil_scoped_release>());
8686
}
Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,36 @@
11
#include <torch/extension.h>
22

33
void multi_tensor_fused_adam_cuda(
4-
int chunk_size,
5-
at::Tensor noop_flag,
6-
std::vector<std::vector<at::Tensor>> tensor_lists,
7-
at::Tensor per_tensor_beta1,
8-
at::Tensor per_tensor_beta2,
9-
at::Tensor per_tensor_bias_correction,
10-
at::Tensor per_tensor_eps,
11-
at::Tensor per_tensor_weight_decay,
12-
float lr,
13-
float grad_scale,
14-
int step,
15-
int mode);
4+
int chunk_size, at::Tensor noop_flag,
5+
std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor grad_scale,
6+
float lr, float beta1, float beta2, float eps, int step, int mode,
7+
int bias_correction, float weight_decay);
8+
9+
void multi_tensor_fused_adam_capturable_cuda(
10+
int chunk_size, at::Tensor noop_flag,
11+
std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor grad_scale,
12+
at::Tensor lr, float beta1, float beta2, float eps, at::Tensor step,
13+
int mode, int bias_correction, float weight_decay);
14+
15+
void multi_tensor_fused_adam_with_param_remainders_cuda(
16+
int chunk_size, at::Tensor noop_flag,
17+
std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor grad_scale,
18+
float lr, float beta1, float beta2, float eps, int step, int mode,
19+
int bias_correction, float weight_decay);
1620

1721
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1822
m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda,
19-
"Multi tensor Adam optimized CUDA implementation.");
20-
}
23+
"CUDA kernels for multi-tensor Adam, "
24+
"with param copy",
25+
py::call_guard<py::gil_scoped_release>());
26+
m.def("multi_tensor_fused_adam_capturable",
27+
&multi_tensor_fused_adam_capturable_cuda,
28+
"CUDA kernels for multi-tensor Adam, "
29+
"with param copy, capturable for CUDA graph",
30+
py::call_guard<py::gil_scoped_release>());
31+
m.def("multi_tensor_fused_adam_with_param_remainders",
32+
&multi_tensor_fused_adam_with_param_remainders_cuda,
33+
"CUDA kernel for multi-tensor Adam, "
34+
"with stored param remainders and param copy",
35+
py::call_guard<py::gil_scoped_release>());
36+
}

0 commit comments

Comments
 (0)