1+ import time
2+
13import cvxpy as cp
24import numpy as np
35from scipy .optimize import minimize
@@ -82,6 +84,7 @@ def __init__(
8284 eta = 0 ,
8385 random_state = None ,
8486 show_plots = False ,
87+ verbose = False ,
8588 ):
8689 """Initialize an instance of sNMF with estimator
8790 hyperparameters.
@@ -126,6 +129,7 @@ def __init__(
126129 self .eta = eta
127130 self .random_state = random_state
128131 self .show_plots = show_plots
132+ self .verbose = verbose
129133
130134 self ._rng = np .random .default_rng (self .random_state )
131135 self ._plotter = SNMFPlotter () if self .show_plots else None
@@ -214,6 +218,7 @@ def _initialize_factors(
214218 [1 , - 2 , 1 ],
215219 offsets = [0 , 1 , 2 ],
216220 shape = (self .n_signals_ - 2 , self .n_signals_ ),
221+ dtype = float ,
217222 )
218223
219224 def fit (
@@ -303,7 +308,14 @@ def fit(
303308 self .stretch_ .copy (),
304309 ]
305310 self .objective_difference_ = None
306- self ._objective_history = [self .objective_function_ ]
311+ self .objective_log = [
312+ {
313+ "step" : "start" ,
314+ "iteration" : 0 ,
315+ "objective" : self .objective_function_ ,
316+ "timestamp" : time .time (),
317+ }
318+ ]
307319
308320 # Set up tracking variables for _update_components()
309321 self ._prev_components = None
@@ -322,13 +334,15 @@ def fit(
322334 sparsity_term = self .eta * np .sum (
323335 np .sqrt (self .components_ )
324336 ) # Square root penalty
325- objective_without_penalty = (
337+ base_obj = (
326338 self .objective_function_ - regularization_term - sparsity_term
327339 )
328- print (
329- f"Start, Objective function: { self .objective_function_ :.5e} "
330- f", Obj - reg/sparse: { objective_without_penalty :.5e} "
331- )
340+ if self .verbose :
341+ print (
342+ f"\n --- Start ---"
343+ f"\n Total Objective : { self .objective_function_ :.5e} "
344+ f"\n Base Obj (No Reg) : { base_obj :.5e} "
345+ )
332346
333347 # Main optimization loop
334348 for outiter in range (self .max_iter ):
@@ -347,35 +361,23 @@ def fit(
347361 sparsity_term = self .eta * np .sum (
348362 np .sqrt (self .components_ )
349363 ) # Square root penalty
350- objective_without_penalty = (
364+ base_obj = (
351365 self .objective_function_ - regularization_term - sparsity_term
352366 )
353- print (
354- f"Obj fun: { self .objective_function_ :.5e} , "
355- f"Obj - reg/sparse: { objective_without_penalty :.5e} , "
356- f"Iter: { self ._outer_iter } "
357- )
358- obj_diff = (
359- self .objective_function - regularization_term - sparsity_term
360- )
361- print (
362- f"Obj fun: { self .objective_function :.5e} , "
363- f", Obj - reg/sparse: { obj_diff :.5e} "
364- f"Iter: { self .outiter } "
365- )
366-
367+ convergence_threshold = self .objective_function_ * self .tol
367368 # Convergence check: Stop if diffun is small
368369 # and at least min_iter iterations have passed
369- print (
370- "Checking if " ,
371- self .objective_difference_ ,
372- " < " ,
373- self .objective_function_ * self .tol ,
374- )
370+ if self .verbose :
371+ print (
372+ f"\n --- Iteration { self ._outer_iter } ---"
373+ f"\n Total Objective : { self .objective_function_ :.5e} "
374+ f"\n Base Obj (No Reg) : { base_obj :.5e} "
375+ "\n Convergence Check : Δ "
376+ f"({ self .objective_difference_ :.2e} )"
377+ f" < Threshold ({ convergence_threshold :.2e} )\n "
378+ )
375379 if (
376- self .objective_difference_ is not None
377- and self .objective_difference_
378- < self .objective_function_ * self .tol
380+ self .objective_difference_ < convergence_threshold
379381 and outiter >= self .min_iter
380382 ):
381383 self .converged_ = True
@@ -387,6 +389,8 @@ def fit(
387389 return self
388390
389391 def _normalize_results (self ):
392+ if self .verbose :
393+ print ("\n Normalizing results after convergence..." )
390394 # Select our best results for normalization
391395 self .components_ = self .best_matrices_ [0 ]
392396 self .weights_ = self .best_matrices_ [1 ]
@@ -420,13 +424,17 @@ def _normalize_results(self):
420424 self ._update_components ()
421425 self .residuals_ = self ._get_residual_matrix ()
422426 self .objective_function_ = self ._get_objective_function ()
423- print (
424- f"Objective function after normalize_components: "
425- f"{ self .objective_function_ :.5e} "
427+ self .objective_log .append (
428+ {
429+ "step" : "c_norm" ,
430+ "iteration" : outiter ,
431+ "objective" : self .objective_function_ ,
432+ "timestamp" : time .time (),
433+ }
426434 )
427- self ._objective_history .append (self .objective_function_ )
428435 self .objective_difference_ = (
429- self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
436+ self .objective_log [- 2 ]["objective" ]
437+ - self .objective_log [- 1 ]["objective" ]
430438 )
431439 if self ._plotter is not None :
432440 self ._plotter .update (
@@ -435,27 +443,41 @@ def _normalize_results(self):
435443 stretch = self .stretch_ ,
436444 update_tag = "normalize components" ,
437445 )
446+ convergence_threshold = self .objective_function_ * self .tol
447+ if self .verbose :
448+ print (
449+ f"\n --- Iteration { outiter } after normalization---"
450+ f"\n Total Objective : { self .objective_function_ :.5e} "
451+ "\n Convergence Check : Δ "
452+ f"({ self .objective_difference_ :.2e} )"
453+ f" < Threshold ({ convergence_threshold :.2e} )\n "
454+ )
438455 if (
439- self .objective_difference_
440- < self .objective_function_ * self .tol
456+ self .objective_difference_ < convergence_threshold
441457 and outiter >= 7
442458 ):
443459 break
444460
445461 def _outer_loop (self ):
462+ if self .verbose :
463+ print ("Updating components and weights..." )
446464 for inner_iter in range (4 ):
447465 self ._inner_iter = inner_iter
448466 self ._prev_grad_components = self ._grad_components .copy ()
449467 self ._update_components ()
450468 self .residuals_ = self ._get_residual_matrix ()
451469 self .objective_function_ = self ._get_objective_function ()
452- print (
453- f"Objective function after _update_components: "
454- f"{ self .objective_function_ :.5e} "
470+ self .objective_log .append (
471+ {
472+ "step" : "c" ,
473+ "iteration" : self ._outer_iter ,
474+ "objective" : self .objective_function_ ,
475+ "timestamp" : time .time (),
476+ }
455477 )
456- self ._objective_history .append (self .objective_function_ )
457478 self .objective_difference_ = (
458- self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
479+ self .objective_log [- 2 ]["objective" ]
480+ - self .objective_log [- 1 ]["objective" ]
459481 )
460482 if self .objective_function_ < self .best_objective_ :
461483 self .best_objective_ = self .objective_function_
@@ -475,13 +497,18 @@ def _outer_loop(self):
475497 self ._update_weights ()
476498 self .residuals_ = self ._get_residual_matrix ()
477499 self .objective_function_ = self ._get_objective_function ()
478- print (
479- f"Objective function after _update_weights: "
480- f"{ self .objective_function_ :.5e} "
500+ self .objective_log .append (
501+ {
502+ "step" : "w" ,
503+ "iteration" : self ._outer_iter ,
504+ "objective" : self .objective_function_ ,
505+ "timestamp" : time .time (),
506+ }
481507 )
482- self . _objective_history . append ( self . objective_function_ )
508+
483509 self .objective_difference_ = (
484- self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
510+ self .objective_log [- 2 ]["objective" ]
511+ - self .objective_log [- 1 ]["objective" ]
485512 )
486513 if self .objective_function_ < self .best_objective_ :
487514 self .best_objective_ = self .objective_function_
@@ -499,10 +526,11 @@ def _outer_loop(self):
499526 )
500527
501528 self .objective_difference_ = (
502- self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
529+ self .objective_log [- 2 ]["objective" ]
530+ - self .objective_log [- 1 ]["objective" ]
503531 )
504532 if (
505- self ._objective_history [- 3 ] - self .objective_function_
533+ self .objective_log [- 3 ][ "objective" ] - self .objective_function_
506534 < self .objective_difference_ * 1e-3
507535 ):
508536 break
@@ -512,13 +540,17 @@ def _outer_loop(self):
512540 self ._update_stretch ()
513541 self .residuals_ = self ._get_residual_matrix ()
514542 self .objective_function_ = self ._get_objective_function ()
515- print (
516- f"Objective function after _update_stretch: "
517- f"{ self .objective_function_ :.5e} "
543+ self .objective_log .append (
544+ {
545+ "step" : "s" ,
546+ "iteration" : self ._outer_iter ,
547+ "objective" : self .objective_function_ ,
548+ "timestamp" : time .time (),
549+ }
518550 )
519- self ._objective_history .append (self .objective_function_ )
520551 self .objective_difference_ = (
521- self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
552+ self .objective_log [- 2 ]["objective" ]
553+ - self .objective_log [- 1 ]["objective" ]
522554 )
523555 if self .objective_function_ < self .best_objective_ :
524556 self .best_objective_ = self .objective_function_
@@ -804,7 +836,12 @@ def _solve_quadratic_program(self, t, m):
804836
805837 # Solve using a QP solver
806838 prob = cp .Problem (objective , constraints )
807- prob .solve (solver = cp .OSQP , verbose = False )
839+ prob .solve (
840+ solver = cp .OSQP ,
841+ verbose = False ,
842+ polish = False , # TODO keep? removes polish message
843+ # solver_verbose=False
844+ )
808845
809846 # Get the solution
810847 return np .maximum (
@@ -814,6 +851,7 @@ def _solve_quadratic_program(self, t, m):
814851 def _update_components (self ):
815852 """Updates `components` using gradient-based optimization with
816853 adaptive step size."""
854+
817855 # Compute stretched components using the interpolation function
818856 stretched_components , _ , _ = (
819857 self ._compute_stretched_components ()
@@ -878,8 +916,8 @@ def _update_components(self):
878916 )
879917 self .components_ = mask * self .components_
880918
881- objective_improvement = self ._objective_history [
882- - 1
919+ objective_improvement = self .objective_log [ - 1 ] [
920+ "objective"
883921 ] - self ._get_objective_function (
884922 residuals = self ._get_residual_matrix ()
885923 )
@@ -962,6 +1000,9 @@ def _update_stretch(self):
9621000 """Updates stretching matrix using constrained optimization
9631001 (equivalent to fmincon in MATLAB)."""
9641002
1003+ if self .verbose :
1004+ print ("Updating stretch factors..." )
1005+
9651006 # Flatten stretch for compatibility with the optimizer
9661007 # (since SciPy expects 1D input)
9671008 stretch_flat_initial = self .stretch_ .flatten ()
0 commit comments