Skip to content

Commit 5ceeef0

Browse files
author
mdhumphries
committed
updated Demonstrate script
1 parent b135915 commit 5ceeef0

2 files changed

Lines changed: 47 additions & 67 deletions

File tree

Demonstrate_Bayesian_Strategy_Analysis.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
import numpy as np
55
import pandas as pd
66
import matplotlib.pyplot as plt
7+
import strategymodels
8+
79
from Functions.set_Beta_prior import set_priors
8-
from strategy_models.go_left import go_left
910
from Functions.update_strategy_posterior_probability import update_strategy_posterior_probability
1011
from Functions.Summaries_of_Beta_distribution import summaries_of_Beta_Distribution
12+
from Functions.plotSessionStructure import plotSessionStructure
1113

1214
# initiate TestData variable so that rat 2 testdata can be loaded
1315
TestData = pd.read_csv('data.csv')
@@ -18,15 +20,8 @@
1820

1921
decay_rate = 0.9 # Set Decay rate (gamma)
2022

21-
# MAIN LOOP: FOR EACH TRIAL TO UPDATE STRATEGY INDEX
22-
2323
no_Trials = np.size(TestData.TrialIndex)
2424

25-
Alpha = 0 # initialises storage
26-
Beta = 0
27-
MAPprobability = 0
28-
precision = 0
29-
3025
success_total = 0 # initialise variables to zero
3126
failure_total = 0
3227

@@ -35,7 +30,7 @@
3530
#%% run strategy analysis
3631
for trial in range(len(TestData)):
3732
rows_of_data = TestData.iloc[0:trial+1] # select all rows of data up to the curren trial; is trial+1 as dataframe includes column row as row 0????
38-
trial_type = go_left(rows_of_data) # test whether go-left was used
33+
trial_type = strategymodels.go_left(rows_of_data) # test whether go-left was used
3934

4035
[success_total, failure_total, Alpha, Beta] = update_strategy_posterior_probability(trial_type, decay_rate,
4136
success_total, failure_total,
@@ -58,22 +53,7 @@
5853
plt.xlabel('Trials'), plt.ylabel('P(Strategy)') # labelling the axis
5954
plt.axhline(y=0.5, color='firebrick', linestyle='--', linewidth=0.75,
6055
label="Chance") # shows the line at which Chance is exceeded
61-
sessionLines = TestData[TestData['NewSessionTrials'] == 1].index # indices list when new session was started
62-
plt.vlines(sessionLines, 0, 1, colors='lightgray', linestyles='--', linewidth=0.75,
63-
label="New Sessions") # vertical lines indicate the new session trials
64-
ruleLines = np.array(TestData[TestData['RuleChangeTrials'] == 1].index) # indices list when new session was started
65-
ruleLines = np.insert(ruleLines, 0, 0) # sets array for x values of rule change
6656

67-
for x in ruleLines: # to get shade change for rule change and labels
68-
minx = x / no_Trials
69-
plt.axhspan(1, 1.25, xmin=minx, alpha=0.3, edgecolor='k') # change transparency for separation
70-
plt.axvline(x, 0.8, 1) # dividing lines
57+
plotSessionStructure(TestData)
7158

72-
# creating labels
73-
plt.text(1, 1.125, "Right Arm", label='Go to the Right')
74-
plt.text(120, 1.125, "Lit Arm", label='Go to the Lit Arm')
75-
plt.text(225, 1.125, "Left Arm", label='Go to the Left')
76-
plt.text(330, 1.125, "Unlit Arm", label='Go to the Dark Arm')
77-
plt.text(150, 1.3, "Rule for Reward")
78-
plt.legend() # add legend
7959
plt.show()

Output.csv

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ Alpha,Beta,MAPprobability,Precision
218218
5.300203773019574,6.69979622580384,0.43,52.71712088773742
219219
4.870183395717617,7.129816603223456,0.387,53.911593626064956
220220
5.483165056145856,6.51683494290111,0.448,52.38872192443417
221-
6.034848550531271,5.9651514486109996,0.503,52.001754220140775
221+
6.034848550531271,5.9651514486109996,0.503,52.00175422014077
222222
6.531363695478143,5.4686363037499,0.553,52.41105902631707
223223
5.978227325930329,6.02177267337491,0.498,52.0006847441459
224224
5.480404593337297,6.5195954060374195,0.448,52.39291687376341
@@ -229,23 +229,23 @@ Alpha,Beta,MAPprobability,Precision
229229
6.18463410831974,5.815365891311036,0.518,52.04928742645532
230230
5.666170697487766,6.333829302179933,0.467,52.161471634341666
231231
6.1995536277389895,5.80044637196194,0.52,52.05758385728475
232-
5.679598264965091,6.320401734765746,0.468,52.14870677564352
232+
5.679598264965091,6.320401734765746,0.468,52.14870677564354
233233
6.211638438468582,5.788361561289172,0.521,52.06477845938611
234234
5.690474594621724,6.3095254051602545,0.46900000000000003,52.13875567695228
235235
6.221427135159552,5.77857286464423,0.522,52.07091766186058
236-
6.699284421643597,5.3007155781798065,0.5700000000000001,52.71605788558549
237-
7.129355979479238,4.870644020361826,0.613,53.90997800052084
236+
6.699284421643597,5.3007155781798065,0.5700000000000001,52.716057885585485
237+
7.129355979479238,4.870644020361826,0.613,53.90997800052083
238238
6.516420381531314,5.483579618325644,0.552,52.38809392501376
239-
6.964778343378183,5.035221656493079,0.596,53.38016987249572
239+
6.964778343378183,5.035221656493079,0.596,53.3801698724957
240240
7.368300509040364,4.631699490843772,0.637,54.852716516871105
241241
7.731470458136328,4.268529541759395,0.673,56.723817640635886
242242
8.058323412322695,3.9416765875834554,0.706,58.93593301760396
243-
8.352491071090427,3.64750892882511,0.735,61.445977948011674
243+
8.352491071090427,3.64750892882511,0.735,61.44597794801169
244244
8.617241963981384,3.382758035942599,0.762,64.21945599021568
245245
7.855517767583246,4.14448223234834,0.686,57.49906028610618
246-
8.16996599082492,3.8300340091135054,0.717,59.8250355502601
246+
8.16996599082492,3.8300340091135054,0.717,59.82503555026008
247247
8.452969391742428,3.5470306082021548,0.745,62.43550256847977
248-
8.707672452568186,3.2923275473819396,0.771,65.29812701163428
248+
8.707672452568186,3.2923275473819396,0.771,65.29812701163429
249249
8.936905207311368,3.0630947926437457,0.794,68.38459133925393
250250
9.143214686580231,2.8567853133793712,0.8140000000000001,71.66866629009075
251251
9.328893217922207,2.6711067820414343,0.833,75.12499781248921
@@ -258,59 +258,59 @@ Alpha,Beta,MAPprobability,Precision
258258
10.20071480656322,1.7992851934193939,0.92,101.9941428768469
259259
9.280643325906897,2.7193566740774546,0.8280000000000001,74.17568887377108
260260
9.452578993316209,2.547421006669709,0.845,77.74183783687629
261-
9.607321093984588,2.3926789060027382,0.861,81.43650328954836
262-
9.746588984586129,2.2534110154024645,0.875,85.2339821539392
261+
9.607321093984588,2.3926789060027382,0.861,81.43650328954834
262+
9.746588984586129,2.2534110154024645,0.875,85.23398215393918
263263
8.871930086127517,3.128069913862218,0.787,67.454555135619
264264
9.084737077514765,2.9152629224759963,0.808,70.68312669184965
265-
9.276263369763289,2.7237366302283963,0.8280000000000001,74.09137633391158
265+
9.276263369763289,2.7237366302283963,0.8280000000000001,74.0913763339116
266266
9.448637032786959,2.5513629672055567,0.845,77.65410719967797
267267
9.603773329508263,2.3962266704850013,0.86,81.34597062921584
268268
9.743395996557437,2.256604003436501,0.874,85.1412724255392
269-
9.869056396901694,2.130943603092851,0.887,89.01398778815889
269+
9.869056396901694,2.130943603092851,0.887,89.0139877881589
270270
9.982150757211524,2.0178492427835657,0.898,92.93793179675056
271271
10.083935681490372,1.9160643185052093,0.908,96.88704136032035
272272
10.175542113341335,1.8244578866546886,0.918,100.83572965321693
273273
10.257987902007201,1.7420120979892197,0.926,104.7592776310976
274-
9.332189111806482,2.6678108881902975,0.833,75.19124446235213
274+
9.332189111806482,2.6678108881902975,0.833,75.19124446235212
275275
9.498970200625834,2.501029799371268,0.85,78.79713965677631
276276
9.649073180563251,2.3509268194341413,0.865,82.52416405662672
277-
9.784165862506926,2.215834137490727,0.878,86.34651009702914
278-
9.905749276256234,2.0942507237416543,0.891,90.23807845854381
277+
9.784165862506926,2.215834137490727,0.878,86.34651009702911
278+
9.905749276256234,2.0942507237416543,0.891,90.23807845854378
279279
9.01517434863061,2.9848256513674887,0.802,69.56851705539694
280280
9.21365691376755,2.7863430862307403,0.8210000000000001,72.91875033841015
281281
9.392291222390796,2.607708777607666,0.839,76.43199675073127
282-
9.553062100151717,2.4469378998468994,0.855,80.08299297464482
283-
9.697755890136545,2.30224410986221,0.87,83.84616934721731
284-
9.82798030112289,2.172019698875989,0.883,87.6956007161701
282+
9.553062100151717,2.4469378998468994,0.855,80.0829929746448
283+
9.697755890136545,2.30224410986221,0.87,83.84616934721734
284+
9.82798030112289,2.172019698875989,0.883,87.69560071617009
285285
9.945182271010601,2.05481772898839,0.895,91.60512952893302
286286
10.050664043909542,1.9493359560895511,0.905,95.54861460825909
287287
10.145597639518588,1.854402360480596,0.915,99.50026517374877
288288
10.231037875566729,1.7689621244325364,0.923,103.43502356110339
289289
10.307934088010056,1.6920659119892827,0.931,107.32896346858584
290290
10.37714067920905,1.6228593207903543,0.9380000000000001,111.15967436062158
291-
10.439426611288146,1.560573388711319,0.9440000000000001,114.90660727681917
291+
10.439426611288146,1.560573388711319,0.9440000000000001,114.90660727681914
292292
10.495483950159333,1.5045160498401873,0.9500000000000001,118.55136273368429
293-
10.5459355551434,1.4540644448561686,0.9550000000000001,122.07790740279881
293+
10.5459355551434,1.4540644448561686,0.9550000000000001,122.07790740279879
294294
10.59134199962906,1.4086580003705516,0.9590000000000001,125.47271239115283
295295
10.632207799666155,1.3677922003334966,0.963,128.72481178164438
296296
10.66898701969954,1.331012980300147,0.967,131.8257852037574
297297
10.702088317729586,1.2979116822701322,0.97,134.76967229067563
298298
10.731879485956629,1.2681205140431189,0.973,137.55282977076868
299-
10.758691537360967,1.2413084626388071,0.976,140.17374360942387
300-
10.78282238362487,1.2171776163749264,0.978,142.632809153164
301-
10.804540145262383,1.1954598547374338,0.98,144.93209181124323
299+
10.758691537360967,1.2413084626388071,0.976,140.17374360942384
300+
10.78282238362487,1.2171776163749264,0.978,142.63280915316398
301+
10.804540145262383,1.1954598547374338,0.98,144.9320918112432
302302
10.824086130736145,1.1759138692636903,0.982,147.07507966950106
303303
10.84167751766253,1.1583224823373213,0.984,149.0664378101318
304-
10.857509765896278,1.1424902341035892,0.986,150.91177223717548
304+
10.857509765896278,1.1424902341035892,0.986,150.91177223717543
305305
9.87175878930665,2.12824121069323,0.887,89.10261722380483
306306
9.984582910375984,2.015417089623907,0.898,93.02742073562489
307307
10.086124619338387,1.9138753806615165,0.909,96.97680209535673
308-
10.177512157404548,1.8224878425953648,0.918,100.92518976117528
308+
10.177512157404548,1.8224878425953648,0.918,100.92518976117525
309309
10.259760941664094,1.7402390583358285,0.926,104.84788903569286
310-
10.333784847497686,1.6662151525022455,0.933,108.72147847681366
310+
10.333784847497686,1.6662151525022455,0.933,108.72147847681363
311311
10.400406362747917,1.599593637252021,0.9400000000000001,112.52418294349124
312312
10.460365726473126,1.539634273526819,0.9460000000000001,116.2362004044769
313-
10.514329153825814,1.485670846174137,0.9510000000000001,119.8399653191584
313+
10.514329153825814,1.485670846174137,0.9510000000000001,119.83996531915837
314314
10.562896238443233,1.4371037615567233,0.9560000000000001,123.32033750054764
315315
10.60660661459891,1.393393385401051,0.961,126.66471145234812
316316
10.64594595313902,1.3540540468609459,0.965,129.8630467869815
@@ -338,21 +338,21 @@ Alpha,Beta,MAPprobability,Precision
338338
10.965133787685096,1.0348662123149013,0.996,164.97103598462877
339339
10.968620408916586,1.0313795910834112,0.997,165.47611036668067
340340
10.971758368024927,1.02824163197507,0.997,165.9336345968057
341-
10.974582531222435,1.025417468777563,0.997,166.34782385825866
342-
10.977124278100192,1.0228757218998068,0.998,166.72256859628868
341+
10.974582531222435,1.025417468777563,0.997,166.34782385825864
342+
10.977124278100192,1.0228757218998068,0.998,166.72256859628865
343343
10.979411850290173,1.0205881497098261,0.998,167.0614501090752
344344
10.981470665261156,1.0185293347388433,0.998,167.36775740035435
345345
10.98332359873504,1.0166764012649592,0.998,167.64450472953752
346346
10.984991238861538,1.0150087611384633,0.998,167.89444941051255
347347
10.986492114975384,1.0135078850246169,0.999,168.120109508795
348-
10.987842903477846,1.0121570965221551,0.999,168.3237811696369
349-
10.989058613130062,1.0109413868699397,0.999,168.50755537867585
350-
10.990152751817057,1.0098472481829457,0.999,168.67333401337868
348+
10.987842903477846,1.0121570965221551,0.999,168.32378116963687
349+
10.989058613130062,1.0109413868699397,0.999,168.50755537867582
350+
10.990152751817057,1.0098472481829457,0.999,168.67333401337865
351351
10.991137476635352,1.008862523364651,0.999,168.82284508951886
352352
10.992023728971818,1.007976271028186,0.999,168.9576571437698
353353
10.992821356074636,1.0071786439253674,0.999,169.07919272259343
354354
10.993539220467172,1.0064607795328306,0.999,169.18874097022754
355-
10.994185298420454,1.0058147015795476,0.999,169.2874693258358
355+
10.994185298420454,1.0058147015795476,0.999,169.28746932583576
356356
10.99476676857841,1.0052332314215928,0.999,169.3764343527415
357357
10.99529009172057,1.0047099082794335,0.999,169.45659173195799
358358
10.995761082548514,1.0042389174514903,0.999,169.52880545864426
@@ -364,25 +364,25 @@ Alpha,Beta,MAPprobability,Precision
364364
10.997747265470666,1.0022527345293375,0.999,169.83408643350313
365365
10.9979725389236,1.0020274610764037,0.999,169.86878860839823
366366
10.99817528503124,1.0018247149687634,0.999,169.90003403124143
367-
10.998357756528115,1.001642243471887,0.999,169.92816582544694
367+
10.998357756528115,1.001642243471887,0.999,169.9281658254469
368368
10.998521980875305,1.0014780191246984,0.999,169.9534932851317
369-
10.998669782787774,1.0013302172122285,0.999,169.97629516675184
369+
10.998669782787774,1.0013302172122285,0.999,169.9762951667518
370370
10.998802804508996,1.0011971954910055,0.999,169.99682266879097
371371
10.998922524058097,1.001077475941905,0.999,170.01530212745763
372372
10.999030271652288,1.0009697283477146,0.999,170.03193745416306
373373
10.99912724448706,1.000872755512943,0.999,170.04691233846205
374374
10.999214520038354,1.0007854799616487,0.999,170.06039223817461
375375
10.999293068034518,1.000706931965484,0.999,170.07252617656127
376-
10.999363761231066,1.0006362387689356,0.999,170.08344836470027
377-
10.99942738510796,1.000572614892042,0.999,170.09327966561696
376+
10.999363761231066,1.0006362387689356,0.999,170.08344836470025
377+
10.99942738510796,1.000572614892042,0.999,170.09327966561693
378378
10.999484646597164,1.0005153534028377,0.999,170.10212891523753
379379
10.999536181937447,1.000463818062554,0.999,170.11009411387093
380-
10.999582563743703,1.0004174362562985,0.999,170.1172635006704
381-
10.999624307369332,1.0003756926306688,0.999,170.12371652237363
382-
10.9996618766324,1.0003381233676019,0.999,170.12952470656782
380+
10.999582563743703,1.0004174362562985,0.999,170.11726350067036
381+
10.999624307369332,1.0003756926306688,0.999,170.1237165223736
382+
10.9996618766324,1.0003381233676019,0.999,170.1295247065678
383383
10.99969568896916,1.0003043110308416,0.999,170.13475244876065
384-
10.999726120072244,1.0002738799277575,0.999,170.1394577216639
385-
10.99975350806502,1.0002464919349818,0.999,170.14369271429243
384+
10.999726120072244,1.0002738799277575,0.999,170.13945772166386
385+
10.99975350806502,1.0002464919349818,0.999,170.1436927142924
386386
10.999778157258518,1.0002218427414835,0.999,170.1475044077573
387387
10.999800341532668,1.0001996584673352,0.999,170.15093509396803
388388
10.999820307379402,1.0001796926206017,0.999,170.1540228428612

0 commit comments

Comments
 (0)