@@ -45,32 +45,47 @@ function Base.show(io::IO, term::LossTerm)
4545 end
4646end
4747
48- # ###########################################################################################
49- # constructor for Sem types
50- # ###########################################################################################
48+ # scaling corrections for multigroup models
49+
50+ # fallback method for non-standard SemLoss type
51+ multigroup_correction_scale (:: Type{<:SemLoss} ) = nothing
5152
52- function multigroup_weights (models, n)
53- nsamples_total = sum (nsamples, models)
53+ multigroup_correction_scale (:: Type{<:SemFIML} ) = 0
54+ multigroup_correction_scale (:: Type{<:SemML} ) = 0
55+ multigroup_correction_scale (:: Type{<:SemWLS} ) = - 1
56+
57+ multigroup_correction_scale (loss:: SemLoss ) =
58+ multigroup_correction_scale (typeof (loss))
59+
60+ # calculate sem term weights for multigroup models
61+ # correcting for the number of samples and the loss type
62+ function multigroup_weights (semterms... )
63+ n = length (semterms)
64+ nsamples_total = sum (nsamples, semterms)
5465 semloss_type = check_same_semterm_type (semterms; throw_error = false )
5566 if isnothing (semloss_type)
5667 @info """
5768 Your ensemble model contains heterogeneous loss functions.
5869 Default weights of (#samples per group/#total samples) will be used
5970 """
60- return [(nsamples (model)) / (nsamples_total) for model in models]
61- end
62- lossfun = models[1 ]. loss. functions[1 ]
63- if ! applicable (mg_correction, lossfun)
64- @info """
65- We don't know how to choose group weights for the specified loss function.
66- Default weights of (#samples per group/#total samples) will be used
67- """
68- return [(nsamples (model)) / (nsamples_total) for model in models]
71+ c = 0
72+ else
73+ c = multigroup_correction_scale (semloss_type)
74+ if isnothing (c)
75+ @info """
76+ We don't know how to choose group weights for the specified loss function.
77+ Default weights of (#samples per group/#total samples) will be used
78+ """
79+ c = 0
80+ end
6981 end
70- c = mg_correction (lossfun)
71- return [(nsamples (model)+ c) / (nsamples_total+ n* c) for model in models]
82+ return [(nsamples (term)+ c) / (nsamples_total+ n* c) for term in semterms]
7283end
7384
85+ # ###########################################################################################
86+ # constructor for Sem types
87+ # ###########################################################################################
88+
7489function Sem (
7590 loss_terms... ;
7691 params:: Union{Vector{Symbol}, Nothing} = nothing ,
0 commit comments