Skip to content

Commit 679d3ae

Browse files
committed
Split new tests into new file
1 parent 8a58379 commit 679d3ae

2 files changed

Lines changed: 222 additions & 173 deletions

File tree

tests/test_new_readers.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
from io import StringIO
2+
3+
import duckdb
4+
import numpy as np
5+
import xarray as xr
6+
from pytest import approx, fixture, mark, raises
7+
8+
9+
@fixture
10+
def default_new_input(tmp_path):
11+
from muse.examples import copy_model
12+
13+
copy_model("default_new_input", tmp_path)
14+
return tmp_path / "model"
15+
16+
17+
@fixture
18+
def con():
19+
return duckdb.connect(":memory:")
20+
21+
22+
@fixture
23+
def populate_commodities(default_new_input, con):
24+
from muse.new_input.readers import read_commodities_csv
25+
26+
with open(default_new_input / "commodities.csv") as f:
27+
return read_commodities_csv(f, con)
28+
29+
30+
@fixture
31+
def populate_demand(default_new_input, con, populate_regions, populate_commodities):
32+
from muse.new_input.readers import read_demand_csv
33+
34+
with open(default_new_input / "demand.csv") as f:
35+
return read_demand_csv(f, con)
36+
37+
38+
@fixture
39+
def populate_regions(default_new_input, con):
40+
from muse.new_input.readers import read_regions_csv
41+
42+
with open(default_new_input / "regions.csv") as f:
43+
return read_regions_csv(f, con)
44+
45+
46+
def test_read_regions(populate_regions):
47+
assert populate_regions["id"] == np.array(["R1"])
48+
49+
50+
def test_read_new_global_commodities(populate_commodities):
51+
data = populate_commodities
52+
assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"]
53+
assert list(data["type"]) == ["energy"] * 5
54+
assert list(data["unit"]) == ["PJ"] * 4 + ["kt"]
55+
56+
57+
def test_calculate_global_commodities(populate_commodities):
58+
from muse.new_input.readers import calculate_global_commodities
59+
60+
data = calculate_global_commodities(populate_commodities)
61+
62+
assert isinstance(data, xr.Dataset)
63+
assert set(data.dims) == {"commodity"}
64+
for dt in data.dtypes.values():
65+
assert np.issubdtype(dt, np.dtype("str"))
66+
67+
assert list(data.coords["commodity"].values) == list(populate_commodities["id"])
68+
assert list(data.data_vars["type"].values) == list(populate_commodities["type"])
69+
assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"])
70+
71+
72+
def test_read_new_global_commodities_type_constraint(default_new_input, con):
73+
from muse.new_input.readers import read_commodities_csv
74+
75+
csv = StringIO("id,type,unit\nfoo,invalid,bar\n")
76+
with raises(duckdb.ConstraintException):
77+
read_commodities_csv(csv, con)
78+
79+
80+
def test_new_read_demand_csv(populate_demand):
81+
data = populate_demand
82+
assert np.all(data["year"] == np.array([2020, 2050]))
83+
assert np.all(data["commodity"] == np.array(["heat", "heat"]))
84+
assert np.all(data["region"] == np.array(["R1", "R1"]))
85+
assert np.all(data["demand"] == np.array([10, 30]))
86+
87+
88+
def test_new_read_demand_csv_commodity_constraint(
89+
default_new_input, con, populate_commodities, populate_regions
90+
):
91+
from muse.new_input.readers import read_demand_csv
92+
93+
csv = StringIO("year,commodity_id,region_id,demand\n2020,invalid,R1,0\n")
94+
with raises(duckdb.ConstraintException, match=".*foreign key.*"):
95+
read_demand_csv(csv, con)
96+
97+
98+
def test_new_read_demand_csv_region_constraint(
99+
default_new_input, con, populate_commodities, populate_regions
100+
):
101+
from muse.new_input.readers import read_demand_csv
102+
103+
csv = StringIO("year,commodity_id,region_id,demand\n2020,heat,invalid,0\n")
104+
with raises(duckdb.ConstraintException, match=".*foreign key.*"):
105+
read_demand_csv(csv, con)
106+
107+
108+
@mark.xfail
109+
def test_demand_dataset(default_new_input):
110+
import duckdb
111+
112+
from muse.new_input.readers import read_commodities, read_demand, read_regions
113+
114+
con = duckdb.connect(":memory:")
115+
116+
read_regions(default_new_input, con)
117+
read_commodities(default_new_input, con)
118+
data = read_demand(default_new_input, con)
119+
120+
assert isinstance(data, xr.DataArray)
121+
assert data.dtype == np.float64
122+
123+
assert set(data.dims) == {"year", "commodity", "region", "timeslice"}
124+
assert list(data.coords["region"].values) == ["R1"]
125+
assert list(data.coords["timeslice"].values) == list(range(1, 7))
126+
assert list(data.coords["year"].values) == [2020, 2050]
127+
assert set(data.coords["commodity"].values) == {
128+
"electricity",
129+
"gas",
130+
"heat",
131+
"wind",
132+
"CO2f",
133+
}
134+
135+
assert data.sel(year=2020, commodity="electricity", region="R1", timeslice=0) == 1
136+
137+
138+
@mark.xfail
139+
def test_new_read_initial_market(default_new_input):
140+
from muse.new_input.readers import read_inputs
141+
142+
all_data = read_inputs(default_new_input)
143+
data = all_data["initial_market"]
144+
145+
assert isinstance(data, xr.Dataset)
146+
assert set(data.dims) == {"region", "year", "commodity", "timeslice"}
147+
assert dict(data.dtypes) == dict(
148+
prices=np.float64,
149+
exports=np.float64,
150+
imports=np.float64,
151+
static_trade=np.float64,
152+
)
153+
assert list(data.coords["region"].values) == ["R1"]
154+
assert list(data.coords["year"].values) == list(range(2010, 2105, 5))
155+
assert list(data.coords["commodity"].values) == [
156+
"electricity",
157+
"gas",
158+
"heat",
159+
"CO2f",
160+
"wind",
161+
]
162+
month_values = ["all-year"] * 6
163+
day_values = ["all-week"] * 6
164+
hour_values = [
165+
"night",
166+
"morning",
167+
"afternoon",
168+
"early-peak",
169+
"late-peak",
170+
"evening",
171+
]
172+
173+
assert list(data.coords["timeslice"].values) == list(
174+
zip(month_values, day_values, hour_values)
175+
)
176+
assert list(data.coords["month"]) == month_values
177+
assert list(data.coords["day"]) == day_values
178+
assert list(data.coords["hour"]) == hour_values
179+
180+
assert all(var.coords.equals(data.coords) for var in data.data_vars.values())
181+
182+
prices = data.data_vars["prices"]
183+
assert approx(
184+
prices.sel(
185+
year=2010,
186+
region="R1",
187+
commodity="electricity",
188+
timeslice=("all-year", "all-week", "night"),
189+
)
190+
- 14.81481,
191+
abs=1e-4,
192+
)
193+
194+
exports = data.data_vars["exports"]
195+
assert (
196+
exports.sel(
197+
year=2010,
198+
region="R1",
199+
commodity="electricity",
200+
timeslice=("all-year", "all-week", "night"),
201+
)
202+
) == 0
203+
204+
imports = data.data_vars["imports"]
205+
assert (
206+
imports.sel(
207+
year=2010,
208+
region="R1",
209+
commodity="electricity",
210+
timeslice=("all-year", "all-week", "night"),
211+
)
212+
) == 0
213+
214+
static_trade = data.data_vars["static_trade"]
215+
assert (
216+
static_trade.sel(
217+
year=2010,
218+
region="R1",
219+
commodity="electricity",
220+
timeslice=("all-year", "all-week", "night"),
221+
)
222+
) == 0

0 commit comments

Comments
 (0)