@@ -62,7 +62,7 @@ multigroup_correction_scale(loss::SemLoss) =
6262function multigroup_weights (semterms... )
6363 n = length (semterms)
6464 nsamples_total = sum (nsamples, semterms)
65- semloss_type = check_single_lossfun (semterms; throw_error = false )
65+ semloss_type = check_same_semterm_type (semterms; throw_error = false )
6666 if isnothing (semloss_type)
6767 @info """
6868 Your ensemble model contains heterogeneous loss functions.
@@ -274,32 +274,46 @@ function sem_term(model::AbstractSem, id::Nothing = nothing)
274274 error (" Unreachable reached" )
275275end
276276
277- # check that all SEM terms in the model use the same loss type
278- function check_same_semloss_type (model:: AbstractSem ; throw_error:: Bool = true )
279- terms = sem_terms (model)
277+ # check that all models use the same single loss function
278+ # returns the type of the single SEM loss function, SemLoss if there are multiple different SEM losses,
279+ # nothing if there are no SEM terms.
280+ # If throw_error=true, throws an error if there are multiple different SEM loss functions
281+ check_same_semterm_type (model:: AbstractSem ; throw_error:: Bool = true ) =
282+ check_same_semterm_type (sem_terms (model); throw_error = throw_error)
283+
284+ # check that all models use the same single loss function
285+ # returns the type of the single SEM loss function,
286+ # nothing if there are multiple different SEM losses or no SEM terms.
287+ # If throw_error=true, throws an error if there are multiple different SEM loss functions
288+ function check_same_semterm_type (terms:: Tuple ; throw_error:: Bool = true )
280289 isempty (terms) && return nothing
281290
282- L = typeof (lossfun)
283- end
284-
285- function check_single_lossfun (models:: AbstractSemSingle... ; throw_error)
286- uniform = true
287- lossfun = models[1 ]. loss. functions[1 ]
288- L = typeof (lossfun)
289- for (i, model) in enumerate (models)
290- uniform &= check_single_lossfun (model; throw_error = throw_error)
291- cur_lossfun = model. loss. functions[1 ]
292- if ! isa (cur_lossfun, L) & throw_error
293- @error " Loss function for group #$i model is $(typeof (cur_lossfun)) , expected $L .
294- Heterogeneous loss functions are not supported."
291+ _semloss (term:: SemLoss ) = _unwrap (term)
292+ _semloss (term:: LossTerm ) = _semloss (loss (term))
293+ _semloss (term) = throw (ArgumentError (" SemLoss term expected, $(typeof (term)) found" ))
294+ _semloss_label (i:: Integer , _:: Union{SemLoss, LossTerm{<:SemLoss, Nothing}} ) = " #$i "
295+ _semloss_label (i:: Integer , term:: LossTerm{<:SemLoss, Symbol} ) = " #$i ($(SEM. id (term)) )"
296+
297+ term1 = _semloss (terms[1 ])
298+ L = typeof (term1). name
299+
300+ # check that all SemLoss terms are of the same class (ML, FIML, WLS etc), ignore typeparams
301+ for (i, term) in enumerate (terms)
302+ lossterm = _semloss (term)
303+ @assert lossterm isa SemLoss
304+ if typeof (lossterm). name != L
305+ if throw_error
306+ error (" SemLoss term $(_semloss_label (i, term)) is $(typeof (lossterm). name) , expected $L . Heterogeneous loss functions are not supported" )
307+ else
308+ return nothing
309+ end
295310 end
296- uniform &= isa (cur_lossfun, L)
297311 end
298- return uniform
299- end
300312
301- check_single_lossfun (model:: SemEnsemble ; throw_error) =
302- check_single_lossfun (model. sems... ; throw_error)
313+ # return the type of the first SEM term
314+ # note that type params of the SEM terms might be different
315+ return typeof (term1)
316+ end
303317
304318# wrappers arounds a single SemLoss term
305319observed (model:: AbstractSem , id:: Nothing = nothing ) = observed (sem_term (model, id))
0 commit comments