@@ -231,15 +231,21 @@ def __init__(self, params):
231231 # We should NOT modify the list itself. This would break the
232232 # hyperparameter algorithms.
233233 use_only_one_activation_type = False
234-
235- if not isinstance (self .params .layer_activations , str ):
234+ if isinstance (self .params .layer_activations , list ):
236235 if len (self .params .layer_activations ) > self .number_of_layers :
236+
237+ number_of_ignored_layers = (
238+ len (self .params .layer_activations ) - self .number_of_layers
239+ )
240+ number_of_ignored_layers += (
241+ 1
242+ if self .params .layer_activations_include_output_layer
243+ is False
244+ else 0
245+ )
237246 printout (
238247 "Too many activation layers provided. The last" ,
239- str (
240- len (self .params .layer_activations )
241- - self .number_of_layers
242- ),
248+ str (number_of_ignored_layers ),
243249 "activation function(s) will be ignored." ,
244250 min_verbosity = 1 ,
245251 )
@@ -256,20 +262,24 @@ def __init__(self, params):
256262 )
257263 )
258264 )
259- try :
260- if isinstance (self .params .layer_activations , str ):
261- self ._append_activation_function (
262- self .params .layer_activations
263- )
264- else :
265- self ._append_activation_function (
266- self .params .layer_activations [i ]
267- )
268- except KeyError :
269- raise Exception ("Invalid activation type seleceted." )
270- except IndexError :
271- # No activation functions left to append at the end.
272- pass
265+ if (
266+ i < self .number_of_layers - 1
267+ ) or self .params .layer_activations_include_output_layer :
268+ try :
269+ if isinstance (self .params .layer_activations , list ):
270+ self ._append_activation_function (
271+ self .params .layer_activations [i ]
272+ )
273+ else :
274+ self ._append_activation_function (
275+ self .params .layer_activations
276+ )
277+
278+ except KeyError :
279+ raise Exception ("Invalid activation type seleceted." )
280+ except IndexError :
281+ # No activation functions left to append at the end.
282+ pass
273283
274284 # Once everything is done, we can move the Network on the target
275285 # device.
@@ -300,7 +310,7 @@ def _append_activation_function(self, activation_function):
300310
301311 Parameters
302312 ----------
303- activation_function : str
313+ activation_function : str or nn.Module or class
304314 Activation function to be appended.
305315 """
306316 if activation_function is None :
@@ -311,6 +321,8 @@ def _append_activation_function(self, activation_function):
311321 )
312322 elif isinstance (activation_function , nn .Module ):
313323 self .layers .append (activation_function )
324+ elif issubclass (activation_function , nn .Module ):
325+ self .layers .append (activation_function ())
314326
315327
316328class LSTM (Network ):
0 commit comments