Skip to content

Commit a191de8

Browse files
committed
CommutationMatrix type
replace comm_matrix helper functions with a CommutationMatrix and overloaded linalg ops
1 parent 75d5d1c commit a191de8

4 files changed

Lines changed: 68 additions & 109 deletions

File tree

src/StructuralEquationModels.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ const SEM = StructuralEquationModels
2323
# type hierarchy
2424
include("types.jl")
2525
include("objective_gradient_hessian.jl")
26+
27+
# helper objects and functions
28+
include("additional_functions/commutation_matrix.jl")
29+
2630
# fitted objects
2731
include("frontend/fit/SemFit.jl")
2832
# specification of models
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# transpose linear indices of the square n×n matrix
2+
# i.e.
3+
# 1 4
4+
# 2 5 => 1 2 3
5+
# 3 6 4 5 6
6+
transpose_linear_indices(n::Integer, m::Integer = n) =
7+
repeat(1:n, inner = m) .+ repeat((0:(m-1)) * n, outer = n)
8+
9+
"""
10+
CommutationMatrix(n::Integer) <: AbstractMatrix{Int}
11+
12+
A *commutation matrix* *C* is a n²×n² matrix of 0s and 1s.
13+
If *vec(A)* is a vectorized form of a n×n matrix *A*,
14+
then ``C * vec(A) = vec(Aᵀ)``.
15+
"""
16+
struct CommutationMatrix <: AbstractMatrix{Int}
17+
n::Int
18+
::Int
19+
transpose_inds::Vector{Int} # maps the linear indices of n×n matrix *B* to the indices of matrix *B'*
20+
21+
CommutationMatrix(n::Integer) = new(n, n^2, transpose_linear_indices(n))
22+
end
23+
24+
Base.size(A::CommutationMatrix) = (A.n², A.n²)
25+
Base.size(A::CommutationMatrix, dim::Integer) =
26+
1 <= dim <= 2 ? A.: throw(ArgumentError("invalid matrix dimension $dim"))
27+
Base.length(A::CommutationMatrix) = A.^2
28+
Base.getindex(A::CommutationMatrix, i::Int, j::Int) = j == A.transpose_inds[i] ? 1 : 0
29+
30+
function Base.:(*)(A::CommutationMatrix, B::AbstractMatrix)
31+
size(A, 2) == size(B, 1) || throw(
32+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
33+
)
34+
return B[A.transpose_inds, :]
35+
end
36+
37+
function Base.:(*)(A::CommutationMatrix, B::SparseMatrixCSC)
38+
size(A, 2) == size(B, 1) || throw(
39+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
40+
)
41+
return SparseMatrixCSC(
42+
size(B, 1),
43+
size(B, 2),
44+
copy(B.colptr),
45+
A.transpose_inds[B.rowval],
46+
copy(B.nzval),
47+
)
48+
end
49+
50+
function LinearAlgebra.lmul!(A::CommutationMatrix, B::SparseMatrixCSC)
51+
size(A, 2) == size(B, 1) || throw(
52+
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
53+
)
54+
55+
@inbounds for (i, rowind) in enumerate(B.rowval)
56+
B.rowval[i] = A.transpose_inds[rowind]
57+
end
58+
return B
59+
end

src/additional_functions/helper.jl

Lines changed: 0 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -142,109 +142,6 @@ function elimination_matrix(nobs)
142142
return L
143143
end
144144

145-
function commutation_matrix(n; tosparse = false)
146-
M = zeros(n^2, n^2)
147-
148-
for i in 1:n
149-
for j in 1:n
150-
M[i+n*(j-1), j+n*(i-1)] = 1.0
151-
end
152-
end
153-
154-
if tosparse
155-
M = sparse(M)
156-
end
157-
158-
return M
159-
end
160-
161-
function commutation_matrix_pre_square(A)
162-
n2 = size(A, 1)
163-
n = Int(sqrt(n2))
164-
165-
ind = repeat(1:n, inner = n)
166-
indadd = (0:(n-1)) * n
167-
for i in 1:n
168-
ind[((i-1)*n+1):i*n] .+= indadd
169-
end
170-
171-
A_post = A[ind, :]
172-
173-
return A_post
174-
end
175-
176-
function commutation_matrix_pre_square_add!(B, A) # comuptes B + KₙA
177-
n2 = size(A, 1)
178-
n = Int(sqrt(n2))
179-
180-
ind = repeat(1:n, inner = n)
181-
indadd = (0:(n-1)) * n
182-
for i in 1:n
183-
ind[((i-1)*n+1):i*n] .+= indadd
184-
end
185-
186-
@views @inbounds B .+= A[ind, :]
187-
188-
return B
189-
end
190-
191-
function get_commutation_lookup(n2::Int64)
192-
n = Int(sqrt(n2))
193-
ind = repeat(1:n, inner = n)
194-
indadd = (0:(n-1)) * n
195-
for i in 1:n
196-
ind[((i-1)*n+1):i*n] .+= indadd
197-
end
198-
199-
lookup = Dict{Int64, Int64}()
200-
201-
for i in 1:n2
202-
j = findall(x -> (x == i), ind)[1]
203-
push!(lookup, i => j)
204-
end
205-
206-
return lookup
207-
end
208-
209-
function commutation_matrix_pre_square!(A::SparseMatrixCSC, lookup) # comuptes B + KₙA
210-
for (i, rowind) in enumerate(A.rowval)
211-
A.rowval[i] = lookup[rowind]
212-
end
213-
end
214-
215-
function commutation_matrix_pre_square!(A::SparseMatrixCSC) # computes KₙA
216-
lookup = get_commutation_lookup(size(A, 2))
217-
commutation_matrix_pre_square!(A, lookup)
218-
end
219-
220-
function commutation_matrix_pre_square(A::SparseMatrixCSC)
221-
B = copy(A)
222-
commutation_matrix_pre_square!(B)
223-
return B
224-
end
225-
226-
function commutation_matrix_pre_square(A::SparseMatrixCSC, lookup)
227-
B = copy(A)
228-
commutation_matrix_pre_square!(B, lookup)
229-
return B
230-
end
231-
232-
function commutation_matrix_pre_square_add_mt!(B, A) # comuptes B + KₙA # 0 allocations but slower
233-
n2 = size(A, 1)
234-
n = Int(sqrt(n2))
235-
236-
indadd = (0:(n-1)) * n
237-
238-
Threads.@threads for i in 1:n
239-
for j in 1:n
240-
row = i + indadd[j]
241-
@views @inbounds B[row, :] .+= A[row, :]
242-
end
243-
end
244-
245-
return B
246-
end
247-
248145
# returns the vector of non-unique values in the order of appearance
249146
# each non-unique values is reported once
250147
function nonunique(values::AbstractVector)

src/loss/ML/FIML.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ struct SemFIML{T, W} <: SemLossFunction{ExactHessian}
9191

9292
imp_inv::Matrix{T} # implied inverse
9393

94-
commutation_indices::Dict{Int, Int}
94+
commutator::CommutationMatrix
9595

9696
interaction::W
9797
end
@@ -104,7 +104,7 @@ function SemFIML(; observed::SemObservedMissing, specification, kwargs...)
104104
return SemFIML(
105105
[SemFIMLPattern(pat) for pat in observed.patterns],
106106
zeros(n_man(observed), n_man(observed)),
107-
get_commutation_lookup(nvars(specification)^2),
107+
CommutationMatrix(nvars(specification)),
108108
nothing,
109109
)
110110
end
@@ -158,13 +158,12 @@ function ∇F_fiml_outer!(G, JΣ, Jμ, fiml::SemFIML, imply::SemImplySymbolic, m
158158
end
159159

160160
function ∇F_fiml_outer!(G, JΣ, Jμ, fiml::SemFIML, imply, model)
161-
Iₙ = sparse(1.0I, size(imply.A)...)
162161
P = kron(imply.F⨉I_A⁻¹, imply.F⨉I_A⁻¹)
162+
Iₙ = sparse(1.0I, size(imply.A)...)
163163
Q = kron(imply.S * imply.I_A⁻¹', Iₙ)
164-
#commutation_matrix_pre_square_add!(Q, Q)
165-
Q2 = commutation_matrix_pre_square(Q, fiml.commutation_indices)
164+
Q .+= fiml.commutator * Q
166165

167-
∇Σ = P * (imply.∇S + (Q + Q2) * imply.∇A)
166+
∇Σ = P * (imply.∇S + Q * imply.∇A)
168167

169168
∇μ = imply.F⨉I_A⁻¹ * imply.∇M + kron((imply.I_A⁻¹ * imply.M)', imply.F⨉I_A⁻¹) * imply.∇A
170169

0 commit comments

Comments
 (0)