Skip to content

MODEL: Falcon-H1 support #14238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 105 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
1f0fea7
llama : initial Mamba-2 support
compilade Aug 1, 2024
dceff23
ggml : SIMD ggml_ssm_scan for Mamba-2
compilade Aug 19, 2024
2bfe9de
llama : support running Mamba-Codestral-7B-v0.1
compilade Aug 19, 2024
aff9692
llama : fix Mamba-2 conv state saving
compilade Aug 21, 2024
e04910d
llama : remove unused variable
compilade Aug 22, 2024
fa358e7
llama : add missing break
compilade Aug 22, 2024
38913dc
convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present
compilade Aug 22, 2024
0e601ca
Merge branch 'master' into compilade/mamba2
compilade Sep 18, 2024
273e7a4
llama : avoid redundant state copy for Mamba 1 and 2
compilade Sep 30, 2024
7d6cb36
Merge branch 'master' into compilade/mamba2
compilade Oct 1, 2024
2c77d79
metal : attempt to adapt SSM_SCAN for Mamba-2
compilade Oct 2, 2024
87b97d0
metal : fix SSM_SCAN pipeline scope
compilade Oct 2, 2024
03d0e6e
metal : use log and exp instead of log1pf and expf in SSM_SCAN
compilade Oct 2, 2024
7a351ab
metal : remove unused arguments for SSM_SCAN
compilade Oct 2, 2024
8b15bc6
metal : add back n_seqs to SSM_SCAN args
compilade Oct 2, 2024
5b8ec2b
metal : fix SSM_SCAN state head offset
compilade Oct 2, 2024
62b09b3
metal : fix wrong number of tokens per sequence in SSM_SCAN
compilade Oct 3, 2024
038d958
Merge branch 'master' into compilade/mamba2
compilade Oct 12, 2024
805512a
ggml : remove unused fast broadcast path in GGML_MUL
compilade Oct 12, 2024
7d16e1b
Merge branch 'master' into compilade/mamba2
compilade Nov 1, 2024
3bc7103
ggml : avoid multiply by D in GGML_OP_SSM_SCAN
compilade Nov 4, 2024
8d8f065
Merge branch 'master' into compilade/mamba2
compilade Nov 4, 2024
b4e9c59
convert : fix flake8 lint
compilade Nov 4, 2024
1ee6c48
Merge branch 'master' into compilade/mamba2
compilade Nov 25, 2024
c9ecf62
Merge branch 'master' into compilade/mamba2
compilade Feb 26, 2025
35d06fa
Merge branch 'master' into compilade/mamba2
compilade May 1, 2025
cf4f0a4
metal : fix confusion between ; and ,
compilade May 1, 2025
6def5cd
metal : add missing args for nb references in ssm_scan_f32_group
compilade May 1, 2025
791998b
metal : single-user mamba2 inference works
compilade May 2, 2025
94c3d53
kv-cache : remove const_cast when setting inputs for s_copy
compilade May 2, 2025
929fe85
Merge branch 'master' into compilade/mamba2
compilade May 2, 2025
d55b0d0
convert : avoid AutoConfig for Mamba and Mamba2 hparams
compilade May 2, 2025
e94f393
kv-cache : allow context shift for recurrent models
compilade May 2, 2025
9864bfc
Merge branch 'master' into compilade/mamba2
compilade Jun 10, 2025
2fa5f2c
graph : fix recurrent state copies when avoiding copies
compilade Jun 11, 2025
757aa62
ggml : fix mamba2 ssm scan when compiled with SVE
compilade Jun 11, 2025
0b6f6be
ggml-cpu : reorder SVE FMA for consistency with other SIMD arches
compilade Jun 11, 2025
4dff845
feat: Add llama_model_is_hybrid API call
gabe-l-hart May 9, 2025
18224b7
feat: Add c++ side constants for attention layer indices hparam
gabe-l-hart May 9, 2025
4782ac0
feat: Add support for distinguishing recurrent vs non-recurrent layer…
gabe-l-hart May 9, 2025
6a397e5
feat: Auto-fill hparams.recurrent_layer_arr based on whether the mode…
gabe-l-hart May 9, 2025
4521708
refactor: rename *_is_hybrid -> *_is_hybrid_recurrent
gabe-l-hart May 28, 2025
263d913
feat: Add layer filter to recurrent cache
gabe-l-hart May 20, 2025
a7bd782
fix: Use per-layer sizing everywhere in kv caches
gabe-l-hart May 14, 2025
af916a0
feat: First pass at llama_kv_cache_hybrid_recurrent
gabe-l-hart May 30, 2025
1ee27a4
feat: Construct hybrid recurrent cache for hybrid recurrent models
gabe-l-hart May 28, 2025
781d5aa
fix: Fix wrong bool condition for split equal in hybrid cache
gabe-l-hart May 28, 2025
a3f58d9
fix: Fix shift logic to defer to unified cache
gabe-l-hart Jun 3, 2025
c8c9aa2
feat: Support hybrid recurrent in llama-graph
gabe-l-hart Jun 4, 2025
e77503f
fix: Fix logic for initializing inputs and attn layers for hybrid caches
gabe-l-hart Jun 4, 2025
d327af7
fix: Update recurrent cache for changes to remove intermediate kv_cac…
gabe-l-hart Jun 5, 2025
ee9b31c
fix: Fix status for init_update sig for recurrent cache state
gabe-l-hart Jun 5, 2025
c3c0ee6
fix: Add missing padding to n_ctx for hybrid cache construction
gabe-l-hart Jun 5, 2025
df695b3
fix: Update clear signature for data argument after rebase
gabe-l-hart Jun 6, 2025
9520af2
fix: Remove errant virtual destructor leftover from previous impl att…
gabe-l-hart Jun 10, 2025
6361a73
fix: Use per-layer n_embd_k/v_s calls for mamba (1) layers
gabe-l-hart Jun 10, 2025
b115052
refactor: Remove n_embd_k/v_s from unified cache
gabe-l-hart Jun 11, 2025
8690495
refactor: Remove layer index from n_embd_k/v_s
gabe-l-hart Jun 11, 2025
fa6dbe8
refactor: Remove n_embd_k/v_gqa from recurrent cache
gabe-l-hart Jun 11, 2025
42459d0
feat: Allow custom layer filters for hybrid recurrent
gabe-l-hart Jun 11, 2025
74ad4f8
fix: Remove logits_all after rebase
gabe-l-hart Jun 12, 2025
b3c948a
fix: Remove llama_model_is_hybrid_Recurrent public API
gabe-l-hart Jun 12, 2025
f3e34bb
refactor: Use llama_memory_state_ptr for child states in hybrid memor…
gabe-l-hart Jun 12, 2025
6253c7c
feat: Overhaul build_recurrent_state / build_inp_s_copy to match atte…
gabe-l-hart Jun 12, 2025
0d80ee7
Merge remote-tracking branch 'origin/compilade/mamba2' into mamba2-sync
gabe-l-hart Jun 12, 2025
0385fc0
some changes
younesbelkada Jun 16, 2025
c5f9f36
feat: Add llama_model_is_hybrid API call
gabe-l-hart May 9, 2025
665a357
feat: Add c++ side constants for attention layer indices hparam
gabe-l-hart May 9, 2025
6497b80
feat: Add support for distinguishing recurrent vs non-recurrent layer…
gabe-l-hart May 9, 2025
acd5564
feat: Auto-fill hparams.recurrent_layer_arr based on whether the mode…
gabe-l-hart May 9, 2025
cc40ce9
refactor: rename *_is_hybrid -> *_is_hybrid_recurrent
gabe-l-hart May 28, 2025
f56561b
feat: Add layer filter to recurrent cache
gabe-l-hart May 20, 2025
17abb2b
fix: Use per-layer sizing everywhere in kv caches
gabe-l-hart May 14, 2025
b63beac
feat: First pass at llama_kv_cache_hybrid_recurrent
gabe-l-hart May 30, 2025
4c37343
feat: Construct hybrid recurrent cache for hybrid recurrent models
gabe-l-hart May 28, 2025
0ffbec5
fix: Fix wrong bool condition for split equal in hybrid cache
gabe-l-hart May 28, 2025
936da1a
fix: Fix shift logic to defer to unified cache
gabe-l-hart Jun 3, 2025
fb220cc
feat: Support hybrid recurrent in llama-graph
gabe-l-hart Jun 4, 2025
40e7abd
fix: Fix logic for initializing inputs and attn layers for hybrid caches
gabe-l-hart Jun 4, 2025
309d75e
fix: Update recurrent cache for changes to remove intermediate kv_cac…
gabe-l-hart Jun 5, 2025
eaaf4a9
fix: Fix status for init_update sig for recurrent cache state
gabe-l-hart Jun 5, 2025
581113f
fix: Add missing padding to n_ctx for hybrid cache construction
gabe-l-hart Jun 5, 2025
84e300f
fix: Update clear signature for data argument after rebase
gabe-l-hart Jun 6, 2025
ae7a02e
fix: Remove errant virtual destructor leftover from previous impl att…
gabe-l-hart Jun 10, 2025
1e49029
fix: Use per-layer n_embd_k/v_s calls for mamba (1) layers
gabe-l-hart Jun 10, 2025
ddc85f0
refactor: Remove n_embd_k/v_s from unified cache
gabe-l-hart Jun 11, 2025
368c9f4
refactor: Remove layer index from n_embd_k/v_s
gabe-l-hart Jun 11, 2025
b82a225
refactor: Remove n_embd_k/v_gqa from recurrent cache
gabe-l-hart Jun 11, 2025
8b7f687
feat: Allow custom layer filters for hybrid recurrent
gabe-l-hart Jun 11, 2025
3ee2222
fix: Remove logits_all after rebase
gabe-l-hart Jun 12, 2025
bd37fc8
fix: Remove llama_model_is_hybrid_Recurrent public API
gabe-l-hart Jun 12, 2025
2e5e45c
refactor: Use llama_memory_state_ptr for child states in hybrid memor…
gabe-l-hart Jun 12, 2025
a42b9cb
feat: Overhaul build_recurrent_state / build_inp_s_copy to match atte…
gabe-l-hart Jun 12, 2025
beddd62
fix: Fix resize vs reserve and skip null tensors in size computation
gabe-l-hart Jun 16, 2025
8d24073
fix: Fix initialization of child states
gabe-l-hart Jun 16, 2025
95b6698
refactor: Use a common build_recurrent_state method that is cache-agn…
gabe-l-hart Jun 16, 2025
e16aacc
Merge remote-tracking branch 'origin/compilade/mamba2' into mamba2-sync
gabe-l-hart Jun 16, 2025
9d30fd4
Merge remote-tracking branch 'gabe/mamba2-sync' into add-fh1-clean + …
younesbelkada Jun 17, 2025
aff19ae
more clean ups
younesbelkada Jun 17, 2025
9150fc9
fix inference
younesbelkada Jun 17, 2025
11857e8
Commit credits attribution
HDElectronics Jun 17, 2025
ebfcb6f
Commit credits attribution
brahimfarhat Jun 17, 2025
82f59d0
Commit credits attribution
HamzaYousLM Jun 17, 2025
837be3e
add convert script
younesbelkada Jun 17, 2025
41f25cf
add missing elements in py files
younesbelkada Jun 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 261 additions & 2 deletions convert_hf_to_gguf.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1878,7 +1878,8 @@ extern "C" {
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C);
struct ggml_tensor * C,
struct ggml_tensor * ids);

// partition into non-overlapping windows with padding if needed
// example:
Expand Down
278 changes: 184 additions & 94 deletions ggml/src/ggml-cpu/ops.cpp

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ggml/src/ggml-cpu/simd-mappings.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
Expand Down
18 changes: 9 additions & 9 deletions ggml/src/ggml-cpu/vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,43 +37,43 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
for (int i = 0; i < np; i += ggml_f32_step) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);

ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2);
sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);

ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3);
sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);

ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4);
sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);

ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5);
sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);

ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6);
sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);

ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7);
sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);

ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8);
sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
}
// leftovers
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
const int np2 = (n & ~(ggml_f32_epr - 1));
for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
}
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
if (np2 < n) {
Expand Down
18 changes: 9 additions & 9 deletions ggml/src/ggml-cpu/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const

ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);

GGML_F32_VEC_STORE(y + i, ay1);

ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);

GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);

ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);

GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);

ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);

GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);

ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);

GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);

ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);

GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);

ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);

GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);

ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);

GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
}
Expand All @@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);

GGML_F32_VEC_STORE(y + i, ay1);
}
Expand Down
11 changes: 5 additions & 6 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,26 +488,25 @@ typedef struct {
typedef struct {
int64_t d_state;
int64_t d_inner;
int64_t n_head;
int64_t n_group;
int64_t n_seq_tokens;
int64_t n_seqs;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb10;
uint64_t nb03;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb30;
uint64_t nb31;
uint64_t nb40;
uint64_t nb41;
uint64_t nb42;
uint64_t nb50;
uint64_t nb43;
uint64_t nb51;
uint64_t nb52;
uint64_t nb53;
} ggml_metal_kargs_ssm_scan;

typedef struct {
Expand Down
93 changes: 61 additions & 32 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
Expand Down Expand Up @@ -1147,6 +1148,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
Expand Down Expand Up @@ -2632,71 +2634,91 @@ static bool ggml_metal_encode_node(
struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src4 = node->src[4];
struct ggml_tensor * src5 = node->src[5];
struct ggml_tensor * src6 = node->src[6];

GGML_ASSERT(src3);
GGML_ASSERT(src4);
GGML_ASSERT(src5);
GGML_ASSERT(src6);

size_t offs_src3 = 0;
size_t offs_src4 = 0;
size_t offs_src5 = 0;
size_t offs_src6 = 0;

id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;

const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
const int64_t ne30 = src3->ne[0];
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);

const uint64_t nb30 = src3->nb[0];
const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
const uint64_t nb31 = src3->nb[1];

const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
const int64_t ne41 = src4->ne[1];
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);

const uint64_t nb40 = src4->nb[0];
const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
const uint64_t nb41 = src4->nb[1];
const uint64_t nb42 = src4->nb[2];
const uint64_t nb43 = src4->nb[3];

const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);

const uint64_t nb50 = src5->nb[0];
const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
const uint64_t nb51 = src5->nb[1];
const uint64_t nb52 = src5->nb[2];
const uint64_t nb53 = src5->nb[3];

const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);

const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);

const int64_t d_state = ne00;
const int64_t d_inner = ne01;
const int64_t n_seq_tokens = ne11;
const int64_t n_seqs = ne02;
const int64_t n_head = ne02;
const int64_t n_group = ne41;
const int64_t n_seq_tokens = ne12;
const int64_t n_seqs = ne13;

id<MTLComputePipelineState> pipeline = nil;

id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
if (ne30 == 1) {
// Mamba-2
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
}

ggml_metal_kargs_ssm_scan args = {
/*.d_state =*/ d_state,
/*.d_inner =*/ d_inner,
/*.d_state =*/ d_state,
/*.d_inner =*/ d_inner,
/*.n_head =*/ n_head,
/*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb30 =*/ nb30,
/*.nb31 =*/ nb31,
/*.nb40 =*/ nb40,
/*.nb41 =*/ nb41,
/*.nb42 =*/ nb42,
/*.nb50 =*/ nb50,
/*.nb51 =*/ nb51,
/*.nb52 =*/ nb52,
/*.n_seqs =*/ n_seqs,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb31 =*/ nb31,
/*.nb41 =*/ nb41,
/*.nb42 =*/ nb42,
/*.nb43 =*/ nb43,
/*.nb51 =*/ nb51,
/*.nb52 =*/ nb52,
/*.nb53 =*/ nb53,
};

[encoder setComputePipelineState:pipeline];
Expand All @@ -2706,10 +2728,17 @@ static bool ggml_metal_encode_node(
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
[encoder setBytes:&args length:sizeof(args) atIndex:7];
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
[encoder setBytes:&args length:sizeof(args) atIndex:8];

[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
} break;
case GGML_OP_RWKV_WKV6:
{
Expand Down
Loading
Loading