Skip to content

Commit 2e41bfd

Browse files
committed
CommutationMatrix type
replace comm_matrix helper functions with a CommutationMatrix and overloaded linalg ops
1 parent 50029a8 commit 2e41bfd

4 files changed

Lines changed: 68 additions & 109 deletions

File tree

src/StructuralEquationModels.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ const SEM = StructuralEquationModels
2323
# type hierarchy
2424
include("types.jl")
2525
include("objective_gradient_hessian.jl")
26+
27+
# helper objects and functions
28+
include("additional_functions/commutation_matrix.jl")
29+
2630
# fitted objects
2731
include("frontend/fit/SemFit.jl")
2832
# specification of models
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# transpose linear indices of the square n×n matrix
2+
# i.e.
3+
# 1 4
4+
# 2 5 => 1 2 3
5+
# 3 6 4 5 6
6+
transpose_linear_indices(n::Integer, m::Integer = n) =
7+
repeat(1:n, inner = m) .+ repeat((0:(m-1)) * n, outer = n)
8+
9+
"""
10+
CommutationMatrix(n::Integer) <: AbstractMatrix{Int}
11+
12+
A *commutation matrix* *C* is a n²×n² matrix of 0s and 1s.
13+
If *vec(A)* is a vectorized form of a n×n matrix *A*,
14+
then ``C * vec(A) = vec(Aᵀ)``.
15+
"""
16+
struct CommutationMatrix <: AbstractMatrix{Int}
17+
n::Int
18+
::Int
19+
transpose_inds::Vector{Int} # maps the linear indices of n×n matrix *B* to the indices of matrix *B'*
20+
21+
CommutationMatrix(n::Integer) = new(n, n^2, transpose_linear_indices(n))
22+
end
23+
24+
Base.size(A::CommutationMatrix) = (A.n², A.n²)
25+
Base.size(A::CommutationMatrix, dim::Integer) =
26+
1 <= dim <= 2 ? A.: throw(ArgumentError("invalid matrix dimension $dim"))
27+
Base.length(A::CommutationMatrix) = A.^2
28+
Base.getindex(A::CommutationMatrix, i::Int, j::Int) = j == A.transpose_inds[i] ? 1 : 0
29+
30+
function Base.:(*)(A::CommutationMatrix, B::AbstractMatrix)
31+
size(A, 2) == size(B, 1) || throw(
32+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
33+
)
34+
return B[A.transpose_inds, :]
35+
end
36+
37+
function Base.:(*)(A::CommutationMatrix, B::SparseMatrixCSC)
38+
size(A, 2) == size(B, 1) || throw(
39+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
40+
)
41+
return SparseMatrixCSC(
42+
size(B, 1),
43+
size(B, 2),
44+
copy(B.colptr),
45+
A.transpose_inds[B.rowval],
46+
copy(B.nzval),
47+
)
48+
end
49+
50+
function LinearAlgebra.lmul!(A::CommutationMatrix, B::SparseMatrixCSC)
51+
size(A, 2) == size(B, 1) || throw(
52+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
53+
)
54+
55+
@inbounds for (i, rowind) in enumerate(B.rowval)
56+
B.rowval[i] = A.transpose_inds[rowind]
57+
end
58+
return B
59+
end

src/additional_functions/helper.jl

Lines changed: 0 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -144,109 +144,6 @@ function elimination_matrix(nobs)
144144
return L
145145
end
146146

147-
function commutation_matrix(n; tosparse = false)
148-
M = zeros(n^2, n^2)
149-
150-
for i in 1:n
151-
for j in 1:n
152-
M[i+n*(j-1), j+n*(i-1)] = 1.0
153-
end
154-
end
155-
156-
if tosparse
157-
M = sparse(M)
158-
end
159-
160-
return M
161-
end
162-
163-
function commutation_matrix_pre_square(A)
164-
n2 = size(A, 1)
165-
n = Int(sqrt(n2))
166-
167-
ind = repeat(1:n, inner = n)
168-
indadd = (0:(n-1)) * n
169-
for i in 1:n
170-
ind[((i-1)*n+1):i*n] .+= indadd
171-
end
172-
173-
A_post = A[ind, :]
174-
175-
return A_post
176-
end
177-
178-
function commutation_matrix_pre_square_add!(B, A) # comuptes B + KₙA
179-
n2 = size(A, 1)
180-
n = Int(sqrt(n2))
181-
182-
ind = repeat(1:n, inner = n)
183-
indadd = (0:(n-1)) * n
184-
for i in 1:n
185-
ind[((i-1)*n+1):i*n] .+= indadd
186-
end
187-
188-
@views @inbounds B .+= A[ind, :]
189-
190-
return B
191-
end
192-
193-
function get_commutation_lookup(n2::Int64)
194-
n = Int(sqrt(n2))
195-
ind = repeat(1:n, inner = n)
196-
indadd = (0:(n-1)) * n
197-
for i in 1:n
198-
ind[((i-1)*n+1):i*n] .+= indadd
199-
end
200-
201-
lookup = Dict{Int64, Int64}()
202-
203-
for i in 1:n2
204-
j = findall(x -> (x == i), ind)[1]
205-
push!(lookup, i => j)
206-
end
207-
208-
return lookup
209-
end
210-
211-
function commutation_matrix_pre_square!(A::SparseMatrixCSC, lookup) # comuptes B + KₙA
212-
for (i, rowind) in enumerate(A.rowval)
213-
A.rowval[i] = lookup[rowind]
214-
end
215-
end
216-
217-
function commutation_matrix_pre_square!(A::SparseMatrixCSC) # computes KₙA
218-
lookup = get_commutation_lookup(size(A, 2))
219-
commutation_matrix_pre_square!(A, lookup)
220-
end
221-
222-
function commutation_matrix_pre_square(A::SparseMatrixCSC)
223-
B = copy(A)
224-
commutation_matrix_pre_square!(B)
225-
return B
226-
end
227-
228-
function commutation_matrix_pre_square(A::SparseMatrixCSC, lookup)
229-
B = copy(A)
230-
commutation_matrix_pre_square!(B, lookup)
231-
return B
232-
end
233-
234-
function commutation_matrix_pre_square_add_mt!(B, A) # comuptes B + KₙA # 0 allocations but slower
235-
n2 = size(A, 1)
236-
n = Int(sqrt(n2))
237-
238-
indadd = (0:(n-1)) * n
239-
240-
Threads.@threads for i in 1:n
241-
for j in 1:n
242-
row = i + indadd[j]
243-
@views @inbounds B[row, :] .+= A[row, :]
244-
end
245-
end
246-
247-
return B
248-
end
249-
250147
# returns the vector of non-unique values in the order of appearance
251148
# each non-unique values is reported once
252149
function nonunique(values::AbstractVector)

src/loss/ML/FIML.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ struct SemFIML{T, W} <: SemLossFunction{ExactHessian}
9191

9292
imp_inv::Matrix{T} # implied inverse
9393

94-
commutation_indices::Dict{Int, Int}
94+
commutator::CommutationMatrix
9595

9696
interaction::W
9797
end
@@ -104,7 +104,7 @@ function SemFIML(; observed::SemObservedMissing, specification, kwargs...)
104104
return SemFIML(
105105
[SemFIMLPattern(pat) for pat in observed.patterns],
106106
zeros(n_man(observed), n_man(observed)),
107-
get_commutation_lookup(nvars(specification)^2),
107+
CommutationMatrix(nvars(specification)),
108108
nothing,
109109
)
110110
end
@@ -158,13 +158,12 @@ function ∇F_fiml_outer!(G, JΣ, Jμ, fiml::SemFIML, imply::SemImplySymbolic, m
158158
end
159159

160160
function ∇F_fiml_outer!(G, JΣ, Jμ, fiml::SemFIML, imply, model)
161-
Iₙ = sparse(1.0I, size(imply.A)...)
162161
P = kron(imply.F⨉I_A⁻¹, imply.F⨉I_A⁻¹)
162+
Iₙ = sparse(1.0I, size(imply.A)...)
163163
Q = kron(imply.S * imply.I_A⁻¹', Iₙ)
164-
#commutation_matrix_pre_square_add!(Q, Q)
165-
Q2 = commutation_matrix_pre_square(Q, fiml.commutation_indices)
164+
Q .+= fiml.commutator * Q
166165

167-
∇Σ = P * (imply.∇S + (Q + Q2) * imply.∇A)
166+
∇Σ = P * (imply.∇S + Q * imply.∇A)
168167

169168
∇μ = imply.F⨉I_A⁻¹ * imply.∇M + kron((imply.I_A⁻¹ * imply.M)', imply.F⨉I_A⁻¹) * imply.∇A
170169

0 commit comments

Comments
 (0)