Skip to content

Commit c1e7b04

Browse files
author
Alexey Stukalov
committed
fix tests after Sem refactor
1 parent 4489906 commit c1e7b04

3 files changed

Lines changed: 28 additions & 30 deletions

File tree

test/examples/multigroup/build_models.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,28 @@ const SEM = StructuralEquationModels
77
obs_g1 = SemObservedData(data = dat_g1, observed_vars = SEM.observed_vars(specification_g1))
88
obs_g2 = SemObservedData(data = dat_g2, observed_vars = SEM.observed_vars(specification_g2))
99

10-
model_ml_multigroup =
11-
Sem(SemML(obs_g1, RAMSymbolic(specification_g1)), SemML(obs_g2, RAM(specification_g2)))
10+
model_ml_multigroup = Sem(
11+
:Pasteur => SemML(obs_g1, RAMSymbolic(specification_g1)),
12+
:Grant_White => SemML(obs_g2, RAM(specification_g2)),
13+
)
1214

1315
@testset "Sem API" begin
1416
@test SEM.nsamples(model_ml_multigroup) == nsamples(obs_g1) + nsamples(obs_g2)
1517
@test SEM.nsem_terms(model_ml_multigroup) == 2
1618
@test length(SEM.sem_terms(model_ml_multigroup)) == 2
1719
end
1820

21+
# replace observed using Dict of data matrices
1922
model_ml_multigroup3 = replace_observed(
20-
model_ml_multigroup2,
21-
column = :school,
22-
specification = partable,
23-
data = dat,
23+
model_ml_multigroup,
24+
Dict(:Pasteur => dat_g1, :Grant_White => dat_g2),
25+
)
26+
27+
# replace observed using DataFrame with group column
28+
model_ml_multigroup4 = replace_observed(
29+
model_ml_multigroup,
30+
dat;
31+
semterm_column = :school,
2432
)
2533

2634
# gradients
@@ -42,8 +50,10 @@ end
4250

4351
@testset "replace_observed_multigroup" begin
4452
sem_fit_1 = fit(semoptimizer, model_ml_multigroup)
45-
sem_fit_2 = fit(semoptimizer, model_ml_multigroup3)
46-
@test sem_fit_1.solution sem_fit_2.solution
53+
sem_fit_3 = fit(semoptimizer, model_ml_multigroup3)
54+
@test sem_fit_1.solution sem_fit_3.solution
55+
sem_fit_4 = fit(semoptimizer, model_ml_multigroup4)
56+
@test sem_fit_1.solution sem_fit_4.solution
4757
end
4858

4959
@testset "fitmeasures/se_ml" begin
@@ -194,7 +204,7 @@ model_ls_multigroup = Sem(
194204
end
195205

196206
@testset "ls_solution_multigroup" begin
197-
solution = sem_fit(semoptimizer, model_ls_multigroup)
207+
solution = fit(semoptimizer, model_ls_multigroup)
198208
update_estimate!(partable, solution)
199209
test_estimates(
200210
partable,

test/examples/political_democracy/constructor.jl

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -149,19 +149,14 @@ end
149149
)
150150
# set seed for simulation
151151
Random.seed!(83472834)
152-
colnames = Symbol.(names(example_data("political_democracy")))
153152
# simulate data
154153
model_ml_new = replace_observed(
155154
model_ml,
156-
data = rand(model_ml, params, 1_000_000),
157-
specification = spec,
158-
observed_vars = colnames,
155+
rand(model_ml, params, 1_000_000),
159156
)
160157
model_ml_sym_new = replace_observed(
161158
model_ml_sym,
162-
data = rand(model_ml_sym, params, 1_000_000),
163-
specification = spec,
164-
observed_vars = colnames,
159+
rand(model_ml_sym, params, 1_000_000),
165160
)
166161
# fit models
167162
sol_ml = solution(fit(semoptimizer, model_ml_new))
@@ -346,21 +341,14 @@ end
346341
)
347342
# set seed for simulation
348343
Random.seed!(83472834)
349-
colnames = Symbol.(names(example_data("political_democracy")))
350344
# simulate data
351345
model_ml_new = replace_observed(
352346
model_ml,
353-
data = rand(model_ml, params, 1_000_000),
354-
specification = spec,
355-
observed_vars = colnames,
356-
meanstructure = true,
347+
rand(model_ml, params, 1_000_000),
357348
)
358349
model_ml_sym_new = replace_observed(
359350
model_ml_sym,
360-
data = rand(model_ml_sym, params, 1_000_000),
361-
specification = spec,
362-
observed_vars = colnames,
363-
meanstructure = true,
351+
rand(model_ml_sym, params, 1_000_000),
364352
)
365353
# fit models
366354
sol_ml = solution(fit(semoptimizer, model_ml_new))

test/examples/recover_parameters/recover_parameters_twofact.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ include(
99
),
1010
)
1111

12-
x = Symbol.("x", 1:13)
12+
pars = Symbol.("x", 1:13)
1313

1414
S = [
1515
:x1 0 0 0 0 0 0 0
@@ -42,7 +42,7 @@ A = [
4242
0 0 0 0 0 0 0 0
4343
]
4444

45-
ram_matrices = RAMMatrices(; A = A, S = S, F = F, param_labels = x, vars = nothing)
45+
ram_matrices = RAMMatrices(; A = A, S = S, F = F, param_labels = pars, vars = nothing)
4646

4747
true_val = [
4848
repeat([1], 8)
@@ -55,8 +55,6 @@ start = [
5555
repeat([0.5], 4)
5656
]
5757

58-
observed = SemObservedData(data = x, specification = ram_matrices)
59-
6058
implied_sym = RAMSymbolic(ram_matrices)
6159

6260
implied_sym.Σ_eval!(implied_sym.Σ, true_val)
@@ -66,7 +64,9 @@ true_dist = MultivariateNormal(implied_sym.Σ)
6664
Random.seed!(1234)
6765
x = permutedims(rand(true_dist, 10^5), (2, 1))
6866

69-
model_ml = Sem(observed, implied_sym; loss = SemML)
67+
observed = SemObservedData(data = x, specification = ram_matrices)
68+
69+
model_ml = Sem(SemML(observed, implied_sym))
7070

7171
objective!(model_ml, true_val)
7272

0 commit comments

Comments
 (0)