Skip to content

Commit 96f527b

Browse files
Simplified implementation
1 parent f17de3d commit 96f527b

2 files changed

Lines changed: 27 additions & 46 deletions

File tree

mala/network/network.py

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
"""Neural network for MALA."""
22

3-
from copy import deepcopy
4-
53
from abc import abstractmethod
64
import numpy as np
75
import torch
86
import torch.distributed as dist
97
import torch.nn as nn
108
import torch.nn.functional as functional
119

12-
from mala.common.parameters import Parameters, ParametersNetwork
10+
from mala.common.parameters import Parameters
1311
from mala.common.parallelizer import printout, parallel_warn
1412

1513

@@ -44,7 +42,7 @@ class Network(nn.Module):
4442
If True, the torch distributed data parallel formalism will be used.
4543
"""
4644

47-
def __new__(cls, params: Parameters):
45+
def __new__(cls, params: Parameters = None):
4846
"""
4947
Create a neural network instance.
5048
@@ -59,46 +57,34 @@ def __new__(cls, params: Parameters):
5957
"""
6058
model = None
6159

62-
# Check if we're accessing through base class.
63-
# If not, we need to return the correct object directly.
64-
if cls == Network:
65-
if isinstance(params, Parameters):
66-
params_network = params.network
67-
elif isinstance(params, ParametersNetwork):
68-
params_network = params
69-
else:
70-
raise Exception("Incompativle parameters supplied")
60+
# Check if we're accessing through base class. If not, we need to
61+
# return the correct object directly.
62+
#
63+
# params=None if we load a serialized object, for instance. Then we
64+
# just want an empty object which gets populated during
65+
# deserialization.
7166

72-
if params_network.nn_type == "feed-forward":
67+
if params is None or cls != Network:
68+
model = super().__new__(cls)
69+
else:
70+
if params.network.nn_type == "feed-forward":
7371
model = super(Network, FeedForwardNet).__new__(FeedForwardNet)
7472

75-
if model is None:
73+
else:
7674
raise Exception("Unsupported network architecture.")
77-
else:
78-
model = super(Network, cls).__new__(cls)
7975

8076
return model
8177

8278
def __init__(self, params: Parameters):
83-
if isinstance(params, Parameters):
84-
params_network = params.network
85-
self.use_ddp = params.use_ddp
86-
seed = params.manual_seed
87-
elif isinstance(params, ParametersNetwork):
88-
params_network = params
89-
self.use_ddp = params_network._configuration["ddp"]
90-
seed = params._configuration["manual_seed"]
91-
else:
92-
raise Exception("Incompativle parameters supplied")
93-
9479
# copy the network params from the input parameter object
95-
self.params = params_network
80+
self.use_ddp = params.use_ddp
81+
self.params = params.network
9682

9783
# if the user has planted a seed (for comparibility purposes) we
9884
# should use it.
99-
if seed is not None:
100-
torch.manual_seed(seed)
101-
torch.cuda.manual_seed(seed)
85+
if params.manual_seed is not None:
86+
torch.manual_seed(params.manual_seed)
87+
torch.cuda.manual_seed(params.manual_seed)
10288

10389
# initialize the parent class
10490
super(Network, self).__init__()
@@ -112,20 +98,6 @@ def __init__(self, params: Parameters):
11298
else:
11399
raise Exception("Unsupported loss function.")
114100

115-
def __copy__(self):
116-
"""Copy a Network instance."""
117-
result = Network.__new__(Network, params=self.params)
118-
result.__dict__.update(self.__dict__)
119-
return result
120-
121-
def __deepcopy__(self, memo):
122-
"""Deepcopy a Network instance."""
123-
result = Network.__new__(Network, params=self.params)
124-
memo[id(self)] = result
125-
for k, v in self.__dict__.items():
126-
setattr(result, k, deepcopy(v, memo))
127-
return result
128-
129101
@abstractmethod
130102
def forward(self, inputs):
131103
"""

test/workflow_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ def test_model_copying(self):
471471
data_path,
472472
"te",
473473
)
474+
parameters.manual_seed = 123456252
474475
data_handler.prepare_data(reparametrize_scaler=False)
475476

476477
actual_ldos, predicted_ldos = tester.predict_targets(0)
@@ -489,11 +490,19 @@ def test_model_copying(self):
489490
os.path.join(data_path, "Be_snapshot3.out"), "espresso-out"
490491
)
491492
band_energy_2 = ldos_calculator.get_band_energy(predicted_ldos)
493+
print(
494+
network.params._configuration, copied_network.params._configuration
495+
)
492496
assert np.isclose(
493497
band_energy_1,
494498
band_energy_2,
495499
atol=accuracy_strict,
496500
)
501+
assert (
502+
network.params._configuration["manual_seed"]
503+
== copied_network.params._configuration["manual_seed"]
504+
== 123456252
505+
)
497506

498507
@pytest.mark.skipif(
499508
importlib.util.find_spec("total_energy") is None

0 commit comments

Comments
 (0)