2020 FinetuneLRScheduler ,
2121 FinetuneRequest ,
2222 FinetuneResponse ,
23+ FinetunePriceEstimationRequest ,
24+ FinetunePriceEstimationResponse ,
2325 FinetuneTrainingLimits ,
2426 FullTrainingType ,
2527 LinearLRScheduler ,
3133 TrainingMethodSFT ,
3234 TrainingType ,
3335)
34- from together .types .finetune import DownloadCheckpointType
36+ from together .types .finetune import DownloadCheckpointType , TrainingMethod
3537from together .utils import log_warn_once , normalize_key
3638
3739
4244 TrainingMethodSFT ().method ,
4345 TrainingMethodDPO ().method ,
4446}
47+ _CONFIRMATION_MESSAGE_INSUFFICIENT_FUNDS = (
48+ "The estimated price of the fine-tuning job is {} which is significantly "
49+ "greater than your current credit limit and balance. "
50+ "It will likely fail due to insufficient funds. "
51+ "Please proceed at your own risk."
52+ )
4553
4654
4755def create_finetune_request (
@@ -474,11 +482,29 @@ def create(
474482 hf_output_repo_name = hf_output_repo_name ,
475483 )
476484
485+ price_estimation_result = self .estimate_price (
486+ training_file = training_file ,
487+ validation_file = validation_file ,
488+ model = model_name ,
489+ n_epochs = n_epochs ,
490+ n_evals = n_evals ,
491+ training_type = "lora" if lora else "full" ,
492+ training_method = training_method ,
493+ )
494+
477495 if verbose :
478496 rprint (
479497 "Submitting a fine-tuning job with the following parameters:" ,
480498 finetune_request ,
481499 )
500+ if not price_estimation_result .allowed_to_proceed :
501+ rprint (
502+ "[red]"
503+ + _CONFIRMATION_MESSAGE_INSUFFICIENT_FUNDS .format (
504+ price_estimation_result .estimated_total_price
505+ )
506+ + "[/red]" ,
507+ )
482508 parameter_payload = finetune_request .model_dump (exclude_none = True )
483509
484510 response , _ , _ = requestor .request (
@@ -493,6 +519,73 @@ def create(
493519
494520 return FinetuneResponse (** response .data )
495521
522+ def estimate_price (
523+ self ,
524+ * ,
525+ training_file : str ,
526+ model : str | None ,
527+ validation_file : str | None = None ,
528+ n_epochs : int | None = None ,
529+ n_evals : int | None = None ,
530+ training_type : str = "lora" ,
531+ training_method : str = "sft" ,
532+ ) -> FinetunePriceEstimationResponse :
533+ """
534+ Estimates the price of a fine-tuning job
535+
536+ Args:
537+ request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
538+
539+ Returns:
540+ FinetunePriceEstimationResponse: Object containing the estimated price.
541+ """
542+ training_type_cls : TrainingType | None = None
543+ training_method_cls : TrainingMethod | None = None
544+
545+ if training_method == "sft" :
546+ training_method_cls = TrainingMethodSFT (method = "sft" )
547+ elif training_method == "dpo" :
548+ training_method_cls = TrainingMethodDPO (method = "dpo" )
549+ else :
550+ raise ValueError (f"Unknown training method: { training_method } " )
551+
552+ if training_type .lower () == "lora" :
553+ training_type_cls = LoRATrainingType (
554+ type = "Lora" ,
555+ lora_r = 16 ,
556+ lora_alpha = 16 ,
557+ lora_dropout = 0.0 ,
558+ lora_trainable_modules = "all-linear" ,
559+ )
560+ elif training_type .lower () == "full" :
561+ training_type_cls = FullTrainingType (type = "Full" )
562+ else :
563+ raise ValueError (f"Unknown training type: { training_type } " )
564+
565+ request = FinetunePriceEstimationRequest (
566+ training_file = training_file ,
567+ validation_file = validation_file ,
568+ model = model ,
569+ n_epochs = n_epochs ,
570+ n_evals = n_evals ,
571+ training_type = training_type_cls ,
572+ training_method = training_method_cls ,
573+ )
574+ parameter_payload = request .model_dump (exclude_none = True )
575+ requestor = api_requestor .APIRequestor (
576+ client = self ._client ,
577+ )
578+
579+ response , _ , _ = requestor .request (
580+ options = TogetherRequest (
581+ method = "POST" , url = "fine-tunes/estimate-price" , params = parameter_payload
582+ ),
583+ stream = False ,
584+ )
585+ assert isinstance (response , TogetherResponse )
586+
587+ return FinetunePriceEstimationResponse (** response .data )
588+
496589 def list (self ) -> FinetuneList :
497590 """
498591 Lists fine-tune job history
@@ -941,11 +1034,29 @@ async def create(
9411034 hf_output_repo_name = hf_output_repo_name ,
9421035 )
9431036
1037+ price_estimation_result = await self .estimate_price (
1038+ training_file = training_file ,
1039+ validation_file = validation_file ,
1040+ model = model_name ,
1041+ n_epochs = n_epochs ,
1042+ n_evals = n_evals ,
1043+ training_type = finetune_request .training_type ,
1044+ training_method = finetune_request .training_method ,
1045+ )
1046+
9441047 if verbose :
9451048 rprint (
9461049 "Submitting a fine-tuning job with the following parameters:" ,
9471050 finetune_request ,
9481051 )
1052+ if not price_estimation_result .allowed_to_proceed :
1053+ rprint (
1054+ "[red]"
1055+ + _CONFIRMATION_MESSAGE_INSUFFICIENT_FUNDS .format (
1056+ price_estimation_result .estimated_total_price
1057+ )
1058+ + "[/red]" ,
1059+ )
9491060 parameter_payload = finetune_request .model_dump (exclude_none = True )
9501061
9511062 response , _ , _ = await requestor .arequest (
@@ -961,6 +1072,50 @@ async def create(
9611072
9621073 return FinetuneResponse (** response .data )
9631074
1075+ async def estimate_price (
1076+ self ,
1077+ * ,
1078+ training_file : str ,
1079+ model : str ,
1080+ validation_file : str | None = None ,
1081+ n_epochs : int | None = None ,
1082+ n_evals : int | None = None ,
1083+ training_type : TrainingType | None = None ,
1084+ training_method : TrainingMethodSFT | TrainingMethodDPO | None = None ,
1085+ ) -> FinetunePriceEstimationResponse :
1086+ """
1087+ Async method to estimate the price of a fine-tuning job
1088+
1089+ Args:
1090+ request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
1091+
1092+ Returns:
1093+ FinetunePriceEstimationResponse: Object containing the estimated price.
1094+ """
1095+ request = FinetunePriceEstimationRequest (
1096+ training_file = training_file ,
1097+ validation_file = validation_file ,
1098+ model = model ,
1099+ n_epochs = n_epochs ,
1100+ n_evals = n_evals ,
1101+ training_type = training_type ,
1102+ training_method = training_method ,
1103+ )
1104+ parameter_payload = request .model_dump (exclude_none = True )
1105+ requestor = api_requestor .APIRequestor (
1106+ client = self ._client ,
1107+ )
1108+
1109+ response , _ , _ = await requestor .arequest (
1110+ options = TogetherRequest (
1111+ method = "POST" , url = "fine-tunes/estimate-price" , params = parameter_payload
1112+ ),
1113+ stream = False ,
1114+ )
1115+ assert isinstance (response , TogetherResponse )
1116+
1117+ return FinetunePriceEstimationResponse (** response .data )
1118+
9641119 async def list (self ) -> FinetuneList :
9651120 """
9661121 Async method to list fine-tune job history
0 commit comments