@@ -28,23 +28,23 @@ function em_mvn(
2828 start_em = start_em_observed,
2929 max_iter_em:: Integer = 100 ,
3030 rtol_em:: Number = 1e-4 ,
31- max_nobs_em :: Union{Integer, Nothing} = nothing ,
31+ max_nsamples_em :: Union{Integer, Nothing} = nothing ,
3232 kwargs... ,
3333)
3434 nvars = SEM. nobserved_vars (patterns[1 ])
3535
3636 # precompute for full cases
3737 𝔼x_full = zeros (nvars)
3838 𝔼xxᵀ_full = zeros (nvars, nvars)
39- nobs_full = 0
39+ nsamples_full = 0
4040 for pat in patterns
4141 if nmissed_vars (pat) == 0
4242 𝔼x_full .+ = sum (pat. data, dims = 2 )
4343 mul! (𝔼xxᵀ_full, pat. data, pat. data' , 1 , 1 )
44- nobs_full += nsamples (pat)
44+ nsamples_full += nsamples (pat)
4545 end
4646 end
47- if nobs_full == 0
47+ if nsamples_full == 0
4848 @warn " No full cases in data"
4949 end
5050
@@ -59,7 +59,17 @@ function em_mvn(
5959 Δμ_rel = NaN
6060 ΔΣ_rel = NaN
6161 while ! converged && (iter < max_iter_em)
62- em_step! (Σ, μ, Σ_prev, μ_prev, patterns, 𝔼xxᵀ_full, 𝔼x_full, nobs_full; max_nobs_em)
62+ em_step! (
63+ Σ,
64+ μ,
65+ Σ_prev,
66+ μ_prev,
67+ patterns,
68+ 𝔼xxᵀ_full,
69+ 𝔼x_full,
70+ nsamples_full;
71+ max_nsamples_em,
72+ )
6373
6474 if iter > 0
6575 Δμ = norm (μ - μ_prev)
@@ -99,15 +109,15 @@ function em_step!(
99109 patterns:: AbstractVector{<:SemObservedMissingPattern} ,
100110 𝔼xxᵀ_full:: AbstractMatrix ,
101111 𝔼x_full:: AbstractVector ,
102- nobs_full :: Integer ;
103- max_nobs_em :: Union{Integer, Nothing} = nothing ,
112+ nsamples_full :: Integer ;
113+ max_nsamples_em :: Union{Integer, Nothing} = nothing ,
104114)
105115 # E step, update 𝔼x and 𝔼xxᵀ
106116 copy! (μ, 𝔼x_full)
107117 copy! (Σ, 𝔼xxᵀ_full)
108- nobs_used = nobs_full
109- mul! (Σ, μ₀, μ₀' , - nobs_used , 1 )
110- axpy! (- nobs_used , μ₀, μ)
118+ nsamples_used = nsamples_full
119+ mul! (Σ, μ₀, μ₀' , - nsamples_used , 1 )
120+ axpy! (- nsamples_used , μ₀, μ)
111121
112122 # Compute the expected sufficient statistics
113123 for pat in patterns
@@ -124,18 +134,21 @@ function em_step!(
124134 μ₀o = μ₀[o]
125135
126136 # get pattern observations
127- nobs = ! isnothing (max_nobs_em) ? min (max_nobs_em, nsamples (pat)) : nsamples (pat)
137+ nsamples_pat =
138+ ! isnothing (max_nsamples_em) ? min (max_nsamples_em, nsamples (pat)) :
139+ nsamples (pat)
128140 zo =
129- nobs < nsamples (pat) ?
130- pat. data[:, sort! (sample (1 : nsamples (pat), nobs, replace = false ))] : copy (pat. data)
141+ nsamples_pat < nsamples (pat) ?
142+ pat. data[:, sort! (sample (1 : nsamples (pat), nsamples_pat, replace = false ))] :
143+ copy (pat. data)
131144 zo .- = μ₀o # subtract current mean from observations
132145
133146 𝔼zo = sum (zo, dims = 2 )
134147 𝔼zu = fill! (similar (μ₀u), 0 )
135148
136149 𝔼zzᵀuo = fill! (similar (Σ₀uo), 0 )
137- 𝔼zzᵀuu = nobs * Σ₀[u, u]
138- mul! (𝔼zzᵀuu, Σ₀uo, Σ₀oo_chol \ Σ₀uo' , - nobs , 1 )
150+ 𝔼zzᵀuu = nsamples_pat * Σ₀[u, u]
151+ mul! (𝔼zzᵀuu, Σ₀uo, Σ₀oo_chol \ Σ₀uo' , - nsamples_pat , 1 )
139152
140153 # loop through observations
141154 yᵢo = similar (μ₀o)
@@ -167,12 +180,12 @@ function em_step!(
167180 μ[o] .+ = 𝔼zo
168181 μ[u] .+ = 𝔼zu
169182
170- nobs_used += nobs
183+ nsamples_used += nsamples_pat
171184 end
172185
173186 # M step, update em_model
174- lmul! (1 / nobs_used , Σ)
175- lmul! (1 / nobs_used , μ)
187+ lmul! (1 / nsamples_used , Σ)
188+ lmul! (1 / nsamples_used , μ)
176189 # at this point μ = μ - μ₀
177190 # and Σ = Σ + (μ - μ₀)×(μ - μ₀)' - μ₀×μ₀'
178191 mul! (Σ, μ, μ₀' , - 1 , 1 )
0 commit comments