55# what about random restarts?
66
77"""
8- em_mvn(;
9- observed::SemObservedMissing,
10- start_em = start_em_observed,
11- max_iter_em = 100,
12- rtol_em = 1e-4,
13- kwargs...)
8+ em_mvn(patterns::AbstractVector{SemObservedMissingPattern};
9+ start_em = start_em_observed,
10+ max_iter_em = 100,
11+ rtol_em = 1e-4,
12+ kwargs...)
1413
15- Estimates the covariance matrix and mean vector of the normal distribution via expectation maximization for `observed`.
16- Overwrites the statistics stored in `observed`.
14+ Estimates the covariance matrix and mean vector of the
15+ multivariate normal distribution (MVN)
16+ via expectation maximization (EM) for `observed`.
17+
18+ Returns the tuple of the EM covariance matrix and the EM mean vector.
1719
1820Uses the EM algorithm for MVN-distributed data with missing values
1921adapted from the supplementary material to the book *Machine Learning: A Probabilistic Perspective*,
@@ -22,23 +24,21 @@ copyright (2010) Kevin Murphy and Matt Dunham: see
2224[*emAlgo.m*](https://github.com/probml/pmtk3/blob/master/toolbox/Algorithms/optimization/emAlgo.m) scripts.
2325"""
2426function em_mvn (
25- observed :: SemObservedMissing ;
27+ patterns :: AbstractVector{<:SemObservedMissingPattern} ;
2628 start_em = start_em_observed,
27- max_iter_em = 100 ,
28- rtol_em = 1e-4 ,
29+ max_iter_em:: Integer = 100 ,
30+ rtol_em:: Number = 1e-4 ,
2931 kwargs... ,
3032)
31- n_man = SEM. n_man (observed)
32-
33- # preallocate stuff?
34- 𝔼x_pre = zeros (n_man)
35- 𝔼xxᵀ_pre = zeros (n_man, n_man)
33+ n_man = SEM. n_man (patterns[1 ])
3634
3735 # ## precompute for full cases
38- fullpat = observed. patterns[1 ]
39- if nmissed_vars (fullpat) == 0
40- sum! (reshape (𝔼x_pre, 1 , n_man), fullpat. data)
41- mul! (𝔼xxᵀ_pre, fullpat. data' , fullpat. data)
36+ 𝔼x_full = zeros (n_man)
37+ 𝔼xxᵀ_full = zeros (n_man, n_man)
38+ if nmissed_vars (patterns[1 ]) == 0
39+ fullpat = patterns[1 ]
40+ sum! (reshape (𝔼x_full, 1 , n_man), fullpat. data)
41+ mul! (𝔼xxᵀ_full, fullpat. data' , fullpat. data)
4242 else
4343 @warn " No full cases pattern found"
4444 end
@@ -47,72 +47,79 @@ function em_mvn(
4747 # estepFn = (em_model, data) -> estep(em_model, data, EXsum, EXXsum, ismissing, missingRows, n_obs)
4848
4949 # initialize
50- em_model = start_em (observed; kwargs... )
51- em_model_prev = EmMVNModel (zeros (n_man, n_man), zeros (n_man), false )
52- iter = 1
53- done = false
54- 𝔼x = zeros (n_man)
55- 𝔼xxᵀ = zeros (n_man, n_man)
56-
57- while ! done
58- step! (em_model, observed, 𝔼x, 𝔼xxᵀ, 𝔼x_pre, 𝔼xxᵀ_pre)
59-
60- if iter > max_iter_em
61- done = true
62- @warn " EM Algorithm for MVN missing data did not converge. Likelihood for FIML is not interpretable.
63- Maybe try passing different starting values via 'start_em = ...' "
64- elseif iter > 1
65- # done = isapprox(ll, ll_prev; rtol = rtol)
66- done =
67- isapprox (em_model_prev. μ, em_model. μ; rtol = rtol_em) &&
68- isapprox (em_model_prev. Σ, em_model. Σ; rtol = rtol_em)
50+ Σ₀, μ = start_em (patterns; kwargs... )
51+ Σ = convert (Matrix, Σ₀)
52+ @assert all (isfinite, Σ) all (isfinite, μ)
53+ Σ_prev, μ_prev = copy (Σ), copy (μ)
54+
55+ iter = 0
56+ converged = false
57+ while ! converged && (iter < max_iter_em)
58+ em_step! (Σ, μ, Σ_prev, μ_prev, patterns, 𝔼x_full, 𝔼xxᵀ_full)
59+
60+ if iter > 0
61+ Δμ = norm (μ - μ_prev)
62+ ΔΣ = norm (Σ - Σ_prev)
63+ Δμ_rel = Δμ / max (norm (μ_prev), norm (μ))
64+ ΔΣ_rel = ΔΣ / max (norm (Σ_prev), norm (Σ))
65+ # @info "Iteration #$iter: ΔΣ=$(ΔΣ) ΔΣ/Σ=$(ΔΣ_rel) Δμ=$(Δμ) Δμ/μ=$(Δμ_rel)"
66+ # converged = isapprox(ll, ll_prev; rtol = rtol)
67+ converged = ΔΣ_rel <= rtol_em && Δμ_rel <= rtol_em
68+ end
69+ if ! converged
70+ Σ, Σ_prev = Σ_prev, Σ
71+ μ, μ_prev = μ_prev, μ
6972 end
70-
71- # print("$iter \n")
7273 iter += 1
73- copyto! (em_model_prev. μ, em_model. μ)
74- copyto! (em_model_prev. Σ, em_model. Σ)
74+ # @info "$iter\n"
7575 end
7676
77- # update EM Mode in observed
78- observed. em_model. Σ .= em_model. Σ
79- observed. em_model. μ .= em_model. μ
80- observed. em_model. fitted = true
77+ if ! converged
78+ @warn " EM Algorithm for MVN missing data did not converge in $iter iterations.\n " *
79+ " Likelihood for FIML is not interpretable.\n " *
80+ " Maybe try passing different starting values via 'start_em = ...' "
81+ else
82+ @info " EM for MVN missing data converged in $iter iterations"
83+ end
8184
82- return nothing
85+ return Σ, μ
8386end
8487
8588# E and M steps -----------------------------------------------------------------------------
8689
87- # update em_model
88- function step! (em_model:: EmMVNModel , observed:: SemObserved , 𝔼x, 𝔼xxᵀ, 𝔼x_pre, 𝔼xxᵀ_pre)
90+ function em_step! (
91+ Σ:: AbstractMatrix ,
92+ μ:: AbstractVector ,
93+ Σ₀:: AbstractMatrix ,
94+ μ₀:: AbstractVector ,
95+ patterns:: AbstractVector{<:SemObservedMissingPattern} ,
96+ 𝔼x_full,
97+ 𝔼xxᵀ_full,
98+ )
8999 # E step, update 𝔼x and 𝔼xxᵀ
90- fill! (𝔼x, 0 )
91- fill! (𝔼xxᵀ, 0 )
92-
93- μ = em_model. μ
94- Σ = em_model. Σ
100+ copy! (μ, 𝔼x_full)
101+ copy! (Σ, 𝔼xxᵀ_full)
95102
96103 # Compute the expected sufficient statistics
97- for pat in observed . patterns
104+ for pat in patterns
98105 (nmissed_vars (pat) == 0 ) && continue # skip full cases
99106
100107 # observed and unobserved vars
101108 u = pat. miss_mask
102109 o = pat. obs_mask
103110
104111 # precompute for pattern
105- Σoo_chol = cholesky (Symmetric (Σ[o, o]))
106- Σuo = Σ[u, o]
107- μu = μ[u]
108- μo = μ[o]
112+ Σoo_chol = cholesky (Symmetric (Σ₀ [o, o]))
113+ Σuo = Σ₀ [u, o]
114+ μu = μ₀ [u]
115+ μo = μ₀ [o]
109116
110117 𝔼xu = fill! (similar (μu), 0 )
111118 𝔼xo = fill! (similar (μo), 0 )
112119 𝔼xᵢu = similar (μu)
113120
114121 𝔼xxᵀuo = fill! (similar (Σuo), 0 )
115- 𝔼xxᵀuu = n_obs (pat) * (Σ[u, u] - Σuo * (Σoo_chol \ Σuo' ))
122+ 𝔼xxᵀuu = n_obs (pat) * (Σ₀ [u, u] - Σuo * (Σoo_chol \ Σuo' ))
116123
117124 # loop trough data
118125 @inbounds for rowdata in eachrow (pat. data)
@@ -124,24 +131,21 @@ function step!(em_model::EmMVNModel, observed::SemObserved, 𝔼x, 𝔼xxᵀ,
124131 𝔼xo .+ = rowdata
125132 end
126133
127- 𝔼xxᵀ [o, o] .+ = pat. data' * pat. data
128- 𝔼xxᵀ [u, o] .+ = 𝔼xxᵀuo
129- 𝔼xxᵀ [o, u] .+ = 𝔼xxᵀuo'
130- 𝔼xxᵀ [u, u] .+ = 𝔼xxᵀuu
134+ Σ [o, o] .+ = pat. data' * pat. data
135+ Σ [u, o] .+ = 𝔼xxᵀuo
136+ Σ [o, u] .+ = 𝔼xxᵀuo'
137+ Σ [u, u] .+ = 𝔼xxᵀuu
131138
132- 𝔼x [o] .+ = 𝔼xo
133- 𝔼x [u] .+ = 𝔼xu
139+ μ [o] .+ = 𝔼xo
140+ μ [u] .+ = 𝔼xu
134141 end
135142
136- 𝔼x .+ = 𝔼x_pre
137- 𝔼xxᵀ .+ = 𝔼xxᵀ_pre
138-
139143 # M step, update em_model
140- em_model. μ .= 𝔼x ./ n_obs (observed)
141- em_model. Σ .= 𝔼xxᵀ ./ n_obs (observed)
142- mul! (em_model. Σ, em_model. μ, em_model. μ' , - 1 , 1 )
144+ k = inv (sum (n_obs, patterns))
145+ lmul! (k, Σ)
146+ lmul! (k, μ)
147+ mul! (Σ, μ, μ' , - 1 , 1 )
143148
144- # Σ = em_model.Σ
145149 # ridge Σ
146150 # while !isposdef(Σ)
147151 # Σ += 0.5I
@@ -150,42 +154,48 @@ function step!(em_model::EmMVNModel, observed::SemObserved, 𝔼x, 𝔼xxᵀ,
150154 # diagonalization
151155 # if !isposdef(Σ)
152156 # print("Matrix not positive definite")
153- # em_model. Σ .= 0
154- # em_model. Σ[diagind(em_model.Σ)] .= diag(Σ)
157+ # Σ .= 0
158+ # Σ[diagind(em_model.Σ)] .= diag(Σ)
155159 # else
156- # em_model. Σ = Σ
160+ # Σ = Σ
157161 # end
158162
159- return em_model
163+ return Σ, μ
160164end
161165
162166# generate starting values -----------------------------------------------------------------
163167
164168# use μ and Σ of full cases
165- function start_em_observed (observed :: SemObservedMissing ; kwargs... )
166- fullpat = observed . patterns[1 ]
169+ function start_em_observed (patterns :: AbstractVector{<:SemObservedMissingPattern} ; kwargs... )
170+ fullpat = patterns[1 ]
167171 if (nmissed_vars (fullpat) == 0 ) && (n_obs (fullpat) > 1 )
168172 μ = copy (fullpat. obs_mean)
169- Σ = copy (fullpat. obs_cov)
173+ Σ = copy (parent ( fullpat. obs_cov) )
170174 if ! isposdef (Σ)
171175 Σ = Diagonal (Σ)
172176 end
173- return EmMVNModel ( convert (Matrix, Σ) , μ, false )
177+ return Σ , μ
174178 else
175179 return start_em_simple (observed, kwargs... )
176180 end
177181end
178182
179183# use μ = O and Σ = I
180- function start_em_simple (observed:: SemObservedMissing ; kwargs... )
181- μ = zeros (n_man (observed))
182- Σ = rand (n_man (observed), n_man (observed))
184+ function start_em_simple (patterns:: AbstractVector{<:SemObservedMissingPattern} ; kwargs... )
185+ nvars = n_man (first (patterns))
186+ μ = zeros (nvars)
187+ Σ = rand (nvars, nvars)
183188 Σ = Σ * Σ'
184189 # Σ = Matrix(1.0I, n_man, n_man)
185- return EmMVNModel ( Σ, μ, false )
190+ return Σ, μ
186191end
187192
188193# set to passed values
189- function start_em_set (observed:: SemObservedMissing ; model_em, kwargs... )
190- return em_model
194+ function start_em_set (
195+ patterns:: AbstractVector{<:SemObservedMissingPattern} ;
196+ obs_cov:: AbstractMatrix ,
197+ obs_mean:: AbstractVector ,
198+ kwargs... ,
199+ )
200+ return copy (obs_cov), copy (obs_mean)
191201end
0 commit comments