Skip to content

Commit 355d1bf

Browse files
author
Alexey Stukalov
committed
replace_observed(): simplify & refactor
remove update_observed!()
1 parent 8a2393c commit 355d1bf

16 files changed

Lines changed: 140 additions & 253 deletions

File tree

ext/SEMNLOptExt/NLopt.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,6 @@ function SemOptimizerNLopt(;
107107
)
108108
end
109109

110-
############################################################################################
111-
### Recommended methods
112-
############################################################################################
113-
114-
SEM.update_observed(optimizer::SemOptimizerNLopt, observed::SemObserved; kwargs...) =
115-
optimizer
116-
117110
############################################################################################
118111
### additional methods
119112
############################################################################################

ext/SEMProximalOptExt/ProximalAlgorithms.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,6 @@ SemOptimizerProximal(;
3434

3535
SEM.sem_optimizer_subtype(::Val{:Proximal}) = SemOptimizerProximal
3636

37-
############################################################################################
38-
### Recommended methods
39-
############################################################################################
40-
41-
SEM.update_observed(optimizer::SemOptimizerProximal, observed::SemObserved; kwargs...) =
42-
optimizer
43-
4437
############################################################################
4538
### Model fitting
4639
############################################################################

src/StructuralEquationModels.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ export AbstractSem,
205205
z_test!,
206206
example_data,
207207
replace_observed,
208-
update_observed,
209208
@StenoGraph,
210209
,
211210
,

src/additional_functions/simulation.jl

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,3 @@
1-
"""
2-
(1) replace_observed(model::AbstractSemSingle; kwargs...)
3-
4-
(2) replace_observed(model::AbstractSemSingle, observed; kwargs...)
5-
6-
(3) replace_observed(model::SemEnsemble; column = :group, weights = nothing, kwargs...)
7-
8-
Return a new model with swaped observed part.
9-
10-
# Arguments
11-
- `model::AbstractSemSingle`: model to swap the observed part of.
12-
- `kwargs`: additional keyword arguments; typically includes `data` and `specification`
13-
- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved`
14-
15-
# For SemEnsemble models:
16-
- `column`: if a DataFrame is passed as `data = ...`, which column signifies the group?
17-
- `weights`: how to weight the different sub-models,
18-
defaults to number of samples per group in the new data
19-
- `kwargs`: has to be a dict with keys equal to the group names.
20-
For `data` can also be a DataFrame with `column` containing the group information,
21-
and for `specification` can also be an `EnsembleParameterTable`.
22-
23-
# Examples
24-
See the online documentation on [Replace observed data](@ref).
25-
"""
26-
function replace_observed end
27-
28-
"""
29-
update_observed(to_update, observed::SemObserved; kwargs...)
30-
31-
Update a `SemImplied`, `SemLossFunction` or `SemOptimizer` object to use a `SemObserved` object.
32-
33-
# Examples
34-
See the online documentation on [Replace observed data](@ref).
35-
36-
# Implementation
37-
You can provide a method for this function when defining a new type, for more information
38-
on this see the online developer documentation on [Update observed data](@ref).
39-
"""
40-
function update_observed end
41-
42-
############################################################################################
43-
# change observed (data) without reconstructing the whole model
44-
############################################################################################
45-
46-
# don't change non-SEM terms
47-
replace_observed(loss::AbstractLoss; kwargs...) = loss
48-
49-
# use the same observed type as before
50-
replace_observed(loss::SemLoss; kwargs...) =
51-
replace_observed(loss, typeof(SEM.observed(loss)).name.wrapper; kwargs...)
52-
53-
# construct a new observed type
54-
replace_observed(loss::SemLoss, observed_type; kwargs...) =
55-
replace_observed(loss, observed_type(; kwargs...); kwargs...)
56-
57-
function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...)
58-
kwargs = Dict{Symbol, Any}(kwargs...)
59-
old_observed = SEM.observed(loss)
60-
implied = SEM.implied(loss)
61-
62-
# get field types
63-
kwargs[:observed_type] = typeof(new_observed)
64-
kwargs[:old_observed_type] = typeof(old_observed)
65-
66-
# update implied
67-
new_implied = update_observed(implied, new_observed; kwargs...)
68-
kwargs[:implied] = new_implied
69-
kwargs[:implied_type] = typeof(new_implied)
70-
kwargs[:nparams] = nparams(new_implied)
71-
72-
# update loss
73-
return update_observed(loss, new_observed; kwargs...)
74-
end
75-
76-
replace_observed(loss::LossTerm; kwargs...) =
77-
LossTerm(replace_observed(loss.loss; kwargs...), loss.id, loss.weight)
78-
79-
function replace_observed(sem::Sem; kwargs...)
80-
updated_terms = Tuple(replace_observed(term; kwargs...) for term in loss_terms(sem))
81-
return Sem(updated_terms...)
82-
end
83-
84-
function replace_observed(
85-
emodel::SemEnsemble;
86-
column = :group,
87-
weights = nothing,
88-
kwargs...,
89-
)
90-
kwargs = Dict{Symbol, Any}(kwargs...)
91-
# allow for EnsembleParameterTable to be passed as specification
92-
if haskey(kwargs, :specification) && isa(kwargs[:specification], EnsembleParameterTable)
93-
kwargs[:specification] = convert(Dict{Symbol, RAMMatrices}, kwargs[:specification])
94-
end
95-
# allow for DataFrame with group variable "column" to be passed as new data
96-
if haskey(kwargs, :data) && isa(kwargs[:data], DataFrame)
97-
kwargs[:data] = Dict(
98-
group =>
99-
select(filter(r -> r[column] == group, kwargs[:data]), Not(column)) for
100-
group in emodel.groups
101-
)
102-
end
103-
# update each model for new data
104-
models = emodel.sems
105-
new_models = Tuple(
106-
replace_observed(m; group_kwargs(g, kwargs)...) for
107-
(m, g) in zip(models, emodel.groups)
108-
)
109-
return SemEnsemble(new_models...; weights = weights, groups = emodel.groups)
110-
end
111-
112-
function group_kwargs(g, kwargs)
113-
return Dict(k => kwargs[k][g] for k in keys(kwargs))
114-
end
115-
1161
############################################################################################
1172
# simulate data
1183
############################################################################################

src/frontend/specification/Sem.jl

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,118 @@ function build_SemTerms(loss, observed, implied; kwargs...)
436436
end
437437
end
438438

439-
function update_observed(sem::Sem, new_observed; kwargs...)
440-
new_terms = Tuple(
441-
update_observed(lossterm.loss, new_observed; kwargs...) for
442-
lossterm in loss_terms(sem)
439+
##############################################################
440+
# replace_observed: Sem level
441+
##############################################################
442+
443+
"""
444+
replace_observed(model::Sem, observed::SemObserved)
445+
replace_observed(model::Sem, data::AbstractDict{Symbol})
446+
replace_observed(model::Sem, data::AbstractDataFrame; [semterm_column])
447+
replace_observed(loss::SemLoss, observed::SemObserved)
448+
replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame})
449+
450+
Construct a new SEM model or SEM loss with replaced observed data.
451+
452+
The SEM structure (implied covariance, loss type) is preserved;
453+
only the observed data is swapped.
454+
455+
# Single-term models
456+
457+
Pass a `SemObserved` object, a data matrix, or a `DataFrame`:
458+
```julia
459+
replace_observed(model, new_data_matrix)
460+
replace_observed(model, new_sem_observed)
461+
replace_observed(model, new_df)
462+
```
463+
464+
# Multi-term models
465+
466+
Pass a `Dict{Symbol}` mapping term ids to data or `SemObserved` objects:
467+
```julia
468+
replace_observed(model, Dict(:g1 => data1, :g2 => data2))
469+
```
470+
471+
Or pass a `DataFrame` with a `semterm_column` identifying the group:
472+
```julia
473+
replace_observed(model, new_df; semterm_column = :group)
474+
```
475+
"""
476+
function replace_observed end
477+
478+
function replace_observed(sem::Sem, data::Union{SemObserved, AbstractMatrix})
479+
nsem_terms(sem) > 1 && throw(
480+
ArgumentError(
481+
"Model contains $(nsem_terms(sem)) SEM terms. " *
482+
"Use a Dict{Symbol} or a DataFrame with `semterm_column` to provide per-term data.",
483+
),
484+
)
485+
updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem))
486+
return Sem(updated_terms...)
487+
end
488+
489+
function replace_observed(sem::Sem, data::AbstractDict{Symbol})
490+
term_ids = Set(
491+
if !isnothing(id(term))
492+
id(term)
493+
else
494+
"Multigroup replace_observed(sem, data::Dict) requires all SEM terms to have ids." |>
495+
ArgumentError |>
496+
throw
497+
end for term in loss_terms(sem) if issemloss(term)
498+
)
499+
# check for extra ids
500+
extra_term_ids = setdiff(keys(data), term_ids)
501+
isempty(extra_term_ids) ||
502+
@warn "Ignoring data with ids=$(collect(extra_term_ids)): no such SEM terms exist in the model"
503+
504+
updated_terms = map(loss_terms(sem)) do term
505+
issemloss(term) || return term
506+
tid = id(term)
507+
term_data = get(data, tid, nothing)
508+
isnothing(term_data) &&
509+
throw(ArgumentError("No data provided for SEM term :$tid"))
510+
return replace_observed(term, term_data)
511+
end
512+
return Sem(Tuple(updated_terms)...)
513+
end
514+
515+
function replace_observed(sem::Sem, data::AbstractVector)
516+
nsem = nsem_terms(sem)
517+
nsem == length(data) || throw(
518+
ArgumentError(
519+
"Length of data ($(length(data))) does not match number of SEM terms ($nsem)",
520+
),
521+
)
522+
updated_terms = map(enumerate(loss_terms(sem))) do (i, term)
523+
issemloss(term) ? replace_observed(term, data[i]) : term
524+
end
525+
return Sem(Tuple(updated_terms)...)
526+
end
527+
528+
function replace_observed(
529+
sem::Sem,
530+
data::AbstractDataFrame;
531+
semterm_column::Union{Symbol, Nothing} = nothing,
532+
)
533+
if isnothing(semterm_column)
534+
# single-term shortcut
535+
nsem_terms(sem) > 1 && throw(
536+
ArgumentError(
537+
"Model contains $(nsem_terms(sem)) SEM terms. " *
538+
"Provide `semterm_column` to specify which DataFrame column identifies the groups.",
539+
),
540+
)
541+
updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem))
542+
return Sem(updated_terms...)
543+
end
544+
545+
# multi-term: split DataFrame by semterm_column
546+
terms_data = Dict(
547+
g[semterm_column] => group_data for
548+
(g, group_data) in pairs(groupby(data, semterm_column))
443549
)
444-
return Sem(new_terms...)
550+
return replace_observed(sem, terms_data)
445551
end
446552

447553
##############################################################

src/implied/RAM/generic.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -179,20 +179,3 @@ function update!(targets::EvaluationTargets, implied::RAM, params)
179179
mul!(implied.μ, implied.F⨉I_A⁻¹, implied.M)
180180
end
181181
end
182-
183-
############################################################################################
184-
### Recommended methods
185-
############################################################################################
186-
187-
function update_observed(implied::RAM, observed::SemObserved; kwargs...)
188-
if nobserved_vars(observed) == nobserved_vars(implied)
189-
return implied
190-
else
191-
return RAM(;
192-
observed = observed,
193-
gradient_required = !isnothing(implied.∇A),
194-
meanstructure = MeanStruct(implied) == HasMeanStruct,
195-
kwargs...,
196-
)
197-
end
198-
end

src/implied/RAM/symbolic.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -190,26 +190,6 @@ function update!(targets::EvaluationTargets, implied::RAMSymbolic, par)
190190
end
191191
end
192192

193-
############################################################################################
194-
### Recommended methods
195-
############################################################################################
196-
197-
function update_observed(implied::RAMSymbolic, observed::SemObserved; kwargs...)
198-
if nobserved_vars(observed) == nobserved_vars(implied)
199-
return implied
200-
else
201-
return RAMSymbolic(;
202-
observed = observed,
203-
vech = implied.Σ isa Vector,
204-
gradient = !isnothing(implied.∇Σ),
205-
hessian = !isnothing(implied.∇²Σ),
206-
meanstructure = MeanStruct(implied) == HasMeanStruct,
207-
approximate_hessian = isnothing(implied.∇²Σ),
208-
kwargs...,
209-
)
210-
end
211-
end
212-
213193
############################################################################################
214194
### additional functions
215195
############################################################################################

src/implied/empty.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,3 @@ end
4646
############################################################################################
4747

4848
update!(targets::EvaluationTargets, implied::ImpliedEmpty, par) = nothing
49-
50-
############################################################################################
51-
### Recommended methods
52-
############################################################################################
53-
54-
update_observed(implied::ImpliedEmpty, observed::SemObserved; kwargs...) = implied

src/loss/ML/FIML.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,6 @@ function evaluate!(objective, gradient, hessian, loss::SemFIML, params)
159159

160160
return objective
161161
end
162-
163-
############################################################################################
164-
### Recommended methods
165-
############################################################################################
166-
167-
update_observed(loss::SemFIML, observed::SemObserved; kwargs...) =
168-
SemFIML(; observed = observed, kwargs...)
169-
170162
############################################################################################
171163
### additional functions
172164
############################################################################################

src/loss/ML/ML.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -235,23 +235,3 @@ function non_posdef_return(par)
235235
return typemax(eltype(par))
236236
end
237237
end
238-
239-
############################################################################################
240-
### recommended methods
241-
############################################################################################
242-
243-
update_observed(loss::SemML, observed::SemObservedMissing; kwargs...) =
244-
error("ML estimation does not work with missing data - use FIML instead")
245-
246-
function update_observed(loss::SemML, observed::SemObserved; kwargs...)
247-
if (obs_cov(loss) == obs_cov(observed)) && (obs_mean(loss) == obs_mean(observed))
248-
return loss # no change
249-
else
250-
return SemML(
251-
observed,
252-
loss.implied;
253-
approximate_hessian = HessianEval(loss) == ApproxHessian,
254-
kwargs...,
255-
)
256-
end
257-
end

0 commit comments

Comments
 (0)