Skip to content

Commit e4724ed

Browse files
committed
ParTable: add explicit params field
1 parent 65dfc87 commit e4724ed

3 files changed

Lines changed: 78 additions & 31 deletions

File tree

src/frontend/specification/EnsembleParameterTable.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,19 @@ function EnsembleParameterTable(
2323
spec_ensemble::AbstractDict{K, V};
2424
params::Union{Nothing, Vector{Symbol}} = nothing,
2525
) where {K, V <: SemSpecification}
26-
partables = Dict{K, ParameterTable}(
27-
group => convert(ParameterTable, spec; params = params) for
28-
(group, spec) in pairs(spec_ensemble)
29-
)
30-
31-
if isnothing(params)
26+
params = if isnothing(params)
3227
# collect all SEM parameters in ensemble if not specified
3328
# and apply the set to all partables
34-
params =
35-
unique(mapreduce(SEM.params, vcat, values(partables), init = Vector{Symbol}()))
36-
for partable in values(partables)
37-
if partable.params != params
38-
copyto!(resize!(partable.params, length(params)), params)
39-
#throw(ArgumentError("The parameter sets of the SEM specifications in the ensemble do not match."))
40-
end
41-
end
29+
unique(mapreduce(SEM.params, vcat, values(spec_ensemble), init = Vector{Symbol}()))
4230
else
43-
params = copy(params)
31+
copy(params)
4432
end
33+
34+
# convert each model specification to ParameterTable
35+
partables = Dict{K, ParameterTable}(
36+
group => convert(ParameterTable, spec; params) for
37+
(group, spec) in pairs(spec_ensemble)
38+
)
4539
return EnsembleParameterTable{K}(partables, params)
4640
end
4741

src/frontend/specification/ParameterTable.jl

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ struct ParameterTable{C} <: AbstractParameterTable
77
observed_vars::Vector{Symbol}
88
latent_vars::Vector{Symbol}
99
sorted_vars::Vector{Symbol}
10+
params::Vector{Symbol}
1011
end
1112

1213
############################################################################################
@@ -31,11 +32,32 @@ function ParameterTable(
3132
columns::Dict{Symbol, Vector} = empty_partable_columns();
3233
observed_vars::Union{AbstractVector{Symbol}, Nothing} = nothing,
3334
latent_vars::Union{AbstractVector{Symbol}, Nothing} = nothing,
35+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
3436
)
35-
return ParameterTable(columns,
37+
params = isnothing(params) ? unique!(filter(!=(:const), columns[:param])) : copy(params)
38+
check_params(params, columns[:param])
39+
return ParameterTable(
40+
columns,
3641
!isnothing(observed_vars) ? copy(observed_vars) : Vector{Symbol}(),
3742
!isnothing(latent_vars) ? copy(latent_vars) : Vector{Symbol}(),
38-
Vector{Symbol}())
43+
Vector{Symbol}(),
44+
params,
45+
)
46+
end
47+
48+
# new parameter table with different parameters order
49+
function ParameterTable(
50+
partable::ParameterTable;
51+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
52+
)
53+
isnothing(params) || check_params(params, partable.columns[:param])
54+
55+
return ParameterTable(
56+
Dict(col => copy(values) for (col, values) in pairs(partable.columns)),
57+
observed_vars = copy(partable.observed_vars),
58+
latent_vars = copy(partable.latent_vars),
59+
params = params,
60+
)
3961
end
4062

4163
############################################################################################
@@ -46,6 +68,15 @@ function Base.convert(::Type{Dict}, partable::ParameterTable)
4668
return partable.columns
4769
end
4870

71+
function Base.convert(
72+
::Type{ParameterTable},
73+
partable::ParameterTable;
74+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
75+
)
76+
return isnothing(params) || partable.params == params ? partable :
77+
ParameterTable(partable; params)
78+
end
79+
4980
function DataFrames.DataFrame(
5081
partable::ParameterTable;
5182
columns::Union{AbstractVector{Symbol}, Nothing} = nothing,
@@ -111,12 +142,8 @@ Base.iterate(partable::ParameterTable) = iterate(partable, 1)
111142
Base.iterate(partable::ParameterTable, i::Integer) =
112143
i > length(partable) ? nothing : (partable[i], i + 1)
113144

114-
115-
# get the vector of all parameters in the table
116-
# the position of the parameter is based on its first appearance in the table (and the ensemble)
117-
params(partable::ParameterTable) =
118-
filter!(!=(:const), unique(partable.columns[:param]))
119-
145+
params(partable::ParameterTable) = partable.params
146+
n_par(partable::ParameterTable) = length(params(partable))
120147

121148
# Sorting ----------------------------------------------------------------------------------
122149

src/frontend/specification/RAMMatrices.jl

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -222,21 +222,43 @@ function RAMMatrices(
222222
)
223223
end
224224

225-
Base.convert(::Type{RAMMatrices}, partable::ParameterTable) = RAMMatrices(partable)
225+
Base.convert(
226+
::Type{RAMMatrices},
227+
partable::ParameterTable;
228+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
229+
) = RAMMatrices(partable; params)
226230

227231
############################################################################################
228232
### get parameter table from RAMMatrices
229233
############################################################################################
230234

231-
function ParameterTable(ram_matrices::RAMMatrices)
232-
colnames = ram_matrices.colnames
235+
function ParameterTable(
236+
ram_matrices::RAMMatrices;
237+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
238+
observed_var_prefix::Symbol = :obs,
239+
latent_var_prefix::Symbol = :var,
240+
)
241+
# defer parameter checks until we know which ones are used
242+
if !isnothing(ram_matrices.colnames)
243+
colnames = ram_matrices.colnames
244+
observed_vars = colnames[ram_matrices.F_ind]
245+
latent_vars = colnames[setdiff(eachindex(colnames), ram_matrices.F_ind)]
246+
else
247+
observed_vars =
248+
[Symbol("$(observed_var_prefix)_$i") for i in 1:nobserved_vars(ram_matrices)]
249+
latent_vars =
250+
[Symbol("$(latent_var_prefix)_$i") for i in 1:nlatent_vars(ram_matrices)]
251+
colnames = vcat(observed_vars, latent_vars)
252+
end
233253

254+
# construct an empty table
234255
partable = ParameterTable(
235-
observed_vars = colnames[ram_matrices.F_ind],
236-
latent_vars = colnames[setdiff(eachindex(colnames), ram_matrices.F_ind)],
256+
observed_vars = observed_vars,
257+
latent_vars = latent_vars,
258+
params = isnothing(params) ? SEM.params(ram_matrices) : params,
237259
)
238260

239-
position_names = Dict{Int64, Symbol}(1:length(colnames) .=> colnames)
261+
position_names = Dict{Int, Symbol}(eachindex(colnames) .=> colnames)
240262

241263
# constants
242264
for c in ram_matrices.constants
@@ -256,12 +278,16 @@ function ParameterTable(ram_matrices::RAMMatrices)
256278
ram_matrices.size_F[2],
257279
)
258280
end
281+
check_params(SEM.params(partable), partable.columns[:param])
259282

260283
return partable
261284
end
262285

263-
Base.convert(::Type{<:ParameterTable}, ram_matrices::RAMMatrices) =
264-
ParameterTable(ram_matrices)
286+
Base.convert(
287+
::Type{<:ParameterTable},
288+
ram::RAMMatrices;
289+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
290+
) = ParameterTable(ram; params)
265291

266292
############################################################################################
267293
### Pretty Printing

0 commit comments

Comments
 (0)