33
44 (2) replace_observed(model::AbstractSemSingle, observed; kwargs...)
55
6+ (3) replace_observed(model::SemEnsemble; column = :group, weights = nothing, kwargs...)
7+
68Return a new model with swaped observed part.
79
810# Arguments
911- `model::AbstractSemSingle`: model to swap the observed part of.
1012- `kwargs`: additional keyword arguments; typically includes `data` and `specification`
1113- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved`
1214
15+ # For SemEnsemble models:
16+ - `column`: if a DataFrame is passed as `data = ...`, which column signifies the group?
17+ - `weights`: how to weight the different sub-models,
18+ defaults to number of samples per group in the new data
19+ - `kwargs`: has to be a dict with keys equal to the group names.
20+ For `data` can also be a DataFrame with `column` containing the group information,
21+ and for `specification` can also be an `EnsembleParameterTable`.
22+
1323# Examples
1424See the online documentation on [Replace observed data](@ref).
1525"""
@@ -37,51 +47,28 @@ function update_observed end
3747replace_observed (model:: AbstractSemSingle ; kwargs... ) =
3848 replace_observed (model, typeof (observed (model)). name. wrapper; kwargs... )
3949
40- # construct a new observed type
41- replace_observed (model:: AbstractSemSingle , observed_type; kwargs... ) =
42- replace_observed (model, observed_type (; kwargs... ); kwargs... )
43-
44- replace_observed (model:: AbstractSemSingle , new_observed:: SemObserved ; kwargs... ) =
45- replace_observed (
46- model,
47- observed (model),
48- implied (model),
49- loss (model),
50- new_observed;
51- kwargs... ,
52- )
53-
54- function replace_observed (
55- model:: AbstractSemSingle ,
56- old_observed,
57- implied,
58- loss,
59- new_observed:: SemObserved ;
60- kwargs... ,
61- )
50+ function replace_observed (model:: AbstractSemSingle , observed_type; kwargs... )
51+ new_observed = observed_type (;kwargs... )
6252 kwargs = Dict {Symbol, Any} (kwargs... )
6353
6454 # get field types
6555 kwargs[:observed_type ] = typeof (new_observed)
66- kwargs[:old_observed_type ] = typeof (old_observed )
67- kwargs[:implied_type ] = typeof (implied)
68- kwargs[:loss_types ] = [typeof (lossfun) for lossfun in loss. functions]
56+ kwargs[:old_observed_type ] = typeof (model . observed )
57+ kwargs[:implied_type ] = typeof (model . implied)
58+ kwargs[:loss_types ] = [typeof (lossfun) for lossfun in model . loss. functions]
6959
7060 # update implied
71- implied = update_observed (implied, new_observed; kwargs... )
72- kwargs[:implied ] = implied
73- kwargs[:nparams ] = nparams (implied )
61+ new_implied = update_observed (model . implied, new_observed; kwargs... )
62+ kwargs[:implied ] = new_implied
63+ kwargs[:nparams ] = nparams (new_implied )
7464
7565 # update loss
76- loss = update_observed (loss, new_observed; kwargs... )
77- kwargs[:loss ] = loss
78-
79- # new_implied = update_observed(model.implied, new_observed; kwargs...)
66+ new_loss = update_observed (model. loss, new_observed; kwargs... )
8067
8168 return Sem (
8269 new_observed,
83- update_observed (model . implied, new_observed; kwargs ... ) ,
84- update_observed (model . loss, new_observed; kwargs ... ),
70+ new_implied ,
71+ new_loss
8572 )
8673end
8774
@@ -92,6 +79,39 @@ function update_observed(loss::SemLoss, new_observed; kwargs...)
9279 return SemLoss (new_functions, loss. weights)
9380end
9481
82+
83+ function replace_observed (
84+ emodel:: SemEnsemble ;
85+ column = :group ,
86+ weights = nothing ,
87+ kwargs... ,
88+ )
89+ kwargs = Dict {Symbol, Any} (kwargs... )
90+ # allow for EnsembleParameterTable to be passed as specification
91+ if haskey (kwargs, :specification ) && isa (kwargs[:specification ], EnsembleParameterTable)
92+ kwargs[:specification ] = convert (Dict{Symbol, RAMMatrices}, kwargs[:specification ])
93+ end
94+ # allow for DataFrame with group variable "column" to be passed as new data
95+ if haskey (kwargs, :data ) && isa (kwargs[:data ], DataFrame)
96+ kwargs[:data ] = Dict (
97+ group => select (
98+ filter (
99+ r -> r[column] == group,
100+ kwargs[:data ]),
101+ Not (column)) for group in emodel. groups)
102+ end
103+ # update each model for new data
104+ models = emodel. sems
105+ new_models = Tuple (
106+ replace_observed (m; group_kwargs (g, kwargs)... ) for (m, g) in zip (models, emodel. groups)
107+ )
108+ return SemEnsemble (new_models... ; weights = weights, groups = emodel. groups)
109+ end
110+
111+ function group_kwargs (g, kwargs)
112+ return Dict (k => kwargs[k][g] for k in keys (kwargs))
113+ end
114+
95115# ###########################################################################################
96116# simulate data
97117# ###########################################################################################
0 commit comments