1111import pandas as pd
1212import strategymodels
1313
14- # make test dataframe
15- test_data_for_rules = {'Choice' : ['left' ,'left' ,'left' ,'left' ,'right' ,'right' ,'right' ,'right' ],
16- 'CuePosition' : ['left' ,'right' ,'left' ,'right' ,'left' ,'right' ,'left' ,'right' ],
17- 'Reward' : ['yes' ,'yes' ,'no' ,'no' ,'yes' ,'yes' ,'no' ,'no' ]}
18- TestData = pd .DataFrame (test_data_for_rules )
14+ # load test dataframe
15+ # set na_filter = False to load "null" as strings
16+ UnitTestData = pd .read_csv ('UnitTestData.csv' ,na_filter = False )
1917
20- # target results for rules
21-
22- go_left_target = ['success' ,'success' ,'success' ,'success' ,'failure' ,'failure' ,'failure' ,'failure' ]
23- go_right_target = ['failure' ,'failure' ,'failure' ,'failure' ,'success' ,'success' ,'success' ,'success' ]
24- go_cued_target = ['success' ,'failure' ,'success' ,'failure' ,'failure' ,'success' ,'failure' ,'success' ]
25- go_uncued_target = ['failure' ,'success' ,'failure' ,'success' ,'success' ,'failure' ,'success' ,'failure' ]
18+ TestResults = pd .DataFrame ();
2619
2720#%% run rule strategy models
2821trial_type_left = []; trial_type_right = []; trial_type_cued = []; trial_type_uncued = [];
29- for trial in range (len (TestData )):
30- rows_of_data = TestData .iloc [0 :trial + 1 ] # select all rows of data up to the curren trial; is trial+1 as dataframe includes column row as row 0????
22+ for trial in range (len (UnitTestData )):
23+ rows_of_data = UnitTestData .iloc [0 :trial + 1 ] # select all rows of data up to the curren trial; is trial+1 as dataframe includes column row as row 0????
3124 trial_type_left .append (strategymodels .go_left (rows_of_data )) # test whether go-left was used
3225 trial_type_right .append (strategymodels .go_right (rows_of_data ))
3326 trial_type_cued .append (strategymodels .go_cued (rows_of_data ))
3427 trial_type_uncued .append (strategymodels .go_uncued (rows_of_data ))
3528
3629#%% did they pass the tests?
37- print ('Go left passed?' )
38- print (trial_type_left == go_left_target )
30+ TestResults ['go_left_result' ] = UnitTestData ['go_left' ].eq (trial_type_left )
31+ TestResults ['go_right_result' ] = UnitTestData ['go_right' ].eq (trial_type_right )
32+ TestResults ['go_cued_result' ] = UnitTestData ['go_cued' ].eq (trial_type_cued )
33+ TestResults ['go_uncued_result' ] = UnitTestData ['go_uncued' ].eq (trial_type_uncued )
3934
40- print ( 'Go right passed?' )
41- print (trial_type_right == go_right_target )
35+ # print to screen
36+ print ('Go left passed? ' + str ( TestResults [ 'go_left_result' ]. all ()) )
4237
43- print ('Go cued passed?' )
44- print ( trial_type_cued == go_cued_target )
45-
46- print ( 'Go uncued passed?' )
47- print (trial_type_uncued == go_uncued_target )
38+ print ('Go right passed? ' + str ( TestResults [ 'go_right_result' ]. all ()) )
39+
40+ print ( 'Go cued passed? ' + str (( TestResults [ 'go_cued_result' ]. all ())))
41+
42+ print ('Go uncued passed? ' + str ( TestResults [ 'go_uncued_result' ]. all ()))
4843
4944#%%
5045# test Explore strategy models
5146#
5247
53- # extend test data to include all win-stay and lose-shift test cases
54- test_data_for_rules ['Choice' ].append ('right' ,'left' ,'right' ,'right' ,'left' )
55- test_data_for_rules ['CuePosition' ].append ('right' ,'right' ,'left' ,'left' ,'left' )
56- test_data_for_rules ['Reward' ].append ('yes' ,'yes' ,'yes' ,'no' ,'no' )
57-
58- TestData = pd .DataFrame (test_data_for_rules ) # overwrite dataframe
48+ trial_type_alternate = []; trial_type_sticky = []; trial_type_win_stay_spatial = [];
49+ trial_type_win_stay_cued = []; trial_type_lose_shift_spatial = []; trial_type_lose_shift_cued = [];
50+
51+ for trial in range (len (UnitTestData )):
52+ rows_of_data = UnitTestData .iloc [0 :trial + 1 ] # select all rows of data up to the curren trial; is trial+1 as dataframe includes column row as row 0????
53+ trial_type_alternate .append (strategymodels .alternate (rows_of_data ))
54+ trial_type_sticky .append (strategymodels .sticky (rows_of_data ))
55+ trial_type_win_stay_spatial .append (strategymodels .win_stay_spatial (rows_of_data ))
56+ trial_type_win_stay_cued .append (strategymodels .win_stay_cued (rows_of_data ))
57+ trial_type_lose_shift_spatial .append (strategymodels .lose_shift_spatial (rows_of_data ))
58+ trial_type_lose_shift_cued .append (strategymodels .lose_shift_cued (rows_of_data ))
59+
60+ # did they pass test?
61+ TestResults ['alternate_result' ] = UnitTestData ['alternate' ].eq (trial_type_alternate )
62+ TestResults ['sticky_result' ] = UnitTestData ['sticky' ].eq (trial_type_sticky )
63+ TestResults ['win_stay_spatial_result' ] = UnitTestData ['win_stay_spatial' ].eq (trial_type_win_stay_spatial )
64+ TestResults ['win_stay_cued_result' ] = UnitTestData ['win_stay_cued' ].eq (trial_type_win_stay_cued )
65+ TestResults ['lose_shift_spatial_result' ] = UnitTestData ['lose_shift_spatial' ].eq (trial_type_lose_shift_spatial )
66+ TestResults ['lose_shift_cued_result' ] = UnitTestData ['lose_shift_cued' ].eq (trial_type_lose_shift_cued )
5967
60- # define target results for exploration rules (null, success, failure)
68+ # print to screen
69+ print ('Alternate passed? ' + str (TestResults ['alternate_result' ].all ()))
6170
71+ print ('Sticky passed? ' + str (TestResults ['sticky_result' ].all ()))
6272
63- # run explore strategy models
64- trial_type_alternate = []; trial_type_sticky = []; trial_type_win_stay_spatial = [];
65- trial_type_win_stay_cue = []; trial_type_lose_shift_spatial = []; trial_type_lose_shift_cue = [];
73+ print ( 'Win-stay-spatial passed? ' + str ( TestResults [ 'win_stay_spatial_result' ]. all ()))
74+
75+ print ( 'Win-stay-cued passed? ' + str ( TestResults [ 'win_stay_cued_result' ]. all ()))
6676
67- for trial in range ( len ( TestData )):
68- rows_of_data = TestData . iloc [ 0 : trial + 1 ] # select all rows of data up to the curren trial; is trial+1 as dataframe includes column row as row 0????
69-
77+ print ( 'Lose-shift-spatial passed? ' + str ( TestResults [ 'lose_shift_spatial_result' ]. all ()))
78+
79+ print ( 'Lose-shift-cued passed? ' + str ( TestResults [ 'lose_shift_cued_result' ]. all ()))
0 commit comments