Skip to content

Commit 9b15370

Browse files
committed
format and add try/except
1 parent 90d5a6d commit 9b15370

1 file changed

Lines changed: 14 additions & 3 deletions

File tree

src/together/resources/finetune.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,9 @@ def download(
570570
*,
571571
output: Path | str | None = None,
572572
checkpoint_step: int | None = None,
573-
checkpoint_type: Union[DownloadCheckpointType, str] = DownloadCheckpointType.DEFAULT,
573+
checkpoint_type: Union[
574+
DownloadCheckpointType, str
575+
] = DownloadCheckpointType.DEFAULT,
574576
) -> FinetuneDownloadResult:
575577
"""
576578
Downloads compressed fine-tuned model or checkpoint to local disk.
@@ -609,7 +611,13 @@ def download(
609611

610612
# convert to str
611613
if isinstance(checkpoint_type, str):
612-
checkpoint_type = DownloadCheckpointType(checkpoint_type)
614+
try:
615+
checkpoint_type = DownloadCheckpointType(checkpoint_type.lower())
616+
except ValueError:
617+
enum_strs = ", ".join([e.value for e in DownloadCheckpointType])
618+
raise ValueError(
619+
f"Invalid checkpoint type: {checkpoint_type}. Choose one of {{{enum_strs}}}."
620+
)
613621

614622
if isinstance(ft_job.training_type, FullTrainingType):
615623
if checkpoint_type != DownloadCheckpointType.DEFAULT:
@@ -621,7 +629,10 @@ def download(
621629
if checkpoint_type == DownloadCheckpointType.DEFAULT:
622630
checkpoint_type = DownloadCheckpointType.MERGED
623631

624-
if checkpoint_type in {DownloadCheckpointType.MERGED, DownloadCheckpointType.ADAPTER}:
632+
if checkpoint_type in {
633+
DownloadCheckpointType.MERGED,
634+
DownloadCheckpointType.ADAPTER,
635+
}:
625636
url += f"&checkpoint={checkpoint_type.value}"
626637
else:
627638
raise ValueError(

0 commit comments

Comments
 (0)