Skip to content

Commit a0270e8

Browse files
committed
Add dpo improvements arguments
1 parent ecd68a4 commit a0270e8

3 files changed

Lines changed: 85 additions & 1 deletion

File tree

src/together/cli/api/finetune.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,36 @@ def fine_tuning(ctx: click.Context) -> None:
142142
default=0.1,
143143
help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')",
144144
)
145+
@click.option(
146+
"--dpo-normalize-logratios_by-length",
147+
type=bool,
148+
default=False,
149+
help=(
150+
"Whether to normalize logratios by sample length "
151+
"(only used when '--training-method' is 'dpo')"
152+
),
153+
)
154+
@click.option(
155+
"--dpo-reference-free",
156+
type=bool,
157+
default=False,
158+
help="Whether to skip reference logits usage (only used when '--training-method' is 'dpo')",
159+
)
160+
@click.option(
161+
"--rpo-alpha",
162+
type=float,
163+
default=0.0,
164+
help=(
165+
"RPO alpha parameter of DPO training to include NLL in the loss "
166+
"(only used when '--training-method' is 'dpo')"
167+
),
168+
)
169+
@click.option(
170+
"--simpo-gamma",
171+
type=float,
172+
default=0.1,
173+
help="SimPO gamma parameter (only used when '--training-method' is 'dpo')",
174+
)
145175
@click.option(
146176
"--suffix",
147177
"-s",
@@ -206,6 +236,10 @@ def create(
206236
train_on_inputs: bool | Literal["auto"],
207237
training_method: str,
208238
dpo_beta: float,
239+
dpo_normalize_logratios_by_length: bool,
240+
dpo_reference_free: bool,
241+
rpo_alpha: float,
242+
simpo_gamma: float,
209243
from_checkpoint: str,
210244
) -> None:
211245
"""Start fine-tuning"""
@@ -239,6 +273,10 @@ def create(
239273
train_on_inputs=train_on_inputs,
240274
training_method=training_method,
241275
dpo_beta=dpo_beta,
276+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
277+
dpo_reference_free=dpo_reference_free,
278+
rpo_alpha=rpo_alpha,
279+
simpo_gamma=simpo_gamma,
242280
from_checkpoint=from_checkpoint,
243281
)
244282

src/together/resources/finetune.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ def create_finetune_request(
7272
train_on_inputs: bool | Literal["auto"] | None = None,
7373
training_method: str = "sft",
7474
dpo_beta: float | None = None,
75+
dpo_normalize_logratios_by_length: bool = False,
76+
dpo_reference_free: bool = False,
77+
rpo_alpha: float | None = None,
78+
simpo_gamma: float | None = None,
7579
from_checkpoint: str | None = None,
7680
) -> FinetuneRequest:
7781
if model is not None and from_checkpoint is not None:
@@ -182,6 +186,14 @@ def create_finetune_request(
182186

183187
if dpo_beta is not None and training_method != "dpo":
184188
raise ValueError("dpo_beta is only supported for DPO training")
189+
if dpo_normalize_logratios_by_length and training_method != "dpo":
190+
raise ValueError("dpo_normalize_logratios_by_length=True is only supported for DPO training")
191+
if dpo_reference_free and training_method != "dpo":
192+
raise ValueError("dpo_reference_free=True is only supported for DPO training")
193+
if rpo_alpha is not None and training_method != "dpo":
194+
raise ValueError("rpo_alpha is only supported for DPO training")
195+
if simpo_gamma is not None and training_method != "dpo":
196+
raise ValueError("simpo_gamma is only supported for DPO training")
185197

186198
lr_scheduler: FinetuneLRScheduler
187199
if lr_scheduler_type == "cosine":
@@ -204,7 +216,13 @@ def create_finetune_request(
204216
if training_method == "sft":
205217
training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs)
206218
elif training_method == "dpo":
207-
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
219+
training_method_cls = TrainingMethodDPO(
220+
dpo_beta=dpo_beta,
221+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
222+
dpo_reference_free=dpo_reference_free,
223+
rpo_alpha=rpo_alpha,
224+
simpo_gamma=simpo_gamma,
225+
)
208226

209227
finetune_request = FinetuneRequest(
210228
model=model,
@@ -302,6 +320,10 @@ def create(
302320
train_on_inputs: bool | Literal["auto"] | None = None,
303321
training_method: str = "sft",
304322
dpo_beta: float | None = None,
323+
dpo_normalize_logratios_by_length: bool = False,
324+
dpo_reference_free: bool = False,
325+
rpo_alpha: float | None = None,
326+
simpo_gamma: float | None = None,
305327
from_checkpoint: str | None = None,
306328
) -> FinetuneResponse:
307329
"""
@@ -353,6 +375,10 @@ def create(
353375
training_method (str, optional): Training method. Defaults to "sft".
354376
Supported methods: "sft", "dpo".
355377
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
378+
dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample lenght. Defaults to False,
379+
dpo_reference_free (bool): Whether to skip reference logits usage. Defaults to False.
380+
rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
381+
simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
356382
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
357383
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
358384
The step value is optional, without it the final checkpoint will be used.
@@ -405,6 +431,10 @@ def create(
405431
train_on_inputs=train_on_inputs,
406432
training_method=training_method,
407433
dpo_beta=dpo_beta,
434+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
435+
dpo_reference_free=dpo_reference_free,
436+
rpo_alpha=rpo_alpha,
437+
simpo_gamma=simpo_gamma,
408438
from_checkpoint=from_checkpoint,
409439
)
410440

@@ -714,6 +744,10 @@ async def create(
714744
train_on_inputs: bool | Literal["auto"] | None = None,
715745
training_method: str = "sft",
716746
dpo_beta: float | None = None,
747+
dpo_normalize_logratios_by_length: bool = False,
748+
dpo_reference_free: bool = False,
749+
rpo_alpha: float | None = None,
750+
simpo_gamma: float | None = None,
717751
from_checkpoint: str | None = None,
718752
) -> FinetuneResponse:
719753
"""
@@ -765,6 +799,10 @@ async def create(
765799
training_method (str, optional): Training method. Defaults to "sft".
766800
Supported methods: "sft", "dpo".
767801
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
802+
dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample lenght. Defaults to False,
803+
dpo_reference_free (bool): Whether to skip reference logits usage. Defaults to False.
804+
rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
805+
simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
768806
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
769807
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
770808
The step value is optional, without it the final checkpoint will be used.
@@ -817,6 +855,10 @@ async def create(
817855
train_on_inputs=train_on_inputs,
818856
training_method=training_method,
819857
dpo_beta=dpo_beta,
858+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
859+
dpo_reference_free=dpo_reference_free,
860+
rpo_alpha=rpo_alpha,
861+
simpo_gamma=simpo_gamma,
820862
from_checkpoint=from_checkpoint,
821863
)
822864

src/together/types/finetune.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ class TrainingMethodDPO(TrainingMethod):
159159

160160
method: Literal["dpo"] = "dpo"
161161
dpo_beta: float | None = None
162+
dpo_normalize_logratios_by_length: bool = False
163+
dpo_reference_free: bool = False
164+
rpo_alpha: float | None = None
165+
simpo_gamma: float | None = None
162166

163167

164168
class FinetuneRequest(BaseModel):

0 commit comments

Comments
 (0)