Skip to content

Commit 4659fc5

Browse files
Alexey Stukalovalyst
authored andcommitted
EM MVN: decouple from SemObsMissing
so EM MVN could be done when SemObsMissing is constructed
1 parent 55884f2 commit 4659fc5

6 files changed

Lines changed: 159 additions & 175 deletions

File tree

src/StructuralEquationModels.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ include("frontend/pretty_printing.jl")
4242
# observed
4343
include("observed/data.jl")
4444
include("observed/covariance.jl")
45-
include("observed/missing.jl")
45+
include("observed/missing_pattern.jl")
4646
include("observed/EM.jl")
47+
include("observed/missing.jl")
4748
# constructor
4849
include("frontend/specification/Sem.jl")
4950
include("frontend/specification/documentation.jl")

src/additional_functions/start_val/start_fabin3.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,10 @@ function start_fabin3(model::AbstractSemSingle; kwargs...)
1717
)
1818
end
1919

20-
function start_fabin3(observed, imply, optimizer, args...; kwargs...)
20+
function start_fabin3(observed::SemObserved, imply, optimizer, args...; kwargs...)
2121
return start_fabin3(imply.ram_matrices, obs_cov(observed), obs_mean(observed))
2222
end
2323

24-
# SemObservedMissing
25-
function start_fabin3(observed::SemObservedMissing, imply, optimizer, args...; kwargs...)
26-
if !observed.em_model.fitted
27-
em_mvn(observed; kwargs...)
28-
end
29-
30-
return start_fabin3(imply.ram_matrices, observed.em_model.Σ, observed.em_model.μ)
31-
end
32-
3324
function start_fabin3(
3425
ram_matrices::RAMMatrices,
3526
Σ::AbstractMatrix,

src/frontend/fit/fitmeasures/minus2ll.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,7 @@ end
4040
# compute likelihood for missing data - H1 -------------------------------------------------
4141
# -2ll = ∑ log(2π)*(nᵢ + mᵢ) + ln(Σᵢ) + (mᵢ - μᵢ)ᵀ Σᵢ⁻¹ (mᵢ - μᵢ)) + tr(SᵢΣᵢ)
4242
function minus2ll(observed::SemObservedMissing)
43-
observed.em_model.fitted || em_mvn(observed)
44-
45-
μ = observed.em_model.μ
46-
Σ = observed.em_model.Σ
43+
Σ, μ = obs_cov(observed), obs_mean(observed)
4744

4845
F = 0.0
4946
for pat in observed.patterns

src/observed/EM.jl

Lines changed: 98 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
# what about random restarts?
66

77
"""
8-
em_mvn(;
9-
observed::SemObservedMissing,
10-
start_em = start_em_observed,
11-
max_iter_em = 100,
12-
rtol_em = 1e-4,
13-
kwargs...)
8+
em_mvn(patterns::AbstractVector{SemObservedMissingPattern};
9+
start_em = start_em_observed,
10+
max_iter_em = 100,
11+
rtol_em = 1e-4,
12+
kwargs...)
1413
15-
Estimates the covariance matrix and mean vector of the normal distribution via expectation maximization for `observed`.
16-
Overwrites the statistics stored in `observed`.
14+
Estimates the covariance matrix and mean vector of the
15+
multivariate normal distribution (MVN)
16+
via expectation maximization (EM) for `observed`.
17+
18+
Returns the tuple of the EM covariance matrix and the EM mean vector.
1719
1820
Uses the EM algorithm for MVN-distributed data with missing values
1921
adapted from the supplementary material to the book *Machine Learning: A Probabilistic Perspective*,
@@ -22,23 +24,21 @@ copyright (2010) Kevin Murphy and Matt Dunham: see
2224
[*emAlgo.m*](https://github.com/probml/pmtk3/blob/master/toolbox/Algorithms/optimization/emAlgo.m) scripts.
2325
"""
2426
function em_mvn(
25-
observed::SemObservedMissing;
27+
patterns::AbstractVector{<:SemObservedMissingPattern};
2628
start_em = start_em_observed,
27-
max_iter_em = 100,
28-
rtol_em = 1e-4,
29+
max_iter_em::Integer = 100,
30+
rtol_em::Number = 1e-4,
2931
kwargs...,
3032
)
31-
n_man = SEM.n_man(observed)
32-
33-
# preallocate stuff?
34-
𝔼x_pre = zeros(n_man)
35-
𝔼xxᵀ_pre = zeros(n_man, n_man)
33+
n_man = SEM.n_man(patterns[1])
3634

3735
### precompute for full cases
38-
fullpat = observed.patterns[1]
39-
if nmissed_vars(fullpat) == 0
40-
sum!(reshape(𝔼x_pre, 1, n_man), fullpat.data)
41-
mul!(𝔼xxᵀ_pre, fullpat.data', fullpat.data)
36+
𝔼x_full = zeros(n_man)
37+
𝔼xxᵀ_full = zeros(n_man, n_man)
38+
if nmissed_vars(patterns[1]) == 0
39+
fullpat = patterns[1]
40+
sum!(reshape(𝔼x_full, 1, n_man), fullpat.data)
41+
mul!(𝔼xxᵀ_full, fullpat.data', fullpat.data)
4242
else
4343
@warn "No full cases pattern found"
4444
end
@@ -47,72 +47,79 @@ function em_mvn(
4747
# estepFn = (em_model, data) -> estep(em_model, data, EXsum, EXXsum, ismissing, missingRows, n_obs)
4848

4949
# initialize
50-
em_model = start_em(observed; kwargs...)
51-
em_model_prev = EmMVNModel(zeros(n_man, n_man), zeros(n_man), false)
52-
iter = 1
53-
done = false
54-
𝔼x = zeros(n_man)
55-
𝔼xxᵀ = zeros(n_man, n_man)
56-
57-
while !done
58-
step!(em_model, observed, 𝔼x, 𝔼xxᵀ, 𝔼x_pre, 𝔼xxᵀ_pre)
59-
60-
if iter > max_iter_em
61-
done = true
62-
@warn "EM Algorithm for MVN missing data did not converge. Likelihood for FIML is not interpretable.
63-
Maybe try passing different starting values via 'start_em = ...' "
64-
elseif iter > 1
65-
# done = isapprox(ll, ll_prev; rtol = rtol)
66-
done =
67-
isapprox(em_model_prev.μ, em_model.μ; rtol = rtol_em) &&
68-
isapprox(em_model_prev.Σ, em_model.Σ; rtol = rtol_em)
50+
Σ₀, μ = start_em(patterns; kwargs...)
51+
Σ = convert(Matrix, Σ₀)
52+
@assert all(isfinite, Σ) all(isfinite, μ)
53+
Σ_prev, μ_prev = copy(Σ), copy(μ)
54+
55+
iter = 0
56+
converged = false
57+
while !converged && (iter < max_iter_em)
58+
em_step!(Σ, μ, Σ_prev, μ_prev, patterns, 𝔼x_full, 𝔼xxᵀ_full)
59+
60+
if iter > 0
61+
Δμ = norm- μ_prev)
62+
ΔΣ = norm- Σ_prev)
63+
Δμ_rel = Δμ / max(norm(μ_prev), norm(μ))
64+
ΔΣ_rel = ΔΣ / max(norm(Σ_prev), norm(Σ))
65+
#@info "Iteration #$iter: ΔΣ=$(ΔΣ) ΔΣ/Σ=$(ΔΣ_rel) Δμ=$(Δμ) Δμ/μ=$(Δμ_rel)"
66+
# converged = isapprox(ll, ll_prev; rtol = rtol)
67+
converged = ΔΣ_rel <= rtol_em && Δμ_rel <= rtol_em
68+
end
69+
if !converged
70+
Σ, Σ_prev = Σ_prev, Σ
71+
μ, μ_prev = μ_prev, μ
6972
end
70-
71-
# print("$iter \n")
7273
iter += 1
73-
copyto!(em_model_prev.μ, em_model.μ)
74-
copyto!(em_model_prev.Σ, em_model.Σ)
74+
#@info "$iter\n"
7575
end
7676

77-
# update EM Mode in observed
78-
observed.em_model.Σ .= em_model.Σ
79-
observed.em_model.μ .= em_model.μ
80-
observed.em_model.fitted = true
77+
if !converged
78+
@warn "EM Algorithm for MVN missing data did not converge in $iter iterations.\n" *
79+
"Likelihood for FIML is not interpretable.\n" *
80+
"Maybe try passing different starting values via 'start_em = ...' "
81+
else
82+
@info "EM for MVN missing data converged in $iter iterations"
83+
end
8184

82-
return nothing
85+
return Σ, μ
8386
end
8487

8588
# E and M steps -----------------------------------------------------------------------------
8689

87-
# update em_model
88-
function step!(em_model::EmMVNModel, observed::SemObserved, 𝔼x, 𝔼xxᵀ, 𝔼x_pre, 𝔼xxᵀ_pre)
90+
function em_step!(
91+
Σ::AbstractMatrix,
92+
μ::AbstractVector,
93+
Σ₀::AbstractMatrix,
94+
μ₀::AbstractVector,
95+
patterns::AbstractVector{<:SemObservedMissingPattern},
96+
𝔼x_full,
97+
𝔼xxᵀ_full,
98+
)
8999
# E step, update 𝔼x and 𝔼xxᵀ
90-
fill!(𝔼x, 0)
91-
fill!(𝔼xxᵀ, 0)
92-
93-
μ = em_model.μ
94-
Σ = em_model.Σ
100+
copy!(μ, 𝔼x_full)
101+
copy!(Σ, 𝔼xxᵀ_full)
95102

96103
# Compute the expected sufficient statistics
97-
for pat in observed.patterns
104+
for pat in patterns
98105
(nmissed_vars(pat) == 0) && continue # skip full cases
99106

100107
# observed and unobserved vars
101108
u = pat.miss_mask
102109
o = pat.obs_mask
103110

104111
# precompute for pattern
105-
Σoo_chol = cholesky(Symmetric(Σ[o, o]))
106-
Σuo = Σ[u, o]
107-
μu = μ[u]
108-
μo = μ[o]
112+
Σoo_chol = cholesky(Symmetric[o, o]))
113+
Σuo = Σ[u, o]
114+
μu = μ[u]
115+
μo = μ[o]
109116

110117
𝔼xu = fill!(similar(μu), 0)
111118
𝔼xo = fill!(similar(μo), 0)
112119
𝔼xᵢu = similar(μu)
113120

114121
𝔼xxᵀuo = fill!(similar(Σuo), 0)
115-
𝔼xxᵀuu = n_obs(pat) * (Σ[u, u] - Σuo * (Σoo_chol \ Σuo'))
122+
𝔼xxᵀuu = n_obs(pat) *[u, u] - Σuo * (Σoo_chol \ Σuo'))
116123

117124
# loop trough data
118125
@inbounds for rowdata in eachrow(pat.data)
@@ -124,24 +131,21 @@ function step!(em_model::EmMVNModel, observed::SemObserved, 𝔼x, 𝔼xxᵀ,
124131
𝔼xo .+= rowdata
125132
end
126133

127-
𝔼xxᵀ[o, o] .+= pat.data' * pat.data
128-
𝔼xxᵀ[u, o] .+= 𝔼xxᵀuo
129-
𝔼xxᵀ[o, u] .+= 𝔼xxᵀuo'
130-
𝔼xxᵀ[u, u] .+= 𝔼xxᵀuu
134+
Σ[o, o] .+= pat.data' * pat.data
135+
Σ[u, o] .+= 𝔼xxᵀuo
136+
Σ[o, u] .+= 𝔼xxᵀuo'
137+
Σ[u, u] .+= 𝔼xxᵀuu
131138

132-
𝔼x[o] .+= 𝔼xo
133-
𝔼x[u] .+= 𝔼xu
139+
μ[o] .+= 𝔼xo
140+
μ[u] .+= 𝔼xu
134141
end
135142

136-
𝔼x .+= 𝔼x_pre
137-
𝔼xxᵀ .+= 𝔼xxᵀ_pre
138-
139143
# M step, update em_model
140-
em_model.μ .= 𝔼x ./ n_obs(observed)
141-
em_model.Σ .= 𝔼xxᵀ ./ n_obs(observed)
142-
mul!(em_model.Σ, em_model.μ, em_model.μ', -1, 1)
144+
k = inv(sum(n_obs, patterns))
145+
lmul!(k, Σ)
146+
lmul!(k, μ)
147+
mul!(Σ, μ, μ', -1, 1)
143148

144-
#Σ = em_model.Σ
145149
# ridge Σ
146150
# while !isposdef(Σ)
147151
# Σ += 0.5I
@@ -150,42 +154,48 @@ function step!(em_model::EmMVNModel, observed::SemObserved, 𝔼x, 𝔼xxᵀ,
150154
# diagonalization
151155
#if !isposdef(Σ)
152156
# print("Matrix not positive definite")
153-
# em_model.Σ .= 0
154-
# em_model.Σ[diagind(em_model.Σ)] .= diag(Σ)
157+
# Σ .= 0
158+
# Σ[diagind(em_model.Σ)] .= diag(Σ)
155159
#else
156-
# em_model.Σ = Σ
160+
# Σ = Σ
157161
#end
158162

159-
return em_model
163+
return Σ, μ
160164
end
161165

162166
# generate starting values -----------------------------------------------------------------
163167

164168
# use μ and Σ of full cases
165-
function start_em_observed(observed::SemObservedMissing; kwargs...)
166-
fullpat = observed.patterns[1]
169+
function start_em_observed(patterns::AbstractVector{<:SemObservedMissingPattern}; kwargs...)
170+
fullpat = patterns[1]
167171
if (nmissed_vars(fullpat) == 0) && (n_obs(fullpat) > 1)
168172
μ = copy(fullpat.obs_mean)
169-
Σ = copy(fullpat.obs_cov)
173+
Σ = copy(parent(fullpat.obs_cov))
170174
if !isposdef(Σ)
171175
Σ = Diagonal(Σ)
172176
end
173-
return EmMVNModel(convert(Matrix, Σ), μ, false)
177+
return Σ, μ
174178
else
175179
return start_em_simple(observed, kwargs...)
176180
end
177181
end
178182

179183
# use μ = O and Σ = I
180-
function start_em_simple(observed::SemObservedMissing; kwargs...)
181-
μ = zeros(n_man(observed))
182-
Σ = rand(n_man(observed), n_man(observed))
184+
function start_em_simple(patterns::AbstractVector{<:SemObservedMissingPattern}; kwargs...)
185+
nvars = n_man(first(patterns))
186+
μ = zeros(nvars)
187+
Σ = rand(nvars, nvars)
183188
Σ = Σ * Σ'
184189
# Σ = Matrix(1.0I, n_man, n_man)
185-
return EmMVNModel(Σ, μ, false)
190+
return Σ, μ
186191
end
187192

188193
# set to passed values
189-
function start_em_set(observed::SemObservedMissing; model_em, kwargs...)
190-
return em_model
194+
function start_em_set(
195+
patterns::AbstractVector{<:SemObservedMissingPattern};
196+
obs_cov::AbstractMatrix,
197+
obs_mean::AbstractVector,
198+
kwargs...,
199+
)
200+
return copy(obs_cov), copy(obs_mean)
191201
end

0 commit comments

Comments
 (0)