|
2 | 2 | from dataclasses import dataclass, field |
3 | 3 | from functools import cached_property |
4 | 4 | from itertools import takewhile |
| 5 | +import json |
5 | 6 | import math |
6 | 7 | import random |
7 | 8 | from typing import Any, Generator, cast |
@@ -31,6 +32,40 @@ def _normalize_tools_for_chat_template(tools: Any) -> list[ChatTemplateTool] | N |
31 | 32 | return normalized_tools |
32 | 33 |
|
33 | 34 |
|
| 35 | +def _normalize_tool_call_arguments_for_chat_template( |
| 36 | + tokenizer: PreTrainedTokenizerBase, |
| 37 | + messages: list[dict[str, Any]], |
| 38 | +) -> list[dict[str, Any]]: |
| 39 | + chat_template = tokenizer.chat_template |
| 40 | + assert isinstance(chat_template, str) |
| 41 | + if "tool_call.arguments|items" not in chat_template: |
| 42 | + return messages |
| 43 | + |
| 44 | + normalized_messages: list[dict[str, Any]] = [] |
| 45 | + for message in messages: |
| 46 | + tool_calls = message.get("tool_calls") |
| 47 | + if tool_calls is None: |
| 48 | + normalized_messages.append(message) |
| 49 | + continue |
| 50 | + |
| 51 | + assert isinstance(tool_calls, list) |
| 52 | + normalized_tool_calls = [] |
| 53 | + for tool_call in tool_calls: |
| 54 | + assert isinstance(tool_call, dict) |
| 55 | + function = tool_call["function"] |
| 56 | + assert isinstance(function, dict) |
| 57 | + arguments_json = function["arguments"] |
| 58 | + assert isinstance(arguments_json, str) |
| 59 | + arguments = json.loads(arguments_json) |
| 60 | + assert isinstance(arguments, dict) |
| 61 | + normalized_tool_calls.append( |
| 62 | + {**tool_call, "function": {**function, "arguments": arguments}} |
| 63 | + ) |
| 64 | + normalized_messages.append({**message, "tool_calls": normalized_tool_calls}) |
| 65 | + |
| 66 | + return normalized_messages |
| 67 | + |
| 68 | + |
34 | 69 | @dataclass |
35 | 70 | class TokenizedResult: |
36 | 71 | advantage: float |
@@ -223,20 +258,23 @@ def tokenize_trajectory( |
223 | 258 | if last_assistant_index == -1: |
224 | 259 | return None |
225 | 260 | messages_and_choices = history.messages_and_choices[: last_assistant_index + 1] |
226 | | - messages = get_messages(messages_and_choices) |
| 261 | + messages = cast(list[dict[str, Any]], get_messages(messages_and_choices)) |
| 262 | + # Qwen3.5's chat template uses `tool_call.arguments|items`, so it needs a |
| 263 | + # mapping here instead of the OpenAI JSON string. |
| 264 | + messages = _normalize_tool_call_arguments_for_chat_template(tokenizer, messages) |
227 | 265 | tools = _normalize_tools_for_chat_template(history.tools) |
228 | 266 | chat = cast( |
229 | 267 | str, |
230 | 268 | tokenizer.apply_chat_template( |
231 | | - cast(list[dict], messages), |
| 269 | + messages, |
232 | 270 | tools=tools, |
233 | 271 | continue_final_message=True, |
234 | 272 | tokenize=False, |
235 | 273 | ), |
236 | 274 | ) |
237 | 275 | original_token_ids = _apply_chat_template_token_ids( |
238 | 276 | tokenizer, |
239 | | - cast(list[dict[str, Any]], messages), |
| 277 | + messages, |
240 | 278 | tools=tools, |
241 | 279 | continue_final_message=True, |
242 | 280 | ) |
|
0 commit comments