Skip to content

Commit 90c81a8

Browse files
committed
Update snmf_class.py
Added "iteration" key to objective_log Removed redundant _objective_history attribute
1 parent f69c1a1 commit 90c81a8

1 file changed

Lines changed: 23 additions & 16 deletions

File tree

src/diffpy/stretched_nmf/snmf_class.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,10 @@ def fit(self, rho=0, eta=0, reset=True):
239239
self.stretch_.copy(),
240240
]
241241
self.objective_difference = None
242-
self._objective_history = [self.objective_function]
243242
self.objective_log = [
244243
{
245244
"step": "start",
245+
"iteration": 0,
246246
"objective": self.objective_function,
247247
"timestamp": time.time(),
248248
}
@@ -340,7 +340,6 @@ def normalize_results(self):
340340
self.residuals = self.get_residual_matrix()
341341
self.objective_function = self.get_objective_function()
342342
self.objective_difference = None
343-
self._objective_history = [self.objective_function]
344343
self.outiter = 0
345344
self.iter = 0
346345
for outiter in range(self.max_iter):
@@ -349,16 +348,17 @@ def normalize_results(self):
349348
self.update_components()
350349
self.residuals = self.get_residual_matrix()
351350
self.objective_function = self.get_objective_function()
352-
self._objective_history.append(self.objective_function)
353351
self.objective_log.append(
354352
{
355353
"step": "c_norm",
354+
"iteration": outiter,
356355
"objective": self.objective_function,
357356
"timestamp": time.time(),
358357
}
359358
)
360359
self.objective_difference = (
361-
self._objective_history[-2] - self._objective_history[-1]
360+
self.objective_log[-2]["objective"]
361+
- self.objective_log[-1]["objective"]
362362
)
363363
if self.plotter is not None:
364364
self.plotter.update(
@@ -385,13 +385,14 @@ def outer_loop(self):
385385
self.objective_log.append(
386386
{
387387
"step": "c",
388+
"iteration": self.outiter,
388389
"objective": self.objective_function,
389390
"timestamp": time.time(),
390391
}
391392
)
392-
self._objective_history.append(self.objective_function)
393393
self.objective_difference = (
394-
self._objective_history[-2] - self._objective_history[-1]
394+
self.objective_log[-2]["objective"]
395+
- self.objective_log[-1]["objective"]
395396
)
396397
if self.objective_function < self.best_objective:
397398
self.best_objective = self.objective_function
@@ -411,17 +412,18 @@ def outer_loop(self):
411412
self.update_weights()
412413
self.residuals = self.get_residual_matrix()
413414
self.objective_function = self.get_objective_function()
414-
self._objective_history.append(self.objective_function)
415415
self.objective_log.append(
416416
{
417417
"step": "w",
418+
"iteration": self.outiter,
418419
"objective": self.objective_function,
419420
"timestamp": time.time(),
420421
}
421422
)
422423

423424
self.objective_difference = (
424-
self._objective_history[-2] - self._objective_history[-1]
425+
self.objective_log[-2]["objective"]
426+
- self.objective_log[-1]["objective"]
425427
)
426428
if self.objective_function < self.best_objective:
427429
self.best_objective = self.objective_function
@@ -439,10 +441,11 @@ def outer_loop(self):
439441
)
440442

441443
self.objective_difference = (
442-
self._objective_history[-2] - self._objective_history[-1]
444+
self.objective_log[-2]["objective"]
445+
- self.objective_log[-1]["objective"]
443446
)
444447
if (
445-
self._objective_history[-3] - self.objective_function
448+
self.objective_log[-3]["objective"] - self.objective_function
446449
< self.objective_difference * 1e-3
447450
):
448451
break
@@ -452,16 +455,17 @@ def outer_loop(self):
452455
self.update_stretch()
453456
self.residuals = self.get_residual_matrix()
454457
self.objective_function = self.get_objective_function()
455-
self._objective_history.append(self.objective_function)
456458
self.objective_log.append(
457459
{
458460
"step": "s",
461+
"iteration": self.outiter,
459462
"objective": self.objective_function,
460463
"timestamp": time.time(),
461464
}
462465
)
463466
self.objective_difference = (
464-
self._objective_history[-2] - self._objective_history[-1]
467+
self.objective_log[-2]["objective"]
468+
- self.objective_log[-1]["objective"]
465469
)
466470
if self.objective_function < self.best_objective:
467471
self.best_objective = self.objective_function
@@ -820,10 +824,13 @@ def update_components(self):
820824
)
821825
self.components_ = mask * self.components_
822826

823-
objective_improvement = self._objective_history[
824-
-1
825-
] - self.get_objective_function(
826-
residuals=self.get_residual_matrix()
827+
# self.objective_function is the baseline value cached right
828+
# before we called update_components()
829+
objective_improvement = (
830+
self.objective_function
831+
- self.get_objective_function(
832+
residuals=self.get_residual_matrix()
833+
)
827834
)
828835

829836
# Check if objective function improves

0 commit comments

Comments
 (0)