Skip to content

Commit 245ef87

Browse files
author
Alexey Stukalov
committed
check_same_semterm_type(): refactor check_single_lossfun()
1 parent 05ede2e commit 245ef87

4 files changed

Lines changed: 39 additions & 25 deletions

File tree

src/frontend/fit/fitmeasures/RMSEA.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ For multigroup models, the correction proposed by J.H. Steiger is applied
2222
RMSEA(fit::SemFit) = RMSEA(fit, fit.model)
2323

2424
function RMSEA(fit::SemFit, model::AbstractSem)
25-
term_type = check_single_lossfun(model; throw_error = true)
25+
term_type = check_same_semterm_type(model; throw_error = true)
2626
n = nsamples(fit) + nsem_terms(model) * multigroup_correction_scale(term_type)
2727
sqrt(nsem_terms(model)) * RMSEA(dof(fit), χ²(fit), n)
2828
end

src/frontend/fit/fitmeasures/chi2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function χ²(fit::SemFit, model::AbstractSem)
1616
terms = sem_terms(model)
1717
isempty(terms) && return 0.0
1818

19-
L = check_single_lossfun(model; throw_error = true)
19+
L = check_same_semterm_type(model; throw_error = true)
2020
term1 = _unwrap(loss(terms[1]))
2121
L = typeof(term1).name
2222

src/frontend/fit/fitmeasures/minus2ll.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,6 @@ end
6464
############################################################################################
6565

6666
function minus2ll(model::AbstractSem, fit::SemFit)
67-
check_single_lossfun(model; throw_error = true)
67+
check_same_semterm_type(model; throw_error = true)
6868
sum(Base.Fix2(minus2ll, fit) _unwrap loss, sem_terms(model))
6969
end

src/frontend/specification/Sem.jl

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ multigroup_correction_scale(loss::SemLoss) =
6262
function multigroup_weights(semterms...)
6363
n = length(semterms)
6464
nsamples_total = sum(nsamples, semterms)
65-
semloss_type = check_single_lossfun(semterms; throw_error = false)
65+
semloss_type = check_same_semterm_type(semterms; throw_error = false)
6666
if isnothing(semloss_type)
6767
@info """
6868
Your ensemble model contains heterogeneous loss functions.
@@ -274,32 +274,46 @@ function sem_term(model::AbstractSem, id::Nothing = nothing)
274274
error("Unreachable reached")
275275
end
276276

277-
# check that all SEM terms in the model use the same loss type
278-
function check_same_semloss_type(model::AbstractSem; throw_error::Bool = true)
279-
terms = sem_terms(model)
277+
# check that all models use the same single loss function
278+
# returns the type of the single SEM loss function, SemLoss if there are multiple different SEM losses,
279+
# nothing if there are no SEM terms.
280+
# If throw_error=true, throws an error if there are multiple different SEM loss functions
281+
check_same_semterm_type(model::AbstractSem; throw_error::Bool = true) =
282+
check_same_semterm_type(sem_terms(model); throw_error = throw_error)
283+
284+
# check that all models use the same single loss function
285+
# returns the type of the single SEM loss function,
286+
# nothing if there are multiple different SEM losses or no SEM terms.
287+
# If throw_error=true, throws an error if there are multiple different SEM loss functions
288+
function check_same_semterm_type(terms::Tuple; throw_error::Bool = true)
280289
isempty(terms) && return nothing
281290

282-
L = typeof(lossfun)
283-
end
284-
285-
function check_single_lossfun(models::AbstractSemSingle...; throw_error)
286-
uniform = true
287-
lossfun = models[1].loss.functions[1]
288-
L = typeof(lossfun)
289-
for (i, model) in enumerate(models)
290-
uniform &= check_single_lossfun(model; throw_error = throw_error)
291-
cur_lossfun = model.loss.functions[1]
292-
if !isa(cur_lossfun, L) & throw_error
293-
@error "Loss function for group #$i model is $(typeof(cur_lossfun)), expected $L.
294-
Heterogeneous loss functions are not supported."
291+
_semloss(term::SemLoss) = _unwrap(term)
292+
_semloss(term::LossTerm) = _semloss(loss(term))
293+
_semloss(term) = throw(ArgumentError("SemLoss term expected, $(typeof(term)) found"))
294+
_semloss_label(i::Integer, _::Union{SemLoss, LossTerm{<:SemLoss, Nothing}}) = "#$i"
295+
_semloss_label(i::Integer, term::LossTerm{<:SemLoss, Symbol}) = "#$i ($(SEM.id(term)))"
296+
297+
term1 = _semloss(terms[1])
298+
L = typeof(term1).name
299+
300+
# check that all SemLoss terms are of the same class (ML, FIML, WLS etc), ignore typeparams
301+
for (i, term) in enumerate(terms)
302+
lossterm = _semloss(term)
303+
@assert lossterm isa SemLoss
304+
if typeof(lossterm).name != L
305+
if throw_error
306+
error("SemLoss term $(_semloss_label(i, term)) is $(typeof(lossterm).name), expected $L. Heterogeneous loss functions are not supported")
307+
else
308+
return nothing
309+
end
295310
end
296-
uniform &= isa(cur_lossfun, L)
297311
end
298-
return uniform
299-
end
300312

301-
check_single_lossfun(model::SemEnsemble; throw_error) =
302-
check_single_lossfun(model.sems...; throw_error)
313+
# return the type of the first SEM term
314+
# note that type params of the SEM terms might be different
315+
return typeof(term1)
316+
end
303317

304318
# wrappers arounds a single SemLoss term
305319
observed(model::AbstractSem, id::Nothing = nothing) = observed(sem_term(model, id))

0 commit comments

Comments
 (0)