Skip to content

Commit e3ec0cf

Browse files
author
Alexey Stukalov
committed
bootstrap: sync with Sem updates
1 parent cbca143 commit e3ec0cf

1 file changed

Lines changed: 54 additions & 98 deletions

File tree

Lines changed: 54 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,127 +1,83 @@
11
"""
2-
se_bootstrap(sem_fit::SemFit; n_boot = 3000, data = nothing, kwargs...)
2+
se_bootstrap(semfit::SemFit, n_boot::Integer = 3000, data = nothing)
33
4-
Return boorstrap standard errors.
4+
Return bootstrap standard errors.
5+
6+
Supports both single-group and multi-group models.
7+
For multi-group models, each group is resampled independently.
58
69
# Arguments
7-
- `n_boot`: number of boostrap samples
8-
- `data`: data to sample from. Only needed if different than the data from `sem_fit`
9-
- `kwargs...`: passed down to `replace_observed`
10+
- `n_boot`: number of bootstrap samples
11+
- `data`: optional new data to sample from; for multi-group models,
12+
pass a `Dict{Symbol}` mapping term ids to data matrices
1013
"""
1114
function se_bootstrap(
12-
semfit::SemFit{Mi, So, St, Mo, O};
13-
n_boot = 3000,
15+
semfit::SemFit,
16+
n_boot::Integer = 3000,
1417
data = nothing,
15-
specification = nothing,
16-
kwargs...,
17-
) where {Mi, So, St, Mo <: AbstractSemSingle, O}
18+
)
19+
sem = model(semfit)
20+
semterms = SEM.sem_terms(sem)
21+
1822
if isnothing(data)
19-
data = samples(observed(model(semfit)))
23+
if length(semterms) == 1
24+
# single-group: extract data matrix
25+
data = samples(observed(loss(semterms[1])))
26+
else
27+
# multi-group: extract per-term data dict
28+
data = Dict{Symbol, AbstractMatrix}(
29+
id(term) => samples(observed(loss(term))) for term in semterms
30+
)
31+
end
2032
end
2133

22-
data = prepare_data_bootstrap(data)
23-
24-
start = solution(semfit)
34+
fit_params = solution(semfit)
2535

26-
new_solution = zero(start)
27-
sum = zero(start)
28-
squared_sum = zero(start)
36+
# accumulator of bootstrapped params and their squares
37+
params_sum = zero(fit_params)
38+
params_squared = zero(fit_params)
2939

30-
n_failed = 0.0
31-
32-
converged = true
40+
n_conv = 0
41+
n_failed = 0
3342

3443
for _ in 1:n_boot
35-
sample_data = bootstrap_sample(data)
36-
new_model = replace_observed(
37-
model(semfit);
38-
data = sample_data,
39-
specification = specification,
40-
kwargs...,
41-
)
42-
43-
new_solution .= 0.0
44+
boot_data = resample_with_replacement(data)
45+
boot_model = replace_observed(sem, boot_data)
4446

4547
try
46-
new_solution = solution(fit(new_model; start_val = start))
48+
params_boot = solution(fit(boot_model; start_val = fit_params))
49+
params_sum .+= params_boot
50+
params_squared .+= params_boot .^ 2
51+
n_conv += 1
4752
catch
4853
n_failed += 1
4954
end
50-
51-
@. sum += new_solution
52-
@. squared_sum += new_solution^2
53-
54-
converged = true
5555
end
5656

57-
n_conv = n_boot - n_failed
58-
sd = sqrt.(squared_sum / n_conv - (sum / n_conv) .^ 2)
59-
print("Number of nonconverged models: ", n_failed, "\n")
60-
return sd
61-
end
62-
63-
function se_bootstrap(
64-
semfit::SemFit{Mi, So, St, Mo, O};
65-
n_boot = 3000,
66-
data = nothing,
67-
specification = nothing,
68-
kwargs...,
69-
) where {Mi, So, St, Mo <: SemEnsemble, O}
70-
models = semfit.model.sems
71-
groups = semfit.model.groups
72-
73-
if isnothing(data)
74-
data = Dict(g => samples(observed(m)) for (g, m) in zip(groups, models))
57+
if n_failed > 0
58+
@warn "$n_failed bootstrap attempts did not converge"
7559
end
7660

77-
data = Dict(k => prepare_data_bootstrap(data[k]) for k in keys(data))
78-
79-
start = solution(semfit)
80-
81-
new_solution = zero(start)
82-
sum = zero(start)
83-
squared_sum = zero(start)
84-
85-
n_failed = 0.0
86-
87-
converged = true
88-
89-
for _ in 1:n_boot
90-
sample_data = Dict(k => bootstrap_sample(data[k]) for k in keys(data))
91-
new_model = replace_observed(
92-
semfit.model;
93-
data = sample_data,
94-
specification = specification,
95-
kwargs...,
96-
)
97-
98-
new_solution .= 0.0
99-
100-
try
101-
new_solution = solution(fit(new_model; start_val = start))
102-
catch
103-
n_failed += 1
104-
end
105-
106-
@. sum += new_solution
107-
@. squared_sum += new_solution^2
61+
# calculate element-wise standard errors
62+
return sqrt.(params_squared ./ n_conv .- (params_sum ./ n_conv) .^ 2)
63+
end
10864

109-
converged = true
110-
end
65+
"""
66+
resample_with_replacement(data::AbstractMatrix)
67+
resample_with_replacement(data::AbstractDict)
11168
112-
n_conv = n_boot - n_failed
113-
sd = sqrt.(squared_sum / n_conv - (sum / n_conv) .^ 2)
114-
print("Number of nonconverged models: ", n_failed, "\n")
115-
return sd
116-
end
69+
Resample rows of a data matrix with replacement (bootstrap sample).
11770
118-
function prepare_data_bootstrap(data)
119-
return Matrix(data)
71+
If the dictionary of matrices is passed (for multi-group models),
72+
independently resamples each matrix.
73+
"""
74+
function resample_with_replacement(data::AbstractMatrix)
75+
n = size(data, 1)
76+
return data[rand(1:n, n), :]
12077
end
12178

122-
function bootstrap_sample(data)
123-
nobs = size(data, 1)
124-
index_new = rand(1:nobs, nobs)
125-
data_new = data[index_new, :]
126-
return data_new
79+
function resample_with_replacement(data::AbstractDict)
80+
return typeof(data)(
81+
k => resample_with_replacement(v) for (k, v) in data
82+
)
12783
end

0 commit comments

Comments
 (0)