Skip to content

Commit 8530ff6

Browse files
Fix DekuRead for #[repr(...)] enums constructed using id via ctx and whose variants assign discriminant values (sharksforarms#577)
* deku-derive: Derive Clone, Copy for ReprType ReprType is a simple enum; copy and clone operations have the same cost as forming pointers, so let's allow them to reduce some source noise. Signed-off-by: Andrew Jeffery <andrew@codeconstruct.com.au> * deku-derive: Prune discriminant error paths with let-else Reduce indentation and improve readability by eliminating the match expression. Also, in my opinion, describing the constraints on the goal in a linear fashion makes the code easier to reason about. Signed-off-by: Andrew Jeffery <andrew@codeconstruct.com.au> * deku-derive: Drop redundant <T>::try_from() for discriminant access The unsafe expression already dereferences the pointer cast to the target type. Signed-off-by: Andrew Jeffery <andrew@codeconstruct.com.au> * deku-derive: emit_enum(): Prefer repr over id_type for conversions Use of the `id` attribute in enum contexts precludes use of `id_type`, under which condition id_type is None. A None id_type breaks the generated token stream at the point of the cast in the unsafe expression accessing the discriminant value. However, we must have a valid `repr` and that repr must match id_type where it's present, by elimination of paths to the contrary in the existing error handling. Emit the token for the repr type in place of id_type for the purpose of the type conversion logic. Signed-off-by: Andrew Jeffery <andrew@codeconstruct.com.au> --------- Signed-off-by: Andrew Jeffery <andrew@codeconstruct.com.au> Co-authored-by: Emmanuel Thompson <ethompson@fastly.com>
1 parent 01b9f05 commit 8530ff6

4 files changed

Lines changed: 82 additions & 24 deletions

File tree

deku-derive/src/lib.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ struct DekuData {
184184
bit_order: Option<syn::LitStr>,
185185
}
186186

187-
#[derive(Debug, PartialEq, Eq)]
187+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
188188
enum ReprType {
189189
U8,
190190
U16,
@@ -198,6 +198,23 @@ enum ReprType {
198198
I128,
199199
}
200200

201+
impl From<ReprType> for TokenStream {
202+
fn from(value: ReprType) -> Self {
203+
match value {
204+
ReprType::U8 => quote! { u8 },
205+
ReprType::U16 => quote! { u16 },
206+
ReprType::U32 => quote! { u32 },
207+
ReprType::U64 => quote! { u64 },
208+
ReprType::U128 => quote! { u128 },
209+
ReprType::I8 => quote! { i8 },
210+
ReprType::I16 => quote! { i16 },
211+
ReprType::I32 => quote! { i32 },
212+
ReprType::I64 => quote! { i64 },
213+
ReprType::I128 => quote! { i128 },
214+
}
215+
}
216+
}
217+
201218
fn from_token(ts: TokenStream) -> Option<ReprType> {
202219
let mut repr_value = None;
203220

deku-derive/src/macros/deku_read.rs

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -257,36 +257,32 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
257257
variant_id_pat.clone()
258258
}
259259
} else if has_discriminant {
260-
match &input.repr {
261-
None => {
260+
let Some(repr) = input.repr else {
261+
return Err(syn::Error::new(
262+
variant.ident.span(),
263+
"DekuRead: `id_type` must be specified on non-unit variants",
264+
));
265+
};
266+
if let Some(id_type) = id_type {
267+
let Some(id_type_repr) = from_token(id_type.clone()) else {
262268
return Err(syn::Error::new(
263269
variant.ident.span(),
264-
"DekuRead: `id_type` must be specified on non-unit variants",
270+
"DekuRead: `repr` must be specified on non-unit variants",
271+
));
272+
};
273+
if id_type_repr != repr {
274+
return Err(syn::Error::new(
275+
variant.ident.span(),
276+
"DekuRead: `repr` must match `id_type`",
265277
));
266-
}
267-
Some(repr) => {
268-
if let Some(id_type) = id_type {
269-
if let Some(id_type_repr) = from_token(id_type.clone()) {
270-
if id_type_repr != *repr {
271-
return Err(syn::Error::new(
272-
variant.ident.span(),
273-
"DekuRead: `repr` must match `id_type`",
274-
));
275-
}
276-
} else {
277-
return Err(syn::Error::new(
278-
variant.ident.span(),
279-
"DekuRead: `repr` must be specified on non-unit variants",
280-
));
281-
}
282-
}
283278
}
284279
}
280+
let repr_type: TokenStream = repr.into();
285281
let ident = &variant.ident;
286282
let internal_ident = gen_internal_field_ident(&quote!(#ident));
287283
pre_match_tokens.push(quote! {
288284
// https://doc.rust-lang.org/reference/items/enumerations.html#r-items.enum.discriminant.access-memory
289-
let #internal_ident = <#id_type>::try_from(unsafe { *(&Self::#ident as *const Self as *const #id_type) })?;
285+
let #internal_ident = unsafe { *(&Self::#ident as *const Self as *const #repr_type) };
290286
});
291287
quote! { _ if __deku_variant_id == #internal_ident }
292288
} else {

deku-derive/src/macros/deku_write.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
267267
// if the variant has fields, the first must be storing the id
268268
quote! {}
269269
} else if has_discriminant {
270-
match &input.repr {
270+
match input.repr {
271271
None => {
272272
return Err(syn::Error::new(
273273
variant.ident.span(),
@@ -277,7 +277,7 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
277277
Some(repr) => {
278278
if let Some(id_type) = id_type {
279279
if let Some(id_type_repr) = from_token(id_type.clone()) {
280-
if id_type_repr != *repr {
280+
if id_type_repr != repr {
281281
return Err(syn::Error::new(
282282
variant.ident.span(),
283283
"DekuWrite: `repr` must match `id_type`",

tests/test_enum.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,48 @@ fn test_variable_endian_enum(input: &[u8], expected: VariableEndian) {
362362
let ret_write: Vec<u8> = ret_read.try_into().unwrap();
363363
assert_eq!(input.to_vec(), ret_write);
364364
}
365+
366+
#[test]
367+
fn test_repr_assignment_with_id_via_ctx() {
368+
use deku::ctx::Endian;
369+
370+
#[derive(Debug, DekuRead, DekuWrite, Eq, PartialEq)]
371+
#[deku(ctx = "endian: Endian, mid: u8", id = "mid", endian = "endian")]
372+
#[repr(u8)]
373+
enum Body {
374+
First = 0x00,
375+
#[deku(id = "0x01")]
376+
Second(u8),
377+
}
378+
379+
#[derive(Debug, DekuRead, DekuWrite, Eq, PartialEq)]
380+
#[deku(endian = "little")]
381+
struct Message {
382+
id: u8,
383+
header: u16,
384+
#[deku(ctx = "*id")]
385+
body: Body,
386+
}
387+
388+
let input = [0u8, 1u8, 0u8];
389+
let mut cursor = Cursor::new(input);
390+
assert_eq!(
391+
Message {
392+
id: 0,
393+
header: 1,
394+
body: Body::First,
395+
},
396+
Message::from_reader((&mut cursor, 0)).unwrap().1
397+
);
398+
399+
let input = [1u8, 2u8, 0u8, 3u8];
400+
let mut cursor = Cursor::new(input);
401+
assert_eq!(
402+
Message {
403+
id: 1,
404+
header: 2,
405+
body: Body::Second(3),
406+
},
407+
Message::from_reader((&mut cursor, 0)).unwrap().1
408+
);
409+
}

0 commit comments

Comments
 (0)