Skip to content

Commit e448b09

Browse files
Enforce consistency in objective dimensions (#584)
* Enforce that all objectives have an asset dimension * Check inputs to objectives * Temporarily suppress tests * Bump version * [pre-commit.ci] pre-commit autoupdate (#579) * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.7.4 → v0.8.0](astral-sh/ruff-pre-commit@v0.7.4...v0.8.0) - [github.com/igorshubovych/markdownlint-cli: v0.42.0 → v0.43.0](igorshubovych/markdownlint-cli@v0.42.0...v0.43.0) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Revert "Bump version" This reverts commit 3d223e6. * Add check_dimensions function * Add checks for demand_share dimensions * Add dst_region to check * Fix more checks * And more * Fix function default args * Remove xfail mark * More descriptive comment * Address reviewer comments * Allow function to work with any iterable of strings --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f94ee1b commit e448b09

5 files changed

Lines changed: 140 additions & 60 deletions

File tree

src/muse/demand_share.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def demand_share(
6363
RetrofitAgentInStandardDemandShare,
6464
)
6565
from muse.registration import registrator
66+
from muse.utilities import check_dimensions
6667

6768
DEMAND_SHARE_SIGNATURE = Callable[
6869
[Sequence[AbstractAgent], xr.Dataset, xr.Dataset, KwArg(Any)], xr.DataArray
@@ -102,7 +103,27 @@ def demand_share(
102103

103104
keyword_args = copy(keywords)
104105
keyword_args.update(**kwargs)
105-
return function(agents, market, technologies, **keyword_args)
106+
107+
# Check inputs
108+
check_dimensions(
109+
market,
110+
["commodity", "year", "timeslice", "region"],
111+
optional=["dst_region"],
112+
)
113+
check_dimensions(
114+
technologies,
115+
["technology", "year", "region"],
116+
optional=["timeslice", "commodity", "dst_region"],
117+
)
118+
119+
# Calculate demand share
120+
result = function(agents, market, technologies, **keyword_args)
121+
122+
# Check result
123+
check_dimensions(
124+
result, ["timeslice", "commodity"], optional=["asset", "region"]
125+
) # TODO: asset should be required, but trade model is failing
126+
return result
106127

107128
return cast(DEMAND_SHARE_SIGNATURE, demand_share)
108129

src/muse/objectives.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ def comfort(
4242
these parameters.
4343
4444
Returns:
45-
A DataArray with at least one dimension corresponding to ``replacement``.
46-
Other dimensions can be present, as long as the subsequent decision function knows
47-
how to reduce them.
45+
A DataArray with at least two dimension corresponding to `replacement` and `asset`.
46+
A `timeslice` dimension may also be present.
4847
"""
4948

5049
__all__ = [
@@ -72,7 +71,7 @@ def comfort(
7271
from muse.outputs.cache import cache_quantity
7372
from muse.registration import registrator
7473
from muse.timeslices import broadcast_timeslice, distribute_timeslice, drop_timeslice
75-
from muse.utilities import filter_input
74+
from muse.utilities import check_dimensions, filter_input
7675

7776
OBJECTIVE_SIGNATURE = Callable[
7877
[xr.Dataset, xr.DataArray, xr.DataArray, KwArg(Any)], xr.DataArray
@@ -168,25 +167,30 @@ def register_objective(function: OBJECTIVE_SIGNATURE):
168167
from functools import wraps
169168

170169
@wraps(function)
171-
def decorated_objective(technologies: xr.Dataset, *args, **kwargs) -> xr.DataArray:
170+
def decorated_objective(
171+
technologies: xr.Dataset, demand: xr.DataArray, *args, **kwargs
172+
) -> xr.DataArray:
172173
from logging import getLogger
173174

174-
result = function(technologies, *args, **kwargs)
175+
# Check inputs
176+
check_dimensions(
177+
demand, ["asset", "timeslice", "commodity"], optional=["region"]
178+
)
179+
check_dimensions(
180+
technologies, ["replacement", "commodity"], optional=["timeslice"]
181+
)
182+
183+
# Calculate objective
184+
result = function(technologies, demand, *args, **kwargs)
185+
result.name = function.__name__
175186

187+
# Check result
176188
dtype = result.values.dtype
177189
if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
178190
msg = f"dtype of objective {function.__name__} is not a number ({dtype})"
179191
getLogger(function.__module__).warning(msg)
192+
check_dimensions(result, ["replacement", "asset"], optional=["timeslice"])
180193

181-
if "replacement" not in result.dims:
182-
raise RuntimeError("Objective should return a dimension 'replacement'")
183-
if "technology" in result.dims:
184-
raise RuntimeError("Objective should not return a dimension 'technology'")
185-
if "technology" in result.coords:
186-
raise RuntimeError("Objective should not return a coordinate 'technology'")
187-
if "year" in result.dims:
188-
raise RuntimeError("Objective should not return a dimension 'year'")
189-
result.name = function.__name__
190194
cache_quantity(**{result.name: result})
191195
return result
192196

@@ -196,21 +200,25 @@ def decorated_objective(technologies: xr.Dataset, *args, **kwargs) -> xr.DataArr
196200
@register_objective
197201
def comfort(
198202
technologies: xr.Dataset,
203+
demand: xr.DataArray,
199204
*args,
200205
**kwargs,
201206
) -> xr.DataArray:
202207
"""Comfort value provided by technologies."""
203-
return technologies.comfort
208+
result = xr.broadcast(technologies.comfort, demand.asset)[0]
209+
return result
204210

205211

206212
@register_objective
207213
def efficiency(
208214
technologies: xr.Dataset,
215+
demand: xr.DataArray,
209216
*args,
210217
**kwargs,
211218
) -> xr.DataArray:
212219
"""Efficiency of the technologies."""
213-
return technologies.efficiency
220+
result = xr.broadcast(technologies.efficiency, demand.asset)[0]
221+
return result
214222

215223

216224
@register_objective(name="capacity")
@@ -292,6 +300,7 @@ def fixed_costs(
292300
@register_objective
293301
def capital_costs(
294302
technologies: xr.Dataset,
303+
demand: xr.Dataset,
295304
*args,
296305
**kwargs,
297306
) -> xr.DataArray:
@@ -303,6 +312,7 @@ def capital_costs(
303312
simulation for each technology.
304313
"""
305314
result = technologies.cap_par * (technologies.scaling_size**technologies.cap_exp)
315+
result = xr.broadcast(result, demand.asset)[0]
306316
return result
307317

308318

@@ -373,10 +383,12 @@ def annual_levelized_cost_of_energy(
373383
"""
374384
from muse.costs import annual_levelized_cost_of_energy as aLCOE
375385

376-
return filter_input(
386+
result = filter_input(
377387
aLCOE(technologies=technologies, prices=prices).max("timeslice"),
378388
year=demand.year.item(),
379389
)
390+
result = xr.broadcast(result, demand.asset)[0]
391+
return result
380392

381393

382394
@register_objective(name=["LCOE", "LLCOE"])

src/muse/utilities.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
"""Collection of functions and stand-alone algorithms."""
22

3+
from __future__ import annotations
4+
35
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
46
from typing import (
57
Any,
68
Callable,
79
NamedTuple,
8-
Optional,
9-
Union,
1010
cast,
1111
)
1212

1313
import numpy as np
1414
import xarray as xr
1515

1616

17-
def multiindex_to_coords(
18-
data: Union[xr.Dataset, xr.DataArray], dimension: str = "asset"
19-
):
17+
def multiindex_to_coords(data: xr.Dataset | xr.DataArray, dimension: str = "asset"):
2018
"""Flattens multi-index dimension into multi-coord dimension."""
2119
from pandas import MultiIndex
2220

@@ -33,8 +31,8 @@ def multiindex_to_coords(
3331

3432

3533
def coords_to_multiindex(
36-
data: Union[xr.Dataset, xr.DataArray], dimension: str = "asset"
37-
) -> Union[xr.Dataset, xr.DataArray]:
34+
data: xr.Dataset | xr.DataArray, dimension: str = "asset"
35+
) -> xr.Dataset | xr.DataArray:
3836
"""Creates a multi-index from flattened multiple coords."""
3937
from pandas import MultiIndex
4038

@@ -47,11 +45,11 @@ def coords_to_multiindex(
4745

4846

4947
def reduce_assets(
50-
assets: Union[xr.DataArray, xr.Dataset, Sequence[Union[xr.Dataset, xr.DataArray]]],
51-
coords: Optional[Union[str, Sequence[str], Iterable[str]]] = None,
48+
assets: xr.DataArray | xr.Dataset | Sequence[xr.Dataset | xr.DataArray],
49+
coords: str | Sequence[str] | Iterable[str] | None = None,
5250
dim: str = "asset",
53-
operation: Optional[Callable] = None,
54-
) -> Union[xr.DataArray, xr.Dataset]:
51+
operation: Callable | None = None,
52+
) -> xr.DataArray | xr.Dataset:
5553
r"""Combine assets along given asset dimension.
5654
5755
This method simplifies combining assets across multiple agents, or combining assets
@@ -182,13 +180,13 @@ def operation(x):
182180

183181

184182
def broadcast_techs(
185-
technologies: Union[xr.Dataset, xr.DataArray],
186-
template: Union[xr.DataArray, xr.Dataset],
183+
technologies: xr.Dataset | xr.DataArray,
184+
template: xr.DataArray | xr.Dataset,
187185
dimension: str = "asset",
188186
interpolation: str = "linear",
189187
installed_as_year: bool = True,
190188
**kwargs,
191-
) -> Union[xr.Dataset, xr.DataArray]:
189+
) -> xr.Dataset | xr.DataArray:
192190
"""Broadcasts technologies to the shape of template in given dimension.
193191
194192
The dimensions of the technologies are fully explicit, in that each concept
@@ -246,7 +244,7 @@ def broadcast_techs(
246244
return techs.sel(second_sel)
247245

248246

249-
def clean_assets(assets: xr.Dataset, years: Union[int, Sequence[int]]):
247+
def clean_assets(assets: xr.Dataset, years: int | Sequence[int]):
250248
"""Cleans up and prepares asset for current iteration.
251249
252250
- adds current and forecast year by backfilling missing entries
@@ -265,11 +263,11 @@ def clean_assets(assets: xr.Dataset, years: Union[int, Sequence[int]]):
265263

266264

267265
def filter_input(
268-
dataset: Union[xr.Dataset, xr.DataArray],
269-
year: Optional[Union[int, Iterable[int]]] = None,
266+
dataset: xr.Dataset | xr.DataArray,
267+
year: int | Iterable[int] | None = None,
270268
interpolation: str = "linear",
271269
**kwargs,
272-
) -> Union[xr.Dataset, xr.DataArray]:
270+
) -> xr.Dataset | xr.DataArray:
273271
"""Filter inputs, taking care to interpolate years."""
274272
if year is None:
275273
setyear: set[int] = set()
@@ -300,8 +298,8 @@ def filter_input(
300298

301299

302300
def filter_with_template(
303-
data: Union[xr.Dataset, xr.DataArray],
304-
template: Union[xr.DataArray, xr.Dataset],
301+
data: xr.Dataset | xr.DataArray,
302+
template: xr.DataArray | xr.Dataset,
305303
asset_dimension: str = "asset",
306304
**kwargs,
307305
):
@@ -350,7 +348,7 @@ def tupled_dimension(array: np.ndarray, axis: int):
350348
def lexical_comparison(
351349
objectives: xr.Dataset,
352350
binsize: xr.Dataset,
353-
order: Optional[Sequence[Hashable]] = None,
351+
order: Sequence[Hashable] | None = None,
354352
bin_last: bool = True,
355353
) -> xr.DataArray:
356354
"""Lexical comparison over the objectives.
@@ -438,7 +436,7 @@ def avoid_repetitions(data: xr.DataArray, dim: str = "year") -> xr.DataArray:
438436
return data.year[years]
439437

440438

441-
def nametuple_to_dict(nametup: Union[Mapping, NamedTuple]) -> Mapping:
439+
def nametuple_to_dict(nametup: Mapping | NamedTuple) -> Mapping:
442440
"""Transforms a nametuple of type GenericDict into an OrderDict."""
443441
from collections import OrderedDict
444442
from dataclasses import asdict, is_dataclass
@@ -537,11 +535,11 @@ def future_propagation(
537535

538536

539537
def agent_concatenation(
540-
data: Mapping[Hashable, Union[xr.DataArray, xr.Dataset]],
538+
data: Mapping[Hashable, xr.DataArray | xr.Dataset],
541539
dim: str = "asset",
542540
name: str = "agent",
543541
fill_value: Any = 0,
544-
) -> Union[xr.DataArray, xr.Dataset]:
542+
) -> xr.DataArray | xr.Dataset:
545543
"""Concatenates input map along given dimension.
546544
547545
Example:
@@ -613,10 +611,10 @@ def agent_concatenation(
613611

614612

615613
def aggregate_technology_model(
616-
data: Union[xr.DataArray, xr.Dataset],
614+
data: xr.DataArray | xr.Dataset,
617615
dim: str = "asset",
618-
drop: Union[str, Sequence[str]] = "installed",
619-
) -> Union[xr.DataArray, xr.Dataset]:
616+
drop: str | Sequence[str] = "installed",
617+
) -> xr.DataArray | xr.Dataset:
620618
"""Aggregate together assets with the same installation year.
621619
622620
The assets of a given agent, region, and technology but different installation year
@@ -659,3 +657,27 @@ def aggregate_technology_model(
659657
data,
660658
[cast(str, u) for u in data.coords if u not in drop and data[u].dims == (dim,)],
661659
)
660+
661+
662+
def check_dimensions(
663+
data: xr.DataArray | xr.Dataset,
664+
required: Iterable[str] = (),
665+
optional: Iterable[str] = (),
666+
):
667+
"""Ensure that an array has the required dimensions.
668+
669+
This will check that all required dimensions are present, and that no other
670+
dimensions are present, apart from those listed as optional.
671+
672+
Args:
673+
data: DataArray or Dataset to check dimensions of
674+
required: List of dimension names that must be present
675+
optional: List of dimension names that may be present
676+
"""
677+
present = set(data.dims)
678+
missing = set(required) - present
679+
if missing:
680+
raise ValueError(f"Missing required dimensions: {missing}")
681+
extra = present - set(required) - set(optional)
682+
if extra:
683+
raise ValueError(f"Extra dimensions: {extra}")

0 commit comments

Comments
 (0)