|
1 | 1 | """ |
2 | | - se_bootstrap(sem_fit::SemFit; n_boot = 3000, data = nothing, kwargs...) |
| 2 | + se_bootstrap(semfit::SemFit, n_boot::Integer = 3000, data = nothing) |
3 | 3 |
|
4 | | -Return boorstrap standard errors. |
| 4 | +Return bootstrap standard errors. |
| 5 | +
|
| 6 | +Supports both single-group and multi-group models. |
| 7 | +For multi-group models, each group is resampled independently. |
5 | 8 |
|
6 | 9 | # Arguments |
7 | | -- `n_boot`: number of boostrap samples |
8 | | -- `data`: data to sample from. Only needed if different than the data from `sem_fit` |
9 | | -- `kwargs...`: passed down to `replace_observed` |
| 10 | +- `n_boot`: number of bootstrap samples |
| 11 | +- `data`: optional new data to sample from; for multi-group models, |
| 12 | + pass a `Dict{Symbol}` mapping term ids to data matrices |
10 | 13 | """ |
11 | 14 | function se_bootstrap( |
12 | | - semfit::SemFit{Mi, So, St, Mo, O}; |
13 | | - n_boot = 3000, |
| 15 | + semfit::SemFit, |
| 16 | + n_boot::Integer = 3000, |
14 | 17 | data = nothing, |
15 | | - specification = nothing, |
16 | | - kwargs..., |
17 | | -) where {Mi, So, St, Mo <: AbstractSemSingle, O} |
| 18 | +) |
| 19 | + sem = model(semfit) |
| 20 | + semterms = SEM.sem_terms(sem) |
| 21 | + |
18 | 22 | if isnothing(data) |
19 | | - data = samples(observed(model(semfit))) |
| 23 | + if length(semterms) == 1 |
| 24 | + # single-group: extract data matrix |
| 25 | + data = samples(observed(loss(semterms[1]))) |
| 26 | + else |
| 27 | + # multi-group: extract per-term data dict |
| 28 | + data = Dict{Symbol, AbstractMatrix}( |
| 29 | + id(term) => samples(observed(loss(term))) for term in semterms |
| 30 | + ) |
| 31 | + end |
20 | 32 | end |
21 | 33 |
|
22 | | - data = prepare_data_bootstrap(data) |
23 | | - |
24 | | - start = solution(semfit) |
| 34 | + fit_params = solution(semfit) |
25 | 35 |
|
26 | | - new_solution = zero(start) |
27 | | - sum = zero(start) |
28 | | - squared_sum = zero(start) |
| 36 | + # accumulator of bootstrapped params and their squares |
| 37 | + params_sum = zero(fit_params) |
| 38 | + params_squared = zero(fit_params) |
29 | 39 |
|
30 | | - n_failed = 0.0 |
31 | | - |
32 | | - converged = true |
| 40 | + n_conv = 0 |
| 41 | + n_failed = 0 |
33 | 42 |
|
34 | 43 | for _ in 1:n_boot |
35 | | - sample_data = bootstrap_sample(data) |
36 | | - new_model = replace_observed( |
37 | | - model(semfit); |
38 | | - data = sample_data, |
39 | | - specification = specification, |
40 | | - kwargs..., |
41 | | - ) |
42 | | - |
43 | | - new_solution .= 0.0 |
| 44 | + boot_data = resample_with_replacement(data) |
| 45 | + boot_model = replace_observed(sem, boot_data) |
44 | 46 |
|
45 | 47 | try |
46 | | - new_solution = solution(fit(new_model; start_val = start)) |
| 48 | + params_boot = solution(fit(boot_model; start_val = fit_params)) |
| 49 | + params_sum .+= params_boot |
| 50 | + params_squared .+= params_boot .^ 2 |
| 51 | + n_conv += 1 |
47 | 52 | catch |
48 | 53 | n_failed += 1 |
49 | 54 | end |
50 | | - |
51 | | - @. sum += new_solution |
52 | | - @. squared_sum += new_solution^2 |
53 | | - |
54 | | - converged = true |
55 | 55 | end |
56 | 56 |
|
57 | | - n_conv = n_boot - n_failed |
58 | | - sd = sqrt.(squared_sum / n_conv - (sum / n_conv) .^ 2) |
59 | | - print("Number of nonconverged models: ", n_failed, "\n") |
60 | | - return sd |
61 | | -end |
62 | | - |
63 | | -function se_bootstrap( |
64 | | - semfit::SemFit{Mi, So, St, Mo, O}; |
65 | | - n_boot = 3000, |
66 | | - data = nothing, |
67 | | - specification = nothing, |
68 | | - kwargs..., |
69 | | -) where {Mi, So, St, Mo <: SemEnsemble, O} |
70 | | - models = semfit.model.sems |
71 | | - groups = semfit.model.groups |
72 | | - |
73 | | - if isnothing(data) |
74 | | - data = Dict(g => samples(observed(m)) for (g, m) in zip(groups, models)) |
| 57 | + if n_failed > 0 |
| 58 | + @warn "$n_failed bootstrap attempts did not converge" |
75 | 59 | end |
76 | 60 |
|
77 | | - data = Dict(k => prepare_data_bootstrap(data[k]) for k in keys(data)) |
78 | | - |
79 | | - start = solution(semfit) |
80 | | - |
81 | | - new_solution = zero(start) |
82 | | - sum = zero(start) |
83 | | - squared_sum = zero(start) |
84 | | - |
85 | | - n_failed = 0.0 |
86 | | - |
87 | | - converged = true |
88 | | - |
89 | | - for _ in 1:n_boot |
90 | | - sample_data = Dict(k => bootstrap_sample(data[k]) for k in keys(data)) |
91 | | - new_model = replace_observed( |
92 | | - semfit.model; |
93 | | - data = sample_data, |
94 | | - specification = specification, |
95 | | - kwargs..., |
96 | | - ) |
97 | | - |
98 | | - new_solution .= 0.0 |
99 | | - |
100 | | - try |
101 | | - new_solution = solution(fit(new_model; start_val = start)) |
102 | | - catch |
103 | | - n_failed += 1 |
104 | | - end |
105 | | - |
106 | | - @. sum += new_solution |
107 | | - @. squared_sum += new_solution^2 |
| 61 | + # calculate element-wise standard errors |
| 62 | + return sqrt.(params_squared ./ n_conv .- (params_sum ./ n_conv) .^ 2) |
| 63 | +end |
108 | 64 |
|
109 | | - converged = true |
110 | | - end |
| 65 | +""" |
| 66 | + resample_with_replacement(data::AbstractMatrix) |
| 67 | + resample_with_replacement(data::AbstractDict) |
111 | 68 |
|
112 | | - n_conv = n_boot - n_failed |
113 | | - sd = sqrt.(squared_sum / n_conv - (sum / n_conv) .^ 2) |
114 | | - print("Number of nonconverged models: ", n_failed, "\n") |
115 | | - return sd |
116 | | -end |
| 69 | +Resample rows of a data matrix with replacement (bootstrap sample). |
117 | 70 |
|
118 | | -function prepare_data_bootstrap(data) |
119 | | - return Matrix(data) |
| 71 | +If the dictionary of matrices is passed (for multi-group models), |
| 72 | +independently resamples each matrix. |
| 73 | +""" |
| 74 | +function resample_with_replacement(data::AbstractMatrix) |
| 75 | + n = size(data, 1) |
| 76 | + return data[rand(1:n, n), :] |
120 | 77 | end |
121 | 78 |
|
122 | | -function bootstrap_sample(data) |
123 | | - nobs = size(data, 1) |
124 | | - index_new = rand(1:nobs, nobs) |
125 | | - data_new = data[index_new, :] |
126 | | - return data_new |
| 79 | +function resample_with_replacement(data::AbstractDict) |
| 80 | + return typeof(data)( |
| 81 | + k => resample_with_replacement(v) for (k, v) in data |
| 82 | + ) |
127 | 83 | end |
0 commit comments