Skip to content

Commit 19b16d4

Browse files
committed
compare_estimates() -> test_estimates()
* do tests inside * use param_values()/lavaan_param_values()
1 parent d0b0294 commit 19b16d4

5 files changed

Lines changed: 79 additions & 298 deletions

File tree

test/examples/helper.jl

Lines changed: 43 additions & 240 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using LinearAlgebra: norm
2+
13
function test_gradient(model, params; rtol = 1e-10, atol = 0)
24
true_grad = FiniteDiff.finite_difference_gradient(Base.Fix1(objective!, model), params)
35
gradient = similar(params)
@@ -71,262 +73,63 @@ function test_fitmeasures(
7173
end
7274
end
7375

74-
function compare_estimates(
76+
function test_estimates(
7577
partable::ParameterTable,
7678
partable_lav;
7779
rtol = 1e-10,
7880
atol = 0,
7981
col = :estimate,
8082
lav_col = :est,
83+
lav_group = nothing,
84+
skip::Bool = false,
8185
)
82-
correct = []
83-
84-
for i in findall(partable.columns[:free])
85-
from = partable.columns[:from][i]
86-
to = partable.columns[:to][i]
87-
type = partable.columns[:parameter_type][i]
88-
estimate = partable.columns[col][i]
89-
90-
if from == Symbol("1")
91-
lav_ind =
92-
findall((partable_lav.lhs .== String(to)) .& (partable_lav.op .== "~1"))
93-
94-
if length(lav_ind) == 0
95-
throw(
96-
ErrorException(
97-
"Parameter from: $from, to: $to, type: $type, could not be found in the lavaan solution",
98-
),
99-
)
100-
elseif length(lav_ind) > 1
101-
throw(
102-
ErrorException(
103-
"At least one parameter was found twice in the lavaan solution",
104-
),
105-
)
106-
else
107-
is_correct = isapprox(
108-
estimate,
109-
partable_lav[:, lav_col][lav_ind[1]];
110-
rtol = rtol,
111-
atol = atol,
112-
)
113-
push!(correct, is_correct)
114-
end
115-
116-
else
117-
if type == :
118-
type = "~~"
119-
elseif type == :
120-
if (from partable.latent_vars) && (to partable.observed_vars)
121-
type = "=~"
122-
else
123-
type = "~"
124-
from, to = to, from
125-
end
126-
end
127-
128-
if type == "~~"
129-
lav_ind = findall(
130-
(
131-
(
132-
(partable_lav.lhs .== String(from)) .&
133-
(partable_lav.rhs .== String(to))
134-
) .| (
135-
(partable_lav.lhs .== String(to)) .&
136-
(partable_lav.rhs .== String(from))
137-
)
138-
) .& (partable_lav.op .== type),
139-
)
140-
141-
if length(lav_ind) == 0
142-
throw(
143-
ErrorException(
144-
"Parameter from: $from, to: $to, type: $type, could not be found in the lavaan solution",
145-
),
146-
)
147-
elseif length(lav_ind) > 1
148-
throw(
149-
ErrorException(
150-
"At least one parameter was found twice in the lavaan solution",
151-
),
152-
)
153-
else
154-
is_correct = isapprox(
155-
estimate,
156-
partable_lav[:, lav_col][lav_ind[1]];
157-
rtol = rtol,
158-
atol = atol,
159-
)
160-
push!(correct, is_correct)
161-
end
162-
163-
else
164-
lav_ind = findall(
165-
(partable_lav.lhs .== String(from)) .&
166-
(partable_lav.rhs .== String(to)) .&
167-
(partable_lav.op .== type),
168-
)
169-
170-
if length(lav_ind) == 0
171-
throw(
172-
ErrorException(
173-
"Parameter from: $from, to: $to, type: $type, could not be found in the lavaan solution",
174-
),
175-
)
176-
elseif length(lav_ind) > 1
177-
throw(
178-
ErrorException(
179-
"At least one parameter was found twice in the lavaan solution",
180-
),
181-
)
182-
else
183-
is_correct = isapprox(
184-
estimate,
185-
partable_lav[:, lav_col][lav_ind[1]];
186-
rtol = rtol,
187-
atol = atol,
188-
)
189-
push!(correct, is_correct)
190-
end
191-
end
192-
end
86+
actual = StructuralEquationModels.param_values(partable, col)
87+
expected = StructuralEquationModels.lavaan_param_values(
88+
partable_lav,
89+
partable,
90+
lav_col,
91+
lav_group,
92+
)
93+
@test !any(isnan, actual)
94+
@test !any(isnan, expected)
95+
96+
if skip # workaround skip=false not supported in earlier versions
97+
@test actual expected rtol = rtol atol = atol norm = Base.Fix2(norm, Inf) skip =
98+
skip
99+
else
100+
@test actual expected rtol = rtol atol = atol norm = Base.Fix2(norm, Inf)
193101
end
194-
195-
return all(correct)
196102
end
197103

198-
function compare_estimates(
104+
function test_estimates(
199105
ens_partable::EnsembleParameterTable,
200106
partable_lav;
201107
rtol = 1e-10,
202108
atol = 0,
203109
col = :estimate,
204110
lav_col = :est,
205-
lav_groups,
111+
lav_groups::AbstractDict,
112+
skip::Bool = false,
206113
)
207-
correct = []
208-
209-
for key in keys(ens_partable.tables)
210-
group = lav_groups[key]
211-
partable = ens_partable.tables[key]
212-
213-
for i in findall(partable.columns[:free])
214-
from = partable.columns[:from][i]
215-
to = partable.columns[:to][i]
216-
type = partable.columns[:parameter_type][i]
217-
estimate = partable.columns[col][i]
218-
219-
if from == Symbol("1")
220-
lav_ind = findall(
221-
(partable_lav.lhs .== String(to)) .&
222-
(partable_lav.op .== "~1") .&
223-
(partable_lav.group .== group),
224-
)
225-
226-
if length(lav_ind) == 0
227-
throw(
228-
ErrorException(
229-
"Mean parameter of variable $to could not be found in the lavaan solution",
230-
),
231-
)
232-
elseif length(lav_ind) > 1
233-
throw(
234-
ErrorException(
235-
"At least one parameter was found twice in the lavaan solution",
236-
),
237-
)
238-
else
239-
is_correct = isapprox(
240-
estimate,
241-
partable_lav[:, lav_col][lav_ind[1]];
242-
rtol = rtol,
243-
atol = atol,
244-
)
245-
push!(correct, is_correct)
246-
end
247-
248-
else
249-
if type == :
250-
type = "~~"
251-
elseif type == :
252-
if (from partable.latent_vars) && (to partable.observed_vars)
253-
type = "=~"
254-
else
255-
type = "~"
256-
from, to = to, from
257-
end
258-
end
259-
260-
if type == "~~"
261-
lav_ind = findall(
262-
(
263-
(
264-
(partable_lav.lhs .== String(from)) .&
265-
(partable_lav.rhs .== String(to))
266-
) .| (
267-
(partable_lav.lhs .== String(to)) .&
268-
(partable_lav.rhs .== String(from))
269-
)
270-
) .&
271-
(partable_lav.op .== type) .&
272-
(partable_lav.group .== group),
273-
)
274-
275-
if length(lav_ind) == 0
276-
throw(
277-
ErrorException(
278-
"Parameter from: $from, to: $to, type: $type, could not be found in the lavaan solution",
279-
),
280-
)
281-
elseif length(lav_ind) > 1
282-
throw(
283-
ErrorException(
284-
"At least one parameter was found twice in the lavaan solution",
285-
),
286-
)
287-
else
288-
is_correct = isapprox(
289-
estimate,
290-
partable_lav[:, lav_col][lav_ind[1]];
291-
rtol = rtol,
292-
atol = atol,
293-
)
294-
push!(correct, is_correct)
295-
end
296-
297-
else
298-
lav_ind = findall(
299-
(partable_lav.lhs .== String(from)) .&
300-
(partable_lav.rhs .== String(to)) .&
301-
(partable_lav.op .== type) .&
302-
(partable_lav.group .== group),
303-
)
304-
305-
if length(lav_ind) == 0
306-
throw(
307-
ErrorException(
308-
"Parameter $from $type $to could not be found in the lavaan solution",
309-
),
310-
)
311-
elseif length(lav_ind) > 1
312-
throw(
313-
ErrorException(
314-
"At least one parameter was found twice in the lavaan solution",
315-
),
316-
)
317-
else
318-
is_correct = isapprox(
319-
estimate,
320-
partable_lav[:, lav_col][lav_ind[1]];
321-
rtol = rtol,
322-
atol = atol,
323-
)
324-
push!(correct, is_correct)
325-
end
326-
end
327-
end
328-
end
114+
actual = fill(NaN, nparams(ens_partable))
115+
expected = fill(NaN, nparams(ens_partable))
116+
for (key, partable) in pairs(ens_partable.tables)
117+
StructuralEquationModels.param_values!(actual, partable, col)
118+
StructuralEquationModels.lavaan_param_values!(
119+
expected,
120+
partable_lav,
121+
partable,
122+
lav_col,
123+
lav_groups[key],
124+
)
125+
end
126+
@test !any(isnan, actual)
127+
@test !any(isnan, expected)
128+
129+
if skip # workaround skip=false not supported in earlier versions
130+
@test actual expected rtol = rtol atol = atol norm = Base.Fix2(norm, Inf) skip =
131+
skip
132+
else
133+
@test actual expected rtol = rtol atol = atol norm = Base.Fix2(norm, Inf)
329134
end
330-
331-
return all(correct)
332135
end

test/examples/multigroup/build_models.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
@testset "ml_solution_multigroup" begin
1818
solution = sem_fit(model_ml_multigroup)
1919
update_estimate!(partable, solution)
20-
@test compare_estimates(
20+
test_estimates(
2121
partable,
2222
solution_lav[:parameter_estimates_ml];
2323
atol = 1e-4,
@@ -35,7 +35,7 @@ end
3535
)
3636

3737
update_se_hessian!(partable, solution_ml)
38-
@test compare_estimates(
38+
test_estimates(
3939
partable,
4040
solution_lav[:parameter_estimates_ml];
4141
atol = 1e-3,
@@ -78,7 +78,7 @@ grad_fd = FiniteDiff.finite_difference_gradient(
7878
@testset "ml_solution_multigroup | sorted" begin
7979
solution = sem_fit(model_ml_multigroup)
8080
update_estimate!(partable_s, solution)
81-
@test compare_estimates(
81+
test_estimates(
8282
partable_s,
8383
solution_lav[:parameter_estimates_ml];
8484
atol = 1e-4,
@@ -96,7 +96,7 @@ end
9696
)
9797

9898
update_se_hessian!(partable_s, solution_ml)
99-
@test compare_estimates(
99+
test_estimates(
100100
partable_s,
101101
solution_lav[:parameter_estimates_ml];
102102
atol = 1e-3,
@@ -152,7 +152,7 @@ end
152152
@testset "solution_user_defined_loss" begin
153153
solution = sem_fit(model_ml_multigroup)
154154
update_estimate!(partable, solution)
155-
@test compare_estimates(
155+
test_estimates(
156156
partable,
157157
solution_lav[:parameter_estimates_ml];
158158
atol = 1e-4,
@@ -179,7 +179,7 @@ end
179179
@testset "ls_solution_multigroup" begin
180180
solution = sem_fit(model_ls_multigroup)
181181
update_estimate!(partable, solution)
182-
@test compare_estimates(
182+
test_estimates(
183183
partable,
184184
solution_lav[:parameter_estimates_ls];
185185
atol = 1e-4,
@@ -198,7 +198,7 @@ end
198198
)
199199

200200
update_se_hessian!(partable, solution_ls)
201-
@test compare_estimates(
201+
test_estimates(
202202
partable,
203203
solution_lav[:parameter_estimates_ls];
204204
atol = 1e-2,
@@ -266,7 +266,7 @@ if !isnothing(specification_miss_g1)
266266
@testset "fiml_solution_multigroup" begin
267267
solution = sem_fit(model_ml_multigroup)
268268
update_estimate!(partable_miss, solution)
269-
@test compare_estimates(
269+
test_estimates(
270270
partable_miss,
271271
solution_lav[:parameter_estimates_fiml];
272272
atol = 1e-4,
@@ -284,7 +284,7 @@ if !isnothing(specification_miss_g1)
284284
)
285285

286286
update_se_hessian!(partable_miss, solution)
287-
@test compare_estimates(
287+
test_estimates(
288288
partable_miss,
289289
solution_lav[:parameter_estimates_fiml];
290290
atol = 1e-3,

0 commit comments

Comments
 (0)