Skip to content

Commit c897ceb

Browse files
Merge pull request mala-project#660 from karanprime/develop
Added json functionality for data scaling
2 parents cdf7945 + 022adcd commit c897ceb

3 files changed

Lines changed: 168 additions & 19 deletions

File tree

mala/datahandling/data_scaler.py

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
import torch
66
import torch.distributed as dist
7+
import json
8+
import io
79

810
from mala.common.parameters import printout
911
from mala.common.parallelizer import parallel_warn
@@ -517,7 +519,7 @@ def inverse_transform(self, scaled, copy=False, as_numpy=False):
517519
else:
518520
return unscaled
519521

520-
def save(self, filename, save_format="pickle"):
522+
def save(self, filename, save_format="json"):
521523
"""
522524
Save the Scaler object so that it can be accessed again later.
523525
@@ -527,23 +529,58 @@ def save(self, filename, save_format="pickle"):
527529
File in which the parameters will be saved.
528530
529531
save_format :
530-
File format which will be used for saving.
532+
File format which will be used for saving. Default is "json".
533+
Pickle format is deprecated and will be removed in future versions.
531534
"""
532535
# If we use ddp, only save the network on root.
533536
if self.use_ddp:
534537
if dist.get_rank() != 0:
535538
return
536-
if save_format == "pickle":
539+
540+
filename_format = filename.rsplit(".", 1)[1]
541+
if save_format == "pickle" or filename_format == "pkl": # similar to "normal" string warning
542+
parallel_warn(
543+
"Pickle format is deprecated and will be removed in future versions. "
544+
"Please use JSON format instead.",
545+
min_verbosity=0,
546+
category=FutureWarning,
547+
)
537548
with open(filename, "wb") as handle:
538549
pickle.dump(self, handle, protocol=4)
550+
elif save_format == "json" or filename_format == "json":
551+
# saving tensors as lists for json
552+
# if scale_normal is used, it will be converted to scale_minmax
553+
data_dict = {
554+
"typestring": self.typestring,
555+
"use_ddp": self.use_ddp,
556+
"scale_standard": self.scale_standard,
557+
"scale_minmax": (self.scale_minmax
558+
if hasattr(self, "scale_minmax")
559+
else self.scale_normal),
560+
"feature_wise": self.feature_wise,
561+
"cantransform": self.cantransform,
562+
"means": self.means.tolist() if hasattr(self.means, "tolist") else [],
563+
"stds": self.stds.tolist() if hasattr(self.stds, "tolist") else [],
564+
"maxs": self.maxs.tolist() if hasattr(self.maxs, "tolist") else [],
565+
"mins": self.mins.tolist() if hasattr(self.mins, "tolist") else [],
566+
"total_mean": float(self.total_mean),
567+
"total_std": float(self.total_std),
568+
"total_max": float(self.total_max),
569+
"total_min": float(self.total_min),
570+
"total_data_count": self.total_data_count
571+
}
572+
573+
with open(filename, "w") as handle:
574+
json.dump(data_dict, handle, indent=4)
575+
539576
else:
540577
raise Exception("Unsupported parameter save format.")
541578

542579
@classmethod
543-
def load_from_file(cls, file, save_format="pickle"):
580+
def load_from_file(cls, file, save_format="json", auto_convert=True):
544581
"""
545582
Load a saved Scaler object.
546-
583+
547584
Parameters
548585
----------
549586
file : string or ZipExtFile
@@ -552,17 +589,73 @@ def load_from_file(cls, file, save_format="pickle"):
552589
save_format :
553590
File format which was used for saving.
554591
592+
auto_convert : bool
593+
If True and loading from pickle format, automatically save as JSON for future use.
594+
555595
Returns
556596
-------
557597
data_scaler : DataScaler
558598
DataScaler which was read from the file.
559599
"""
560-
if save_format == "pickle":
600+
if isinstance(file, str):
601+
filename = file
602+
elif hasattr(file, 'name'): # getting fname from zip file
603+
filename = file.name
604+
else:
605+
raise Exception("File must be either a string path or a ZipFile object")
606+
607+
filename_format = filename.rsplit(".", 1)[1]
608+
609+
if save_format == "pickle" or filename_format == "pkl":
610+
parallel_warn(
611+
"Loading from pickle format is deprecated and will be removed in future versions. "
612+
"Please convert your files to JSON format.",
613+
min_verbosity=0,
614+
category=FutureWarning,
615+
)
561616
if isinstance(file, str):
562617
loaded_scaler = pickle.load(open(file, "rb"))
563-
else:
618+
619+
if auto_convert:
620+
json_file_path = filename.rsplit(".", 1)[0] + ".json"
621+
loaded_scaler.save(json_file_path, save_format="json")
622+
623+
624+
elif hasattr(file, 'name'):
564625
loaded_scaler = pickle.load(file)
565-
else:
566-
raise Exception("Unsupported parameter save format.")
567626

627+
parallel_warn(
628+
"Pickle file has been automatically converted to JSON format.",
629+
min_verbosity=0,
630+
category=FutureWarning,
631+
)
632+
633+
elif save_format == "json" or filename_format == "json":
634+
if isinstance(file, str):
635+
with open(file, "r") as handle:
636+
data_dict = json.load(handle)
637+
elif hasattr(file, 'name'):
638+
text_handle = io.TextIOWrapper(file, encoding="utf-8")
639+
data_dict = json.load(text_handle)
640+
641+
loaded_scaler = cls(data_dict["typestring"], data_dict["use_ddp"])
642+
643+
loaded_scaler.scale_standard = data_dict["scale_standard"]
644+
loaded_scaler.scale_minmax = data_dict["scale_minmax"]
645+
loaded_scaler.feature_wise = data_dict["feature_wise"]
646+
loaded_scaler.cantransform = data_dict["cantransform"]
647+
648+
loaded_scaler.means = torch.tensor(data_dict["means"])
649+
loaded_scaler.stds = torch.tensor(data_dict["stds"])
650+
loaded_scaler.maxs = torch.tensor(data_dict["maxs"])
651+
loaded_scaler.mins = torch.tensor(data_dict["mins"])
652+
653+
loaded_scaler.total_mean = torch.tensor(data_dict["total_mean"])
654+
loaded_scaler.total_std = torch.tensor(data_dict["total_std"])
655+
loaded_scaler.total_max = torch.tensor(data_dict["total_max"])
656+
loaded_scaler.total_min = torch.tensor(data_dict["total_min"])
657+
loaded_scaler.total_data_count = data_dict["total_data_count"]
658+
else:
659+
raise Exception("Unsupported parameter save format. Use 'json' or 'pickle'.")
660+
568661
return loaded_scaler

mala/network/runner.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Runner class for running networks."""
22

33
import os
4+
import tempfile
45
from zipfile import ZipFile, ZIP_STORED
56

67
from mala.common.parallelizer import printout
8+
from mala.common.parallelizer import parallel_warn
79

810
import numpy as np
911
import torch
@@ -580,8 +582,8 @@ def save_run(
580582
# performed on rank 0.
581583
if get_rank() == 0:
582584
model_file = run_name + ".network.pth"
583-
iscaler_file = run_name + ".iscaler.pkl"
584-
oscaler_file = run_name + ".oscaler.pkl"
585+
iscaler_file = run_name + ".iscaler.json"
586+
oscaler_file = run_name + ".oscaler.json"
585587
params_file = run_name + ".params.json"
586588
if save_runner:
587589
optimizer_file = run_name + ".optimizer.pth"
@@ -632,6 +634,7 @@ def load_run(
632634
path="./",
633635
zip_run=True,
634636
params_format="json",
637+
scalers_format="json",
635638
load_runner=True,
636639
prepare_data=False,
637640
load_with_mpi=None,
@@ -669,6 +672,10 @@ def load_run(
669672
Can be "json" or "pkl", depending on what was saved by the model.
670673
Default is "json".
671674
675+
scalers_format: str
676+
Can be "json" or "pkl", depending on what was saved by the model.
677+
Default is "json".
678+
672679
load_runner : bool
673680
If True, a Runner object will be created/loaded for further use.
674681
@@ -719,26 +726,44 @@ def load_run(
719726
loaded_info = None
720727
if zip_run is True:
721728
loaded_network = run_name + ".network.pth"
722-
loaded_iscaler = run_name + ".iscaler.pkl"
723-
loaded_oscaler = run_name + ".oscaler.pkl"
729+
loaded_iscaler = run_name + ".iscaler." + scalers_format
730+
loaded_oscaler = run_name + ".oscaler." + scalers_format
724731
loaded_params = run_name + ".params." + params_format
725732
loaded_info = run_name + ".info.json"
726733

734+
iscale_pickle_flag = False
735+
oscale_pickle_flag = False
736+
727737
zip_path = os.path.join(path, run_name + ".zip")
728738
with ZipFile(zip_path, "r") as zip_obj:
729739
loaded_params = zip_obj.open(loaded_params)
730740
loaded_network = zip_obj.open(loaded_network)
731-
loaded_iscaler = zip_obj.open(loaded_iscaler)
732-
loaded_oscaler = zip_obj.open(loaded_oscaler)
741+
742+
# If json scaler files not found, try pickle format
743+
try:
744+
loaded_iscaler = zip_obj.open(loaded_iscaler)
745+
except KeyError:
746+
iscale_pickle_flag = True
747+
loaded_iscaler = zip_obj.open(loaded_iscaler.replace(".json", ".pkl"))
748+
try:
749+
loaded_oscaler = zip_obj.open(loaded_oscaler)
750+
except KeyError:
751+
oscale_pickle_flag = True
752+
loaded_oscaler = zip_obj.open(loaded_oscaler.replace(".json", ".pkl"))
753+
733754
if loaded_info in zip_obj.namelist():
734755
loaded_info = zip_obj.open(loaded_info)
735756
else:
736757
loaded_info = None
737758

738759
else:
739760
loaded_network = os.path.join(path, run_name + ".network.pth")
740-
loaded_iscaler = os.path.join(path, run_name + ".iscaler.pkl")
741-
loaded_oscaler = os.path.join(path, run_name + ".oscaler.pkl")
761+
loaded_iscaler = os.path.join(
762+
path, run_name + ".iscaler." + scalers_format
763+
)
764+
loaded_oscaler = os.path.join(
765+
path, run_name + ".oscaler." + scalers_format
766+
)
742767
loaded_params = os.path.join(
743768
path, run_name + ".params." + params_format
744769
)
@@ -772,6 +797,37 @@ def load_run(
772797
loaded_network = Network.load_from_file(loaded_params, loaded_network)
773798
loaded_iscaler = DataScaler.load_from_file(loaded_iscaler)
774799
loaded_oscaler = DataScaler.load_from_file(loaded_oscaler)
800+
801+
# only on rank 0, if pickle scaler files are found,
802+
# add their json versions to the existing zip file
803+
if get_rank() == 0 and (zip_run and (iscale_pickle_flag or oscale_pickle_flag)):
804+
parallel_warn(
805+
"Pickle file has been automatically converted to JSON format.",
806+
min_verbosity=0,
807+
category=FutureWarning,
808+
)
809+
with tempfile.TemporaryDirectory() as temp_dir:
810+
with ZipFile(zip_path, 'r') as zip_read:
811+
zip_read.extractall(temp_dir)
812+
813+
iscaler_file = run_name + ".iscaler.json"
814+
oscaler_file = run_name + ".oscaler.json"
815+
816+
loaded_iscaler.save(os.path.join(temp_dir, iscaler_file))
817+
loaded_oscaler.save(os.path.join(temp_dir, oscaler_file))
818+
819+
temp_zip_path = zip_path + ".temp"
820+
with ZipFile(temp_zip_path, 'w') as zip_write:
821+
for foldername, subfolders, filenames in os.walk(temp_dir):
822+
for filename in filenames:
823+
file_path = os.path.join(foldername, filename)
824+
arcname = os.path.relpath(file_path, temp_dir)
825+
zip_write.write(file_path, arcname)
826+
827+
os.replace(temp_zip_path, zip_path)
828+
829+
830+
775831
new_datahandler = DataHandler(
776832
loaded_params,
777833
input_data_scaler=loaded_iscaler,

test/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Test directory
22

3-
Different tests that can be rerun at any time to make sure a certain function or idea stillm works.
3+
Different tests that can be rerun at any time to make sure a certain function or idea still works.
44

55
## tensor_memory.py
66

7-
Verifies that the way we create torch tensors from numpy arrays is in fact by referencing, and not by copying.
7+
Verifies that the way we create torch tensors from numpy arrays is in fact by referencing, and not by copying.

0 commit comments

Comments
 (0)