Skip to content

Commit 9f8da0d

Browse files
cc-atsmbland
authored andcommitted
First pass at duckdb data interface
1 parent 01760a2 commit 9f8da0d

3 files changed

Lines changed: 225 additions & 2 deletions

File tree

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
id,description,type,unit
1+
name,description,type,unit
22
electricity,Electricity,energy,PJ
33
gas,Gas,energy,PJ
44
heat,Heat,energy,PJ
55
wind,Wind,energy,PJ
6-
C02f,Carbon dioxide,energy,kt
6+
CO2f,Carbon dioxide,energy,kt

src/muse/new_input/readers.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import duckdb
2+
import numpy as np
3+
import xarray as xr
4+
5+
6+
def read_inputs(data_dir):
7+
data = {}
8+
con = duckdb.connect(":memory:")
9+
10+
with open(data_dir / "regions.csv") as f:
11+
regions = read_regions_csv(f, con) # noqa: F841
12+
13+
with open(data_dir / "commodities.csv") as f:
14+
commodities = read_commodities_csv(f, con)
15+
16+
with open(data_dir / "demand.csv") as f:
17+
demand = read_demand_csv(f, con) # noqa: F841
18+
19+
data["global_commodities"] = calculate_global_commodities(commodities)
20+
return data
21+
22+
23+
def read_regions_csv(buffer_, con):
24+
sql = """CREATE TABLE regions (
25+
name VARCHAR PRIMARY KEY,
26+
);
27+
"""
28+
con.sql(sql)
29+
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
30+
con.sql("INSERT INTO regions SELECT name FROM rel;")
31+
return con.sql("SELECT name from regions").fetchnumpy()
32+
33+
34+
def read_commodities_csv(buffer_, con):
35+
sql = """CREATE TABLE commodities (
36+
name VARCHAR PRIMARY KEY,
37+
type VARCHAR CHECK (type IN ('energy', 'service', 'material', 'environmental')),
38+
unit VARCHAR,
39+
);
40+
"""
41+
con.sql(sql)
42+
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
43+
con.sql("INSERT INTO commodities SELECT name, type, unit FROM rel;")
44+
45+
return con.sql("select name, type, unit from commodities").fetchnumpy()
46+
47+
48+
def calculate_global_commodities(commodities):
49+
names = commodities["name"].astype(np.dtype("str"))
50+
types = commodities["type"].astype(np.dtype("str"))
51+
units = commodities["unit"].astype(np.dtype("str"))
52+
53+
type_array = xr.DataArray(
54+
data=types, dims=["commodity"], coords=dict(commodity=names)
55+
)
56+
57+
unit_array = xr.DataArray(
58+
data=units, dims=["commodity"], coords=dict(commodity=names)
59+
)
60+
61+
data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array))
62+
return data
63+
64+
65+
def read_demand_csv(buffer_, con):
66+
sql = """CREATE TABLE demand (
67+
year BIGINT,
68+
commodity VARCHAR REFERENCES commodities(name),
69+
region VARCHAR REFERENCES regions(name),
70+
demand DOUBLE,
71+
);
72+
"""
73+
con.sql(sql)
74+
rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841
75+
con.sql("INSERT INTO demand SELECT year, commodity_name, region, demand FROM rel;")
76+
return con.sql("SELECT * from demand").fetchnumpy()

tests/test_readers.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from io import StringIO
12
from itertools import chain, permutations
23
from pathlib import Path
34

45
import pandas as pd
6+
import duckdb
7+
import numpy as np
58
import toml
69
import xarray as xr
710
from pytest import fixture, mark, raises
@@ -314,3 +317,147 @@ def test_get_nan_coordinates():
314317
dataset3 = xr.Dataset.from_dataframe(df3.set_index(["region", "year"]))
315318
nan_coords3 = get_nan_coordinates(dataset3)
316319
assert nan_coords3 == []
320+
321+
322+
@fixture
323+
def default_new_input(tmp_path):
324+
from muse.examples import copy_model
325+
326+
copy_model("default_new_input", tmp_path)
327+
return tmp_path / "model"
328+
329+
330+
@fixture
331+
def con():
332+
return duckdb.connect(":memory:")
333+
334+
335+
@fixture
336+
def populate_regions(default_new_input, con):
337+
from muse.new_input.readers import read_regions_csv
338+
339+
with open(default_new_input / "regions.csv") as f:
340+
return read_regions_csv(f, con)
341+
342+
343+
@fixture
344+
def populate_commodities(default_new_input, con):
345+
from muse.new_input.readers import read_commodities_csv
346+
347+
with open(default_new_input / "commodities.csv") as f:
348+
return read_commodities_csv(f, con)
349+
350+
351+
@fixture
352+
def populate_demand(default_new_input, con, populate_regions, populate_commodities):
353+
from muse.new_input.readers import read_demand_csv
354+
355+
with open(default_new_input / "demand.csv") as f:
356+
return read_demand_csv(f, con)
357+
358+
359+
def test_read_regions(populate_regions):
360+
assert populate_regions["name"] == np.array(["R1"])
361+
362+
363+
def test_read_new_global_commodities(populate_commodities):
364+
data = populate_commodities
365+
assert list(data["name"]) == ["electricity", "gas", "heat", "wind", "CO2f"]
366+
assert list(data["type"]) == ["energy"] * 5
367+
assert list(data["unit"]) == ["PJ"] * 4 + ["kt"]
368+
369+
370+
def test_calculate_global_commodities(populate_commodities):
371+
from muse.new_input.readers import calculate_global_commodities
372+
373+
data = calculate_global_commodities(populate_commodities)
374+
375+
assert isinstance(data, xr.Dataset)
376+
assert set(data.dims) == {"commodity"}
377+
for dt in data.dtypes.values():
378+
assert np.issubdtype(dt, np.dtype("str"))
379+
380+
assert list(data.coords["commodity"].values) == list(populate_commodities["name"])
381+
assert list(data.data_vars["type"].values) == list(populate_commodities["type"])
382+
assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"])
383+
384+
385+
def test_read_new_global_commodities_type_constraint(default_new_input, con):
386+
from muse.new_input.readers import read_commodities_csv
387+
388+
csv = StringIO("name,type,unit\nfoo,invalid,bar\n")
389+
with raises(duckdb.ConstraintException):
390+
read_commodities_csv(csv, con)
391+
392+
393+
def test_new_read_demand_csv(populate_demand):
394+
data = populate_demand
395+
assert np.all(data["year"] == np.array([2020, 2050]))
396+
assert np.all(data["commodity"] == np.array(["heat", "heat"]))
397+
assert np.all(data["region"] == np.array(["R1", "R1"]))
398+
assert np.all(data["demand"] == np.array([10, 30]))
399+
400+
401+
def test_new_read_demand_csv_commodity_constraint(
402+
default_new_input, con, populate_commodities, populate_regions
403+
):
404+
from muse.new_input.readers import read_demand_csv
405+
406+
csv = StringIO("year,commodity_name,region,demand\n2020,invalid,R1,0\n")
407+
with raises(duckdb.ConstraintException, match=".*foreign key.*"):
408+
read_demand_csv(csv, con)
409+
410+
411+
def test_new_read_demand_csv_region_constraint(
412+
default_new_input, con, populate_commodities, populate_regions
413+
):
414+
from muse.new_input.readers import read_demand_csv
415+
416+
csv = StringIO("year,commodity_name,region,demand\n2020,heat,invalid,0\n")
417+
with raises(duckdb.ConstraintException, match=".*foreign key.*"):
418+
read_demand_csv(csv, con)
419+
420+
421+
@mark.xfail
422+
def test_demand_dataset(default_new_input):
423+
import duckdb
424+
from muse.new_input.readers import read_commodities, read_demand, read_regions
425+
426+
con = duckdb.connect(":memory:")
427+
428+
read_regions(default_new_input, con)
429+
read_commodities(default_new_input, con)
430+
data = read_demand(default_new_input, con)
431+
432+
assert isinstance(data, xr.DataArray)
433+
assert data.dtype == np.float64
434+
435+
assert set(data.dims) == {"year", "commodity", "region", "timeslice"}
436+
assert list(data.coords["region"].values) == ["R1"]
437+
assert list(data.coords["timeslice"].values) == list(range(1, 7))
438+
assert list(data.coords["year"].values) == [2020, 2050]
439+
assert set(data.coords["commodity"].values) == {
440+
"electricity",
441+
"gas",
442+
"heat",
443+
"wind",
444+
"CO2f",
445+
}
446+
447+
assert data.sel(year=2020, commodity="electricity", region="R1", timeslice=0) == 1
448+
449+
450+
@mark.xfail
451+
def test_new_read_initial_market(default_new_input):
452+
from muse.new_input.readers import read_inputs
453+
454+
all_data = read_inputs(default_new_input)
455+
data = all_data["initial_market"]
456+
457+
assert isinstance(data, xr.Dataset)
458+
assert set(data.dims) == {"region", "year", "commodity", "timeslice"}
459+
assert dict(data.dtypes) == dict(
460+
prices=np.float64,
461+
exports=np.float64,
462+
imports=np.float64,
463+
static_trade=np.float64,

0 commit comments

Comments
 (0)