Skip to content

Commit 51fa458

Browse files
authored
server : support preserving reasoning_content in assistant message (ggml-org#18994)
* support reasoning_content input * report template caps to webui * add docs * rm commented code
1 parent a5eaa1d commit 51fa458

10 files changed

Lines changed: 164 additions & 130 deletions

File tree

common/chat-parser.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,7 +1630,7 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
16301630
}
16311631
auto msg = builder.result();
16321632
if (!is_partial) {
1633-
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
1633+
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str());
16341634
}
16351635
return msg;
16361636
}
@@ -1663,7 +1663,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std
16631663
mapper.from_ast(ctx.ast, result);
16641664
}
16651665
if (!is_partial) {
1666-
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
1666+
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str());
16671667
}
16681668
return msg;
16691669
}

common/chat.cpp

Lines changed: 71 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
#include "log.h"
88
#include "regex-partial.h"
99

10-
// #include <minja/chat-template.hpp>
11-
// #include <minja/minja.hpp>
12-
1310
#include "jinja/parser.h"
1411
#include "jinja/value.h"
1512
#include "jinja/runtime.h"
@@ -56,39 +53,73 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) {
5653
return !msg.content.empty() || !msg.tool_calls.empty();
5754
}
5855

59-
template <>
60-
json common_chat_msg::to_json_oaicompat() const
61-
{
62-
json message {
63-
{"role", "assistant"},
56+
json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
57+
if (!content.empty() && !content_parts.empty()) {
58+
throw std::runtime_error("Cannot specify both content and content_parts");
59+
}
60+
json jmsg {
61+
{"role", role},
6462
};
63+
if (!content.empty()) {
64+
jmsg["content"] = content;
65+
} else if (!content_parts.empty()) {
66+
if (concat_typed_text) {
67+
std::string text;
68+
for (const auto & part : content_parts) {
69+
if (part.type != "text") {
70+
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
71+
continue;
72+
}
73+
if (!text.empty()) {
74+
text += '\n';
75+
}
76+
text += part.text;
77+
}
78+
jmsg["content"] = text;
79+
} else {
80+
auto & parts = jmsg["content"] = json::array();
81+
for (const auto & part : content_parts) {
82+
parts.push_back({
83+
{"type", part.type},
84+
{"text", part.text},
85+
});
86+
}
87+
}
88+
} else {
89+
jmsg["content"] = "";
90+
}
6591
if (!reasoning_content.empty()) {
66-
message["reasoning_content"] = reasoning_content;
92+
jmsg["reasoning_content"] = reasoning_content;
6793
}
68-
if (content.empty() && !tool_calls.empty()) {
69-
message["content"] = json();
70-
} else {
71-
message["content"] = content;
94+
if (!tool_name.empty()) {
95+
jmsg["name"] = tool_name;
96+
}
97+
if (!tool_call_id.empty()) {
98+
jmsg["tool_call_id"] = tool_call_id;
7299
}
73100
if (!tool_calls.empty()) {
74-
auto arr = json::array();
75-
for (const auto & tc : tool_calls) {
76-
arr.push_back({
101+
jmsg["tool_calls"] = json::array();
102+
auto & jtool_calls = jmsg["tool_calls"];
103+
for (const auto & tool_call : tool_calls) {
104+
json tc {
77105
{"type", "function"},
78106
{"function", {
79-
{"name", tc.name},
80-
{"arguments", tc.arguments},
107+
{"name", tool_call.name},
108+
{"arguments", tool_call.arguments},
81109
}},
82-
{"id", tc.id},
83-
// // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
84-
// // We only generate a random id for the ones that don't generate one by themselves
85-
// // (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
86-
// {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
87-
});
110+
};
111+
if (!tool_call.id.empty()) {
112+
tc["id"] = tool_call.id;
113+
}
114+
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
115+
// We only generate a random id for the ones that don't generate one by themselves
116+
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
117+
// {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
118+
jtool_calls.push_back(tc);
88119
}
89-
message["tool_calls"] = arr;
90120
}
91-
return message;
121+
122+
return jmsg;
92123
}
93124

94125
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) {
@@ -256,7 +287,6 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
256287
return rendered_no_thinking.prompt != rendered_with_thinking.prompt;
257288
}
258289

259-
template <>
260290
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
261291
std::vector<common_chat_msg> msgs;
262292

@@ -350,80 +380,15 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
350380
return msgs;
351381
}
352382

353-
template <>
354383
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
355384
json messages = json::array();
356385
for (const auto & msg : msgs) {
357-
if (!msg.content.empty() && !msg.content_parts.empty()) {
358-
throw std::runtime_error("Cannot specify both content and content_parts");
359-
}
360-
json jmsg {
361-
{"role", msg.role},
362-
};
363-
if (!msg.content.empty()) {
364-
jmsg["content"] = msg.content;
365-
} else if (!msg.content_parts.empty()) {
366-
if (concat_typed_text) {
367-
std::string text;
368-
for (const auto & part : msg.content_parts) {
369-
if (part.type != "text") {
370-
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
371-
continue;
372-
}
373-
if (!text.empty()) {
374-
text += '\n';
375-
}
376-
text += part.text;
377-
}
378-
jmsg["content"] = text;
379-
} else {
380-
auto & parts = jmsg["content"] = json::array();
381-
for (const auto & part : msg.content_parts) {
382-
parts.push_back({
383-
{"type", part.type},
384-
{"text", part.text},
385-
});
386-
}
387-
}
388-
} else {
389-
jmsg["content"] = "";
390-
}
391-
if (!msg.reasoning_content.empty()) {
392-
jmsg["reasoning_content"] = msg.reasoning_content;
393-
}
394-
if (!msg.tool_name.empty()) {
395-
jmsg["name"] = msg.tool_name;
396-
}
397-
if (!msg.tool_call_id.empty()) {
398-
jmsg["tool_call_id"] = msg.tool_call_id;
399-
}
400-
if (!msg.tool_calls.empty()) {
401-
auto & tool_calls = jmsg["tool_calls"] = json::array();
402-
for (const auto & tool_call : msg.tool_calls) {
403-
json tc {
404-
{"type", "function"},
405-
{"function", {
406-
{"name", tool_call.name},
407-
{"arguments", tool_call.arguments},
408-
}},
409-
};
410-
if (!tool_call.id.empty()) {
411-
tc["id"] = tool_call.id;
412-
}
413-
tool_calls.push_back(tc);
414-
}
415-
}
386+
json jmsg = msg.to_json_oaicompat(concat_typed_text);
416387
messages.push_back(jmsg);
417388
}
418389
return messages;
419390
}
420391

421-
template <>
422-
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
423-
return common_chat_msgs_parse_oaicompat(json::parse(messages));
424-
}
425-
426-
template <>
427392
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
428393
std::vector<common_chat_tool> result;
429394

@@ -459,12 +424,6 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
459424
return result;
460425
}
461426

462-
template <>
463-
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
464-
return common_chat_tools_parse_oaicompat(json::parse(tools));
465-
}
466-
467-
template <>
468427
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
469428
if (tools.empty()) {
470429
return json();
@@ -484,7 +443,7 @@ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & t
484443
return result;
485444
}
486445

487-
template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
446+
json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
488447
json delta = json::object();
489448
if (!diff.reasoning_content_delta.empty()) {
490449
delta["reasoning_content"] = diff.reasoning_content_delta;
@@ -2867,13 +2826,13 @@ static common_chat_params common_chat_templates_apply_jinja(
28672826
const struct common_chat_templates_inputs & inputs)
28682827
{
28692828
templates_params params;
2870-
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
2829+
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
28712830
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
28722831
? *tmpls->template_tool_use
28732832
: *tmpls->template_default;
28742833
const auto & src = tmpl.source();
28752834
const auto & caps = tmpl.original_caps();
2876-
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
2835+
params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
28772836
params.add_generation_prompt = inputs.add_generation_prompt;
28782837
params.tool_choice = inputs.tool_choice;
28792838
params.reasoning_format = inputs.reasoning_format;
@@ -2943,6 +2902,10 @@ static common_chat_params common_chat_templates_apply_jinja(
29432902
src.find("<arg_value>") != std::string::npos &&
29442903
params.json_schema.is_null()) {
29452904
workaround::func_args_not_string(params.messages);
2905+
if (!params.extra_context.contains("clear_thinking")) {
2906+
// by default, do not clear reasoning_content (added since GLM-4.7)
2907+
params.extra_context["clear_thinking"] = false;
2908+
}
29462909
return common_chat_params_init_glm_4_5(tmpl, params);
29472910
}
29482911

@@ -3174,3 +3137,9 @@ common_chat_params common_chat_templates_apply(
31743137
? common_chat_templates_apply_jinja(tmpls, inputs)
31753138
: common_chat_templates_apply_legacy(tmpls, inputs);
31763139
}
3140+
3141+
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates) {
3142+
GGML_ASSERT(chat_templates != nullptr);
3143+
GGML_ASSERT(chat_templates->template_default != nullptr);
3144+
return chat_templates->template_default->caps.to_map();
3145+
}

common/chat.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <vector>
1111
#include <map>
1212

13+
#include <nlohmann/json_fwd.hpp>
14+
1315
struct common_chat_templates;
1416

1517
struct common_chat_tool_call {
@@ -26,6 +28,11 @@ struct common_chat_msg_content_part {
2628
std::string type;
2729
std::string text;
2830

31+
// TODO @ngxson : no known chat templates support reasoning_content in content parts yet
32+
// this can be useful for models with interleaved thinking (like Kimi-K2)
33+
// if you see any templates explicitly support this, please ping me
34+
// std::string reasoning_content;
35+
2936
bool operator==(const common_chat_msg_content_part & other) const {
3037
return type == other.type && text == other.text;
3138
}
@@ -40,7 +47,7 @@ struct common_chat_msg {
4047
std::string tool_name;
4148
std::string tool_call_id;
4249

43-
template <class T> T to_json_oaicompat() const;
50+
nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const;
4451

4552
bool empty() const {
4653
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
@@ -232,13 +239,13 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin
232239
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
233240

234241
// Parses a JSON array of messages in OpenAI's chat completion API format.
235-
// T can be std::string containing JSON or nlohmann::ordered_json
236-
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
237-
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
242+
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);
243+
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
244+
245+
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
246+
nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
238247

239-
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
240-
// T can be std::string containing JSON or nlohmann::ordered_json
241-
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
242-
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
248+
nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
243249

244-
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
250+
// get template caps, useful for reporting to server /props endpoint
251+
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates);

common/jinja/caps.cpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,23 @@ static void caps_print_stats(value & v, const std::string & path) {
6161
ops.c_str());
6262
}
6363

64+
std::map<std::string, bool> caps::to_map() const {
65+
return {
66+
{"requires_typed_content", requires_typed_content},
67+
{"supports_tools", supports_tools},
68+
{"supports_tool_calls", supports_tool_calls},
69+
{"supports_parallel_tool_calls", supports_parallel_tool_calls},
70+
{"supports_system_role", supports_system_role},
71+
{"supports_preserve_reasoning", supports_preserve_reasoning},
72+
};
73+
}
74+
6475
std::string caps::to_string() const {
6576
std::ostringstream ss;
6677
ss << "Caps(\n";
67-
ss << " requires_typed_content=" << requires_typed_content << "\n";
68-
ss << " supports_tools=" << supports_tools << "\n";
69-
ss << " supports_tool_calls=" << supports_tool_calls << "\n";
70-
ss << " supports_parallel_tool_calls=" << supports_parallel_tool_calls << "\n";
71-
ss << " supports_system_role=" << supports_system_role << "\n";
78+
for (const auto & [key, value] : to_map()) {
79+
ss << " " << key << "=" << (value ? "true" : "false") << "\n";
80+
}
7281
ss << ")";
7382
return ss.str();
7483
}
@@ -229,6 +238,40 @@ caps caps_get(jinja::program & prog) {
229238
}
230239
);
231240

241+
// case: preserve reasoning content in chat history
242+
caps_try_execute(
243+
prog,
244+
[&]() {
245+
// messages
246+
return json::array({
247+
{
248+
{"role", "user"},
249+
{"content", "User message"}
250+
},
251+
{
252+
{"role", "assistant"},
253+
{"content", "Assistant message"},
254+
{"reasoning_content", "Reasoning content"}
255+
},
256+
{
257+
{"role", "user"},
258+
{"content", "User message"}
259+
},
260+
});
261+
},
262+
[&]() {
263+
// tools
264+
return json::array();
265+
},
266+
[&](bool, value & messages, value &) {
267+
auto & content = messages->at(1)->at("reasoning_content");
268+
caps_print_stats(content, "messages[1].reasoning_content");
269+
if (content->stats.used) {
270+
result.supports_preserve_reasoning = true;
271+
}
272+
}
273+
);
274+
232275
JJ_DEBUG("%s\n", result.to_string().c_str());
233276

234277
return result;

0 commit comments

Comments
 (0)