|
| 1 | +import numpy as np |
| 2 | +from kcsd import KCSD2D |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import matplotlib.cm as cm |
| 5 | +from scipy.signal import filtfilt, butter |
| 6 | +from figure_properties import * |
| 7 | +plt.close('all') |
| 8 | +#%% |
| 9 | +def make_plot_spacetime(ax, xx, yy, zz, title='True CSD', cmap=cm.bwr_r, ymin=0, ymax=10000): |
| 10 | + im = ax.imshow(zz,extent=[0, zz.shape[1]/Fs*1000,-3500, 500], aspect='auto', |
| 11 | + vmax = 1*zz.max(),vmin = -1*zz.max(), cmap=cmap) |
| 12 | + ax.set_xlabel('Time (ms)') |
| 13 | + ax.set_ylabel('Y ($\mu$m)') |
| 14 | + if 'Pot' in title: ax.set_ylabel('Y ($\mu$m)') |
| 15 | + ax.set_title(title) |
| 16 | + if 'CSD' in title: |
| 17 | + plt.colorbar(im, orientation='vertical', format='%.2f', ticks = [-0.01,0,0.01]) |
| 18 | + else: |
| 19 | + plt.colorbar(im, orientation='vertical', format='%.1f', ticks = [-0.6,0,0.6]) |
| 20 | + # plt.gca().invert_yaxis() |
| 21 | + |
| 22 | +def make_plot(ax, xx, yy, zz, title='True CSD', cmap=cm.bwr): |
| 23 | + ax.set_aspect('auto') |
| 24 | + levels = np.linspace(zz.min(), -zz.min(), 61) |
| 25 | + im = ax.contourf(xx, -(yy-500), zz, levels=levels, cmap=cmap) |
| 26 | + ax.set_xlabel('X ($\mu$m)') |
| 27 | + ax.set_ylabel('Y ($\mu$m)') |
| 28 | + ax.set_title(title) |
| 29 | + if 'CSD' in title: |
| 30 | + plt.colorbar(im, orientation='vertical', format='%.2f', ticks=[-0.02,0,0.02]) |
| 31 | + else: plt.colorbar(im, orientation='vertical', format='%.1f', ticks=[-0.6,0,0.6]) |
| 32 | + plt.scatter(ele_pos[:, 0], |
| 33 | + -(ele_pos[:, 1]-500), |
| 34 | + s=0.8, color='black') |
| 35 | + # plt.gca().invert_yaxis() |
| 36 | + return ax |
| 37 | + |
| 38 | +def eles_to_ycoord(eles): |
| 39 | + y_coords = [] |
| 40 | + for ii in range(192): |
| 41 | + y_coords.append(ii*20) |
| 42 | + y_coords.append(ii*20) |
| 43 | + return y_coords[::-1] |
| 44 | + |
| 45 | +def eles_to_xcoord(eles): |
| 46 | + x_coords = [] |
| 47 | + for ele in eles: |
| 48 | + off = ele%4 |
| 49 | + if off == 1: x_coords.append(-24) |
| 50 | + elif off == 2: x_coords.append(8) |
| 51 | + elif off == 3: x_coords.append(-8) |
| 52 | + elif off==0: x_coords.append(24) |
| 53 | + return x_coords |
| 54 | + |
| 55 | +def eles_to_coords(eles): |
| 56 | + xs = eles_to_xcoord(eles) |
| 57 | + ys = eles_to_ycoord(eles) |
| 58 | + return np.array((xs, ys)).T |
| 59 | + |
| 60 | +def plot_1D_pics(k, est_csd, est_pots, tp, cut=9): |
| 61 | + plt.figure(figsize=(12, 8)) |
| 62 | + # plt.suptitle('plane: '+str(k.estm_x[cut,0])+' $\mu$m '+' $\lambda$ : '+str(k.lambd)+ |
| 63 | + # ' R: '+ str(k.R)) |
| 64 | + ax1 = plt.subplot(122) |
| 65 | + set_axis(ax1, -0.05, 1.05, letter= 'B') |
| 66 | + make_plot_spacetime(ax1, k.estm_x, k.estm_y, est_csd[cut,:,:], |
| 67 | + title='Estimated CSD', cmap='bwr') |
| 68 | + for lvl, name in zip([-500,-850,-2000], ['II/III', 'IV', 'V/VI']): |
| 69 | + plt.axhline(lvl, ls='--', color='grey') |
| 70 | + plt.text(340, lvl+20, name) |
| 71 | + plt.xlim(250, 400) |
| 72 | + ax2 = plt.subplot(121) |
| 73 | + set_axis(ax2, -0.05, 1.05, letter= 'A') |
| 74 | + make_plot_spacetime(ax2, k.estm_x, k.estm_y, est_pots[cut,:,:], |
| 75 | + title='Estimated LFP', cmap='PRGn') |
| 76 | + plt.axvline(tp/Fs*1000, ls='--', color ='grey', lw=2) |
| 77 | + plt.xlim(250, 400) |
| 78 | + plt.tight_layout() |
| 79 | + |
| 80 | +def plot_2D_pics(k, est_csd, est_pots, tp, cut, save=0): |
| 81 | + plt.figure(figsize=(12, 8)) |
| 82 | + ax1 = plt.subplot(122) |
| 83 | + set_axis(ax1, -0.05, 1.05, letter= 'B') |
| 84 | + make_plot(ax1, k.estm_x, k.estm_y, est_csd[:,:,tp], |
| 85 | + title='Estimated CSD', cmap='bwr') |
| 86 | + # for i in range(383): plt.text(ele_pos_for_csd[i,0], ele_pos_for_csd[i,1]+8, str(i+1)) |
| 87 | + plt.axvline(k.estm_x[cut][0], ls='--', color ='grey', lw=2) |
| 88 | + ax2 = plt.subplot(121) |
| 89 | + set_axis(ax2, -0.05, 1.05, letter= 'A') |
| 90 | + make_plot(ax2, k.estm_x, k.estm_y, est_pots[:,:,tp], |
| 91 | + title='Estimated LFP', cmap='PRGn') |
| 92 | + # plt.suptitle(' $\lambda$ : '+str(k.lambd)+ ' R: '+ str(k.R)) |
| 93 | + plt.tight_layout() |
| 94 | + |
| 95 | +def do_kcsd(ele_pos_for_csd, pots_for_csd, ele_limit): |
| 96 | + ele_position = ele_pos_for_csd[:ele_limit[1]][0::1] |
| 97 | + csd_pots = pots_for_csd[:ele_limit[1]][0::1] |
| 98 | + k = KCSD2D(ele_position, csd_pots, |
| 99 | + h=1, sigma=1, R_init=32, lambd=1e-9, |
| 100 | + xmin= -42, xmax=42, gdx=4, |
| 101 | + ymin=0, ymax=4000, gdy=4) |
| 102 | + # k.L_curve(Rs=np.linspace(16, 48, 3), lambdas=np.logspace(-9, -3, 20)) |
| 103 | + return k, k.values('CSD'), k.values('POT'), ele_position |
| 104 | +#%% |
| 105 | +if __name__ == '__main__': |
| 106 | + lowpass = 0.5 |
| 107 | + highpass = 300 |
| 108 | + Fs = 30000 |
| 109 | + resamp = 12 |
| 110 | + tp= 760 |
| 111 | + |
| 112 | + forfilt=np.load('npx_data.npy') |
| 113 | + |
| 114 | + [b,a] = butter(3, [lowpass/(Fs/2.0), highpass/(Fs/2.0)] ,btype = 'bandpass') |
| 115 | + filtData = filtfilt(b,a, forfilt) |
| 116 | + pots_resamp = filtData[:,::resamp] |
| 117 | + pots = pots_resamp[:, :] |
| 118 | + Fs=int(Fs/resamp) |
| 119 | + |
| 120 | + pots_for_csd = np.delete(pots, 191, axis=0) |
| 121 | + ele_pos_def = eles_to_coords(np.arange(384,0,-1)) |
| 122 | + ele_pos_for_csd = np.delete(ele_pos_def, 191, axis=0) |
| 123 | + |
| 124 | + k, est_csd, est_pots, ele_pos = do_kcsd(ele_pos_for_csd, pots_for_csd, ele_limit = (0,320)) |
| 125 | + |
| 126 | + plot_1D_pics(k, est_csd, est_pots, tp, 15) |
| 127 | + plot_2D_pics(k, est_csd, est_pots, tp=tp, cut=15) |
0 commit comments