Skip to content

Commit 66bbb36

Browse files
author
Alexey Stukalov
committed
check_same_semterm_type(): refactor check_single_lossfun()
1 parent 73b015c commit 66bbb36

5 files changed

Lines changed: 46 additions & 34 deletions

File tree

src/additional_functions/helper.jl

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -116,35 +116,6 @@ function nonunique(values::AbstractVector)
116116
return res
117117
end
118118

119-
# check that a model only has a single lossfun
120-
function check_single_lossfun(model::AbstractSemSingle; throw_error)
121-
if (length(model.loss.functions) > 1) & throw_error
122-
@error "The model has $(length(sem.loss.functions)) loss functions.
123-
Only a single loss function is supported."
124-
end
125-
return isone(length(model.loss.functions))
126-
end
127-
128-
# check that all models use the same single loss function
129-
function check_single_lossfun(models::AbstractSemSingle...; throw_error)
130-
uniform = true
131-
lossfun = models[1].loss.functions[1]
132-
L = typeof(lossfun)
133-
for (i, model) in enumerate(models)
134-
uniform &= check_single_lossfun(model; throw_error = throw_error)
135-
cur_lossfun = model.loss.functions[1]
136-
if !isa(cur_lossfun, L) & throw_error
137-
@error "Loss function for group #$i model is $(typeof(cur_lossfun)), expected $L.
138-
Heterogeneous loss functions are not supported."
139-
end
140-
uniform &= isa(cur_lossfun, L)
141-
end
142-
return uniform
143-
end
144-
145-
check_single_lossfun(model::SemEnsemble; throw_error) =
146-
check_single_lossfun(model.sems...; throw_error)
147-
148119
# scaling corrections for multigroup models
149120
mg_correction(::SemFIML) = 0
150121
mg_correction(::SemML) = 0

src/frontend/fit/fitmeasures/RMSEA.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ RMSEA_corr_scale(::Type{<:SemML}) = -1
2727
RMSEA_corr_scale(::Type{<:SemWLS}) = -1
2828

2929
function RMSEA(fit::SemFit, model::AbstractSem)
30-
term_type = check_single_lossfun(model; throw_error = true)
30+
term_type = check_same_semterm_type(model; throw_error = true)
3131
n = nsamples(fit) + nsem_terms(model) * RMSEA_corr_scale(term_type)
3232
sqrt(nsem_terms(model)) * RMSEA(dof(fit), χ²(fit), n)
3333
end

src/frontend/fit/fitmeasures/chi2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function χ²(fit::SemFit, model::AbstractSem)
1616
terms = sem_terms(model)
1717
isempty(terms) && return 0.0
1818

19-
L = check_single_lossfun(model; throw_error = true)
19+
L = check_same_semterm_type(model; throw_error = true)
2020
term1 = _unwrap(loss(terms[1]))
2121
L = typeof(term1).name
2222

src/frontend/fit/fitmeasures/minus2ll.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,6 @@ end
6262
############################################################################################
6363

6464
function minus2ll(model::AbstractSem, fit::SemFit)
65-
check_single_lossfun(model; throw_error = true)
65+
check_same_semterm_type(model; throw_error = true)
6666
sum(Base.Fix2(minus2ll, fit) _unwrap loss, sem_terms(model))
6767
end

src/frontend/specification/Sem.jl

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ end
5151

5252
function multigroup_weights(models, n)
5353
nsamples_total = sum(nsamples, models)
54-
uniform_lossfun = check_single_lossfun(models...; throw_error = false)
55-
if !uniform_lossfun
54+
semloss_type = check_same_semterm_type(semterms; throw_error = false)
55+
if isnothing(semloss_type)
5656
@info """
5757
Your ensemble model contains heterogeneous loss functions.
5858
Default weights of (#samples per group/#total samples) will be used
@@ -260,6 +260,47 @@ function sem_term(model::AbstractSem, _::Nothing = nothing)
260260
error("Unreachable reached")
261261
end
262262

263+
# check that all models use the same single loss function
264+
# returns the type of the single SEM loss function, SemLoss if there are multiple different SEM losses,
265+
# nothing if there are no SEM terms.
266+
# If throw_error=true, throws an error if there are multiple different SEM loss functions
267+
check_same_semterm_type(model::AbstractSem; throw_error::Bool = true) =
268+
check_same_semterm_type(sem_terms(model); throw_error = throw_error)
269+
270+
# check that all models use the same single loss function
271+
# returns the type of the single SEM loss function,
272+
# nothing if there are multiple different SEM losses or no SEM terms.
273+
# If throw_error=true, throws an error if there are multiple different SEM loss functions
274+
function check_same_semterm_type(terms::Tuple; throw_error::Bool = true)
275+
isempty(terms) && return nothing
276+
277+
_semloss(term::SemLoss) = _unwrap(term)
278+
_semloss(term::LossTerm) = _semloss(loss(term))
279+
_semloss(term) = throw(ArgumentError("SemLoss term expected, $(typeof(term)) found"))
280+
_semloss_label(i::Integer, _::Union{SemLoss, LossTerm{<:SemLoss, Nothing}}) = "#$i"
281+
_semloss_label(i::Integer, term::LossTerm{<:SemLoss, Symbol}) = "#$i ($(SEM.id(term)))"
282+
283+
term1 = _semloss(terms[1])
284+
L = typeof(term1).name
285+
286+
# check that all SemLoss terms are of the same class (ML, FIML, WLS etc), ignore typeparams
287+
for (i, term) in enumerate(terms)
288+
lossterm = _semloss(term)
289+
@assert lossterm isa SemLoss
290+
if typeof(lossterm).name != L
291+
if throw_error
292+
error("SemLoss term $(_semloss_label(i, term)) is $(typeof(lossterm).name), expected $L. Heterogeneous loss functions are not supported")
293+
else
294+
return nothing
295+
end
296+
end
297+
end
298+
299+
# return the type of the first SEM term
300+
# note that type params of the SEM terms might be different
301+
return typeof(term1)
302+
end
303+
263304
# wrappers arounds a single SemLoss term
264305
observed(model::AbstractSem, id::Nothing = nothing) = observed(sem_term(model, id))
265306
implied(model::AbstractSem, id::Nothing = nothing) = implied(sem_term(model, id))

0 commit comments

Comments
 (0)