Skip to content

Commit e272f4a

Browse files
alystAlexey Stukalov
authored andcommitted
SemImplied/SemLossFun: drop meanstructure kwarg
- for SemImplied require spec::SemSpec as positional - for SemLossFunction require implied argument
1 parent a81709c commit e272f4a

9 files changed

Lines changed: 59 additions & 99 deletions

File tree

src/frontend/specification/Sem.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,15 @@ function get_fields!(kwargs, specification, observed, implied, loss)
109109

110110
# implied
111111
if !isa(implied, SemImplied)
112-
implied = implied(; specification, kwargs...)
112+
# FIXME remove this implicit logic
113+
# SemWLS only accepts vech-ed implied covariance
114+
if isa(loss, Type) && (loss <: SemWLS) && !haskey(kwargs, :vech)
115+
implied_kwargs = copy(kwargs)
116+
implied_kwargs[:vech] = true
117+
else
118+
implied_kwargs = kwargs
119+
end
120+
implied = implied(specification; implied_kwargs...)
113121
end
114122

115123
kwargs[:implied] = implied

src/implied/RAM/generic.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,10 @@ Model implied covariance and means via RAM notation.
66
77
# Constructor
88
9-
RAM(;specification,
10-
meanstructure = false,
11-
gradient = true,
12-
kwargs...)
9+
RAM(; specification, gradient = true, kwargs...)
1310
1411
# Arguments
1512
- `specification`: either a `RAMMatrices` or `ParameterTable` object
16-
- `meanstructure::Bool`: does the model have a meanstructure?
1713
- `gradient::Bool`: is gradient-based optimization used
1814
1915
# Extended help
@@ -53,9 +49,9 @@ Vector of indices of each parameter in the respective RAM matrix:
5349
- `ram.M_indices`
5450
5551
Additional interfaces
56-
- `ram.F⨉I_A⁻¹` -> ``F(I-A)^{-1}``
57-
- `ram.F⨉I_A⁻¹S` -> ``F(I-A)^{-1}S``
58-
- `ram.I_A` -> ``I-A``
52+
- `F⨉I_A⁻¹(::RAM)` -> ``F(I-A)^{-1}``
53+
- `F⨉I_A⁻¹S(::RAM)` -> ``F(I-A)^{-1}S``
54+
- `I_A(::RAM)` -> ``I-A``
5955
6056
Only available in gradient! calls:
6157
- `ram.I_A⁻¹` -> ``(I-A)^{-1}``
@@ -90,15 +86,13 @@ end
9086
### Constructors
9187
############################################################################################
9288

93-
function RAM(;
94-
specification::SemSpecification,
89+
function RAM(
90+
spec::SemSpecification;
91+
#vech = false,
9592
gradient_required = true,
96-
meanstructure = false,
9793
kwargs...,
9894
)
99-
ram_matrices = convert(RAMMatrices, specification)
100-
101-
check_meanstructure_specification(meanstructure, ram_matrices)
95+
ram_matrices = convert(RAMMatrices, spec)
10296

10397
# get dimensions of the model
10498
n_par = nparams(ram_matrices)
@@ -126,7 +120,7 @@ function RAM(;
126120
end
127121

128122
# μ
129-
if meanstructure
123+
if !isnothing(ram_matrices.M)
130124
MS = HasMeanStruct
131125
M_pre = materialize(ram_matrices.M, rand_params)
132126
∇M = gradient_required ? sparse_gradient(ram_matrices.M) : nothing

src/implied/RAM/symbolic.jl

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@ Subtype of `SemImplied` that implements the RAM notation with symbolic precomput
1212
gradient = true,
1313
hessian = false,
1414
approximate_hessian = false,
15-
meanstructure = false,
1615
kwargs...)
1716
1817
# Arguments
1918
- `specification`: either a `RAMMatrices` or `ParameterTable` object
20-
- `meanstructure::Bool`: does the model have a meanstructure?
2119
- `gradient::Bool`: is gradient-based optimization used
2220
- `hessian::Bool`: is hessian-based optimization used
2321
- `approximate_hessian::Bool`: for hessian based optimization: should the hessian be approximated
@@ -79,20 +77,16 @@ end
7977
### Constructors
8078
############################################################################################
8179

82-
function RAMSymbolic(;
83-
specification::SemSpecification,
84-
loss_types = nothing,
85-
vech = false,
86-
simplify_symbolics = false,
87-
gradient = true,
88-
hessian = false,
89-
meanstructure = false,
90-
approximate_hessian = false,
80+
function RAMSymbolic(
81+
spec::SemSpecification;
82+
vech::Bool = false,
83+
simplify_symbolics::Bool = false,
84+
gradient::Bool = true,
85+
hessian::Bool = false,
86+
approximate_hessian::Bool = false,
9187
kwargs...,
9288
)
93-
ram_matrices = convert(RAMMatrices, specification)
94-
95-
check_meanstructure_specification(meanstructure, ram_matrices)
89+
ram_matrices = convert(RAMMatrices, spec)
9690

9791
n_par = nparams(ram_matrices)
9892
par = (Symbolics.@variables θ[1:n_par])[1]
@@ -102,10 +96,6 @@ function RAMSymbolic(;
10296
M = !isnothing(ram_matrices.M) ? materialize(Num, ram_matrices.M, par) : nothing
10397
F = ram_matrices.F
10498

105-
if !isnothing(loss_types) && any(T -> T <: SemWLS, loss_types)
106-
vech = true
107-
end
108-
10999
I_A⁻¹ = neumann_series(A)
110100

111101
# Σ
@@ -146,7 +136,7 @@ function RAMSymbolic(;
146136
end
147137

148138
# μ
149-
if meanstructure
139+
if !isnothing(ram_matrices.M)
150140
MS = HasMeanStruct
151141
μ_sym = eval_μ_symbolic(M, I_A⁻¹, F; simplify = simplify_symbolics)
152142
μ_eval! = Symbolics.build_function(μ_sym, par, expression = Val{false})[2]
@@ -222,10 +212,10 @@ end
222212
############################################################################################
223213

224214
# expected covariations of observed vars
225-
function eval_Σ_symbolic(S, I_A⁻¹, F; vech = false, simplify = false)
215+
function eval_Σ_symbolic(S, I_A⁻¹, F; vech::Bool = false, simplify::Bool = false)
226216
Σ = F * I_A⁻¹ * S * permutedims(I_A⁻¹) * permutedims(F)
227217
Σ = Array(Σ)
228-
vech &&= Σ[tril(trues(size(F, 1), size(F, 1)))])
218+
vech &&= SEM.vech(Σ))
229219
if simplify
230220
Threads.@threads for i in eachindex(Σ)
231221
Σ[i] = Symbolics.simplify(Σ[i])

src/implied/abstract.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,3 @@ function check_acyclic(A::AbstractMatrix; verbose::Bool = false)
3131
return A
3232
end
3333
end
34-
35-
# Verify that the `meanstructure` argument aligns with the model specification.
36-
function check_meanstructure_specification(meanstructure, ram_matrices)
37-
if meanstructure & isnothing(ram_matrices.M)
38-
throw(ArgumentError(
39-
"You set `meanstructure = true`, but your model specification contains no mean parameters."
40-
))
41-
end
42-
if !meanstructure & !isnothing(ram_matrices.M)
43-
throw(ArgumentError(
44-
"If your model specification contains mean parameters, you have to set `Sem(..., meanstructure = true)`."
45-
))
46-
end
47-
end

src/loss/ML/FIML.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,16 @@ Can handle observed data with missing values.
7575
7676
# Constructor
7777
78-
SemFIML(; observed::SemObservedMissing, specification, kwargs...)
78+
SemFIML(; observed::SemObservedMissing, implied::SemImplied, kwargs...)
7979
8080
# Arguments
81-
- `observed`: the observed data with missing values (see [`SemObservedMissing`](@ref))
82-
- `specification`: [`SemSpecification`](@ref) object
81+
- `observed::SemObservedMissing`: the observed part of the model
82+
(see [`SemObservedMissing`](@ref))
83+
- `implied::SemImplied`: the implied part of the model
8384
8485
# Examples
8586
```julia
86-
my_fiml = SemFIML(observed = my_observed, specification = my_parameter_table)
87+
my_fiml = SemFIML(observed = my_observed, implied = my_implied)
8788
```
8889
8990
# Interfaces
@@ -118,7 +119,7 @@ function SemFIML(; observed::SemObservedMissing, implied, specification, kwargs.
118119
ExactHessian(),
119120
[SemFIMLPattern(pat) for pat in observed.patterns],
120121
zeros(nobserved_vars(observed), nobserved_vars(observed)),
121-
CommutationMatrix(nvars(specification)),
122+
CommutationMatrix(nvars(implied)),
122123
nothing,
123124
)
124125
end

src/loss/ML/ML.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ Maximum likelihood estimation.
88
99
# Constructor
1010
11-
SemML(;observed, meanstructure = false, approximate_hessian = false, kwargs...)
11+
SemML(; observed, approximate_hessian = false, kwargs...)
1212
1313
# Arguments
1414
- `observed::SemObserved`: the observed part of the model
15-
- `meanstructure::Bool`: does the model have a meanstructure?
1615
- `approximate_hessian::Bool`: if hessian-based optimization is used, should the hessian be swapped for an approximation
1716
1817
# Examples

src/loss/WLS/WLS.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,15 @@ At the moment only available with the `RAMSymbolic` implied type.
1010
# Constructor
1111
1212
SemWLS(;
13-
observed,
14-
meanstructure = false,
13+
observed, implied,
1514
wls_weight_matrix = nothing,
1615
wls_weight_matrix_mean = nothing,
1716
approximate_hessian = false,
1817
kwargs...)
1918
2019
# Arguments
2120
- `observed`: the `SemObserved` part of the model
22-
- `meanstructure::Bool`: does the model have a meanstructure?
21+
- `implied::SemImplied`: the implied part of the model
2322
- `approximate_hessian::Bool`: should the hessian be swapped for an approximation
2423
- `wls_weight_matrix`: the weight matrix for weighted least squares.
2524
Defaults to GLS estimation (``0.5*(D^T*kron(S,S)*D)`` where D is the duplication matrix
@@ -29,7 +28,7 @@ At the moment only available with the `RAMSymbolic` implied type.
2928
3029
# Examples
3130
```julia
32-
my_wls = SemWLS(observed = my_observed)
31+
my_wls = SemWLS(observed = my_observed, implied = my_implied)
3332
```
3433
3534
# Interfaces
@@ -50,12 +49,11 @@ SemWLS{HE}(args...) where {HE <: HessianEval} =
5049
SemWLS{HE, map(typeof, args)...}(HE(), args...)
5150

5251
function SemWLS(;
53-
observed,
54-
implied,
55-
wls_weight_matrix = nothing,
56-
wls_weight_matrix_mean = nothing,
57-
approximate_hessian = false,
58-
meanstructure = false,
52+
observed::SemObserved,
53+
implied::SemImplied,
54+
wls_weight_matrix::Union{AbstractMatrix, Nothing} = nothing,
55+
wls_weight_matrix_mean::Union{AbstractMatrix, Nothing} = nothing,
56+
approximate_hessian::Bool = false,
5957
kwargs...,
6058
)
6159
if observed isa SemObservedMissing
@@ -81,6 +79,10 @@ function SemWLS(;
8179
nobs_vars = nobserved_vars(observed)
8280
tril_ind = filter(x -> (x[1] >= x[2]), CartesianIndices(obs_cov(observed)))
8381
s = obs_cov(observed)[tril_ind]
82+
size(s) == size(implied.Σ) ||
83+
throw(DimensionMismatch("SemWLS requires implied covariance to be in vech-ed form " *
84+
"(vectorized lower triangular part of Σ matrix): $(size(s)) expected, $(size(implied.Σ)) found.\n" *
85+
"$(nameof(typeof(implied))) must be constructed with vech=true."))
8486

8587
# compute V here
8688
if isnothing(wls_weight_matrix)
@@ -94,9 +96,12 @@ function SemWLS(;
9496
"wls_weight_matrix has to be of size $(length(tril_ind))×$(length(tril_ind))",
9597
)
9698
end
99+
size(wls_weight_matrix) == (length(s), length(s)) ||
100+
DimensionMismatch("wls_weight_matrix has to be of size $(length(s))×$(length(s))")
97101

98-
if meanstructure
102+
if MeanStruct(implied) == HasMeanStruct
99103
if isnothing(wls_weight_matrix_mean)
104+
@warn "Computing WLS weight matrix for the meanstructure using obs_cov()"
100105
wls_weight_matrix_mean = inv(obs_cov(observed))
101106
else
102107
size(wls_weight_matrix_mean) == (nobs_vars, nobs_vars) || DimensionMismatch(

test/examples/recover_parameters/recover_parameters_twofact.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ start = [
5353
repeat([0.5], 4)
5454
]
5555

56-
implied_ml = RAMSymbolic(; specification = ram_matrices, start_val = start)
56+
implied_ml = RAMSymbolic(ram_matrices; start_val = start)
5757

5858
implied_ml.Σ_eval!(implied_ml.Σ, true_val)
5959

test/unit_tests/model.jl

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,16 @@ function test_params_api(semobj, spec::SemSpecification)
4646
@test @inferred(param_labels(semobj)) == param_labels(spec)
4747
end
4848

49-
@testset "Sem(implied=$impliedtype, loss=SemML)" for impliedtype in (RAM, RAMSymbolic)
50-
49+
@testset "Sem(implied=$impliedtype, loss=$losstype)" for (impliedtype, losstype) in [
50+
(RAM, SemML),
51+
(RAMSymbolic, SemML),
52+
(RAMSymbolic, SemWLS),
53+
]
5154
model = Sem(
5255
specification = ram_matrices,
5356
observed = obs,
5457
implied = impliedtype,
55-
loss = SemML,
58+
loss = losstype,
5659
)
5760

5861
@test model isa Sem
@@ -71,29 +74,3 @@ end
7174

7275
@test @inferred(nsamples(model)) == nsamples(obs)
7376
end
74-
75-
@testset "Sem(implied=RAMSymbolic, loss=SemWLS)" begin
76-
77-
model = Sem(
78-
specification = ram_matrices,
79-
observed = obs,
80-
implied = RAMSymbolic,
81-
loss = SemWLS,
82-
)
83-
84-
@test model isa Sem
85-
@test @inferred(implied(model)) isa RAMSymbolic
86-
@test @inferred(observed(model)) isa SemObserved
87-
88-
test_vars_api(model, ram_matrices)
89-
test_params_api(model, ram_matrices)
90-
91-
test_vars_api(implied(model), ram_matrices)
92-
test_params_api(implied(model), ram_matrices)
93-
94-
@test @inferred(loss(model)) isa SemLoss
95-
semloss = loss(model).functions[1]
96-
@test semloss isa SemWLS
97-
98-
@test @inferred(nsamples(model)) == nsamples(obs)
99-
end

0 commit comments

Comments
 (0)