Skip to content

Commit c3071d0

Browse files
alystAlexey Stukalov
authored andcommitted
refactor Sem, SemEnsemble, SemLoss
1 parent e272f4a commit c3071d0

32 files changed

Lines changed: 1148 additions & 1187 deletions

File tree

src/StructuralEquationModels.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ include("frontend/specification/EnsembleParameterTable.jl")
4444
include("frontend/specification/StenoGraphs.jl")
4545
include("frontend/fit/summary.jl")
4646
include("frontend/StatsAPI.jl")
47+
include("frontend/finite_diff.jl")
4748
# pretty printing
4849
include("frontend/pretty_printing.jl")
4950
# observed
@@ -53,26 +54,28 @@ include("observed/covariance.jl")
5354
include("observed/missing_pattern.jl")
5455
include("observed/missing.jl")
5556
include("observed/EM.jl")
56-
# constructor
57-
include("frontend/specification/Sem.jl")
58-
include("frontend/specification/documentation.jl")
5957
# implied
6058
include("implied/abstract.jl")
6159
include("implied/RAM/symbolic.jl")
6260
include("implied/RAM/generic.jl")
6361
include("implied/empty.jl")
6462
# loss
63+
include("loss/ML/abstract.jl")
6564
include("loss/ML/ML.jl")
6665
include("loss/ML/FIML.jl")
6766
include("loss/regularization/ridge.jl")
6867
include("loss/WLS/WLS.jl")
6968
include("loss/constant/constant.jl")
69+
# constructor
70+
include("frontend/specification/Sem.jl")
71+
include("frontend/specification/documentation.jl")
7072
# optimizer
7173
include("optimizer/abstract.jl")
7274
include("optimizer/Empty.jl")
7375
include("optimizer/optim.jl")
7476
# helper functions
7577
include("additional_functions/helper.jl")
78+
include("additional_functions/start_val/common.jl")
7679
include("additional_functions/start_val/start_fabin3.jl")
7780
include("additional_functions/start_val/start_simple.jl")
7881
include("additional_functions/artifacts.jl")
@@ -91,14 +94,11 @@ include("frontend/fit/standard_errors/hessian.jl")
9194
include("frontend/fit/standard_errors/bootstrap.jl")
9295

9396
export AbstractSem,
94-
AbstractSemSingle,
95-
AbstractSemCollection,
9697
coef,
9798
coefnames,
9899
coeftable,
99100
Sem,
100101
SemFiniteDiff,
101-
SemEnsemble,
102102
MeanStruct,
103103
NoMeanStruct,
104104
HasMeanStruct,
@@ -113,15 +113,18 @@ export AbstractSem,
113113
start_val,
114114
start_fabin3,
115115
start_simple,
116+
AbstractLoss,
116117
SemLoss,
117-
SemLossFunction,
118118
SemML,
119119
SemFIML,
120120
em_mvn,
121121
SemRidge,
122122
SemConstant,
123123
SemWLS,
124124
loss,
125+
nsem_terms,
126+
sem_terms,
127+
sem_term,
125128
SemOptimizer,
126129
optimizer,
127130
optimizer_engine,

src/additional_functions/simulation.jl

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -43,36 +43,42 @@ function update_observed end
4343
# change observed (data) without reconstructing the whole model
4444
############################################################################################
4545

46+
# don't change non-SEM terms
47+
replace_observed(loss::AbstractLoss; kwargs...) = loss
48+
4649
# use the same observed type as before
47-
replace_observed(model::AbstractSemSingle; kwargs...) =
48-
replace_observed(model, typeof(observed(model)).name.wrapper; kwargs...)
50+
replace_observed(loss::SemLoss; kwargs...) =
51+
replace_observed(loss, typeof(SEM.observed(loss)).name.wrapper; kwargs...)
52+
53+
# construct a new observed type
54+
replace_observed(loss::SemLoss, observed_type; kwargs...) =
55+
replace_observed(loss, observed_type(; kwargs...); kwargs...)
4956

50-
function replace_observed(model::AbstractSemSingle, observed_type; kwargs...)
51-
new_observed = observed_type(; kwargs...)
57+
function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...)
5258
kwargs = Dict{Symbol, Any}(kwargs...)
59+
old_observed = SEM.observed(loss)
60+
implied = SEM.implied(loss)
5361

5462
# get field types
5563
kwargs[:observed_type] = typeof(new_observed)
56-
kwargs[:old_observed_type] = typeof(model.observed)
57-
kwargs[:implied_type] = typeof(model.implied)
58-
kwargs[:loss_types] = [typeof(lossfun) for lossfun in model.loss.functions]
64+
kwargs[:old_observed_type] = typeof(old_observed)
5965

6066
# update implied
61-
new_implied = update_observed(model.implied, new_observed; kwargs...)
67+
new_implied = update_observed(implied, new_observed; kwargs...)
6268
kwargs[:implied] = new_implied
69+
kwargs[:implied_type] = typeof(new_implied)
6370
kwargs[:nparams] = nparams(new_implied)
6471

6572
# update loss
66-
new_loss = update_observed(model.loss, new_observed; kwargs...)
67-
68-
return Sem(new_observed, new_implied, new_loss)
73+
return update_observed(loss, new_observed; kwargs...)
6974
end
7075

71-
function update_observed(loss::SemLoss, new_observed; kwargs...)
72-
new_functions = Tuple(
73-
update_observed(lossfun, new_observed; kwargs...) for lossfun in loss.functions
74-
)
75-
return SemLoss(new_functions, loss.weights)
76+
replace_observed(loss::LossTerm; kwargs...) =
77+
LossTerm(replace_observed(loss.loss; kwargs...), loss.id, loss.weight)
78+
79+
function replace_observed(sem::Sem; kwargs...)
80+
updated_terms = Tuple(replace_observed(term; kwargs...) for term in loss_terms(sem))
81+
return Sem(updated_terms...)
7682
end
7783

7884
function replace_observed(
@@ -111,39 +117,38 @@ end
111117
# simulate data
112118
############################################################################################
113119
"""
114-
(1) rand(model::AbstractSemSingle, params, n)
115-
116-
(2) rand(model::AbstractSemSingle, n)
120+
rand(sem::Union{Sem, SemLoss, SemImplied}, [params], n)
117121
118-
Sample normally distributed data from the model-implied covariance matrix and mean vector.
122+
Sample from the multivariate normal distribution implied by the SEM model.
119123
120124
# Arguments
121-
- `model::AbstractSemSingle`: model to simulate from.
122-
- `params`: parameter values to simulate from.
123-
- `n::Integer`: Number of samples.
125+
- `sem`: SEM model to use. Ensemble models with multiple SEM terms are not supported.
126+
- `params`: optional SEM model parameters to simulate from, otherwise uses the
127+
current state of implied covariances and means.
128+
- `n::Integer`: Number of samples to draw.
124129
125130
# Examples
126131
```julia
127132
rand(model, start_simple(model), 100)
128133
```
129134
"""
130-
function Distributions.rand(
131-
model::AbstractSemSingle{O, I, L},
132-
params,
133-
n::Integer,
134-
) where {O, I <: Union{RAM, RAMSymbolic}, L}
135-
update!(EvaluationTargets{true, false, false}(), model.implied, model, params)
136-
return rand(model, n)
137-
end
138-
139-
function Distributions.rand(
140-
model::AbstractSemSingle{O, I, L},
141-
n::Integer,
142-
) where {O, I <: Union{RAM, RAMSymbolic}, L}
143-
if MeanStruct(model.implied) === NoMeanStruct
144-
data = permutedims(rand(MvNormal(Symmetric(model.implied.Σ)), n))
145-
elseif MeanStruct(model.implied) === HasMeanStruct
146-
data = permutedims(rand(MvNormal(model.implied.μ, Symmetric(model.implied.Σ)), n))
135+
function Distributions.rand(implied::SemImplied, params, n::Integer)
136+
if !isnothing(params)
137+
# update the implied covariances with the new model params
138+
update!(EvaluationTargets{true, false, false}(), implied, params)
139+
end
140+
Σ = Symmetric(implied.Σ)
141+
if MeanStruct(implied) === NoMeanStruct
142+
return permutedims(rand(MvNormal(Σ), n))
143+
elseif MeanStruct(implied) === HasMeanStruct
144+
return permutedims(rand(MvNormal(implied.μ, Σ), n))
147145
end
148-
return data
149146
end
147+
148+
Distributions.rand(loss::SemLoss, params, n::Integer) = rand(SEM.implied(loss), params, n)
149+
150+
Distributions.rand(model::Sem, params, n::Integer) = rand(sem_term(model), params, n)
151+
152+
# rand() overloads without SEM params
153+
Distributions.rand(implied::Union{SemImplied, SemLoss, Sem}, n::Integer) =
154+
Distributions.rand(implied, nothing, n)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
# start values for SEM Models (including ensembles)
3+
function start_values(f, model::AbstractSem; kwargs...)
4+
start_vals = fill(0.0, nparams(model))
5+
6+
# initialize parameters using the SEM loss terms
7+
# (first SEM loss term that sets given parameter to nonzero value)
8+
for term in loss_terms(model)
9+
issemloss(term) || continue
10+
term_start_vals = f(loss(term); kwargs...)
11+
for (i, val) in enumerate(term_start_vals)
12+
iszero(val) || (start_vals[i] = val)
13+
end
14+
end
15+
16+
return start_vals
17+
end

src/additional_functions/start_val/start_fabin3.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@ Not available for ensemble models.
77
function start_fabin3 end
88

99
# splice model and loss functions
10-
function start_fabin3(model::AbstractSemSingle; kwargs...)
11-
return start_fabin3(model.observed, model.implied, model.loss.functions..., kwargs...)
10+
function start_fabin3(model::SemLoss; kwargs...)
11+
return start_fabin3(model.observed, model.implied; kwargs...)
1212
end
1313

14-
function start_fabin3(observed::SemObserved, implied::SemImplied, args...; kwargs...)
15-
return start_fabin3(implied.ram_matrices, obs_cov(observed), obs_mean(observed))
14+
function start_fabin3(observed::SemObserved, implied::SemImplied; kwargs...)
15+
return start_fabin3(
16+
implied.ram_matrices,
17+
obs_cov(observed),
18+
# ignore observed means if no meansturcture
19+
!isnothing(implied.ram_matrices.M) ? obs_mean(observed) : nothing,
20+
)
1621
end
1722

1823
function start_fabin3(
@@ -161,3 +166,6 @@ end
161166
function is_in_Λ(ind_vec, F_ind)
162167
return any(ind -> !(ind[2] F_ind) && (ind[1] F_ind), ind_vec)
163168
end
169+
170+
# ensembles
171+
start_fabin3(model::AbstractSem; kwargs...) = start_values(start_fabin3, model; kwargs...)

src/additional_functions/start_val/start_simple.jl

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,11 @@ Return a vector of simple starting values.
1515
"""
1616
function start_simple end
1717

18-
# Single Models ----------------------------------------------------------------------------
19-
function start_simple(model::AbstractSemSingle; kwargs...)
20-
return start_simple(model.observed, model.implied, model.loss.functions...; kwargs...)
21-
end
22-
23-
function start_simple(observed, implied, args...; kwargs...)
24-
return start_simple(implied.ram_matrices; kwargs...)
25-
end
26-
27-
# Ensemble Models --------------------------------------------------------------------------
28-
function start_simple(model::SemEnsemble; kwargs...)
29-
start_vals = []
30-
31-
for sem in model.sems
32-
push!(start_vals, start_simple(sem; kwargs...))
33-
end
34-
35-
has_start_val = [.!iszero.(start_val) for start_val in start_vals]
18+
start_simple(model::SemLoss; kwargs...) =
19+
start_simple(observed(model), implied(model); kwargs...)
3620

37-
start_val = similar(start_vals[1])
38-
start_val .= 0.0
39-
40-
for (j, indices) in enumerate(has_start_val)
41-
start_val[indices] .= start_vals[j][indices]
42-
end
43-
44-
return start_val
45-
end
21+
start_simple(observed::SemObserved, implied::SemImplied; kwargs...) =
22+
start_simple(implied.ram_matrices; kwargs...)
4623

4724
function start_simple(
4825
ram_matrices::RAMMatrices;
@@ -103,3 +80,6 @@ function start_simple(
10380
end
10481
return start_val
10582
end
83+
84+
# multigroup models
85+
start_simple(model::AbstractSem; kwargs...) = start_values(start_simple, model; kwargs...)

src/frontend/finite_diff.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
_unwrap(wrapper::SemFiniteDiff) = wrapper.model
2+
params(wrapper::SemFiniteDiff) = params(wrapper.model)
3+
loss_terms(wrapper::SemFiniteDiff) = loss_terms(wrapper.model)
4+
5+
FiniteDiffLossWrappers = Union{LossFiniteDiff, SemLossFiniteDiff}
6+
7+
_unwrap(term::AbstractLoss) = term
8+
_unwrap(wrapper::FiniteDiffLossWrappers) = wrapper.loss
9+
implied(wrapper::FiniteDiffLossWrappers) = implied(_unwrap(wrapper))
10+
observed(wrapper::FiniteDiffLossWrappers) = observed(_unwrap(wrapper))
11+
12+
FiniteDiffWrapper(model::AbstractSem) = SemFiniteDiff(model)
13+
FiniteDiffWrapper(loss::AbstractLoss) = LossFiniteDiff(loss)
14+
FiniteDiffWrapper(loss::SemLoss) = SemLossFiniteDiff(loss)
15+
16+
function evaluate!(
17+
objective,
18+
gradient,
19+
hessian,
20+
sem::Union{SemFiniteDiff, FiniteDiffLossWrappers},
21+
params,
22+
)
23+
wrapped = _unwrap(sem)
24+
obj(p) = _evaluate!(
25+
objective_zero(objective, gradient, hessian),
26+
nothing,
27+
nothing,
28+
wrapped,
29+
p,
30+
)
31+
isnothing(gradient) || FiniteDiff.finite_difference_gradient!(gradient, obj, params)
32+
isnothing(hessian) || FiniteDiff.finite_difference_hessian!(hessian, obj, params)
33+
# FIXME if objective is not calculated, SemLoss implied states may not correspond to params
34+
return !isnothing(objective) ? obj(params) : nothing
35+
end

src/frontend/fit/fitmeasures/RMSEA.jl

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,18 @@ for the SEM model.
1919
For multigroup models, the correction proposed by J.H. Steiger is applied
2020
(see [Steiger, J. H. (1998). *A note on multiple sample extensions of the RMSEA fit index*](https://doi.org/10.1080/10705519809540115)).
2121
"""
22-
function RMSEA end
23-
2422
RMSEA(fit::SemFit) = RMSEA(fit, fit.model)
2523

26-
function RMSEA(fit::SemFit, model::AbstractSemSingle)
27-
check_single_lossfun(model; throw_error = true)
28-
return RMSEA(dof(fit), χ²(fit), nsamples(fit)+rmsea_correction(model.loss.functions[1]))
29-
end
30-
31-
function RMSEA(fit::SemFit, model::SemEnsemble)
32-
check_single_lossfun(model; throw_error = true)
33-
n = nsamples(fit)+model.n*rmsea_correction(model.sems[1].loss.functions[1])
34-
return sqrt(length(model.sems)) * RMSEA(dof(fit), χ²(fit), n)
35-
end
36-
37-
function RMSEA(dof, chi2, N⁻)
38-
rmsea = (chi2 - dof) / (N⁻ * dof)
39-
rmsea = rmsea > 0 ? rmsea : 0
40-
return sqrt(rmsea)
24+
# scaling corrections
25+
RMSEA_corr_scale(::Type{<:SemFIML}) = 0
26+
RMSEA_corr_scale(::Type{<:SemML}) = -1
27+
RMSEA_corr_scale(::Type{<:SemWLS}) = -1
28+
29+
function RMSEA(fit::SemFit, model::AbstractSem)
30+
term_type = check_single_lossfun(model; throw_error = true)
31+
n = nsamples(fit) + nsem_terms(model) * RMSEA_corr_scale(term_type)
32+
sqrt(nsem_terms(model)) * RMSEA(dof(fit), χ²(fit), n)
4133
end
4234

43-
# scaling corrections
44-
rmsea_correction(::SemFIML) = 0
45-
rmsea_correction(::SemML) = -1
46-
rmsea_correction(::SemWLS) = -1
35+
RMSEA(dof::Number, chi2::Number, nsamples::Number) =
36+
sqrt(max((chi2 - dof) / (nsamples * dof), 0.0))

0 commit comments

Comments
 (0)