Skip to content

More GPU threads for dequantization #1341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 100 additions & 21 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");

typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);

#define GGML_CUDA_MAX_BLOCK_SIZE 256

#define QK4_0 32
typedef struct {
float d; // delta
Expand Down Expand Up @@ -80,10 +82,14 @@ typedef struct {
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");

static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
static __global__ void dequantize_block_q4_0(const void * vx, float * y, int k) {
const block_q4_0 * x = (const block_q4_0 *) vx;

const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;

if (i >= k) {
return;
}

const float d = x[i].d;

Expand All @@ -103,10 +109,14 @@ static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
}
}

static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
static __global__ void dequantize_block_q4_1(const void * vx, float * y, int k) {
const block_q4_1 * x = (const block_q4_1 *) vx;

const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;

if (i >= k) {
return;
}

const float d = x[i].d;
const float m = x[i].m;
Expand All @@ -127,10 +137,14 @@ static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
}
}

static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
static __global__ void dequantize_block_q4_2(const void * vx, float * y, int k) {
const block_q4_2 * x = (const block_q4_2 *) vx;

const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;

if (i >= k) {
return;
}

const float d = x[i].d;

Expand All @@ -150,10 +164,14 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
}
}

static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
static __global__ void dequantize_block_q5_0(const void * vx, float * y, int k) {
const block_q5_0 * x = (const block_q5_0 *) vx;

const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;

if (i >= k) {
return;
}

const float d = x[i].d;

Expand All @@ -179,10 +197,14 @@ static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
}
}

static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
static __global__ void dequantize_block_q5_1(const void * vx, float * y, int k) {
const block_q5_1 * x = (const block_q5_1 *) vx;

const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;

if (i >= k) {
return;
}

const float d = x[i].d;
const float m = x[i].m;
Expand All @@ -209,10 +231,14 @@ static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
}
}

static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
static __global__ void dequantize_block_q8_0(const void * vx, float * y, int k) {
const block_q8_0 * x = (const block_q8_0 *) vx;

const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;

if (i >= k) {
return;
}

const float d = x[i].d;

Expand All @@ -227,45 +253,98 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {

static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
static int block_size = -1;
if (block_size == -1) {
int min_grid_size;
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_0, 0, 0));
block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE);
}
const int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q4_0<<<grid_size, block_size, 0, stream>>>(vx, y, nb);
}

static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_1;
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
static int block_size = -1;
if (block_size == -1) {
int min_grid_size;
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_1, 0, 0));
block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE);
}
const int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q4_1<<<grid_size, block_size, 0, stream>>>(vx, y, nb);
}

static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_2;
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
static int block_size = -1;
if (block_size == -1) {
int min_grid_size;
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_2, 0, 0));
block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE);
}
const int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q4_2<<<grid_size, block_size, 0, stream>>>(vx, y, nb);
}

static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK5_0;
dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
static int block_size = -1;
if (block_size == -1) {
int min_grid_size;
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_0, 0, 0));
block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE);
}
const int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q5_0<<<grid_size, block_size, 0, stream>>>(vx, y, nb);
}

static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK5_1;
dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
static int block_size = -1;
if (block_size == -1) {
int min_grid_size;
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_1, 0, 0));
block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE);
}
const int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q5_1<<<grid_size, block_size, 0, stream>>>(vx, y, nb);
}

static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK8_0;
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
static int block_size = -1;
if (block_size == -1) {
int min_grid_size;
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q8_0, 0, 0));
block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE);
}
const int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q8_0<<<grid_size, block_size, 0, stream>>>(vx, y, nb);
}

// TODO: optimize
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
static __global__ void convert_fp16_to_fp32(const void * vx, float * y, int k) {
const half * x = (const half *) vx;

const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;

if (i >= k) {
return;
}

y[i] = __half2float(x[i]);
}

static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
static int block_size = -1;
if (block_size == -1) {
int min_grid_size;
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, convert_fp16_to_fp32, 0, 0));
block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE);
}
const int grid_size = (k + block_size - 1) / block_size; // Round up.
convert_fp16_to_fp32<<<grid_size, block_size, 0, stream>>>(x, y, k);
}

static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
Expand Down