4545
4646from openadapt_ml .datasets .next_action import SYSTEM_PROMPT
4747from openadapt_ml .training .grpo .config import GRPOConfig
48- from openadapt_ml .training .grpo .reward import compute_group_advantages
48+ from openadapt_ml .training .grpo .reward import (
49+ compute_group_advantages ,
50+ evaluate_milestones_screenshot ,
51+ )
4952from openadapt_ml .training .grpo .rollout_collector import (
5053 GRPORolloutCollector ,
5154 Rollout ,
5255)
5356
57+ # Optional import for TaskConfig (openadapt-evals may not be installed)
58+ try :
59+ from openadapt_evals .task_config import TaskConfig
60+
61+ _HAS_TASK_CONFIG = True
62+ except ImportError :
63+ TaskConfig = None # type: ignore[assignment, misc]
64+ _HAS_TASK_CONFIG = False
65+
5466logger = logging .getLogger (__name__ )
5567
5668DEFAULT_SCREEN_SIZE : tuple [int , int ] = (1920 , 1080 )
@@ -301,6 +313,106 @@ def __init__(self, config: GRPOConfig) -> None:
301313 self ._optimizer : Any = None
302314 self ._collector : GRPORolloutCollector | None = None
303315 self ._step : int = 0
316+ self ._task_configs : dict [str , Any ] = {}
317+
318+ # Load task configs from --task-dir if specified
319+ if config .task_dir :
320+ self ._load_task_configs (config .task_dir )
321+
322+ def _load_task_configs (self , task_dir : str ) -> None :
323+ """Load TaskConfig YAMLs from a directory.
324+
325+ Populates ``self._task_configs`` (keyed by task ID) and auto-fills
326+ ``config.task_ids`` if it was left empty.
327+
328+ Args:
329+ task_dir: Path to directory containing YAML/JSON task configs.
330+
331+ Raises:
332+ ImportError: If openadapt-evals is not installed.
333+ FileNotFoundError: If the directory does not exist.
334+ """
335+ if not _HAS_TASK_CONFIG :
336+ raise ImportError (
337+ "openadapt-evals is required for --task-dir support. "
338+ "Install with: pip install openadapt-evals"
339+ )
340+
341+ task_dir_path = Path (task_dir )
342+ if not task_dir_path .is_dir ():
343+ raise FileNotFoundError (f"Task directory not found: { task_dir } " )
344+
345+ configs = TaskConfig .from_dir (str (task_dir_path ))
346+ if not configs :
347+ raise ValueError (f"No task configs found in { task_dir } " )
348+
349+ for tc in configs :
350+ self ._task_configs [tc .id ] = tc
351+ logger .info (
352+ "Loaded task config: %s (%s) — %d milestones" ,
353+ tc .id ,
354+ tc .name [:50 ],
355+ len (tc .milestones ),
356+ )
357+
358+ # Auto-populate task_ids if empty
359+ if not self ._config .task_ids :
360+ self ._config .task_ids = list (self ._task_configs .keys ())
361+ logger .info (
362+ "Auto-populated task_ids from task_dir: %s" ,
363+ self ._config .task_ids ,
364+ )
365+
366+ def _compute_milestone_reward (
367+ self ,
368+ task_id : str ,
369+ screenshot_bytes : bytes ,
370+ ) -> float :
371+ """Compute milestone-based reward for a task using VLM judge.
372+
373+ Evaluates screenshot-type milestones locally without needing the
374+ WAA /evaluate endpoint. Falls back to 0.0 if the task has no
375+ milestones or the task_id is not found in loaded configs.
376+
377+ Args:
378+ task_id: The task ID to look up in loaded configs.
379+ screenshot_bytes: PNG screenshot bytes to evaluate.
380+
381+ Returns:
382+ Fraction of screenshot milestones passed (0.0 to 1.0).
383+ """
384+ task_config = self ._task_configs .get (task_id )
385+ if task_config is None :
386+ return 0.0
387+ return evaluate_milestones_screenshot (task_config , screenshot_bytes )
388+
389+ def _compute_milestone_reward_from_rollout (
390+ self ,
391+ rollout : Rollout ,
392+ ) -> float | None :
393+ """Extract the last screenshot from a rollout and compute milestone reward.
394+
395+ Returns None if no task config or no screenshot is available,
396+ signalling the caller to keep the existing reward.
397+ """
398+ task_config = self ._task_configs .get (rollout .task_id )
399+ if task_config is None or not getattr (task_config , "milestones" , None ):
400+ return None
401+
402+ # Find the last step with a screenshot
403+ screenshot_bytes : bytes | None = None
404+ for step in reversed (rollout .steps ):
405+ obs = getattr (step , "observation" , None )
406+ if obs is not None :
407+ ss = getattr (obs , "screenshot" , None )
408+ if ss :
409+ screenshot_bytes = ss
410+ break
411+
412+ if not screenshot_bytes :
413+ return None
414+
415+ return evaluate_milestones_screenshot (task_config , screenshot_bytes )
304416
305417 def _make_agent_fn (self ) -> Callable :
306418 """Create agent closure: observation -> BenchmarkAction.
@@ -381,20 +493,26 @@ def train(self) -> str:
381493 if not self ._config .task_ids :
382494 raise ValueError (
383495 "config.task_ids must be non-empty. Provide at least one "
384- "WAA task ID to train on."
496+ "WAA task ID to train on, or use --task-dir to load from "
497+ "YAML files."
385498 )
386499
387500 logger .info ("Starting GRPO training" )
388501 logger .info (" Model: %s" , self ._config .model_name )
389502 logger .info (" Tasks: %s" , self ._config .task_ids )
503+ logger .info (" Task dir: %s" , self ._config .task_dir or "(none)" )
504+ logger .info (" Task configs loaded: %d" , len (self ._task_configs ))
390505 logger .info (" Rollouts/step: %d" , self ._config .num_rollouts_per_step )
391506 logger .info (" Training steps: %d" , self ._config .num_training_steps )
392507
393508 # Setup
394509 self ._model , self ._processor = _load_model_and_processor (self ._config )
395510 trainable = [p for p in self ._model .parameters () if p .requires_grad ]
396511 self ._optimizer = torch .optim .AdamW (trainable , lr = self ._config .learning_rate )
397- self ._collector = GRPORolloutCollector (self ._config )
512+ self ._collector = GRPORolloutCollector (
513+ self ._config ,
514+ task_configs = self ._task_configs if self ._task_configs else None ,
515+ )
398516
399517 Path (self ._config .output_dir ).mkdir (parents = True , exist_ok = True )
400518 agent_fn = self ._make_agent_fn ()
@@ -409,6 +527,16 @@ def train(self) -> str:
409527 self ._model .eval ()
410528 rollouts = self ._collector .collect_group (agent_fn = agent_fn , task_id = task_id )
411529
530+ # If task configs with milestones are loaded, override the
531+ # binary rewards with milestone-based dense rewards.
532+ if self ._task_configs :
533+ for rollout in rollouts :
534+ milestone_reward = self ._compute_milestone_reward_from_rollout (
535+ rollout
536+ )
537+ if milestone_reward is not None :
538+ rollout .reward = max (rollout .reward , milestone_reward )
539+
412540 # Train (gradient update)
413541 self ._model .train ()
414542 metrics = self ._training_step (rollouts )
0 commit comments