@@ -28,11 +28,16 @@ def set_axis(ax, letter=None):
2828
2929
3030def make_subplot (ax , true_csd , est_csd , estm_x , title = None , ele_pos = None ,
31- xlabel = False , ylabel = False , letter = '' , t_max = None ):
31+ xlabel = False , ylabel = False , letter = '' , t_max = None ,
32+ est_csd_LC = None ):
3233
3334 x = np .linspace (0 , 1 , 100 )
3435 l1 = ax .plot (x , true_csd , label = 'True CSD' , lw = 2. )
35- l2 = ax .plot (estm_x , est_csd , label = 'kCSD' , lw = 2. )
36+ if est_csd_LC is not None :
37+ l2 = ax .plot (estm_x , est_csd , label = 'kCSD_CV' , lw = 2. )
38+ l3 = ax .plot (estm_x , est_csd_LC , label = 'kCSD_LC' , lw = 2. )
39+ else :
40+ l2 = ax .plot (estm_x , est_csd , label = 'kCSD' , lw = 2. )
3641 s1 = ax .scatter (ele_pos , np .zeros (len (ele_pos )), 13 , 'k' , label = 'Electrodes' )
3742 #ax.legend(fontsize=10)
3843 ax .set_xlim ([0 , 1 ])
@@ -68,7 +73,8 @@ def generate_figure(R, MU, N_SRC, TRUE_CSD_XLIMS, TOTAL_ELE, SAVE_PATH,
6873 fig = plt .figure (figsize = (15 , 12 ))
6974 widths = [1 , 1 , 1 ]
7075 heights = [1 , 1 , 1 ]
71- gs = gridspec .GridSpec (3 , 3 , height_ratios = heights , width_ratios = widths , hspace = 0.45 , wspace = 0.3 )
76+ gs = gridspec .GridSpec (3 , 3 , height_ratios = heights , width_ratios = widths ,
77+ hspace = 0.45 , wspace = 0.3 )
7278
7379 ax = fig .add_subplot (gs [0 , 0 ])
7480 xmin = 0
@@ -171,12 +177,177 @@ def generate_figure(R, MU, N_SRC, TRUE_CSD_XLIMS, TOTAL_ELE, SAVE_PATH,
171177 obj = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
172178 sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
173179 xmax = xmax , method = method , Rs = Rs , lambdas = lambdas )
174- ax = make_subplot (ax , true_csd , obj .values ('CSD' ), obj .estm_x , ele_pos = ele_pos ,
175- title = None , xlabel = True , ylabel = False , letter = 'I' )
180+ ax = make_subplot (ax , true_csd , obj .values ('CSD' ), obj .estm_x ,
181+ ele_pos = ele_pos , title = None , xlabel = True , ylabel = False ,
182+ letter = 'I' )
176183 handles , labels = ax .get_legend_handles_labels ()
177184 fig .legend (handles , labels , loc = 'lower center' , ncol = 3 , frameon = False )
178-
179- #plt.tight_layout()
185+
186+ fig .savefig (os .path .join (SAVE_PATH , 'targeted_basis_' + method +
187+ '_noise_' + str (noise ) + '.png' ), dpi = 300 )
188+ plt .show ()
189+
190+
191+ def generate_figure_CVLC (R , MU , N_SRC , TRUE_CSD_XLIMS , TOTAL_ELE , SAVE_PATH ,
192+ Rs = None , lambdas = None , noise = None ):
193+
194+ m_cv = 'cross-validation'
195+ m_lc = 'L-curve'
196+ method = 'CV_LC'
197+ ELE_LIMS = [0 , 1. ]
198+ csd_at , true_csd , ele_pos , pots , val = tb .simulate_data (tb .csd_profile ,
199+ TRUE_CSD_XLIMS , R ,
200+ MU , TOTAL_ELE ,
201+ ELE_LIMS ,
202+ noise = noise )
203+
204+ fig = plt .figure (figsize = (15 , 12 ))
205+ widths = [1 , 1 , 1 ]
206+ heights = [1 , 1 , 1 ]
207+ gs = gridspec .GridSpec (3 , 3 , height_ratios = heights , width_ratios = widths ,
208+ hspace = 0.45 , wspace = 0.3 )
209+
210+ ax = fig .add_subplot (gs [0 , 0 ])
211+ xmin = 0
212+ xmax = 1
213+ ext_x = 0
214+ obj_CV = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
215+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
216+ xmax = xmax , method = m_cv , Rs = Rs , lambdas = lambdas )
217+ obj_LC = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
218+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
219+ xmax = xmax , method = m_lc , Rs = Rs , lambdas = lambdas )
220+ make_subplot (ax , true_csd , obj_CV .values ('CSD' ), obj_CV .estm_x ,
221+ ele_pos = ele_pos , title = 'Basis limits = [0, 1]' , xlabel = False ,
222+ ylabel = True , letter = 'A' , est_csd_LC = obj_LC .values ('CSD' ))
223+
224+ ax = fig .add_subplot (gs [0 , 1 ])
225+ xmin = - 0.5
226+ xmax = 1
227+ ext_x = - 0.5
228+ obj_CV = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
229+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
230+ xmax = xmax , method = m_cv , Rs = Rs , lambdas = lambdas )
231+ obj_LC = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
232+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
233+ xmax = xmax , method = m_lc , Rs = Rs , lambdas = lambdas )
234+ make_subplot (ax , true_csd , obj_CV .values ('CSD' ), obj_CV .estm_x ,
235+ ele_pos = ele_pos , title = 'Basis limits = [0, 0.5]' ,
236+ xlabel = False , ylabel = False , letter = 'B' ,
237+ est_csd_LC = obj_LC .values ('CSD' ))
238+
239+ ax = fig .add_subplot (gs [0 , 2 ])
240+ xmin = 0
241+ xmax = 1.5
242+ ext_x = - 0.5
243+ obj_CV = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
244+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
245+ xmax = xmax , method = m_cv , Rs = Rs , lambdas = lambdas )
246+ obj_LC = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
247+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
248+ xmax = xmax , method = m_lc , Rs = Rs , lambdas = lambdas )
249+ make_subplot (ax , true_csd , obj_CV .values ('CSD' ), obj_CV .estm_x ,
250+ ele_pos = ele_pos , title = 'Basis limits = [0.5, 1]' ,
251+ xlabel = False , ylabel = False , letter = 'C' ,
252+ est_csd_LC = obj_LC .values ('CSD' ))
253+
254+ ELE_LIMS = [0 , 0.5 ]
255+ # TOTAL_ELE = 6
256+ csd_at , true_csd , ele_pos , pots , val = tb .simulate_data (tb .csd_profile ,
257+ TRUE_CSD_XLIMS , R ,
258+ MU , TOTAL_ELE ,
259+ ELE_LIMS )
260+ ax = fig .add_subplot (gs [1 , 0 ])
261+ xmin = 0
262+ xmax = 1
263+ ext_x = 0
264+ obj_CV = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
265+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
266+ xmax = xmax , method = m_cv , Rs = Rs , lambdas = lambdas )
267+ obj_LC = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
268+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
269+ xmax = xmax , method = m_lc , Rs = Rs , lambdas = lambdas )
270+ make_subplot (ax , true_csd , obj_CV .values ('CSD' ), obj_CV .estm_x ,
271+ ele_pos = ele_pos , title = None , xlabel = False , ylabel = True ,
272+ letter = 'D' , est_csd_LC = obj_LC .values ('CSD' ))
273+
274+ ax = fig .add_subplot (gs [1 , 1 ])
275+ xmin = - 0.5
276+ xmax = 1
277+ ext_x = - 0.5
278+ obj_CV = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
279+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
280+ xmax = xmax , method = m_cv , Rs = Rs , lambdas = lambdas )
281+ obj_LC = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
282+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
283+ xmax = xmax , method = m_lc , Rs = Rs , lambdas = lambdas )
284+ make_subplot (ax , true_csd , obj_CV .values ('CSD' ), obj_CV .estm_x ,
285+ ele_pos = ele_pos , title = None , xlabel = False , ylabel = False ,
286+ letter = 'E' , est_csd_LC = obj_LC .values ('CSD' ))
287+
288+ ax = fig .add_subplot (gs [1 , 2 ])
289+ xmin = 0
290+ xmax = 1.5
291+ ext_x = - 0.5
292+ obj_CV = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
293+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
294+ xmax = xmax , method = m_cv , Rs = Rs , lambdas = lambdas )
295+ obj_LC = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
296+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
297+ xmax = xmax , method = m_lc , Rs = Rs , lambdas = lambdas )
298+ make_subplot (ax , true_csd , obj_CV .values ('CSD' ), obj_CV .estm_x ,
299+ ele_pos = ele_pos , title = None , xlabel = False , ylabel = False ,
300+ letter = 'F' , est_csd_LC = obj_LC .values ('CSD' ))
301+
302+ ELE_LIMS = [0.5 , 1. ]
303+ csd_at , true_csd , ele_pos , pots , val = tb .simulate_data (tb .csd_profile ,
304+ TRUE_CSD_XLIMS , R ,
305+ MU , TOTAL_ELE ,
306+ ELE_LIMS )
307+ ax = fig .add_subplot (gs [2 , 0 ])
308+ xmin = 0
309+ xmax = 1
310+ ext_x = 0
311+ obj_CV = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
312+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
313+ xmax = xmax , method = m_cv , Rs = Rs , lambdas = lambdas )
314+ obj_LC = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
315+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
316+ xmax = xmax , method = m_lc , Rs = Rs , lambdas = lambdas )
317+ make_subplot (ax , true_csd , obj_CV .values ('CSD' ), obj_CV .estm_x ,
318+ ele_pos = ele_pos , title = None , xlabel = True , ylabel = True ,
319+ letter = 'G' , est_csd_LC = obj_LC .values ('CSD' ))
320+
321+ ax = fig .add_subplot (gs [2 , 1 ])
322+ xmin = - 0.5
323+ xmax = 1
324+ ext_x = - 0.5
325+ obj_CV = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
326+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
327+ xmax = xmax , method = m_cv , Rs = Rs , lambdas = lambdas )
328+ obj_LC = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
329+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
330+ xmax = xmax , method = m_lc , Rs = Rs , lambdas = lambdas )
331+ make_subplot (ax , true_csd , obj_CV .values ('CSD' ), obj_CV .estm_x ,
332+ ele_pos = ele_pos , title = None , xlabel = True , ylabel = False ,
333+ letter = 'H' , est_csd_LC = obj_LC .values ('CSD' ))
334+
335+ ax = fig .add_subplot (gs [2 , 2 ])
336+ xmin = 0
337+ xmax = 1.5
338+ ext_x = - 0.5
339+ obj_CV = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
340+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
341+ xmax = xmax , method = m_cv , Rs = Rs , lambdas = lambdas )
342+ obj_LC = tb .modified_bases (val , pots , ele_pos , N_SRC , h = 0.25 ,
343+ sigma = 0.3 , gdx = 0.01 , ext_x = ext_x , xmin = xmin ,
344+ xmax = xmax , method = m_lc , Rs = Rs , lambdas = lambdas )
345+ ax = make_subplot (ax , true_csd , obj_CV .values ('CSD' ), obj_CV .estm_x ,
346+ ele_pos = ele_pos , title = None , xlabel = True , ylabel = False ,
347+ letter = 'I' , est_csd_LC = obj_LC .values ('CSD' ))
348+ handles , labels = ax .get_legend_handles_labels ()
349+ fig .legend (handles , labels , loc = 'lower center' , ncol = 4 , frameon = False )
350+
180351 fig .savefig (os .path .join (SAVE_PATH , 'targeted_basis_' + method +
181352 '_noise_' + str (noise ) + '.png' ), dpi = 300 )
182353 plt .show ()
@@ -201,5 +372,7 @@ def generate_figure(R, MU, N_SRC, TRUE_CSD_XLIMS, TOTAL_ELE, SAVE_PATH,
201372 Rs = np .arange (0.1 , 0.4 , 0.05 )
202373# Rs = np.array([0.2])
203374 lambdas = np .zeros (1 )
204- generate_figure (R , MU , N_SRC , TRUE_CSD_XLIMS , TOTAL_ELE , SAVE_PATH , method ,
205- Rs , lambdas = None , noise = 10 )
375+ # generate_figure(R, MU, N_SRC, TRUE_CSD_XLIMS, TOTAL_ELE, SAVE_PATH,
376+ # method, Rs, lambdas=None, noise=10)
377+ generate_figure_CVLC (R , MU , N_SRC , TRUE_CSD_XLIMS , TOTAL_ELE , SAVE_PATH ,
378+ Rs = Rs , lambdas = None , noise = 10 )
0 commit comments