Closed
Description
I reproduced the embeddings sample from GritLM and llama.cpp returns unexpected embedding values. I have been able to get embeddings to work with other models. I verified the tokenization and all seems good (with and without special tokens and bos/eos).
Below is a sample program to reproduce the issue with the unexpected results. The gritlm python inference code can be found here. The model inference supports both text generation and text representation (embeddings) and is based on Mistral 7B.
There is nothing special in this code, and this is based off the embeddings sample in llama.cpp so I'm not sure what is going. Any guidance is appreciated.
Note: It seems HF is in maintenance right now, but I'll add a gguf link when they're back online.
sample source code
static float dot_product(const std::vector<float>& v1, const std::vector<float>& v2) {
auto dot = 0.0f;
for (auto i = 0; i < v1.size(); ++i)
dot += v1[i] * v2[i];
return dot;
}
static float norm(const std::vector<float>& v) {
return std::sqrt(dot_product(v, v));
}
static float cosine_similarity(const std::vector<float>& v1, const std::vector<float>& v2) {
return dot_product(v1, v2) / (norm(v1) * norm(v2));
}
static void normalize(std::span<float> in, std::span<float> out) {
auto norm = 0.0f;
for (auto i = 0; i < in.size(); i++)
norm += in[i] * in[i];
norm = std::sqrt(norm);
for (auto i = 0; i < out.size(); i++)
out[i] = in[i] / norm;
}
static std::vector<std::vector<float>> encode(llama_context* ctx, const std::vector<std::string>& sentences, const std::string& instruction) {
auto result = std::vector<std::vector<float>>{};
auto mdl = llama_get_model(ctx);
for (auto i = 0; i < sentences.size(); i++) {
auto batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
// testing with and without EOS - unexpected embeddings in both cases - GritLM seems to have EOS = ""
auto inputs = llama_tokenize(mdl, std::format("{}{}", instruction, sentences[i]), true, false);
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L116
//inputs.push_back(llama_token_eos(mdl));
// debug tokens - these are matching as referenced in their sample so doesn't appear to be a token issue
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) { std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str()); });
std::printf("\n");
for (auto j = 0; j < inputs.size(); j++)
llama_batch_add(batch, inputs[j], j, { 0 }, false);
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_seq_rm(ctx, 0, 0, -1);
// run model
llama_decode(ctx, batch);
auto emb_unorm = std::span<float>(llama_get_embeddings_ith(ctx, 0), llama_n_embd(mdl));
auto emb_norm = std::vector<float>(emb_unorm.size());
normalize(emb_unorm, emb_norm);
result.push_back(emb_norm);
llama_batch_free(batch);
}
return result;
}
// ./embeddings -m ggml-gritlm-7b-q8_0.gguf -ngl 33
int main(int argc, char* argv[])
{
gpt_params params;
if (!gpt_params_parse(argc, argv, params))
return 1;
auto mparams = llama_model_params_from_gpt_params(params);
auto cparams = llama_context_params_from_gpt_params(params);
mparams.progress_callback = [](std::float_t progress, void* state) { std::printf("%s\rLoading model... %u%%\r", std::string(32, ' ').c_str(), static_cast<std::uint8_t>(progress * 100)); return true; };
cparams.embedding = true;
llama_backend_init();
auto mdl = llama_load_model_from_file(params.model.c_str(), mparams);
auto ctx = llama_new_context_with_model(mdl, cparams);
auto bat = llama_batch_init(llama_n_ctx(ctx), 0, 1);
// ### Embedding/Representation ### taken sample from here:
// https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic
{
auto instruction = std::string{ "Given a scientific paper title, retrieve the paper's abstract" };
auto queries = std::vector<std::string>{
"Bitcoin: A Peer-to-Peer Electronic Cash System",
"Generative Representational Instruction Tuning",
};
auto documents = std::vector<std::string>{
"A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.",
"All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.",
};
auto gritlm_instruction = [](const std::string& instruction) -> std::string { return !instruction.empty() ? "<|user|>\n" + instruction + "\n<|embed|>\n" : "<|embed|>\n"; };
// No need to add instruction for retrieval documents
auto d_rep = encode(ctx, documents, gritlm_instruction(""));
auto q_rep = encode(ctx, queries, gritlm_instruction(instruction));
auto cosine_sim_q0_d0 = 1 - cosine_similarity(q_rep[0], d_rep[0]);
auto cosine_sim_q0_d1 = 1 - cosine_similarity(q_rep[0], d_rep[1]);
auto cosine_sim_q1_d0 = 1 - cosine_similarity(q_rep[1], d_rep[0]);
auto cosine_sim_q1_d1 = 1 - cosine_similarity(q_rep[1], d_rep[1]);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[0].c_str(), cosine_sim_q1_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
}
llama_batch_free(bat);
llama_free(ctx);
llama_free_model(mdl);
llama_backend_free();
return 0;
}
output
./embeddings -m ggml-gritlm-7b-q8_0.gguf -ngl 33
ggml_init_cublas: GGML_CUDA_FORCE_MMQ: no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
llama_model_loader: loaded meta data with 24 key-value pairs and 291 tensors from ggml-gritlm-7b-q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = llama
llama_model_loader: - kv 1: general.name str = GritLM
llama_model_loader: - kv 2: llama.context_length u32 = 32768
llama_model_loader: - kv 3: llama.embedding_length u32 = 4096
llama_model_loader: - kv 4: llama.block_count u32 = 32
llama_model_loader: - kv 5: llama.feed_forward_length u32 = 14336
llama_model_loader: - kv 6: llama.rope.dimension_count u32 = 128
llama_model_loader: - kv 7: llama.attention.head_count u32 = 32
llama_model_loader: - kv 8: llama.attention.head_count_kv u32 = 8
llama_model_loader: - kv 9: llama.attention.layer_norm_rms_epsilon f32 = 0.000010
llama_model_loader: - kv 10: llama.rope.freq_base f32 = 10000.000000
llama_model_loader: - kv 11: general.file_type u32 = 7
llama_model_loader: - kv 12: tokenizer.ggml.model str = llama
llama_model_loader: - kv 13: tokenizer.ggml.tokens arr[str,32000] = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv 14: tokenizer.ggml.scores arr[f32,32000] = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv 15: tokenizer.ggml.token_type arr[i32,32000] = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv 16: tokenizer.ggml.bos_token_id u32 = 1
llama_model_loader: - kv 17: tokenizer.ggml.eos_token_id u32 = 2
llama_model_loader: - kv 18: tokenizer.ggml.unknown_token_id u32 = 0
llama_model_loader: - kv 19: tokenizer.ggml.padding_token_id u32 = 1
llama_model_loader: - kv 20: tokenizer.ggml.add_bos_token bool = true
llama_model_loader: - kv 21: tokenizer.ggml.add_eos_token bool = false
llama_model_loader: - kv 22: tokenizer.chat_template str = {{ bos_token }}{% for message in mess...
llama_model_loader: - kv 23: general.quantization_version u32 = 2
llama_model_loader: - type f32: 65 tensors
llama_model_loader: - type q8_0: 226 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format = GGUF V3 (latest)
llm_load_print_meta: arch = llama
llm_load_print_meta: vocab type = SPM
llm_load_print_meta: n_vocab = 32000
llm_load_print_meta: n_merges = 0
llm_load_print_meta: n_ctx_train = 32768
llm_load_print_meta: n_embd = 4096
llm_load_print_meta: n_head = 32
llm_load_print_meta: n_head_kv = 8
llm_load_print_meta: n_layer = 32
llm_load_print_meta: n_rot = 128
llm_load_print_meta: n_embd_head_k = 128
llm_load_print_meta: n_embd_head_v = 128
llm_load_print_meta: n_gqa = 4
llm_load_print_meta: n_embd_k_gqa = 1024
llm_load_print_meta: n_embd_v_gqa = 1024
llm_load_print_meta: f_norm_eps = 0.0e+00
llm_load_print_meta: f_norm_rms_eps = 1.0e-05
llm_load_print_meta: f_clamp_kqv = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff = 14336
llm_load_print_meta: n_expert = 0
llm_load_print_meta: n_expert_used = 0
llm_load_print_meta: pooling type = 0
llm_load_print_meta: rope type = 0
llm_load_print_meta: rope scaling = linear
llm_load_print_meta: freq_base_train = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx = 32768
llm_load_print_meta: rope_finetuned = unknown
llm_load_print_meta: model type = 7B
llm_load_print_meta: model ftype = Q8_0
llm_load_print_meta: model params = 7.24 B
llm_load_print_meta: model size = 7.17 GiB (8.50 BPW)
llm_load_print_meta: general.name = GritLM
llm_load_print_meta: BOS token = 1 '<s>'
llm_load_print_meta: EOS token = 2 '</s>'
llm_load_print_meta: UNK token = 0 '<unk>'
llm_load_print_meta: PAD token = 1 '<s>'
llm_load_print_meta: LF token = 13 '<0x0A>'
llm_load_tensors: ggml ctx size = 0.22 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors: CPU buffer size = 132.81 MiB
llm_load_tensors: CUDA0 buffer size = 7205.83 MiB
llama_new_context_with_model: n_ctx = 512
llama_new_context_with_model: freq_base = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: CUDA0 KV buffer size = 64.00 MiB
llama_new_context_with_model: KV self size = 64.00 MiB, K (f16): 32.00 MiB, V (f16): 32.00 MiB
llama_new_context_with_model: CUDA_Host input buffer size = 10.01 MiB
llama_new_context_with_model: CUDA0 compute buffer size = 73.00 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 8.00 MiB
llama_new_context_with_model: graph splits (measure): 2
[1:][523: <][28766:|][18320:embed][28766:|][28767:>][13:
][28741:A][21690: purely][13669: peer][28733:-][532:to][28733:-][14720:peer][2751: version][302: of][13176: electronic][7877: cash][682: would][1914: allow][3270: online][14923: payments][298: to][347: be][2662: sent][5090: directly][477: from][624: one][4150: party][298: to][1698: another][1671: without][1404: going][1059: through][264: a][5593: financial][16854: institution][28723:.][13770: Digital][1492: sign][2863:atures][3084: provide][744: part][302: of][272: the][5165: solution][28725:,][562: but][272: the][2191: main][7196: benefits][460: are][3654: lost][513: if][264: a][16437: trusted][4008: third][4150: party][349: is][1309: still][3030: required][298: to][5297: prevent][3579: double][28733:-][886:sp][2570:ending][28723:.][816: We][19333: propose][264: a][5165: solution][298: to][272: the][3579: double][28733:-][886:sp][2570:ending][2700: problem][1413: using][264: a][13669: peer][28733:-][532:to][28733:-][14720:peer][3681: network][28723:.][415: The][3681: network][5104: tim][374:est][10991:amps][15852: transactions][486: by][659: has][2299:hing][706: them][778: into][396: an][15260: ongoing][7650: chain][302: of][7135: hash][28733:-][5527:based][7167: proof][28733:-][1009:of][28733:-][1328:work][28725:,][20345: forming][264: a][2395: record][369: that][3573: cannot][347: be][4648: changed][1671: without][312: re][2432:do][288:ing][272: the][7167: proof][28733:-][1009:of][28733:-][1328:work][28723:.][415: The][23397: longest][7650: chain][459: not][865: only][14449: serves][390: as][7167: proof][302: of][272: the][7768: sequence][302: of][3926: events][24385: witnessed][28725:,][562: but][7167: proof][369: that][378: it][1988: came][477: from][272: the][7639: largest][6313: pool][302: of][14865: CPU][1982: power][28723:.][1136: As][1043: long][390: as][264: a][7757: majority][302: of][14865: CPU][1982: power][349: is][12888: controlled][486: by][9249: nodes][369: that][460: are][459: not][18468: cooper][1077:ating][298: to][3517: attack][272: the][3681: network][28725:,][590: they][28742:'][584:ll][8270: generate][272: the][23397: longest][7650: chain][304: and][575: out][2644:pace][3517: attack][404:ers][28723:.][415: The][3681: network][3837: itself][6948: requires][13383: minimal][4693: structure][28723:.][351: M][9251:essages][460: are][11837: broadcast][356: on][264: a][1489: best][4261: effort][6451: basis][28725:,][304: and][9249: nodes][541: can][3530: leave][304: and][312: re][5906:join][272: the][3681: network][438: at][622: will][28725:,][22368: accepting][272: the][23397: longest][7167: proof][28733:-][1009:of][28733:-][1328:work][7650: chain][390: as][7167: proof][302: of][767: what][4243: happened][1312: while][590: they][654: were][4214: gone][28723:.]
[1:][523: <][28766:|][18320:embed][28766:|][28767:>][13:
][2595:All][2245: text][28733:-][5527:based][3842: language][4418: problems][541: can][347: be][9397: reduced][298: to][2477: either][8342: generation][442: or][28643: embedding][28723:.][10929: Current][4994: models][865: only][2225: perform][1162: well][438: at][624: one][442: or][272: the][799: other][28723:.][816: We][13097: introduce][1350: gener][1197:ative][2904: represent][1249:ational][13126: instruction][15013: tun][288:ing][325: (][8369:GR][1153:IT][28731:)][970: where][1403:by][264: a][2475: large][3842: language][2229: model][349: is][10898: trained][298: to][4269: handle][1560: both][1350: gener][1197:ative][304: and][28643: embedding][9796: tasks][486: by][11731: distingu][5596:ishing][1444: between][706: them][1059: through][11382: instructions][28723:.][3880: Comp][1327:ared][298: to][799: other][1565: open][4994: models][28725:,][813: our][10503: resulting][420: G][872:rit][27149:LM][28705: ][28787:7][28760:B][6491: sets][264: a][633: new][1665: state][302: of][272: the][1524: art][356: on][272: the][7576: Mass][495:ive][7379: Text][18065: Emb][286:ed][3202:ding][4121: Ben][338:ch][3325:mark][325: (][28755:M][3392:TE][28760:B][28731:)][304: and][575: out][487:per][14367:forms][544: all][4994: models][582: up][298: to][871: its][1669: size][356: on][264: a][2819: range][302: of][1350: gener][1197:ative][9796: tasks][28723:.][2463: By][19903: scaling][582: up][3629: further][28725:,][420: G][872:rit][27149:LM][28705: ][28783:8][28814:X][28787:7][28760:B][575: out][487:per][14367:forms][544: all][1565: open][1350: gener][1197:ative][3842: language][4994: models][369: that][478: we][3851: tried][1312: while][1309: still][1250: being][3352: among][272: the][1489: best][28643: embedding][4994: models][28723:.][2280: Not][1907:ably][28725:,][478: we][1300: find][369: that][19348: GR][1153:IT][9019: matches][4154: training][356: on][865: only][1350: gener][1197:ative][442: or][28643: embedding][1178: data][28725:,][5884: thus][478: we][541: can][521: un][1575:ify][1560: both][438: at][708: no][4397: performance][4320: loss][28723:.][13927: Among][799: other][7196: benefits][28725:,][272: the][521: un][2500:ification][4213: via][19348: GR][1153:IT][27480: speeds][582: up][8337: Ret][10212:riev][282:al][28733:-][21575:Aug][466:ment][286:ed][26802: Generation][325: (][28754:R][2377:AG][28731:)][486: by][876: >][28705: ][28784:6][28734:0][28823:%][354: for][1043: long][10181: documents][28725:,][486: by][708: no][3774: longer][22579: requiring][7681: separate][17913: retriev][282:al][304: and][8342: generation][4994: models][28723:.][3813: Mod][1190:els][28725:,][2696: code][28725:,][4345: etc][28723:.][460: are][21964: freely][2632: available][438: at][4449: https][1508:://][6222:github][28723:.][675:com][28748:/][2083:Context][840:ual][11741:AI][28748:/][820:gr][279:it][24174:lm][28723:.]
[1:][523: <][28766:|][1838:user][28766:|][28767:>][13:
][28777:G][5067:iven][264: a][10469: scientific][3830: paper][3941: title][28725:,][20132: retrieve][272: the][3830: paper][28742:'][28713:s][11576: abstract][13:
][28789:<][28766:|][18320:embed][28766:|][28767:>][13:
][8443:Bit][10817:coin][28747::][330: A][3242: Pe][263:er][28733:-][532:to][28733:-][22163:Peer][10394: Elect][7624:ronic][23439: Cash][2135: System]
[1:][523: <][28766:|][1838:user][28766:|][28767:>][13:
][28777:G][5067:iven][264: a][10469: scientific][3830: paper][3941: title][28725:,][20132: retrieve][272: the][3830: paper][28742:'][28713:s][11576: abstract][13:
][28789:<][28766:|][18320:embed][28766:|][28767:>][13:
][3602:Gener][1197:ative][17891: Represent][1249:ational][3133: Inst][3112:ruction][22756: Tun][288:ing]
Cosine similarity between "Bitcoin: A Peer-to-Peer Electronic Cash System" and "A purely peer-to-peer version of electronic cash w" is: 0.551
Cosine similarity between "Bitcoin: A Peer-to-Peer Electronic Cash System" and "All text-based language problems can be reduced to" is: 0.794
Cosine similarity between "Generative Representational Instruction Tuning" and "A purely peer-to-peer version of electronic cash w" is: 0.730
Cosine similarity between "Generative Representational Instruction Tuning" and "All text-based language problems can be reduced to" is: 0.803