Skip to content

Commit 8e1ee62

Browse files
committed
Implicit setting of reference_free in case if simpo_gamma is set
1 parent b92bc17 commit 8e1ee62

2 files changed

Lines changed: 6 additions & 11 deletions

File tree

src/together/cli/api/finetune.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def create(
237237
training_method: str,
238238
dpo_beta: float,
239239
dpo_normalize_logratios_by_length: bool,
240-
dpo_reference_free: bool,
241240
rpo_alpha: float,
242241
simpo_gamma: float,
243242
from_checkpoint: str,
@@ -274,7 +273,6 @@ def create(
274273
training_method=training_method,
275274
dpo_beta=dpo_beta,
276275
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
277-
dpo_reference_free=dpo_reference_free,
278276
rpo_alpha=rpo_alpha,
279277
simpo_gamma=simpo_gamma,
280278
from_checkpoint=from_checkpoint,

src/together/resources/finetune.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def create_finetune_request(
7373
training_method: str = "sft",
7474
dpo_beta: float | None = None,
7575
dpo_normalize_logratios_by_length: bool = False,
76-
dpo_reference_free: bool = False,
7776
rpo_alpha: float | None = None,
7877
simpo_gamma: float | None = None,
7978
from_checkpoint: str | None = None,
@@ -188,8 +187,6 @@ def create_finetune_request(
188187
raise ValueError("dpo_beta is only supported for DPO training")
189188
if dpo_normalize_logratios_by_length and training_method != "dpo":
190189
raise ValueError("dpo_normalize_logratios_by_length=True is only supported for DPO training")
191-
if dpo_reference_free and training_method != "dpo":
192-
raise ValueError("dpo_reference_free=True is only supported for DPO training")
193190
if rpo_alpha is not None and training_method != "dpo":
194191
raise ValueError("rpo_alpha is only supported for DPO training")
195192
if simpo_gamma is not None and training_method != "dpo":
@@ -216,6 +213,12 @@ def create_finetune_request(
216213
if training_method == "sft":
217214
training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs)
218215
elif training_method == "dpo":
216+
if simpo_gamma is not None and simpo_gamma > 0:
217+
dpo_reference_free = True
218+
rprint(
219+
f"Parameter simpo_gamma was set to {simpo_gamma}. "
220+
"SimPO training detected. Reference logits will not be used."
221+
)
219222
training_method_cls = TrainingMethodDPO(
220223
dpo_beta=dpo_beta,
221224
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
@@ -321,7 +324,6 @@ def create(
321324
training_method: str = "sft",
322325
dpo_beta: float | None = None,
323326
dpo_normalize_logratios_by_length: bool = False,
324-
dpo_reference_free: bool = False,
325327
rpo_alpha: float | None = None,
326328
simpo_gamma: float | None = None,
327329
from_checkpoint: str | None = None,
@@ -376,7 +378,6 @@ def create(
376378
Supported methods: "sft", "dpo".
377379
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
378380
dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample lenght. Defaults to False,
379-
dpo_reference_free (bool): Whether to skip reference logits usage. Defaults to False.
380381
rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
381382
simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
382383
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
@@ -432,7 +433,6 @@ def create(
432433
training_method=training_method,
433434
dpo_beta=dpo_beta,
434435
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
435-
dpo_reference_free=dpo_reference_free,
436436
rpo_alpha=rpo_alpha,
437437
simpo_gamma=simpo_gamma,
438438
from_checkpoint=from_checkpoint,
@@ -745,7 +745,6 @@ async def create(
745745
training_method: str = "sft",
746746
dpo_beta: float | None = None,
747747
dpo_normalize_logratios_by_length: bool = False,
748-
dpo_reference_free: bool = False,
749748
rpo_alpha: float | None = None,
750749
simpo_gamma: float | None = None,
751750
from_checkpoint: str | None = None,
@@ -800,7 +799,6 @@ async def create(
800799
Supported methods: "sft", "dpo".
801800
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
802801
dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample lenght. Defaults to False,
803-
dpo_reference_free (bool): Whether to skip reference logits usage. Defaults to False.
804802
rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
805803
simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
806804
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
@@ -856,7 +854,6 @@ async def create(
856854
training_method=training_method,
857855
dpo_beta=dpo_beta,
858856
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
859-
dpo_reference_free=dpo_reference_free,
860857
rpo_alpha=rpo_alpha,
861858
simpo_gamma=simpo_gamma,
862859
from_checkpoint=from_checkpoint,

0 commit comments

Comments
 (0)