forked from sanmitraghosh/hergModels
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmake_predictions.py
More file actions
executable file
·175 lines (147 loc) · 7.16 KB
/
make_predictions.py
File metadata and controls
executable file
·175 lines (147 loc) · 7.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#!/usr/bin/env python3
#
# Fit Kylie's model to Cell 5 data using CMA-ES
#
import models_forward.LogPrior as prior
import models_forward.pintsForwardModel as forwardModel
import models_forward.Rates as Rates
import models_forward.util as util
import os
import pints
import numpy as np
import cPickle
import myokit
import argparse
# Load a hERG model and prior
cmaes_result_files = 'cmaes_results/'
mcmc_result_files = 'mcmc_results/'
# Check input arguments
parser = argparse.ArgumentParser(
description='Make AP predictions based on the CMAES fit to sine wave data')
parser.add_argument('--cell', type=int, default=5, metavar='N',
help='cell number : 1, 2, ..., 5')
args = parser.parse_args()
protocols = ['sine-wave','ap','original-sine'] # Keep sine wave first to get good sigma estimate, and load params properly
indices = range(len(protocols))
num_models = 30
cell = args.cell
likelihood_results = np.zeros((num_models, len(protocols)+1))
for model_num in range(1,num_models+1):
# Import markov models from the models file, and rate dictionaries.
model_name = 'model-' + str(model_num)
root = os.path.abspath('models_myokit')
myo_model = os.path.join(root, model_name + '.mmt')
root = os.path.abspath('rate_dictionaries')
rate_file = os.path.join(root, model_name + '-priors.p')
rate_dict = cPickle.load(open(rate_file, 'rb'))
print("LOADING MODEL "+str(model_num))
for protocol_index in indices:
protocol_name = protocols[protocol_index]
print('Looking at Model ', model_num, ' and protocol ', protocol_name, protocol_index)
#
# Select data file
#
root = os.path.abspath(protocol_name + '-data')
data_file = os.path.join(root, 'cell-' + str(cell) + '.csv')
#
# Load data
#
log = myokit.DataLog.load_csv(data_file).npview()
time = log.time()
current = log['current']
voltage = log['voltage']
del(log)
#
# Load protocol
#
if protocol_name=='sine-wave':
protocol_file = os.path.join(root, 'steps.mmt')
protocol = myokit.load_protocol(protocol_file)
sw=1
#
# Apply capacitance filter based on protocol
#
print('Applying capacitance filtering')
time, voltage, current = forwardModel.capacitance(protocol, 0.1, time, voltage, current)
elif protocol_name=='original-sine':
root = os.path.abspath('original-sine-data')
protocol_file = os.path.join(root, 'steps.mmt')
protocol = myokit.load_protocol(protocol_file) # Same steps before sine wave
print('Applying capacitance filtering')
time, voltage, current = forwardModel.capacitance(protocol, 0.1, time, voltage, current)
if protocol_name!='sine-wave':
sw=0
print('Defining the protocol from ', root)
protocol_file = os.path.join(root, protocol_name + '.csv')
log = myokit.DataLog.load_csv(protocol_file).npview()
prot_times = log.time()
prot_voltages = log['voltage']
del(log)
protocol = [prot_times, prot_voltages]
if model_num==1:
root = os.path.abspath('predictions/' + protocol_name + '/cell-' + str(cell))
if not os.path.exists(root):
os.makedirs(root)
np.savetxt(root + '/spike-filtered-data.csv', np.transpose([time, voltage, current]), delimiter=',')
#
# Cell-specific parameters
#
temperature = forwardModel.temperature(cell)
lower_conductance = forwardModel.conductance_limit(cell)
if protocol_name=='sine-wave' or protocol_name=='original-sine':
#
# Estimate noise from start of data
# Kylie uses the first 200ms, where I = 0 + noise
#
sigma_noise = np.std(current[:2000], ddof=1)
#
# Create forward model
#
transform = 0 # we don't need to bother with transforms for a forward run...
model = forwardModel.ForwardModel(
protocol, temperature, myo_model, rate_dict, transform, sine_wave=sw)
n_params = model.n_params
#
# Define problem
#
problem = pints.SingleOutputProblem(model, time, current)
#
# Define log-posterior
#
log_likelihood = pints.GaussianKnownSigmaLogLikelihood(problem, sigma_noise)
log_prior = prior.LogPrior(
rate_dict, lower_conductance, n_params, transform)
log_posterior = pints.LogPosterior(log_likelihood, log_prior)
rate_checker = Rates.ratesPrior(transform, lower_conductance)
# Define parameter set from best ones we have found so far.
# Only refresh thse on the first sine wave fit protocol
if protocol_name=='sine-wave':
parameter_set = np.loadtxt(cmaes_result_files + model_name +
'-cell-' + str(cell) + '-cmaes.txt')
ll_score = log_likelihood(parameter_set)
print('CMAES model parameters start point: ', parameter_set)
print('LogLikelihood (proportional to square error): ', ll_score)
mcmc_best_param_file = mcmc_result_files + model_name +'-cell-' + str(cell) + '-best-parameters.txt'
if os.path.isfile(mcmc_best_param_file):
mcmc_parameter_set = np.loadtxt(mcmc_best_param_file)
mcmc_parameter_set = util.transformer(1,mcmc_parameter_set,rate_dict,False)# Transform hard coded to 1
mcmc_ll_score = log_likelihood(mcmc_parameter_set)
print('MCMC model parameters start point: ', mcmc_parameter_set)
print('LogLikelihood (proportional to square error): ', mcmc_ll_score)
if (mcmc_ll_score>ll_score):
ll_score = mcmc_ll_score
parameter_set = mcmc_parameter_set
print('Replacing best fit parameters with MCMC max posterior sample')
else:
ll_score = log_likelihood(parameter_set)
# Keep track of the likelihoods for each protocol, and the best and worst for colour bar scalings.
likelihood_results[model_num-1,0]=model_num
likelihood_results[model_num-1,protocol_index+1]=ll_score
root = os.path.abspath('predictions/' + protocol_name + '/cell-' + str(cell))
if not os.path.exists(root):
os.makedirs(root)
#print('Running sim with set ', parameter_set)
sol = model.simulate(parameter_set, time)
np.savetxt(root + '/model-' + str(model_num) + '.csv', np.transpose([time, sol]), delimiter=',')
np.savetxt(root + '/for-teun-model-' + str(model_num) + '.csv', np.transpose([time, model.simulated_v, sol, model.simulated_o]), delimiter=',')
np.savetxt('predictions/likelihoods-cell-' + str(cell) + '.csv', likelihood_results, delimiter=',')