Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/calculus.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Regularize
Postcompose
Precompose
PrecomposeDiagonal
PrecomposedSlicedSeparableSum
Tilt
Translate
ReshapeInput
Expand Down
1 change: 1 addition & 0 deletions src/ProximalOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ include("calculus/regularize.jl")
include("calculus/separableSum.jl")
include("calculus/slicedSeparableSum.jl")
include("calculus/reshapeInput.jl")
include("calculus/precomposedSlicedSeparableSum.jl")
include("calculus/sqrDistL2.jl")
include("calculus/tilt.jl")
include("calculus/translate.jl")
Expand Down
152 changes: 152 additions & 0 deletions src/calculus/precomposedSlicedSeparableSum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Separable sum, using slices of an array as variables

export PrecomposedSlicedSeparableSum

"""
precomposedSlicedSeparableSum((f_1, ..., f_k), (J_1, ..., J_k), (L_1, ..., L_k))

Return the function
```math
g(x) = \\sum_{i=1}^k f_i(L_i * x_{J_i}).
```

precomposedSlicedSeparableSum(f, (J_1, ..., J_k), (L_1, ..., L_k))

Analogous to the previous one, but apply the same function `f` to all slices
of the variable `x`:
```math
g(x) = \\sum_{i=1}^k f(L_i * x_{J_i}).
```
"""
struct PrecomposedSlicedSeparableSum{S <: Tuple, T <: AbstractArray, U <: AbstractArray, V <: AbstractArray, N}
fs::S # Tuple, where each element is a Vector with elements of the same type; the functions to prox on
# Example: S = Tuple{Array{ProximalOperators.NormL1{Float64},1}, Array{ProximalOperators.NormL2{Float64},1}}
idxs::T # Vector, where each element is a Vector containing the indices to prox on
# Example: T = Array{Array{Tuple{Colon,UnitRange{Int64}},1},1}
ops::U # Vector of operations (matrices or AbstractOperators) to apply to the function
# Example: U = Array{Array{Matrix{Float64},1},1}
μs::V # Vector of mu values for each function
end

function PrecomposedSlicedSeparableSum(fs::Tuple, idxs::Tuple, ops::Tuple, μs::Tuple)
@assert length(fs) == length(idxs)
@assert length(fs) == length(ops)
ftypes = DataType[]
fsarr = Array{Any,1}[]
indarr = Array{eltype(idxs),1}[]
opsarr = Array{Any,1}[]
μsarr = Array{Any,1}[]
for (i,f) in enumerate(fs)
t = typeof(f)
fi = findfirst(isequal(t), ftypes)
if fi === nothing
push!(ftypes, t)
push!(fsarr, Any[f])
push!(indarr, eltype(idxs)[idxs[i]])
push!(opsarr, Any[ops[i]])
push!(μsarr, Any[μs[i]])
else
push!(fsarr[fi], f)
push!(indarr[fi], idxs[i])
push!(opsarr[fi], ops[i])
push!(μsarr[fi], μs[i])
end
end
fsnew = ((Array{typeof(fs[1]),1}(fs) for fs in fsarr)...,)
@assert typeof(fsnew) == Tuple{(Array{ft,1} for ft in ftypes)...}
PrecomposedSlicedSeparableSum{typeof(fsnew),typeof(indarr),typeof(opsarr),typeof(μsarr),length(fsnew)}(fsnew, indarr, opsarr, μsarr)
end

# Constructor for the case where the same function is applied to all slices
PrecomposedSlicedSeparableSum(f::F, idxs::T, ops::U, μs::V) where {F, T <: Tuple, U <: Tuple, V <: Tuple} =
PrecomposedSlicedSeparableSum(Tuple(f for k in eachindex(idxs)), idxs, ops, μs)

# Unroll the loop over the different types of functions to evaluate
function (f::PrecomposedSlicedSeparableSum)(x::Tuple)
v = zero(eltype(x[1]))
for (fs_group, idxs_group, ops_group) = zip(f.fs, f.idxs, f.ops) # For each function type
for (fun, idx_group, hcat_op) in zip(fs_group, idxs_group, ops_group) # For each function of that type
for (var_index, (x_var, idx)) in enumerate(zip(x, idx_group))
if idx isa Tuple
v += fun(hcat_op[var_index] * view(x_var, idx...))
elseif idx isa Colon
v += fun(hcat_op[var_index] * x_var)
elseif idx isa Nothing
# do nothing
else
v += fun(hcat_op[var_index] * view(x_var, idx))
end
end
end
end
return v
end

function slice_var(x, idx)
if idx isa Tuple
return view(x, idx...)
elseif idx isa Colon
return x
else
return view(x, idx)
end
end

# Unroll the loop over the different types of functions to prox on
function prox!(y::Tuple, f::PrecomposedSlicedSeparableSum, x::Tuple, gamma)
v = zero(eltype(x[1]))
for (fs_group, idxs_group, ops_group, μ_group) = zip(f.fs, f.idxs, f.ops, f.μs) # For each function type
for (fun, idx_group, hcat_op, μ) in zip(fs_group, idxs_group, ops_group, μ_group) # For each function of that type
for (idx, op, x_var, y_var) in zip(idx_group, hcat_op, x, y)
if idx isa Nothing
continue
end
sliced_x = slice_var(x_var, idx)
sliced_y = slice_var(y_var, idx)
res = op * sliced_x
prox_res, g = prox(fun, res, μ.*gamma)
prox_res .-= res
prox_res ./= μ
mul!(sliced_y, adjoint(op), prox_res)
sliced_y .+= sliced_x
v += g
end
end
end
return v
end

component_types(::Type{PrecomposedSlicedSeparableSum{S, T, N}}) where {S, T, N} = Tuple(A.parameters[1] for A in fieldtypes(S))

@generated is_proximable(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_proximable, component_types(T)) ? :(true) : :(false)
@generated is_convex(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_convex, component_types(T)) ? :(true) : :(false)
@generated is_set_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_set_indicator, component_types(T)) ? :(true) : :(false)
@generated is_singleton_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_singleton_indicator, component_types(T)) ? :(true) : :(false)
@generated is_cone_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_cone_indicator, component_types(T)) ? :(true) : :(false)
@generated is_affine_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_affine_indicator, component_types(T)) ? :(true) : :(false)
@generated is_smooth(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_smooth, component_types(T)) ? :(true) : :(false)
@generated is_generalized_quadratic(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_generalized_quadratic, component_types(T)) ? :(true) : :(false)
@generated is_strongly_convex(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_strongly_convex, component_types(T)) ? :(true) : :(false)

function prox_naive(f::PrecomposedSlicedSeparableSum, x, gamma)
fy = 0
y = similar.(x)
for (fs_group, idxs_group, ops_group, μ_group) = zip(f.fs, f.idxs, f.ops, f.μs) # For each function type
for (fun, idx_group, hcat_op, μ) in zip(fs_group, idxs_group, ops_group, μ_group) # For each function of that type
for (idx, op, x_var, y_var) in zip(idx_group, hcat_op, x, y)
if idx isa Nothing
continue
end
sliced_x = slice_var(x_var, idx)
sliced_y = slice_var(y_var, idx)
res = op * sliced_x
prox_res, _fy = prox_naive(fun, res, μ.*gamma)
prox_res = (prox_res .- res) ./ μ
mul!(sliced_y, adjoint(op), prox_res)
fy += _fy
sliced_y .+= sliced_x
end
end
end
return y, fy
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ end
include("test_regularize.jl")
include("test_separableSum.jl")
include("test_slicedSeparableSum.jl")
include("test_precomposedSlicedSeparableSum.jl")
include("test_sum.jl")
include("test_reshapeInput.jl")
end
Expand Down
48 changes: 48 additions & 0 deletions test/test_precomposedSlicedSeparableSum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using Test
using Random
using ProximalOperators
using LinearAlgebra

Random.seed!(1234)

# x = (randn(10), randn(10))
# norm(x[1], 1) + norm(A2[1:5, 1:5] * x[2][1:5], 2) + norm(A2[6:10, 6:10] * x[2][6:10], 2)^2

@testset "PrecomposedSlicedSeparableSum" begin

fs = (NormL1(), NormL2(), SqrNormL2())

A1 = (Diagonal(ones(10)), nothing)
F = qr(randn(5, 5))
A2 = (nothing, Matrix(F.Q))
F = qr(randn(5, 5))
A3 = (nothing, Matrix(F.Q))
mu = rand(5)
A3[2] .*= reshape(mu, 5, 1)
ops = (A1, A2, A3)

idxs = ((Colon(), nothing), (nothing, 1:5), (nothing, 6:10))
μs = (1.0, 1.0, mu)

AAc2 = A2[2] * A2[2]'
@test AAc2 ≈ I
AAc3 = A3[2] * A3[2]'
@test AAc3 ≈ Diagonal(mu) .^ 2

f = PrecomposedSlicedSeparableSum(fs, idxs, ops, μs)
x = (randn(10), rand(10))
y = (zeros(10), zeros(10))
fy = prox!(y, f, x, 1.0)
yn, fyn = ProximalOperators.prox_naive(f, x, 1.0)
y1, fy1 = prox(NormL1(), x[1], 1.0)
y2, fy2 = prox(Precompose(NormL2(), A2[2], 1), x[2][1:5], 1.0)
y3, fy3 = prox(Precompose(SqrNormL2(), A3[2], mu), x[2][6:10], 1.0)

@test abs(fyn-fy)<1e-11
@test norm(yn[1]-y[1])+norm(yn[2]-y[2])<1e-11
@test abs((fy1+fy2+fy3)-fy)<1e-11
@test norm(y[1] - y1) < 1e-11
@test norm(y[2][1:5] - y2) < 1e-11
@test norm(y[2][6:10] - y3) < 1e-11

end
Loading