Skip to content

Commit 43b5626

Browse files
author
ochafik
committed
tools: enable hermes2/qwen chat logic even w/o tools
1 parent f5cd27b commit 43b5626

File tree

2 files changed

+92
-90
lines changed

2 files changed

+92
-90
lines changed

common/chat.cpp

Lines changed: 90 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,98 +1466,100 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
14661466
data.thinking_forced_open = true;
14671467
}
14681468

1469-
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
1470-
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1471-
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1472-
std::vector<std::string> tool_rules;
1473-
std::vector<std::string> tool_call_alts;
1474-
std::vector<std::string> escaped_names;
1475-
foreach_function(inputs.tools, [&](const json & tool) {
1476-
const auto & function = tool.at("function");
1477-
std::string name = function.at("name");
1478-
auto parameters = function.at("parameters");
1479-
builder.resolve_refs(parameters);
1480-
tool_rules.push_back(builder.add_schema(name + "-call", {
1481-
{"type", "object"},
1482-
{"properties", json {
1483-
{"name", json {{"const", name}}},
1484-
{"arguments", parameters},
1485-
}},
1486-
{"required", json::array({"name", "arguments"})},
1487-
}));
1488-
tool_call_alts.push_back(builder.add_rule(
1489-
name + "-function-tag",
1490-
"\"<function\" ( \"=" + name + "\" | \" name=\\\"" + name + "\\\"\" ) \">\" space " +
1491-
builder.add_schema(name + "-args", parameters) + " "
1492-
"\"</function>\" space"));
1469+
if (!inputs.tools.is_null()) {
1470+
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
1471+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1472+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1473+
std::vector<std::string> tool_rules;
1474+
std::vector<std::string> tool_call_alts;
1475+
std::vector<std::string> escaped_names;
1476+
foreach_function(inputs.tools, [&](const json & tool) {
1477+
const auto & function = tool.at("function");
1478+
std::string name = function.at("name");
1479+
auto parameters = function.at("parameters");
1480+
builder.resolve_refs(parameters);
1481+
tool_rules.push_back(builder.add_schema(name + "-call", {
1482+
{"type", "object"},
1483+
{"properties", json {
1484+
{"name", json {{"const", name}}},
1485+
{"arguments", parameters},
1486+
}},
1487+
{"required", json::array({"name", "arguments"})},
1488+
}));
1489+
tool_call_alts.push_back(builder.add_rule(
1490+
name + "-function-tag",
1491+
"\"<function\" ( \"=" + name + "\" | \" name=\\\"" + name + "\\\"\" ) \">\" space " +
1492+
builder.add_schema(name + "-args", parameters) + " "
1493+
"\"</function>\" space"));
14931494

1494-
data.grammar_triggers.push_back({
1495-
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1496-
"<function=" + name + ">",
1495+
data.grammar_triggers.push_back({
1496+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1497+
"<function=" + name + ">",
1498+
});
1499+
auto escaped_name = regex_escape(name);
1500+
data.grammar_triggers.push_back({
1501+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1502+
"<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
1503+
});
1504+
escaped_names.push_back(escaped_name);
14971505
});
1498-
auto escaped_name = regex_escape(name);
1506+
auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
1507+
std::vector<std::string> alt_tags {
1508+
any_tool_call,
1509+
"\"<tool_call>\" space " + any_tool_call + " \"</tool_call>\"",
1510+
// The rest is just to accommodate common "good bad" outputs.
1511+
"\"<function_call>\" space " + any_tool_call + " \"</function_call>\"",
1512+
"\"<response>\" space " + any_tool_call + " \"</response>\"",
1513+
"\"<tools>\" space " + any_tool_call + " \"</tools>\"",
1514+
"\"<json>\" space " + any_tool_call + " \"</json>\"",
1515+
"\"<xml>\" space " + any_tool_call + " \"</xml>\"",
1516+
"\"<JSON>\" space " + any_tool_call + " \"</JSON>\"",
1517+
};
1518+
auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space");
1519+
tool_call_alts.push_back(wrappable_tool_call);
1520+
tool_call_alts.push_back(
1521+
"( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
1522+
auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
1523+
builder.add_rule("root",
1524+
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
1525+
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
1526+
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
14991527
data.grammar_triggers.push_back({
1500-
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1501-
"<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
1528+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1529+
// If thinking_forced_open, then we capture the </think> tag in the grammar,
1530+
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
1531+
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
1532+
"(\\s*"
1533+
"(?:<tool_call>"
1534+
"|<function"
1535+
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
1536+
"\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\""
1537+
")"
1538+
")[\\s\\S]*"
1539+
),
15021540
});
1503-
escaped_names.push_back(escaped_name);
1504-
});
1505-
auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
1506-
std::vector<std::string> alt_tags {
1507-
any_tool_call,
1508-
"\"<tool_call>\" space " + any_tool_call + " \"</tool_call>\"",
1509-
// The rest is just to accommodate common "good bad" outputs.
1510-
"\"<function_call>\" space " + any_tool_call + " \"</function_call>\"",
1511-
"\"<response>\" space " + any_tool_call + " \"</response>\"",
1512-
"\"<tools>\" space " + any_tool_call + " \"</tools>\"",
1513-
"\"<json>\" space " + any_tool_call + " \"</json>\"",
1514-
"\"<xml>\" space " + any_tool_call + " \"</xml>\"",
1515-
"\"<JSON>\" space " + any_tool_call + " \"</JSON>\"",
1516-
};
1517-
auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space");
1518-
tool_call_alts.push_back(wrappable_tool_call);
1519-
tool_call_alts.push_back(
1520-
"( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
1521-
auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
1522-
builder.add_rule("root",
1523-
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
1524-
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
1525-
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
1526-
data.grammar_triggers.push_back({
1527-
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
1528-
// If thinking_forced_open, then we capture the </think> tag in the grammar,
1529-
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
1530-
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
1531-
"(\\s*"
1532-
"(?:<tool_call>"
1533-
"|<function"
1534-
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
1535-
"\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\""
1536-
")"
1537-
")[\\s\\S]*"
1538-
),
1541+
data.preserved_tokens = {
1542+
"<think>",
1543+
"</think>",
1544+
"<tool_call>",
1545+
"</tool_call>",
1546+
"<function",
1547+
"<tools>",
1548+
"</tools>",
1549+
"<response>",
1550+
"</response>",
1551+
"<function_call>",
1552+
"</function_call>",
1553+
"<json>",
1554+
"</json>",
1555+
"<JSON>",
1556+
"</JSON>",
1557+
"```",
1558+
"```json",
1559+
"```xml",
1560+
};
15391561
});
1540-
data.preserved_tokens = {
1541-
"<think>",
1542-
"</think>",
1543-
"<tool_call>",
1544-
"</tool_call>",
1545-
"<function",
1546-
"<tools>",
1547-
"</tools>",
1548-
"<response>",
1549-
"</response>",
1550-
"<function_call>",
1551-
"</function_call>",
1552-
"<json>",
1553-
"</json>",
1554-
"<JSON>",
1555-
"</JSON>",
1556-
"```",
1557-
"```json",
1558-
"```xml",
1559-
};
1560-
});
1562+
}
15611563

15621564
return data;
15631565
}
@@ -1702,7 +1704,7 @@ static common_chat_params common_chat_templates_apply_jinja(
17021704
}
17031705

17041706
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
1705-
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null() && params.tools.is_array() && params.json_schema.is_null()) {
1707+
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
17061708
return common_chat_params_init_hermes_2_pro(tmpl, params);
17071709
}
17081710

tests/test-chat.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,14 +737,14 @@ static void test_template_output_parsers() {
737737
auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja");
738738
std::vector<std::string> end_tokens{ "<|im_end|>" };
739739

740-
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
740+
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
741741
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
742742
}
743743
{
744744
auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
745745
std::vector<std::string> end_tokens{ "<|im_end|>" };
746746

747-
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
747+
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
748748
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
749749
assert_equals(
750750
COMMON_CHAT_FORMAT_HERMES_2_PRO,

0 commit comments

Comments
 (0)