Skip to content

Commit df19205

Browse files
committed
fixed components initializer
Fixed tests, refactored name, added raised error.
1 parent 564213c commit df19205

2 files changed

Lines changed: 14 additions & 18 deletions

File tree

diffpy/snmf/subroutines.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import numdifftools
66

77

8-
def create_components(number_of_components, number_of_signals, grid_vector):
9-
"""Creates the ComponentSignal objects
8+
def initialize_components(number_of_components, number_of_signals, grid_vector):
9+
"""Initializes ComponentSignals for each of the components in the decomposition
1010
1111
Parameters
1212
----------
@@ -21,10 +21,12 @@ def create_components(number_of_components, number_of_signals, grid_vector):
2121
tuple of ComponentSignal objects
2222
The tuple containing `number_of_components` of initialized ComponentSignal objects.
2323
"""
24+
if number_of_components <= 0:
25+
raise ValueError(f"Number of components = {number_of_components}. Number_of_components must be >= 1.")
2426
components = list()
25-
for c in range(number_of_components):
26-
c = ComponentSignal(grid_vector,number_of_signals,c)
27-
components.append(c)
27+
for component in range(number_of_components):
28+
component = ComponentSignal(grid_vector,number_of_signals,component)
29+
components.append(component)
2830
return tuple(components)
2931

3032

diffpy/snmf/tests/test_subroutines.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import numpy as np
33
from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
4-
update_weights_matrix, initialize_arrays, lift_data, create_components
4+
update_weights_matrix, initialize_arrays, lift_data, initialize_components
55

66
to = [
77
([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
@@ -152,18 +152,12 @@ def test_lift_data(tld):
152152
expected = tld[1]
153153
np.testing.assert_allclose(actual, expected)
154154

155-
tcc = [(2, [0, .5, 1, 1.5], 3, 3),
156-
(3,[0,10,20,30],10,15),
157-
(0,[0],11,30),
158-
(5,[1,1,1,1,1,1],10000,40000),
159-
(3,np.arange(stop=125,step=.05),20,2500),
155+
tcc = [(2, 3,[0, .5, 1, 1.5]), # Regular usage
156+
#(0, 3,[0, .5, 1, 1.5]), # Zero components raise an exception. Not tested
160157
]
161158
@pytest.mark.parametrize('tcc', tcc)
162-
def test_create_components(tcc):
163-
actual = create_components(tcc[0], tcc[1], tcc[2], tcc[3])
159+
def test_initialize_components(tcc):
160+
actual = initialize_components(tcc[0], tcc[1], tcc[2])
164161
assert len(actual) == tcc[0]
165-
for c in actual:
166-
assert len(c.iq) == tcc[3]
167-
assert len(c.weights) == tcc[2]
168-
assert len(c.stretching_factors) == tcc[2]
169-
assert (c.grid == tcc[1]).all()
162+
assert len(actual[0].weights) == tcc[1]
163+
assert (actual[0].grid == np.array(tcc[2])).all()

0 commit comments

Comments
 (0)