Skip to content

Commit 293c88b

Browse files
author
Alexey Stukalov
committed
tests/model: test multi-group data ctor
1 parent b5e920a commit 293c88b

2 files changed

Lines changed: 86 additions & 4 deletions

File tree

src/frontend/specification/Sem.jl

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,56 @@ function set_field_type_kwargs!(kwargs, observed, implied, loss, O, I)
335335
end
336336
end
337337

338+
# build ensemble/multi-group observed from the specification and Sem(...) kwargs
339+
# used by Sem(...) and replace_observed()
340+
function build_ensemble_observed(observed_type, spec::EnsembleParameterTable, kwargs)
341+
if !haskey(kwargs, :data)
342+
@warn """
343+
No data provided for ensemble SEM model. Each SEM term will be constructed with empty data.
344+
To provide data for each term, pass a DataFrame with a column identifying the term groups or a Dict mapping term ids to data
345+
"""
346+
semterms_data = nothing
347+
else
348+
kwdata = kwargs[:data]
349+
if isa(kwdata, AbstractDataFrame)
350+
semterm_col = get(kwargs, :semterm_column, nothing)
351+
isnothing(semterm_col) &&
352+
throw(ArgumentError("No semterm_column specified for ensemble data."))
353+
semterms_data = Dict(
354+
g[semterm_col] => group_data for
355+
(g, group_data) in pairs(groupby(kwdata, semterm_col))
356+
)
357+
elseif isa(kwdata, AbstractDict)
358+
semterms_data = kwdata
359+
else
360+
"""
361+
Unsupported data type for ensemble SEM model: $(typeof(kwdata)).
362+
Provide a DataFrame with a column identifying the term groups or a Dict mapping term ids to data.
363+
""" |>
364+
ArgumentError |>
365+
throw
366+
end
367+
unused_term_ids = setdiff(keys(semterms_data), keys(spec.tables))
368+
isempty(unused_term_ids) ||
369+
@warn "Ignoring data with ids=$(collect(unused_term_ids)): no such SEM terms exist"
370+
end
371+
372+
# construct SemObserved for each term
373+
return Dict(
374+
term_id => begin
375+
term_kwargs = copy(kwargs)
376+
if !isnothing(semterms_data)
377+
term_data = get(semterms_data, term_id, nothing)
378+
isnothing(term_data) &&
379+
throw(ArgumentError("No data provided for SEM term :$term_id"))
380+
term_kwargs[:data] = term_data
381+
delete!(term_kwargs, :semterm_column)
382+
end
383+
observed_type(; specification = term_spec, term_kwargs...)
384+
end for (term_id, term_spec) in pairs(spec.tables)
385+
)
386+
end
387+
338388
# construct Sem fields
339389
function get_fields!(kwargs, spec, observed, implied, loss)
340390
if !isa(spec, SemSpecification)
@@ -344,10 +394,7 @@ function get_fields!(kwargs, spec, observed, implied, loss)
344394
# observed
345395
if !isa(observed, SemObserved)
346396
observed = if spec isa EnsembleParameterTable
347-
Dict(
348-
term_id => observed(; specification = term_spec, kwargs...) for
349-
(term_id, term_spec) in pairs(spec.tables)
350-
)
397+
build_ensemble_observed(observed, spec, kwargs)
351398
else
352399
observed(; specification = spec, kwargs...)
353400
end

test/unit_tests/model.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,38 @@ end
106106
@test loss_newstate.V !== loss_orig.V
107107
@test observed_vars(loss_oldstate) == observed_vars(loss_orig)
108108
end
109+
110+
@testset "Sem(...; semterm_column=...) splits ensemble data by group" begin
111+
dat_grouped = copy(dat[:, [:x1, :x2]])
112+
n_g1 = size(dat_grouped, 1) ÷ 2
113+
dat_grouped.group = [fill(:g1, n_g1); fill(:g2, size(dat_grouped, 1) - n_g1)]
114+
115+
group_graph = @StenoGraph begin
116+
f1 fixed(1.0, 1.0) * x1 + label(:λ₂, :λ₂) * x2
117+
_(Symbol[:x1, :x2]) _(Symbol[:x1, :x2])
118+
_(Symbol[:f1]) _(Symbol[:f1])
119+
end
120+
121+
grouped_partable = EnsembleParameterTable(
122+
group_graph;
123+
observed_vars = [:x1, :x2],
124+
latent_vars = [:f1],
125+
groups = [:g1, :g2],
126+
)
127+
128+
grouped_model = Sem(
129+
specification = grouped_partable,
130+
data = dat_grouped,
131+
semterm_column = :group,
132+
observed = SemObservedData,
133+
implied = RAM,
134+
loss = SemML,
135+
)
136+
137+
term_g1 = only(filter(term -> SEM.id(term) == :g1, SEM.loss_terms(grouped_model)))
138+
term_g2 = only(filter(term -> SEM.id(term) == :g2, SEM.loss_terms(grouped_model)))
139+
140+
@test nsamples(observed(term_g1)) == n_g1
141+
@test nsamples(observed(term_g2)) == size(dat_grouped, 1) - n_g1
142+
@test nsamples(grouped_model) == size(dat_grouped, 1)
143+
end

0 commit comments

Comments
 (0)