4848 build_rl_train_configs ,
4949)
5050from ..backend import AnyTrainableModel , Backend
51+ from ..costs import build_cost_calculator , get_model_pricing
5152from ..metrics_taxonomy import (
5253 TRAIN_GRADIENT_STEPS_KEY ,
5354 build_training_summary_metrics ,
@@ -185,26 +186,23 @@ async def close(self) -> None:
185186 """
186187 If running vLLM in a separate process, this will kill that process and close the communication threads.
187188 """
188- await self ._aclose ()
189+ for service in self ._services .values ():
190+ aclose = getattr (service , "aclose" , None )
191+ if aclose is None :
192+ close = getattr (service , "close" , None )
193+ if close is not None :
194+ close ()
195+ else :
196+ await aclose ()
197+ close_proxy (service )
189198
190199 def _close (self ) -> None :
191- for _ , service in self ._services .items ():
200+ for service in self ._services .values ():
192201 close = getattr (service , "close" , None )
193202 if close is not None :
194203 close ()
195204 close_proxy (service )
196205
197- async def _aclose (self ) -> None :
198- for _ , service in self ._services .items ():
199- aclose = getattr (service , "aclose" , None )
200- if aclose is not None :
201- await aclose ()
202- else :
203- close = getattr (service , "close" , None )
204- if close is not None :
205- close ()
206- close_proxy (service )
207-
208206 async def register (
209207 self ,
210208 model : Model ,
@@ -231,6 +229,11 @@ async def register(
231229 # (wandb initialization is now handled by the model's _get_wandb_run method)
232230 if model .trainable and "WANDB_API_KEY" in os .environ :
233231 _ = model ._get_wandb_run ()
232+ if model .trainable :
233+ trainable_model = cast (TrainableModel , model )
234+ pricing = get_model_pricing (trainable_model .base_model )
235+ if pricing is not None :
236+ trainable_model .set_cost_calculator (build_cost_calculator (pricing ))
234237
235238 def _model_inference_name (self , model : Model , step : int | None = None ) -> str :
236239 """Return the inference name for a model checkpoint.
@@ -244,25 +247,27 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
244247 If None, returns name for latest checkpoint (step 0 initially).
245248 """
246249
247- # For LocalBackend, vLLM always serves LoRA adapters with @step suffix
248- # Default to step 0 when not specified (the initial checkpoint created at registration)
249- if step is not None :
250- actual_step = step
251- elif model .name in self ._services and self ._in_process :
252- # In dedicated mode the service tracks which adapter vLLM has
253- # actually loaded. Reading the filesystem would race: the
254- # checkpoint directory appears before the HTTP reload completes.
255- svc = self ._services [model .name ]
256- loaded_step = getattr (svc , "_latest_step" , None )
257- actual_step = (
258- loaded_step if loaded_step is not None else self .__get_step (model )
259- )
260- else :
261- actual_step = self .__get_step (model )
262- name = f"{ model .name } @{ actual_step } "
250+ requested_step = step
251+
252+ if step is None and isinstance (model , TrainableModel ):
253+ from ..dev .validate import is_dedicated_mode
254+
255+ service = self ._services .get (model .name )
256+ if service is not None and is_dedicated_mode (
257+ model ._internal_config or dev .InternalModelConfig ()
258+ ):
259+ loaded_step = getattr (service , "_latest_step" , None )
260+ if isinstance (loaded_step , int ):
261+ step = loaded_step
262+
263+ if step is None :
264+ # The checkpoint directory is written before dedicated-mode
265+ # vLLM finishes reloading the new adapter.
266+ step = self .__get_step (model )
267+ name = f"{ model .name } @{ step } "
263268 logger .debug (
264- f"[BACKEND] _model_inference_name: step_arg={ step } "
265- f"actual_step={ actual_step } -> { name } "
269+ f"[BACKEND] _model_inference_name: step_arg={ requested_step } "
270+ f"actual_step={ step } -> { name } "
266271 )
267272 return name
268273
@@ -527,13 +532,14 @@ async def train( # type: ignore[override]
527532 * ,
528533 # Core training parameters
529534 learning_rate : float = 5e-6 ,
530- loss_fn : Literal ["cispo" , "ppo" ] | None = None ,
535+ loss_fn : Literal ["cispo" , "ppo" ] = "cispo" ,
536+ loss_fn_config : dict | None = None ,
537+ normalize_advantages : bool = True ,
538+ adam_params : object | None = None ,
531539 # KL-penalized advantage adjustment
532540 kl_penalty_coef : float = 0.0 ,
533541 kl_penalty_reference_step : int | None = None ,
534542 kl_ref_adapter_path : str | None = None ,
535- # RL algorithm settings
536- ppo : bool = False ,
537543 epsilon : float | None = None ,
538544 epsilon_high : float | None = None ,
539545 # Advantage computation
@@ -570,6 +576,14 @@ async def train( # type: ignore[override]
570576 model: The trainable model to train.
571577 trajectory_groups: Batches of trajectories to train on.
572578 learning_rate: Learning rate for training. Defaults to 5e-6.
579+ loss_fn: RL loss function. LocalBackend currently supports
580+ "cispo" and "ppo".
581+ loss_fn_config: Additional loss-function config. Not supported by
582+ LocalBackend.
583+ normalize_advantages: Whether to normalize advantages. LocalBackend
584+ currently requires True.
585+ adam_params: Custom optimizer params. Not supported by
586+ LocalBackend.
573587 kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
574588 Tokens diverging more from the reference get reduced advantages.
575589 Defaults to 0.0 (disabled).
@@ -579,8 +593,7 @@ async def train( # type: ignore[override]
579593 kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
580594 checkpoint to use as the KL reference. Alternative to
581595 kl_penalty_reference_step.
582- ppo: Whether to use PPO clipping. Defaults to False.
583- epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
596+ epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn.
584597 epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
585598 advantage_balance: Balance between negative and positive advantages
586599 in range [-1.0, 1.0]. Defaults to 0.0 (balanced).
@@ -623,8 +636,14 @@ async def train( # type: ignore[override]
623636 # await model.log(metrics=result.metrics, step=result.step)
624637 """
625638 groups_list = list (trajectory_groups )
626- if loss_fn is not None :
627- ppo = loss_fn == "ppo"
639+ if loss_fn not in {"cispo" , "ppo" }:
640+ raise ValueError ("LocalBackend only supports loss_fn='cispo' or 'ppo'." )
641+ if loss_fn_config is not None :
642+ raise ValueError ("LocalBackend requires loss_fn_config=None." )
643+ if not normalize_advantages :
644+ raise ValueError ("LocalBackend requires normalize_advantages=True." )
645+ if adam_params is not None :
646+ raise ValueError ("LocalBackend requires adam_params=None." )
628647
629648 resolved_kl_ref_adapter_path = kl_ref_adapter_path
630649 if (
@@ -641,7 +660,7 @@ async def train( # type: ignore[override]
641660 scale_rewards = scale_rewards ,
642661 importance_sampling_level = importance_sampling_level ,
643662 mask_prob_ratio = mask_prob_ratio ,
644- ppo = ppo ,
663+ ppo = loss_fn == " ppo" ,
645664 precalculate_logprobs = precalculate_logprobs ,
646665 epsilon = epsilon ,
647666 epsilon_high = epsilon_high ,
0 commit comments