11"""Neural network for MALA."""
22
3- from copy import deepcopy
4-
53from abc import abstractmethod
64import numpy as np
75import torch
86import torch .distributed as dist
97import torch .nn as nn
108import torch .nn .functional as functional
119
12- from mala .common .parameters import Parameters , ParametersNetwork
10+ from mala .common .parameters import Parameters
1311from mala .common .parallelizer import printout , parallel_warn
1412
1513
@@ -44,7 +42,7 @@ class Network(nn.Module):
4442 If True, the torch distributed data parallel formalism will be used.
4543 """
4644
47- def __new__ (cls , params : Parameters ):
45+ def __new__ (cls , params : Parameters = None ):
4846 """
4947 Create a neural network instance.
5048
@@ -59,46 +57,34 @@ def __new__(cls, params: Parameters):
5957 """
6058 model = None
6159
62- # Check if we're accessing through base class.
63- # If not, we need to return the correct object directly.
64- if cls == Network :
65- if isinstance (params , Parameters ):
66- params_network = params .network
67- elif isinstance (params , ParametersNetwork ):
68- params_network = params
69- else :
70- raise Exception ("Incompativle parameters supplied" )
60+ # Check if we're accessing through base class. If not, we need to
61+ # return the correct object directly.
62+ #
63+ # params=None if we load a serialized object, for instance. Then we
64+ # just want an empty object which gets populated during
65+ # deserialization.
7166
72- if params_network .nn_type == "feed-forward" :
67+ if params is None or cls != Network :
68+ model = super ().__new__ (cls )
69+ else :
70+ if params .network .nn_type == "feed-forward" :
7371 model = super (Network , FeedForwardNet ).__new__ (FeedForwardNet )
7472
75- if model is None :
73+ else :
7674 raise Exception ("Unsupported network architecture." )
77- else :
78- model = super (Network , cls ).__new__ (cls )
7975
8076 return model
8177
8278 def __init__ (self , params : Parameters ):
83- if isinstance (params , Parameters ):
84- params_network = params .network
85- self .use_ddp = params .use_ddp
86- seed = params .manual_seed
87- elif isinstance (params , ParametersNetwork ):
88- params_network = params
89- self .use_ddp = params_network ._configuration ["ddp" ]
90- seed = params ._configuration ["manual_seed" ]
91- else :
92- raise Exception ("Incompativle parameters supplied" )
93-
9479 # copy the network params from the input parameter object
95- self .params = params_network
80+ self .use_ddp = params .use_ddp
81+ self .params = params .network
9682
9783 # if the user has planted a seed (for comparibility purposes) we
9884 # should use it.
99- if seed is not None :
100- torch .manual_seed (seed )
101- torch .cuda .manual_seed (seed )
85+ if params . manual_seed is not None :
86+ torch .manual_seed (params . manual_seed )
87+ torch .cuda .manual_seed (params . manual_seed )
10288
10389 # initialize the parent class
10490 super (Network , self ).__init__ ()
@@ -112,20 +98,6 @@ def __init__(self, params: Parameters):
11298 else :
11399 raise Exception ("Unsupported loss function." )
114100
115- def __copy__ (self ):
116- """Copy a Network instance."""
117- result = Network .__new__ (Network , params = self .params )
118- result .__dict__ .update (self .__dict__ )
119- return result
120-
121- def __deepcopy__ (self , memo ):
122- """Deepcopy a Network instance."""
123- result = Network .__new__ (Network , params = self .params )
124- memo [id (self )] = result
125- for k , v in self .__dict__ .items ():
126- setattr (result , k , deepcopy (v , memo ))
127- return result
128-
129101 @abstractmethod
130102 def forward (self , inputs ):
131103 """
0 commit comments