Skip to content

Commit 0c5d875

Browse files
author
mdhumphries
committed
finished unit testing
1 parent 1c490f7 commit 0c5d875

4 files changed

Lines changed: 68 additions & 39 deletions

File tree

UnitTestData.csv

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
Choice,CuePosition,Reward,go_left,go_right,go_cued,go_uncued,sticky,alternate,win_stay_spatial,win_stay_cued,lose_shift_spatial,lose_shift_cued
2+
left,left,yes,success,failure,success,failure,null,null,null,null,null,null
3+
left,right,yes,success,failure,failure,success,success,failure,success,failure,null,null
4+
left,left,no,success,failure,success,failure,success,failure,success,failure,null,null
5+
left,right,no,success,failure,failure,success,success,failure,null,null,failure,success
6+
right,left,yes,failure,success,failure,success,failure,success,null,null,success,failure
7+
right,right,yes,failure,success,success,failure,success,failure,success,failure,null,null
8+
right,left,no,failure,success,failure,success,success,failure,success,failure,null,null
9+
right,right,no,failure,success,success,failure,success,failure,null,null,failure,success
10+
right,right,yes,failure,success,success,failure,success,failure,null,null,failure,failure
11+
left,right,yes,success,failure,failure,success,failure,success,failure,failure,null,null
12+
right,left,yes,failure,success,failure,success,failure,success,failure,success,null,null
13+
right,left,no,failure,success,failure,success,success,failure,success,success,null,null
14+
left,left,no,success,failure,success,failure,failure,success,null,null,success,success
41 Bytes
Binary file not shown.

strategymodels.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,13 @@ def alternate(rows):
8484
# checks if the subject made a different choice on this trial from the previous obe
8585
nTrials = len(rows)
8686
# "at" selects the value at the row/column location in the dataframe
87-
if nTrials > 1 and rows.at[nTrials-1,'Choice'] != rows.at[nTrials-2,'Choice']: # check the current trial's choice
87+
if nTrials == 1:
88+
trial_type = "null" # undefined on first trial
89+
elif nTrials > 1 and rows.at[nTrials-1,'Choice'] != rows.at[nTrials-2,'Choice']: # check the current trial's choice
8890
trial_type = "success"
89-
else:
91+
else:
9092
trial_type = "failure"
93+
9194
return trial_type
9295

9396

@@ -134,7 +137,9 @@ def sticky(rows):
134137
# checks if the subject made the same choice on this trial as the previous one
135138
nTrials = len(rows)
136139
# "at" selects the value at the row/column location in the dataframe
137-
if nTrials > 1 and rows.at[nTrials-1,'Choice'] == rows.at[nTrials-2,'Choice']: # check the current trial's choice
140+
if nTrials == 1:
141+
trial_type = "null" # undefined on first trial
142+
elif nTrials > 1 and rows.at[nTrials-1,'Choice'] == rows.at[nTrials-2,'Choice']: # check the current trial's choice
138143
trial_type = "success"
139144
else:
140145
trial_type = "failure"
@@ -143,7 +148,7 @@ def sticky(rows):
143148

144149

145150
def win_stay_cued(rows):
146-
# checks if the subject made the same cued choice on this trial after being rewarded on the previous one
151+
# checks if the subject made the same cue-driven choice on this trial after being rewarded on the previous one
147152
trial_type = "null" # default is that this trial does not meet criterion for win-stay
148153

149154
nTrials = len(rows)

unit_test_strategy_models.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,59 +11,69 @@
1111
import pandas as pd
1212
import strategymodels
1313

14-
# make test dataframe
15-
test_data_for_rules = {'Choice' : ['left','left','left','left','right','right','right','right'],
16-
'CuePosition': ['left','right','left','right','left','right','left','right'],
17-
'Reward': ['yes','yes','no','no','yes','yes','no','no']}
18-
TestData = pd.DataFrame(test_data_for_rules)
14+
# load test dataframe
15+
# set na_filter = False to load "null" as strings
16+
UnitTestData = pd.read_csv('UnitTestData.csv',na_filter=False)
1917

20-
# target results for rules
21-
22-
go_left_target = ['success','success','success','success','failure','failure','failure','failure']
23-
go_right_target = ['failure','failure','failure','failure','success','success','success','success']
24-
go_cued_target = ['success','failure','success','failure','failure','success','failure','success']
25-
go_uncued_target = ['failure','success','failure','success','success','failure','success','failure']
18+
TestResults = pd.DataFrame();
2619

2720
#%% run rule strategy models
2821
trial_type_left = []; trial_type_right = []; trial_type_cued = []; trial_type_uncued = [];
29-
for trial in range(len(TestData)):
30-
rows_of_data = TestData.iloc[0:trial+1] # select all rows of data up to the curren trial; is trial+1 as dataframe includes column row as row 0????
22+
for trial in range(len(UnitTestData)):
23+
rows_of_data = UnitTestData.iloc[0:trial+1] # select all rows of data up to the curren trial; is trial+1 as dataframe includes column row as row 0????
3124
trial_type_left.append(strategymodels.go_left(rows_of_data)) # test whether go-left was used
3225
trial_type_right.append(strategymodels.go_right(rows_of_data))
3326
trial_type_cued.append(strategymodels.go_cued(rows_of_data))
3427
trial_type_uncued.append(strategymodels.go_uncued(rows_of_data))
3528

3629
#%% did they pass the tests?
37-
print('Go left passed?')
38-
print(trial_type_left == go_left_target)
30+
TestResults['go_left_result'] = UnitTestData['go_left'].eq(trial_type_left)
31+
TestResults['go_right_result'] = UnitTestData['go_right'].eq(trial_type_right)
32+
TestResults['go_cued_result'] = UnitTestData['go_cued'].eq(trial_type_cued)
33+
TestResults['go_uncued_result'] = UnitTestData['go_uncued'].eq(trial_type_uncued)
3934

40-
print('Go right passed?')
41-
print(trial_type_right == go_right_target)
35+
# print to screen
36+
print('Go left passed? ' + str(TestResults['go_left_result'].all()))
4237

43-
print('Go cued passed?')
44-
print(trial_type_cued == go_cued_target)
45-
46-
print('Go uncued passed?')
47-
print(trial_type_uncued == go_uncued_target)
38+
print('Go right passed? ' + str(TestResults['go_right_result'].all()))
39+
40+
print('Go cued passed? ' + str((TestResults['go_cued_result'].all())))
41+
42+
print('Go uncued passed? ' + str(TestResults['go_uncued_result'].all()))
4843

4944
#%%
5045
# test Explore strategy models
5146
#
5247

53-
# extend test data to include all win-stay and lose-shift test cases
54-
test_data_for_rules['Choice'].append('right','left','right','right','left')
55-
test_data_for_rules['CuePosition'].append('right','right','left','left','left')
56-
test_data_for_rules['Reward'].append('yes','yes','yes','no','no')
57-
58-
TestData = pd.DataFrame(test_data_for_rules) # overwrite dataframe
48+
trial_type_alternate = []; trial_type_sticky = []; trial_type_win_stay_spatial= [];
49+
trial_type_win_stay_cued = []; trial_type_lose_shift_spatial= []; trial_type_lose_shift_cued= [];
50+
51+
for trial in range(len(UnitTestData)):
52+
rows_of_data = UnitTestData.iloc[0:trial+1] # select all rows of data up to the curren trial; is trial+1 as dataframe includes column row as row 0????
53+
trial_type_alternate.append(strategymodels.alternate(rows_of_data))
54+
trial_type_sticky.append(strategymodels.sticky(rows_of_data))
55+
trial_type_win_stay_spatial.append(strategymodels.win_stay_spatial(rows_of_data))
56+
trial_type_win_stay_cued.append(strategymodels.win_stay_cued(rows_of_data))
57+
trial_type_lose_shift_spatial.append(strategymodels.lose_shift_spatial(rows_of_data))
58+
trial_type_lose_shift_cued.append(strategymodels.lose_shift_cued(rows_of_data))
59+
60+
# did they pass test?
61+
TestResults['alternate_result'] = UnitTestData['alternate'].eq(trial_type_alternate)
62+
TestResults['sticky_result'] = UnitTestData['sticky'].eq(trial_type_sticky)
63+
TestResults['win_stay_spatial_result'] = UnitTestData['win_stay_spatial'].eq(trial_type_win_stay_spatial)
64+
TestResults['win_stay_cued_result'] = UnitTestData['win_stay_cued'].eq(trial_type_win_stay_cued)
65+
TestResults['lose_shift_spatial_result'] = UnitTestData['lose_shift_spatial'].eq(trial_type_lose_shift_spatial)
66+
TestResults['lose_shift_cued_result'] = UnitTestData['lose_shift_cued'].eq(trial_type_lose_shift_cued)
5967

60-
# define target results for exploration rules (null, success, failure)
68+
# print to screen
69+
print('Alternate passed? ' + str(TestResults['alternate_result'].all()))
6170

71+
print('Sticky passed? ' + str(TestResults['sticky_result'].all()))
6272

63-
# run explore strategy models
64-
trial_type_alternate = []; trial_type_sticky = []; trial_type_win_stay_spatial= [];
65-
trial_type_win_stay_cue = []; trial_type_lose_shift_spatial= []; trial_type_lose_shift_cue = [];
73+
print('Win-stay-spatial passed? ' + str(TestResults['win_stay_spatial_result'].all()))
74+
75+
print('Win-stay-cued passed? ' + str(TestResults['win_stay_cued_result'].all()))
6676

67-
for trial in range(len(TestData)):
68-
rows_of_data = TestData.iloc[0:trial+1] # select all rows of data up to the curren trial; is trial+1 as dataframe includes column row as row 0????
69-
77+
print('Lose-shift-spatial passed? ' + str(TestResults['lose_shift_spatial_result'].all()))
78+
79+
print('Lose-shift-cued passed? ' + str(TestResults['lose_shift_cued_result'].all()))

0 commit comments

Comments
 (0)