-
Notifications
You must be signed in to change notification settings - Fork 11.7k
feat: First pass at llama_kv_cache_hybrid #13276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome to see this progress!
src/llama-kv-cache.cpp
Outdated
// TODO: Will it cause problems if some caches are able to remove the seq | ||
// but others aren't? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it will cause problems if this breaks the coherency between caches. (e.g. part of a sequence is removed in one cache but not the other).
This is what I was referring to in #12799 (comment) when I wrote:
The hardest part will be handling errors and properly keeping coherency between the different types of caches (because they don't necessarily roll-back states in the same way).
I think the seq_rm
API might fundamentally be too specific to self-attention KV cache. Recurrent models can't rollback their state, because intermediate states are not kept since keeping them for all tokens would take too much space. (when seq_rm
returns false, it means the states have to be re-calculated from scratch for the affected sequence (at least that was the intention in #5328))
Ideally, if there was some API to create snapshots and rollback to them, the implementation would be simpler for recurrent models (and for hybrid models by extension). (technically, sequences (with seq_id
) already kind of do this (and are copy-on-write), but snapshots within sequences might be more convenient to manage in user code, since managing which state is the latest per sequence could be done transparently)
But that would also mean having to manage the lifetime of explicit state snapshots (in examples/server/server.cpp
among others) instead of directly dealing with ranges of token positions (and might make things like largest-common-prefix context caching harder to handle). I've previously shared some ideas about state snapshots/checkpoints in #7531 (comment) (although the first half of the comment is about session restore as in state_read
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, interesting. I'm definitely still learning on-the-fly here, but based on this description and the logic here in server.cpp
, it seems like the most correct implementation would be to leak implementation details of the child caches or introduce a new member API for can_seq_rm
that is const
but returns the same logic. I think I'll give that a shot and see how far I can get.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I've pushed an attempt at doing this safely. One thing I noticed is that these mutating methods don't seem to have any sort of locking mechanism, so the way I have it implemented could certainly be prone to thread safety problems if concurrent threads tried to call seq_rm
. I don't think this is any different than for the current cache implementations since those would also be sensitive to the same races where the validated condition changes after validation but before the members get mutated, but I wanted to double check if this kind of thread safety is guarded against elsewhere (or just assumed to be handled in the client layer).
src/llama-kv-cache.cpp
Outdated
// If any of the caches are recurrent, require simple split | ||
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simple split should not be used with recurrent models, they expect equal split.
See #7531 (comment) which illustrates the splits
// If any of the caches are recurrent, require simple split | |
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all); | |
// If any of the caches are recurrent, require non-simple split | |
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment pointer, this is super helpful for understanding what the consequences of these actually are!
src/llama-kv-cache.cpp
Outdated
if (m_has_recurrent) { | ||
return sbatch.split_simple(n_ubatch); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will not work, recurrent models expect split_equal
to be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I'm following now. I had them backwards in my head
// TODO: Is this correct? | ||
// If any children can shift, return true | ||
for (const auto & cache : m_children) { | ||
if (cache->get_can_shift()) { | ||
return true; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should be if all children can shift, then return true.
But as you've noticed elsewhere, can_shift
should technically always be true for all currently-implemented cache types, so I don't know if that part of the API will stay anyway.
This implementation covers both `llama_memory_i` and `llama_kv_cache` interfaces, but they could very well not be correct. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
… seq_rm This allows the hybrid cache to check first before mutating any of the children. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
The parent should fully own the lifecycle of the children which is managed by the m_children member holding unique_ptrs. These need to be initialized correctly, so the constructor now takes the input vector of child_cache by value instead of reference so that the child pointers can be transferred to the parent cache. The expectation is that the vector of child_cache instances will be instantiated in-place with move semantics. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
Description
This implementation covers both
llama_memory_i
andllama_kv_cache
interfaces, but they could very well not be correct.Discussion
I'm putting this up for discussion even though it doesn't have much value as standalone. My ultimate goal is support for the just-released granite 4 which is a combination of
mamba2
andgranitemoeshared
layers. I opened #13275 to track the full scope of model architecture changes.