-
Notifications
You must be signed in to change notification settings - Fork 36
Expand file tree
/
Copy pathcore.py
More file actions
357 lines (303 loc) · 16.7 KB
/
core.py
File metadata and controls
357 lines (303 loc) · 16.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# -*- coding: utf-8 -*-
import numpy as np
import collections
import multiprocess as mp
from . import Dream_shared_vars
from .Dream import Dream, DreamPool
from .model import Model
import traceback
def run_dream(parameters, likelihood, nchains=5, niterations=50000, start=None, restart=False, verbose=True, nverbose=10, tempering=False, mp_context=None, **kwargs):
"""Run DREAM given a set of parameters with priors and a likelihood function.
Parameters
----------
parameters: iterable of SampledParam class
A list of parameter priors
likelihood: function
A user-defined likelihood function
nchains: int, optional
The number of parallel DREAM chains to run. Default = 5
niterations: int, optional
The number of algorithm iterations to run. Default = 50,000
start: iterable of arrays or single array, optional
Either a list of start locations to initialize chains in, or a single start location to initialize all chains in. Default: None
restart: Boolean, optional
Whether run is a continuation of an earlier run. Pass this with the model_name argument to automatically load previous history and crossover probability files. Default: False
verbose: Boolean, optional
Whether to print verbose output (including acceptance or rejection of moves and the current acceptance rate). Default: True
nverbose: int, optional
Rate at which the acceptance rate is printed if verbose is set to True. Every n-th iteration the acceptance rate
will be printed and added to the acceptance rate file. Default: 10
tempering: Boolean, optional
Whether to use parallel tempering for the DREAM chains. Warning: this feature is untested. Use at your own risk! Default: False
mp_context: multiprocessing context or None.
Method used to to start the processes. If it's None, the default context, which depends in Python version and OS, is used.
For more information please check: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
kwargs:
Other arguments that will be passed to the Dream class on initialization. For more information, see Dream class.
Returns
-------
sampled_params : list of arrays
Sampled parameters for each chain
log_ps : list of arrays
Log probability for each sampled point for each chain
"""
if restart:
if start == None:
raise Exception('Restart run specified but no start positions given.')
if 'model_name' not in kwargs:
raise Exception('Restart run specified but no model name to load history and crossover value files from given.')
if type(parameters) is not list:
parameters = [parameters]
model = Model(likelihood=likelihood, sampled_parameters=parameters)
if restart:
model_prefix = kwargs['model_name']
step_instance = Dream(model=model, variables=parameters,
history_file=model_prefix + '_DREAM_chain_history.npy',
crossover_file=model_prefix + '_DREAM_chain_adapted_crossoverprob.npy',
gamma_file=model_prefix + '_DREAM_chain_adapted_gammalevelprob.npy',
verbose=verbose, mp_context=mp_context, **kwargs)
# Reload acceptance rate data
chains_naccepts_iterations = []
for chain in range(nchains):
na = np.load(f'{model_prefix}_naccepts_chain{chain}.npy')
chains_naccepts_iterations.append(na)
else:
step_instance = Dream(model=model, variables=parameters, verbose=verbose, mp_context=mp_context, **kwargs)
chains_naccepts_iterations = [np.zeros((2, 1), dtype=np.int)] * nchains
pool = _setup_mp_dream_pool(nchains, niterations, step_instance, start_pt=start, mp_context=mp_context)
try:
if tempering:
sampled_params, log_ps = _sample_dream_pt(nchains, niterations, step_instance, start, pool, verbose=verbose)
else:
if not isinstance(start, collections.abc.Iterable):
start = [start] * nchains
args = list(zip([step_instance] * nchains, [niterations] * nchains, start, [verbose] * nchains,
[nverbose]*nchains, list(range(nchains)), chains_naccepts_iterations))
returned_vals = pool.map(_sample_dream, args)
sampled_params = [val[0] for val in returned_vals]
log_ps = [val[1] for val in returned_vals]
acceptance_rates = [val[2] for val in returned_vals]
for chain in range(nchains):
filename = f'{step_instance.model_name}acceptance_rates_chain{chain}.txt'
with open(filename, 'ab') as f:
np.savetxt(f, acceptance_rates[chain])
finally:
pool.close()
pool.join()
return sampled_params, log_ps
def _sample_dream(args):
try:
dream_instance = args[0]
iterations = args[1]
start = args[2]
verbose = args[3]
nverbose = args[4]
chain_idx = args[5]
naccepts_iterations_total = args[6]
step_fxn = getattr(dream_instance, 'astep')
sampled_params = np.empty((iterations, dream_instance.total_var_dimension))
log_ps = np.empty((iterations, 1))
acceptance_rates_size = int(np.floor(iterations / nverbose))
if acceptance_rates_size == 0:
acceptance_rates_size = 1
acceptance_rates = np.zeros(acceptance_rates_size)
q0 = start
iterations_total = np.sum(naccepts_iterations_total[1])
naccepts = naccepts_iterations_total[0][-1]
naccepts100win = 0
acceptance_counter = 0
for iteration_idx, iteration in enumerate(range(iterations_total, iterations_total + iterations)):
if iteration%nverbose == 0:
acceptance_rate = float(naccepts)/(iteration+1)
acceptance_rates[acceptance_counter] = acceptance_rate
acceptance_counter += 1
if verbose:
print('Iteration: ',iteration,' acceptance rate: ',acceptance_rate)
if iteration%100 == 0:
acceptance_rate_100win = float(naccepts100win)/100
if verbose:
print('Iteration: ',iteration,' acceptance rate over last 100 iterations: ',acceptance_rate_100win)
naccepts100win = 0
old_params = q0
sampled_params[iteration_idx], log_prior, log_like = step_fxn(q0)
log_ps[iteration_idx] = log_like + log_prior
q0 = sampled_params[iteration_idx]
if old_params is None:
old_params = q0
if np.any(q0 != old_params):
naccepts += 1
naccepts100win += 1
naccepts_iterations_total = np.append(naccepts_iterations_total, np.array([[naccepts], [iterations]]), axis=1)
np.save(f'{dream_instance.model_name}naccepts_chain{chain_idx}.npy', naccepts_iterations_total)
except Exception as e:
traceback.print_exc()
print()
raise e
return sampled_params, log_ps, acceptance_rates
def _sample_dream_pt(nchains, niterations, step_instance, start, pool, verbose):
T = np.zeros((nchains))
T[0] = 1.
for i in range(nchains):
T[i] = np.power(.001, (float(i)/nchains))
step_instances = [step_instance]*nchains
if type(start) is list:
args = list(zip(step_instances, start, T, [None]*nchains, [None]*nchains))
else:
args = list(zip(step_instances, [start]*nchains, T, [None]*nchains, [None]*nchains))
sampled_params = np.zeros((nchains, niterations*2, step_instance.total_var_dimension))
log_ps = np.zeros((nchains, niterations*2, 1))
q0 = start
naccepts = np.zeros((nchains))
naccepts100win = np.zeros((nchains))
nacceptsT = np.zeros((nchains))
nacceptsT100win = np.zeros((nchains))
ttestsper100 = 100./nchains
for iteration in range(niterations):
itidx = iteration*2
if iteration%10 == 0:
ttests = iteration/float(nchains)
ntests = ttests + iteration
acceptance_rate = naccepts/(ntests+1)
Tacceptance_rate = nacceptsT/(ttests+1)
overall_Tacceptance_rate = np.sum(nacceptsT)/(iteration+1)
if verbose:
print('Iteration: ',iteration,' overall acceptance rate: ',acceptance_rate,' temp swap acceptance rate per chain: ',Tacceptance_rate,' and overall temp swap acceptance rate: ',overall_Tacceptance_rate)
if iteration%100 == 0:
acceptance_rate_100win = naccepts100win/(100 + ttestsper100)
Tacceptance_rate_100win = nacceptsT100win/ttestsper100
overall_Tacceptance_rate_100win = np.sum(nacceptsT100win)/100
if verbose:
print('Iteration: ',iteration,' overall acceptance rate over last 100 iterations: ',acceptance_rate_100win,' temp swap acceptance rate: ',Tacceptance_rate_100win,' and overall temp swap acceptance rate: ',overall_Tacceptance_rate_100win)
naccepts100win = np.zeros((nchains))
nacceptsT100win = np.zeros((nchains))
returned_vals = pool.map(_sample_dream_pt_chain, args)
qnews = [val[0] for val in returned_vals]
logprinews = [val[1] for val in returned_vals]
loglikenews = [val[2] for val in returned_vals]
dream_instances = [val[3] for val in returned_vals]
logpnews = [T[i]*loglikenews[i] + logprinews[i] for i in range(nchains)]
for chain in range(nchains):
sampled_params[chain][itidx] = qnews[chain]
log_ps[chain][itidx] = logpnews[chain]
random_chains = np.random.choice(nchains, 2, replace=False)
loglike1 = loglikenews[random_chains[0]]
T1 = T[random_chains[0]]
loglike2 = loglikenews[random_chains[1]]
T2 = T[random_chains[1]]
logp1 = logpnews[random_chains[0]]
logp2 = logpnews[random_chains[1]]
alpha = ((T1*loglike2)+(T2*loglike1))-((T1*loglike1)+(T2*loglike2))
if np.log(np.random.uniform()) < alpha:
if verbose:
print('Accepted temperature swap of chains: ',random_chains,' at temperatures: ',T1,' and ',T2,' and logps: ',logp1,' and ',logp2)
nacceptsT[random_chains[0]] += 1
nacceptsT[random_chains[1]] += 1
nacceptsT100win[random_chains[0]] += 1
nacceptsT100win[random_chains[1]] += 1
old_qs = list(qnews)
old_logps = list(logpnews)
old_loglikes = list(loglikenews)
old_logpri = list(logprinews)
qnews[random_chains[0]] = old_qs[random_chains[1]]
qnews[random_chains[1]] = old_qs[random_chains[0]]
logpnews[random_chains[0]] = old_logps[random_chains[1]]
logpnews[random_chains[1]] = old_logps[random_chains[0]]
loglikenews[random_chains[0]] = old_loglikes[random_chains[1]]
loglikenews[random_chains[1]] = old_loglikes[random_chains[0]]
logprinews[random_chains[0]] = old_logpri[random_chains[1]]
logprinews[random_chains[1]] = old_logpri[random_chains[0]]
else:
if verbose:
print('Did not accept temperature swap of chains: ',random_chains,' at temperatures: ',T1,' and ',T2,' and logps: ',logp1,' and ',logp2)
for chain in range(nchains):
sampled_params[chain][itidx+1] = qnews[chain]
log_ps[chain][itidx+1] = logpnews[chain]
for i, q in enumerate(qnews):
try:
if not np.all(q == q0[i]):
naccepts[i] += 1
naccepts100win[i] += 1
except TypeError:
#On first iteration without starting points this will fail because q0 == None
pass
args = list(zip(dream_instances, qnews, T, loglikenews, logprinews))
q0 = qnews
return sampled_params, log_ps
def _sample_dream_pt_chain(args):
dream_instance = args[0]
start = args[1]
T = args[2]
last_loglike = args[3]
last_logpri = args[4]
step_fxn = getattr(dream_instance, 'astep')
q1, logprior1, loglike1 = step_fxn(start, T, last_loglike, last_logpri)
return q1, logprior1, loglike1, dream_instance
def _setup_mp_dream_pool(nchains, niterations, step_instance, start_pt=None, mp_context=None):
min_njobs = (2*len(step_instance.DEpairs))+1
if nchains < min_njobs:
raise Exception('Dream should be run with at least (2*DEpairs)+1 number of chains. For current algorithmic settings, set njobs>=%s.' %str(min_njobs))
if step_instance.history_file != False:
old_history = np.load(step_instance.history_file)
len_old_history = len(old_history.flatten())
nold_history_records = len_old_history/step_instance.total_var_dimension
step_instance.nseedchains = nold_history_records
if niterations < step_instance.history_thin:
arr_dim = ((np.floor(nchains*niterations/step_instance.history_thin)+nchains)*step_instance.total_var_dimension)+len_old_history
else:
arr_dim = np.floor((((nchains*niterations)*step_instance.total_var_dimension)/step_instance.history_thin))+len_old_history
else:
if niterations < step_instance.history_thin:
arr_dim = ((np.floor(nchains*niterations/step_instance.history_thin)+nchains)*step_instance.total_var_dimension)+(step_instance.nseedchains*step_instance.total_var_dimension)
else:
arr_dim = np.floor(((nchains*niterations/step_instance.history_thin)*step_instance.total_var_dimension))+(step_instance.nseedchains*step_instance.total_var_dimension)
min_nseedchains = 2*len(step_instance.DEpairs)*nchains
if step_instance.nseedchains < min_nseedchains:
raise Exception('The size of the seeded starting history is insufficient. Increase nseedchains>=%s.' %str(min_nseedchains))
current_position_dim = nchains*step_instance.total_var_dimension
# Get context to define arrays
if mp_context is None:
ctx = mp.get_context(mp_context)
else:
ctx = mp_context
history_arr = ctx.Array('d', [0] * int(arr_dim))
if step_instance.history_file != False:
history_arr[0:len_old_history] = old_history.flatten()
nCR = step_instance.nCR
ngamma = step_instance.ngamma
crossover_setting = step_instance.CR_probabilities
crossover_probabilities = ctx.Array('d', crossover_setting)
ncrossover_updates = ctx.Array('d', [0] * nCR)
delta_m = ctx.Array('d', [0] * nCR)
gamma_level_setting = step_instance.gamma_probabilities
gamma_probabilities = ctx.Array('d', gamma_level_setting)
ngamma_updates = ctx.Array('d', [0] * ngamma)
delta_m_gamma = ctx.Array('d', [0] * ngamma)
current_position_arr = ctx.Array('d', [0] * current_position_dim)
shared_nchains = ctx.Value('i', nchains)
n = ctx.Value('i', 0)
tf = ctx.Value('c', b'F')
if step_instance.crossover_burnin == None:
step_instance.crossover_burnin = int(np.floor(niterations/10))
if start_pt != None:
if step_instance.start_random:
print('Warning: start position provided but random_start set to True. Overrode random_start value and starting walk at provided start position.')
step_instance.start_random = False
p = DreamPool(nchains, context=ctx, initializer=_mp_dream_init,
initargs=(history_arr, current_position_arr, shared_nchains,
crossover_probabilities, ncrossover_updates, delta_m,
gamma_probabilities, ngamma_updates, delta_m_gamma, n, tf,))
# p = mp.pool.ThreadPool(nchains, initializer=_mp_dream_init, initargs=(history_arr, current_position_arr, shared_nchains, crossover_probabilities, ncrossover_updates, delta_m, gamma_probabilities, ngamma_updates, delta_m_gamma, n, tf, ))
# p = mp.Pool(nchains, initializer=_mp_dream_init, initargs=(history_arr, current_position_arr, shared_nchains, crossover_probabilities, ncrossover_updates, delta_m, gamma_probabilities, ngamma_updates, delta_m_gamma, n, tf, ))
return p
def _mp_dream_init(arr, cp_arr, nchains, crossover_probs, ncrossover_updates, delta_m, gamma_probs, ngamma_updates, delta_m_gamma, val, switch):
Dream_shared_vars.history = arr
Dream_shared_vars.current_positions = cp_arr
Dream_shared_vars.nchains = nchains
Dream_shared_vars.cross_probs = crossover_probs
Dream_shared_vars.ncr_updates = ncrossover_updates
Dream_shared_vars.delta_m = delta_m
Dream_shared_vars.gamma_level_probs = gamma_probs
Dream_shared_vars.ngamma_updates = ngamma_updates
Dream_shared_vars.delta_m_gamma = delta_m_gamma
Dream_shared_vars.count = val
Dream_shared_vars.history_seeded = switch