Skip to content

Commit ac13a75

Browse files
author
Alexey Stukalov
committed
update multi-group correction
deduplicate the correction scale methods and move to Sem.jl
1 parent 66bbb36 commit ac13a75

3 files changed

Lines changed: 32 additions & 27 deletions

File tree

src/additional_functions/helper.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,3 @@ function nonunique(values::AbstractVector)
115115
end
116116
return res
117117
end
118-
119-
# scaling corrections for multigroup models
120-
mg_correction(::SemFIML) = 0
121-
mg_correction(::SemML) = 0
122-
mg_correction(::SemWLS) = -1

src/frontend/fit/fitmeasures/RMSEA.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,9 @@ For multigroup models, the correction proposed by J.H. Steiger is applied
2121
"""
2222
RMSEA(fit::SemFit) = RMSEA(fit, fit.model)
2323

24-
# scaling corrections
25-
RMSEA_corr_scale(::Type{<:SemFIML}) = 0
26-
RMSEA_corr_scale(::Type{<:SemML}) = -1
27-
RMSEA_corr_scale(::Type{<:SemWLS}) = -1
28-
2924
function RMSEA(fit::SemFit, model::AbstractSem)
3025
term_type = check_same_semterm_type(model; throw_error = true)
31-
n = nsamples(fit) + nsem_terms(model) * RMSEA_corr_scale(term_type)
26+
n = nsamples(fit) + nsem_terms(model) * multigroup_correction_scale(term_type)
3227
sqrt(nsem_terms(model)) * RMSEA(dof(fit), χ²(fit), n)
3328
end
3429

src/frontend/specification/Sem.jl

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,47 @@ function Base.show(io::IO, term::LossTerm)
4545
end
4646
end
4747

48-
############################################################################################
49-
# constructor for Sem types
50-
############################################################################################
48+
# scaling corrections for multigroup models
49+
50+
# fallback method for non-standard SemLoss type
51+
multigroup_correction_scale(::Type{<:SemLoss}) = nothing
5152

52-
function multigroup_weights(models, n)
53-
nsamples_total = sum(nsamples, models)
53+
multigroup_correction_scale(::Type{<:SemFIML}) = 0
54+
multigroup_correction_scale(::Type{<:SemML}) = 0
55+
multigroup_correction_scale(::Type{<:SemWLS}) = -1
56+
57+
multigroup_correction_scale(loss::SemLoss) =
58+
multigroup_correction_scale(typeof(loss))
59+
60+
# calculate sem term weights for multigroup models
61+
# correcting for the number of samples and the loss type
62+
function multigroup_weights(semterms...)
63+
n = length(semterms)
64+
nsamples_total = sum(nsamples, semterms)
5465
semloss_type = check_same_semterm_type(semterms; throw_error = false)
5566
if isnothing(semloss_type)
5667
@info """
5768
Your ensemble model contains heterogeneous loss functions.
5869
Default weights of (#samples per group/#total samples) will be used
5970
"""
60-
return [(nsamples(model)) / (nsamples_total) for model in models]
61-
end
62-
lossfun = models[1].loss.functions[1]
63-
if !applicable(mg_correction, lossfun)
64-
@info """
65-
We don't know how to choose group weights for the specified loss function.
66-
Default weights of (#samples per group/#total samples) will be used
67-
"""
68-
return [(nsamples(model)) / (nsamples_total) for model in models]
71+
c = 0
72+
else
73+
c = multigroup_correction_scale(semloss_type)
74+
if isnothing(c)
75+
@info """
76+
We don't know how to choose group weights for the specified loss function.
77+
Default weights of (#samples per group/#total samples) will be used
78+
"""
79+
c = 0
80+
end
6981
end
70-
c = mg_correction(lossfun)
71-
return [(nsamples(model)+c) / (nsamples_total+n*c) for model in models]
82+
return [(nsamples(term)+c) / (nsamples_total+n*c) for term in semterms]
7283
end
7384

85+
############################################################################################
86+
# constructor for Sem types
87+
############################################################################################
88+
7489
function Sem(
7590
loss_terms...;
7691
params::Union{Vector{Symbol}, Nothing} = nothing,

0 commit comments

Comments
 (0)