@@ -43,36 +43,42 @@ function update_observed end
4343# change observed (data) without reconstructing the whole model
4444# ###########################################################################################
4545
46+ # don't change non-SEM terms
47+ replace_observed (loss:: AbstractLoss ; kwargs... ) = loss
48+
4649# use the same observed type as before
47- replace_observed (model:: AbstractSemSingle ; kwargs... ) =
48- replace_observed (model, typeof (observed (model)). name. wrapper; kwargs... )
50+ replace_observed (loss:: SemLoss ; kwargs... ) =
51+ replace_observed (loss, typeof (SEM. observed (loss)). name. wrapper; kwargs... )
52+
53+ # construct a new observed type
54+ replace_observed (loss:: SemLoss , observed_type; kwargs... ) =
55+ replace_observed (loss, observed_type (; kwargs... ); kwargs... )
4956
50- function replace_observed (model:: AbstractSemSingle , observed_type; kwargs... )
51- new_observed = observed_type (; kwargs... )
57+ function replace_observed (loss:: SemLoss , new_observed:: SemObserved ; kwargs... )
5258 kwargs = Dict {Symbol, Any} (kwargs... )
59+ old_observed = SEM. observed (loss)
60+ implied = SEM. implied (loss)
5361
5462 # get field types
5563 kwargs[:observed_type ] = typeof (new_observed)
56- kwargs[:old_observed_type ] = typeof (model. observed)
57- kwargs[:implied_type ] = typeof (model. implied)
58- kwargs[:loss_types ] = [typeof (lossfun) for lossfun in model. loss. functions]
64+ kwargs[:old_observed_type ] = typeof (old_observed)
5965
6066 # update implied
61- new_implied = update_observed (model . implied, new_observed; kwargs... )
67+ new_implied = update_observed (implied, new_observed; kwargs... )
6268 kwargs[:implied ] = new_implied
69+ kwargs[:implied_type ] = typeof (new_implied)
6370 kwargs[:nparams ] = nparams (new_implied)
6471
6572 # update loss
66- new_loss = update_observed (model. loss, new_observed; kwargs... )
67-
68- return Sem (new_observed, new_implied, new_loss)
73+ return update_observed (loss, new_observed; kwargs... )
6974end
7075
71- function update_observed (loss:: SemLoss , new_observed; kwargs... )
72- new_functions = Tuple (
73- update_observed (lossfun, new_observed; kwargs... ) for lossfun in loss. functions
74- )
75- return SemLoss (new_functions, loss. weights)
76+ replace_observed (loss:: LossTerm ; kwargs... ) =
77+ LossTerm (replace_observed (loss. loss; kwargs... ), loss. id, loss. weight)
78+
79+ function replace_observed (sem:: Sem ; kwargs... )
80+ updated_terms = Tuple (replace_observed (term; kwargs... ) for term in loss_terms (sem))
81+ return Sem (updated_terms... )
7682end
7783
7884function replace_observed (
@@ -111,39 +117,38 @@ end
111117# simulate data
112118# ###########################################################################################
113119"""
114- (1) rand(model::AbstractSemSingle, params, n)
115-
116- (2) rand(model::AbstractSemSingle, n)
120+ rand(sem::Union{Sem, SemLoss, SemImplied}, [params], n)
117121
118- Sample normally distributed data from the model- implied covariance matrix and mean vector .
122+ Sample from the multivariate normal distribution implied by the SEM model .
119123
120124# Arguments
121- - `model::AbstractSemSingle`: model to simulate from.
122- - `params`: parameter values to simulate from.
123- - `n::Integer`: Number of samples.
125+ - `sem`: SEM model to use. Ensemble models with multiple SEM terms are not supported.
126+ - `params`: optional SEM model parameters to simulate from, otherwise uses the
127+ current state of implied covariances and means.
128+ - `n::Integer`: Number of samples to draw.
124129
125130# Examples
126131```julia
127132rand(model, start_simple(model), 100)
128133```
129134"""
130- function Distributions. rand (
131- model:: AbstractSemSingle{O, I, L} ,
132- params,
133- n:: Integer ,
134- ) where {O, I <: Union{RAM, RAMSymbolic} , L}
135- update! (EvaluationTargets {true, false, false} (), model. implied, model, params)
136- return rand (model, n)
137- end
138-
139- function Distributions. rand (
140- model:: AbstractSemSingle{O, I, L} ,
141- n:: Integer ,
142- ) where {O, I <: Union{RAM, RAMSymbolic} , L}
143- if MeanStruct (model. implied) === NoMeanStruct
144- data = permutedims (rand (MvNormal (Symmetric (model. implied. Σ)), n))
145- elseif MeanStruct (model. implied) === HasMeanStruct
146- data = permutedims (rand (MvNormal (model. implied. μ, Symmetric (model. implied. Σ)), n))
135+ function Distributions. rand (implied:: SemImplied , params, n:: Integer )
136+ if ! isnothing (params)
137+ # update the implied covariances with the new model params
138+ update! (EvaluationTargets {true, false, false} (), implied, params)
139+ end
140+ Σ = Symmetric (implied. Σ)
141+ if MeanStruct (implied) === NoMeanStruct
142+ return permutedims (rand (MvNormal (Σ), n))
143+ elseif MeanStruct (implied) === HasMeanStruct
144+ return permutedims (rand (MvNormal (implied. μ, Σ), n))
147145 end
148- return data
149146end
147+
148+ Distributions. rand (loss:: SemLoss , params, n:: Integer ) = rand (SEM. implied (loss), params, n)
149+
150+ Distributions. rand (model:: Sem , params, n:: Integer ) = rand (sem_term (model), params, n)
151+
152+ # rand() overloads without SEM params
153+ Distributions. rand (implied:: Union{SemImplied, SemLoss, Sem} , n:: Integer ) =
154+ Distributions. rand (implied, nothing , n)
0 commit comments