Skip to content

Commit 363fb3e

Browse files
committed
improve hyperparam parsing for model saving
1 parent e7e46b0 commit 363fb3e

3 files changed

Lines changed: 47 additions & 12 deletions

File tree

codes/surrogates/AbstractSurrogate/abstract_surrogate.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.utils.data import DataLoader
1414
from tqdm import tqdm
1515

16-
from codes.utils import create_model_dir
16+
from codes.utils import create_model_dir, parse_hyperparameters
1717

1818

1919
class AbstractSurrogateModel(ABC, nn.Module):
@@ -334,17 +334,7 @@ def save(
334334
)
335335
hyperparameters["date"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
336336

337-
# Recursively parse hyperparameters for numpy arrays and convert them to lists
338-
def parse_hyperparameters(hyperparams: dict) -> dict:
339-
for key, value in hyperparams.items():
340-
if isinstance(value, np.ndarray):
341-
hyperparams[key] = value.tolist()
342-
elif isinstance(value, Tensor):
343-
hyperparams[key] = value.cpu().detach().numpy().tolist()
344-
elif isinstance(value, dict):
345-
hyperparams[key] = parse_hyperparameters(value)
346-
return hyperparams
347-
337+
# Recursively parse hyperparameters to make them yaml-serializable
348338
hyperparameters = parse_hyperparameters(hyperparameters)
349339

350340
# Reduce the precision of the losses and accuracy

codes/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
load_task_list,
1717
make_description,
1818
nice_print,
19+
parse_hyperparameters,
1920
read_yaml_config,
2021
save_task_list,
2122
set_random_seeds,
@@ -44,4 +45,5 @@
4445
"check_training_status",
4546
"determine_batch_size",
4647
"batch_factor_to_float",
48+
"parse_hyperparameters",
4749
]

codes/utils/utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
import torch
1212
import yaml
13+
from torch import Tensor
1314
from tqdm import tqdm
1415

1516

@@ -391,3 +392,45 @@ def batch_factor_to_float(batch_factor: str | int | float) -> float:
391392
raise ValueError(
392393
f"Invalid batch factor: {batch_factor}. Must be a float, int, or a valid fraction string."
393394
) from e
395+
396+
397+
def parse_hyperparameters(hyperparams: dict) -> dict:
398+
for key, value in hyperparams.items():
399+
if isinstance(value, np.ndarray):
400+
hyperparams[key] = value.tolist()
401+
elif isinstance(value, Tensor):
402+
hyperparams[key] = value.cpu().detach().numpy().tolist()
403+
elif isinstance(value, (np.number, np.bool_)):
404+
# Handle numpy scalars (like numpy.float32, numpy.int64, etc.)
405+
hyperparams[key] = value.item()
406+
elif isinstance(value, dict):
407+
hyperparams[key] = parse_hyperparameters(value)
408+
elif isinstance(value, (list, tuple)):
409+
# Recursively handle lists and tuples that might contain numpy objects
410+
hyperparams[key] = [
411+
(
412+
item.item()
413+
if isinstance(item, (np.number, np.bool_))
414+
else (
415+
item.tolist()
416+
if isinstance(item, np.ndarray)
417+
else (
418+
item.cpu().detach().numpy().tolist()
419+
if isinstance(item, Tensor)
420+
else item
421+
)
422+
)
423+
)
424+
for item in value
425+
]
426+
elif hasattr(value, "item") and callable(getattr(value, "item")):
427+
# Catch-all for any object with an .item() method (like numpy scalars)
428+
try:
429+
hyperparams[key] = value.item()
430+
except (ValueError, TypeError):
431+
# If .item() fails, convert to string as fallback
432+
hyperparams[key] = str(value)
433+
elif not isinstance(value, (str, int, float, bool, type(None))):
434+
# Convert any remaining non-serializable objects to string
435+
hyperparams[key] = str(value)
436+
return hyperparams

0 commit comments

Comments
 (0)