@@ -29,15 +29,15 @@ function em_mvn(
2929 rtol_em = 1e-4 ,
3030 kwargs... ,
3131)
32- n_man = observed. n_man
32+ nvars = nobserved_vars ( observed)
3333 nsamps = nsamples (observed)
3434
3535 # preallocate stuff?
36- 𝔼x_pre = zeros (n_man )
37- 𝔼xxᵀ_pre = zeros (n_man, n_man )
36+ 𝔼x_pre = zeros (nvars )
37+ 𝔼xxᵀ_pre = zeros (nvars, nvars )
3838
3939 # ## precompute for full cases
40- if length (observed. patterns[1 ]) == observed . n_man
40+ if length (observed. patterns[1 ]) == nvars
4141 for row in observed. rows[1 ]
4242 row = observed. data_rowwise[row]
4343 𝔼x_pre += row
@@ -50,11 +50,11 @@ function em_mvn(
5050
5151 # initialize
5252 em_model = start_em (observed; kwargs... )
53- em_model_prev = EmMVNModel (zeros (n_man, n_man ), zeros (n_man ), false )
53+ em_model_prev = EmMVNModel (zeros (nvars, nvars ), zeros (nvars ), false )
5454 iter = 1
5555 done = false
56- 𝔼x = zeros (n_man )
57- 𝔼xxᵀ = zeros (n_man, n_man )
56+ 𝔼x = zeros (nvars )
57+ 𝔼xxᵀ = zeros (nvars, nvars )
5858
5959 while ! done
6060 em_mvn_Estep! (𝔼x, 𝔼xxᵀ, em_model, observed, 𝔼x_pre, 𝔼xxᵀ_pre)
153153
154154# use μ and Σ of full cases
155155function start_em_observed (observed:: SemObservedMissing ; kwargs... )
156- if (length (observed. patterns[1 ]) == observed. n_man ) & (observed. pattern_nsamples[1 ] > 1 )
156+ if (length (observed. patterns[1 ]) == nobserved_vars ( observed) ) & (observed. pattern_nsamples[1 ] > 1 )
157157 μ = copy (observed. obs_mean[1 ])
158158 Σ = copy (Symmetric (observed. obs_cov[1 ]))
159159 if ! isposdef (Σ)
@@ -167,11 +167,11 @@ end
167167
168168# use μ = O and Σ = I
169169function start_em_simple (observed:: SemObservedMissing ; kwargs... )
170- n_man = Int (observed. n_man )
171- μ = zeros (n_man )
172- Σ = rand (n_man, n_man )
170+ nvars = nobserved_vars (observed)
171+ μ = zeros (nvars )
172+ Σ = rand (nvars, nvars )
173173 Σ = Σ * Σ'
174- # Σ = Matrix(1.0I, n_man, n_man )
174+ # Σ = Matrix(1.0I, nvars, nvars )
175175 return EmMVNModel (Σ, μ, false )
176176end
177177
0 commit comments