Skip to content

Commit cbca143

Browse files
author
Alexey Stukalov
committed
replace_observed(): simplify & refactor
remove update_observed!()
1 parent ac13a75 commit cbca143

16 files changed

Lines changed: 132 additions & 224 deletions

File tree

ext/SEMNLOptExt/NLopt.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,6 @@ function SemOptimizerNLopt(;
107107
)
108108
end
109109

110-
############################################################################################
111-
### Recommended methods
112-
############################################################################################
113-
114-
SEM.update_observed(optimizer::SemOptimizerNLopt, observed::SemObserved; kwargs...) =
115-
optimizer
116-
117110
############################################################################################
118111
### additional methods
119112
############################################################################################

ext/SEMProximalOptExt/ProximalAlgorithms.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,6 @@ SemOptimizerProximal(;
3434

3535
SEM.sem_optimizer_subtype(::Val{:Proximal}) = SemOptimizerProximal
3636

37-
############################################################################################
38-
### Recommended methods
39-
############################################################################################
40-
41-
SEM.update_observed(optimizer::SemOptimizerProximal, observed::SemObserved; kwargs...) =
42-
optimizer
43-
4437
############################################################################
4538
### Model fitting
4639
############################################################################

src/StructuralEquationModels.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ export AbstractSem,
195195
se_bootstrap,
196196
example_data,
197197
replace_observed,
198-
update_observed,
199198
@StenoGraph,
200199
,
201200
,

src/additional_functions/simulation.jl

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,3 @@
1-
"""
2-
(1) replace_observed(model::AbstractSemSingle; kwargs...)
3-
4-
(2) replace_observed(model::AbstractSemSingle, observed; kwargs...)
5-
6-
(3) replace_observed(model::SemEnsemble; column = :group, weights = nothing, kwargs...)
7-
8-
Return a new model with swaped observed part.
9-
10-
# Arguments
11-
- `model::AbstractSemSingle`: model to swap the observed part of.
12-
- `kwargs`: additional keyword arguments; typically includes `data` and `specification`
13-
- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved`
14-
15-
# For SemEnsemble models:
16-
- `column`: if a DataFrame is passed as `data = ...`, which column signifies the group?
17-
- `weights`: how to weight the different sub-models,
18-
defaults to number of samples per group in the new data
19-
- `kwargs`: has to be a dict with keys equal to the group names.
20-
For `data` can also be a DataFrame with `column` containing the group information,
21-
and for `specification` can also be an `EnsembleParameterTable`.
22-
23-
# Examples
24-
See the online documentation on [Replace observed data](@ref).
25-
"""
26-
function replace_observed end
27-
28-
"""
29-
update_observed(to_update, observed::SemObserved; kwargs...)
30-
31-
Update a `SemImplied`, `SemLossFunction` or `SemOptimizer` object to use a `SemObserved` object.
32-
33-
# Examples
34-
See the online documentation on [Replace observed data](@ref).
35-
36-
# Implementation
37-
You can provide a method for this function when defining a new type, for more information
38-
on this see the online developer documentation on [Update observed data](@ref).
39-
"""
40-
function update_observed end
41-
42-
############################################################################################
43-
# change observed (data) without reconstructing the whole model
44-
############################################################################################
45-
46-
# don't change non-SEM terms
47-
replace_observed(loss::AbstractLoss; kwargs...) = loss
48-
49-
# use the same observed type as before
50-
replace_observed(loss::SemLoss; kwargs...) =
51-
replace_observed(loss, typeof(SEM.observed(loss)).name.wrapper; kwargs...)
52-
53-
# construct a new observed type
54-
replace_observed(loss::SemLoss, observed_type; kwargs...) =
55-
replace_observed(loss, observed_type(; kwargs...); kwargs...)
56-
57-
function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...)
58-
kwargs = Dict{Symbol, Any}(kwargs...)
59-
old_observed = SEM.observed(loss)
60-
implied = SEM.implied(loss)
61-
62-
# get field types
63-
kwargs[:observed_type] = typeof(new_observed)
64-
kwargs[:old_observed_type] = typeof(old_observed)
65-
66-
# update implied
67-
new_implied = update_observed(implied, new_observed; kwargs...)
68-
kwargs[:implied] = new_implied
69-
kwargs[:implied_type] = typeof(new_implied)
70-
kwargs[:nparams] = nparams(new_implied)
71-
72-
# update loss
73-
return update_observed(loss, new_observed; kwargs...)
74-
end
75-
76-
replace_observed(loss::LossTerm; kwargs...) =
77-
LossTerm(replace_observed(loss.loss; kwargs...), loss.id, loss.weight)
78-
79-
function replace_observed(sem::Sem; kwargs...)
80-
updated_terms = Tuple(replace_observed(term; kwargs...) for term in loss_terms(sem))
81-
return Sem(updated_terms...)
82-
end
83-
84-
function replace_observed(
85-
emodel::SemEnsemble;
86-
column = :group,
87-
weights = nothing,
88-
kwargs...,
89-
)
90-
kwargs = Dict{Symbol, Any}(kwargs...)
91-
# allow for EnsembleParameterTable to be passed as specification
92-
if haskey(kwargs, :specification) && isa(kwargs[:specification], EnsembleParameterTable)
93-
kwargs[:specification] = convert(Dict{Symbol, RAMMatrices}, kwargs[:specification])
94-
end
95-
# allow for DataFrame with group variable "column" to be passed as new data
96-
if haskey(kwargs, :data) && isa(kwargs[:data], DataFrame)
97-
kwargs[:data] = Dict(
98-
group =>
99-
select(filter(r -> r[column] == group, kwargs[:data]), Not(column)) for
100-
group in emodel.groups
101-
)
102-
end
103-
# update each model for new data
104-
models = emodel.sems
105-
new_models = Tuple(
106-
replace_observed(m; group_kwargs(g, kwargs)...) for
107-
(m, g) in zip(models, emodel.groups)
108-
)
109-
return SemEnsemble(new_models...; weights = weights, groups = emodel.groups)
110-
end
111-
112-
function group_kwargs(g, kwargs)
113-
return Dict(k => kwargs[k][g] for k in keys(kwargs))
114-
end
115-
1161
############################################################################################
1172
# simulate data
1183
############################################################################################

src/frontend/specification/Sem.jl

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,12 +396,108 @@ function get_SemLoss(loss, observed, implied; kwargs...)
396396
end
397397
end
398398

399-
function update_observed(sem::Sem, new_observed; kwargs...)
400-
new_terms = Tuple(
401-
update_observed(lossterm.loss, new_observed; kwargs...) for
402-
lossterm in loss_terms(sem)
399+
##############################################################
400+
# replace_observed: Sem level
401+
##############################################################
402+
403+
"""
404+
replace_observed(model::Sem, observed::SemObserved)
405+
replace_observed(model::Sem, data::AbstractDict{Symbol})
406+
replace_observed(model::Sem, data::AbstractDataFrame; [semterm_column])
407+
replace_observed(loss::SemLoss, observed::SemObserved)
408+
replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame})
409+
410+
Construct a new SEM model or SEM loss with replaced observed data.
411+
412+
The SEM structure (implied covariance, loss type) is preserved;
413+
only the observed data is swapped.
414+
415+
# Single-term models
416+
417+
Pass a `SemObserved` object, a data matrix, or a `DataFrame`:
418+
```julia
419+
replace_observed(model, new_data_matrix)
420+
replace_observed(model, new_sem_observed)
421+
replace_observed(model, new_df)
422+
```
423+
424+
# Multi-term models
425+
426+
Pass a `Dict{Symbol}` mapping term ids to data or `SemObserved` objects:
427+
```julia
428+
replace_observed(model, Dict(:g1 => data1, :g2 => data2))
429+
```
430+
431+
Or pass a `DataFrame` with a `semterm_column` identifying the group:
432+
```julia
433+
replace_observed(model, new_df; semterm_column = :group)
434+
```
435+
"""
436+
function replace_observed end
437+
438+
function replace_observed(sem::Sem, data::Union{SemObserved, AbstractMatrix})
439+
nsem_terms(sem) > 1 && throw(
440+
ArgumentError(
441+
"Model contains $(nsem_terms(sem)) SEM terms. " *
442+
"Use a Dict{Symbol} or a DataFrame with `semterm_column` to provide per-term data.",
443+
),
444+
)
445+
updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem))
446+
return Sem(updated_terms...)
447+
end
448+
449+
function replace_observed(sem::Sem, data::AbstractDict{Symbol})
450+
term_ids = Set(begin
451+
tid= id(term)
452+
isnothing(tid) && throw(
453+
ArgumentError(
454+
"Multigroup replace_observed requires all SEM terms to have ids.",
455+
),
456+
)
457+
end for term in loss_terms(sem) if issemloss(term)
458+
)
459+
# check for extra ids
460+
for tid in keys(data)
461+
if tid term_ids
462+
@warn "Data provided for term id :$tid, but no SEM term with this id exists in the model"
463+
end
464+
end
465+
466+
updated_terms = map(loss_terms(sem)) do term
467+
issemloss(term) || return term
468+
tid = id(term)
469+
term_data = get(data, tid, nothing)
470+
isnothing(term_data) && throw(
471+
ArgumentError("No data provided for SEM term :$tid"),
472+
)
473+
return replace_observed(term, term_data)
474+
end
475+
return Sem(Tuple(updated_terms)...)
476+
end
477+
478+
function replace_observed(
479+
sem::Sem,
480+
data::AbstractDataFrame;
481+
semterm_column::Union{Symbol, Nothing} = nothing,
482+
)
483+
if isnothing(semterm_column)
484+
# single-term shortcut
485+
nsem_terms(sem) > 1 && throw(
486+
ArgumentError(
487+
"Model contains $(nsem_terms(sem)) SEM terms. " *
488+
"Provide `semterm_column` to specify which DataFrame column identifies the groups.",
489+
),
490+
)
491+
updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem))
492+
return Sem(updated_terms...)
493+
end
494+
495+
# multi-term: split DataFrame by semterm_column
496+
terms_data = Dict(
497+
g[semterm_column] => group_data
498+
for (g, group_data) in pairs(groupby(data, semterm_column))
403499
)
404-
return Sem(new_terms...)
500+
return replace_observed(sem, terms_data)
405501
end
406502

407503
##############################################################

src/implied/RAM/generic.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,3 @@ function update!(targets::EvaluationTargets, implied::RAM, params)
179179
mul!(implied.μ, implied.F⨉I_A⁻¹, implied.M)
180180
end
181181
end
182-
183-
############################################################################################
184-
### Recommended methods
185-
############################################################################################
186-
187-
function update_observed(implied::RAM, observed::SemObserved; kwargs...)
188-
if nobserved_vars(observed) == size(implied.Σ, 1)
189-
return implied
190-
else
191-
return RAM(; observed = observed, kwargs...)
192-
end
193-
end

src/implied/RAM/symbolic.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,6 @@ function update!(targets::EvaluationTargets, implied::RAMSymbolic, par)
190190
end
191191
end
192192

193-
############################################################################################
194-
### Recommended methods
195-
############################################################################################
196-
197-
function update_observed(implied::RAMSymbolic, observed::SemObserved; kwargs...)
198-
if nobserved_vars(observed) == size(implied.Σ, 1)
199-
return implied
200-
else
201-
return RAMSymbolic(; observed = observed, kwargs...)
202-
end
203-
end
204-
205193
############################################################################################
206194
### additional functions
207195
############################################################################################

src/implied/empty.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,3 @@ end
4646
############################################################################################
4747

4848
update!(targets::EvaluationTargets, implied::ImpliedEmpty, par) = nothing
49-
50-
############################################################################################
51-
### Recommended methods
52-
############################################################################################
53-
54-
update_observed(implied::ImpliedEmpty, observed::SemObserved; kwargs...) = implied

src/loss/ML/FIML.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,6 @@ function evaluate!(objective, gradient, hessian, loss::SemFIML, params)
159159

160160
return objective
161161
end
162-
163-
############################################################################################
164-
### Recommended methods
165-
############################################################################################
166-
167-
update_observed(loss::SemFIML, observed::SemObserved; kwargs...) =
168-
SemFIML(; observed = observed, kwargs...)
169-
170162
############################################################################################
171163
### additional functions
172164
############################################################################################

src/loss/ML/ML.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -235,18 +235,3 @@ function non_posdef_return(par)
235235
return typemax(eltype(par))
236236
end
237237
end
238-
239-
############################################################################################
240-
### recommended methods
241-
############################################################################################
242-
243-
update_observed(loss::SemML, observed::SemObservedMissing; kwargs...) =
244-
error("ML estimation does not work with missing data - use FIML instead")
245-
246-
function update_observed(loss::SemML, observed::SemObserved; kwargs...)
247-
if (obs_cov(loss) == obs_cov(observed)) && (obs_mean(loss) == obs_mean(observed))
248-
return loss # no change
249-
else
250-
return SemML(observed, loss.implied; kwargs...)
251-
end
252-
end

0 commit comments

Comments
 (0)