Skip to content

Commit 6ce86ac

Browse files
author
RobJellinghaus
committed
Streaming works!
1 parent d261a96 commit 6ce86ac

4 files changed

Lines changed: 445 additions & 66 deletions

File tree

Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,22 @@ path = "backend/main.rs"
3232

3333
[dependencies]
3434
actix-files = "0.6.0"
35+
bytes = "1.0"
3536
actix-http = "3.6"
3637
actix-multipart = "0.6.0"
3738
actix-web = "4.5"
39+
actix-web-lab = "0.20"
3840
async-graphql = { version = "3.0.38", features = ["chrono"] }
3941
async-graphql-actix-web = "3.0.38"
4042
fang = "0.10.4"
4143
futures-util = "0.3.30"
4244
jsonwebtoken = "8.1.0"
43-
ollama-rs = "0.2.0"
45+
ollama-rs = { version = "0.2.0", features = ["stream"] }
46+
tokio-stream = { version = "0.1", features = ["time"] }
4447
utoipa-swagger-ui = { version="4", features=["actix-web"]}
4548
serde_json = "1"
4649
simple_logger = "5.0"
50+
tracing = "0.1"
4751
tsync = "3"
4852

4953
[dependencies.chrono]

backend/services/chat.rs

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
use actix_web::{post, HttpResponse, web::{Data, Json}};
1+
use actix_web::{post, get, HttpResponse, web::{Data, Json, Query}};
2+
use actix_web_lab::sse::{self, Sse};
23
use ollama_rs::{Ollama, generation::completion::request::GenerationRequest};
4+
use tokio_stream::StreamExt;
5+
use futures_util::stream::Stream;
6+
use tracing::debug;
7+
use std::pin::Pin;
38
use crate::models::suppliers::Suppliers;
49
use create_rust_app::Database;
510

@@ -35,6 +40,44 @@ impl ChatService {
3540
Ok(response.response)
3641
}
3742

43+
pub async fn chat_with_suppliers_stream(
44+
&self,
45+
user_message: String,
46+
db: &Database,
47+
) -> Result<Pin<Box<dyn Stream<Item = Result<String, Box<dyn std::error::Error + Send>>> + Send>>, Box<dyn std::error::Error>> {
48+
// 1. Fetch all current suppliers
49+
let suppliers = self.get_all_suppliers(db).await?;
50+
51+
// 2. Build context with supplier data
52+
let context = self.build_supplier_context(&suppliers);
53+
54+
// 3. Create procurement expert prompt
55+
let full_prompt = self.build_prompt(&context, &user_message);
56+
57+
// 4. Send to Ollama with streaming
58+
let request = GenerationRequest::new("mistral-small3.2:24b".to_string(), full_prompt);
59+
let stream = self.ollama.generate_stream(request).await?;
60+
61+
// 5. Transform the stream to extract text responses
62+
let text_stream = stream.map(|result| {
63+
match result {
64+
Ok(responses) => {
65+
let text = responses.iter()
66+
.map(|resp| resp.response.clone())
67+
.collect::<Vec<String>>()
68+
.join("");
69+
70+
debug!("streaming text response: {text}");
71+
72+
Ok(text)
73+
},
74+
Err(e) => Err(Box::new(e) as Box<dyn std::error::Error + Send>)
75+
}
76+
});
77+
78+
Ok(Box::pin(text_stream))
79+
}
80+
3881
async fn get_all_suppliers(&self, db: &Database) -> Result<Vec<Suppliers>, Box<dyn std::error::Error>> {
3982
use crate::models::suppliers::*;
4083
let mut conn = db.get_connection()?;
@@ -99,6 +142,12 @@ pub struct ChatRequest {
99142
pub message: String,
100143
}
101144

145+
#[tsync::tsync]
146+
#[derive(serde::Deserialize)]
147+
pub struct ChatStreamRequest {
148+
pub message: String,
149+
}
150+
102151
#[tsync::tsync]
103152
#[derive(serde::Serialize)]
104153
pub struct ChatResponse {
@@ -123,6 +172,60 @@ async fn chat(
123172
}
124173
}
125174

175+
#[get("/stream")]
176+
async fn chat_stream(
177+
db: Data<Database>,
178+
Query(request): Query<ChatStreamRequest>,
179+
) -> Sse<impl Stream<Item = Result<sse::Event, actix_web::Error>>> {
180+
let chat_service = ChatService::new();
181+
182+
let result_stream = match chat_service.chat_with_suppliers_stream(request.message, &db).await {
183+
Ok(stream) => {
184+
use futures_util::stream::StreamExt as FuturesStreamExt;
185+
186+
// Convert to proper SSE events
187+
let mapped_stream = FuturesStreamExt::map(stream, |result| {
188+
match result {
189+
Ok(text) => {
190+
if text.is_empty() {
191+
// Skip empty chunks
192+
Ok(sse::Event::Data(sse::Data::new("")))
193+
} else {
194+
// Send text as SSE data event
195+
let json_text = serde_json::to_string(&text).unwrap_or_else(|_| "\"\"".to_string());
196+
Ok(sse::Event::Data(sse::Data::new(json_text)))
197+
}
198+
},
199+
Err(e) => {
200+
eprintln!("Stream error: {}", e);
201+
let error_msg = serde_json::to_string(&format!("Error: {}", e)).unwrap_or_else(|_| "\"Error occurred\"".to_string());
202+
Ok(sse::Event::Data(sse::Data::new(error_msg).event("error")))
203+
}
204+
}
205+
});
206+
207+
let sse_stream = FuturesStreamExt::chain(mapped_stream, futures_util::stream::once(async {
208+
// Send completion event
209+
Ok(sse::Event::Data(sse::Data::new("").event("end")))
210+
}));
211+
212+
Box::pin(sse_stream) as Pin<Box<dyn Stream<Item = Result<sse::Event, actix_web::Error>> + Send>>
213+
},
214+
Err(e) => {
215+
eprintln!("Chat stream error: {}", e);
216+
let error_stream = futures_util::stream::once(async move {
217+
let error_msg = "Sorry, I'm having trouble processing your request right now.";
218+
Ok(sse::Event::Data(sse::Data::new(error_msg).event("error")))
219+
});
220+
Box::pin(error_stream) as Pin<Box<dyn Stream<Item = Result<sse::Event, actix_web::Error>> + Send>>
221+
}
222+
};
223+
224+
Sse::from_stream(result_stream)
225+
}
226+
126227
pub fn endpoints(scope: actix_web::Scope) -> actix_web::Scope {
127-
scope.service(chat)
228+
scope
229+
.service(chat)
230+
.service(chat_stream)
128231
}

0 commit comments

Comments
 (0)