Skip to content

Commit 0ae972a

Browse files
authored
Merge pull request #580 from EnergySystemsModellingLab/id-tweaks
Various tweaks to ID code
2 parents 8308580 + 6e299c5 commit 0ae972a

13 files changed

Lines changed: 48 additions & 63 deletions

File tree

src/id.rs

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
//! Code for handling IDs
22
use anyhow::{Context, Result};
3-
use indexmap::IndexSet;
3+
use indexmap::{IndexMap, IndexSet};
4+
use std::borrow::Borrow;
45
use std::collections::HashSet;
6+
use std::fmt::Display;
7+
use std::hash::Hash;
58

69
/// A trait alias for ID types
7-
pub trait IDLike:
8-
Eq + std::hash::Hash + std::borrow::Borrow<str> + Clone + std::fmt::Display + From<String>
9-
{
10-
}
11-
impl<T> IDLike for T where
12-
T: Eq + std::hash::Hash + std::borrow::Borrow<str> + Clone + std::fmt::Display + From<String>
13-
{
14-
}
10+
pub trait IDLike: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
11+
impl<T> IDLike for T where T: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
1512

1613
macro_rules! define_id_type {
1714
($name:ident) => {
@@ -78,43 +75,25 @@ pub(crate) use define_id_getter;
7875

7976
/// A data structure containing a set of IDs
8077
pub trait IDCollection<ID: IDLike> {
81-
/// Get the ID from the collection by its string representation.
82-
///
83-
/// # Arguments
84-
///
85-
/// * `id` - The string representation of the ID
86-
///
87-
/// # Returns
88-
///
89-
/// A copy of the ID in `self`, or an error if not found.
90-
fn get_id_by_str(&self, id: &str) -> Result<ID>;
91-
9278
/// Check if the ID is in the collection, returning a copy of it if found.
9379
///
9480
/// # Arguments
9581
///
96-
/// * `id` - The ID to check
82+
/// * `id` - The ID to check (can be string or ID type)
9783
///
9884
/// # Returns
9985
///
10086
/// A copy of the ID in `self`, or an error if not found.
101-
fn get_id(&self, id: &ID) -> Result<ID>;
87+
fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID>;
10288
}
10389

10490
macro_rules! define_id_methods {
10591
() => {
106-
fn get_id_by_str(&self, id: &str) -> Result<ID> {
107-
let found = self
108-
.get(id)
109-
.with_context(|| format!("Unknown ID {id} found"))?;
110-
Ok(found.clone())
111-
}
112-
113-
fn get_id(&self, id: &ID) -> Result<ID> {
92+
fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
11493
let found = self
11594
.get(id.borrow())
11695
.with_context(|| format!("Unknown ID {id} found"))?;
117-
Ok(found.clone())
96+
Ok(found)
11897
}
11998
};
12099
}
@@ -126,3 +105,12 @@ impl<ID: IDLike> IDCollection<ID> for HashSet<ID> {
126105
impl<ID: IDLike> IDCollection<ID> for IndexSet<ID> {
127106
define_id_methods!();
128107
}
108+
109+
impl<ID: IDLike, V> IDCollection<ID> for IndexMap<ID, V> {
110+
fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
111+
let (found, _) = self
112+
.get_key_value(id.borrow())
113+
.with_context(|| format!("Unknown ID {id} found"))?;
114+
Ok(found)
115+
}
116+
}

src/input/agent/commodity_portion.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
use super::super::*;
33
use crate::agent::{AgentCommodityPortionsMap, AgentID, AgentMap};
44
use crate::commodity::{CommodityID, CommodityMap, CommodityType};
5+
use crate::id::IDCollection;
56
use crate::region::RegionID;
67
use crate::year::parse_year_str;
78
use anyhow::{ensure, Context, Result};
@@ -66,9 +67,7 @@ where
6667
for agent_commodity_portion_raw in iter {
6768
// Get agent ID
6869
let agent_id_raw = agent_commodity_portion_raw.agent_id.as_str();
69-
let (id, _agent) = agents
70-
.get_key_value(agent_id_raw)
71-
.with_context(|| format!("Invalid agent ID {agent_id_raw}"))?;
70+
let id = agents.get_id(agent_id_raw)?;
7271

7372
// Get/create entry for agent
7473
let entry = agent_commodity_portions
@@ -77,9 +76,7 @@ where
7776

7877
// Insert portion for the commodity/year(s)
7978
let commodity_id_raw = agent_commodity_portion_raw.commodity_id.as_str();
80-
let (commodity_id, _commodity) = commodities
81-
.get_key_value(commodity_id_raw)
82-
.with_context(|| format!("Invalid commodity ID {commodity_id_raw}"))?;
79+
let commodity_id = commodities.get_id(commodity_id_raw)?;
8380
let years = parse_year_str(&agent_commodity_portion_raw.years, milestone_years)?;
8481
for year in years {
8582
try_insert(

src/input/agent/cost_limit.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ where
6161
let years = parse_year_str(&agent_cost_limits_raw.years, milestone_years)?;
6262

6363
// Get agent ID
64-
let agent_id = agent_ids.get_id_by_str(&agent_cost_limits_raw.agent_id)?;
64+
let agent_id = agent_ids.get_id(&agent_cost_limits_raw.agent_id)?;
6565

6666
// Get or create entry in the map
6767
let entry = map.entry(agent_id.clone()).or_default();

src/input/agent/search_space.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,16 @@ impl AgentSearchSpaceRaw {
5252
let search_space = Rc::new(parse_search_space_str(&self.search_space, process_ids)?);
5353

5454
// Get commodity
55-
let commodity_id = commodity_ids.get_id_by_str(&self.commodity_id)?;
55+
let commodity_id = commodity_ids.get_id(&self.commodity_id)?;
5656

5757
// Check that the year is a valid milestone year
5858
let year = parse_year_str(&self.years, milestone_years)?;
5959

60-
let (agent_id, _) = agents
61-
.get_key_value(self.agent_id.as_str())
62-
.context("Invalid agent ID")?;
60+
let agent_id = agents.get_id(&self.agent_id)?;
6361

6462
Ok(AgentSearchSpace {
6563
agent_id: agent_id.clone(),
66-
commodity_id,
64+
commodity_id: commodity_id.clone(),
6765
years: year,
6866
search_space,
6967
})
@@ -86,7 +84,7 @@ fn parse_search_space_str(
8684
} else {
8785
search_space
8886
.split(';')
89-
.map(|id| process_ids.get_id_by_str(id.trim()))
87+
.map(|id| Ok(process_ids.get_id(id.trim())?.clone()))
9088
.try_collect()
9189
}
9290
}

src/input/asset.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,16 @@ where
6969
I: Iterator<Item = AssetRaw>,
7070
{
7171
iter.map(|asset| -> Result<_> {
72-
let agent_id = agent_ids.get_id_by_str(&asset.agent_id)?;
72+
let agent_id = agent_ids.get_id(&asset.agent_id)?;
7373
let process = processes
7474
.get(asset.process_id.as_str())
7575
.with_context(|| format!("Invalid process ID: {}", &asset.process_id))?;
76-
let region_id = region_ids.get_id_by_str(&asset.region_id)?;
76+
let region_id = region_ids.get_id(&asset.region_id)?;
7777

7878
Asset::new(
79-
agent_id,
79+
agent_id.clone(),
8080
Rc::clone(process),
81-
region_id,
81+
region_id.clone(),
8282
asset.capacity,
8383
asset.commission_year,
8484
)

src/input/commodity/cost.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ where
7878
let mut commodity_regions: HashMap<CommodityID, HashSet<RegionID>> = HashMap::new();
7979

8080
for cost in iter {
81-
let commodity_id = commodity_ids.get_id_by_str(&cost.commodity_id)?;
81+
let commodity_id = commodity_ids.get_id(&cost.commodity_id)?;
8282
let regions = parse_region_str(&cost.regions, region_ids)?;
8383
let years = parse_year_str(&cost.years, milestone_years)?;
8484
let ts_selection = time_slice_info.get_selection(&cost.time_slice)?;

src/input/commodity/demand.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,14 @@ where
110110
let mut map = AnnualDemandMap::new();
111111
for demand in iter {
112112
let commodity_id = svd_commodity_ids
113-
.get_id_by_str(&demand.commodity_id)
113+
.get_id(&demand.commodity_id)
114114
.with_context(|| {
115115
format!(
116116
"Can only provide demand data for SVD commodities. Found entry for '{}'",
117117
demand.commodity_id
118118
)
119119
})?;
120-
let region_id = region_ids.get_id_by_str(&demand.region_id)?;
120+
let region_id = region_ids.get_id(&demand.region_id)?;
121121

122122
ensure!(
123123
milestone_years.binary_search(&demand.year).is_ok(),

src/input/commodity/demand_slicing.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ where
6464

6565
for slice in iter {
6666
let commodity_id = svd_commodity_ids
67-
.get_id_by_str(&slice.commodity_id)
67+
.get_id(&slice.commodity_id)
6868
.with_context(|| {
6969
format!(
7070
"Can only provide demand slice data for SVD commodities. Found entry for '{}'",
7171
slice.commodity_id
7272
)
7373
})?;
74-
let region_id = region_ids.get_id_by_str(&slice.region_id)?;
74+
let region_id = region_ids.get_id(&slice.region_id)?;
7575

7676
// We need to know how many time slices are covered by the current demand slice entry and
7777
// how long they are relative to one another so that we can divide up the demand for this

src/input/process/availability.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ where
109109
record.validate()?;
110110

111111
// Get process
112-
let id = process_ids.get_id_by_str(&record.process_id)?;
112+
let id = process_ids.get_id(&record.process_id)?;
113113
let process = processes
114-
.get(&id)
114+
.get(id)
115115
.with_context(|| format!("Process {id} not found"))?;
116116

117117
// Get regions
@@ -131,7 +131,9 @@ where
131131
let ts_selection = time_slice_info.get_selection(&record.time_slice)?;
132132

133133
// Insert the energy limit into the map
134-
let entry = map.entry(id).or_insert_with(ProcessEnergyLimitsMap::new);
134+
let entry = map
135+
.entry(id.clone())
136+
.or_insert_with(ProcessEnergyLimitsMap::new);
135137
for (time_slice, ts_length) in time_slice_info.iter_selection(&ts_selection) {
136138
let bounds = record.to_bounds(ts_length);
137139

src/input/process/flow.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ where
8282
record.validate()?;
8383

8484
// Get process
85-
let id = process_ids.get_id_by_str(&record.process_id)?;
85+
let id = process_ids.get_id(&record.process_id)?;
8686
let process = processes
87-
.get(&id)
87+
.get(id)
8888
.with_context(|| format!("Process {id} not found"))?;
8989

9090
// Get regions

0 commit comments

Comments
 (0)