|
192 | 192 | function SemEnsemble(models...; weights = nothing, groups = nothing, kwargs...) |
193 | 193 | n = length(models) |
194 | 194 | # default weights |
195 | | - if isnothing(weights) |
196 | | - nsamples_total = sum(nsamples, models) |
197 | | - weights = [nsamples(model) / nsamples_total for model in models] |
198 | | - end |
| 195 | + weights = isnothing(weights) ? multigroup_weights(models, n) : weights |
199 | 196 | # default group labels |
200 | 197 | groups = isnothing(groups) ? Symbol.(:g, 1:n) : groups |
201 | 198 | # check parameters equality |
@@ -226,7 +223,25 @@ function SemEnsemble(; specification, data, groups, column = :group, kwargs...) |
226 | 223 | model = Sem(; specification = ram_matrices, data = data_group, kwargs...) |
227 | 224 | push!(models, model) |
228 | 225 | end |
229 | | - return SemEnsemble(models...; weights = nothing, groups = groups, kwargs...) |
| 226 | + return SemEnsemble(models...; groups = groups, kwargs...) |
| 227 | +end |
| 228 | + |
| 229 | +function multigroup_weights(models, n) |
| 230 | + nsamples_total = sum(nsamples, models) |
| 231 | + uniform_lossfun = check_single_lossfun(models...; throw_error = false) |
| 232 | + if !uniform_lossfun |
| 233 | + @info "Your ensemble model contains heterogeneous loss functions. |
| 234 | + Default weights of (#samples per group/#total samples) will be used". |
| 235 | + return [(nsamples(model)) / (nsamples_total) for model in models] |
| 236 | + end |
| 237 | + lossfun = models[1].loss.functions[1] |
| 238 | + if !applicable(dof_correction, lossfun) |
| 239 | + @info "We don't know how to choose group weights for the specified loss function. |
| 240 | + Default weights of (#samples per group/#total samples) will be used". |
| 241 | + return [(nsamples(model)) / (nsamples_total) for model in models] |
| 242 | + end |
| 243 | + dc = dof_correction(lossfun) |
| 244 | + return [(nsamples(model)-dc) / (nsamples_total-n*dc) for model in models] |
230 | 245 | end |
231 | 246 |
|
232 | 247 | param_labels(ensemble::SemEnsemble) = ensemble.param_labels |
|
0 commit comments