|
1 | 1 | use proc_macro::TokenStream; |
2 | | -use quote::quote; |
3 | | -use syn::Attribute; |
| 2 | +use quote::{format_ident, quote}; |
| 3 | +use syn::{ |
| 4 | + Data, DeriveInput, Fields, Type, Variant, parse_macro_input, punctuated::Punctuated, |
| 5 | + token::Comma, |
| 6 | +}; |
4 | 7 |
|
5 | | -#[proc_macro_derive(EnumDiscriminate)] |
6 | | -pub fn enum_discriminate_derive(input: TokenStream) -> TokenStream { |
7 | | - // Construct a representation of Rust code as a syntax tree |
8 | | - // that we can manipulate. |
9 | | - let ast = syn::parse(input).unwrap(); |
| 8 | +#[proc_macro_derive(ByteCodec)] |
| 9 | +pub fn derive_byte_codec(input: TokenStream) -> TokenStream { |
| 10 | + let input = parse_macro_input!(input as DeriveInput); |
| 11 | + let type_name = &input.ident; |
10 | 12 |
|
11 | | - // Build the trait implementation. |
12 | | - impl_enum_discriminate_derive(&ast) |
| 13 | + let expanded = match &input.data { |
| 14 | + Data::Enum(data_enum) => { |
| 15 | + let variants = &data_enum.variants; |
| 16 | + let max_size_impl = build_enum_max_serialized_size(variants); |
| 17 | + let serialize_impl = build_serialize(type_name, variants); |
| 18 | + let deserialize_impl = build_deserialize(type_name, variants); |
| 19 | + |
| 20 | + quote! { |
| 21 | + impl liquidcan_rust_macros::byte_codec::ByteCodec for #type_name { |
| 22 | + #max_size_impl |
| 23 | + #serialize_impl |
| 24 | + #deserialize_impl |
| 25 | + } |
| 26 | + } |
| 27 | + } |
| 28 | + Data::Struct(data_struct) => { |
| 29 | + let max_size_impl = build_struct_max_serialized_size(&data_struct.fields); |
| 30 | + let serialize_impl = build_struct_serialize(type_name, &data_struct.fields); |
| 31 | + let deserialize_impl = build_struct_deserialize(type_name, &data_struct.fields); |
| 32 | + |
| 33 | + quote! { |
| 34 | + impl liquidcan_rust_macros::byte_codec::ByteCodec for #type_name { |
| 35 | + #max_size_impl |
| 36 | + #serialize_impl |
| 37 | + #deserialize_impl |
| 38 | + } |
| 39 | + } |
| 40 | + } |
| 41 | + Data::Union(_) => panic!("ByteCodec cannot be derived for unions"), |
| 42 | + }; |
| 43 | + |
| 44 | + TokenStream::from(expanded) |
13 | 45 | } |
14 | 46 |
|
15 | | -fn has_repr_u8(attrs: &[Attribute]) -> bool { |
16 | | - let mut is_u8 = false; |
17 | | - for attr in attrs { |
18 | | - if attr.path().is_ident("repr") { |
19 | | - attr.parse_nested_meta(|meta| { |
20 | | - if meta.path.is_ident("u8") { |
21 | | - is_u8 = true; |
| 47 | +fn build_serialize( |
| 48 | + enum_name: &syn::Ident, |
| 49 | + variants: &Punctuated<Variant, Comma>, |
| 50 | +) -> proc_macro2::TokenStream { |
| 51 | + let match_arms = variants.iter().map(|variant| { |
| 52 | + let variant_name = &variant.ident; |
| 53 | + let (_, variant_discriminant) = variant |
| 54 | + .discriminant |
| 55 | + .as_ref() |
| 56 | + .expect("Must explicitly specify discriminant"); |
| 57 | + |
| 58 | + match &variant.fields { |
| 59 | + Fields::Named(fields) => { |
| 60 | + let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect(); |
| 61 | + |
| 62 | + quote! { |
| 63 | + #enum_name::#variant_name { #(#field_names),* } => { |
| 64 | + out.push(#variant_discriminant); |
| 65 | + #( #field_names.serialize(out); )* |
| 66 | + } |
| 67 | + } |
| 68 | + } |
| 69 | + Fields::Unnamed(fields) => { |
| 70 | + // Generate dummy identifiers for the tuple fields (e.g., f0, f1, f2) |
| 71 | + let field_idents: Vec<_> = (0..fields.unnamed.len()) |
| 72 | + .map(|i| format_ident!("f{}", i)) |
| 73 | + .collect(); |
| 74 | + |
| 75 | + quote! { |
| 76 | + #enum_name::#variant_name( #(#field_idents),* ) => { |
| 77 | + out.push(#variant_discriminant); |
| 78 | + #( #field_idents.serialize(out); )* |
| 79 | + } |
| 80 | + } |
| 81 | + } |
| 82 | + Fields::Unit => { |
| 83 | + quote! { |
| 84 | + #enum_name::#variant_name => { |
| 85 | + out.push(#variant_discriminant); |
| 86 | + } |
22 | 87 | } |
23 | | - Ok(()) |
24 | | - }) |
25 | | - .unwrap() |
| 88 | + } |
26 | 89 | } |
| 90 | + }); |
| 91 | + |
| 92 | + let expanded = quote! { fn serialize(&self, out: &mut Vec<u8>) { |
| 93 | + let out_len_before = out.len(); |
| 94 | + match self { |
| 95 | + #(#match_arms)* |
| 96 | + } |
| 97 | + assert!(out.len() - out_len_before <= Self::MAX_SERIALIZED_SIZE, "Serialized data exceeds MAX_SERIALIZED_SIZE"); |
| 98 | + }}; |
| 99 | + |
| 100 | + return expanded; |
| 101 | +} |
| 102 | + |
| 103 | +fn max_size_for_type(ty: &Type) -> proc_macro2::TokenStream { |
| 104 | + quote! { |
| 105 | + <#ty as liquidcan_rust_macros::byte_codec::ByteCodec>::MAX_SERIALIZED_SIZE |
27 | 106 | } |
28 | | - is_u8 |
29 | 107 | } |
30 | 108 |
|
31 | | -fn impl_enum_discriminate_derive(ast: &syn::DeriveInput) -> TokenStream { |
32 | | - let name = &ast.ident; |
33 | | - if !has_repr_u8(&ast.attrs) { |
34 | | - panic!("EnumDiscriminate can only be derived for enums which have the u8 repr"); |
| 109 | +fn sum_max_sizes(types: Vec<&Type>) -> proc_macro2::TokenStream { |
| 110 | + let field_sizes: Vec<_> = types.into_iter().map(max_size_for_type).collect(); |
| 111 | + quote! { |
| 112 | + 0usize #( + #field_sizes )* |
35 | 113 | } |
36 | | - let generated = quote! { |
37 | | - impl #name { |
38 | | - pub const fn discriminant(&self) -> u8 { |
39 | | - // SAFETY: Because we require the enum to be marked as `repr(u8)`, its layout is a `repr(C)` `union` |
40 | | - // between `repr(C)` structs, each of which has the `u8` discriminant as its first |
41 | | - // field, so we can read the discriminant without offsetting the pointer. |
42 | | - unsafe { |
43 | | - let ptr = self as *const Self; |
44 | | - let discriminant_ptr = ptr.cast::<u8>(); |
45 | | - *discriminant_ptr |
| 114 | +} |
| 115 | + |
| 116 | +fn build_struct_max_serialized_size(fields: &Fields) -> proc_macro2::TokenStream { |
| 117 | + let payload_size = match fields { |
| 118 | + Fields::Named(named) => { |
| 119 | + let types: Vec<_> = named.named.iter().map(|f| &f.ty).collect(); |
| 120 | + sum_max_sizes(types) |
| 121 | + } |
| 122 | + Fields::Unnamed(unnamed) => { |
| 123 | + let types: Vec<_> = unnamed.unnamed.iter().map(|f| &f.ty).collect(); |
| 124 | + sum_max_sizes(types) |
| 125 | + } |
| 126 | + Fields::Unit => quote! { 0usize }, |
| 127 | + }; |
| 128 | + |
| 129 | + quote! { |
| 130 | + const MAX_SERIALIZED_SIZE: usize = #payload_size; |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +fn build_enum_max_serialized_size( |
| 135 | + variants: &Punctuated<Variant, Comma>, |
| 136 | +) -> proc_macro2::TokenStream { |
| 137 | + let variant_sizes: Vec<_> = variants |
| 138 | + .iter() |
| 139 | + .map(|variant| { |
| 140 | + let payload_size = match &variant.fields { |
| 141 | + Fields::Named(named) => { |
| 142 | + let types: Vec<_> = named.named.iter().map(|f| &f.ty).collect(); |
| 143 | + sum_max_sizes(types) |
| 144 | + } |
| 145 | + Fields::Unnamed(unnamed) => { |
| 146 | + let types: Vec<_> = unnamed.unnamed.iter().map(|f| &f.ty).collect(); |
| 147 | + sum_max_sizes(types) |
| 148 | + } |
| 149 | + Fields::Unit => quote! { 0usize }, |
| 150 | + }; |
| 151 | + |
| 152 | + quote! { |
| 153 | + 1usize + (#payload_size) |
| 154 | + } |
| 155 | + }) |
| 156 | + .collect(); |
| 157 | + |
| 158 | + let mut max_expr = variant_sizes |
| 159 | + .first() |
| 160 | + .cloned() |
| 161 | + .expect("ByteCodec cannot be derived for enums without variants"); |
| 162 | + |
| 163 | + for size_expr in variant_sizes.iter().skip(1) { |
| 164 | + max_expr = quote! { |
| 165 | + { |
| 166 | + let a = #max_expr; |
| 167 | + let b = #size_expr; |
| 168 | + if a > b { a } else { b } |
| 169 | + } |
| 170 | + }; |
| 171 | + } |
| 172 | + |
| 173 | + quote! { |
| 174 | + const MAX_SERIALIZED_SIZE: usize = #max_expr; |
| 175 | + } |
| 176 | +} |
| 177 | + |
| 178 | +fn build_deserialize( |
| 179 | + enum_name: &syn::Ident, |
| 180 | + variants: &Punctuated<Variant, Comma>, |
| 181 | +) -> proc_macro2::TokenStream { |
| 182 | + let match_arms = variants.iter().map(|variant| { |
| 183 | + let variant_name = &variant.ident; |
| 184 | + let (_, variant_discriminant) = variant |
| 185 | + .discriminant |
| 186 | + .as_ref() |
| 187 | + .expect("Must explicitly specify discriminant"); |
| 188 | + |
| 189 | + match &variant.fields { |
| 190 | + Fields::Named(fields) => { |
| 191 | + let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect(); |
| 192 | + |
| 193 | + quote! { |
| 194 | + #variant_discriminant => { |
| 195 | + #( |
| 196 | + let (#field_names, input) = liquidcan_rust_macros::byte_codec::ByteCodec::deserialize(input)?; |
| 197 | + )* |
| 198 | + Ok((#enum_name::#variant_name { #( #field_names ),* }, input)) |
| 199 | + } |
| 200 | + } |
| 201 | + } |
| 202 | + Fields::Unnamed(fields) => { |
| 203 | + let field_idents: Vec<_> = (0..fields.unnamed.len()) |
| 204 | + .map(|i| format_ident!("f{}", i)) |
| 205 | + .collect(); |
| 206 | + |
| 207 | + quote! { |
| 208 | + #variant_discriminant => { |
| 209 | + #( |
| 210 | + let (#field_idents, input) = liquidcan_rust_macros::byte_codec::ByteCodec::deserialize(input)?; |
| 211 | + )* |
| 212 | + Ok((#enum_name::#variant_name( #( #field_idents ),* ), input)) |
| 213 | + } |
| 214 | + } |
| 215 | + } |
| 216 | + Fields::Unit => { |
| 217 | + quote! { |
| 218 | + #variant_discriminant => Ok((#enum_name::#variant_name, input)) |
46 | 219 | } |
47 | 220 | } |
48 | 221 | } |
49 | | - }; |
50 | | - generated.into() |
| 222 | + }); |
| 223 | + |
| 224 | + let expanded = quote! { fn deserialize(input: &[u8]) -> Result<(Self, &[u8]), liquidcan_rust_macros::byte_codec::DeserializationError> { |
| 225 | + let (discriminant, input) = input |
| 226 | + .split_first() |
| 227 | + .ok_or(liquidcan_rust_macros::byte_codec::DeserializationError::NotEnoughData)?; |
| 228 | + |
| 229 | + match discriminant { |
| 230 | + #(#match_arms,)* |
| 231 | + _ => Err(liquidcan_rust_macros::byte_codec::DeserializationError::InvalidDiscriminant(*discriminant)), |
| 232 | + } |
| 233 | + }}; |
| 234 | + |
| 235 | + return expanded; |
| 236 | +} |
| 237 | + |
| 238 | +fn build_struct_serialize(type_name: &syn::Ident, fields: &Fields) -> proc_macro2::TokenStream { |
| 239 | + match fields { |
| 240 | + Fields::Named(fields) => { |
| 241 | + let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect(); |
| 242 | + quote! { |
| 243 | + fn serialize(&self, out: &mut Vec<u8>) { |
| 244 | + let out_len_before = out.len(); |
| 245 | + let #type_name { #( #field_names ),* } = self; |
| 246 | + #( #field_names.serialize(out); )* |
| 247 | + assert!(out.len() - out_len_before <= Self::MAX_SERIALIZED_SIZE, "Serialized data exceeds MAX_SERIALIZED_SIZE"); |
| 248 | + } |
| 249 | + } |
| 250 | + } |
| 251 | + Fields::Unnamed(fields) => { |
| 252 | + let field_idents: Vec<_> = (0..fields.unnamed.len()) |
| 253 | + .map(|i| format_ident!("f{}", i)) |
| 254 | + .collect(); |
| 255 | + |
| 256 | + quote! { |
| 257 | + fn serialize(&self, out: &mut Vec<u8>) { |
| 258 | + let out_len_before = out.len(); |
| 259 | + let #type_name( #( #field_idents ),* ) = self; |
| 260 | + #( #field_idents.serialize(out); )* |
| 261 | + assert!(out.len() - out_len_before <= Self::MAX_SERIALIZED_SIZE, "Serialized data exceeds MAX_SERIALIZED_SIZE"); |
| 262 | + } |
| 263 | + } |
| 264 | + } |
| 265 | + Fields::Unit => { |
| 266 | + quote! { |
| 267 | + fn serialize(&self, _out: &mut Vec<u8>) { |
| 268 | + } |
| 269 | + } |
| 270 | + } |
| 271 | + } |
| 272 | +} |
| 273 | + |
| 274 | +fn build_struct_deserialize(type_name: &syn::Ident, fields: &Fields) -> proc_macro2::TokenStream { |
| 275 | + match fields { |
| 276 | + Fields::Named(fields) => { |
| 277 | + let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect(); |
| 278 | + |
| 279 | + quote! { |
| 280 | + fn deserialize(input: &[u8]) -> Result<(Self, &[u8]), liquidcan_rust_macros::byte_codec::DeserializationError> { |
| 281 | + #( |
| 282 | + let (#field_names, input) = liquidcan_rust_macros::byte_codec::ByteCodec::deserialize(input)?; |
| 283 | + )* |
| 284 | + Ok((#type_name { #( #field_names ),* }, input)) |
| 285 | + } |
| 286 | + } |
| 287 | + } |
| 288 | + Fields::Unnamed(fields) => { |
| 289 | + let field_idents: Vec<_> = (0..fields.unnamed.len()) |
| 290 | + .map(|i| format_ident!("f{}", i)) |
| 291 | + .collect(); |
| 292 | + |
| 293 | + quote! { |
| 294 | + fn deserialize(input: &[u8]) -> Result<(Self, &[u8]), liquidcan_rust_macros::byte_codec::DeserializationError> { |
| 295 | + #( |
| 296 | + let (#field_idents, input) = liquidcan_rust_macros::byte_codec::ByteCodec::deserialize(input)?; |
| 297 | + )* |
| 298 | + Ok((#type_name( #( #field_idents ),* ), input)) |
| 299 | + } |
| 300 | + } |
| 301 | + } |
| 302 | + Fields::Unit => { |
| 303 | + quote! { |
| 304 | + fn deserialize(input: &[u8]) -> Result<(Self, &[u8]), liquidcan_rust_macros::byte_codec::DeserializationError> { |
| 305 | + Ok((#type_name, input)) |
| 306 | + } |
| 307 | + } |
| 308 | + } |
| 309 | + } |
51 | 310 | } |
0 commit comments