Skip to content

Commit 6f4ac26

Browse files
remove bootstrap try-catch and update tests
1 parent b4cc3a3 commit 6f4ac26

3 files changed

Lines changed: 43 additions & 76 deletions

File tree

src/frontend/fit/standard_errors/bootstrap.jl

Lines changed: 41 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -55,28 +55,21 @@ function bootstrap(
5555
# pre-allocations
5656
out = []
5757
conv = []
58-
errors = []
59-
n_failed = Ref(0)
6058
# fit to bootstrap samples
6159
if !parallel
6260
for _ in 1:n_boot
63-
try
64-
sample_data = bootstrap_sample(data)
65-
new_model = replace_observed(
66-
fitted.model;
67-
data = sample_data,
68-
specification = specification,
69-
replace_kwargs...,
70-
)
71-
new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...)
72-
sample = statistic(new_fit)
73-
c = converged(new_fit)
74-
push!(out, sample)
75-
push!(conv, c)
76-
catch e
77-
n_failed[] += 1
78-
push!(errors, e)
79-
end
61+
sample_data = bootstrap_sample(data)
62+
new_model = replace_observed(
63+
fitted.model;
64+
data = sample_data,
65+
specification = specification,
66+
replace_kwargs...,
67+
)
68+
new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...)
69+
sample = statistic(new_fit)
70+
c = converged(new_fit)
71+
push!(out, sample)
72+
push!(conv, c)
8073
end
8174
else
8275
n_threads = Threads.nthreads()
@@ -89,42 +82,28 @@ function bootstrap(
8982
lk = ReentrantLock()
9083
Threads.@threads for _ in 1:n_boot
9184
thread_model = take!(model_pool)
92-
try
93-
sample_data = bootstrap_sample(data)
94-
new_model = replace_observed(
95-
thread_model;
96-
data = sample_data,
97-
specification = specification,
98-
replace_kwargs...,
99-
)
100-
new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...)
101-
sample = statistic(new_fit)
102-
c = converged(new_fit)
103-
lock(lk) do
104-
push!(out, sample)
105-
push!(conv, c)
106-
end
107-
catch e
108-
lock(lk) do
109-
n_failed[] += 1
110-
push!(errors, e)
111-
end
112-
finally
113-
put!(model_pool, thread_model)
85+
sample_data = bootstrap_sample(data)
86+
new_model = replace_observed(
87+
thread_model;
88+
data = sample_data,
89+
specification = specification,
90+
replace_kwargs...,
91+
)
92+
new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...)
93+
sample = statistic(new_fit)
94+
c = converged(new_fit)
95+
lock(lk) do
96+
push!(out, sample)
97+
push!(conv, c)
11498
end
99+
put!(model_pool, thread_model)
115100
end
116101
end
117-
# compute parameters
118-
if !iszero(n_failed[])
119-
@warn "During bootstrap sampling, "*string(n_failed[])*" samples errored."
120-
end
121102
return Dict(
122103
:samples => out,
123104
:n_boot => n_boot,
124105
:n_converged => isempty(conv) ? 0 : sum(conv),
125106
:converged => conv,
126-
:n_errored => n_failed[],
127-
:errors => errors
128107
)
129108
end
130109

@@ -181,8 +160,6 @@ function se_bootstrap(
181160
# pre-allocations
182161
total_sum = zero(start)
183162
total_squared_sum = zero(start)
184-
n_failed = Ref(0)
185-
n_conv = Ref(0)
186163
# fit to bootstrap samples
187164
if !parallel
188165
for _ in 1:n_boot
@@ -217,39 +194,29 @@ function se_bootstrap(
217194
lk = ReentrantLock()
218195
Threads.@threads for _ in 1:n_boot
219196
thread_model = take!(model_pool)
220-
try
221-
sample_data = bootstrap_sample(data)
222-
new_model = replace_observed(
223-
thread_model;
224-
data = sample_data,
225-
specification = specification,
226-
replace_kwargs...,
227-
)
228-
new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...)
229-
sol = solution(new_fit)
230-
conv = converged(new_fit)
231-
if conv
232-
lock(lk) do
233-
n_conv[] += 1
234-
@. total_sum += sol
235-
@. total_squared_sum += sol^2
236-
end
237-
end
238-
catch
197+
sample_data = bootstrap_sample(data)
198+
new_model = replace_observed(
199+
thread_model;
200+
data = sample_data,
201+
specification = specification,
202+
replace_kwargs...,
203+
)
204+
new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...)
205+
sol = solution(new_fit)
206+
conv = converged(new_fit)
207+
if conv
239208
lock(lk) do
240-
n_failed[] += 1
209+
n_conv[] += 1
210+
@. total_sum += sol
211+
@. total_squared_sum += sol^2
241212
end
242-
finally
243-
put!(model_pool, thread_model)
244213
end
214+
put!(model_pool, thread_model)
245215
end
246216
end
247217
# compute parameters
248218
n_conv = n_conv[]
249219
sd = sqrt.(total_squared_sum / n_conv - (total_sum / n_conv) .^ 2)
250-
if !iszero(n_failed[])
251-
@warn "During bootstrap sampling, "*string(n_failed[])*" samples errored"
252-
end
253220
@info string(n_conv)*" models converged"
254221
return sd
255222
end

test/examples/helper.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ function test_bootstrap(
153153
# se_bootstrap and bootstrap |> se are close
154154
if compare_bs
155155
bs_samples = bootstrap(model_fit, spec; n_boot = n_boot)
156-
@test bs_samples[:n_converged] > 0.95*n_boot
156+
@test bs_samples[:n_converged] >= 0.95*n_boot
157157
bs_samples = cat(bs_samples[:samples][BitVector(bs_samples[:converged])]..., dims = 2)
158158
se_bs_2 = sqrt.(var(bs_samples, corrected = false, dims = 2))
159159
@test isapprox(se_bs_2, se_bs, rtol = rtol_bs)

test/examples/multigroup/build_models.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ model_g2 = Sem(specification = specification_g2, data = dat_g2, implied = RAM)
1212
SEM.param_labels(model_g2.implied.ram_matrices)
1313

1414
# test the different constructors
15-
model_ml_multigroup = SemEnsemble(model_g1, model_g2)
15+
model_ml_multigroup = SemEnsemble(model_g1, model_g2; groups = [:Pasteur, :Grant_White])
1616
model_ml_multigroup2 = SemEnsemble(
1717
specification = partable,
1818
data = dat,

0 commit comments

Comments
 (0)