Skip to content

Commit 959f9f1

Browse files
committed
Add the field for RPO-alpha
1 parent c09481b commit 959f9f1

3 files changed

Lines changed: 18 additions & 1 deletion

File tree

src/together/cli/api/finetune.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,13 @@ def fine_tuning(ctx: click.Context) -> None:
142142
default=0.1,
143143
help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')",
144144
)
145+
@click.option(
146+
"--rpo-alpha",
147+
type=float,
148+
default=1.0,
149+
help="RPO alpha to control the weight of NLL loss component for chosen responses "
150+
"(only used when '--training-method' is 'dpo')",
151+
)
145152
@click.option(
146153
"--suffix",
147154
"-s",
@@ -206,6 +213,7 @@ def create(
206213
train_on_inputs: bool | Literal["auto"],
207214
training_method: str,
208215
dpo_beta: float,
216+
rpo_alpha: float,
209217
from_checkpoint: str,
210218
) -> None:
211219
"""Start fine-tuning"""
@@ -239,6 +247,7 @@ def create(
239247
train_on_inputs=train_on_inputs,
240248
training_method=training_method,
241249
dpo_beta=dpo_beta,
250+
rpo_alpha=rpo_alpha,
242251
from_checkpoint=from_checkpoint,
243252
)
244253

src/together/resources/finetune.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def create_finetune_request(
8080
train_on_inputs: bool | Literal["auto"] = "auto",
8181
training_method: str = "sft",
8282
dpo_beta: float | None = None,
83+
rpo_alpha: float | None = None,
8384
from_checkpoint: str | None = None,
8485
) -> FinetuneRequest:
8586
if model is not None and from_checkpoint is not None:
@@ -193,7 +194,7 @@ def create_finetune_request(
193194

194195
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
195196
if training_method == "dpo":
196-
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
197+
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta, rpo_alpha=rpo_alpha)
197198

198199
finetune_request = FinetuneRequest(
199200
model=model,
@@ -322,6 +323,7 @@ def create(
322323
train_on_inputs: bool | Literal["auto"] = "auto",
323324
training_method: str = "sft",
324325
dpo_beta: float | None = None,
326+
rpo_alpha: float | None = None,
325327
from_checkpoint: str | None = None,
326328
) -> FinetuneResponse:
327329
"""
@@ -373,6 +375,7 @@ def create(
373375
training_method (str, optional): Training method. Defaults to "sft".
374376
Supported methods: "sft", "dpo".
375377
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
378+
rpo_alpha (float, optional): RPO alpha to control the weight of NLL loss component for chosen responses. Defaults to None.
376379
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
377380
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
378381
The step value is optional, without it the final checkpoint will be used.
@@ -425,6 +428,7 @@ def create(
425428
train_on_inputs=train_on_inputs,
426429
training_method=training_method,
427430
dpo_beta=dpo_beta,
431+
rpo_alpha=rpo_alpha,
428432
from_checkpoint=from_checkpoint,
429433
)
430434

@@ -710,6 +714,7 @@ async def create(
710714
train_on_inputs: bool | Literal["auto"] = "auto",
711715
training_method: str = "sft",
712716
dpo_beta: float | None = None,
717+
rpo_alpha: float | None = None,
713718
from_checkpoint: str | None = None,
714719
) -> FinetuneResponse:
715720
"""
@@ -761,6 +766,7 @@ async def create(
761766
training_method (str, optional): Training method. Defaults to "sft".
762767
Supported methods: "sft", "dpo".
763768
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
769+
rpo_alpha (float, optional): RPO alpha to control the weight of NLL loss component for chosen responses. Defaults to None.
764770
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
765771
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
766772
The step value is optional, without it the final checkpoint will be used.
@@ -813,6 +819,7 @@ async def create(
813819
train_on_inputs=train_on_inputs,
814820
training_method=training_method,
815821
dpo_beta=dpo_beta,
822+
rpo_alpha=rpo_alpha,
816823
from_checkpoint=from_checkpoint,
817824
)
818825

src/together/types/finetune.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ class TrainingMethodDPO(TrainingMethod):
158158

159159
method: Literal["dpo"] = "dpo"
160160
dpo_beta: float | None = None
161+
rpo_alpha: float | None = None
161162

162163

163164
class FinetuneRequest(BaseModel):

0 commit comments

Comments
 (0)