Skip to content

Commit 878d5f5

Browse files
committed
Scripts to generate figures from NPX
1 parent 2e50768 commit 878d5f5

4 files changed

Lines changed: 352 additions & 0 deletions

File tree

figures/npx/figure_properties.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import matplotlib.pyplot as plt
2+
3+
def set_axis(ax, x, y, letter=None):
4+
ax.text(
5+
x,
6+
y,
7+
letter,
8+
fontsize=25,
9+
weight='bold',
10+
transform=ax.transAxes)
11+
return ax
12+
13+
plt.rcParams.update({
14+
'xtick.labelsize': 15,
15+
'xtick.major.size': 10,
16+
'ytick.labelsize': 15,
17+
'ytick.major.size': 10,
18+
'font.size': 12,
19+
'axes.labelsize': 15,
20+
'axes.titlesize': 20,
21+
'axes.titlepad' : 30,
22+
'legend.fontsize': 15,
23+
# 'figure.subplot.wspace': 0.4,
24+
# 'figure.subplot.hspace': 0.4,
25+
# 'figure.subplot.left': 0.1,
26+
})
27+
28+
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import numpy as np
2+
from kcsd import KCSD2D
3+
from pathlib import Path
4+
import DemoReadSGLXData.readSGLX as readSGLX
5+
import matplotlib.pyplot as plt
6+
import matplotlib.cm as cm
7+
from scipy.signal import filtfilt, butter
8+
from figure_properties import *
9+
# from matplotlib import gridspec
10+
plt.close('all')
11+
#%%
12+
def make_plot_spacetime(ax, xx, yy, zz, title='True CSD', cmap=cm.bwr_r, ymin=0, ymax=10000):
13+
im = ax.imshow(zz,extent=[0, zz.shape[1]/Fs*1000,-3500, 500], aspect='auto',
14+
vmax = 1*zz.max(),vmin = -1*zz.max(), cmap=cmap)
15+
ax.set_xlabel('Time (ms)')
16+
ax.set_ylabel('Y ($\mu$m)')
17+
if 'Pot' in title: ax.set_ylabel('Y ($\mu$m)')
18+
ax.set_title(title)
19+
if 'CSD' in title:
20+
plt.colorbar(im, orientation='vertical', format='%.2f', ticks = [-0.01,0,0.01])
21+
else:
22+
plt.colorbar(im, orientation='vertical', format='%.1f', ticks = [-0.6,0,0.6])
23+
# plt.gca().invert_yaxis()
24+
25+
def make_plot(ax, xx, yy, zz, title='True CSD', cmap=cm.bwr):
26+
ax.set_aspect('auto')
27+
levels = np.linspace(zz.min(), -zz.min(), 61)
28+
# if 'CSD' in title: levels = levels = np.linspace(zz.min(), -zz.min(), 32)
29+
# if 'POT' in title: levels = np.linspace(zz.min(), -zz.min(), 64)
30+
im = ax.contourf(xx, -(yy-500), zz, levels=levels, cmap=cmap)
31+
ax.set_xlabel('X ($\mu$m)')
32+
ax.set_ylabel('Y ($\mu$m)')
33+
ax.set_title(title)
34+
# ticks = np.linspace(-100,100, 7, endpoint=True)
35+
if 'CSD' in title:
36+
plt.colorbar(im, orientation='vertical', format='%.2f', ticks=[-0.02,0,0.02])
37+
else: plt.colorbar(im, orientation='vertical', format='%.1f', ticks=[-0.6,0,0.6])
38+
plt.scatter(ele_pos[:, 0],
39+
-(ele_pos[:, 1]-500),
40+
s=0.8, color='black')
41+
# plt.gca().invert_yaxis()
42+
return ax
43+
44+
def dan_fetch_electrodes(meta):
45+
imroList = meta['imroTbl'].split(sep=')')
46+
# One entry for each channel plus header entry,
47+
# plus a final empty entry following the last ')'
48+
nChan = len(imroList) - 2
49+
electrode = np.zeros(nChan, dtype=int) # default type = float
50+
channel = np.zeros(nChan, dtype=int)
51+
bank = np.zeros(nChan, dtype=int)
52+
for i in range(0, nChan):
53+
currList = imroList[i+1].split(sep=' ')
54+
print(currList)
55+
channel[i] = int(currList[0][1:])
56+
bank[i] = int(currList[1])
57+
# reference_electrode[i] = currList[2]
58+
# Channel N => Electrode (1+N+384*A), where N = 0:383, A=0:2
59+
electrode = 1 + channel + 384 * bank
60+
return(electrode, channel)
61+
62+
def eles_to_ycoord(eles):
63+
y_coords = []
64+
for ii in range(192):
65+
y_coords.append(ii*20)
66+
y_coords.append(ii*20)
67+
return y_coords[::-1]
68+
69+
def eles_to_xcoord(eles):
70+
x_coords = []
71+
for ele in eles:
72+
off = ele%4
73+
if off == 1: x_coords.append(-24)
74+
elif off == 2: x_coords.append(8)
75+
elif off == 3: x_coords.append(-8)
76+
elif off==0: x_coords.append(24)
77+
return x_coords
78+
79+
def eles_to_coords(eles):
80+
xs = eles_to_xcoord(eles)
81+
ys = eles_to_ycoord(eles)
82+
return np.array((xs, ys)).T
83+
#%%
84+
binFullPath = Path('/Users/Wladek/Dysk Google/kCSD_lcurve/validation/'
85+
'15_3800_bank0_defauld PnoFIltr_OLD_headsage_OLD_electrode_g0_t0.imec0.ap.bin')
86+
87+
meta = readSGLX.readMeta(binFullPath)
88+
Fss = int(readSGLX.SampRate(meta))
89+
90+
path = '/Users/Wladek/Dysk Google/kCSD_lcurve/validation/'
91+
electrodes, channels = dan_fetch_electrodes(meta)
92+
ch_order = electrodes.argsort()
93+
94+
rawData = readSGLX.makeMemMapRaw(binFullPath, meta)
95+
selectData = rawData[channels, 30*Fss:50*Fss]
96+
# convData is the potential in uV or mV
97+
if meta['typeThis'] == 'imec':
98+
rawData = 1e3*readSGLX.GainCorrectIM(selectData, channels, meta)
99+
else:
100+
rawData = 1e3*readSGLX.GainCorrectNI(selectData, channels, meta)
101+
102+
electrodes.sort()
103+
ele_pos_def = eles_to_coords(electrodes[::-1])
104+
#%%
105+
ex_time = 16.6#12.7
106+
lowpass = 0.5#20 beta
107+
highpass = 300#50 beta
108+
after = 0.3
109+
forfilt = rawData[:,int((ex_time-0.3)*Fss):int((ex_time+after)*Fss)]
110+
# forfilt = detrend(forfilt, bp=np.array([0,int(0.1*Fss)]))
111+
# for i in range(384): forfilt[i] = forfilt[i]/np.std(forfilt[i])
112+
[b,a] = butter(3, [lowpass/(Fss/2.0), highpass/(Fss/2.0)] ,btype = 'bandpass')
113+
filtData = filtfilt(b,a, forfilt)
114+
np.save('npx_data', filtData)
115+
#%%
116+
resamp = 12
117+
pots_resamp = filtData[:,::resamp]
118+
pots = pots_resamp[:, :]
119+
Fs=int(Fss/resamp)
120+
#%%
121+
time = np.linspace(0, pots.shape[1]/Fs, pots.shape[1])
122+
plt.figure()
123+
plt.subplot(121)
124+
for ch in range(0,384,8):#, potsy.shape[0], 8):
125+
plt.plot(time, pots[ch,:]+1*ch, color='grey', lw=0.3)
126+
print('start averaging')
127+
plt.subplot(122)
128+
plt.imshow(pots[::-1], extent=[0,pots.shape[1]/Fs,pots.shape[0],0],
129+
aspect='auto', cmap = 'PRGn',
130+
vmin = -pots.max(), vmax = pots.max())
131+
# plt.xlim(280, 330)
132+
# plt.subplot(133)
133+
# %%
134+
pots_for_csd = np.delete(pots, 191, axis=0)
135+
ele_pos_for_csd = np.delete(ele_pos_def, 191, axis=0)
136+
# pots_for_csd = pots
137+
# ele_pos_for_csd = ele_pos_def
138+
def do_kcsd(ele_pos_for_csd, pots_for_csd, ele_limit):
139+
ele_position = ele_pos_for_csd[:ele_limit[1]][0::1]
140+
csd_pots = pots_for_csd[:ele_limit[1]][0::1]
141+
k = KCSD2D(ele_position, csd_pots,
142+
h=1, sigma=1,
143+
xmin= -42, xmax=42, gdx=4,
144+
ymin=0, ymax=4000, gdy=4)
145+
k.L_curve(Rs=np.linspace(32, 90, 1), lambdas=np.logspace(-9, -7, 1))
146+
# k.cross_validate(Rs=np.linspace(20, 30, 1), lambdas=np.logspace(-5, -3, 20))
147+
plt.figure()
148+
plt.imshow(k.curve_surf)#, vmin=-k.curve_surf.max(), vmax=k.curve_surf.max(), cmap='BrBG_r')
149+
plt.colorbar()
150+
return k, k.values('CSD'), k.values('POT'), ele_position
151+
152+
k, est_csd, est_pots, ele_pos = do_kcsd(ele_pos_for_csd, pots_for_csd, ele_limit = (0,320))
153+
#%%
154+
plt.close('all')
155+
save= 1
156+
tp= 760
157+
def plot_1D_pics(k, est_csd, est_pots, cut=9):
158+
plt.figure(figsize=(12, 8))
159+
# plt.suptitle('plane: '+str(k.estm_x[cut,0])+' $\mu$m '+' $\lambda$ : '+str(k.lambd)+
160+
# ' R: '+ str(k.R))
161+
ax1 = plt.subplot(122)
162+
set_axis(ax1, -0.05, 1.05, letter= 'B')
163+
make_plot_spacetime(ax1, k.estm_x, k.estm_y, est_csd[cut,:,:],
164+
title='Estimated CSD', cmap=cm.bwr)
165+
for lvl, name in zip([-500,-850,-2000], ['II/III', 'IV', 'V/VI']):
166+
plt.axhline(lvl, ls='--', color='grey')
167+
plt.text(340, lvl+20, name)
168+
plt.xlim(250, 400)
169+
ax2 = plt.subplot(121)
170+
set_axis(ax2, -0.05, 1.05, letter= 'A')
171+
make_plot_spacetime(ax2, k.estm_x, k.estm_y, est_pots[cut,:,:],
172+
title='Estimated LFP', cmap=cm.PRGn)
173+
plt.axvline(tp/Fs*1000, ls='--', color ='grey', lw=2)
174+
plt.xlim(250, 400)
175+
plt.tight_layout()
176+
plt.savefig(savedir +'Figure_15', dpi=300)
177+
savedir = '/Users/Wladek/Dysk Google/kCSD_lcurve/validation/'
178+
for cut in range(15,16,1): plot_1D_pics(k, est_csd, est_pots, cut)
179+
# plt.close('all')
180+
181+
def plot_2D_pics(tp, cut, save=0):
182+
plt.figure(figsize=(12, 8))
183+
ax1 = plt.subplot(122)
184+
set_axis(ax1, -0.05, 1.05, letter= 'B')
185+
make_plot(ax1, k.estm_x, k.estm_y, est_csd[:,:,tp],
186+
title='Estimated CSD', cmap=cm.bwr)
187+
# for i in range(383): plt.text(ele_pos_for_csd[i,0], ele_pos_for_csd[i,1]+8, str(i+1))
188+
plt.axvline(k.estm_x[cut][0], ls='--', color ='grey', lw=2)
189+
ax2 = plt.subplot(121)
190+
set_axis(ax2, -0.05, 1.05, letter= 'A')
191+
make_plot(ax2, k.estm_x, k.estm_y, est_pots[:,:,tp],
192+
title='Estimated LFP', cmap=cm.PRGn)
193+
# plt.suptitle(' $\lambda$ : '+str(k.lambd)+ ' R: '+ str(k.R))
194+
plt.tight_layout()
195+
plt.savefig(savedir +'Figure_14', dpi=300)
196+
197+
plot_2D_pics(tp = tp, cut=15)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import numpy as np
2+
from kcsd import KCSD2D
3+
import matplotlib.pyplot as plt
4+
import matplotlib.cm as cm
5+
from scipy.signal import filtfilt, butter
6+
from figure_properties import *
7+
plt.close('all')
8+
#%%
9+
def make_plot_spacetime(ax, xx, yy, zz, title='True CSD', cmap=cm.bwr_r, ymin=0, ymax=10000):
10+
im = ax.imshow(zz,extent=[0, zz.shape[1]/Fs*1000,-3500, 500], aspect='auto',
11+
vmax = 1*zz.max(),vmin = -1*zz.max(), cmap=cmap)
12+
ax.set_xlabel('Time (ms)')
13+
ax.set_ylabel('Y ($\mu$m)')
14+
if 'Pot' in title: ax.set_ylabel('Y ($\mu$m)')
15+
ax.set_title(title)
16+
if 'CSD' in title:
17+
plt.colorbar(im, orientation='vertical', format='%.2f', ticks = [-0.01,0,0.01])
18+
else:
19+
plt.colorbar(im, orientation='vertical', format='%.1f', ticks = [-0.6,0,0.6])
20+
# plt.gca().invert_yaxis()
21+
22+
def make_plot(ax, xx, yy, zz, title='True CSD', cmap=cm.bwr):
23+
ax.set_aspect('auto')
24+
levels = np.linspace(zz.min(), -zz.min(), 61)
25+
im = ax.contourf(xx, -(yy-500), zz, levels=levels, cmap=cmap)
26+
ax.set_xlabel('X ($\mu$m)')
27+
ax.set_ylabel('Y ($\mu$m)')
28+
ax.set_title(title)
29+
if 'CSD' in title:
30+
plt.colorbar(im, orientation='vertical', format='%.2f', ticks=[-0.02,0,0.02])
31+
else: plt.colorbar(im, orientation='vertical', format='%.1f', ticks=[-0.6,0,0.6])
32+
plt.scatter(ele_pos[:, 0],
33+
-(ele_pos[:, 1]-500),
34+
s=0.8, color='black')
35+
# plt.gca().invert_yaxis()
36+
return ax
37+
38+
def eles_to_ycoord(eles):
39+
y_coords = []
40+
for ii in range(192):
41+
y_coords.append(ii*20)
42+
y_coords.append(ii*20)
43+
return y_coords[::-1]
44+
45+
def eles_to_xcoord(eles):
46+
x_coords = []
47+
for ele in eles:
48+
off = ele%4
49+
if off == 1: x_coords.append(-24)
50+
elif off == 2: x_coords.append(8)
51+
elif off == 3: x_coords.append(-8)
52+
elif off==0: x_coords.append(24)
53+
return x_coords
54+
55+
def eles_to_coords(eles):
56+
xs = eles_to_xcoord(eles)
57+
ys = eles_to_ycoord(eles)
58+
return np.array((xs, ys)).T
59+
60+
def plot_1D_pics(k, est_csd, est_pots, tp, cut=9):
61+
plt.figure(figsize=(12, 8))
62+
# plt.suptitle('plane: '+str(k.estm_x[cut,0])+' $\mu$m '+' $\lambda$ : '+str(k.lambd)+
63+
# ' R: '+ str(k.R))
64+
ax1 = plt.subplot(122)
65+
set_axis(ax1, -0.05, 1.05, letter= 'B')
66+
make_plot_spacetime(ax1, k.estm_x, k.estm_y, est_csd[cut,:,:],
67+
title='Estimated CSD', cmap='bwr')
68+
for lvl, name in zip([-500,-850,-2000], ['II/III', 'IV', 'V/VI']):
69+
plt.axhline(lvl, ls='--', color='grey')
70+
plt.text(340, lvl+20, name)
71+
plt.xlim(250, 400)
72+
ax2 = plt.subplot(121)
73+
set_axis(ax2, -0.05, 1.05, letter= 'A')
74+
make_plot_spacetime(ax2, k.estm_x, k.estm_y, est_pots[cut,:,:],
75+
title='Estimated LFP', cmap='PRGn')
76+
plt.axvline(tp/Fs*1000, ls='--', color ='grey', lw=2)
77+
plt.xlim(250, 400)
78+
plt.tight_layout()
79+
80+
def plot_2D_pics(k, est_csd, est_pots, tp, cut, save=0):
81+
plt.figure(figsize=(12, 8))
82+
ax1 = plt.subplot(122)
83+
set_axis(ax1, -0.05, 1.05, letter= 'B')
84+
make_plot(ax1, k.estm_x, k.estm_y, est_csd[:,:,tp],
85+
title='Estimated CSD', cmap='bwr')
86+
# for i in range(383): plt.text(ele_pos_for_csd[i,0], ele_pos_for_csd[i,1]+8, str(i+1))
87+
plt.axvline(k.estm_x[cut][0], ls='--', color ='grey', lw=2)
88+
ax2 = plt.subplot(121)
89+
set_axis(ax2, -0.05, 1.05, letter= 'A')
90+
make_plot(ax2, k.estm_x, k.estm_y, est_pots[:,:,tp],
91+
title='Estimated LFP', cmap='PRGn')
92+
# plt.suptitle(' $\lambda$ : '+str(k.lambd)+ ' R: '+ str(k.R))
93+
plt.tight_layout()
94+
95+
def do_kcsd(ele_pos_for_csd, pots_for_csd, ele_limit):
96+
ele_position = ele_pos_for_csd[:ele_limit[1]][0::1]
97+
csd_pots = pots_for_csd[:ele_limit[1]][0::1]
98+
k = KCSD2D(ele_position, csd_pots,
99+
h=1, sigma=1, R_init=32, lambd=1e-9,
100+
xmin= -42, xmax=42, gdx=4,
101+
ymin=0, ymax=4000, gdy=4)
102+
# k.L_curve(Rs=np.linspace(16, 48, 3), lambdas=np.logspace(-9, -3, 20))
103+
return k, k.values('CSD'), k.values('POT'), ele_position
104+
#%%
105+
if __name__ == '__main__':
106+
lowpass = 0.5
107+
highpass = 300
108+
Fs = 30000
109+
resamp = 12
110+
tp= 760
111+
112+
forfilt=np.load('npx_data.npy')
113+
114+
[b,a] = butter(3, [lowpass/(Fs/2.0), highpass/(Fs/2.0)] ,btype = 'bandpass')
115+
filtData = filtfilt(b,a, forfilt)
116+
pots_resamp = filtData[:,::resamp]
117+
pots = pots_resamp[:, :]
118+
Fs=int(Fs/resamp)
119+
120+
pots_for_csd = np.delete(pots, 191, axis=0)
121+
ele_pos_def = eles_to_coords(np.arange(384,0,-1))
122+
ele_pos_for_csd = np.delete(ele_pos_def, 191, axis=0)
123+
124+
k, est_csd, est_pots, ele_pos = do_kcsd(ele_pos_for_csd, pots_for_csd, ele_limit = (0,320))
125+
126+
plot_1D_pics(k, est_csd, est_pots, tp, 15)
127+
plot_2D_pics(k, est_csd, est_pots, tp=tp, cut=15)

figures/npx/npx_data.npy

52.7 MB
Binary file not shown.

0 commit comments

Comments
 (0)