|
1 | | -""" |
2 | | - (1) replace_observed(model::AbstractSemSingle; kwargs...) |
3 | | -
|
4 | | - (2) replace_observed(model::AbstractSemSingle, observed; kwargs...) |
5 | | -
|
6 | | - (3) replace_observed(model::SemEnsemble; column = :group, weights = nothing, kwargs...) |
7 | | -
|
8 | | -Return a new model with swaped observed part. |
9 | | -
|
10 | | -# Arguments |
11 | | -- `model::AbstractSemSingle`: model to swap the observed part of. |
12 | | -- `kwargs`: additional keyword arguments; typically includes `data` and `specification` |
13 | | -- `observed`: Either an object of subtype of `SemObserved` or a subtype of `SemObserved` |
14 | | -
|
15 | | -# For SemEnsemble models: |
16 | | -- `column`: if a DataFrame is passed as `data = ...`, which column signifies the group? |
17 | | -- `weights`: how to weight the different sub-models, |
18 | | - defaults to number of samples per group in the new data |
19 | | -- `kwargs`: has to be a dict with keys equal to the group names. |
20 | | - For `data` can also be a DataFrame with `column` containing the group information, |
21 | | - and for `specification` can also be an `EnsembleParameterTable`. |
22 | | -
|
23 | | -# Examples |
24 | | -See the online documentation on [Replace observed data](@ref). |
25 | | -""" |
26 | | -function replace_observed end |
27 | | - |
28 | | -""" |
29 | | - update_observed(to_update, observed::SemObserved; kwargs...) |
30 | | -
|
31 | | -Update a `SemImplied`, `SemLossFunction` or `SemOptimizer` object to use a `SemObserved` object. |
32 | | -
|
33 | | -# Examples |
34 | | -See the online documentation on [Replace observed data](@ref). |
35 | | -
|
36 | | -# Implementation |
37 | | -You can provide a method for this function when defining a new type, for more information |
38 | | -on this see the online developer documentation on [Update observed data](@ref). |
39 | | -""" |
40 | | -function update_observed end |
41 | | - |
42 | | -############################################################################################ |
43 | | -# change observed (data) without reconstructing the whole model |
44 | | -############################################################################################ |
45 | | - |
46 | | -# don't change non-SEM terms |
47 | | -replace_observed(loss::AbstractLoss; kwargs...) = loss |
48 | | - |
49 | | -# use the same observed type as before |
50 | | -replace_observed(loss::SemLoss; kwargs...) = |
51 | | - replace_observed(loss, typeof(SEM.observed(loss)).name.wrapper; kwargs...) |
52 | | - |
53 | | -# construct a new observed type |
54 | | -replace_observed(loss::SemLoss, observed_type; kwargs...) = |
55 | | - replace_observed(loss, observed_type(; kwargs...); kwargs...) |
56 | | - |
57 | | -function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...) |
58 | | - kwargs = Dict{Symbol, Any}(kwargs...) |
59 | | - old_observed = SEM.observed(loss) |
60 | | - implied = SEM.implied(loss) |
61 | | - |
62 | | - # get field types |
63 | | - kwargs[:observed_type] = typeof(new_observed) |
64 | | - kwargs[:old_observed_type] = typeof(old_observed) |
65 | | - |
66 | | - # update implied |
67 | | - new_implied = update_observed(implied, new_observed; kwargs...) |
68 | | - kwargs[:implied] = new_implied |
69 | | - kwargs[:implied_type] = typeof(new_implied) |
70 | | - kwargs[:nparams] = nparams(new_implied) |
71 | | - |
72 | | - # update loss |
73 | | - return update_observed(loss, new_observed; kwargs...) |
74 | | -end |
75 | | - |
76 | | -replace_observed(loss::LossTerm; kwargs...) = |
77 | | - LossTerm(replace_observed(loss.loss; kwargs...), loss.id, loss.weight) |
78 | | - |
79 | | -function replace_observed(sem::Sem; kwargs...) |
80 | | - updated_terms = Tuple(replace_observed(term; kwargs...) for term in loss_terms(sem)) |
81 | | - return Sem(updated_terms...) |
82 | | -end |
83 | | - |
84 | | -function replace_observed( |
85 | | - emodel::SemEnsemble; |
86 | | - column = :group, |
87 | | - weights = nothing, |
88 | | - kwargs..., |
89 | | -) |
90 | | - kwargs = Dict{Symbol, Any}(kwargs...) |
91 | | - # allow for EnsembleParameterTable to be passed as specification |
92 | | - if haskey(kwargs, :specification) && isa(kwargs[:specification], EnsembleParameterTable) |
93 | | - kwargs[:specification] = convert(Dict{Symbol, RAMMatrices}, kwargs[:specification]) |
94 | | - end |
95 | | - # allow for DataFrame with group variable "column" to be passed as new data |
96 | | - if haskey(kwargs, :data) && isa(kwargs[:data], DataFrame) |
97 | | - kwargs[:data] = Dict( |
98 | | - group => |
99 | | - select(filter(r -> r[column] == group, kwargs[:data]), Not(column)) for |
100 | | - group in emodel.groups |
101 | | - ) |
102 | | - end |
103 | | - # update each model for new data |
104 | | - models = emodel.sems |
105 | | - new_models = Tuple( |
106 | | - replace_observed(m; group_kwargs(g, kwargs)...) for |
107 | | - (m, g) in zip(models, emodel.groups) |
108 | | - ) |
109 | | - return SemEnsemble(new_models...; weights = weights, groups = emodel.groups) |
110 | | -end |
111 | | - |
112 | | -function group_kwargs(g, kwargs) |
113 | | - return Dict(k => kwargs[k][g] for k in keys(kwargs)) |
114 | | -end |
115 | | - |
116 | 1 | ############################################################################################ |
117 | 2 | # simulate data |
118 | 3 | ############################################################################################ |
|
0 commit comments