Skip to content

Commit 807b0c4

Browse files
fairydreamingsszymczyggerganov
authored
Inference support for T5 and FLAN-T5 model families (#5763)
* llama : add inference support and model types for T5 and FLAN-T5 model families * llama : add new API functions to support encoder-decoder models: llama_encode(), llama_model_has_encoder(), llama_model_decoder_start_token() * common, llama-cli, llama-batched : add support for encoder-decoder models * convert-hf : handle shared token embeddings tensors in T5Model * convert-hf : add support for SentencePiece BPE tokenizer in T5Model (for Pile-T5 models) * convert-hf : add MT5ForConditionalGeneration and UMT5ForConditionalGeneration to architectures supported by T5Model * convert : add t5 tokenizer tests, use "slow" HF tokenizer for t5 --------- Co-authored-by: Stanisław Szymczyk <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent f8c4c07 commit 807b0c4

33 files changed

+946
-31
lines changed

common/common.cpp

+18-1
Original file line numberDiff line numberDiff line change
@@ -2070,7 +2070,24 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
20702070
if (params.warmup) {
20712071
LOG("warming up the model with an empty run\n");
20722072

2073-
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
2073+
std::vector<llama_token> tmp;
2074+
llama_token bos = llama_token_bos(model);
2075+
llama_token eos = llama_token_eos(model);
2076+
// some models (e.g. T5) don't have a BOS token
2077+
if (bos != -1) {
2078+
tmp.push_back(bos);
2079+
}
2080+
tmp.push_back(eos);
2081+
2082+
if (llama_model_has_encoder(model)) {
2083+
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
2084+
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
2085+
if (decoder_start_token_id == -1) {
2086+
decoder_start_token_id = bos;
2087+
}
2088+
tmp.clear();
2089+
tmp.push_back(decoder_start_token_id);
2090+
}
20742091
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
20752092
llama_kv_cache_clear(lctx);
20762093
llama_synchronize(lctx);

convert-hf-to-gguf-update.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class TOKENIZER_TYPE(IntEnum):
4545
SPM = auto()
4646
BPE = auto()
4747
WPM = auto()
48+
UGM = auto()
4849

4950

5051
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
@@ -89,6 +90,7 @@ class TOKENIZER_TYPE(IntEnum):
8990
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
9091
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
9192
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
93+
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
9294
]
9395

9496

@@ -110,9 +112,13 @@ def download_model(model):
110112
os.makedirs(f"models/tokenizers/{name}", exist_ok=True)
111113

112114
files = ["config.json", "tokenizer.json", "tokenizer_config.json"]
115+
113116
if tokt == TOKENIZER_TYPE.SPM:
114117
files.append("tokenizer.model")
115118

119+
if tokt == TOKENIZER_TYPE.UGM:
120+
files.append("spiece.model")
121+
116122
for file in files:
117123
save_path = f"models/tokenizers/{name}/{file}"
118124
if os.path.isfile(save_path):
@@ -135,7 +141,7 @@ def download_model(model):
135141
name = model["name"]
136142
tokt = model["tokt"]
137143

138-
if tokt == TOKENIZER_TYPE.SPM:
144+
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
139145
continue
140146

141147
# Skip if the tokenizer folder does not exist or there are other download issues previously
@@ -145,7 +151,10 @@ def download_model(model):
145151

146152
# create the tokenizer
147153
try:
148-
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
154+
if name == "t5":
155+
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
156+
else:
157+
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
149158
except OSError as e:
150159
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
151160
continue # Skip to the next model if the tokenizer can't be loaded
@@ -266,6 +275,7 @@ def get_vocab_base_pre(self, tokenizer) -> str:
266275
"\n =",
267276
"' era",
268277
"Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
278+
"!!!!!!",
269279
"3",
270280
"33",
271281
"333",
@@ -304,7 +314,10 @@ def get_vocab_base_pre(self, tokenizer) -> str:
304314

305315
# create the tokenizer
306316
try:
307-
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
317+
if name == "t5":
318+
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
319+
else:
320+
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
308321
except OSError as e:
309322
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
310323
continue # Skip this model and continue with the next one in the loop

convert-hf-to-gguf.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -2853,29 +2853,47 @@ def write_tensors(self):
28532853
raise ValueError(f"Unprocessed experts: {experts}")
28542854

28552855

2856-
@Model.register("T5ForConditionalGeneration")
28572856
@Model.register("T5WithLMHeadModel")
2857+
@Model.register("T5ForConditionalGeneration")
2858+
@Model.register("MT5ForConditionalGeneration")
2859+
@Model.register("UMT5ForConditionalGeneration")
28582860
class T5Model(Model):
28592861
model_arch = gguf.MODEL_ARCH.T5
28602862

2863+
def __init__(self, *args, **kwargs):
2864+
super().__init__(*args, **kwargs)
2865+
self.shared_token_embeddings_found = False
2866+
28612867
def set_vocab(self):
28622868
# to avoid TypeError: Descriptors cannot be created directly
28632869
# exception when importing sentencepiece_model_pb2
28642870
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
28652871
from sentencepiece import SentencePieceProcessor
28662872
from sentencepiece import sentencepiece_model_pb2 as model
28672873

2868-
tokenizer_path = self.dir_model / 'spiece.model'
2874+
tokenizer_path = self.dir_model / 'tokenizer.model'
2875+
2876+
# many older models use spiece.model tokenizer model filename
2877+
if not tokenizer_path.is_file():
2878+
tokenizer_path = self.dir_model / 'spiece.model'
28692879

28702880
if not tokenizer_path.is_file():
28712881
raise FileNotFoundError(f"File not found: {tokenizer_path}")
28722882

28732883
sentencepiece_model = model.ModelProto()
28742884
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
2885+
2886+
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
2887+
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
2888+
# assure the tokenizer model file name is correct
2889+
assert tokenizer_path.name == 'tokenizer.model'
2890+
return self._set_vocab_sentencepiece()
2891+
else:
2892+
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
2893+
28752894
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
28762895
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
28772896
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
2878-
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
28792897

28802898
tokenizer = SentencePieceProcessor()
28812899
tokenizer.LoadFromFile(str(tokenizer_path))
@@ -2945,7 +2963,10 @@ def set_vocab(self):
29452963

29462964
def set_gguf_parameters(self):
29472965
self.gguf_writer.add_name("T5")
2948-
self.gguf_writer.add_context_length(self.hparams["n_positions"])
2966+
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
2967+
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
2968+
n_ctx = 512
2969+
self.gguf_writer.add_context_length(n_ctx)
29492970
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
29502971
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
29512972
self.gguf_writer.add_block_count(self.hparams["num_layers"])
@@ -2961,12 +2982,17 @@ def set_gguf_parameters(self):
29612982
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
29622983
del bid # unused
29632984

2964-
# Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or
2965-
# "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor
2966-
# To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight".
2967-
if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight":
2968-
logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.")
2969-
return []
2985+
# T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
2986+
# "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
2987+
# in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
2988+
# and decoder and ignore the remaining ones.
2989+
if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
2990+
if not self.shared_token_embeddings_found:
2991+
name = "shared.weight"
2992+
self.shared_token_embeddings_found = True
2993+
else:
2994+
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
2995+
return []
29702996

29712997
return [(self.map_tensor_name(name), data_torch)]
29722998

examples/batched/batched.cpp

+27-7
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,34 @@ int main(int argc, char ** argv) {
9393

9494
// create a llama_batch
9595
// we use this object to submit token data for decoding
96-
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
96+
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
97+
98+
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
99+
for (int32_t i = 0; i < n_parallel; ++i) {
100+
seq_ids[i] = i;
101+
}
97102

98103
// evaluate the initial prompt
99104
for (size_t i = 0; i < tokens_list.size(); ++i) {
100-
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
105+
llama_batch_add(batch, tokens_list[i], i, seq_ids, false);
101106
}
102107
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
103108

109+
if (llama_model_has_encoder(model)) {
110+
if (llama_encode(ctx, batch)) {
111+
LOG_TEE("%s : failed to eval\n", __func__);
112+
return 1;
113+
}
114+
115+
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
116+
if (decoder_start_token_id == -1) {
117+
decoder_start_token_id = llama_token_bos(model);
118+
}
119+
120+
llama_batch_clear(batch);
121+
llama_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
122+
}
123+
104124
// llama_decode will output logits only for the last token of the prompt
105125
batch.logits[batch.n_tokens - 1] = true;
106126

@@ -109,11 +129,11 @@ int main(int argc, char ** argv) {
109129
return 1;
110130
}
111131

112-
// assign the system KV cache to all parallel sequences
113-
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
114-
for (int32_t i = 1; i < n_parallel; ++i) {
115-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
116-
}
132+
//// assign the system KV cache to all parallel sequences
133+
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
134+
//for (int32_t i = 1; i < n_parallel; ++i) {
135+
// llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
136+
//}
117137

118138
if (n_parallel > 1) {
119139
LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);

examples/main/main.cpp

+21-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,9 @@ int main(int argc, char ** argv) {
255255
}
256256

257257
const bool add_bos = llama_should_add_bos_token(model);
258-
GGML_ASSERT(llama_add_eos_token(model) != 1);
258+
if (!llama_model_has_encoder(model)) {
259+
GGML_ASSERT(llama_add_eos_token(model) != 1);
260+
}
259261
LOG("add_bos: %d\n", add_bos);
260262

261263
std::vector<llama_token> embd_inp;
@@ -517,6 +519,24 @@ int main(int argc, char ** argv) {
517519
exit(1);
518520
}
519521

522+
if (llama_model_has_encoder(model)) {
523+
int enc_input_size = embd_inp.size();
524+
llama_token * enc_input_buf = embd_inp.data();
525+
526+
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
527+
LOG_TEE("%s : failed to eval\n", __func__);
528+
return 1;
529+
}
530+
531+
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
532+
if (decoder_start_token_id == -1) {
533+
decoder_start_token_id = llama_token_bos(model);
534+
}
535+
536+
embd_inp.clear();
537+
embd_inp.push_back(decoder_start_token_id);
538+
}
539+
520540
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
521541
// predict
522542
if (!embd.empty()) {

include/llama.h

+15
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,13 @@ extern "C" {
485485
// Get a llama model tensor
486486
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
487487

488+
// Returns true if the model contains an encoder that requires llama_encode() call
489+
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
490+
491+
// For encoder-decoder models, this function returns id of the token that must be provided
492+
// to the decoder to start generating output sequence. For other models, it returns -1.
493+
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
494+
488495
// Returns 0 on success
489496
LLAMA_API uint32_t llama_model_quantize(
490497
const char * fname_inp,
@@ -770,6 +777,14 @@ extern "C" {
770777
// Frees a batch of tokens allocated with llama_batch_init()
771778
LLAMA_API void llama_batch_free(struct llama_batch batch);
772779

780+
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
781+
// Stores the encoder output internally for later use by the decoder cross-attention layers.
782+
// 0 - success
783+
// < 0 - error
784+
LLAMA_API int32_t llama_encode(
785+
struct llama_context * ctx,
786+
struct llama_batch batch);
787+
773788
// Positive return values does not mean a fatal error, but rather a warning.
774789
// 0 - success
775790
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)

models/ggml-vocab-bert-bge.gguf.inp

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ __ggml_vocab_test__
7373
__ggml_vocab_test__
7474
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
7575
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
7678
3
7779
__ggml_vocab_test__
7880
33

models/ggml-vocab-bert-bge.gguf.out

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
1027
3232
1005 3690
3333
7592 1010 1061 1005 2035 999 2129 2024 2017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995
34+
999 999 999 999 999 999
3435
1017
3536
3943
3637
21211

models/ggml-vocab-command-r.gguf.inp

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ __ggml_vocab_test__
7373
__ggml_vocab_test__
7474
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
7575
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
7678
3
7779
__ggml_vocab_test__
7880
33

models/ggml-vocab-command-r.gguf.out

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
206 1857
3232
14 4515
3333
28339 19 1770 14 1954 8 4070 1955 1933 80503 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372
34+
57178 10251
3435
26
3536
26 26
3637
26 26 26

models/ggml-vocab-deepseek-coder.gguf.inp

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ __ggml_vocab_test__
7373
__ggml_vocab_test__
7474
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
7575
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
7678
3
7779
__ggml_vocab_test__
7880
33

models/ggml-vocab-deepseek-coder.gguf.out

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
185 405
3232
6 2895
3333
17535 11 320 6 435 0 1717 417 340 12394 233 210 3015 19100 608 9413 2668 16 18 16 19 16 20 16 1393 169 121 239
34+
15330 3023
3435
18
3536
18 18
3637
18 18 18

models/ggml-vocab-deepseek-llm.gguf.inp

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ __ggml_vocab_test__
7373
__ggml_vocab_test__
7474
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
7575
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
7678
3
7779
__ggml_vocab_test__
7880
33

models/ggml-vocab-deepseek-llm.gguf.out

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
185 403
3232
6 2906
3333
17464 11 320 6 436 0 1724 418 340 33701 210 3025 19017 612 9407 2681 16 18 16 19 16 20 16 1398 68940 239
34+
15278 3033
3435
18
3536
18 18
3637
18 18 18

models/ggml-vocab-falcon.gguf.inp

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ __ggml_vocab_test__
7373
__ggml_vocab_test__
7474
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
7575
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
7678
3
7779
__ggml_vocab_test__
7880
33

models/ggml-vocab-falcon.gguf.out

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
1212 40
3232
18 4932
3333
9856 23 291 18 436 12 1265 362 299 8196 207 204 42 50087 123 2727 20300 32022 133 234 17419 30137 28 7858 181 133 236
34+
51520
3435
30
3536
3138
3637
22287

models/ggml-vocab-gpt-2.gguf.inp

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ __ggml_vocab_test__
7373
__ggml_vocab_test__
7474
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
7575
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
7678
3
7779
__ggml_vocab_test__
7880
33

0 commit comments

Comments
 (0)