Skip to content

Commit f7f7452

Browse files
author
Alexey Stukalov
committed
check_same_semterm_type(): refactor check_single_lossfun()
1 parent bab1317 commit f7f7452

5 files changed

Lines changed: 50 additions & 47 deletions

File tree

src/additional_functions/helper.jl

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -116,35 +116,6 @@ function nonunique(values::AbstractVector)
116116
return res
117117
end
118118

119-
# check that a model only has a single lossfun
120-
function check_single_lossfun(model::AbstractSemSingle; throw_error)
121-
if (length(model.loss.functions) > 1) & throw_error
122-
@error "The model has $(length(sem.loss.functions)) loss functions.
123-
Only a single loss function is supported."
124-
end
125-
return isone(length(model.loss.functions))
126-
end
127-
128-
# check that all models use the same single loss function
129-
function check_single_lossfun(models::AbstractSemSingle...; throw_error)
130-
uniform = true
131-
lossfun = models[1].loss.functions[1]
132-
L = typeof(lossfun)
133-
for (i, model) in enumerate(models)
134-
uniform &= check_single_lossfun(model; throw_error = throw_error)
135-
cur_lossfun = model.loss.functions[1]
136-
if !isa(cur_lossfun, L) & throw_error
137-
@error "Loss function for group #$i model is $(typeof(cur_lossfun)), expected $L.
138-
Heterogeneous loss functions are not supported."
139-
end
140-
uniform &= isa(cur_lossfun, L)
141-
end
142-
return uniform
143-
end
144-
145-
check_single_lossfun(model::SemEnsemble; throw_error) =
146-
check_single_lossfun(model.sems...; throw_error)
147-
148119
# scaling corrections for multigroup models
149120
mg_correction(::SemFIML) = 0
150121
mg_correction(::SemML) = 0

src/frontend/fit/fitmeasures/RMSEA.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ RMSEA_corr_scale(::Type{<:SemML}) = -1
2727
RMSEA_corr_scale(::Type{<:SemWLS}) = -1
2828

2929
function RMSEA(fit::SemFit, model::AbstractSem)
30-
term_type = check_single_lossfun(model; throw_error = true)
30+
term_type = check_same_semterm_type(model; throw_error = true)
3131
n = nsamples(fit) + nsem_terms(model) * RMSEA_corr_scale(term_type)
3232
sqrt(nsem_terms(model)) * RMSEA(dof(fit), χ²(fit), n)
3333
end

src/frontend/fit/fitmeasures/chi2.jl

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,10 @@ with the *observed* covariance matrix.
1414

1515
function χ²(fit::SemFit, model::AbstractSem)
1616
terms = sem_terms(model)
17-
isempty(terms) && return 0.0
17+
@assert !isempty(terms)
1818

19-
term1 = _unwrap(loss(terms[1]))
20-
L = typeof(term1).name
21-
22-
# check that all SemLoss terms are of the same class (ML, FIML, WLS etc), ignore typeparams
23-
for (i, term) in enumerate(terms)
24-
lossterm = _unwrap(loss(term))
25-
@assert lossterm isa SemLoss
26-
if typeof(_unwrap(lossterm)).name != L
27-
@error "SemLoss term #$i is $(typeof(_unwrap(lossterm)).name), expected $L. Heterogeneous loss functions are not supported"
28-
end
29-
end
30-
31-
return χ²(typeof(term1), fit, model)
19+
L = check_same_semterm_type(model; throw_error = true)
20+
return χ²(L, fit, model)
3221
end
3322

3423
# bollen, p. 115, only correct for GLS weight matrix

src/frontend/fit/fitmeasures/minus2ll.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,6 @@ end
6262
############################################################################################
6363

6464
function minus2ll(model::AbstractSem, fit::SemFit)
65-
check_single_lossfun(model; throw_error = true)
65+
check_same_semterm_type(model; throw_error = true)
6666
sum(Base.Fix2(minus2ll, fit) _unwrap loss, sem_terms(model))
6767
end

src/frontend/specification/Sem.jl

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ end
5151

5252
function multigroup_weights(models, n)
5353
nsamples_total = sum(nsamples, models)
54-
uniform_lossfun = check_single_lossfun(models...; throw_error = false)
55-
if !uniform_lossfun
54+
semloss_type = check_same_semterm_type(semterms; throw_error = false)
55+
if isnothing(semloss_type)
5656
@info """
5757
Your ensemble model contains heterogeneous loss functions.
5858
Default weights of (#samples per group/#total samples) will be used
@@ -258,6 +258,49 @@ function sem_term(model::AbstractSem, _::Nothing = nothing)
258258
error("Unreachable reached")
259259
end
260260

261+
# check that all models use the same single loss function
262+
# returns the type of the single SEM loss function, SemLoss if there are multiple different SEM losses,
263+
# nothing if there are no SEM terms.
264+
# If throw_error=true, throws an error if there are multiple different SEM loss functions
265+
check_same_semterm_type(model::AbstractSem; throw_error::Bool = true) =
266+
check_same_semterm_type(sem_terms(model); throw_error = throw_error)
267+
268+
# check that all models use the same single loss function
269+
# returns the type of the single SEM loss function,
270+
# nothing if there are multiple different SEM losses or no SEM terms.
271+
# If throw_error=true, throws an error if there are multiple different SEM loss functions
272+
function check_same_semterm_type(terms::Tuple; throw_error::Bool = true)
273+
isempty(terms) && return nothing
274+
275+
_semloss(term::SemLoss) = _unwrap(term)
276+
_semloss(term::LossTerm) = _semloss(loss(term))
277+
_semloss(term) = throw(ArgumentError("SemLoss term expected, $(typeof(term)) found"))
278+
_semloss_label(i::Integer, _::Union{SemLoss, LossTerm{<:SemLoss, Nothing}}) = "#$i"
279+
_semloss_label(i::Integer, term::LossTerm{<:SemLoss, Symbol}) = "#$i ($(SEM.id(term)))"
280+
281+
term1 = _semloss(terms[1])
282+
L = typeof(term1).name
283+
284+
# check that all SemLoss terms are of the same class (ML, FIML, WLS etc), ignore typeparams
285+
for (i, term) in enumerate(terms)
286+
lossterm = _semloss(term)
287+
@assert lossterm isa SemLoss
288+
if typeof(lossterm).name != L
289+
if throw_error
290+
error(
291+
"SemLoss term $(_semloss_label(i, term)) is $(typeof(lossterm).name), expected $L. Heterogeneous loss functions are not supported",
292+
)
293+
else
294+
return nothing
295+
end
296+
end
297+
end
298+
299+
# return the type of the first SEM term
300+
# note that type params of the SEM terms might be different
301+
return typeof(term1)
302+
end
303+
261304
# wrappers arounds a single SemLoss term
262305
observed(model::AbstractSem, id::Nothing = nothing) = observed(sem_term(model, id))
263306
implied(model::AbstractSem, id::Nothing = nothing) = implied(sem_term(model, id))

0 commit comments

Comments
 (0)