forked from InftyAI/AMRS
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient.rs
More file actions
109 lines (97 loc) · 3.3 KB
/
client.rs
File metadata and controls
109 lines (97 loc) · 3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use dotenvy::from_filename;
use arms::client;
use arms::types::chat;
use arms::types::responses;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_response() {
from_filename(".env.integration-test").ok();
// case 1: one model.
let config = client::Config::builder()
.provider("faker")
.model(
client::ModelConfig::builder()
.name("fake-model")
.build()
.unwrap(),
)
.build()
.unwrap();
let mut client = client::Client::new(config);
let request = responses::CreateResponseArgs::default()
.input("tell me the weather today")
.build()
.unwrap();
let response = client.create_response(request).await.unwrap();
assert!(response.id.starts_with("fake-response-id"));
assert!(response.model == "fake-model");
// case 2: specify model in request.
let config = client::Config::builder()
.provider("openai")
.model(
client::ModelConfig::builder()
.name("gpt-3.5-turbo")
.build()
.unwrap(),
)
.build()
.unwrap();
let mut client = client::Client::new(config);
let request = responses::CreateResponseArgs::default()
.model("gpt-3.5-turbo")
.input("tell me a joke")
.build()
.unwrap();
let response = client.create_response(request).await;
assert!(response.is_err());
// case 3: multiple models with router.
let config = client::Config::builder()
.provider("faker")
.router_mode(client::RouterMode::WRR)
.model(
client::ModelConfig::builder()
.name("gpt-3.5-turbo")
.weight(1)
.build()
.unwrap(),
)
.model(
client::ModelConfig::builder()
.name("gpt-4")
.weight(1)
.build()
.unwrap(),
)
.build()
.unwrap();
let mut client = client::Client::new(config);
let request = responses::CreateResponseArgs::default()
.input("give me a poem about nature")
.build()
.unwrap();
let _ = client.create_response(request).await.unwrap();
}
#[tokio::test]
async fn test_completion() {
from_filename(".env.integration-test").ok();
let config = client::Config::builder()
.provider("faker")
.model(
client::ModelConfig::builder()
.name("fake-completion-model")
.build()
.unwrap(),
)
.build()
.unwrap();
let mut client = client::Client::new(config);
let request = chat::CreateChatCompletionRequestArgs::default()
.build()
.unwrap();
let response = client.create_completion(request).await.unwrap();
assert!(response.id.starts_with("fake-completion-id"));
assert!(response.model == "fake-completion-model");
}
}