Skip to content

Commit 6f6bda8

Browse files
authored
Merge pull request #34 from aajayi-21/create_components
function create_components
2 parents 2025d70 + df19205 commit 6f6bda8

2 files changed

Lines changed: 39 additions & 1 deletion

File tree

diffpy/snmf/subroutines.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,35 @@
11
import numpy as np
22
from diffpy.snmf.optimizers import get_weights
33
from diffpy.snmf.factorizers import lsqnonneg
4+
from diffpy.snmf.containers import ComponentSignal
45
import numdifftools
56

67

8+
def initialize_components(number_of_components, number_of_signals, grid_vector):
9+
"""Initializes ComponentSignals for each of the components in the decomposition
10+
11+
Parameters
12+
----------
13+
number_of_components: int
14+
The number of component signals in the NMF decomposition
15+
number_of_signals: int
16+
grid_vector: 1d array
17+
The grid of the user provided signals.
18+
19+
Returns
20+
-------
21+
tuple of ComponentSignal objects
22+
The tuple containing `number_of_components` of initialized ComponentSignal objects.
23+
"""
24+
if number_of_components <= 0:
25+
raise ValueError(f"Number of components = {number_of_components}. Number_of_components must be >= 1.")
26+
components = list()
27+
for component in range(number_of_components):
28+
component = ComponentSignal(grid_vector,number_of_signals,component)
29+
components.append(component)
30+
return tuple(components)
31+
32+
733
def lift_data(data_input, lift=1):
834
"""Lifts values of data_input
935

diffpy/snmf/tests/test_subroutines.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import numpy as np
33
from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
4-
update_weights_matrix, initialize_arrays, lift_data
4+
update_weights_matrix, initialize_arrays, lift_data, initialize_components
55

66
to = [
77
([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
@@ -144,8 +144,20 @@ def test_reconstruct_data(trd):
144144
(([[[1.5, 2], [10.5, 1], [0.5, 2]], 1]), ([[2, 2.5], [11, 1.5], [1, 2.5]])),
145145
(([[[-10, -10.5], [-12.2, -12.2], [0, 0]], 1]), ([[2.2, 1.7], [0, 0], [12.2, 12.2]])),
146146
]
147+
148+
147149
@pytest.mark.parametrize('tld', tld)
148150
def test_lift_data(tld):
149151
actual = lift_data(tld[0][0], tld[0][1])
150152
expected = tld[1]
151153
np.testing.assert_allclose(actual, expected)
154+
155+
tcc = [(2, 3,[0, .5, 1, 1.5]), # Regular usage
156+
#(0, 3,[0, .5, 1, 1.5]), # Zero components raise an exception. Not tested
157+
]
158+
@pytest.mark.parametrize('tcc', tcc)
159+
def test_initialize_components(tcc):
160+
actual = initialize_components(tcc[0], tcc[1], tcc[2])
161+
assert len(actual) == tcc[0]
162+
assert len(actual[0].weights) == tcc[1]
163+
assert (actual[0].grid == np.array(tcc[2])).all()

0 commit comments

Comments
 (0)