-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathProximalAlgorithms.jl
More file actions
114 lines (97 loc) · 4.11 KB
/
ProximalAlgorithms.jl
File metadata and controls
114 lines (97 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
############################################################################################
### Types
############################################################################################
mutable struct SemOptimizerProximal{A, B, C} <: SemOptimizer{:Proximal}
algorithm::A
operator_g::B
operator_h::C
end
"""
SemOptimizerProximal(;
algorithm = ProximalAlgorithms.PANOC(),
operator_g,
operator_h = nothing,
kwargs...,
)
Connects to `ProximalAlgorithms.jl` as the optimization backend. For more information on
the available algorithms and options, see the online docs on [Regularization](@ref) and
the documentation of [*ProximalAlgorithms.jl*](https://github.com/JuliaFirstOrder/ProximalAlgorithms.jl) /
[ProximalOperators.jl](https://github.com/JuliaFirstOrder/ProximalOperators.jl).
# Arguments
- `algorithm`: proximal optimization algorithm.
- `operator_g`: proximal operator (e.g., regularization penalty)
- `operator_h`: optional second proximal operator
"""
SemOptimizerProximal(;
algorithm = ProximalAlgorithms.PANOC(),
operator_g,
operator_h = nothing,
kwargs...,
) = SemOptimizerProximal(algorithm, operator_g, operator_h)
SEM.sem_optimizer_subtype(::Val{:Proximal}) = SemOptimizerProximal
############################################################################################
### Recommended methods
############################################################################################
SEM.update_observed(optimizer::SemOptimizerProximal, observed::SemObserved; kwargs...) =
optimizer
############################################################################
### Model fitting
############################################################################
## connect to ProximalAlgorithms.jl
function ProximalAlgorithms.value_and_gradient(model::AbstractSem, params)
grad = similar(params)
obj = SEM.evaluate!(zero(eltype(params)), grad, nothing, model, params)
return obj, grad
end
# wrapper for the Proximal optimization result
struct ProximalResult{O <: SemOptimizer{:Proximal}} <: SEM.SemOptimizerResult{O}
optimizer::O
minimum::Float64
n_iterations::Int
end
function SEM.fit(
optim::SemOptimizerProximal,
model::AbstractSem,
start_params::AbstractVector;
kwargs...,
)
if isnothing(optim.operator_h)
solution, niterations =
optim.algorithm(x0 = start_params, f = model, g = optim.operator_g)
else
solution, niterations = optim.algorithm(
x0 = start_params,
f = model,
g = optim.operator_g,
h = optim.operator_h,
)
end
optim_res = ProximalResult(optim, objective!(model, solution), niterations)
return SemFit(optim_res.minimum, solution, start_params, model, optim_res)
end
############################################################################################
### additional methods
############################################################################################
SEM.algorithm_name(res::ProximalResult) = SEM.algorithm_name(res.optimizer.algorithm)
SEM.algorithm_name(
::ProximalAlgorithms.IterativeAlgorithm{I, H, S, D, K},
) where {I, H, S, D, K} = nameof(I)
SEM.convergence(
::ProximalResult,
) = "No standard convergence criteria for proximal \n algorithms available."
SEM.converged(::ProximalResult) = missing
SEM.n_iterations(res::ProximalResult) = res.n_iterations
############################################################################################
# pretty printing
############################################################################################
function Base.show(io::IO, struct_inst::SemOptimizerProximal)
print_type_name(io, struct_inst)
print_field_types(io, struct_inst)
end
function Base.show(io::IO, result::ProximalResult)
print(io, "Minimum: $(round(result.minimum; digits = 2)) \n")
print(io, "No. evaluations: $(result.n_iterations) \n")
print(io, "Operator: $(nameof(typeof(result.optimizer.operator_g))) \n")
op_h = result.optimizer.operator_h
isnothing(op_h) || print(io, "Second Operator: $(nameof(typeof(op_h))) \n")
end