|
| 1 | +import pandas as pd |
| 2 | +import pytest |
| 3 | +from jax import config |
| 4 | +from numpy.testing import assert_array_almost_equal as aaae |
| 5 | + |
| 6 | +from skillmodels.variance_decomposition import ( |
| 7 | + create_dataset_with_variance_decomposition, |
| 8 | +) |
| 9 | + |
| 10 | +config.update("jax_enable_x64", True) |
| 11 | + |
| 12 | +# ====================================================================================== |
| 13 | +# Variance decomposition |
| 14 | +# ====================================================================================== |
| 15 | + |
| 16 | + |
| 17 | +@pytest.fixture |
| 18 | +def setup_variance_decomposition(): |
| 19 | + data1 = { |
| 20 | + "fac1": [0.1, 0.1, 0.1, 0.2], |
| 21 | + "fac2": [0.1] * 4, |
| 22 | + "fac3": [0.2, 0.2, 0.2, 0.4], |
| 23 | + "mixture": [0] * 4, |
| 24 | + "period": [0] * 4, |
| 25 | + "id": [0] * 4, |
| 26 | + } |
| 27 | + setup_filtered_states = pd.DataFrame(data1) |
| 28 | + |
| 29 | + value_loadings = [1, 0, 0] + [0, 0.1, 0] + [0, 0, 2] |
| 30 | + value_meas_sds = [0.05, 1.1, 0.1] |
| 31 | + iterables1 = [[0], ["y1", "y2", "y3"], ["fac1", "fac2", "fac3"]] |
| 32 | + index1 = pd.MultiIndex.from_product(iterables1, names=["period", "name1", "name2"]) |
| 33 | + setup_loadings = pd.DataFrame(value_loadings, index=index1, columns=["value"]) |
| 34 | + iterables2 = [[0], ["y1", "y2", "y3"]] |
| 35 | + index2 = pd.MultiIndex.from_product(iterables2, names=["period", "name1"]) |
| 36 | + |
| 37 | + setup_meas = pd.DataFrame(value_meas_sds, index=index2, columns=["value"]) |
| 38 | + setup_meas["name2"] = "-" |
| 39 | + setup_meas = setup_meas.reset_index() |
| 40 | + setup_meas = setup_meas.set_index(["period", "name1", "name2"]) |
| 41 | + setup_params = pd.concat( |
| 42 | + [setup_loadings, setup_meas], keys=["loadings", "meas_sds"] |
| 43 | + ) |
| 44 | + |
| 45 | + args = {"filtered_states": setup_filtered_states, "params": setup_params} |
| 46 | + return args |
| 47 | + |
| 48 | + |
| 49 | +@pytest.fixture |
| 50 | +def expected_variance_decomposition(): |
| 51 | + value3 = [ |
| 52 | + [1, 0.0025, 0.05, 0.5, 0.5], |
| 53 | + [0.1, 0, 1.1, 1, 0], |
| 54 | + [2, 0.01, 0.1, 0.2, 0.8], |
| 55 | + ] |
| 56 | + iterables3 = [(0, "y1", "fac1"), (0, "y2", "fac2"), (0, "y3", "fac3")] |
| 57 | + index3 = pd.MultiIndex.from_tuples(iterables3, names=("period", "name1", "name2")) |
| 58 | + expected_result = pd.DataFrame( |
| 59 | + value3, |
| 60 | + index=index3, |
| 61 | + columns=[ |
| 62 | + "loadings", |
| 63 | + "variance of factor", |
| 64 | + "meas_sds", |
| 65 | + "fraction due to meas error", |
| 66 | + "fraction due to factor var", |
| 67 | + ], |
| 68 | + ) |
| 69 | + return expected_result |
| 70 | + |
| 71 | + |
| 72 | +def test_variance_decomposition( |
| 73 | + setup_variance_decomposition, expected_variance_decomposition |
| 74 | +): |
| 75 | + aaae( |
| 76 | + create_dataset_with_variance_decomposition(**setup_variance_decomposition), |
| 77 | + expected_variance_decomposition, |
| 78 | + ) |
0 commit comments