Skip to content

Commit 0635ea3

Browse files
Update documentation for amortized inference example
1 parent 5858fdd commit 0635ea3

1 file changed

Lines changed: 244 additions & 42 deletions

File tree

docs/src/neuralestimators_amorized.md

Lines changed: 244 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ Below, we demonstrate how to estimate parameters of the LCA model using the [Neu
66

77
## Example
88

9-
We'll estimate parameters of the LCA model, which is particularly challenging due to its complex dynamics, where parameters like leak rate (λ) and lateral inhibition (β) can be difficult to recover. This example draws from a more in-depth case that highlights many of the steps one ought to consider when utilizing amortized inference for cognitive modeling; see [Principled Amortized Bayesian Workflow for Cognitive Modeling](https://bayesflow.org/stable-legacy/_examples/Amortized_Bayesian_Workflow_for_Cognitive_Modeling.html).
9+
We'll estimate parameters of the LCA model, which is particularly challenging due to its complex dynamics, where parameters like leak rate (λ) and lateral inhibition (β) can be difficult to recover. This example draws from a more in-depth case that highlights many of the steps one ought to consider when utilizing amortized inference for cognitive modeling; see [Principled Amortized Bayesian Workflow for Cognitive Modeling](https://bayesflow.org/stable-legacy/_examples/LCA_Model_Posterior_Estimation.html).
1010

1111
## Load Packages
1212

13-
```@example
13+
```julia
1414
using NeuralEstimators
1515
using SequentialSamplingModels
1616
using Flux
@@ -22,30 +22,19 @@ using Plots
2222
Random.seed!(123)
2323
```
2424

25-
## Define Parameter Bounds
26-
27-
```@example
28-
# Define parameter bounds for the LCA model
29-
const ν_min, ν_max = 0.1, 4.0 # Drift rates
30-
const α_min, α_max = 0.5, 3.5 # Threshold
31-
const β_min, β_max = 0.0, 0.5 # Lateral inhibition
32-
const λ_min, λ_max = 0.0, 0.5 # Leak rate
33-
const τ_min, τ_max = 0.1, 0.5 # Non-decision time
34-
```
35-
3625
## Define Parameter Sampling
3726

38-
Unlike traditional Bayesian approaches, simulation based inference methods require us to define a prior sampling function to generate training data. We will use this function to sample a range of parameters for training:
27+
Unlike traditional Bayesian inference methods, simulation-based inference approaches require us to define a prior sampling function specifically to generate synthetic training data. While traditional methods like MCMC also sample from the prior, those samples are used directly during inference rather than to create a separate training dataset. In SBI, we use the prior to sample a wide range of parameters and simulate corresponding data, which we then use to train a model (e.g., a neural network) to approximate the posterior. We will use the following function to sample a range of parameters for training:
3928

40-
```@example
29+
```julia
4130
# Function to sample parameters from priors
4231
function sample(K::Integer)
43-
ν1 = rand(Gamma(2, 1/1.2), K) # Drift rate 1
44-
ν2 = rand(Gamma(2, 1/1.2), K) # Drift rate 2
45-
α = rand(Gamma(10, 1/6), K) # Threshold
46-
β = rand(Beta(1, 5), K) # Lateral inhibition
47-
λ = rand(Beta(1, 5), K) # Leak rate
48-
τ = rand(Gamma(1.5, 1/5.0), K) # Non-decision time
32+
ν1 = rand(Gamma(2, 1/1.2f0), K) # Drift rate 1
33+
ν2 = rand(Gamma(2, 1/1.2f0), K) # Drift rate 2
34+
α = rand(Gamma(10, 1/6f0), K) # Threshold
35+
β = rand(Beta(1, 5f0), K) # Lateral inhibition
36+
λ = rand(Beta(1, 5f0), K) # Leak rate
37+
τ = rand(Gamma(1.5, 1/5.0f0), K) # Non-decision time
4938

5039
# Stack parameters into a matrix (d×K)
5140
θ = vcat(ν1', ν2', α', β', λ', τ')
@@ -58,7 +47,7 @@ end
5847

5948
Neural estimators learn the mapping from data to parameters through simulation. Here we define a function to simulate LCA model data. To do so we will use the [LCA](https://itsdfish.github.io/SequentialSamplingModels.jl/dev/lca/).
6049

61-
```@example
50+
```julia
6251
# Function to simulate data from the LCA model
6352
function simulate(θ, n_trials_per_param)
6453
# Simulate data for each parameter vector
@@ -74,7 +63,7 @@ function simulate(θ, n_trials_per_param)
7463
choices, rts = rand(model, n_trials_per_param)
7564

7665
# Return as a transpose matrix where each column is a trial
77-
return Float32.([choices rts]')
66+
return [choices rts]'
7867
end
7968

8069
return simulated_data
@@ -83,11 +72,24 @@ end
8372

8473
## Define Neural Network Architecture
8574

86-
For LCA parameter recovery, we use a DeepSet architecture which respects the permutation invariance of trial data. For more details on the method see NeuralEstimators.jl documentation. To construct the network architecture we will use the Flux.jl package.
75+
For LCA parameter recovery, we use a DeepSet architecture which respects the permutation invariance of trial data. For more details on the method [see NeuralEstimators.jl documentation](https://msainsburydale.github.io/NeuralEstimators.jl/dev/API/architectures/#NeuralEstimators.DeepSet). To construct the network architecture we will use the Flux.jl package.
8776

88-
```@example
77+
```julia
8978
# Create neural network architecture for parameter recovery
90-
function create_neural_estimator()
79+
function create_neural_estimator(;
80+
ν_bounds = (0.1, 4.0),
81+
α_bounds = (0.5, 3.5),
82+
β_bounds = (0.0, 0.5),
83+
λ_bounds = (0.0, 0.5),
84+
τ_bounds = (0.1, 0.5)
85+
)
86+
# Unpack defined parameter Bounds
87+
ν_min, ν_max = ν_bounds # Drift rates
88+
α_min, α_max = α_bounds # Threshold
89+
β_min, β_max = β_bounds # Lateral inhibition
90+
λ_min, λ_max = λ_bounds # Leak rate
91+
τ_min, τ_max = τ_bounds # Non-decision time
92+
9193
# Input dimension: 2 (choice and RT for each trial)
9294
n = 2
9395
# Output dimension: 6 parameters
@@ -132,25 +134,25 @@ end
132134

133135
## Training the Neural Estimator
134136

135-
Neural estimators, like all deep learning methods, require a training phase where they learn the mapping from data to parameters. Here we will train the estimator, simulating data as we go where the sampler provides new parameter vectors from the prior, and a simulator can be provided to continuously simulate new data conditional on the parameters. For more details on the training see the API for arguments [here](https://msainsburydale.github.io/NeuralEstimators.jl/dev/API/core/#Training).
137+
Neural estimators, like all deep learning methods, require a training phase during which they learn the mapping from data to parameters. Here, we train the estimator by simulating data on the fly: the sampler provides new parameter vectors from the prior, and the simulator generates corresponding data conditional on those parameters. Since we use online training and the network never sees the same simulated dataset twice, overfitting is less likely. For more details on training, see the API for arguments [here](https://msainsburydale.github.io/NeuralEstimators.jl/dev/API/core/#Training).
136138

137-
```@example
139+
```julia
138140
# Create the neural estimator
139141
estimator = create_neural_estimator()
140142

141143
# Train network
142144
trained_estimator = train(
143145
estimator,
144-
sample, # Parameter sampler function
145-
simulate, # Data simulator function
146-
m = 100, # Number of trials per parameter vector
147-
K = 10000, # Number of training parameter vectors
148-
K_val = 2000, # Number of validation parameter vectors
149-
loss = Flux.mae, # Mean absolute error loss
150-
epochs = 50, # Number of training epochs
151-
epochs_per_Z_refresh = 1, # Refresh data every epoch
152-
epochs_per_θ_refresh = 5, # Refresh parameters every 5 epochs
153-
batchsize = 64, # Batch size for training
146+
sample, # Parameter sampler function
147+
simulate, # Data simulator function
148+
m = 100, # Number of trials per parameter vector
149+
K = 10000, # Number of training parameter vectors
150+
K_val = 2000, # Number of validation parameter vectors
151+
loss = Flux.mae, # Mean absolute error loss
152+
epochs = 60, # Number of training epochs
153+
epochs_per_Z_refresh = 1, # Refresh data every epoch
154+
epochs_per_θ_refresh = 5, # Refresh parameters every 5 epochs
155+
batchsize = 16, # Batch size for training
154156
verbose = true
155157
)
156158
```
@@ -159,7 +161,7 @@ trained_estimator = train(
159161

160162
We can assess the performance of our trained estimator on held-out test data:
161163

162-
```@example
164+
```julia
163165
# Generate test data
164166
n_test = 100
165167
θ_test = sample(n_test)
@@ -183,9 +185,9 @@ println("RMSE: ", rmse_results)
183185

184186
## Visualizing Parameter Recovery
185187

186-
A key advantage of neural estimation is the ability to quickly conduct inference after training. For example, we can visualize the recovery of parameters:
188+
A key advantage of neural estimation is the ability to quickly conduct inference after training. For example, we can visualize the recovery of parameters. While NeuralEstimators provides built-in visualization capabilities through the [AlgebraOfGraphics.jl](https://github.com/MakieOrg/AlgebraOfGraphics.jl), we will demonstrate custom plotting below:
187189

188-
```@example
190+
```julia
189191
# Extract data from assessment
190192
df = assessment.df
191193

@@ -219,7 +221,7 @@ display(p_combined)
219221

220222
Once trained, the estimator can instantly recover parameters from new data via a forward pass:
221223

222-
```@example
224+
```julia
223225
# Generate "observed" data
224226
ν = [2.5, 2.0]
225227
α = 1.5
@@ -249,6 +251,206 @@ Neural estimators are particularly effective for models with computationally int
249251

250252
Additional details can be found in the [NeuralEstimators.jl documentation](https://github.com/msainsburydale/NeuralEstimators).
251253

254+
# Complete Code
255+
```@raw html
256+
<details>
257+
<summary><b>Show Details</b></summary>
258+
```
259+
```julia
260+
using NeuralEstimators
261+
using SequentialSamplingModels
262+
using Flux
263+
using Distributions
264+
using Random
265+
using Plots
266+
267+
Random.seed!(123)
268+
269+
# Function to sample parameters from priors
270+
function sample(K::Integer)
271+
ν1 = rand(Gamma(2, 1/1.2f0), K) # Drift rate 1
272+
ν2 = rand(Gamma(2, 1/1.2f0), K) # Drift rate 2
273+
α = rand(Gamma(10, 1/6f0), K) # Threshold
274+
β = rand(Beta(1, 5f0), K) # Lateral inhibition
275+
λ = rand(Beta(1, 5f0), K) # Leak rate
276+
τ = rand(Gamma(1.5, 1/5.0f0), K) # Non-decision time
277+
278+
# Stack parameters into a matrix (d×K)
279+
θ = vcat(ν1', ν2', α', β', λ', τ')
280+
281+
return θ
282+
end
283+
284+
# Function to simulate data from the LCA model
285+
function simulate(θ, n_trials_per_param)
286+
# Simulate data for each parameter vector
287+
simulated_data = map(eachcol(θ)) do param
288+
# Extract parameters for this model
289+
ν1, ν2, α, β, λ, τ = param
290+
ν = [ν1, ν2] # Two-choice LCA
291+
292+
# Create LCA model with SSM
293+
model = LCA(; ν, α, β, λ, τ)
294+
295+
# Generate choices and reaction times
296+
choices, rts = rand(model, n_trials_per_param)
297+
298+
# Return as a transpose matrix where each column is a trial
299+
return [choices rts]'
300+
301+
end
302+
303+
return simulated_data
304+
end
305+
306+
# Create neural network architecture for parameter recovery
307+
function create_neural_estimator(;
308+
ν_bounds = (0.1, 4.0),
309+
α_bounds = (0.5, 3.5),
310+
β_bounds = (0.0, 0.5),
311+
λ_bounds = (0.0, 0.5),
312+
τ_bounds = (0.1, 0.5)
313+
)
314+
# Unpack defined parameter Bounds
315+
ν_min, ν_max = ν_bounds # Drift rates
316+
α_min, α_max = α_bounds # Threshold
317+
β_min, β_max = β_bounds # Lateral inhibition
318+
λ_min, λ_max = λ_bounds # Leak rate
319+
τ_min, τ_max = τ_bounds # Non-decision time
320+
321+
# Input dimension: 2 (choice and RT for each trial)
322+
n = 2
323+
# Output dimension: 6 parameters
324+
d = 6 # ν[1], ν[2], α, β, λ, τ
325+
# Width of hidden layers
326+
w = 128
327+
328+
# Inner network - processes each trial independently
329+
ψ = Chain(
330+
Dense(n, w, relu),
331+
Dense(w, w, relu),
332+
Dense(w, w, relu)
333+
)
334+
335+
# Final layer with parameter constraints
336+
final_layer = Parallel(
337+
vcat,
338+
Dense(w, 1, x -> ν_min + (ν_max - ν_min) * σ(x)), # ν1
339+
Dense(w, 1, x -> ν_min + (ν_max - ν_min) * σ(x)), # ν2
340+
Dense(w, 1, x -> α_min + (α_max - α_min) * σ(x)), # α
341+
Dense(w, 1, x -> β_min + (β_max - β_min) * σ(x)), # β
342+
Dense(w, 1, x -> λ_min + (λ_max - λ_min) * σ(x)), # λ
343+
Dense(w, 1, x -> τ_min + (τ_max - τ_min) * σ(x)) # τ
344+
)
345+
346+
# Outer network - maps aggregated features to parameters
347+
ϕ = Chain(
348+
Dense(w, w, relu),
349+
Dense(w, w, relu),
350+
final_layer
351+
)
352+
353+
# Combine into a DeepSet
354+
network = DeepSet(ψ, ϕ)
355+
356+
# Initialize neural Bayes estimator
357+
estimator = PointEstimator(network)
358+
359+
return estimator
360+
end
361+
362+
# Create the neural estimator
363+
estimator = create_neural_estimator()
364+
365+
# Train network
366+
trained_estimator = train(
367+
estimator,
368+
sample, # Parameter sampler function
369+
simulate, # Data simulator function
370+
m = 100, # Number of trials per parameter vector
371+
K = 10000, # Number of training parameter vectors
372+
K_val = 2000, # Number of validation parameter vectors
373+
loss = Flux.mae, # Mean absolute error loss
374+
epochs = 60, # Number of training epochs
375+
epochs_per_Z_refresh = 1, # Refresh data every epoch
376+
epochs_per_θ_refresh = 5, # Refresh parameters every 5 epochs
377+
batchsize = 16, # Batch size for training
378+
verbose = true
379+
)
380+
381+
# Generate test data
382+
n_test = 100
383+
θ_test = sample(n_test)
384+
Z_test = simulate(θ_test, 100)
385+
386+
# Assess the estimator
387+
parameter_names = ["ν1", "ν2", "α", "β", "λ", "τ"]
388+
assessment = assess(
389+
trained_estimator,
390+
θ_test,
391+
Z_test;
392+
parameter_names = parameter_names
393+
)
394+
395+
# Calculate performance metrics
396+
bias_results = bias(assessment)
397+
rmse_results = rmse(assessment)
398+
println("Bias: ", bias_results)
399+
println("RMSE: ", rmse_results)
400+
401+
# Extract data from assessment
402+
df = assessment.df
403+
404+
# Create recovery plots for each parameter
405+
params = unique(df.parameter)
406+
p_plots = []
407+
408+
for param in params
409+
param_data = filter(row -> row.parameter == param, df)
410+
p = scatter(
411+
param_data.truth,
412+
param_data.estimate,
413+
xlabel="Ground Truth",
414+
ylabel="Estimated",
415+
title=param,
416+
legend=false
417+
)
418+
plot!(p, [minimum(param_data.truth), maximum(param_data.truth)],
419+
[minimum(param_data.truth), maximum(param_data.truth)],
420+
line=:dash, color=:black)
421+
push!(p_plots, p)
422+
end
423+
424+
# Combine plots
425+
p_combined = plot(p_plots..., layout=(3,2), size=(800, 600))
426+
display(p_combined)
427+
428+
# Generate "observed" data
429+
ν = [2.5, 2.0]
430+
α = 1.5
431+
β = 0.2
432+
λ = 0.1
433+
τ = 0.3
434+
σ = 1.0
435+
436+
# Create model and generate data
437+
true_model = LCA(; ν, α, β, λ, τ)
438+
observed_choices, observed_rts = rand(true_model, 100)
439+
440+
# Format the data
441+
observed_data = Float32.([observed_choices observed_rts]')
442+
443+
# Recover parameters
444+
recovered_params = NeuralEstimators.estimate(trained_estimator, [observed_data])
445+
446+
# Compare true and recovered parameters
447+
println("True parameters: ", [ν[1], ν[2], α, β, λ, τ])
448+
println("Recovered parameters: ", recovered_params)
449+
```
450+
```@raw html
451+
</details>
452+
```
453+
252454
# References
253455

254456
Miletić, S., Turner, B. M., Forstmann, B. U., & van Maanen, L. (2017). Parameter recovery for the leaky competing accumulator model. Journal of Mathematical Psychology, 76, 25-50.

0 commit comments

Comments
 (0)