-
Notifications
You must be signed in to change notification settings - Fork 502
Expand file tree
/
Copy pathprompt_handler.rs
More file actions
165 lines (143 loc) · 4.97 KB
/
prompt_handler.rs
File metadata and controls
165 lines (143 loc) · 4.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
use darling::FromMeta;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Expr, ImplItem, ItemImpl, parse_quote};
use crate::{
common::{has_method, has_sibling_handler},
tool_handler::{CallerCapability, build_get_info},
};
#[derive(FromMeta, Debug, Default)]
#[darling(default)]
pub struct PromptHandlerAttribute {
pub router: Option<Expr>,
pub meta: Option<Expr>,
}
pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
let attribute = if attr.is_empty() {
Default::default()
} else {
let attr_args = darling::ast::NestedMeta::parse_meta_list(attr)?;
PromptHandlerAttribute::from_list(&attr_args)?
};
let mut impl_block = syn::parse2::<ItemImpl>(input)?;
let router_expr = attribute
.router
.unwrap_or_else(|| syn::parse2(quote! { Self::prompt_router() }).unwrap());
// Add get_prompt implementation
let get_prompt_impl: ImplItem = parse_quote! {
async fn get_prompt(
&self,
request: GetPromptRequestParams,
context: RequestContext<RoleServer>,
) -> Result<GetPromptResult, rmcp::ErrorData> {
let prompt_context = rmcp::handler::server::prompt::PromptContext::new(
self,
request.name,
request.arguments,
context,
);
#router_expr.get_prompt(prompt_context).await
}
};
let meta = if let Some(meta) = attribute.meta {
quote! { Some(#meta) }
} else {
quote! { None }
};
// Add list_prompts implementation
let list_prompts_impl: ImplItem = parse_quote! {
async fn list_prompts(
&self,
_request: Option<PaginatedRequestParams>,
_context: RequestContext<RoleServer>,
) -> Result<ListPromptsResult, rmcp::ErrorData> {
let prompts = #router_expr.list_all();
Ok(ListPromptsResult {
prompts,
meta: #meta,
next_cursor: None,
})
}
};
// Check if methods already exist and replace them if they do
let mut has_get_prompt = false;
let mut has_list_prompts = false;
for item in &mut impl_block.items {
if let ImplItem::Fn(fn_item) = item {
match fn_item.sig.ident.to_string().as_str() {
"get_prompt" => {
*item = get_prompt_impl.clone();
has_get_prompt = true;
}
"list_prompts" => {
*item = list_prompts_impl.clone();
has_list_prompts = true;
}
_ => {}
}
}
}
// Add methods if they don't exist
if !has_get_prompt {
impl_block.items.push(get_prompt_impl);
}
if !has_list_prompts {
impl_block.items.push(list_prompts_impl);
}
// Auto-generate get_info() if not already provided
if !has_method("get_info", &impl_block) {
// Detect whether tool_handler is also present — if so, it will generate get_info
// with both capabilities. Only generate here if tool_handler is NOT present.
if !has_sibling_handler(&impl_block, "tool_handler") {
let get_info_fn =
build_get_info(&impl_block, None, None, None, CallerCapability::Prompts)?;
impl_block.items.push(get_info_fn);
}
}
Ok(quote! {
#impl_block
})
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_prompt_handler_macro() -> syn::Result<()> {
let input = quote! {
impl ServerHandler for MyPromptHandler {
// Other handler methods...
}
};
let result = prompt_handler(TokenStream::new(), input)?;
let result_str = result.to_string();
// Check that the required methods were generated
assert!(result_str.contains("async fn get_prompt"));
assert!(result_str.contains("PromptContext") && result_str.contains("new"));
assert!(result_str.contains("async fn list_prompts"));
assert!(result_str.contains("ListPromptsResult"));
Ok(())
}
#[test]
fn test_prompt_handler_with_custom_router() -> syn::Result<()> {
let attr = quote! { router = self.get_prompt_router() };
let input = quote! {
impl ServerHandler for MyPromptHandler {
// Other handler methods...
}
};
let result = prompt_handler(attr, input)?;
let result_str = result.to_string();
// Check that the custom router expression is used
assert!(
result_str.contains("self")
&& result_str.contains("get_prompt_router")
&& result_str.contains("get_prompt")
);
assert!(
result_str.contains("self")
&& result_str.contains("get_prompt_router")
&& result_str.contains("list_all")
);
Ok(())
}
}