Skip to content

FIx parsing single-byte UTF-8 tokens by manually parsing the protobuf #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -15,9 +15,12 @@ build-sanitize-addr/
build-sanitize-thread/

models/*
*.bin*

/main
/quantize

arm_neon.h
compile_commands.json
.venv
__pycache__
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -139,7 +139,7 @@ ls ./models
65B 30B 13B 7B tokenizer_checklist.chk tokenizer.model

# install Python dependencies
python3 -m pip install torch numpy sentencepiece
python3 -m pip install torch numpy protobuf

# convert the 7B model to ggml FP16 format
python3 convert-pth-to-ggml.py models/7B/ 1
37 changes: 27 additions & 10 deletions convert-pth-to-ggml.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
import numpy as np
import torch

from sentencepiece import SentencePieceProcessor
import sentencepiece_model_pb2

if len(sys.argv) < 3:
print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n")
@@ -68,9 +68,11 @@ def get_n_parts(dim):
with open(fname_hparams, "r") as f:
hparams = json.load(f)

tokenizer = SentencePieceProcessor(fname_tokenizer)
tokenizer = sentencepiece_model_pb2.ModelProto()
with open(fname_tokenizer, "rb") as f:
tokenizer.ParseFromString(f.read())

hparams.update({"vocab_size": tokenizer.vocab_size()})
hparams.update({"vocab_size": len(tokenizer.pieces)})

n_parts = get_n_parts(hparams["dim"])

@@ -100,13 +102,28 @@ def get_n_parts(dim):
fout.write(struct.pack("i", ftype))

# Is this correct??
for i in range(32000):
# TODO: this is probably wrong - not sure how this tokenizer works
text = tokenizer.decode([29889, i]).encode('utf-8')
# remove the first byte (it's always '.')
text = text[1:]
fout.write(struct.pack("i", len(text)))
fout.write(text)
for token in tokenizer.pieces:
if token.type == 1:
# normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.
text = token.piece.replace("\u2581", " ").encode("utf-8")
fout.write(struct.pack("i", len(text)))
fout.write(text)
elif token.type == 2:
# "<unk>" token (translated as ??)
text = " \u2047 ".encode("utf-8")
fout.write(struct.pack("i", len(text)))
fout.write(text)
elif token.type == 3:
# "<s>"/"</s>" tokens
fout.write(struct.pack("i", 0))
elif token.type == 6:
# "<U+XX>" tokens (which may be invalid UTF-8)
if len(token.piece) != 6:
print("Invalid token: " + token.piece)
sys.exit(1)
byte_value = int(token.piece[3:-1], 16)
fout.write(struct.pack("i", 1))
fout.write(struct.pack("B", byte_value))

for k, v in model.items():
name = k
324 changes: 324 additions & 0 deletions sentencepiece_model.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
// SOURCE: https://github.com/google/sentencepiece/blob/9ffb33a14c97c512103be0ee74740099660b39aa/src/sentencepiece_model.proto#L282


// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!

syntax = "proto2";

// TODO(taku): Needs to use LITE RUNTIME in OSS release.
option optimize_for = LITE_RUNTIME;

package sentencepiece;

// TrainerSpec encodes a various parameters for SentencePiece training.
// Next id: 53
message TrainerSpec {
///////////////////////////////////////////////////////////////////
// General parameters
//
// Input corpus files.
// Trainer accepts the following two formats:
// A) Monolingual: plain text, one sentence per line.
// B) Bilingual: TSV, source sentence <tab> target sentence
// When bilingual data is passed, shared vocabulary model is built.
// Note that the input file must be raw corpus, not a preprocessed corpus.
// Trainer only loads the first `input_sentence_size` sentences specified
// with this parameter.
repeated string input = 1;

// Input corpus format:
// "text": one-sentence-per-line text format (default)
// "tsv": sentence <tab> freq
optional string input_format = 7;

// Output model file prefix.
// <model_prefix>.model and <model_prefix>.vocab are generated.
optional string model_prefix = 2;

// Model type. only have UNIGRAM now.
enum ModelType {
UNIGRAM = 1; // Unigram language model with dynamic algorithm
BPE = 2; // Byte Pair Encoding
WORD = 3; // Delimitered by whitespace.
CHAR = 4; // tokenizes into character sequence
}
optional ModelType model_type = 3 [default = UNIGRAM];

// Vocabulary size. 8k is the default size.
optional int32 vocab_size = 4 [default = 8000];

// List of the languages this model can accept.
// Since the model is language-agnostic, this field is used as a reference.
repeated string accept_language = 5;

// Size of self-test samples, which are encoded in the model file.
optional int32 self_test_sample_size = 6 [default = 0];

// Whether to use DP version of sentencepiece. Use it with TSV input format
// (requires precomputed word tab counts to work).
optional bool enable_differential_privacy = 50 [default = false];
// Set these parameters if you need DP version of sentencepiece.
// std of noise to add.
optional float differential_privacy_noise_level = 51 [default = 0.0];
// Clipping threshold to apply after adding noise. All the words with
// frequency less than this value are dropped.
optional uint64 differential_privacy_clipping_threshold = 52 [default = 0];

///////////////////////////////////////////////////////////////////
// Training parameters.
//
// Uses characters which cover the corpus with the ratio of `chars_coverage`.
// This parameter determines the set of basic Alphabet of sentence piece.
// 1.0 - `chars_coverage` characters are treated as UNK.
// See also required_chars field.
optional float character_coverage = 10 [default = 0.9995];

// Maximum size of sentences the trainer loads from `input` parameter.
// Trainer simply loads the `input` files in sequence.
// It is better to shuffle the input corpus randomly.
optional uint64 input_sentence_size = 11 [default = 0];
optional bool shuffle_input_sentence = 19 [default = true];

// Maximum size of sentences to make seed sentence pieces.
// Extended suffix array is constructed to extract frequent
// sub-strings from the corpus. This uses 20N working space,
// where N is the size of corpus.
optional int32 mining_sentence_size = 12 [deprecated = true];

// Maximum size of sentences to train sentence pieces.
optional int32 training_sentence_size = 13 [deprecated = true];

// The size of seed sentencepieces.
// `seed_sentencepiece_size` must be larger than `vocab_size`.
optional int32 seed_sentencepiece_size = 14 [default = 1000000];

// In every EM sub-iterations, keeps top
// `shrinking_factor` * `current sentencepieces size` with respect to
// the loss of the sentence piece. This value should be smaller than 1.0.
optional float shrinking_factor = 15 [default = 0.75];

// The maximum sentence length in byte. The sentences with the length
// larger than `max_sentence_length` is simply ignored.
// Longer input tends to bring the following risks:
// * Overflow during EM training (unigram language model only)
// * Performance drop because of O(n log n) cost in BPE.
optional int32 max_sentence_length = 18 [default = 4192];

// Number of threads in the training.
optional int32 num_threads = 16 [default = 16];

// Number of EM sub iterations.
optional int32 num_sub_iterations = 17 [default = 2];

///////////////////////////////////////////////////////////////////
// SentencePiece parameters which control the shapes of sentence piece.
//
// Maximum length of sentencepiece.
optional int32 max_sentencepiece_length = 20 [default = 16];

// Uses Unicode script to split sentence pieces.
// When `split_by_unicode_script` is true, we do not allow sentence piece to
// include multiple Unicode scripts, e.g. "F1" is not a valid piece.
// Exception: CJ characters (Hiragana/Katakana/Han) are all handled
// as one script type, since Japanese word can consist of multiple scripts.
// This exception is always applied regardless of the accept-language
// parameter.
optional bool split_by_unicode_script = 21 [default = true];

// When `split_by_number` is true, put a boundary between number and
// non-number transition. If we want to treat "F1" is one token, set this flag
// to be false.
optional bool split_by_number = 23 [default = true];

// Use a white space to split sentence pieces.
// When `split_by_whitespace` is false, we may have the piece containing
// a white space in the middle. e.g., "in_the".
optional bool split_by_whitespace = 22 [default = true];

// Adds whitespace symbol (_) as a suffix instead of prefix. e.g., _hello =>
// hello_. When `treat_whitespace_as_suffix` is true,
// NormalizerSpec::add_dummy_prefix will add the dummy whitespace to the end
// of sentence.
optional bool treat_whitespace_as_suffix = 24 [default = false];

// Allows pieces that only contain whitespaces instead of appearing only as
// prefix or suffix of other pieces.
optional bool allow_whitespace_only_pieces = 26 [default = false];

// Split all digits (0-9) into separate pieces.
optional bool split_digits = 25 [default = false];

///////////////////////////////////////////////////////////////////
// Vocabulary management
//
// Defines control symbols used as an indicator to
// change the behavior of the decoder. <s> and </s> are pre-defined.
// We can use this field to encode various meta information,
// including language indicator in multilingual model.
// These symbols are not visible to users, but visible to
// the decoder. Note that when the input sentence contains control symbols,
// they are not treated as one token, but segmented into normal pieces.
// Control symbols must be inserted independently from the segmentation.
repeated string control_symbols = 30;

// Defines user defined symbols.
// These symbols are added with extremely high score
// so they are always treated as one unique symbol in any context.
// Typical usage of user_defined_symbols is placeholder for named entities.
repeated string user_defined_symbols = 31;

// Defines required characters. Each UTF8 character in this string is included
// in the character set regardless of character_coverage value. Unlike
// user_defined_symbols, these characters have scores based on the frequency
// on input sentences, and the model can form subwords using characters
// in this field.
optional string required_chars = 36;

// Decomposes unknown pieces into UTF-8 bytes.
optional bool byte_fallback = 35 [default = false];

// When creating the vocabulary file, defines whether or not to additionally
// output the score for each piece.
optional bool vocabulary_output_piece_score = 32 [default = true];

// `vocab_size` is treated as hard limit. Crash if
// the model can not produce the vocab of size `vocab_size`,
// When `hard_vocab_limit` is false, vocab_size is treated
// as soft limit. Note that when model_type=char,
// always assumes hard_vocab_limit = false.
optional bool hard_vocab_limit = 33 [default = true];

// use all symbols for vocab extraction. This flag is valid
// if model type is either CHAR or WORD
optional bool use_all_vocab = 34 [default = false];

///////////////////////////////////////////////////////////////////
// Reserved special meta tokens.
// * -1 is not used.
// * unk_id must not be -1.
// Id must starts with 0 and be contigous.
optional int32 unk_id = 40 [default = 0]; // <unk>
optional int32 bos_id = 41 [default = 1]; // <s>
optional int32 eos_id = 42 [default = 2]; // </s>
optional int32 pad_id = 43 [default = -1]; // <pad> (padding)
optional string unk_piece = 45 [default = "<unk>"];
optional string bos_piece = 46 [default = "<s>"];
optional string eos_piece = 47 [default = "</s>"];
optional string pad_piece = 48 [default = "<pad>"];

// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
// since this character can be useful both for user and
// developer. We can easily figure out that <unk> is emitted.
optional string unk_surface = 44 [default = " \xE2\x81\x87 "];

// Increase bit depth to allow unigram model training on large
// (>10M sentences) corpora. A Side-effect of enabling this flag
// is increased memory usage.
optional bool train_extremely_large_corpus = 49 [default = false];

// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
}

// NormalizerSpec encodes a various parameters for string normalizaiton
message NormalizerSpec {
// name of normalization rule.
optional string name = 1;

// Pre-compiled normalization rule created by
// Builder::GetPrecompiledCharsMap() or Builder::CompileCharsMap() method.
// Usually this field is set by Builder::GetNormalizerSpec() method.
optional bytes precompiled_charsmap = 2;

// Adds dummy whitespace at the beginning of text in order to
// treat "world" in "world" and "hello world" in the same way.
optional bool add_dummy_prefix = 3 [default = true];

// Removes leading, trailing, and duplicate internal whitespace.
optional bool remove_extra_whitespaces = 4 [default = true];

// Replaces whitespace with meta symbol.
// This field must be true to train sentence piece model.
optional bool escape_whitespaces = 5 [default = true];

// Custom normalization rule file in TSV format.
// https://github.com/google/sentencepiece/blob/master/doc/normalization.md
// This field is only used in SentencePieceTrainer::Train() method, which
// compiles the rule into the binary rule stored in `precompiled_charsmap`.
optional string normalization_rule_tsv = 6;

// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
}

// Proto to store samples for self-testing.
message SelfTestData {
message Sample {
optional string input = 1;
optional string expected = 2;
}
repeated Sample samples = 1;

// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
}

// ModelProto stores model parameters.
// SentencePieceProcessor is supposed to be self-contained.
// All settings/parameters which may change the behavior must be encoded
// in ModelProto.
message ModelProto {
message SentencePiece {
enum Type {
NORMAL = 1; // normal symbol
UNKNOWN = 2; // unknown symbol. only <unk> for now.
CONTROL = 3; // control symbols. </s>, <s>, <2ja> etc.
USER_DEFINED = 4; // user defined symbols.
// Typical usage of USER_DEFINED symbol
// is placeholder.
BYTE = 6; // byte symbols. Used when `byte_fallback` is true.
UNUSED = 5; // this piece is not used.
}
optional string piece = 1; // piece must not be empty.
optional float score = 2;
optional Type type = 3 [default = NORMAL];

// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
}

// Sentence pieces with scores.
repeated SentencePiece pieces = 1;

// Spec used to generate this model file.
optional TrainerSpec trainer_spec = 2;

// Spec for text normalization.
optional NormalizerSpec normalizer_spec = 3;

// Stores sample input and its expected segmentation to verify the model.
optional SelfTestData self_test_data = 4;

// Spec for text de-normalization.
optional NormalizerSpec denormalizer_spec = 5;

// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
}
44 changes: 44 additions & 0 deletions sentencepiece_model_pb2.py