Skip to content

Commit f7f28d7

Browse files
author
Alexey Stukalov
committed
replace_observed(): simplify & refactor
remove update_observed!()
1 parent 47199d0 commit f7f28d7

15 files changed

Lines changed: 132 additions & 223 deletions

File tree

ext/SEMNLOptExt/NLopt.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,6 @@ function SemOptimizerNLopt(;
107107
)
108108
end
109109

110-
############################################################################################
111-
### Recommended methods
112-
############################################################################################
113-
114-
SEM.update_observed(optimizer::SemOptimizerNLopt, observed::SemObserved; kwargs...) =
115-
optimizer
116-
117110
############################################################################################
118111
### additional methods
119112
############################################################################################

ext/SEMProximalOptExt/ProximalAlgorithms.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,6 @@ SemOptimizerProximal(;
3434

3535
SEM.sem_optimizer_subtype(::Val{:Proximal}) = SemOptimizerProximal
3636

37-
############################################################################################
38-
### Recommended methods
39-
############################################################################################
40-
41-
SEM.update_observed(optimizer::SemOptimizerProximal, observed::SemObserved; kwargs...) =
42-
optimizer
43-
4437
############################################################################
4538
### Model fitting
4639
############################################################################

src/additional_functions/simulation.jl

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,3 @@
1-
"""
2-
(1) replace_observed(model::AbstractSemSingle; kwargs...)
3-
4-
(2) replace_observed(model::AbstractSemSingle, observed; kwargs...)
5-
6-
(3) replace_observed(model::SemEnsemble; column = :group, weights = nothing, kwargs...)
7-
8-
Return a new model with swaped observed part.
9-
10-
# Arguments
11-
- `model::AbstractSemSingle`: model to swap the observed part of.
12-
- `kwargs`: additional keyword arguments; typically includes `data` and `specification`
13-
- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved`
14-
15-
# For SemEnsemble models:
16-
- `column`: if a DataFrame is passed as `data = ...`, which column signifies the group?
17-
- `weights`: how to weight the different sub-models,
18-
defaults to number of samples per group in the new data
19-
- `kwargs`: has to be a dict with keys equal to the group names.
20-
For `data` can also be a DataFrame with `column` containing the group information,
21-
and for `specification` can also be an `EnsembleParameterTable`.
22-
23-
# Examples
24-
See the online documentation on [Replace observed data](@ref).
25-
"""
26-
function replace_observed end
27-
28-
"""
29-
update_observed(to_update, observed::SemObserved; kwargs...)
30-
31-
Update a `SemImplied`, `SemLossFunction` or `SemOptimizer` object to use a `SemObserved` object.
32-
33-
# Examples
34-
See the online documentation on [Replace observed data](@ref).
35-
36-
# Implementation
37-
You can provide a method for this function when defining a new type, for more information
38-
on this see the online developer documentation on [Update observed data](@ref).
39-
"""
40-
function update_observed end
41-
42-
############################################################################################
43-
# change observed (data) without reconstructing the whole model
44-
############################################################################################
45-
46-
# don't change non-SEM terms
47-
replace_observed(loss::AbstractLoss; kwargs...) = loss
48-
49-
# use the same observed type as before
50-
replace_observed(loss::SemLoss; kwargs...) =
51-
replace_observed(loss, typeof(SEM.observed(loss)).name.wrapper; kwargs...)
52-
53-
# construct a new observed type
54-
replace_observed(loss::SemLoss, observed_type; kwargs...) =
55-
replace_observed(loss, observed_type(; kwargs...); kwargs...)
56-
57-
function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...)
58-
kwargs = Dict{Symbol, Any}(kwargs...)
59-
old_observed = SEM.observed(loss)
60-
implied = SEM.implied(loss)
61-
62-
# get field types
63-
kwargs[:observed_type] = typeof(new_observed)
64-
kwargs[:old_observed_type] = typeof(old_observed)
65-
66-
# update implied
67-
new_implied = update_observed(implied, new_observed; kwargs...)
68-
kwargs[:implied] = new_implied
69-
kwargs[:implied_type] = typeof(new_implied)
70-
kwargs[:nparams] = nparams(new_implied)
71-
72-
# update loss
73-
return update_observed(loss, new_observed; kwargs...)
74-
end
75-
76-
replace_observed(loss::LossTerm; kwargs...) =
77-
LossTerm(replace_observed(loss.loss; kwargs...), loss.id, loss.weight)
78-
79-
function replace_observed(sem::Sem; kwargs...)
80-
updated_terms = Tuple(replace_observed(term; kwargs...) for term in loss_terms(sem))
81-
return Sem(updated_terms...)
82-
end
83-
84-
function replace_observed(
85-
emodel::SemEnsemble;
86-
column = :group,
87-
weights = nothing,
88-
kwargs...,
89-
)
90-
kwargs = Dict{Symbol, Any}(kwargs...)
91-
# allow for EnsembleParameterTable to be passed as specification
92-
if haskey(kwargs, :specification) && isa(kwargs[:specification], EnsembleParameterTable)
93-
kwargs[:specification] = convert(Dict{Symbol, RAMMatrices}, kwargs[:specification])
94-
end
95-
# allow for DataFrame with group variable "column" to be passed as new data
96-
if haskey(kwargs, :data) && isa(kwargs[:data], DataFrame)
97-
kwargs[:data] = Dict(
98-
group =>
99-
select(filter(r -> r[column] == group, kwargs[:data]), Not(column)) for
100-
group in emodel.groups
101-
)
102-
end
103-
# update each model for new data
104-
models = emodel.sems
105-
new_models = Tuple(
106-
replace_observed(m; group_kwargs(g, kwargs)...) for
107-
(m, g) in zip(models, emodel.groups)
108-
)
109-
return SemEnsemble(new_models...; weights = weights, groups = emodel.groups)
110-
end
111-
112-
function group_kwargs(g, kwargs)
113-
return Dict(k => kwargs[k][g] for k in keys(kwargs))
114-
end
115-
1161
############################################################################################
1172
# simulate data
1183
############################################################################################

src/frontend/specification/Sem.jl

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,12 +396,108 @@ function get_SemLoss(loss, observed, implied; kwargs...)
396396
end
397397
end
398398

399-
function update_observed(sem::Sem, new_observed; kwargs...)
400-
new_terms = Tuple(
401-
update_observed(lossterm.loss, new_observed; kwargs...) for
402-
lossterm in loss_terms(sem)
399+
##############################################################
400+
# replace_observed: Sem level
401+
##############################################################
402+
403+
"""
404+
replace_observed(model::Sem, observed::SemObserved)
405+
replace_observed(model::Sem, data::AbstractDict{Symbol})
406+
replace_observed(model::Sem, data::AbstractDataFrame; [semterm_column])
407+
replace_observed(loss::SemLoss, observed::SemObserved)
408+
replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame})
409+
410+
Construct a new SEM model or SEM loss with replaced observed data.
411+
412+
The SEM structure (implied covariance, loss type) is preserved;
413+
only the observed data is swapped.
414+
415+
# Single-term models
416+
417+
Pass a `SemObserved` object, a data matrix, or a `DataFrame`:
418+
```julia
419+
replace_observed(model, new_data_matrix)
420+
replace_observed(model, new_sem_observed)
421+
replace_observed(model, new_df)
422+
```
423+
424+
# Multi-term models
425+
426+
Pass a `Dict{Symbol}` mapping term ids to data or `SemObserved` objects:
427+
```julia
428+
replace_observed(model, Dict(:g1 => data1, :g2 => data2))
429+
```
430+
431+
Or pass a `DataFrame` with a `semterm_column` identifying the group:
432+
```julia
433+
replace_observed(model, new_df; semterm_column = :group)
434+
```
435+
"""
436+
function replace_observed end
437+
438+
function replace_observed(sem::Sem, data::Union{SemObserved, AbstractMatrix})
439+
nsem_terms(sem) > 1 && throw(
440+
ArgumentError(
441+
"Model contains $(nsem_terms(sem)) SEM terms. " *
442+
"Use a Dict{Symbol} or a DataFrame with `semterm_column` to provide per-term data.",
443+
),
444+
)
445+
updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem))
446+
return Sem(updated_terms...)
447+
end
448+
449+
function replace_observed(sem::Sem, data::AbstractDict{Symbol})
450+
term_ids = Set(begin
451+
tid= id(term)
452+
isnothing(tid) && throw(
453+
ArgumentError(
454+
"Multigroup replace_observed requires all SEM terms to have ids.",
455+
),
456+
)
457+
end for term in loss_terms(sem) if issemloss(term)
458+
)
459+
# check for extra ids
460+
for tid in keys(data)
461+
if tid term_ids
462+
@warn "Data provided for term id :$tid, but no SEM term with this id exists in the model"
463+
end
464+
end
465+
466+
updated_terms = map(loss_terms(sem)) do term
467+
issemloss(term) || return term
468+
tid = id(term)
469+
term_data = get(data, tid, nothing)
470+
isnothing(term_data) && throw(
471+
ArgumentError("No data provided for SEM term :$tid"),
472+
)
473+
return replace_observed(term, term_data)
474+
end
475+
return Sem(Tuple(updated_terms)...)
476+
end
477+
478+
function replace_observed(
479+
sem::Sem,
480+
data::AbstractDataFrame;
481+
semterm_column::Union{Symbol, Nothing} = nothing,
482+
)
483+
if isnothing(semterm_column)
484+
# single-term shortcut
485+
nsem_terms(sem) > 1 && throw(
486+
ArgumentError(
487+
"Model contains $(nsem_terms(sem)) SEM terms. " *
488+
"Provide `semterm_column` to specify which DataFrame column identifies the groups.",
489+
),
490+
)
491+
updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem))
492+
return Sem(updated_terms...)
493+
end
494+
495+
# multi-term: split DataFrame by semterm_column
496+
terms_data = Dict(
497+
g[semterm_column] => group_data
498+
for (g, group_data) in pairs(groupby(data, semterm_column))
403499
)
404-
return Sem(new_terms...)
500+
return replace_observed(sem, terms_data)
405501
end
406502

407503
##############################################################

src/implied/RAM/generic.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,3 @@ function update!(targets::EvaluationTargets, implied::RAM, params)
181181
mul!(implied.μ, implied.F⨉I_A⁻¹, implied.M)
182182
end
183183
end
184-
185-
############################################################################################
186-
### Recommended methods
187-
############################################################################################
188-
189-
function update_observed(implied::RAM, observed::SemObserved; kwargs...)
190-
if nobserved_vars(observed) == size(implied.Σ, 1)
191-
return implied
192-
else
193-
return RAM(; observed = observed, kwargs...)
194-
end
195-
end

src/implied/RAM/symbolic.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -192,18 +192,6 @@ function update!(targets::EvaluationTargets, implied::RAMSymbolic, par)
192192
end
193193
end
194194

195-
############################################################################################
196-
### Recommended methods
197-
############################################################################################
198-
199-
function update_observed(implied::RAMSymbolic, observed::SemObserved; kwargs...)
200-
if nobserved_vars(observed) == size(implied.Σ, 1)
201-
return implied
202-
else
203-
return RAMSymbolic(; observed = observed, kwargs...)
204-
end
205-
end
206-
207195
############################################################################################
208196
### additional functions
209197
############################################################################################

src/implied/empty.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,3 @@ end
4646
############################################################################################
4747

4848
update!(targets::EvaluationTargets, implied::ImpliedEmpty, par) = nothing
49-
50-
############################################################################################
51-
### Recommended methods
52-
############################################################################################
53-
54-
update_observed(implied::ImpliedEmpty, observed::SemObserved; kwargs...) = implied

src/loss/ML/FIML.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,6 @@ function evaluate!(objective, gradient, hessian, loss::SemFIML, params)
159159

160160
return objective
161161
end
162-
163-
############################################################################################
164-
### Recommended methods
165-
############################################################################################
166-
167-
update_observed(loss::SemFIML, observed::SemObserved; kwargs...) =
168-
SemFIML(; observed = observed, kwargs...)
169-
170162
############################################################################################
171163
### additional functions
172164
############################################################################################

src/loss/ML/ML.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -235,18 +235,3 @@ function non_posdef_return(par)
235235
return typemax(eltype(par))
236236
end
237237
end
238-
239-
############################################################################################
240-
### recommended methods
241-
############################################################################################
242-
243-
update_observed(loss::SemML, observed::SemObservedMissing; kwargs...) =
244-
error("ML estimation does not work with missing data - use FIML instead")
245-
246-
function update_observed(loss::SemML, observed::SemObserved; kwargs...)
247-
if (obs_cov(loss) == obs_cov(observed)) && (obs_mean(loss) == obs_mean(observed))
248-
return loss # no change
249-
else
250-
return SemML(observed, loss.implied; kwargs...)
251-
end
252-
end

src/loss/ML/abstract.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,34 @@ function check_observed_vars(observed::SemObserved, implied::SemImplied)
4040
end
4141

4242
check_observed_vars(sem::SemLoss) = check_observed_vars(observed(sem), implied(sem))
43+
44+
############################################################################################
45+
# replace_observed: SemLoss, AbstractLoss, LossTerm
46+
############################################################################################
47+
48+
function replace_observed(loss::SemLoss, new_observed::SemObserved)
49+
old_obs = SEM.observed(loss)
50+
observed_vars(old_obs) == observed_vars(new_observed) || throw(
51+
ArgumentError(
52+
"observed_vars of the new data do not match the model: " *
53+
"expected $(observed_vars(old_obs)), got $(observed_vars(new_observed))",
54+
),
55+
)
56+
return typeof(loss).name.wrapper(new_observed, SEM.implied(loss))
57+
end
58+
59+
function replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame})
60+
old_obs = SEM.observed(loss)
61+
new_observed = typeof(old_obs).name.wrapper(
62+
data = data,
63+
observed_vars = observed_vars(old_obs),
64+
)
65+
return replace_observed(loss, new_observed)
66+
end
67+
68+
# non-SEM loss terms are unchanged
69+
replace_observed(loss::AbstractLoss, ::Any) = loss
70+
71+
# LossTerm: delegate to inner loss
72+
replace_observed(term::LossTerm, data) =
73+
LossTerm(replace_observed(loss(term), data), id(term), weight(term))

0 commit comments

Comments
 (0)