Skip to content

Commit 7f9753f

Browse files
CUDA GPU acceleration for LoRAs + f16 models (#1970)
1 parent cfa0750 commit 7f9753f

File tree

4 files changed

+78
-19
lines changed

4 files changed

+78
-19
lines changed

examples/common.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -416,13 +416,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
416416
exit(1);
417417
}
418418

419-
#ifdef GGML_USE_CUBLAS
420-
if (!params.lora_adapter.empty() && params.n_gpu_layers > 0) {
421-
fprintf(stderr, "%s: error: the simultaneous use of LoRAs and GPU acceleration is not supported", __func__);
422-
exit(1);
423-
}
424-
#endif // GGML_USE_CUBLAS
425-
426419
if (escape_prompt) {
427420
process_escapes(params.prompt);
428421
}

ggml-cuda.cu

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
223223
dst[i] = x[i] + y[i];
224224
}
225225

226+
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
227+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
228+
229+
if (i >= k) {
230+
return;
231+
}
232+
dst[i] = __hadd(x[i], __float2half(y[i]));
233+
}
234+
226235
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
227236
const int i = blockDim.x*blockIdx.x + threadIdx.x;
228237

@@ -1459,6 +1468,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
14591468
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
14601469
}
14611470

1471+
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
1472+
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
1473+
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
1474+
}
1475+
14621476
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
14631477
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
14641478
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@@ -1941,15 +1955,21 @@ inline void ggml_cuda_op_add(
19411955
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
19421956
cudaStream_t & cudaStream_main){
19431957

1944-
GGML_ASSERT(src0_ddf_i != nullptr);
1958+
GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
19451959
GGML_ASSERT(src1_ddf_i != nullptr);
19461960
GGML_ASSERT(dst_ddf_i != nullptr);
19471961

19481962
const int64_t ne0 = src0->ne[0];
19491963
const int64_t i01_diff = i01_high - i01_low;
19501964

19511965
// compute
1952-
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
1966+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
1967+
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
1968+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
1969+
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
1970+
} else {
1971+
GGML_ASSERT(false);
1972+
}
19531973
CUDA_CHECK(cudaGetLastError());
19541974

19551975
(void) src1;
@@ -2547,8 +2567,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
25472567
}
25482568

25492569
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2550-
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
2551-
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
2570+
// ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
2571+
// Due to flatten_rows == true this does in practice not make a difference however.
2572+
// Better solution would be nice but right now that would require disproportionate changes.
2573+
GGML_ASSERT(
2574+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
2575+
src1->type == GGML_TYPE_F32 &&
2576+
(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
2577+
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
25522578
}
25532579

25542580
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2801,7 +2827,7 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
28012827
delete extra;
28022828
}
28032829

2804-
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
2830+
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
28052831
if (scratch && g_scratch_size == 0) {
28062832
return;
28072833
}
@@ -2810,23 +2836,24 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
28102836
if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
28112837
const ggml_op src0_op = tensor->src0->op;
28122838
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
2813-
ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
2839+
ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
28142840
}
28152841
}
28162842
if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
2817-
ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
2843+
ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
28182844
}
28192845

28202846
tensor->backend = GGML_BACKEND_GPU;
28212847
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
28222848
memset(extra, 0, sizeof(*extra));
28232849

28242850
const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
2825-
tensor->op == GGML_OP_VIEW;
2851+
tensor->op == GGML_OP_VIEW ||
2852+
force_inplace;
28262853
const size_t size = ggml_nbytes(tensor);
28272854

28282855
CUDA_CHECK(cudaSetDevice(g_main_device));
2829-
if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
2856+
if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
28302857
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
28312858
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
28322859
size_t offset = 0;
@@ -2865,11 +2892,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
28652892
}
28662893

28672894
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
2868-
ggml_cuda_assign_buffers_impl(tensor, true);
2895+
ggml_cuda_assign_buffers_impl(tensor, true, false);
28692896
}
28702897

28712898
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
2872-
ggml_cuda_assign_buffers_impl(tensor, false);
2899+
ggml_cuda_assign_buffers_impl(tensor, false, false);
2900+
}
2901+
2902+
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
2903+
ggml_cuda_assign_buffers_impl(tensor, false, true);
28732904
}
28742905

28752906
void ggml_cuda_set_main_device(int main_device) {

ggml-cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
2929
void ggml_cuda_free_data(struct ggml_tensor * tensor);
3030
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
3131
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
32+
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
3233
void ggml_cuda_set_main_device(int main_device);
3334
void ggml_cuda_set_scratch_size(size_t scratch_size);
3435
void ggml_cuda_free_scratch(void);

llama.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2976,14 +2976,15 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
29762976
return false;
29772977
}
29782978
}
2979-
ggml_tensor* lora_tensor;
2979+
ggml_tensor * lora_tensor;
29802980
if (n_dims == 2) {
29812981
lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
29822982
}
29832983
else {
29842984
fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims);
29852985
return 1;
29862986
}
2987+
ggml_set_name(lora_tensor, "lora_tensor");
29872988

29882989
// load tensor data
29892990
size_t offset = fin.tellg();
@@ -2999,6 +3000,21 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
29993000
lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {
30003001

30013002
ggml_tensor * dest_t = model_tensors[base_name];
3003+
3004+
offload_func_t offload_func = llama_nop;
3005+
offload_func_t offload_func_force_inplace = llama_nop;
3006+
3007+
#ifdef GGML_USE_CUBLAS
3008+
if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) {
3009+
if (dest_t->type != GGML_TYPE_F16) {
3010+
throw std::runtime_error(format(
3011+
"%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models", __func__));
3012+
}
3013+
offload_func = ggml_cuda_assign_buffers;
3014+
offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace;
3015+
}
3016+
#endif // GGML_USE_CUBLAS
3017+
30023018
ggml_tensor * base_t;
30033019
if (model_loader) {
30043020
// load from base model
@@ -3026,7 +3042,12 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
30263042
}
30273043

30283044
ggml_tensor * loraA = lora_tensors[base_name + ".loraA"];
3045+
GGML_ASSERT(loraA->type == GGML_TYPE_F32);
3046+
ggml_set_name(loraA, "loraA");
3047+
30293048
ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];
3049+
GGML_ASSERT(loraB->type == GGML_TYPE_F32);
3050+
ggml_set_name(loraB, "loraB");
30303051

30313052
if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
30323053
fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
@@ -3036,19 +3057,32 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
30363057

30373058
// w = w + BA*s
30383059
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
3060+
offload_func(BA);
3061+
ggml_set_name(BA, "BA");
30393062

30403063
if (scaling != 1.0f) {
30413064
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
3065+
ggml_set_name(scale_tensor, "scale_tensor");
3066+
30423067
BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor);
3068+
offload_func(BA);
3069+
ggml_set_name(BA, "BA_scaled");
30433070
}
30443071

30453072
ggml_tensor * r;
30463073
if (base_t == dest_t) {
30473074
r = ggml_add_inplace(lora_ctx, dest_t, BA);
3075+
offload_func_force_inplace(r);
3076+
ggml_set_name(r, "r_add_inplace");
30483077
}
30493078
else {
30503079
r = ggml_add(lora_ctx, base_t, BA);
3080+
offload_func(r);
3081+
ggml_set_name(r, "r_add");
3082+
30513083
r = ggml_cpy(lora_ctx, r, dest_t);
3084+
offload_func(r);
3085+
ggml_set_name(r, "r_cpy");
30523086
}
30533087

30543088
struct ggml_cgraph gf = ggml_build_forward(r);

0 commit comments

Comments
 (0)