Skip to content

Commit 1db3106

Browse files
adapt default multigroup weights and give info about defaults used
1 parent 57ec89b commit 1db3106

1 file changed

Lines changed: 20 additions & 5 deletions

File tree

src/types.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,7 @@ end
192192
function SemEnsemble(models...; weights = nothing, groups = nothing, kwargs...)
193193
n = length(models)
194194
# default weights
195-
if isnothing(weights)
196-
nsamples_total = sum(nsamples, models)
197-
weights = [nsamples(model) / nsamples_total for model in models]
198-
end
195+
weights = isnothing(weights) ? multigroup_weights(models, n) : weights
199196
# default group labels
200197
groups = isnothing(groups) ? Symbol.(:g, 1:n) : groups
201198
# check parameters equality
@@ -226,7 +223,25 @@ function SemEnsemble(; specification, data, groups, column = :group, kwargs...)
226223
model = Sem(; specification = ram_matrices, data = data_group, kwargs...)
227224
push!(models, model)
228225
end
229-
return SemEnsemble(models...; weights = nothing, groups = groups, kwargs...)
226+
return SemEnsemble(models...; groups = groups, kwargs...)
227+
end
228+
229+
function multigroup_weights(models, n)
230+
nsamples_total = sum(nsamples, models)
231+
uniform_lossfun = check_single_lossfun(models...; throw_error = false)
232+
if !uniform_lossfun
233+
@info "Your ensemble model contains heterogeneous loss functions.
234+
Default weights of (#samples per group/#total samples) will be used".
235+
return [(nsamples(model)) / (nsamples_total) for model in models]
236+
end
237+
lossfun = models[1].loss.functions[1]
238+
if !applicable(dof_correction, lossfun)
239+
@info "We don't know how to choose group weights for the specified loss function.
240+
Default weights of (#samples per group/#total samples) will be used".
241+
return [(nsamples(model)) / (nsamples_total) for model in models]
242+
end
243+
dc = dof_correction(lossfun)
244+
return [(nsamples(model)-dc) / (nsamples_total-n*dc) for model in models]
230245
end
231246

232247
param_labels(ensemble::SemEnsemble) = ensemble.param_labels

0 commit comments

Comments
 (0)