Skip to content

Commit 7061902

Browse files
committed
Made tests use temporary directory. Added missing unit tests.
1 parent ea24c57 commit 7061902

1 file changed

Lines changed: 51 additions & 35 deletions

File tree

tests/test_trace_class.py

100644100755
Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,33 @@
1+
#!/usr/bin/env python
12
import json
23
import os
4+
import tempfile
35
import unittest
46

57
import matplotlib.pyplot as plt
68
import numpy as np
79
import pandas as pd
810

9-
from syncropatch_export.trace import Trace as tr
11+
from syncropatch_export.trace import Trace
1012
from syncropatch_export.voltage_protocols import VoltageProtocol
1113

1214

1315
class TestTraceClass(unittest.TestCase):
14-
def setUp(self):
15-
filepath = os.path.join('tests', 'test_data', '13112023_MW2_FF',
16-
'staircaseramp (2)_2kHz_15.01.07')
17-
json_file = "staircaseramp (2)_2kHz_15.01.07"
16+
"""
17+
Tests both the Trace and VoltageProtocol classes.
18+
"""
1819

19-
self.output_dir = os.path.join('test_output', 'test_trace_class')
2020

21-
if not os.path.exists(self.output_dir):
22-
os.makedirs(self.output_dir) # pragma: no-cover
23-
self.test_trace = tr(filepath, json_file)
21+
def setUp(self):
22+
f = 'staircaseramp (2)_2kHz_15.01.07'
23+
self.trace = Trace(
24+
os.path.join('tests', 'test_data', '13112023_MW2_FF', f), f)
2425

2526
def test_protocol_descriptions(self):
26-
voltages = self.test_trace.get_voltage()
27-
times = self.test_trace.get_times()
27+
voltages = self.trace.get_voltage()
28+
times = self.trace.get_times()
2829

29-
protocol_from_json = self.test_trace.get_voltage_protocol()
30+
protocol_from_json = self.trace.get_voltage_protocol()
3031
holding_potential = protocol_from_json.get_holding_potential()
3132
protocol_desc = VoltageProtocol.from_voltage_trace(voltages, times,
3233
holding_potential)
@@ -43,18 +44,17 @@ def test_protocol_descriptions(self):
4344
self.assertLess(v_error, 1e-4)
4445

4546
def test_protocol_export(self):
46-
protocol = self.test_trace.get_voltage_protocol()
47-
protocol.export_txt(os.path.join(self.output_dir, 'protocol.txt'))
48-
json_protocol = self.test_trace.get_voltage_protocol_json()
49-
50-
with open(os.path.join(self.output_dir, 'protocol.json'), 'w') as fin:
51-
json.dump(json_protocol, fin)
47+
with tempfile.TemporaryDirectory() as d:
48+
protocol = self.trace.get_voltage_protocol()
49+
protocol.export_txt(os.path.join(d, 'protocol.txt'))
50+
json_protocol = self.trace.get_voltage_protocol_json()
51+
with open(os.path.join(d, 'protocol.json'), 'w') as fin:
52+
json.dump(json_protocol, fin)
5253

5354
def test_protocol_timeseries(self):
54-
voltages = self.test_trace.get_voltage()
55-
times = self.test_trace.get_times()
56-
57-
voltage_protocol = self.test_trace.get_voltage_protocol()
55+
voltages = self.trace.get_voltage()
56+
times = self.trace.get_times()
57+
voltage_protocol = self.trace.get_voltage_protocol()
5858

5959
def voltage_func(t):
6060
for tstart, tend, vstart, vend in voltage_protocol.get_all_sections():
@@ -68,33 +68,46 @@ def voltage_func(t):
6868
for t, v in zip(times, voltages):
6969
self.assertLess(voltage_func(t) - v, 1e-3)
7070

71+
def test_protocol_get_step_start_times(self):
72+
a = list(self.trace.get_voltage_protocol().get_step_start_times())
73+
b = [0, 250, 300, 696, 896, 1896, 2396, 3396, 3896, 4396, 4896, 5396,
74+
5896, 6396, 6896, 7396, 7896, 8396, 8896, 9396, 9896, 10396,
75+
10896, 11396, 11896, 12396, 12896, 13896, 14396, 14406, 14502,
76+
14892]
77+
self.assertEqual(a, b)
78+
79+
def test_protocol_get_ramps(self):
80+
a = np.array(self.trace.get_voltage_protocol().get_ramps())
81+
b = np.array([[300, 696, -120, -80], [14406, 14502, -70, -110]])
82+
self.assertEqual(a.shape, b.shape)
83+
self.assertTrue(np.all(a == b))
84+
7185
def test_get_QC(self):
72-
tr = self.test_trace
73-
QC_values = tr.get_onboard_QC_values()
86+
QC_values = self.trace.get_onboard_QC_values()
7487
self.assertGreater(len(QC_values), 0)
75-
df = tr.get_onboard_QC_df()
88+
df = self.trace.get_onboard_QC_df()
7689

7790
self.assertGreater(df.shape[0], 0)
7891
self.assertGreater(df.shape[1], 0)
7992

8093
def test_get_traces(self):
81-
tr = self.test_trace
82-
v = tr.get_voltage()
83-
ts = tr.get_times()
84-
all_traces = tr.get_all_traces(leakcorrect=True)
85-
all_traces = tr.get_all_traces()
94+
v = self.trace.get_voltage()
95+
ts = self.trace.get_times()
96+
all_traces = self.trace.get_all_traces(leakcorrect=True)
97+
all_traces = self.trace.get_all_traces()
8698

8799
self.assertTrue(np.all(np.isfinite(v)))
88100
self.assertTrue(np.all(np.isfinite(ts)))
89101

90102
for well, trace in all_traces.items():
91103
self.assertTrue(np.all(np.isfinite(trace)))
92104

105+
'''
93106
if self.output_dir:
94107
# plot test output
95108
fig, (ax1, ax2) = plt.subplots(2, 1)
96109
ax1.set_title("Example Sweeps")
97-
some_sweeps = tr.get_trace_sweeps([0])['A01']
110+
some_sweeps = self.trace.get_trace_sweeps([0])['A01']
98111
99112
ax1.plot(ts, np.transpose(some_sweeps), color='grey', alpha=0.5)
100113
ax1.set_ylabel('Current')
@@ -104,13 +117,13 @@ def test_get_traces(self):
104117
ax2.set_ylabel('Voltage')
105118
ax2.set_xlabel('Time')
106119
plt.tight_layout()
107-
plt.savefig(os.path.join(self.output_dir,
108-
'example_trace'))
120+
plt.savefig(os.path.join(self.output_dir, 'example_trace'))
109121
plt.close(fig)
122+
'''
110123

111124
def test_qc_df(self):
112-
dfs = [self.test_trace.get_onboard_QC_df(sweeps=[0]),
113-
self.test_trace.get_onboard_QC_df(sweeps=None)]
125+
dfs = [self.trace.get_onboard_QC_df(sweeps=[0]),
126+
self.trace.get_onboard_QC_df(sweeps=None)]
114127
for res in dfs:
115128
# Check res is a pd.DataFrame
116129
self.assertIsInstance(res, pd.DataFrame)
@@ -125,3 +138,6 @@ def test_qc_df(self):
125138
# Check restricting number of sweeps returns less data
126139
self.assertLess(dfs[0].shape[0], dfs[1].shape[0])
127140

141+
142+
if __name__ == '__main__':
143+
unittest.main()

0 commit comments

Comments
 (0)