Skip to content

Commit 59a3ed8

Browse files
author
Alexey Stukalov
committed
SemOptimizerResult: optim wrapper
1 parent 39eb8ff commit 59a3ed8

8 files changed

Lines changed: 66 additions & 68 deletions

File tree

ext/SEMNLOptExt/NLopt.jl

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ local_options(optimizer::SemOptimizerNLopt) = optimizer.local_options
128128
equality_constraints(optimizer::SemOptimizerNLopt) = optimizer.equality_constraints
129129
inequality_constraints(optimizer::SemOptimizerNLopt) = optimizer.inequality_constraints
130130

131-
struct NLoptResult
131+
# wrapper for the NLopt optimization result
132+
struct NLoptResult <: SEM.SemOptimizerResult{SemOptimizerNLopt}
133+
optimizer::SemOptimizerNLopt
132134
result::Any
133135
problem::Any
134136
end
@@ -137,27 +139,15 @@ SEM.algorithm_name(res::NLoptResult) = res.problem.algorithm
137139
SEM.n_iterations(res::NLoptResult) = res.problem.numevals
138140
SEM.convergence(res::NLoptResult) = res.result[3]
139141

140-
# construct SemFit from fitted NLopt object
141-
function SemFit_NLopt(optimization_result, model::AbstractSem, start_val, optim, opt)
142-
return SemFit(
143-
optimization_result[1],
144-
optimization_result[2],
145-
start_val,
146-
model,
147-
optim,
148-
NLoptResult(optimization_result, opt),
149-
)
150-
end
151-
152142
# fit method
153143
function SEM.fit(
154144
optim::SemOptimizerNLopt,
155145
model::AbstractSem,
156146
start_params::AbstractVector;
157147
kwargs...,
158148
)
159-
opt = construct_NLopt(optim.algorithm, optim.options, nparams(model))
160-
opt.min_objective =
149+
problem = NLopt_problem(optim.algorithm, optim.options, nparams(model))
150+
problem.min_objective =
161151
(par, G) -> SEM.evaluate!(
162152
zero(eltype(par)),
163153
!isnothing(G) && !isempty(G) ? G : nothing,
@@ -166,36 +156,42 @@ function SEM.fit(
166156
par,
167157
)
168158
for (f, tol) in optim.inequality_constraints
169-
inequality_constraint!(opt, f, tol)
159+
inequality_constraint!(problem, f, tol)
170160
end
171161
for (f, tol) in optim.equality_constraints
172-
equality_constraint!(opt, f, tol)
162+
equality_constraint!(problem, f, tol)
173163
end
174164

175165
if !isnothing(optim.local_algorithm)
176-
opt_local =
177-
construct_NLopt(optim.local_algorithm, optim.local_options, nparams(model))
178-
opt.local_optimizer = opt_local
166+
problem.local_optimizer =
167+
NLopt_problem(optim.local_algorithm, optim.local_options, nparams(model))
179168
end
180169

181170
# fit
182-
result = NLopt.optimize(opt, start_params)
171+
result = NLopt.optimize(problem, start_params)
183172

184-
return SemFit_NLopt(result, model, start_params, optim, opt)
173+
return SemFit(
174+
result[1], # minimum
175+
result[2], # optimal params
176+
start_val,
177+
model,
178+
NLoptResult(optim, result, problem),
179+
)
185180
end
186181

187182
############################################################################################
188183
### additional functions
189184
############################################################################################
190185

191-
function construct_NLopt(algorithm, options, npar)
192-
opt = Opt(algorithm, npar)
186+
# construct NLopt.jl problem
187+
function NLopt_problem(algorithm, options, npar)
188+
problem = Opt(algorithm, npar)
193189

194190
for (key, val) in pairs(options)
195-
setproperty!(opt, key, val)
191+
setproperty!(problem, key, val)
196192
end
197193

198-
return opt
194+
return problem
199195
end
200196

201197
############################################################################################

ext/SEMProximalOptExt/ProximalAlgorithms.jl

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ SEM.update_observed(optimizer::SemOptimizerProximal, observed::SemObserved; kwar
4747
### Model fitting
4848
############################################################################
4949

50-
mutable struct ProximalResult
51-
result::Any
50+
# wrapper for the Proximal optimization result
51+
struct ProximalResult{O <: SemOptimizer{:Proximal}} <: SEM.SemOptimizerResult{O}
52+
optimizer::O
53+
n_iterations::Int
5254
end
5355

5456
## connect to ProximalAlgorithms.jl
@@ -65,49 +67,39 @@ function SEM.fit(
6567
kwargs...,
6668
)
6769
if isnothing(optim.operator_h)
68-
solution, iterations =
70+
solution, niterations =
6971
optim.algorithm(x0 = start_params, f = model, g = optim.operator_g)
7072
else
71-
solution, iterations = optim.algorithm(
73+
solution, niterations = optim.algorithm(
7274
x0 = start_params,
7375
f = model,
7476
g = optim.operator_g,
7577
h = optim.operator_h,
7678
)
7779
end
7880

79-
minimum = objective!(model, solution)
80-
81-
optimization_result = Dict(
82-
:minimum => minimum,
83-
:iterations => iterations,
84-
:algorithm => optim.algorithm,
85-
:operator_g => optim.operator_g,
86-
)
87-
88-
isnothing(optim.operator_h) ||
89-
push!(optimization_result, :operator_h => optim.operator_h)
90-
9181
return SemFit(
92-
minimum,
82+
objective!(model, solution), # minimum
9383
solution,
9484
start_params,
9585
model,
96-
optim,
97-
ProximalResult(optimization_result),
86+
ProximalResult(optim, niterations),
9887
)
9988
end
10089

10190
############################################################################################
10291
### additional methods
10392
############################################################################################
10493

105-
SEM.algorithm_name(res::ProximalResult) = SEM.algorithm_name(res.result[:algorithm])
106-
SEM.algorithm_name(::ProximalAlgorithms.IterativeAlgorithm{I,H,S,D,K}) where
107-
{I, H, S, D, K} = nameof(I)
94+
SEM.algorithm_name(res::ProximalResult) = SEM.algorithm_name(res.optimizer.algorithm)
95+
SEM.algorithm_name(
96+
::ProximalAlgorithms.IterativeAlgorithm{I, H, S, D, K},
97+
) where {I, H, S, D, K} = nameof(I)
10898

109-
SEM.convergence(::ProximalResult) = "No standard convergence criteria for proximal \n algorithms available."
110-
SEM.n_iterations(res::ProximalResult) = res.result[:iterations]
99+
SEM.convergence(
100+
::ProximalResult,
101+
) = "No standard convergence criteria for proximal \n algorithms available."
102+
SEM.n_iterations(res::ProximalResult) = res.n_iterations
111103

112104
############################################################################################
113105
# pretty printing
@@ -119,10 +111,8 @@ function Base.show(io::IO, struct_inst::SemOptimizerProximal)
119111
end
120112

121113
function Base.show(io::IO, result::ProximalResult)
122-
print(io, "Minimum: $(round(result.result[:minimum]; digits = 2)) \n")
123-
print(io, "No. evaluations: $(result.result[:iterations]) \n")
124-
print(io, "Operator: $(nameof(typeof(result.result[:operator_g]))) \n")
125-
if haskey(result.result, :operator_h)
126-
print(io, "Second Operator: $(nameof(typeof(result.result[:operator_h]))) \n")
127-
end
114+
print(io, "No. evaluations: $(result.n_iterations) \n")
115+
print(io, "Operator: $(nameof(typeof(result.optimizer.operator_g))) \n")
116+
op_h = result.optimizer.operator_h
117+
isnothing(op_h) || print(io, "Second Operator: $(nameof(typeof(op_h))) \n")
128118
end

src/frontend/fit/SemFit.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ Fitted structural equation model.
1717
- `n_iterations(::SemFit)` -> number of iterations
1818
- `convergence(::SemFit)` -> convergence properties
1919
"""
20-
mutable struct SemFit{Mi, So, St, Mo, Op, O}
20+
mutable struct SemFit{Mi, So, St, Mo, O}
2121
minimum::Mi
2222
solution::So
2323
start_val::St
2424
model::Mo
25-
optimizer::Op
2625
optimization_result::O
2726
end
2827

@@ -64,6 +63,6 @@ optimization_result(sem_fit::SemFit) = sem_fit.optimization_result
6463

6564
# optimizer properties
6665
algorithm_name(sem_fit::SemFit) = algorithm_name(sem_fit.optimization_result)
67-
optimizer_engine(sem_fit::SemFit) = optimizer_engine(sem_fit.optimizer)
66+
optimizer_engine(sem_fit::SemFit) = optimizer_engine(sem_fit.optimization_result)
6867
n_iterations(sem_fit::SemFit) = n_iterations(optimization_result(sem_fit))
6968
convergence(sem_fit::SemFit) = convergence(optimization_result(sem_fit))

src/optimizer/abstract.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ For a list of available engines, call [`optimizer_engines`](@ref).
7979
"""
8080
optimizer_engine_doc(engine) = doc(SemOptimizer_impltype(engine))
8181

82+
optimizer(result::SemOptimizerResult) = result.optimizer
83+
84+
optimizer_engine(result::SemOptimizerResult) = optimizer_engine(result.optimizer)
85+
8286
"""
8387
fit([optim::SemOptimizer], model::AbstractSem;
8488
[engine::Symbol], start_val = start_val, kwargs...)

src/optimizer/optim.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,15 @@ update_observed(optimizer::SemOptimizerOptim, observed::SemObserved; kwargs...)
7575

7676
options(optimizer::SemOptimizerOptim) = optimizer.options
7777

78-
algorithm_name(res::Optim.MultivariateOptimizationResults) = Optim.summary(res)
79-
n_iterations(res::Optim.MultivariateOptimizationResults) = Optim.iterations(res)
80-
convergence(res::Optim.MultivariateOptimizationResults) = Optim.converged(res)
78+
# wrapper for the Optim.jl result
79+
struct SemOptimResult{O <: SemOptimizerOptim} <: SemOptimizerResult{O}
80+
optimizer::O
81+
result::Optim.MultivariateOptimizationResults
82+
end
83+
84+
algorithm_name(res::SemOptimResult) = Optim.summary(res.result)
85+
n_iterations(res::SemOptimResult) = Optim.iterations(res.result)
86+
convergence(res::SemOptimResult) = Optim.converged(res.result)
8187

8288
function fit(
8389
optim::SemOptimizerOptim,
@@ -129,6 +135,6 @@ function fit(
129135
result.minimizer,
130136
start_params,
131137
model,
132-
optim,
133-
result)
138+
SemOptimResult(optim, result),
139+
)
134140
end

src/types.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ Base.:*(x::SemWeight, y) = x.w * y
8181

8282
abstract type SemOptimizer{E} end
8383

84+
# wrapper around optimization result
85+
abstract type SemOptimizerResult{O <: SemOptimizer} end
86+
8487
"""
8588
Supertype of all objects that can serve as the observed field of a SEM.
8689
Pre-processes data and computes sufficient statistics for example.

test/examples/proximal/l0.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ model_prox = Sem(specification = partable, data = dat, loss = SemML)
4545
fit_prox = fit(model_prox, engine = :Proximal, operator_g = prox_operator)
4646

4747
@testset "l0 | solution_unregularized" begin
48-
@test fit_prox.optimization_result.result[:iterations] < 1000
48+
@test n_iterations(fit_prox.optimization_result) < 1000
4949
@test maximum(abs.(solution(sem_fit) - solution(fit_prox))) < 0.002
5050
end
5151

@@ -57,7 +57,7 @@ model_prox = Sem(specification = partable, data = dat, loss = SemML)
5757
fit_prox = fit(model_prox, engine = :Proximal, operator_g = prox_operator)
5858

5959
@testset "l0 | solution_regularized" begin
60-
@test fit_prox.optimization_result.result[:iterations] < 1000
60+
@test n_iterations(fit_prox.optimization_result) < 1000
6161
@test solution(fit_prox)[31] == 0.0
6262
@test abs(
6363
StructuralEquationModels.minimum(fit_prox) -

test/examples/proximal/lasso.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ model_prox = Sem(specification = partable, data = dat, loss = SemML)
4343
fit_prox = fit(model_prox, engine = :Proximal, operator_g = NormL1(λ))
4444

4545
@testset "lasso | solution_unregularized" begin
46-
@test fit_prox.optimization_result.result[:iterations] < 1000
46+
@test n_iterations(fit_prox.optimization_result) < 1000
4747
@test maximum(abs.(solution(sem_fit) - solution(fit_prox))) < 0.002
4848
end
4949

@@ -55,7 +55,7 @@ model_prox = Sem(specification = partable, data = dat, loss = SemML)
5555
fit_prox = fit(model_prox, engine = :Proximal, operator_g = NormL1(λ))
5656

5757
@testset "lasso | solution_regularized" begin
58-
@test fit_prox.optimization_result.result[:iterations] < 1000
58+
@test n_iterations(fit_prox.optimization_result) < 1000
5959
@test all(solution(fit_prox)[16:20] .< solution(sem_fit)[16:20])
6060
@test StructuralEquationModels.minimum(fit_prox) -
6161
StructuralEquationModels.minimum(sem_fit) < 0.03

0 commit comments

Comments
 (0)