Skip to content

Commit b5b8fa1

Browse files
authored
chat : fix translategemma crash on common_chat_format_example (ggml-org#19019)
1 parent a14b960 commit b5b8fa1

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

common/chat.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2650,6 +2650,45 @@ static common_chat_params common_chat_params_init_exaone_moe(const common_chat_t
26502650
return data;
26512651
}
26522652

2653+
static common_chat_params common_chat_params_init_translate_gemma(const common_chat_template & tmpl, const struct templates_params & inputs) {
2654+
common_chat_params data;
2655+
2656+
// This template does not support tools or reasoning
2657+
// we just need to transform the messages into the correct schema
2658+
2659+
templates_params inputs_new = inputs;
2660+
json & messages = inputs_new.messages;
2661+
2662+
GGML_ASSERT(messages.is_array());
2663+
for (auto & message : messages) {
2664+
if (message.contains("role") && message["role"].get<std::string>() != "user") {
2665+
continue;
2666+
}
2667+
if (!message.contains("content")) {
2668+
message["content"] = json::array();
2669+
}
2670+
if (message.contains("content") && !message["content"].is_array()) {
2671+
auto content_str = message["content"].get<std::string>();
2672+
// default to en-GB if not specified (to make common_chat_format_example works)
2673+
auto src_lang = message.contains("source_lang_code") ? message["source_lang_code"].get<std::string>() : "en-GB";
2674+
auto tgt_lang = message.contains("target_lang_code") ? message["target_lang_code"].get<std::string>() : "en-GB";
2675+
message["content"] = json::array({
2676+
json{
2677+
{"type", "text"},
2678+
{"text", content_str},
2679+
{"source_lang_code", src_lang},
2680+
{"target_lang_code", tgt_lang},
2681+
}
2682+
});
2683+
}
2684+
}
2685+
2686+
data.prompt = apply(tmpl, inputs_new, std::nullopt, std::nullopt);
2687+
data.format = COMMON_CHAT_FORMAT_GENERIC;
2688+
2689+
return data;
2690+
}
2691+
26532692
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
26542693
common_chat_params data;
26552694
data.prompt = apply(tmpl, inputs);
@@ -3045,6 +3084,12 @@ static common_chat_params common_chat_templates_apply_jinja(
30453084
return common_chat_params_init_solar_open(tmpl, params);
30463085
}
30473086

3087+
// TranslateGemma
3088+
if (src.find("[source_lang_code]") != std::string::npos &&
3089+
src.find("[target_lang_code]") != std::string::npos) {
3090+
return common_chat_params_init_translate_gemma(tmpl, params);
3091+
}
3092+
30483093
// Plain handler (no tools)
30493094
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
30503095
return common_chat_params_init_without_tools(tmpl, params);

0 commit comments

Comments
 (0)