@@ -30,10 +30,11 @@ llama_kv_cache_unified::llama_kv_cache_unified(
30
30
bool v_trans,
31
31
bool offload,
32
32
uint32_t kv_size,
33
- uint32_t padding,
33
+ uint32_t n_seq_max,
34
+ uint32_t n_pad,
34
35
uint32_t n_swa,
35
- llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding ), n_swa(n_swa), swa_type(swa_type) {
36
- GGML_ASSERT (kv_size % padding == 0 && " kv_size must be a multiple of padding" );
36
+ llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), n_pad(n_pad ), n_swa(n_swa), swa_type(swa_type) {
37
+ GGML_ASSERT (kv_size % n_pad == 0 && " kv_size must be a multiple of padding" );
37
38
38
39
this ->type_k = type_k;
39
40
this ->type_v = type_v;
@@ -442,7 +443,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
442
443
void llama_kv_cache_unified::defrag_sched (float thold) {
443
444
// - do not defrag small contexts (i.e. < 2048 tokens)
444
445
// - count the padding towards the number of used tokens
445
- const float fragmentation = n >= 2048 ? std::max (0 .0f , 1 .0f - (float (used + padding )/n)) : 0 .0f ;
446
+ const float fragmentation = n >= 2048 ? std::max (0 .0f , 1 .0f - (float (used + n_pad )/n)) : 0 .0f ;
446
447
447
448
// queue defragmentation for next llama_kv_cache_update
448
449
if (fragmentation > thold) {
@@ -558,7 +559,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
558
559
// a heuristic, to avoid attending the full cache if it is not yet utilized
559
560
// after enough generations, the benefit from this heuristic disappears
560
561
// if we start defragmenting the cache, the benefit from this will be more important
561
- n = std::min (size, std::max (padding , GGML_PAD (cell_max (), padding )));
562
+ n = std::min (size, std::max (n_pad , GGML_PAD (cell_max (), n_pad )));
562
563
563
564
#ifdef FIND_SLOT_DEBUG
564
565
LLAMA_LOG_WARN (" end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n " , n, used, head, n_swa);
@@ -567,20 +568,6 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
567
568
return true ;
568
569
}
569
570
570
- int32_t llama_kv_cache_unified::get_n_tokens () const {
571
- int32_t result = 0 ;
572
-
573
- for (uint32_t i = 0 ; i < size; i++) {
574
- result += cells[i].seq_id .size ();
575
- }
576
-
577
- return result;
578
- }
579
-
580
- int32_t llama_kv_cache_unified::get_used_cells () const {
581
- return used;
582
- }
583
-
584
571
bool llama_kv_cache_unified::get_can_shift () const {
585
572
return true ;
586
573
}
@@ -802,16 +789,6 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
802
789
}
803
790
}
804
791
805
- llama_pos llama_kv_cache_unified::get_pos_max () const {
806
- llama_pos pos_max = -1 ;
807
-
808
- for (const auto & cell : cells) {
809
- pos_max = std::max (pos_max, cell.pos );
810
- }
811
-
812
- return pos_max;
813
- }
814
-
815
792
size_t llama_kv_cache_unified::total_size () const {
816
793
size_t size = 0 ;
817
794
@@ -1655,17 +1632,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1655
1632
ggml_type type_v,
1656
1633
bool v_trans,
1657
1634
bool offload,
1658
- uint32_t kv_size,
1659
1635
bool swa_full,
1636
+ uint32_t kv_size,
1660
1637
uint32_t n_seq_max,
1661
1638
uint32_t n_batch,
1662
- uint32_t padding ) : hparams(model.hparams) {
1639
+ uint32_t n_pad ) : hparams(model.hparams) {
1663
1640
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
1664
1641
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
1665
1642
1666
1643
const uint32_t size_base = kv_size;
1667
1644
1668
- uint32_t size_swa = std::min (size_base, GGML_PAD (hparams.n_swa *n_seq_max + n_batch, padding ));
1645
+ uint32_t size_swa = std::min (size_base, GGML_PAD (hparams.n_swa *n_seq_max + n_batch, n_pad ));
1669
1646
1670
1647
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
1671
1648
if (swa_full) {
@@ -1680,14 +1657,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1680
1657
1681
1658
kv_base = std::make_unique<llama_kv_cache_unified>(
1682
1659
model, std::move (filter_base), type_k, type_v,
1683
- v_trans, offload, size_base, padding ,
1660
+ v_trans, offload, size_base, n_seq_max, n_pad ,
1684
1661
0 , LLAMA_SWA_TYPE_NONE);
1685
1662
1686
1663
LLAMA_LOG_INFO (" %s: creating SWA KV cache, size = %u cells\n " , __func__, size_swa);
1687
1664
1688
1665
kv_swa = std::make_unique<llama_kv_cache_unified>(
1689
1666
model, std::move (filter_swa), type_k, type_v,
1690
- v_trans, offload, size_swa, padding ,
1667
+ v_trans, offload, size_swa, n_seq_max, n_pad ,
1691
1668
hparams.n_swa , hparams.swa_type );
1692
1669
}
1693
1670
@@ -1810,18 +1787,6 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
1810
1787
return res;
1811
1788
}
1812
1789
1813
- int32_t llama_kv_cache_unified_iswa::get_n_tokens () const {
1814
- return kv_base->get_n_tokens ();
1815
- }
1816
-
1817
- int32_t llama_kv_cache_unified_iswa::get_used_cells () const {
1818
- return kv_base->get_used_cells ();
1819
- }
1820
-
1821
- llama_pos llama_kv_cache_unified_iswa::get_pos_max () const {
1822
- return kv_base->get_pos_max ();
1823
- }
1824
-
1825
1790
bool llama_kv_cache_unified_iswa::get_can_shift () const {
1826
1791
return kv_base->get_size () == kv_swa->get_size ();
1827
1792
}
@@ -1853,7 +1818,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1853
1818
ggml_type type_k,
1854
1819
ggml_type type_v,
1855
1820
bool offload,
1856
- uint32_t kv_size) : hparams(model.hparams) {
1821
+ uint32_t kv_size,
1822
+ uint32_t n_seq_max) : hparams(model.hparams) {
1857
1823
const int32_t n_layer = hparams.n_layer ;
1858
1824
1859
1825
LLAMA_LOG_INFO (" %s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n " ,
@@ -2203,8 +2169,8 @@ void llama_kv_cache_recurrent::commit() {
2203
2169
pending.ranges .clear ();
2204
2170
}
2205
2171
2206
- bool llama_kv_cache_recurrent::update (llama_context & lctx ) {
2207
- GGML_UNUSED (lctx );
2172
+ bool llama_kv_cache_recurrent::update (llama_context & ctx ) {
2173
+ GGML_UNUSED (ctx );
2208
2174
return false ;
2209
2175
}
2210
2176
@@ -2408,29 +2374,6 @@ bool llama_kv_cache_recurrent::find_slot(
2408
2374
return n >= n_seqs;
2409
2375
}
2410
2376
2411
- int32_t llama_kv_cache_recurrent::get_n_tokens () const {
2412
- int32_t result = 0 ;
2413
-
2414
- for (uint32_t i = 0 ; i < size; i++) {
2415
- result += cells[i].seq_id .size ();
2416
- }
2417
-
2418
- return result;
2419
- }
2420
-
2421
- int32_t llama_kv_cache_recurrent::get_used_cells () const {
2422
- return used;
2423
- }
2424
-
2425
- llama_pos llama_kv_cache_recurrent::get_pos_max () const {
2426
- llama_pos pos_max = -1 ;
2427
- for (const auto & cell : cells) {
2428
- pos_max = std::max (pos_max, cell.pos );
2429
- }
2430
-
2431
- return pos_max;
2432
- }
2433
-
2434
2377
bool llama_kv_cache_recurrent::get_can_shift () const {
2435
2378
return false ;
2436
2379
}
0 commit comments