Skip to content

Commit c01f8cc

Browse files
Alexey Stukalovalyst
authored andcommitted
predict_latent_vars()
1 parent 127573a commit c01f8cc

2 files changed

Lines changed: 119 additions & 0 deletions

File tree

src/StructuralEquationModels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ include("frontend/specification/RAMMatrices.jl")
3232
include("frontend/specification/EnsembleParameterTable.jl")
3333
include("frontend/specification/StenoGraphs.jl")
3434
include("frontend/fit/summary.jl")
35+
include("frontend/predict.jl")
3536
# pretty printing
3637
include("frontend/pretty_printing.jl")
3738
# observed

src/frontend/predict.jl

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
abstract type SemScoresPredictMethod end
2+
3+
struct SemRegressionScores <: SemScoresPredictMethod end
4+
struct SemBartlettScores <: SemScoresPredictMethod end
5+
struct SemAndersonRubinScores <: SemScoresPredictMethod end
6+
7+
function SemScoresPredictMethod(method::Symbol)
8+
if method == :regression
9+
return SemRegressionScores()
10+
elseif method == :Bartlett
11+
return SemBartlettScores()
12+
elseif method == :AndersonRubin
13+
return SemAndersonRubinScores()
14+
else
15+
throw(ArgumentError("Unsupported prediction method: $method"))
16+
end
17+
end
18+
19+
predict_latent_scores(
20+
fit::SemFit,
21+
data::SemObserved = fit.model.observed;
22+
method::Symbol = :regression,
23+
) = predict_latent_scores(SemScoresPredictMethod(method), fit, data)
24+
25+
predict_latent_scores(
26+
method::SemScoresPredictMethod,
27+
fit::SemFit,
28+
data::SemObserved = fit.model.observed,
29+
) = predict_latent_scores(method, fit.model, fit.solution, data)
30+
31+
function inv_cov!(A::AbstractMatrix)
32+
if istril(A)
33+
A = LowerTriangular(A)
34+
elseif istriu(A)
35+
A = UpperTriangular(A)
36+
else
37+
end
38+
A_chol = Cholesky(A)
39+
return inv!(A_chol)
40+
end
41+
42+
function latent_scores_operator(
43+
::SemRegressionScores,
44+
model::AbstractSemSingle,
45+
params::AbstractVector,
46+
)
47+
implied = model.imply
48+
ram = implied.ram_matrices
49+
lv_inds = latent_var_indices(ram)
50+
51+
A = materialize(ram.A, params)
52+
lv_FA = ram.F * A[:, lv_inds]
53+
lv_I_A⁻¹ = inv(I - A)[lv_inds, :]
54+
55+
S = materialize(ram.S, params)
56+
57+
cov_lv = lv_I_A⁻¹ * S * lv_I_A⁻¹'
58+
Σ = implied.Σ
59+
Σ⁻¹ = inv(Σ)
60+
return cov_lv * lv_FA' * Σ⁻¹
61+
end
62+
63+
function latent_scores_operator(
64+
::SemBartlettScores,
65+
model::AbstractSemSingle,
66+
params::AbstractVector,
67+
)
68+
implied = model.imply
69+
ram = implied.ram_matrices
70+
lv_inds = latent_var_indices(ram)
71+
A = materialize(ram.A, params)
72+
lv_FA = ram.F * A[:, lv_inds]
73+
74+
S = materialize(ram.S, params)
75+
obs_inds = observed_var_indices(ram)
76+
ov_S⁻¹ = inv(S[obs_inds, obs_inds])
77+
78+
return inv(lv_FA' * ov_S⁻¹ * lv_FA) * lv_FA' * ov_S⁻¹
79+
end
80+
81+
function predict_latent_scores(
82+
method::SemScoresPredictMethod,
83+
model::AbstractSemSingle,
84+
params::AbstractVector,
85+
data::SemObserved,
86+
)
87+
n_man(data) == nobserved_vars(model) || throw(
88+
DimensionMismatch(
89+
"Number of variables in data ($(n_obs(data))) does not match the number of observed variables in the model ($(nobserved_vars(model)))",
90+
),
91+
)
92+
length(params) == nparams(model) || throw(
93+
DimensionMismatch(
94+
"The length of parameters vector ($(length(params))) does not match the number of parameters in the model ($(nparams(model)))",
95+
),
96+
)
97+
98+
implied = model.imply
99+
hasmeanstruct = MeanStructure(implied) === HasMeanStructure
100+
101+
update!(EvaluationTargets(0.0, nothing, nothing), model.imply, model, params)
102+
ram = implied.ram_matrices
103+
lv_inds = latent_var_indices(ram)
104+
A = materialize(ram.A, params)
105+
lv_I_A⁻¹ = inv(I - A)[lv_inds, :]
106+
107+
lv_scores_op = latent_scores_operator(method, model, params)
108+
109+
data =
110+
data.data .- (isnothing(data.obs_mean) ? mean(data.data, dims = 1) : data.obs_mean')
111+
lv_scores = data * lv_scores_op'
112+
if hasmeanstruct
113+
M = materialize(ram.M, params)
114+
lv_scores .+= (lv_I_A⁻¹ * M)'
115+
end
116+
117+
return lv_scores
118+
end

0 commit comments

Comments
 (0)