Skip to content

Commit 1c490f7

Browse files
author
mdhumphries
committed
finish main scripts and some unit tests
1 parent 5ceeef0 commit 1c490f7

10 files changed

Lines changed: 160 additions & 25 deletions

Functions/Summaries_of_Beta_distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
def summaries_of_Beta_Distribution(Alpha, Beta, stat_type, *args):
66
if np.isnan(Alpha) or np.isnan(Beta):
7-
statistic = None
7+
statistic = np.nan
88
else:
99
if stat_type == 'MAP':
1010
x = np.arange(0, 1, 0.001)
Binary file not shown.
980 Bytes
Binary file not shown.
Binary file not shown.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Wed Aug 24 12:14:59 2022
5+
6+
@author: Mark Humphries
7+
"""
8+
import numpy as np
9+
from Functions.Summaries_of_Beta_distribution import summaries_of_Beta_Distribution
10+
11+
# trial_data is a dict of {Alpha,Beta,MAPprobability,Precision}
12+
# dataframe_row_of_prior_trials is a Pandas dataframe row of the previous trial
13+
# (or the empty dataframe on the first trial)
14+
15+
def interpolate_null_trials(dict_of_trial_data,dataframe_row_of_prior_trial,alpha_zero,beta_zero):
16+
17+
# if not NAN, then no need to interpolate: just assign the supdated alpha etc to the interpolated variables
18+
if not np.isnan(dict_of_trial_data['Alpha']):
19+
interpolated_values = {'Alpha_interpolated': dict_of_trial_data['Alpha'], 'Beta_interpolated': dict_of_trial_data['Beta'], \
20+
'MAPprobability_interpolated': dict_of_trial_data['MAPprobability'], 'Precision_interpolated': dict_of_trial_data['Precision']}
21+
22+
23+
elif dataframe_row_of_prior_trial.empty:
24+
# it is a null trial: check if dataframe row is empty - then is first trial, so assign priors
25+
MAPprobability = summaries_of_Beta_Distribution(alpha_zero,beta_zero,'MAP')
26+
precision = summaries_of_Beta_Distribution(alpha_zero,beta_zero,'precision')
27+
interpolated_values = {'Alpha_interpolated': alpha_zero, 'Beta_interpolated': beta_zero, \
28+
'MAPprobability_interpolated': MAPprobability, 'Precision_interpolated': precision}
29+
else:
30+
# it is a null trial, and the dataframe is not empty, so assign interpolated values of previous trial
31+
interpolated_values = {'Alpha_interpolated': dataframe_row_of_prior_trial.at['Alpha_interpolated'], 'Beta_interpolated': dataframe_row_of_prior_trial.at['Beta_interpolated'], \
32+
'MAPprobability_interpolated': dataframe_row_of_prior_trial.at['MAPprobability_interpolated'], 'Precision_interpolated': dataframe_row_of_prior_trial.at['Precision_interpolated']}
33+
34+
35+
dict_of_trial_data.update(interpolated_values) # add interpolated values to others from this trial
36+
return dict_of_trial_data # a dict of all results for this trial

Functions/update_strategy_posterior_probability.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12

23
def update_strategy_posterior_probability(trial_type, decay_rate, success_total, failure_total, alpha0, beta0):
34
if trial_type == "success":
@@ -11,8 +12,8 @@ def update_strategy_posterior_probability(trial_type, decay_rate, success_total,
1112
alpha = (alpha0 + success_total)
1213
beta = (beta0 + failure_total)
1314
else:
14-
alpha = None
15-
beta = None
15+
alpha = np.nan
16+
beta = np.nan
1617

1718
return success_total, failure_total, alpha, beta
1819

Replicate_Figure1.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from Functions.update_strategy_posterior_probability import update_strategy_posterior_probability
99
from Functions.Summaries_of_Beta_distribution import summaries_of_Beta_Distribution
1010
from Functions.plotSessionStructure import plotSessionStructure
11+
from Functions.interpolate_null_trials import interpolate_null_trials
1112

1213
# initiate TestData variable so that rat 2 testdata can be loaded
1314
TestData = pd.read_csv('data.csv')
1415

1516
# choose strategies to evaluate: subset shown in Figure 1
16-
strategies = ['go_left','go_right']
17+
strategies = ['go_left','go_right','go_cued','win_stay_spatial','lose_shift_cued','lose_shift_spatial']
1718

1819
# set prior
1920
prior_type = 'Uniform' #set prior type
@@ -27,7 +28,7 @@
2728
event_totals = {} # empty dict to store totals of events for each strategy
2829
# initialise dataframes
2930
for index_strategy in range(len(strategies)):
30-
Output_collection[strategies[index_strategy]] = pd.DataFrame(columns = ['Alpha', 'Beta', 'MAPprobability', 'Precision']) # empty Dataframe to input data into
31+
Output_collection[strategies[index_strategy]] = pd.DataFrame(columns = ['Alpha', 'Beta','MAPprobability', 'Precision','Alpha_interpolated', 'Beta_interpolated','MAPprobability_interpolated', 'Precision_interpolated']) # empty Dataframe to input data into
3132
event_totals[strategies[index_strategy]] = {}; # create empty dict for this strategy
3233
event_totals[strategies[index_strategy]]['success_total'] = 0;
3334
event_totals[strategies[index_strategy]]['failure_total'] = 0;
@@ -49,25 +50,53 @@
4950

5051
MAPprobability = summaries_of_Beta_Distribution(Alpha, Beta, 'MAP')
5152
precision = summaries_of_Beta_Distribution(Alpha, Beta, 'precision')
52-
53+
54+
55+
#%% interpolate null trials
56+
this_trials_data= {'Alpha':Alpha, 'Beta':Beta, 'MAPprobability':MAPprobability, 'Precision':precision} # create dict of current data to pass
57+
if trial > 0:
58+
previous_trials_data = Output_collection[strategies[index_strategy]].iloc[trial-1]
59+
else:
60+
previous_trials_data = Output_collection[strategies[index_strategy]] # pass empty dataframe
61+
62+
new_row_of_data = interpolate_null_trials(this_trials_data,previous_trials_data,alpha0,beta0)
63+
5364
# store results - dynamically-defined dataframe...
54-
new_row = {'Alpha':Alpha, 'Beta':Beta, 'MAPprobability':MAPprobability, 'Precision':precision} # create new row for dataframe as a dict
55-
new_df= pd.DataFrame([new_row]) # have to convert to dataframe to use concat!!
65+
new_df= pd.DataFrame([new_row_of_data]) # have to convert to dataframe to use concat!!
5666
Output_collection[strategies[index_strategy]] = pd.concat([Output_collection[strategies[index_strategy]], new_df], ignore_index=True) # add new row to dataframe
57-
67+
68+
5869
#%% plot results
5970
no_Trials = np.size(TestData.TrialIndex)
6071

61-
# plotting time series of MAPprobability
72+
# plotting time series of MAPprobability for Rule Strategies
6273
plt.figure(figsize=(10, 5))
6374
plt.plot(Output_collection['go_left'].MAPprobability, linewidth=0.75) # plots the time series
64-
plt.axis([0, no_Trials, 0, 1.25]) # establishes axis limits
75+
plt.plot(Output_collection['go_right'].MAPprobability, linewidth=0.75, color=(0.4, 0.8, 0.5)) # plots the time series
76+
plt.plot(Output_collection['go_cued'].MAPprobability, linewidth=0.75, color=(0.8,0.6,0.5)) # plots the time series
77+
plt.axis([0, no_Trials, 0, 1.25]) # establishes axis limits
6578
plt.xlabel('Trials'), plt.ylabel('P(Strategy)') # labelling the axis
66-
plt.axhline(y=0.5, color='firebrick', linestyle='--', linewidth=0.75,
67-
label="Chance") # shows the line at which Chance is exceeded
79+
plt.axhline(y=0.5, color='darkgrey', linewidth=0.75, label="Chance") # shows the line at which Chance is exceeded
6880

6981
plotSessionStructure(TestData)
70-
7182
plt.show()
83+
84+
# plotting Precision for the same three strategies (precision identical for go_left and go_right)
85+
plt.figure(figsize=(10, 5))
86+
plt.plot(Output_collection['go_left'].Precision, linewidth=0.75) # plots the time series
87+
plt.plot(Output_collection['go_cued'].Precision, linewidth=0.75, color=(0.8,0.6,0.5)) # plots the time series
88+
plt.xlabel('Trials'), plt.ylabel('Precision') # labelling the axis
89+
plt.show()
7290

73-
91+
92+
# plotting MAP probability for some exploratory strategies - use interpolated values
93+
plt.figure(figsize=(10, 5))
94+
plt.plot(Output_collection['lose_shift_cued'].MAPprobability_interpolated, linewidth=0.75, color=(1, 0.1, 0.6)) # plots the time series
95+
plt.plot(Output_collection['lose_shift_spatial'].MAPprobability_interpolated, linewidth=0.75, color=(0.8, 0.6, 0.5)) # plots the time series
96+
plt.plot(Output_collection['win_stay_spatial'].MAPprobability_interpolated, linewidth=0.75, color=(0.4,0.8,0.5)) # plots the time series
97+
plt.axis([0, no_Trials, 0, 1.25]) # establishes axis limits
98+
plt.xlabel('Trials'), plt.ylabel('P(Strategy)') # labelling the axis
99+
plt.axhline(y=0.5, color='darkgrey', linewidth=0.75, label="Chance") # shows the line at which Chance is exceeded
100+
101+
plotSessionStructure(TestData)
102+
plt.show()
-120 Bytes
Binary file not shown.

strategymodels.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def alternate(rows):
8484
# checks if the subject made a different choice on this trial from the previous obe
8585
nTrials = len(rows)
8686
# "at" selects the value at the row/column location in the dataframe
87-
if nTrials > 1 & rows.at[nTrials-1,'Choice'] != rows.at[nTrials-2,'Choice']: # check the current trial's choice
87+
if nTrials > 1 and rows.at[nTrials-1,'Choice'] != rows.at[nTrials-2,'Choice']: # check the current trial's choice
8888
trial_type = "success"
8989
else:
9090
trial_type = "failure"
@@ -99,11 +99,11 @@ def lose_shift_cued(rows):
9999
nTrials = len(rows)
100100
# "at" selects the value at the row/column location in the dataframe
101101
# check that the previous trial was not rewarded ('lose')
102-
if nTrials > 1 & rows.at[nTrials-2,'Reward'] == "no":
102+
if nTrials > 1 and rows.at[nTrials-2,'Reward'] == "no":
103103
# now check if the subject shifted their cued-based choice
104-
if rows.at[nTrials-2,'Choice'] == rows.at[nTrials-2,'CuePosition'] & rows.at[nTrials-1,'Choice'] != rows.at[nTrials-1,'CuePosition']:
104+
if rows.at[nTrials-2,'Choice'] == rows.at[nTrials-2,'CuePosition'] and rows.at[nTrials-1,'Choice'] != rows.at[nTrials-1,'CuePosition']:
105105
trial_type = "success" # shifted from cued to uncued choice
106-
elif rows.at[nTrials-2,'Choice'] != rows.at[nTrials-2,'CuePosition'] & rows.at[nTrials-1,'Choice'] == rows.at[nTrials-1,'CuePosition']:
106+
elif rows.at[nTrials-2,'Choice'] != rows.at[nTrials-2,'CuePosition'] and rows.at[nTrials-1,'Choice'] == rows.at[nTrials-1,'CuePosition']:
107107
trial_type = "success" # shifted from uncued to cued chpice
108108
else:
109109
trial_type = "failure"
@@ -119,7 +119,7 @@ def lose_shift_spatial(rows):
119119
nTrials = len(rows)
120120
# "at" selects the value at the row/column location in the dataframe
121121
# check that the previous trial was not rewarded ('lose')
122-
if nTrials > 1 & rows.at[nTrials-2,'Reward'] == "no":
122+
if nTrials > 1 and rows.at[nTrials-2,'Reward'] == "no":
123123
# now check if the subject shifted their spatial choice
124124
if rows.at[nTrials-1,'Choice'] != rows.at[nTrials-2,'Choice']:
125125
trial_type = "success"
@@ -134,7 +134,7 @@ def sticky(rows):
134134
# checks if the subject made the same choice on this trial as the previous one
135135
nTrials = len(rows)
136136
# "at" selects the value at the row/column location in the dataframe
137-
if nTrials > 1 & rows.at[nTrials-1,'Choice'] == rows.at[nTrials-2,'Choice']: # check the current trial's choice
137+
if nTrials > 1 and rows.at[nTrials-1,'Choice'] == rows.at[nTrials-2,'Choice']: # check the current trial's choice
138138
trial_type = "success"
139139
else:
140140
trial_type = "failure"
@@ -149,11 +149,11 @@ def win_stay_cued(rows):
149149
nTrials = len(rows)
150150
# "at" selects the value at the row/column location in the dataframe
151151
# check that the previous trial was rewarded ('win')
152-
if nTrials > 1 & rows.at[nTrials-2,'Reward'] == "yes":
152+
if nTrials > 1 and rows.at[nTrials-2,'Reward'] == "yes":
153153
# now check if the subject stayed with the cued-based choice
154-
if rows.at[nTrials-2,'Choice'] == rows.at[nTrials-2,'CuePosition'] & rows.at[nTrials-1,'Choice'] == rows.at[nTrials-1,'CuePosition']:
154+
if rows.at[nTrials-2,'Choice'] == rows.at[nTrials-2,'CuePosition'] and rows.at[nTrials-1,'Choice'] == rows.at[nTrials-1,'CuePosition']:
155155
trial_type = "success" # made the same cued choice
156-
elif rows.at[nTrials-2,'Choice'] != rows.at[nTrials-2,'CuePosition'] & rows.at[nTrials-1,'Choice'] != rows.at[nTrials-1,'CuePosition']:
156+
elif rows.at[nTrials-2,'Choice'] != rows.at[nTrials-2,'CuePosition'] and rows.at[nTrials-1,'Choice'] != rows.at[nTrials-1,'CuePosition']:
157157
trial_type = "success" # made the same uncued choice
158158
else:
159159
trial_type = "failure"
@@ -169,7 +169,7 @@ def win_stay_spatial(rows):
169169
nTrials = len(rows)
170170
# "at" selects the value at the row/column location in the dataframe
171171
# check that the previous trial was rewarded ('win')
172-
if nTrials > 1 & rows.at[nTrials-2,'Reward'] == "yes":
172+
if nTrials > 1 and rows.at[nTrials-2,'Reward'] == "yes":
173173
# now check if the subject stayed with the same spatial choice
174174
if rows.at[nTrials-1,'Choice'] == rows.at[nTrials-2,'Choice']:
175175
trial_type = "success"

unit_test_strategy_models.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Wed Aug 24 19:39:45 2022
5+
6+
Unit tests for the strategy models
7+
8+
@author: Mark Humphries
9+
"""
10+
11+
import pandas as pd
12+
import strategymodels
13+
14+
# make test dataframe
15+
test_data_for_rules = {'Choice' : ['left','left','left','left','right','right','right','right'],
16+
'CuePosition': ['left','right','left','right','left','right','left','right'],
17+
'Reward': ['yes','yes','no','no','yes','yes','no','no']}
18+
TestData = pd.DataFrame(test_data_for_rules)
19+
20+
# target results for rules
21+
22+
go_left_target = ['success','success','success','success','failure','failure','failure','failure']
23+
go_right_target = ['failure','failure','failure','failure','success','success','success','success']
24+
go_cued_target = ['success','failure','success','failure','failure','success','failure','success']
25+
go_uncued_target = ['failure','success','failure','success','success','failure','success','failure']
26+
27+
#%% run rule strategy models
28+
trial_type_left = []; trial_type_right = []; trial_type_cued = []; trial_type_uncued = [];
29+
for trial in range(len(TestData)):
30+
rows_of_data = TestData.iloc[0:trial+1] # select all rows of data up to the curren trial; is trial+1 as dataframe includes column row as row 0????
31+
trial_type_left.append(strategymodels.go_left(rows_of_data)) # test whether go-left was used
32+
trial_type_right.append(strategymodels.go_right(rows_of_data))
33+
trial_type_cued.append(strategymodels.go_cued(rows_of_data))
34+
trial_type_uncued.append(strategymodels.go_uncued(rows_of_data))
35+
36+
#%% did they pass the tests?
37+
print('Go left passed?')
38+
print(trial_type_left == go_left_target)
39+
40+
print('Go right passed?')
41+
print(trial_type_right == go_right_target)
42+
43+
print('Go cued passed?')
44+
print(trial_type_cued == go_cued_target)
45+
46+
print('Go uncued passed?')
47+
print(trial_type_uncued == go_uncued_target)
48+
49+
#%%
50+
# test Explore strategy models
51+
#
52+
53+
# extend test data to include all win-stay and lose-shift test cases
54+
test_data_for_rules['Choice'].append('right','left','right','right','left')
55+
test_data_for_rules['CuePosition'].append('right','right','left','left','left')
56+
test_data_for_rules['Reward'].append('yes','yes','yes','no','no')
57+
58+
TestData = pd.DataFrame(test_data_for_rules) # overwrite dataframe
59+
60+
# define target results for exploration rules (null, success, failure)
61+
62+
63+
# run explore strategy models
64+
trial_type_alternate = []; trial_type_sticky = []; trial_type_win_stay_spatial= [];
65+
trial_type_win_stay_cue = []; trial_type_lose_shift_spatial= []; trial_type_lose_shift_cue = [];
66+
67+
for trial in range(len(TestData)):
68+
rows_of_data = TestData.iloc[0:trial+1] # select all rows of data up to the curren trial; is trial+1 as dataframe includes column row as row 0????
69+

0 commit comments

Comments
 (0)