Skip to content

Commit df5ea75

Browse files
committed
EM: nobs -> nsamples
1 parent 4c15328 commit df5ea75

1 file changed

Lines changed: 31 additions & 18 deletions

File tree

src/observed/EM.jl

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,23 @@ function em_mvn(
2828
start_em = start_em_observed,
2929
max_iter_em::Integer = 100,
3030
rtol_em::Number = 1e-4,
31-
max_nobs_em::Union{Integer, Nothing} = nothing,
31+
max_nsamples_em::Union{Integer, Nothing} = nothing,
3232
kwargs...,
3333
)
3434
nvars = SEM.nobserved_vars(patterns[1])
3535

3636
# precompute for full cases
3737
𝔼x_full = zeros(nvars)
3838
𝔼xxᵀ_full = zeros(nvars, nvars)
39-
nobs_full = 0
39+
nsamples_full = 0
4040
for pat in patterns
4141
if nmissed_vars(pat) == 0
4242
𝔼x_full .+= sum(pat.data, dims = 2)
4343
mul!(𝔼xxᵀ_full, pat.data, pat.data', 1, 1)
44-
nobs_full += nsamples(pat)
44+
nsamples_full += nsamples(pat)
4545
end
4646
end
47-
if nobs_full == 0
47+
if nsamples_full == 0
4848
@warn "No full cases in data"
4949
end
5050

@@ -59,7 +59,17 @@ function em_mvn(
5959
Δμ_rel = NaN
6060
ΔΣ_rel = NaN
6161
while !converged && (iter < max_iter_em)
62-
em_step!(Σ, μ, Σ_prev, μ_prev, patterns, 𝔼xxᵀ_full, 𝔼x_full, nobs_full; max_nobs_em)
62+
em_step!(
63+
Σ,
64+
μ,
65+
Σ_prev,
66+
μ_prev,
67+
patterns,
68+
𝔼xxᵀ_full,
69+
𝔼x_full,
70+
nsamples_full;
71+
max_nsamples_em,
72+
)
6373

6474
if iter > 0
6575
Δμ = norm- μ_prev)
@@ -99,15 +109,15 @@ function em_step!(
99109
patterns::AbstractVector{<:SemObservedMissingPattern},
100110
𝔼xxᵀ_full::AbstractMatrix,
101111
𝔼x_full::AbstractVector,
102-
nobs_full::Integer;
103-
max_nobs_em::Union{Integer, Nothing} = nothing,
112+
nsamples_full::Integer;
113+
max_nsamples_em::Union{Integer, Nothing} = nothing,
104114
)
105115
# E step, update 𝔼x and 𝔼xxᵀ
106116
copy!(μ, 𝔼x_full)
107117
copy!(Σ, 𝔼xxᵀ_full)
108-
nobs_used = nobs_full
109-
mul!(Σ, μ₀, μ₀', -nobs_used, 1)
110-
axpy!(-nobs_used, μ₀, μ)
118+
nsamples_used = nsamples_full
119+
mul!(Σ, μ₀, μ₀', -nsamples_used, 1)
120+
axpy!(-nsamples_used, μ₀, μ)
111121

112122
# Compute the expected sufficient statistics
113123
for pat in patterns
@@ -124,18 +134,21 @@ function em_step!(
124134
μ₀o = μ₀[o]
125135

126136
# get pattern observations
127-
nobs = !isnothing(max_nobs_em) ? min(max_nobs_em, nsamples(pat)) : nsamples(pat)
137+
nsamples_pat =
138+
!isnothing(max_nsamples_em) ? min(max_nsamples_em, nsamples(pat)) :
139+
nsamples(pat)
128140
zo =
129-
nobs < nsamples(pat) ?
130-
pat.data[:, sort!(sample(1:nsamples(pat), nobs, replace = false))] : copy(pat.data)
141+
nsamples_pat < nsamples(pat) ?
142+
pat.data[:, sort!(sample(1:nsamples(pat), nsamples_pat, replace = false))] :
143+
copy(pat.data)
131144
zo .-= μ₀o # subtract current mean from observations
132145

133146
𝔼zo = sum(zo, dims = 2)
134147
𝔼zu = fill!(similar(μ₀u), 0)
135148

136149
𝔼zzᵀuo = fill!(similar(Σ₀uo), 0)
137-
𝔼zzᵀuu = nobs * Σ₀[u, u]
138-
mul!(𝔼zzᵀuu, Σ₀uo, Σ₀oo_chol \ Σ₀uo', -nobs, 1)
150+
𝔼zzᵀuu = nsamples_pat * Σ₀[u, u]
151+
mul!(𝔼zzᵀuu, Σ₀uo, Σ₀oo_chol \ Σ₀uo', -nsamples_pat, 1)
139152

140153
# loop through observations
141154
yᵢo = similar(μ₀o)
@@ -167,12 +180,12 @@ function em_step!(
167180
μ[o] .+= 𝔼zo
168181
μ[u] .+= 𝔼zu
169182

170-
nobs_used += nobs
183+
nsamples_used += nsamples_pat
171184
end
172185

173186
# M step, update em_model
174-
lmul!(1 / nobs_used, Σ)
175-
lmul!(1 / nobs_used, μ)
187+
lmul!(1 / nsamples_used, Σ)
188+
lmul!(1 / nsamples_used, μ)
176189
# at this point μ = μ - μ₀
177190
# and Σ = Σ + (μ - μ₀)×(μ - μ₀)' - μ₀×μ₀'
178191
mul!(Σ, μ, μ₀', -1, 1)

0 commit comments

Comments
 (0)