Skip to content

Commit 8ea3998

Browse files
committed
Update snmf_class.py
1 parent 128cb5b commit 8ea3998

1 file changed

Lines changed: 15 additions & 6 deletions

File tree

src/diffpy/stretched_nmf/snmf_class.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,14 @@ def fit(self, rho=0, eta=0, reset=True):
264264
sparsity_term = self.eta * np.sum(
265265
np.sqrt(self.components_)
266266
) # Square root penalty
267-
obj_diff = (
267+
base_obj = (
268268
self.objective_function - regularization_term - sparsity_term
269269
)
270270
if self.verbose:
271271
print(
272272
f"\n--- Start ---"
273273
f"\nTotal Objective : {self.objective_function:.5e}"
274-
f"\nBase Obj (No Reg) : {obj_diff:.5e}"
274+
f"\nBase Obj (No Reg) : {base_obj:.5e}"
275275
)
276276

277277
# Main optimization loop
@@ -290,7 +290,7 @@ def fit(self, rho=0, eta=0, reset=True):
290290
sparsity_term = self.eta * np.sum(
291291
np.sqrt(self.components_)
292292
) # Square root penalty
293-
obj_diff = (
293+
base_obj = (
294294
self.objective_function - regularization_term - sparsity_term
295295
)
296296
convergence_threshold = self.objective_function * self.tol
@@ -300,7 +300,7 @@ def fit(self, rho=0, eta=0, reset=True):
300300
print(
301301
f"\n--- Iteration {self.outiter} ---"
302302
f"\nTotal Objective : {self.objective_function:.5e}"
303-
f"\nBase Obj (No Reg) : {obj_diff:.5e}"
303+
f"\nBase Obj (No Reg) : {base_obj:.5e}"
304304
"\nConvergence Check : Δ "
305305
f"({self.objective_difference:.2e})"
306306
f" < Threshold ({convergence_threshold:.2e})\n"
@@ -367,15 +367,24 @@ def normalize_results(self):
367367
stretch=self.stretch_,
368368
update_tag="normalize components",
369369
)
370+
convergence_threshold = self.objective_function * self.tol
371+
if self.verbose:
372+
print(
373+
f"\n--- Iteration {outiter} after normalization---"
374+
f"\nTotal Objective : {self.objective_function:.5e}"
375+
"\nConvergence Check : Δ "
376+
f"({self.objective_difference:.2e})"
377+
f" < Threshold ({convergence_threshold:.2e})\n"
378+
)
370379
if (
371-
self.objective_difference < self.objective_function * self.tol
380+
self.objective_difference < convergence_threshold
372381
and outiter >= 7
373382
):
374383
break
375384

376385
def outer_loop(self):
377386
if self.verbose:
378-
print("Updating components and weights in outer loop...")
387+
print("Updating components and weights...")
379388
for iter in range(4):
380389
self.iter = iter
381390
self._prev_grad_components = self._grad_components.copy()

0 commit comments

Comments
 (0)