@@ -12,9 +12,7 @@ impl<T> IDLike for T where T: Eq + Hash + Borrow<str> + Clone + Display + From<S
1212
1313macro_rules! define_id_type {
1414 ( $name: ident) => {
15- #[ derive(
16- Clone , std:: hash:: Hash , PartialEq , Eq , serde:: Deserialize , Debug , serde:: Serialize ,
17- ) ]
15+ #[ derive( Clone , std:: hash:: Hash , PartialEq , Eq , Debug , serde:: Serialize ) ]
1816 /// An ID type (e.g. `AgentID`, `CommodityID`, etc.)
1917 pub struct $name( pub std:: rc:: Rc <str >) ;
2018
@@ -42,6 +40,32 @@ macro_rules! define_id_type {
4240 }
4341 }
4442
43+ impl <' de> serde:: Deserialize <' de> for $name {
44+ fn deserialize<D >( deserialiser: D ) -> std:: result:: Result <Self , D :: Error >
45+ where
46+ D : serde:: Deserializer <' de>,
47+ {
48+ use serde:: de:: Error ;
49+
50+ let id: String = serde:: Deserialize :: deserialize( deserialiser) ?;
51+ let id = id. trim( ) ;
52+ if id. is_empty( ) {
53+ return Err ( D :: Error :: custom( "IDs cannot be empty" ) ) ;
54+ }
55+
56+ const FORBIDDEN_IDS : [ & str ; 2 ] = [ "all" , "annual" ] ;
57+ for forbidden in FORBIDDEN_IDS . iter( ) {
58+ if id. eq_ignore_ascii_case( forbidden) {
59+ return Err ( D :: Error :: custom( format!(
60+ "'{id}' is an invalid value for an ID"
61+ ) ) ) ;
62+ }
63+ }
64+
65+ Ok ( id. into( ) )
66+ }
67+ }
68+
4569 impl $name {
4670 /// Create a new ID from a string slice
4771 pub fn new( id: & str ) -> Self {
@@ -114,3 +138,39 @@ impl<ID: IDLike, V> IDCollection<ID> for IndexMap<ID, V> {
114138 Ok ( found)
115139 }
116140}
141+
142+ #[ cfg( test) ]
143+ mod tests {
144+ use super :: * ;
145+ use rstest:: rstest;
146+
147+ use serde:: Deserialize ;
148+
149+ #[ derive( Debug , Deserialize ) ]
150+ struct Record {
151+ id : GenericID ,
152+ }
153+
154+ fn deserialise_id ( id : & str ) -> Result < Record > {
155+ Ok ( toml:: from_str ( & format ! ( "id = \" {id}\" " ) ) ?)
156+ }
157+
158+ #[ rstest]
159+ #[ case( "commodity1" ) ]
160+ #[ case( "some commodity" ) ]
161+ #[ case( "PROCESS" ) ]
162+ #[ case( "café" ) ] // unicode supported
163+ fn test_deserialise_id_valid ( #[ case] id : & str ) {
164+ assert_eq ! ( deserialise_id( id) . unwrap( ) . id. to_string( ) , id) ;
165+ }
166+
167+ #[ rstest]
168+ #[ case( "" ) ]
169+ #[ case( "all" ) ]
170+ #[ case( "annual" ) ]
171+ #[ case( "ALL" ) ]
172+ #[ case( " ALL " ) ]
173+ fn test_deserialise_id_invalid ( #[ case] id : & str ) {
174+ assert ! ( deserialise_id( id) . is_err( ) ) ;
175+ }
176+ }
0 commit comments