From b4e7a71358486a0f3479937071a94385b2b44d92 Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 26 Dec 2024 21:49:44 +0000 Subject: [PATCH 1/3] Compatibility adjustments for chats that dont' support tool calls --- include/minja/chat-template.hpp | 114 ++++++++++++++++++++++++++------ 1 file changed, 95 insertions(+), 19 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 11c4d9c..d9e3e8c 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -26,30 +26,56 @@ class chat_template { // Most other templates (and OpenAI's API) expect the arguments object to be stringified. bool _requires_object_arguments = false; bool _supports_system_role = true; + bool _supports_parallel_tool_calls = false; std::string _source; std::string _bos_token; std::string _eos_token; std::shared_ptr _template_root; + bool renders_needles( + const std::vector & needles, + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const + { + try { + auto prompt = apply(messages, tools, add_generation_prompt, extra_context); + for (const auto & needle : needles) { + if (prompt.find(needle) == std::string::npos) { + return false; + } + } + return true; + } catch (const std::exception & e) { + return false; + } + } + public: chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) : _source(source), _bos_token(bos_token), _eos_token(eos_token) { - _supports_tools = source.find("tools") != std::string::npos; - _requires_object_arguments = - source.find("tool_call.arguments | items") != std::string::npos - || source.find("tool_call.arguments | tojson") != std::string::npos; - _supports_system_role = source.find("System role not supported") == std::string::npos; - _template_root = minja::Parser::parse(_source, { /* .trim_blocks = */ true, /* .lstrip_blocks = */ true, /* .keep_trailing_newline = */ false, }); + _supports_tools = source.find("tools") != std::string::npos; + _requires_object_arguments = + source.find("tool_call.arguments | items") != std::string::npos + || source.find("tool_call.arguments | tojson") != std::string::npos; + _supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos; + + _supports_system_role = renders_needles({""}, { + {{"role", "system"}, {"content", ""}}, + {{"role", "user"}, {"content", "Hey"}} + }, {}, false); } const std::string & source() const { return _source; } bool supports_tools() const { return _supports_tools; } + bool supports_parallel_tool_calls() const { return _supports_parallel_tool_calls; } std::string apply( const nlohmann::ordered_json & messages, @@ -57,11 +83,13 @@ class chat_template { bool add_generation_prompt, const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const { - auto actual_messages = messages; + json actual_messages; // First, "fix" messages so they have a chance to be rendered correctly by the template - if (_requires_object_arguments || !_supports_system_role) { + if (_requires_object_arguments || !_supports_system_role || !_supports_tools) { + actual_messages = json::array(); + std::string pending_system; auto flush_sys = [&]() { if (!pending_system.empty()) { @@ -72,12 +100,66 @@ class chat_template { pending_system.clear(); } }; - for (auto & message : actual_messages) { + for (const auto & message_ : messages) { + auto message = message_; if (!message.contains("role") || !message.contains("content")) { throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); } std::string role = message.at("role"); + if (message.contains("tool_calls")) { + if (_requires_object_arguments || !_supports_tools) { + for (auto & tool_call : message.at("tool_calls")) { + if (tool_call["type"] == "function") { + auto & function = tool_call.at("function"); + std::string arguments = function.at("arguments"); + function["arguments"] = json::parse(arguments); + } + } + } + if (!_supports_tools) { + auto content = message.at("content"); + auto tool_calls = json::array(); + for (const auto & tool_call : message.at("tool_calls")) { + if (tool_call.at("type") != "function") { + continue; + } + const auto & function = tool_call.at("function"); + auto tc = json { + {"name", function.at("name")}, + {"arguments", function.at("arguments")}, + }; + if (tool_call.contains("id")) { + tc["id"] = tool_call["id"]; + } + tool_calls.push_back(tc); + } + auto obj = json { + {"tool_calls", tool_calls}, + }; + if (!content.is_null() && content != "") { + obj["content"] = content; + } + message["content"] = obj.dump(2); + message.erase("tool_calls"); + } + } + if (!_supports_tools && role == "tool") { + message["role"] = "user"; + auto obj = json { + {"tool_response", { + {"tool", message.at("name")}, + {"content", message.at("content")}, + }}, + }; + if (message.contains("tool_call_id")) { + obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); + } + message["content"] = obj.dump(2); + message.erase("name"); + } + + // std::string content = message["content"]; if (!message["content"].is_null() && !_supports_system_role) { std::string content = message.at("content"); if (role == "system") { @@ -95,17 +177,11 @@ class chat_template { } } } - if (_requires_object_arguments && message.contains("tool_calls")) { - for (auto & tool_call : message.at("tool_calls")) { - if (tool_call["type"] == "function") { - auto & function = tool_call.at("function"); - std::string arguments = function.at("arguments"); - function["arguments"] = json::parse(arguments); - } - } - } + actual_messages.push_back(message); } flush_sys(); + } else { + actual_messages = messages; } auto context = minja::Context::make(json({ @@ -130,4 +206,4 @@ class chat_template { } }; -} // namespace minja +} // namespace minja From 4b835addce9796a7f59981c7ed3b72b138808939 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Dec 2024 19:53:56 +0000 Subject: [PATCH 2/3] update field naming (_ suffix not prefix) --- include/minja/chat-template.hpp | 50 ++++++++++++++++----------------- include/minja/minja.hpp | 4 +-- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index d9e3e8c..c4aef40 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -21,16 +21,16 @@ class chat_template { public: private: - bool _supports_tools = true; + bool supports_tools_ = true; // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. // Most other templates (and OpenAI's API) expect the arguments object to be stringified. - bool _requires_object_arguments = false; - bool _supports_system_role = true; - bool _supports_parallel_tool_calls = false; - std::string _source; - std::string _bos_token; - std::string _eos_token; - std::shared_ptr _template_root; + bool requires_object_arguments_ = false; + bool supports_system_role_ = true; + bool supports_parallel_tool_calls_ = false; + std::string source_; + std::string bos_token_; + std::string eos_token_; + std::shared_ptr template_root_; bool renders_needles( const std::vector & needles, @@ -54,28 +54,28 @@ class chat_template { public: chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) - : _source(source), _bos_token(bos_token), _eos_token(eos_token) + : source_(source), bos_token_(bos_token), eos_token_(eos_token) { - _template_root = minja::Parser::parse(_source, { + template_root_ = minja::Parser::parse(source_, { /* .trim_blocks = */ true, /* .lstrip_blocks = */ true, /* .keep_trailing_newline = */ false, }); - _supports_tools = source.find("tools") != std::string::npos; - _requires_object_arguments = + supports_tools_ = source.find("tools") != std::string::npos; + requires_object_arguments_ = source.find("tool_call.arguments | items") != std::string::npos || source.find("tool_call.arguments | tojson") != std::string::npos; - _supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos; + supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos; - _supports_system_role = renders_needles({""}, { + supports_system_role_ = renders_needles({""}, { {{"role", "system"}, {"content", ""}}, {{"role", "user"}, {"content", "Hey"}} }, {}, false); } - const std::string & source() const { return _source; } - bool supports_tools() const { return _supports_tools; } - bool supports_parallel_tool_calls() const { return _supports_parallel_tool_calls; } + const std::string & source() const { return source_; } + bool supports_tools() const { return supports_tools_; } + bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; } std::string apply( const nlohmann::ordered_json & messages, @@ -87,7 +87,7 @@ class chat_template { // First, "fix" messages so they have a chance to be rendered correctly by the template - if (_requires_object_arguments || !_supports_system_role || !_supports_tools) { + if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) { actual_messages = json::array(); std::string pending_system; @@ -108,7 +108,7 @@ class chat_template { std::string role = message.at("role"); if (message.contains("tool_calls")) { - if (_requires_object_arguments || !_supports_tools) { + if (requires_object_arguments_ || !supports_tools_) { for (auto & tool_call : message.at("tool_calls")) { if (tool_call["type"] == "function") { auto & function = tool_call.at("function"); @@ -117,7 +117,7 @@ class chat_template { } } } - if (!_supports_tools) { + if (!supports_tools_) { auto content = message.at("content"); auto tool_calls = json::array(); for (const auto & tool_call : message.at("tool_calls")) { @@ -144,7 +144,7 @@ class chat_template { message.erase("tool_calls"); } } - if (!_supports_tools && role == "tool") { + if (!supports_tools_ && role == "tool") { message["role"] = "user"; auto obj = json { {"tool_response", { @@ -160,7 +160,7 @@ class chat_template { } // std::string content = message["content"]; - if (!message["content"].is_null() && !_supports_system_role) { + if (!message["content"].is_null() && !supports_system_role_) { std::string content = message.at("content"); if (role == "system") { if (!pending_system.empty()) pending_system += "\n"; @@ -187,8 +187,8 @@ class chat_template { auto context = minja::Context::make(json({ {"messages", actual_messages}, {"add_generation_prompt", add_generation_prompt}, - {"bos_token", _bos_token}, - {"eos_token", _eos_token}, + {"bos_token", bos_token_}, + {"eos_token", eos_token_}, })); if (!tools.is_null()) { @@ -202,7 +202,7 @@ class chat_template { } } - return _template_root->render(context); + return template_root_->render(context); } }; diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index c5472a0..26f20fd 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -1009,7 +1009,7 @@ class FilterNode : public TemplateNode { throw std::runtime_error("Filter must be a callable: " + filter_value.dump()); } std::string rendered_body = body->render(context); - + ArgumentsValue filter_args = {{Value(rendered_body)}, {}}; auto result = filter_value.call(context, filter_args); out << result.to_str(); @@ -1181,7 +1181,7 @@ class UnaryOpExpr : public Expression { case Op::Expansion: case Op::ExpansionDict: throw std::runtime_error("Expansion operator is only supported in function calls and collections"); - + } throw std::runtime_error("Unknown unary operator"); } From bbc02d68f9db4603fe20d9c78011a793bbd85bd6 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Dec 2024 23:46:01 +0000 Subject: [PATCH 3/3] Robust detection of system support & tool arguments variants --- include/minja/chat-template.hpp | 68 ++++++++++++++++++------ include/minja/minja.hpp | 2 +- scripts/fetch_templates_and_goldens.py | 73 ++++++++++++++++++++------ scripts/run_tests.sh | 4 +- tests/CMakeLists.txt | 2 +- 5 files changed, 115 insertions(+), 34 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index c4aef40..aea5d36 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -32,8 +32,7 @@ class chat_template { std::string eos_token_; std::shared_ptr template_root_; - bool renders_needles( - const std::vector & needles, + std::string try_render( const nlohmann::ordered_json & messages, const nlohmann::ordered_json & tools, bool add_generation_prompt, @@ -41,14 +40,11 @@ class chat_template { { try { auto prompt = apply(messages, tools, add_generation_prompt, extra_context); - for (const auto & needle : needles) { - if (prompt.find(needle) == std::string::npos) { - return false; - } - } - return true; + // fprintf(stderr, "Prompt: %s\n", prompt.c_str()); + return prompt; } catch (const std::exception & e) { - return false; + // fprintf(stderr, "Error: %s\n", e.what()); + return ""; } } @@ -62,15 +58,58 @@ class chat_template { /* .keep_trailing_newline = */ false, }); supports_tools_ = source.find("tools") != std::string::npos; - requires_object_arguments_ = - source.find("tool_call.arguments | items") != std::string::npos - || source.find("tool_call.arguments | tojson") != std::string::npos; + + auto renders_string_arguments = + try_render({ + { + {"role", "user"}, + {"content", "Hey"} + }, + { + {"role", "assistant"}, + {"tool_calls", json::array({ + { + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"arguments", "{\"code\": \"print('Hello, World!')\"}"}, + {"name", "ipython"}, + }}, + }, + })}, + } + }, {}, false).find("{\"code\": \"print") != std::string::npos; + if (!renders_string_arguments) { + auto renders_object_arguments = + try_render({ + { + {"role", "user"}, + {"content", "Hey"} + }, + { + {"role", "assistant"}, + {"tool_calls", json::array({ + { + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"arguments", { + {"code", "print('Hello, World!')"}, + }}, + {"name", "ipython"}, + }}, + }, + })}, + } + }, {}, false).find("{\"code\": \"print") != std::string::npos; + requires_object_arguments_ = renders_object_arguments; + } supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos; - supports_system_role_ = renders_needles({""}, { + supports_system_role_ = try_render({ {{"role", "system"}, {"content", ""}}, {{"role", "user"}, {"content", "Hey"}} - }, {}, false); + }, {}, false).find("") != std::string::npos; } const std::string & source() const { return source_; } @@ -159,7 +198,6 @@ class chat_template { message.erase("name"); } - // std::string content = message["content"]; if (!message["content"].is_null() && !supports_system_role_) { std::string content = message.at("content"); if (role == "system") { diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 26f20fd..9d9a1a0 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -2557,7 +2557,7 @@ inline std::shared_ptr Context::builtins() { return (int64_t) items.size(); })); globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { - return args.at("value"); + return args.at("value").to_str(); })); globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { return args.at("value").to_str(); diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index d4fce98..e8beaa6 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -11,16 +11,17 @@ All files are written to the specified output folder. Usage: - python tests/fetch_templates_and_goldens.py output_folder context_file1.json context_file2.json ... model_id1 model_id2 ... + python scripts/fetch_templates_and_goldens.py output_folder context_file1.json context_file2.json ... model_id1 model_id2 ... Example: pip install -r requirements.txt - python tests/fetch_templates_and_goldens.py ./test_files tests/contexts/*.json microsoft/Phi-3-medium-4k-instruct Qwen/Qwen2-7B-Instruct + python scripts/fetch_templates_and_goldens.py ./test_files tests/contexts/*.json mistralai/Mistral-Large-Instruct-2407 meetkai/functionary-medium-v3.1.jinja microsoft/Phi-3-medium-4k-instruct Qwen/Qwen2-7B-Instruct ''' import logging import datetime import os +import sys from huggingface_hub import hf_hub_download import json import jinja2 @@ -37,9 +38,8 @@ def raise_exception(message: str): raise ValueError(message) -def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): - return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) - +def tojson(eval_ctx, value, indent=None): + return json.dumps(value, indent=indent) TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26') @@ -72,16 +72,61 @@ def handle_chat_template(output_folder, model_id, variant, template_src, context lstrip_blocks=True, extensions=[jinja2.ext.loopcontrols] ) + template = env.from_string(template_src) + env.filters['safe'] = lambda x: x env.filters['tojson'] = tojson env.globals['raise_exception'] = raise_exception env.globals['strftime_now'] = strftime_now template_handles_tools = 'tools' in template_src - template_hates_the_system = 'System role not supported' in template_src - - template = env.from_string(template_src) + + def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): + try: + prompt = template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **extra_context) + for str in expect_strings: + if str not in prompt: + # print(f"Expected string not found: {str}\nin prompt:\n{prompt}", file=sys.stderr, flush=True) + return False + return True + except Exception as e: + # print(f"Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) + return False + + basic_extra_context = { + "bos_token": "<|startoftext|>", + "eos_token": "<|endoftext|>", + } + renders_string_arguments = renders([ + {"role": "user", "content": "Hey"}, + {"role": "assistant", "tool_calls": [{ + "id": "call_1___", + "type": "function", + "function": { + "arguments": "{\"code\": \"print('Hello, World!')\"}", + "name": "ipython" + } + }]} + ], extra_context=basic_extra_context, expect_strings=[r'{"code": "print']) + renders_object_arguments = renders([ + {"role": "user", "content": "Hey"}, + {"role": "assistant", "tool_calls": [{ + "id": "call_1___", + "type": "function", + "function": { + "arguments": {"code": "print('Hello, World!')"}, + "name": "ipython" + } + }]} + ], extra_context=basic_extra_context, expect_strings=[r'{"code": "print']) + requires_object_arguments = not renders_string_arguments and renders_object_arguments + + supports_system_role = renders([ + {"role": "system", "content": "System Needle"}, + {"role": "user", "content": "Hey"} + ], extra_context=basic_extra_context, expect_strings=["System Needle"]) + for context_file in context_files: context_name = os.path.basename(context_file).replace(".json", "") with open(context_file, 'r') as f: @@ -90,15 +135,13 @@ def handle_chat_template(output_folder, model_id, variant, template_src, context if not template_handles_tools and 'tools' in context: continue - if template_hates_the_system and any(m['role'] == 'system' for m in context['messages']): + if not supports_system_role and any(m['role'] == 'system' for m in context['messages']): continue output_file = join_cmake_path(output_folder, f'{base_name}-{context_name}.txt') - render_context = json.loads(json.dumps(context)) - - if 'tool_call.arguments | items' in template_src or 'tool_call.arguments | tojson' in template_src: - for message in render_context['messages']: + if requires_object_arguments: + for message in context['messages']: if 'tool_calls' in message: for tool_call in message['tool_calls']: if tool_call.get('type') == 'function': @@ -106,14 +149,14 @@ def handle_chat_template(output_folder, model_id, variant, template_src, context tool_call['function']['arguments'] = json.loads(arguments) try: - output = template.render(**render_context) + output = template.render(**context) except Exception as e1: for message in context["messages"]: if message.get("content") is None: message["content"] = "" try: - output = template.render(**render_context) + output = template.render(**context) except Exception as e2: logger.info(f" ERROR: {e2} (after first error: {e1})") output = f"ERROR: {e2}" diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index a3efbdd..3965a01 100755 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -13,6 +13,6 @@ # set -euo pipefail -cmake -B build "$@" -DCMAKE_BUILD_TYPE=Release && \ +cmake -B build -DCMAKE_BUILD_TYPE=Release && \ cmake --build build -j --config Release && \ - ctest --test-dir build -j -C Release --output-on-failure + ctest --test-dir build -j -C Release --output-on-failure "$@" diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0adab92..e9a3f04 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -43,7 +43,6 @@ set(MODEL_IDS google/gemma-7b-it # Gated indischepartij/MiniCPM-3B-OpenHermes-2.5-v2 mattshumer/Reflection-Llama-3.1-70B - meetkai/functionary-medium-v3.1 meetkai/functionary-medium-v3.2 meta-llama/Llama-3.2-3B-Instruct # Gated meta-llama/Meta-Llama-3.1-8B-Instruct # Gated @@ -74,6 +73,7 @@ set(MODEL_IDS TheBloke/FusionNet_34Bx2_MoE-AWQ # Broken, TODO: + # meetkai/functionary-medium-v3.1 # jinja2 expectation is computed w/ wrong escapes # fireworks-ai/llama-3-firefunction-v2 # https://github.com/google/minja/issues/7 # ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8