diff --git a/CHANGELOG.md b/CHANGELOG.md index efc4558..61781b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ ## Release History +## [Unreleased] + +#### Features Added +* Embeddings and chat clients can now be injected via the new `embeddings_client` + and `chat_client` constructor arguments on `CosmosMemoryClient` and + `AsyncCosmosMemoryClient`. When supplied, the toolkit uses the provided client + instead of building an Azure-backed one, and does not close it (the caller owns + its lifecycle, mirroring the existing credential-ownership behavior). This enables + OpenAI-compatible / self-hosted embedding and chat backends, reuse of a + caller-configured client, and deterministic offline testing (for example against + the Cosmos DB emulator). + ## [0.2.0b1] (2026-06-30) #### Features Added diff --git a/azure/cosmos/agent_memory/aio/cosmos_memory_client.py b/azure/cosmos/agent_memory/aio/cosmos_memory_client.py index 4c5b4a7..7caa832 100644 --- a/azure/cosmos/agent_memory/aio/cosmos_memory_client.py +++ b/azure/cosmos/agent_memory/aio/cosmos_memory_client.py @@ -86,6 +86,8 @@ def __init__( chat_deployment_name: str = "gpt-4o-mini", use_default_credential: bool = True, enable_turn_embeddings: Optional[bool] = None, + embeddings_client: Optional[Any] = None, + chat_client: Optional[Any] = None, processor: Optional[AsyncMemoryProcessor] = None, transcript_metadata_keys: Optional[Iterable[str]] = None, ) -> None: @@ -113,19 +115,33 @@ def __init__( ) self._background_tasks: set[asyncio.Task[Any]] = set() self._pipeline_init_error: Exception | None = None - self._embeddings_client = AsyncEmbeddingsClient( - endpoint=self._ai_foundry_endpoint, - credential=self._ai_foundry_credential, - api_key=self._ai_foundry_api_key, - model=self._embedding_deployment_name, - dimensions=self._embedding_dimensions, - ) - self._chat_client = AsyncChatClient( - endpoint=self._ai_foundry_endpoint, - credential=self._ai_foundry_credential, - api_key=self._ai_foundry_api_key, - model=self._chat_deployment_name, - ) + # Embeddings/chat clients may be injected (e.g. an OpenAI-compatible backend, a + # caller-configured client, or a deterministic fake for offline tests). When a client + # is injected the caller owns its lifecycle, so the toolkit does not close it; otherwise + # the toolkit builds the Azure-backed client and closes it in ``close()``. + if embeddings_client is not None: + self._embeddings_client = embeddings_client + self._owns_embeddings_client = False + else: + self._embeddings_client = AsyncEmbeddingsClient( + endpoint=self._ai_foundry_endpoint, + credential=self._ai_foundry_credential, + api_key=self._ai_foundry_api_key, + model=self._embedding_deployment_name, + dimensions=self._embedding_dimensions, + ) + self._owns_embeddings_client = True + if chat_client is not None: + self._chat_client = chat_client + self._owns_chat_client = False + else: + self._chat_client = AsyncChatClient( + endpoint=self._ai_foundry_endpoint, + credential=self._ai_foundry_credential, + api_key=self._ai_foundry_api_key, + model=self._chat_deployment_name, + ) + self._owns_chat_client = True self._pipeline: Optional[AsyncPipelineService] = None self._processor: Optional[AsyncMemoryProcessor] = processor self._processor_explicit = processor is not None @@ -157,8 +173,10 @@ async def close(self) -> None: if self._processor is not None and not self._processor_explicit: await self._close_maybe_async(self._processor) self._processor = None - await self._embeddings_client.close() - await self._close_maybe_async(self._chat_client) + if self._owns_embeddings_client: + await self._close_maybe_async(self._embeddings_client) + if self._owns_chat_client: + await self._close_maybe_async(self._chat_client) for owns, cred in ( (self._owns_cosmos_credential, self._cosmos_credential), (self._owns_ai_foundry_credential, self._ai_foundry_credential), diff --git a/azure/cosmos/agent_memory/cosmos_memory_client.py b/azure/cosmos/agent_memory/cosmos_memory_client.py index f4924c3..4bd03a6 100644 --- a/azure/cosmos/agent_memory/cosmos_memory_client.py +++ b/azure/cosmos/agent_memory/cosmos_memory_client.py @@ -81,6 +81,8 @@ def __init__( chat_deployment_name: str = "gpt-4o-mini", use_default_credential: bool = True, enable_turn_embeddings: Optional[bool] = None, + embeddings_client: Optional[Any] = None, + chat_client: Optional[Any] = None, processor: Optional[MemoryProcessor] = None, transcript_metadata_keys: Optional[Iterable[str]] = None, ) -> None: @@ -105,19 +107,33 @@ def __init__( use_default_credential=use_default_credential, enable_turn_embeddings=enable_turn_embeddings, ) - self._embeddings_client = EmbeddingsClient( - endpoint=self._ai_foundry_endpoint, - credential=self._ai_foundry_credential, - api_key=self._ai_foundry_api_key, - model=self._embedding_deployment_name, - dimensions=self._embedding_dimensions, - ) - self._chat_client = ChatClient( - endpoint=self._ai_foundry_endpoint, - credential=self._ai_foundry_credential, - api_key=self._ai_foundry_api_key, - model=self._chat_deployment_name, - ) + # Embeddings/chat clients may be injected (e.g. an OpenAI-compatible backend, a + # caller-configured client, or a deterministic fake for offline tests). When a client + # is injected the caller owns its lifecycle, so the toolkit does not close it; otherwise + # the toolkit builds the Azure-backed client and closes it in ``close()``. + if embeddings_client is not None: + self._embeddings_client = embeddings_client + self._owns_embeddings_client = False + else: + self._embeddings_client = EmbeddingsClient( + endpoint=self._ai_foundry_endpoint, + credential=self._ai_foundry_credential, + api_key=self._ai_foundry_api_key, + model=self._embedding_deployment_name, + dimensions=self._embedding_dimensions, + ) + self._owns_embeddings_client = True + if chat_client is not None: + self._chat_client = chat_client + self._owns_chat_client = False + else: + self._chat_client = ChatClient( + endpoint=self._ai_foundry_endpoint, + credential=self._ai_foundry_credential, + api_key=self._ai_foundry_api_key, + model=self._chat_deployment_name, + ) + self._owns_chat_client = True self._pipeline: Optional[PipelineService] = None self._processor: Optional[MemoryProcessor] = processor self._processor_explicit = processor is not None @@ -146,8 +162,10 @@ def close(self) -> None: if self._processor is not None and not self._processor_explicit: self._close_sync_closeable(self._processor) self._processor = None - self._close_sync_closeable(self._chat_client) - self._close_sync_closeable(self._embeddings_client) + if self._owns_chat_client: + self._close_sync_closeable(self._chat_client) + if self._owns_embeddings_client: + self._close_sync_closeable(self._embeddings_client) for owns, cred in ( (self._owns_cosmos_credential, self._cosmos_credential), (self._owns_ai_foundry_credential, self._ai_foundry_credential), diff --git a/tests/unit/aio/test_cosmos_memory_client.py b/tests/unit/aio/test_cosmos_memory_client.py index 0010495..2b7904a 100644 --- a/tests/unit/aio/test_cosmos_memory_client.py +++ b/tests/unit/aio/test_cosmos_memory_client.py @@ -103,6 +103,70 @@ def test_default_credential_enabled(self): assert mem._cosmos_credential is not None +# =================================================================== +# Injected embeddings / chat clients +# =================================================================== + + +class _FakeEmbeddings: + """Minimal stand-in for AsyncEmbeddingsClient used to verify injection.""" + + def __init__(self) -> None: + self.closed = False + + async def close(self) -> None: + self.closed = True + + +class _FakeChat: + """Minimal stand-in for AsyncChatClient used to verify injection.""" + + def __init__(self) -> None: + self.closed = False + + async def close(self) -> None: + self.closed = True + + +class TestInjectedModelClients: + def test_injected_clients_are_used_and_not_owned(self): + emb = _FakeEmbeddings() + chat = _FakeChat() + mem = _make_client(embeddings_client=emb, chat_client=chat) + + assert mem._embeddings_client is emb + assert mem._chat_client is chat + assert mem._owns_embeddings_client is False + assert mem._owns_chat_client is False + + def test_default_clients_are_built_and_owned(self): + mem = _make_client() + + assert mem._embeddings_client is not None + assert mem._chat_client is not None + assert mem._owns_embeddings_client is True + assert mem._owns_chat_client is True + + def test_clients_can_be_injected_independently(self): + emb = _FakeEmbeddings() + mem = _make_client(embeddings_client=emb) + + assert mem._embeddings_client is emb + assert mem._owns_embeddings_client is False + # Chat client was not injected, so the toolkit builds and owns it. + assert mem._owns_chat_client is True + + async def test_close_does_not_close_injected_clients(self): + emb = _FakeEmbeddings() + chat = _FakeChat() + mem = _make_client(embeddings_client=emb, chat_client=chat) + + await mem.close() + + assert emb.closed is False + assert chat.closed is False + + # =================================================================== # Local CRUD (synchronous) # =================================================================== diff --git a/tests/unit/test_cosmos_memory_client.py b/tests/unit/test_cosmos_memory_client.py index c04b000..f84277d 100644 --- a/tests/unit/test_cosmos_memory_client.py +++ b/tests/unit/test_cosmos_memory_client.py @@ -90,6 +90,70 @@ def test_no_credential_when_flag_false(self): assert mem._ai_foundry_credential is None +# =================================================================== +# Injected embeddings / chat clients +# =================================================================== + + +class _FakeEmbeddings: + """Minimal stand-in for EmbeddingsClient used to verify injection.""" + + def __init__(self) -> None: + self.closed = False + + def close(self) -> None: + self.closed = True + + +class _FakeChat: + """Minimal stand-in for ChatClient used to verify injection.""" + + def __init__(self) -> None: + self.closed = False + + def close(self) -> None: + self.closed = True + + +class TestInjectedModelClients: + def test_injected_clients_are_used_and_not_owned(self): + emb = _FakeEmbeddings() + chat = _FakeChat() + mem = _make_client(embeddings_client=emb, chat_client=chat) + + assert mem._embeddings_client is emb + assert mem._chat_client is chat + assert mem._owns_embeddings_client is False + assert mem._owns_chat_client is False + + def test_default_clients_are_built_and_owned(self): + mem = _make_client() + + assert mem._embeddings_client is not None + assert mem._chat_client is not None + assert mem._owns_embeddings_client is True + assert mem._owns_chat_client is True + + def test_clients_can_be_injected_independently(self): + emb = _FakeEmbeddings() + mem = _make_client(embeddings_client=emb) + + assert mem._embeddings_client is emb + assert mem._owns_embeddings_client is False + # Chat client was not injected, so the toolkit builds and owns it. + assert mem._owns_chat_client is True + + def test_close_does_not_close_injected_clients(self): + emb = _FakeEmbeddings() + chat = _FakeChat() + mem = _make_client(embeddings_client=emb, chat_client=chat) + + mem.close() + + assert emb.closed is False + assert chat.closed is False + + # =================================================================== # Local CRUD # ===================================================================