Skip to content

Commit 7c8249f

Browse files
author
John
committed
cuda malloc:
- added functionality to find the smallest fitting buffer instead of the first found buffer that >= than requested -- this prevents that two buffer allocations in sequence can take a huge buffer for a small tensor and then require a new buffer for the 2nd tensor -- in my test it saved 1GB VRAM that are now free for more offloading cuda free buffers: - added a helper function that frees all unused buffers from a device to prevent huge F32 buffers from cuBLAS occupying VRAM needlessly after token ingestion libfalcon: - corrected vram_overhead calculation to account for the actual non-weight buffers needed during inference - added vram_overhead for n_batch > 1 as this switches the ingestion into a 32 bit dequantization mode for cu_blas which needs almost 2 GB VRAM buffers - corrected the automated layer distribution to fill VRAM as much as possible with layers From here on it's recommended to use --ngl 100 and -b 1 for CUDA processing. In addition -t is recommended using 1 or 1 less threads than CPU cores (depends on CPU, GPU used)
1 parent b4028ed commit 7c8249f

File tree

5 files changed

+104
-49
lines changed

5 files changed

+104
-49
lines changed

examples/falcon/falcon_main.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,9 @@ int main(int argc, char ** argv) {
397397
embd.erase(embd.begin(), embd.begin() + i);
398398
}
399399
}
400-
400+
// We have buffers from the warmup run that won't all align with a batched run
401+
if (params.n_batch > 1 && embd.size() > 1)
402+
ggml_cuda_pool_free_all(-1);
401403
// evaluate tokens in batches
402404
// embd is typically prepared beforehand to fit within a batch, but not always
403405
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
@@ -411,7 +413,9 @@ int main(int argc, char ** argv) {
411413
}
412414
n_past += n_eval;
413415
}
414-
416+
// frees unused allocations, those during batch processing are of different size than single token eval
417+
if (params.n_batch > 1 && embd.size() > 1)
418+
ggml_cuda_pool_free_all(-1);
415419
if (embd.size() > 0 && !path_session.empty()) {
416420
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
417421
n_session_consumed = session_tokens.size();

ggml-cuda.cu

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,17 +1428,29 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
14281428
scoped_spin_lock lock(g_cuda_pool_lock);
14291429
int id;
14301430
CUDA_CHECK(cudaGetDevice(&id));
1431-
1431+
size_t min_size_diff = SIZE_MAX;
1432+
size_t min_size_diff_ok = size * 0.05; // wiggle room
1433+
cuda_buffer* best_fit = nullptr; // candidate pointer
14321434
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
14331435
cuda_buffer& b = g_cuda_buffer_pool[id][i];
14341436
if (b.size >= size && b.ptr != nullptr) {
1435-
void * ptr = b.ptr;
1436-
*actual_size = b.size;
1437-
b.ptr = nullptr;
1438-
b.size = 0;
1439-
return ptr;
1437+
size_t size_diff = b.size - size;
1438+
if (size_diff < min_size_diff) {
1439+
best_fit = &b;
1440+
min_size_diff = size_diff;
1441+
if (size_diff < min_size_diff_ok) {
1442+
break;
1443+
}
1444+
}
14401445
}
14411446
}
1447+
if (best_fit != nullptr) {
1448+
*actual_size = best_fit->size;
1449+
void * ptr = best_fit->ptr;
1450+
best_fit->ptr = nullptr;
1451+
best_fit->size = 0;
1452+
return ptr;
1453+
}
14421454
void * ptr;
14431455
CUDA_CHECK(cudaMalloc((void **) &ptr, size));
14441456
*actual_size = size;
@@ -1462,6 +1474,30 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
14621474
CUDA_CHECK(cudaFree(ptr));
14631475
}
14641476

1477+
// free all buffers that are not currently in use
1478+
void ggml_cuda_pool_free_all(int device_id) {
1479+
while (atomic_flag_test_and_set(&g_cuda_pool_lock)) {}
1480+
1481+
int start_id = (device_id < 0) ? 0 : device_id;
1482+
int end_id = (device_id < 0) ? GGML_CUDA_MAX_DEVICES : device_id + 1;
1483+
1484+
for (int id = start_id; id < end_id; ++id) {
1485+
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
1486+
cuda_buffer* b = &(g_cuda_buffer_pool[id][i]);
1487+
if (b->ptr != NULL) {
1488+
cudaError_t err = cudaFree(b->ptr);
1489+
if (err != cudaSuccess) {
1490+
fprintf(stderr, "ERROR: CUDA buffer free failed: %s\n", cudaGetErrorString(err));
1491+
} else {
1492+
b->ptr = NULL;
1493+
b->size = 0;
1494+
}
1495+
}
1496+
}
1497+
}
1498+
1499+
atomic_flag_clear(&g_cuda_pool_lock);
1500+
}
14651501

14661502
static void * g_scratch_buffer = nullptr;
14671503
static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default

ggml-cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_
2020
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
2121
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
2222

23+
void ggml_cuda_pool_free_all(int device_id);
2324
// TODO: export these with GGML_API
2425
void * ggml_cuda_host_malloc(size_t size);
2526
void ggml_cuda_host_free(void * ptr);

libfalcon.cpp

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ struct falcon_model {
168168
std::vector<falcon_layer> layers;
169169

170170
int n_gpu_layers;
171+
int i_gpu_start;
172+
int i_gpu_last;
171173

172174
// context
173175
struct ggml_context * ctx = NULL;
@@ -911,7 +913,9 @@ struct falcon_context_params falcon_context_default_params() {
911913
struct falcon_context_params result = {
912914
/*.n_ctx =*/ 512,
913915
/*.n_batch =*/ 512,
914-
/*.gpu_layers =*/ 0,
916+
/*.n_gpu_layers =*/ 0,
917+
/*.i_gpu_start =*/ -1,
918+
/*.i_gpu_last =*/ -1,
915919
/*.main_gpu =*/ 0,
916920
/*.tensor_split =*/ {0},
917921
/*.seed =*/ -1,
@@ -1068,7 +1072,7 @@ static void falcon_model_load_internal(
10681072
{
10691073
switch (hparams.n_layer) {
10701074
case 32: model.type = e_model::FALCON_7B; break;
1071-
case 40: model.type = e_model::FALCON_40B; break;
1075+
case 60: model.type = e_model::FALCON_40B; break;
10721076
default:
10731077
{
10741078
if (hparams.version == 7) {
@@ -1166,23 +1170,30 @@ if (n_gpu_layers > 0)
11661170
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU
11671171
#endif
11681172

1169-
size_t vram_total=0;
1170-
size_t vram_free=0;
1171-
size_t vram_reserved=1024*1024*512; //will be adapted by model
1173+
size_t vram_total=0;
1174+
size_t vram_free=0;
1175+
const size_t vram_reserved=512*MB; // that amount of VRAM is to stay free on GPU (headroom for other processes - may be reduced in pure server environments)
1176+
size_t vram_overhead = 1250*MB; // this amount of vram is estimated for non weight storage buffers on VRAM (no big difference between 7B and 40B, needs to increase when more work is offloaded in the future)
1177+
// cublas is used in 32 bit mode, temporary cuda storage/conversion buffers are needed for batch ingestion ( could be run in 16 bit mode without performance downgrade and save half the VRAM)
1178+
if (model.type == FALCON_40B && n_batch > 1)
1179+
vram_overhead += (1024+288+256) * MB;
1180+
if (model.type == FALCON_7B && n_batch > 1)
1181+
vram_overhead += (315+80+78) * MB;
11721182
#if defined(GGML_USE_CUBLAS)
11731183
cudaMemGetInfo(&vram_free, &vram_total); // this should go in ggml-cuda.cu but I don't want to make Johannes life harder by modifying that yet
1174-
fprintf(stderr, "%s: VRAM free: %7.2f MB of %7.2f MB (already used: %7.2f MB)\n", __func__, vram_free/MB*1.0, vram_total/MB*1.0, (vram_total-vram_free)/MB*1.0);
1184+
fprintf(stderr, "%s: VRAM free: %7.2f MB of %7.2f MB (in use: %7.2f MB)\n", __func__, vram_free/MB*1.0, vram_total/MB*1.0, (vram_total-vram_free)/MB*1.0);
11751185
#endif
11761186

11771187
// prepare memory for the weights
11781188
size_t vram_weights = 0;
11791189
size_t vram_scratch = 0;
1180-
size_t vram_overhead = 0;
1190+
11811191
(void) vram_scratch;
11821192
(void) n_batch;
11831193
// calculate scratch buffer size and allocate it
11841194
#ifdef GGML_USE_CUBLAS
1185-
vram_scratch = n_batch * MB;
1195+
// vram_scratch = n_batch * MB;
1196+
vram_scratch = 0; // these are not used until we have multi operation support
11861197
ggml_cuda_set_scratch_size(vram_scratch);
11871198
if (n_gpu_layers > 0) {
11881199

@@ -1203,22 +1214,7 @@ size_t vram_reserved=1024*1024*512; //will be adapted by model
12031214
ml->ggml_ctx = ctx;
12041215

12051216
model.tok_embeddings = ml->get_tensor("transformer.word_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU);
1206-
1207-
// I did not analyze the cause but that's the overhead that is dynamically added to the VRAM at first inference
1208-
// same goes with reserved, most likely we can skip both for a proper size calculation.
1209-
// If the below values are not correct GPU memory will fill up to 100%, resulting in a extreme slowdown of inference
1210-
if (model.type == FALCON_40B)
1211-
{
1212-
vram_reserved=1900*MB;
1213-
vram_overhead+=2700*MB;
1214-
}
1215-
else
1216-
{
1217-
vram_reserved=768*MB;
1218-
vram_overhead+=1200*MB;
1219-
}
1220-
1221-
1217+
12221218
ggml_backend backend_norm;
12231219
ggml_backend backend_output;
12241220
// disabled norm/output offloading until further tests, causes silent crash at the moment
@@ -1240,10 +1236,8 @@ size_t vram_reserved=1024*1024*512; //will be adapted by model
12401236

12411237
if (backend_norm != GGML_BACKEND_CPU)
12421238
{
1243-
vram_weights += ggml_nbytes(model.output_norm);
1244-
vram_weights += ggml_nbytes(model.output_norm_b);
1245-
vram_free -= ggml_nbytes(model.output_norm);
1246-
vram_free -= ggml_nbytes(model.output_norm_b);
1239+
vram_weights += ggml_nbytes(model.output_norm) + ggml_nbytes(model.output_norm_b);
1240+
vram_free -= ggml_nbytes(model.output_norm) + ggml_nbytes(model.output_norm_b);
12471241
}
12481242
if (backend_output != GGML_BACKEND_CPU)
12491243
{
@@ -1252,12 +1246,14 @@ size_t vram_reserved=1024*1024*512; //will be adapted by model
12521246
}
12531247

12541248
const int i_gpu_start = n_layer - n_gpu_layers;
1255-
int i_gpu_end = n_layer; // allows to terminate the offloading earlier. TODO: instead do a proper calculation run and determine the start before the loop
1249+
int i_gpu_last = n_layer; // allows to terminate the offloading earlier. TODO: instead do a proper calculation run and determine the start before the loop
1250+
model.i_gpu_start = i_gpu_start;
1251+
model.i_gpu_last = i_gpu_last;
12561252

12571253
model.layers.resize(n_layer);
12581254
for (uint32_t i = 0; i < n_layer; ++i) {
1259-
const ggml_backend backend = (int(i) < i_gpu_start || int(i) >= i_gpu_end) ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
1260-
const ggml_backend backend_split = (int(i) < i_gpu_start || int(i) >= i_gpu_end) ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
1255+
const ggml_backend backend = (int(i) < i_gpu_start || int(i) > i_gpu_last) ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
1256+
const ggml_backend backend_split = (int(i) < i_gpu_start || int(i) > i_gpu_last) ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
12611257

12621258
auto & layer = model.layers[i];
12631259

@@ -1288,14 +1284,15 @@ size_t vram_reserved=1024*1024*512; //will be adapted by model
12881284
vram_layer = calculate_layer_vram_bytes(layer);
12891285
vram_weights += vram_layer;
12901286
vram_free = (vram_layer > vram_free) ? 0 : vram_free - vram_layer; // simulate the layer being loaded in VRAM
1291-
1292-
if (vram_free <= (vram_overhead+vram_scratch+vram_reserved))
1287+
// test if we have enough VRAM to load the next layer
1288+
if (i < n_layer && vram_free <= (vram_overhead+vram_scratch+vram_reserved+vram_layer))
12931289
{
12941290
// this needs some polishing (instead of fiddling with --ngl I'd like the option to auto-fill the vram with as many layers as possible as an alternative)
1295-
fprintf(stderr, "WARNING: Not enough VRAM to load the model as configured - at layer %d of %d\n", i, n_layer);
1291+
fprintf(stderr, "INFO: Not enough VRAM to load all requested layers - at layer %d of %d: skipping\n", i, n_layer);
12961292
n_gpu_layers = i+1;
1297-
model.n_gpu_layers = n_gpu_layers;
1298-
i_gpu_end = i;
1293+
model.n_gpu_layers = n_gpu_layers;
1294+
i_gpu_last = i;
1295+
model.i_gpu_last = i_gpu_last;
12991296
}
13001297
}
13011298

@@ -1335,7 +1332,7 @@ size_t vram_reserved=1024*1024*512; //will be adapted by model
13351332
if (n_gpu_layers > (int) hparams.n_layer) {
13361333
fprintf(stderr, "%s: offloading output layer to GPU\n", __func__);
13371334
}
1338-
fprintf(stderr, "%s: total VRAM used: %zu MB\n",
1335+
fprintf(stderr, "%s: estimated VRAM usage: %zu MB\n",
13391336
__func__, (vram_weights + vram_scratch + vram_overhead + MB - 1) / MB); // round up
13401337
#else
13411338
(void) n_gpu_layers;
@@ -1468,7 +1465,9 @@ static bool falcon_eval_internal(
14681465
// ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
14691466
const int sizeof_wtype = ggml_type_sizef(wtype);
14701467

1471-
const int i_gpu_start = n_layer - n_gpu_layers;
1468+
// const int i_gpu_start = n_layer - n_gpu_layers;
1469+
const int i_gpu_start = lctx.model.i_gpu_start;
1470+
const int i_gpu_last = lctx.model.i_gpu_last > 0 ? lctx.model.i_gpu_last : n_layer;
14721471
(void) i_gpu_start;
14731472

14741473
// offload functions set the tensor output backend to GPU
@@ -1492,7 +1491,7 @@ static bool falcon_eval_internal(
14921491
offload_func_t offload_func = llama_nop;
14931492

14941493
#ifdef GGML_USE_CUBLAS
1495-
if (il >= i_gpu_start) {
1494+
if (il >= i_gpu_start && il < i_gpu_last) {
14961495
offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU
14971496
}
14981497
#endif // GGML_USE_CUBLAS
@@ -1507,7 +1506,7 @@ static bool falcon_eval_internal(
15071506
layernorm_output = ggml_norm(ctx0, inpL);
15081507

15091508
ggml_tensor * il_a = ggml_mul(ctx0, layernorm_output, model.layers[il].input_layernorm);
1510-
offload_func(il_a);
1509+
offload_func(il_a); // (todo: uses vram scratch)
15111510

15121511
layernorm_output = ggml_add(ctx0,
15131512
il_a,
@@ -1737,6 +1736,15 @@ static bool falcon_eval_internal(
17371736

17381737
// run the computation
17391738
ggml_build_forward_expand(&gf, cur);
1739+
#if 0
1740+
// use to confirm vram_overhead is correct
1741+
size_t vram_total=0;
1742+
size_t vram_free=0;
1743+
#if defined(GGML_USE_CUBLAS)
1744+
cudaMemGetInfo(&vram_free, &vram_total); // this should go in ggml-cuda.cu but I don't want to make Johannes life harder by modifying that yet
1745+
fprintf(stderr, "\n%s: VRAM free: %7.2f MB of %7.2f MB (in use: %7.2f MB)\n", __func__, vram_free/MB*1.0, vram_total/MB*1.0, (vram_total-vram_free)/MB*1.0);
1746+
#endif
1747+
#endif
17401748

17411749
#ifdef GGML_USE_METAL
17421750
if (lctx.ctx_metal && N == 1) {
@@ -2701,7 +2709,11 @@ struct falcon_context * falcon_init_from_file(
27012709
llama_free(ctx);
27022710
return nullptr;
27032711
}
2704-
params.n_gpu_layers = ctx->model.n_gpu_layers; // model_load_internal() may change this
2712+
// model_load_internal() may change this if VRAM runs out
2713+
params.n_gpu_layers = ctx->model.n_gpu_layers;
2714+
params.i_gpu_start = ctx->model.i_gpu_start;
2715+
params.i_gpu_last = ctx->model.i_gpu_last;
2716+
27052717

27062718
// reserve memory for context buffers
27072719
if (!params.vocab_only) {

libfalcon.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ extern "C" {
7575
int n_ctx; // text context
7676
int n_batch; // prompt processing batch size
7777
int n_gpu_layers; // number of layers to store in VRAM
78+
int i_gpu_start; // first gpu layer
79+
int i_gpu_last; // last gpu layer
7880
int main_gpu; // the GPU that is used for scratch and small tensors
7981
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
8082
int seed; // RNG seed, -1 for random

0 commit comments

Comments
 (0)