Skip to content

Commit 8c510c6

Browse files
committed
initial commit
1 parent d526f5c commit 8c510c6

2 files changed

Lines changed: 43 additions & 1 deletion

File tree

diffpy/snmf/subroutines.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,33 @@
11
import numpy as np
22
from diffpy.snmf.optimizers import get_weights
33
from diffpy.snmf.factorizers import lsqnonneg
4+
from diffpy.snmf.containers import ComponentSignal
45
import numdifftools
56

67

8+
def create_components(number_of_components, number_of_signals, grid_vector):
9+
"""Creates the ComponentSignal objects
10+
11+
Parameters
12+
----------
13+
number_of_components: int
14+
The number of component signals in the NMF decomposition
15+
number_of_signals: int
16+
grid_vector: 1d array
17+
The grid of the user provided signals.
18+
19+
Returns
20+
-------
21+
tuple of ComponentSignal objects
22+
The tuple containing `number_of_components` of initialized ComponentSignal objects.
23+
"""
24+
components = list()
25+
for c in range(number_of_components):
26+
c = ComponentSignal(grid_vector,number_of_signals,c)
27+
components.append(c)
28+
return tuple(components)
29+
30+
731
def lift_data(data_input, lift=1):
832
"""Lifts values of data_input
933

diffpy/snmf/tests/test_subroutines.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import numpy as np
33
from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
4-
update_weights_matrix, initialize_arrays, lift_data
4+
update_weights_matrix, initialize_arrays, lift_data, create_components
55

66
to = [
77
([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
@@ -144,8 +144,26 @@ def test_reconstruct_data(trd):
144144
(([[[1.5, 2], [10.5, 1], [0.5, 2]], 1]), ([[2, 2.5], [11, 1.5], [1, 2.5]])),
145145
(([[[-10, -10.5], [-12.2, -12.2], [0, 0]], 1]), ([[2.2, 1.7], [0, 0], [12.2, 12.2]])),
146146
]
147+
148+
147149
@pytest.mark.parametrize('tld', tld)
148150
def test_lift_data(tld):
149151
actual = lift_data(tld[0][0], tld[0][1])
150152
expected = tld[1]
151153
np.testing.assert_allclose(actual, expected)
154+
155+
tcc = [(2, [0, .5, 1, 1.5], 3, 3),
156+
(3,[0,10,20,30],10,15),
157+
(0,[0],11,30),
158+
(5,[1,1,1,1,1,1],10000,40000),
159+
(3,np.arange(stop=125,step=.05),20,2500),
160+
]
161+
@pytest.mark.parametrize('tcc', tcc)
162+
def test_create_components(tcc):
163+
actual = create_components(tcc[0], tcc[1], tcc[2], tcc[3])
164+
assert len(actual) == tcc[0]
165+
for c in actual:
166+
assert len(c.iq) == tcc[3]
167+
assert len(c.weights) == tcc[2]
168+
assert len(c.stretching_factors) == tcc[2]
169+
assert (c.grid == tcc[1]).all()

0 commit comments

Comments
 (0)