-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Feature Request: Support Falcon Mamba 7B #9009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I agree.please support it. https://huggingface.co/tiiuae/falcon-mamba-7b-instruct/tree/main |
This is a Mamba model (not Mamba 2), so it should be relatively easy to adapt the existing Mamba support. The only difference I see is that Since these norms don't use any learned parameters, this will require either a new metadata field to signal that norms for dt, B and C are used, or a new architecture could be added (e.g. |
Converted models can be found here: https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a |
Conversion (using Trying to run the server (b3616, too) with Same problem when trying to use official GGUF from https://huggingface.co/tiiuae/falcon-mamba-7b-instruct-Q8_0-GGUF repo. |
Oh, right. When split, I think the fix might be something like this: diff --git a/src/llama.cpp b/src/llama.cpp
index bd7f1508..57fa5450 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -9118,9 +9118,9 @@ static struct ggml_tensor * llm_build_mamba(
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
if (ssm_dt_b_c_rms) {
- dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
- B = ggml_rms_norm(ctx, B, norm_rms_eps);
- C = ggml_rms_norm(ctx, C, norm_rms_eps);
+ dt = ggml_rms_norm(ctx, ggml_cont(ctx, dt), norm_rms_eps);
+ B = ggml_rms_norm(ctx, ggml_cont(ctx, B), norm_rms_eps);
+ C = ggml_rms_norm(ctx, ggml_cont(ctx, C), norm_rms_eps);
}
// {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} Either that, or making CUDA support doing norms when The CPU-only build does not have this problem, because the CPU-only implementation of |
With this change it ends with If GPU support for Mamba architecture is still not implemented, then maybe this issue shouldn't be closed, yet? Most pople would like to use GPU rather than CPU, I think. |
@MoonRide303, there is already #6758 which tracks GPU support for Mamba. |
Not Falcon Mamba, but there is also https://huggingface.co/TRI-ML/mamba-7b-rw/ (7b Mamba trained on RefinedWeb but using same config, GPT-NeoX tokenizer, vocab as OG Mamba). |
Prerequisites
Feature Description
Please support Falcon Mamba 7B from TII (Technology Innovation Institute TII - UAE)
Motivation
Support for all models is helpful.
My acid test for whether a model will run is to try and make a quant using "gruff my repo".
Admittedly it is hot off the presses yet it ought to run at least in theory, but it doesn't.
Possible Implementation
They discuss an implementation here: https://falconllm.tii.ae/tii-releases-first-sslm-with-falcon-mamba-7b.html
Any functional mamba or mamba 2 models would be great, but this one is slightly changed.
The text was updated successfully, but these errors were encountered: