Skip to content

Commit b135915

Browse files
author
mdhumphries
committed
all strategy models drafted
1 parent 6be9f83 commit b135915

19 files changed

Lines changed: 175 additions & 69 deletions
1.2 KB
Binary file not shown.

Functions/plotSessionStructure.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Wed Aug 24 11:48:26 2022
5+
6+
@author: Lowri Powell & Mark Humphries
7+
"""
8+
import matplotlib.pyplot as plt
9+
import numpy as np
10+
11+
12+
def plotSessionStructure(TestData):
13+
no_Trials = np.size(TestData.TrialIndex)
14+
15+
sessionLines = TestData[TestData['NewSessionTrials'] == 1].index # indices list when new session was started
16+
plt.vlines(sessionLines, 0, 1, colors='lightgray', linestyles='--', linewidth=0.75,
17+
label="New Sessions") # vertical lines indicate the new session trials
18+
ruleLines = np.array(TestData[TestData['RuleChangeTrials'] == 1].index) # indices list when new session was started
19+
ruleLines = np.insert(ruleLines, 0, 0) # sets array for x values of rule change
20+
21+
for x in ruleLines: # to get shade change for rule change and labels
22+
minx = x / no_Trials
23+
plt.axhspan(1, 1.25, xmin=minx, alpha=0.3, edgecolor='k') # change transparency for separation
24+
plt.axvline(x, 0.8, 1) # dividing lines
25+
26+
# creating labels
27+
plt.text(1, 1.125, "Right Arm", label='Go to the Right')
28+
plt.text(120, 1.125, "Lit Arm", label='Go to the Lit Arm')
29+
plt.text(225, 1.125, "Left Arm", label='Go to the Left')
30+
plt.text(330, 1.125, "Unlit Arm", label='Go to the Dark Arm')
31+
plt.text(150, 1.3, "Rule for Reward")
32+
plt.legend() # add legend

Replicate_Figure1.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44

55
import strategymodels
66

7-
#from strategy_models.go_left import go_left
8-
#from strategy_models.go_right import go_right
9-
#from strategy_models.alternate import alternate
107
from Functions.set_Beta_prior import set_priors
118
from Functions.update_strategy_posterior_probability import update_strategy_posterior_probability
129
from Functions.Summaries_of_Beta_distribution import summaries_of_Beta_Distribution
10+
from Functions.plotSessionStructure import plotSessionStructure
1311

1412
# initiate TestData variable so that rat 2 testdata can be loaded
1513
TestData = pd.read_csv('data.csv')
@@ -23,8 +21,6 @@
2321

2422
[alpha0, beta0] = set_priors(prior_type) # define priors
2523

26-
no_Trials = np.size(TestData.TrialIndex)
27-
2824

2925
#%% initialise storage
3026
Output_collection = {} # empty dict in which to store dataframes
@@ -41,7 +37,6 @@
4137

4238
rows_of_data = TestData.iloc[0:trial+1] # select all rows of data up to the curren trial; is trial+1 as dataframe includes colu
4339

44-
print(rows_of_data)
4540
#%%
4641
for index_strategy in range(len(strategies)):
4742
# run current strategy model on data up to current trial
@@ -61,31 +56,18 @@
6156
Output_collection[strategies[index_strategy]] = pd.concat([Output_collection[strategies[index_strategy]], new_df], ignore_index=True) # add new row to dataframe
6257

6358
#%% plot results
59+
no_Trials = np.size(TestData.TrialIndex)
60+
6461
# plotting time series of MAPprobability
6562
plt.figure(figsize=(10, 5))
66-
plt.plot(Output_collection['go_right'].MAPprobability, linewidth=0.75) # plots the time series
63+
plt.plot(Output_collection['go_left'].MAPprobability, linewidth=0.75) # plots the time series
6764
plt.axis([0, no_Trials, 0, 1.25]) # establishes axis limits
6865
plt.xlabel('Trials'), plt.ylabel('P(Strategy)') # labelling the axis
6966
plt.axhline(y=0.5, color='firebrick', linestyle='--', linewidth=0.75,
7067
label="Chance") # shows the line at which Chance is exceeded
71-
sessionLines = TestData[TestData['NewSessionTrials'] == 1].index # indices list when new session was started
72-
plt.vlines(sessionLines, 0, 1, colors='lightgray', linestyles='--', linewidth=0.75,
73-
label="New Sessions") # vertical lines indicate the new session trials
74-
ruleLines = np.array(TestData[TestData['RuleChangeTrials'] == 1].index) # indices list when new session was started
75-
ruleLines = np.insert(ruleLines, 0, 0) # sets array for x values of rule change
7668

77-
for x in ruleLines: # to get shade change for rule change and labels
78-
minx = x / no_Trials
79-
plt.axhspan(1, 1.25, xmin=minx, alpha=0.3, edgecolor='k') # change transparency for separation
80-
plt.axvline(x, 0.8, 1) # dividing lines
69+
plotSessionStructure(TestData)
8170

82-
# creating labels
83-
plt.text(1, 1.125, "Right Arm", label='Go to the Right')
84-
plt.text(120, 1.125, "Lit Arm", label='Go to the Lit Arm')
85-
plt.text(225, 1.125, "Left Arm", label='Go to the Left')
86-
plt.text(330, 1.125, "Unlit Arm", label='Go to the Dark Arm')
87-
plt.text(150, 1.3, "Rule for Reward")
88-
plt.legend() # add legend
8971
plt.show()
9072

9173

1.79 KB
Binary file not shown.
-415 Bytes
Binary file not shown.
-371 Bytes
Binary file not shown.
-420 Bytes
Binary file not shown.
-395 Bytes
Binary file not shown.

strategy_models/alternate.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

strategy_models/go_cued.py

Whitespace-only changes.

0 commit comments

Comments
 (0)