diff --git a/README.md b/README.md index 4604ecd..a8aea71 100644 --- a/README.md +++ b/README.md @@ -44,9 +44,12 @@ fn main() { assert_eq!(json, "\"cat\""); let value: Animal = serde_json::from_str(&json).unwrap(); assert_eq!(value, Animal::Cat); - let json = serde_json::to_string(&Animal::from(100u8)).unwrap(); + assert_eq!(value, serde_json::from_str("\"kitty\"").unwrap()); + assert!(serde_json::from_str::("\"Rat\"").is_err()); + let value = Animal::from(2i16); + let json = serde_json::to_string(&value).unwrap(); assert_eq!(json, "\"\""); - let value: Animal = serde_json::from_str(&json).unwrap(); - assert_eq!(value, Animal::UNKNOWN); + assert_eq!(usize::from(value), 2); + assert_eq!(u8::try_from(value).unwrap(), 2); } ``` diff --git a/examples/derive_serde.rs b/examples/derive_serde.rs index 782f126..5a7ccc3 100644 --- a/examples/derive_serde.rs +++ b/examples/derive_serde.rs @@ -8,6 +8,7 @@ use thiserror::Error; #[repr(i16)] pub enum Animal { #[error("unknown animal")] + #[serde(alias = "kitty")] Cat, #[error("unknown animal")] Dog, @@ -18,10 +19,11 @@ fn main() { assert_eq!(json, "\"cat\""); let value: Animal = serde_json::from_str(&json).unwrap(); assert_eq!(value, Animal::Cat); - let json = serde_json::to_string(&Animal::from(100i16)).unwrap(); + assert_eq!(value, serde_json::from_str("\"kitty\"").unwrap()); + assert!(serde_json::from_str::("\"Rat\"").is_err()); + let value = Animal::from(2i16); + let json = serde_json::to_string(&value).unwrap(); assert_eq!(json, "\"\""); - let value: Animal = serde_json::from_str(&json).unwrap(); - assert_eq!(value, Animal::UNKNOWN); assert_eq!(usize::from(value), 2); assert_eq!(u8::try_from(value).unwrap(), 2); } diff --git a/macros/src/delegate.rs b/macros/src/delegate.rs index ae2cf2e..080cea3 100644 --- a/macros/src/delegate.rs +++ b/macros/src/delegate.rs @@ -1,8 +1,9 @@ use proc_macro2::{Span, TokenStream}; -use quote::quote_spanned; +use quote::{quote_spanned, ToTokens}; use syn::{ - Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Path, PathArguments, PathSegment, - Token, Type, + parse::Parser as _, punctuated::Punctuated, token::Comma, Attribute, Expr, ExprLit, Ident, Lit, + LitByteStr, LitStr, Meta, MetaNameValue, Path, PathArguments, PathSegment, Token, Type, + Variant, }; use crate::utils::extract_meta_from_lists; @@ -13,6 +14,8 @@ struct Context<'a, 'b> { target: &'a Type, origin: &'a Type, attrs: &'a [Attribute], + variants: &'a Punctuated, + repr: &'a Type, impls: &'b mut TokenStream, } @@ -83,7 +86,6 @@ impl Context<'_, '_> { let Self { span, target, - origin, impls, .. } = self; @@ -92,11 +94,7 @@ impl Context<'_, '_> { impl ::core::hash::Hash for #target { #[inline(always)] fn hash(&self, state: &mut H) { - if let Ok(origin) = <#origin as ::core::convert::TryFrom<#target>>::try_from(*self) { - <#origin as ::core::hash::Hash>::hash(&origin, state) - } else { - ::core::hash::Hash::hash(&<#target as ::ffi_enum::FfiEnum>::UNKNOWN.repr, state) - } + self.repr.hash(state) } } }) @@ -109,18 +107,21 @@ impl Context<'_, '_> { target, origin, impls, + repr, .. } = self; impls.extend(quote_spanned! { span => impl ::core::str::FromStr for #target { - type Err = ::core::convert::Infallible; + type Err = <#repr as ::core::str::FromStr>::Err; #[inline(always)] fn from_str(s: &str) -> Result { Ok( - <#origin as ::core::str::FromStr>::from_str(s) - .map_or_else(|_| ::UNKNOWN, |v| v.into()), + match <#origin as ::core::str::FromStr>::from_str(s) { + Ok(value) => Ok(value.into()) + Err(err) => #repr::from_str(s), + } ) } } @@ -147,10 +148,10 @@ impl Context<'_, '_> { where S: #serde::Serializer, { - if let Ok(origin) = <#origin as ::core::convert::TryFrom<#target>>::try_from(*self) { - <#origin as #serde::Serialize>::serialize(&origin, serializer) + if let Ok(val) = <#origin as ::core::convert::TryFrom<#target>>::try_from(*self) { + <#origin as #serde::Serialize>::serialize(&val, serializer) } else { - ::serialize_unit_variant(serializer, #name, <#target as ::ffi_enum::FfiEnum>::UNKNOWN.repr as _, "") + serializer.serialize_unit_variant(#name, self.repr as _, "") } } } @@ -160,13 +161,67 @@ impl Context<'_, '_> { fn deserialize(self) { let Self { target, - origin, span, attrs, impls, + repr, + variants, .. } = self; let serde = locate_serde(attrs); + let renameall = rename_with( + attrs + .iter() + .find_map(list_attr_filter_map("serde", |path, val| { + if !path.is_ident("rename_all") { + return None; + } + Some(val.value()) + })) + .unwrap_or("PascalCase".to_string()), + ); + let variants = variants + .iter() + .map(|v| { + (v.ident.clone(), { + let mut name = renameall(v.ident.to_string().as_ref()); + let mut names = v + .attrs + .iter() + .filter_map(list_attr_filter_map("serde", |path, val| { + if path.is_ident("rename") { + name = rename_with(val.value())(v.ident.to_string().as_ref()); + None + } else if path.is_ident("alias") { + Some(val.value()) + } else { + None + } + })) + .collect::>(); + names.push(name); + names + }) + }) + .collect::>(); + + let visit_repr = Ident::new( + &format!("visit_{}", repr.to_token_stream()), + Span::mixed_site(), + ); + + let names = variants.iter().map(|v| &v.0).collect::>(); + let aliases = variants.iter().map(|v| &v.1).map(|v| { + v.iter() + .map(|v| Lit::Str(LitStr::new(v, Span::call_site()))) + .collect::>() + }); + + let baliases = variants.iter().map(|v| &v.1).map(|v| { + v.iter() + .map(|v| Lit::ByteStr(LitByteStr::new(v.as_bytes(), Span::call_site()))) + .collect::>() + }); impls.extend(quote_spanned! { span => impl<'de> #serde::Deserialize<'de> for #target { @@ -175,7 +230,44 @@ impl Context<'_, '_> { where D: #serde::Deserializer<'de>, { - Ok(<#origin as #serde::Deserialize>::deserialize(deserializer).map(Into::into).unwrap_or_else(|_|<#target as ::ffi_enum::FfiEnum>::UNKNOWN)) + struct Visitor; + impl<'de> #serde::de::Visitor<'de> for Visitor { + type Value = #target; + + fn expecting( + &self, + formatter: &mut ::core::fmt::Formatter + ) -> ::core::fmt::Result { + formatter.write_str("variant identifier") + } + + fn #visit_repr(self,v:#repr)->::core::result::Result{ + Ok(match v{ + #(v if v==#target::#names.repr => #target::#names,)* + v => #target{repr:v} + }) + } + + fn visit_str(self,v:&str)->::core::result::Result{ + match v{ + #(#aliases => Ok(#target::#names),)* + v => Err(#serde::de::Error::unknown_variant(v,VARIANTS)), + } + } + + fn visit_bytes(self,v:&[u8])->::core::result::Result{ + match v{ + #(#baliases => Ok(#target::#names),)* + v => Err(#serde::de::Error::unknown_variant(&::std::string::String::from_utf8_lossy(v),VARIANTS)), + } + } + } + const VARIANTS:&'static[&'static str]=&[#(stringify!(#names),)*]; + deserializer.deserialize_identifier( + // #name, + // VARIANTS, + Visitor, + ) } } }); @@ -208,18 +300,84 @@ fn locate_serde(attrs: &[Attribute]) -> Path { }) } +fn list_attr_filter_map<'a, T>( + name: &'a str, + mut f: impl FnMut(Path, syn::LitStr) -> Option + 'a, +) -> impl FnMut(&'a Attribute) -> Option + 'a { + move |attr| { + let Attribute { + style: syn::AttrStyle::Outer, + meta: Meta::List(syn::MetaList { path, tokens, .. }), + .. + } = attr + else { + return None; + }; + if !path.is_ident(name) { + return None; + }; + Punctuated::::parse_terminated + .parse2(tokens.clone()) + .ok()? + .into_iter() + .find_map(|meta| { + let Meta::NameValue(syn::MetaNameValue { + path, + value: + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(val), + .. + }), + .. + }) = meta + else { + return None; + }; + f(path, val) + }) + } +} + +fn rename_with(fmt: String) -> impl Fn(&str) -> String { + move |variant| match fmt.as_str() { + "PascalCase" => variant.to_owned(), + "lowercase" => variant.to_ascii_lowercase(), + "UPPERCASE" => variant.to_ascii_uppercase(), + "camelCase" => variant[..1].to_ascii_lowercase() + &variant[1..], + "snake_case" => { + let mut snake = String::new(); + for (i, ch) in variant.char_indices() { + if i > 0 && ch.is_uppercase() { + snake.push('_'); + } + snake.push(ch.to_ascii_lowercase()); + } + snake + } + "SCREAMING_SNAKE_CASE" => { + rename_with("snake_case".to_string())(variant).to_ascii_uppercase() + } + "kebab-case" => rename_with("snake_case".to_string())(variant).replace('_', "-"), + "SCREAMING-KEBAB-CASE" => { + rename_with("SCREAMING_SNAKE_CASE".to_string())(variant).replace('_', "-") + } + _ => panic!("error fmt with {}", fmt), + } +} + pub fn delegate<'a, 'b>( name: &'a str, target: &'a Type, origin: &'a Type, attrs: &'a [Attribute], + repr: &'a Type, + variants: &'a Punctuated, impls: &'b mut TokenStream, ) -> impl FnMut(Meta) + use<'a, 'b> { |meta| { let Meta::Path(path) = meta else { return; }; - let Some(ident) = path.get_ident() else { return; }; @@ -230,6 +388,8 @@ pub fn delegate<'a, 'b>( target, origin, attrs, + repr, + variants, impls, }; diff --git a/macros/src/ffi_enum.rs b/macros/src/ffi_enum.rs index 6d2f62a..581a8ed 100644 --- a/macros/src/ffi_enum.rs +++ b/macros/src/ffi_enum.rs @@ -85,7 +85,7 @@ pub fn handle(Args(args): Args, input: ItemEnum) -> Result { let mut impls = Default::default(); extract_meta_from_lists(attrs, "derive").for_each(super::delegate::delegate( - &name, &target, &origin, attrs, &mut impls, + &name, &target, &origin, attrs, &repr, variants, &mut impls, )); let types = [ @@ -133,13 +133,6 @@ pub fn handle(Args(args): Args, input: ItemEnum) -> Result { impl ::ffi_enum::FfiEnum for #target { type Enum = #origin; type Repr = #repr; - const UNKNOWN: Self = { - let mut result: #repr = 0; - while #(result == #target::#variant_ids.repr ||)* false { - result += 1; - } - #target { repr: result } - }; } #[allow(dead_code, non_upper_case_globals)] diff --git a/src/lib.rs b/src/lib.rs index 0737f08..71fee4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,12 +20,6 @@ pub trait FfiEnum: Copy + Eq { /// The representation type of the `ffi_enum` type type Repr: Copy + From + Into; - - /// A sample of an unknown values - /// - /// Note that a `ffi_enum` accepts any value of the representation type, so `some_value != FfiEnum::UNKNOWN` **does not** - /// indicate that `some_value` is known - const UNKNOWN: Self; } /// Convenient operations on `FfiEnum`