Skip to content

Commit dfe5658

Browse files
committed
refactor SemSpec tests
1 parent e61d4fb commit dfe5658

3 files changed

Lines changed: 127 additions & 11 deletions

File tree

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1212
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1313
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1414
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
15+
StenoGraphs = "78862bba-adae-4a83-bb4d-33c106177f81"
1516
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/unit_tests/specification.jl

Lines changed: 122 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,133 @@
1-
@testset "ParameterTable - RAMMatrices conversion" begin
2-
partable = ParameterTable(ram_matrices)
3-
@test ram_matrices == RAMMatrices(partable)
4-
end
1+
using StenoGraphs, StructuralEquationModels
2+
using StructuralEquationModels:
3+
vars, nvars, observed_vars, latent_vars, nobserved_vars, nlatent_vars, params, nparams
54

6-
@testset "params()" begin
7-
@test params(model_ml)[2, 10, 28] == [:x2, :x10, :x28]
8-
@test params(model_ml) == params(partable)
9-
@test params(model_ml) == params(RAMMatrices(partable))
10-
end
5+
obs_vars = Symbol.("x", 1:9)
6+
lat_vars = [:visual, :textual, :speed]
117

128
graph = @StenoGraph begin
9+
# measurement model
10+
visual fixed(1.0) * x1 + fixed(0.5) * x2 + fixed(0.6) * x3
11+
textual fixed(1.0) * x4 + x5 + label(:a₁) * x6
12+
speed fixed(1.0) * x7 + fixed(1.0) * x8 + label(:λ₉) * x9
13+
# variances and covariances
14+
_(obs_vars) _(obs_vars)
15+
_(lat_vars) _(lat_vars)
16+
visual textual + speed
17+
textual speed
18+
end
19+
20+
ens_graph = @StenoGraph begin
1321
# measurement model
1422
visual fixed(1.0, 1.0) * x1 + fixed(0.5, 0.5) * x2 + fixed(0.6, 0.8) * x3
1523
textual fixed(1.0, 1.0) * x4 + x5 + label(:a₁, :a₂) * x6
1624
speed fixed(1.0, 1.0) * x7 + fixed(1.0, NaN) * x8 + label(:λ₉, :λ₉) * x9
1725
# variances and covariances
18-
_(observed_vars) _(observed_vars)
19-
_(latent_vars) _(latent_vars)
26+
_(obs_vars) _(obs_vars)
27+
_(lat_vars) _(lat_vars)
2028
visual textual + speed
2129
textual speed
2230
end
31+
32+
@testset "ParameterTable" begin
33+
@testset "from StenoGraph" begin
34+
@test_throws UndefKeywordError(:observed_vars) ParameterTable(graph)
35+
@test_throws UndefKeywordError(:latent_vars) ParameterTable(
36+
graph,
37+
observed_vars = obs_vars,
38+
)
39+
partable = @inferred(
40+
ParameterTable(graph, observed_vars = obs_vars, latent_vars = lat_vars)
41+
)
42+
43+
@test partable isa ParameterTable
44+
45+
# vars API
46+
@test observed_vars(partable) == obs_vars
47+
@test nobserved_vars(partable) == length(obs_vars)
48+
@test latent_vars(partable) == lat_vars
49+
@test nlatent_vars(partable) == length(lat_vars)
50+
@test nvars(partable) == length(obs_vars) + length(lat_vars)
51+
@test issetequal(vars(partable), [obs_vars; lat_vars])
52+
53+
# params API
54+
@test params(partable) == [[:θ_1, :a₁, :λ₉]; Symbol.("θ_", 2:16)]
55+
@test nparams(partable) == 18
56+
57+
# don't allow constructing ParameterTable from a graph for an ensemble
58+
@test_throws ArgumentError ParameterTable(
59+
ens_graph,
60+
observed_vars = obs_vars,
61+
latent_vars = lat_vars,
62+
)
63+
end
64+
65+
@testset "from RAMMatrices" begin
66+
partable_orig =
67+
ParameterTable(graph, observed_vars = obs_vars, latent_vars = lat_vars)
68+
ram_matrices = RAMMatrices(partable_orig)
69+
70+
partable = @inferred(ParameterTable(ram_matrices))
71+
@test partable isa ParameterTable
72+
@test issetequal(keys(partable.columns), keys(partable_orig.columns))
73+
# FIXME nrow()?
74+
@test length(partable.columns[:from]) == length(partable_orig.columns[:from])
75+
@test partable == partable_orig broken = true
76+
end
77+
end
78+
79+
@testset "EnsembleParameterTable" begin
80+
groups = [:Pasteur, :Grant_White],
81+
@test_throws UndefKeywordError(:observed_vars) EnsembleParameterTable(ens_graph)
82+
@test_throws UndefKeywordError(:latent_vars) EnsembleParameterTable(
83+
ens_graph,
84+
observed_vars = obs_vars,
85+
)
86+
@test_throws UndefKeywordError(:groups) EnsembleParameterTable(
87+
ens_graph,
88+
observed_vars = obs_vars,
89+
latent_vars = lat_vars,
90+
)
91+
92+
enspartable = @inferred(
93+
EnsembleParameterTable(
94+
ens_graph,
95+
observed_vars = obs_vars,
96+
latent_vars = lat_vars,
97+
groups = [:Pasteur, :Grant_White],
98+
)
99+
)
100+
@test enspartable isa EnsembleParameterTable
101+
102+
@test nobserved_vars(enspartable) == length(obs_vars) broken = true
103+
@test observed_vars(enspartable) == obs_vars broken = true
104+
@test nlatent_vars(enspartable) == length(lat_vars) broken = true
105+
@test latent_vars(enspartable) == lat_vars broken = true
106+
@test nvars(enspartable) == length(obs_vars) + length(lat_vars) broken = true
107+
@test issetequal(vars(enspartable), [obs_vars; lat_vars]) broken = true
108+
109+
@test nparams(enspartable) == 36
110+
@test issetequal(
111+
params(enspartable),
112+
[Symbol.("gPasteur_", 1:16); Symbol.("gGrant_White_", 1:17); [:a₁, :a₂, :λ₉]],
113+
)
114+
end
115+
116+
@testset "RAMMatrices" begin
117+
partable = ParameterTable(graph, observed_vars = obs_vars, latent_vars = lat_vars)
118+
119+
ram_matrices = @inferred(RAMMatrices(partable))
120+
@test ram_matrices isa RAMMatrices
121+
122+
# vars API
123+
@test nobserved_vars(ram_matrices) == length(obs_vars)
124+
@test observed_vars(ram_matrices) == obs_vars
125+
@test nlatent_vars(ram_matrices) == length(lat_vars)
126+
@test latent_vars(ram_matrices) == lat_vars
127+
@test nvars(ram_matrices) == length(obs_vars) + length(lat_vars)
128+
@test issetequal(vars(ram_matrices), [obs_vars; lat_vars])
129+
130+
# params API
131+
@test nparams(ram_matrices) == nparams(partable)
132+
@test params(ram_matrices) == params(partable)
133+
end

test/unit_tests/unit_tests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ end
1111
@safetestset "SemObserved" begin
1212
include("data_input_formats.jl")
1313
end
14+
15+
@safetestset "SemSpecification" begin
16+
include("specification.jl")
17+
end

0 commit comments

Comments
 (0)