-
Notifications
You must be signed in to change notification settings - Fork 7
Stochastic Rounding Optimizers #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
eca05da
71a0c29
e2112f0
88b9850
50cbbe4
46750ee
17696a2
2e130eb
907f568
31bfb75
ae652ea
b765059
6f7a93a
9159b03
59523c8
8ea3245
ccab446
31bd573
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/native/cuda/stochastic_rounding.cuh> | ||
|
||
|
||
namespace at { | ||
namespace native { | ||
|
||
template <typename input_t, typename output_t> | ||
__global__ void stochastic_rounding_kernel( | ||
const input_t* input, | ||
output_t* output, | ||
const int64_t numel, | ||
std::pair<uint64_t, uint64_t> seed_and_offset) { | ||
int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
curandStatePhilox4_32_10_t state; | ||
curand_init(seed_and_offset.first, tid, seed_and_offset.second, &state); | ||
|
||
round_stochastically<output_t, input_t, at::Half> rounder; | ||
|
||
for (int64_t i = tid; i < numel; i += blockDim.x * gridDim.x) { | ||
output[i] = rounder(input[i], curand_uniform(&state)); | ||
} | ||
} | ||
|
||
Tensor stochastic_rounding_cuda(const Tensor& input, c10::optional<Generator> gen_) { | ||
|
||
TORCH_CHECK(input.is_contiguous()); | ||
|
||
if (input.scalar_type() == kHalf) { | ||
return input; | ||
} | ||
|
||
Tensor output = at::empty_like(input, input.options().dtype(kHalf), input.suggest_memory_format()); | ||
const int64_t numel = input.numel(); | ||
if (numel == 0) { | ||
return output; | ||
} | ||
|
||
const int block = 256; | ||
const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is only correct if the kernel's number of registers per thread is <= 32, otherwise register pressure limits your occupancy. You can recompile kernels with |
||
unsigned int grid = (numel + block - 1) / block; | ||
mcarilli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
grid = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid); | ||
|
||
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator()); | ||
std::pair<uint64_t, uint64_t> rng_engine_inputs; | ||
{ | ||
std::lock_guard<std::mutex> lock(gen->mutex_); | ||
rng_engine_inputs = gen->philox_engine_inputs((numel + block * grid - 1) / (block * grid)); | ||
} | ||
|
||
AT_DISPATCH_FLOATING_TYPES( | ||
input.scalar_type(), "stochastic_rounding_cuda", [&] { | ||
stochastic_rounding_kernel<scalar_t, at::Half><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My biggest concern is, upstream will probably ask you to rewrite this with TensorIterator in some form, as @zasdfgbnm hinted. |
||
input.data_ptr<scalar_t>(), | ||
output.data_ptr<at::Half>(), | ||
numel, rng_engine_inputs); | ||
}); | ||
|
||
return output; | ||
} | ||
|
||
} // namespace native | ||
} // namespace at |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/native/cuda/stochastic_rounding.cuh> | ||
|
||
|
||
namespace at { | ||
namespace native { | ||
|
||
template <typename scalar_t> | ||
__global__ void stochastic_rounding_adam_step_kernel( | ||
scalar_t *weights, scalar_t *gradients, | ||
scalar_t *exp_avg, scalar_t *exp_avg_sq, scalar_t *max_exp_avg_sq, | ||
float *inv_scale, float *found_inf, | ||
float lr, float beta1, float beta2, | ||
float weight_decay, float eps, int step, | ||
bool is_decoupled, bool is_amsgrad, | ||
int numel, std::pair<uint64_t, uint64_t> seeds) { | ||
|
||
if (*found_inf) return; | ||
|
||
int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
curandStatePhilox4_32_10_t state; | ||
curand_init(seeds.first, tid, seeds.second, &state); | ||
|
||
round_stochastically<scalar_t, float, at::Half> rounder; | ||
|
||
float m_correction = 1.0 - powf(beta1, step); | ||
float v_correction = 1.0 - powf(beta2, step); | ||
|
||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { | ||
float weight = static_cast<float>(weights[i]); | ||
float gradient = static_cast<float>(gradients[i]) * (*inv_scale); | ||
float m = static_cast<float>(exp_avg[i]); | ||
// Stochastic Rounding Adam tracks square root of the exponential average of squared gradient. | ||
float v = static_cast<float>(exp_avg_sq[i]); | ||
v = v * v; | ||
float4 random_values = curand_uniform4(&state); | ||
|
||
if (weight_decay != 0.0f) { | ||
if (is_decoupled) | ||
weight *= (1 - lr * weight_decay); | ||
else | ||
gradient += weight_decay * weight; | ||
} | ||
|
||
// Update m and v. | ||
m = beta1 * m + (1.0 - beta1) * gradient; | ||
v = beta2 * v + (1.0 - beta2) * (gradient * gradient); | ||
|
||
// Unbias v | ||
float max_v = v; | ||
if (is_amsgrad) { | ||
float prev_max_v = static_cast<float>(max_exp_avg_sq[i]); | ||
prev_max_v = prev_max_v * prev_max_v; | ||
max_v = fmaxf(prev_max_v, v); | ||
} | ||
|
||
weight -= (lr / m_correction) * m / (sqrtf(max_v / v_correction) + eps); | ||
|
||
weights[i] = rounder(weight, random_values.x); | ||
exp_avg[i] = rounder(m, random_values.y); | ||
exp_avg_sq[i] = rounder(sqrtf(v), random_values.z); | ||
if (is_amsgrad) { | ||
max_exp_avg_sq[i] = rounder(sqrtf(max_v), random_values.w); | ||
} | ||
} | ||
} | ||
|
||
|
||
Tensor stochastic_rounding_adam_step_cuda( | ||
Tensor& param, | ||
const Tensor& grad, | ||
Tensor& exp_avg, | ||
Tensor& exp_avg_sq, | ||
Tensor& max_exp_avg_sq, | ||
const Tensor& inv_scale, | ||
const Tensor& found_inf, | ||
double lr, double beta1, double beta2, | ||
double weight_decay, double eps, int64_t step, | ||
bool is_decoupled, bool is_amsgrad, c10::optional<Generator> gen_) { | ||
|
||
if (param.numel() == 0) return param; | ||
|
||
TORCH_CHECK(param.is_contiguous()); | ||
TORCH_CHECK(grad.is_contiguous()); | ||
TORCH_CHECK(exp_avg.is_contiguous()); | ||
TORCH_CHECK(exp_avg_sq.is_contiguous()); | ||
TORCH_CHECK(max_exp_avg_sq.is_contiguous()); | ||
|
||
const int64_t numel = param.numel(); | ||
const int block_size = 256; | ||
const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; | ||
dim3 dim_block(block_size); | ||
dim3 grid((numel + block_size - 1) / block_size); | ||
mcarilli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); | ||
|
||
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator()); | ||
|
||
uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (block_size * grid.x)) * 4; | ||
std::pair<uint64_t, uint64_t> rng_engine_inputs; | ||
{ | ||
std::lock_guard<std::mutex> lock(gen->mutex_); | ||
rng_engine_inputs = gen->philox_engine_inputs(counter_offset); | ||
} | ||
crcrpar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||
param.scalar_type(), "stochastic_rounding_adam_step_cuda", [&] { | ||
stochastic_rounding_adam_step_kernel<scalar_t><<<grid, dim_block, 0, c10::cuda::getCurrentCUDAStream()>>>( | ||
param.data_ptr<scalar_t>(), | ||
grad.data_ptr<scalar_t>(), | ||
exp_avg.data_ptr<scalar_t>(), | ||
exp_avg_sq.data_ptr<scalar_t>(), | ||
max_exp_avg_sq.data_ptr<scalar_t>(), | ||
inv_scale.data_ptr<float>(), | ||
found_inf.data_ptr<float>(), | ||
lr, beta1, beta2, weight_decay, eps, step, | ||
is_decoupled, is_amsgrad, | ||
numel, rng_engine_inputs); | ||
} | ||
); | ||
AT_CUDA_CHECK(cudaGetLastError()); | ||
return param; | ||
} | ||
|
||
} // namespace native | ||
} // namespace at |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/native/cuda/stochastic_rounding.cuh> | ||
|
||
|
||
namespace at { | ||
namespace native { | ||
|
||
// SGD update math with Stochastic Rounding | ||
template <typename scalar_t> | ||
__global__ void stochastic_rounding_sgd_step_kernel( | ||
scalar_t *weights, scalar_t *gradients, scalar_t *momentum_buffer, | ||
float* inv_scale, float* found_inf, | ||
float weight_decay, float momentum, float dampening, float lr, | ||
bool nesterov, bool first_run, int numel, std::pair<uint64_t, uint64_t> seeds) { | ||
|
||
if (*found_inf) return; | ||
|
||
int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
curandStatePhilox4_32_10_t state; | ||
curand_init(seeds.first, tid, seeds.second, &state); | ||
|
||
round_stochastically<scalar_t, float, at::Half> rounder; | ||
|
||
for (int i = tid; i < numel; i += blockDim.x * gridDim.x) { | ||
float weight = static_cast<float>(weights[i]); | ||
float gradient = static_cast<float>(gradients[i]) * (*inv_scale); | ||
float velocity = static_cast<float>(momentum_buffer[i]); | ||
float4 random_values = curand_uniform4(&state); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you generate 4 rng and only use 2. I don't think that's a big problem though. |
||
|
||
if (weight_decay != 0.0f) | ||
gradient += weight_decay * weight; | ||
|
||
if (momentum != 0.0f) { | ||
if (!first_run) | ||
velocity = velocity * momentum + (1.0f - dampening) * gradient; | ||
else | ||
velocity = gradient; | ||
|
||
if (nesterov) | ||
gradient += momentum * velocity; | ||
else | ||
gradient = velocity; | ||
} | ||
|
||
weight -= lr * gradient; | ||
|
||
weights[i] = rounder(weight, random_values.x); | ||
if (momentum != 0.0f) | ||
momentum_buffer[i] = rounder(velocity, random_values.y); | ||
} | ||
} | ||
|
||
Tensor stochastic_rounding_sgd_step_cuda( | ||
Tensor& param, const Tensor& grad, Tensor& momentum_buffer, | ||
const Tensor& inv_scale, const Tensor& found_inf, | ||
double lr, double momentum, double weight_decay, double dampening, | ||
bool nesterov, bool first_run, c10::optional<Generator> gen_) { | ||
|
||
if (param.numel() == 0) return param; | ||
|
||
TORCH_CHECK(param.is_contiguous()); | ||
TORCH_CHECK(grad.is_contiguous()); | ||
TORCH_CHECK(momentum_buffer.is_contiguous()); | ||
|
||
const int64_t numel = param.numel(); | ||
const int block_size = 256; | ||
const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; | ||
dim3 dim_block(block_size); | ||
dim3 grid((numel + block_size - 1) / block_size); | ||
mcarilli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); | ||
|
||
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator()); | ||
uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (dim_block.x * grid.x)) * 4; | ||
std::pair<uint64_t, uint64_t> rng_engine_inputs; | ||
{ | ||
std::lock_guard<std::mutex> lock(gen->mutex_); | ||
rng_engine_inputs = gen->philox_engine_inputs(counter_offset); | ||
} | ||
|
||
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | ||
param.scalar_type(), "stochastic_rounding_sgd_step_cuda", [&] { | ||
stochastic_rounding_sgd_step_kernel<scalar_t><<<grid, dim_block, 0, c10::cuda::getCurrentCUDAStream()>>>( | ||
param.data_ptr<scalar_t>(), | ||
grad.data_ptr<scalar_t>(), | ||
momentum_buffer.data_ptr<scalar_t>(), | ||
inv_scale.data_ptr<float>(), found_inf.data_ptr<float>(), | ||
static_cast<float>(weight_decay), static_cast<float>(momentum), static_cast<float>(dampening), static_cast<float>(lr), | ||
nesterov, first_run, numel, rng_engine_inputs); | ||
}); | ||
AT_CUDA_CHECK(cudaGetLastError()); | ||
return param; | ||
} | ||
|
||
} // namespace native | ||
} // namespace at |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#pragma once | ||
|
||
#include <math.h> | ||
#include <utility> | ||
|
||
#include <cuda.h> | ||
#include <cuda_fp16.h> | ||
#include <cuda_runtime.h> | ||
#include <curand.h> | ||
#include <curand_kernel.h> | ||
|
||
#include <ATen/Utils.h> | ||
#include <ATen/Generator.h> | ||
#include <ATen/CUDAGeneratorImpl.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAStream.h> | ||
#include <c10/cuda/CUDAFunctions.h> | ||
|
||
// 2^-10 is the step for normal FP16 numbers. | ||
// 2^-24 is the unit in the last place (ULP)/precision limitation. | ||
// 24 is **NOT** related to the number of mantissa bits of single precision format. | ||
__device__ const float TWO_10 = 0.0009765625; | ||
__device__ const float TWO_24 = 0.000000059604644775390625; | ||
|
||
|
||
template<typename T> | ||
__device__ __forceinline__ T maybe_upcast(__half x){ | ||
return T(__half2float(x)); | ||
} | ||
|
||
template<> | ||
__device__ __forceinline__ __half maybe_upcast<__half>(__half x){ | ||
return x; | ||
} | ||
|
||
__device__ __forceinline__ float get_delta_fp16(float x) { | ||
int exponent; | ||
frexpf(x, &exponent); | ||
exponent -= 1; | ||
if (exponent >= -14) | ||
return TWO_10 * std::pow(2, exponent); | ||
else | ||
return TWO_24; | ||
} | ||
|
||
// Natalia magic | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep this comment. |
||
template <typename out_type, typename in_type, typename round_to_prec=at::Half> | ||
struct round_stochastically { | ||
static_assert(std::is_same<round_to_prec, at::Half>::value, "round_stochastically only supports round_to_prec=at::Half"); | ||
}; | ||
|
||
template <typename out_type, typename in_type> | ||
struct round_stochastically<out_type, in_type, at::Half> { | ||
__device__ __forceinline__ out_type operator()(in_type x, float random_value) { | ||
if (x == 0.0) { | ||
return out_type(0.0); | ||
} | ||
float delta = get_delta_fp16(static_cast<float>(x)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoding float here is probably fine IMO, but natalia may ask you to change this to |
||
float val; | ||
if (x < 0.0) { | ||
val = x - random_value * delta; | ||
} else { | ||
val = x + random_value * delta; | ||
} | ||
return maybe_upcast<out_type>(__float2half_rz(val)); | ||
} | ||
}; |
Uh oh!
There was an error while loading. Please reload this page.