Skip to content

Commit 7eca89d

Browse files
committed
Update snmf_class.py
Optimized print messages Added verbose option Added objective_log attribute to track objective function updates, with associated time stamps and specifying what matrix has been updated
1 parent f24c5e3 commit 7eca89d

1 file changed

Lines changed: 86 additions & 34 deletions

File tree

src/diffpy/stretched_nmf/snmf_class.py

Lines changed: 86 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import time
2+
13
import cvxpy as cp
24
import numpy as np
35
from scipy.optimize import minimize
@@ -82,6 +84,7 @@ def __init__(
8284
n_components=None,
8385
random_state=None,
8486
show_plots=False,
87+
verbose=True,
8588
):
8689
"""Initialize an instance of sNMF.
8790
@@ -131,6 +134,7 @@ def __init__(
131134
self.num_updates = 0
132135
self._rng = np.random.default_rng(random_state)
133136
self.plotter = SNMFPlotter() if show_plots else None
137+
self.verbose = verbose
134138

135139
# Enforce exclusive specification of n_components or init_weights
136140
if (n_components is None and init_weights is None) or (
@@ -183,6 +187,7 @@ def __init__(
183187
[1, -2, 1],
184188
offsets=[0, 1, 2],
185189
shape=(self.n_signals - 2, self.n_signals),
190+
dtype=float,
186191
)
187192

188193
def fit(self, rho=0, eta=0, reset=True):
@@ -235,6 +240,13 @@ def fit(self, rho=0, eta=0, reset=True):
235240
]
236241
self.objective_difference = None
237242
self._objective_history = [self.objective_function]
243+
self.objective_log = [
244+
{
245+
"step": "start",
246+
"objective": self.objective_function,
247+
"timestamp": time.time(),
248+
}
249+
]
238250

239251
# Set up tracking variables for update_components()
240252
self._prev_components = None
@@ -255,10 +267,11 @@ def fit(self, rho=0, eta=0, reset=True):
255267
obj_diff = (
256268
self.objective_function - regularization_term - sparsity_term
257269
)
258-
print(
259-
f"Start, Objective function: {self.objective_function:.5e}"
260-
f", Obj - reg/sparse: {obj_diff:.5e}"
261-
)
270+
if self.verbose:
271+
print(
272+
f"Start, Objective function: {self.objective_function:.5e}"
273+
f", Obj - reg/sparse: {obj_diff:.5e}"
274+
)
262275

263276
# Main optimization loop
264277
for outiter in range(self.max_iter):
@@ -279,22 +292,19 @@ def fit(self, rho=0, eta=0, reset=True):
279292
obj_diff = (
280293
self.objective_function - regularization_term - sparsity_term
281294
)
282-
print(
283-
f"Obj fun: {self.objective_function:.5e}, "
284-
f", Obj - reg/sparse: {obj_diff:.5e}"
285-
f"Iter: {self.outiter}"
286-
)
287-
295+
convergence_threshold = self.objective_function * self.tol
288296
# Convergence check: Stop if diffun is small
289297
# and at least min_iter iterations have passed
290-
print(
291-
"Checking if ",
292-
self.objective_difference,
293-
" < ",
294-
self.objective_function * self.tol,
295-
)
298+
if self.verbose:
299+
print(
300+
f"\n--- Iteration {self.outiter} ---"
301+
f"\nTotal Objective : {self.objective_function:.5e}"
302+
f"\nBase Obj (No Reg) : {obj_diff:.5e}"
303+
f"\nConvergence Check : Δ {self.objective_difference:.5e}"
304+
f" < {convergence_threshold:.5e} (Threshold)\n"
305+
)
296306
if (
297-
self.objective_difference < self.objective_function * self.tol
307+
self.objective_difference < convergence_threshold
298308
and outiter >= self.min_iter
299309
):
300310
self.converged_ = True
@@ -305,6 +315,8 @@ def fit(self, rho=0, eta=0, reset=True):
305315
return self
306316

307317
def normalize_results(self):
318+
if self.verbose:
319+
print("\nNormalizing results after convergence...")
308320
# Select our best results for normalization
309321
self.components_ = self.best_matrices[0]
310322
self.weights_ = self.best_matrices[1]
@@ -335,11 +347,18 @@ def normalize_results(self):
335347
self.update_components()
336348
self.residuals = self.get_residual_matrix()
337349
self.objective_function = self.get_objective_function()
338-
print(
339-
f"Objective function after normalize_components: "
340-
f"{self.objective_function:.5e}"
341-
)
350+
# print(
351+
# f"Objective function after normalize_components: "
352+
# f"{self.objective_function:.5e}"
353+
# )
342354
self._objective_history.append(self.objective_function)
355+
self.objective_log = [
356+
{
357+
"step": "c_norm",
358+
"objective": self.objective_function,
359+
"timestamp": time.time(),
360+
}
361+
]
343362
self.objective_difference = (
344363
self._objective_history[-2] - self._objective_history[-1]
345364
)
@@ -357,16 +376,25 @@ def normalize_results(self):
357376
break
358377

359378
def outer_loop(self):
379+
if self.verbose:
380+
print("Updating components and weights in outer loop...")
360381
for iter in range(4):
361382
self.iter = iter
362383
self._prev_grad_components = self._grad_components.copy()
363384
self.update_components()
364385
self.residuals = self.get_residual_matrix()
365386
self.objective_function = self.get_objective_function()
366-
print(
367-
f"Objective function after update_components: "
368-
f"{self.objective_function:.5e}"
369-
)
387+
self.objective_log = [
388+
{
389+
"step": "c",
390+
"objective": self.objective_function,
391+
"timestamp": time.time(),
392+
}
393+
]
394+
# print(
395+
# f"Objective function after update_components: "
396+
# f"{self.objective_function:.5e}"
397+
# )
370398
self._objective_history.append(self.objective_function)
371399
self.objective_difference = (
372400
self._objective_history[-2] - self._objective_history[-1]
@@ -389,11 +417,19 @@ def outer_loop(self):
389417
self.update_weights()
390418
self.residuals = self.get_residual_matrix()
391419
self.objective_function = self.get_objective_function()
392-
print(
393-
f"Objective function after update_weights: "
394-
f"{self.objective_function:.5e}"
395-
)
420+
# print(
421+
# f"Objective function after update_weights: "
422+
# f"{self.objective_function:.5e}"
423+
# )
396424
self._objective_history.append(self.objective_function)
425+
self.objective_log = [
426+
{
427+
"step": "w",
428+
"objective": self.objective_function,
429+
"timestamp": time.time(),
430+
}
431+
]
432+
397433
self.objective_difference = (
398434
self._objective_history[-2] - self._objective_history[-1]
399435
)
@@ -426,11 +462,18 @@ def outer_loop(self):
426462
self.update_stretch()
427463
self.residuals = self.get_residual_matrix()
428464
self.objective_function = self.get_objective_function()
429-
print(
430-
f"Objective function after update_stretch: "
431-
f"{self.objective_function:.5e}"
432-
)
465+
# print(
466+
# f"Objective function after update_stretch: "
467+
# f"{self.objective_function:.5e}"
468+
# )
433469
self._objective_history.append(self.objective_function)
470+
self.objective_log = [
471+
{
472+
"step": "s",
473+
"objective": self.objective_function,
474+
"timestamp": time.time(),
475+
}
476+
]
434477
self.objective_difference = (
435478
self._objective_history[-2] - self._objective_history[-1]
436479
)
@@ -712,7 +755,12 @@ def solve_quadratic_program(self, t, m):
712755

713756
# Solve using a QP solver
714757
prob = cp.Problem(objective, constraints)
715-
prob.solve(solver=cp.OSQP, verbose=False)
758+
prob.solve(
759+
solver=cp.OSQP,
760+
verbose=False,
761+
polish=False, # TODO keep? removes polish message
762+
# solver_verbose=False
763+
)
716764

717765
# Get the solution
718766
return np.maximum(
@@ -722,6 +770,7 @@ def solve_quadratic_program(self, t, m):
722770
def update_components(self):
723771
"""Updates `components` using gradient-based optimization with
724772
adaptive step size."""
773+
725774
# Compute stretched components using the interpolation function
726775
stretched_components, _, _ = (
727776
self.compute_stretched_components()
@@ -868,6 +917,9 @@ def update_stretch(self):
868917
"""Updates stretching matrix using constrained optimization
869918
(equivalent to fmincon in MATLAB)."""
870919

920+
if self.verbose:
921+
print("Updating stretch factors...")
922+
871923
# Flatten stretch for compatibility with the optimizer
872924
# (since SciPy expects 1D input)
873925
stretch_flat_initial = self.stretch_.flatten()

0 commit comments

Comments
 (0)