Skip to content

Commit 5f72f9e

Browse files
author
Alexey Stukalov
committed
replace_observed(): simplify & refactor
remove update_observed!()
1 parent ef4e08a commit 5f72f9e

16 files changed

Lines changed: 143 additions & 253 deletions

File tree

ext/SEMNLOptExt/NLopt.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,6 @@ function SemOptimizerNLopt(;
107107
)
108108
end
109109

110-
############################################################################################
111-
### Recommended methods
112-
############################################################################################
113-
114-
SEM.update_observed(optimizer::SemOptimizerNLopt, observed::SemObserved; kwargs...) =
115-
optimizer
116-
117110
############################################################################################
118111
### additional methods
119112
############################################################################################

ext/SEMProximalOptExt/ProximalAlgorithms.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,6 @@ SemOptimizerProximal(;
3434

3535
SEM.sem_optimizer_subtype(::Val{:Proximal}) = SemOptimizerProximal
3636

37-
############################################################################################
38-
### Recommended methods
39-
############################################################################################
40-
41-
SEM.update_observed(optimizer::SemOptimizerProximal, observed::SemObserved; kwargs...) =
42-
optimizer
43-
4437
############################################################################
4538
### Model fitting
4639
############################################################################

src/StructuralEquationModels.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ export AbstractSem,
205205
z_test!,
206206
example_data,
207207
replace_observed,
208-
update_observed,
209208
@StenoGraph,
210209
,
211210
,

src/additional_functions/simulation.jl

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,3 @@
1-
"""
2-
(1) replace_observed(model::AbstractSemSingle; kwargs...)
3-
4-
(2) replace_observed(model::AbstractSemSingle, observed; kwargs...)
5-
6-
(3) replace_observed(model::SemEnsemble; column = :group, weights = nothing, kwargs...)
7-
8-
Return a new model with swaped observed part.
9-
10-
# Arguments
11-
- `model::AbstractSemSingle`: model to swap the observed part of.
12-
- `kwargs`: additional keyword arguments; typically includes `data` and `specification`
13-
- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved`
14-
15-
# For SemEnsemble models:
16-
- `column`: if a DataFrame is passed as `data = ...`, which column signifies the group?
17-
- `weights`: how to weight the different sub-models,
18-
defaults to number of samples per group in the new data
19-
- `kwargs`: has to be a dict with keys equal to the group names.
20-
For `data` can also be a DataFrame with `column` containing the group information,
21-
and for `specification` can also be an `EnsembleParameterTable`.
22-
23-
# Examples
24-
See the online documentation on [Replace observed data](@ref).
25-
"""
26-
function replace_observed end
27-
28-
"""
29-
update_observed(to_update, observed::SemObserved; kwargs...)
30-
31-
Update a `SemImplied`, `SemLossFunction` or `SemOptimizer` object to use a `SemObserved` object.
32-
33-
# Examples
34-
See the online documentation on [Replace observed data](@ref).
35-
36-
# Implementation
37-
You can provide a method for this function when defining a new type, for more information
38-
on this see the online developer documentation on [Update observed data](@ref).
39-
"""
40-
function update_observed end
41-
42-
############################################################################################
43-
# change observed (data) without reconstructing the whole model
44-
############################################################################################
45-
46-
# don't change non-SEM terms
47-
replace_observed(loss::AbstractLoss; kwargs...) = loss
48-
49-
# use the same observed type as before
50-
replace_observed(loss::SemLoss; kwargs...) =
51-
replace_observed(loss, typeof(SEM.observed(loss)).name.wrapper; kwargs...)
52-
53-
# construct a new observed type
54-
replace_observed(loss::SemLoss, observed_type; kwargs...) =
55-
replace_observed(loss, observed_type(; kwargs...); kwargs...)
56-
57-
function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...)
58-
kwargs = Dict{Symbol, Any}(kwargs...)
59-
old_observed = SEM.observed(loss)
60-
implied = SEM.implied(loss)
61-
62-
# get field types
63-
kwargs[:observed_type] = typeof(new_observed)
64-
kwargs[:old_observed_type] = typeof(old_observed)
65-
66-
# update implied
67-
new_implied = update_observed(implied, new_observed; kwargs...)
68-
kwargs[:implied] = new_implied
69-
kwargs[:implied_type] = typeof(new_implied)
70-
kwargs[:nparams] = nparams(new_implied)
71-
72-
# update loss
73-
return update_observed(loss, new_observed; kwargs...)
74-
end
75-
76-
replace_observed(loss::LossTerm; kwargs...) =
77-
LossTerm(replace_observed(loss.loss; kwargs...), loss.id, loss.weight)
78-
79-
function replace_observed(sem::Sem; kwargs...)
80-
updated_terms = Tuple(replace_observed(term; kwargs...) for term in loss_terms(sem))
81-
return Sem(updated_terms...)
82-
end
83-
84-
function replace_observed(
85-
emodel::SemEnsemble;
86-
column = :group,
87-
weights = nothing,
88-
kwargs...,
89-
)
90-
kwargs = Dict{Symbol, Any}(kwargs...)
91-
# allow for EnsembleParameterTable to be passed as specification
92-
if haskey(kwargs, :specification) && isa(kwargs[:specification], EnsembleParameterTable)
93-
kwargs[:specification] = convert(Dict{Symbol, RAMMatrices}, kwargs[:specification])
94-
end
95-
# allow for DataFrame with group variable "column" to be passed as new data
96-
if haskey(kwargs, :data) && isa(kwargs[:data], DataFrame)
97-
kwargs[:data] = Dict(
98-
group =>
99-
select(filter(r -> r[column] == group, kwargs[:data]), Not(column)) for
100-
group in emodel.groups
101-
)
102-
end
103-
# update each model for new data
104-
models = emodel.sems
105-
new_models = Tuple(
106-
replace_observed(m; group_kwargs(g, kwargs)...) for
107-
(m, g) in zip(models, emodel.groups)
108-
)
109-
return SemEnsemble(new_models...; weights = weights, groups = emodel.groups)
110-
end
111-
112-
function group_kwargs(g, kwargs)
113-
return Dict(k => kwargs[k][g] for k in keys(kwargs))
114-
end
115-
1161
############################################################################################
1172
# simulate data
1183
############################################################################################

src/frontend/specification/Sem.jl

Lines changed: 114 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,121 @@ function build_SemTerms(loss, observed, implied; kwargs...)
438438
end
439439
end
440440

441-
function update_observed(sem::Sem, new_observed; kwargs...)
442-
new_terms = Tuple(
443-
update_observed(lossterm.loss, new_observed; kwargs...) for
444-
lossterm in loss_terms(sem)
441+
##############################################################
442+
# replace_observed: Sem level
443+
##############################################################
444+
445+
"""
446+
replace_observed(model::Sem, observed::SemObserved)
447+
replace_observed(model::Sem, data::AbstractDict{Symbol})
448+
replace_observed(model::Sem, data::AbstractDataFrame; [semterm_column])
449+
replace_observed(loss::SemLoss, observed::SemObserved)
450+
replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame})
451+
452+
Construct a new SEM model or SEM loss with replaced observed data.
453+
454+
The SEM structure (implied covariance, loss type) is preserved;
455+
only the observed data is swapped.
456+
457+
# Single-term models
458+
459+
Pass a `SemObserved` object, a data matrix, or a `DataFrame`:
460+
```julia
461+
replace_observed(model, new_data_matrix)
462+
replace_observed(model, new_sem_observed)
463+
replace_observed(model, new_df)
464+
```
465+
466+
# Multi-term models
467+
468+
Pass a `Dict{Symbol}` mapping term ids to data or `SemObserved` objects:
469+
```julia
470+
replace_observed(model, Dict(:g1 => data1, :g2 => data2))
471+
```
472+
473+
Or pass a `DataFrame` with a `semterm_column` identifying the group:
474+
```julia
475+
replace_observed(model, new_df; semterm_column = :group)
476+
```
477+
"""
478+
function replace_observed end
479+
480+
function replace_observed(sem::Sem, data::Union{SemObserved, AbstractMatrix})
481+
nsem_terms(sem) > 1 && throw(
482+
ArgumentError(
483+
"Model contains $(nsem_terms(sem)) SEM terms. " *
484+
"Use a Dict{Symbol} or a DataFrame with `semterm_column` to provide per-term data.",
485+
),
486+
)
487+
updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem))
488+
return Sem(updated_terms...)
489+
end
490+
491+
function replace_observed(sem::Sem, data::AbstractDict{Symbol})
492+
term_ids = Set(
493+
begin
494+
tid = id(term)
495+
isnothing(tid) && throw(
496+
ArgumentError(
497+
"Multigroup replace_observed requires all SEM terms to have ids.",
498+
),
499+
)
500+
end for term in loss_terms(sem) if issemloss(term)
501+
)
502+
# check for extra ids
503+
for tid in keys(data)
504+
if tid term_ids
505+
@warn "Data provided for term id :$tid, but no SEM term with this id exists in the model"
506+
end
507+
end
508+
509+
updated_terms = map(loss_terms(sem)) do term
510+
issemloss(term) || return term
511+
tid = id(term)
512+
term_data = get(data, tid, nothing)
513+
isnothing(term_data) &&
514+
throw(ArgumentError("No data provided for SEM term :$tid"))
515+
return replace_observed(term, term_data)
516+
end
517+
return Sem(Tuple(updated_terms)...)
518+
end
519+
520+
function replace_observed(sem::Sem, data::AbstractVector)
521+
nsem = nsem_terms(sem)
522+
nsem == length(data) || throw(
523+
ArgumentError(
524+
"Length of data ($(length(data))) does not match number of SEM terms ($nsem)",
525+
),
526+
)
527+
updated_terms = map(enumerate(loss_terms(sem))) do (i, term)
528+
issemloss(term) ? replace_observed(term, data[i]) : term
529+
end
530+
return Sem(Tuple(updated_terms)...)
531+
end
532+
533+
function replace_observed(
534+
sem::Sem,
535+
data::AbstractDataFrame;
536+
semterm_column::Union{Symbol, Nothing} = nothing,
537+
)
538+
if isnothing(semterm_column)
539+
# single-term shortcut
540+
nsem_terms(sem) > 1 && throw(
541+
ArgumentError(
542+
"Model contains $(nsem_terms(sem)) SEM terms. " *
543+
"Provide `semterm_column` to specify which DataFrame column identifies the groups.",
544+
),
545+
)
546+
updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem))
547+
return Sem(updated_terms...)
548+
end
549+
550+
# multi-term: split DataFrame by semterm_column
551+
terms_data = Dict(
552+
g[semterm_column] => group_data for
553+
(g, group_data) in pairs(groupby(data, semterm_column))
445554
)
446-
return Sem(new_terms...)
555+
return replace_observed(sem, terms_data)
447556
end
448557

449558
##############################################################

src/implied/RAM/generic.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -179,20 +179,3 @@ function update!(targets::EvaluationTargets, implied::RAM, params)
179179
mul!(implied.μ, implied.F⨉I_A⁻¹, implied.M)
180180
end
181181
end
182-
183-
############################################################################################
184-
### Recommended methods
185-
############################################################################################
186-
187-
function update_observed(implied::RAM, observed::SemObserved; kwargs...)
188-
if nobserved_vars(observed) == nobserved_vars(implied)
189-
return implied
190-
else
191-
return RAM(;
192-
observed = observed,
193-
gradient_required = !isnothing(implied.∇A),
194-
meanstructure = MeanStruct(implied) == HasMeanStruct,
195-
kwargs...,
196-
)
197-
end
198-
end

src/implied/RAM/symbolic.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -190,26 +190,6 @@ function update!(targets::EvaluationTargets, implied::RAMSymbolic, par)
190190
end
191191
end
192192

193-
############################################################################################
194-
### Recommended methods
195-
############################################################################################
196-
197-
function update_observed(implied::RAMSymbolic, observed::SemObserved; kwargs...)
198-
if nobserved_vars(observed) == nobserved_vars(implied)
199-
return implied
200-
else
201-
return RAMSymbolic(;
202-
observed = observed,
203-
vech = implied.Σ isa Vector,
204-
gradient = !isnothing(implied.∇Σ),
205-
hessian = !isnothing(implied.∇²Σ),
206-
meanstructure = MeanStruct(implied) == HasMeanStruct,
207-
approximate_hessian = isnothing(implied.∇²Σ),
208-
kwargs...,
209-
)
210-
end
211-
end
212-
213193
############################################################################################
214194
### additional functions
215195
############################################################################################

src/implied/empty.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,3 @@ end
4646
############################################################################################
4747

4848
update!(targets::EvaluationTargets, implied::ImpliedEmpty, par) = nothing
49-
50-
############################################################################################
51-
### Recommended methods
52-
############################################################################################
53-
54-
update_observed(implied::ImpliedEmpty, observed::SemObserved; kwargs...) = implied

src/loss/ML/FIML.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,6 @@ function evaluate!(objective, gradient, hessian, loss::SemFIML, params)
159159

160160
return objective
161161
end
162-
163-
############################################################################################
164-
### Recommended methods
165-
############################################################################################
166-
167-
update_observed(loss::SemFIML, observed::SemObserved; kwargs...) =
168-
SemFIML(; observed = observed, kwargs...)
169-
170162
############################################################################################
171163
### additional functions
172164
############################################################################################

src/loss/ML/ML.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -235,23 +235,3 @@ function non_posdef_return(par)
235235
return typemax(eltype(par))
236236
end
237237
end
238-
239-
############################################################################################
240-
### recommended methods
241-
############################################################################################
242-
243-
update_observed(loss::SemML, observed::SemObservedMissing; kwargs...) =
244-
error("ML estimation does not work with missing data - use FIML instead")
245-
246-
function update_observed(loss::SemML, observed::SemObserved; kwargs...)
247-
if (obs_cov(loss) == obs_cov(observed)) && (obs_mean(loss) == obs_mean(observed))
248-
return loss # no change
249-
else
250-
return SemML(
251-
observed,
252-
loss.implied;
253-
approximate_hessian = HessianEval(loss) == ApproxHessian,
254-
kwargs...,
255-
)
256-
end
257-
end

0 commit comments

Comments
 (0)