Skip to content

Commit 77049ae

Browse files
committed
fixing multiple strategy script
solved problems of: (a) calling functions dynamically [use getattr(()] (b) dynamic storage of results [dict of dataframes] (c) dynamic updates of scalars [dict of dicts - like a multi-level struct]
1 parent dd89f96 commit 77049ae

3 files changed

Lines changed: 100 additions & 8 deletions

File tree

Replicate_Figure1.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import numpy as np
22
import pandas as pd
3-
4-
53
import matplotlib.pyplot as plt
64

7-
from strategy_models.go_left import go_left
8-
from strategy_models.go_right import go_right
9-
from strategy_models.alternate import alternate
5+
import strategymodels
6+
7+
#from strategy_models.go_left import go_left
8+
#from strategy_models.go_right import go_right
9+
#from strategy_models.alternate import alternate
1010
from Functions.set_Beta_prior import set_priors
1111
from Functions.update_strategy_posterior_probability import update_strategy_posterior_probability
1212
from Functions.Summaries_of_Beta_distribution import summaries_of_Beta_Distribution
@@ -25,12 +25,66 @@
2525

2626
no_Trials = np.size(TestData.TrialIndex)
2727

28-
Output = pd.DataFrame(columns = ['Alpha', 'Beta', 'MAPprobability', 'Precision']) # empty Dataframe to input data into
2928

29+
#%% initialise storage
30+
Output_collection = {} # empty dict in which to store dataframes
31+
event_totals = {} # empty dict to store totals of events for each strategy
32+
# initialise dataframes
33+
for index_strategy in range(len(strategies)):
34+
Output_collection[strategies[index_strategy]] = pd.DataFrame(columns = ['Alpha', 'Beta', 'MAPprobability', 'Precision']) # empty Dataframe to input data into
35+
event_totals[strategies[index_strategy]] = {}; # create empty dict for this strategy
36+
event_totals[strategies[index_strategy]]['success_total'] = 0;
37+
event_totals[strategies[index_strategy]]['failure_total'] = 0;
38+
39+
#%%
3040
for trial in range(len(TestData)):
3141

32-
rows_of_data = TestData.iloc[0:trial+1] # select all rows of data up to the curren trial; is trial+1 as dataframe includes colu
42+
rows_of_data = TestData.iloc[0:trial+1] # select all rows of data up to the curren trial; is trial+1 as dataframe includes colu
43+
44+
#%%
3345
for index_strategy in range(len(strategies)):
3446
# run current strategy model on data up to current trial
47+
strategy_fcn = getattr(strategymodels,strategies[index_strategy]) # dynamically assign string as function to be called
48+
trial_type = strategy_fcn(rows_of_data) #call currently assigned function
49+
50+
# update probability of strategy
51+
[event_totals[strategies[index_strategy]]['success_total'], event_totals[strategies[index_strategy]]['failure_total'], Alpha, Beta] \
52+
= update_strategy_posterior_probability(trial_type, decay_rate,event_totals[strategies[index_strategy]]['success_total'], event_totals[strategies[index_strategy]]['failure_total'],alpha0, beta0)
53+
54+
MAPprobability = summaries_of_Beta_Distribution(Alpha, Beta, 'MAP')
55+
precision = summaries_of_Beta_Distribution(Alpha, Beta, 'precision')
56+
57+
# store results - dynamically-defined dataframe...
58+
new_row = {'Alpha':Alpha, 'Beta':Beta, 'MAPprobability':MAPprobability, 'Precision':precision} # create new row for dataframe as a dict
59+
new_df= pd.DataFrame([new_row]) # have to convert to dataframe to use concat!!
60+
Output_collection[strategies[index_strategy]] = pd.concat([Output_collection[strategies[index_strategy]], new_df], ignore_index=True) # add new row to dataframe
61+
62+
#%% plot results
63+
# plotting time series of MAPprobability
64+
plt.figure(figsize=(10, 5))
65+
plt.plot(Output_collection['go_right'].MAPprobability, linewidth=0.75) # plots the time series
66+
plt.axis([0, no_Trials, 0, 1.25]) # establishes axis limits
67+
plt.xlabel('Trials'), plt.ylabel('P(Strategy)') # labelling the axis
68+
plt.axhline(y=0.5, color='firebrick', linestyle='--', linewidth=0.75,
69+
label="Chance") # shows the line at which Chance is exceeded
70+
sessionLines = TestData[TestData['NewSessionTrials'] == 1].index # indices list when new session was started
71+
plt.vlines(sessionLines, 0, 1, colors='lightgray', linestyles='--', linewidth=0.75,
72+
label="New Sessions") # vertical lines indicate the new session trials
73+
ruleLines = np.array(TestData[TestData['RuleChangeTrials'] == 1].index) # indices list when new session was started
74+
ruleLines = np.insert(ruleLines, 0, 0) # sets array for x values of rule change
75+
76+
for x in ruleLines: # to get shade change for rule change and labels
77+
minx = x / no_Trials
78+
plt.axhspan(1, 1.25, xmin=minx, alpha=0.3, edgecolor='k') # change transparency for separation
79+
plt.axvline(x, 0.8, 1) # dividing lines
80+
81+
# creating labels
82+
plt.text(1, 1.125, "Right Arm", label='Go to the Right')
83+
plt.text(120, 1.125, "Lit Arm", label='Go to the Lit Arm')
84+
plt.text(225, 1.125, "Left Arm", label='Go to the Left')
85+
plt.text(330, 1.125, "Unlit Arm", label='Go to the Dark Arm')
86+
plt.text(150, 1.3, "Rule for Reward")
87+
plt.legend() # add legend
88+
plt.show()
3589

36-
# store results
90+
952 Bytes
Binary file not shown.

strategymodels.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Strategy models module
4+
5+
A set of functions that define each strategy model
6+
Each function takes a Pandas dataframe, where each row is the data for one trial,
7+
up to the current trial
8+
9+
The column names in the dataframe are used to find the values
10+
11+
Created on Tue Aug 23 12:42:13 2022
12+
13+
@author: lpzmdh
14+
"""
15+
16+
def go_left(rows):
17+
# checks if the subject chose the left option on this trial
18+
nTrials = len(rows);
19+
# "at" selects the value at the row/column location in the dataframe
20+
if rows.at[nTrials-1,'Choice'] == "left": # check the current trial's choice
21+
trial_type = "success"
22+
elif rows.at[nTrials-1,'Choice'] == "right":
23+
trial_type = "failure"
24+
return trial_type
25+
26+
27+
def go_right(rows):
28+
# checks if the subject chose the right-hand option on this trial
29+
nTrials = len(rows);
30+
# "at" selects the value at the row/column location in the dataframe
31+
if rows.at[nTrials-1,'Choice'] == "right": # check the current trial's choice
32+
trial_type = "success"
33+
elif rows.at[nTrials-1,'Choice'] == "left":
34+
trial_type = "failure"
35+
return trial_type
36+
37+
38+

0 commit comments

Comments
 (0)