@@ -113,7 +113,7 @@ class RLAIFTrainer(BaseTrainer):
113113 The VPC configuration for the training job.
114114 stopping_condition (Optional[StoppingCondition]):
115115 The stopping condition to override training runtime limit.
116- If not specified, defaults to 1 hour max runtime .
116+ If not specified, uses SageMaker service default (24 hours for serverless training) .
117117 """
118118
119119 def __init__ (
@@ -221,10 +221,6 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
221221 current_training_job_name = _get_unique_name (
222222 self .base_job_name or f"{ self ._model_name } -rlaif"
223223 )
224-
225- stopping_condition = TrainDefaults .get_stopping_condition (
226- stopping_condition = self .stopping_condition
227- )
228224
229225 logger .info (f"Training Job Name: { current_training_job_name } " )
230226
@@ -269,22 +265,28 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
269265 vpc_config = self .networking if self .networking else None
270266 tags = _get_studio_tags (self ._model_name , HUB_NAME )
271267
268+ # Build TrainingJob.create() arguments
269+ create_args = {
270+ "training_job_name" : current_training_job_name ,
271+ "role_arn" : role ,
272+ "input_data_config" : channels ,
273+ "output_data_config" : output_config ,
274+ "serverless_job_config" : serverless_config ,
275+ "mlflow_config" : mlflow_config ,
276+ "hyper_parameters" : final_hyperparameters ,
277+ "model_package_config" : model_package_config ,
278+ "vpc_config" : vpc_config ,
279+ "session" : sagemaker_session .boto_session ,
280+ "region" : sagemaker_session .boto_session .region_name ,
281+ "tags" : tags ,
282+ }
283+
284+ # Only pass stopping_condition if explicitly provided by user
285+ if self .stopping_condition is not None :
286+ create_args ["stopping_condition" ] = self .stopping_condition
287+
272288 try :
273- training_job = TrainingJob .create (
274- training_job_name = current_training_job_name ,
275- role_arn = role ,
276- input_data_config = channels ,
277- output_data_config = output_config ,
278- serverless_job_config = serverless_config ,
279- mlflow_config = mlflow_config ,
280- hyper_parameters = final_hyperparameters ,
281- model_package_config = model_package_config ,
282- vpc_config = vpc_config ,
283- stopping_condition = stopping_condition ,
284- session = sagemaker_session .boto_session ,
285- region = sagemaker_session .boto_session .region_name ,
286- tags = tags ,
287- )
289+ training_job = TrainingJob .create (** create_args )
288290 except Exception as e :
289291 logger .error ("Error: %s" , e )
290292 raise e
0 commit comments