Skip to content

Feature distributed fused adam #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/Exception.h>
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
#include <torch/extension.h>

#include <nccl.h>

#define NCCL_CHECK(cmd) \
do { \
ncclResult_t result = cmd; \
if (result != ncclSuccess) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + \
std::string(ncclGetErrorString(result)); \
TORCH_CHECK(false, err); \
} \
} while (0)

void *nccl_alloc_plug(size_t size, int device, void *stream) {
void *ptr;
NCCL_CHECK(ncclMemAlloc(&ptr, size));
return ptr;
}

void nccl_free_plug(void *ptr, std::size_t size, int device, void *stream) {
NCCL_CHECK(ncclMemFree(ptr));
}

std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> nccl_allocator;

void maybe_init() {
if (!nccl_allocator) {
nccl_allocator = std::make_shared<
torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator>(
nccl_alloc_plug, nccl_free_plug);
}
}

std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
get_nccl_allocator() {
maybe_init();
return nccl_allocator;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_nccl_allocator", []() { return get_nccl_allocator(); });
};
11 changes: 11 additions & 0 deletions apex/contrib/csrc/nccl_p2p/nccl_version.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
// This file is used to check the version of NCCL detected.
#include <tuple>

#include <torch/extension.h>

std::tuple<int, int> get_nccl_version();

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_nccl_version", &get_nccl_version);
}
10 changes: 10 additions & 0 deletions apex/contrib/csrc/nccl_p2p/nccl_version_check.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

// This file is used to check the version of NCCL detected.
#include <tuple>
#include <nccl.h>


std::tuple<int, int> get_nccl_version() {
return { int(NCCL_MAJOR), int(NCCL_MINOR) };
}
14 changes: 7 additions & 7 deletions apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_ou
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.");
m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.");
m.def("strided_check_finite", &strided_check_finite, "Strided finite check.", py::call_guard<py::gil_scoped_release>());
m.def("adam", &adam, "Adam optimized CUDA implementation.", py::call_guard<py::gil_scoped_release>());
m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.", py::call_guard<py::gil_scoped_release>());
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.", py::call_guard<py::gil_scoped_release>());
m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.", py::call_guard<py::gil_scoped_release>());
m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.", py::call_guard<py::gil_scoped_release>());
m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.", py::call_guard<py::gil_scoped_release>());
}
44 changes: 30 additions & 14 deletions apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
#include <torch/extension.h>

void multi_tensor_fused_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_bias_correction,
at::Tensor per_tensor_eps,
at::Tensor per_tensor_weight_decay,
float lr,
float grad_scale,
int step,
int mode);
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor grad_scale,
float lr, float beta1, float beta2, float eps, int step, int mode,
int bias_correction, float weight_decay);

void multi_tensor_fused_adam_capturable_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor grad_scale,
at::Tensor lr, float beta1, float beta2, float eps, at::Tensor step,
int mode, int bias_correction, float weight_decay);

void multi_tensor_fused_adam_with_param_remainders_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor grad_scale,
float lr, float beta1, float beta2, float eps, int step, int mode,
int bias_correction, float weight_decay);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda,
"Multi tensor Adam optimized CUDA implementation.");
}
"CUDA kernels for multi-tensor Adam, "
"with param copy",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_fused_adam_capturable",
&multi_tensor_fused_adam_capturable_cuda,
"CUDA kernels for multi-tensor Adam, "
"with param copy, capturable for CUDA graph",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_fused_adam_with_param_remainders",
&multi_tensor_fused_adam_with_param_remainders_cuda,
"CUDA kernel for multi-tensor Adam, "
"with stored param remainders and param copy",
py::call_guard<py::gil_scoped_release>());
}
Loading