Skip to content

Commit 2ddadf0

Browse files
author
Alexey Stukalov
committed
Sem(): cleanup constructor
* rename get_fields!() into build_sem_terms() for clarity * move set_field_type!() code into Sem() ctor since its not used outside
1 parent 5404adb commit 2ddadf0

1 file changed

Lines changed: 17 additions & 20 deletions

File tree

  • src/frontend/specification

src/frontend/specification/Sem.jl

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,19 @@ function Sem(;
200200
) where {O, I, L}
201201
kwdict = Dict{Symbol, Any}(kwargs...)
202202

203-
set_field_type_kwargs!(kwdict, observed, implied, loss, O, I)
203+
# add kwargs with type information
204+
kwdict[:observed_type] = O <: Type ? observed : typeof(observed)
205+
kwdict[:implied_type] = I <: Type ? implied : typeof(implied)
206+
if loss isa SemLoss
207+
kwdict[:loss_types] =
208+
[aloss isa SemLoss ? typeof(aloss) : aloss for aloss in loss.functions]
209+
elseif applicable(iterate, loss)
210+
kwdict[:loss_types] = [aloss isa SemLoss ? typeof(aloss) : aloss for aloss in loss]
211+
else
212+
kwdict[:loss_types] = [loss isa SemLoss ? typeof(loss) : loss]
213+
end
204214

205-
loss = get_fields!(kwdict, specification, observed, implied, loss)
215+
loss = build_sem_terms(kwdict, specification, observed, implied, loss)
206216

207217
return Sem(loss...)
208218
end
@@ -337,19 +347,6 @@ vars(model::AbstractSem, id::Nothing = nothing) = vars(implied(model, id))
337347
observed_vars(model::AbstractSem, id::Nothing = nothing) = observed_vars(implied(model, id))
338348
latent_vars(model::AbstractSem, id::Nothing = nothing) = latent_vars(implied(model, id))
339349

340-
function set_field_type_kwargs!(kwargs, observed, implied, loss, O, I)
341-
kwargs[:observed_type] = O <: Type ? observed : typeof(observed)
342-
kwargs[:implied_type] = I <: Type ? implied : typeof(implied)
343-
if loss isa SemLoss
344-
kwargs[:loss_types] =
345-
[aloss isa SemLoss ? typeof(aloss) : aloss for aloss in loss.functions]
346-
elseif applicable(iterate, loss)
347-
kwargs[:loss_types] = [aloss isa SemLoss ? typeof(aloss) : aloss for aloss in loss]
348-
else
349-
kwargs[:loss_types] = [loss isa SemLoss ? typeof(loss) : loss]
350-
end
351-
end
352-
353350
# build ensemble/multi-group observed from the specification and Sem(...) kwargs
354351
# used by Sem(...) and replace_observed()
355352
function build_ensemble_observed(observed_type, spec::EnsembleParameterTable, kwargs)
@@ -400,8 +397,8 @@ function build_ensemble_observed(observed_type, spec::EnsembleParameterTable, kw
400397
)
401398
end
402399

403-
# construct Sem fields
404-
function get_fields!(kwargs, spec, observed, implied, loss)
400+
# called by Sem() ctor to construct its loss terms
401+
function build_sem_terms(kwargs::AbstractDict, spec, observed, implied, loss)
405402
if !isa(spec, SemSpecification)
406403
spec = spec(; kwargs...)
407404
end
@@ -430,13 +427,13 @@ function get_fields!(kwargs, spec, observed, implied, loss)
430427
# loss
431428
loss_kwargs = copy(kwargs)
432429
loss_kwargs[:nparams] = nparams(spec)
433-
loss = build_SemTerms(loss, observed, implied; loss_kwargs...)
430+
loss = build_sem_terms(loss, observed, implied; loss_kwargs...)
434431

435432
return loss
436433
end
437434

438-
# construct loss field
439-
function build_SemTerms(loss, observed, implied; kwargs...)
435+
# construct loss terms for the given observed and implied
436+
function build_sem_terms(loss, observed, implied; kwargs...)
440437
function build_SemLoss(aloss, observed, implied)
441438
if loss isa AbstractLoss
442439
return loss

0 commit comments

Comments
 (0)