Skip to content

Commit 1f835a0

Browse files
committed
ParTable: add explicit params field
1 parent 1cb865d commit 1f835a0

3 files changed

Lines changed: 78 additions & 31 deletions

File tree

src/frontend/specification/EnsembleParameterTable.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,19 @@ function EnsembleParameterTable(
2323
spec_ensemble::AbstractDict{K, V};
2424
params::Union{Nothing, Vector{Symbol}} = nothing,
2525
) where {K, V <: SemSpecification}
26-
partables = Dict{K, ParameterTable}(
27-
group => convert(ParameterTable, spec; params = params) for
28-
(group, spec) in pairs(spec_ensemble)
29-
)
30-
31-
if isnothing(params)
26+
params = if isnothing(params)
3227
# collect all SEM parameters in ensemble if not specified
3328
# and apply the set to all partables
34-
params =
35-
unique(mapreduce(SEM.params, vcat, values(partables), init = Vector{Symbol}()))
36-
for partable in values(partables)
37-
if partable.params != params
38-
copyto!(resize!(partable.params, length(params)), params)
39-
#throw(ArgumentError("The parameter sets of the SEM specifications in the ensemble do not match."))
40-
end
41-
end
29+
unique(mapreduce(SEM.params, vcat, values(spec_ensemble), init = Vector{Symbol}()))
4230
else
43-
params = copy(params)
31+
copy(params)
4432
end
33+
34+
# convert each model specification to ParameterTable
35+
partables = Dict{K, ParameterTable}(
36+
group => convert(ParameterTable, spec; params) for
37+
(group, spec) in pairs(spec_ensemble)
38+
)
4539
return EnsembleParameterTable{K}(partables, params)
4640
end
4741

src/frontend/specification/ParameterTable.jl

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ struct ParameterTable{C} <: AbstractParameterTable
77
observed_vars::Vector{Symbol}
88
latent_vars::Vector{Symbol}
99
sorted_vars::Vector{Symbol}
10+
params::Vector{Symbol}
1011
end
1112

1213
############################################################################################
@@ -31,11 +32,32 @@ function ParameterTable(
3132
columns::Dict{Symbol, Vector} = empty_partable_columns();
3233
observed_vars::Union{AbstractVector{Symbol}, Nothing} = nothing,
3334
latent_vars::Union{AbstractVector{Symbol}, Nothing} = nothing,
35+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
3436
)
35-
return ParameterTable(columns,
37+
params = isnothing(params) ? unique!(filter(!=(:const), columns[:param])) : copy(params)
38+
check_params(params, columns[:param])
39+
return ParameterTable(
40+
columns,
3641
!isnothing(observed_vars) ? copy(observed_vars) : Vector{Symbol}(),
3742
!isnothing(latent_vars) ? copy(latent_vars) : Vector{Symbol}(),
38-
Vector{Symbol}())
43+
Vector{Symbol}(),
44+
params,
45+
)
46+
end
47+
48+
# new parameter table with different parameters order
49+
function ParameterTable(
50+
partable::ParameterTable;
51+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
52+
)
53+
isnothing(params) || check_params(params, partable.columns[:param])
54+
55+
return ParameterTable(
56+
Dict(col => copy(values) for (col, values) in pairs(partable.columns)),
57+
observed_vars = copy(partable.observed_vars),
58+
latent_vars = copy(partable.latent_vars),
59+
params = params,
60+
)
3961
end
4062

4163
############################################################################################
@@ -46,6 +68,15 @@ function Base.convert(::Type{Dict}, partable::ParameterTable)
4668
return partable.columns
4769
end
4870

71+
function Base.convert(
72+
::Type{ParameterTable},
73+
partable::ParameterTable;
74+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
75+
)
76+
return isnothing(params) || partable.params == params ? partable :
77+
ParameterTable(partable; params)
78+
end
79+
4980
function DataFrames.DataFrame(
5081
partable::ParameterTable;
5182
columns::Union{AbstractVector{Symbol}, Nothing} = nothing,
@@ -111,12 +142,8 @@ Base.iterate(partable::ParameterTable) = iterate(partable, 1)
111142
Base.iterate(partable::ParameterTable, i::Integer) =
112143
i > length(partable) ? nothing : (partable[i], i + 1)
113144

114-
115-
# get the vector of all parameters in the table
116-
# the position of the parameter is based on its first appearance in the table (and the ensemble)
117-
params(partable::ParameterTable) =
118-
filter!(!=(:const), unique(partable.columns[:param]))
119-
145+
params(partable::ParameterTable) = partable.params
146+
n_par(partable::ParameterTable) = length(params(partable))
120147

121148
# Sorting ----------------------------------------------------------------------------------
122149

src/frontend/specification/RAMMatrices.jl

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,21 +223,43 @@ function RAMMatrices(
223223
)
224224
end
225225

226-
Base.convert(::Type{RAMMatrices}, partable::ParameterTable) = RAMMatrices(partable)
226+
Base.convert(
227+
::Type{RAMMatrices},
228+
partable::ParameterTable;
229+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
230+
) = RAMMatrices(partable; params)
227231

228232
############################################################################################
229233
### get parameter table from RAMMatrices
230234
############################################################################################
231235

232-
function ParameterTable(ram_matrices::RAMMatrices)
233-
colnames = ram_matrices.colnames
236+
function ParameterTable(
237+
ram_matrices::RAMMatrices;
238+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
239+
observed_var_prefix::Symbol = :obs,
240+
latent_var_prefix::Symbol = :var,
241+
)
242+
# defer parameter checks until we know which ones are used
243+
if !isnothing(ram_matrices.colnames)
244+
colnames = ram_matrices.colnames
245+
observed_vars = colnames[ram_matrices.F_ind]
246+
latent_vars = colnames[setdiff(eachindex(colnames), ram_matrices.F_ind)]
247+
else
248+
observed_vars =
249+
[Symbol("$(observed_var_prefix)_$i") for i in 1:nobserved_vars(ram_matrices)]
250+
latent_vars =
251+
[Symbol("$(latent_var_prefix)_$i") for i in 1:nlatent_vars(ram_matrices)]
252+
colnames = vcat(observed_vars, latent_vars)
253+
end
234254

255+
# construct an empty table
235256
partable = ParameterTable(
236-
observed_vars = colnames[ram_matrices.F_ind],
237-
latent_vars = colnames[setdiff(eachindex(colnames), ram_matrices.F_ind)],
257+
observed_vars = observed_vars,
258+
latent_vars = latent_vars,
259+
params = isnothing(params) ? SEM.params(ram_matrices) : params,
238260
)
239261

240-
position_names = Dict{Int64, Symbol}(1:length(colnames) .=> colnames)
262+
position_names = Dict{Int, Symbol}(eachindex(colnames) .=> colnames)
241263

242264
# constants
243265
for c in ram_matrices.constants
@@ -257,12 +279,16 @@ function ParameterTable(ram_matrices::RAMMatrices)
257279
ram_matrices.size_F[2],
258280
)
259281
end
282+
check_params(SEM.params(partable), partable.columns[:param])
260283

261284
return partable
262285
end
263286

264-
Base.convert(::Type{<:ParameterTable}, ram_matrices::RAMMatrices) =
265-
ParameterTable(ram_matrices)
287+
Base.convert(
288+
::Type{<:ParameterTable},
289+
ram::RAMMatrices;
290+
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
291+
) = ParameterTable(ram; params)
266292

267293
############################################################################################
268294
### Pretty Printing

0 commit comments

Comments
 (0)