Skip to content

Commit dea7fdf

Browse files
address None and comments from review
1 parent f73533d commit dea7fdf

4 files changed

Lines changed: 206 additions & 39 deletions

File tree

src/together/cli/api/finetune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242

4343
_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
4444
"The estimated price of this job is significantly greater than your current credit limit and balance combined. "
45-
"It will likely fail due to insufficient funds. "
46-
"Please consider increasing your credit limit at https://api.together.xyz/settings/profile\n"
45+
"It will likely get cancelled due to insufficient funds. "
46+
"Consider increasing your credit limit at https://api.together.xyz/settings/profile\n"
4747
)
4848

4949

src/together/resources/finetune.py

Lines changed: 77 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
4848
"The estimated price of the fine-tuning job is {} which is significantly "
4949
"greater than your current credit limit and balance combined. "
50-
"It will likely fail due to insufficient funds. "
51-
"Please proceed at your own risk."
50+
"It will likely get cancelled due to insufficient funds. "
51+
"Proceed at your own risk."
5252
)
5353

5454

@@ -481,16 +481,25 @@ def create(
481481
hf_api_token=hf_api_token,
482482
hf_output_repo_name=hf_output_repo_name,
483483
)
484-
485-
price_estimation_result = self.estimate_price(
486-
training_file=training_file,
487-
validation_file=validation_file,
488-
model=model_name,
489-
n_epochs=n_epochs,
490-
n_evals=n_evals,
491-
training_type="lora" if lora else "full",
492-
training_method=training_method,
493-
)
484+
if from_checkpoint is None:
485+
price_estimation_result = self.estimate_price(
486+
training_file=training_file,
487+
validation_file=validation_file,
488+
model=model_name,
489+
n_epochs=finetune_request.n_epochs,
490+
n_evals=finetune_request.n_evals,
491+
training_type="lora" if lora else "full",
492+
training_method=training_method,
493+
)
494+
else:
495+
# unsupported case
496+
price_estimation_result = FinetunePriceEstimationResponse(
497+
estimated_total_price=0.0,
498+
allowed_to_proceed=True,
499+
estimated_train_token_count=0,
500+
estimated_eval_token_count=0,
501+
user_limit=0.0,
502+
)
494503

495504
if verbose:
496505
rprint(
@@ -523,10 +532,10 @@ def estimate_price(
523532
self,
524533
*,
525534
training_file: str,
526-
model: str | None,
535+
model: str,
527536
validation_file: str | None = None,
528-
n_epochs: int | None = None,
529-
n_evals: int | None = None,
537+
n_epochs: int | None = 1,
538+
n_evals: int | None = 0,
530539
training_type: str = "lora",
531540
training_method: str = "sft",
532541
) -> FinetunePriceEstimationResponse:
@@ -539,8 +548,8 @@ def estimate_price(
539548
Returns:
540549
FinetunePriceEstimationResponse: Object containing the estimated price.
541550
"""
542-
training_type_cls: TrainingType | None = None
543-
training_method_cls: TrainingMethod | None = None
551+
training_type_cls: TrainingType
552+
training_method_cls: TrainingMethod
544553

545554
if training_method == "sft":
546555
training_method_cls = TrainingMethodSFT(method="sft")
@@ -1036,15 +1045,25 @@ async def create(
10361045
hf_output_repo_name=hf_output_repo_name,
10371046
)
10381047

1039-
price_estimation_result = await self.estimate_price(
1040-
training_file=training_file,
1041-
validation_file=validation_file,
1042-
model=model_name,
1043-
n_epochs=n_epochs,
1044-
n_evals=n_evals,
1045-
training_type=finetune_request.training_type,
1046-
training_method=finetune_request.training_method,
1047-
)
1048+
if from_checkpoint is not None:
1049+
price_estimation_result = await self.estimate_price(
1050+
training_file=training_file,
1051+
validation_file=validation_file,
1052+
model=model_name,
1053+
n_epochs=finetune_request.n_epochs,
1054+
n_evals=finetune_request.n_evals,
1055+
training_type="lora" if lora else "full",
1056+
training_method=training_method,
1057+
)
1058+
else:
1059+
# unsupported case
1060+
price_estimation_result = FinetunePriceEstimationResponse(
1061+
estimated_total_price=0.0,
1062+
allowed_to_proceed=True,
1063+
estimated_train_token_count=0,
1064+
estimated_eval_token_count=0,
1065+
user_limit=0.0,
1066+
)
10481067

10491068
if verbose:
10501069
rprint(
@@ -1080,28 +1099,53 @@ async def estimate_price(
10801099
training_file: str,
10811100
model: str,
10821101
validation_file: str | None = None,
1083-
n_epochs: int | None = None,
1084-
n_evals: int | None = None,
1085-
training_type: TrainingType | None = None,
1086-
training_method: TrainingMethodSFT | TrainingMethodDPO | None = None,
1102+
n_epochs: int | None = 1,
1103+
n_evals: int | None = 0,
1104+
training_type: str = "lora",
1105+
training_method: str = "sft",
10871106
) -> FinetunePriceEstimationResponse:
10881107
"""
1089-
Async method to estimate the price of a fine-tuning job
1108+
Estimates the price of a fine-tuning job
10901109
10911110
Args:
10921111
request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
10931112
10941113
Returns:
10951114
FinetunePriceEstimationResponse: Object containing the estimated price.
10961115
"""
1116+
training_type_cls: TrainingType
1117+
training_method_cls: TrainingMethod
1118+
1119+
if training_method == "sft":
1120+
training_method_cls = TrainingMethodSFT(method="sft")
1121+
elif training_method == "dpo":
1122+
training_method_cls = TrainingMethodDPO(method="dpo")
1123+
else:
1124+
raise ValueError(f"Unknown training method: {training_method}")
1125+
1126+
if training_type.lower() == "lora":
1127+
# parameters of lora are unused in price estimation
1128+
# but we need to set them to valid values
1129+
training_type_cls = LoRATrainingType(
1130+
type="Lora",
1131+
lora_r=16,
1132+
lora_alpha=16,
1133+
lora_dropout=0.0,
1134+
lora_trainable_modules="all-linear",
1135+
)
1136+
elif training_type.lower() == "full":
1137+
training_type_cls = FullTrainingType(type="Full")
1138+
else:
1139+
raise ValueError(f"Unknown training type: {training_type}")
1140+
10971141
request = FinetunePriceEstimationRequest(
10981142
training_file=training_file,
10991143
validation_file=validation_file,
11001144
model=model,
11011145
n_epochs=n_epochs,
11021146
n_evals=n_evals,
1103-
training_type=training_type,
1104-
training_method=training_method,
1147+
training_type=training_type_cls,
1148+
training_method=training_method_cls,
11051149
)
11061150
parameter_payload = request.model_dump(exclude_none=True)
11071151
requestor = api_requestor.APIRequestor(

src/together/types/finetune.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,10 @@ class FinetunePriceEstimationRequest(BaseModel):
316316
training_file: str
317317
validation_file: str | None = None
318318
model: str
319-
n_epochs: int | None = None
320-
n_evals: int | None = None
321-
training_type: TrainingType | None = None
322-
training_method: TrainingMethodSFT | TrainingMethodDPO
319+
n_epochs: int
320+
n_evals: int
321+
training_type: TrainingType
322+
training_method: TrainingMethod
323323

324324

325325
class FinetunePriceEstimationResponse(BaseModel):

tests/unit/test_finetune_resources.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import pytest
2+
from unittest.mock import MagicMock, Mock, patch
23

4+
from together.client import Together
35
from together.resources.finetune import create_finetune_request
6+
from together.together_response import TogetherResponse
7+
from together.types import TogetherRequest
48
from together.types.finetune import (
59
FinetuneFullTrainingLimits,
610
FinetuneLoraTrainingLimits,
@@ -31,6 +35,41 @@
3135
)
3236

3337

38+
def mock_request(options: TogetherRequest, *args, **kwargs):
39+
if options.url == "fine-tunes/estimate-price":
40+
return (
41+
TogetherResponse(
42+
data={
43+
"estimated_total_price": 100,
44+
"allowed_to_proceed": True,
45+
"estimated_train_token_count": 1000,
46+
"estimated_eval_token_count": 100,
47+
"user_limit": 1000,
48+
},
49+
headers={},
50+
),
51+
None,
52+
None,
53+
)
54+
elif options.url == "fine-tunes":
55+
return (
56+
TogetherResponse(
57+
data={
58+
"id": "ft-12345678-1234-1234-1234-1234567890ab",
59+
},
60+
headers={},
61+
),
62+
None,
63+
None,
64+
)
65+
else:
66+
return (
67+
TogetherResponse(data=_MODEL_LIMITS.model_dump(), headers={}),
68+
None,
69+
None,
70+
)
71+
72+
3473
def test_simple_request():
3574
request = create_finetune_request(
3675
model_limits=_MODEL_LIMITS,
@@ -335,3 +374,87 @@ def test_train_on_inputs_not_supported_for_dpo():
335374
training_method="dpo",
336375
train_on_inputs=True,
337376
)
377+
378+
379+
@patch("together.abstract.api_requestor.APIRequestor.request")
380+
def test_price_estimation_request(mocker):
381+
test_data = [
382+
{
383+
"training_type": "lora",
384+
"training_method": "sft",
385+
},
386+
{
387+
"training_type": "lora",
388+
"training_method": "dpo",
389+
},
390+
{
391+
"training_type": "full",
392+
"training_method": "sft",
393+
},
394+
]
395+
mocker.return_value = (
396+
TogetherResponse(
397+
data={
398+
"estimated_total_price": 100,
399+
"allowed_to_proceed": True,
400+
"estimated_train_token_count": 1000,
401+
"estimated_eval_token_count": 100,
402+
"user_limit": 1000,
403+
},
404+
headers={},
405+
),
406+
None,
407+
None,
408+
)
409+
client = Together()
410+
for test_case in test_data:
411+
response = client.fine_tuning.estimate_price(
412+
training_file=_TRAINING_FILE,
413+
model=_MODEL_NAME,
414+
validation_file=_VALIDATION_FILE,
415+
n_epochs=1,
416+
n_evals=0,
417+
training_type=test_case["training_type"],
418+
training_method=test_case["training_method"],
419+
)
420+
assert response.estimated_total_price > 0
421+
assert response.allowed_to_proceed
422+
assert response.estimated_train_token_count > 0
423+
assert response.estimated_eval_token_count > 0
424+
425+
426+
def test_create_ft_job(mocker):
427+
mock_requestor = Mock()
428+
mock_requestor.request = MagicMock()
429+
mock_requestor.request.side_effect = mock_request
430+
mocker.patch(
431+
"together.abstract.api_requestor.APIRequestor", return_value=mock_requestor
432+
)
433+
434+
client = Together()
435+
response = client.fine_tuning.create(
436+
training_file=_TRAINING_FILE,
437+
model=_MODEL_NAME,
438+
validation_file=_VALIDATION_FILE,
439+
n_epochs=1,
440+
n_evals=0,
441+
lora=True,
442+
training_method="sft",
443+
)
444+
445+
assert mock_requestor.request.call_count == 3
446+
assert response.id == "ft-12345678-1234-1234-1234-1234567890ab"
447+
448+
response = client.fine_tuning.create(
449+
training_file=_TRAINING_FILE,
450+
model=None,
451+
validation_file=_VALIDATION_FILE,
452+
n_epochs=1,
453+
n_evals=0,
454+
lora=True,
455+
training_method="sft",
456+
from_checkpoint=_FROM_CHECKPOINT,
457+
)
458+
459+
assert mock_requestor.request.call_count == 5
460+
assert response.id == "ft-12345678-1234-1234-1234-1234567890ab"

0 commit comments

Comments
 (0)