Skip to content

Commit 2e5c9b3

Browse files
Maximilian-Stefan-ErnstAlexey Stukalov
authored andcommitted
fix proximal extension
1 parent 6ba91f4 commit 2e5c9b3

1 file changed

Lines changed: 20 additions & 6 deletions

File tree

ext/SEMProximalOptExt/ProximalAlgorithms.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,9 @@ SEM.update_observed(optimizer::SemOptimizerProximal, observed::SemObserved; kwar
4848
SEM.algorithm(optimizer::SemOptimizerProximal) = optimizer.algorithm
4949

5050
############################################################################
51-
### Pretty Printing
51+
### Model fitting
5252
############################################################################
5353

54-
function Base.show(io::IO, struct_inst::SemOptimizerProximal)
55-
print_type_name(io, struct_inst)
56-
print_field_types(io, struct_inst)
57-
end
58-
5954
## connect to ProximalAlgorithms.jl
6055
function ProximalAlgorithms.value_and_gradient(model::AbstractSem, params)
6156
grad = similar(params)
@@ -106,10 +101,29 @@ function SEM.fit(
106101
)
107102
end
108103

104+
############################################################################################
105+
### additional methods
106+
############################################################################################
107+
108+
SEM.algorithm_name(res::ProximalResult) = SEM.algorithm_name(res.optimizer.algorithm)
109+
SEM.algorithm_name(
110+
::ProximalAlgorithms.IterativeAlgorithm{I, H, S, D, K},
111+
) where {I, H, S, D, K} = nameof(I)
112+
113+
SEM.convergence(
114+
::ProximalResult,
115+
) = "No standard convergence criteria for proximal \n algorithms available."
116+
SEM.n_iterations(res::ProximalResult) = res.n_iterations
117+
109118
############################################################################################
110119
# pretty printing
111120
############################################################################################
112121

122+
function Base.show(io::IO, struct_inst::SemOptimizerProximal)
123+
print_type_name(io, struct_inst)
124+
print_field_types(io, struct_inst)
125+
end
126+
113127
function Base.show(io::IO, result::ProximalResult)
114128
print(io, "Minimum: $(round(result.result[:minimum]; digits = 2)) \n")
115129
print(io, "No. evaluations: $(result.result[:iterations]) \n")

0 commit comments

Comments
 (0)