From c912c6744902efb8e166539f823dfe225e23948f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 2 May 2025 11:51:16 +0200 Subject: [PATCH 01/30] wip llama 4 conversion --- convert_hf_to_gguf.py | 26 ++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 23 +++++++++++++++-------- gguf-py/gguf/tensor_mapping.py | 19 +++++++++++++++++-- 3 files changed, 58 insertions(+), 10 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index df3f8a55d5320..cb18624fef1f2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2044,6 +2044,32 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Llama4ForConditionalGeneration") +class Llama4VisionModel(VisionModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this + self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"] + self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"]) + self.gguf_writer.add_vision_projector_scale_factor((1.0 / self.hparams["pixel_shuffle_ratio"]) // 1) + assert self.hparams["hidden_act"] == "gelu" + self.gguf_writer.add_vision_use_gelu(True) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if "multi_modal_projector" in name or "vision_model" in name: + # process vision tensors + if "positional_embedding_vlm" in name: + name += ".weight" + return [] + + + @ModelBase.register("Mistral3ForConditionalGeneration") class Mistral3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.LLAMA diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a2540bd93fd91..a2a9af2429f0f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -481,15 +481,17 @@ class MODEL_TENSOR(IntEnum): V_ENC_EMBD_CLS = auto() V_ENC_EMBD_PATCH = auto() V_ENC_EMBD_POS = auto() + V_ENC_INPUT_NORM = auto() V_ENC_ATTN_Q = auto() V_ENC_ATTN_K = auto() V_ENC_ATTN_V = auto() - V_ENC_INPUT_NORM = auto() - V_ENC_OUTPUT = auto() - V_ENC_OUTPUT_NORM = auto() + V_ENC_ATTN_O = auto() + V_ENC_ATTN_O_NORM = auto() + V_ENC_POST_ATTN_NORM = auto() V_ENC_FFN_UP = auto() V_ENC_FFN_GATE = auto() V_ENC_FFN_DOWN = auto() + V_ENC_FFN_POST_NORM = auto() V_PRE_NORM = auto() V_POST_NORM = auto() V_MM_INP_NORM = auto() @@ -742,11 +744,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k", MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v", MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1", - MODEL_TENSOR.V_ENC_OUTPUT: "v.blk.{bid}.attn_out", - MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.blk.{bid}.ln2", + MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out", + MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm", + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2", MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down", + MODEL_TENSOR.V_ENC_FFN_POST_NORM: "v.blk.{bid}.ffn_post_norm", MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", MODEL_TENSOR.V_POST_NORM: "v.post_ln", MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", @@ -776,15 +780,17 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_ENC_EMBD_CLS, MODEL_TENSOR.V_ENC_EMBD_PATCH, MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_INPUT_NORM, MODEL_TENSOR.V_ENC_ATTN_Q, MODEL_TENSOR.V_ENC_ATTN_K, MODEL_TENSOR.V_ENC_ATTN_V, - MODEL_TENSOR.V_ENC_INPUT_NORM, - MODEL_TENSOR.V_ENC_OUTPUT, - MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_ATTN_O, + MODEL_TENSOR.V_ENC_ATTN_O_NORM, + MODEL_TENSOR.V_ENC_POST_ATTN_NORM, MODEL_TENSOR.V_ENC_FFN_UP, MODEL_TENSOR.V_ENC_FFN_GATE, MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_ENC_FFN_POST_NORM, MODEL_TENSOR.V_PRE_NORM, MODEL_TENSOR.V_POST_NORM, MODEL_TENSOR.V_MM_INP_PROJ, @@ -2162,6 +2168,7 @@ class VisionProjectorType: GEMMA3 = "gemma3" IDEFICS3 = "idefics3" PIXTRAL = "pixtral" + LLAMA4 = "llama4" # Items here are (block size, type size) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2f6326104ffa7..d3e4edfcc8890 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -900,10 +900,12 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ_FC: ( "model.connector.modality_projection.proj", # SmolVLM + "multi_modal_projector.linear_1", # llama 4 ), MODEL_TENSOR.V_MMPROJ_MLP: ( "model.mm_projector.mlp.mlp.{bid}", + "vision_model.vision_adapter.mlp.fc{bid}.weight", # llama 4 ), MODEL_TENSOR.V_MMPROJ_PEG: ( @@ -912,6 +914,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_EMBD_CLS: ( "vision_tower.vision_model.embeddings.class_embedding", + "vision_model.class_embedding", # llama 4 ), MODEL_TENSOR.V_ENC_EMBD_PATCH: ( @@ -919,18 +922,21 @@ class TensorNameMap: "vpm.embeddings.patch_embedding", "model.vision_model.embeddings.patch_embedding", # SmolVLM "vision_tower.patch_conv", # pixtral + "vision_model.patch_embedding.linear", # llama 4 ), MODEL_TENSOR.V_ENC_EMBD_POS: ( "vision_tower.vision_model.embeddings.position_embedding", "vpm.embeddings.position_embedding", "model.vision_model.embeddings.position_embedding", # SmolVLM + "vision_model.positional_embedding_vlm", # llama 4 ), MODEL_TENSOR.V_ENC_ATTN_Q: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", "vpm.encoder.layers.{bid}.self_attn.q_proj", "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.q_proj", # llama4 "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral ), @@ -938,6 +944,7 @@ class TensorNameMap: "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", "vpm.encoder.layers.{bid}.self_attn.k_proj", "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.k_proj", # llama4 "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral ), @@ -945,6 +952,7 @@ class TensorNameMap: "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", "vpm.encoder.layers.{bid}.self_attn.v_proj", "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.v_proj", # llama4 "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral ), @@ -953,19 +961,22 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.layer_norm1", "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral + "vision_model.model.layers.{bid}.input_layernorm", # llama4 ), - MODEL_TENSOR.V_ENC_OUTPUT: ( + MODEL_TENSOR.V_ENC_ATTN_O: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", "vpm.encoder.layers.{bid}.self_attn.out_proj", "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4 "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral ), - MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", "vpm.encoder.layers.{bid}.layer_norm2", "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM + "vision_model.model.layers.{bid}.post_attention_layernorm", # llama4 "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral ), @@ -974,6 +985,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.mlp.fc1", "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped) "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral + "vision_model.model.layers.{bid}.mlp.fc1", # llama4 ), MODEL_TENSOR.V_ENC_FFN_GATE: ( @@ -985,16 +997,19 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.mlp.fc2", "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped) "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral + "vision_model.model.layers.{bid}.mlp.fc2", # llama4 ), MODEL_TENSOR.V_PRE_NORM: ( "vision_tower.vision_model.pre_layrnorm", "vision_tower.ln_pre", # pixtral + "vision_model.layernorm_pre", # llama4 ), MODEL_TENSOR.V_POST_NORM: ( "vision_tower.vision_model.post_layernorm", "model.vision_model.post_layernorm", # SmolVLM + "vision_model.layernorm_post", # llama4 ), MODEL_TENSOR.V_MM_INP_PROJ: ( From a67a1bed62811f10c70f19b2dd83c675c6e2f3d1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 2 May 2025 11:57:37 +0200 Subject: [PATCH 02/30] rm redundant __init__ --- convert_hf_to_gguf.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cb18624fef1f2..66e4197156d86 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2046,12 +2046,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): @ModelBase.register("Llama4ForConditionalGeneration") class Llama4VisionModel(VisionModel): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this - self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"] - self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"] - def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4) From 10db58b53ffa6e7361a38f86807b094f8a2a814e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 2 May 2025 12:00:38 +0200 Subject: [PATCH 03/30] fix conversion --- convert_hf_to_gguf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 66e4197156d86..4afd529abc4ab 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2050,7 +2050,7 @@ def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4) self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"]) - self.gguf_writer.add_vision_projector_scale_factor((1.0 / self.hparams["pixel_shuffle_ratio"]) // 1) + self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"])) assert self.hparams["hidden_act"] == "gelu" self.gguf_writer.add_vision_use_gelu(True) @@ -2060,6 +2060,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # process vision tensors if "positional_embedding_vlm" in name: name += ".weight" + return [(self.map_tensor_name(name), data_torch)] return [] From c50e627b48c3b1793a9ac908da1dd9082c5c4c43 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 2 May 2025 23:52:56 +0200 Subject: [PATCH 04/30] fix conversion --- convert_hf_to_gguf.py | 2 +- gguf-py/gguf/tensor_mapping.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d483c2a26e341..876ee644ee321 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2090,7 +2090,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter del bid # unused if "multi_modal_projector" in name or "vision_model" in name: # process vision tensors - if "positional_embedding_vlm" in name: + if "positional_embedding_vlm" in name and ".weight" not in name: name += ".weight" return [(self.map_tensor_name(name), data_torch)] return [] diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 249ab62a1ba8b..4e2e357901e67 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -906,7 +906,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ_MLP: ( "model.mm_projector.mlp.mlp.{bid}", - "vision_model.vision_adapter.mlp.fc{bid}.weight", # llama 4 + "vision_model.vision_adapter.mlp.fc{bid}", # llama 4 ), MODEL_TENSOR.V_MMPROJ_PEG: ( From 8775bc4ed6ff2716bea4407c25e6710f09283c50 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 3 May 2025 00:09:36 +0200 Subject: [PATCH 05/30] test impl --- tools/llava/clip-impl.h | 3 + tools/llava/clip.cpp | 235 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 217 insertions(+), 21 deletions(-) diff --git a/tools/llava/clip-impl.h b/tools/llava/clip-impl.h index b575ca4d7c2a9..3d1cf4e8bb3ca 100644 --- a/tools/llava/clip-impl.h +++ b/tools/llava/clip-impl.h @@ -60,6 +60,7 @@ #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s" #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" +#define TN_FFN_POST_NORM "%s.blk.%d.ffn_post_norm.%s" #define TN_LN_1 "%s.blk.%d.ln1.%s" #define TN_LN_2 "%s.blk.%d.ln2.%s" #define TN_LN_PRE "%s.pre_ln.%s" @@ -103,6 +104,7 @@ enum projector_type { PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, PROJECTOR_TYPE_QWEN25VL, + PROJECTOR_TYPE_LLAMA4, PROJECTOR_TYPE_UNKNOWN, }; @@ -117,6 +119,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_GEMMA3, "gemma3"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, + { PROJECTOR_TYPE_LLAMA4, "llama4"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/llava/clip.cpp b/tools/llava/clip.cpp index 7607d4e3ae3a4..9c21b22bfc242 100644 --- a/tools/llava/clip.cpp +++ b/tools/llava/clip.cpp @@ -176,6 +176,10 @@ struct clip_hparams { }; struct clip_layer { + // layernorm 1 (input norm) + struct ggml_tensor * ln_1_w = nullptr; + struct ggml_tensor * ln_1_b = nullptr; + // attention struct ggml_tensor * k_w = nullptr; struct ggml_tensor * k_b = nullptr; @@ -187,29 +191,28 @@ struct clip_layer { struct ggml_tensor * o_w = nullptr; struct ggml_tensor * o_b = nullptr; - // layernorm 1 - struct ggml_tensor * ln_1_w = nullptr; - struct ggml_tensor * ln_1_b = nullptr; + // layernorm 2 (post-attn norm / pre-ffn norm) + struct ggml_tensor * ln_2_w = nullptr; + struct ggml_tensor * ln_2_b = nullptr; // ff struct ggml_tensor * ff_i_w = nullptr; // legacy naming struct ggml_tensor * ff_i_b = nullptr; // legacy naming struct ggml_tensor * ff_o_w = nullptr; // legacy naming struct ggml_tensor * ff_o_b = nullptr; // legacy naming + struct ggml_tensor * ff_g_w = nullptr; // legacy naming + struct ggml_tensor * ff_g_b = nullptr; // legacy naming - struct ggml_tensor * ff_up_w = nullptr; - struct ggml_tensor * ff_up_b = nullptr; + struct ggml_tensor * ff_up_w = nullptr; + struct ggml_tensor * ff_up_b = nullptr; struct ggml_tensor * ff_gate_w = nullptr; struct ggml_tensor * ff_gate_b = nullptr; struct ggml_tensor * ff_down_w = nullptr; struct ggml_tensor * ff_down_b = nullptr; - struct ggml_tensor * ff_g_w = NULL; - struct ggml_tensor * ff_g_b = NULL; - - // layernorm 2 - struct ggml_tensor * ln_2_w = nullptr; - struct ggml_tensor * ln_2_b = nullptr; + // post-ffn norm (output layer norm) + struct ggml_tensor * post_ffn_norm_w = nullptr; + struct ggml_tensor * post_ffn_norm_b = nullptr; }; struct clip_vision_model { @@ -560,9 +563,10 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im static ggml_tensor * build_rope_2d( ggml_context * ctx0, ggml_tensor * cur, - ggml_tensor * pos_h, - ggml_tensor * pos_w, - const float freq_base + ggml_tensor * pos_a, // first half + ggml_tensor * pos_b, // second half + const float freq_base, + const bool interleave_freq ) { const int64_t n_dim = cur->ne[0]; const int64_t n_head = cur->ne[1]; @@ -576,7 +580,9 @@ static ggml_tensor * build_rope_2d( // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2) // then for the second half, we use freq_scale to shift the inv_freq // ^ why? replace (2i) with (2i+1) in the above equation - const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim); + const float freq_scale_odd = interleave_freq + ? std::pow(freq_base, (float)-2/n_dim) + : 1.0; // first half ggml_tensor * first; @@ -589,7 +595,7 @@ static ggml_tensor * build_rope_2d( first = ggml_rope_ext( ctx0, first, - pos_h, // positions + pos_a, // positions nullptr, // freq factors n_dim/2, // n_dims 0, 0, freq_base, @@ -609,7 +615,7 @@ static ggml_tensor * build_rope_2d( second = ggml_rope_ext( ctx0, second, - pos_w, // positions + pos_b, // positions nullptr, // freq factors n_dim/2, // n_dims 0, 0, freq_base, @@ -687,13 +693,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur); Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches); - Q = build_rope_2d(ctx0, Q, pos_h, pos_w, hparams.rope_theta); + Q = build_rope_2d(ctx0, Q, pos_h, pos_w, hparams.rope_theta, true); Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur); K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches); - K = build_rope_2d(ctx0, K, pos_h, pos_w, hparams.rope_theta); + K = build_rope_2d(ctx0, K, pos_h, pos_w, hparams.rope_theta, true); K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur); @@ -809,6 +815,174 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i return gf; } +static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_image_f32 & img) { + const auto & model = ctx->vision_model; + const auto & hparams = model.hparams; + + const int patch_size = hparams.patch_size; + const int num_patches = ((img.nx / patch_size) * (img.ny / patch_size)); + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + const int n_layer = hparams.n_layer; + const float eps = hparams.eps; + + struct ggml_init_params params = { + /*.mem_size =*/ ctx->buf_compute_meta.size(), + /*.mem_buffer =*/ ctx->buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx0_ptr(ggml_init(params)); + auto ctx0 = ctx0_ptr.get(); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + // input raw + struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + // 2D input positions + struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + inp = ggml_add(ctx0, inp, model.patch_bias); + + // position embeddings + struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings); + + // loop over layers + for (int il = 0; il < n_layer; il++) { + struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + + // layernorm1 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b); + } + + // self-attention + { + + struct ggml_tensor * Q = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); + + Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches); + Q = build_rope_2d(ctx0, Q, pos_w, pos_h, hparams.rope_theta, false); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + + struct ggml_tensor * K = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); + + K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches); + K = build_rope_2d(ctx0, K, pos_w, pos_h, hparams.rope_theta, false); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + + struct ggml_tensor * V = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b); + + V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches); + } + + // attention output + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, embeddings); + + embeddings = cur; // embeddings = residual, cur = hidden_states + + // layernorm2 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + + if (ctx->use_silu) { + cur = ggml_silu(ctx0, cur); + } else if (ctx->use_gelu) { + cur = ggml_gelu(ctx0, cur); + } else { + GGML_ABORT("llama4: Unsupported activation"); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + + // residual 2 + cur = ggml_add(ctx0, embeddings, cur); + + // norm output + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].post_ffn_norm_w), model.layers[il].post_ffn_norm_b); + } + + embeddings = cur; + } + + // post-layernorm + if (model.post_ln_w) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); + } + + // Llama4VisionPixelShuffleMLP + { + ggml_tensor * cur = embeddings; + const int scale_factor = model.hparams.proj_scale_factor; + const int n_embd = cur->ne[0]; + const int seq = cur->ne[1]; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int height = std::sqrt(seq); + const int width = std::sqrt(seq); + GGML_ASSERT(scale_factor != 0); + cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + height / scale_factor, + width / scale_factor, + bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + seq / (scale_factor * scale_factor), + bsz); + + cur = ggml_mul_mat(ctx0, model.projection, cur); + embeddings = cur; + } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + return gf; +} + static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) { const auto & model = ctx->vision_model; const auto & hparams = model.hparams; @@ -1599,6 +1773,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = clip_image_build_graph_qwen25vl(ctx, imgs); } break; + case PROJECTOR_TYPE_LLAMA4: + { + res = clip_image_build_graph_llama4(ctx, *imgs.entries[0]); + } break; default: { // TODO: we should have one build_* function per model @@ -1781,6 +1959,10 @@ struct clip_model_loader { { get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern); } break; + case PROJECTOR_TYPE_LLAMA4: + { + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor); + } break; default: break; } @@ -1867,6 +2049,9 @@ struct clip_model_loader { layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false); layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false); + layer.post_ffn_norm_b = get_tensor(string_format(TN_FFN_POST_NORM, "v", il, "bias"), false); + layer.post_ffn_norm_w = get_tensor(string_format(TN_FFN_POST_NORM, "v", il, "weight"), false); + // new naming layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight")); layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false); @@ -2008,6 +2193,12 @@ struct clip_model_loader { vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); } break; + case PROJECTOR_TYPE_LLAMA4: + { + vision_model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); + vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); + } break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -2796,7 +2987,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE || ctx->proj_type == PROJECTOR_TYPE_GEMMA3 - || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { + || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 + || ctx->proj_type == PROJECTOR_TYPE_LLAMA4) { clip_image_u8 resized_image; int sz = params.image_size; image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz}); @@ -2968,7 +3160,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches = x_patch * y_patch; } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { n_patches = 256; - } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { + } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_LLAMA4) { n_patches /= ctx->vision_model.hparams.proj_scale_factor; } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { int n_merge = ctx->vision_model.hparams.spatial_merge_size; @@ -3550,6 +3742,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_GEMMA3: return ctx->vision_model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: + case PROJECTOR_TYPE_LLAMA4: return ctx->vision_model.projection->ne[1]; default: GGML_ABORT("Unknown projector type"); From 7341e70995ca8ac18aeffc46744ac0ce69856112 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 3 May 2025 09:47:29 +0200 Subject: [PATCH 06/30] try this --- tools/llava/clip.cpp | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/tools/llava/clip.cpp b/tools/llava/clip.cpp index 9c21b22bfc242..241e079befd6d 100644 --- a/tools/llava/clip.cpp +++ b/tools/llava/clip.cpp @@ -821,6 +821,7 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im const int patch_size = hparams.patch_size; const int num_patches = ((img.nx / patch_size) * (img.ny / patch_size)); + const int num_pos = num_patches + 1; // +1 for [CLS] const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; @@ -843,19 +844,23 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im ggml_set_name(inp_raw, "inp_raw"); ggml_set_input(inp_raw); + // create patches + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + inp = ggml_add(ctx0, inp, model.patch_bias); + + // add CLS + inp_raw = ggml_concat(ctx0, inp_raw, model.class_embedding, 0); + // 2D input positions - struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); + struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos); ggml_set_name(pos_h, "pos_h"); ggml_set_input(pos_h); - struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); + struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos); ggml_set_name(pos_w, "pos_w"); ggml_set_input(pos_w); - struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); - inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); - inp = ggml_add(ctx0, inp, model.patch_bias); - // position embeddings struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings); @@ -1961,6 +1966,7 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_LLAMA4: { + hparams.rope_theta = 10000.0f; get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor); } break; default: @@ -3558,6 +3564,23 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima { // do nothing } break; + case PROJECTOR_TYPE_LLAMA4: + { + // set the 2D positions + int n_patches_per_col = image_size_width / patch_size; + std::vector pos_data(num_patches + 1, 0); // +1 for the [CLS] token + // last pos is always kept 0, it's for CLS + // dimension H + for (int i = 0; i < num_patches; i++) { + pos_data[i] = i / n_patches_per_col; + } + set_input_i32("pos_h", pos_data); + // dimension W + for (int i = 0; i < num_patches; i++) { + pos_data[i] = i % n_patches_per_col; + } + set_input_i32("pos_w", pos_data); + } break; default: GGML_ABORT("Unknown projector type"); } From 893ad9c23ad9f4a668a233d69c3e73952082099f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 3 May 2025 09:54:04 +0200 Subject: [PATCH 07/30] reshape patch_embeddings_0 --- tools/llava/clip.cpp | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/tools/llava/clip.cpp b/tools/llava/clip.cpp index 241e079befd6d..e026fdb8b66fe 100644 --- a/tools/llava/clip.cpp +++ b/tools/llava/clip.cpp @@ -828,7 +828,7 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im const int n_layer = hparams.n_layer; const float eps = hparams.eps; - struct ggml_init_params params = { + ggml_init_params params = { /*.mem_size =*/ ctx->buf_compute_meta.size(), /*.mem_buffer =*/ ctx->buf_compute_meta.data(), /*.no_alloc =*/ true, @@ -837,15 +837,20 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im ggml_context_ptr ctx0_ptr(ggml_init(params)); auto ctx0 = ctx0_ptr.get(); - struct ggml_cgraph * gf = ggml_new_graph(ctx0); + ggml_cgraph * gf = ggml_new_graph(ctx0); // input raw - struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3); + ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3); ggml_set_name(inp_raw, "inp_raw"); ggml_set_input(inp_raw); // create patches - struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + ggml_tensor * patch_embd_view = ggml_view_4d(ctx0, model.patch_embeddings_0, + hidden_size, patch_size, patch_size, 3, + ggml_row_size(model.patch_embeddings_0->type, hidden_size), + ggml_row_size(model.patch_embeddings_0->type, hidden_size * patch_size), + ggml_row_size(model.patch_embeddings_0->type, hidden_size * patch_size * 3), 0); + ggml_tensor * inp = ggml_conv_2d(ctx0, patch_embd_view, inp_raw, patch_size, patch_size, 0, 0, 1, 1); inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); inp = ggml_add(ctx0, inp, model.patch_bias); @@ -854,19 +859,19 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im inp_raw = ggml_concat(ctx0, inp_raw, model.class_embedding, 0); // 2D input positions - struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos); + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos); ggml_set_name(pos_h, "pos_h"); ggml_set_input(pos_h); - struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos); + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos); ggml_set_name(pos_w, "pos_w"); ggml_set_input(pos_w); // position embeddings - struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings); + ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings); // loop over layers for (int il = 0; il < n_layer; il++) { - struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states // layernorm1 { @@ -877,30 +882,30 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im // self-attention { - struct ggml_tensor * Q = + ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches); Q = build_rope_2d(ctx0, Q, pos_w, pos_h, hparams.rope_theta, false); Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); - struct ggml_tensor * K = + ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches); K = build_rope_2d(ctx0, K, pos_w, pos_h, hparams.rope_theta, false); K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); - struct ggml_tensor * V = + ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b); V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches); V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); From 15605e4635dac509fb80d6c0b99f28592d218641 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 3 May 2025 09:55:22 +0200 Subject: [PATCH 08/30] fix view --- tools/llava/clip.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/llava/clip.cpp b/tools/llava/clip.cpp index e026fdb8b66fe..e503c6f9fbabd 100644 --- a/tools/llava/clip.cpp +++ b/tools/llava/clip.cpp @@ -846,10 +846,10 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im // create patches ggml_tensor * patch_embd_view = ggml_view_4d(ctx0, model.patch_embeddings_0, - hidden_size, patch_size, patch_size, 3, - ggml_row_size(model.patch_embeddings_0->type, hidden_size), - ggml_row_size(model.patch_embeddings_0->type, hidden_size * patch_size), - ggml_row_size(model.patch_embeddings_0->type, hidden_size * patch_size * 3), 0); + patch_size, patch_size, 3, hidden_size, + ggml_row_size(model.patch_embeddings_0->type, patch_size), + ggml_row_size(model.patch_embeddings_0->type, patch_size * patch_size), + ggml_row_size(model.patch_embeddings_0->type, patch_size * patch_size * 3), 0); ggml_tensor * inp = ggml_conv_2d(ctx0, patch_embd_view, inp_raw, patch_size, patch_size, 0, 0, 1, 1); inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); From 32a62d1fc053b720c33cc4f23a4b4b1def630af5 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 3 May 2025 10:02:07 +0200 Subject: [PATCH 09/30] rm ffn_post_norm --- gguf-py/gguf/constants.py | 3 --- tools/llava/clip-impl.h | 3 +-- tools/llava/clip.cpp | 15 +-------------- 3 files changed, 2 insertions(+), 19 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 243d4aa5b1af6..947c6e019b6d4 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -492,7 +492,6 @@ class MODEL_TENSOR(IntEnum): V_ENC_FFN_UP = auto() V_ENC_FFN_GATE = auto() V_ENC_FFN_DOWN = auto() - V_ENC_FFN_POST_NORM = auto() V_PRE_NORM = auto() V_POST_NORM = auto() V_MM_INP_NORM = auto() @@ -751,7 +750,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down", - MODEL_TENSOR.V_ENC_FFN_POST_NORM: "v.blk.{bid}.ffn_post_norm", MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", MODEL_TENSOR.V_POST_NORM: "v.post_ln", MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", @@ -791,7 +789,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_ENC_FFN_UP, MODEL_TENSOR.V_ENC_FFN_GATE, MODEL_TENSOR.V_ENC_FFN_DOWN, - MODEL_TENSOR.V_ENC_FFN_POST_NORM, MODEL_TENSOR.V_PRE_NORM, MODEL_TENSOR.V_POST_NORM, MODEL_TENSOR.V_MM_INP_PROJ, diff --git a/tools/llava/clip-impl.h b/tools/llava/clip-impl.h index 3d1cf4e8bb3ca..ae4bd3a4f798f 100644 --- a/tools/llava/clip-impl.h +++ b/tools/llava/clip-impl.h @@ -47,7 +47,7 @@ // tensor name constants // -#define TN_POS_EMBD "%s.position_embd.weight" +#define TN_POS_EMBD "v.position_embd.weight" #define TN_CLASS_EMBD "v.class_embd" #define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat #define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" @@ -60,7 +60,6 @@ #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s" #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" -#define TN_FFN_POST_NORM "%s.blk.%d.ffn_post_norm.%s" #define TN_LN_1 "%s.blk.%d.ln1.%s" #define TN_LN_2 "%s.blk.%d.ln2.%s" #define TN_LN_PRE "%s.pre_ln.%s" diff --git a/tools/llava/clip.cpp b/tools/llava/clip.cpp index e503c6f9fbabd..144bbb1206c80 100644 --- a/tools/llava/clip.cpp +++ b/tools/llava/clip.cpp @@ -209,10 +209,6 @@ struct clip_layer { struct ggml_tensor * ff_gate_b = nullptr; struct ggml_tensor * ff_down_w = nullptr; struct ggml_tensor * ff_down_b = nullptr; - - // post-ffn norm (output layer norm) - struct ggml_tensor * post_ffn_norm_w = nullptr; - struct ggml_tensor * post_ffn_norm_b = nullptr; }; struct clip_vision_model { @@ -943,12 +939,6 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im // residual 2 cur = ggml_add(ctx0, embeddings, cur); - // norm output - { - cur = ggml_norm(ctx0, cur, eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].post_ffn_norm_w), model.layers[il].post_ffn_norm_b); - } - embeddings = cur; } @@ -2041,7 +2031,7 @@ struct clip_model_loader { vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); - vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false); + vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false); // layers vision_model.layers.resize(vision_model.hparams.n_layer); @@ -2060,9 +2050,6 @@ struct clip_model_loader { layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false); layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false); - layer.post_ffn_norm_b = get_tensor(string_format(TN_FFN_POST_NORM, "v", il, "bias"), false); - layer.post_ffn_norm_w = get_tensor(string_format(TN_FFN_POST_NORM, "v", il, "weight"), false); - // new naming layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight")); layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false); From 97a5cd13dfa8c59d786c58b5ebc3df1f2894bf5e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 3 May 2025 10:31:50 +0200 Subject: [PATCH 10/30] cgraph ok --- tools/llava/clip.cpp | 77 ++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/tools/llava/clip.cpp b/tools/llava/clip.cpp index 144bbb1206c80..f834fef95e919 100644 --- a/tools/llava/clip.cpp +++ b/tools/llava/clip.cpp @@ -816,7 +816,9 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im const auto & hparams = model.hparams; const int patch_size = hparams.patch_size; - const int num_patches = ((img.nx / patch_size) * (img.ny / patch_size)); + const int px = img.nx / patch_size; + const int py = img.ny / patch_size; + const int num_patches = px * py; const int num_pos = num_patches + 1; // +1 for [CLS] const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; @@ -849,10 +851,9 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im ggml_tensor * inp = ggml_conv_2d(ctx0, patch_embd_view, inp_raw, patch_size, patch_size, 0, 0, 1, 1); inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); - inp = ggml_add(ctx0, inp, model.patch_bias); // add CLS - inp_raw = ggml_concat(ctx0, inp_raw, model.class_embedding, 0); + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); // 2D input positions ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_pos); @@ -881,31 +882,31 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); - Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches); + Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_pos); Q = build_rope_2d(ctx0, Q, pos_w, pos_h, hparams.rope_theta, false); Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); - K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches); + K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_pos); K = build_rope_2d(ctx0, K, pos_w, pos_h, hparams.rope_theta, false); K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b); - V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches); + V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_pos); V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head); + KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_pos, n_head); KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches); + cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_pos); } // attention output @@ -922,8 +923,8 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b); if (ctx->use_silu) { cur = ggml_silu(ctx0, cur); @@ -933,8 +934,8 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im GGML_ABORT("llama4: Unsupported activation"); } - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b); // residual 2 cur = ggml_add(ctx0, embeddings, cur); @@ -950,33 +951,43 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); } - // Llama4VisionPixelShuffleMLP + // based on Llama4VisionPixelShuffleMLP + // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151 { ggml_tensor * cur = embeddings; + const int batch_size = 1; // always 1 for now since we don't support batching const int scale_factor = model.hparams.proj_scale_factor; - const int n_embd = cur->ne[0]; - const int seq = cur->ne[1]; - const int bsz = 1; // batch size, always 1 for now since we don't support batching - const int height = std::sqrt(seq); - const int width = std::sqrt(seq); GGML_ASSERT(scale_factor != 0); - cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz); + + // remove CLS by doing a view + cur = ggml_view_3d(ctx0, cur, + hidden_size, num_patches, batch_size, + ggml_row_size(cur->type, hidden_size), + ggml_row_size(cur->type, hidden_size * num_patches), 0); + + cur = ggml_reshape_3d(ctx0, cur, + hidden_size * scale_factor, + num_patches / scale_factor, + batch_size); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), - n_embd * scale_factor * scale_factor, - height / scale_factor, - width / scale_factor, - bsz); + hidden_size * scale_factor * scale_factor, + py / scale_factor, + px / scale_factor, + batch_size); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur), - n_embd * scale_factor * scale_factor, - seq / (scale_factor * scale_factor), - bsz); - cur = ggml_mul_mat(ctx0, model.projection, cur); + // based on Llama4VisionMLP2 (always uses GELU activation, no bias) + cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur); embeddings = cur; } + // based on Llama4MultiModalProjector + embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings); + // build the graph ggml_build_forward_expand(gf, embeddings); @@ -3135,6 +3146,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im const auto & params = ctx->vision_model.hparams; int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); + int scale_factor = ctx->vision_model.hparams.proj_scale_factor; if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { n_patches /= 4; @@ -3158,8 +3170,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches = x_patch * y_patch; } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { n_patches = 256; - } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_LLAMA4) { - n_patches /= ctx->vision_model.hparams.proj_scale_factor; + } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { + n_patches /= scale_factor; + } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) { + n_patches /= (scale_factor * scale_factor); } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { int n_merge = ctx->vision_model.hparams.spatial_merge_size; int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1); @@ -3757,8 +3771,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_GEMMA3: return ctx->vision_model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: - case PROJECTOR_TYPE_LLAMA4: return ctx->vision_model.projection->ne[1]; + case PROJECTOR_TYPE_LLAMA4: + return ctx->vision_model.mm_model_proj->ne[1]; default: GGML_ABORT("Unknown projector type"); } From c6c2d66fe8fe61cf5e0910004916b470db7de4f1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 3 May 2025 10:34:44 +0200 Subject: [PATCH 11/30] f32 for pos embd --- convert_hf_to_gguf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 876ee644ee321..1f168e8a129ba 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -308,6 +308,7 @@ def prepare_tensors(self): gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED, gguf.MODEL_TENSOR.POSNET_NORM1, gguf.MODEL_TENSOR.POSNET_NORM2, + gguf.MODEL_TENSOR.V_ENC_EMBD_POS, ) ) or not new_name.endswith(".weight") From 224cbbade8200d63736f3b0fdd05ffd0f7fb05a7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 3 May 2025 10:36:51 +0200 Subject: [PATCH 12/30] add image marker tokens --- tools/llava/mtmd.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tools/llava/mtmd.cpp b/tools/llava/mtmd.cpp index d1d7530feb625..ffc4ea2ad7038 100644 --- a/tools/llava/mtmd.cpp +++ b/tools/llava/mtmd.cpp @@ -203,13 +203,17 @@ int32_t mtmd_tokenize(mtmd_context * ctx, // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md marker_modified = ctx->image_marker + "[IMG_END]"; string_replace_all(prompt_modified, ctx->image_marker, marker_modified); - } - else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) { + } else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) { // <|vision_start|> ... (image embeddings) ... <|vision_end|> marker_modified = "<|vision_start|>" + ctx->image_marker + "<|vision_end|>"; string_replace_all(prompt_modified, ctx->image_marker, marker_modified); + } else if (proj_type == PROJECTOR_TYPE_LLAMA4) { + // <|image_start|> ... (image embeddings) ... <|image_end|> + marker_modified = "<|image_start|>" + ctx->image_marker + "<|image_end|>"; + string_replace_all(prompt_modified, ctx->image_marker, marker_modified); + } // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix From 9d1a4d658de5aa8a07b47e2db745bce61f8099f6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 3 May 2025 12:44:28 +0200 Subject: [PATCH 13/30] Llama4UnfoldConvolution --- tools/llava/clip.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tools/llava/clip.cpp b/tools/llava/clip.cpp index f834fef95e919..58fa8dbe22445 100644 --- a/tools/llava/clip.cpp +++ b/tools/llava/clip.cpp @@ -842,15 +842,15 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im ggml_set_name(inp_raw, "inp_raw"); ggml_set_input(inp_raw); - // create patches - ggml_tensor * patch_embd_view = ggml_view_4d(ctx0, model.patch_embeddings_0, - patch_size, patch_size, 3, hidden_size, - ggml_row_size(model.patch_embeddings_0->type, patch_size), - ggml_row_size(model.patch_embeddings_0->type, patch_size * patch_size), - ggml_row_size(model.patch_embeddings_0->type, patch_size * patch_size * 3), 0); - ggml_tensor * inp = ggml_conv_2d(ctx0, patch_embd_view, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); - inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + // Llama4UnfoldConvolution + ggml_tensor * inp; + { + ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0, + patch_size, patch_size, 3, hidden_size); + inp = ggml_im2col(ctx0, kernel, inp_raw, patch_size, patch_size, 0, 0, 1, 1, true, inp_raw->type); + inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); + inp = ggml_reshape_2d(ctx0, inp, hidden_size, num_patches); + } // add CLS inp = ggml_concat(ctx0, inp, model.class_embedding, 1); @@ -3578,12 +3578,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // last pos is always kept 0, it's for CLS // dimension H for (int i = 0; i < num_patches; i++) { - pos_data[i] = i / n_patches_per_col; + pos_data[i] = (i / n_patches_per_col) + 1; } set_input_i32("pos_h", pos_data); // dimension W for (int i = 0; i < num_patches; i++) { - pos_data[i] = i % n_patches_per_col; + pos_data[i] = (i % n_patches_per_col) + 1; } set_input_i32("pos_w", pos_data); } break; From 532c33210e22225bbf3c4c50ca689223fad3a34b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 3 May 2025 13:47:33 +0200 Subject: [PATCH 14/30] correct pixel shuffle --- tools/llava/clip.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tools/llava/clip.cpp b/tools/llava/clip.cpp index 58fa8dbe22445..01ee9a6091b55 100644 --- a/tools/llava/clip.cpp +++ b/tools/llava/clip.cpp @@ -965,23 +965,30 @@ static ggml_cgraph * clip_image_build_graph_llama4(clip_ctx * ctx, const clip_im ggml_row_size(cur->type, hidden_size), ggml_row_size(cur->type, hidden_size * num_patches), 0); - cur = ggml_reshape_3d(ctx0, cur, + cur = ggml_reshape_4d(ctx0, cur, hidden_size * scale_factor, - num_patches / scale_factor, + px / scale_factor, + py, batch_size); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), hidden_size * scale_factor * scale_factor, - py / scale_factor, px / scale_factor, + py / scale_factor, batch_size); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur), + hidden_size * scale_factor * scale_factor, + num_patches / scale_factor / scale_factor, + batch_size); + // based on Llama4VisionMLP2 (always uses GELU activation, no bias) cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur); cur = ggml_gelu(ctx0, cur); cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur); + cur = ggml_gelu(ctx0, cur); embeddings = cur; } From 8caeed5c3dfe16432edc25f6db4c3d48a94f24a6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 00:40:35 +0200 Subject: [PATCH 15/30] fix merge conflicts --- tools/mtmd/clip.cpp | 105 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 94 insertions(+), 11 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 23b2ab93bccf9..8b4c18c8d123c 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -194,10 +194,6 @@ struct clip_hparams { }; struct clip_layer { - // layernorm 1 (input norm) - struct ggml_tensor * ln_1_w = nullptr; - struct ggml_tensor * ln_1_b = nullptr; - // attention ggml_tensor * k_w = nullptr; ggml_tensor * k_b = nullptr; @@ -526,7 +522,7 @@ struct clip_graph { ggml_set_input(pos_w); auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { - return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta); + return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true); }; ggml_tensor * inp = build_inp(); @@ -940,6 +936,90 @@ struct clip_graph { return gf; } + ggml_cgraph * build_llama4() { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + + const int n_pos = n_patches + 1; + + // 2D input positions + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + ggml_tensor * inp = build_inp_raw(); + + // Llama4UnfoldConvolution + { + inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + cb(inp, "patch_conv", -1); + inp = ggml_add(ctx0, inp, model.patch_bias); + cb(inp, "patch_bias", -1); + } + + // add CLS token + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + + // build ViT with 2D position embeddings + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + }; + ggml_tensor * cur = build_vit( + inp, n_pos, + NORM_TYPE_NORMAL, + hparams.ffn_op, + model.position_embeddings, + add_pos); + + // remove CLS token + cur = ggml_view_2d(ctx0, cur, + n_embd, n_patches, + ggml_row_size(cur->type, n_embd), 0); + + // pixel shuffle + // based on Llama4VisionPixelShuffleMLP + // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151 + { + const int scale_factor = model.hparams.proj_scale_factor; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int height = n_patches_y; + const int width = n_patches_x; + GGML_ASSERT(scale_factor > 0); + GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images + cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + height / scale_factor, + width / scale_factor, + bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // flatten to 2D + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + cur->ne[1] * cur->ne[2]); + } + + // based on Llama4VisionMLP2 (always uses GELU activation, no bias) + { + cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur); + cur = ggml_gelu(ctx0, cur); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; + } + // this graph is used by llava, granite and glm // due to having embedding_stack (used by granite), we cannot reuse build_vit ggml_cgraph * build_llava() { @@ -1634,9 +1714,10 @@ struct clip_graph { static ggml_tensor * build_rope_2d( ggml_context * ctx0, ggml_tensor * cur, - ggml_tensor * pos_h, - ggml_tensor * pos_w, - const float freq_base + ggml_tensor * pos_a, // first half + ggml_tensor * pos_b, // second half + const float freq_base, + const bool interleave_freq ) { const int64_t n_dim = cur->ne[0]; const int64_t n_head = cur->ne[1]; @@ -1650,7 +1731,9 @@ struct clip_graph { // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2) // then for the second half, we use freq_scale to shift the inv_freq // ^ why? replace (2i) with (2i+1) in the above equation - const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim); + const float freq_scale_odd = interleave_freq + ? std::pow(freq_base, (float)-2/n_dim) + : 1.0; // first half ggml_tensor * first; @@ -1663,7 +1746,7 @@ struct clip_graph { first = ggml_rope_ext( ctx0, first, - pos_h, // positions + pos_a, // positions nullptr, // freq factors n_dim/2, // n_dims 0, 0, freq_base, @@ -1683,7 +1766,7 @@ struct clip_graph { second = ggml_rope_ext( ctx0, second, - pos_w, // positions + pos_b, // positions nullptr, // freq factors n_dim/2, // n_dims 0, 0, freq_base, From 2ffafd5d0ce68ec11bf64c7652fb71657c6565da Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 10:01:43 +0200 Subject: [PATCH 16/30] correct --- tools/mtmd/clip.cpp | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 8b4c18c8d123c..d302967027254 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -940,7 +940,7 @@ struct clip_graph { GGML_ASSERT(model.class_embedding != nullptr); GGML_ASSERT(model.position_embeddings != nullptr); - const int n_pos = n_patches + 1; + const int n_pos = n_patches + 1; // +1 for [CLS] // 2D input positions ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); @@ -955,17 +955,19 @@ struct clip_graph { // Llama4UnfoldConvolution { - inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); - inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); - cb(inp, "patch_conv", -1); - inp = ggml_add(ctx0, inp, model.patch_bias); - cb(inp, "patch_bias", -1); + ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0, + patch_size, patch_size, 3, n_embd); + inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type); + inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); + inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); } // add CLS token inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + // add position embeddings + inp = ggml_add(ctx0, inp, model.position_embeddings); + // build ViT with 2D position embeddings auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); @@ -988,22 +990,24 @@ struct clip_graph { { const int scale_factor = model.hparams.proj_scale_factor; const int bsz = 1; // batch size, always 1 for now since we don't support batching - const int height = n_patches_y; - const int width = n_patches_x; GGML_ASSERT(scale_factor > 0); GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images - cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); + cur = ggml_reshape_4d(ctx0, cur, + n_embd * scale_factor, + n_patches_x / scale_factor, + n_patches_y, + bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd * scale_factor * scale_factor, - height / scale_factor, - width / scale_factor, + n_patches_x / scale_factor, + n_patches_y / scale_factor, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // flatten to 2D cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), n_embd * scale_factor * scale_factor, - cur->ne[1] * cur->ne[2]); + n_patches / scale_factor / scale_factor); } // based on Llama4VisionMLP2 (always uses GELU activation, no bias) From 7d9d4e34655d0237053c2f851368ec7b237f1db6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 10:15:54 +0200 Subject: [PATCH 17/30] add debug_graph --- tools/mtmd/clip-impl.h | 65 ++++++++++++++++++++++++++++++++++++++++++ tools/mtmd/clip.cpp | 28 +++++++++++++++--- 2 files changed, 89 insertions(+), 4 deletions(-) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1a9e835158f79..82142e18b2eb7 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -360,6 +361,70 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { } } +// +// debugging +// + +static void print_tensor_shape(ggml_tensor * t) { + printf("%s.shape = [", t->name); + for (int i = 0; i < ggml_n_dims(t); ++i) { + printf("%" PRId64, t->ne[i]); + if (i < ggml_n_dims(t) - 1) { + printf(", "); + } + } + printf("]\n"); +} + +static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) { + ggml_type type = t->type; + int64_t * ne = t->ne; + size_t * nb = t->nb; + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + printf("%s.data: [\n", t->name); + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + if (i2 == n && ne[2] > 2*n) { + printf(" ..., \n"); + i2 = ne[2] - n; + } + printf(" [\n"); + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + if (i1 == n && ne[1] > 2*n) { + printf(" ..., \n"); + i1 = ne[1] - n; + } + printf(" ["); + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + if (i0 == n && ne[0] > 2*n) { + printf("..., "); + i0 = ne[0] - n; + } + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + float v; + if (type == GGML_TYPE_F16) { + v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); + } else if (type == GGML_TYPE_F32) { + v = *(float *) &data[i]; + } else if (type == GGML_TYPE_I32) { + v = (float) *(int32_t *) &data[i]; + } else if (type == GGML_TYPE_I16) { + v = (float) *(int16_t *) &data[i]; + } else if (type == GGML_TYPE_I8) { + v = (float) *(int8_t *) &data[i]; + } else { + GGML_ABORT("fatal error"); + } + printf("%8.4f", v); + if (i0 < ne[0] - 1) printf(", "); + } + printf("],\n"); + } + printf(" ],\n"); + } + printf(" ]\n"); + } +} + // // API used internally with mtmd // diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d302967027254..e2d49197093dc 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -361,7 +361,12 @@ struct clip_ctx { clip_image_size load_image_size; + // for debugging + bool debug_graph = false; + std::vector debug_print_tensors; + clip_ctx(clip_context_params & ctx_params) { + debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr; backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); if (!backend_cpu) { throw std::runtime_error("failed to initialize CPU backend"); @@ -1404,10 +1409,12 @@ struct clip_graph { // void cb(ggml_tensor * cur, const char * name, int il) const { - // TODO: implement this - GGML_UNUSED(cur); - GGML_UNUSED(name); - GGML_UNUSED(il); + if (ctx->debug_graph) { + std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name; + ggml_set_name(cur, cur_name.c_str()); + ggml_set_output(cur); + ctx->debug_print_tensors.push_back(cur); + } } // build vision transformer (ViT) cgraph @@ -3357,6 +3364,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } // build the inference graph + ctx->debug_print_tensors.clear(); ggml_backend_sched_reset(ctx->sched.get()); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); @@ -3675,6 +3683,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return false; } + // print debug nodes + if (ctx->debug_graph) { + LOG_INF("\n\n---\n\n"); + LOG_INF("\n\nDebug graph:\n\n"); + for (ggml_tensor * t : ctx->debug_print_tensors) { + std::vector data(ggml_nbytes(t)); + ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t)); + print_tensor_shape(t); + print_tensor_data(t, data.data(), 3); + } + } + // the last node is the embedding tensor ggml_tensor * embeddings = ggml_graph_node(gf, -1); From 919318e3da17947f82720d81fe81abb4b4201881 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 13:21:52 +0200 Subject: [PATCH 18/30] logits matched, but it still preceives the image incorrectly --- tools/mtmd/clip.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index e2d49197093dc..cab8fb9065c1d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -445,7 +445,7 @@ struct clip_graph { }; ctx0_ptr.reset(ggml_init(params)); ctx0 = ctx0_ptr.get(); - gf = ggml_new_graph(ctx0); + gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false); } ggml_cgraph * build_siglip() { @@ -965,14 +965,12 @@ struct clip_graph { inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type); inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); + cb(inp, "patch_conv", -1); } // add CLS token inp = ggml_concat(ctx0, inp, model.class_embedding, 1); - // add position embeddings - inp = ggml_add(ctx0, inp, model.position_embeddings); - // build ViT with 2D position embeddings auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); @@ -1013,6 +1011,7 @@ struct clip_graph { cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), n_embd * scale_factor * scale_factor, n_patches / scale_factor / scale_factor); + cb(cur, "pixel_shuffle", -1); } // based on Llama4VisionMLP2 (always uses GELU activation, no bias) @@ -1021,8 +1020,13 @@ struct clip_graph { cur = ggml_gelu(ctx0, cur); cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur); cur = ggml_gelu(ctx0, cur); + cb(cur, "adapter_mlp", -1); } + // Llama4MultiModalProjector + cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur); + cb(cur, "projected", -1); + // build the graph ggml_build_forward_expand(gf, cur); @@ -1408,11 +1412,13 @@ struct clip_graph { // utility functions // - void cb(ggml_tensor * cur, const char * name, int il) const { + void cb(ggml_tensor * cur0, const char * name, int il) const { if (ctx->debug_graph) { + ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0)); std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name; ggml_set_name(cur, cur_name.c_str()); ggml_set_output(cur); + ggml_build_forward_expand(gf, cur); ctx->debug_print_tensors.push_back(cur); } } From 5b81972e7a6c07eb92853d28486e7d3a9beab384 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 14:54:26 +0200 Subject: [PATCH 19/30] fix style --- convert_hf_to_gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 5b79d985d3102..15e019a10f253 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2113,7 +2113,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] - @ModelBase.register("Mistral3ForConditionalGeneration") class Mistral3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.LLAMA From b74122f0871dc8b46ceea3106bfe70b9fe57b7a0 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 15:02:10 +0200 Subject: [PATCH 20/30] add image_grid_pinpoints --- tools/mtmd/clip.cpp | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index cab8fb9065c1d..994165cf223f4 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2038,6 +2038,16 @@ struct clip_model_loader { { hparams.rope_theta = 10000.0f; get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor); + + // borrowed from llava-1.6 + const int psize = hparams.patch_size; + hparams.image_grid_pinpoints = { + psize, psize*2, // 336, 672 + psize*2, psize, // 672, 336 + psize*2, psize*2, // 672, 672 + psize*3, psize, // 1008, 336 + psize, psize*3, // 336, 1008 + }; } break; default: break; @@ -3091,8 +3101,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); res_imgs->entries.push_back(std::move(img_f32)); return true; - } - else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { + + } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { clip_image_u8 resized_image; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size); image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height); @@ -3100,6 +3110,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); res_imgs->entries.push_back(std::move(img_f32)); return true; + + } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) { + GGML_ASSERT(!params.image_grid_pinpoints.empty()); + auto const inst = llava_uhd::get_slice_instructions(ctx, original_size); + std::vector imgs = llava_uhd::slice_image(img, inst); + + for (size_t i = 0; i < imgs.size(); ++i) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(res)); + } + + return true; + } // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) From 4217d424ce0ceb48b7ac63c9082ce8e658ccc9d8 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 15:21:24 +0200 Subject: [PATCH 21/30] handle llama 4 preprocessing --- tools/mtmd/mtmd.cpp | 66 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index f13dd7f5e6fcd..875d5db9fcac0 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -64,15 +64,18 @@ struct mtmd_context { int n_threads; std::string image_marker; - // for minicpmv, we need special tokens in-between slices + // for llava-uhd style models, we need special tokens in-between slices mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE; llama_token tok_ov_img_start = LLAMA_TOKEN_NULL; // overview image llama_token tok_ov_img_end = LLAMA_TOKEN_NULL; // overview image llama_token tok_slices_start = LLAMA_TOKEN_NULL; // start of all slices llama_token tok_slices_end = LLAMA_TOKEN_NULL; // end of all slices - llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice - llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice + llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice start + llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice end + llama_token tok_sli_img_mid = LLAMA_TOKEN_NULL; // between 2 slices llama_token tok_row_end = LLAMA_TOKEN_NULL; // end of row + bool tok_row_end_trail = false; + bool ov_img_first = false; bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE @@ -96,6 +99,7 @@ struct mtmd_context { use_mrope = clip_is_qwen2vl(ctx_clip); + projector_type proj = clip_get_projector_type(ctx_clip); int minicpmv_version = clip_is_minicpmv(ctx_clip); if (minicpmv_version == 2) { // minicpmv 2.5 format: @@ -108,6 +112,8 @@ struct mtmd_context { tok_sli_img_start = tok_ov_img_start; tok_sli_img_end = tok_ov_img_end; tok_row_end = lookup_token("\n"); + tok_row_end_trail = false; // no trailing end-of-row token + ov_img_first = true; } else if (minicpmv_version == 3 || minicpmv_version == 4) { // minicpmv 2.6 format: @@ -118,9 +124,24 @@ struct mtmd_context { tok_sli_img_start = lookup_token(""); tok_sli_img_end = lookup_token(""); tok_row_end = lookup_token("\n"); + tok_row_end_trail = false; // no trailing end-of-row token + ov_img_first = true; } else if (minicpmv_version != 0) { GGML_ASSERT(false && "unsupported minicpmv version"); + } else if (proj == PROJECTOR_TYPE_LLAMA4) { + // llama 4 format: + // <|image_start|> + // (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|> + // (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|> + // ... <|tile_y_separator|> <-- trailing end-of-row token + // <|image|> (overview) <-- overview image is last + // <|image_end|> + tok_ov_img_start = lookup_token("<|image|>"); + tok_sli_img_mid = lookup_token("<|tile_x_separator|>"); + tok_row_end = lookup_token("<|tile_y_separator|>"); + tok_row_end_trail = true; // add trailing end-of-row token + ov_img_first = false; // overview image is last } } @@ -250,13 +271,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx, string_replace_all(prompt_modified, ctx->image_marker, marker_modified); } else if (proj_type == PROJECTOR_TYPE_LLAMA4) { - // <|image_start|> ... (image embeddings) ... <|image_end|> + // (more details in mtmd_context constructor) marker_modified = "<|image_start|>" + ctx->image_marker + "<|image_end|>"; string_replace_all(prompt_modified, ctx->image_marker, marker_modified); - } - - else if (proj_type == PROJECTOR_TYPE_INTERNVL) { + } else if (proj_type == PROJECTOR_TYPE_INTERNVL) { // ... (image embeddings) ... marker_modified = "" + ctx->image_marker + ""; string_replace_all(prompt_modified, ctx->image_marker, marker_modified); @@ -347,11 +366,19 @@ int32_t mtmd_tokenize(mtmd_context * ctx, auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id); GGML_ASSERT(chunks.size() > 0); - // add overview image - add_text_chunk({ctx->tok_ov_img_start}); - output->entries.emplace_back(std::move(chunks.front())); + auto ov_chunk = std::move(chunks.front()); chunks.erase(chunks.begin()); - add_text_chunk({ctx->tok_ov_img_end}); + + // add overview image (first) + if (ctx->ov_img_first) { + if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { + add_text_chunk({ctx->tok_ov_img_start}); + } + output->entries.emplace_back(std::move(ov_chunk)); + if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { + add_text_chunk({ctx->tok_ov_img_end}); + } + } // add slices if (!chunks.empty()) { @@ -364,6 +391,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, } for (int y = 0; y < n_row; y++) { for (int x = 0; x < n_col; x++) { + const bool is_last_in_row = (x == n_col - 1); if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_sli_img_start}); } @@ -371,8 +399,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx, if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_sli_img_end}); } + if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) { + add_text_chunk({ctx->tok_sli_img_mid}); + } } - if (ctx->tok_row_end != LLAMA_TOKEN_NULL && y != n_row - 1) { + if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_row_end}); } } @@ -381,6 +412,17 @@ int32_t mtmd_tokenize(mtmd_context * ctx, } } + // add overview image (last) + if (!ctx->ov_img_first) { + if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { + add_text_chunk({ctx->tok_ov_img_start}); + } + output->entries.emplace_back(std::move(ov_chunk)); + if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { + add_text_chunk({ctx->tok_ov_img_end}); + } + } + } else { size_t n_tokens = 0; for (const auto & entry : batch_f32.entries) { From 3645fe0ba1ae14155b83a9b3e4224c5117fe5eda Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 15:29:01 +0200 Subject: [PATCH 22/30] rm load_image_size --- tools/mtmd/clip-impl.h | 5 +++++ tools/mtmd/clip.cpp | 29 +++++++++-------------------- tools/mtmd/clip.h | 4 ---- tools/mtmd/mtmd.cpp | 26 ++++++++++++-------------- 4 files changed, 26 insertions(+), 38 deletions(-) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 82142e18b2eb7..7b7d2df39622c 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -243,6 +243,11 @@ struct clip_image_u8_batch { struct clip_image_f32_batch { std::vector entries; + // for llava-uhd style models, we need to know the grid size + // note: entries.size() == grid_x * grid_y + 1 (one overview image) + int grid_x = 0; + int grid_y = 0; + clip_image_f32_batch clone() const { clip_image_f32_batch new_batch; new_batch.entries.reserve(entries.size()); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 994165cf223f4..68580eef004bf 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -359,8 +359,6 @@ struct clip_ctx { int max_nodes = 8192; ggml_backend_sched_ptr sched; - clip_image_size load_image_size; - // for debugging bool debug_graph = false; std::vector debug_print_tensors; @@ -2457,14 +2455,6 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p return ctx_clip; } -void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) { - ctx_clip->load_image_size = *load_image_size; // copy -} - -struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) { - return &ctx_clip->load_image_size; -} - struct clip_image_size * clip_image_size_init() { struct clip_image_size * load_image_size = new struct clip_image_size(); load_image_size->width = 448; @@ -3045,12 +3035,6 @@ struct llava_uhd { } }; -// TODO @ngxson : decprecate the load_image_size singleton pattern -int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { - const auto inst = llava_uhd::get_slice_instructions(ctx_clip, ctx_clip->load_image_size); - return inst.grid_size.width; -} - // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { @@ -3072,9 +3056,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); res_imgs->entries.push_back(std::move(res)); } + + res_imgs->grid_x = inst.grid_size.width; + res_imgs->grid_y = inst.grid_size.height; return true; - } - else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { + + } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { clip_image_u8 resized; auto patch_size = params.patch_size * 2; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size); @@ -3122,6 +3109,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(res)); } + res_imgs->grid_x = inst.grid_size.width; + res_imgs->grid_y = inst.grid_size.height; return true; } @@ -3409,8 +3398,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int n_pos = num_patches + (model.class_embedding ? 1 : 0); - const int pos_w = ctx->load_image_size.width / patch_size; - const int pos_h = ctx->load_image_size.height / patch_size; + const int pos_w = image_size_width / patch_size; + const int pos_h = image_size_height / patch_size; const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 2d70eec94736f..e7a1c0782dd6a 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -47,10 +47,6 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * // this should be equal to the embedding dimension of the text model int clip_n_mmproj_embd(const struct clip_ctx * ctx); -int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); -void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); -struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip); - struct clip_image_size * clip_image_size_init(void); struct clip_image_u8 * clip_image_u8_init (void); struct clip_image_f32 * clip_image_f32_init(void); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 875d5db9fcac0..4f39c8e9a5c2b 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -42,6 +42,7 @@ enum mtmd_slice_tmpl { MTMD_SLICE_TMPL_NONE, MTMD_SLICE_TMPL_MINICPMV_2_5, MTMD_SLICE_TMPL_MINICPMV_2_6, + MTMD_SLICE_TMPL_LLAMA4, // TODO @ngxson : add support for idefics (SmolVLM) }; @@ -65,6 +66,7 @@ struct mtmd_context { std::string image_marker; // for llava-uhd style models, we need special tokens in-between slices + // minicpmv calls them "slices", llama 4 calls them "tiles" mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE; llama_token tok_ov_img_start = LLAMA_TOKEN_NULL; // overview image llama_token tok_ov_img_end = LLAMA_TOKEN_NULL; // overview image @@ -137,6 +139,7 @@ struct mtmd_context { // ... <|tile_y_separator|> <-- trailing end-of-row token // <|image|> (overview) <-- overview image is last // <|image_end|> + slice_tmpl = MTMD_SLICE_TMPL_LLAMA4; tok_ov_img_start = lookup_token("<|image|>"); tok_sli_img_mid = lookup_token("<|tile_x_separator|>"); tok_row_end = lookup_token("<|tile_y_separator|>"); @@ -361,7 +364,12 @@ int32_t mtmd_tokenize(mtmd_context * ctx, return 2; } - if (ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6) { + // handle llava-uhd style preprocessing + if ( + ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 + || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6 + || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4 + ) { // split batch into chunks of single images auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id); GGML_ASSERT(chunks.size() > 0); @@ -380,12 +388,10 @@ int32_t mtmd_tokenize(mtmd_context * ctx, } } - // add slices + // add slices (or tiles) if (!chunks.empty()) { - clip_add_load_image_size(ctx->ctx_clip, &img_u8_size); - int n_col = clip_uhd_num_image_embeds_col(ctx->ctx_clip); - int n_row = (int)chunks.size() / n_col; - GGML_ASSERT(n_row * n_col == (int)chunks.size()); + const int n_col = batch_f32.grid_x; + const int n_row = batch_f32.grid_y; if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_slices_start}); } @@ -473,14 +479,6 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); bool ok = false; - // only effective for minicpmv and qwen2vl, other models will ignore load_image_size - { - clip_image_size slice_size{ - image_tokens->batch_f32.entries[0]->nx, - image_tokens->batch_f32.entries[0]->ny}; - clip_add_load_image_size(ctx->ctx_clip, &slice_size); - } - if (clip_is_llava(ctx->ctx_clip) || clip_is_minicpmv(ctx->ctx_clip) || clip_is_glm(ctx->ctx_clip)) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; From 791f10695d267c25574840af7086007f7569b41d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 15:30:53 +0200 Subject: [PATCH 23/30] rm unused line --- tools/mtmd/mtmd.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 4f39c8e9a5c2b..1234dbb468767 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -354,7 +354,6 @@ int32_t mtmd_tokenize(mtmd_context * ctx, img_u8->ny = bitmaps[i_img]->ny; img_u8->buf.resize(bitmaps[i_img]->data.size()); std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3); - clip_image_size img_u8_size{img_u8->nx, img_u8->ny}; // preprocess image clip_image_f32_batch batch_f32; From 8646240a6cdf4ab651f756779543a8d00c00d94c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 15:48:49 +0200 Subject: [PATCH 24/30] fix --- tools/mtmd/clip.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 68580eef004bf..e370f21871b74 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2038,13 +2038,13 @@ struct clip_model_loader { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor); // borrowed from llava-1.6 - const int psize = hparams.patch_size; + const int isize = hparams.patch_size; hparams.image_grid_pinpoints = { - psize, psize*2, // 336, 672 - psize*2, psize, // 672, 336 - psize*2, psize*2, // 672, 672 - psize*3, psize, // 1008, 336 - psize, psize*3, // 336, 1008 + isize, isize*2, // 336, 672 + isize*2, isize, // 672, 336 + isize*2, isize*2, // 672, 672 + isize*3, isize, // 1008, 336 + isize, isize*3, // 336, 1008 }; } break; default: @@ -2968,7 +2968,7 @@ struct llava_uhd { // used by llava 1.6 with custom list of pinpoints static clip_image_size select_best_resolution(const std::vector & pinpoints, const clip_image_size & original_size) { - std::vector possible_resolutions; + std::vector possible_resolutions; // TODO @ngxson : construct this inside hparams, not here for (size_t i = 0; i < pinpoints.size(); i += 2) { possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]}); } @@ -3077,7 +3077,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE || ctx->proj_type == PROJECTOR_TYPE_GEMMA3 || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 - || ctx->proj_type == PROJECTOR_TYPE_LLAMA4 || ctx->proj_type == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution ) { clip_image_u8 resized_image; From 53fb622328da128819b362afaebf1cb2086ca1b1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 15:49:50 +0200 Subject: [PATCH 25/30] small fix 2 --- tools/mtmd/clip.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index e370f21871b74..01f85901926f2 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2038,7 +2038,7 @@ struct clip_model_loader { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor); // borrowed from llava-1.6 - const int isize = hparams.patch_size; + const int isize = hparams.image_size; hparams.image_grid_pinpoints = { isize, isize*2, // 336, 672 isize*2, isize, // 672, 336 From f083c1923fcba74703fbc0be267ee3be13c42256 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 16:26:55 +0200 Subject: [PATCH 26/30] add test & docs --- docs/multimodal.md | 3 +++ tools/mtmd/tests.sh | 1 + 2 files changed, 4 insertions(+) diff --git a/docs/multimodal.md b/docs/multimodal.md index 80014ba1cef6d..054778e91f297 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -74,4 +74,7 @@ NOTE: some models may require large context window, for example: `-c 8192` (tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF (tool_name) -hf ggml-org/InternVL3-8B-Instruct-GGUF (tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF + +# Llama 4 Scout +(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF ``` diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 05ac7a04d8fce..9e6029c04120c 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -64,6 +64,7 @@ if [ "$RUN_BIG_TESTS" = true ]; then add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M" # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big + add_test "llama-mtmd-cli" "ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S" # too big fi # these models always give the wrong answer, not sure why From b199e70dc8858c507f8164dda90ff3b4f385328e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 16:36:09 +0200 Subject: [PATCH 27/30] fix llava-1.6 test --- tools/mtmd/tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 9e6029c04120c..c10eb6a27c204 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -42,7 +42,7 @@ add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0" add_test "llama-mtmd-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M" add_test "llama-mtmd-cli" "second-state/Llava-v1.5-7B-GGUF:Q2_K" "vicuna" -add_test "llama-mtmd-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K" "vicuna" +add_test "llama-mtmd-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K_M" "vicuna" add_test "llama-mtmd-cli" "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K" From e52481baf22fd5fd77db20b2d035978230b099b3 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 18 May 2025 16:41:03 +0200 Subject: [PATCH 28/30] test: add notion of huge models --- tools/mtmd/tests.sh | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index c10eb6a27c204..15a37b0d22bb4 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -21,6 +21,13 @@ if [ "${1:-}" = "big" ]; then echo "Include BIG models..." fi +RUN_HUGE_TESTS=false +if [ "${1:-}" = "huge" ]; then + RUN_HUGE_TESTS=true + RUN_BIG_TESTS=true + echo "Include BIG models..." +fi + ############### arr_bin=() @@ -60,11 +67,17 @@ if [ "$RUN_BIG_TESTS" = true ]; then add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M" - add_test "llama-mtmd-cli" "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M" - add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M" # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra - # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big - add_test "llama-mtmd-cli" "ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S" # too big +fi + +# to test the huge models, run: ./tests.sh huge +# this will run both the big and huge models +# huge models are > 32B parameters +if [ "$RUN_HUGE_TESTS" = true ]; then + add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S" fi # these models always give the wrong answer, not sure why From 186d7a88853faca40681a12fcae50f41321bd227 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 19 May 2025 10:48:47 +0200 Subject: [PATCH 29/30] add comment --- tools/mtmd/clip.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 01f85901926f2..7eba741febb2d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -219,7 +219,7 @@ struct clip_layer { ggml_tensor * ff_down_w = nullptr; ggml_tensor * ff_down_b = nullptr; - // layernorm 2 (post-attn norm / pre-ffn norm) + // layernorm 2 ggml_tensor * ln_2_w = nullptr; ggml_tensor * ln_2_b = nullptr; @@ -971,6 +971,9 @@ struct clip_graph { // build ViT with 2D position embeddings auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + // first half is X axis and second half is Y axis + // ref: https://github.com/huggingface/transformers/blob/40a493c7ed4f19f08eadb0639cf26d49bfa5e180/src/transformers/models/llama4/modeling_llama4.py#L1312 + // ref: https://github.com/Blaizzy/mlx-vlm/blob/a57156aa87b33cca6e5ee6cfc14dd4ef8f611be6/mlx_vlm/models/llama4/vision.py#L441 return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); }; ggml_tensor * cur = build_vit( @@ -990,7 +993,7 @@ struct clip_graph { // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151 { const int scale_factor = model.hparams.proj_scale_factor; - const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int bsz = 1; // batch size, always 1 for now since we don't support batching GGML_ASSERT(scale_factor > 0); GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images cur = ggml_reshape_4d(ctx0, cur, From d5e50aa09b1eeb0525ff1dffe31523aefa2c29de Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 19 May 2025 10:50:59 +0200 Subject: [PATCH 30/30] add warn about degraded quality --- tools/mtmd/clip.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 7eba741febb2d..eba07f6c82eba 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2070,6 +2070,10 @@ struct clip_model_loader { LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str()); LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0); + + if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) { + LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__); + } } }