Skip to content

Commit bfd32b4

Browse files
author
Alexey Stukalov
committed
replace_observed(): support kwargs
1 parent 91d6f47 commit bfd32b4

3 files changed

Lines changed: 27 additions & 22 deletions

File tree

src/frontend/finite_diff.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ _unwrap(wrapper::SemFiniteDiff) = wrapper.model
22
params(wrapper::SemFiniteDiff) = params(wrapper.model)
33
loss_terms(wrapper::SemFiniteDiff) = loss_terms(wrapper.model)
44

5-
replace_observed(wrapper::SemFiniteDiff, data) =
6-
SemFiniteDiff(replace_observed(wrapper.model, data))
5+
replace_observed(wrapper::SemFiniteDiff, data; kwargs...) =
6+
SemFiniteDiff(replace_observed(wrapper.model, data; kwargs...))
77

88
FiniteDiffLossWrappers = Union{LossFiniteDiff, SemLossFiniteDiff}
99

@@ -12,16 +12,17 @@ _unwrap(wrapper::FiniteDiffLossWrappers) = wrapper.loss
1212
implied(wrapper::FiniteDiffLossWrappers) = implied(_unwrap(wrapper))
1313
observed(wrapper::FiniteDiffLossWrappers) = observed(_unwrap(wrapper))
1414

15-
replace_observed(wrapper::LossFiniteDiff, data) =
16-
LossFiniteDiff(replace_observed(_unwrap(wrapper), data))
15+
replace_observed(wrapper::LossFiniteDiff, data; kwargs...) =
16+
LossFiniteDiff(replace_observed(_unwrap(wrapper), data; kwargs...))
1717

18-
replace_observed(wrapper::SemLossFiniteDiff, new_observed::SemObserved) =
19-
SemLossFiniteDiff(replace_observed(_unwrap(wrapper), new_observed))
18+
replace_observed(wrapper::SemLossFiniteDiff, new_observed::SemObserved; kwargs...) =
19+
SemLossFiniteDiff(replace_observed(_unwrap(wrapper), new_observed; kwargs...))
2020

2121
replace_observed(
2222
wrapper::SemLossFiniteDiff,
23-
data::Union{AbstractMatrix, DataFrame},
24-
) = SemLossFiniteDiff(replace_observed(_unwrap(wrapper), data))
23+
data::Union{AbstractMatrix, DataFrame};
24+
kwargs...,
25+
) = SemLossFiniteDiff(replace_observed(_unwrap(wrapper), data; kwargs...))
2526

2627
FiniteDiffWrapper(model::AbstractSem) = SemFiniteDiff(model)
2728
FiniteDiffWrapper(loss::AbstractLoss) = LossFiniteDiff(loss)

src/frontend/specification/Sem.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -475,18 +475,19 @@ replace_observed(model, new_df; semterm_column = :group)
475475
"""
476476
function replace_observed end
477477

478-
function replace_observed(sem::Sem, data::Union{SemObserved, AbstractMatrix})
478+
function replace_observed(sem::Sem, data::Union{SemObserved, AbstractMatrix}; kwargs...)
479479
nsem_terms(sem) > 1 && throw(
480480
ArgumentError(
481481
"Model contains $(nsem_terms(sem)) SEM terms. " *
482482
"Use a Dict{Symbol} or a DataFrame with `semterm_column` to provide per-term data.",
483483
),
484484
)
485-
updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem))
485+
updated_terms =
486+
Tuple(replace_observed(term, data; kwargs...) for term in loss_terms(sem))
486487
return Sem(updated_terms...)
487488
end
488489

489-
function replace_observed(sem::Sem, data::AbstractDict{Symbol})
490+
function replace_observed(sem::Sem, data::AbstractDict{Symbol}; kwargs...)
490491
term_ids = Set(
491492
if !isnothing(id(term))
492493
id(term)
@@ -507,20 +508,20 @@ function replace_observed(sem::Sem, data::AbstractDict{Symbol})
507508
term_data = get(data, tid, nothing)
508509
isnothing(term_data) &&
509510
throw(ArgumentError("No data provided for SEM term :$tid"))
510-
return replace_observed(term, term_data)
511+
return replace_observed(term, term_data; kwargs...)
511512
end
512513
return Sem(Tuple(updated_terms)...)
513514
end
514515

515-
function replace_observed(sem::Sem, data::AbstractVector)
516+
function replace_observed(sem::Sem, data::AbstractVector; kwargs...)
516517
nsem = nsem_terms(sem)
517518
nsem == length(data) || throw(
518519
ArgumentError(
519520
"Length of data ($(length(data))) does not match number of SEM terms ($nsem)",
520521
),
521522
)
522523
updated_terms = map(enumerate(loss_terms(sem))) do (i, term)
523-
issemloss(term) ? replace_observed(term, data[i]) : term
524+
issemloss(term) ? replace_observed(term, data[i]; kwargs...) : term
524525
end
525526
return Sem(Tuple(updated_terms)...)
526527
end
@@ -529,6 +530,7 @@ function replace_observed(
529530
sem::Sem,
530531
data::AbstractDataFrame;
531532
semterm_column::Union{Symbol, Nothing} = nothing,
533+
kwargs...,
532534
)
533535
if isnothing(semterm_column)
534536
# single-term shortcut
@@ -538,7 +540,8 @@ function replace_observed(
538540
"Provide `semterm_column` to specify which DataFrame column identifies the groups.",
539541
),
540542
)
541-
updated_terms = Tuple(replace_observed(term, data) for term in loss_terms(sem))
543+
updated_terms =
544+
Tuple(replace_observed(term, data; kwargs...) for term in loss_terms(sem))
542545
return Sem(updated_terms...)
543546
end
544547

@@ -547,7 +550,7 @@ function replace_observed(
547550
g[semterm_column] => group_data for
548551
(g, group_data) in pairs(groupby(data, semterm_column))
549552
)
550-
return replace_observed(sem, terms_data)
553+
return replace_observed(sem, terms_data; kwargs...)
551554
end
552555

553556
##############################################################

src/loss/abstract.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,28 @@ check_observed_vars(sem::SemLoss) = check_observed_vars(observed(sem), implied(s
4545
# replace_observed: SemLoss, AbstractLoss, LossTerm
4646
############################################################################################
4747

48-
function replace_observed(loss::SemLoss, new_observed::SemObserved)
48+
function replace_observed(loss::SemLoss, new_observed::SemObserved; kwargs...)
4949
old_obs = SEM.observed(loss)
5050
observed_vars(old_obs) == observed_vars(new_observed) || throw(
5151
ArgumentError(
5252
"observed_vars of the new data do not match the model: " *
5353
"expected $(observed_vars(old_obs)), got $(observed_vars(new_observed))",
5454
),
5555
)
56+
# the default replace_observed() does not pass through kwargs to the ctor
5657
return typeof(loss).name.wrapper(new_observed, SEM.implied(loss))
5758
end
5859

59-
function replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame})
60+
function replace_observed(loss::SemLoss, data::Union{AbstractMatrix, DataFrame}; kwargs...)
6061
old_obs = SEM.observed(loss)
6162
new_observed =
6263
typeof(old_obs).name.wrapper(data = data, observed_vars = observed_vars(old_obs))
63-
return replace_observed(loss, new_observed)
64+
return replace_observed(loss, new_observed; kwargs...)
6465
end
6566

6667
# non-SEM loss terms are unchanged
67-
replace_observed(loss::AbstractLoss, ::Any) = loss
68+
replace_observed(loss::AbstractLoss, ::Any; kwargs...) = loss
6869

6970
# LossTerm: delegate to inner loss
70-
replace_observed(term::LossTerm, data) =
71-
LossTerm(replace_observed(loss(term), data), id(term), weight(term))
71+
replace_observed(term::LossTerm, data; kwargs...) =
72+
LossTerm(replace_observed(loss(term), data; kwargs...), id(term), weight(term))

0 commit comments

Comments
 (0)