|
1 | 1 | import pytest |
2 | 2 | import numpy as np |
3 | 3 | from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \ |
4 | | - update_weights_matrix, initialize_arrays, lift_data |
| 4 | + update_weights_matrix, initialize_arrays, lift_data, initialize_components |
5 | 5 |
|
6 | 6 | to = [ |
7 | 7 | ([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14), |
@@ -144,8 +144,20 @@ def test_reconstruct_data(trd): |
144 | 144 | (([[[1.5, 2], [10.5, 1], [0.5, 2]], 1]), ([[2, 2.5], [11, 1.5], [1, 2.5]])), |
145 | 145 | (([[[-10, -10.5], [-12.2, -12.2], [0, 0]], 1]), ([[2.2, 1.7], [0, 0], [12.2, 12.2]])), |
146 | 146 | ] |
| 147 | + |
| 148 | + |
147 | 149 | @pytest.mark.parametrize('tld', tld) |
148 | 150 | def test_lift_data(tld): |
149 | 151 | actual = lift_data(tld[0][0], tld[0][1]) |
150 | 152 | expected = tld[1] |
151 | 153 | np.testing.assert_allclose(actual, expected) |
| 154 | + |
| 155 | +tcc = [(2, 3,[0, .5, 1, 1.5]), # Regular usage |
| 156 | + #(0, 3,[0, .5, 1, 1.5]), # Zero components raise an exception. Not tested |
| 157 | + ] |
| 158 | +@pytest.mark.parametrize('tcc', tcc) |
| 159 | +def test_initialize_components(tcc): |
| 160 | + actual = initialize_components(tcc[0], tcc[1], tcc[2]) |
| 161 | + assert len(actual) == tcc[0] |
| 162 | + assert len(actual[0].weights) == tcc[1] |
| 163 | + assert (actual[0].grid == np.array(tcc[2])).all() |
0 commit comments