diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 2a7d575ce..cdf368fae 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3513,6 +3513,57 @@ def __call__(self, **kwargs): return super().__call__(**kwargs) +class Gemma3ChatHandler(Llava15ChatHandler): + # Chat Format: + # 'user\n{system_prompt}\n\n{prompt}\nmodel\n' + + DEFAULT_SYSTEM_MESSAGE = None + + CHAT_FORMAT = ( + "{% if messages[0]['role'] == 'system' %}" + "{% if messages[0]['content'] is string %}" + "{% set first_user_prefix = messages[0]['content'] + '\n\n' %}" + "{% else %}" + "{% set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' %}" + "{% endif %}" + "{% set loop_messages = messages[1:] %}" + "{% else %}" + "{% set first_user_prefix = \"\" %}" + "{% set loop_messages = messages %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}" + "{% endif %}" + "{% if (message['role'] == 'assistant') %}" + "{% set role = \"model\" %}" + "{% else %}" + "{% set role = message['role'] %}" + "{% endif %}" + "{{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}" + "{% if message['content'] is string %}" + "{{ message['content'] | trim }}" + "{% elif message['content'] is iterable %}" + "{% for item in message['content'] %}" + "{% if item['type'] == 'image_url' and item['image_url'] is string %}" + "{{ '\n\n' + item['image_url'] + '\n\n' }}" + "{% elif item['type'] == 'image_url' and item['image_url'] is mapping %}" + "{{ '\n\n' + item['image_url']['url'] + '\n\n' }}" + "{% elif item['type'] == 'text' %}" + "{{ item['text'] | trim }}" + "{% endif %}" + "{% endfor %}" + "{% else %}" + "{{ raise_exception(\"Invalid content type\") }}" + "{% endif %}" + "{{ '\n' }}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ 'model\n' }}" + "{% endif %}" + ) + + @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( llama: llama.Llama,