-
Notifications
You must be signed in to change notification settings - Fork 177
Expand file tree
/
Copy pathtest_gcp_integration.py
More file actions
512 lines (445 loc) · 17.1 KB
/
test_gcp_integration.py
File metadata and controls
512 lines (445 loc) · 17.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
"""
Integration tests for GCP SDK.
These tests require GCP credentials and make real API calls.
Skip if GCP_PROJECT_ID env var is not set.
Prerequisites:
1. Authenticate with GCP: gcloud auth application-default login
2. Have "Vertex AI User" role on the project
The SDK automatically:
- Detects credentials via google.auth.default()
- Auto-refreshes tokens when they expire
- Builds the Vertex AI URL from project_id and region
Available models:
- Chat: mistral-small-2503, mistral-large-2501, ...
- FIM: codestral-2
See: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/mistral
Usage:
GCP_PROJECT_ID=<your-project-id> pytest tests/test_gcp_integration.py -v
Environment variables:
GCP_PROJECT_ID: GCP project ID (required, or auto-detected from credentials)
GCP_REGION: Vertex AI region (default: us-central1)
GCP_MODEL: Model name for chat (default: mistral-small-2503)
GCP_FIM_MODEL: Model name for FIM (default: codestral-2)
"""
import json
import os
import pytest
# Configuration from env vars
GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID")
GCP_REGION = os.environ.get("GCP_REGION", "us-central1")
GCP_MODEL = os.environ.get("GCP_MODEL", "mistral-small-2503")
GCP_FIM_MODEL = os.environ.get("GCP_FIM_MODEL", "codestral-2")
SKIP_REASON = "GCP_PROJECT_ID env var required"
pytestmark = pytest.mark.skipif(
not GCP_PROJECT_ID,
reason=SKIP_REASON
)
# Shared tool definition for tool-call tests
WEATHER_TOOL = {
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a city",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
}
@pytest.fixture
def gcp_client():
"""Create a GCP client for chat tests.
The SDK automatically:
- Detects credentials via google.auth.default()
- Auto-refreshes tokens when they expire
- Builds the Vertex AI URL from project_id and region
"""
from mistralai.gcp.client import MistralGCP
return MistralGCP(
project_id=GCP_PROJECT_ID,
region=GCP_REGION,
)
class TestGCPChatComplete:
"""Test synchronous chat completion."""
def test_basic_completion(self, gcp_client):
"""Test basic chat completion returns a response."""
res = gcp_client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Say 'hello' and nothing else."}
],
)
assert res is not None
assert res.choices is not None
assert len(res.choices) > 0
assert res.choices[0].message is not None
assert res.choices[0].message.content is not None
assert len(res.choices[0].message.content) > 0
def test_completion_with_system_message(self, gcp_client):
"""Test chat completion with system + user message."""
res = gcp_client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "system", "content": "You are a pirate. Respond in pirate speak."},
{"role": "user", "content": "Say hello."},
],
)
assert res is not None
assert res.choices[0].message.content is not None
assert len(res.choices[0].message.content) > 0
def test_completion_with_max_tokens(self, gcp_client):
"""Test chat completion respects max_tokens."""
res = gcp_client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Count from 1 to 100."}
],
max_tokens=10,
)
assert res is not None
assert res.choices[0].finish_reason in ("length", "stop")
def test_completion_with_temperature(self, gcp_client):
"""Test chat completion accepts temperature parameter."""
res = gcp_client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Say 'test'."}
],
temperature=0.0,
)
assert res is not None
assert res.choices[0].message.content is not None
def test_completion_with_stop_sequence(self, gcp_client):
"""Test chat completion stops at stop sequence."""
res = gcp_client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Write three sentences about the sky."}
],
stop=["."],
)
assert res is not None
content = res.choices[0].message.content
assert content is not None
# The model should stop at or before the first period
assert content.count(".") <= 1
def test_completion_with_random_seed(self, gcp_client):
"""Test chat completion with random_seed returns valid responses."""
res1 = gcp_client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Say 'deterministic'."}
],
random_seed=42,
)
res2 = gcp_client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Say 'deterministic'."}
],
random_seed=42,
)
# Both should return valid responses (not asserting equality due to model non-determinism)
assert res1.choices[0].message.content is not None
assert res2.choices[0].message.content is not None
def test_multi_turn_conversation(self, gcp_client):
"""Test multi-turn conversation with user/assistant round-trip."""
res1 = gcp_client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "My name is Alice."}
],
)
assert res1.choices[0].message.content is not None
res2 = gcp_client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "My name is Alice."},
{"role": "assistant", "content": res1.choices[0].message.content},
{"role": "user", "content": "What is my name?"},
],
)
assert res2.choices[0].message.content is not None
assert "Alice" in res2.choices[0].message.content
def test_tool_call(self, gcp_client):
"""Test that the model returns a tool call when given tools."""
res = gcp_client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "What is the weather in Paris?"}
],
tools=[WEATHER_TOOL],
tool_choice="any",
)
assert res is not None
choice = res.choices[0]
assert choice.message.tool_calls is not None
assert len(choice.message.tool_calls) > 0
tool_call = choice.message.tool_calls[0]
assert tool_call.function.name == "get_weather"
args = json.loads(tool_call.function.arguments)
assert "city" in args
def test_json_response_format(self, gcp_client):
"""Test JSON response format returns valid JSON."""
res = gcp_client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Return a JSON object with a key 'greeting' and value 'hello'."}
],
response_format={"type": "json_object"},
)
assert res is not None
content = res.choices[0].message.content
assert content is not None
parsed = json.loads(content)
assert isinstance(parsed, dict)
class TestGCPChatStream:
"""Test streaming chat completion."""
def test_basic_stream(self, gcp_client):
"""Test streaming returns chunks with content."""
stream = gcp_client.chat.stream(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Say 'hello' and nothing else."}
],
)
chunks = list(stream)
assert len(chunks) > 0
content = ""
for chunk in chunks:
if chunk.data.choices and chunk.data.choices[0].delta.content:
content += chunk.data.choices[0].delta.content
assert len(content) > 0
def test_stream_with_max_tokens(self, gcp_client):
"""Test streaming respects max_tokens truncation."""
stream = gcp_client.chat.stream(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Count from 1 to 100."}
],
max_tokens=10,
)
chunks = list(stream)
assert len(chunks) > 0
# Find finish_reason in any chunk
finish_reasons = [
chunk.data.choices[0].finish_reason
for chunk in chunks
if chunk.data.choices and chunk.data.choices[0].finish_reason is not None
]
assert len(finish_reasons) > 0
assert finish_reasons[-1] in ("length", "stop")
def test_stream_finish_reason(self, gcp_client):
"""Test that the last chunk has a finish_reason."""
stream = gcp_client.chat.stream(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Say 'hi'."}
],
)
chunks = list(stream)
assert len(chunks) > 0
# The final chunk(s) should contain a finish_reason
finish_reasons = [
chunk.data.choices[0].finish_reason
for chunk in chunks
if chunk.data.choices and chunk.data.choices[0].finish_reason is not None
]
assert len(finish_reasons) > 0
assert finish_reasons[-1] == "stop"
def test_stream_tool_call(self, gcp_client):
"""Test tool call via streaming, collecting tool_call delta chunks."""
stream = gcp_client.chat.stream(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "What is the weather in Paris?"}
],
tools=[WEATHER_TOOL],
tool_choice="any",
)
chunks = list(stream)
assert len(chunks) > 0
# Collect tool call information from delta chunks
tool_call_found = False
for chunk in chunks:
if chunk.data.choices and chunk.data.choices[0].delta.tool_calls:
tool_call_found = True
break
assert tool_call_found, "Expected tool_call delta chunks in stream"
class TestGCPChatCompleteAsync:
"""Test async chat completion."""
@pytest.mark.asyncio
async def test_basic_completion_async(self, gcp_client):
"""Test async chat completion returns a response."""
res = await gcp_client.chat.complete_async(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Say 'hello' and nothing else."}
],
)
assert res is not None
assert res.choices is not None
assert len(res.choices) > 0
assert res.choices[0].message.content is not None
@pytest.mark.asyncio
async def test_completion_with_system_message_async(self, gcp_client):
"""Test async chat completion with system + user message."""
res = await gcp_client.chat.complete_async(
model=GCP_MODEL,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say 'hello'."},
],
)
assert res is not None
assert res.choices[0].message.content is not None
@pytest.mark.asyncio
async def test_tool_call_async(self, gcp_client):
"""Test async tool call returns tool_calls."""
res = await gcp_client.chat.complete_async(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "What is the weather in Paris?"}
],
tools=[WEATHER_TOOL],
tool_choice="any",
)
assert res is not None
choice = res.choices[0]
assert choice.message.tool_calls is not None
assert len(choice.message.tool_calls) > 0
assert choice.message.tool_calls[0].function.name == "get_weather"
class TestGCPChatStreamAsync:
"""Test async streaming chat completion."""
@pytest.mark.asyncio
async def test_basic_stream_async(self, gcp_client):
"""Test async streaming returns chunks with content."""
stream = await gcp_client.chat.stream_async(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Say 'hello' and nothing else."}
],
)
content = ""
async for chunk in stream:
if chunk.data.choices and chunk.data.choices[0].delta.content:
content += chunk.data.choices[0].delta.content
assert len(content) > 0
class TestGCPContextManager:
"""Test context manager support."""
def test_sync_context_manager(self):
"""Test that MistralGCP works as a sync context manager."""
from mistralai.gcp.client import MistralGCP
with MistralGCP(
project_id=GCP_PROJECT_ID,
region=GCP_REGION,
) as client:
res = client.chat.complete(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Say 'context'."}
],
)
assert res is not None
assert res.choices[0].message.content is not None
@pytest.mark.asyncio
async def test_async_context_manager(self):
"""Test that MistralGCP works as an async context manager."""
from mistralai.gcp.client import MistralGCP
async with MistralGCP(
project_id=GCP_PROJECT_ID,
region=GCP_REGION,
) as client:
res = await client.chat.complete_async(
model=GCP_MODEL,
messages=[
{"role": "user", "content": "Say 'async context'."}
],
)
assert res is not None
assert res.choices[0].message.content is not None
class TestGCPFIM:
"""Test FIM (Fill-in-the-middle) completion."""
def _make_fim_client(self):
"""Create a GCP client configured for FIM model."""
from mistralai.gcp.client import MistralGCP
return MistralGCP(project_id=GCP_PROJECT_ID, region=GCP_REGION)
def test_fim_complete(self):
"""Test FIM completion returns a response."""
client = self._make_fim_client()
res = client.fim.complete(
model=GCP_FIM_MODEL,
prompt="def fib():",
suffix=" return result",
timeout_ms=10000,
)
assert res is not None
assert res.choices is not None
assert len(res.choices) > 0
assert res.choices[0].message.content is not None
def test_fim_stream(self):
"""Test FIM streaming returns chunks."""
client = self._make_fim_client()
stream = client.fim.stream(
model=GCP_FIM_MODEL,
prompt="def hello():",
suffix=" return greeting",
timeout_ms=10000,
)
chunks = list(stream)
assert len(chunks) > 0
content = ""
for chunk in chunks:
if chunk.data.choices and chunk.data.choices[0].delta.content:
delta_content = chunk.data.choices[0].delta.content
if isinstance(delta_content, str):
content += delta_content
assert len(content) > 0
def test_fim_with_max_tokens(self):
"""Test FIM completion with max_tokens."""
client = self._make_fim_client()
res = client.fim.complete(
model=GCP_FIM_MODEL,
prompt="def add(a, b):",
suffix=" return result",
max_tokens=10,
timeout_ms=10000,
)
assert res is not None
assert res.choices[0].finish_reason in ("length", "stop")
@pytest.mark.asyncio
async def test_fim_complete_async(self):
"""Test async FIM completion returns a response."""
client = self._make_fim_client()
res = await client.fim.complete_async(
model=GCP_FIM_MODEL,
prompt="def fib():",
suffix=" return result",
timeout_ms=10000,
)
assert res is not None
assert res.choices is not None
assert len(res.choices) > 0
assert res.choices[0].message.content is not None
@pytest.mark.asyncio
async def test_fim_stream_async(self):
"""Test async FIM streaming returns chunks."""
client = self._make_fim_client()
stream = await client.fim.stream_async(
model=GCP_FIM_MODEL,
prompt="def hello():",
suffix=" return greeting",
timeout_ms=10000,
)
chunks = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
content = ""
for chunk in chunks:
if chunk.data.choices and chunk.data.choices[0].delta.content:
delta_content = chunk.data.choices[0].delta.content
if isinstance(delta_content, str):
content += delta_content
assert len(content) > 0