-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmod.rs
More file actions
127 lines (115 loc) · 4.52 KB
/
mod.rs
File metadata and controls
127 lines (115 loc) · 4.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
mod preludes;
use super::{Def, ExtensionFn};
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens};
pub fn expand(mut def: Def) -> TokenStream2 {
let preludes = preludes::generate_preludes(&def);
let expanded_extension_fns = def
.extension_fns
.iter_mut()
.map(|extension_fn| expand_extension_fn(extension_fn, &def.parity_scale_codec))
.collect::<Vec<_>>();
let main_fn = expand_main(&def);
let new_items = quote! {
#preludes
#(#expanded_extension_fns)*
#main_fn
};
def.item
.content
.as_mut()
.expect("This is checked by parsing")
.1
.push(syn::Item::Verbatim(new_items));
def.item.into_token_stream()
}
// REVIEW: The `expand_extension_fn` and `expand_main` functions use `expect()` and
// `unreachable!()` in several places. While some of these might be safe due to checks in the
// parsing stage, it's generally better to return a `compile_error!` with a descriptive
// message. This will provide better error messages to the user of the macro.
fn expand_extension_fn(extension_fn: &mut ExtensionFn, parity_scale_codec: &syn::Path) -> TokenStream2 {
let extension_id = extension_fn.extension_id;
let fn_index = extension_fn.fn_index;
let fn_name = &extension_fn.item_fn.sig.ident;
let args = &extension_fn.item_fn.sig.inputs;
let enum_name = format_ident!("{}Call", fn_name);
let expanded_enum = quote! (
#[allow(non_camel_case_types)]
#[derive(#parity_scale_codec::Encode, #parity_scale_codec::Decode)]
enum #enum_name {
#[codec(index = #fn_index)]
#fn_name {
#args
}
}
);
let arg_names = args
.iter()
.map(|arg| {
let syn::FnArg::Typed(pat_type) = arg else {
unreachable!("Checked in parse stage")
};
&pat_type.pat
})
.collect::<Vec<_>>();
let fn_name_str = fn_name.to_string();
extension_fn.item_fn.block = Box::new(syn::parse_quote!(
{
let encoded_call = #parity_scale_codec::Encode::encode(&#enum_name::#fn_name {
#(#arg_names),*
});
let res = unsafe {
host_call(#extension_id, encoded_call.as_ptr() as u32, encoded_call.len() as u32)
};
let res_ptr = res as u32 as *const u8;
let res_len = (res >> 32) as usize;
let mut res_bytes = unsafe { core::slice::from_raw_parts(res_ptr, res_len) };
#parity_scale_codec::Decode::decode(&mut res_bytes).expect(concat!("Failed to decode result of ", #fn_name_str))
}
));
let modified_extension_fn = &extension_fn.item_fn;
quote!(
#expanded_enum
#modified_extension_fn
)
}
fn expand_main(def: &Def) -> TokenStream2 {
let parity_scale_codec = &def.parity_scale_codec;
// Generate match arms for each entrypoint
let match_arms = def.entrypoints.iter().enumerate().map(|(index, entrypoint)| {
let entrypoint_ident = &entrypoint.item_fn.sig.ident;
let arg_pats = entrypoint.item_fn.sig.inputs.iter().collect::<Vec<_>>();
let arg_identifiers = arg_pats
.iter()
.map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
pat_type.pat.to_token_stream()
} else {
unreachable!("Checked in parse stage")
}
})
.collect::<Vec<_>>();
quote! {
#index => {
#(let #arg_pats = #parity_scale_codec::Decode::decode(&mut arg_bytes)
.expect(concat!("Failed to decode arguments for ", stringify!(#entrypoint_ident)));)*
let res = #entrypoint_ident(#(#arg_identifiers),*);
let encoded_res = #parity_scale_codec::Encode::encode(&res);
(encoded_res.len() as u64) << 32 | (encoded_res.as_ptr() as u64)
}
}
});
quote! {
#[polkavm_derive::polkavm_export]
extern "C" fn pvq(arg_ptr: u32, size: u32) -> u64 {
// First stage: read fn_index
let fn_index = unsafe { *(arg_ptr as *const u8) } as usize;
// Second stage: read arg_bytes
let mut arg_bytes = unsafe { core::slice::from_raw_parts((arg_ptr + 1) as *const u8, (size - 1) as usize) };
match fn_index {
#(#match_arms,)*
_ => panic!("Invalid function index"),
}
}
}
}