Skip to content
Open
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ version = "1.0.1"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[compat]
TimerOutputs = "0.5.26"
julia = "1.7"

[extras]
Expand Down
117 changes: 55 additions & 62 deletions src/QRupdate.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
module QRupdate

using LinearAlgebra
using LinearAlgebra, TimerOutputs
to = TimerOutput()

export qraddcol, qraddcol!, qraddrow, qrdelcol, qrdelcol!, csne, csne!

function swapcols!(M::Matrix{T},i::Int,j::Int) where {T}
Base.permutecols!!(M, replace(axes(M,2), i=>j, j=>i))
end

ORTHO_TOL = 1e-6 # err = |unew|^2 / |uold|^2 < ORTHO_TOL
ORTHO_MAX_IT = 1
verbose = true

"""
Triangular solve `Rx = b`, where `R` is upper triangular of size `realSize x realSize`. The storage of `R` is described in the documentation for `qraddcol!`.
Expand All @@ -23,8 +21,7 @@ function solveR!(R::AbstractMatrix{T}, b::AbstractVector{T}, sol::AbstractVector
# provide a view only, so they do not incur in further memory allocations. I
# verified this with BenchmarkTooks.
# N Barnafi 06/11/24
sol .= b
ldiv!(UpperTriangular(R), sol)
ldiv!(sol, UpperTriangular(R), b)
end

"""
Expand All @@ -33,8 +30,7 @@ Triangular solve `R'x = b`, where `R` is upper triangular of size `realSize x re
function solveRT!(R::AbstractMatrix{T}, b::AbstractVector{T}, sol::AbstractVector{T}) where {T}
# Note: R is upper triangular.
# Note 2: We solve for the conjugate transpose.
sol .= b
ldiv!(LowerTriangular(R'), sol)
ldiv!(sol, LowerTriangular(R'), b)
end


Expand Down Expand Up @@ -132,19 +128,18 @@ R = [0 0 0 R = [r11 0 0 R = [r11 r12 0
0 0 0 0 0 0 0 r22 0
0 0 0] 0 0 0] 0 0 0]
"""
#function qraddcol!(A::AT, R::RT, a::aT, N::Int64, work::wT, work2::w2T, u::uT, z::zT, r::rT) where {AT,RT,aT,wT,w2T,uT,zT,rT,T}
function qraddcol!(A::AbstractMatrix{T}, R::AbstractMatrix{T}, a::AbstractVector{T}, N::Int64, work::AbstractVector{T}, work2::AbstractVector{T}, u::AbstractVector{T}, z::AbstractVector{T}, r::AbstractVector{T}) where {T}
function qraddcol!(A::AbstractMatrix{T}, R::AbstractMatrix{T}, a::AbstractVector{T}, N::Int64, work::AbstractVector{T}, work2::AbstractVector{T}, u::AbstractVector{T}, z::AbstractVector{T}, r::AbstractVector{T}; updateMat::Bool=true, verbose::Bool=false, log::Bool=false, ortho_tol::Float64=1e-14, ortho_max_it::Int=1) where {T}
#c,u,z,du,dz are R^n. Only r is R^m
#c -> work; du -> work2. dz is redundant

#@timeit "get views" begin
m, n = size(A)
@assert size(work,1) == n "Expected "*string(n)*", actual size: " * string(size(work))
@assert size(work2,1) == n
@assert size(u,1) == n
@assert size(z,1) == n
@assert size(r,1) == m

@timeit to "Extract views" begin
if N < n
cols = 1:N
Atr = view(A, :, cols) #truncated
Expand All @@ -161,68 +156,67 @@ function qraddcol!(A::AbstractMatrix{T}, R::AbstractMatrix{T}, a::AbstractVector
u_tr = u
z_tr = z
end
#end #timeit get views
end # Extract views"

#@timeit "norms" begin
anorm = norm(a)
anorm2 = anorm^2

@timeit to "Base case" begin
if N == 0
# First iteration is simpler
anorm = sqrt(anorm2)
R[1,1] = anorm
view(A,:,N+1) .= a
return
updateMat && view(A,:,N+1) .= a
return 0
end
#end #timeit norms
end #timeit Base case


# work := c = A'a
mul!(work_tr, Atr', a)
solveRT!(Rtr, work_tr, u_tr) #u = R'\c = R'\work
solveR!(Rtr, u_tr, z_tr) #z = R\u
copy!(r, a)
mul!(r, Atr, z_tr, -1, 1) #r = a - A*z
@timeit to "Default iteration" begin
@timeit to "A'a" mul!(work_tr, Atr', a)
@timeit to "solveRT" solveRT!(Rtr, work_tr, u_tr) #u = R'\c = R'\work
@timeit to "solveR" solveR!(Rtr, u_tr, z_tr) #z = R\u
@timeit to "copy" copy!(r, a)
@timeit to "mul!" mul!(r, Atr, z_tr, -1, 1) #r = a - A*z
γ = norm(r)
mul!(work_tr, Atr', r) # r := c = A'r
@timeit to "mul: c=A'r" mul!(work_tr, Atr', r) # r := c = A'r
err = norm(work_tr) / sqrt(anorm2)
end # timeit Default iteration

# Iterative refinement
if err < ORTHO_TOL
view(R,1:N,N+1) .= view(u, 1:N)
i = 0
if err < ortho_tol
@timeit to "No refinement update" begin
view(R,1:N,N+1) .= u_tr
R[N+1,N+1] = γ
view(A,:,N+1) .= a
return
updateMat && view(A,:,N+1) .= a
end # timeit No refinement
else

i = 0
while err > ORTHO_TOL && i < ORTHO_MAX_IT

solveRT!(Rtr, work_tr, work2_tr) # work2 := du = R'\c
solveR!(Rtr, work2_tr, work_tr) # work := dz = R\du
axpy!(1.0, work_tr, z_tr) #z += dz # Refine z
#@timeit "residual 2" begin

copy!(r, a)
mul!(r, Atr, z_tr, -1.0, 1.0) #r = a - A*z
γ = norm(r)
work .= 0.0
mul!(work_tr, Atr', r) # work := c = A'r

err = norm(work_tr) / sqrt(anorm2)
#verbose && println(" *** Reorthogonalize ",string(i)," . Error:", err)
verbose && print("*")
i += 1


#if !iszero(β)
#γ = sqrt(γ^2 + β2*norm(z)^2 + β2)
#end
end # while
@timeit "Reorthogonalize" while err > ortho_tol && i < ortho_max_it

solveRT!(Rtr, work_tr, work2_tr) # work2 := du = R'\c
axpy!(1.0, work2_tr, u_tr) # Refine u
solveR!(Rtr, work2_tr, work_tr) # work := dz = R\du
axpy!(1.0, work_tr, z_tr) #z += dz # Refine z

copy!(r, a)
mul!(r, Atr, z_tr, -1.0, 1.0) #r = a - A*z
γ = norm(r)
mul!(work_tr, Atr', r) # work := c = A'r

err = norm(work_tr) / sqrt(anorm2)
i += 1
end # while
verbose && println(" *** $(i) reorthogonalization steps. Error:", err)

@timeit to axpy!(1, work2_tr, u_tr)
@timeit to "Update R" view(R,1:N,N+1) .= u_tr
@timeit to "Update R" R[N+1,N+1] = γ
@timeit "Update A" updateMat && view(A,:,N+1) .= a
log && return i
end # if

axpy!(1, work2_tr, u_tr)
view(R,1:N,N+1) .= view(u, 1:N)
R[N+1,N+1] = γ
view(A,:,N+1) .= a
return 0
end

"""
Expand Down Expand Up @@ -323,8 +317,6 @@ function qrdelcol!(A::AbstractMatrix{T}, R::AbstractMatrix{T}, k::Integer) where
end
R[j+1,j] = zero(T)
end
#end # timeit shift row
#end #timeit all
end

"""
Expand Down Expand Up @@ -358,7 +350,7 @@ function csne(Rin::AbstractMatrix{T}, A::AbstractMatrix{T}, b::Vector{T}) where
return (x, r)
end

function csne!(R::RT, A::AT, b::bT, sol::solT, work::wT, work2::w2T, u::uT, r::rT, N::Int) where {RT,AT,bT,solT,wT,w2T,uT,rT}
function csne!(R::RT, A::AT, b::bT, sol::solT, work::wT, work2::w2T, u::uT, r::rT, N::Int; verbose::Bool=false, log::Bool=false, ortho_tol::Float64=1e-14, ortho_max_it::Int=1) where {RT,AT,bT,solT,wT,w2T,uT,rT}
#c,u,sol,du are R^n. Only r is R^m
#c -> work; du -> work2. dsol is redundant.

Expand Down Expand Up @@ -397,7 +389,7 @@ function csne!(R::RT, A::AT, b::bT, sol::solT, work::wT, work2::w2T, u::uT, r::
err = norm(work_tr) / bnorm

i = 0
while err > ORTHO_TOL && i < ORTHO_MAX_IT
while err > ortho_tol && i < ortho_max_it

solveRT!(Rtr, work_tr, work2_tr) # work2 := du = R'\c
solveR!(Rtr, work2_tr, work_tr) # work := dz = R\du
Expand All @@ -408,11 +400,12 @@ function csne!(R::RT, A::AT, b::bT, sol::solT, work::wT, work2::w2T, u::uT, r::
mul!(work_tr, Atr', r) # work := c = A'r

err = norm(work_tr) / bnorm
#verbose && println(" *** Reorthogonalize ",string(i), " CSNE. Error:", err)
verbose && print("*")
i += 1

end

i > 0 && verbose && println(" *** $(i) CSNE reorthogonalization steps. Error:", err)
log && return i
end

end # module
24 changes: 24 additions & 0 deletions test/memory-usage-test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using QRupdate, TimerOutputs

m = 1000
n = 10
T = ComplexF64
work = rand(T, n)
work2 = rand(T, n)
work3 = rand(T, n)
work4 = rand(T, n)
work5 = rand(T, m)
Rin = zeros(T,n,n)
Ain = zeros(T,m,n)

reset_timer!(QRupdate.to)
for i in 1:n
a = randn(T, m)
@timeit QRupdate.to "add col" qraddcol!(Ain, Rin, a, i-1, work, work2, work3, work4, work5)
Qin = view(Ain, :, 1:i)*inv(view(Rin, 1:i, 1:i))
end

println(QRupdate.to)
println("N of rows: $(m)")