Skip to content

Commit 66b5d8e

Browse files
Made exclusion of activation layer to output function possible, handling class/instances now correctly
1 parent bbef691 commit 66b5d8e

2 files changed

Lines changed: 51 additions & 31 deletions

File tree

mala/common/parameters.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -344,26 +344,33 @@ class ParametersNetwork(ParametersBase):
344344
network. Please note that the input layer is included therein.
345345
Default: [10,10,0]
346346
347-
layer_activations : list or str
348-
A list of strings detailing the activation functions to be used
349-
by the neural network. If a single string is supplied, then this
350-
activation function is used for all layers (including the output layer,
351-
i.e., an output activation is used!). Otherwise, the activation
352-
functions are added layer by layer.
347+
layer_activations : list or str or class or nn.Module
348+
Detailing the activation functions to be used
349+
by the neural network. If a single object is supplied, then this
350+
activation function is used for all layers (whether this applies to the
351+
output layer is controlled by layer_activations_include_output_layer).
352+
Otherwise, the activation functions are added layer by layer.
353353
Note that no activation function is applied between input layer and
354354
first hidden layer!
355355
The items in the list can either be strings, which MALA
356-
will map them directly to the correct activation functions or
357-
torch.nn.Module objects containing the activation functions directly
358-
OR None, in which case no activation function is used. The None
359-
can be ommitted at the end, but is useful when layers without
356+
will map to the correct activation functions,
357+
torch.nn.Module objects, torch.nn.Module classes (which MALA will
358+
instantiate) OR None, in which case no activation function is used.
359+
The None can be ommitted at the end, but is useful when layers without
360360
activation functions are to be added in the middle
361361
Currently supported activation function strings are:
362362
363363
- "Sigmoid"
364364
- "ReLU"
365365
- "LeakyReLU" (default)
366366
367+
layer_activations_include_output_layer : bool
368+
If False, no activation function is added to the output layer. This
369+
can of course also be done by supplying just the right amount of
370+
activation functions and this parameter mainly exist to control the
371+
last layer of activation functions in the case of using
372+
layer_activations with only a single object.
373+
367374
loss_function_type : string
368375
Loss function for the neural network
369376
Currently supported loss functions include:
@@ -398,6 +405,7 @@ def __init__(self):
398405
self.nn_type = "feed-forward"
399406
self.layer_sizes = [10, 10, 10]
400407
self.layer_activations = "LeakyReLU"
408+
self.layer_activations_include_output_layer = True
401409
self.loss_function_type = "mse"
402410

403411
# for LSTM/Gru

mala/network/network.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -231,15 +231,21 @@ def __init__(self, params):
231231
# We should NOT modify the list itself. This would break the
232232
# hyperparameter algorithms.
233233
use_only_one_activation_type = False
234-
235-
if not isinstance(self.params.layer_activations, str):
234+
if isinstance(self.params.layer_activations, list):
236235
if len(self.params.layer_activations) > self.number_of_layers:
236+
237+
number_of_ignored_layers = (
238+
len(self.params.layer_activations) - self.number_of_layers
239+
)
240+
number_of_ignored_layers += (
241+
1
242+
if self.params.layer_activations_include_output_layer
243+
is False
244+
else 0
245+
)
237246
printout(
238247
"Too many activation layers provided. The last",
239-
str(
240-
len(self.params.layer_activations)
241-
- self.number_of_layers
242-
),
248+
str(number_of_ignored_layers),
243249
"activation function(s) will be ignored.",
244250
min_verbosity=1,
245251
)
@@ -256,20 +262,24 @@ def __init__(self, params):
256262
)
257263
)
258264
)
259-
try:
260-
if isinstance(self.params.layer_activations, str):
261-
self._append_activation_function(
262-
self.params.layer_activations
263-
)
264-
else:
265-
self._append_activation_function(
266-
self.params.layer_activations[i]
267-
)
268-
except KeyError:
269-
raise Exception("Invalid activation type seleceted.")
270-
except IndexError:
271-
# No activation functions left to append at the end.
272-
pass
265+
if (
266+
i < self.number_of_layers - 1
267+
) or self.params.layer_activations_include_output_layer:
268+
try:
269+
if isinstance(self.params.layer_activations, list):
270+
self._append_activation_function(
271+
self.params.layer_activations[i]
272+
)
273+
else:
274+
self._append_activation_function(
275+
self.params.layer_activations
276+
)
277+
278+
except KeyError:
279+
raise Exception("Invalid activation type seleceted.")
280+
except IndexError:
281+
# No activation functions left to append at the end.
282+
pass
273283

274284
# Once everything is done, we can move the Network on the target
275285
# device.
@@ -300,7 +310,7 @@ def _append_activation_function(self, activation_function):
300310
301311
Parameters
302312
----------
303-
activation_function : str
313+
activation_function : str or nn.Module or class
304314
Activation function to be appended.
305315
"""
306316
if activation_function is None:
@@ -311,6 +321,8 @@ def _append_activation_function(self, activation_function):
311321
)
312322
elif isinstance(activation_function, nn.Module):
313323
self.layers.append(activation_function)
324+
elif issubclass(activation_function, nn.Module):
325+
self.layers.append(activation_function())
314326

315327

316328
class LSTM(Network):

0 commit comments

Comments
 (0)