Skip to content

Commit 04e6e9f

Browse files
abrichrclaude
andauthored
fix: align GRPO prompt format with SFT training format (#61)
The GRPO rollout prompt was missing the "Thought:" line and action history that the SFT training uses. Models fine-tuned via SFT output "Thought: ...\nAction: CLICK(...)" but the GRPO prompt didn't prompt for this format, causing verbose free-form output that couldn't be parsed → reward 0.0 → zero gradients. Changes: - Add "Thought:" and "Action:" prompt lines matching SFT format - Add action_history parameter for step context - Parser extracts action from "Action: ..." line before regex matching - Parser handles JSON format {"action_type": "click", "coordinate": [x,y]} - Debug logging of raw VLM output for zero-reward diagnosis Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7d095da commit 04e6e9f

1 file changed

Lines changed: 43 additions & 3 deletions

File tree

openadapt_ml/training/grpo/trainer.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,16 @@ def policy_gradient_loss(
110110

111111

112112
def _build_agent_messages(
113-
instruction: str, *, include_image: bool = False
113+
instruction: str,
114+
*,
115+
include_image: bool = False,
116+
action_history: str = "",
114117
) -> list[dict]:
115118
"""Build chat messages for the GRPO agent.
116119
117-
Uses the same SYSTEM_PROMPT as SFT training so GRPO operates in
118-
the same prompt distribution the model was warm-started on.
120+
Uses the same SYSTEM_PROMPT and prompt format as SFT training
121+
(``next_action.py``) so GRPO operates in the same prompt
122+
distribution the model was warm-started on.
119123
120124
This is the **single source of truth** for prompt construction
121125
during both rollout collection and loss computation.
@@ -125,10 +129,15 @@ def _build_agent_messages(
125129
include_image: If True, include an image placeholder in the user
126130
message so ``apply_chat_template`` inserts ``<|image_pad|>``
127131
tokens required by Qwen2.5-VL and similar VLMs.
132+
action_history: Formatted action history from previous steps
133+
(e.g. "Step 1: CLICK(x=0.5, y=0.3)\\nStep 2: TYPE(...)").
128134
"""
135+
history_text = f"{action_history}\n" if action_history else ""
129136
text_content = (
130137
f"Goal: {instruction}\n\n"
138+
f"{history_text}"
131139
"Look at the screenshot and determine the NEXT action.\n\n"
140+
"Thought: [what element to interact with and why]\n"
132141
'Action: [CLICK(x=..., y=...) or TYPE(text="...") or WAIT() or DONE()]'
133142
)
134143
if include_image:
@@ -173,6 +182,37 @@ class BenchmarkAction: # type: ignore[no-redef]
173182
text = text.strip()
174183
width, height = screen_size
175184

185+
# Log raw output for debugging zero-reward issues
186+
logger.debug("Parsing VLM output (%d chars): %.200s", len(text), text)
187+
188+
# Extract action from "Thought: ...\nAction: ..." format (SFT output)
189+
action_match = re.search(r"Action:\s*(.+)", text, re.IGNORECASE)
190+
if action_match:
191+
text = action_match.group(1).strip()
192+
193+
# Try JSON format: {"action_type": "click", "coordinate": [x, y]}
194+
json_match = re.search(r'\{[^}]*"action_type"[^}]*\}', text)
195+
if json_match:
196+
try:
197+
import json as _json
198+
action_data = _json.loads(json_match.group())
199+
atype = action_data.get("action_type", "").lower()
200+
coord = action_data.get("coordinate", action_data.get("coords", []))
201+
if atype == "click" and len(coord) >= 2:
202+
x_val, y_val = float(coord[0]), float(coord[1])
203+
# Handle both normalized (0-1) and pixel coordinates
204+
if x_val <= 1.0 and y_val <= 1.0:
205+
x_val, y_val = x_val * width, y_val * height
206+
return BenchmarkAction(type="click", x=int(x_val), y=int(y_val))
207+
if atype == "type":
208+
return BenchmarkAction(
209+
type="type", text=action_data.get("text", "")
210+
)
211+
if atype in ("done", "wait"):
212+
return BenchmarkAction(type=atype)
213+
except Exception:
214+
pass # Fall through to regex parsing
215+
176216
# CLICK(x=..., y=...)
177217
m = re.search(r"CLICK\(x=(-?[\d.]+),\s*y=(-?[\d.]+)\)", text, re.IGNORECASE)
178218
if m:

0 commit comments

Comments
 (0)