Skip to content

Commit 27be1e2

Browse files
Workflow Payload Encoding: Add missing storage providers (#488)
* Workflow Payload Encoding: Add missing storage providers * Install extra dependencies for custom code lint
1 parent 2d01b47 commit 27be1e2

6 files changed

Lines changed: 405 additions & 2 deletions

File tree

scripts/lint_custom_code.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ echo "-> running on examples"
2727
uv run mypy examples/ \
2828
--exclude 'audio/' || ERRORS=1
2929
echo "-> running on extra"
30-
uv run mypy src/mistralai/extra/ || ERRORS=1
30+
uv run --all-extras mypy src/mistralai/extra/ || ERRORS=1
3131
echo "-> running on hooks"
3232
uv run mypy src/mistralai/client/_hooks/ \
3333
--exclude __init__.py --exclude sdkhooks.py --exclude types.py || ERRORS=1
@@ -48,7 +48,7 @@ echo "Running pyright..."
4848
# TODO: Uncomment once the examples are fixed
4949
# uv run pyright examples/ || ERRORS=1
5050
echo "-> running on extra"
51-
uv run pyright src/mistralai/extra/ || ERRORS=1
51+
uv run --all-extras pyright src/mistralai/extra/ || ERRORS=1
5252
echo "-> running on hooks"
5353
uv run pyright src/mistralai/client/_hooks/ || ERRORS=1
5454
echo "-> running on azure hooks"

src/mistralai/extra/py.typed

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Marker file for PEP 561. The package enables type hints.
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Tests for workflow encoding configuration lifecycle."""
2+
3+
import gc
4+
5+
import pytest
6+
from pydantic import SecretStr
7+
8+
from mistralai.client import Mistral
9+
from mistralai.client._hooks.workflow_encoding_hook import (
10+
_workflow_configs,
11+
_ENCODING_CONFIG_ID_ATTR,
12+
configure_workflow_encoding,
13+
)
14+
from mistralai.extra.workflows import (
15+
WorkflowEncodingConfig,
16+
PayloadEncryptionConfig,
17+
PayloadEncryptionMode,
18+
)
19+
20+
21+
@pytest.fixture
22+
def encryption_config() -> WorkflowEncodingConfig:
23+
"""Create a test encryption config."""
24+
return WorkflowEncodingConfig(
25+
payload_encryption=PayloadEncryptionConfig(
26+
mode=PayloadEncryptionMode.FULL,
27+
main_key=SecretStr("0" * 64), # 256-bit key in hex
28+
)
29+
)
30+
31+
32+
def test_payload_encoder_cleanup_on_client_gc(encryption_config: WorkflowEncodingConfig):
33+
"""Test that PayloadEncoder is cleaned up when client is garbage collected."""
34+
initial_config_count = len(_workflow_configs)
35+
36+
# Create client and configure encoding
37+
client = Mistral(api_key="test-key")
38+
configure_workflow_encoding(
39+
encryption_config,
40+
namespace="test-namespace",
41+
sdk_config=client.sdk_configuration,
42+
)
43+
44+
# Verify config was added
45+
config_id = getattr(client.sdk_configuration, _ENCODING_CONFIG_ID_ATTR)
46+
assert config_id is not None
47+
assert config_id in _workflow_configs
48+
assert len(_workflow_configs) == initial_config_count + 1
49+
50+
# Delete client and force garbage collection
51+
del client
52+
gc.collect()
53+
54+
# Verify config was cleaned up
55+
assert config_id not in _workflow_configs
56+
assert len(_workflow_configs) == initial_config_count
57+
58+
59+
def test_multiple_clients_independent_configs(encryption_config: WorkflowEncodingConfig):
60+
"""Test that multiple clients have independent configs."""
61+
initial_config_count = len(_workflow_configs)
62+
63+
# Create two clients with different namespaces
64+
client1 = Mistral(api_key="test-key-1")
65+
client2 = Mistral(api_key="test-key-2")
66+
67+
configure_workflow_encoding(
68+
encryption_config,
69+
namespace="namespace-1",
70+
sdk_config=client1.sdk_configuration,
71+
)
72+
configure_workflow_encoding(
73+
encryption_config,
74+
namespace="namespace-2",
75+
sdk_config=client2.sdk_configuration,
76+
)
77+
78+
# Verify both configs exist
79+
config_id1 = getattr(client1.sdk_configuration, _ENCODING_CONFIG_ID_ATTR)
80+
config_id2 = getattr(client2.sdk_configuration, _ENCODING_CONFIG_ID_ATTR)
81+
assert config_id1 != config_id2
82+
assert len(_workflow_configs) == initial_config_count + 2
83+
84+
# Verify namespaces are independent
85+
assert _workflow_configs[config_id1].namespace == "namespace-1"
86+
assert _workflow_configs[config_id2].namespace == "namespace-2"
87+
88+
# Delete first client
89+
del client1
90+
gc.collect()
91+
92+
# First config should be cleaned up, second should remain
93+
assert config_id1 not in _workflow_configs
94+
assert config_id2 in _workflow_configs
95+
assert len(_workflow_configs) == initial_config_count + 1
96+
97+
# Delete second client
98+
del client2
99+
gc.collect()
100+
101+
# Both configs should be cleaned up
102+
assert config_id2 not in _workflow_configs
103+
assert len(_workflow_configs) == initial_config_count
104+
105+
106+
def test_reconfigure_same_client(encryption_config: WorkflowEncodingConfig):
107+
"""Test that reconfiguring the same client updates the config."""
108+
client = Mistral(api_key="test-key")
109+
110+
# Initial configuration
111+
configure_workflow_encoding(
112+
encryption_config,
113+
namespace="namespace-v1",
114+
sdk_config=client.sdk_configuration,
115+
)
116+
117+
config_id = getattr(client.sdk_configuration, _ENCODING_CONFIG_ID_ATTR)
118+
assert _workflow_configs[config_id].namespace == "namespace-v1"
119+
120+
# Reconfigure with different namespace
121+
configure_workflow_encoding(
122+
encryption_config,
123+
namespace="namespace-v2",
124+
sdk_config=client.sdk_configuration,
125+
)
126+
127+
# Should use same config_id but updated namespace
128+
assert getattr(client.sdk_configuration, _ENCODING_CONFIG_ID_ATTR) == config_id
129+
assert _workflow_configs[config_id].namespace == "namespace-v2"
130+
131+
# Cleanup
132+
del client
133+
gc.collect()
134+
assert config_id not in _workflow_configs
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, cast
4+
5+
from azure.core.exceptions import ResourceNotFoundError
6+
from azure.storage.blob.aio import BlobServiceClient
7+
from .blob_storage import BlobNotFoundError, BlobStorage
8+
9+
10+
class AzureBlobStorage(BlobStorage):
11+
def __init__(
12+
self,
13+
container_name: str,
14+
azure_connection_string: str,
15+
prefix: str | None = None,
16+
):
17+
self.container_name = container_name
18+
self.connection_string = azure_connection_string
19+
self.prefix = prefix or ""
20+
self._service_client: BlobServiceClient | None = None
21+
self._container_client: Any = None
22+
23+
def _get_full_key(self, key: str) -> str:
24+
if not self.prefix:
25+
return key
26+
if key.startswith(self.prefix):
27+
return key
28+
return f"{self.prefix}/{key}"
29+
30+
async def __aenter__(self) -> "AzureBlobStorage":
31+
self._service_client = BlobServiceClient.from_connection_string(
32+
self.connection_string
33+
)
34+
assert self._service_client is not None
35+
self._container_client = self._service_client.get_container_client(
36+
self.container_name
37+
)
38+
return self
39+
40+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
41+
if self._service_client:
42+
await self._service_client.close()
43+
44+
async def upload_blob(self, key: str, content: bytes) -> str:
45+
full_key = self._get_full_key(key)
46+
blob_client = self._container_client.get_blob_client(full_key)
47+
await blob_client.upload_blob(content, overwrite=True)
48+
return cast(str, blob_client.url)
49+
50+
async def get_blob(self, key: str) -> bytes:
51+
full_key = self._get_full_key(key)
52+
blob_client = self._container_client.get_blob_client(full_key)
53+
try:
54+
stream = await blob_client.download_blob()
55+
return cast(bytes, await stream.readall())
56+
except ResourceNotFoundError as e:
57+
raise BlobNotFoundError(f"Blob not found: {key}") from e
58+
59+
async def get_blob_properties(self, key: str) -> dict[str, Any] | None:
60+
full_key = self._get_full_key(key)
61+
blob_client = self._container_client.get_blob_client(full_key)
62+
try:
63+
props = await blob_client.get_blob_properties()
64+
return {"size": props.size, "last_modified": props.last_modified}
65+
except ResourceNotFoundError:
66+
return None
67+
68+
async def delete_blob(self, key: str) -> None:
69+
full_key = self._get_full_key(key)
70+
blob_client = self._container_client.get_blob_client(full_key)
71+
await blob_client.delete_blob()
72+
73+
async def blob_exists(self, key: str) -> bool:
74+
full_key = self._get_full_key(key)
75+
blob_client = self._container_client.get_blob_client(full_key)
76+
return cast(bool, await blob_client.exists())
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, cast
4+
5+
import aiohttp
6+
from gcloud.aio.storage import Storage
7+
8+
from .blob_storage import BlobNotFoundError, BlobStorage
9+
10+
11+
class GCSBlobStorage(BlobStorage):
12+
def __init__(self, bucket_id: str, prefix: str | None = None):
13+
self.bucket_id = bucket_id
14+
self.prefix = prefix or ""
15+
self._storage: Storage | None = None
16+
self._session: aiohttp.ClientSession | None = None
17+
18+
def _get_full_key(self, key: str) -> str:
19+
if not self.prefix:
20+
return key
21+
if key.startswith(self.prefix):
22+
return key
23+
return f"{self.prefix}/{key}"
24+
25+
async def __aenter__(self) -> "GCSBlobStorage":
26+
self._session = aiohttp.ClientSession()
27+
self._storage = Storage(session=self._session)
28+
return self
29+
30+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
31+
if self._storage:
32+
await self._storage.close()
33+
if self._session:
34+
await self._session.close()
35+
36+
async def upload_blob(self, key: str, content: bytes) -> str:
37+
full_key = self._get_full_key(key)
38+
assert self._storage is not None
39+
response = await self._storage.upload(self.bucket_id, full_key, content)
40+
return str(response.get("selfLink"))
41+
42+
async def get_blob(self, key: str) -> bytes:
43+
full_key = self._get_full_key(key)
44+
assert self._storage is not None
45+
try:
46+
content = await self._storage.download(self.bucket_id, full_key)
47+
return cast(bytes, content)
48+
except Exception as e:
49+
if "404" in str(e) or "Not Found" in str(e):
50+
raise BlobNotFoundError(f"Blob not found: {key}") from e
51+
raise
52+
53+
async def get_blob_properties(self, key: str) -> dict[str, Any] | None:
54+
full_key = self._get_full_key(key)
55+
assert self._storage is not None
56+
try:
57+
metadata = await self._storage.download_metadata(self.bucket_id, full_key)
58+
return {
59+
"size": int(metadata.get("size", 0)),
60+
"last_modified": metadata.get("updated"),
61+
}
62+
except Exception as e:
63+
if "404" in str(e) or "Not Found" in str(e):
64+
return None
65+
raise
66+
67+
async def delete_blob(self, key: str) -> None:
68+
full_key = self._get_full_key(key)
69+
assert self._storage is not None
70+
await self._storage.delete(self.bucket_id, full_key)
71+
72+
async def blob_exists(self, key: str) -> bool:
73+
full_key = self._get_full_key(key)
74+
assert self._storage is not None
75+
try:
76+
await self._storage.download_metadata(self.bucket_id, full_key)
77+
return True
78+
except Exception as e:
79+
if "404" in str(e) or "Not Found" in str(e):
80+
return False
81+
raise

0 commit comments

Comments
 (0)