Skip to content

Commit ed46842

Browse files
committed
playing with neuropixel data and fixing L-curve lambda selection
1 parent 7b395ff commit ed46842

2 files changed

Lines changed: 77 additions & 32 deletions

File tree

figures/npx/dan_kCSD_from_npx.py

Lines changed: 76 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,42 @@
1010
# from matplotlib import gridspec
1111

1212

13-
def make_plot(xx, yy, zz, title='True CSD', cmap=cm.bwr):
14-
fig = plt.figure(figsize=(7, 7))
15-
ax = plt.subplot(111)
13+
def make_plot(ax, xx, yy, zz, title='True CSD', cmap=cm.bwr):
14+
# fig = plt.figure(figsize=(7, 7))
15+
# ax = plt.subplot(111)
1616
ax.set_aspect('equal')
1717
t_max = np.max(np.abs(zz))
1818
levels = np.linspace(-1 * t_max, t_max, 32)
1919
im = ax.contourf(xx, yy, zz, levels=levels, cmap=cmap)
2020
ax.set_xlabel('X (mm)')
2121
ax.set_ylabel('Y (mm)')
2222
ax.set_title(title)
23-
ticks = np.linspace(-1 * t_max, t_max, 3, endpoint=True)
24-
plt.colorbar(im, orientation='horizontal', format='%.2f', ticks=ticks)
23+
# ticks = np.linspace(-1 * t_max, t_max, 3, endpoint=True)
24+
# plt.colorbar(im, orientation='horizontal', format='%.2f', ticks=ticks)
2525
return ax
2626

2727

28+
def dan_make_plot(k):
29+
fig = plt.figure(figsize=(7, 7))
30+
ax1 = plt.subplot(121)
31+
32+
est_csd = k.values('CSD')
33+
est_csd = est_csd.reshape(7, 90)
34+
est_pots = k.values('POT')
35+
est_pots = est_pots.reshape(7, 90)
36+
37+
make_plot(ax1, k.estm_x, k.estm_y, est_csd[:, :],
38+
title='Estimated CSD', cmap=cm.bwr)
39+
40+
ax2 = plt.subplot(122)
41+
make_plot(ax2, k.estm_x, k.estm_y, est_pots[:, :],
42+
title='Estimated POT', cmap=cm.PRGn)
43+
fig.suptitle('lambda = %f, R = %f' % (k.lambd, k.R))
44+
45+
return fig
46+
47+
48+
2849
# Specific to Ewas experimental setup
2950
def load_chann_map():
3051
book = load_workbook('NP_do_map.xlsx')
@@ -59,16 +80,16 @@ def dan_fetch_electrodes(meta):
5980
return(electrode, channel)
6081

6182

62-
def fetch_channels(eles):
63-
chans = []
64-
exist_ele = []
65-
for ii in eles:
66-
try:
67-
chans.append(ele_chan_dict[ii])
68-
exist_ele.append(ii)
69-
except KeyError:
70-
print('Not recording from ele', ii)
71-
return chans, exist_ele
83+
# def fetch_channels(eles):
84+
# chans = []
85+
# exist_ele = []
86+
# for ii in eles:
87+
# try:
88+
# chans.append(ele_chan_dict[ii])
89+
# exist_ele.append(ii)
90+
# except KeyError:
91+
# print('Not recording from ele', ii)
92+
# return chans, exist_ele
7293

7394
def eles_to_rows(eles):
7495
rows = []
@@ -189,31 +210,55 @@ def eles_to_coords(eles):
189210

190211

191212
pots = pots.reshape((len(channels), 1))
192-
R_init = 5. # 0.3
213+
R = 5. # 0.3
214+
lambd = 0.
193215
h = 20. # 50
194216
sigma = 0.3
217+
195218
k = KCSD2D(ele_pos, pots, h=h, sigma=sigma,
196-
xmin=-35, xmax=35,
197-
ymin=1100, ymax=2000,
198-
# ymin=1000, ymax=10000,
199-
gdx=10, gdy=10,
200-
R_init=R_init, n_src_init=1000,
201-
src_type='gauss') # rest of the parameters are set at default
202-
k.cross_validate(Rs=np.logspace(-1., 1., 10), lambdas=None)
219+
xmin=-35, xmax=35,
220+
ymin=1100, ymax=2000,
221+
# ymin=1000, ymax=10000,
222+
gdx=10, gdy=10, lambd=lambd,
223+
R_init=R, n_src_init=10000,
224+
src_type='gauss') # rest of the parameters are set at default
225+
226+
k.L_curve(Rs=np.logspace(-1., 2., 31), lambdas=np.logspace(-5., 1., 11))
227+
plt.imshow(k.curve_surf)
228+
229+
# k.cross_validate(Rs=np.logspace(0., 2., 21), lambdas=np.logspace(-5., 1., 11))
203230
# k.cross_validate(Rs=np.linspace(0.1, 1.001, 2), lambdas=None)
204231
# 2 -> 20
205232

233+
dan_make_plot(k)
234+
235+
236+
# =============================================================================
237+
# for R in np.logspace(0., 2., 21):
238+
# for lambd in np.logspace(-5., 1., 11):
239+
# k = KCSD2D(ele_pos, pots, h=h, sigma=sigma,
240+
# xmin=-35, xmax=35,
241+
# ymin=1100, ymax=2000,
242+
# # ymin=1000, ymax=10000,
243+
# gdx=10, gdy=10, lambd=lambd,
244+
# R_init=R, n_src_init=1000,
245+
# src_type='gauss') # rest of the parameters are set at default
246+
#
247+
# est_csd = k.values('CSD')
248+
# est_csd = est_csd.reshape(7, 90)
249+
# est_pots = k.values('POT')
250+
# est_pots = est_pots.reshape(7, 90)
251+
#
252+
# dan_make_plot(k)
253+
#
254+
# =============================================================================
206255

207-
est_csd = k.values('CSD')
208-
est_csd = est_csd.reshape(7, 90)
209-
est_pots = k.values('POT')
210-
est_pots = est_pots.reshape(7, 90)
211256

212-
make_plot(k.estm_x, k.estm_y, est_csd[:, :],
213-
title='Estimated CSD without CV', cmap=cm.bwr)
257+
# make_plot(k.estm_x, k.estm_y, est_csd[:, :],
258+
# title='Estimated CSD without CV', cmap=cm.bwr)
214259

215-
make_plot(k.estm_x, k.estm_y, est_pots[:, :],
216-
title='Estimated POT without CV', cmap=cm.PRGn)
260+
# make_plot(k.estm_x, k.estm_y, est_pots[:, :],
261+
# title='Estimated POT without CV', cmap=cm.PRGn)
217262

218263

219264
# # ax = plt.subplot(121)

kcsd/KCSD.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def L_curve(self, lambdas=None, Rs=None, n_jobs=1):
410410
self.curve_surf = np.zeros((len(Rs), len(lambdas)))
411411
for R_idx, R in enumerate(Rs):
412412
self.update_R(R)
413-
self.suggest_lambda()
413+
# self.suggest_lambda()
414414
print('l-curve (all lambda): ', np.round(R, decimals=3))
415415
modelnormseq, residualseq = utils.parallel_search(self.k_pot, self.pots, lambdas,
416416
n_jobs=n_jobs)

0 commit comments

Comments
 (0)