@@ -772,15 +772,16 @@ struct llama_vocab {
772
772
using id = int32_t ;
773
773
using token = std::string;
774
774
775
- struct token_score {
775
+ struct token_data {
776
776
token tok;
777
777
float score;
778
+ int toktype;
778
779
};
779
780
780
781
llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
781
782
782
783
std::unordered_map<token, id> token_to_id;
783
- std::vector<token_score> id_to_token;
784
+ std::vector<token_data> id_to_token;
784
785
785
786
// default LLaMA special tokens
786
787
id special_bos_id = 1 ;
@@ -1507,17 +1508,25 @@ static void llama_model_load_internal(
1507
1508
1508
1509
const float * scores = (const float * ) gguf_get_arr_data (ctx, score_idx);
1509
1510
1511
+ const int toktype_idx = gguf_find_key (ctx, " tokenizer.ggml.token_type" );
1512
+ if (toktype_idx == -1 ) {
1513
+ throw std::runtime_error (" cannot find token type list in GGUF file\n " );
1514
+ }
1515
+
1516
+ const int * toktypes = (const int * ) gguf_get_arr_data (ctx, toktype_idx);
1517
+
1510
1518
for (uint32_t i = 0 ; i < hparams.n_vocab ; i++) {
1511
1519
std::string word = gguf_get_arr_str (ctx, token_idx, i);
1512
1520
1513
1521
vocab.token_to_id [word] = i;
1514
1522
1515
- auto & tok_score = vocab.id_to_token [i];
1516
- tok_score.tok = std::move (word);
1517
- tok_score.score = scores[i];
1523
+ auto & token_data = vocab.id_to_token [i];
1524
+ token_data.tok = std::move (word);
1525
+ token_data.score = scores[i];
1526
+ token_data.toktype = toktypes[i];
1518
1527
1519
1528
// determine the newline token: 0x0A == 10 == '\n'
1520
- if (tok_score .tok == " <0x0A>" ) {
1529
+ if (token_data .tok == " <0x0A>" ) {
1521
1530
vocab.linefeed_id = i;
1522
1531
}
1523
1532
}
@@ -2345,92 +2354,57 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
2345
2354
return vocab.type ;
2346
2355
}
2347
2356
2348
- static bool llama_is_normal_token (const llama_vocab & vocab, llama_token token) {
2349
- if (llama_vocab_get_type (vocab) == LLAMA_VOCAB_TYPE_SPM) {
2350
- return token >= 259 ;
2351
- }
2352
-
2353
- if (llama_vocab_get_type (vocab) == LLAMA_VOCAB_TYPE_BPE) {
2354
- return token >= 95 ;
2355
- }
2356
-
2357
- return false ;
2357
+ static bool llama_is_normal_token (const llama_vocab & vocab, llama_token id) {
2358
+ return vocab.id_to_token [id].toktype == 1 ;
2358
2359
}
2359
2360
2360
- static bool llama_is_bos_token (const llama_vocab & vocab, llama_token token ) {
2361
- return token == vocab. special_bos_id ;
2361
+ static bool llama_is_unknown_token (const llama_vocab & vocab, llama_token id ) {
2362
+ return vocab. id_to_token [id]. toktype == 2 ;
2362
2363
}
2363
2364
2364
- static bool llama_is_eos_token (const llama_vocab & vocab, llama_token token ) {
2365
- return token == vocab. special_eos_id ;
2365
+ static bool llama_is_control_token (const llama_vocab & vocab, llama_token id ) {
2366
+ return vocab. id_to_token [id]. toktype == 3 ;
2366
2367
}
2367
2368
2368
- static bool llama_is_control_token (const llama_vocab & vocab, llama_token token) {
2369
- if (llama_vocab_get_type (vocab) == LLAMA_VOCAB_TYPE_SPM) {
2370
- return token == llama_is_bos_token (vocab, token) || token == llama_is_eos_token (vocab, token);
2371
- }
2372
-
2373
- // TODO: improve?
2374
- return false ;
2369
+ static bool llama_is_bos_token (const llama_vocab & vocab, llama_token id) {
2370
+ GGML_ASSERT (llama_is_control_token (vocab, id));
2371
+ return id == vocab.special_bos_id ;
2375
2372
}
2376
2373
2377
- static bool llama_is_unknown_token (const llama_vocab & vocab, llama_token token) {
2378
- if (llama_vocab_get_type (vocab) == LLAMA_VOCAB_TYPE_SPM) {
2379
- return token == 0 ;
2380
- }
2381
-
2382
- // TODO: improve?
2383
- return false ;
2374
+ static bool llama_is_eos_token (const llama_vocab & vocab, llama_token id ) {
2375
+ GGML_ASSERT (llama_is_control_token (vocab, id));
2376
+ return id == vocab.special_eos_id ;
2384
2377
}
2385
2378
2386
- static bool llama_is_user_defined_token (const llama_vocab & vocab, llama_token token) {
2387
- GGML_UNUSED (vocab);
2388
- GGML_UNUSED (token);
2389
- // TODO: improve?
2390
- return false ;
2379
+ static bool llama_is_pad_token (const llama_vocab & vocab, llama_token id ) {
2380
+ GGML_ASSERT (id < 0 || llama_is_control_token (vocab, id));
2381
+ return id == vocab.special_pad_id ;
2391
2382
}
2392
2383
2393
- static bool llama_is_unused_token (const llama_vocab & vocab, llama_token token) {
2394
- GGML_UNUSED (vocab);
2395
- GGML_UNUSED (token);
2396
- // TODO: improve?
2397
- return false ;
2384
+ static bool llama_is_user_defined_token (const llama_vocab & vocab, llama_token id) {
2385
+ return vocab.id_to_token [id].toktype == 4 ;
2398
2386
}
2399
2387
2400
- static bool llama_is_byte_token (const llama_vocab & vocab, llama_token token) {
2401
- if (llama_vocab_get_type (vocab) == LLAMA_VOCAB_TYPE_SPM) {
2402
- return 3 <= token && token < 259 ;
2403
- }
2404
-
2405
- if (llama_vocab_get_type (vocab) == LLAMA_VOCAB_TYPE_BPE) {
2406
- return 1 <= token && token < 95 ;
2407
- }
2408
-
2409
- return false ;
2388
+ static bool llama_is_unused_token (const llama_vocab & vocab, llama_token id) {
2389
+ return vocab.id_to_token [id].toktype == 5 ;
2410
2390
}
2411
2391
2412
- static uint8_t llama_byte_to_char (const llama_vocab & vocab, uint8_t byte) {
2413
- if (llama_vocab_get_type (vocab) == LLAMA_VOCAB_TYPE_SPM) {
2414
- return byte - 3 ;
2415
- }
2416
-
2417
- if (llama_vocab_get_type (vocab) == LLAMA_VOCAB_TYPE_BPE) {
2418
- return byte + 32 ;
2419
- }
2420
-
2421
- return false ;
2392
+ static bool llama_is_byte_token (const llama_vocab & vocab, llama_token id) {
2393
+ return vocab.id_to_token [id].toktype == 6 ;
2422
2394
}
2423
2395
2424
- static uint8_t llama_char_to_byte (const llama_vocab & vocab, uint8_t ch) {
2425
- if (llama_vocab_get_type (vocab) == LLAMA_VOCAB_TYPE_SPM) {
2426
- return ch + 3 ;
2427
- }
2428
-
2429
- if (llama_vocab_get_type (vocab) == LLAMA_VOCAB_TYPE_BPE) {
2430
- return ch - 32 ;
2431
- }
2396
+ static uint8_t llama_token_to_byte (const llama_vocab & vocab, llama_token id) {
2397
+ GGML_ASSERT (llama_is_byte_token (vocab, id));
2398
+ const auto & token_data = vocab.id_to_token .at (id);
2399
+ auto buf = token_data.tok .substr (3 , 2 );
2400
+ return strtol (buf.c_str (), NULL , 16 );
2401
+ }
2432
2402
2433
- return false ;
2403
+ static llama_token llama_byte_to_token (const llama_vocab & vocab, uint8_t ch) {
2404
+ char buf[7 ];
2405
+ int result = snprintf (buf, sizeof (buf), " <0x%02X>" , ch);
2406
+ GGML_ASSERT (0 <= result && result < 7 );
2407
+ return vocab.token_to_id .at (buf);
2434
2408
}
2435
2409
2436
2410
static std::string llama_escape_whitespace (const std::string& text) {
@@ -2569,7 +2543,7 @@ struct llama_tokenizer {
2569
2543
if (p == rev_merge.end ()) {
2570
2544
// output any symbols that did not form tokens as bytes.
2571
2545
for (int j = 0 ; j < (int )symbol.n ; ++j) {
2572
- llama_vocab::id token_id = llama_char_to_byte (vocab_, symbol.text [j]);
2546
+ llama_vocab::id token_id = llama_byte_to_token (vocab_, symbol.text [j]);
2573
2547
output.push_back (token_id);
2574
2548
}
2575
2549
return ;
@@ -2595,12 +2569,12 @@ struct llama_tokenizer {
2595
2569
return ;
2596
2570
}
2597
2571
2598
- const auto &tok_score = vocab_.id_to_token [(*token).second ];
2572
+ const auto &tok_data = vocab_.id_to_token [(*token).second ];
2599
2573
2600
2574
llama_sp_bigram bigram;
2601
2575
bigram.left = left;
2602
2576
bigram.right = right;
2603
- bigram.score = tok_score .score ;
2577
+ bigram.score = tok_data .score ;
2604
2578
bigram.size = text.size ();
2605
2579
work_queue_.push (bigram);
2606
2580
@@ -5109,7 +5083,7 @@ int llama_token_to_str_with_model(const struct llama_model * model, llama_token
5109
5083
if (length < 1 ) {
5110
5084
return -1 ;
5111
5085
}
5112
- buf[0 ] = llama_byte_to_char (model->vocab , token);
5086
+ buf[0 ] = llama_token_to_byte (model->vocab , token);
5113
5087
return 1 ;
5114
5088
}
5115
5089
}
0 commit comments