Skip to content

Commit 6e7a0c1

Browse files
committed
update_partable!(): dict-based generic version
1 parent c3a7831 commit 6e7a0c1

2 files changed

Lines changed: 80 additions & 27 deletions

File tree

src/frontend/specification/EnsembleParameterTable.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,28 @@ Base.getindex(partable::EnsembleParameterTable, group) = partable.tables[group]
117117
### Update Partable from Fitted Model
118118
############################################################################################
119119

120-
# update generic ---------------------------------------------------------------------------
121120
function update_partable!(
122-
partable::EnsembleParameterTable,
121+
partables::EnsembleParameterTable,
122+
column::Symbol,
123+
param_values::AbstractDict{Symbol},
124+
default::Any = nothing,
125+
)
126+
for partable in values(partables.tables)
127+
update_partable!(partable, column, param_values, default)
128+
end
129+
return partables
130+
end
131+
132+
function update_partable!(
133+
partables::EnsembleParameterTable,
134+
column::Symbol,
123135
params::AbstractVector{Symbol},
124136
values::AbstractVector,
125-
column,
137+
default::Any = nothing,
126138
)
127-
for k in keys(partable.tables)
128-
update_partable!(partable.tables[k], params, values, column)
139+
param_values = Dict(zip(params, values))
140+
for (id, partable) in pairs(partables.tables)
141+
update_partable!(partable, column, param_values, default)
129142
end
130-
return partable
143+
return partables
131144
end

src/frontend/specification/ParameterTable.jl

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,33 @@ end
227227
############################################################################################
228228

229229
# update generic ---------------------------------------------------------------------------
230+
function update_partable!(
231+
partable::ParameterTable,
232+
column::Symbol,
233+
param_values::AbstractDict{Symbol, T},
234+
default::Any = nothing,
235+
) where {T}
236+
coldata = get!(() -> Vector{T}(), partable.columns, column)
237+
resize!(coldata, length(partable))
238+
239+
isvec_def = (default isa AbstractVector) && (length(default) == length(partable))
240+
241+
for (i, par) in enumerate(partable.columns[:param])
242+
if par == :const
243+
coldata[i] = !isnothing(default) ? (isvec_def ? default[i] : default) : zero(T)
244+
elseif haskey(param_values, par)
245+
coldata[i] = param_values[par]
246+
else
247+
if isnothing(default)
248+
throw(KeyError(par))
249+
else
250+
coldata[i] = isvec_def ? default[i] : default
251+
end
252+
end
253+
end
254+
255+
return partable
256+
end
230257

231258
"""
232259
update_partable!(partable::AbstractParameterTable, params::Vector{Symbol}, values, column)
@@ -239,49 +266,62 @@ of the `partable`.
239266
"""
240267
function update_partable!(
241268
partable::ParameterTable,
269+
column::Symbol,
242270
params::AbstractVector{Symbol},
243271
values::AbstractVector,
244-
column::Symbol,
272+
default::Any = nothing,
245273
)
246274
length(params) == length(values) || throw(
247275
ArgumentError(
248276
"The length of `params` ($(length(params))) and their `values` ($(length(values))) must be the same",
249277
),
250278
)
251-
coldata = get!(() -> Vector{eltype(values)}(), partable.columns, column)
252-
resize!(coldata, length(partable))
253-
params_index = Dict(zip(params, eachindex(params)))
254-
for (i, param) in enumerate(partable.columns[:param])
255-
coldata[i] = param != :const ? values[params_index[param]] : zero(eltype(values))
279+
param_values = Dict(zip(params, values))
280+
if length(param_values) != length(params)
281+
throw(ArgumentError("Duplicate parameter names in `params`"))
256282
end
257-
return partable
283+
update_partable!(partable, column, param_values, default)
258284
end
259285

260286
# update estimates -------------------------------------------------------------------------
261287
"""
262288
update_estimate!(
263289
partable::AbstractParameterTable,
264-
sem_fit::SemFit)
290+
fit::SemFit)
265291
266-
Write parameter estimates from `sem_fit` to the `:estimate` column of `partable`
292+
Write parameter estimates from `fit` to the `:estimate` column of `partable`
267293
"""
268-
update_estimate!(partable::AbstractParameterTable, sem_fit::SemFit) =
269-
update_partable!(partable, params(sem_fit), sem_fit.solution, :estimate)
294+
update_estimate!(partable::ParameterTable, fit::SemFit) = update_partable!(
295+
partable,
296+
:estimate,
297+
params(fit),
298+
fit.solution,
299+
partable.columns[:value_fixed],
300+
)
301+
302+
# fallback method for ensemble
303+
update_estimate!(partable::AbstractParameterTable, fit::SemFit) =
304+
update_partable!(partable, :estimate, params(fit), fit.solution)
270305

271306
# update starting values -------------------------------------------------------------------
272307
"""
273-
update_start!(partable::AbstractParameterTable, sem_fit::SemFit)
308+
update_start!(partable::AbstractParameterTable, fit::SemFit)
274309
update_start!(partable::AbstractParameterTable, model::AbstractSem, start_val; kwargs...)
275310
276-
Write starting values from `sem_fit` or `start_val` to the `:estimate` column of `partable`.
311+
Write starting values from `fit` or `start_val` to the `:estimate` column of `partable`.
277312
278313
# Arguments
279314
- `start_val`: either a vector of starting values or a function to compute starting values
280315
from `model`
281316
- `kwargs...`: are passed to `start_val`
282317
"""
283-
update_start!(partable::AbstractParameterTable, sem_fit::SemFit) =
284-
update_partable!(partable, params(sem_fit), sem_fit.start_val, :start)
318+
update_start!(partable::AbstractParameterTable, fit::SemFit) = update_partable!(
319+
partable,
320+
:start,
321+
params(fit),
322+
fit.start_val,
323+
partable.columns[:value_fixed],
324+
)
285325

286326
function update_start!(
287327
partable::AbstractParameterTable,
@@ -292,17 +332,17 @@ function update_start!(
292332
if !(start_val isa Vector)
293333
start_val = start_val(model; kwargs...)
294334
end
295-
return update_partable!(partable, params(model), start_val, :start)
335+
return update_partable!(partable, :start, params(model), start_val)
296336
end
297337

298338
# update partable standard errors ----------------------------------------------------------
299339
"""
300340
update_se_hessian!(
301341
partable::AbstractParameterTable,
302-
sem_fit::SemFit;
342+
fit::SemFit;
303343
hessian = :finitediff)
304344
305-
Write hessian standard errors computed for `sem_fit` to the `:se` column of `partable`
345+
Write hessian standard errors computed for `fit` to the `:se` column of `partable`
306346
307347
# Arguments
308348
- `hessian::Symbol`: how to compute the hessian, see [se_hessian](@ref) for more information.
@@ -312,11 +352,11 @@ Write hessian standard errors computed for `sem_fit` to the `:se` column of `par
312352
"""
313353
function update_se_hessian!(
314354
partable::AbstractParameterTable,
315-
sem_fit::SemFit;
355+
fit::SemFit;
316356
hessian = :finitediff,
317357
)
318-
se = se_hessian(sem_fit; hessian = hessian)
319-
return update_partable!(partable, params(sem_fit), se, :se)
358+
se = se_hessian(fit; hessian = hessian)
359+
return update_partable!(partable, :se, params(fit), se)
320360
end
321361

322362
"""

0 commit comments

Comments
 (0)