|
| 1 | +using LinearAlgebra: norm |
| 2 | + |
1 | 3 | function test_gradient(model, params; rtol = 1e-10, atol = 0) |
2 | 4 | true_grad = FiniteDiff.finite_difference_gradient(Base.Fix1(objective!, model), params) |
3 | 5 | gradient = similar(params) |
@@ -71,262 +73,63 @@ function test_fitmeasures( |
71 | 73 | end |
72 | 74 | end |
73 | 75 |
|
74 | | -function compare_estimates( |
| 76 | +function test_estimates( |
75 | 77 | partable::ParameterTable, |
76 | 78 | partable_lav; |
77 | 79 | rtol = 1e-10, |
78 | 80 | atol = 0, |
79 | 81 | col = :estimate, |
80 | 82 | lav_col = :est, |
| 83 | + lav_group = nothing, |
| 84 | + skip::Bool = false, |
81 | 85 | ) |
82 | | - correct = [] |
83 | | - |
84 | | - for i in findall(partable.columns[:free]) |
85 | | - from = partable.columns[:from][i] |
86 | | - to = partable.columns[:to][i] |
87 | | - type = partable.columns[:parameter_type][i] |
88 | | - estimate = partable.columns[col][i] |
89 | | - |
90 | | - if from == Symbol("1") |
91 | | - lav_ind = |
92 | | - findall((partable_lav.lhs .== String(to)) .& (partable_lav.op .== "~1")) |
93 | | - |
94 | | - if length(lav_ind) == 0 |
95 | | - throw( |
96 | | - ErrorException( |
97 | | - "Parameter from: $from, to: $to, type: $type, could not be found in the lavaan solution", |
98 | | - ), |
99 | | - ) |
100 | | - elseif length(lav_ind) > 1 |
101 | | - throw( |
102 | | - ErrorException( |
103 | | - "At least one parameter was found twice in the lavaan solution", |
104 | | - ), |
105 | | - ) |
106 | | - else |
107 | | - is_correct = isapprox( |
108 | | - estimate, |
109 | | - partable_lav[:, lav_col][lav_ind[1]]; |
110 | | - rtol = rtol, |
111 | | - atol = atol, |
112 | | - ) |
113 | | - push!(correct, is_correct) |
114 | | - end |
115 | | - |
116 | | - else |
117 | | - if type == :↔ |
118 | | - type = "~~" |
119 | | - elseif type == :→ |
120 | | - if (from ∈ partable.latent_vars) && (to ∈ partable.observed_vars) |
121 | | - type = "=~" |
122 | | - else |
123 | | - type = "~" |
124 | | - from, to = to, from |
125 | | - end |
126 | | - end |
127 | | - |
128 | | - if type == "~~" |
129 | | - lav_ind = findall( |
130 | | - ( |
131 | | - ( |
132 | | - (partable_lav.lhs .== String(from)) .& |
133 | | - (partable_lav.rhs .== String(to)) |
134 | | - ) .| ( |
135 | | - (partable_lav.lhs .== String(to)) .& |
136 | | - (partable_lav.rhs .== String(from)) |
137 | | - ) |
138 | | - ) .& (partable_lav.op .== type), |
139 | | - ) |
140 | | - |
141 | | - if length(lav_ind) == 0 |
142 | | - throw( |
143 | | - ErrorException( |
144 | | - "Parameter from: $from, to: $to, type: $type, could not be found in the lavaan solution", |
145 | | - ), |
146 | | - ) |
147 | | - elseif length(lav_ind) > 1 |
148 | | - throw( |
149 | | - ErrorException( |
150 | | - "At least one parameter was found twice in the lavaan solution", |
151 | | - ), |
152 | | - ) |
153 | | - else |
154 | | - is_correct = isapprox( |
155 | | - estimate, |
156 | | - partable_lav[:, lav_col][lav_ind[1]]; |
157 | | - rtol = rtol, |
158 | | - atol = atol, |
159 | | - ) |
160 | | - push!(correct, is_correct) |
161 | | - end |
162 | | - |
163 | | - else |
164 | | - lav_ind = findall( |
165 | | - (partable_lav.lhs .== String(from)) .& |
166 | | - (partable_lav.rhs .== String(to)) .& |
167 | | - (partable_lav.op .== type), |
168 | | - ) |
169 | | - |
170 | | - if length(lav_ind) == 0 |
171 | | - throw( |
172 | | - ErrorException( |
173 | | - "Parameter from: $from, to: $to, type: $type, could not be found in the lavaan solution", |
174 | | - ), |
175 | | - ) |
176 | | - elseif length(lav_ind) > 1 |
177 | | - throw( |
178 | | - ErrorException( |
179 | | - "At least one parameter was found twice in the lavaan solution", |
180 | | - ), |
181 | | - ) |
182 | | - else |
183 | | - is_correct = isapprox( |
184 | | - estimate, |
185 | | - partable_lav[:, lav_col][lav_ind[1]]; |
186 | | - rtol = rtol, |
187 | | - atol = atol, |
188 | | - ) |
189 | | - push!(correct, is_correct) |
190 | | - end |
191 | | - end |
192 | | - end |
| 86 | + actual = StructuralEquationModels.param_values(partable, col) |
| 87 | + expected = StructuralEquationModels.lavaan_param_values( |
| 88 | + partable_lav, |
| 89 | + partable, |
| 90 | + lav_col, |
| 91 | + lav_group, |
| 92 | + ) |
| 93 | + @test !any(isnan, actual) |
| 94 | + @test !any(isnan, expected) |
| 95 | + |
| 96 | + if skip # workaround skip=false not supported in earlier versions |
| 97 | + @test actual ≈ expected rtol = rtol atol = atol norm = Base.Fix2(norm, Inf) skip = |
| 98 | + skip |
| 99 | + else |
| 100 | + @test actual ≈ expected rtol = rtol atol = atol norm = Base.Fix2(norm, Inf) |
193 | 101 | end |
194 | | - |
195 | | - return all(correct) |
196 | 102 | end |
197 | 103 |
|
198 | | -function compare_estimates( |
| 104 | +function test_estimates( |
199 | 105 | ens_partable::EnsembleParameterTable, |
200 | 106 | partable_lav; |
201 | 107 | rtol = 1e-10, |
202 | 108 | atol = 0, |
203 | 109 | col = :estimate, |
204 | 110 | lav_col = :est, |
205 | | - lav_groups, |
| 111 | + lav_groups::AbstractDict, |
| 112 | + skip::Bool = false, |
206 | 113 | ) |
207 | | - correct = [] |
208 | | - |
209 | | - for key in keys(ens_partable.tables) |
210 | | - group = lav_groups[key] |
211 | | - partable = ens_partable.tables[key] |
212 | | - |
213 | | - for i in findall(partable.columns[:free]) |
214 | | - from = partable.columns[:from][i] |
215 | | - to = partable.columns[:to][i] |
216 | | - type = partable.columns[:parameter_type][i] |
217 | | - estimate = partable.columns[col][i] |
218 | | - |
219 | | - if from == Symbol("1") |
220 | | - lav_ind = findall( |
221 | | - (partable_lav.lhs .== String(to)) .& |
222 | | - (partable_lav.op .== "~1") .& |
223 | | - (partable_lav.group .== group), |
224 | | - ) |
225 | | - |
226 | | - if length(lav_ind) == 0 |
227 | | - throw( |
228 | | - ErrorException( |
229 | | - "Mean parameter of variable $to could not be found in the lavaan solution", |
230 | | - ), |
231 | | - ) |
232 | | - elseif length(lav_ind) > 1 |
233 | | - throw( |
234 | | - ErrorException( |
235 | | - "At least one parameter was found twice in the lavaan solution", |
236 | | - ), |
237 | | - ) |
238 | | - else |
239 | | - is_correct = isapprox( |
240 | | - estimate, |
241 | | - partable_lav[:, lav_col][lav_ind[1]]; |
242 | | - rtol = rtol, |
243 | | - atol = atol, |
244 | | - ) |
245 | | - push!(correct, is_correct) |
246 | | - end |
247 | | - |
248 | | - else |
249 | | - if type == :↔ |
250 | | - type = "~~" |
251 | | - elseif type == :→ |
252 | | - if (from ∈ partable.latent_vars) && (to ∈ partable.observed_vars) |
253 | | - type = "=~" |
254 | | - else |
255 | | - type = "~" |
256 | | - from, to = to, from |
257 | | - end |
258 | | - end |
259 | | - |
260 | | - if type == "~~" |
261 | | - lav_ind = findall( |
262 | | - ( |
263 | | - ( |
264 | | - (partable_lav.lhs .== String(from)) .& |
265 | | - (partable_lav.rhs .== String(to)) |
266 | | - ) .| ( |
267 | | - (partable_lav.lhs .== String(to)) .& |
268 | | - (partable_lav.rhs .== String(from)) |
269 | | - ) |
270 | | - ) .& |
271 | | - (partable_lav.op .== type) .& |
272 | | - (partable_lav.group .== group), |
273 | | - ) |
274 | | - |
275 | | - if length(lav_ind) == 0 |
276 | | - throw( |
277 | | - ErrorException( |
278 | | - "Parameter from: $from, to: $to, type: $type, could not be found in the lavaan solution", |
279 | | - ), |
280 | | - ) |
281 | | - elseif length(lav_ind) > 1 |
282 | | - throw( |
283 | | - ErrorException( |
284 | | - "At least one parameter was found twice in the lavaan solution", |
285 | | - ), |
286 | | - ) |
287 | | - else |
288 | | - is_correct = isapprox( |
289 | | - estimate, |
290 | | - partable_lav[:, lav_col][lav_ind[1]]; |
291 | | - rtol = rtol, |
292 | | - atol = atol, |
293 | | - ) |
294 | | - push!(correct, is_correct) |
295 | | - end |
296 | | - |
297 | | - else |
298 | | - lav_ind = findall( |
299 | | - (partable_lav.lhs .== String(from)) .& |
300 | | - (partable_lav.rhs .== String(to)) .& |
301 | | - (partable_lav.op .== type) .& |
302 | | - (partable_lav.group .== group), |
303 | | - ) |
304 | | - |
305 | | - if length(lav_ind) == 0 |
306 | | - throw( |
307 | | - ErrorException( |
308 | | - "Parameter $from $type $to could not be found in the lavaan solution", |
309 | | - ), |
310 | | - ) |
311 | | - elseif length(lav_ind) > 1 |
312 | | - throw( |
313 | | - ErrorException( |
314 | | - "At least one parameter was found twice in the lavaan solution", |
315 | | - ), |
316 | | - ) |
317 | | - else |
318 | | - is_correct = isapprox( |
319 | | - estimate, |
320 | | - partable_lav[:, lav_col][lav_ind[1]]; |
321 | | - rtol = rtol, |
322 | | - atol = atol, |
323 | | - ) |
324 | | - push!(correct, is_correct) |
325 | | - end |
326 | | - end |
327 | | - end |
328 | | - end |
| 114 | + actual = fill(NaN, nparams(ens_partable)) |
| 115 | + expected = fill(NaN, nparams(ens_partable)) |
| 116 | + for (key, partable) in pairs(ens_partable.tables) |
| 117 | + StructuralEquationModels.param_values!(actual, partable, col) |
| 118 | + StructuralEquationModels.lavaan_param_values!( |
| 119 | + expected, |
| 120 | + partable_lav, |
| 121 | + partable, |
| 122 | + lav_col, |
| 123 | + lav_groups[key], |
| 124 | + ) |
| 125 | + end |
| 126 | + @test !any(isnan, actual) |
| 127 | + @test !any(isnan, expected) |
| 128 | + |
| 129 | + if skip # workaround skip=false not supported in earlier versions |
| 130 | + @test actual ≈ expected rtol = rtol atol = atol norm = Base.Fix2(norm, Inf) skip = |
| 131 | + skip |
| 132 | + else |
| 133 | + @test actual ≈ expected rtol = rtol atol = atol norm = Base.Fix2(norm, Inf) |
329 | 134 | end |
330 | | - |
331 | | - return all(correct) |
332 | 135 | end |
0 commit comments