Skip to content

Commit 08e9a3f

Browse files
Add FFT option to CSL table.
1 parent 21869b5 commit 08e9a3f

3 files changed

Lines changed: 104 additions & 21 deletions

File tree

djimaging/tables/response/csl/csl.py

Lines changed: 99 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ class CslMetrics(response.CslMetricsTemplate):
1515
presentation_table = Presentation
1616
traces_table = Traces
1717
18+
_stim_frequency = 2.0 # Hz
1819
_kind = 'naka_rushton'
20+
_metric_kind = 'auc'
1921
_w_zero_fit = True
2022
_dt_order = 3
2123
_dt_window = 60
@@ -26,6 +28,8 @@ class CslMetrics(response.CslMetricsTemplate):
2628
_dt_baseline_b = 2.9
2729
_dt_window_plateau = 1.0
2830
31+
_n_fit_repeats = 3
32+
2933
# Populate with plots:
3034
CslMetrics().populate(make_kwargs=dict(plot=True))
3135
"""
@@ -48,7 +52,9 @@ class CslMetricsTemplate(dj.Computed):
4852
database = ""
4953
_stim_restriction = dict(stim_name='csl')
5054

55+
_stim_frequency = 2.0 # Hz
5156
_fit_kind = 'naka_rushton'
57+
_metric_kind = 'auc'
5258
_w_zero_fit = 1
5359
_dt_order = 3
5460
_dt_window = 60
@@ -59,9 +65,11 @@ class CslMetricsTemplate(dj.Computed):
5965
_dt_baseline_b = 2.9
6066
_dt_window_plateau = 1.0
6167

68+
_n_fit_repeats = 3
69+
6270
@property
6371
def definition(self):
64-
definition = '''
72+
definition = f'''
6573
# Normalized contrast step light response
6674
-> self.traces_table
6775
---
@@ -75,8 +83,12 @@ def definition(self):
7583
contrast_sensitivity: float # Relating step responses to contrast responses
7684
tonic_release_index: float # Tonic release index as in Franke et al 2017, but for last contrast step
7785
plateau_index: float # Plateau index (a - b) / (a + b), similar to Franke et al 2017
78-
contrast_aucs: blob # Area under the curve for each contrast, incl. baseline at i=0 if _w_zero_fit=1
79-
fit_half_amp_y = NULL : float # y at half amplitude of fit
86+
contrast_{self._metric_kind}s: blob # Metric per contrast (e.g., auc or fft_f1), incl. baseline at i=0 if _w_zero_fit=1
87+
'''
88+
# Add an optional phases field when using FFT metric
89+
if getattr(self, '_metric_kind', 'auc') == 'fft_f1':
90+
definition += ' contrast_fft_f1_phases: blob # Phase of F1 (degrees) per contrast; baseline at i=0 if present\n'
91+
definition += f''' fit_half_amp_y = NULL : float # y at half amplitude of fit
8092
fit_half_amp_x = NULL : float # x at half amplitude of fit
8193
fit_half_amp_slope = NULL : float # Slope at half amplitude of fit
8294
droppedlastrep_flag: tinyint unsigned # Was the last repetition incomplete and therefore dropped?
@@ -123,7 +135,8 @@ def _make_fetch_and_compute(self, key, plot=False):
123135

124136
d_csl = analyse_csl_response(
125137
trace, trace_t0, trace_dt, triggertimes, ntrigger_rep, fs_resample, plot=plot,
126-
w_zero_fit=self._w_zero_fit, fit_kind=self._fit_kind,
138+
w_zero_fit=self._w_zero_fit, fit_kind=self._fit_kind, metric_kind=self._metric_kind,
139+
stim_frequency=self._stim_frequency, n_fit_repeats=self._n_fit_repeats,
127140
dt_order=self._dt_order, dt_window=self._dt_window, peak_q=self._peak_q,
128141
contrast_levels=self._contrast_levels, dt_breaks=self._dt_breaks,
129142
dt_baseline_a=self._dt_baseline_a, dt_baseline_b=self._dt_baseline_b,
@@ -135,14 +148,19 @@ def _make_fetch_and_compute(self, key, plot=False):
135148
bc_snippets = d_csl['bc_snippets'][::n_lines, :]
136149
avg = d_csl['avg'][::n_lines]
137150

151+
metric_field = f"contrast_{self._metric_kind}s"
138152
entry = dict(
139153
**key, average=avg, snippets=bc_snippets, fs=fs, fs_metrics=fs_resample,
140154
qidx_full=d_csl['qidx_full'], qidx_contrast=d_csl['qidx_contrast'],
141155
on_off_index=d_csl['on_off_index'], contrast_sensitivity=d_csl['contrast_sensitivity'],
142156
tonic_release_index=d_csl['tonic_release_index'], plateau_index=d_csl['plateau_index'],
143157
fit_half_amp_y=d_csl['fit'][0], fit_half_amp_x=d_csl['fit'][1], fit_half_amp_slope=d_csl['fit'][2],
144-
contrast_aucs=d_csl['contrast_aucs'], droppedlastrep_flag=d_csl['droppedlastrep_flag']
158+
droppedlastrep_flag=d_csl['droppedlastrep_flag']
145159
)
160+
entry[metric_field] = d_csl['contrast_metrics']
161+
# Optionally store phases when using fft_f1 metric
162+
if getattr(self, '_metric_kind', 'auc') == 'fft_f1' and 'contrast_phases_deg' in d_csl:
163+
entry['contrast_fft_f1_phases'] = d_csl['contrast_phases_deg']
146164

147165
return entry
148166

@@ -175,7 +193,8 @@ def plot1(self, key=None):
175193

176194
def analyse_csl_response(
177195
trace, trace_t0, trace_dt, triggertimes, ntrigger_rep, fs_resample, plot=False,
178-
w_zero_fit=1, fit_kind='naka_rushton',
196+
w_zero_fit=1, fit_kind='naka_rushton', metric_kind='auc',
197+
stim_frequency=6.0, n_fit_repeats=3,
179198
dt_order=3, dt_window=60, peak_q=98,
180199
contrast_levels=(0.10, 0.20, 0.40, 0.60, 0.80, 1.00),
181200
dt_breaks=3., dt_baseline_a=1.6, dt_baseline_b=2.9, dt_window_plateau=1.0):
@@ -269,19 +288,42 @@ def analyse_csl_response(
269288
amp_100 = np.maximum(0, np.percentile(avg[idxs_cs_a[-1]:idxs_cs_b[-1]], q=peak_q))
270289
contrast_sensitivity = amp_100 / np.maximum(amp_100 + amp_step, 1e-9)
271290

272-
# Fit curve to responses
273-
contrast_aucs = np.array([np.mean(np.abs(avg[ia:ib])) for ia, ib in zip(idxs_cs_a, idxs_cs_b)])
274-
if w_zero_fit:
275-
zero_value = np.mean(np.abs(avg[idxs_baseline]))
276-
contrast_aucs = np.append(zero_value, contrast_aucs)
277-
contrast_levels = np.append(0, contrast_levels)
291+
# Compute contrast metrics
292+
if metric_kind == 'auc':
293+
contrast_metrics = np.array([np.mean(np.abs(avg[ia:ib])) for ia, ib in zip(idxs_cs_a, idxs_cs_b)])
294+
if w_zero_fit:
295+
zero_value = np.mean(np.abs(avg[idxs_baseline]))
296+
contrast_metrics = np.append(zero_value, contrast_metrics)
297+
contrast_levels = np.append(0, contrast_levels)
298+
elif metric_kind == 'fft_f1':
299+
# Contrast windows
300+
amps = []
301+
phases = []
302+
for ia, ib in zip(idxs_cs_a, idxs_cs_b):
303+
a, p = f1_amp_phase(
304+
avg[ia:ib], stim_frequency=stim_frequency, trace_frequency=fs_resample)
305+
amps.append(a)
306+
phases.append(p)
307+
contrast_metrics = np.array(amps, dtype=float)
308+
contrast_phases_deg = np.array(phases, dtype=float)
309+
if w_zero_fit:
310+
zero_a, zero_p = f1_amp_phase(
311+
avg[idxs_baseline], stim_frequency=stim_frequency, trace_frequency=fs_resample)
312+
contrast_metrics = np.append(zero_a, contrast_metrics)
313+
contrast_phases_deg = np.append(zero_p, contrast_phases_deg)
314+
contrast_levels = np.append(0, contrast_levels)
315+
else:
316+
raise ValueError(f"Unknown metric_kind={metric_kind}")
278317

318+
# Fit curve to metric-vs-contrast
279319
if fit_kind == 'naka_rushton':
280320
fit = fit_naka_rushton(
281-
y_data=normalize_zero_one(contrast_aucs), x_data=contrast_levels, ax=None if not plot else axs['H'])
321+
y_data=normalize_zero_one(contrast_metrics), x_data=contrast_levels, n=n_fit_repeats,
322+
ax=None if not plot else axs['H'])
282323
elif fit_kind == 'sigmoid':
283324
fit = fit_sigmoid(
284-
y_data=normalize_zero_one(contrast_aucs), x_data=contrast_levels, ax=None if not plot else axs['H'])
325+
y_data=normalize_zero_one(contrast_metrics), x_data=contrast_levels, n=n_fit_repeats,
326+
ax=None if not plot else axs['H'])
285327
else:
286328
raise ValueError(f'Unknown fit_kind={fit_kind}')
287329

@@ -358,8 +400,9 @@ def analyse_csl_response(
358400

359401
ax.axhline(0, ls=':', color='gray')
360402
ax2 = ax.twinx()
361-
ax2.bar(x=rel_time[(idxs_cs_a + idxs_cs_b) // 2], height=contrast_aucs[-len(idxs_cs_a):], color='k', alpha=0.5)
362-
vabsmax = np.max(np.abs(contrast_aucs[-len(idxs_cs_a):]))
403+
ax2.bar(x=rel_time[(idxs_cs_a + idxs_cs_b) // 2], height=contrast_metrics[-len(idxs_cs_a):], color='k',
404+
alpha=0.5)
405+
vabsmax = np.max(np.abs(contrast_metrics[-len(idxs_cs_a):]))
363406
ax2.set_ylim(-vabsmax * 1.1, vabsmax * 1.1)
364407
ax2.axhline(0, ls='--', color='gray')
365408

@@ -372,7 +415,46 @@ def analyse_csl_response(
372415
bc_snippets=bc_snippets, avg=avg, qidx_full=qidx_full, qidx_contrast=qidx_contrast,
373416
tonic_release_index=tonic_release_index,
374417
plateau_index=plateau_index, on_off_index=on_off_index, contrast_sensitivity=contrast_sensitivity,
375-
contrast_aucs=contrast_aucs, fit=fit, droppedlastrep_flag=droppedlastrep_flag
418+
contrast_metrics=contrast_metrics, fit=fit, droppedlastrep_flag=droppedlastrep_flag
376419
)
420+
if metric_kind == 'fft_f1':
421+
# Attach phases if they were computed
422+
try:
423+
result_dict['contrast_phases_deg'] = contrast_phases_deg
424+
except NameError:
425+
pass
377426

378427
return result_dict
428+
429+
430+
def f1_amp_phase(seg, stim_frequency, trace_frequency):
431+
""" First harmonic at the stimulation frequency within each window."""
432+
x = np.asarray(seg)
433+
N = x.size
434+
if N < 3:
435+
return 0.0, 0.0
436+
# make N odd to avoid Nyquist ambiguity and mimic Igor behavior
437+
if (N % 2) == 0:
438+
x = x[:-1]
439+
N = x.size
440+
# Choose FFT bin closest to stim_frequency
441+
k = int(np.round(stim_frequency * N / trace_frequency))
442+
443+
# rfft length is N//2 + 1; clamp k into valid range
444+
kmax = N // 2
445+
k = int(np.clip(k, 0, kmax))
446+
X = np.fft.rfft(x)
447+
448+
# Handle edge cases
449+
if k == 0: # DC component
450+
return 0.0, 0.0
451+
452+
# Single-sided amplitude scaling: DC=|X0|/N; k>0: 2*|Xk|/N
453+
# At Nyquist (k == kmax), no 2x factor needed
454+
if k == kmax:
455+
amp = np.abs(X[k]) / N
456+
else:
457+
amp = (2.0 * np.abs(X[k])) / N
458+
459+
phase_deg = float(np.degrees(np.angle(X[k])))
460+
return float(amp), phase_deg

djimaging/tables/response/csl/naka_rushton_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def fit_naka_rushton_repeated(x_data: np.ndarray, y_data: np.ndarray, n: int = 3
106106
return popt_best
107107

108108

109-
def fit_naka_rushton(x_data: np.ndarray, y_data: np.ndarray, ax=None):
109+
def fit_naka_rushton(x_data: np.ndarray, y_data: np.ndarray, n: int = 3, ax=None):
110110
"""
111111
Fit Naka-Rushton function to data and estimate half-maximum metrics.
112112
This function will use a linear fit if the Naka-Rushton fit is not good.
@@ -117,14 +117,15 @@ def fit_naka_rushton(x_data: np.ndarray, y_data: np.ndarray, ax=None):
117117
Args:
118118
x_data (array-like): Stimulus data.
119119
y_data (array-like): Response data.
120+
n (int): Number of retries.
120121
ax (matplotlib.axes.Axes, optional): Axis for plotting. Defaults to None.
121122
122123
Returns:
123124
tuple: Half-maximum response, corresponding stimulus, and slope at that point.
124125
"""
125126
np.random.seed(42)
126127

127-
popt = fit_naka_rushton_repeated(x_data, y_data, n=3)
128+
popt = fit_naka_rushton_repeated(x_data, y_data, n=n)
128129
R0_fit, Rm_fit, S50_fit, n_fit = popt
129130

130131
# Compare MSE of linear fit

djimaging/tables/response/csl/sigmoid_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def fit_sigmoid_with_retry(x_data, y_data, max_tries=3):
2828
return p00
2929

3030

31-
def fit_sigmoid(y_data, x_data=None, sign=1, ax=None):
31+
def fit_sigmoid(y_data, x_data=None, sign=1, n: int = 3, ax=None):
3232
"""Fit sigmoid function to data, and estimate half amp"""
3333
np.random.seed(42)
3434

@@ -38,7 +38,7 @@ def fit_sigmoid(y_data, x_data=None, sign=1, ax=None):
3838
y_data = y_data * sign
3939

4040
# Fit sigmoid curve to the data
41-
popt = fit_sigmoid_with_retry(x_data, y_data, max_tries=3)
41+
popt = fit_sigmoid_with_retry(x_data, y_data, max_tries=n)
4242
x0_fit, k_fit, a_fit = popt
4343

4444
if (x0_fit > x_data.max()) | (x0_fit < x_data.min()):

0 commit comments

Comments
 (0)