We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 726a7e2 commit db1c456Copy full SHA for db1c456
2 files changed
diffpy/snmf/subroutines.py
@@ -449,7 +449,7 @@ def get_residual_matrix(component_matrix, weights_matrix, stretching_matrix, dat
449
return residual_matrx
450
451
452
-def reconstruct_data(components, input_data):
+def reconstruct_data(components):
453
"""Reconstructs the `input_data` matrix
454
455
Reconstructs the `input_data` matrix from calculated component signals, weights, and stretching factors.
@@ -458,13 +458,16 @@ def reconstruct_data(components, input_data):
458
----------
459
components: tuple of ComponentSignal objects
460
The tuple containing the component signals.
461
- input_data: 2d array
462
- The 2d array containing the user provided signals.
463
464
Returns
465
-------
466
2d array
467
The 2d array containing the reconstruction of input_data.
468
469
"""
470
- pass
+ signal_length = len(components[0].iq)
+ number_of_signal = len(components[0].weights)
+ data_reconstruction = np.zeros((signal_length, number_of_signal))
471
+ for signal in range(number_of_signal):
472
+ data_reconstruction[:, signal] = reconstruct_signal(components, signal)
473
+ return data_reconstruction
diffpy/snmf/tests/test_subroutines.py
@@ -108,13 +108,22 @@ def test_get_residual_matrix(tgrm):
108
109
110
trd = [
111
+ ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
112
+ ComponentSignal([0, .25, .5, .75, 1], 2, 2)]),
113
+ ([ComponentSignal([0, .25, .5, .75, 1], 2, 0)]),
114
115
+ ComponentSignal([0, .25, .5, .75, 1], 2, 2), ComponentSignal([0, .25, .5, .75, 1], 2, 3),
116
+ ComponentSignal([0, .25, .5, .75, 1], 2, 4)]),
117
+ #([]) # Exception expected
118
119
]
120
121
122
@pytest.mark.parametrize('trd', trd)
123
def test_reconstruct_data(trd):
- assert False
124
+ actual = reconstruct_data(trd)
125
+ assert actual.shape == (len(trd[0].iq),len(trd[0].weights))
126
+ print(actual)
127
128
129
tld = [(([[[1, -1, 1], [0, 0, 0], [2, 10, -3]], 1]), ([[4, 2, 4], [3, 3, 3], [5, 13, 0]])),
0 commit comments