|
51 | 51 |
|
52 | 52 | function multigroup_weights(models, n) |
53 | 53 | nsamples_total = sum(nsamples, models) |
54 | | - uniform_lossfun = check_single_lossfun(models...; throw_error = false) |
55 | | - if !uniform_lossfun |
| 54 | + semloss_type = check_same_semterm_type(semterms; throw_error = false) |
| 55 | + if isnothing(semloss_type) |
56 | 56 | @info """ |
57 | 57 | Your ensemble model contains heterogeneous loss functions. |
58 | 58 | Default weights of (#samples per group/#total samples) will be used |
@@ -260,6 +260,47 @@ function sem_term(model::AbstractSem, _::Nothing = nothing) |
260 | 260 | error("Unreachable reached") |
261 | 261 | end |
262 | 262 |
|
| 263 | +# check that all models use the same single loss function |
| 264 | +# returns the type of the single SEM loss function, SemLoss if there are multiple different SEM losses, |
| 265 | +# nothing if there are no SEM terms. |
| 266 | +# If throw_error=true, throws an error if there are multiple different SEM loss functions |
| 267 | +check_same_semterm_type(model::AbstractSem; throw_error::Bool = true) = |
| 268 | + check_same_semterm_type(sem_terms(model); throw_error = throw_error) |
| 269 | + |
| 270 | +# check that all models use the same single loss function |
| 271 | +# returns the type of the single SEM loss function, |
| 272 | +# nothing if there are multiple different SEM losses or no SEM terms. |
| 273 | +# If throw_error=true, throws an error if there are multiple different SEM loss functions |
| 274 | +function check_same_semterm_type(terms::Tuple; throw_error::Bool = true) |
| 275 | + isempty(terms) && return nothing |
| 276 | + |
| 277 | + _semloss(term::SemLoss) = _unwrap(term) |
| 278 | + _semloss(term::LossTerm) = _semloss(loss(term)) |
| 279 | + _semloss(term) = throw(ArgumentError("SemLoss term expected, $(typeof(term)) found")) |
| 280 | + _semloss_label(i::Integer, _::Union{SemLoss, LossTerm{<:SemLoss, Nothing}}) = "#$i" |
| 281 | + _semloss_label(i::Integer, term::LossTerm{<:SemLoss, Symbol}) = "#$i ($(SEM.id(term)))" |
| 282 | + |
| 283 | + term1 = _semloss(terms[1]) |
| 284 | + L = typeof(term1).name |
| 285 | + |
| 286 | + # check that all SemLoss terms are of the same class (ML, FIML, WLS etc), ignore typeparams |
| 287 | + for (i, term) in enumerate(terms) |
| 288 | + lossterm = _semloss(term) |
| 289 | + @assert lossterm isa SemLoss |
| 290 | + if typeof(lossterm).name != L |
| 291 | + if throw_error |
| 292 | + error("SemLoss term $(_semloss_label(i, term)) is $(typeof(lossterm).name), expected $L. Heterogeneous loss functions are not supported") |
| 293 | + else |
| 294 | + return nothing |
| 295 | + end |
| 296 | + end |
| 297 | + end |
| 298 | + |
| 299 | + # return the type of the first SEM term |
| 300 | + # note that type params of the SEM terms might be different |
| 301 | + return typeof(term1) |
| 302 | +end |
| 303 | + |
263 | 304 | # wrappers arounds a single SemLoss term |
264 | 305 | observed(model::AbstractSem, id::Nothing = nothing) = observed(sem_term(model, id)) |
265 | 306 | implied(model::AbstractSem, id::Nothing = nothing) = implied(sem_term(model, id)) |
|
0 commit comments