Skip to content

Commit 47b6c26

Browse files
authored
Merge pull request #572 from EnergySystemsModellingLab/prohibit-dubious-ids
Disallow certain values for IDs
2 parents d1dfee1 + 5e025fc commit 47b6c26

1 file changed

Lines changed: 63 additions & 3 deletions

File tree

src/id.rs

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ impl<T> IDLike for T where T: Eq + Hash + Borrow<str> + Clone + Display + From<S
1212

1313
macro_rules! define_id_type {
1414
($name:ident) => {
15-
#[derive(
16-
Clone, std::hash::Hash, PartialEq, Eq, serde::Deserialize, Debug, serde::Serialize,
17-
)]
15+
#[derive(Clone, std::hash::Hash, PartialEq, Eq, Debug, serde::Serialize)]
1816
/// An ID type (e.g. `AgentID`, `CommodityID`, etc.)
1917
pub struct $name(pub std::rc::Rc<str>);
2018

@@ -42,6 +40,32 @@ macro_rules! define_id_type {
4240
}
4341
}
4442

43+
impl<'de> serde::Deserialize<'de> for $name {
44+
fn deserialize<D>(deserialiser: D) -> std::result::Result<Self, D::Error>
45+
where
46+
D: serde::Deserializer<'de>,
47+
{
48+
use serde::de::Error;
49+
50+
let id: String = serde::Deserialize::deserialize(deserialiser)?;
51+
let id = id.trim();
52+
if id.is_empty() {
53+
return Err(D::Error::custom("IDs cannot be empty"));
54+
}
55+
56+
const FORBIDDEN_IDS: [&str; 2] = ["all", "annual"];
57+
for forbidden in FORBIDDEN_IDS.iter() {
58+
if id.eq_ignore_ascii_case(forbidden) {
59+
return Err(D::Error::custom(format!(
60+
"'{id}' is an invalid value for an ID"
61+
)));
62+
}
63+
}
64+
65+
Ok(id.into())
66+
}
67+
}
68+
4569
impl $name {
4670
/// Create a new ID from a string slice
4771
pub fn new(id: &str) -> Self {
@@ -114,3 +138,39 @@ impl<ID: IDLike, V> IDCollection<ID> for IndexMap<ID, V> {
114138
Ok(found)
115139
}
116140
}
141+
142+
#[cfg(test)]
143+
mod tests {
144+
use super::*;
145+
use rstest::rstest;
146+
147+
use serde::Deserialize;
148+
149+
#[derive(Debug, Deserialize)]
150+
struct Record {
151+
id: GenericID,
152+
}
153+
154+
fn deserialise_id(id: &str) -> Result<Record> {
155+
Ok(toml::from_str(&format!("id = \"{id}\""))?)
156+
}
157+
158+
#[rstest]
159+
#[case("commodity1")]
160+
#[case("some commodity")]
161+
#[case("PROCESS")]
162+
#[case("café")] // unicode supported
163+
fn test_deserialise_id_valid(#[case] id: &str) {
164+
assert_eq!(deserialise_id(id).unwrap().id.to_string(), id);
165+
}
166+
167+
#[rstest]
168+
#[case("")]
169+
#[case("all")]
170+
#[case("annual")]
171+
#[case("ALL")]
172+
#[case(" ALL ")]
173+
fn test_deserialise_id_invalid(#[case] id: &str) {
174+
assert!(deserialise_id(id).is_err());
175+
}
176+
}

0 commit comments

Comments
 (0)