From 024bd29445ee5cb84abbd959b2fdb2d149aa8126 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Tue, 17 Jun 2025 15:03:34 +0200 Subject: [PATCH 1/6] Init - first pass. --- convert_hf_to_gguf.py | 10 ++++ docs/development/HOWTO-add-model.md | 20 ++++---- gguf-py/gguf/constants.py | 17 +++++++ src/llama-arch.cpp | 19 ++++++++ src/llama-arch.h | 1 + src/llama-model.cpp | 75 +++++++++++++++++++++++++++++ 6 files changed, 133 insertions(+), 9 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 58e455ae645ed..9d12a1e56d66a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6298,6 +6298,16 @@ def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) +@Model.register("SmolLM3ForCausalLM") +class SmolLM3Model(LlamaModel): + model_arch = gguf.MODEL_ARCH.SMOLLM3 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + if self.model.config.no_rope_layers is not None: + self.gguf_writer.add_array("smollm3.no_rope_layers", self.model.config.no_rope_layers, gguf.GGUFValueType.INT32) + ###### CONVERSION LOGIC ###### diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 7f71e0247ddc7..51e0b0b20f58d 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -83,20 +83,22 @@ NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the conv ### 2. Define the model architecture in `llama.cpp` -The model params and tensors layout must be defined in `llama.cpp`: -1. Define a new `llm_arch` -2. Define the tensors layout in `LLM_TENSOR_NAMES` -3. Add any non-standard metadata in `llm_load_hparams` -4. Create the tensors for inference in `llm_load_tensors` -5. If the model has a RoPE operation, add the rope type in `llama_rope_type` +The model params and tensors layout must be defined in `llama.cpp` source files: +1. Define a new `llm_arch` enum value in `src/llama-arch.h`. +2. In `src/llama-arch.cpp`: + - Add the architecture name to the `LLM_ARCH_NAMES` map. + - Add the tensor mappings to the `LLM_TENSOR_NAMES` map. +3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`. +4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`. NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions. ### 3. Build the GGML graph implementation -This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`. - -Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`. +This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`. +Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor. +Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`. +Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct. Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR. diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 834a1d5e1a97e..054591dcf8e23 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -346,6 +346,7 @@ class MODEL_ARCH(IntEnum): BAILINGMOE = auto() DOTS1 = auto() ARCEE = auto() + SMOLLM3 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -629,6 +630,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BAILINGMOE: "bailingmoe", MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.ARCEE: "arcee", + MODEL_ARCH.SMOLLM3: "smollm3", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -2101,6 +2103,21 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.SMOLLM3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index de8d289cf967e..ec8e5278a6200 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -75,6 +75,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_BAILINGMOE, "bailingmoe" }, { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, + { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1625,6 +1626,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, }, }, + { + LLM_ARCH_SMOLLM3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd.weight" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm.weight" }, + { LLM_TENSOR_OUTPUT, "output.weight" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm.weight" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q.weight" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k.weight" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v.weight" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output.weight" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate.weight" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down.weight" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up.weight" }, + }, + }, }; static const std::map LLM_TENSOR_INFOS = { diff --git a/src/llama-arch.h b/src/llama-arch.h index 3e8a61da3c13e..11b71e536338b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -79,6 +79,7 @@ enum llm_arch { LLM_ARCH_BAILINGMOE, LLM_ARCH_DOTS1, LLM_ARCH_ARCEE, + LLM_ARCH_SMOLLM3, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a5eb122f998d8..611b3f6fcff25 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13734,6 +13734,75 @@ struct llm_build_arcee : public llm_graph_context { } }; +struct llm_build_smollm3 : public llm_graph_context { + llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + std::vector no_rope_layers; + if (arch == LLM_ARCH_SMOLLM3) { + const int kid = gguf_find_key(model.meta, "smollm3.no_rope_layers"); + if (kid != -1) { + const uint32_t n = gguf_get_arr_n(model.meta, kid); + no_rope_layers.resize(n); + const int nb = gguf_get_arr_data(model.meta, kid, no_rope_layers.data(), n * sizeof(int32_t)); + GGML_ASSERT(nb == int(n * sizeof(int32_t))); + } + } + + const int64_t n_tokens = params.n_tokens; + const int64_t n_layer = hparams.n_layer; + + gf->n_threads = params.n_threads; + + // build the graph + inp_tokens->set_input(ubatch); + inp_pos->set_input(ubatch); + inp_attn_temp->set_input(ubatch); + + struct ggml_tensor * cur = build_inp_embd(); + struct ggml_tensor * lay_out = nullptr; + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inp_norm = build_norm(cur, hparams.f_norm_eps, il, tn(LLM_TENSOR_ATTN_NORM, il)); + struct ggml_tensor * qkv = build_attn(inp_norm, il); + struct ggml_tensor * q = ggml_view_4d(ctx, qkv, hparams.n_embd_head_v, hparams.n_head(il), n_tokens, 1, ggml_element_size(qkv)*hparams.n_embd_head_v, 0, 0, 0); + struct ggml_tensor * k = ggml_view_4d(ctx, qkv, hparams.n_embd_head_k, hparams.n_head_kv(il), n_tokens, 1, ggml_element_size(qkv)*hparams.n_embd_head_k, ggml_element_size(qkv)*hparams.n_embd_k_gqa(il), 0, 0); + struct ggml_tensor * v = ggml_view_4d(ctx, qkv, hparams.n_embd_head_v, hparams.n_head_kv(il), n_tokens, 1, ggml_element_size(qkv)*hparams.n_embd_head_v, ggml_element_size(qkv)*hparams.n_embd_k_gqa(il) + ggml_element_size(qkv)*hparams.n_embd_k_gqa(il), 0, 0); + + ggml_set_name(q, "q"); + ggml_set_name(k, "k"); + ggml_set_name(v, "v"); + + struct ggml_tensor * qcur = q; + struct ggml_tensor * kcur = k; + + bool apply_rope = true; + if (arch == LLM_ARCH_SMOLLM3) { + if (std::find(no_rope_layers.begin(), no_rope_layers.end(), il) != no_rope_layers.end()) { + apply_rope = false; + } + } + + if (apply_rope && get_tensor_meta(tn(LLM_TENSOR_ROPE_FREQS, il))) { + qcur = ggml_rope_ext(ctx, q, inp_pos->pos, get_tensor_meta(tn(LLM_TENSOR_ROPE_FREQS, il)), hparams.rope_type, 0, hparams.n_rot, hparams.n_gqa(il), hparams.rope_freq_base_train, hparams.rope_freq_scale_train, hparams.n_ctx_orig_yarn, hparams.rope_yarn_log_mul); + kcur = ggml_rope_ext(ctx, k, inp_pos->pos, get_tensor_meta(tn(LLM_TENSOR_ROPE_FREQS, il)), hparams.rope_type, 0, hparams.n_rot, hparams.n_gqa(il), hparams.rope_freq_base_train, hparams.rope_freq_scale_train, hparams.n_ctx_orig_yarn, hparams.rope_yarn_log_mul); + } + + struct ggml_tensor * attn_out = build_attn_out(inp_norm, qcur, kcur, v, il); + + if (hparams.use_par_res) { + // parallel residual + lay_out = ggml_add(ctx, attn_out, build_ff_par(inp_norm, il)); + } else { + // sequential residual + lay_out = ggml_add(ctx, cur, attn_out); + lay_out = build_ff_seq(lay_out, il); + } + cur = lay_out; + } + + build_output(cur, lay_out); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -14085,6 +14154,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_SMOLLM3: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -14235,9 +14308,11 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: case LLM_ARCH_NEO_BERT: + case LLM_ARCH_SMOLLM3: case LLM_ARCH_ARCEE: return LLAMA_ROPE_TYPE_NORM; + // the pairs of head values are offset by n_rot/2 case LLM_ARCH_FALCON: case LLM_ARCH_GROK: From 32ea9c5fc1c66c61a2d8a21119c788713c531d22 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Tue, 17 Jun 2025 15:09:15 +0200 Subject: [PATCH 2/6] Model -> ModelBase. --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9d12a1e56d66a..ef0501883d7b6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6298,7 +6298,7 @@ def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) -@Model.register("SmolLM3ForCausalLM") +@ModelBase.register("SmolLM3ForCausalLM") class SmolLM3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.SMOLLM3 From 02ff08507146c516f10eda0d709d1897fc13b20f Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Tue, 17 Jun 2025 16:01:53 +0200 Subject: [PATCH 3/6] fix errors in conversion. --- convert_hf_to_gguf.py | 7 +++++-- gguf-py/gguf/constants.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ef0501883d7b6..ce1500d699b9e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6305,8 +6305,11 @@ class SmolLM3Model(LlamaModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if self.model.config.no_rope_layers is not None: - self.gguf_writer.add_array("smollm3.no_rope_layers", self.model.config.no_rope_layers, gguf.GGUFValueType.INT32) + # if self.model.config.no_rope_layers is not None: + # self.gguf_writer.add_array("smollm3.no_rope_layers", self.model.config.no_rope_layers, gguf.GGUFValueType.INT32) + no_rope_layers = self.hparams.get("no_rope_layers") + if no_rope_layers is not None: + self.gguf_writer.add_array("smollm3.no_rope_layers", no_rope_layers) ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 054591dcf8e23..70083cd4e7713 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2114,6 +2114,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, From 6201b438146d2506240729788adf6e024b070842 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Thu, 19 Jun 2025 17:13:28 +0200 Subject: [PATCH 4/6] Update the graph. --- src/llama-model.cpp | 146 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 113 insertions(+), 33 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 611b3f6fcff25..ac6fc4e1e7662 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13736,6 +13736,11 @@ struct llm_build_arcee : public llm_graph_context { struct llm_build_smollm3 : public llm_graph_context { llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + // collect layers for which RoPE is disabled (metadata key: "smollm3.no_rope_layers") std::vector no_rope_layers; if (arch == LLM_ARCH_SMOLLM3) { const int kid = gguf_find_key(model.meta, "smollm3.no_rope_layers"); @@ -13747,59 +13752,134 @@ struct llm_build_smollm3 : public llm_graph_context { } } - const int64_t n_tokens = params.n_tokens; - const int64_t n_layer = hparams.n_layer; + // token embeddings + ggml_tensor * inpL = build_inp_embd(model.tok_embd); - gf->n_threads = params.n_threads; + // positional ids + ggml_tensor * inp_pos = build_inp_pos(); - // build the graph - inp_tokens->set_input(ubatch); - inp_pos->set_input(ubatch); - inp_attn_temp->set_input(ubatch); + // attention helper (unified KV cache) + auto * inp_attn = build_attn_inp_kv_unified(); - struct ggml_tensor * cur = build_inp_embd(); - struct ggml_tensor * lay_out = nullptr; + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + ggml_tensor * cur = nullptr; for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inp_norm = build_norm(cur, hparams.f_norm_eps, il, tn(LLM_TENSOR_ATTN_NORM, il)); - struct ggml_tensor * qkv = build_attn(inp_norm, il); - struct ggml_tensor * q = ggml_view_4d(ctx, qkv, hparams.n_embd_head_v, hparams.n_head(il), n_tokens, 1, ggml_element_size(qkv)*hparams.n_embd_head_v, 0, 0, 0); - struct ggml_tensor * k = ggml_view_4d(ctx, qkv, hparams.n_embd_head_k, hparams.n_head_kv(il), n_tokens, 1, ggml_element_size(qkv)*hparams.n_embd_head_k, ggml_element_size(qkv)*hparams.n_embd_k_gqa(il), 0, 0); - struct ggml_tensor * v = ggml_view_4d(ctx, qkv, hparams.n_embd_head_v, hparams.n_head_kv(il), n_tokens, 1, ggml_element_size(qkv)*hparams.n_embd_head_v, ggml_element_size(qkv)*hparams.n_embd_k_gqa(il) + ggml_element_size(qkv)*hparams.n_embd_k_gqa(il), 0, 0); + ggml_tensor * inpSA = inpL; + + // attention norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // ---- self-attention ---- + { + // fused QKV projection + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + if (model.layers[il].bqkv) { + qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv); + cb(qkv, "bqkv", il); + } + + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - ggml_set_name(q, "q"); - ggml_set_name(k, "k"); - ggml_set_name(v, "v"); + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0)); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], sizeof(float)*(n_embd + n_embd_gqa))); - struct ggml_tensor * qcur = q; - struct ggml_tensor * kcur = k; + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - bool apply_rope = true; - if (arch == LLM_ARCH_SMOLLM3) { - if (std::find(no_rope_layers.begin(), no_rope_layers.end(), il) != no_rope_layers.end()) { - apply_rope = false; + if (std::find(no_rope_layers.begin(), no_rope_layers.end(), il) == no_rope_layers.end()) { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); } - } - if (apply_rope && get_tensor_meta(tn(LLM_TENSOR_ROPE_FREQS, il))) { - qcur = ggml_rope_ext(ctx, q, inp_pos->pos, get_tensor_meta(tn(LLM_TENSOR_ROPE_FREQS, il)), hparams.rope_type, 0, hparams.n_rot, hparams.n_gqa(il), hparams.rope_freq_base_train, hparams.rope_freq_scale_train, hparams.n_ctx_orig_yarn, hparams.rope_yarn_log_mul); - kcur = ggml_rope_ext(ctx, k, inp_pos->pos, get_tensor_meta(tn(LLM_TENSOR_ROPE_FREQS, il)), hparams.rope_type, 0, hparams.n_rot, hparams.n_gqa(il), hparams.rope_freq_base_train, hparams.rope_freq_scale_train, hparams.n_ctx_orig_yarn, hparams.rope_yarn_log_mul); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); } - struct ggml_tensor * attn_out = build_attn_out(inp_norm, qcur, kcur, v, il); + // skip padded tokens for final layer + if (il == n_layer - 1) { + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + // ---- feed-forward ---- if (hparams.use_par_res) { // parallel residual - lay_out = ggml_add(ctx, attn_out, build_ff_par(inp_norm, il)); + ggml_tensor * ffn_cur = build_norm(inpL, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(ffn_cur, "ffn_norm", il); + + ffn_cur = build_ffn(ffn_cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_cur); + cb(cur, "par_res", il); } else { // sequential residual - lay_out = ggml_add(ctx, cur, attn_out); - lay_out = build_ff_seq(lay_out, il); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); } - cur = lay_out; + + // post-processing + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + inpL = cur; } - build_output(cur, lay_out); + // final RMSNorm + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); } }; From 97c64a0974f0d46395ce6ab364b2815c9642d345 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Fri, 4 Jul 2025 14:15:34 +0200 Subject: [PATCH 5/6] up. --- src/llama-model.cpp | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ac6fc4e1e7662..c0be7e6932e96 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13828,12 +13828,13 @@ struct llm_build_smollm3 : public llm_graph_context { LLM_NORM_RMS, il); cb(ffn_cur, "ffn_norm", il); - ffn_cur = build_ffn(ffn_cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); + ffn_cur = build_ffn( + ffn_cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, nullptr, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); cb(ffn_cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_cur); @@ -13848,13 +13849,14 @@ struct llm_build_smollm3 : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); + cur = build_ffn( + cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, nullptr, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); From 996195299e36565e5d07ba7bde470ca4822a2c10 Mon Sep 17 00:00:00 2001 From: Vaibhavs10 Date: Mon, 7 Jul 2025 23:42:40 +0200 Subject: [PATCH 6/6] up. --- convert_hf_to_gguf.py | 8 +++----- src/llama-hparams.h | 3 +++ src/llama-model.cpp | 18 ++++++------------ 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ce1500d699b9e..1a6eb52485a4c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6305,11 +6305,9 @@ class SmolLM3Model(LlamaModel): def set_gguf_parameters(self): super().set_gguf_parameters() - # if self.model.config.no_rope_layers is not None: - # self.gguf_writer.add_array("smollm3.no_rope_layers", self.model.config.no_rope_layers, gguf.GGUFValueType.INT32) - no_rope_layers = self.hparams.get("no_rope_layers") - if no_rope_layers is not None: - self.gguf_writer.add_array("smollm3.no_rope_layers", no_rope_layers) + no_rope_layer_interval = self.hparams.get("no_rope_layer_interval") + if no_rope_layer_interval is not None: + self.gguf_writer.add_uint32("no_rope_layer_interval", no_rope_layer_interval) ###### CONVERSION LOGIC ###### diff --git a/src/llama-hparams.h b/src/llama-hparams.h index b2bcb8b01a18b..3d5225dd474f2 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -186,6 +186,9 @@ struct llama_hparams { // dimension of the recurrent state embeddings uint32_t n_embd_v_s() const; + // for NoPE interval + uint32_t no_rope_layer_interval = 0; + bool is_swa(uint32_t il) const; }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c0be7e6932e96..b2e7668311c12 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -443,6 +443,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { return; } + if (arch == LLM_ARCH_SMOLLM3) { + ml.get_key("no_rope_layer_interval", hparams.no_rope_layer_interval); + } + ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); @@ -13740,17 +13744,7 @@ struct llm_build_smollm3 : public llm_graph_context { GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); - // collect layers for which RoPE is disabled (metadata key: "smollm3.no_rope_layers") - std::vector no_rope_layers; - if (arch == LLM_ARCH_SMOLLM3) { - const int kid = gguf_find_key(model.meta, "smollm3.no_rope_layers"); - if (kid != -1) { - const uint32_t n = gguf_get_arr_n(model.meta, kid); - no_rope_layers.resize(n); - const int nb = gguf_get_arr_data(model.meta, kid, no_rope_layers.data(), n * sizeof(int32_t)); - GGML_ASSERT(nb == int(n * sizeof(int32_t))); - } - } + const uint32_t interval = hparams.no_rope_layer_interval; // token embeddings ggml_tensor * inpL = build_inp_embd(model.tok_embd); @@ -13793,7 +13787,7 @@ struct llm_build_smollm3 : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - if (std::find(no_rope_layers.begin(), no_rope_layers.end(), il) == no_rope_layers.end()) { + if (interval == 0 || il % interval != 0) { ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,