44import numpy as np
55import torch
66import torch .distributed as dist
7+ import json
8+ import io
79
810from mala .common .parameters import printout
911from mala .common .parallelizer import parallel_warn
@@ -517,7 +519,7 @@ def inverse_transform(self, scaled, copy=False, as_numpy=False):
517519 else :
518520 return unscaled
519521
520- def save (self , filename , save_format = "pickle " ):
522+ def save (self , filename , save_format = "json " ):
521523 """
522524 Save the Scaler object so that it can be accessed again later.
523525
@@ -527,23 +529,58 @@ def save(self, filename, save_format="pickle"):
527529 File in which the parameters will be saved.
528530
529531 save_format :
530- File format which will be used for saving.
532+ File format which will be used for saving. Default is "json".
533+ Pickle format is deprecated and will be removed in future versions.
531534 """
532535 # If we use ddp, only save the network on root.
533536 if self .use_ddp :
534537 if dist .get_rank () != 0 :
535538 return
536- if save_format == "pickle" :
539+
540+ filename_format = filename .rsplit ("." , 1 )[1 ]
541+ if save_format == "pickle" or filename_format == "pkl" : # similar to "normal" string warning
542+ parallel_warn (
543+ "Pickle format is deprecated and will be removed in future versions. "
544+ "Please use JSON format instead." ,
545+ min_verbosity = 0 ,
546+ category = FutureWarning ,
547+ )
537548 with open (filename , "wb" ) as handle :
538549 pickle .dump (self , handle , protocol = 4 )
550+ elif save_format == "json" or filename_format == "json" :
551+ # saving tensors as lists for json
552+ # if scale_normal is used, it will be converted to scale_minmax
553+ data_dict = {
554+ "typestring" : self .typestring ,
555+ "use_ddp" : self .use_ddp ,
556+ "scale_standard" : self .scale_standard ,
557+ "scale_minmax" : (self .scale_minmax
558+ if hasattr (self , "scale_minmax" )
559+ else self .scale_normal ),
560+ "feature_wise" : self .feature_wise ,
561+ "cantransform" : self .cantransform ,
562+ "means" : self .means .tolist () if hasattr (self .means , "tolist" ) else [],
563+ "stds" : self .stds .tolist () if hasattr (self .stds , "tolist" ) else [],
564+ "maxs" : self .maxs .tolist () if hasattr (self .maxs , "tolist" ) else [],
565+ "mins" : self .mins .tolist () if hasattr (self .mins , "tolist" ) else [],
566+ "total_mean" : float (self .total_mean ),
567+ "total_std" : float (self .total_std ),
568+ "total_max" : float (self .total_max ),
569+ "total_min" : float (self .total_min ),
570+ "total_data_count" : self .total_data_count
571+ }
572+
573+ with open (filename , "w" ) as handle :
574+ json .dump (data_dict , handle , indent = 4 )
575+
539576 else :
540577 raise Exception ("Unsupported parameter save format." )
541578
542579 @classmethod
543- def load_from_file (cls , file , save_format = "pickle" ):
580+ def load_from_file (cls , file , save_format = "json" , auto_convert = True ):
544581 """
545582 Load a saved Scaler object.
546-
583+
547584 Parameters
548585 ----------
549586 file : string or ZipExtFile
@@ -552,17 +589,73 @@ def load_from_file(cls, file, save_format="pickle"):
552589 save_format :
553590 File format which was used for saving.
554591
592+ auto_convert : bool
593+ If True and loading from pickle format, automatically save as JSON for future use.
594+
555595 Returns
556596 -------
557597 data_scaler : DataScaler
558598 DataScaler which was read from the file.
559599 """
560- if save_format == "pickle" :
600+ if isinstance (file , str ):
601+ filename = file
602+ elif hasattr (file , 'name' ): # getting fname from zip file
603+ filename = file .name
604+ else :
605+ raise Exception ("File must be either a string path or a ZipFile object" )
606+
607+ filename_format = filename .rsplit ("." , 1 )[1 ]
608+
609+ if save_format == "pickle" or filename_format == "pkl" :
610+ parallel_warn (
611+ "Loading from pickle format is deprecated and will be removed in future versions. "
612+ "Please convert your files to JSON format." ,
613+ min_verbosity = 0 ,
614+ category = FutureWarning ,
615+ )
561616 if isinstance (file , str ):
562617 loaded_scaler = pickle .load (open (file , "rb" ))
563- else :
618+
619+ if auto_convert :
620+ json_file_path = filename .rsplit ("." , 1 )[0 ] + ".json"
621+ loaded_scaler .save (json_file_path , save_format = "json" )
622+
623+
624+ elif hasattr (file , 'name' ):
564625 loaded_scaler = pickle .load (file )
565- else :
566- raise Exception ("Unsupported parameter save format." )
567626
627+ parallel_warn (
628+ "Pickle file has been automatically converted to JSON format." ,
629+ min_verbosity = 0 ,
630+ category = FutureWarning ,
631+ )
632+
633+ elif save_format == "json" or filename_format == "json" :
634+ if isinstance (file , str ):
635+ with open (file , "r" ) as handle :
636+ data_dict = json .load (handle )
637+ elif hasattr (file , 'name' ):
638+ text_handle = io .TextIOWrapper (file , encoding = "utf-8" )
639+ data_dict = json .load (text_handle )
640+
641+ loaded_scaler = cls (data_dict ["typestring" ], data_dict ["use_ddp" ])
642+
643+ loaded_scaler .scale_standard = data_dict ["scale_standard" ]
644+ loaded_scaler .scale_minmax = data_dict ["scale_minmax" ]
645+ loaded_scaler .feature_wise = data_dict ["feature_wise" ]
646+ loaded_scaler .cantransform = data_dict ["cantransform" ]
647+
648+ loaded_scaler .means = torch .tensor (data_dict ["means" ])
649+ loaded_scaler .stds = torch .tensor (data_dict ["stds" ])
650+ loaded_scaler .maxs = torch .tensor (data_dict ["maxs" ])
651+ loaded_scaler .mins = torch .tensor (data_dict ["mins" ])
652+
653+ loaded_scaler .total_mean = torch .tensor (data_dict ["total_mean" ])
654+ loaded_scaler .total_std = torch .tensor (data_dict ["total_std" ])
655+ loaded_scaler .total_max = torch .tensor (data_dict ["total_max" ])
656+ loaded_scaler .total_min = torch .tensor (data_dict ["total_min" ])
657+ loaded_scaler .total_data_count = data_dict ["total_data_count" ]
658+ else :
659+ raise Exception ("Unsupported parameter save format. Use 'json' or 'pickle'." )
660+
568661 return loaded_scaler
0 commit comments