@@ -73,7 +73,6 @@ def create_finetune_request(
7373 training_method : str = "sft" ,
7474 dpo_beta : float | None = None ,
7575 dpo_normalize_logratios_by_length : bool = False ,
76- dpo_reference_free : bool = False ,
7776 rpo_alpha : float | None = None ,
7877 simpo_gamma : float | None = None ,
7978 from_checkpoint : str | None = None ,
@@ -188,8 +187,6 @@ def create_finetune_request(
188187 raise ValueError ("dpo_beta is only supported for DPO training" )
189188 if dpo_normalize_logratios_by_length and training_method != "dpo" :
190189 raise ValueError ("dpo_normalize_logratios_by_length=True is only supported for DPO training" )
191- if dpo_reference_free and training_method != "dpo" :
192- raise ValueError ("dpo_reference_free=True is only supported for DPO training" )
193190 if rpo_alpha is not None and training_method != "dpo" :
194191 raise ValueError ("rpo_alpha is only supported for DPO training" )
195192 if simpo_gamma is not None and training_method != "dpo" :
@@ -216,6 +213,12 @@ def create_finetune_request(
216213 if training_method == "sft" :
217214 training_method_cls = TrainingMethodSFT (train_on_inputs = train_on_inputs )
218215 elif training_method == "dpo" :
216+ if simpo_gamma is not None and simpo_gamma > 0 :
217+ dpo_reference_free = True
218+ rprint (
219+ f"Parameter simpo_gamma was set to { simpo_gamma } . "
220+ "SimPO training detected. Reference logits will not be used."
221+ )
219222 training_method_cls = TrainingMethodDPO (
220223 dpo_beta = dpo_beta ,
221224 dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
@@ -321,7 +324,6 @@ def create(
321324 training_method : str = "sft" ,
322325 dpo_beta : float | None = None ,
323326 dpo_normalize_logratios_by_length : bool = False ,
324- dpo_reference_free : bool = False ,
325327 rpo_alpha : float | None = None ,
326328 simpo_gamma : float | None = None ,
327329 from_checkpoint : str | None = None ,
@@ -376,7 +378,6 @@ def create(
376378 Supported methods: "sft", "dpo".
377379 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
378380 dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample lenght. Defaults to False,
379- dpo_reference_free (bool): Whether to skip reference logits usage. Defaults to False.
380381 rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
381382 simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
382383 from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
@@ -432,7 +433,6 @@ def create(
432433 training_method = training_method ,
433434 dpo_beta = dpo_beta ,
434435 dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
435- dpo_reference_free = dpo_reference_free ,
436436 rpo_alpha = rpo_alpha ,
437437 simpo_gamma = simpo_gamma ,
438438 from_checkpoint = from_checkpoint ,
@@ -745,7 +745,6 @@ async def create(
745745 training_method : str = "sft" ,
746746 dpo_beta : float | None = None ,
747747 dpo_normalize_logratios_by_length : bool = False ,
748- dpo_reference_free : bool = False ,
749748 rpo_alpha : float | None = None ,
750749 simpo_gamma : float | None = None ,
751750 from_checkpoint : str | None = None ,
@@ -800,7 +799,6 @@ async def create(
800799 Supported methods: "sft", "dpo".
801800 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
802801 dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample lenght. Defaults to False,
803- dpo_reference_free (bool): Whether to skip reference logits usage. Defaults to False.
804802 rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
805803 simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
806804 from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
@@ -856,7 +854,6 @@ async def create(
856854 training_method = training_method ,
857855 dpo_beta = dpo_beta ,
858856 dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
859- dpo_reference_free = dpo_reference_free ,
860857 rpo_alpha = rpo_alpha ,
861858 simpo_gamma = simpo_gamma ,
862859 from_checkpoint = from_checkpoint ,
0 commit comments