Skip to content

Commit 580b359

Browse files
Add enhance documentation for neural parameter estimation
1 parent 0635ea3 commit 580b359

3 files changed

Lines changed: 69 additions & 29 deletions

File tree

-49.2 KB
Loading

docs/src/assets/npe_example.png

29.5 KB
Loading

docs/src/neuralestimators_amorized.md

Lines changed: 69 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Neural Parameter Estimation
22

3-
Neural parameter estimation provides a likelihood-free approach to parameter recovery, especially useful for models with computationally intractable likelihoods. This method is based on training neural networks to learn the mapping from data to parameters. See the review paper by Zammit-Mangion et al. (2025) for more details. Once trained, these networks can perform inference rapidly across multiple datasets, making them particularly valuable for models like the Leaky Competing Accumulator (LCA; Usher & McClelland, 2001).
3+
Neural parameter estimation uses neural networks to perform parameter estimation by learning the mapping between simulated data and model parameters (for a detailed review, see Zammit-Mangion et al., 2024). Neural parameter estimation is particularly useful for models with computationally intractable likelihoods, such as the Leaky Competing Accumulator (LCA; Usher & McClelland, 2001). Once trained, neural networks can be saved and reused to perform inference rapidly across multiple datasets, or used for computationally intensive parameter recovery simulations to understand the quality of parameter estimates under ideal conditions.
44

55
Below, we demonstrate how to estimate parameters of the LCA model using the [NeuralEstimators.jl](https://github.com/msainsburydale/NeuralEstimators) package.
66

@@ -24,7 +24,12 @@ Random.seed!(123)
2424

2525
## Define Parameter Sampling
2626

27-
Unlike traditional Bayesian inference methods, simulation-based inference approaches require us to define a prior sampling function specifically to generate synthetic training data. While traditional methods like MCMC also sample from the prior, those samples are used directly during inference rather than to create a separate training dataset. In SBI, we use the prior to sample a wide range of parameters and simulate corresponding data, which we then use to train a model (e.g., a neural network) to approximate the posterior. We will use the following function to sample a range of parameters for training:
27+
Unlike traditional Bayesian inference methods, neural parameter estimation requires us to define two functions so that the neural network can learn the mapping between simulated data and parameters. One function samples parameters from a prior distribution, and the other generates simulated data based on a sampled parameter vector. While traditional methods like MCMC also sample from the prior, those samples are used directly during inference rather than to create a separate training dataset.
28+
29+
![](assets/npe_example.png)
30+
*Schematic of neural parameter estimation. Once trained, the neural network provides a direct mapping from observed data (Z) to parameter estimates (θ̂), enabling rapid inference without the computational burden of traditional methods.*
31+
32+
In neural parameter estimation, we use the prior to sample a wide range of parameters and simulate corresponding data, which we then use to train a model (e.g., a neural network) to approximate a point estimate or the posterior. We use the following function to sample a range of parameters for training:
2833

2934
```julia
3035
# Function to sample parameters from priors
@@ -63,7 +68,8 @@ function simulate(θ, n_trials_per_param)
6368
choices, rts = rand(model, n_trials_per_param)
6469

6570
# Return as a transpose matrix where each column is a trial
66-
return [choices rts]'
71+
return Float32.([choices rts]')
72+
6773
end
6874

6975
return simulated_data
@@ -77,11 +83,11 @@ For LCA parameter recovery, we use a DeepSet architecture which respects the per
7783
```julia
7884
# Create neural network architecture for parameter recovery
7985
function create_neural_estimator(;
80-
ν_bounds = (0.1, 4.0),
81-
α_bounds = (0.5, 3.5),
82-
β_bounds = (0.0, 0.5),
83-
λ_bounds = (0.0, 0.5),
84-
τ_bounds = (0.1, 0.5)
86+
ν_bounds = (0.1, 6.0),
87+
α_bounds = (0.3, 4.5),
88+
β_bounds = (0.0, 0.8),
89+
λ_bounds = (0.0, 0.8),
90+
τ_bounds = (0.100, 2.0)
8591
)
8692
# Unpack defined parameter Bounds
8793
ν_min, ν_max = ν_bounds # Drift rates
@@ -132,6 +138,8 @@ function create_neural_estimator(;
132138
end
133139
```
134140

141+
The result of our constructed neural network is a point estimator that corresponds to a Bayes estimator, which is a functional of the posterior distribution. Under the specified loss function, this point estimate corresponds to the posterior mean. For details on the theoretical foundations of neural Bayes estimators, see Sainsbury-Dale et al. (2024).
142+
135143
## Training the Neural Estimator
136144

137145
Neural estimators, like all deep learning methods, require a training phase during which they learn the mapping from data to parameters. Here, we train the estimator by simulating data on the fly: the sampler provides new parameter vectors from the prior, and the simulator generates corresponding data conditional on those parameters. Since we use online training and the network never sees the same simulated dataset twice, overfitting is less likely. For more details on training, see the API for arguments [here](https://msainsburydale.github.io/NeuralEstimators.jl/dev/API/core/#Training).
@@ -163,9 +171,9 @@ We can assess the performance of our trained estimator on held-out test data:
163171

164172
```julia
165173
# Generate test data
166-
n_test = 100
174+
n_test = 500
167175
θ_test = sample(n_test)
168-
Z_test = simulate(θ_test, 100)
176+
Z_test = simulate(θ_test, 500)
169177

170178
# Assess the estimator
171179
parameter_names = ["ν1", "ν2", "α", "β", "λ", "τ"]
@@ -197,17 +205,34 @@ p_plots = []
197205

198206
for param in params
199207
param_data = filter(row -> row.parameter == param, df)
208+
209+
# Calculate correlation coefficient
210+
truth = param_data.truth
211+
estimate = param_data.estimate
212+
correlation = cor(truth, estimate)
213+
214+
# Create plot
200215
p = scatter(
201-
param_data.truth,
202-
param_data.estimate,
216+
truth,
217+
estimate,
203218
xlabel="Ground Truth",
204219
ylabel="Estimated",
205220
title=param,
206221
legend=false
207222
)
208-
plot!(p, [minimum(param_data.truth), maximum(param_data.truth)],
209-
[minimum(param_data.truth), maximum(param_data.truth)],
223+
224+
# Add diagonal reference line
225+
plot!(p, [minimum(truth), maximum(truth)],
226+
[minimum(truth), maximum(truth)],
210227
line=:dash, color=:black)
228+
229+
# Get current axis limits after plot is created
230+
x_min, x_max = xlims(p)
231+
y_min, y_max = ylims(p)
232+
233+
# Position text at the top-left corner of the plot
234+
annotate!(p, x_min + 0.1, y_max, text("R = $(round(correlation, digits=3))", :left, 10))
235+
211236
push!(p_plots, p)
212237
end
213238

@@ -228,7 +253,6 @@ Once trained, the estimator can instantly recover parameters from new data via a
228253
β = 0.2
229254
λ = 0.1
230255
τ = 0.3
231-
σ = 1.0
232256

233257
# Create model and generate data
234258
true_model = LCA(; ν, α, β, λ, τ)
@@ -247,7 +271,7 @@ println("Recovered parameters: ", recovered_params)
247271

248272
## Notes on Performance
249273

250-
Neural estimators are particularly effective for models with computationally intractable likelihoods like the LCA model. However, certain parameters (particularly β and λ) can be difficult to recover, even with advanced neural network architectures. This is a property of the LCA model rather than a limitation of the estimation technique.
274+
Neural estimators are particularly effective for models with computationally intractable likelihoods like the LCA model. However, certain parameters (particularly β and λ) can be difficult to recover, even with advanced neural network architectures. This is a property of the LCA model rather than a limitation of the estimation technique.
251275

252276
Additional details can be found in the [NeuralEstimators.jl documentation](https://github.com/msainsburydale/NeuralEstimators).
253277

@@ -296,7 +320,7 @@ function simulate(θ, n_trials_per_param)
296320
choices, rts = rand(model, n_trials_per_param)
297321

298322
# Return as a transpose matrix where each column is a trial
299-
return [choices rts]'
323+
return Float32.([choices rts]')
300324

301325
end
302326

@@ -305,11 +329,11 @@ end
305329

306330
# Create neural network architecture for parameter recovery
307331
function create_neural_estimator(;
308-
ν_bounds = (0.1, 4.0),
309-
α_bounds = (0.5, 3.5),
310-
β_bounds = (0.0, 0.5),
311-
λ_bounds = (0.0, 0.5),
312-
τ_bounds = (0.1, 0.5)
332+
ν_bounds = (0.1, 6.0),
333+
α_bounds = (0.3, 4.5),
334+
β_bounds = (0.0, 0.8),
335+
λ_bounds = (0.0, 0.8),
336+
τ_bounds = (0.100, 2.0)
313337
)
314338
# Unpack defined parameter Bounds
315339
ν_min, ν_max = ν_bounds # Drift rates
@@ -379,9 +403,9 @@ trained_estimator = train(
379403
)
380404

381405
# Generate test data
382-
n_test = 100
406+
n_test = 500
383407
θ_test = sample(n_test)
384-
Z_test = simulate(θ_test, 100)
408+
Z_test = simulate(θ_test, 500)
385409

386410
# Assess the estimator
387411
parameter_names = ["ν1", "ν2", "α", "β", "λ", "τ"]
@@ -407,17 +431,34 @@ p_plots = []
407431

408432
for param in params
409433
param_data = filter(row -> row.parameter == param, df)
434+
435+
# Calculate correlation coefficient
436+
truth = param_data.truth
437+
estimate = param_data.estimate
438+
correlation = cor(truth, estimate)
439+
440+
# Create plot
410441
p = scatter(
411-
param_data.truth,
412-
param_data.estimate,
442+
truth,
443+
estimate,
413444
xlabel="Ground Truth",
414445
ylabel="Estimated",
415446
title=param,
416447
legend=false
417448
)
418-
plot!(p, [minimum(param_data.truth), maximum(param_data.truth)],
419-
[minimum(param_data.truth), maximum(param_data.truth)],
449+
450+
# Add diagonal reference line
451+
plot!(p, [minimum(truth), maximum(truth)],
452+
[minimum(truth), maximum(truth)],
420453
line=:dash, color=:black)
454+
455+
# Get current axis limits after plot is created
456+
x_min, x_max = xlims(p)
457+
y_min, y_max = ylims(p)
458+
459+
# Position text at the top-left corner of the plot
460+
annotate!(p, x_min + 0.1, y_max, text("R = $(round(correlation, digits=3))", :left, 10))
461+
421462
push!(p_plots, p)
422463
end
423464

@@ -431,7 +472,6 @@ display(p_combined)
431472
β = 0.2
432473
λ = 0.1
433474
τ = 0.3
434-
σ = 1.0
435475

436476
# Create model and generate data
437477
true_model = LCA(; ν, α, β, λ, τ)

0 commit comments

Comments
 (0)