forked from a-agmon/rs-graph-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcomplex_example.rs
More file actions
228 lines (193 loc) · 8.71 KB
/
Copy pathcomplex_example.rs
File metadata and controls
228 lines (193 loc) · 8.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
use async_trait::async_trait;
use graph_flow::{
Context, ExecutionStatus, FlowRunner, GraphBuilder, GraphStorage, InMemoryGraphStorage,
InMemorySessionStorage, NextAction, Session, SessionStorage, Task, TaskResult,
};
use rig::completion::{Chat, Message};
use rig::prelude::*;
use serde::Deserialize;
use std::sync::Arc;
use tracing::{Level, info};
// --- Sentiment analysis helpers -------------------------------------------------------------
#[derive(Deserialize)]
struct SentimentResponse {
sentiment: String,
}
const SENTIMENT_PROMPT: &str = r#"You are a helpful sentiment analysis assistant.
ANALYZE THE USER INPUT AND RESPOND **ONLY** WITH ONE OF THE FOLLOWING JSON OBJECTS:
{ "sentiment": "positive" }
{ "sentiment": "negative" }
If you are not sure, ask a short clarifying question **instead** of returning JSON. Do not add any additional text around the JSON.
"#;
/// Very small wrapper around `rig` to obtain an agent that can answer our prompt.
fn get_llm_agent() -> anyhow::Result<impl rig::completion::Chat> {
let api_key = std::env::var("OPENROUTER_API_KEY")
.map_err(|_| anyhow::anyhow!("OPENROUTER_API_KEY not set"))?;
let client = rig::providers::openrouter::Client::new(&api_key)
.map_err(|e| anyhow::anyhow!("Failed to create client: {}", e))?;
Ok(client
.agent("openai/gpt-4o-mini")
.preamble(SENTIMENT_PROMPT)
.build())
}
// --- Task A: run sentiment analysis ---------------------------------------------------------
struct SentimentAnalysisTask;
#[async_trait]
impl Task for SentimentAnalysisTask {
async fn run(&self, context: Context) -> graph_flow::Result<TaskResult> {
// Pull the user input we stored in the session context
let user_input: String = context
.get_sync("user_input")
.unwrap_or_else(|| "".to_string());
// Build the LLM agent
let agent = match get_llm_agent() {
Ok(a) => a,
Err(e) => {
// If the agent cannot be created (for example, the API key is missing) we fall back
// to a dummy implementation so that this example can still be executed without an LLM.
info!(error = %e, "Falling back to dummy sentiment detection (LLM not available)");
return self.dummy_sentiment(context, user_input).await;
}
};
// We are not using chat history here for simplicity, but rig expects a vector – supply an empty one.
let response = agent
.chat(&user_input, &mut Vec::<Message>::new())
.await
.map_err(|e| graph_flow::GraphError::TaskExecutionFailed(e.to_string()))?;
// Try to parse the JSON response returned by the LLM
if let Ok(parsed) = serde_json::from_str::<SentimentResponse>(response.trim()) {
let sentiment = parsed.sentiment;
info!(sentiment, "Sentiment detected – continuing");
// Persist the sentiment in the context so that the conditional edge can read it.
context.set("sentiment", sentiment.clone()).await;
// We want to proceed straight to the next task and execute it immediately.
return Ok(TaskResult::new(None, NextAction::ContinueAndExecute));
}
// If we are here the model did not return the expected JSON – treat its reply as a clarifying question.
context.add_assistant_message(response.clone()).await;
Ok(TaskResult::new(
Some(response),
NextAction::WaitForInput, // Wait for the user to answer the clarifying question.
))
}
}
impl SentimentAnalysisTask {
// Very small heuristic fallback in case an LLM is not available.
async fn dummy_sentiment(
&self,
context: Context,
user_input: String,
) -> graph_flow::Result<TaskResult> {
let lowered = user_input.to_lowercase();
let sentiment = if lowered.contains("good") || lowered.contains("love") {
"positive"
} else {
"negative"
};
context.set("sentiment", sentiment.to_string()).await;
Ok(TaskResult::new(None, NextAction::Continue))
}
}
// --- Task B: positive branch ----------------------------------------------------------------
struct PositiveResponseTask;
#[async_trait]
impl Task for PositiveResponseTask {
async fn run(&self, _context: Context) -> graph_flow::Result<TaskResult> {
let reply = "That is awesome to hear! Keep up the good vibes.".to_string();
Ok(TaskResult::new(Some(reply), NextAction::End))
}
}
// --- Task C: negative branch ----------------------------------------------------------------
struct NegativeResponseTask;
#[async_trait]
impl Task for NegativeResponseTask {
async fn run(&self, _context: Context) -> graph_flow::Result<TaskResult> {
let reply = "I am sorry to hear that. Let me know if there is anything I can do to help."
.to_string();
Ok(TaskResult::new(Some(reply), NextAction::End))
}
}
// --------------------------------------------------------------------------------------------
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// A little bit of logging so that the flow is easier to follow when the example is run.
tracing_subscriber::fmt()
.with_max_level(Level::INFO)
.compact()
.init();
// Capture the user input that we want to analyse. If none is supplied we fall back to a default.
let args: Vec<String> = std::env::args().collect();
let user_input = if args.len() > 1 {
args[1..].join(" ")
} else {
"I feel good today".to_string()
};
info!(%user_input, "Starting complex example");
// --- Storage ---------------------------------------------------------------------------------
let session_storage: Arc<dyn SessionStorage> = Arc::new(InMemorySessionStorage::new());
let graph_storage: Arc<dyn GraphStorage> = Arc::new(InMemoryGraphStorage::new());
// --- Build graph -----------------------------------------------------------------------------
let sentiment_task: Arc<dyn Task> = Arc::new(SentimentAnalysisTask);
let positive_task: Arc<dyn Task> = Arc::new(PositiveResponseTask);
let negative_task: Arc<dyn Task> = Arc::new(NegativeResponseTask);
let sentiment_id = sentiment_task.id().to_string();
let positive_id = positive_task.id().to_string();
let negative_id = negative_task.id().to_string();
let graph = Arc::new(
GraphBuilder::new("sentiment_flow")
.add_task(sentiment_task)
.add_task(positive_task)
.add_task(negative_task)
// Conditional routing based on the sentiment detected in the first task
.add_conditional_edge(
sentiment_id.clone(),
|context| {
context
.get_sync::<String>("sentiment")
.map(|s| s == "positive")
.unwrap_or(false)
},
positive_id.clone(),
negative_id.clone(),
)
.build(),
);
graph_storage
.save("sentiment_flow".to_string(), graph.clone())
.await?;
// --- Session ---------------------------------------------------------------------------------
let session_id = "sentiment_session_001".to_string();
let session = Session::new_from_task(session_id.clone(), &sentiment_id);
// Seed the session context with the user input gathered on the command line
session.context.set("user_input", user_input.clone()).await;
// Persist the session before we start executing the graph
session_storage.save(session.clone()).await?;
info!(%session_id, "Session created");
// --- Execute ---------------------------------------------------------------------------------
let runner = FlowRunner::new(graph.clone(), session_storage.clone());
loop {
let execution_result = runner.run(&session_id).await?;
if let Some(resp) = execution_result.response {
println!("Assistant: {}", resp);
}
match execution_result.status {
ExecutionStatus::Completed => {
info!("Workflow completed successfully");
break;
}
ExecutionStatus::Paused { next_task_id, reason } => {
info!("Workflow paused, will continue to task: {} (reason: {})", next_task_id, reason);
continue;
}
ExecutionStatus::WaitingForInput => {
info!("Waiting for user input, continuing...");
continue;
}
ExecutionStatus::Error(e) => {
eprintln!("Error: {}", e);
break;
}
}
}
Ok(())
}