@@ -436,12 +436,118 @@ function build_SemTerms(loss, observed, implied; kwargs...)
436436 end
437437end
438438
439- function update_observed (sem:: Sem , new_observed; kwargs... )
440- new_terms = Tuple (
441- update_observed (lossterm. loss, new_observed; kwargs... ) for
442- lossterm in loss_terms (sem)
439+ # #############################################################
440+ # replace_observed: Sem level
441+ # #############################################################
442+
443+ """
444+ replace_observed(model::Sem, observed::SemObserved)
445+ replace_observed(model::Sem, data::AbstractDict{Symbol})
446+ replace_observed(model::Sem, data::AbstractDataFrame; [semterm_column])
447+ replace_observed(loss::SemLoss, observed::SemObserved)
448+ replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame})
449+
450+ Construct a new SEM model or SEM loss with replaced observed data.
451+
452+ The SEM structure (implied covariance, loss type) is preserved;
453+ only the observed data is swapped.
454+
455+ # Single-term models
456+
457+ Pass a `SemObserved` object, a data matrix, or a `DataFrame`:
458+ ```julia
459+ replace_observed(model, new_data_matrix)
460+ replace_observed(model, new_sem_observed)
461+ replace_observed(model, new_df)
462+ ```
463+
464+ # Multi-term models
465+
466+ Pass a `Dict{Symbol}` mapping term ids to data or `SemObserved` objects:
467+ ```julia
468+ replace_observed(model, Dict(:g1 => data1, :g2 => data2))
469+ ```
470+
471+ Or pass a `DataFrame` with a `semterm_column` identifying the group:
472+ ```julia
473+ replace_observed(model, new_df; semterm_column = :group)
474+ ```
475+ """
476+ function replace_observed end
477+
478+ function replace_observed (sem:: Sem , data:: Union{SemObserved, AbstractMatrix} )
479+ nsem_terms (sem) > 1 && throw (
480+ ArgumentError (
481+ " Model contains $(nsem_terms (sem)) SEM terms. " *
482+ " Use a Dict{Symbol} or a DataFrame with `semterm_column` to provide per-term data." ,
483+ ),
484+ )
485+ updated_terms = Tuple (replace_observed (term, data) for term in loss_terms (sem))
486+ return Sem (updated_terms... )
487+ end
488+
489+ function replace_observed (sem:: Sem , data:: AbstractDict{Symbol} )
490+ term_ids = Set (
491+ if ! isnothing (id (term))
492+ id (term)
493+ else
494+ " Multigroup replace_observed(sem, data::Dict) requires all SEM terms to have ids." |>
495+ ArgumentError |>
496+ throw
497+ end for term in loss_terms (sem) if issemloss (term)
498+ )
499+ # check for extra ids
500+ extra_term_ids = setdiff (keys (data), term_ids)
501+ isempty (extra_term_ids) ||
502+ @warn " Ignoring data with ids=$(collect (extra_term_ids)) : no such SEM terms exist in the model"
503+
504+ updated_terms = map (loss_terms (sem)) do term
505+ issemloss (term) || return term
506+ tid = id (term)
507+ term_data = get (data, tid, nothing )
508+ isnothing (term_data) &&
509+ throw (ArgumentError (" No data provided for SEM term :$tid " ))
510+ return replace_observed (term, term_data)
511+ end
512+ return Sem (Tuple (updated_terms)... )
513+ end
514+
515+ function replace_observed (sem:: Sem , data:: AbstractVector )
516+ nsem = nsem_terms (sem)
517+ nsem == length (data) || throw (
518+ ArgumentError (
519+ " Length of data ($(length (data)) ) does not match number of SEM terms ($nsem )" ,
520+ ),
521+ )
522+ updated_terms = map (enumerate (loss_terms (sem))) do (i, term)
523+ issemloss (term) ? replace_observed (term, data[i]) : term
524+ end
525+ return Sem (Tuple (updated_terms)... )
526+ end
527+
528+ function replace_observed (
529+ sem:: Sem ,
530+ data:: AbstractDataFrame ;
531+ semterm_column:: Union{Symbol, Nothing} = nothing ,
532+ )
533+ if isnothing (semterm_column)
534+ # single-term shortcut
535+ nsem_terms (sem) > 1 && throw (
536+ ArgumentError (
537+ " Model contains $(nsem_terms (sem)) SEM terms. " *
538+ " Provide `semterm_column` to specify which DataFrame column identifies the groups." ,
539+ ),
540+ )
541+ updated_terms = Tuple (replace_observed (term, data) for term in loss_terms (sem))
542+ return Sem (updated_terms... )
543+ end
544+
545+ # multi-term: split DataFrame by semterm_column
546+ terms_data = Dict (
547+ g[semterm_column] => group_data for
548+ (g, group_data) in pairs (groupby (data, semterm_column))
443549 )
444- return Sem (new_terms ... )
550+ return replace_observed (sem, terms_data )
445551end
446552
447553# #############################################################
0 commit comments