Skip to content

Commit c7cbc34

Browse files
tristandruyenngxson
authored andcommitted
Fix phi3 chat template confusion with zephyr (ggml-org#7449)
* Fix phi3 template matching vs zephyr * Add regression test for new phi3 chat template * Implement review suggestions * Fix phi3 jinja test templates & match by <|end|> * Apply suggestion Co-authored-by: Xuan Son Nguyen <[email protected]> * Add all phi3 template variants in tests * Remove unneeded message trimming Co-authored-by: Xuan Son Nguyen <[email protected]> * Fix tests to not expect trimmed messages --------- Co-authored-by: Xuan Son Nguyen <[email protected]>
1 parent 9a29b6e commit c7cbc34

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

llama.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17856,6 +17856,15 @@ static int32_t llama_chat_apply_template_internal(
1785617856
}
1785717857
}
1785817858
// llama2 templates seem to not care about "add_generation_prompt"
17859+
} else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos)) {
17860+
// Phi 3
17861+
for (auto message : chat) {
17862+
std::string role(message->role);
17863+
ss << "<|" << role << "|>\n" << message->content << "<|end|>\n";
17864+
}
17865+
if (add_ass) {
17866+
ss << "<|assistant|>\n";
17867+
}
1785917868
} else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
1786017869
// zephyr template
1786117870
for (auto message : chat) {
@@ -17988,15 +17997,6 @@ static int32_t llama_chat_apply_template_internal(
1798817997
if (add_ass) {
1798917998
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
1799017999
}
17991-
} else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos )) {
17992-
// Phi 3
17993-
for (auto message : chat) {
17994-
std::string role(message->role);
17995-
ss << "<|" << role << "|>\n" << trim(message->content) << "<|end|>\n";
17996-
}
17997-
if (add_ass) {
17998-
ss << "<|assistant|>\n";
17999-
}
1800018000
} else {
1800118001
// template not supported
1800218002
return -1;

tests/test-chat-template.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,14 @@ int main(void) {
4949
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
5050
// Llama-3
5151
"{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
52-
// Phi-3
53-
"{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + ' ' + message['content'] + '<|end|> ' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|> ' }}{% else %}{{ eos_token }}{% endif %}"
52+
//Phi-3-mini
53+
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
54+
//Phi-3-small
55+
"{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
56+
//Phi-3-medium
57+
"{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
58+
//Phi-3-vision
59+
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}"
5460
};
5561
std::vector<std::string> expected_output = {
5662
// teknium/OpenHermes-2.5-Mistral-7B
@@ -79,8 +85,14 @@ int main(void) {
7985
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
8086
// Llama 3
8187
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
82-
// Phi 3
83-
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\nI am an assistant<|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
88+
//Phi-3-mini
89+
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
90+
//Phi-3-small
91+
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
92+
//Phi-3-medium
93+
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
94+
//Phi-3-vision
95+
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
8496
};
8597
std::vector<char> formatted_chat(1024);
8698
int32_t res;

0 commit comments

Comments
 (0)