1+ import numpy as np
2+ from kcsd import KCSD2D
3+ from pathlib import Path
4+ import DemoReadSGLXData .readSGLX as readSGLX
5+ import matplotlib .pyplot as plt
6+ import matplotlib .cm as cm
7+ from scipy .signal import filtfilt , butter
8+ from figure_properties import *
9+ # from matplotlib import gridspec
10+ plt .close ('all' )
11+ #%%
12+ def make_plot_spacetime (ax , xx , yy , zz , title = 'True CSD' , cmap = cm .bwr_r , ymin = 0 , ymax = 10000 ):
13+ im = ax .imshow (zz ,extent = [0 , zz .shape [1 ]/ Fs * 1000 ,- 3500 , 500 ], aspect = 'auto' ,
14+ vmax = 1 * zz .max (),vmin = - 1 * zz .max (), cmap = cmap )
15+ ax .set_xlabel ('Time (ms)' )
16+ ax .set_ylabel ('Y ($\mu$m)' )
17+ if 'Pot' in title : ax .set_ylabel ('Y ($\mu$m)' )
18+ ax .set_title (title )
19+ if 'CSD' in title :
20+ plt .colorbar (im , orientation = 'vertical' , format = '%.2f' , ticks = [- 0.01 ,0 ,0.01 ])
21+ else :
22+ plt .colorbar (im , orientation = 'vertical' , format = '%.1f' , ticks = [- 0.6 ,0 ,0.6 ])
23+ # plt.gca().invert_yaxis()
24+
25+ def make_plot (ax , xx , yy , zz , title = 'True CSD' , cmap = cm .bwr ):
26+ ax .set_aspect ('auto' )
27+ levels = np .linspace (zz .min (), - zz .min (), 61 )
28+ # if 'CSD' in title: levels = levels = np.linspace(zz.min(), -zz.min(), 32)
29+ # if 'POT' in title: levels = np.linspace(zz.min(), -zz.min(), 64)
30+ im = ax .contourf (xx , - (yy - 500 ), zz , levels = levels , cmap = cmap )
31+ ax .set_xlabel ('X ($\mu$m)' )
32+ ax .set_ylabel ('Y ($\mu$m)' )
33+ ax .set_title (title )
34+ # ticks = np.linspace(-100,100, 7, endpoint=True)
35+ if 'CSD' in title :
36+ plt .colorbar (im , orientation = 'vertical' , format = '%.2f' , ticks = [- 0.02 ,0 ,0.02 ])
37+ else : plt .colorbar (im , orientation = 'vertical' , format = '%.1f' , ticks = [- 0.6 ,0 ,0.6 ])
38+ plt .scatter (ele_pos [:, 0 ],
39+ - (ele_pos [:, 1 ]- 500 ),
40+ s = 0.8 , color = 'black' )
41+ # plt.gca().invert_yaxis()
42+ return ax
43+
44+ def dan_fetch_electrodes (meta ):
45+ imroList = meta ['imroTbl' ].split (sep = ')' )
46+ # One entry for each channel plus header entry,
47+ # plus a final empty entry following the last ')'
48+ nChan = len (imroList ) - 2
49+ electrode = np .zeros (nChan , dtype = int ) # default type = float
50+ channel = np .zeros (nChan , dtype = int )
51+ bank = np .zeros (nChan , dtype = int )
52+ for i in range (0 , nChan ):
53+ currList = imroList [i + 1 ].split (sep = ' ' )
54+ print (currList )
55+ channel [i ] = int (currList [0 ][1 :])
56+ bank [i ] = int (currList [1 ])
57+ # reference_electrode[i] = currList[2]
58+ # Channel N => Electrode (1+N+384*A), where N = 0:383, A=0:2
59+ electrode = 1 + channel + 384 * bank
60+ return (electrode , channel )
61+
62+ def eles_to_ycoord (eles ):
63+ y_coords = []
64+ for ii in range (192 ):
65+ y_coords .append (ii * 20 )
66+ y_coords .append (ii * 20 )
67+ return y_coords [::- 1 ]
68+
69+ def eles_to_xcoord (eles ):
70+ x_coords = []
71+ for ele in eles :
72+ off = ele % 4
73+ if off == 1 : x_coords .append (- 24 )
74+ elif off == 2 : x_coords .append (8 )
75+ elif off == 3 : x_coords .append (- 8 )
76+ elif off == 0 : x_coords .append (24 )
77+ return x_coords
78+
79+ def eles_to_coords (eles ):
80+ xs = eles_to_xcoord (eles )
81+ ys = eles_to_ycoord (eles )
82+ return np .array ((xs , ys )).T
83+ #%%
84+ binFullPath = Path ('/Users/Wladek/Dysk Google/kCSD_lcurve/validation/'
85+ '15_3800_bank0_defauld PnoFIltr_OLD_headsage_OLD_electrode_g0_t0.imec0.ap.bin' )
86+
87+ meta = readSGLX .readMeta (binFullPath )
88+ Fss = int (readSGLX .SampRate (meta ))
89+
90+ path = '/Users/Wladek/Dysk Google/kCSD_lcurve/validation/'
91+ electrodes , channels = dan_fetch_electrodes (meta )
92+ ch_order = electrodes .argsort ()
93+
94+ rawData = readSGLX .makeMemMapRaw (binFullPath , meta )
95+ selectData = rawData [channels , 30 * Fss :50 * Fss ]
96+ # convData is the potential in uV or mV
97+ if meta ['typeThis' ] == 'imec' :
98+ rawData = 1e3 * readSGLX .GainCorrectIM (selectData , channels , meta )
99+ else :
100+ rawData = 1e3 * readSGLX .GainCorrectNI (selectData , channels , meta )
101+
102+ electrodes .sort ()
103+ ele_pos_def = eles_to_coords (electrodes [::- 1 ])
104+ #%%
105+ ex_time = 16.6 #12.7
106+ lowpass = 0.5 #20 beta
107+ highpass = 300 #50 beta
108+ after = 0.3
109+ forfilt = rawData [:,int ((ex_time - 0.3 )* Fss ):int ((ex_time + after )* Fss )]
110+ # forfilt = detrend(forfilt, bp=np.array([0,int(0.1*Fss)]))
111+ # for i in range(384): forfilt[i] = forfilt[i]/np.std(forfilt[i])
112+ [b ,a ] = butter (3 , [lowpass / (Fss / 2.0 ), highpass / (Fss / 2.0 )] ,btype = 'bandpass' )
113+ filtData = filtfilt (b ,a , forfilt )
114+ np .save ('npx_data' , filtData )
115+ #%%
116+ resamp = 12
117+ pots_resamp = filtData [:,::resamp ]
118+ pots = pots_resamp [:, :]
119+ Fs = int (Fss / resamp )
120+ #%%
121+ time = np .linspace (0 , pots .shape [1 ]/ Fs , pots .shape [1 ])
122+ plt .figure ()
123+ plt .subplot (121 )
124+ for ch in range (0 ,384 ,8 ):#, potsy.shape[0], 8):
125+ plt .plot (time , pots [ch ,:]+ 1 * ch , color = 'grey' , lw = 0.3 )
126+ print ('start averaging' )
127+ plt .subplot (122 )
128+ plt .imshow (pots [::- 1 ], extent = [0 ,pots .shape [1 ]/ Fs ,pots .shape [0 ],0 ],
129+ aspect = 'auto' , cmap = 'PRGn' ,
130+ vmin = - pots .max (), vmax = pots .max ())
131+ # plt.xlim(280, 330)
132+ # plt.subplot(133)
133+ # %%
134+ pots_for_csd = np .delete (pots , 191 , axis = 0 )
135+ ele_pos_for_csd = np .delete (ele_pos_def , 191 , axis = 0 )
136+ # pots_for_csd = pots
137+ # ele_pos_for_csd = ele_pos_def
138+ def do_kcsd (ele_pos_for_csd , pots_for_csd , ele_limit ):
139+ ele_position = ele_pos_for_csd [:ele_limit [1 ]][0 ::1 ]
140+ csd_pots = pots_for_csd [:ele_limit [1 ]][0 ::1 ]
141+ k = KCSD2D (ele_position , csd_pots ,
142+ h = 1 , sigma = 1 ,
143+ xmin = - 42 , xmax = 42 , gdx = 4 ,
144+ ymin = 0 , ymax = 4000 , gdy = 4 )
145+ k .L_curve (Rs = np .linspace (32 , 90 , 1 ), lambdas = np .logspace (- 9 , - 7 , 1 ))
146+ # k.cross_validate(Rs=np.linspace(20, 30, 1), lambdas=np.logspace(-5, -3, 20))
147+ plt .figure ()
148+ plt .imshow (k .curve_surf )#, vmin=-k.curve_surf.max(), vmax=k.curve_surf.max(), cmap='BrBG_r')
149+ plt .colorbar ()
150+ return k , k .values ('CSD' ), k .values ('POT' ), ele_position
151+
152+ k , est_csd , est_pots , ele_pos = do_kcsd (ele_pos_for_csd , pots_for_csd , ele_limit = (0 ,320 ))
153+ #%%
154+ plt .close ('all' )
155+ save = 1
156+ tp = 760
157+ def plot_1D_pics (k , est_csd , est_pots , cut = 9 ):
158+ plt .figure (figsize = (12 , 8 ))
159+ # plt.suptitle('plane: '+str(k.estm_x[cut,0])+' $\mu$m '+' $\lambda$ : '+str(k.lambd)+
160+ # ' R: '+ str(k.R))
161+ ax1 = plt .subplot (122 )
162+ set_axis (ax1 , - 0.05 , 1.05 , letter = 'B' )
163+ make_plot_spacetime (ax1 , k .estm_x , k .estm_y , est_csd [cut ,:,:],
164+ title = 'Estimated CSD' , cmap = cm .bwr )
165+ for lvl , name in zip ([- 500 ,- 850 ,- 2000 ], ['II/III' , 'IV' , 'V/VI' ]):
166+ plt .axhline (lvl , ls = '--' , color = 'grey' )
167+ plt .text (340 , lvl + 20 , name )
168+ plt .xlim (250 , 400 )
169+ ax2 = plt .subplot (121 )
170+ set_axis (ax2 , - 0.05 , 1.05 , letter = 'A' )
171+ make_plot_spacetime (ax2 , k .estm_x , k .estm_y , est_pots [cut ,:,:],
172+ title = 'Estimated LFP' , cmap = cm .PRGn )
173+ plt .axvline (tp / Fs * 1000 , ls = '--' , color = 'grey' , lw = 2 )
174+ plt .xlim (250 , 400 )
175+ plt .tight_layout ()
176+ plt .savefig (savedir + 'Figure_15' , dpi = 300 )
177+ savedir = '/Users/Wladek/Dysk Google/kCSD_lcurve/validation/'
178+ for cut in range (15 ,16 ,1 ): plot_1D_pics (k , est_csd , est_pots , cut )
179+ # plt.close('all')
180+
181+ def plot_2D_pics (tp , cut , save = 0 ):
182+ plt .figure (figsize = (12 , 8 ))
183+ ax1 = plt .subplot (122 )
184+ set_axis (ax1 , - 0.05 , 1.05 , letter = 'B' )
185+ make_plot (ax1 , k .estm_x , k .estm_y , est_csd [:,:,tp ],
186+ title = 'Estimated CSD' , cmap = cm .bwr )
187+ # for i in range(383): plt.text(ele_pos_for_csd[i,0], ele_pos_for_csd[i,1]+8, str(i+1))
188+ plt .axvline (k .estm_x [cut ][0 ], ls = '--' , color = 'grey' , lw = 2 )
189+ ax2 = plt .subplot (121 )
190+ set_axis (ax2 , - 0.05 , 1.05 , letter = 'A' )
191+ make_plot (ax2 , k .estm_x , k .estm_y , est_pots [:,:,tp ],
192+ title = 'Estimated LFP' , cmap = cm .PRGn )
193+ # plt.suptitle(' $\lambda$ : '+str(k.lambd)+ ' R: '+ str(k.R))
194+ plt .tight_layout ()
195+ plt .savefig (savedir + 'Figure_14' , dpi = 300 )
196+
197+ plot_2D_pics (tp = tp , cut = 15 )
0 commit comments