Skip to content

Commit efefbbb

Browse files
add price estimation
1 parent 1201470 commit efefbbb

4 files changed

Lines changed: 238 additions & 13 deletions

File tree

src/together/cli/api/finetune.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
DownloadCheckpointType,
1818
FinetuneEventType,
1919
FinetuneTrainingLimits,
20+
FullTrainingType,
21+
LoRATrainingType,
2022
)
2123
from together.utils import (
2224
finetune_price_to_dollars,
@@ -36,6 +38,15 @@
3638
"Do you want to proceed?"
3739
)
3840

41+
_PRICE_ESTIMATION_CONFIRMATION_MESSAGE = (
42+
"The estimated price of the fine-tuning job is {} which is significantly "
43+
"greater than your current credit limit and balance. "
44+
"It will likely fail due to insufficient funds. "
45+
"Please consider increasing your credit limit at https://api.together.xyz/settings/profile\n"
46+
"You can pass `-y` or `--confirm` to your command to skip this message.\n\n"
47+
"Do you want to proceed?"
48+
)
49+
3950

4051
class DownloadCheckpointTypeChoice(click.Choice):
4152
def __init__(self) -> None:
@@ -358,20 +369,49 @@ def create(
358369
)
359370

360371
if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
361-
response = client.fine_tuning.create(
362-
**training_args,
363-
verbose=True,
372+
price_estimation_response = client.fine_tuning.estimate_price(
373+
training_file=training_file,
374+
validation_file=validation_file,
375+
model=model,
376+
n_epochs=n_epochs,
377+
n_evals=n_evals,
378+
training_type="lora" if lora else "full",
379+
training_method=training_method,
364380
)
365-
366-
report_string = f"Successfully submitted a fine-tuning job {response.id}"
367-
if response.created_at is not None:
368-
created_time = datetime.strptime(
369-
response.created_at, "%Y-%m-%dT%H:%M:%S.%f%z"
381+
proceed = (
382+
confirm
383+
or price_estimation_response.allowed_to_proceed
384+
or (
385+
not price_estimation_response.allowed_to_proceed
386+
and click.confirm(
387+
click.style(
388+
_PRICE_ESTIMATION_CONFIRMATION_MESSAGE.format(
389+
price_estimation_response.estimated_total_price
390+
),
391+
fg="red",
392+
bold=True,
393+
),
394+
default=True,
395+
show_default=True,
396+
)
397+
)
398+
)
399+
if proceed:
400+
response = client.fine_tuning.create(
401+
**training_args,
402+
verbose=True,
370403
)
371-
# created_at reports UTC time, we use .astimezone() to convert to local time
372-
formatted_time = created_time.astimezone().strftime("%m/%d/%Y, %H:%M:%S")
373-
report_string += f" at {formatted_time}"
374-
rprint(report_string)
404+
report_string = f"Successfully submitted a fine-tuning job {response.id}"
405+
if response.created_at is not None:
406+
created_time = datetime.strptime(
407+
response.created_at, "%Y-%m-%dT%H:%M:%S.%f%z"
408+
)
409+
# created_at reports UTC time, we use .astimezone() to convert to local time
410+
formatted_time = created_time.astimezone().strftime(
411+
"%m/%d/%Y, %H:%M:%S"
412+
)
413+
report_string += f" at {formatted_time}"
414+
rprint(report_string)
375415
else:
376416
click.echo("No confirmation received, stopping job launch")
377417

src/together/resources/finetune.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
FinetuneLRScheduler,
2121
FinetuneRequest,
2222
FinetuneResponse,
23+
FinetunePriceEstimationRequest,
24+
FinetunePriceEstimationResponse,
2325
FinetuneTrainingLimits,
2426
FullTrainingType,
2527
LinearLRScheduler,
@@ -31,7 +33,7 @@
3133
TrainingMethodSFT,
3234
TrainingType,
3335
)
34-
from together.types.finetune import DownloadCheckpointType
36+
from together.types.finetune import DownloadCheckpointType, TrainingMethod
3537
from together.utils import log_warn_once, normalize_key
3638

3739

@@ -42,6 +44,12 @@
4244
TrainingMethodSFT().method,
4345
TrainingMethodDPO().method,
4446
}
47+
_CONFIRMATION_MESSAGE_INSUFFICIENT_FUNDS = (
48+
"The estimated price of the fine-tuning job is {} which is significantly "
49+
"greater than your current credit limit and balance. "
50+
"It will likely fail due to insufficient funds. "
51+
"Please proceed at your own risk."
52+
)
4553

4654

4755
def create_finetune_request(
@@ -474,11 +482,29 @@ def create(
474482
hf_output_repo_name=hf_output_repo_name,
475483
)
476484

485+
price_estimation_result = self.estimate_price(
486+
training_file=training_file,
487+
validation_file=validation_file,
488+
model=model_name,
489+
n_epochs=n_epochs,
490+
n_evals=n_evals,
491+
training_type="lora" if lora else "full",
492+
training_method=training_method,
493+
)
494+
477495
if verbose:
478496
rprint(
479497
"Submitting a fine-tuning job with the following parameters:",
480498
finetune_request,
481499
)
500+
if not price_estimation_result.allowed_to_proceed:
501+
rprint(
502+
"[red]"
503+
+ _CONFIRMATION_MESSAGE_INSUFFICIENT_FUNDS.format(
504+
price_estimation_result.estimated_total_price
505+
)
506+
+ "[/red]",
507+
)
482508
parameter_payload = finetune_request.model_dump(exclude_none=True)
483509

484510
response, _, _ = requestor.request(
@@ -493,6 +519,73 @@ def create(
493519

494520
return FinetuneResponse(**response.data)
495521

522+
def estimate_price(
523+
self,
524+
*,
525+
training_file: str,
526+
model: str | None,
527+
validation_file: str | None = None,
528+
n_epochs: int | None = None,
529+
n_evals: int | None = None,
530+
training_type: str = "lora",
531+
training_method: str = "sft",
532+
) -> FinetunePriceEstimationResponse:
533+
"""
534+
Estimates the price of a fine-tuning job
535+
536+
Args:
537+
request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
538+
539+
Returns:
540+
FinetunePriceEstimationResponse: Object containing the estimated price.
541+
"""
542+
training_type_cls: TrainingType | None = None
543+
training_method_cls: TrainingMethod | None = None
544+
545+
if training_method == "sft":
546+
training_method_cls = TrainingMethodSFT(method="sft")
547+
elif training_method == "dpo":
548+
training_method_cls = TrainingMethodDPO(method="dpo")
549+
else:
550+
raise ValueError(f"Unknown training method: {training_method}")
551+
552+
if training_type.lower() == "lora":
553+
training_type_cls = LoRATrainingType(
554+
type="Lora",
555+
lora_r=16,
556+
lora_alpha=16,
557+
lora_dropout=0.0,
558+
lora_trainable_modules="all-linear",
559+
)
560+
elif training_type.lower() == "full":
561+
training_type_cls = FullTrainingType(type="Full")
562+
else:
563+
raise ValueError(f"Unknown training type: {training_type}")
564+
565+
request = FinetunePriceEstimationRequest(
566+
training_file=training_file,
567+
validation_file=validation_file,
568+
model=model,
569+
n_epochs=n_epochs,
570+
n_evals=n_evals,
571+
training_type=training_type_cls,
572+
training_method=training_method_cls,
573+
)
574+
parameter_payload = request.model_dump(exclude_none=True)
575+
requestor = api_requestor.APIRequestor(
576+
client=self._client,
577+
)
578+
579+
response, _, _ = requestor.request(
580+
options=TogetherRequest(
581+
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
582+
),
583+
stream=False,
584+
)
585+
assert isinstance(response, TogetherResponse)
586+
587+
return FinetunePriceEstimationResponse(**response.data)
588+
496589
def list(self) -> FinetuneList:
497590
"""
498591
Lists fine-tune job history
@@ -941,11 +1034,29 @@ async def create(
9411034
hf_output_repo_name=hf_output_repo_name,
9421035
)
9431036

1037+
price_estimation_result = await self.estimate_price(
1038+
training_file=training_file,
1039+
validation_file=validation_file,
1040+
model=model_name,
1041+
n_epochs=n_epochs,
1042+
n_evals=n_evals,
1043+
training_type=finetune_request.training_type,
1044+
training_method=finetune_request.training_method,
1045+
)
1046+
9441047
if verbose:
9451048
rprint(
9461049
"Submitting a fine-tuning job with the following parameters:",
9471050
finetune_request,
9481051
)
1052+
if not price_estimation_result.allowed_to_proceed:
1053+
rprint(
1054+
"[red]"
1055+
+ _CONFIRMATION_MESSAGE_INSUFFICIENT_FUNDS.format(
1056+
price_estimation_result.estimated_total_price
1057+
)
1058+
+ "[/red]",
1059+
)
9491060
parameter_payload = finetune_request.model_dump(exclude_none=True)
9501061

9511062
response, _, _ = await requestor.arequest(
@@ -961,6 +1072,50 @@ async def create(
9611072

9621073
return FinetuneResponse(**response.data)
9631074

1075+
async def estimate_price(
1076+
self,
1077+
*,
1078+
training_file: str,
1079+
model: str,
1080+
validation_file: str | None = None,
1081+
n_epochs: int | None = None,
1082+
n_evals: int | None = None,
1083+
training_type: TrainingType | None = None,
1084+
training_method: TrainingMethodSFT | TrainingMethodDPO | None = None,
1085+
) -> FinetunePriceEstimationResponse:
1086+
"""
1087+
Async method to estimate the price of a fine-tuning job
1088+
1089+
Args:
1090+
request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
1091+
1092+
Returns:
1093+
FinetunePriceEstimationResponse: Object containing the estimated price.
1094+
"""
1095+
request = FinetunePriceEstimationRequest(
1096+
training_file=training_file,
1097+
validation_file=validation_file,
1098+
model=model,
1099+
n_epochs=n_epochs,
1100+
n_evals=n_evals,
1101+
training_type=training_type,
1102+
training_method=training_method,
1103+
)
1104+
parameter_payload = request.model_dump(exclude_none=True)
1105+
requestor = api_requestor.APIRequestor(
1106+
client=self._client,
1107+
)
1108+
1109+
response, _, _ = await requestor.arequest(
1110+
options=TogetherRequest(
1111+
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
1112+
),
1113+
stream=False,
1114+
)
1115+
assert isinstance(response, TogetherResponse)
1116+
1117+
return FinetunePriceEstimationResponse(**response.data)
1118+
9641119
async def list(self) -> FinetuneList:
9651120
"""
9661121
Async method to list fine-tune job history

src/together/types/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
FinetuneListEvents,
5555
FinetuneRequest,
5656
FinetuneResponse,
57+
FinetunePriceEstimationRequest,
58+
FinetunePriceEstimationResponse,
5759
FinetuneDeleteResponse,
5860
FinetuneTrainingLimits,
5961
FullTrainingType,
@@ -103,6 +105,8 @@
103105
"FinetuneDeleteResponse",
104106
"FinetuneDownloadResult",
105107
"FinetuneLRScheduler",
108+
"FinetunePriceEstimationRequest",
109+
"FinetunePriceEstimationResponse",
106110
"LinearLRScheduler",
107111
"LinearLRSchedulerArgs",
108112
"CosineLRScheduler",

src/together/types/finetune.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,32 @@ def validate_training_type(cls, v: TrainingType) -> TrainingType:
308308
raise ValueError("Unknown training type")
309309

310310

311+
class FinetunePriceEstimationRequest(BaseModel):
312+
"""
313+
Fine-tune price estimation request type
314+
"""
315+
316+
training_file: str
317+
validation_file: str | None = None
318+
model: str
319+
n_epochs: int | None = None
320+
n_evals: int | None = None
321+
training_type: TrainingType | None = None
322+
training_method: TrainingMethodSFT | TrainingMethodDPO
323+
324+
325+
class FinetunePriceEstimationResponse(BaseModel):
326+
"""
327+
Fine-tune price estimation response type
328+
"""
329+
330+
estimated_total_price: float
331+
user_limit: float
332+
estimated_train_token_count: int
333+
estimated_eval_token_count: int
334+
allowed_to_proceed: bool
335+
336+
311337
class FinetuneList(BaseModel):
312338
# object type
313339
object: Literal["list"] | None = None

0 commit comments

Comments
 (0)