Skip to content

Commit 961a3c8

Browse files
author
Alexey Stukalov
committed
update multi-group correction
deduplicate the correction scale methods and move to Sem.jl
1 parent f7f7452 commit 961a3c8

3 files changed

Lines changed: 31 additions & 27 deletions

File tree

src/additional_functions/helper.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,3 @@ function nonunique(values::AbstractVector)
115115
end
116116
return res
117117
end
118-
119-
# scaling corrections for multigroup models
120-
mg_correction(::SemFIML) = 0
121-
mg_correction(::SemML) = 0
122-
mg_correction(::SemWLS) = -1

src/frontend/fit/fitmeasures/RMSEA.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,9 @@ For multigroup models, the correction proposed by J.H. Steiger is applied
2121
"""
2222
RMSEA(fit::SemFit) = RMSEA(fit, fit.model)
2323

24-
# scaling corrections
25-
RMSEA_corr_scale(::Type{<:SemFIML}) = 0
26-
RMSEA_corr_scale(::Type{<:SemML}) = -1
27-
RMSEA_corr_scale(::Type{<:SemWLS}) = -1
28-
2924
function RMSEA(fit::SemFit, model::AbstractSem)
3025
term_type = check_same_semterm_type(model; throw_error = true)
31-
n = nsamples(fit) + nsem_terms(model) * RMSEA_corr_scale(term_type)
26+
n = nsamples(fit) + nsem_terms(model) * multigroup_correction_scale(term_type)
3227
sqrt(nsem_terms(model)) * RMSEA(dof(fit), χ²(fit), n)
3328
end
3429

src/frontend/specification/Sem.jl

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,46 @@ function Base.show(io::IO, term::LossTerm)
4545
end
4646
end
4747

48-
############################################################################################
49-
# constructor for Sem types
50-
############################################################################################
48+
# scaling corrections for multigroup models
49+
50+
# fallback method for non-standard SemLoss type
51+
multigroup_correction_scale(::Type{<:SemLoss}) = nothing
5152

52-
function multigroup_weights(models, n)
53-
nsamples_total = sum(nsamples, models)
53+
multigroup_correction_scale(::Type{<:SemFIML}) = 0
54+
multigroup_correction_scale(::Type{<:SemML}) = 0
55+
multigroup_correction_scale(::Type{<:SemWLS}) = -1
56+
57+
multigroup_correction_scale(loss::SemLoss) = multigroup_correction_scale(typeof(loss))
58+
59+
# calculate sem term weights for multigroup models
60+
# correcting for the number of samples and the loss type
61+
function multigroup_weights(semterms...)
62+
n = length(semterms)
63+
nsamples_total = sum(nsamples, semterms)
5464
semloss_type = check_same_semterm_type(semterms; throw_error = false)
5565
if isnothing(semloss_type)
5666
@info """
5767
Your ensemble model contains heterogeneous loss functions.
5868
Default weights of (#samples per group/#total samples) will be used
5969
"""
60-
return [(nsamples(model)) / (nsamples_total) for model in models]
61-
end
62-
lossfun = models[1].loss.functions[1]
63-
if !applicable(mg_correction, lossfun)
64-
@info """
65-
We don't know how to choose group weights for the specified loss function.
66-
Default weights of (#samples per group/#total samples) will be used
67-
"""
68-
return [(nsamples(model)) / (nsamples_total) for model in models]
70+
c = 0
71+
else
72+
c = multigroup_correction_scale(semloss_type)
73+
if isnothing(c)
74+
@info """
75+
We don't know how to choose group weights for the specified loss function.
76+
Default weights of (#samples per group/#total samples) will be used
77+
"""
78+
c = 0
79+
end
6980
end
70-
c = mg_correction(lossfun)
71-
return [(nsamples(model)+c) / (nsamples_total+n*c) for model in models]
81+
return [(nsamples(term)+c) / (nsamples_total+n*c) for term in semterms]
7282
end
7383

84+
############################################################################################
85+
# constructor for Sem types
86+
############################################################################################
87+
7488
function Sem(
7589
loss_terms...;
7690
params::Union{Vector{Symbol}, Nothing} = nothing,

0 commit comments

Comments
 (0)