Skip to content

Commit 38939b7

Browse files
author
Alexey Stukalov
committed
SemOptimizerResult: streamline optim results
1 parent 309c578 commit 38939b7

10 files changed

Lines changed: 71 additions & 95 deletions

File tree

docs/src/developer/optimizer.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ update_observed(optimizer::SemOptimizerName, observed::SemObserved; kwargs...) =
3030
### additional methods
3131
############################################################################################
3232

33-
algorithm(optimizer::SemOptimizerName) = optimizer.algorithm
3433
options(optimizer::SemOptimizerName) = optimizer.options
3534
```
3635

@@ -68,7 +67,7 @@ The method has to return a `SemFit` object that consists of the minimum of the o
6867
In addition, you might want to provide methods to access properties of your optimization result:
6968

7069
```julia
71-
optimizer(res::MyOptimizationResult) = ...
70+
algorithm_name(res::MyOptimizationResult) = ...
7271
n_iterations(res::MyOptimizationResult) = ...
7372
convergence(res::MyOptimizationResult) = ...
7473
```

ext/SEMNLOptExt/NLopt.jl

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -118,31 +118,32 @@ SEM.update_observed(optimizer::SemOptimizerNLopt, observed::SemObserved; kwargs.
118118
### additional methods
119119
############################################################################################
120120

121-
SEM.algorithm(optimizer::SemOptimizerNLopt) = optimizer.algorithm
122121
local_algorithm(optimizer::SemOptimizerNLopt) = optimizer.local_algorithm
123122
SEM.options(optimizer::SemOptimizerNLopt) = optimizer.options
124123
local_options(optimizer::SemOptimizerNLopt) = optimizer.local_options
125124
equality_constraints(optimizer::SemOptimizerNLopt) = optimizer.equality_constraints
126125
inequality_constraints(optimizer::SemOptimizerNLopt) = optimizer.inequality_constraints
127126

128-
struct NLoptResult
127+
# wrapper for the NLopt optimization result
128+
struct NLoptResult <: SEM.SemOptimizerResult{SemOptimizerNLopt}
129+
optimizer::SemOptimizerNLopt
129130
result::Any
130131
problem::Any
131132
end
132133

133-
SEM.optimizer(res::NLoptResult) = res.problem.algorithm
134+
SEM.algorithm_name(res::NLoptResult) = res.problem.algorithm
134135
SEM.n_iterations(res::NLoptResult) = res.problem.numevals
135136
SEM.convergence(res::NLoptResult) = res.result[3]
136137

137-
# construct SemFit from fitted NLopt object
138-
function SemFit_NLopt(optimization_result, model::AbstractSem, start_val, opt)
139-
return SemFit(
140-
optimization_result[1],
141-
optimization_result[2],
142-
start_val,
143-
model,
144-
NLoptResult(optimization_result, opt),
145-
)
138+
# construct NLopt.jl problem
139+
function NLopt_problem(algorithm, options, npar)
140+
problem = Opt(algorithm, npar)
141+
142+
for (key, val) in pairs(options)
143+
setproperty!(problem, key, val)
144+
end
145+
146+
return problem
146147
end
147148

148149
# fit method
@@ -152,8 +153,8 @@ function SEM.fit(
152153
start_params::AbstractVector;
153154
kwargs...,
154155
)
155-
opt = construct_NLopt(optim.algorithm, optim.options, nparams(model))
156-
opt.min_objective =
156+
problem = NLopt_problem(optim.algorithm, optim.options, nparams(model))
157+
problem.min_objective =
157158
(par, G) -> SEM.evaluate!(
158159
zero(eltype(par)),
159160
!isnothing(G) && !isempty(G) ? G : nothing,
@@ -162,36 +163,27 @@ function SEM.fit(
162163
par,
163164
)
164165
for (f, tol) in optim.inequality_constraints
165-
inequality_constraint!(opt, f, tol)
166+
inequality_constraint!(problem, f, tol)
166167
end
167168
for (f, tol) in optim.equality_constraints
168-
equality_constraint!(opt, f, tol)
169+
equality_constraint!(problem, f, tol)
169170
end
170171

171172
if !isnothing(optim.local_algorithm)
172-
opt_local =
173-
construct_NLopt(optim.local_algorithm, optim.local_options, nparams(model))
174-
opt.local_optimizer = opt_local
173+
problem.local_optimizer =
174+
NLopt_problem(optim.local_algorithm, optim.local_options, nparams(model))
175175
end
176176

177177
# fit
178-
result = NLopt.optimize(opt, start_params)
178+
result = NLopt.optimize(problem, start_params)
179179

180-
return SemFit_NLopt(result, model, start_params, opt)
181-
end
182-
183-
############################################################################################
184-
### additional functions
185-
############################################################################################
186-
187-
function construct_NLopt(algorithm, options, npar)
188-
opt = Opt(algorithm, npar)
189-
190-
for (key, val) in pairs(options)
191-
setproperty!(opt, key, val)
192-
end
193-
194-
return opt
180+
return SemFit(
181+
result[1], # minimum
182+
result[2], # optimal params
183+
start_val,
184+
model,
185+
NLoptResult(optim, result, problem),
186+
)
195187
end
196188

197189
############################################################################################

ext/SEMProximalOptExt/ProximalAlgorithms.jl

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,6 @@ SEM.sem_optimizer_subtype(::Val{:Proximal}) = SemOptimizerProximal
4141
SEM.update_observed(optimizer::SemOptimizerProximal, observed::SemObserved; kwargs...) =
4242
optimizer
4343

44-
############################################################################################
45-
### additional methods
46-
############################################################################################
47-
48-
SEM.algorithm(optimizer::SemOptimizerProximal) = optimizer.algorithm
49-
5044
############################################################################
5145
### Model fitting
5246
############################################################################
@@ -58,8 +52,11 @@ function ProximalAlgorithms.value_and_gradient(model::AbstractSem, params)
5852
return obj, grad
5953
end
6054

61-
mutable struct ProximalResult
62-
result::Any
55+
# wrapper for the Proximal optimization result
56+
struct ProximalResult{O <: SemOptimizer{:Proximal}} <: SEM.SemOptimizerResult{O}
57+
optimizer::O
58+
minimum::Float64
59+
n_iterations::Int
6360
end
6461

6562
function SEM.fit(
@@ -69,36 +66,20 @@ function SEM.fit(
6966
kwargs...,
7067
)
7168
if isnothing(optim.operator_h)
72-
solution, iterations =
69+
solution, niterations =
7370
optim.algorithm(x0 = start_params, f = model, g = optim.operator_g)
7471
else
75-
solution, iterations = optim.algorithm(
72+
solution, niterations = optim.algorithm(
7673
x0 = start_params,
7774
f = model,
7875
g = optim.operator_g,
7976
h = optim.operator_h,
8077
)
8178
end
8279

83-
minimum = objective!(model, solution)
80+
optim_res = ProximalResult(optim, objective!(model, solution), niterations)
8481

85-
optimization_result = Dict(
86-
:minimum => minimum,
87-
:iterations => iterations,
88-
:algorithm => optim.algorithm,
89-
:operator_g => optim.operator_g,
90-
)
91-
92-
isnothing(optim.operator_h) ||
93-
push!(optimization_result, :operator_h => optim.operator_h)
94-
95-
return SemFit(
96-
minimum,
97-
solution,
98-
start_params,
99-
model,
100-
ProximalResult(optimization_result),
101-
)
82+
return SemFit(optim_res.minimum, solution, start_params, model, optim_res)
10283
end
10384

10485
############################################################################################
@@ -125,10 +106,9 @@ function Base.show(io::IO, struct_inst::SemOptimizerProximal)
125106
end
126107

127108
function Base.show(io::IO, result::ProximalResult)
128-
print(io, "Minimum: $(round(result.result[:minimum]; digits = 2)) \n")
129-
print(io, "No. evaluations: $(result.result[:iterations]) \n")
130-
print(io, "Operator: $(nameof(typeof(result.result[:operator_g]))) \n")
131-
if haskey(result.result, :operator_h)
132-
print(io, "Second Operator: $(nameof(typeof(result.result[:operator_h]))) \n")
133-
end
109+
print(io, "Minimum: $(round(result.minimum; digits = 2)) \n")
110+
print(io, "No. evaluations: $(result.n_iterations) \n")
111+
print(io, "Operator: $(nameof(typeof(result.optimizer.operator_g))) \n")
112+
op_h = result.optimizer.operator_h
113+
isnothing(op_h) || print(io, "Second Operator: $(nameof(typeof(op_h))) \n")
134114
end

src/frontend/fit/SemFit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Fitted structural equation model.
1313
- `model(::SemFit)`
1414
- `optimization_result(::SemFit)`
1515
16-
- `optimizer(::SemFit)` -> optimization algorithm
16+
- `algorithm_name(::SemFit)` -> optimization algorithm
1717
- `n_iterations(::SemFit)` -> number of iterations
1818
- `convergence(::SemFit)` -> convergence properties
1919
"""
@@ -63,6 +63,6 @@ optimization_result(sem_fit::SemFit) = sem_fit.optimization_result
6363

6464
# optimizer properties
6565
optimizer_engine(sem_fit::SemFit) = optimizer_engine(optimization_result(sem_fit))
66-
optimizer(sem_fit::SemFit) = optimizer(optimization_result(sem_fit))
66+
algorithm_name(sem_fit::SemFit) = algorithm_name(optimization_result(sem_fit))
6767
n_iterations(sem_fit::SemFit) = n_iterations(optimization_result(sem_fit))
6868
convergence(sem_fit::SemFit) = convergence(optimization_result(sem_fit))

src/frontend/fit/summary.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ function details(sem_fit::SemFit; show_fitmeasures = false, color = :light_cyan,
77
color = color,
88
)
99
print("\n")
10-
println("Optimization algorithm: $(optimizer(sem_fit))")
10+
println("Optimization engine: $(optimizer_engine(sem_fit))")
11+
println("Optimization algorithm: $(algorithm_name(sem_fit))")
1112
println("Convergence: $(convergence(sem_fit))")
1213
println("No. iterations/evaluations: $(n_iterations(sem_fit))")
1314
print("\n")

src/optimizer/abstract.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ For a list of available engines, call [`optimizer_engines`](@ref).
8484
"""
8585
optimizer_engine_doc(engine) = Base.Docs.doc(sem_optimizer_subtype(engine))
8686

87+
optimizer(result::SemOptimizerResult) = result.optimizer
88+
89+
optimizer_engine(result::SemOptimizerResult) = optimizer_engine(result.optimizer)
90+
8791
"""
8892
fit([optim::SemOptimizer], model::AbstractSem;
8993
[engine::Symbol], start_val = start_val, kwargs...)

src/optimizer/optim.jl

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,26 +67,17 @@ update_observed(optimizer::SemOptimizerOptim, observed::SemObserved; kwargs...)
6767
### additional methods
6868
############################################################################################
6969

70-
algorithm(optimizer::SemOptimizerOptim) = optimizer.algorithm
7170
options(optimizer::SemOptimizerOptim) = optimizer.options
7271

73-
function SemFit(
74-
optimization_result::Optim.MultivariateOptimizationResults,
75-
model::AbstractSem,
76-
start_val,
77-
)
78-
return SemFit(
79-
optimization_result.minimum,
80-
optimization_result.minimizer,
81-
start_val,
82-
model,
83-
optimization_result,
84-
)
72+
# wrapper for the Optim.jl result
73+
struct SemOptimResult{O <: SemOptimizerOptim} <: SemOptimizerResult{O}
74+
optimizer::O
75+
result::Optim.MultivariateOptimizationResults
8576
end
8677

87-
optimizer(res::Optim.MultivariateOptimizationResults) = Optim.summary(res)
88-
n_iterations(res::Optim.MultivariateOptimizationResults) = Optim.iterations(res)
89-
convergence(res::Optim.MultivariateOptimizationResults) = Optim.converged(res)
78+
algorithm_name(res::SemOptimResult) = Optim.summary(res.result)
79+
n_iterations(res::SemOptimResult) = Optim.iterations(res.result)
80+
convergence(res::SemOptimResult) = Optim.converged(res.result)
9081

9182
function fit(
9283
optim::SemOptimizerOptim,
@@ -133,5 +124,11 @@ function fit(
133124
optim.options,
134125
)
135126
end
136-
return SemFit(result, model, start_params)
127+
return SemFit(
128+
result.minimum,
129+
result.minimizer,
130+
start_params,
131+
model,
132+
SemOptimResult(optim, result),
133+
)
137134
end

src/types.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ Base.:*(x::SemWeight, y) = x.w * y
8181

8282
abstract type SemOptimizer{E} end
8383

84+
# wrapper around optimization result
85+
abstract type SemOptimizerResult{O <: SemOptimizer} end
86+
8487
"""
8588
Supertype of all objects that can serve as the observed field of a SEM.
8689
Pre-processes data and computes sufficient statistics for example.

test/examples/proximal/l0.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ model_prox = Sem(specification = partable, data = dat, loss = SemML)
4545
fit_prox = fit(model_prox, engine = :Proximal, operator_g = prox_operator)
4646

4747
@testset "l0 | solution_unregularized" begin
48-
@test fit_prox.optimization_result.result[:iterations] < 1000
48+
@test n_iterations(fit_prox.optimization_result) < 1000
4949
@test maximum(abs.(solution(sem_fit) - solution(fit_prox))) < 0.002
5050
end
5151

@@ -57,7 +57,7 @@ model_prox = Sem(specification = partable, data = dat, loss = SemML)
5757
fit_prox = fit(model_prox, engine = :Proximal, operator_g = prox_operator)
5858

5959
@testset "l0 | solution_regularized" begin
60-
@test fit_prox.optimization_result.result[:iterations] < 1000
60+
@test n_iterations(fit_prox.optimization_result) < 1000
6161
@test solution(fit_prox)[31] == 0.0
6262
@test abs(
6363
StructuralEquationModels.minimum(fit_prox) -

test/examples/proximal/lasso.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ model_prox = Sem(specification = partable, data = dat, loss = SemML)
4343
fit_prox = fit(model_prox, engine = :Proximal, operator_g = NormL1(λ))
4444

4545
@testset "lasso | solution_unregularized" begin
46-
@test fit_prox.optimization_result.result[:iterations] < 1000
46+
@test n_iterations(fit_prox.optimization_result) < 1000
4747
@test maximum(abs.(solution(sem_fit) - solution(fit_prox))) < 0.002
4848
end
4949

@@ -55,7 +55,7 @@ model_prox = Sem(specification = partable, data = dat, loss = SemML)
5555
fit_prox = fit(model_prox, engine = :Proximal, operator_g = NormL1(λ))
5656

5757
@testset "lasso | solution_regularized" begin
58-
@test fit_prox.optimization_result.result[:iterations] < 1000
58+
@test n_iterations(fit_prox.optimization_result) < 1000
5959
@test all(solution(fit_prox)[16:20] .< solution(sem_fit)[16:20])
6060
@test StructuralEquationModels.minimum(fit_prox) -
6161
StructuralEquationModels.minimum(sem_fit) < 0.03

0 commit comments

Comments
 (0)