@@ -94,14 +94,6 @@ def __init__(self, params: Parameters):
9494 # initialize the parent class
9595 super (Network , self ).__init__ ()
9696
97- # Mappings for parsing of the activation layers.
98- self ._activation_mappings = {
99- "Sigmoid" : nn .Sigmoid ,
100- "ReLU" : nn .ReLU ,
101- "LeakyReLU" : nn .LeakyReLU ,
102- "Tanh" : nn .Tanh ,
103- }
104-
10597 # initialize the layers
10698 self .number_of_layers = len (self .params .layer_sizes ) - 1
10799
@@ -231,22 +223,27 @@ def __init__(self, params):
231223 # We should NOT modify the list itself. This would break the
232224 # hyperparameter algorithms.
233225 use_only_one_activation_type = False
234- if len (self .params .layer_activations ) == 1 :
235- use_only_one_activation_type = True
236- elif len (self .params .layer_activations ) < self .number_of_layers :
237- raise Exception ("Not enough activation layers provided." )
238- elif len (self .params .layer_activations ) > self .number_of_layers :
239- printout (
240- "Too many activation layers provided. The last" ,
241- str (
226+ if isinstance (self .params .layer_activations , list ):
227+ if len (self .params .layer_activations ) > self .number_of_layers :
228+
229+ number_of_ignored_layers = (
242230 len (self .params .layer_activations ) - self .number_of_layers
243- ),
244- "activation function(s) will be ignored." ,
245- min_verbosity = 1 ,
246- )
231+ )
232+ number_of_ignored_layers += (
233+ 1
234+ if self .params .layer_activations_include_output_layer
235+ is False
236+ else 0
237+ )
238+ printout (
239+ "Too many activation layers provided. The last" ,
240+ str (number_of_ignored_layers ),
241+ "activation function(s) will be ignored." ,
242+ min_verbosity = 1 ,
243+ )
247244
248245 # Add the layers.
249- # As this is a feedforward layer we always add linear layers, and then
246+ # As this is a feedforward NN we always add linear layers, and then
250247 # an activation function
251248 for i in range (0 , self .number_of_layers ):
252249 self .layers .append (
@@ -257,21 +254,24 @@ def __init__(self, params):
257254 )
258255 )
259256 )
260- try :
261- if use_only_one_activation_type :
262- self .layers .append (
263- self ._activation_mappings [
264- self .params .layer_activations [0 ]
265- ]()
266- )
267- else :
268- self .layers .append (
269- self ._activation_mappings [
257+ if (
258+ i < self .number_of_layers - 1
259+ ) or self .params .layer_activations_include_output_layer :
260+ try :
261+ if isinstance (self .params .layer_activations , list ):
262+ self ._append_activation_function (
270263 self .params .layer_activations [i ]
271- ]()
272- )
273- except KeyError :
274- raise Exception ("Invalid activation type seleceted." )
264+ )
265+ else :
266+ self ._append_activation_function (
267+ self .params .layer_activations
268+ )
269+
270+ except KeyError :
271+ raise Exception ("Invalid activation type seleceted." )
272+ except IndexError :
273+ # No activation functions left to append at the end.
274+ pass
275275
276276 # Once everything is done, we can move the Network on the target
277277 # device.
@@ -296,6 +296,30 @@ def forward(self, inputs):
296296 inputs = layer (inputs )
297297 return inputs
298298
299+ def _append_activation_function (self , activation_function ):
300+ """
301+ Append an activation function to the network.
302+
303+ Parameters
304+ ----------
305+ activation_function : str or nn.Module or class
306+ Activation function to be appended.
307+ """
308+ if activation_function is None :
309+ pass
310+ elif isinstance (activation_function , str ):
311+ try :
312+ self .layers .append (getattr (torch .nn , activation_function )())
313+ except AttributeError :
314+ raise Exception (
315+ "Torch does not contain the specified "
316+ "activation function: " + activation_function
317+ )
318+ elif isinstance (activation_function , nn .Module ):
319+ self .layers .append (activation_function )
320+ elif issubclass (activation_function , nn .Module ):
321+ self .layers .append (activation_function ())
322+
299323
300324class LSTM (Network ):
301325 """Initialize this network as a LSTM network."""
@@ -339,9 +363,7 @@ def __init__(self, params):
339363 self .params .num_hidden_layers ,
340364 batch_first = True ,
341365 )
342- self .activation = self ._activation_mappings [
343- self .params .layer_activations [0 ]
344- ]()
366+ self .activation = getattr (torch .nn , self .params .layer_activations [0 ])()
345367
346368 self .batch_size = None
347369 # Once everything is done, we can move the Network on the target
@@ -477,9 +499,7 @@ def __init__(self, params):
477499 self .params .num_hidden_layers ,
478500 batch_first = True ,
479501 )
480- self .activation = self ._activation_mappings [
481- self .params .layer_activations [0 ]
482- ]()
502+ self .activation = getattr (torch .nn , self .params .layer_activations [0 ])()
483503
484504 if params .use_gpu :
485505 self .to ("cuda" )
0 commit comments