11import duckdb
22import numpy as np
3+ import pandas as pd
34import xarray as xr
45
56
67def read_inputs (data_dir ):
78 data = {}
89 con = duckdb .connect (":memory:" )
910
11+ with open (data_dir / "timeslices.csv" ) as f :
12+ timeslices = read_timeslices_csv (f , con )
13+
1014 with open (data_dir / "commodities.csv" ) as f :
1115 commodities = read_commodities_csv (f , con )
1216
17+ with open (data_dir / "regions.csv" ) as f :
18+ regions = read_regions_csv (f , con )
19+
1320 with open (data_dir / "commodity_trade.csv" ) as f :
14- commodity_trade = read_commodity_trade_csv (f , con ) # noqa: F841
21+ commodity_trade = read_commodity_trade_csv (f , con )
1522
1623 with open (data_dir / "commodity_costs.csv" ) as f :
17- commodity_costs = read_commodity_costs_csv (f , con ) # noqa: F841
24+ commodity_costs = read_commodity_costs_csv (f , con )
1825
1926 with open (data_dir / "demand.csv" ) as f :
20- demand = read_demand_csv (f , con ) # noqa: F841
27+ demand = read_demand_csv (f , con )
2128
2229 with open (data_dir / "demand_slicing.csv" ) as f :
23- demand_slicing = read_demand_slicing_csv (f , con ) # noqa: F841
24-
25- with open (data_dir / "regions.csv" ) as f :
26- regions = read_regions_csv (f , con ) # noqa: F841
30+ demand_slicing = read_demand_slicing_csv (f , con )
2731
2832 data ["global_commodities" ] = calculate_global_commodities (commodities )
33+ data ["demand" ] = calculate_demand (
34+ commodities , regions , timeslices , demand , demand_slicing
35+ )
36+ data ["initial_market" ] = calculate_initial_market (
37+ commodities , regions , timeslices , commodity_trade , commodity_costs
38+ )
2939 return data
3040
3141
42+ def read_timeslices_csv (buffer_ , con ):
43+ sql = """CREATE TABLE timeslices (
44+ id VARCHAR PRIMARY KEY,
45+ season VARCHAR,
46+ day VARCHAR,
47+ time_of_day VARCHAR,
48+ fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1),
49+ );
50+ """
51+ con .sql (sql )
52+ rel = con .read_csv (buffer_ , header = True , delimiter = "," ) # noqa: F841
53+ con .sql (
54+ "INSERT INTO timeslices SELECT id, season, day, time_of_day, fraction FROM rel;"
55+ )
56+ return con .sql ("SELECT * from timeslices" ).fetchnumpy ()
57+
58+
3259def read_commodities_csv (buffer_ , con ):
3360 sql = """CREATE TABLE commodities (
3461 id VARCHAR PRIMARY KEY,
@@ -42,13 +69,25 @@ def read_commodities_csv(buffer_, con):
4269 return con .sql ("select * from commodities" ).fetchnumpy ()
4370
4471
72+ def read_regions_csv (buffer_ , con ):
73+ sql = """CREATE TABLE regions (
74+ id VARCHAR PRIMARY KEY,
75+ );
76+ """
77+ con .sql (sql )
78+ rel = con .read_csv (buffer_ , header = True , delimiter = "," ) # noqa: F841
79+ con .sql ("INSERT INTO regions SELECT id FROM rel;" )
80+ return con .sql ("SELECT * from regions" ).fetchnumpy ()
81+
82+
4583def read_commodity_trade_csv (buffer_ , con ):
4684 sql = """CREATE TABLE commodity_trade (
4785 commodity VARCHAR REFERENCES commodities(id),
4886 region VARCHAR REFERENCES regions(id),
4987 year BIGINT,
5088 import DOUBLE,
5189 export DOUBLE,
90+ PRIMARY KEY (commodity, region, year)
5291 );
5392 """
5493 con .sql (sql )
@@ -64,6 +103,7 @@ def read_commodity_costs_csv(buffer_, con):
64103 region VARCHAR REFERENCES regions(id),
65104 year BIGINT,
66105 value DOUBLE,
106+ PRIMARY KEY (commodity, region, year)
67107 );
68108 """
69109 con .sql (sql )
@@ -79,6 +119,7 @@ def read_demand_csv(buffer_, con):
79119 region VARCHAR REFERENCES regions(id),
80120 year BIGINT,
81121 demand DOUBLE,
122+ PRIMARY KEY (commodity, region, year)
82123 );
83124 """
84125 con .sql (sql )
@@ -92,28 +133,19 @@ def read_demand_slicing_csv(buffer_, con):
92133 commodity VARCHAR REFERENCES commodities(id),
93134 region VARCHAR REFERENCES regions(id),
94135 year BIGINT,
95- timeslice VARCHAR,
136+ timeslice VARCHAR REFERENCES timeslices(id) ,
96137 fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1),
138+ PRIMARY KEY (commodity, region, year, timeslice),
139+ FOREIGN KEY (commodity, region, year) REFERENCES demand(commodity, region, year)
97140 );
98141 """
99142 con .sql (sql )
100143 rel = con .read_csv (buffer_ , header = True , delimiter = "," ) # noqa: F841
101144 con .sql ("""INSERT INTO demand_slicing SELECT
102- commodity_id, region_id, year, timeslice , fraction FROM rel;""" )
145+ commodity_id, region_id, year, timeslice_id , fraction FROM rel;""" )
103146 return con .sql ("SELECT * from demand_slicing" ).fetchnumpy ()
104147
105148
106- def read_regions_csv (buffer_ , con ):
107- sql = """CREATE TABLE regions (
108- id VARCHAR PRIMARY KEY,
109- );
110- """
111- con .sql (sql )
112- rel = con .read_csv (buffer_ , header = True , delimiter = "," ) # noqa: F841
113- con .sql ("INSERT INTO regions SELECT id FROM rel;" )
114- return con .sql ("SELECT * from regions" ).fetchnumpy ()
115-
116-
117149def calculate_global_commodities (commodities ):
118150 names = commodities ["id" ].astype (np .dtype ("str" ))
119151 types = commodities ["type" ].astype (np .dtype ("str" ))
@@ -129,3 +161,195 @@ def calculate_global_commodities(commodities):
129161
130162 data = xr .Dataset (data_vars = dict (type = type_array , unit = unit_array ))
131163 return data
164+
165+
166+ def calculate_demand (
167+ commodities , regions , timeslices , demand , demand_slicing
168+ ) -> xr .DataArray :
169+ """Calculate demand data for all commodities, regions, years, and timeslices.
170+
171+ Result: A DataArray with a demand value for every combination of:
172+ - commodity: all commodities specified in the commodities table
173+ - region: all regions specified in the regions table
174+ - year: all years specified in the demand table
175+ - timeslice: all timeslices specified in the timeslices table
176+
177+ Checks:
178+ - If demand data is specified for one year, it must be specified for all years.
179+ - If demand is nonzero, slicing data must be present.
180+ - If slicing data is specified for a commodity/region/year, the sum of
181+ the fractions must be 1, and all timeslices must be present.
182+
183+ Fills:
184+ - If demand data is not specified for a commodity/region combination, the demand is
185+ 0 for all years and timeslices.
186+
187+ Todo:
188+ - Interpolation to allow for missing years in demand data.
189+ - Ability to leave the year field blank in both tables to indicate all years
190+ - Allow slicing data to be missing -> demand is spread equally across timeslices
191+ - Allow more flexibility for timeslices (e.g. can specify "winter" to apply to all
192+ winter timeslices, or "all" to apply to all timeslices)
193+ """
194+ # Prepare dataframes
195+ df_demand = pd .DataFrame (demand ).set_index (["commodity" , "region" , "year" ])
196+ df_slicing = pd .DataFrame (demand_slicing ).set_index (
197+ ["commodity" , "region" , "year" , "timeslice" ]
198+ )
199+
200+ # DataArray dimensions
201+ all_commodities = commodities ["id" ].astype (np .dtype ("str" ))
202+ all_regions = regions ["id" ].astype (np .dtype ("str" ))
203+ all_years = df_demand .index .get_level_values ("year" ).unique ()
204+ all_timeslices = timeslices ["id" ].astype (np .dtype ("str" ))
205+
206+ # CHECK: all years are specified for each commodity/region combination
207+ check_all_values_specified (df_demand , ["commodity" , "region" ], "year" , all_years )
208+
209+ # CHECK: if slicing data is present, all timeslices must be specified
210+ check_all_values_specified (
211+ df_slicing , ["commodity" , "region" , "year" ], "timeslice" , all_timeslices
212+ )
213+
214+ # CHECK: timeslice fractions sum to 1
215+ check_timeslice_sum = df_slicing .groupby (["commodity" , "region" , "year" ]).apply (
216+ lambda x : np .isclose (x ["fraction" ].sum (), 1 )
217+ )
218+ if not check_timeslice_sum .all ():
219+ raise DataValidationError
220+
221+ # CHECK: if demand data >0, fraction data must be specified
222+ check_fraction_data_present = (
223+ df_demand [df_demand ["demand" ] > 0 ]
224+ .index .isin (df_slicing .droplevel ("timeslice" ).index )
225+ .all ()
226+ )
227+ if not check_fraction_data_present .all ():
228+ raise DataValidationError
229+
230+ # FILL: demand is zero if unspecified
231+ df_demand = df_demand .reindex (
232+ pd .MultiIndex .from_product (
233+ [all_commodities , all_regions , all_years ],
234+ names = ["commodity" , "region" , "year" ],
235+ ),
236+ fill_value = 0 ,
237+ )
238+
239+ # FILL: slice data is zero if unspecified
240+ df_slicing = df_slicing .reindex (
241+ pd .MultiIndex .from_product (
242+ [all_commodities , all_regions , all_years , all_timeslices ],
243+ names = ["commodity" , "region" , "year" , "timeslice" ],
244+ ),
245+ fill_value = 0 ,
246+ )
247+
248+ # Create DataArray
249+ da_demand = df_demand .to_xarray ()["demand" ]
250+ da_slicing = df_slicing .to_xarray ()["fraction" ]
251+ data = da_demand * da_slicing
252+ return data
253+
254+
255+ def calculate_initial_market (
256+ commodities , regions , timeslices , commodity_trade , commodity_costs
257+ ) -> xr .Dataset :
258+ """Calculate trade and price data for all commodities, regions and years.
259+
260+ Result: A Dataset with variables:
261+ - prices
262+ - exports
263+ - imports
264+ - static_trade
265+ For every combination of:
266+ - commodity: all commodities specified in the commodities table
267+ - region: all regions specified in the regions table
268+ - year: all years specified in the commodity_costs table
269+ - timeslice (multiindex): all timeslices specified in the timeslices table
270+
271+ Checks:
272+ - If trade data is specified for one year, it must be specified for all years.
273+ - If price data is specified for one year, it must be specified for all years.
274+
275+ Fills:
276+ - If trade data is not specified for a commodity/region combination, imports and
277+ exports are both zero
278+ - If price data is not specified for a commodity/region combination, the price is
279+ zero
280+
281+ """
282+ from muse .timeslices import QuantityType , convert_timeslice
283+
284+ # Prepare dataframes
285+ df_trade = pd .DataFrame (commodity_trade ).set_index (["commodity" , "region" , "year" ])
286+ df_costs = (
287+ pd .DataFrame (commodity_costs )
288+ .set_index (["commodity" , "region" , "year" ])
289+ .rename (columns = {"value" : "prices" })
290+ )
291+ df_timeslices = pd .DataFrame (timeslices ).set_index (["season" , "day" , "time_of_day" ])
292+
293+ # DataArray dimensions
294+ all_commodities = commodities ["id" ].astype (np .dtype ("str" ))
295+ all_regions = regions ["id" ].astype (np .dtype ("str" ))
296+ all_years = df_costs .index .get_level_values ("year" ).unique ()
297+
298+ # CHECK: all years are specified for each commodity/region combination
299+ check_all_values_specified (df_trade , ["commodity" , "region" ], "year" , all_years )
300+ check_all_values_specified (df_costs , ["commodity" , "region" ], "year" , all_years )
301+
302+ # FILL: price is zero if unspecified
303+ df_costs = df_costs .reindex (
304+ pd .MultiIndex .from_product (
305+ [all_commodities , all_regions , all_years ],
306+ names = ["commodity" , "region" , "year" ],
307+ ),
308+ fill_value = 0 ,
309+ )
310+
311+ # FILL: trade is zero if unspecified
312+ df_trade = df_trade .reindex (
313+ pd .MultiIndex .from_product (
314+ [all_commodities , all_regions , all_years ],
315+ names = ["commodity" , "region" , "year" ],
316+ ),
317+ fill_value = 0 ,
318+ )
319+
320+ # Calculate static trade
321+ df_trade ["static_trade" ] = df_trade ["export" ] - df_trade ["import" ]
322+
323+ # Create Data
324+ df_full = df_costs .join (df_trade )
325+ data = df_full .to_xarray ()
326+ ts = df_timeslices .to_xarray ()["fraction" ]
327+ ts = ts .stack (timeslice = ("season" , "day" , "time_of_day" ))
328+ convert_timeslice (data , ts , QuantityType .EXTENSIVE )
329+
330+ return data
331+
332+
333+ class DataValidationError (ValueError ):
334+ pass
335+
336+
337+ def check_all_values_specified (
338+ df : pd .DataFrame , group_by_cols : list [str ], column_name : str , values : list
339+ ) -> None :
340+ """Check that the required values are specified in a dataframe.
341+
342+ Checks that a row exists for all specified values of column_name for each
343+ group in the grouped dataframe.
344+ """
345+ if not (
346+ df .groupby (group_by_cols )
347+ .apply (
348+ lambda x : (
349+ set (x .index .get_level_values (column_name ).unique ()) == set (values )
350+ )
351+ )
352+ .all ()
353+ ).all ():
354+ msg = "" # TODO
355+ raise DataValidationError (msg )
0 commit comments