@@ -45,32 +45,46 @@ function Base.show(io::IO, term::LossTerm)
4545 end
4646end
4747
48- # ###########################################################################################
49- # constructor for Sem types
50- # ###########################################################################################
48+ # scaling corrections for multigroup models
49+
50+ # fallback method for non-standard SemLoss type
51+ multigroup_correction_scale (:: Type{<:SemLoss} ) = nothing
5152
52- function multigroup_weights (models, n)
53- nsamples_total = sum (nsamples, models)
53+ multigroup_correction_scale (:: Type{<:SemFIML} ) = 0
54+ multigroup_correction_scale (:: Type{<:SemML} ) = 0
55+ multigroup_correction_scale (:: Type{<:SemWLS} ) = - 1
56+
57+ multigroup_correction_scale (loss:: SemLoss ) = multigroup_correction_scale (typeof (loss))
58+
59+ # calculate sem term weights for multigroup models
60+ # correcting for the number of samples and the loss type
61+ function multigroup_weights (semterms... )
62+ n = length (semterms)
63+ nsamples_total = sum (nsamples, semterms)
5464 semloss_type = check_same_semterm_type (semterms; throw_error = false )
5565 if isnothing (semloss_type)
5666 @info """
5767 Your ensemble model contains heterogeneous loss functions.
5868 Default weights of (#samples per group/#total samples) will be used
5969 """
60- return [(nsamples (model)) / (nsamples_total) for model in models]
61- end
62- lossfun = models[1 ]. loss. functions[1 ]
63- if ! applicable (mg_correction, lossfun)
64- @info """
65- We don't know how to choose group weights for the specified loss function.
66- Default weights of (#samples per group/#total samples) will be used
67- """
68- return [(nsamples (model)) / (nsamples_total) for model in models]
70+ c = 0
71+ else
72+ c = multigroup_correction_scale (semloss_type)
73+ if isnothing (c)
74+ @info """
75+ We don't know how to choose group weights for the specified loss function.
76+ Default weights of (#samples per group/#total samples) will be used
77+ """
78+ c = 0
79+ end
6980 end
70- c = mg_correction (lossfun)
71- return [(nsamples (model)+ c) / (nsamples_total+ n* c) for model in models]
81+ return [(nsamples (term)+ c) / (nsamples_total+ n* c) for term in semterms]
7282end
7383
84+ # ###########################################################################################
85+ # constructor for Sem types
86+ # ###########################################################################################
87+
7488function Sem (
7589 loss_terms... ;
7690 params:: Union{Vector{Symbol}, Nothing} = nothing ,
0 commit comments