Skip to content

Commit dca8fca

Browse files
effieHANhmgaudecker
authored andcommitted
add test for variance decomposition
1 parent e8789ce commit dca8fca

1 file changed

Lines changed: 78 additions & 0 deletions

File tree

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import pandas as pd
2+
import pytest
3+
from jax import config
4+
from numpy.testing import assert_array_almost_equal as aaae
5+
6+
from skillmodels.variance_decomposition import (
7+
create_dataset_with_variance_decomposition,
8+
)
9+
10+
config.update("jax_enable_x64", True)
11+
12+
# ======================================================================================
13+
# Variance decomposition
14+
# ======================================================================================
15+
16+
17+
@pytest.fixture
18+
def setup_variance_decomposition():
19+
data1 = {
20+
"fac1": [0.1, 0.1, 0.1, 0.2],
21+
"fac2": [0.1] * 4,
22+
"fac3": [0.2, 0.2, 0.2, 0.4],
23+
"mixture": [0] * 4,
24+
"period": [0] * 4,
25+
"id": [0] * 4,
26+
}
27+
setup_filtered_states = pd.DataFrame(data1)
28+
29+
value_loadings = [1, 0, 0] + [0, 0.1, 0] + [0, 0, 2]
30+
value_meas_sds = [0.05, 1.1, 0.1]
31+
iterables1 = [[0], ["y1", "y2", "y3"], ["fac1", "fac2", "fac3"]]
32+
index1 = pd.MultiIndex.from_product(iterables1, names=["period", "name1", "name2"])
33+
setup_loadings = pd.DataFrame(value_loadings, index=index1, columns=["value"])
34+
iterables2 = [[0], ["y1", "y2", "y3"]]
35+
index2 = pd.MultiIndex.from_product(iterables2, names=["period", "name1"])
36+
37+
setup_meas = pd.DataFrame(value_meas_sds, index=index2, columns=["value"])
38+
setup_meas["name2"] = "-"
39+
setup_meas = setup_meas.reset_index()
40+
setup_meas = setup_meas.set_index(["period", "name1", "name2"])
41+
setup_params = pd.concat(
42+
[setup_loadings, setup_meas], keys=["loadings", "meas_sds"]
43+
)
44+
45+
args = {"filtered_states": setup_filtered_states, "params": setup_params}
46+
return args
47+
48+
49+
@pytest.fixture
50+
def expected_variance_decomposition():
51+
value3 = [
52+
[1, 0.0025, 0.05, 0.5, 0.5],
53+
[0.1, 0, 1.1, 1, 0],
54+
[2, 0.01, 0.1, 0.2, 0.8],
55+
]
56+
iterables3 = [(0, "y1", "fac1"), (0, "y2", "fac2"), (0, "y3", "fac3")]
57+
index3 = pd.MultiIndex.from_tuples(iterables3, names=("period", "name1", "name2"))
58+
expected_result = pd.DataFrame(
59+
value3,
60+
index=index3,
61+
columns=[
62+
"loadings",
63+
"variance of factor",
64+
"meas_sds",
65+
"fraction due to meas error",
66+
"fraction due to factor var",
67+
],
68+
)
69+
return expected_result
70+
71+
72+
def test_variance_decomposition(
73+
setup_variance_decomposition, expected_variance_decomposition
74+
):
75+
aaae(
76+
create_dataset_with_variance_decomposition(**setup_variance_decomposition),
77+
expected_variance_decomposition,
78+
)

0 commit comments

Comments
 (0)