Skip to content

Commit f03ac68

Browse files
authored
Merge pull request #38 from aajayi-21/construct_stretching
function construct_stretching_matrix
2 parents 712a6fd + 8517d76 commit f03ac68

2 files changed

Lines changed: 57 additions & 2 deletions

File tree

diffpy/snmf/subroutines.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def initialize_components(number_of_components, number_of_signals, grid_vector):
2525
raise ValueError(f"Number of components = {number_of_components}. Number_of_components must be >= 1.")
2626
components = list()
2727
for component in range(number_of_components):
28-
component = ComponentSignal(grid_vector,number_of_signals,component)
28+
component = ComponentSignal(grid_vector, number_of_signals, component)
2929
components.append(component)
3030
return tuple(components)
3131

@@ -54,6 +54,36 @@ def lift_data(data_input, lift=1):
5454
return data_input + np.abs(np.min(data_input) * lift)
5555

5656

57+
def construct_stretching_matrix(components, number_of_components, number_of_signals):
58+
"""Constructs the stretching factor matrix
59+
60+
Parameters
61+
----------
62+
components: tuple of ComponentSignal objects
63+
The tuple containing the component signals in ComponentSignal objects.
64+
number_of_signals: int
65+
The number of signals in the data provided by the user.
66+
67+
Returns
68+
-------
69+
2d array
70+
The matrix containing the stretching factors for the component signals for each of the signals in the raw data.
71+
Has dimensions `component_signal` x `number_of_signals`
72+
73+
"""
74+
if (len(components)) == 0:
75+
raise ValueError(f"Number of components = {number_of_components}. Number_of_components must be >= 1.")
76+
number_of_components = len(components)
77+
78+
if number_of_signals <= 0:
79+
raise ValueError(f"Number of signals = {number_of_signals}. Number_of_signals must be >= 1.")
80+
81+
stretching_factor_matrix = np.zeros((number_of_components, number_of_signals))
82+
for i, component in enumerate(components):
83+
stretching_factor_matrix[i] = component.stretching_factors
84+
return stretching_factor_matrix
85+
86+
5787
def initialize_arrays(number_of_components, number_of_moments, signal_length):
5888
"""Generates the initial guesses for the weight, stretching, and component matrices
5989

diffpy/snmf/tests/test_subroutines.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
22
import numpy as np
3+
from diffpy.snmf.containers import ComponentSignal
34
from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
4-
update_weights_matrix, initialize_arrays, lift_data, initialize_components
5+
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix
56

67
to = [
78
([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
@@ -161,3 +162,27 @@ def test_initialize_components(tcc):
161162
assert len(actual) == tcc[0]
162163
assert len(actual[0].weights) == tcc[1]
163164
assert (actual[0].grid == np.array(tcc[2])).all()
165+
166+
tcso =[([ComponentSignal([0,.5,1,1.5],20,0)],1,20),
167+
([ComponentSignal([0,.5,1,1.5],20,0)],4,20),
168+
# ([ComponentSignal([0,.5,1,1.5],20,0)],0,20), # Raises an exception
169+
# ([ComponentSignal([0,.5,1,1.5],20,0)],-2,20), # Raises an exception
170+
# ([ComponentSignal([0,.5,1,1.5],20,0)],1,0), # Raises an Exception
171+
# ([ComponentSignal([0,.5,1,1.5],20,0)],1,-3), # Raises an exception
172+
([ComponentSignal([0,.5,1,1.5],20,0),ComponentSignal([0,.5,1,1.5],20,1)],2,20),
173+
([ComponentSignal([0,.5,1,1.5],20,0),ComponentSignal([0,.5,1,21.5],20,1)],2,20),
174+
([ComponentSignal([0,1,1.5],20,0),ComponentSignal([0,.5,1,21.5],20,1)],2,20),
175+
# ([ComponentSignal([0,.5,1,1.5],20,0),ComponentSignal([0,.5,1,1.5],20,1)],1,-3), # Negative signal length. Raises an exception
176+
#([],1,20), # Empty components. Raises an Exception
177+
#([],-1,20), # Empty components with negative number of components. Raises an exception
178+
#([],0,20), # Empty components with zero number of components. Raises an exception
179+
#([],1,0), # Empty components with zero signal length. Raises an exception.
180+
#([],-1,-2), # Empty components with negative number of components and signal length Raises an exception.
181+
182+
]
183+
@pytest.mark.parametrize('tcso',tcso)
184+
def test_construct_stretching_matrix(tcso):
185+
actual = construct_stretching_matrix(tcso[0],tcso[1],tcso[2])
186+
for component in tcso[0]:
187+
np.testing.assert_allclose(actual[component.id,:], component.stretching_factors)
188+
#assert actual[component.id, :] == component.stretching_factors

0 commit comments

Comments
 (0)