From 86c455cadede72201192a33c6c3b52023a02a156 Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Mon, 14 Apr 2025 19:22:19 +0200 Subject: [PATCH 1/2] Add PrecomposedSlicedSeparableSum implementation and corresponding tests --- src/ProximalOperators.jl | 1 + src/calculus/precomposedSlicedSeparableSum.jl | 152 ++++++++++++++++++ test/runtests.jl | 1 + test/test_precomposedSlicedSeparableSum.jl | 48 ++++++ 4 files changed, 202 insertions(+) create mode 100644 src/calculus/precomposedSlicedSeparableSum.jl create mode 100644 test/test_precomposedSlicedSeparableSum.jl diff --git a/src/ProximalOperators.jl b/src/ProximalOperators.jl index a1f11a7..65c99ed 100644 --- a/src/ProximalOperators.jl +++ b/src/ProximalOperators.jl @@ -92,6 +92,7 @@ include("calculus/regularize.jl") include("calculus/separableSum.jl") include("calculus/slicedSeparableSum.jl") include("calculus/reshapeInput.jl") +include("calculus/precomposedSlicedSeparableSum.jl") include("calculus/sqrDistL2.jl") include("calculus/tilt.jl") include("calculus/translate.jl") diff --git a/src/calculus/precomposedSlicedSeparableSum.jl b/src/calculus/precomposedSlicedSeparableSum.jl new file mode 100644 index 0000000..09796e2 --- /dev/null +++ b/src/calculus/precomposedSlicedSeparableSum.jl @@ -0,0 +1,152 @@ +# Separable sum, using slices of an array as variables + +export PrecomposedSlicedSeparableSum + +""" + precomposedSlicedSeparableSum((f_1, ..., f_k), (J_1, ..., J_k), (L_1, ..., L_k)) + +Return the function +```math +g(x) = \\sum_{i=1}^k f_i(L_i * x_{J_i}). +``` + + precomposedSlicedSeparableSum(f, (J_1, ..., J_k), (L_1, ..., L_k)) + +Analogous to the previous one, but apply the same function `f` to all slices +of the variable `x`: +```math +g(x) = \\sum_{i=1}^k f(L_i * x_{J_i}). +``` +""" +struct PrecomposedSlicedSeparableSum{S <: Tuple, T <: AbstractArray, U <: AbstractArray, V <: AbstractArray, N} + fs::S # Tuple, where each element is a Vector with elements of the same type; the functions to prox on + # Example: S = Tuple{Array{ProximalOperators.NormL1{Float64},1}, Array{ProximalOperators.NormL2{Float64},1}} + idxs::T # Vector, where each element is a Vector containing the indices to prox on + # Example: T = Array{Array{Tuple{Colon,UnitRange{Int64}},1},1} + ops::U # Vector of operations (matrices or AbstractOperators) to apply to the function + # Example: U = Array{Array{Matrix{Float64},1},1} + μs::V # Vector of mu values for each function +end + +function PrecomposedSlicedSeparableSum(fs::Tuple, idxs::Tuple, ops::Tuple, μs::Tuple) + @assert length(fs) == length(idxs) + @assert length(fs) == length(ops) + ftypes = DataType[] + fsarr = Array{Any,1}[] + indarr = Array{eltype(idxs),1}[] + opsarr = Array{Any,1}[] + μsarr = Array{Any,1}[] + for (i,f) in enumerate(fs) + t = typeof(f) + fi = findfirst(isequal(t), ftypes) + if fi === nothing + push!(ftypes, t) + push!(fsarr, Any[f]) + push!(indarr, eltype(idxs)[idxs[i]]) + push!(opsarr, Any[ops[i]]) + push!(μsarr, Any[μs[i]]) + else + push!(fsarr[fi], f) + push!(indarr[fi], idxs[i]) + push!(opsarr[fi], ops[i]) + push!(μsarr[fi], μs[i]) + end + end + fsnew = ((Array{typeof(fs[1]),1}(fs) for fs in fsarr)...,) + @assert typeof(fsnew) == Tuple{(Array{ft,1} for ft in ftypes)...} + PrecomposedSlicedSeparableSum{typeof(fsnew),typeof(indarr),typeof(opsarr),typeof(μsarr),length(fsnew)}(fsnew, indarr, opsarr, μsarr) +end + +# Constructor for the case where the same function is applied to all slices +PrecomposedSlicedSeparableSum(f::F, idxs::T, ops::U, μs::V) where {F, T <: Tuple, U <: Tuple, V <: Tuple} = + PrecomposedSlicedSeparableSum(Tuple(f for k in eachindex(idxs)), idxs, ops, μs) + +# Unroll the loop over the different types of functions to evaluate +function (f::PrecomposedSlicedSeparableSum)(x::Tuple) + v = zero(eltype(x[1])) + for (fs_group, idxs_group, ops_group) = zip(f.fs, f.idxs, f.ops) # For each function type + for (fun, idx_group, hcat_op) in zip(fs_group, idxs_group, ops_group) # For each function of that type + for (var_index, (x_var, idx)) in enumerate(zip(x, idx_group)) + if idx isa Tuple + v += fun(hcat_op[var_index] * view(x_var, idx...)) + elseif idx isa Colon + v += fun(hcat_op[var_index] * x_var) + elseif idx isa Nothing + # do nothing + else + v += fun(hcat_op[var_index] * view(x_var, idx)) + end + end + end + end + return v +end + +function slice_var(x, idx) + if idx isa Tuple + return view(x, idx...) + elseif idx isa Colon + return x + else + return view(x, idx) + end +end + +# Unroll the loop over the different types of functions to prox on +function prox!(y::Tuple, f::PrecomposedSlicedSeparableSum, x::Tuple, gamma) + v = zero(eltype(x[1])) + for (fs_group, idxs_group, ops_group, μ_group) = zip(f.fs, f.idxs, f.ops, f.μs) # For each function type + for (fun, idx_group, hcat_op, μ) in zip(fs_group, idxs_group, ops_group, μ_group) # For each function of that type + for (idx, op, x_var, y_var) in zip(idx_group, hcat_op, x, y) + if idx isa Nothing + continue + end + sliced_x = slice_var(x_var, idx) + sliced_y = slice_var(y_var, idx) + res = op * sliced_x + prox_res, g = prox(fun, res, μ.*gamma) + prox_res .-= res + prox_res ./= μ + mul!(sliced_y, adjoint(op), prox_res) + sliced_y .+= sliced_x + v += g + end + end + end + return v +end + +component_types(::Type{PrecomposedSlicedSeparableSum{S, T, N}}) where {S, T, N} = Tuple(A.parameters[1] for A in fieldtypes(S)) + +@generated is_proximable(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_proximable, component_types(T)) ? :(true) : :(false) +@generated is_convex(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_convex, component_types(T)) ? :(true) : :(false) +@generated is_set_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_set_indicator, component_types(T)) ? :(true) : :(false) +@generated is_singleton_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_singleton_indicator, component_types(T)) ? :(true) : :(false) +@generated is_cone_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_cone_indicator, component_types(T)) ? :(true) : :(false) +@generated is_affine_indicator(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_affine_indicator, component_types(T)) ? :(true) : :(false) +@generated is_smooth(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_smooth, component_types(T)) ? :(true) : :(false) +@generated is_generalized_quadratic(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_generalized_quadratic, component_types(T)) ? :(true) : :(false) +@generated is_strongly_convex(::Type{T}) where T <: PrecomposedSlicedSeparableSum = return all(is_strongly_convex, component_types(T)) ? :(true) : :(false) + +function prox_naive(f::PrecomposedSlicedSeparableSum, x, gamma) + fy = 0 + y = similar.(x) + for (fs_group, idxs_group, ops_group, μ_group) = zip(f.fs, f.idxs, f.ops, f.μs) # For each function type + for (fun, idx_group, hcat_op, μ) in zip(fs_group, idxs_group, ops_group, μ_group) # For each function of that type + for (idx, op, x_var, y_var) in zip(idx_group, hcat_op, x, y) + if idx isa Nothing + continue + end + sliced_x = slice_var(x_var, idx) + sliced_y = slice_var(y_var, idx) + res = op * sliced_x + prox_res, _fy = prox_naive(fun, res, μ.*gamma) + prox_res = (prox_res .- res) ./ μ + mul!(sliced_y, adjoint(op), prox_res) + fy += _fy + sliced_y .+= sliced_x + end + end + end + return y, fy +end diff --git a/test/runtests.jl b/test/runtests.jl index 58a4e07..0447404 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -155,6 +155,7 @@ end include("test_regularize.jl") include("test_separableSum.jl") include("test_slicedSeparableSum.jl") + include("test_precomposedSlicedSeparableSum.jl") include("test_sum.jl") include("test_reshapeInput.jl") end diff --git a/test/test_precomposedSlicedSeparableSum.jl b/test/test_precomposedSlicedSeparableSum.jl new file mode 100644 index 0000000..dad1576 --- /dev/null +++ b/test/test_precomposedSlicedSeparableSum.jl @@ -0,0 +1,48 @@ +using Test +using Random +using ProximalOperators +using LinearAlgebra + +Random.seed!(1234) + +# x = (randn(10), randn(10)) +# norm(x[1], 1) + norm(A2[1:5, 1:5] * x[2][1:5], 2) + norm(A2[6:10, 6:10] * x[2][6:10], 2)^2 + +@testset "PrecomposedSlicedSeparableSum" begin + +fs = (NormL1(), NormL2(), SqrNormL2()) + +A1 = (Diagonal(ones(10)), nothing) +F = qr(randn(5, 5)) +A2 = (nothing, Matrix(F.Q)) +F = qr(randn(5, 5)) +A3 = (nothing, Matrix(F.Q)) +mu = rand(5) +A3[2] .*= reshape(mu, 5, 1) +ops = (A1, A2, A3) + +idxs = ((Colon(), nothing), (nothing, 1:5), (nothing, 6:10)) +μs = (1.0, 1.0, mu) + +AAc2 = A2[2] * A2[2]' +@test AAc2 ≈ I +AAc3 = A3[2] * A3[2]' +@test AAc3 ≈ Diagonal(mu) .^ 2 + +f = PrecomposedSlicedSeparableSum(fs, idxs, ops, μs) +x = (randn(10), rand(10)) +y = (zeros(10), zeros(10)) +fy = prox!(y, f, x, 1.0) +yn, fyn = ProximalOperators.prox_naive(f, x, 1.0) +y1, fy1 = prox(NormL1(), x[1], 1.0) +y2, fy2 = prox(Precompose(NormL2(), A2[2], 1), x[2][1:5], 1.0) +y3, fy3 = prox(Precompose(SqrNormL2(), A3[2], mu), x[2][6:10], 1.0) + +@test abs(fyn-fy)<1e-11 +@test norm(yn[1]-y[1])+norm(yn[2]-y[2])<1e-11 +@test abs((fy1+fy2+fy3)-fy)<1e-11 +@test norm(y[1] - y1) < 1e-11 +@test norm(y[2][1:5] - y2) < 1e-11 +@test norm(y[2][6:10] - y3) < 1e-11 + +end From 845ea58703951940ef1d60dbc1762c91ebdf3697 Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Mon, 13 Apr 2026 14:41:31 +0200 Subject: [PATCH 2/2] add PrecomposedSlicedSeparableSum to documentation --- docs/src/calculus.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/calculus.md b/docs/src/calculus.md index c0e3296..18aae25 100644 --- a/docs/src/calculus.md +++ b/docs/src/calculus.md @@ -30,6 +30,7 @@ Regularize Postcompose Precompose PrecomposeDiagonal +PrecomposedSlicedSeparableSum Tilt Translate ReshapeInput