Skip to content

Commit 0ba27a2

Browse files
committed
update_partable!(): dict-based generic version
1 parent f5b45bf commit 0ba27a2

2 files changed

Lines changed: 80 additions & 27 deletions

File tree

src/frontend/specification/EnsembleParameterTable.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,28 @@ Base.getindex(partable::EnsembleParameterTable, group) = partable.tables[group]
119119
### Update Partable from Fitted Model
120120
############################################################################################
121121

122-
# update generic ---------------------------------------------------------------------------
123122
function update_partable!(
124-
partable::EnsembleParameterTable,
123+
partables::EnsembleParameterTable,
124+
column::Symbol,
125+
param_values::AbstractDict{Symbol},
126+
default::Any = nothing,
127+
)
128+
for partable in values(partables.tables)
129+
update_partable!(partable, column, param_values, default)
130+
end
131+
return partables
132+
end
133+
134+
function update_partable!(
135+
partables::EnsembleParameterTable,
136+
column::Symbol,
125137
params::AbstractVector{Symbol},
126138
values::AbstractVector,
127-
column,
139+
default::Any = nothing,
128140
)
129-
for k in keys(partable.tables)
130-
update_partable!(partable.tables[k], params, values, column)
141+
param_values = Dict(zip(params, values))
142+
for (id, partable) in pairs(partables.tables)
143+
update_partable!(partable, column, param_values, default)
131144
end
132-
return partable
145+
return partables
133146
end

src/frontend/specification/ParameterTable.jl

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,33 @@ end
224224
############################################################################################
225225

226226
# update generic ---------------------------------------------------------------------------
227+
function update_partable!(
228+
partable::ParameterTable,
229+
column::Symbol,
230+
param_values::AbstractDict{Symbol, T},
231+
default::Any = nothing,
232+
) where {T}
233+
coldata = get!(() -> Vector{T}(), partable.columns, column)
234+
resize!(coldata, length(partable))
235+
236+
isvec_def = (default isa AbstractVector) && (length(default) == length(partable))
237+
238+
for (i, par) in enumerate(partable.columns[:param])
239+
if par == :const
240+
coldata[i] = !isnothing(default) ? (isvec_def ? default[i] : default) : zero(T)
241+
elseif haskey(param_values, par)
242+
coldata[i] = param_values[par]
243+
else
244+
if isnothing(default)
245+
throw(KeyError(par))
246+
else
247+
coldata[i] = isvec_def ? default[i] : default
248+
end
249+
end
250+
end
251+
252+
return partable
253+
end
227254

228255
"""
229256
update_partable!(partable::AbstractParameterTable, params::Vector{Symbol}, values, column)
@@ -236,49 +263,62 @@ of the `partable`.
236263
"""
237264
function update_partable!(
238265
partable::ParameterTable,
266+
column::Symbol,
239267
params::AbstractVector{Symbol},
240268
values::AbstractVector,
241-
column::Symbol,
269+
default::Any = nothing,
242270
)
243271
length(params) == length(values) || throw(
244272
ArgumentError(
245273
"The length of `params` ($(length(params))) and their `values` ($(length(values))) must be the same",
246274
),
247275
)
248-
coldata = get!(() -> Vector{eltype(values)}(), partable.columns, column)
249-
resize!(coldata, length(partable))
250-
params_index = Dict(zip(params, eachindex(params)))
251-
for (i, param) in enumerate(partable.columns[:param])
252-
coldata[i] = param != :const ? values[params_index[param]] : zero(eltype(values))
276+
param_values = Dict(zip(params, values))
277+
if length(param_values) != length(params)
278+
throw(ArgumentError("Duplicate parameter names in `params`"))
253279
end
254-
return partable
280+
update_partable!(partable, column, param_values, default)
255281
end
256282

257283
# update estimates -------------------------------------------------------------------------
258284
"""
259285
update_estimate!(
260286
partable::AbstractParameterTable,
261-
sem_fit::SemFit)
287+
fit::SemFit)
262288
263-
Write parameter estimates from `sem_fit` to the `:estimate` column of `partable`
289+
Write parameter estimates from `fit` to the `:estimate` column of `partable`
264290
"""
265-
update_estimate!(partable::AbstractParameterTable, sem_fit::SemFit) =
266-
update_partable!(partable, params(sem_fit), sem_fit.solution, :estimate)
291+
update_estimate!(partable::ParameterTable, fit::SemFit) = update_partable!(
292+
partable,
293+
:estimate,
294+
params(fit),
295+
fit.solution,
296+
partable.columns[:value_fixed],
297+
)
298+
299+
# fallback method for ensemble
300+
update_estimate!(partable::AbstractParameterTable, fit::SemFit) =
301+
update_partable!(partable, :estimate, params(fit), fit.solution)
267302

268303
# update starting values -------------------------------------------------------------------
269304
"""
270-
update_start!(partable::AbstractParameterTable, sem_fit::SemFit)
305+
update_start!(partable::AbstractParameterTable, fit::SemFit)
271306
update_start!(partable::AbstractParameterTable, model::AbstractSem, start_val; kwargs...)
272307
273-
Write starting values from `sem_fit` or `start_val` to the `:estimate` column of `partable`.
308+
Write starting values from `fit` or `start_val` to the `:estimate` column of `partable`.
274309
275310
# Arguments
276311
- `start_val`: either a vector of starting values or a function to compute starting values
277312
from `model`
278313
- `kwargs...`: are passed to `start_val`
279314
"""
280-
update_start!(partable::AbstractParameterTable, sem_fit::SemFit) =
281-
update_partable!(partable, params(sem_fit), sem_fit.start_val, :start)
315+
update_start!(partable::AbstractParameterTable, fit::SemFit) = update_partable!(
316+
partable,
317+
:start,
318+
params(fit),
319+
fit.start_val,
320+
partable.columns[:value_fixed],
321+
)
282322

283323
function update_start!(
284324
partable::AbstractParameterTable,
@@ -289,17 +329,17 @@ function update_start!(
289329
if !(start_val isa Vector)
290330
start_val = start_val(model; kwargs...)
291331
end
292-
return update_partable!(partable, params(model), start_val, :start)
332+
return update_partable!(partable, :start, params(model), start_val)
293333
end
294334

295335
# update partable standard errors ----------------------------------------------------------
296336
"""
297337
update_se_hessian!(
298338
partable::AbstractParameterTable,
299-
sem_fit::SemFit;
339+
fit::SemFit;
300340
hessian = :finitediff)
301341
302-
Write hessian standard errors computed for `sem_fit` to the `:se` column of `partable`
342+
Write hessian standard errors computed for `fit` to the `:se` column of `partable`
303343
304344
# Arguments
305345
- `hessian::Symbol`: how to compute the hessian, see [se_hessian](@ref) for more information.
@@ -309,11 +349,11 @@ Write hessian standard errors computed for `sem_fit` to the `:se` column of `par
309349
"""
310350
function update_se_hessian!(
311351
partable::AbstractParameterTable,
312-
sem_fit::SemFit;
352+
fit::SemFit;
313353
hessian = :finitediff,
314354
)
315-
se = se_hessian(sem_fit; hessian = hessian)
316-
return update_partable!(partable, params(sem_fit), se, :se)
355+
se = se_hessian(fit; hessian = hessian)
356+
return update_partable!(partable, :se, params(fit), se)
317357
end
318358

319359
"""

0 commit comments

Comments
 (0)