4747_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
4848 "The estimated price of the fine-tuning job is {} which is significantly "
4949 "greater than your current credit limit and balance combined. "
50- "It will likely fail due to insufficient funds. "
51- "Please proceed at your own risk."
50+ "It will likely get cancelled due to insufficient funds. "
51+ "Proceed at your own risk."
5252)
5353
5454
@@ -481,16 +481,25 @@ def create(
481481 hf_api_token = hf_api_token ,
482482 hf_output_repo_name = hf_output_repo_name ,
483483 )
484-
485- price_estimation_result = self .estimate_price (
486- training_file = training_file ,
487- validation_file = validation_file ,
488- model = model_name ,
489- n_epochs = n_epochs ,
490- n_evals = n_evals ,
491- training_type = "lora" if lora else "full" ,
492- training_method = training_method ,
493- )
484+ if from_checkpoint is None :
485+ price_estimation_result = self .estimate_price (
486+ training_file = training_file ,
487+ validation_file = validation_file ,
488+ model = model_name ,
489+ n_epochs = finetune_request .n_epochs ,
490+ n_evals = finetune_request .n_evals ,
491+ training_type = "lora" if lora else "full" ,
492+ training_method = training_method ,
493+ )
494+ else :
495+ # unsupported case
496+ price_estimation_result = FinetunePriceEstimationResponse (
497+ estimated_total_price = 0.0 ,
498+ allowed_to_proceed = True ,
499+ estimated_train_token_count = 0 ,
500+ estimated_eval_token_count = 0 ,
501+ user_limit = 0.0 ,
502+ )
494503
495504 if verbose :
496505 rprint (
@@ -523,10 +532,10 @@ def estimate_price(
523532 self ,
524533 * ,
525534 training_file : str ,
526- model : str | None ,
535+ model : str ,
527536 validation_file : str | None = None ,
528- n_epochs : int | None = None ,
529- n_evals : int | None = None ,
537+ n_epochs : int | None = 1 ,
538+ n_evals : int | None = 0 ,
530539 training_type : str = "lora" ,
531540 training_method : str = "sft" ,
532541 ) -> FinetunePriceEstimationResponse :
@@ -539,8 +548,8 @@ def estimate_price(
539548 Returns:
540549 FinetunePriceEstimationResponse: Object containing the estimated price.
541550 """
542- training_type_cls : TrainingType | None = None
543- training_method_cls : TrainingMethod | None = None
551+ training_type_cls : TrainingType
552+ training_method_cls : TrainingMethod
544553
545554 if training_method == "sft" :
546555 training_method_cls = TrainingMethodSFT (method = "sft" )
@@ -1036,15 +1045,25 @@ async def create(
10361045 hf_output_repo_name = hf_output_repo_name ,
10371046 )
10381047
1039- price_estimation_result = await self .estimate_price (
1040- training_file = training_file ,
1041- validation_file = validation_file ,
1042- model = model_name ,
1043- n_epochs = n_epochs ,
1044- n_evals = n_evals ,
1045- training_type = finetune_request .training_type ,
1046- training_method = finetune_request .training_method ,
1047- )
1048+ if from_checkpoint is not None :
1049+ price_estimation_result = await self .estimate_price (
1050+ training_file = training_file ,
1051+ validation_file = validation_file ,
1052+ model = model_name ,
1053+ n_epochs = finetune_request .n_epochs ,
1054+ n_evals = finetune_request .n_evals ,
1055+ training_type = "lora" if lora else "full" ,
1056+ training_method = training_method ,
1057+ )
1058+ else :
1059+ # unsupported case
1060+ price_estimation_result = FinetunePriceEstimationResponse (
1061+ estimated_total_price = 0.0 ,
1062+ allowed_to_proceed = True ,
1063+ estimated_train_token_count = 0 ,
1064+ estimated_eval_token_count = 0 ,
1065+ user_limit = 0.0 ,
1066+ )
10481067
10491068 if verbose :
10501069 rprint (
@@ -1080,28 +1099,53 @@ async def estimate_price(
10801099 training_file : str ,
10811100 model : str ,
10821101 validation_file : str | None = None ,
1083- n_epochs : int | None = None ,
1084- n_evals : int | None = None ,
1085- training_type : TrainingType | None = None ,
1086- training_method : TrainingMethodSFT | TrainingMethodDPO | None = None ,
1102+ n_epochs : int | None = 1 ,
1103+ n_evals : int | None = 0 ,
1104+ training_type : str = "lora" ,
1105+ training_method : str = "sft" ,
10871106 ) -> FinetunePriceEstimationResponse :
10881107 """
1089- Async method to estimate the price of a fine-tuning job
1108+ Estimates the price of a fine-tuning job
10901109
10911110 Args:
10921111 request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
10931112
10941113 Returns:
10951114 FinetunePriceEstimationResponse: Object containing the estimated price.
10961115 """
1116+ training_type_cls : TrainingType
1117+ training_method_cls : TrainingMethod
1118+
1119+ if training_method == "sft" :
1120+ training_method_cls = TrainingMethodSFT (method = "sft" )
1121+ elif training_method == "dpo" :
1122+ training_method_cls = TrainingMethodDPO (method = "dpo" )
1123+ else :
1124+ raise ValueError (f"Unknown training method: { training_method } " )
1125+
1126+ if training_type .lower () == "lora" :
1127+ # parameters of lora are unused in price estimation
1128+ # but we need to set them to valid values
1129+ training_type_cls = LoRATrainingType (
1130+ type = "Lora" ,
1131+ lora_r = 16 ,
1132+ lora_alpha = 16 ,
1133+ lora_dropout = 0.0 ,
1134+ lora_trainable_modules = "all-linear" ,
1135+ )
1136+ elif training_type .lower () == "full" :
1137+ training_type_cls = FullTrainingType (type = "Full" )
1138+ else :
1139+ raise ValueError (f"Unknown training type: { training_type } " )
1140+
10971141 request = FinetunePriceEstimationRequest (
10981142 training_file = training_file ,
10991143 validation_file = validation_file ,
11001144 model = model ,
11011145 n_epochs = n_epochs ,
11021146 n_evals = n_evals ,
1103- training_type = training_type ,
1104- training_method = training_method ,
1147+ training_type = training_type_cls ,
1148+ training_method = training_method_cls ,
11051149 )
11061150 parameter_payload = request .model_dump (exclude_none = True )
11071151 requestor = api_requestor .APIRequestor (
0 commit comments