1+ #!/usr/bin/env python
12import json
23import os
4+ import tempfile
35import unittest
46
57import matplotlib .pyplot as plt
68import numpy as np
79import pandas as pd
810
9- from syncropatch_export .trace import Trace as tr
11+ from syncropatch_export .trace import Trace
1012from syncropatch_export .voltage_protocols import VoltageProtocol
1113
1214
1315class TestTraceClass (unittest .TestCase ):
14- def setUp (self ):
15- filepath = os .path .join ('tests' , 'test_data' , '13112023_MW2_FF' ,
16- 'staircaseramp (2)_2kHz_15.01.07' )
17- json_file = "staircaseramp (2)_2kHz_15.01.07"
16+ """
17+ Tests both the Trace and VoltageProtocol classes.
18+ """
1819
19- self .output_dir = os .path .join ('test_output' , 'test_trace_class' )
2020
21- if not os .path .exists (self .output_dir ):
22- os .makedirs (self .output_dir ) # pragma: no-cover
23- self .test_trace = tr (filepath , json_file )
21+ def setUp (self ):
22+ f = 'staircaseramp (2)_2kHz_15.01.07'
23+ self .trace = Trace (
24+ os .path .join ('tests' , 'test_data' , '13112023_MW2_FF' , f ), f )
2425
2526 def test_protocol_descriptions (self ):
26- voltages = self .test_trace .get_voltage ()
27- times = self .test_trace .get_times ()
27+ voltages = self .trace .get_voltage ()
28+ times = self .trace .get_times ()
2829
29- protocol_from_json = self .test_trace .get_voltage_protocol ()
30+ protocol_from_json = self .trace .get_voltage_protocol ()
3031 holding_potential = protocol_from_json .get_holding_potential ()
3132 protocol_desc = VoltageProtocol .from_voltage_trace (voltages , times ,
3233 holding_potential )
@@ -43,18 +44,17 @@ def test_protocol_descriptions(self):
4344 self .assertLess (v_error , 1e-4 )
4445
4546 def test_protocol_export (self ):
46- protocol = self . test_trace . get_voltage_protocol ()
47- protocol . export_txt ( os . path . join ( self .output_dir , 'protocol.txt' ) )
48- json_protocol = self . test_trace . get_voltage_protocol_json ( )
49-
50- with open (os .path .join (self . output_dir , 'protocol.json' ), 'w' ) as fin :
51- json .dump (json_protocol , fin )
47+ with tempfile . TemporaryDirectory () as d :
48+ protocol = self .trace . get_voltage_protocol ( )
49+ protocol . export_txt ( os . path . join ( d , 'protocol.txt' ) )
50+ json_protocol = self . trace . get_voltage_protocol_json ()
51+ with open (os .path .join (d , 'protocol.json' ), 'w' ) as fin :
52+ json .dump (json_protocol , fin )
5253
5354 def test_protocol_timeseries (self ):
54- voltages = self .test_trace .get_voltage ()
55- times = self .test_trace .get_times ()
56-
57- voltage_protocol = self .test_trace .get_voltage_protocol ()
55+ voltages = self .trace .get_voltage ()
56+ times = self .trace .get_times ()
57+ voltage_protocol = self .trace .get_voltage_protocol ()
5858
5959 def voltage_func (t ):
6060 for tstart , tend , vstart , vend in voltage_protocol .get_all_sections ():
@@ -68,33 +68,46 @@ def voltage_func(t):
6868 for t , v in zip (times , voltages ):
6969 self .assertLess (voltage_func (t ) - v , 1e-3 )
7070
71+ def test_protocol_get_step_start_times (self ):
72+ a = list (self .trace .get_voltage_protocol ().get_step_start_times ())
73+ b = [0 , 250 , 300 , 696 , 896 , 1896 , 2396 , 3396 , 3896 , 4396 , 4896 , 5396 ,
74+ 5896 , 6396 , 6896 , 7396 , 7896 , 8396 , 8896 , 9396 , 9896 , 10396 ,
75+ 10896 , 11396 , 11896 , 12396 , 12896 , 13896 , 14396 , 14406 , 14502 ,
76+ 14892 ]
77+ self .assertEqual (a , b )
78+
79+ def test_protocol_get_ramps (self ):
80+ a = np .array (self .trace .get_voltage_protocol ().get_ramps ())
81+ b = np .array ([[300 , 696 , - 120 , - 80 ], [14406 , 14502 , - 70 , - 110 ]])
82+ self .assertEqual (a .shape , b .shape )
83+ self .assertTrue (np .all (a == b ))
84+
7185 def test_get_QC (self ):
72- tr = self .test_trace
73- QC_values = tr .get_onboard_QC_values ()
86+ QC_values = self .trace .get_onboard_QC_values ()
7487 self .assertGreater (len (QC_values ), 0 )
75- df = tr .get_onboard_QC_df ()
88+ df = self . trace .get_onboard_QC_df ()
7689
7790 self .assertGreater (df .shape [0 ], 0 )
7891 self .assertGreater (df .shape [1 ], 0 )
7992
8093 def test_get_traces (self ):
81- tr = self .test_trace
82- v = tr .get_voltage ()
83- ts = tr .get_times ()
84- all_traces = tr .get_all_traces (leakcorrect = True )
85- all_traces = tr .get_all_traces ()
94+ v = self .trace .get_voltage ()
95+ ts = self .trace .get_times ()
96+ all_traces = self .trace .get_all_traces (leakcorrect = True )
97+ all_traces = self .trace .get_all_traces ()
8698
8799 self .assertTrue (np .all (np .isfinite (v )))
88100 self .assertTrue (np .all (np .isfinite (ts )))
89101
90102 for well , trace in all_traces .items ():
91103 self .assertTrue (np .all (np .isfinite (trace )))
92104
105+ '''
93106 if self.output_dir:
94107 # plot test output
95108 fig, (ax1, ax2) = plt.subplots(2, 1)
96109 ax1.set_title("Example Sweeps")
97- some_sweeps = tr .get_trace_sweeps ([0 ])['A01' ]
110+ some_sweeps = self.trace .get_trace_sweeps([0])['A01']
98111
99112 ax1.plot(ts, np.transpose(some_sweeps), color='grey', alpha=0.5)
100113 ax1.set_ylabel('Current')
@@ -104,13 +117,13 @@ def test_get_traces(self):
104117 ax2.set_ylabel('Voltage')
105118 ax2.set_xlabel('Time')
106119 plt.tight_layout()
107- plt .savefig (os .path .join (self .output_dir ,
108- 'example_trace' ))
120+ plt.savefig(os.path.join(self.output_dir, 'example_trace'))
109121 plt.close(fig)
122+ '''
110123
111124 def test_qc_df (self ):
112- dfs = [self .test_trace .get_onboard_QC_df (sweeps = [0 ]),
113- self .test_trace .get_onboard_QC_df (sweeps = None )]
125+ dfs = [self .trace .get_onboard_QC_df (sweeps = [0 ]),
126+ self .trace .get_onboard_QC_df (sweeps = None )]
114127 for res in dfs :
115128 # Check res is a pd.DataFrame
116129 self .assertIsInstance (res , pd .DataFrame )
@@ -125,3 +138,6 @@ def test_qc_df(self):
125138 # Check restricting number of sweeps returns less data
126139 self .assertLess (dfs [0 ].shape [0 ], dfs [1 ].shape [0 ])
127140
141+
142+ if __name__ == '__main__' :
143+ unittest .main ()
0 commit comments