Skip to content

Commit 76b1017

Browse files
Removed mappings dictionary, now handled over getattr
also updated docs
1 parent 2f9556b commit 76b1017

2 files changed

Lines changed: 20 additions & 28 deletions

File tree

mala/common/parameters.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -352,18 +352,18 @@ class ParametersNetwork(ParametersBase):
352352
Otherwise, the activation functions are added layer by layer.
353353
Note that no activation function is applied between input layer and
354354
first hidden layer!
355-
The items in the list can either be strings, which MALA
356-
will map to the correct activation functions,
357-
torch.nn.Module objects, torch.nn.Module classes (which MALA will
358-
instantiate) OR None, in which case no activation function is used.
355+
The items in the list can either be strings (=names of torch.nn.Module
356+
activation functions), which MALA will map to the correct activation
357+
functions, torch.nn.Module objects, torch.nn.Module classes (which MALA
358+
will instantiate) OR None, in which case no activation function is
359+
used.
359360
The None can be ommitted at the end, but is useful when layers without
360-
activation functions are to be added in the middle
361-
Currently supported activation function strings are:
362-
363-
- "Sigmoid"
364-
- "ReLU"
365-
- "LeakyReLU" (default)
366-
- "Tanh"
361+
activation functions are to be skipped in the middle.
362+
Note that output from the output layer is by default restricted to
363+
only have positive values via restrict_targets in the ParameterTargets
364+
subclass. This is similar to having a ReLU function as a final
365+
activation function and ensures the physicality of the outputs (since
366+
the (L)DOS can never be negative).
367367
368368
layer_activations_include_output_layer : bool
369369
If False, no activation function is added to the output layer. This

mala/network/network.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,6 @@ def __init__(self, params: Parameters):
9494
# initialize the parent class
9595
super(Network, self).__init__()
9696

97-
# Mappings for parsing of the activation layers.
98-
self._activation_mappings = {
99-
"Sigmoid": nn.Sigmoid,
100-
"ReLU": nn.ReLU,
101-
"LeakyReLU": nn.LeakyReLU,
102-
"Tanh": nn.Tanh,
103-
}
104-
10597
# initialize the layers
10698
self.number_of_layers = len(self.params.layer_sizes) - 1
10799

@@ -316,9 +308,13 @@ def _append_activation_function(self, activation_function):
316308
if activation_function is None:
317309
pass
318310
elif isinstance(activation_function, str):
319-
self.layers.append(
320-
self._activation_mappings[activation_function]()
321-
)
311+
try:
312+
self.layers.append(getattr(torch.nn, activation_function)())
313+
except AttributeError:
314+
raise Exception(
315+
"Torch does not contain the specified "
316+
"activation function: " + activation_function
317+
)
322318
elif isinstance(activation_function, nn.Module):
323319
self.layers.append(activation_function)
324320
elif issubclass(activation_function, nn.Module):
@@ -367,9 +363,7 @@ def __init__(self, params):
367363
self.params.num_hidden_layers,
368364
batch_first=True,
369365
)
370-
self.activation = self._activation_mappings[
371-
self.params.layer_activations[0]
372-
]()
366+
self.activation = getattr(torch.nn, self.params.layer_activations[0])()
373367

374368
self.batch_size = None
375369
# Once everything is done, we can move the Network on the target
@@ -505,9 +499,7 @@ def __init__(self, params):
505499
self.params.num_hidden_layers,
506500
batch_first=True,
507501
)
508-
self.activation = self._activation_mappings[
509-
self.params.layer_activations[0]
510-
]()
502+
self.activation = getattr(torch.nn, self.params.layer_activations[0])()
511503

512504
if params.use_gpu:
513505
self.to("cuda")

0 commit comments

Comments
 (0)