From 2347428db4fb4e59257d90cd43ac931067dff80d Mon Sep 17 00:00:00 2001 From: Aayush Kataria Date: Wed, 1 Jul 2026 09:29:26 -0700 Subject: [PATCH 1/2] Rebasing --- .env.template | 17 - CHANGELOG.md | 24 + Docs/concepts.md | 40 +- Docs/design_patterns.md | 1 - Docs/public_api.md | 8 +- Docs/troubleshooting.md | 10 +- Samples/Advanced/advanced_search_patterns.py | 13 +- Samples/Notebooks/Demo_async.ipynb | 2 +- azure/cosmos/agent_memory/_counters.py | 78 +- azure/cosmos/agent_memory/_utils.py | 206 ++++- azure/cosmos/agent_memory/aio/auto_trigger.py | 61 +- .../agent_memory/aio/cosmos_memory_client.py | 23 +- .../agent_memory/aio/processors/base.py | 2 + .../agent_memory/aio/processors/durable.py | 3 +- .../agent_memory/aio/processors/inprocess.py | 28 +- .../agent_memory/aio/services/pipeline.py | 750 ++++++++++++----- .../agent_memory/aio/store/memory_store.py | 27 +- azure/cosmos/agent_memory/auto_trigger.py | 62 +- .../agent_memory/cosmos_memory_client.py | 29 +- azure/cosmos/agent_memory/processors/base.py | 2 + .../cosmos/agent_memory/processors/durable.py | 3 +- .../agent_memory/processors/inprocess.py | 29 +- azure/cosmos/agent_memory/prompts/_schemas.py | 35 +- .../cosmos/agent_memory/prompts/dedup.prompty | 2 +- .../prompts/dedup_episodic.prompty | 160 ++++ .../prompts/extract_memories.prompty | 293 +------ .../services/_pipeline_helpers.py | 61 +- .../cosmos/agent_memory/services/pipeline.py | 764 ++++++++++++++---- .../agent_memory/store/_search_helpers.py | 21 +- .../cosmos/agent_memory/store/memory_store.py | 27 +- azure/cosmos/agent_memory/thresholds.py | 76 ++ function_app/local.settings.json.template | 1 + .../orchestrators/extract_memories.py | 76 +- function_app/shared/config.py | 11 - function_app/shared/counters.py | 39 + function_app/triggers/change_feed.py | 29 + tests/integration/test_async_full_pipeline.py | 274 +++++++ tests/integration/test_full_pipeline.py | 129 ++- .../integration/test_processor_integration.py | 9 +- .../test_processor_integration_async.py | 17 +- tests/unit/aio/processors/test_inprocess.py | 22 +- .../processors/test_protocol_satisfaction.py | 1 + .../aio/services/test_dedup_vector_async.py | 446 ++++++++++ tests/unit/aio/test_auto_trigger.py | 259 +++++- tests/unit/aio/test_cosmos_memory_client.py | 106 ++- tests/unit/aio/test_process_now.py | 10 +- tests/unit/aio/test_reconcile_telemetry.py | 8 + tests/unit/function_app/test_change_feed.py | 83 +- tests/unit/function_app/test_orchestrators.py | 198 ++++- tests/unit/processors/test_inprocess.py | 23 +- .../processors/test_protocol_satisfaction.py | 1 + .../services/test_chaos_extract_persist.py | 20 + tests/unit/services/test_dedup_vector.py | 383 +++++++++ tests/unit/services/test_extract_dry.py | 568 +++++-------- tests/unit/services/test_pipeline_service.py | 33 +- tests/unit/store/test_memory_store.py | 57 ++ tests/unit/test_auto_trigger.py | 214 ++++- tests/unit/test_cosmos_memory_client.py | 113 ++- tests/unit/test_pipeline_confidence.py | 218 ++--- tests/unit/test_procedural_synthesis.py | 12 + tests/unit/test_process_now.py | 10 +- tests/unit/test_reconcile.py | 404 +-------- tests/unit/test_thresholds.py | 147 ++++ tests/unit/test_utils.py | 146 +++- 64 files changed, 5029 insertions(+), 1895 deletions(-) create mode 100644 azure/cosmos/agent_memory/prompts/dedup_episodic.prompty create mode 100644 tests/integration/test_async_full_pipeline.py create mode 100644 tests/unit/aio/services/test_dedup_vector_async.py create mode 100644 tests/unit/services/test_dedup_vector.py diff --git a/.env.template b/.env.template index 68bef3c..a264e6e 100644 --- a/.env.template +++ b/.env.template @@ -18,15 +18,6 @@ COSMOS_DB_SUMMARIES_CONTAINER="memories_summaries" COSMOS_DB_TURNS_CONTAINER="memories_turns" COSMOS_DB_COUNTERS_CONTAINER=counter COSMOS_DB_LEASE_CONTAINER=leases -# Throughput mode for all required Cosmos DB containers created by the toolkit -# (memories, counter, and lease). -# - serverless: default. The toolkit does not send container RU/s settings. -# Use this only with a Cosmos DB account configured for serverless. -# - autoscale: the toolkit provisions all required containers with autoscale -# throughput using COSMOS_DB_AUTOSCALE_MAX_RU as the max RU/s cap. -# Default max RU/s is 1000. -COSMOS_DB_THROUGHPUT_MODE=serverless -COSMOS_DB_AUTOSCALE_MAX_RU=1000 # ---- Processing thresholds (set to 0 to disable) ---- THREAD_SUMMARY_EVERY_N=10 @@ -58,14 +49,6 @@ AI_FOUNDRY_ENDPOINT=https://.openai.azure.com/ AI_FOUNDRY_API_KEY= AI_FOUNDRY_EMBEDDING_DEPLOYMENT_NAME=text-embedding-3-large AI_FOUNDRY_EMBEDDING_DIMENSIONS=1536 -AI_FOUNDRY_EMBEDDING_DATA_TYPE=float32 -AI_FOUNDRY_EMBEDDING_DISTANCE_FUNCTION=cosine -# Optional. Vector index type for the memories container: quantizedFlat -# (default), diskANN, or flat. quantizedFlat works on any Cosmos DB account -# (including the classic emulator); diskANN requires the DiskANN capability on -# the Cosmos DB account, so opt into it explicitly when available. -AI_FOUNDRY_EMBEDDING_VECTOR_INDEX_TYPE=quantizedFlat -COSMOS_DB_FULL_TEXT_LANGUAGE=en-US # Embed raw conversation turns on write so they can be vector-searched via # search_turns(). The turns container is always provisioned with a diff --git a/CHANGELOG.md b/CHANGELOG.md index efc4558..b822456 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,29 @@ ## Release History +## [0.3.0b1] (Unreleased) + +#### Features Added +* Contradiction-aware vector dedup and reconciliation. Extraction now runs a + vector-similarity dedup ladder: near-exact duplicates are auto-dropped and + borderline matches are tagged for review before they are written. A periodic + LLM reconcile pass (driven by `dedup.prompty` for facts and + `dedup_episodic.prompty` for episodic memories) then merges duplicate groups + and resolves contradictions, soft-deleting the losers with a supersede reason. + Reconciliation is distance-function aware (the destructive near-exact + auto-drop is disabled for `euclidean` containers). In-process reconcile now + covers both facts and episodic memories, matching the Durable backend, and its + cadence is derived from the persisted message counter. See [PR:#26](https://github.com/AzureCosmosDB/AgentMemoryToolkit/pull/26) +* Extraction watermark. The window of recent turns sent to extraction is sized + from a persisted per-thread watermark (`last_extract_count`) that only advances + after a successful extract, so under normal operation no turns are skipped when + extraction lags or transiently fails. See [PR:#26](https://github.com/AzureCosmosDB/AgentMemoryToolkit/pull/26) + +#### Other Changes +* `search_cosmos` and `search_turns` now always fuse vector similarity with + BM25 (full-text) ranking, falling back to vector-only for all-stopword + queries. The `hybrid_search` flag has been removed — hybrid ranking is the + default and requires no opt-in. See [PR:#26](https://github.com/AzureCosmosDB/AgentMemoryToolkit/pull/26) + ## [0.2.0b1] (2026-06-30) #### Features Added diff --git a/Docs/concepts.md b/Docs/concepts.md index e3b646f..b72807c 100644 --- a/Docs/concepts.md +++ b/Docs/concepts.md @@ -118,26 +118,52 @@ Prompts for summarization and fact extraction live in `azure_functions/prompts/` ## Memory Reconciliation -The `reconcile_memories(user_id, n=50)` pipeline step reads up to N most-recent active facts for a user and asks the LLM to identify two orthogonal outcomes in one pass: +Reconciliation runs in **two complementary tiers**: a cheap, LLM-free **vector-floor dedup ladder** applied to freshly-extracted memories before they persist, and a periodic **LLM reconcile** that runs in a **dual mode** (cheap candidate clusters most sweeps, a full-pool backstop occasionally). -- **Duplicates** — two or more facts that restate the same claim in different words. Resolution: collapse into one merged fact; the originals are soft-deleted with `supersede_reason="duplicate"` and `superseded_by` set to the merged fact's id. -- **Contradictions** — two facts that assert opposing claims about the same subject. Resolution: keep the winner (more recent first, higher confidence as tiebreaker), soft-delete the loser with `supersede_reason="contradict"` and `superseded_by` set to the winner. +### Vector-floor dedup ladder (write path, LLM-free) -### Why one pass +Between extraction and persist, `dedup_extracted_memories` compares each new fact/episodic memory against the user's existing active memories of the same type using Cosmos `VectorDistance` (pure vector, no hybrid). Each new memory takes one rung of a similarity ladder: -Detecting contradictions semantically requires the LLM to see the candidate pool as a whole — paraphrased ("user prefers aisle seats") and contradictory ("user is vegetarian" vs "user loves steak") facts often have very different embedding vectors and would never co-occur in any cosine cluster. Putting all N candidates into one prompt lets the LLM do the semantic reasoning across both axes simultaneously. The pipeline returns `{"kept": int, "merged": int, "contradicted": int}`. +| band | condition (cosine) | action | +|------|--------------------|--------| +| exact | `content_hash` hit | skip (Stage 0, free) | +| near-exact | `s ≥ DEDUP_SIM_HIGH` (0.97) | **auto-skip** the new memory (no LLM); logged for audit | +| borderline | `DEDUP_SIM_LOW ≤ s < DEDUP_SIM_HIGH` (0.80–0.97) | persist, tag `sys:dup-candidate` + stash `dup_of`/`dup_score` for the LLM reconcile | +| novel | `s < DEDUP_SIM_LOW` | persist clean | + +The thresholds are calibrated for **cosine/dotproduct** on normalized embeddings. On a container whose `distanceFunction` is **euclidean**, the destructive near-exact auto-skip is **disabled** (one-shot warning) and those memories fall through to borderline tagging so the LLM adjudicates — euclidean distances aren't a bounded [0,1] similarity and would mis-fire the cosine-tuned drop. + +### Dual-mode LLM reconcile + +`reconcile_memories(user_id, n=50, *, memory_type="fact", full_rebuild=False)` identifies two orthogonal outcomes: + +- **Duplicates** — facts restating the same claim. Resolution: collapse into one merged fact; originals soft-deleted with `supersede_reason="duplicate"` and `superseded_by` set to the merged fact. +- **Contradictions** — facts asserting opposing claims about the same subject. Resolution: keep the winner (more recent first, higher confidence as tiebreaker), soft-delete the loser with `supersede_reason="contradict"`. + +It runs in one of two modes: + +- **Candidate mode** (default auto sweeps) — builds connected-component clusters from the `sys:dup-candidate` seeds + their vector neighbors (edge threshold `DEDUP_CLUSTER_SIM`, 0.60) and sends **only those clusters** to the LLM. Cheap, but keyed on near-duplicate similarity. Tagged seeds that never join a cluster have their stale tag cleared so they aren't re-scanned forever. +- **Full-pool backstop** — every `DEDUP_FULL_RECLUSTER_EVERY_N`-th sweep (default 12), and on any explicit `reconcile(full_rebuild=True)`, the **entire** active pool goes into one LLM pass. This is the only path that catches **dissimilar contradictions** — paraphrased ("prefers aisle seats") and contradictory ("vegetarian" vs "loves steak") facts have very different embedding vectors and would never co-occur in a cosine cluster, so candidate mode alone can't link them. + +Both modes return `{"kept": int, "merged": int, "contradicted": int}`. In-process and durable backends reconcile **both** facts and episodic memories so episodic duplicates don't accrue forever. ### Loser preservation -Soft-deleted facts stay in the container with their `supersede_reason`, `superseded_at`, and `superseded_by` fields populated. Default reads (`get_memories`, `search_cosmos`) filter them out via `superseded_by IS NULL`. To inspect the audit trail (e.g. "show everything that ever applied to this user"), opt out of the filter at the query level. +Soft-deleted facts stay in the container with their `supersede_reason`, `superseded_at`, and `superseded_by` fields populated. Default reads (`get_memories`, `search_cosmos`) filter them out via `superseded_by IS NULL`. To inspect the audit trail, opt out of the filter at the query level. ### Write-time exact dedup Each fact written by `extract_memories` carries a `content_hash` (SHA-256 of normalized content, truncated to 32 hex chars; lowercase, whitespace-collapsed). Before upserting a freshly-extracted fact, the pipeline checks the hash against existing active facts and short-circuits if a match exists, incrementing the `exact_dedup_skipped` metric. This catches identical re-extractions cheaply without an LLM call. +### Extraction watermark (`recent_k`) + +The auto-trigger paths size `recent_k` (how many recent turns extraction reads) from a per-thread **watermark** (`last_extract_count` on the counter doc): `recent_k = current_count − last_extract_count` (with `last_extract_count` treated as `0` before the first successful extract). The newest-`recent_k` turns are exactly the turns added since the last successful extract, and the watermark advances **only after a successful extract** — so under normal operation no turns are skipped when extraction lags or transiently fails: a failed run leaves the watermark put and the full backlog is retried next sweep. The window is deliberately **not** capped by `DEDUP_POOL_SIZE` (that knob governs the reconcile prompt, not the extraction window) — capping would extract only the newest N and silently strand the oldest backlog turns. + +> **Caveat (rare):** the SDK's inline counter increment is best-effort — under sustained optimistic-concurrency contention it can drop an increment rather than block the user's write path (see `increment_counter_sync`). A dropped increment leaves `current_count` lagging the true turn count, which can in turn under-cover a later extraction window. This is the one case where the "no turns skipped" property does not hold; the Function App backend avoids it by raising to force change-feed redelivery. + ### Tunable -`DEDUP_EVERY_N` (default 5) controls how often `reconcile_memories` runs in the auto-trigger path. Set to `0` to disable. The candidate cap `n` (default 50) is tunable per call; larger values give the LLM a wider view at higher token cost. +`DEDUP_EVERY_N` (default 5) controls how often reconcile runs in the auto-trigger path. Set to `0` to disable. The candidate cap `n` (default `DEDUP_POOL_SIZE`, 50) is tunable per call; larger values give the LLM a wider view at higher token cost. `DEDUP_FULL_RECLUSTER_EVERY_N` (default 12) sets how often the full-pool backstop fires. > **Indexing note.** The reconcile pool query orders by `created_at` (matching the prompt's "more recent first" tiebreaker). Cosmos's default indexing policy includes every property, so this works out of the box. If you customize the indexing policy to reduce write RU, ensure `/created_at/?` remains indexed or the query will fail with a 400 (`Order-by over a non-indexed path`). diff --git a/Docs/design_patterns.md b/Docs/design_patterns.md index ce8e72b..0e520c3 100644 --- a/Docs/design_patterns.md +++ b/Docs/design_patterns.md @@ -170,7 +170,6 @@ facts = await mem.search_cosmos( results = await mem.search_cosmos( search_terms="PostgreSQL to Cosmos DB", user_id="user-1", - hybrid_search=True, top_k=5, ) ``` diff --git a/Docs/public_api.md b/Docs/public_api.md index dd17452..bec37e4 100644 --- a/Docs/public_api.md +++ b/Docs/public_api.md @@ -37,8 +37,8 @@ ### Retrieval -- `search_cosmos(search_terms, memory_id=None, user_id=None, role=None, memory_types=None, thread_id=None, hybrid_search=False, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, include_superseded=False, min_salience=None, min_confidence=None, created_after=None, created_before=None) -> list[dict]` — vector or hybrid search derived memories (facts/episodic/procedural). -- `search_turns(search_terms, user_id, thread_id=None, role=None, hybrid_search=False, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, created_after=None, created_before=None) -> list[dict]` — vector or hybrid search the raw conversation log instead of facts/episodic/procedural (requires turn embeddings; see `enable_turn_embeddings`). `user_id` is required so the search is scoped to one partition instead of scanning every user's turns. +- `search_cosmos(search_terms, memory_id=None, user_id=None, role=None, memory_types=None, thread_id=None, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, include_superseded=False, min_salience=None, min_confidence=None, created_after=None, created_before=None) -> list[dict]` — hybrid vector/full-text search memories, falling back to vector-only for all-stopword queries. +- `search_turns(search_terms, user_id, thread_id=None, role=None, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, created_after=None, created_before=None) -> list[dict]` — hybrid vector/full-text search the raw conversation log instead of facts/episodic/procedural (requires turn embeddings; see `enable_turn_embeddings`). `user_id` is required so the search is scoped to one partition instead of scanning every user's turns. - `get_procedural_prompt(user_id) -> Optional[str]` — read the active procedural prompt. - `get_procedural_history(user_id, limit=10) -> list[dict]` — read procedural prompt history. - `get_procedural_memories(user_id, priority=None, category=None, min_salience=None, include_superseded=False) -> list[dict]` — retrieve procedural memory documents. @@ -91,8 +91,8 @@ Local-buffer methods remain synchronous in-memory operations; Cosmos, retrieval, ### Retrieval -- `async search_cosmos(search_terms, memory_id=None, user_id=None, role=None, memory_types=None, thread_id=None, hybrid_search=False, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, include_superseded=False, min_salience=None, min_confidence=None, created_after=None, created_before=None) -> list[dict]` — vector or hybrid search derived memories (facts/episodic/procedural). -- `async search_turns(search_terms, user_id, thread_id=None, role=None, hybrid_search=False, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, created_after=None, created_before=None) -> list[dict]` — vector or hybrid search the raw conversation log instead of facts/episodic/procedural (requires turn embeddings; see `enable_turn_embeddings`). `user_id` is required so the search is scoped to one partition instead of scanning every user's turns. +- `async search_cosmos(search_terms, memory_id=None, user_id=None, role=None, memory_types=None, thread_id=None, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, include_superseded=False, min_salience=None, min_confidence=None, created_after=None, created_before=None) -> list[dict]` — hybrid vector/full-text search memories, falling back to vector-only for all-stopword queries. +- `async search_turns(search_terms, user_id, thread_id=None, role=None, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, created_after=None, created_before=None) -> list[dict]` — hybrid vector/full-text search the raw conversation log instead of facts/episodic/procedural (requires turn embeddings; see `enable_turn_embeddings`). `user_id` is required so the search is scoped to one partition instead of scanning every user's turns. - `async get_procedural_prompt(user_id) -> Optional[str]` — read the active procedural prompt. - `async get_procedural_history(user_id, limit=10) -> list[dict]` — read procedural prompt history. - `async get_procedural_memories(user_id, priority=None, category=None, min_salience=None, include_superseded=False) -> list[dict]` — retrieve procedural memory documents. diff --git a/Docs/troubleshooting.md b/Docs/troubleshooting.md index 27f500f..a64d1cd 100644 --- a/Docs/troubleshooting.md +++ b/Docs/troubleshooting.md @@ -50,15 +50,11 @@ COSMOS_DB_DATABASE=ai_memory COSMOS_DB_MEMORIES_CONTAINER=memories COSMOS_DB_COUNTERS_CONTAINER=counter COSMOS_DB_LEASE_CONTAINER=leases -COSMOS_DB_THROUGHPUT_MODE=serverless -COSMOS_DB_AUTOSCALE_MAX_RU=1000 AI_FOUNDRY_ENDPOINT=https://.openai.azure.com/ AI_FOUNDRY_API_KEY= AI_FOUNDRY_EMBEDDING_DEPLOYMENT_NAME=text-embedding-3-large AI_FOUNDRY_EMBEDDING_DIMENSIONS=1536 -AI_FOUNDRY_EMBEDDING_DATA_TYPE=float32 -AI_FOUNDRY_EMBEDDING_DISTANCE_FUNCTION=cosine AI_FOUNDRY_CHAT_DEPLOYMENT_NAME= ``` @@ -77,8 +73,6 @@ The notebooks and samples pass these values into the client like this: | `AI_FOUNDRY_EMBEDDING_DEPLOYMENT_NAME` | `embedding_deployment_name` | | `AI_FOUNDRY_CHAT_DEPLOYMENT_NAME` | `chat_deployment_name` | -`AI_FOUNDRY_EMBEDDING_DIMENSIONS`, `AI_FOUNDRY_EMBEDDING_DATA_TYPE`, and `AI_FOUNDRY_EMBEDDING_DISTANCE_FUNCTION` are read by the toolkit when creating the Cosmos DB vector policy. The Function App also reads `COSMOS_DB__accountEndpoint` for its identity-based Cosmos DB trigger binding; set it to the same value as `COSMOS_DB_ENDPOINT`. - Run `az login` before using `DefaultAzureCredential`. Required roles: @@ -104,7 +98,7 @@ The memories container is created with: If vector or full-text search fails after changing dimensions or indexing settings, create a fresh container with the desired configuration. Cosmos container vector policies are creation-time infrastructure choices. -Use `COSMOS_DB_THROUGHPUT_MODE=serverless` for the default setup. Use `autoscale` with `COSMOS_DB_AUTOSCALE_MAX_RU` when you need provisioned autoscale throughput. +Pass `cosmos_throughput_mode="serverless"` (the default) when creating the client. Use `cosmos_throughput_mode="autoscale"` with `cosmos_autoscale_max_ru` when you need provisioned autoscale throughput. --- @@ -117,7 +111,7 @@ Embedding failures usually mean one of these is wrong: - `AI_FOUNDRY_EMBEDDING_DIMENSIONS` - Azure OpenAI / AI Services RBAC -For hybrid search, `search_terms` is required when `hybrid_search=True`. +Search always uses hybrid vector/full-text ranking when keyword extraction finds terms; all-stopword queries fall back to vector-only ranking. If search returns documents but scores look poor, check that records have an `embedding` field and that the query uses similar language to the stored memory content. diff --git a/Samples/Advanced/advanced_search_patterns.py b/Samples/Advanced/advanced_search_patterns.py index 2d9a518..dff042d 100644 --- a/Samples/Advanced/advanced_search_patterns.py +++ b/Samples/Advanced/advanced_search_patterns.py @@ -89,10 +89,11 @@ def seed_memories(mem: CosmosMemoryClient, user_id: str, thread_id: str) -> None # --------------------------------------------------------------------------- def vector_search(mem: CosmosMemoryClient, user_id: str) -> None: - """Pattern 1 — Pure vector (semantic similarity) search.""" - print_header("1. Vector Search (semantic similarity)") + """Pattern 1 — Semantic-style query (natural language, low keyword overlap).""" + print_header("1. Semantic Search (natural-language query)") print(" Query: 'outdoor activities'") - print(" Finds semantically related memories even without exact keyword matches.\n") + print(" Hybrid ranking leans on embedding similarity when there are few exact") + print(" keyword matches, so semantically related memories still surface.\n") results = mem.search_cosmos( search_terms="outdoor activities", @@ -103,15 +104,15 @@ def vector_search(mem: CosmosMemoryClient, user_id: str) -> None: def hybrid_search(mem: CosmosMemoryClient, user_id: str) -> None: - """Pattern 2 — Hybrid search (vector + full-text).""" + """Pattern 2 — Hybrid search (vector + full-text) is the default.""" print_header("2. Hybrid Search (vector + full-text)") print(" Query: 'hiking trails Pacific Northwest'") - print(" Combines embedding similarity with BM25 keyword matching.\n") + print(" Every search_cosmos call fuses embedding similarity with BM25 keyword") + print(" matching automatically — no flag required.\n") results = mem.search_cosmos( search_terms="hiking trails Pacific Northwest", user_id=user_id, - hybrid_search=True, top_k=5, ) print_results(results) diff --git a/Samples/Notebooks/Demo_async.ipynb b/Samples/Notebooks/Demo_async.ipynb index bf413f8..3aca2d1 100644 --- a/Samples/Notebooks/Demo_async.ipynb +++ b/Samples/Notebooks/Demo_async.ipynb @@ -872,7 +872,7 @@ "results_search_async = await memory.search_cosmos(\n", " search_terms=\"What did the user ask about the weather?\",\n", " user_id=USER_ID,\n", - " top_k=3, hybrid_search= True\n", + " top_k=3\n", ")" ] }, diff --git a/azure/cosmos/agent_memory/_counters.py b/azure/cosmos/agent_memory/_counters.py index 2223d06..65d888c 100644 --- a/azure/cosmos/agent_memory/_counters.py +++ b/azure/cosmos/agent_memory/_counters.py @@ -170,7 +170,7 @@ def increment_counter_sync( if exc.status_code == 412 and attempt < MAX_RETRIES - 1: continue logger.warning( - "Counter increment failed counter_id=%s status=%s — auto-trigger skipped", + "Counter increment failed counter_id=%s status=%s — auto-trigger skipped (increment dropped)", counter_id, exc.status_code, ) @@ -277,6 +277,10 @@ def _build_counter_doc( doc["last_batch_old_count"] = existing.get("last_batch_old_count", old_count) else: doc["last_batch_lsn"] = None + # Preserve the extraction watermark (count value at the last successful + # extract) so recent_k can cover every turn since, not just this batch. + if existing is not None and "last_extract_count" in existing: + doc["last_extract_count"] = existing.get("last_extract_count") # Carry over auto-trigger failure breadcrumbs so they aren't blown away # by a successful write. ``stamp_failure_*`` helpers refresh them on # failure; operators can monitor ``last_failure_at`` directly. @@ -346,6 +350,74 @@ async def stamp_failure_async( logger.debug("stamp_failure_async failed counter_id=%s: %s", counter_id, exc) +def read_extract_watermark_sync( + container: Any, + counter_id: str, + user_id: str, + thread_id: str, +) -> Optional[int]: + """Return the count value at the last successful extract, or ``None``. + + The watermark lets recent_k cover every turn since the previous extract + succeeded, instead of just the current batch — so turns are never skipped + when extraction lags or transiently fails. Best-effort: returns ``None`` + on any read error so callers fall back to a batch-based recent_k. + """ + try: + doc = container.read_item(item=counter_id, partition_key=[user_id, thread_id]) + value = doc.get("last_extract_count") + return int(value) if value is not None else None + except Exception as exc: # pragma: no cover - best-effort + logger.debug("read_extract_watermark_sync failed counter_id=%s: %s", counter_id, exc) + return None + + +def advance_extract_watermark_sync( + container: Any, + counter_id: str, + user_id: str, + thread_id: str, + count: int, +) -> None: + """Stamp ``last_extract_count=count`` after a successful extract (sync).""" + patch_ops = [{"op": "add", "path": "/last_extract_count", "value": int(count)}] + try: + container.patch_item(item=counter_id, partition_key=[user_id, thread_id], patch_operations=patch_ops) + except Exception as exc: # pragma: no cover - best-effort + logger.debug("advance_extract_watermark_sync failed counter_id=%s: %s", counter_id, exc) + + +async def read_extract_watermark_async( + container: Any, + counter_id: str, + user_id: str, + thread_id: str, +) -> Optional[int]: + """Async version of :func:`read_extract_watermark_sync`.""" + try: + doc = await container.read_item(item=counter_id, partition_key=[user_id, thread_id]) + value = doc.get("last_extract_count") + return int(value) if value is not None else None + except Exception as exc: # pragma: no cover - best-effort + logger.debug("read_extract_watermark_async failed counter_id=%s: %s", counter_id, exc) + return None + + +async def advance_extract_watermark_async( + container: Any, + counter_id: str, + user_id: str, + thread_id: str, + count: int, +) -> None: + """Stamp ``last_extract_count=count`` after a successful extract (async).""" + patch_ops = [{"op": "add", "path": "/last_extract_count", "value": int(count)}] + try: + await container.patch_item(item=counter_id, partition_key=[user_id, thread_id], patch_operations=patch_ops) + except Exception as exc: # pragma: no cover - best-effort + logger.debug("advance_extract_watermark_async failed counter_id=%s: %s", counter_id, exc) + + __all__ = [ "USER_COUNTER_THREAD_ID", "thread_counter_id", @@ -355,4 +427,8 @@ async def stamp_failure_async( "increment_counter_async", "stamp_failure_sync", "stamp_failure_async", + "read_extract_watermark_sync", + "advance_extract_watermark_sync", + "read_extract_watermark_async", + "advance_extract_watermark_async", ] diff --git a/azure/cosmos/agent_memory/_utils.py b/azure/cosmos/agent_memory/_utils.py index 849ef48..ac32f9e 100644 --- a/azure/cosmos/agent_memory/_utils.py +++ b/azure/cosmos/agent_memory/_utils.py @@ -224,11 +224,12 @@ def normalize_ai_foundry_endpoint(endpoint: Optional[str]) -> Optional[str]: def _resolve_embedding_data_type(val: Optional[str]) -> str: - """Resolve embedding data type from explicit value or ``AI_FOUNDRY_EMBEDDING_DATA_TYPE`` env var. + """Resolve embedding data type from the explicit value, defaulting to ``float32``. - Defaults to ``float32``. Raises :class:`ConfigurationError` for unknown values. + Provided by the caller at memory-client creation. Raises :class:`ConfigurationError` + for unknown values. """ - raw = (val if val is not None else os.environ.get("AI_FOUNDRY_EMBEDDING_DATA_TYPE") or "float32").strip() + raw = (val if val is not None else "float32").strip() if raw not in _ALLOWED_EMBEDDING_DATA_TYPES: raise ConfigurationError( message=( @@ -241,11 +242,12 @@ def _resolve_embedding_data_type(val: Optional[str]) -> str: def _resolve_distance_function(val: Optional[str]) -> str: - """Resolve distance function from explicit value or ``AI_FOUNDRY_EMBEDDING_DISTANCE_FUNCTION`` env var. + """Resolve distance function from the explicit value, defaulting to ``cosine``. - Defaults to ``cosine``. Raises :class:`ConfigurationError` for unknown values. + Provided by the caller at memory-client creation. Raises :class:`ConfigurationError` + for unknown values. """ - raw = (val if val is not None else os.environ.get("AI_FOUNDRY_EMBEDDING_DISTANCE_FUNCTION") or "cosine").strip() + raw = (val if val is not None else "cosine").strip() if raw not in _ALLOWED_DISTANCE_FUNCTIONS: raise ConfigurationError( message=( @@ -258,17 +260,16 @@ def _resolve_distance_function(val: Optional[str]) -> str: def _resolve_vector_index_type(val: Optional[str]) -> str: - """Resolve vector index type from explicit value or ``AI_FOUNDRY_EMBEDDING_VECTOR_INDEX_TYPE`` env var. + """Resolve the vector index type from the explicit value, defaulting to ``quantizedFlat``. - Defaults to ``quantizedFlat``. Raises :class:`ConfigurationError` for unknown values. + Provided by the caller at memory-client creation. Raises :class:`ConfigurationError` + for unknown values. ``quantizedFlat`` works on any Cosmos DB account (including the classic emulator). ``diskANN`` requires the Cosmos DB account to have the DiskANN vector index capability enabled; opt into it explicitly when available. """ - raw = ( - val if val is not None else os.environ.get("AI_FOUNDRY_EMBEDDING_VECTOR_INDEX_TYPE") or "quantizedFlat" - ).strip() + raw = (val if val is not None else "quantizedFlat").strip() if raw not in _ALLOWED_VECTOR_INDEX_TYPES: raise ConfigurationError( message=( @@ -280,21 +281,79 @@ def _resolve_vector_index_type(val: Optional[str]) -> str: return raw +_SIMILARITY_DESCENDING_FUNCTIONS = frozenset({"cosine", "dotproduct"}) + + +def vector_order_direction(distance_function: str) -> str: + """Return the ``ORDER BY VectorDistance(...)`` direction for most-similar-first. + + ``DESC`` for cosine/dotproduct (higher score = more similar), ``ASC`` for + euclidean (lower distance = more similar). + """ + return "DESC" if distance_function in _SIMILARITY_DESCENDING_FUNCTIONS else "ASC" + + +def vector_similarity_at_least(score: float, threshold: float, distance_function: str) -> bool: + """Return ``True`` when ``score`` meets/exceeds ``threshold`` similarity. + + For cosine/dotproduct (higher = more similar) this is ``score >= threshold``; + for euclidean (lower = more similar) it inverts to ``score <= threshold``. The + dedup thresholds (``DEDUP_SIM_*``) are calibrated for cosine/dotproduct on + normalized embeddings; euclidean gets the correct *direction* but its + thresholds would need separate calibration. + """ + if distance_function in _SIMILARITY_DESCENDING_FUNCTIONS: + return score >= threshold + return score <= threshold + + +def vector_autodrop_supported(distance_function: str) -> bool: + """Whether the cosine-calibrated near-exact auto-drop is safe to apply. + + The destructive ``DEDUP_SIM_HIGH`` auto-skip drops a new memory without an + LLM check, relying on thresholds (~0.97) calibrated for cosine/dotproduct + on normalized embeddings. Euclidean returns an *unbounded distance* (not a + [0,1] similarity), so those thresholds mis-fire — auto-drop is disabled for + euclidean and the borderline tagging path (LLM-adjudicated) is used instead. + """ + return distance_function != "euclidean" + + +def distance_function_from_container_properties(props: Any, *, default: str = "cosine") -> str: + """Read the vector embedding's ``distanceFunction`` from container properties. + + The distance function (cosine/dotproduct/euclidean) is chosen at + ``create_memory_store`` time, written immutably into the container's vector + embedding policy, and read back here from the authoritative source + (``container.read()``) so the dedup vector-floor logic matches how the + container actually ranks. This SDK provisions exactly one vector embedding; + falls back to ``default`` (cosine) when the policy is absent or malformed + (e.g. ``__new__``-built test instances with mocked containers). + """ + policy = props.get("vectorEmbeddingPolicy") if isinstance(props, dict) else None + embeddings = policy.get("vectorEmbeddings") if isinstance(policy, dict) else None + entry = embeddings[0] if isinstance(embeddings, list) and embeddings else None + fn = entry.get("distanceFunction") if isinstance(entry, dict) else None + if isinstance(fn, str) and fn in _ALLOWED_DISTANCE_FUNCTIONS: + return fn + return default + + def _resolve_full_text_language(val: Optional[str]) -> str: - """Resolve full-text language from explicit value or ``COSMOS_DB_FULL_TEXT_LANGUAGE`` env var. + """Resolve full-text language from the explicit value, defaulting to ``en-US``. - Defaults to ``en-US``. Empty values fall back to the default. + Provided by the caller at memory-client creation. Empty values fall back to the default. """ - raw = (val if val is not None else os.environ.get("COSMOS_DB_FULL_TEXT_LANGUAGE") or "en-US").strip() + raw = (val if val is not None else "en-US").strip() return raw or "en-US" def _resolve_cosmos_throughput_mode(val: Optional[str]) -> str: - """Resolve throughput mode from explicit value or env var. + """Resolve throughput mode from the explicit value, defaulting to ``serverless``. Allowed values are ``serverless`` and ``autoscale``. """ - raw = (val if val is not None else os.environ.get("COSMOS_DB_THROUGHPUT_MODE") or "serverless").strip().lower() + raw = (val if val is not None else "serverless").strip().lower() if raw not in {"serverless", "autoscale"}: raise ConfigurationError( @@ -307,28 +366,15 @@ def _resolve_cosmos_throughput_mode(val: Optional[str]) -> str: def _resolve_cosmos_autoscale_max_ru(val: Optional[int]) -> int: - """Resolve autoscale max RU from explicit value or env var.""" - if val is not None: - if not isinstance(val, int) or isinstance(val, bool) or val <= 0: - raise ConfigurationError( - message=f"Invalid configuration for cosmos_autoscale_max_ru: expected a positive integer, got '{val}'", - parameter="cosmos_autoscale_max_ru", - ) - return val - raw = (os.environ.get("COSMOS_DB_AUTOSCALE_MAX_RU") or "1000").strip() - try: - parsed = int(raw) - except ValueError as exc: + """Resolve autoscale max RU from the explicit value, defaulting to 1000.""" + if val is None: + return 1000 + if not isinstance(val, int) or isinstance(val, bool) or val <= 0: raise ConfigurationError( - message=(f"Invalid configuration for cosmos_autoscale_max_ru: expected an integer, got '{raw}'"), - parameter="cosmos_autoscale_max_ru", - ) from exc - if parsed <= 0: - raise ConfigurationError( - message=(f"Invalid configuration for cosmos_autoscale_max_ru: expected a positive integer, got '{raw}'"), + message=f"Invalid configuration for cosmos_autoscale_max_ru: expected a positive integer, got '{val}'", parameter="cosmos_autoscale_max_ru", ) - return parsed + return val def _resolve_cosmos_provisioning_autoscale_max_ru( @@ -490,10 +536,86 @@ def _container_policies( return vector_embedding_policy, indexing_policy, full_text_policy -def _validate_hybrid_search( - hybrid_search: bool, - search_terms: Optional[str], -) -> None: - """Raise :class:`ValidationError` if hybrid search is requested without search terms.""" - if hybrid_search and not search_terms: - raise ValidationError("search_terms is required when hybrid_search is True") +_FULLTEXT_STOPWORDS: frozenset[str] = frozenset( + """ + 0 1 2 3 4 5 6 7 8 9 a a's able about above according accordingly across actually + after afterwards again against ain't all allow allows almost alone along already + also although always am among amongst an and another any anybody anyhow anyone + anything anyway anyways anywhere apart appear appreciate appropriate are aren't + around as aside ask asking associated at available away awfully b be became + because become becomes becoming been before beforehand behind being believe below + beside besides best better between beyond both brief but by c c'mon c's came can + can't cannot cant cause causes certain certainly changes clearly co com come comes + concerning consequently consider considering contain containing contains + corresponding could couldn't course currently d definitely described despite did + didn't different do does doesn't doing don don't done down downwards during e each + edu eg eight either else elsewhere enough entirely especially et etc even ever + every everybody everyone everything everywhere ex exactly example except f far few + fifth first five followed following follows for former formerly forth four from + further furthermore g get gets getting given gives go goes going gone got gotten + greetings h had hadn't happens hardly has hasn't have haven't having he he's hello + help hence her here here's hereafter hereby herein hereupon hers herself hi him + himself his hither hopefully how howbeit however i i'd i'll i'm i've ie if ignored + immediate in inasmuch inc indeed indicate indicated indicates inner insofar instead + into inward is isn't it it'd it'll it's its itself j just k keep keeps kept know + known knows l last lately later latter latterly least less lest let let's like + liked likely little ll look looking looks ltd m mainly make many may maybe me mean + meanwhile merely might more moreover most mostly mr mrs ms much must my myself n + name namely nd near nearly necessary need needs neither never nevertheless new next + nine no nobody non none noone nor normally not nothing novel now nowhere o obviously + of off often oh ok okay old on once one ones only onto or other others otherwise + ought our ours ourselves out outside over overall own p particular particularly per + perhaps placed please plus possible presumably probably provides q que quite qv r + rather rd re really reasonably regarding regardless regards relatively respectively + right s said same saw say saying says second secondly see seeing seem seemed seeming + seems seen self selves sensible sent serious seriously seven several shall she + should shouldn't since six so some somebody somehow someone something sometime + sometimes somewhat somewhere soon sorry specified specify specifying still sub such + sup sure t t's take taken tell tends th than thank thanks thanx that that's thats + the their theirs them themselves then thence there there's thereafter thereby + therefore therein theres thereupon these they they'd they'll they're they've think + third this thorough thoroughly those though three through throughout thru thus to + together too took toward towards tried tries truly try trying twice two u un under + unfortunately unless unlikely until unto up upon us use used useful uses using + usually v value various ve very via viz vs w want wants was wasn't way we we'd we'll + we're we've welcome well went were weren't what what's whatever when whence whenever + where where's whereafter whereas whereby wherein whereupon wherever whether which + while whither who who's whoever whole whom whose why will willing wish with within + without won't wonder would wouldn't x y yes yet you you'd you'll you're you've your + yours yourself yourselves z zero + """.split() +) + +_KEYWORD_TOKEN_RE = re.compile(r"[a-z0-9]+") + +# Azure Cosmos DB ``FullTextScore`` accepts at most 30 search terms; a query with +# 31+ terms is rejected with ``BadRequest: One of the input values is invalid``. +# Keyword extraction is therefore capped here so the hybrid search SQL can never +# exceed the limit. The full query text is still embedded uncapped for the vector +# half of the hybrid rank, so trimming the BM25 keyword tail does not lose semantics. +MAX_FULLTEXT_TERMS = 30 + + +def extract_keywords(text: Optional[str]) -> list[str]: + """Extract de-duplicated, stopword-filtered keyword terms for full-text search. + + Lowercases, tokenizes on alphanumeric runs (apostrophes/punctuation split into + fragments that are themselves stopwords), removes stopwords, and de-duplicates + while preserving first-seen order. The result is capped at ``MAX_FULLTEXT_TERMS`` + (30) — the hard limit on terms Azure Cosmos DB ``FullTextScore`` accepts — so the + hybrid search query is always valid. Returns ``[]`` when the text is empty or all + stopwords, which the search layer treats as a signal to fall back to pure vector + ranking. + """ + if not text: + return [] + seen: set[str] = set() + keywords: list[str] = [] + for token in _KEYWORD_TOKEN_RE.findall(text.lower()): + if token in _FULLTEXT_STOPWORDS or token in seen: + continue + seen.add(token) + keywords.append(token) + if len(keywords) >= MAX_FULLTEXT_TERMS: + break + return keywords diff --git a/azure/cosmos/agent_memory/aio/auto_trigger.py b/azure/cosmos/agent_memory/aio/auto_trigger.py index b0d4c47..78f8716 100644 --- a/azure/cosmos/agent_memory/aio/auto_trigger.py +++ b/azure/cosmos/agent_memory/aio/auto_trigger.py @@ -64,6 +64,10 @@ async def maybe_trigger_steps( return n_dedup_turns = n_facts * n_dedup if n_facts > 0 and n_dedup > 0 else 0 + # Persisted-counter full-pool backstop cadence (durable-safe, mirrors the + # change-feed): every DEDUP_FULL_RECLUSTER_EVERY_N-th reconcile. + n_full_recluster = _threshold_int(thresholds, "get_dedup_full_recluster_every_n", "DEDUP_FULL_RECLUSTER_EVERY_N") + n_full_turns = n_dedup_turns * n_full_recluster if (n_dedup_turns > 0 and n_full_recluster > 0) else 0 user_batch_counts = await _trigger_thread_steps( processor, counter_container, @@ -71,6 +75,7 @@ async def maybe_trigger_steps( n_facts=n_facts, n_summary=n_summary, n_dedup_turns=n_dedup_turns, + n_full_turns=n_full_turns, thresholds=thresholds, ) await _trigger_user_steps(processor, counter_container, user_batch_counts, n_user=n_user) @@ -84,6 +89,7 @@ async def _trigger_thread_steps( n_facts: int, n_summary: int, n_dedup_turns: int, + n_full_turns: int, thresholds: Any = None, ) -> dict[str, int]: user_batch_counts: dict[str, int] = {} @@ -110,14 +116,38 @@ async def _trigger_thread_steps( counter_id, user_id, thread_id, + new_count=new_count, fire_extract=n_facts > 0 and _counters.crosses_threshold(old_count, new_count, n_facts), fire_summary=n_summary > 0 and _counters.crosses_threshold(old_count, new_count, n_summary), fire_dedup=n_dedup_turns > 0 and _counters.crosses_threshold(old_count, new_count, n_dedup_turns), + fire_full_rebuild=n_full_turns > 0 and _counters.crosses_threshold(old_count, new_count, n_full_turns), thresholds=thresholds, ) return user_batch_counts +async def _watermark_recent_k( + counter_container: Any, + counter_id: str, + user_id: str, + thread_id: str, + *, + new_count: int, +) -> int: + """Async: recent_k covering every turn since the last successful extract. + + Not capped — ``new_count - watermark`` is exactly the unextracted backlog and + the newest-``recent_k`` slice covers precisely those turns, so the watermark + can advance to ``new_count`` with no stranded turns. **Bootstrap:** with no + watermark yet the base is ``0`` (``recent_k = new_count``), so turns added + during earlier failed extracts aren't stranded when the watermark first + advances to ``new_count``. + """ + watermark = await _counters.read_extract_watermark_async(counter_container, counter_id, user_id, thread_id) + base = watermark if watermark is not None else 0 + return max(new_count - base, 1) + + async def _fire_thread_steps( processor: AsyncInProcessProcessor, counter_container: Any, @@ -125,9 +155,11 @@ async def _fire_thread_steps( user_id: str, thread_id: str, *, + new_count: int, fire_extract: bool, fire_summary: bool, fire_dedup: bool, + fire_full_rebuild: bool = False, thresholds: Any = None, ) -> None: fire_procedural = fire_dedup and bool( @@ -138,14 +170,33 @@ async def _fire_thread_steps( default=True, ) ) + if fire_extract: + recent_k = await _watermark_recent_k( + counter_container, + counter_id, + user_id, + thread_id, + new_count=new_count, + ) + try: + await _call_async_compatible( + processor.process_extract_memories, user_id=user_id, thread_id=thread_id, recent_k=recent_k + ) + await _counters.advance_extract_watermark_async( + counter_container, counter_id, user_id, thread_id, new_count + ) + except Exception as exc: + logger.warning("Async auto-trigger process_extract_memories failed for %s/%s: %s", user_id, thread_id, exc) + await _counters.stamp_failure_async( + counter_container, counter_id, user_id, thread_id, f"process_extract_memories: {exc!r}" + ) calls = ( ( - fire_extract, - "process_extract_memories", - processor.process_extract_memories, - {"user_id": user_id, "thread_id": thread_id}, + fire_dedup, + "process_reconcile", + processor.process_reconcile, + {"user_id": user_id, "full_rebuild": fire_full_rebuild}, ), - (fire_dedup, "process_reconcile", processor.process_reconcile, {"user_id": user_id}), ( fire_procedural, "synthesize_procedural", diff --git a/azure/cosmos/agent_memory/aio/cosmos_memory_client.py b/azure/cosmos/agent_memory/aio/cosmos_memory_client.py index 4c5b4a7..6bfd626 100644 --- a/azure/cosmos/agent_memory/aio/cosmos_memory_client.py +++ b/azure/cosmos/agent_memory/aio/cosmos_memory_client.py @@ -663,7 +663,6 @@ async def search_cosmos( role: Optional[str] = None, memory_types: Optional[list[str]] = None, thread_id: Optional[str] = None, - hybrid_search: bool = False, top_k: int = 5, tags_all: Optional[list[str]] = None, tags_any: Optional[list[str]] = None, @@ -681,7 +680,6 @@ async def search_cosmos( role=role, memory_types=memory_types, thread_id=thread_id, - hybrid_search=hybrid_search, top_k=top_k, tags_all=tags_all, tags_any=tags_any, @@ -699,7 +697,6 @@ async def search_turns( user_id: str, thread_id: Optional[str] = None, role: Optional[str] = None, - hybrid_search: bool = False, top_k: int = 5, tags_all: Optional[list[str]] = None, tags_any: Optional[list[str]] = None, @@ -721,7 +718,6 @@ async def search_turns( user_id=user_id, thread_id=thread_id, role=role, - hybrid_search=hybrid_search, top_k=top_k, tags_all=tags_all, tags_any=tags_any, @@ -836,12 +832,23 @@ async def search_episodic_memories( min_salience: Optional[float] = None, include_superseded: bool = False, ) -> list[dict[str, Any]]: - return await self._get_store().search_episodic(user_id, search_terms, top_k, min_salience, include_superseded) + return await self._get_store().search_episodic( + user_id, + search_terms, + top_k, + min_salience, + include_superseded, + ) async def build_procedural_context(self, user_id: str) -> str: return await self._get_pipeline().build_procedural_context(user_id) - async def build_episodic_context(self, user_id: str, query: str, top_k: int = 3) -> str: + async def build_episodic_context( + self, + user_id: str, + query: str, + top_k: int = 3, + ) -> str: return await self._get_store().build_episodic_context(user_id, query, top_k) async def extract_memories(self, user_id: str, thread_id: str, recent_k: Optional[int] = None) -> dict[str, int]: @@ -879,7 +886,9 @@ async def generate_user_summary( async def reconcile(self, user_id: str, n: Optional[int] = None) -> dict[str, int]: from azure.cosmos.agent_memory.thresholds import get_dedup_pool_size - return await self._get_pipeline().reconcile_memories(user_id, n if n is not None else get_dedup_pool_size()) + return await self._get_pipeline().reconcile_memories( + user_id, n if n is not None else get_dedup_pool_size(), full_rebuild=True + ) async def process_now(self, *, user_id: str, thread_id: str) -> "ProcessThreadResult": """Force the processor to run the full pipeline RIGHT NOW for one thread. diff --git a/azure/cosmos/agent_memory/aio/processors/base.py b/azure/cosmos/agent_memory/aio/processors/base.py index cdab0c0..421a5e5 100644 --- a/azure/cosmos/agent_memory/aio/processors/base.py +++ b/azure/cosmos/agent_memory/aio/processors/base.py @@ -33,6 +33,7 @@ async def process_extract_memories( *, user_id: str, thread_id: str, + recent_k: Optional[int] = None, ) -> dict[str, int]: ... async def process_thread_summary( @@ -53,6 +54,7 @@ async def process_reconcile( self, *, user_id: str, + full_rebuild: bool = False, ) -> int: ... async def generate_user_summary( diff --git a/azure/cosmos/agent_memory/aio/processors/durable.py b/azure/cosmos/agent_memory/aio/processors/durable.py index 677855d..59fef03 100644 --- a/azure/cosmos/agent_memory/aio/processors/durable.py +++ b/azure/cosmos/agent_memory/aio/processors/durable.py @@ -40,6 +40,7 @@ async def process_extract_memories( *, user_id: str, thread_id: str, + recent_k: Optional[int] = None, ) -> dict[str, int]: logger.debug( "AsyncDurableFunctionProcessor.process_extract_memories no-op user_id=%s thread_id=%s", @@ -73,7 +74,7 @@ async def process_user_summary( ) return UserSummaryResult(summary=None) - async def process_reconcile(self, *, user_id: str) -> int: + async def process_reconcile(self, *, user_id: str, full_rebuild: bool = False) -> int: logger.debug( "AsyncDurableFunctionProcessor.process_reconcile no-op user_id=%s", user_id, diff --git a/azure/cosmos/agent_memory/aio/processors/inprocess.py b/azure/cosmos/agent_memory/aio/processors/inprocess.py index 9e10071..62cac70 100644 --- a/azure/cosmos/agent_memory/aio/processors/inprocess.py +++ b/azure/cosmos/agent_memory/aio/processors/inprocess.py @@ -70,9 +70,7 @@ async def process_thread( thread_summary = await self._pipeline.generate_thread_summary(user_id, thread_id) extracted = await self._pipeline.extract_memories(user_id, thread_id) - reconciled = await self._pipeline.reconcile_memories(user_id, get_dedup_pool_size()) - - deduped_count = self._extract_reconcile_count(reconciled) + deduped_count = await self._reconcile_fact_and_episodic(user_id, get_dedup_pool_size()) extracted_counts: dict[str, int] = ( {k: v for k, v in extracted.items() if isinstance(v, int)} if isinstance(extracted, dict) else {} @@ -91,8 +89,9 @@ async def process_extract_memories( *, user_id: str, thread_id: str, + recent_k: Optional[int] = None, ) -> dict[str, int]: - extracted = await self._pipeline.extract_memories(user_id, thread_id) + extracted = await self._pipeline.extract_memories(user_id, thread_id, recent_k=recent_k) return {k: v for k, v in extracted.items() if isinstance(v, int)} if isinstance(extracted, dict) else {} async def process_thread_summary( @@ -113,11 +112,26 @@ async def process_user_summary( summary = await self._pipeline.generate_user_summary(user_id, thread_ids) return UserSummaryResult(summary=summary if isinstance(summary, dict) else None) - async def process_reconcile(self, *, user_id: str) -> int: + async def process_reconcile(self, *, user_id: str, full_rebuild: bool = False) -> int: from ...thresholds import get_dedup_pool_size - reconciled = await self._pipeline.reconcile_memories(user_id, get_dedup_pool_size()) - return self._extract_reconcile_count(reconciled) + return await self._reconcile_fact_and_episodic(user_id, get_dedup_pool_size(), full_rebuild=full_rebuild) + + async def _reconcile_fact_and_episodic(self, user_id: str, n: int, *, full_rebuild: bool = False) -> int: + """Reconcile facts and episodic memories; sum merged+contradicted counts. + + SDK in-process processing reconciles both types (matching the Durable + backend) so episodic dups don't accrue forever. ``full_rebuild`` (set by + the auto-trigger on its persisted-counter full-recluster cadence) forces + the full-pool LLM pass that catches dissimilar-embedding contradictions. + """ + total = 0 + for memory_type in ("fact", "episodic"): + reconciled = await self._pipeline.reconcile_memories( + user_id, n=n, memory_type=memory_type, full_rebuild=full_rebuild + ) + total += self._extract_reconcile_count(reconciled) + return total @staticmethod def _extract_reconcile_count(reconciled: Any) -> int: diff --git a/azure/cosmos/agent_memory/aio/services/pipeline.py b/azure/cosmos/agent_memory/aio/services/pipeline.py index 0db0475..4e5de1e 100644 --- a/azure/cosmos/agent_memory/aio/services/pipeline.py +++ b/azure/cosmos/agent_memory/aio/services/pipeline.py @@ -10,7 +10,6 @@ from __future__ import annotations -import asyncio import hashlib import inspect import json @@ -20,13 +19,19 @@ from typing import Any, Iterable, Literal, Optional from azure.cosmos.exceptions import ( - CosmosHttpResponseError, CosmosResourceExistsError, CosmosResourceNotFoundError, ) from azure.cosmos.agent_memory._container_routing import ContainerKey -from azure.cosmos.agent_memory._utils import DEFAULT_TTL_BY_TYPE, compute_content_hash +from azure.cosmos.agent_memory._utils import ( + DEFAULT_TTL_BY_TYPE, + compute_content_hash, + distance_function_from_container_properties, + vector_autodrop_supported, + vector_order_direction, + vector_similarity_at_least, +) from azure.cosmos.agent_memory.aio.store import AsyncMemoryStore from azure.cosmos.agent_memory.exceptions import ( LLMError, @@ -58,9 +63,6 @@ coerce_valence, parse_llm_json, ) -from azure.cosmos.agent_memory.services._pipeline_helpers import ( - format_existing_episodics as _format_existing_episodics, -) from azure.cosmos.agent_memory.services._pipeline_helpers import ( is_real_number as _is_real_number, ) @@ -68,6 +70,16 @@ max_or_none as _max_or_none, ) from azure.cosmos.agent_memory.store._search_helpers import top_literal +from azure.cosmos.agent_memory.thresholds import ( + get_dedup_candidate_topk, + get_dedup_cluster_sim, + get_dedup_context_topk, + get_dedup_context_vector_enabled, + get_dedup_reconcile_mode, + get_dedup_sim_high, + get_dedup_sim_low, + get_dedup_vector_enabled, +) logger = get_logger("azure.cosmos.agent_memory.pipeline.aio") @@ -252,6 +264,106 @@ async def _embed_one(self, text: str) -> list[float]: async def _embed_batch(self, texts: list[str]) -> list[list[float]]: return await self._embeddings.generate_batch(texts) + async def _vector_distance_function(self) -> str: + """Return the container's configured Cosmos ``distanceFunction`` (cached). + + Read from the container's vector embedding policy (``await container.read()``) + — the authoritative, immutable source set when the container was created. + Drives the ORDER BY direction and similarity-threshold comparisons so dedup + never silently assumes cosine. Falls back to cosine when the policy can't be + read (e.g. ``__new__``-built test instances with mocked containers). + """ + fn = getattr(self, "_distance_function_cache", None) + if fn is not None: + return fn + try: + props = await self._memories_container.read() + except Exception: + # Transient read failure is indistinguishable from "no policy" once we + # drop to None — so DON'T cache. An uncached cosine default self-heals on + # the next call; caching it would pin cosine and silently mis-handle a + # euclidean container (cosine bands on euclidean distances → data loss). + logger.debug( + "vector dedup: could not read container vector policy; defaulting to cosine (not cached)", + exc_info=True, + ) + return "cosine" + fn = distance_function_from_container_properties(props) + self._distance_function_cache = fn + return fn + + def _warn_euclidean_autodrop_once(self, distance_function: str) -> None: + """One-shot WARN that the near-exact vector auto-drop is disabled. + + The ``DEDUP_SIM_HIGH`` thresholds are cosine-calibrated; on euclidean + the destructive auto-drop is skipped (borderline tagging + LLM reconcile + still run). Logged once per pipeline instance to avoid hot-path spam. + """ + if getattr(self, "_warned_euclidean_autodrop", False): + return + self._warned_euclidean_autodrop = True + logger.warning( + "Container distanceFunction=%r: near-exact vector auto-drop is " + "cosine-calibrated and has been DISABLED for this distance function. " + "Duplicate detection falls back to borderline tagging + LLM reconcile. " + "Use cosine/dotproduct embeddings for vector-floor auto-dedup.", + distance_function, + ) + + async def _vector_candidates( + self, + *, + user_id: str, + embedding, + memory_type, + top_k, + exclude_ids, + ) -> list[dict]: + """Return active same-user vector candidates from Cosmos.""" + if not user_id or not embedding or not top_k or int(top_k) < 1: + return [] + excluded = set(exclude_ids or []) + capped_top = top_literal(int(top_k), name="_vector_candidates.top_k") + distance_function = await self._vector_distance_function() + order_direction = vector_order_direction(distance_function) + field = "embedding" + query = ( + f"SELECT TOP {capped_top} c.id, c.content, c.type, " + f"VectorDistance(c.{field}, @vec) AS score " + "FROM c WHERE c.user_id = @user_id " + "AND c.type = @memory_type " + f"AND {_ACTIVE_DOC_FILTER} " + f"AND IS_DEFINED(c.{field}) " + # Cosmos orders ORDER BY VectorDistance() most-similar-first per the + # container's distanceFunction; an explicit ASC/DESC is rejected (BadRequest). + f"ORDER BY VectorDistance(c.{field}, @vec)" + ) + rows = await self._query_items( + self._memories_container, + query=query, + parameters=[ + {"name": "@user_id", "value": user_id}, + {"name": "@memory_type", "value": memory_type}, + {"name": "@vec", "value": embedding}, + ], + ) + candidates = [ + { + "id": row.get("id"), + "content": row.get("content"), + "type": row.get("type"), + "score": float(row.get("score") or 0.0), + } + for row in rows + if row.get("id") and row.get("id") not in excluded + ] + # Most-similar-first: descending score for cosine/dotproduct, ascending for euclidean. + candidates.sort( + key=lambda item: item.get("score", 0.0), + reverse=order_direction == "DESC", + ) + return candidates + def _prompt_lineage(self, filename: str) -> dict[str, str]: """Return ``{prompt_id, prompt_version}`` for stamping a doc. @@ -354,58 +466,6 @@ def _stable_source_timestamp(items: list[dict[str, Any]]) -> str: return max(timestamps) return datetime.now(timezone.utc).isoformat() - async def _mark_extracted_superseded( - self, - *, - user_id: str, - thread_id: str, - supersedes_id: str, - superseder_id: str, - reason: Literal["update", "contradict"], - ) -> bool: - try: - old_mem = await self._read_item( - self._memories_container, - item=supersedes_id, - partition_key=[user_id, thread_id], - ) - if old_mem.get("superseded_by"): - logger.debug( - "extract_memories: skipping UPDATE — target %s already superseded by %s", - supersedes_id, - old_mem.get("superseded_by"), - ) - return False - return await self._mark_superseded(old_mem, superseder_id, reason=reason) - except CosmosResourceNotFoundError: - logger.debug( - "extract_memories: %s not found at (user_id, thread_id) — retrying cross-partition", - supersedes_id, - ) - except Exception as exc: - # Includes 429s, 503s, transient connection errors — surface at WARNING - # so they're not masked by the silent cross-partition fallback below. - logger.warning( - "extract_memories: read_item failed for %s (%s); retrying cross-partition", - supersedes_id, - type(exc).__name__, - ) - try: - q = f"SELECT * FROM c WHERE c.id = @id AND c.user_id = @uid AND {_ACTIVE_DOC_FILTER}" - results = await self._query_items( - self._memories_container, - query=q, - parameters=[ - {"name": "@id", "value": supersedes_id}, - {"name": "@uid", "value": user_id}, - ], - ) - if results and not results[0].get("superseded_by"): - return await self._mark_superseded(results[0], superseder_id, reason=reason) - except CosmosHttpResponseError as exc: - logger.warning("Failed to mark superseded memory %s: %s", supersedes_id, exc) - return False - async def _mark_superseded( self, old_doc: dict[str, Any], @@ -464,30 +524,35 @@ async def extract_memories_dry( logger.warning("extract_memories_dry no memories found user_id=%s thread_id=%s", user_id, thread_id) return {"facts": [], "episodic": [], "updates": [], "processed_turn_docs": []} - existing_facts, existing_episodics = await asyncio.gather( - self._load_existing_memories(user_id, ["fact"]), - self._load_existing_memories(user_id, ["episodic"]), - ) + transcript = self._build_transcript(items) + existing_for_hash = await self._load_existing_memories(user_id, ["fact"]) existing_fact_hashes: set[str] = { - m["content_hash"] for m in existing_facts if m.get("type") == "fact" and m.get("content_hash") + m["content_hash"] for m in existing_for_hash if m.get("type") == "fact" and m.get("content_hash") } - if existing_facts: + if get_dedup_context_vector_enabled(): + user_turns_text = "\n".join( + str(it.get("content", "")) for it in items if it.get("role") == "user" + ).strip() + context_query = user_turns_text or transcript + existing = await self._store.search( + search_terms=context_query, + user_id=user_id, + memory_types=["fact"], + top_k=get_dedup_context_topk(), + ) + else: + existing = existing_for_hash + if existing: existing_text = "\n".join( - f"- [ID: {mem['id']}] {mem.get('content', '')} (type=fact, salience={mem.get('salience', 'N/A')})" - for mem in existing_facts + f"- [ID: {mem['id']}] {mem.get('content', '')} " + f"(type={mem.get('type', 'fact')}, salience={mem.get('salience', 'N/A')})" + for mem in existing ) else: existing_text = "(none)" - existing_episodics_text = _format_existing_episodics(existing_episodics) - - transcript = self._build_transcript(items) response_text = await self._run_prompty( "extract_memories.prompty", - inputs={ - "existing_facts": existing_text, - "existing_episodics": existing_episodics_text, - "transcript": transcript, - }, + inputs={"existing_facts": existing_text, "transcript": transcript}, ) parsed = self._parse_llm_json(response_text) facts = parsed.get("facts", []) @@ -502,14 +567,13 @@ async def extract_memories_dry( dropped_episodic_count = 0 for fact in facts: - action = fact.get("action", "ADD").upper() text = fact.get("text") if not text: logger.warning("extract_memories: dropping malformed fact (missing 'text'): %r", fact) continue new_content_hash = compute_content_hash(text) - if action == "ADD" and new_content_hash in existing_fact_hashes: + if new_content_hash in existing_fact_hashes: logger.debug( "extract_memories: skipping exact-dup fact hash=%s user_id=%s thread_id=%s", new_content_hash, @@ -546,22 +610,6 @@ async def extract_memories_dry( "updated_at": doc_timestamp, } - if action in {"UPDATE", "CONTRADICT"} and fact.get("supersedes_id"): - reason: Literal["update", "contradict"] = "contradict" if action == "CONTRADICT" else "update" - if det_id == fact["supersedes_id"]: - logger.debug("extract_memories: skipping UPDATE — det_id == supersedes_id (%s)", det_id) - continue - doc["supersedes_ids"] = [fact["supersedes_id"]] - updates.append( - { - "op": "supersede", - "supersedes_id": fact["supersedes_id"], - "superseder_id": det_id, - "thread_id": thread_id, - "reason": reason, - } - ) - fact_docs.append(self._validate_extracted_doc(doc)) existing_fact_hashes.add(new_content_hash) @@ -581,28 +629,16 @@ async def extract_memories_dry( dropped_episodic_count += 1 continue - text_raw = ep.get("text") - text = text_raw.strip() if isinstance(text_raw, str) else None - if not text: - logger.error( - "extract_memories: dropping episodic with empty/missing text field " - "(LLM extraction did not populate the required `text` field — likely a " - "weaker extraction model that needs upgrading or a prompt-compliance issue). " - "scope_type=%s scope_value=%s user_id=%s thread_id=%s reason=missing_text", - scope_type, - scope_value, - user_id, - thread_id, - ) - dropped_episodic_count += 1 - continue - situation = ep.get("situation") action_taken = ep.get("action_taken") outcome = ep.get("outcome") + if situation and action_taken and outcome: + text = f"{situation} → {action_taken} → {outcome}" + else: + text = f"For the user's {scope_value} {scope_type}, intent recorded." content_hash = compute_content_hash(text) - seed = _ID_SEED_SEP.join((user_id, scope_type, scope_value)) + seed = _ID_SEED_SEP.join((user_id, thread_id, content_hash)) det_id = f"ep_{hashlib.sha256(seed.encode()).hexdigest()[:32]}" topic_tags = build_topic_tags(ep.get("tags", [])) confidence = ep.get("confidence") @@ -619,7 +655,7 @@ async def extract_memories_dry( doc = { "id": det_id, "user_id": user_id, - "thread_id": "__episodic__", + "thread_id": thread_id, "role": "system", "type": "episodic", "content": text, @@ -630,7 +666,6 @@ async def extract_memories_dry( "metadata": { "scope_type": scope_type, "scope_value": scope_value, - "originating_thread_id": thread_id, "situation": situation, "action_taken": action_taken, "outcome": outcome, @@ -694,7 +729,7 @@ async def extract_memories_dry( check_extracted_fact_grounding( fact_docs, items, - existing_facts, + existing, user_id=user_id, thread_id=thread_id, logger=logger, @@ -716,6 +751,100 @@ async def extract_memories_dry( ) return result + async def dedup_extracted_memories(self, user_id: str, extracted: dict) -> dict: + """Apply gated vector-floor deduplication to extracted facts/episodes.""" + if not get_dedup_vector_enabled(): + return extracted + if not user_id: + raise ValidationError("user_id is required") + if not isinstance(extracted, dict): + raise ValidationError("extracted must be a dict") + + high = get_dedup_sim_high() + low = get_dedup_sim_low() + top_k = get_dedup_candidate_topk() + distance_function = await self._vector_distance_function() + autodrop_ok = vector_autodrop_supported(distance_function) + if not autodrop_ok: + self._warn_euclidean_autodrop_once(distance_function) + result = { + "facts": [dict(doc) for doc in extracted.get("facts", [])], + "episodic": [dict(doc) for doc in extracted.get("episodic", [])], + "updates": [dict(op) for op in extracted.get("updates", [])], + } + docs = [doc for doc in result["facts"] + result["episodic"] if doc.get("content")] + missing_embeddings = [doc for doc in docs if not doc.get("embedding")] + if missing_embeddings: + embeddings = await self._embed_batch([str(doc["content"]) for doc in missing_embeddings]) + for doc, embedding in zip(missing_embeddings, embeddings): + doc["embedding"] = embedding + + vector_dedup_skipped = 0 + dup_candidates_tagged = 0 + kept_ids: set[str] = set() + dropped_ids: set[str] = set() + filtered_by_key: dict[str, list[dict[str, Any]]] = {"facts": [], "episodic": []} + for key in ("facts", "episodic"): + for doc in result[key]: + if not doc.get("content"): + filtered_by_key[key].append(doc) + continue + doc_id = str(doc.get("id") or "") + memory_type = str(doc.get("type") or "") + if not doc_id or memory_type not in {"fact", "episodic"}: + # Parity with sync: under-specified docs (no id / unknown type) + # skip dedup and pass through verbatim. + filtered_by_key[key].append(doc) + continue + exclude_ids = kept_ids | dropped_ids | {doc_id, *(doc.get("supersedes_ids") or [])} + candidates = await self._vector_candidates( + user_id=user_id, + embedding=doc.get("embedding"), + memory_type=memory_type, + top_k=top_k, + exclude_ids=exclude_ids, + ) + best: dict[str, Any] | None = candidates[0] if candidates else None + score = float(best.get("score") or 0.0) if best else 0.0 + if best and autodrop_ok and vector_similarity_at_least(score, high, distance_function): + vector_dedup_skipped += 1 + dropped_ids.add(doc_id) + logger.info( + "dedup_extracted_memories: vector skip user_id=%s dropped=%r " + "surviving_id=%s surviving=%r score=%.4f", + user_id, + doc.get("content"), + best.get("id"), + best.get("content"), + score, + ) + continue + if best and vector_similarity_at_least(score, low, distance_function): + tags = list(doc.get("tags") or []) + if "sys:dup-candidate" not in tags: + tags.append("sys:dup-candidate") + doc["tags"] = tags + metadata = dict(doc.get("metadata") or {}) + metadata["dup_of"] = best.get("id") + metadata["dup_score"] = score + doc["metadata"] = metadata + dup_candidates_tagged += 1 + + kept_ids.add(doc_id) + filtered_by_key[key].append(doc) + + if vector_dedup_skipped or dup_candidates_tagged: + result["updates"].append( + { + "op": "stats", + "vector_dedup_skipped": vector_dedup_skipped, + "dup_candidates_tagged": dup_candidates_tagged, + } + ) + result["facts"] = filtered_by_key["facts"] + result["episodic"] = filtered_by_key["episodic"] + return result + async def persist_extracted_memories( self, user_id: str, @@ -763,27 +892,14 @@ async def persist_extracted_memories( if op.get("op") == "stats": result["exact_dedup_skipped"] += int(op.get("exact_dedup_skipped") or 0) result["dropped_episodic_count"] += int(op.get("dropped_episodic_count") or 0) - continue - if op.get("op") != "supersede": - continue - reason = op.get("reason") - op_thread_id = op.get("thread_id") - supersedes_id = op.get("supersedes_id") - superseder_id = op.get("superseder_id") - if reason not in {"update", "contradict"} or not op_thread_id or not supersedes_id or not superseder_id: - continue - marked = await self._mark_extracted_superseded( - user_id=user_id, - thread_id=op_thread_id, - supersedes_id=supersedes_id, - superseder_id=superseder_id, - reason=reason, - ) - if marked: - if reason == "contradict": - result["contradicted_count"] += 1 - else: - result["updated_count"] += 1 + if "vector_dedup_skipped" in op: + result["vector_dedup_skipped"] = result.get("vector_dedup_skipped", 0) + int( + op.get("vector_dedup_skipped") or 0 + ) + if "dup_candidates_tagged" in op: + result["dup_candidates_tagged"] = result.get("dup_candidates_tagged", 0) + int( + op.get("dup_candidates_tagged") or 0 + ) logger.info("persist_extracted_memories completed user_id=%s counts=%s", user_id, result) @@ -834,6 +950,8 @@ async def extract_memories( ) -> dict[str, int]: """Extract facts and episodic memories from a thread and persist them.""" extracted = await self.extract_memories_dry(user_id, thread_id, recent_k, turns=turns) + if get_dedup_vector_enabled(): + extracted = await self.dedup_extracted_memories(user_id, extracted) return await self.persist_extracted_memories(user_id, extracted) async def synthesize_procedural( @@ -1031,7 +1149,7 @@ def _render_bullets(values: list[str]) -> str: } validated = construct_internal(ProceduralRecord, new_doc).to_doc() try: - await self._create_item(self._memories_container, body=validated) + await self._create_item(self._memories_container, body=dict(validated)) written_doc = validated break except CosmosResourceExistsError: @@ -1392,7 +1510,228 @@ def _emit_reconcile_outcome( }, ) - async def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: + async def _active_memories_for_reconcile(self, user_id: str, memory_type: str, n: int) -> list[dict[str, Any]]: + capped_n = top_literal(n, name="reconcile_memories.n") + query = ( + f"SELECT TOP {capped_n} * FROM c " + "WHERE c.user_id = @user_id " + "AND c.type = @memory_type " + f"AND {_ACTIVE_DOC_FILTER} " + "ORDER BY c.created_at DESC" + ) + return await self._query_items( + self._memories_container, + query=query, + parameters=[ + {"name": "@user_id", "value": user_id}, + {"name": "@memory_type", "value": memory_type}, + ], + ) + + async def _load_memories_by_ids( + self, + user_id: str, + memory_type: str, + ids: Iterable[str], + ) -> list[dict[str, Any]]: + id_list = [mid for mid in dict.fromkeys(ids) if mid] + if not id_list: + return [] + placeholders = ", ".join(f"@id{i}" for i in range(len(id_list))) + query = ( + "SELECT * FROM c WHERE c.user_id = @user_id " + "AND c.type = @memory_type " + f"AND c.id IN ({placeholders}) " + f"AND {_ACTIVE_DOC_FILTER}" + ) + parameters = [ + {"name": "@user_id", "value": user_id}, + {"name": "@memory_type", "value": memory_type}, + ] + parameters.extend({"name": f"@id{i}", "value": mid} for i, mid in enumerate(id_list)) + return await self._query_items(self._memories_container, query=query, parameters=parameters) + + async def _build_candidate_clusters( + self, + user_id: str, + memory_type: str, + n: int, + ) -> tuple[list[list[dict[str, Any]]], int, list[dict[str, Any]]]: + """Cluster dup-candidate seeds (+ vector neighbors) into connected components. + + Returns ``(clusters, node_count, seeds)`` where each cluster has >= 2 members, + ``node_count`` is the total distinct memories pulled into the graph (used to + report ``reconcile_llm_calls_saved``), and ``seeds`` is the tagged seed scan + (so the caller can clear stale tags on orphan seeds that never clustered). + The seed scan is bounded to ``n`` so a single cluster can never exceed the + reconcile prompt's pool cap. + """ + cluster_sim = get_dedup_cluster_sim() + top_k = get_dedup_candidate_topk() + distance_function = await self._vector_distance_function() + query = ( + f"SELECT TOP {top_literal(n, name='reconcile_memories.candidate_n')} * FROM c " + "WHERE c.user_id = @user_id " + "AND c.type = @memory_type " + "AND ARRAY_CONTAINS(c.tags, @tag) " + f"AND {_ACTIVE_DOC_FILTER} " + "ORDER BY c.created_at DESC" + ) + seeds = await self._query_items( + self._memories_container, + query=query, + parameters=[ + {"name": "@user_id", "value": user_id}, + {"name": "@memory_type", "value": memory_type}, + {"name": "@tag", "value": "sys:dup-candidate"}, + ], + ) + nodes_by_id: dict[str, dict[str, Any]] = {doc["id"]: doc for doc in seeds if doc.get("id")} + edges: set[tuple[str, str]] = set() + for seed in seeds: + sid = seed.get("id") + if not sid: + continue + dup_of = (seed.get("metadata") or {}).get("dup_of") if isinstance(seed.get("metadata"), dict) else None + if dup_of: + for doc in await self._load_memories_by_ids(user_id, memory_type, [dup_of]): + nodes_by_id[doc["id"]] = doc + edges.add(tuple(sorted((sid, doc["id"])))) + for cand in await self._vector_candidates( + user_id=user_id, + embedding=seed.get("embedding"), + memory_type=memory_type, + top_k=top_k, + exclude_ids={sid}, + ): + if vector_similarity_at_least(float(cand.get("score") or 0.0), cluster_sim, distance_function): + for doc in await self._load_memories_by_ids(user_id, memory_type, [cand.get("id")]): + nodes_by_id[doc["id"]] = doc + edges.add(tuple(sorted((sid, doc["id"])))) + node_ids = set(nodes_by_id) + for doc in list(nodes_by_id.values()): + did = doc.get("id") + if not did: + continue + for cand in await self._vector_candidates( + user_id=user_id, + embedding=doc.get("embedding"), + memory_type=memory_type, + top_k=top_k, + exclude_ids={did}, + ): + cid = cand.get("id") + if cid in node_ids and vector_similarity_at_least( + float(cand.get("score") or 0.0), cluster_sim, distance_function + ): + edges.add(tuple(sorted((did, cid)))) + + adjacency: dict[str, set[str]] = {node_id: set() for node_id in nodes_by_id} + for left, right in edges: + if left != right and left in adjacency and right in adjacency: + adjacency[left].add(right) + adjacency[right].add(left) + clusters: list[list[dict[str, Any]]] = [] + seen: set[str] = set() + for node_id in adjacency: + if node_id in seen: + continue + stack = [node_id] + component: list[str] = [] + seen.add(node_id) + while stack: + current = stack.pop() + component.append(current) + for nxt in adjacency[current]: + if nxt not in seen: + seen.add(nxt) + stack.append(nxt) + if len(component) >= 2: + # Cap cluster size at the reconcile pool limit: lowering the cluster + # threshold can chain many facts into one giant transitive component + # that would blow the prompt cap; keep the most-recent ``n``. + if len(component) > n: + component = component[:n] + clusters.append([nodes_by_id[cid] for cid in component]) + return clusters, len(nodes_by_id), seeds + + async def _reconcile_candidate_mode( + self, user_id: str, *, n: int, memory_type: str, started_at: float + ) -> dict[str, int]: + # Candidate clustering only. The periodic full-pool backstop that catches + # dissimilar-embedding contradictions ("vegetarian" vs "loves steak") is + # driven by the caller via ``full_rebuild`` on a PERSISTED-counter cadence + # (in-process auto-trigger + durable change-feed), not an in-memory sweep + # counter — the latter reset per worker/process and never fired reliably on + # the Function-App backend. + clusters, node_count, seeds = await self._build_candidate_clusters(user_id, memory_type, n) + aggregate = {"kept": 0, "merged": 0, "contradicted": 0} + clustered_ids: set[str] = set() + for cluster in clusters: + # Mark members as clustered BEFORE the LLM call so a failed cluster's + # seeds are not treated as orphans below (which would clear their tags + # and prevent a retry). A truncated/malformed LLM response on one + # cluster must not abort the sweep or starve the remaining clusters. + clustered_ids.update(doc["id"] for doc in cluster if doc.get("id")) + try: + counts, consumed = await self._reconcile_pool(user_id, memory_type, cluster) + except Exception as exc: + logger.warning( + "reconcile_memories: cluster reconcile failed user_id=%s memory_type=%s; " + "skipping cluster, tags retained for next sweep: %s", + user_id, + memory_type, + exc, + ) + continue + for key in aggregate: + aggregate[key] += int(counts.get(key, 0)) + # Clear dup-candidate tags only on survivors. Re-upserting a doc that + # was just superseded (duplicate source or contradiction loser) would + # resurrect it, since the in-memory cluster copy lacks superseded_by. + survivors = [doc for doc in cluster if doc.get("id") and doc["id"] not in consumed] + await self._clear_dup_candidate_tags(survivors) + # Orphan seeds: tagged dup-candidates that never joined a cluster have no + # near-duplicate, so clear the stale tag — otherwise every future sweep + # re-scans them as seeds and they accumulate forever. + orphan_seeds = [seed for seed in seeds if seed.get("id") and seed["id"] not in clustered_ids] + await self._clear_dup_candidate_tags(orphan_seeds) + aggregate["reconcile_clusters_sent"] = len(clusters) + aggregate["reconcile_llm_calls_saved"] = max(0, node_count - len(clusters)) + logger.info( + "reconcile_memories candidate completed user_id=%s memory_type=%s result=%s", + user_id, + memory_type, + aggregate, + ) + self._emit_reconcile_outcome( + started_at=started_at, + user_id=user_id, + candidates=node_count, + result=aggregate, + ) + return aggregate + + async def _clear_dup_candidate_tags(self, docs: Iterable[dict[str, Any]]) -> None: + for doc in docs: + tags = [tag for tag in (doc.get("tags") or []) if tag != "sys:dup-candidate"] + if tags == (doc.get("tags") or []): + continue + updated = dict(doc) + updated["tags"] = tags + metadata = dict(updated.get("metadata") or {}) + metadata.pop("dup_of", None) + metadata.pop("dup_score", None) + updated["metadata"] = metadata + updated["updated_at"] = datetime.now(timezone.utc).isoformat() + try: + await self._upsert_memory(updated) + except Exception: + logger.exception("reconcile_memories: failed to clear dup-candidate tag id=%s", doc.get("id")) + + async def reconcile_memories( + self, user_id: str, n: int = 50, *, memory_type: str = "fact", full_rebuild: bool = False + ) -> dict[str, int]: """Reconcile a user's active facts in a single LLM pass. Loads the most recent ``n`` active (non-superseded) facts for @@ -1417,44 +1756,64 @@ async def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: raise ValidationError(f"n must be a positive integer, got {n!r}") if n > 500: raise ValidationError(f"n must be <= 500 to bound prompt size and LLM cost, got {n}") + if memory_type not in {"fact", "episodic", "procedural"}: + raise ValidationError(f"memory_type must be one of fact, episodic, procedural, got {memory_type!r}") + if memory_type == "procedural": + result = { + "kept": 0, + "merged": 0, + "contradicted": 0, + "reconcile_clusters_sent": 0, + "reconcile_llm_calls_saved": 0, + } + logger.info("reconcile_memories procedural no-op user_id=%s result=%s", user_id, result) + return result started_at = time.monotonic() - logger.info("reconcile_memories started user_id=%s n=%d", user_id, n) + logger.info("reconcile_memories started user_id=%s n=%d memory_type=%s", user_id, n, memory_type) + + # Explicit user-triggered reconcile (full_rebuild) always takes the + # full-pool single-LLM-pass path: it sees every active fact together, so it + # catches contradictions that aren't vector-similar (e.g. "vegetarian" vs + # "loves steak") — which candidate clustering, keyed on near-duplicate + # similarity, would never group. Automatic sweeps use cheap candidate mode. + if get_dedup_reconcile_mode() == "candidate" and not full_rebuild: + return await self._reconcile_candidate_mode( + user_id, n=n, memory_type=memory_type, started_at=started_at + ) - # ---- 1. Load up to N most recent active facts ---- - # ORDER BY c.created_at DESC keeps the TOP cap deterministic across - # physical partitions and matches the dedup prompt's tiebreaker - # ("more recent created_at first"). Cosmos's _ts is the last-write - # timestamp, which would diverge from created_at after any UPDATE. - query = ( - f"SELECT TOP {n} * FROM c " - "WHERE c.user_id = @user_id " - "AND c.type = 'fact' " - f"AND {_ACTIVE_DOC_FILTER} " - "ORDER BY c.created_at DESC" - ) - parameters: list[dict[str, Any]] = [ - {"name": "@user_id", "value": user_id}, - ] - facts = await self._query_items( - self._memories_container, - query=query, - parameters=parameters, + facts = await self._active_memories_for_reconcile(user_id, memory_type, n) + result, consumed = await self._reconcile_pool(user_id, memory_type, facts) + # Clear dup-candidate tags on survivors so an explicit reconcile(full_rebuild=True) + # doesn't leave stale sys:dup-candidate/dup_of metadata on user-visible memories. + survivors = [doc for doc in facts if doc.get("id") and doc["id"] not in consumed] + await self._clear_dup_candidate_tags(survivors) + self._emit_reconcile_outcome( + started_at=started_at, + user_id=user_id, + candidates=len(facts), + result=result, ) + return result + async def _reconcile_pool( + self, user_id: str, memory_type: str, facts: list[dict[str, Any]] + ) -> tuple[dict[str, int], set[str]]: + """Reconcile an explicit pool of same-type memories in one LLM pass. + + Returns ``({"kept", "merged", "contradicted"}, consumed_ids)`` where + ``consumed_ids`` are the source/loser ids that were actually superseded, + so callers can skip them when clearing dup-candidate tags (re-upserting a + superseded source would resurrect it). Does not emit telemetry — the + caller owns the ``reconcile.outcome`` line. + """ if len(facts) <= 1: logger.info( - "reconcile_memories: %d facts, nothing to reconcile", + "reconcile_memories: %d %s memories, nothing to reconcile", len(facts), + memory_type, ) - early_result = {"kept": len(facts), "merged": 0, "contradicted": 0} - self._emit_reconcile_outcome( - started_at=started_at, - user_id=user_id, - candidates=len(facts), - result=early_result, - ) - return early_result + return {"kept": len(facts), "merged": 0, "contradicted": 0}, set() # ---- 2. Format the facts pool for the prompt ---- # ``json.dumps`` escapes embedded quotes and pipes inside content so @@ -1481,14 +1840,16 @@ async def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: facts_text = "\n".join(lines) # ---- 3. Single LLM call over the entire pool ---- - response_text = await self._run_prompty( - "dedup.prompty", - inputs={"facts_text": facts_text}, - ) + # Polarity keyed on an explicit predicate so a future third type (e.g. + # procedural, were it ever routed here) can't silently diverge from sync. + is_episodic = memory_type == "episodic" + prompt_name = "dedup_episodic.prompty" if is_episodic else "dedup.prompty" + prompt_inputs = {"episodics_text": facts_text} if is_episodic else {"facts_text": facts_text} + response_text = await self._run_prompty(prompt_name, inputs=prompt_inputs) parsed = self._parse_llm_json(response_text) duplicate_groups = parsed.get("duplicate_groups", []) or [] - contradicted_pairs = parsed.get("contradicted_pairs", []) or [] + contradicted_pairs = [] if is_episodic else (parsed.get("contradicted_pairs", []) or []) # ``kept_ids`` from the LLM is used below as a cross-check for # accounting drift (hallucinated IDs, double-counting). The actual # kept count is computed from facts minus consumed losers. @@ -1564,11 +1925,13 @@ async def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: seen_tags: set[str] = set() for src in source_docs: for t in src.get("tags", []) or []: + if t == "sys:dup-candidate": + continue if t not in seen_tags: seen_tags.add(t) merged_tags.append(t) if not merged_tags: - merged_tags = ["sys:fact"] + merged_tags = [f"sys:{memory_type}"] # Union source_memory_ids across all source docs (provenance chain). merged_source_memory_ids: list[str] = [] @@ -1625,16 +1988,39 @@ async def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: # see the same id rather than chaining through a new UUID. merged_content_hash = compute_content_hash(merged_content) merged_id_seed = _ID_SEED_SEP.join((user_id, "merged", merged_content_hash)) - merged_id = "fact_" + hashlib.sha256(merged_id_seed.encode()).hexdigest()[:32] - + merged_prefix = "ep_" if memory_type == "episodic" else "fact_" + merged_id = merged_prefix + hashlib.sha256(merged_id_seed.encode()).hexdigest()[:32] try: + if memory_type == "episodic": + base_metadata = dict(source_docs[0].get("metadata") or {}) + base_metadata.update( + { + "lesson": merged_content, + "merged_via": "reconcile", + "merged_from_count": len(valid_source_ids), + } + ) + base_metadata.setdefault("scope_type", base_metadata.get("scope_type") or "general") + base_metadata.setdefault("scope_value", base_metadata.get("scope_value") or "general") + base_metadata.setdefault("outcome_valence", base_metadata.get("outcome_valence") or "neutral") + record_cls = EpisodicRecord + prompt_lineage = self._prompt_lineage("dedup_episodic.prompty") + metadata = base_metadata + else: + record_cls = FactRecord + prompt_lineage = self._prompt_lineage("dedup.prompty") + metadata = { + "category": "preference", + "merged_via": "reconcile", + "merged_from_count": len(valid_source_ids), + } merged_record = construct_internal( - FactRecord, + record_cls, { "id": merged_id, "user_id": user_id, "role": "system", - "type": "fact", + "type": memory_type, "content": merged_content, "thread_id": recent_thread_id or f"__reconciled__:{user_id}", "confidence": confidence_val if confidence_val is not None else 0.5, @@ -1643,12 +2029,8 @@ async def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: "source_memory_ids": merged_source_memory_ids, "tags": merged_tags, "content_hash": merged_content_hash, - "metadata": { - "category": "preference", - "merged_via": "reconcile", - "merged_from_count": len(valid_source_ids), - }, - **self._prompt_lineage("dedup.prompty"), + "metadata": metadata, + **prompt_lineage, "created_at": datetime.now(timezone.utc).isoformat(), "updated_at": datetime.now(timezone.utc).isoformat(), }, @@ -1820,13 +2202,7 @@ async def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: ) result = {"kept": kept, "merged": merged, "contradicted": contradicted} logger.info("reconcile_memories completed: %s", result) - self._emit_reconcile_outcome( - started_at=started_at, - user_id=user_id, - candidates=len(facts), - result=result, - ) - return result + return result, consumed_ids async def build_procedural_context(self, user_id: str) -> str: """Return the active synthesized procedural prompt for system injection.""" diff --git a/azure/cosmos/agent_memory/aio/store/memory_store.py b/azure/cosmos/agent_memory/aio/store/memory_store.py index 4f725f9..b7aec2e 100644 --- a/azure/cosmos/agent_memory/aio/store/memory_store.py +++ b/azure/cosmos/agent_memory/aio/store/memory_store.py @@ -17,8 +17,8 @@ from azure.cosmos.agent_memory._utils import ( _build_memory_query_builder, _coerce_datetime_iso, - _validate_hybrid_search, compute_content_hash, + extract_keywords, new_id, ) from azure.cosmos.agent_memory.exceptions import ( @@ -804,7 +804,6 @@ async def search( role: Optional[str] = None, memory_types: Optional[list[str]] = None, thread_id: Optional[str] = None, - hybrid_search: bool = False, top_k: int = 5, tags_all: Optional[list[str]] = None, tags_any: Optional[list[str]] = None, @@ -823,9 +822,9 @@ async def search( :meth:`search_turns` to vector-search the raw conversation log instead. """ terms = require_search_terms(search_terms, query) - _validate_hybrid_search(hybrid_search, terms) top = top_literal(top_k, name="top_k") query_vector = await self._embed(terms) + keywords = extract_keywords(terms) qb = _build_memory_query_builder( memory_id=memory_id, @@ -845,11 +844,16 @@ async def search( ) add_salience_filter(qb, min_salience) - sql = build_search_sql(qb=qb, top=top, hybrid_search=hybrid_search, include_superseded=include_superseded) + sql = build_search_sql( + qb=qb, + top=top, + keyword_count=len(keywords), + include_superseded=include_superseded, + ) parameters = qb.get_parameters() parameters.append({"name": "@embedding", "value": query_vector}) - if hybrid_search: - parameters.append({"name": "@key_terms", "value": terms}) + for i, kw in enumerate(keywords): + parameters.append({"name": f"@kw{i}", "value": kw}) partition_key, _ = query_scope(user_id, thread_id) if thread_id is not None and (not memory_types or set(memory_types) & USER_SCOPED_MEMORIES_TYPES): @@ -868,7 +872,6 @@ async def search_turns( user_id: Optional[str] = None, thread_id: Optional[str] = None, role: Optional[str] = None, - hybrid_search: bool = False, top_k: int = 5, tags_all: Optional[list[str]] = None, tags_any: Optional[list[str]] = None, @@ -878,7 +881,7 @@ async def search_turns( *, query: Optional[str] = None, ) -> list[dict[str, Any]]: - """Search raw conversation turns using vector similarity with optional hybrid ranking. + """Search raw conversation turns using vector similarity with hybrid ranking. Only vector-searchable when turn embeddings were enabled at write time (see ``enable_turn_embeddings``). ``user_id`` is required and always @@ -890,9 +893,9 @@ async def search_turns( if not user_id: raise ValidationError("user_id is required for search_turns") terms = require_search_terms(search_terms, query) - _validate_hybrid_search(hybrid_search, terms) top = top_literal(top_k, name="top_k") query_vector = await self._embed(terms) + keywords = extract_keywords(terms) qb = _QueryBuilder() qb.add_filter("c.user_id", "@user_id", user_id) @@ -907,11 +910,11 @@ async def search_turns( before_param="@created_before", ) - sql = build_search_sql(qb=qb, top=top, hybrid_search=hybrid_search, include_superseded=False) + sql = build_search_sql(qb=qb, top=top, keyword_count=len(keywords), include_superseded=False) parameters = qb.get_parameters() parameters.append({"name": "@embedding", "value": query_vector}) - if hybrid_search: - parameters.append({"name": "@key_terms", "value": terms}) + for i, kw in enumerate(keywords): + parameters.append({"name": f"@kw{i}", "value": kw}) partition_key, _ = query_scope(user_id, thread_id) logger.debug("AsyncMemoryStore.search_turns query: %s", sql) diff --git a/azure/cosmos/agent_memory/auto_trigger.py b/azure/cosmos/agent_memory/auto_trigger.py index 10ef595..a494181 100644 --- a/azure/cosmos/agent_memory/auto_trigger.py +++ b/azure/cosmos/agent_memory/auto_trigger.py @@ -108,6 +108,12 @@ def maybe_trigger_steps( return n_dedup_turns = n_facts * n_dedup if n_facts > 0 and n_dedup > 0 else 0 + # Full-pool reconcile backstop cadence, derived from the PERSISTED counter so + # it fires reliably regardless of process/worker lifetime (the old in-memory + # per-instance sweep counter reset and under-fired). Every + # DEDUP_FULL_RECLUSTER_EVERY_N-th reconcile = every n_dedup_turns * N turns. + n_full_recluster = _threshold_int(thresholds, "get_dedup_full_recluster_every_n", "DEDUP_FULL_RECLUSTER_EVERY_N") + n_full_turns = n_dedup_turns * n_full_recluster if (n_dedup_turns > 0 and n_full_recluster > 0) else 0 user_batch_counts = _trigger_thread_steps( processor, counter_container, @@ -115,6 +121,7 @@ def maybe_trigger_steps( n_facts=n_facts, n_summary=n_summary, n_dedup_turns=n_dedup_turns, + n_full_turns=n_full_turns, thresholds=thresholds, ) _trigger_user_steps(processor, counter_container, user_batch_counts, n_user=n_user) @@ -128,6 +135,7 @@ def _trigger_thread_steps( n_facts: int, n_summary: int, n_dedup_turns: int, + n_full_turns: int, thresholds: Any = None, ) -> dict[str, int]: user_batch_counts: dict[str, int] = {} @@ -154,14 +162,43 @@ def _trigger_thread_steps( counter_id, user_id, thread_id, + new_count=new_count, fire_extract=n_facts > 0 and _counters.crosses_threshold(old_count, new_count, n_facts), fire_summary=n_summary > 0 and _counters.crosses_threshold(old_count, new_count, n_summary), fire_dedup=n_dedup_turns > 0 and _counters.crosses_threshold(old_count, new_count, n_dedup_turns), + fire_full_rebuild=n_full_turns > 0 and _counters.crosses_threshold(old_count, new_count, n_full_turns), thresholds=thresholds, ) return user_batch_counts +def _watermark_recent_k( + counter_container: Any, + counter_id: str, + user_id: str, + thread_id: str, + *, + new_count: int, +) -> int: + """recent_k covering every turn since the last successful extract. + + ``new_count - watermark`` is exactly the count of turns added since the last + successful extract, and the newest-``recent_k`` slice in the pipeline picks + precisely those turns — so the whole backlog is covered and the watermark can + safely advance to ``new_count``. **Not capped:** capping would extract only the + newest N and strand the oldest ``backlog - N`` turns. + + **Bootstrap:** before a thread's first successful extract there is no watermark, + so we treat it as ``0`` — ``recent_k = new_count`` covers every turn the thread + has so far. Using only the current batch size here would strand turns added + during earlier *failed* extracts, because the watermark still advances to + ``new_count`` on the first success. + """ + watermark = _counters.read_extract_watermark_sync(counter_container, counter_id, user_id, thread_id) + base = watermark if watermark is not None else 0 + return max(new_count - base, 1) + + def _fire_thread_steps( processor: InProcessProcessor, counter_container: Any, @@ -169,9 +206,11 @@ def _fire_thread_steps( user_id: str, thread_id: str, *, + new_count: int, fire_extract: bool, fire_summary: bool, fire_dedup: bool, + fire_full_rebuild: bool = False, thresholds: Any = None, ) -> None: fire_procedural = fire_dedup and bool( @@ -182,13 +221,28 @@ def _fire_thread_steps( default=True, ) ) + if fire_extract: + recent_k = _watermark_recent_k( + counter_container, + counter_id, + user_id, + thread_id, + new_count=new_count, + ) + try: + processor.process_extract_memories(user_id=user_id, thread_id=thread_id, recent_k=recent_k) + _counters.advance_extract_watermark_sync(counter_container, counter_id, user_id, thread_id, new_count) + except Exception as exc: + logger.warning("Auto-trigger process_extract_memories failed for %s/%s: %s", user_id, thread_id, exc) + _counters.stamp_failure_sync( + counter_container, counter_id, user_id, thread_id, f"process_extract_memories: {exc!r}" + ) for enabled, label, call in ( ( - fire_extract, - "process_extract_memories", - lambda: processor.process_extract_memories(user_id=user_id, thread_id=thread_id), + fire_dedup, + "process_reconcile", + lambda: processor.process_reconcile(user_id=user_id, full_rebuild=fire_full_rebuild), ), - (fire_dedup, "process_reconcile", lambda: processor.process_reconcile(user_id=user_id)), ( fire_procedural, "synthesize_procedural", diff --git a/azure/cosmos/agent_memory/cosmos_memory_client.py b/azure/cosmos/agent_memory/cosmos_memory_client.py index f4924c3..4bad143 100644 --- a/azure/cosmos/agent_memory/cosmos_memory_client.py +++ b/azure/cosmos/agent_memory/cosmos_memory_client.py @@ -621,7 +621,6 @@ def search_cosmos( role: Optional[str] = None, memory_types: Optional[list[str]] = None, thread_id: Optional[str] = None, - hybrid_search: bool = False, top_k: int = 5, tags_all: Optional[list[str]] = None, tags_any: Optional[list[str]] = None, @@ -644,7 +643,6 @@ def search_cosmos( role=role, memory_types=memory_types, thread_id=thread_id, - hybrid_search=hybrid_search, top_k=top_k, tags_all=tags_all, tags_any=tags_any, @@ -662,7 +660,6 @@ def search_turns( user_id: str, thread_id: Optional[str] = None, role: Optional[str] = None, - hybrid_search: bool = False, top_k: int = 5, tags_all: Optional[list[str]] = None, tags_any: Optional[list[str]] = None, @@ -684,7 +681,6 @@ def search_turns( user_id=user_id, thread_id=thread_id, role=role, - hybrid_search=hybrid_search, top_k=top_k, tags_all=tags_all, tags_any=tags_any, @@ -801,15 +797,30 @@ def search_episodic_memories( include_superseded: bool = False, ) -> list[dict[str, Any]]: """Semantic search across episodic memories for a user.""" - return self._get_store().search_episodic(user_id, search_terms, top_k, min_salience, include_superseded) + return self._get_store().search_episodic( + user_id=user_id, + search_terms=search_terms, + top_k=top_k, + min_salience=min_salience, + include_superseded=include_superseded, + ) def build_procedural_context(self, user_id: str) -> str: """Build formatted procedural context for prompt injection.""" return self._get_pipeline().build_procedural_context(user_id) - def build_episodic_context(self, user_id: str, query: str, top_k: int = 3) -> str: + def build_episodic_context( + self, + user_id: str, + query: str, + top_k: int = 3, + ) -> str: """Build formatted context of relevant past experiences.""" - return self._get_store().build_episodic_context(user_id, query, top_k) + return self._get_store().build_episodic_context( + user_id=user_id, + query=query, + top_k=top_k, + ) def extract_memories( self, @@ -855,7 +866,9 @@ def reconcile(self, user_id: str, n: Optional[int] = None) -> dict[str, int]: """Reconcile a user's facts via the contradiction-aware dedup pass.""" from .thresholds import get_dedup_pool_size - return self._get_pipeline().reconcile_memories(user_id, n if n is not None else get_dedup_pool_size()) + return self._get_pipeline().reconcile_memories( + user_id, n if n is not None else get_dedup_pool_size(), full_rebuild=True + ) def process_now(self, *, user_id: str, thread_id: str) -> "ProcessThreadResult": """Force the processor to run the full pipeline RIGHT NOW for one thread. diff --git a/azure/cosmos/agent_memory/processors/base.py b/azure/cosmos/agent_memory/processors/base.py index cd257b7..2fbf9b2 100644 --- a/azure/cosmos/agent_memory/processors/base.py +++ b/azure/cosmos/agent_memory/processors/base.py @@ -79,6 +79,7 @@ def process_extract_memories( *, user_id: str, thread_id: str, + recent_k: Optional[int] = None, ) -> dict[str, int]: ... def process_thread_summary( @@ -99,6 +100,7 @@ def process_reconcile( self, *, user_id: str, + full_rebuild: bool = False, ) -> int: ... def generate_user_summary( diff --git a/azure/cosmos/agent_memory/processors/durable.py b/azure/cosmos/agent_memory/processors/durable.py index d63b24f..9834dcf 100644 --- a/azure/cosmos/agent_memory/processors/durable.py +++ b/azure/cosmos/agent_memory/processors/durable.py @@ -46,6 +46,7 @@ def process_extract_memories( *, user_id: str, thread_id: str, + recent_k: Optional[int] = None, ) -> dict[str, int]: logger.debug( "DurableFunctionProcessor.process_extract_memories no-op user_id=%s thread_id=%s", @@ -79,7 +80,7 @@ def process_user_summary( ) return UserSummaryResult(summary=None) - def process_reconcile(self, *, user_id: str) -> int: + def process_reconcile(self, *, user_id: str, full_rebuild: bool = False) -> int: logger.debug( "DurableFunctionProcessor.process_reconcile no-op user_id=%s", user_id, diff --git a/azure/cosmos/agent_memory/processors/inprocess.py b/azure/cosmos/agent_memory/processors/inprocess.py index 9684ce0..aac1760 100644 --- a/azure/cosmos/agent_memory/processors/inprocess.py +++ b/azure/cosmos/agent_memory/processors/inprocess.py @@ -81,9 +81,7 @@ def process_thread( thread_summary = self._pipeline.generate_thread_summary(user_id, thread_id) extracted = self._pipeline.extract_memories(user_id, thread_id) - reconciled = self._pipeline.reconcile_memories(user_id, get_dedup_pool_size()) - - deduped_count = self._extract_reconcile_count(reconciled) + deduped_count = self._reconcile_fact_and_episodic(user_id, get_dedup_pool_size()) extracted_counts: dict[str, int] = ( {k: v for k, v in extracted.items() if isinstance(v, int)} if isinstance(extracted, dict) else {} @@ -102,8 +100,9 @@ def process_extract_memories( *, user_id: str, thread_id: str, + recent_k: Optional[int] = None, ) -> dict[str, int]: - extracted = self._pipeline.extract_memories(user_id, thread_id) + extracted = self._pipeline.extract_memories(user_id, thread_id, recent_k=recent_k) return {k: v for k, v in extracted.items() if isinstance(v, int)} if isinstance(extracted, dict) else {} def process_thread_summary( @@ -126,16 +125,32 @@ def process_user_summary( summary = self._pipeline.generate_user_summary(user_id, thread_ids) return UserSummaryResult(summary=summary if isinstance(summary, dict) else None) - def process_reconcile(self, *, user_id: str) -> int: + def process_reconcile(self, *, user_id: str, full_rebuild: bool = False) -> int: """Run reconciliation standalone. Returns count of facts merged + contradicted. Pool size is read from ``DEDUP_POOL_SIZE`` (env-tunable, default 50, capped at 500) so the auto-trigger and the standalone path agree. + ``full_rebuild`` (set by the auto-trigger on its persisted-counter + full-recluster cadence) forces the full-pool LLM pass that catches + dissimilar-embedding contradictions. """ from ..thresholds import get_dedup_pool_size - reconciled = self._pipeline.reconcile_memories(user_id, n=get_dedup_pool_size()) - return self._extract_reconcile_count(reconciled) + return self._reconcile_fact_and_episodic(user_id, get_dedup_pool_size(), full_rebuild=full_rebuild) + + def _reconcile_fact_and_episodic(self, user_id: str, n: int, *, full_rebuild: bool = False) -> int: + """Reconcile facts and episodic memories; sum merged+contradicted counts. + + SDK in-process processing reconciles both types (matching the Durable + backend) so episodic dups don't accrue forever. + """ + total = 0 + for memory_type in ("fact", "episodic"): + reconciled = self._pipeline.reconcile_memories( + user_id, n=n, memory_type=memory_type, full_rebuild=full_rebuild + ) + total += self._extract_reconcile_count(reconciled) + return total @staticmethod def _extract_reconcile_count(reconciled: Any) -> int: diff --git a/azure/cosmos/agent_memory/prompts/_schemas.py b/azure/cosmos/agent_memory/prompts/_schemas.py index 1b34973..6c3306c 100644 --- a/azure/cosmos/agent_memory/prompts/_schemas.py +++ b/azure/cosmos/agent_memory/prompts/_schemas.py @@ -65,6 +65,34 @@ } +# --------------------------------------------------------------------------- +# dedup_episodic.prompty — reconcile a pool of active episodic memories +# (MERGE-ONLY: same-event duplicates collapse; no contradiction/deletion) +# --------------------------------------------------------------------------- +DEDUP_EPISODIC_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "duplicate_groups": { + "type": "array", + "items": { + "type": "object", + "properties": { + "merged_content": {"type": "string"}, + "source_ids": {"type": "array", "items": {"type": "string"}}, + "confidence": {"type": ["number", "null"]}, + "salience": {"type": ["number", "null"]}, + }, + "required": ["merged_content", "source_ids", "confidence", "salience"], + "additionalProperties": False, + }, + }, + "kept_ids": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["duplicate_groups", "kept_ids"], + "additionalProperties": False, +} + + # --------------------------------------------------------------------------- # extract_memories.prompty — extract facts + episodic + unclassified # --------------------------------------------------------------------------- @@ -91,8 +119,6 @@ "salience": {"type": "number"}, "temporal_context": {"type": ["string", "null"]}, "tags": {"type": "array", "items": {"type": "string"}}, - "action": {"type": "string", "enum": ["ADD", "UPDATE", "CONTRADICT"]}, - "supersedes_id": {"type": ["string", "null"]}, }, "required": [ "text", @@ -104,8 +130,6 @@ "salience", "temporal_context", "tags", - "action", - "supersedes_id", ], "additionalProperties": False, } @@ -115,7 +139,6 @@ "properties": { "scope_type": {"type": "string"}, "scope_value": {"type": "string"}, - "text": {"type": "string"}, "situation": {"type": ["string", "null"]}, "action_taken": {"type": ["string", "null"]}, "outcome": {"type": ["string", "null"]}, @@ -133,7 +156,6 @@ "required": [ "scope_type", "scope_value", - "text", "situation", "action_taken", "outcome", @@ -282,6 +304,7 @@ # --------------------------------------------------------------------------- PROMPTY_SCHEMAS: dict[str, tuple[str, dict[str, Any]]] = { "dedup.prompty": ("DedupOutput", DEDUP_SCHEMA), + "dedup_episodic.prompty": ("DedupEpisodicOutput", DEDUP_EPISODIC_SCHEMA), "extract_memories.prompty": ("ExtractMemoriesOutput", EXTRACT_MEMORIES_SCHEMA), "summarize.prompty": ("SummarizeOutput", SUMMARIZE_SCHEMA), "summarize_update.prompty": ("SummarizeUpdateOutput", SUMMARIZE_UPDATE_SCHEMA), diff --git a/azure/cosmos/agent_memory/prompts/dedup.prompty b/azure/cosmos/agent_memory/prompts/dedup.prompty index c83a05a..d45bb60 100644 --- a/azure/cosmos/agent_memory/prompts/dedup.prompty +++ b/azure/cosmos/agent_memory/prompts/dedup.prompty @@ -6,7 +6,7 @@ model: apiType: chat options: seed: 43 - maxOutputTokens: 2000 + maxOutputTokens: 16384 additionalProperties: response_format: type: json_object diff --git a/azure/cosmos/agent_memory/prompts/dedup_episodic.prompty b/azure/cosmos/agent_memory/prompts/dedup_episodic.prompty new file mode 100644 index 0000000..6008dd2 --- /dev/null +++ b/azure/cosmos/agent_memory/prompts/dedup_episodic.prompty @@ -0,0 +1,160 @@ +--- +name: dedup_episodic +version: v1 +description: Reconcile a pool of active episodic memories — collapse only true same-event duplicates. +model: + apiType: chat + options: + seed: 43 + maxOutputTokens: 16384 + additionalProperties: + response_format: + type: json_object +inputs: + episodics_text: + type: string +--- + +system: +You are a precision episodic-memory reconciliation system. You receive a pool of active episodic memories (each with an ID, content, confidence, salience, and creation timestamp) and must collapse only true same-event duplicates while leaving distinct experiences untouched. + +## Your Goal +Produce a clean merge-only reconciliation that: +1. Collapses repeated extractions of the same past experience into a single merged episode. +2. Preserves distinct occurrences, even when they involve the same action, tool, project, or outcome. +3. Leaves everything else untouched. + +## What is an EPISODIC MEMORY + +An episodic memory is a **past experience**: a specific event or bounded experience in the shape `situation → action_taken → outcome`, with a scope and an `outcome_valence`. It records what happened, not an atomic claim about what is generally true. + +Examples: +- "During the Q3 launch, Redis-backed sessions timed out under load → the team switched to DynamoDB → sessions stabilized" → episodic. +- "On Monday's load test, increasing worker count to 20 still failed with timeouts" → episodic. + +## Two Orthogonal Outcomes (mutually exclusive, exhaustive) + +Every input episode must end up in exactly one of these two places: +- `duplicate_groups[*].source_ids` — the episode is a re-extraction of the same event as one or more other episodes. +- `kept_ids` — the episode is not a true same-event duplicate. + +No input episode may be omitted from the output. Do NOT emit a `contradicted_pairs` key or any deletion bucket. + +## What is a DUPLICATE + +Two or more episodic memories are duplicates only when they describe the **same event re-extracted**: the same situation, the same action taken, and the same outcome. + +The duplicate bar is HIGH. Same topic, same tool, same action pattern, or same outcome is not enough. + +Resolution: emit one entry in `duplicate_groups` with: +- `merged_content` — a clean, self-contained same-event restatement. Do NOT concatenate the originals; synthesize. +- `source_ids` — every original ID that participates in the duplicate group. It must contain at least 2 IDs. +- `confidence` — the **max** confidence among the source episodes. +- `salience` — the **max** salience among the source episodes. + +## What is KEPT + +Episodes that are not true same-event duplicates go in `kept_ids`. + +Distinct occurrences MUST NOT merge. If the user tried the same action on Monday and again on Tuesday with the same outcome, those are separate experiences. Frequency and repetition are evidence and must be preserved. + +## No Contradiction Handling + +Past events are append-only ground truth. There is no contradiction handling and no deletion of episodes in this prompt. + +- Do NOT pick winners or losers. +- Do NOT soft-delete an episode because another episode appears to conflict with it. +- Do NOT emit `contradicted_pairs` or any contradiction/deletion bucket. +- Conflicting lessons are resolved elsewhere, not here. + +## Decision Guidelines + +1. **Conservative bias.** If you cannot confidently prove two episodes are the same event, put them both in `kept_ids`. Over-merging distinct experiences is worse than retaining redundancy. + +2. **Same event test.** Merge only when the memories align on the concrete situation/scope, action taken, and outcome. If the date, run, incident, session, customer, environment, or other occurrence marker differs, keep them separate. + +3. **Don't drop information.** A `merged_content` must preserve every material detail from its sources. If you cannot preserve everything in one clean same-event restatement, the episodes probably are not duplicates. + +4. **Don't fabricate.** Never introduce dates, outcomes, causes, tools, or entities that are not present in at least one source. + +## Input Format + +You will receive a numbered list of episodic memories. Each line has the form: +``` +N. ID: | Content: "" | Confidence: 0.85 | Salience: 0.7 | Created: +``` + +## Missing-Field Handling + +Some episodes may show `Confidence: N/A`, `Salience: N/A`, or `Created: N/A`. + +- Treat `N/A` confidence/salience as **unknown** — set those fields to `null` + in the `duplicate_groups` output (do NOT omit the keys); the pipeline will + fall back to `max(source.confidence)` / `max(source.salience)`. +- Created timestamps are for traceability only. Do not use recency to delete or prefer episodes. + +## Output Format + +You must output ONLY valid JSON matching this exact schema. No preamble, no explanation, no markdown fences — just the JSON object. + +```json +{ + "duplicate_groups": [ + {"merged_content": "", "source_ids": ["", ""], "confidence": 0.0, "salience": 0.0} + ], + "kept_ids": ["", ""] +} +``` + + +The `confidence` and `salience` values **must** be real numbers between 0 and 1 in the actual output — the `0.0` above is a structural placeholder, not a literal value to echo. Compute them as the maximum across source episodes in the group. If you cannot determine confidence or salience, set the field to `null` (the pipeline will fall back to `max(source.*)` from the source records). **Do not omit the keys** — the response schema requires both fields to appear on every group, with either a number or `null` as the value. + +If a bucket is empty, emit it as an empty array (`[]`) rather than omitting the key. Never include `contradicted_pairs`. + +## Worked Examples + +### Example 1: True same-event duplicate + +**Input pool:** +``` +1. ID: E1 | Content: "During the Q3 checkout launch, Redis sessions timed out under load; the team switched session storage to DynamoDB and checkout stabilized." | Confidence: 0.9 | Salience: 0.8 | Created: 2024-01-06T00:00:00Z +2. ID: E2 | Content: "In the Q3 checkout launch incident, Redis-backed sessions failed under load, so the team moved sessions to DynamoDB, which stabilized checkout." | Confidence: 0.85 | Salience: 0.9 | Created: 2024-01-07T00:00:00Z +3. ID: E3 | Content: "On Tuesday's checkout load test, raising Redis connection limits still produced session timeouts." | Confidence: 0.8 | Salience: 0.6 | Created: 2024-01-08T00:00:00Z +``` + +**Expected output:** +```json +{ + "duplicate_groups": [ + {"merged_content": "During the Q3 checkout launch, Redis-backed sessions timed out under load; the team moved session storage to DynamoDB, and checkout stabilized.", "source_ids": ["E1", "E2"], "confidence": 0.9, "salience": 0.9} + ], + "kept_ids": ["E3"] +} +``` + +E1 and E2 describe the same situation, action, and outcome. E3 is related but a different occurrence, so it stays kept. + +### Example 2: Same action on different days is not a duplicate + +**Input pool:** +``` +1. ID: E4 | Content: "On Monday's ingestion test, the team increased workers to 20, but the pipeline still timed out." | Confidence: 0.9 | Salience: 0.7 | Created: 2024-02-01T00:00:00Z +2. ID: E5 | Content: "On Tuesday's ingestion test, the team increased workers to 20, but the pipeline still timed out." | Confidence: 0.9 | Salience: 0.7 | Created: 2024-02-02T00:00:00Z +``` + +**Expected output:** +```json +{ + "duplicate_groups": [], + "kept_ids": ["E4", "E5"] +} +``` + +These episodes share the same action and outcome, but they occurred on different days. Preserve both because repetition is evidence. + +user: +Reconcile the following pool of active episodic memories: + +{{episodics_text}} + +Return JSON exactly matching the schema above. diff --git a/azure/cosmos/agent_memory/prompts/extract_memories.prompty b/azure/cosmos/agent_memory/prompts/extract_memories.prompty index d266f1d..723b458 100644 --- a/azure/cosmos/agent_memory/prompts/extract_memories.prompty +++ b/azure/cosmos/agent_memory/prompts/extract_memories.prompty @@ -6,7 +6,7 @@ model: apiType: chat options: seed: 42 - maxOutputTokens: 2500 + maxOutputTokens: 16384 additionalProperties: response_format: type: json_object @@ -14,9 +14,6 @@ inputs: existing_facts: type: string default: '(none)' - existing_episodics: - type: string - default: '(none)' transcript: type: string --- @@ -34,11 +31,11 @@ Every memory must be explicitly grounded in the conversation — never inferred, ## Speaker Discrimination — Where Memories May Come From -The transcript below is line-tagged: `[user]:` lines are the human's own words, `[assistant]:` lines are the agent's response. These two sources are NOT interchangeable. +The transcript below is line-tagged: `[user]:` lines are the human's own words, `[agent]:` lines are the agent's response. These two sources are NOT interchangeable. -- **Facts about the user may ONLY come from `[user]:` lines.** The assistant may restate, paraphrase, confirm, or make assumptions on the user's behalf ("Got it, you don't eat meat", "I assume you want a luxury hotel") — those are the agent's response, NOT the user's assertion. Never treat assistant text as a new source of user facts. If a fact appears only in `[assistant]:` content and is not asserted by the user, do not extract it. -- **Episodic memories may use both speakers' content** — the user's stated intent or scope is the anchor (and must be present in `[user]:`), but the assistant's content may help fill in the `action_taken` or `outcome` of a `situation → action_taken → outcome` arc when the agent carried out the action on the user's behalf. -- The assistant's general world-knowledge answers (e.g. "Python 3.13 was released in October 2024") are never user facts and never user episodic memories. +- **Facts about the user may ONLY come from `[user]:` lines.** The agent may restate, paraphrase, confirm, or make assumptions on the user's behalf ("Got it, you don't eat meat", "I assume you want a luxury hotel") — those are the agent's response, NOT the user's assertion. Never treat agent text as a new source of user facts. If a fact appears only in `[agent]:` content and is not asserted by the user, do not extract it. +- **Episodic memories may use both speakers' content** — the user's stated intent or scope is the anchor (and must be present in `[user]:`), but the agent's content may help fill in the `action_taken` or `outcome` of a `situation → action_taken → outcome` arc when the agent carried out the action on the user's behalf. +- The agent's general world-knowledge answers (e.g. "Python 3.13 was released in October 2024") are never user facts and never user episodic memories. ## Confidence Scoring @@ -73,94 +70,12 @@ Extract concrete, factual statements that fall into these categories: - Each fact must be self-contained and intelligible without context — no pronouns like "it" or "they" without antecedents - Write in third person ("The user...", "The project...") - Keep each fact concise — under 40 words -- Consolidate closely related items **within the same category** into a single fact (e.g., multiple search results on the same topic) -- **Never merge across categories.** A single user turn that combines, say, a `preference` ("I don't eat meat") and a `requirement` ("I need wheelchair-accessible restaurants") MUST produce two separate facts. Compound user statements regularly cross category boundaries; extract every category that applies. Silently folding one into the other drops information. -- Only split *within* a category when the claims are about genuinely different topics +- Consolidate closely related items into a single fact (e.g., multiple search results on the same topic) +- Only split into separate facts when the claims are about genuinely different topics - Each fact will be stored as its own document with its own vector embedding — self-contained phrasing is essential for retrieval -### Fact Reconciliation -Before adding a fact, check it against the existing memories provided below. Use the `action` field: -- **ADD** — This is a genuinely new fact not present in existing memories -- **UPDATE** — This fact refines, narrows, or supplies more detail to an existing memory while remaining *compatible* with it (e.g. "user lives in Seattle" → "user lives in Seattle, WA"). Set `supersedes_id` to the ID of the memory being refined -- **CONTRADICT** — This new fact is the opposite, a negation, or otherwise mutually exclusive with an existing memory (e.g. existing: "user is vegetarian"; new: "user now eats meat"). The old fact is no longer true — emit `CONTRADICT`, **not** UPDATE. Set `supersedes_id` to the ID of the memory being contradicted -- **NONE** — This fact is already captured in existing memories with no meaningful change. Do not include NONE entries in your output — simply omit them - -**UPDATE vs CONTRADICT — the decisive test:** If the old fact and the new fact could both be true at the same moment about the same subject, use UPDATE. If they cannot both be true (one negates, reverses, or excludes the other), use CONTRADICT. - -### ADD / CONTRADICT / UPDATE — hard constraints - -These rules are absolute. Violating them silently corrupts the user's memory store. - -1. **Every emitted fact (ADD, UPDATE, or CONTRADICT) must paraphrase a claim directly stated by the user in a `[user]:` line in this extraction's transcript.** Do not invent supporting, clarifying, or "explicit-negation" facts. Do not synthesize facts by combining, restating, or "consolidating" entries from the existing-facts list. If the [user]: lines in this transcript do not assert claim X, do NOT emit claim X — regardless of what the existing-facts list contains. Merging existing facts is the job of a separate reconciliation pass, not yours. -2. **One user statement → one fact, even when it implies a contradiction.** If a single user statement both adds new information AND opposes an existing fact, emit EXACTLY ONE fact: `text` paraphrases what the user actually said, `action="CONTRADICT"`, `supersedes_id` points at the opposed fact. The `supersedes_id` field by itself encodes the semantic opposition — do NOT also emit a second "explicit-negation" fact that restates the opposite of the prior claim. See the worked examples below. -3. **Fact text must describe the world, never describe a memory operation.** Strings like `"X is contradicted by Y"`, `"The user's prior preference is no longer accurate"`, or `"Previous fact superseded"` are meta-commentary about prior reconciliations and are **never** valid fact content. If you find yourself writing one, drop the item entirely. -4. **`supersedes_id` must point at a fact that is semantically about the same subject and property as the new fact.** A new dietary fact may only contradict an existing dietary fact; a new accessibility requirement may only update an existing accessibility requirement. Cross-subject supersedes (e.g. contradicting a "wheelchair access" fact with a new "loves seafood" fact) are always wrong — emit `ADD` instead. -5. **When in doubt, emit `ADD`.** A spurious `ADD` is recoverable (exact-content-hash deduping catches it; reconciliation can later merge or supersede it). A spurious `CONTRADICT` or `UPDATE` corrupts the audit trail and silently marks a still-valid fact as superseded. - -#### Worked example A — explicit CONTRADICT (user states the negation directly) - -- Existing memory: `[ID: fact_abc123] User is vegetarian.` -- New turn: *"I started eating meat again last month."* -- Correct output (one fact, text paraphrases what the user said): - ```json - { - "text": "User eats meat.", - "category": "preference", - "action": "CONTRADICT", - "supersedes_id": "fact_abc123", - "confidence": 0.95, - "salience": 0.8 - } - ``` -- **Wrong**: emitting `"action": "UPDATE"` here would be a silent bug — the pipeline treats UPDATE as a compatible refinement and CONTRADICT as an opposing claim, and downstream telemetry / belief-revision logic depends on the distinction. - -#### Worked example B — implicit CONTRADICT (user states a claim that semantically opposes an existing fact) - -- Existing memory: `[ID: fact_xyz789] The user does not eat meat.` -- New turn: *"Actually, I love steak and seafood."* -- Correct output (ONE fact — `text` is what the user said, `supersedes_id` carries the opposition): - ```json - { - "text": "The user loves steak and seafood.", - "category": "preference", - "action": "CONTRADICT", - "supersedes_id": "fact_xyz789", - "confidence": 0.95, - "salience": 0.8 - } - ``` -- **Wrong** — emitting two facts: - ```json - // BAD: phantom explicit-negation fact alongside the literal user claim - [ - {"text": "The user loves steak and seafood.", "action": "ADD", ...}, - {"text": "The user eats meat.", "action": "CONTRADICT", "supersedes_id": "fact_xyz789", ...} - ] - ``` - The phantom `"The user eats meat"` fact was never said by the user — it is an invented restatement to make the contradiction "explicit". This pollutes the fact store with claims the user did not make. The CONTRADICT relation on the literal fact is sufficient. - -#### Worked example C — do NOT synthesize ADDs from existing facts - -- Existing memories: - - `[ID: fact_111] The user eats meat.` - - `[ID: fact_222] The user loves steak and seafood.` -- New turn: *"Normally, I prefer moderate hotels."* -- Correct output (ONE fact — only the hotel preference; the existing-facts list is reference-only): - ```json - { - "text": "The user normally prefers moderate hotels.", - "category": "preference", - "action": "ADD", - "confidence": 0.9, - "salience": 0.7 - } - ``` -- **Wrong** — emitting a synthesized "consolidation" fact: - ```json - // BAD: this fact is a paraphrase-merge of fact_111 + fact_222; the user never said this in this turn - {"text": "The user loves steak and seafood, indicating they eat meat.", "action": "ADD", ...} - ``` - Merging existing facts is the job of a separate reconciliation pass. Your job here is to extract claims from the new [user] turn only. +### Avoiding Duplicates +Before adding a fact, check it against the existing memories provided below. **Only emit facts that are genuinely new.** If a fact is already captured in existing memories with no meaningful change, simply omit it. Do not try to update, merge, or flag contradictions with existing facts — that is handled later by a separate reconciliation step. Your only job here is to extract new facts and episodic memories from the transcript. --- @@ -175,12 +90,8 @@ Extract memories that are tied to a **specific situation, scope, or context** th ### Required Fields - **scope_type** — short, free-form noun describing the kind of context (e.g. `trip`, `project`, `event`, `session`, `release`, `campaign`). Pick whatever vocabulary fits the user's domain. Do not invent a value if one is not implied — if no scope is present, the memory probably belongs in facts. - **scope_value** — the specific instance of that scope (e.g. `Paris 2025`, `Acme revamp`, `Q3 launch`). -- **text** — short, self-contained one-liner (under 25 words) describing what this memory is *about*. Required, non-empty. This is the field that gets embedded and full-text-indexed — vague text like "intent recorded" silently kills retrieval. Write it like a subject line: - - *Planned / in-flight* intents — capture the goal **plus the key constraints** the user has stated. Example: `"Planning a Tokyo trip with vegetarian and wheelchair-accessible-restaurant constraints."` - - *Past events* — summarize situation→action→outcome in one sentence. Example: `"Resolved Q3 K8s OOM outage by raising pod memory limits to 1GB."` - - *Ongoing context* — describe the current state. Example: `"User is heads-down on the Acme launch this week and wants short answers."` -All three of `scope_type`, `scope_value`, and `text` must be non-empty. (The `text` field plays the same role for episodic that it does for facts — self-contained phrasing intended for embedding.) +Both must be non-empty. ### Optional Fields (include only when applicable) - **situation** — the context or problem faced (present for past/in-flight events) @@ -191,7 +102,7 @@ All three of `scope_type`, `scope_value`, and `text` must be non-empty. (The `te - **lesson** — a transferable takeaway - **domain** — topic area -For planned/in-flight or ongoing-context memories, leave `situation`, `action_taken`, `outcome`, `outcome_valence`, `reasoning`, and `lesson` as `null`. The scope fields plus `text` carry the meaning. +For planned/in-flight or ongoing-context memories, leave `situation`, `action_taken`, `outcome`, `outcome_valence`, `reasoning`, and `lesson` as `null`. The scope fields alone carry the meaning. --- @@ -214,34 +125,18 @@ When in doubt: --- ## Existing Memories +The following facts already exist for this user. They are provided so you can avoid re-extracting facts that are already known. If a fact from the new transcript is already captured here, omit it. -### Existing facts (REFERENCE ONLY — never source ADDs from this list) +### Existing Memories Are Context, Not a Source -The list below shows facts already stored for this user. Use it ONLY for these three purposes: +The existing memories block is provided ONLY for grounding and for deciding whether a fact from the new transcript is already known and should be omitted. It is NOT part of the conversation transcript. -1. **Deduplication** — if a fact you would otherwise ADD is already captured (semantically equivalent to an existing entry), omit it (action=NONE; don't include NONE entries in your output). -2. **UPDATE** — if a NEW `[user]:` line in this transcript refines or adds detail to an existing fact in a compatible way, emit one fact with `action=UPDATE` and `supersedes_id` pointing at the existing fact. -3. **CONTRADICT** — if a NEW `[user]:` line in this transcript opposes or negates an existing fact, emit one fact with `action=CONTRADICT` and `supersedes_id` pointing at the existing fact. The new fact's `text` paraphrases what the user actually said (not an invented explicit-negation). - -**Do NOT** treat this list as source material for ADD. Never combine, merge, restate, or "consolidate" entries from this list into a new ADD fact. If the new `[user]:` lines do not directly assert claim X, do NOT emit claim X — regardless of what the existing list contains. Cross-fact merging is the job of a separate reconciliation pass, not this one. +- New memories must come exclusively from the new transcript below. +- Never extract a fact or episodic memory solely because it appears in existing memories. +- If information appears only in existing memories and not in the new transcript, do not emit it. {{existing_facts}} -### Existing episodics - -The following episodic memories already exist for this user, grouped by scope. **Each scope (the pair of `scope_type` + `scope_value`) is the unique identity of an episodic memory** — there is one episodic per scope. Storage is upsert-by-scope: whatever you emit for an existing scope **replaces** the prior record. There is no ADD/UPDATE/CONTRADICT vocabulary for episodics — emit the full current state in `text` and the pipeline does the right thing. - -Decision rule when the transcript mentions a scope: - -- **Already captured, nothing new** — if the same scope is already covered and the transcript adds no new information (e.g. the user just re-states "for the Tokyo trip, I want luxury hotels" and that intent is already in the existing record), **omit it from the output**. Re-emitting it accomplishes nothing. -- **Refinement / extension** — the transcript adds new details to a scope already present (e.g. existing: "Planning Tokyo trip with luxury hotels"; new turn: "let's also book a 5-night Shinjuku stay"). Emit the episodic with **the merged, richer `text`** that includes both the prior intent and the new detail. The new text replaces the old record by scope. -- **Reversal** — the transcript negates or replaces the prior intent for the same scope (e.g. existing: "Planning Tokyo trip with luxury hotels"; new turn: "switching to budget hostels"). Emit the episodic with the **new** `text` describing the updated intent. The reversal replaces the old record by scope. -- **New scope** — the transcript introduces a scope that is not in the list below. Emit a fresh episodic with that scope. - -If two genuinely-distinct events would share the same `(scope_type, scope_value)` (e.g. "lost wallet in Tokyo" and "booked Tokyo hotel" both as `(trip, Tokyo)`), differentiate them via `scope_value` (`Tokyo lost-wallet incident` vs `Tokyo trip`) — not by emitting two records under the same scope. - -{{existing_episodics}} - --- ## Salience Scoring Rubric @@ -283,7 +178,7 @@ A fact is a standing claim that holds outside any specific context. The test: if 1. Could someone act on or reference this fact without reading the original thread? 2. Is this fact stated explicitly, not inferred? 3. Is each fact truly atomic — one claim per entry? -4. Is every emitted fact grounded in a `[user]:` line in this transcript? (Existing-facts entries are reference-only; never source new ADDs from them.) +4. Can any facts be merged because they describe variants of the same thing? **For Episodic:** 1. Did this event actually happen, or is it hypothetical? @@ -312,9 +207,7 @@ A fact is a standing claim that holds outside any specific context. The test: if "confidence": 1.0, "salience": 0.9, "temporal_context": null, - "tags": ["topic:identity"], - "action": "ADD", - "supersedes_id": null + "tags": ["topic:identity"] }, { "text": "Alex is a data engineer at Acme Corp.", @@ -325,9 +218,7 @@ A fact is a standing claim that holds outside any specific context. The test: if "confidence": 1.0, "salience": 0.8, "temporal_context": null, - "tags": ["topic:identity", "topic:career"], - "action": "ADD", - "supersedes_id": null + "tags": ["topic:identity", "topic:career"] }, { "text": "Alex's new ETL pipeline project has a deadline at end of Q2.", @@ -338,9 +229,7 @@ A fact is a standing claim that holds outside any specific context. The test: if "confidence": 0.95, "salience": 0.9, "temporal_context": "end of Q2", - "tags": ["topic:project", "topic:ETL"], - "action": "ADD", - "supersedes_id": null + "tags": ["topic:project", "topic:ETL"] } ], "episodic": [] @@ -365,16 +254,13 @@ A fact is a standing claim that holds outside any specific context. The test: if "confidence": 0.95, "salience": 0.7, "temporal_context": "last month", - "tags": ["topic:kubernetes", "topic:infrastructure"], - "action": "ADD", - "supersedes_id": null + "tags": ["topic:kubernetes", "topic:infrastructure"] } ], "episodic": [ { "scope_type": "incident", "scope_value": "Q3 K8s OOM outage", - "text": "Resolved Q3 K8s OOM outage by raising pod memory limits to 1GB and adding per-namespace resource quotas.", "situation": "Kubernetes pods in production were repeatedly OOM-killed, causing an outage.", "action_taken": "Bumped pod memory limits from 512MB to 1GB and added resource quotas per namespace.", "outcome": "The OOM-killing stopped and production stabilized.", @@ -408,9 +294,7 @@ A fact is a standing claim that holds outside any specific context. The test: if "confidence": 0.95, "salience": 0.7, "temporal_context": null, - "tags": ["topic:database", "topic:BigQuery"], - "action": "ADD", - "supersedes_id": null + "tags": ["topic:database", "topic:BigQuery"] }, { "text": "The marketing team's budget is $50,000 for the current quarter.", @@ -421,9 +305,7 @@ A fact is a standing claim that holds outside any specific context. The test: if "confidence": 0.95, "salience": 0.9, "temporal_context": "current quarter", - "tags": ["topic:budget", "topic:marketing"], - "action": "ADD", - "supersedes_id": null + "tags": ["topic:budget", "topic:marketing"] } ], "episodic": [] @@ -451,16 +333,13 @@ The first statement is a standing preference and belongs in `facts`. The second "confidence": 0.95, "salience": 0.7, "temporal_context": null, - "tags": ["topic:travel", "topic:hotels"], - "action": "ADD", - "supersedes_id": null + "tags": ["topic:travel", "topic:hotels"] } ], "episodic": [ { "scope_type": "trip", "scope_value": "Paris", - "text": "Planning a Paris trip with a luxury-accommodations preference.", "situation": null, "action_taken": null, "outcome": null, @@ -476,126 +355,11 @@ The first statement is a standing preference and belongs in `facts`. The second } ``` -### Example 6: Episodic — same scope, same intent already captured (OMIT) - -**Existing episodics:** -``` -- trip = Tokyo (1 episodic) - - [ID: ep_a1b2c3] (salience 0.8) Planning a Tokyo trip with a luxury hotel preference. -``` - -**Conversation:** -> User: "For this Tokyo trip, I want luxury hotels." -> User: "Normally, I prefer moderate hotels." - -The first user turn re-states an intent that is already captured for `(trip, Tokyo)`. Omit it — re-emitting the same intent for the same scope accomplishes nothing (the pipeline upserts by scope). The second user turn is a standing preference and belongs in `facts`. - -**Output:** -```json -{ - "facts": [ - { - "text": "The user normally prefers moderate hotels.", - "category": "preference", - "subject": "user", - "predicate": "hotel_preference", - "object": "moderate hotels", - "confidence": 0.95, - "salience": 0.7, - "temporal_context": null, - "tags": ["topic:travel", "topic:hotels"], - "action": "ADD", - "supersedes_id": null - } - ], - "episodic": [] -} -``` - -### Example 7: Episodic — refinement (emit merged text for the same scope) - -**Existing episodics:** -``` -- trip = Tokyo (1 episodic) - - [ID: ep_a1b2c3] (salience 0.8) Planning a Tokyo trip with a luxury hotel preference. -``` - -**Conversation:** -> User: "For the Tokyo trip, let's also book a 5-night stay in Shinjuku." - -The transcript adds a new detail (5-night Shinjuku stay) to the existing Tokyo scope. Emit one episodic for `(trip, Tokyo)` with the **merged richer text** that carries both the prior luxury-hotel intent and the new accommodation detail. The pipeline upserts by scope, so this replaces the prior record. - -**Output:** -```json -{ - "facts": [], - "episodic": [ - { - "scope_type": "trip", - "scope_value": "Tokyo", - "text": "Planning a Tokyo trip with a luxury hotel preference and a 5-night stay in Shinjuku.", - "situation": null, - "action_taken": null, - "outcome": null, - "outcome_valence": null, - "reasoning": null, - "lesson": null, - "domain": "travel", - "confidence": 0.95, - "salience": 0.85, - "tags": ["topic:travel", "topic:hotels", "topic:itinerary"] - } - ] -} -``` - -### Example 5: Compound user statement crossing fact categories - -**Conversation:** -> User: "I don't eat meat and I need wheelchair-accessible restaurants." - -A single user turn that combines a `preference` (diet) and a `requirement` (accessibility). These are different categories — they MUST be emitted as separate facts. Collapsing both into one "restaurant preferences" fact silently loses the accessibility constraint, which is the more operationally critical of the two. - -**Output:** -```json -{ - "facts": [ - { - "text": "The user does not eat meat.", - "category": "preference", - "subject": "user", - "predicate": "dietary_restriction", - "object": "no meat", - "confidence": 1.0, - "salience": 0.9, - "temporal_context": null, - "tags": ["topic:diet", "topic:food"], - "action": "ADD", - "supersedes_id": null - }, - { - "text": "The user requires wheelchair-accessible restaurants.", - "category": "requirement", - "subject": "user", - "predicate": "accessibility_requirement", - "object": "wheelchair-accessible restaurants", - "confidence": 1.0, - "salience": 0.95, - "temporal_context": null, - "tags": ["topic:accessibility", "topic:restaurants"], - "action": "ADD", - "supersedes_id": null - } - ], - "episodic": [] -} -``` - --- ## Output Format -You must output ONLY valid JSON matching the schema below. No preamble, no explanation, no closing remarks — just the JSON object. Each array can be empty if no memories of that type are found. Omit entries with action=NONE entirely. +You must output ONLY valid JSON matching the schema below. No preamble, no explanation, no closing remarks — just the JSON object. Each array can be empty if no memories of that type are found. Omit any fact that is already captured in the existing memories. ```json { @@ -609,16 +373,13 @@ You must output ONLY valid JSON matching the schema below. No preamble, no expla "confidence": 0.95, "salience": 0.8, "temporal_context": "optional time ref or null", - "tags": ["topic:x"], - "action": "ADD|UPDATE|CONTRADICT", - "supersedes_id": "id or null" + "tags": ["topic:x"] } ], "episodic": [ { "scope_type": "trip|project|event|session|release|campaign|... (free-form, required, non-empty)", "scope_value": "specific instance, e.g. Paris 2025 (required, non-empty)", - "text": "self-contained one-liner describing the memory — required, non-empty", "situation": "context/problem, or null", "action_taken": "what was tried, or null", "outcome": "what happened, or null", diff --git a/azure/cosmos/agent_memory/services/_pipeline_helpers.py b/azure/cosmos/agent_memory/services/_pipeline_helpers.py index 9748168..e94ce91 100644 --- a/azure/cosmos/agent_memory/services/_pipeline_helpers.py +++ b/azure/cosmos/agent_memory/services/_pipeline_helpers.py @@ -280,47 +280,6 @@ def build_transcript( return "\n".join(parts) -def format_existing_episodics(memories: list[dict[str, Any]]) -> str: - """Render existing episodic memories for the extract_memories prompt. - - Groups by ``(scope_type, scope_value)`` so the LLM can see, per-scope, - which intent is already captured. Episodics use **scope-as-identity**: - the deterministic id is seeded from ``(user_id, scope_type, scope_value)``, - so any re-emission for the same scope (paraphrased intent, added detail, - or a reversal) collides and overwrites the prior record via upsert. The - LLM does NOT make ``ADD``/``UPDATE``/``CONTRADICT`` decisions on - episodics — that vocabulary is not in the episodic schema. - - What this rendering gives the model is per-scope context so it can: - - 1. Emit a single coherent ``text`` that reflects the *current* intent - for the scope (the upsert will overwrite the prior one). - 2. Avoid re-emitting an episodic when the new turn carries no - additional signal beyond what the existing one already records. - - Distinct events under the same umbrella (e.g. hotel booking vs lost - wallet, both under a Tokyo trip) belong under distinct ``scope_value`` - strings so they don't collide on the deterministic id. - """ - if not memories: - return "(none)" - grouped: dict[tuple[str, str], list[dict[str, Any]]] = defaultdict(list) - for mem in memories: - meta = mem.get("metadata") or {} - scope_type = (meta.get("scope_type") or "(none)").strip() or "(none)" - scope_value = (meta.get("scope_value") or "(none)").strip() or "(none)" - grouped[(scope_type, scope_value)].append(mem) - lines: list[str] = [] - for (scope_type, scope_value), bucket in grouped.items(): - lines.append(f"- {scope_type} = {scope_value} ({len(bucket)} episodic{'s' if len(bucket) != 1 else ''})") - for mem in bucket: - mem_id = mem.get("id", "(no-id)") - salience = mem.get("salience", "N/A") - content = (mem.get("content") or "").strip() or "(empty content)" - lines.append(f" - [ID: {mem_id}] (salience {salience}) {content}") - return "\n".join(lines) - - # Stopwords stripped from grounding checks. Keep this list short and focused # on tokens that carry no factual content; any word a memory might legitimately # differ on (e.g. "not", "no") must NOT be added here. @@ -502,13 +461,31 @@ def parse_llm_json(text: str | None) -> dict[str, Any]: cleaned = cleaned.lstrip("`").lstrip() if cleaned.endswith("```"): cleaned = cleaned[:-3] + cleaned = cleaned.strip() try: - return json.loads(cleaned.strip()) + return json.loads(cleaned) except json.JSONDecodeError as exc: preview = (text or "")[:200].replace("\n", " ") + if _looks_truncated(cleaned, exc): + raise LLMError( + "LLM JSON output appears TRUNCATED (decode error at the very end of a " + f"{len(cleaned)}-char body — the model almost certainly hit its output-token " + "cap mid-object). Increase 'maxOutputTokens' in the calling prompty, or reduce " + "the amount of input per call (e.g. lower the fact-extraction batch size / " + f"recent_k, or split oversized turns). Decode error: {exc}. preview={preview!r}" + ) from exc raise LLMError(f"LLM returned invalid JSON (preview={preview!r}): {exc}") from exc +def _looks_truncated(cleaned: str, exc: json.JSONDecodeError) -> bool: + """Heuristic: did the JSON fail because the model ran out of output tokens?""" + if not cleaned: + return False + unbalanced = cleaned.count("{") > cleaned.count("}") or cleaned.count("[") > cleaned.count("]") + unterminated_string = "Unterminated string" in str(exc) + return unbalanced or unterminated_string + + def default_prompts_dir() -> str: """Default ``prompts/`` directory location: under ``azure/cosmos/agent_memory/``.""" pkg_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) diff --git a/azure/cosmos/agent_memory/services/pipeline.py b/azure/cosmos/agent_memory/services/pipeline.py index 2211669..b216816 100644 --- a/azure/cosmos/agent_memory/services/pipeline.py +++ b/azure/cosmos/agent_memory/services/pipeline.py @@ -17,13 +17,20 @@ from typing import Any, Iterable, Literal, Optional from azure.cosmos.exceptions import ( - CosmosHttpResponseError, CosmosResourceExistsError, CosmosResourceNotFoundError, ) +from azure.cosmos.agent_memory import thresholds as threshold_config from azure.cosmos.agent_memory._container_routing import ContainerKey -from azure.cosmos.agent_memory._utils import DEFAULT_TTL_BY_TYPE, compute_content_hash +from azure.cosmos.agent_memory._utils import ( + DEFAULT_TTL_BY_TYPE, + compute_content_hash, + distance_function_from_container_properties, + vector_autodrop_supported, + vector_order_direction, + vector_similarity_at_least, +) from azure.cosmos.agent_memory.exceptions import ( LLMError, MemoryConflictError, @@ -55,9 +62,6 @@ coerce_valence, parse_llm_json, ) -from azure.cosmos.agent_memory.services._pipeline_helpers import ( - format_existing_episodics as _format_existing_episodics, -) from azure.cosmos.agent_memory.services._pipeline_helpers import ( is_real_number as _is_real_number, ) @@ -276,6 +280,183 @@ def _load_existing_memories( ) return items + def _vector_distance_function(self) -> str: + """Return the container's configured Cosmos ``distanceFunction`` (cached). + + Read from the container's vector embedding policy (``container.read()``) — + the authoritative, immutable source set when the container was created. + Drives the ORDER BY direction and similarity-threshold comparisons so dedup + never silently assumes cosine. Falls back to cosine when the policy can't be + read (e.g. ``__new__``-built test instances with mocked containers). + """ + fn = getattr(self, "_distance_function_cache", None) + if fn is not None: + return fn + try: + props = self._memories_container.read() + except Exception: + # Transient read failure (429/503/connection) is indistinguishable from + # "no policy" once we drop to None — so DON'T cache here. Returning an + # uncached cosine default lets the next call self-heal; caching it would + # pin cosine for the instance's life and silently mis-handle a euclidean + # container (cosine bands applied to euclidean distances → data loss). + logger.debug( + "vector dedup: could not read container vector policy; defaulting to cosine (not cached)", + exc_info=True, + ) + return "cosine" + fn = distance_function_from_container_properties(props) + self._distance_function_cache = fn + return fn + + def _warn_euclidean_autodrop_once(self, distance_function: str) -> None: + """One-shot WARN that the near-exact vector auto-drop is disabled. + + The ``DEDUP_SIM_HIGH`` thresholds are cosine-calibrated; on euclidean + the destructive auto-drop is skipped (borderline tagging + LLM reconcile + still run). Logged once per pipeline instance to avoid hot-path spam. + """ + if getattr(self, "_warned_euclidean_autodrop", False): + return + self._warned_euclidean_autodrop = True + logger.warning( + "Container distanceFunction=%r: near-exact vector auto-drop is " + "cosine-calibrated and has been DISABLED for this distance function. " + "Duplicate detection falls back to borderline tagging + LLM reconcile. " + "Use cosine/dotproduct embeddings for vector-floor auto-dedup.", + distance_function, + ) + + def _vector_candidates( + self, + *, + user_id: str, + embedding: list[float], + memory_type: str, + top_k: int, + exclude_ids: set[str], + ) -> list[dict[str, Any]]: + """Return nearest active same-type memories using Cosmos VectorDistance.""" + if not user_id or not embedding or top_k < 1: + return [] + capped_top_k = top_literal(top_k, name="_vector_candidates.top_k") + distance_function = self._vector_distance_function() + order_direction = vector_order_direction(distance_function) + field = "embedding" + query = ( + f"SELECT TOP {capped_top_k} c.id, c.content, c.type, " + f"VectorDistance(c.{field}, @vec) AS score " + "FROM c WHERE c.user_id = @user_id " + "AND c.type = @memory_type " + f"AND {_ACTIVE_DOC_FILTER} " + f"AND IS_DEFINED(c.{field}) " + # Cosmos orders ORDER BY VectorDistance() most-similar-first per the + # container's distanceFunction; an explicit ASC/DESC is rejected (BadRequest). + f"ORDER BY VectorDistance(c.{field}, @vec)" + ) + rows = list( + self._memories_container.query_items( + query=query, + parameters=[ + {"name": "@user_id", "value": user_id}, + {"name": "@memory_type", "value": memory_type}, + {"name": "@vec", "value": embedding}, + ], + enable_cross_partition_query=True, + ) + ) + excluded = set(exclude_ids or set()) + candidates = [ + { + "id": row.get("id"), + "content": row.get("content"), + "type": row.get("type"), + "score": float(row.get("score") or 0.0), + } + for row in rows + if row.get("id") and row.get("id") not in excluded + ] + # Most-similar-first: descending score for cosine/dotproduct, ascending for euclidean. + candidates.sort( + key=lambda row: row.get("score", 0.0), + reverse=order_direction == "DESC", + ) + return candidates + + def _query_active_memories( + self, + user_id: str, + memory_type: str, + *, + limit: int | None = None, + tagged_only: bool = False, + ) -> list[dict[str, Any]]: + top_clause = f"TOP {top_literal(limit, name='_query_active_memories.limit')} " if limit else "" + tag_clause = "AND ARRAY_CONTAINS(c.tags, @tag) " if tagged_only else "" + query = ( + f"SELECT {top_clause}* FROM c WHERE c.user_id = @user_id " + "AND c.type = @memory_type " + f"AND {_ACTIVE_DOC_FILTER} " + f"{tag_clause}" + "ORDER BY c.created_at DESC" + ) + parameters = [ + {"name": "@user_id", "value": user_id}, + {"name": "@memory_type", "value": memory_type}, + ] + if tagged_only: + parameters.append({"name": "@tag", "value": "sys:dup-candidate"}) + return list( + self._memories_container.query_items( + query=query, + parameters=parameters, + enable_cross_partition_query=True, + ) + ) + + def _load_memories_by_ids( + self, user_id: str, memory_type: str, ids: Iterable[str] + ) -> list[dict[str, Any]]: + ids = [mid for mid in dict.fromkeys(ids) if mid] + if not ids: + return [] + placeholders = ", ".join(f"@id{i}" for i in range(len(ids))) + parameters = [ + {"name": "@user_id", "value": user_id}, + {"name": "@memory_type", "value": memory_type}, + ] + parameters.extend({"name": f"@id{i}", "value": mid} for i, mid in enumerate(ids)) + query = ( + "SELECT * FROM c WHERE c.user_id = @user_id " + "AND c.type = @memory_type " + f"AND c.id IN ({placeholders}) " + f"AND {_ACTIVE_DOC_FILTER}" + ) + return list( + self._memories_container.query_items( + query=query, + parameters=parameters, + enable_cross_partition_query=True, + ) + ) + + def _clear_dup_candidate_tags(self, docs: Iterable[dict[str, Any]]) -> None: + for doc in docs: + tags = [tag for tag in (doc.get("tags") or []) if tag != "sys:dup-candidate"] + if tags == (doc.get("tags") or []): + continue + updated = dict(doc) + updated["tags"] = tags + metadata = dict(updated.get("metadata") or {}) + metadata.pop("dup_of", None) + metadata.pop("dup_score", None) + updated["metadata"] = metadata + updated["updated_at"] = datetime.now(timezone.utc).isoformat() + try: + self._upsert_memory(updated) + except Exception: + logger.exception("reconcile_memories: failed to clear dup-candidate tag id=%s", doc.get("id")) + def _upsert_memory(self, doc: dict[str, Any]) -> dict[str, Any]: """Upsert a fact, episodic, or procedural document to the memories container.""" response = self._memories_container.upsert_item(body=doc) @@ -314,56 +495,6 @@ def _stable_source_timestamp(items: list[dict[str, Any]]) -> str: return max(timestamps) return datetime.now(timezone.utc).isoformat() - def _mark_extracted_superseded( - self, - *, - user_id: str, - thread_id: str, - supersedes_id: str, - superseder_id: str, - reason: Literal["update", "contradict"], - ) -> bool: - try: - old_mem = self._memories_container.read_item(item=supersedes_id, partition_key=[user_id, thread_id]) - if old_mem.get("superseded_by"): - logger.debug( - "extract_memories: skipping UPDATE — target %s already superseded by %s", - supersedes_id, - old_mem.get("superseded_by"), - ) - return False - return self._mark_superseded(old_mem, superseder_id, reason=reason) - except CosmosResourceNotFoundError: - logger.debug( - "extract_memories: %s not found at (user_id, thread_id) — retrying cross-partition", - supersedes_id, - ) - except Exception as exc: - # Includes 429s, 503s, transient connection errors — surface at WARNING - # so they're not masked by the silent cross-partition fallback below. - logger.warning( - "extract_memories: read_item failed for %s (%s); retrying cross-partition", - supersedes_id, - type(exc).__name__, - ) - try: - q = f"SELECT * FROM c WHERE c.id = @id AND c.user_id = @uid AND {_ACTIVE_DOC_FILTER}" - results = list( - self._memories_container.query_items( - query=q, - parameters=[ - {"name": "@id", "value": supersedes_id}, - {"name": "@uid", "value": user_id}, - ], - enable_cross_partition_query=True, - ) - ) - if results and not results[0].get("superseded_by"): - return self._mark_superseded(results[0], superseder_id, reason=reason) - except CosmosHttpResponseError as exc: - logger.warning("Failed to mark superseded memory %s: %s", supersedes_id, exc) - return False - def _mark_superseded( self, old_doc: dict[str, Any], @@ -462,28 +593,35 @@ def extract_memories_dry( logger.warning("extract_memories_dry no memories found user_id=%s thread_id=%s", user_id, thread_id) return {"facts": [], "episodic": [], "updates": [], "processed_turn_docs": []} - existing_facts = self._load_existing_memories(user_id, ["fact"]) - existing_episodics = self._load_existing_memories(user_id, ["episodic"]) + existing_for_hashes = self._load_existing_memories(user_id, ["fact"]) existing_fact_hashes: set[str] = { - m["content_hash"] for m in existing_facts if m.get("type") == "fact" and m.get("content_hash") + m["content_hash"] for m in existing_for_hashes if m.get("type") == "fact" and m.get("content_hash") } - if existing_facts: + transcript = self._build_transcript(items) + existing = existing_for_hashes + if threshold_config.get_dedup_context_vector_enabled(): + user_turns_text = "\n".join( + str(it.get("content", "")) for it in items if it.get("role") == "user" + ).strip() + context_query = user_turns_text or transcript + existing = self._store.search( + search_terms=context_query, + user_id=user_id, + memory_types=["fact"], + top_k=threshold_config.get_dedup_context_topk(), + ) + if existing: existing_text = "\n".join( - f"- [ID: {mem['id']}] {mem.get('content', '')} (type=fact, salience={mem.get('salience', 'N/A')})" - for mem in existing_facts + f"- [ID: {mem['id']}] {mem.get('content', '')} " + f"(type={mem.get('type', 'fact')}, salience={mem.get('salience', 'N/A')})" + for mem in existing ) else: existing_text = "(none)" - existing_episodics_text = _format_existing_episodics(existing_episodics) - transcript = self._build_transcript(items) response_text = self._run_prompty( "extract_memories.prompty", - inputs={ - "existing_facts": existing_text, - "existing_episodics": existing_episodics_text, - "transcript": transcript, - }, + inputs={"existing_facts": existing_text, "transcript": transcript}, ) parsed = self._parse_llm_json(response_text) facts = parsed.get("facts", []) @@ -498,14 +636,13 @@ def extract_memories_dry( dropped_episodic_count = 0 for fact in facts: - action = fact.get("action", "ADD").upper() text = fact.get("text") if not text: logger.warning("extract_memories: dropping malformed fact (missing 'text'): %r", fact) continue new_content_hash = compute_content_hash(text) - if action == "ADD" and new_content_hash in existing_fact_hashes: + if new_content_hash in existing_fact_hashes: logger.debug( "extract_memories: skipping exact-dup fact hash=%s user_id=%s thread_id=%s", new_content_hash, @@ -542,22 +679,6 @@ def extract_memories_dry( "updated_at": doc_timestamp, } - if action in {"UPDATE", "CONTRADICT"} and fact.get("supersedes_id"): - reason: Literal["update", "contradict"] = "contradict" if action == "CONTRADICT" else "update" - if det_id == fact["supersedes_id"]: - logger.debug("extract_memories: skipping UPDATE — det_id == supersedes_id (%s)", det_id) - continue - doc["supersedes_ids"] = [fact["supersedes_id"]] - updates.append( - { - "op": "supersede", - "supersedes_id": fact["supersedes_id"], - "superseder_id": det_id, - "thread_id": thread_id, - "reason": reason, - } - ) - fact_docs.append(self._validate_extracted_doc(doc)) existing_fact_hashes.add(new_content_hash) @@ -577,28 +698,16 @@ def extract_memories_dry( dropped_episodic_count += 1 continue - text_raw = ep.get("text") - text = text_raw.strip() if isinstance(text_raw, str) else None - if not text: - logger.error( - "extract_memories: dropping episodic with empty/missing text field " - "(LLM extraction did not populate the required `text` field — likely a " - "weaker extraction model that needs upgrading or a prompt-compliance issue). " - "scope_type=%s scope_value=%s user_id=%s thread_id=%s reason=missing_text", - scope_type, - scope_value, - user_id, - thread_id, - ) - dropped_episodic_count += 1 - continue - situation = ep.get("situation") action_taken = ep.get("action_taken") outcome = ep.get("outcome") + if situation and action_taken and outcome: + text = f"{situation} → {action_taken} → {outcome}" + else: + text = f"For the user's {scope_value} {scope_type}, intent recorded." content_hash = compute_content_hash(text) - seed = _ID_SEED_SEP.join((user_id, scope_type, scope_value)) + seed = _ID_SEED_SEP.join((user_id, thread_id, content_hash)) det_id = f"ep_{hashlib.sha256(seed.encode()).hexdigest()[:32]}" topic_tags = build_topic_tags(ep.get("tags", [])) confidence = ep.get("confidence") @@ -615,7 +724,7 @@ def extract_memories_dry( doc = { "id": det_id, "user_id": user_id, - "thread_id": "__episodic__", + "thread_id": thread_id, "role": "system", "type": "episodic", "content": text, @@ -626,7 +735,6 @@ def extract_memories_dry( "metadata": { "scope_type": scope_type, "scope_value": scope_value, - "originating_thread_id": thread_id, "situation": situation, "action_taken": action_taken, "outcome": outcome, @@ -690,7 +798,7 @@ def extract_memories_dry( check_extracted_fact_grounding( fact_docs, items, - existing_facts, + existing, user_id=user_id, thread_id=thread_id, logger=logger, @@ -712,6 +820,104 @@ def extract_memories_dry( ) return result + def dedup_extracted_memories( + self, + user_id: str, + extracted: dict[str, list[dict[str, Any]]], + ) -> dict[str, list[dict[str, Any]]]: + """Apply the gated Stage-3 vector dedup ladder to extracted docs.""" + if not threshold_config.get_dedup_vector_enabled(): + return extracted + if not user_id: + raise ValidationError("user_id is required") + if not isinstance(extracted, dict): + raise ValidationError("extracted must be a dict") + + high = threshold_config.get_dedup_sim_high() + low = threshold_config.get_dedup_sim_low() + top_k = threshold_config.get_dedup_candidate_topk() + distance_function = self._vector_distance_function() + autodrop_ok = vector_autodrop_supported(distance_function) + if not autodrop_ok: + self._warn_euclidean_autodrop_once(distance_function) + vector_dedup_skipped = 0 + dup_candidates_tagged = 0 + + result: dict[str, list[dict[str, Any]]] = { + "facts": [dict(doc) for doc in extracted.get("facts", [])], + "episodic": [dict(doc) for doc in extracted.get("episodic", [])], + "updates": [dict(op) for op in extracted.get("updates", [])], + } + docs = [doc for bucket in ("facts", "episodic") for doc in result[bucket] if doc.get("content")] + if not docs: + return result + + missing_embeddings = [doc for doc in docs if not doc.get("embedding")] + if missing_embeddings: + embeddings = self._embed_batch([str(doc["content"]) for doc in missing_embeddings]) + for doc, embedding in zip(missing_embeddings, embeddings): + doc["embedding"] = embedding + + kept_ids: set[str] = set() + dropped_ids: set[str] = set() + for doc in docs: + doc_id = str(doc.get("id") or "") + memory_type = str(doc.get("type") or "") + if not doc_id or memory_type not in {"fact", "episodic"}: + continue + + best: dict[str, Any] | None = None + candidates = self._vector_candidates( + user_id=user_id, + embedding=doc.get("embedding") or [], + memory_type=memory_type, + top_k=top_k, + exclude_ids=kept_ids | dropped_ids | {doc_id} | set(doc.get("supersedes_ids") or []), + ) + if candidates: + best = candidates[0] + + score = float(best.get("score") or 0.0) if best else 0.0 + if best and autodrop_ok and vector_similarity_at_least(score, high, distance_function): + logger.info( + "vector dedup skipped new memory id=%s type=%s score=%.4f surviving_id=%s " + "new_content=%r surviving_content=%r", + doc_id, + memory_type, + score, + best.get("id"), + doc.get("content"), + best.get("content"), + ) + vector_dedup_skipped += 1 + dropped_ids.add(doc_id) + continue + + if best and vector_similarity_at_least(score, low, distance_function): + tags = list(doc.get("tags") or []) + if "sys:dup-candidate" not in tags: + tags.append("sys:dup-candidate") + doc["tags"] = tags + metadata = dict(doc.get("metadata") or {}) + metadata["dup_of"] = best.get("id") + metadata["dup_score"] = score + doc["metadata"] = metadata + dup_candidates_tagged += 1 + + kept_ids.add(doc_id) + + for bucket in ("facts", "episodic"): + result[bucket] = [doc for doc in result[bucket] if doc.get("id") not in dropped_ids] + if vector_dedup_skipped or dup_candidates_tagged: + result["updates"].append( + { + "op": "stats", + "vector_dedup_skipped": vector_dedup_skipped, + "dup_candidates_tagged": dup_candidates_tagged, + } + ) + return result + def persist_extracted_memories( self, user_id: str, @@ -764,27 +970,9 @@ def persist_extracted_memories( if op.get("op") == "stats": result["exact_dedup_skipped"] += int(op.get("exact_dedup_skipped") or 0) result["dropped_episodic_count"] += int(op.get("dropped_episodic_count") or 0) - continue - if op.get("op") != "supersede": - continue - reason = op.get("reason") - op_thread_id = op.get("thread_id") - supersedes_id = op.get("supersedes_id") - superseder_id = op.get("superseder_id") - if reason not in {"update", "contradict"} or not op_thread_id or not supersedes_id or not superseder_id: - continue - marked = self._mark_extracted_superseded( - user_id=user_id, - thread_id=op_thread_id, - supersedes_id=supersedes_id, - superseder_id=superseder_id, - reason=reason, - ) - if marked: - if reason == "contradict": - result["contradicted_count"] += 1 - else: - result["updated_count"] += 1 + for key in ("vector_dedup_skipped", "dup_candidates_tagged"): + if key in op: + result[key] = result.get(key, 0) + int(op.get(key) or 0) logger.info("persist_extracted_memories completed user_id=%s counts=%s", user_id, result) @@ -839,6 +1027,8 @@ def extract_memories( ) -> dict[str, int]: """Extract facts and episodic memories from a thread and persist them.""" extracted = self.extract_memories_dry(user_id, thread_id, recent_k, turns=turns) + if threshold_config.get_dedup_vector_enabled(): + extracted = self.dedup_extracted_memories(user_id, extracted) return self.persist_extracted_memories(user_id, extracted) def synthesize_procedural( @@ -1403,7 +1593,171 @@ def _emit_reconcile_outcome( }, ) - def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: + def _build_candidate_clusters( + self, user_id: str, memory_type: str, n: int + ) -> tuple[list[list[dict[str, Any]]], int, list[dict[str, Any]]]: + """Cluster dup-candidate seeds (+ vector neighbors) into connected components. + + Returns ``(clusters, node_count, seeds)`` where each cluster has >= 2 members, + ``node_count`` is the total distinct memories pulled into the graph (used to + report ``reconcile_llm_calls_saved``), and ``seeds`` is the tagged seed scan + (so the caller can clear stale tags on orphan seeds that never clustered). + The seed scan is bounded to ``n`` so a single cluster can never exceed the + reconcile prompt's pool cap. + """ + cluster_sim = threshold_config.get_dedup_cluster_sim() + top_k = threshold_config.get_dedup_candidate_topk() + distance_function = self._vector_distance_function() + seeds = self._query_active_memories(user_id, memory_type, limit=n, tagged_only=True) + nodes_by_id: dict[str, dict[str, Any]] = {doc["id"]: doc for doc in seeds if doc.get("id")} + edge_pairs: set[tuple[str, str]] = set() + + dup_of_ids = [ + (doc.get("metadata") or {}).get("dup_of") + for doc in seeds + if isinstance(doc.get("metadata"), dict) and (doc.get("metadata") or {}).get("dup_of") + ] + for doc in self._load_memories_by_ids(user_id, memory_type, dup_of_ids): + if doc.get("id"): + nodes_by_id[doc["id"]] = doc + + for seed in seeds: + seed_id = seed.get("id") + if not seed_id: + continue + dup_of = (seed.get("metadata") or {}).get("dup_of") if isinstance(seed.get("metadata"), dict) else None + if dup_of: + edge_pairs.add(tuple(sorted((seed_id, dup_of)))) + if not seed.get("embedding"): + continue + candidates = self._vector_candidates( + user_id=user_id, + embedding=seed.get("embedding") or [], + memory_type=memory_type, + top_k=top_k, + exclude_ids={seed_id}, + ) + candidate_ids = [ + row["id"] + for row in candidates + if row.get("id") + and vector_similarity_at_least(row.get("score", 0.0), cluster_sim, distance_function) + ] + for cid in candidate_ids: + edge_pairs.add(tuple(sorted((seed_id, cid)))) + for doc in self._load_memories_by_ids(user_id, memory_type, candidate_ids): + if doc.get("id"): + nodes_by_id[doc["id"]] = doc + + node_ids = list(nodes_by_id) + node_id_set = set(node_ids) + for node_id in node_ids: + node = nodes_by_id[node_id] + if not node.get("embedding"): + continue + for row in self._vector_candidates( + user_id=user_id, + embedding=node.get("embedding") or [], + memory_type=memory_type, + top_k=top_k, + exclude_ids={node_id}, + ): + neighbor_id = row.get("id") + if neighbor_id in node_id_set and vector_similarity_at_least( + float(row.get("score") or 0.0), cluster_sim, distance_function + ): + edge_pairs.add(tuple(sorted((node_id, neighbor_id)))) + + adjacency: dict[str, set[str]] = {node_id: set() for node_id in node_ids} + for left_id, right_id in edge_pairs: + if left_id != right_id and left_id in adjacency and right_id in adjacency: + adjacency[left_id].add(right_id) + adjacency[right_id].add(left_id) + + clusters: list[list[dict[str, Any]]] = [] + seen: set[str] = set() + for node_id in node_ids: + if node_id in seen: + continue + stack = [node_id] + component: list[str] = [] + seen.add(node_id) + while stack: + current = stack.pop() + component.append(current) + for neighbor in adjacency.get(current, set()): + if neighbor not in seen: + seen.add(neighbor) + stack.append(neighbor) + if len(component) >= 2: + # Cap cluster size at the reconcile pool limit: lowering the cluster + # threshold can chain many facts into one giant transitive component + # that would blow the prompt cap; keep the most-recent ``n``. + if len(component) > n: + component = component[:n] + clusters.append([nodes_by_id[cid] for cid in component]) + return clusters, len(nodes_by_id), seeds + + def _reconcile_candidate_mode( + self, user_id: str, *, n: int, memory_type: str, started_at: float + ) -> dict[str, int]: + # Candidate clustering only. The periodic full-pool backstop that catches + # dissimilar-embedding contradictions ("vegetarian" vs "loves steak") is + # driven by the caller via ``full_rebuild`` on a PERSISTED-counter cadence + # (in-process auto-trigger + durable change-feed), not an in-memory sweep + # counter — the latter reset per worker/process and never fired reliably on + # the Function-App backend. + clusters, node_count, seeds = self._build_candidate_clusters(user_id, memory_type, n) + aggregate = {"kept": 0, "merged": 0, "contradicted": 0} + clustered_ids: set[str] = set() + for cluster in clusters: + # Mark members as clustered BEFORE the LLM call so a failed cluster's + # seeds are not treated as orphans below (which would clear their tags + # and prevent a retry). A truncated/malformed LLM response on one + # cluster must not abort the sweep or starve the remaining clusters. + clustered_ids.update(doc["id"] for doc in cluster if doc.get("id")) + try: + counts, consumed = self._reconcile_pool(user_id, memory_type, cluster) + except Exception as exc: + logger.warning( + "reconcile_memories: cluster reconcile failed user_id=%s memory_type=%s; " + "skipping cluster, tags retained for next sweep: %s", + user_id, + memory_type, + exc, + ) + continue + for key in aggregate: + aggregate[key] += int(counts.get(key, 0)) + # Clear dup-candidate tags only on survivors. Re-upserting a doc that + # was just superseded (duplicate source or contradiction loser) would + # resurrect it, since the in-memory cluster copy lacks superseded_by. + survivors = [doc for doc in cluster if doc.get("id") and doc["id"] not in consumed] + self._clear_dup_candidate_tags(survivors) + # Orphan seeds: tagged dup-candidates that never joined a cluster have no + # near-duplicate, so clear the stale tag — otherwise every future sweep + # re-scans them as seeds and they accumulate forever. + orphan_seeds = [seed for seed in seeds if seed.get("id") and seed["id"] not in clustered_ids] + self._clear_dup_candidate_tags(orphan_seeds) + aggregate["reconcile_clusters_sent"] = len(clusters) + aggregate["reconcile_llm_calls_saved"] = max(0, node_count - len(clusters)) + logger.info( + "reconcile_memories candidate completed user_id=%s memory_type=%s result=%s", + user_id, + memory_type, + aggregate, + ) + self._emit_reconcile_outcome( + started_at=started_at, + user_id=user_id, + candidates=node_count, + result=aggregate, + ) + return aggregate + + def reconcile_memories( + self, user_id: str, n: int = 50, *, memory_type: str = "fact", full_rebuild: bool = False + ) -> dict[str, int]: """Reconcile a user's active facts in a single LLM pass. Loads the most recent ``n`` active (non-superseded) facts for @@ -1418,6 +1772,12 @@ def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: the winner. Dangling references are resolved transparently when a contradicted id was just absorbed into a duplicate group. + ``full_rebuild`` (set by the explicit public ``reconcile()`` entrypoint) + makes candidate mode cluster the whole active pool instead of only + Stage-3-tagged seeds, so a user-triggered reconcile dedups every active + fact. Automatic background sweeps leave it ``False`` for the cheap + tagged-only path. + Returns ``{"kept": int, "merged": int, "contradicted": int}`` where ``merged`` and ``contradicted`` count the *losers* that were soft-deleted (duplicates and contradictions respectively). @@ -1428,46 +1788,97 @@ def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: raise ValidationError(f"n must be a positive integer, got {n!r}") if n > 500: raise ValidationError(f"n must be <= 500 to bound prompt size and LLM cost, got {n}") + if memory_type not in {"fact", "episodic", "procedural"}: + raise ValidationError(f"memory_type must be one of fact, episodic, procedural; got {memory_type!r}") + if memory_type == "procedural": + result = { + "kept": 0, + "merged": 0, + "contradicted": 0, + "reconcile_clusters_sent": 0, + "reconcile_llm_calls_saved": 0, + } + logger.info("reconcile_memories procedural no-op user_id=%s result=%s", user_id, result) + return result started_at = time.monotonic() - logger.info("reconcile_memories started user_id=%s n=%d", user_id, n) + logger.info("reconcile_memories started user_id=%s n=%d memory_type=%s", user_id, n, memory_type) + + # Explicit user-triggered reconcile (full_rebuild) always takes the + # full-pool single-LLM-pass path: it sees every active fact together, so it + # catches contradictions that aren't vector-similar (e.g. "vegetarian" vs + # "loves steak") — which candidate clustering, keyed on near-duplicate + # similarity, would never group. Automatic sweeps use cheap candidate mode. + if threshold_config.get_dedup_reconcile_mode() == "candidate" and not full_rebuild: + return self._reconcile_candidate_mode( + user_id, n=n, memory_type=memory_type, started_at=started_at + ) - # ---- 1. Load up to N most recent active facts ---- + facts = self._active_memories_for_reconcile(user_id, memory_type, n) + result, consumed = self._reconcile_pool(user_id, memory_type, facts) + # Clear dup-candidate tags on survivors so an explicit reconcile(full_rebuild=True) + # doesn't leave stale sys:dup-candidate/dup_of metadata on user-visible memories. + survivors = [doc for doc in facts if doc.get("id") and doc["id"] not in consumed] + self._clear_dup_candidate_tags(survivors) + self._emit_reconcile_outcome( + started_at=started_at, + user_id=user_id, + candidates=len(facts), + result=result, + ) + return result + + def _active_memories_for_reconcile( + self, user_id: str, memory_type: str, n: int + ) -> list[dict[str, Any]]: + # ---- Load up to N most recent active memories ---- # ORDER BY c.created_at DESC keeps the TOP cap deterministic across # physical partitions and matches the dedup prompt's tiebreaker # ("more recent created_at first"). Cosmos's _ts is the last-write # timestamp, which would diverge from created_at after any UPDATE. query = ( - f"SELECT TOP {n} * FROM c " + f"SELECT TOP {top_literal(n, name='reconcile_memories.n')} * FROM c " "WHERE c.user_id = @user_id " - "AND c.type = 'fact' " + "AND c.type = @memory_type " f"AND {_ACTIVE_DOC_FILTER} " "ORDER BY c.created_at DESC" ) - parameters: list[dict[str, Any]] = [ - {"name": "@user_id", "value": user_id}, - ] - facts = list( + return list( self._memories_container.query_items( query=query, - parameters=parameters, + parameters=[ + {"name": "@user_id", "value": user_id}, + {"name": "@memory_type", "value": memory_type}, + ], enable_cross_partition_query=True, ) ) + def _reconcile_pool( + self, user_id: str, memory_type: str, facts: list[dict[str, Any]] + ) -> tuple[dict[str, int], set[str]]: + """Reconcile an explicit pool of same-type memories in one LLM pass. + + Returns ``({"kept", "merged", "contradicted"}, consumed_ids)`` where + ``consumed_ids`` are the source/loser ids that were actually superseded, + so callers can skip them when clearing dup-candidate tags (re-upserting a + superseded source would resurrect it). Does not emit telemetry — the + caller owns the ``reconcile.outcome`` line. + """ + # Polarity keyed on an explicit predicate so a future third type (e.g. + # procedural, were it ever routed here) can't silently diverge from async. + is_episodic = memory_type == "episodic" + prompt_filename = "dedup_episodic.prompty" if is_episodic else "dedup.prompty" + prompt_input_key = "episodics_text" if is_episodic else "facts_text" + allow_contradictions = not is_episodic + if len(facts) <= 1: logger.info( - "reconcile_memories: %d facts, nothing to reconcile", + "reconcile_memories: %d %s memories, nothing to reconcile", len(facts), + memory_type, ) - early_result = {"kept": len(facts), "merged": 0, "contradicted": 0} - self._emit_reconcile_outcome( - started_at=started_at, - user_id=user_id, - candidates=len(facts), - result=early_result, - ) - return early_result + return {"kept": len(facts), "merged": 0, "contradicted": 0}, set() # ---- 2. Format the facts pool for the prompt ---- # ``json.dumps`` escapes embedded quotes and pipes inside content so @@ -1495,13 +1906,13 @@ def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: # ---- 3. Single LLM call over the entire pool ---- response_text = self._run_prompty( - "dedup.prompty", - inputs={"facts_text": facts_text}, + prompt_filename, + inputs={prompt_input_key: facts_text}, ) parsed = self._parse_llm_json(response_text) duplicate_groups = parsed.get("duplicate_groups", []) or [] - contradicted_pairs = parsed.get("contradicted_pairs", []) or [] + contradicted_pairs = (parsed.get("contradicted_pairs", []) or []) if allow_contradictions else [] # ``kept_ids`` from the LLM is used below as a cross-check for # accounting drift (hallucinated IDs, double-counting). The actual # kept count is computed from facts minus consumed losers. @@ -1573,15 +1984,19 @@ def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: source_docs.sort(key=lambda d: d.get("_ts", 0), reverse=True) # Union tags across all source docs (preserve order, dedupe). + # Strip sys:dup-candidate so the merged doc isn't reborn as a + # permanent reconcile seed. merged_tags: list[str] = [] seen_tags: set[str] = set() for src in source_docs: for t in src.get("tags", []) or []: + if t == "sys:dup-candidate": + continue if t not in seen_tags: seen_tags.add(t) merged_tags.append(t) if not merged_tags: - merged_tags = ["sys:fact"] + merged_tags = [f"sys:{memory_type}"] # Union source_memory_ids across all source docs (provenance chain). merged_source_memory_ids: list[str] = [] @@ -1638,16 +2053,15 @@ def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: # see the same id rather than chaining through a new UUID. merged_content_hash = compute_content_hash(merged_content) merged_id_seed = _ID_SEED_SEP.join((user_id, "merged", merged_content_hash)) - merged_id = "fact_" + hashlib.sha256(merged_id_seed.encode()).hexdigest()[:32] + id_prefix = "fact_" if memory_type == "fact" else "ep_" + merged_id = id_prefix + hashlib.sha256(merged_id_seed.encode()).hexdigest()[:32] try: - merged_record = construct_internal( - FactRecord, - { + merged_payload: dict[str, Any] = { "id": merged_id, "user_id": user_id, "role": "system", - "type": "fact", + "type": memory_type, "content": merged_content, "thread_id": recent_thread_id or f"__reconciled__:{user_id}", "confidence": confidence_val if confidence_val is not None else 0.5, @@ -1656,16 +2070,32 @@ def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: "source_memory_ids": merged_source_memory_ids, "tags": merged_tags, "content_hash": merged_content_hash, - "metadata": { - "category": "preference", - "merged_via": "reconcile", - "merged_from_count": len(valid_source_ids), - }, - **self._prompt_lineage("dedup.prompty"), + **self._prompt_lineage(prompt_filename), "created_at": datetime.now(timezone.utc).isoformat(), "updated_at": datetime.now(timezone.utc).isoformat(), - }, - ) + } + if memory_type == "fact": + merged_payload["metadata"] = { + "category": "preference", + "merged_via": "reconcile", + "merged_from_count": len(valid_source_ids), + } + record_cls = FactRecord + else: + source_meta = dict(source_docs[0].get("metadata") or {}) + source_meta.update( + { + "lesson": merged_content, + "merged_via": "reconcile", + "merged_from_count": len(valid_source_ids), + } + ) + source_meta.setdefault("scope_type", source_meta.get("scope_type") or "general") + source_meta.setdefault("scope_value", source_meta.get("scope_value") or "general") + source_meta.setdefault("outcome_valence", source_meta.get("outcome_valence") or "neutral") + merged_payload["metadata"] = source_meta + record_cls = EpisodicRecord + merged_record = construct_internal(record_cls, merged_payload) except Exception: logger.exception( "reconcile_memories: failed to build merged record for group %r", @@ -1833,13 +2263,7 @@ def reconcile_memories(self, user_id: str, n: int = 50) -> dict[str, int]: ) result = {"kept": kept, "merged": merged, "contradicted": contradicted} logger.info("reconcile_memories completed: %s", result) - self._emit_reconcile_outcome( - started_at=started_at, - user_id=user_id, - candidates=len(facts), - result=result, - ) - return result + return result, consumed_ids def build_procedural_context(self, user_id: str) -> str: """Return the active synthesized procedural prompt for system injection.""" diff --git a/azure/cosmos/agent_memory/store/_search_helpers.py b/azure/cosmos/agent_memory/store/_search_helpers.py index b0f6c5e..032965c 100644 --- a/azure/cosmos/agent_memory/store/_search_helpers.py +++ b/azure/cosmos/agent_memory/store/_search_helpers.py @@ -94,16 +94,27 @@ def build_search_sql( *, qb: _QueryBuilder, top: int, - hybrid_search: bool, + keyword_count: int, include_superseded: bool, ) -> str: + """Build the search SQL. + + When ``keyword_count > 0`` the query is hybrid: ``RANK RRF`` fuses the vector + distance with ``FullTextScore`` over ``keyword_count`` individual keyword + parameters (``@kw0``..``@kw{n-1}``). When there are no keywords (e.g. an + all-stopword query) it falls back to pure vector ranking. ``similarity_score`` + is always the vector distance and is *not* the RRF ranking basis under hybrid. + """ if not include_superseded: qb.add_is_null_or_undefined("c.superseded_by") - order_by = "ORDER BY VectorDistance(c.embedding, @embedding)" - if hybrid_search: - order_by = "ORDER BY RANK RRF(VectorDistance(c.embedding, @embedding), FullTextScore(c.content, @key_terms))" + vector_distance = "VectorDistance(c.embedding, @embedding)" + if keyword_count > 0: + keyword_params = ", ".join(f"@kw{i}" for i in range(keyword_count)) + order_by = f"ORDER BY RANK RRF({vector_distance}, FullTextScore(c.content, {keyword_params}))" + else: + order_by = f"ORDER BY {vector_distance}" return ( f"SELECT TOP {top} {MEMORY_PROJECTION}, " - "VectorDistance(c.embedding, @embedding) AS similarity_score " + f"{vector_distance} AS similarity_score " f"FROM c{qb.build_where()} {order_by}" ) diff --git a/azure/cosmos/agent_memory/store/memory_store.py b/azure/cosmos/agent_memory/store/memory_store.py index 5755311..dd5d345 100644 --- a/azure/cosmos/agent_memory/store/memory_store.py +++ b/azure/cosmos/agent_memory/store/memory_store.py @@ -15,8 +15,8 @@ from azure.cosmos.agent_memory._utils import ( _build_memory_query_builder, _coerce_datetime_iso, - _validate_hybrid_search, compute_content_hash, + extract_keywords, new_id, ) from azure.cosmos.agent_memory.exceptions import ( @@ -840,7 +840,6 @@ def search( role: Optional[str] = None, memory_types: Optional[list[str]] = None, thread_id: Optional[str] = None, - hybrid_search: bool = False, top_k: int = 5, tags_all: Optional[list[str]] = None, tags_any: Optional[list[str]] = None, @@ -859,9 +858,9 @@ def search( :meth:`search_turns` to vector-search the raw conversation log instead. """ terms = require_search_terms(search_terms, query) - _validate_hybrid_search(hybrid_search, terms) top = top_literal(top_k, name="top_k") query_vector = self._embed(terms) + keywords = extract_keywords(terms) qb = _build_memory_query_builder( memory_id=memory_id, @@ -881,11 +880,16 @@ def search( ) add_salience_filter(qb, min_salience) - sql = build_search_sql(qb=qb, top=top, hybrid_search=hybrid_search, include_superseded=include_superseded) + sql = build_search_sql( + qb=qb, + top=top, + keyword_count=len(keywords), + include_superseded=include_superseded, + ) parameters = qb.get_parameters() parameters.append({"name": "@embedding", "value": query_vector}) - if hybrid_search: - parameters.append({"name": "@key_terms", "value": terms}) + for i, kw in enumerate(keywords): + parameters.append({"name": f"@kw{i}", "value": kw}) partition_key, cross_partition = query_scope(user_id, thread_id) if thread_id is not None and (not memory_types or set(memory_types) & USER_SCOPED_MEMORIES_TYPES): @@ -905,7 +909,6 @@ def search_turns( user_id: Optional[str] = None, thread_id: Optional[str] = None, role: Optional[str] = None, - hybrid_search: bool = False, top_k: int = 5, tags_all: Optional[list[str]] = None, tags_any: Optional[list[str]] = None, @@ -915,7 +918,7 @@ def search_turns( *, query: Optional[str] = None, ) -> list[dict[str, Any]]: - """Search raw conversation turns using vector similarity with optional hybrid ranking. + """Search raw conversation turns using vector similarity with hybrid ranking. Only vector-searchable when turn embeddings were enabled at write time (see ``enable_turn_embeddings``). ``user_id`` is required and always @@ -927,9 +930,9 @@ def search_turns( if not user_id: raise ValidationError("user_id is required for search_turns") terms = require_search_terms(search_terms, query) - _validate_hybrid_search(hybrid_search, terms) top = top_literal(top_k, name="top_k") query_vector = self._embed(terms) + keywords = extract_keywords(terms) qb = _QueryBuilder() qb.add_filter("c.user_id", "@user_id", user_id) @@ -944,11 +947,11 @@ def search_turns( before_param="@created_before", ) - sql = build_search_sql(qb=qb, top=top, hybrid_search=hybrid_search, include_superseded=False) + sql = build_search_sql(qb=qb, top=top, keyword_count=len(keywords), include_superseded=False) parameters = qb.get_parameters() parameters.append({"name": "@embedding", "value": query_vector}) - if hybrid_search: - parameters.append({"name": "@key_terms", "value": terms}) + for i, kw in enumerate(keywords): + parameters.append({"name": f"@kw{i}", "value": kw}) partition_key, cross_partition = query_scope(user_id, thread_id) logger.debug("MemoryStore.search_turns query: %s", sql) diff --git a/azure/cosmos/agent_memory/thresholds.py b/azure/cosmos/agent_memory/thresholds.py index 7fffc60..ca4c508 100644 --- a/azure/cosmos/agent_memory/thresholds.py +++ b/azure/cosmos/agent_memory/thresholds.py @@ -28,6 +28,23 @@ # of 500 (enforced by the pipeline) bounds prompt size and LLM cost. DEFAULT_DEDUP_POOL_SIZE = 50 +# --------------------------------------------------------------------------- +# INTERNAL dedup/search tuning — NOT customer-configurable. +# These ship as fixed feature constants (no env vars, not in any settings +# template). They are maintainer-tunable here in code only; if a knob ever +# needs to become operator-facing we add the env plumbing back deliberately. +# The dedup + hybrid-search features ship ON via these values. +# --------------------------------------------------------------------------- +DEDUP_CONTEXT_VECTOR_ENABLED = True # Stage-1 relevance-ranked extraction context +DEDUP_CONTEXT_TOPK = 10 +DEDUP_VECTOR_ENABLED = True # Stage-3 vector near-dup ladder +DEDUP_SIM_HIGH = 0.97 # >= -> auto-skip near-exact +DEDUP_SIM_LOW = 0.80 # < -> novel; between -> tag candidate +DEDUP_CANDIDATE_TOPK = 10 +DEDUP_RECONCILE_MODE = "candidate" # clustered candidate reconcile (vs legacy full_pool) +DEDUP_CLUSTER_SIM = 0.60 # Stage-5 clustering edge threshold +DEDUP_FULL_RECLUSTER_EVERY_N = 12 # full re-cluster safety net cadence + DEFAULT_TTL_BY_TYPE: dict[str, int] = { "turn": 2_592_000, "episodic": 7_776_000, @@ -138,6 +155,56 @@ def get_dedup_pool_size() -> int: return raw +# --------------------------------------------------------------------------- +# Internal dedup/search feature accessors — return fixed constants (no env). +# Kept as thin functions so call sites stay stable and the values can be +# changed in one place; NOT customer-configurable. +# --------------------------------------------------------------------------- +def get_dedup_context_vector_enabled() -> bool: + """Whether Stage-1 extraction context uses vector retrieval (internal; on).""" + return DEDUP_CONTEXT_VECTOR_ENABLED + + +def get_dedup_context_topk() -> int: + """Top-K memories to retrieve for Stage-1 extraction context (internal).""" + return DEDUP_CONTEXT_TOPK + + +def get_dedup_vector_enabled() -> bool: + """Whether Stage-3 vector deduplication is enabled (internal; on).""" + return DEDUP_VECTOR_ENABLED + + +def get_dedup_sim_high() -> float: + """Similarity at/above which near-exact memories are auto-skipped (internal).""" + return DEDUP_SIM_HIGH + + +def get_dedup_sim_low() -> float: + """Similarity below which memories are treated as novel (internal).""" + return DEDUP_SIM_LOW + + +def get_dedup_candidate_topk() -> int: + """Top-K existing memories pulled per new memory in Stage-3 (internal).""" + return DEDUP_CANDIDATE_TOPK + + +def get_dedup_reconcile_mode() -> str: + """Reconcile mode: ``candidate`` clustering (internal; the shipped feature).""" + return DEDUP_RECONCILE_MODE + + +def get_dedup_cluster_sim() -> float: + """Similarity threshold for Stage-5 clustering edges (internal).""" + return DEDUP_CLUSTER_SIM + + +def get_dedup_full_recluster_every_n() -> int: + """Full re-cluster safety-net cadence, every Nth reconcile sweep (internal).""" + return DEDUP_FULL_RECLUSTER_EVERY_N + + def get_procedural_synthesis_auto() -> bool: """Whether procedural synthesis auto-fires after extract. @@ -214,6 +281,15 @@ def get_processor_owner() -> Optional[str]: "get_user_summary_every_n", "get_dedup_every_n", "get_dedup_pool_size", + "get_dedup_context_vector_enabled", + "get_dedup_context_topk", + "get_dedup_vector_enabled", + "get_dedup_sim_high", + "get_dedup_sim_low", + "get_dedup_candidate_topk", + "get_dedup_reconcile_mode", + "get_dedup_cluster_sim", + "get_dedup_full_recluster_every_n", "get_procedural_synthesis_auto", "get_enable_turn_embeddings", "get_processor_owner", diff --git a/function_app/local.settings.json.template b/function_app/local.settings.json.template index e2f2fa6..2d65bde 100644 --- a/function_app/local.settings.json.template +++ b/function_app/local.settings.json.template @@ -22,6 +22,7 @@ "// --- Threshold knobs (set to 0 to disable that orchestrator) ---": "", "THREAD_SUMMARY_EVERY_N": "10", "FACT_EXTRACTION_EVERY_N": "1", + "DEDUP_EVERY_N": "5", "USER_SUMMARY_EVERY_N": "20", "// --- Processor ownership: 'durable' (FA owns) or 'inprocess' (SDK owns). FA skips processing if not set to 'durable' to avoid double-fire when an SDK install runs against the same Cosmos. ---": "", diff --git a/function_app/orchestrators/extract_memories.py b/function_app/orchestrators/extract_memories.py index a911f20..ec859bc 100644 --- a/function_app/orchestrators/extract_memories.py +++ b/function_app/orchestrators/extract_memories.py @@ -1,7 +1,8 @@ """Memory-extraction orchestrator + activities. -Chain: ``Extract`` → ``Persist`` followed by an optional ``ReconcileMemories`` -activity, then a best-effort ``SynthesizeProceduralOrchestrator`` sub-call. +Chain: ``Extract`` → ``Dedup`` → ``Persist`` followed by an optional +``ReconcileMemories`` activity, then a best-effort +``SynthesizeProceduralOrchestrator`` sub-call. Reconciliation is gated by the change-feed trigger (which tracks the per-user/thread turn counter) and signaled to the orchestrator via the ``reconcile`` flag on its input payload. Procedural synthesis fires only @@ -33,26 +34,44 @@ def ExtractMemoriesOrchestrator(context: df.DurableOrchestrationContext): user_id = payload["user_id"] thread_id = payload["thread_id"] should_reconcile = bool(payload.get("reconcile", False)) + full_rebuild = bool(payload.get("full_rebuild", False)) + recent_k = payload.get("recent_k") retry = default_retry_options() + extract_payload = {"user_id": user_id, "thread_id": thread_id} + if recent_k is not None: + extract_payload["recent_k"] = recent_k extracted = yield context.call_activity_with_retry( "em_Extract", retry, - {"user_id": user_id, "thread_id": thread_id, "limit": config.get_max_batch_size()}, + extract_payload, + ) + deduped = yield context.call_activity_with_retry( + "em_Dedup", + retry, + {"user_id": user_id, "extracted": extracted}, ) persisted = yield context.call_activity_with_retry( "em_Persist", retry, - {"user_id": user_id, "extracted": extracted}, + {"user_id": user_id, "extracted": deduped}, ) + count = payload.get("count") + if count is not None: + yield context.call_activity_with_retry( + "em_AdvanceExtractWatermark", + retry, + {"user_id": user_id, "thread_id": thread_id, "count": count}, + ) + reconciled = None procedural = None if should_reconcile: reconciled = yield context.call_activity_with_retry( "em_ReconcileMemories", retry, - {"user_id": user_id}, + {"user_id": user_id, "full_rebuild": full_rebuild}, ) if config.get_procedural_synthesis_auto(): count = payload.get("count") @@ -86,11 +105,13 @@ def em_Extract(payload: dict) -> dict: """Load recent turns and run LLM extraction without embeddings or writes.""" user_id = payload["user_id"] thread_id = payload["thread_id"] - limit = payload.get("limit") + recent_k = payload.get("recent_k") + if recent_k is None: + recent_k = config.get_max_batch_size() extracted = get_pipeline().extract_memories_dry( user_id=user_id, thread_id=thread_id, - recent_k=limit, + recent_k=recent_k, ) logger.info( "ExtractMemories extracted user=%s thread=%s facts=%d episodic=%d updates=%d", @@ -103,6 +124,15 @@ def em_Extract(payload: dict) -> dict: return extracted +@bp.activity_trigger(input_name="payload") +def em_Dedup(payload: dict) -> dict: + """vector-floor dedup ladder (gated; passthrough when disabled).""" + return get_pipeline().dedup_extracted_memories( + user_id=payload["user_id"], + extracted=payload["extracted"], + ) or payload["extracted"] + + @bp.activity_trigger(input_name="payload") def em_Persist(payload: dict) -> dict: """Persist extracted docs with embeddings and deterministic create semantics.""" @@ -115,12 +145,40 @@ def em_Persist(payload: dict) -> dict: return counts or {} +@bp.activity_trigger(input_name="payload") +async def em_AdvanceExtractWatermark(payload: dict) -> bool: + """Advance the extraction watermark after a successful extract→persist. + + Stamps ``last_extract_count`` on the thread counter so the next batch's + recent_k spans only turns added since this run, never skipping any. + Runs only on success (after persist) so failed extracts re-process. + """ + from shared.cosmos_clients import get_counter_container_async + from shared.counters import advance_extract_watermark, thread_counter_id + + user_id = payload["user_id"] + thread_id = payload["thread_id"] + count = payload["count"] + container = await get_counter_container_async() + await advance_extract_watermark(container, thread_counter_id(user_id, thread_id), user_id, thread_id, count) + return True + + @bp.activity_trigger(input_name="payload") def em_ReconcileMemories(payload: dict) -> dict: # GA keeps reconcile single-activity: its LLM dedup decisions and supersession - # operations are larger/more coupled than the extract→persist split handled here. + # operations are larger/more coupled than the extract→dedup→persist split handled here. user_id = payload["user_id"] + # full_rebuild forces the full-pool single-LLM-pass path (catches dissimilar + # contradictions). The change-feed sets it on a persisted-counter cadence so it + # fires reliably on FA, where the in-memory candidate-mode sweep counter can't. + full_rebuild = bool(payload.get("full_rebuild", False)) pipeline = get_pipeline() from azure.cosmos.agent_memory.thresholds import get_dedup_pool_size - return pipeline.reconcile_memories(user_id=user_id, n=get_dedup_pool_size()) or {} + n = get_dedup_pool_size() + facts = pipeline.reconcile_memories(user_id=user_id, n=n, memory_type="fact", full_rebuild=full_rebuild) or {} + episodic = ( + pipeline.reconcile_memories(user_id=user_id, n=n, memory_type="episodic", full_rebuild=full_rebuild) or {} + ) + return {"fact": facts, "episodic": episodic} diff --git a/function_app/shared/config.py b/function_app/shared/config.py index f15ed35..da4f4fa 100644 --- a/function_app/shared/config.py +++ b/function_app/shared/config.py @@ -116,17 +116,6 @@ def _parse_int(name: str, default: int) -> int: return default -def _parse_float(name: str, default: float) -> float: - raw = os.environ.get(name) - if raw is None or raw == "": - return default - try: - return float(raw) - except (ValueError, TypeError): - logger.warning("Invalid value for %s=%r, using default %f", name, raw, default) - return default - - def _parse_bool(name: str, default: bool) -> bool: raw = os.environ.get(name) if raw is None or raw == "": diff --git a/function_app/shared/counters.py b/function_app/shared/counters.py index 1ab09cc..1adb0e5 100644 --- a/function_app/shared/counters.py +++ b/function_app/shared/counters.py @@ -177,6 +177,8 @@ async def increment_counter_by( new_doc["last_failure_at"] = existing_doc.get("last_failure_at") if "last_failure_reason" in existing_doc: new_doc["last_failure_reason"] = existing_doc.get("last_failure_reason") + if "last_extract_count" in existing_doc: + new_doc["last_extract_count"] = existing_doc.get("last_extract_count") # Stamp the writing backend (advisory only — not enforced server-side). if owner is not None: new_doc["last_owner"] = owner @@ -259,3 +261,40 @@ def crosses_threshold(old_count: int, new_count: int, n: int) -> bool: if n <= 0: raise ValueError("n must be > 0") return old_count // n != new_count // n + + +async def read_extract_watermark( + container: Any, + counter_id: str, + user_id: str, + thread_id: str, +) -> int | None: + """Return the count value at the last successful extract, or ``None``. + + Lets recent_k cover every turn since the previous extract succeeded so + turns are never skipped when extraction lags or transiently fails. + Best-effort: returns ``None`` on any read error so callers fall back to a + batch-based recent_k. + """ + try: + doc = await container.read_item(item=counter_id, partition_key=[user_id, thread_id]) + value = doc.get("last_extract_count") + return int(value) if value is not None else None + except Exception as exc: # pragma: no cover - best-effort + logger.debug("read_extract_watermark failed counter_id=%s: %s", counter_id, exc) + return None + + +async def advance_extract_watermark( + container: Any, + counter_id: str, + user_id: str, + thread_id: str, + count: int, +) -> None: + """Stamp ``last_extract_count=count`` after a successful extract.""" + patch_ops = [{"op": "add", "path": "/last_extract_count", "value": int(count)}] + try: + await container.patch_item(item=counter_id, partition_key=[user_id, thread_id], patch_operations=patch_ops) + except Exception as exc: # pragma: no cover - best-effort + logger.debug("advance_extract_watermark failed counter_id=%s: %s", counter_id, exc) diff --git a/function_app/triggers/change_feed.py b/function_app/triggers/change_feed.py index c7b2c8b..4630a46 100644 --- a/function_app/triggers/change_feed.py +++ b/function_app/triggers/change_feed.py @@ -26,6 +26,7 @@ from shared.counters import ( crosses_threshold, increment_counter_by, + read_extract_watermark, thread_counter_id, user_counter_id, ) @@ -128,6 +129,16 @@ async def process_changefeed_batch( # Reconcile fires every (n_facts * n_dedup) turns, matching the SDK # auto-trigger contract. Disabled when either knob is 0. n_dedup_turns = n_facts * n_dedup if (n_facts > 0 and n_dedup > 0) else 0 + # Full-pool reconcile backstop cadence. The candidate-mode "every Nth sweep" + # gate is an in-memory per-worker counter, which on the FA backend (per-worker + # singleton pipeline) resets on recycle/scale and never reaches N for a user — + # so the full-pool pass that catches dissimilar-embedding contradictions almost + # never fires. Derive it from the PERSISTED counter instead: every + # DEDUP_FULL_RECLUSTER_EVERY_N-th reconcile = every n_dedup_turns * N turns. + from azure.cosmos.agent_memory.thresholds import get_dedup_full_recluster_every_n + + n_full_recluster = get_dedup_full_recluster_every_n() + n_full_turns = n_dedup_turns * n_full_recluster if (n_dedup_turns > 0 and n_full_recluster > 0) else 0 if n_thread == 0 and n_facts == 0 and n_user == 0: return # all orchestrators disabled @@ -206,6 +217,22 @@ async def process_changefeed_batch( if n_facts > 0 and crosses_threshold(old_count, new_count, n_facts): instance_id = f"extract:{user_id}:{thread_id}:{new_count}" should_reconcile = bool(n_dedup_turns > 0 and crosses_threshold(old_count, new_count, n_dedup_turns)) + # Persisted-counter backstop: every n_full_turns turns, force a + # full-pool reconcile so dissimilar-embedding contradictions are + # caught reliably on FA (not gated by the in-memory sweep counter). + should_full_reconcile = bool( + n_full_turns > 0 and crosses_threshold(old_count, new_count, n_full_turns) + ) + watermark = await read_extract_watermark(counter_container, cid, user_id, thread_id) + # Not capped: new_count - watermark is exactly the unextracted backlog + # and the orchestrator advances the watermark to new_count, so capping + # would strand the oldest turns (DEDUP_POOL_SIZE is the reconcile knob, + # not the extraction window). Bootstrap: with no watermark yet, base=0 + # so recent_k = new_count covers every turn so far — using only this + # batch (new_count - old_count) would strand turns added during earlier + # failed extracts once the watermark first advances to new_count. + base = watermark if watermark is not None else 0 + recent_k = max(new_count - base, 1) await _safe_start( starter, "ExtractMemoriesOrchestrator", @@ -215,6 +242,8 @@ async def process_changefeed_batch( "thread_id": thread_id, "count": new_count, "reconcile": should_reconcile, + "full_rebuild": should_full_reconcile, + "recent_k": recent_k, }, orchestration_errors, ) diff --git a/tests/integration/test_async_full_pipeline.py b/tests/integration/test_async_full_pipeline.py new file mode 100644 index 0000000..88f66c4 --- /dev/null +++ b/tests/integration/test_async_full_pipeline.py @@ -0,0 +1,274 @@ +"""Async live integration smoke for ``AsyncCosmosMemoryClient``. + +The async pipeline mirrors every sync dedup/reconcile code path line-for-line +(watermark, euclidean guard, ``exclude_ids`` parity, reconcile cadence, the +vector-floor dedup ladder). The sync suite covers all of that against a live +backend, but the async client had **no** live coverage — its processor test is +fully mocked. This module exercises the real async end-to-end flow (write turns +→ extract → reconcile → search) so the async mirror can't silently diverge from +sync without a test failing. + +The Azure Function host is **not** required: the same ``AsyncPipelineService`` +the change-feed trigger drives is exposed directly on the client. + +Enable by setting:: + + AGENT_MEMORY_RUN_INTEGRATION=true + +Auth: ``COSMOS_DB_KEY`` is used when present; otherwise ``DefaultAzureCredential``. +""" + +from __future__ import annotations + +import asyncio +import time +import uuid + +import pytest + +from azure.cosmos.agent_memory.aio.cosmos_memory_client import AsyncCosmosMemoryClient +from tests.conftest import INTEGRATION_ENABLED + +pytestmark = [ + pytest.mark.integration, + pytest.mark.skipif( + not INTEGRATION_ENABLED, + reason="Set AGENT_MEMORY_RUN_INTEGRATION=true", + ), +] + + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture +async def async_agent_memory( + cosmos_endpoint, + cosmos_key, + cosmos_database, + cosmos_container, + ai_foundry_endpoint, + ai_foundry_api_key, + embedding_deployment_name, + embedding_dimensions, + chat_deployment_name, +): + """A live AsyncCosmosMemoryClient with its containers provisioned/connected.""" + if not cosmos_endpoint or not ai_foundry_endpoint: + pytest.skip("COSMOS_DB_ENDPOINT / AI_FOUNDRY_ENDPOINT not set") + + client = AsyncCosmosMemoryClient( + cosmos_endpoint=cosmos_endpoint, + cosmos_key=cosmos_key or None, + cosmos_database=cosmos_database, + cosmos_container=cosmos_container, + ai_foundry_endpoint=ai_foundry_endpoint, + ai_foundry_api_key=ai_foundry_api_key or None, + embedding_deployment_name=embedding_deployment_name, + embedding_dimensions=embedding_dimensions, + chat_deployment_name=chat_deployment_name, + ) + # The async client cannot auto-connect in __init__; do it explicitly. + await client.create_memory_store() + try: + yield client + finally: + await client.close() + + +async def _async_add_turns( + mem: AsyncCosmosMemoryClient, + user_id: str, + thread_id: str, + turns: list[tuple[str, str]], +) -> None: + for role, content in turns: + await mem.add_cosmos( + user_id=user_id, + role=role, + content=content, + memory_type="turn", + thread_id=thread_id, + ) + + +async def _async_cleanup(mem: AsyncCosmosMemoryClient, user_id: str) -> None: + """Best-effort delete of every document for *user_id* across all containers.""" + sql = "SELECT c.id, c.thread_id FROM c WHERE c.user_id = @uid" + params = [{"name": "@uid", "value": user_id}] + for container in ( + mem._turns_container_client, + mem._memories_container_client, + mem._summaries_container_client, + ): + if container is None: + continue + try: + docs = [doc async for doc in container.query_items(query=sql, parameters=params)] + except Exception: + continue + for doc in docs: + try: + await container.delete_item( + item=doc["id"], + partition_key=[user_id, doc.get("thread_id", "")], + ) + except Exception: + pass + + +async def _async_seed_fact_with_embedding( + mem: AsyncCosmosMemoryClient, + user_id: str, + thread_id: str, + content: str, + *, + retries: int = 4, +) -> None: + """Seed a fact and confirm it stored *with* an embedding (async mirror of the + sync helper). Retries through transient embedding-service blips so the + extract-time vector floor always has a neighbour; skips honestly if the + embedding service is genuinely unavailable.""" + check = ( + "SELECT c.id FROM c WHERE c.user_id = @uid AND c.content = @content " + "AND IS_DEFINED(c.embedding)" + ) + params = [{"name": "@uid", "value": user_id}, {"name": "@content", "value": content}] + for _ in range(retries): + await mem.add_cosmos( + user_id=user_id, + role="user", + content=content, + memory_type="fact", + thread_id=thread_id, + salience=0.7, + ) + embedded = [ + doc async for doc in mem._memories_container_client.query_items(query=check, parameters=params) + ] + if embedded: + return + await asyncio.sleep(1) + pytest.skip(f"embedding service unavailable — could not seed an embedded fact for {content!r}") + + +async def _async_wait_vector_searchable( + mem: AsyncCosmosMemoryClient, + user_id: str, + search_terms: str, + *, + timeout: float = 20.0, +) -> None: + """Poll vector search until the user's seeded fact is retrievable (DiskANN + index caught up), so the subsequent ``_vector_candidates`` lookup is + deterministic rather than racing the async index.""" + deadline = time.time() + timeout + while time.time() < deadline: + try: + if await mem.search_cosmos(search_terms=search_terms, user_id=user_id, top_k=5): + return + except Exception: + pass + await asyncio.sleep(1) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestAsyncEndToEnd: + async def test_extract_reconcile_search( + self, + async_agent_memory, + unique_user_id, + unique_thread_id, + ): + """Full async round-trip: turns → extract → reconcile → search.""" + mem = async_agent_memory + try: + await _async_add_turns( + mem, + unique_user_id, + unique_thread_id, + [ + ("user", "I just adopted a golden retriever puppy named Cosmo."), + ("agent", "Congrats on Cosmo! How old is the puppy?"), + ("user", "Cosmo is 10 weeks old and already loves swimming in the lake."), + ], + ) + await asyncio.sleep(1) + + extract_stats = await mem.extract_memories( + user_id=unique_user_id, + thread_id=unique_thread_id, + ) + assert isinstance(extract_stats, dict) + + facts = await mem.get_memories(user_id=unique_user_id, memory_types=["fact"]) + assert len(facts) >= 1, "Async extraction should produce at least one fact" + + reconcile_stats = await mem.reconcile(user_id=unique_user_id) + assert isinstance(reconcile_stats, dict) + assert {"kept", "merged", "contradicted"} <= set(reconcile_stats), ( + f"reconcile should return kept/merged/contradicted, got {reconcile_stats}" + ) + + results = await mem.search_cosmos( + search_terms="golden retriever puppy", + user_id=unique_user_id, + top_k=5, + ) + assert len(results) >= 1, "Async search should return at least one result" + finally: + await _async_cleanup(mem, unique_user_id) + + async def test_dedup_extracted_memories_flags_near_duplicate_of_stored_fact( + self, + async_agent_memory, + unique_user_id, + unique_thread_id, + ): + """Async extract-time vector floor drops/tags a near-duplicate fact. + + Parity check with the sync ``TestExtractTimeVectorDedup`` — guards the + async ``dedup_extracted_memories`` mirror (``_vector_candidates`` + + similarity bands) against a live backend. Driven with a controlled + near-duplicate (no LLM variance) so the assertion is deterministic. + """ + mem = async_agent_memory + try: + await _async_seed_fact_with_embedding( + mem, unique_user_id, unique_thread_id, "The user has a cat named Whiskers." + ) + await _async_wait_vector_searchable(mem, unique_user_id, "cat named Whiskers") + + extracted = { + "facts": [ + { + "id": f"fact_{uuid.uuid4().hex}", + "type": "fact", + "user_id": unique_user_id, + "thread_id": unique_thread_id, + "content": "The user's cat is called Whiskers.", + "tags": [], + } + ], + "episodic": [], + "updates": [], + } + result = await mem._get_pipeline().dedup_extracted_memories(unique_user_id, extracted) + + stats = next((op for op in result.get("updates", []) if op.get("op") == "stats"), {}) + suppressed = int(stats.get("vector_dedup_skipped", 0)) + int(stats.get("dup_candidates_tagged", 0)) + surviving = result.get("facts", []) + was_dropped = len(surviving) == 0 + was_tagged = any("sys:dup-candidate" in (f.get("tags") or []) for f in surviving) + assert suppressed >= 1 and (was_dropped or was_tagged), ( + "Async vector floor should drop or tag the near-duplicate of the stored " + f"'cat named Whiskers' fact; surviving={surviving} stats={stats}" + ) + finally: + await _async_cleanup(mem, unique_user_id) diff --git a/tests/integration/test_full_pipeline.py b/tests/integration/test_full_pipeline.py index 6e80bd9..dcb217e 100644 --- a/tests/integration/test_full_pipeline.py +++ b/tests/integration/test_full_pipeline.py @@ -122,6 +122,70 @@ def _delete(container, doc: dict) -> None: _delete(container, doc) +def _seed_fact_with_embedding( + mem: CosmosMemoryClient, + user_id: str, + thread_id: str, + content: str, + *, + retries: int = 4, +) -> None: + """Seed a fact and confirm it was stored *with* an embedding. + + ``add_cosmos`` generates the embedding synchronously; a transient + embedding-service blip logs "proceeding without embedding" and stores the doc + without a vector, which would leave the extract-time vector floor with no + neighbour to match. Retry until an embedded copy exists (indexing is fast — + the doc is vector-searchable within ~2s), and skip honestly if the embedding + service is genuinely unavailable rather than reporting a false failure.""" + check = ( + "SELECT c.id FROM c WHERE c.user_id = @uid AND c.content = @content " + "AND IS_DEFINED(c.embedding)" + ) + params = [{"name": "@uid", "value": user_id}, {"name": "@content", "value": content}] + for _ in range(retries): + mem.add_cosmos( + user_id=user_id, + role="user", + content=content, + memory_type="fact", + thread_id=thread_id, + salience=0.7, + ) + embedded = list( + mem._memories_container_client.query_items( + query=check, parameters=params, enable_cross_partition_query=True + ) + ) + if embedded: + return + time.sleep(1) + pytest.skip(f"embedding service unavailable — could not seed an embedded fact for {content!r}") + + +def _wait_vector_searchable( + mem: CosmosMemoryClient, + user_id: str, + search_terms: str, + *, + timeout: float = 20.0, +) -> None: + """Poll vector search until the user's seeded fact is retrievable. + + ``add_cosmos`` stores the embedding synchronously, but Cosmos's DiskANN vector + index catches up asynchronously (~1-2s). Gating on a real vector search makes + the subsequent ``_vector_candidates`` lookup deterministic instead of racing + the index.""" + deadline = time.time() + timeout + while time.time() < deadline: + try: + if mem.search_cosmos(search_terms=search_terms, user_id=user_id, top_k=5): + return + except Exception: + pass + time.sleep(1) + + # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @@ -243,7 +307,7 @@ def test_multi_thread_user_summary(self, agent_memory, unique_user_id): class TestSearchAfterExtraction: - def test_vector_and_hybrid_search(self, agent_memory, unique_user_id, unique_thread_id): + def test_search_after_extraction(self, agent_memory, unique_user_id, unique_thread_id): try: _add_turns( agent_memory, @@ -268,13 +332,12 @@ def test_vector_and_hybrid_search(self, agent_memory, unique_user_id, unique_thr ) assert len(vec) >= 1, "Vector search should return at least 1 result" - hyb = agent_memory.search_cosmos( + hybrid = agent_memory.search_cosmos( search_terms="Buddy the dog park", user_id=unique_user_id, - hybrid_search=True, top_k=5, ) - assert len(hyb) >= 1, "Hybrid search should return at least 1 result" + assert len(hybrid) >= 1, "Hybrid search should return at least 1 result" finally: _cleanup(agent_memory, unique_user_id) @@ -504,3 +567,61 @@ def test_reconcile_writes_supersede_metadata(self, agent_memory, unique_user_id, assert any(m["id"] == survivor_id for m in live), "supersede_by must point at a live record" finally: _cleanup(agent_memory, unique_user_id) + + +class TestExtractTimeVectorDedup: + """Extract-time vector floor (``dedup_extracted_memories``), distinct from the + ``reconcile`` path. A freshly-extracted fact that near-duplicates an + *already-stored* fact is either auto-dropped (``vector_dedup_skipped``, + sim >= DEDUP_SIM_HIGH) or tagged ``sys:dup-candidate`` + (``dup_candidates_tagged``, DEDUP_SIM_LOW <= sim < DEDUP_SIM_HIGH). + + The ladder is driven directly with a controlled extracted fact rather than + through the LLM: extraction phrasing varies run-to-run and often lands the + fact below the 0.80 floor (or produces unrelated facts), which is a property + of the model, not the dedup code. Feeding a fixed near-duplicate keeps the + assertion deterministic while still exercising the real embedding call, the + live Cosmos ``VectorDistance`` query, and the similarity bands.""" + + def test_dedup_extracted_memories_flags_near_duplicate_of_stored_fact( + self, agent_memory, unique_user_id, unique_thread_id + ): + try: + # Seed a stored fact (embedded + vector-indexed) to dedup against. + # Concrete, minimally-reworded facts embed ~0.93-0.98 cosine — well + # inside the DEDUP_SIM_LOW (0.80) / DEDUP_SIM_HIGH (0.97) bands. + _seed_fact_with_embedding( + agent_memory, unique_user_id, unique_thread_id, "The user has a cat named Whiskers." + ) + _wait_vector_searchable(agent_memory, unique_user_id, "cat named Whiskers") + + # A controlled "extracted" near-duplicate (not byte-identical to the + # seed, so this is the vector floor rather than an exact-hash match). + extracted = { + "facts": [ + { + "id": f"fact_{uuid.uuid4().hex}", + "type": "fact", + "user_id": unique_user_id, + "thread_id": unique_thread_id, + "content": "The user's cat is called Whiskers.", + "tags": [], + } + ], + "episodic": [], + "updates": [], + } + result = agent_memory._get_pipeline().dedup_extracted_memories(unique_user_id, extracted) + + stats = next((op for op in result.get("updates", []) if op.get("op") == "stats"), {}) + suppressed = int(stats.get("vector_dedup_skipped", 0)) + int(stats.get("dup_candidates_tagged", 0)) + surviving = result.get("facts", []) + was_dropped = len(surviving) == 0 + was_tagged = any("sys:dup-candidate" in (f.get("tags") or []) for f in surviving) + assert suppressed >= 1 and (was_dropped or was_tagged), ( + "Vector floor should drop or tag the near-duplicate of the stored " + f"'cat named Whiskers' fact; surviving={surviving} stats={stats}" + ) + finally: + _cleanup(agent_memory, unique_user_id) + diff --git a/tests/integration/test_processor_integration.py b/tests/integration/test_processor_integration.py index 1995626..b5a0be2 100644 --- a/tests/integration/test_processor_integration.py +++ b/tests/integration/test_processor_integration.py @@ -9,6 +9,7 @@ from __future__ import annotations from unittest.mock import MagicMock +from unittest.mock import call as mock_call from azure.cosmos.agent_memory.cosmos_memory_client import CosmosMemoryClient from azure.cosmos.agent_memory.processors import ( @@ -70,7 +71,13 @@ def test_process_now_invokes_pipeline_with_correct_args(self): client.get_thread.assert_called_once_with(thread_id="thread-paris", user_id="u-paris") pipeline.generate_thread_summary.assert_called_once_with("u-paris", "thread-paris") pipeline.extract_memories.assert_called_once_with("u-paris", "thread-paris") - pipeline.reconcile_memories.assert_called_once_with("u-paris", 50) + assert pipeline.reconcile_memories.call_count == 2 + pipeline.reconcile_memories.assert_has_calls( + [ + mock_call("u-paris", n=50, memory_type="fact", full_rebuild=False), + mock_call("u-paris", n=50, memory_type="episodic", full_rebuild=False), + ] + ) # --------------------------------------------------------------------------- diff --git a/tests/integration/test_processor_integration_async.py b/tests/integration/test_processor_integration_async.py index f437715..5769c3e 100644 --- a/tests/integration/test_processor_integration_async.py +++ b/tests/integration/test_processor_integration_async.py @@ -3,6 +3,7 @@ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock +from unittest.mock import call as mock_call import pytest @@ -29,9 +30,9 @@ class TestAsyncInProcessProcessNowEndToEnd: @pytest.mark.asyncio async def test_process_now_invokes_pipeline_with_correct_args(self): pipeline = MagicMock() - # ``AsyncInProcessProcessor`` awaits these three pipeline methods, - # so the mocks must be ``AsyncMock`` — a plain ``MagicMock`` returns - # the dict synchronously, which ``await`` cannot consume. + # ``AsyncInProcessProcessor`` awaits these pipeline methods, so the + # mocks must be ``AsyncMock`` — a plain ``MagicMock`` returns the dict + # synchronously, which ``await`` cannot consume. pipeline.generate_thread_summary = AsyncMock( return_value={ "id": "summary-1", @@ -47,6 +48,8 @@ async def test_process_now_invokes_pipeline_with_correct_args(self): } ) pipeline.reconcile_memories = AsyncMock(return_value={"kept": 0, "merged": 0, "contradicted": 0}) + pipeline.synthesize_procedural = AsyncMock(return_value={"id": "proc-1", "type": "procedural"}) + pipeline.generate_user_summary = AsyncMock(return_value={"id": "us-1", "type": "user_summary"}) processor = AsyncInProcessProcessor(pipeline=pipeline) client = _build_client(processor=processor) @@ -73,7 +76,13 @@ async def test_process_now_invokes_pipeline_with_correct_args(self): client.get_thread.assert_awaited_once_with(thread_id="thread-paris", user_id="u-paris") pipeline.generate_thread_summary.assert_awaited_once_with("u-paris", "thread-paris") pipeline.extract_memories.assert_awaited_once_with("u-paris", "thread-paris") - pipeline.reconcile_memories.assert_awaited_once_with("u-paris", 50) + assert pipeline.reconcile_memories.await_count == 2 + pipeline.reconcile_memories.assert_has_awaits( + [ + mock_call("u-paris", n=50, memory_type="fact", full_rebuild=False), + mock_call("u-paris", n=50, memory_type="episodic", full_rebuild=False), + ] + ) # --------------------------------------------------------------------------- diff --git a/tests/unit/aio/processors/test_inprocess.py b/tests/unit/aio/processors/test_inprocess.py index 2d0852b..ec22819 100644 --- a/tests/unit/aio/processors/test_inprocess.py +++ b/tests/unit/aio/processors/test_inprocess.py @@ -24,10 +24,11 @@ async def test_process_thread_calls_summarize_extract_reconcile_in_order(): "generate_thread_summary", "extract_memories", "reconcile_memories", + "reconcile_memories", ] assert isinstance(result, ProcessThreadResult) assert result.thread_summary == {"id": "summary", "type": "thread_summary"} - assert result.reconciled_count == 2 + assert result.reconciled_count == 4 @pytest.mark.asyncio @@ -62,8 +63,19 @@ async def test_process_reconcile_invokes_pipeline_with_env_pool_size(monkeypatch proc = AsyncInProcessProcessor(pipeline=pipeline) count = await proc.process_reconcile(user_id="u") - pipeline.reconcile_memories.assert_called_once_with("u", 37) - assert count == 5 # merged + contradicted + # Reconciles fact + episodic; both forward the env pool size. + assert pipeline.reconcile_memories.await_count == 2 + assert pipeline.reconcile_memories.await_args_list[0].kwargs == { + "n": 37, + "memory_type": "fact", + "full_rebuild": False, + } + assert pipeline.reconcile_memories.await_args_list[1].kwargs == { + "n": 37, + "memory_type": "episodic", + "full_rebuild": False, + } + assert count == 10 # (merged+contradicted) x2 types @pytest.mark.asyncio @@ -75,9 +87,9 @@ async def test_process_extract_memories_invokes_pipeline_and_filters_to_ints(): } proc = AsyncInProcessProcessor(pipeline=pipeline) - result = await proc.process_extract_memories(user_id="u", thread_id="t") + result = await proc.process_extract_memories(user_id="u", thread_id="t", recent_k=3) - pipeline.extract_memories.assert_called_once_with("u", "t") + pipeline.extract_memories.assert_called_once_with("u", "t", recent_k=3) assert result == {"fact_count": 3} diff --git a/tests/unit/aio/processors/test_protocol_satisfaction.py b/tests/unit/aio/processors/test_protocol_satisfaction.py index 04cda06..2a23515 100644 --- a/tests/unit/aio/processors/test_protocol_satisfaction.py +++ b/tests/unit/aio/processors/test_protocol_satisfaction.py @@ -32,6 +32,7 @@ async def process_extract_memories( *, user_id: str, thread_id: str, + recent_k: Optional[int] = None, ) -> dict[str, int]: return {} diff --git a/tests/unit/aio/services/test_dedup_vector_async.py b/tests/unit/aio/services/test_dedup_vector_async.py new file mode 100644 index 0000000..be62553 --- /dev/null +++ b/tests/unit/aio/services/test_dedup_vector_async.py @@ -0,0 +1,446 @@ +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from azure.cosmos.agent_memory.aio.services.pipeline import AsyncPipelineService + + +def _service() -> AsyncPipelineService: + p = AsyncPipelineService.__new__(AsyncPipelineService) + p._memories_container = MagicMock() + p._embed_batch = AsyncMock() + p._embed_one = AsyncMock(return_value=[0.1, 0.2]) + p._upsert_memory = AsyncMock(side_effect=lambda doc: doc) + p._mark_superseded = AsyncMock(return_value=True) + return p + + +def _fact(fid: str, content: str, embedding=None, tags=None, metadata=None) -> dict: + return { + "id": fid, + "user_id": "u1", + "thread_id": "t1", + "type": "fact", + "role": "system", + "content": content, + "content_hash": "0" * 32, + "confidence": 0.8, + "salience": 0.7, + "tags": list(tags or ["sys:fact"]), + "metadata": dict(metadata or {"category": "preference"}), + "created_at": "2025-01-01T00:00:00+00:00", + "embedding": embedding or [1.0, 0.0], + } + + +def _episode(eid: str, content: str) -> dict: + return { + "id": eid, + "user_id": "u1", + "thread_id": "t1", + "type": "episodic", + "role": "system", + "content": content, + "content_hash": "1" * 32, + "confidence": 0.8, + "salience": 0.7, + "tags": ["sys:episodic", "sys:dup-candidate"], + "metadata": { + "scope_type": "project", + "scope_value": "CI", + "lesson": content, + "outcome_valence": "positive", + }, + "created_at": "2025-01-01T00:00:00+00:00", + "embedding": [1.0, 0.0], + } + + +@pytest.mark.asyncio +async def test_vector_distance_function_reads_container_policy(): + # The distance function comes from the container's vector embedding policy + # (read once, cached), NOT an env var. + p = _service() + p._memories_container.read = AsyncMock( + return_value={ + "vectorEmbeddingPolicy": {"vectorEmbeddings": [{"path": "/embedding", "distanceFunction": "dotproduct"}]} + } + ) + assert await p._vector_distance_function() == "dotproduct" + assert await p._vector_distance_function() == "dotproduct" + assert p._memories_container.read.await_count == 1 + + +@pytest.mark.asyncio +async def test_vector_candidates_orders_nearest_first_by_distance_function(): + # Regression: async _vector_candidates must order most-similar-first per the + # container's distanceFunction. For cosine/dotproduct higher score = more + # similar (DESC); for euclidean lower distance = more similar (ASC). A missing + # DESC silently fetched the LEAST-similar rows when the pool exceeded top_k. + p = _service() + captured: dict[str, str] = {} + + async def fake_query_items(_container, *, query, parameters): + captured["query"] = query + return [ + {"id": "near", "content": "a", "type": "fact", "score": 0.95}, + {"id": "far", "content": "b", "type": "fact", "score": 0.10}, + ] + + p._query_items = AsyncMock(side_effect=fake_query_items) + + p._distance_function_cache = "cosine" + out = await p._vector_candidates( + user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set() + ) + # Cosmos rejects an explicit ASC/DESC on ORDER BY VectorDistance(); it orders + # most-similar-first server-side. Direction-awareness lives in the Python sort. + assert "ORDER BY VectorDistance(c.embedding, @vec)" in captured["query"] + assert "VectorDistance(c.embedding, @vec) DESC" not in captured["query"] + assert "VectorDistance(c.embedding, @vec) ASC" not in captured["query"] + assert [c["id"] for c in out] == ["near", "far"] + + p._distance_function_cache = "euclidean" + out = await p._vector_candidates( + user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set() + ) + assert "VectorDistance(c.embedding, @vec) ASC" not in captured["query"] + # euclidean: lower distance = more similar, so 0.10 ("far" label) sorts first. + # euclidean: lower distance = more similar, so 0.10 ("far" label) sorts first. + assert [c["id"] for c in out] == ["far", "near"] + + +@pytest.mark.asyncio +async def test_candidate_mode_clears_tags_only_for_survivors(): + # Latent-bug regression (async mirror): consumed sources must not be + # re-upserted by tag-clearing, which would resurrect them without superseded_by. + p = _service() + f1 = _fact("f1", "a", tags=["sys:fact", "sys:dup-candidate"]) + f2 = _fact("f2", "b", tags=["sys:fact", "sys:dup-candidate"]) + f3 = _fact("f3", "c", tags=["sys:fact", "sys:dup-candidate"]) + p._build_candidate_clusters = AsyncMock(return_value=([[f1, f2, f3]], 3, [f1, f2, f3])) + p._reconcile_pool = AsyncMock(return_value=({"kept": 1, "merged": 2, "contradicted": 0}, {"f1", "f2"})) + cleared: list[str] = [] + + async def clear(docs): + cleared.extend(d["id"] for d in docs) + + p._clear_dup_candidate_tags = AsyncMock(side_effect=clear) + p._emit_reconcile_outcome = MagicMock() + + result = await p._reconcile_candidate_mode("u1", n=50, memory_type="fact", started_at=0.0) + + assert cleared == ["f3"] + assert result == { + "kept": 1, + "merged": 2, + "contradicted": 0, + "reconcile_clusters_sent": 1, + "reconcile_llm_calls_saved": 2, + } + + +@pytest.mark.asyncio +async def test_candidate_mode_clears_tags_on_orphan_seeds(): + # Async mirror: orphan dup-candidate seeds (no cluster) get their stale tag cleared. + p = _service() + orphan = _fact("orphan", "lonely", tags=["sys:fact", "sys:dup-candidate"]) + c1 = _fact("c1", "a", tags=["sys:fact", "sys:dup-candidate"]) + c2 = _fact("c2", "b", tags=["sys:fact", "sys:dup-candidate"]) + p._build_candidate_clusters = AsyncMock(return_value=([[c1, c2]], 3, [orphan, c1, c2])) + p._reconcile_pool = AsyncMock(return_value=({"kept": 2, "merged": 0, "contradicted": 0}, set())) + cleared: list[str] = [] + + async def clear(docs): + cleared.extend(d["id"] for d in docs) + + p._clear_dup_candidate_tags = AsyncMock(side_effect=clear) + p._emit_reconcile_outcome = MagicMock() + + await p._reconcile_candidate_mode("u1", n=50, memory_type="fact", started_at=0.0) + + assert set(cleared) == {"c1", "c2", "orphan"} + + +@pytest.mark.asyncio +async def test_sweep_survives_one_cluster_failure(): + # Async mirror: one failing cluster must not abort the sweep; failed cluster's + # tags are retained, remaining cluster + orphan are still cleared. + p = _service() + c1 = [ + _fact("a1", "x", tags=["sys:fact", "sys:dup-candidate"]), + _fact("a2", "y", tags=["sys:fact", "sys:dup-candidate"]), + ] + c2 = [ + _fact("b1", "p", tags=["sys:fact", "sys:dup-candidate"]), + _fact("b2", "q", tags=["sys:fact", "sys:dup-candidate"]), + ] + orphan = _fact("o1", "lonely", tags=["sys:fact", "sys:dup-candidate"]) + p._build_candidate_clusters = AsyncMock(return_value=([c1, c2], 5, [*c1, *c2, orphan])) + p._reconcile_pool = AsyncMock( + side_effect=[RuntimeError("truncated LLM response"), ({"kept": 2, "merged": 0, "contradicted": 0}, set())] + ) + cleared: list[str] = [] + + async def clear(docs): + cleared.extend(d["id"] for d in docs) + + p._clear_dup_candidate_tags = AsyncMock(side_effect=clear) + p._emit_reconcile_outcome = MagicMock() + + result = await p._reconcile_candidate_mode("u1", n=50, memory_type="fact", started_at=0.0) + + assert "a1" not in cleared and "a2" not in cleared + assert {"b1", "b2", "o1"} <= set(cleared) + assert result["reconcile_clusters_sent"] == 2 + + +@pytest.mark.asyncio +async def test_full_rebuild_clears_survivor_tags(monkeypatch): + # Async mirror: full_rebuild full-pool path clears survivor dup-candidate tags. + monkeypatch.setattr( + "azure.cosmos.agent_memory.aio.services.pipeline.get_dedup_reconcile_mode", lambda: "candidate" + ) + p = _service() + pool = [ + _fact("f1", "a", tags=["sys:fact", "sys:dup-candidate"]), + _fact("f2", "b", tags=["sys:fact", "sys:dup-candidate"]), + ] + p._active_memories_for_reconcile = AsyncMock(return_value=pool) + p._reconcile_pool = AsyncMock(return_value=({"kept": 1, "merged": 1, "contradicted": 0}, {"f1"})) + cleared: list[str] = [] + + async def clear(docs): + cleared.extend(d["id"] for d in docs) + + p._clear_dup_candidate_tags = AsyncMock(side_effect=clear) + p._emit_reconcile_outcome = MagicMock() + + await p.reconcile_memories("u1", n=50, memory_type="fact", full_rebuild=True) + + assert cleared == ["f2"] + + +@pytest.mark.asyncio +async def test_dedup_extracted_memories_flag_off_is_noop(monkeypatch): + monkeypatch.setattr("azure.cosmos.agent_memory.aio.services.pipeline.get_dedup_vector_enabled", lambda: False) + p = _service() + extracted = {"facts": [_fact("f1", "User likes tea")], "episodic": [], "updates": []} + + out = await p.dedup_extracted_memories("u1", extracted) + + assert out is extracted + p._embed_batch.assert_not_called() + + +@pytest.mark.asyncio +async def test_euclidean_disables_near_exact_autodrop(): + # Async mirror: euclidean disables the cosine-calibrated near-exact auto-drop; + # the near-identical new memory is kept + tagged for LLM reconcile. + p = _service() + p._vector_distance_function = AsyncMock(return_value="euclidean") + fact = _fact("f-new", "near identical") + fact.pop("embedding", None) + p._embed_batch.return_value = [[1.0, 0.0]] + p._vector_candidates = AsyncMock( + return_value=[{"id": "existing", "content": "same", "type": "fact", "score": 0.05}] + ) + + out = await p.dedup_extracted_memories("u1", {"facts": [fact], "episodic": [], "updates": []}) + + assert [doc["id"] for doc in out["facts"]] == ["f-new"] + assert out["facts"][0]["tags"][-1] == "sys:dup-candidate" + assert out["updates"][-1]["vector_dedup_skipped"] == 0 + assert out["updates"][-1]["dup_candidates_tagged"] == 1 + + +@pytest.mark.asyncio +async def test_dedup_extracted_memories_passes_user_id_per_concurrent_call(): + p = _service() + seen: list[tuple[str, str]] = [] + + async def vector_candidates(**kwargs): + await asyncio.sleep(0) + exclude_ids = kwargs["exclude_ids"] + doc_id = next(iter(exclude_ids)) + seen.append((doc_id, kwargs["user_id"])) + return [] + + p._vector_candidates = AsyncMock(side_effect=vector_candidates) + user_a_doc = _fact("user-a-doc", "User A likes tea") + user_b_doc = _fact("user-b-doc", "User B likes coffee") + + await asyncio.gather( + p.dedup_extracted_memories("user-a", {"facts": [user_a_doc], "episodic": [], "updates": []}), + p.dedup_extracted_memories("user-b", {"facts": [user_b_doc], "episodic": [], "updates": []}), + ) + + assert dict(seen) == {"user-a-doc": "user-a", "user-b-doc": "user-b"} + + +@pytest.mark.asyncio +async def test_dedup_extracted_memories_vector_ladder_and_intra_batch(): + p = _service() + docs = [ + _fact("drop-existing", "User likes tea"), + _fact("tag-existing", "User likes coffee"), + _fact("clean", "User likes water"), + _fact("batch-keeper", "User likes green tea"), + _fact("drop-batch", "User likes green tea too"), + ] + for doc in docs: + doc.pop("embedding", None) + p._embed_batch.return_value = [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, -1.0], + [0.7, 0.7], + [0.7, 0.7], + ] + p._vector_candidates = AsyncMock( + side_effect=[ + [{"id": "old-high", "content": "User likes tea.", "type": "fact", "score": 0.99}], + [{"id": "old-mid", "content": "User enjoys coffee.", "type": "fact", "score": 0.90}], + [{"id": "old-low", "content": "Unrelated", "type": "fact", "score": 0.20}], + [], + [], + ] + ) + + out = await p.dedup_extracted_memories("u1", {"facts": docs, "episodic": [], "updates": []}) + + # Intra-batch new-vs-new dedup was removed: drop-batch (no existing match) + # is now kept; same-batch near-dups are deferred to reconcile. + ids = [doc["id"] for doc in out["facts"]] + assert ids == ["tag-existing", "clean", "batch-keeper", "drop-batch"] + assert all("embedding" in doc for doc in out["facts"]) + tagged = out["facts"][0] + assert "sys:dup-candidate" in tagged["tags"] + assert tagged["metadata"]["dup_of"] == "old-mid" + assert tagged["metadata"]["dup_score"] == 0.90 + assert "sys:dup-candidate" not in out["facts"][1]["tags"] + stats = out["updates"][-1] + assert stats["vector_dedup_skipped"] == 1 + assert stats["dup_candidates_tagged"] == 1 + + +@pytest.mark.asyncio +async def test_dedup_skips_underspecified_doc_verbatim(): + # Parity with sync: a doc with no/unknown type is passed through untouched + # and never runs vector dedup (async previously defaulted type to the bucket). + p = _service() + p._vector_candidates = AsyncMock( + return_value=[{"id": "x", "content": "c", "type": "fact", "score": 0.99}] + ) + doc = _fact("f1", "content") + doc.pop("type") + doc.pop("embedding", None) + p._embed_batch.return_value = [[1.0, 0.0]] + + out = await p.dedup_extracted_memories("u1", {"facts": [doc], "episodic": [], "updates": []}) + + assert [d["id"] for d in out["facts"]] == ["f1"] + assert "sys:dup-candidate" not in out["facts"][0].get("tags", []) + p._vector_candidates.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_reconcile_memory_type_routes_episodic_merge_only(monkeypatch): + monkeypatch.setattr("azure.cosmos.agent_memory.aio.services.pipeline.get_dedup_reconcile_mode", lambda: "full_pool") + p = _service() + episodes = [_episode("ep_1", "CI failed then retries fixed it"), _episode("ep_2", "CI failed; retries fixed it")] + p._active_memories_for_reconcile = AsyncMock(return_value=episodes) + p._run_prompty = AsyncMock( + return_value=json.dumps( + { + "duplicate_groups": [ + { + "merged_content": "CI failed, then retries fixed it", + "source_ids": ["ep_1", "ep_2"], + "confidence": 0.9, + "salience": 0.8, + } + ], + "kept_ids": [], + "contradicted_pairs": [{"winner_id": "ep_1", "loser_id": "ep_2"}], + } + ) + ) + + result = await p.reconcile_memories("u1", memory_type="episodic") + + assert result == {"kept": 0, "merged": 2, "contradicted": 0} + assert p._run_prompty.await_args.kwargs["inputs"].keys() == {"episodics_text"} + assert p._run_prompty.await_args.args[0] == "dedup_episodic.prompty" + + +@pytest.mark.asyncio +async def test_candidate_reconcile_builds_connected_component(): + p = _service() + docs = { + "f1": _fact("f1", "User likes aisle seats", tags=["sys:fact", "sys:dup-candidate"]), + "f2": _fact("f2", "User prefers aisle seats"), + "f3": _fact("f3", "User enjoys aisle seats"), + } + + async def query_items(_container, **kwargs): + params = {p["name"]: p["value"] for p in kwargs.get("parameters", [])} + if params.get("@tag") == "sys:dup-candidate": + return [docs["f1"]] + ids = [value for name, value in params.items() if name.startswith("@id")] + return [docs[mid] for mid in ids if mid in docs] + + p._query_items = AsyncMock(side_effect=query_items) + p._vector_candidates = AsyncMock( + side_effect=[ + [ + {"id": "f2", "content": docs["f2"]["content"], "type": "fact", "score": 0.90}, + {"id": "f3", "content": docs["f3"]["content"], "type": "fact", "score": 0.88}, + ], + [ + {"id": "f2", "content": docs["f2"]["content"], "type": "fact", "score": 0.90}, + {"id": "f3", "content": docs["f3"]["content"], "type": "fact", "score": 0.88}, + ], + [ + {"id": "f1", "content": docs["f1"]["content"], "type": "fact", "score": 0.90}, + {"id": "f3", "content": docs["f3"]["content"], "type": "fact", "score": 0.89}, + ], + [ + {"id": "f1", "content": docs["f1"]["content"], "type": "fact", "score": 0.88}, + {"id": "f2", "content": docs["f2"]["content"], "type": "fact", "score": 0.89}, + ], + ] + ) + p._run_prompty = AsyncMock( + return_value=json.dumps( + { + "duplicate_groups": [ + { + "merged_content": "User prefers aisle seats", + "source_ids": ["f1", "f2", "f3"], + "confidence": 0.9, + "salience": 0.8, + } + ], + "contradicted_pairs": [], + "kept_ids": [], + } + ) + ) + + result = await p.reconcile_memories("u1") + + assert result == { + "kept": 0, + "merged": 3, + "contradicted": 0, + "reconcile_clusters_sent": 1, + "reconcile_llm_calls_saved": 2, + } + p._run_prompty.assert_awaited_once() + facts_text = p._run_prompty.await_args.kwargs["inputs"]["facts_text"] + assert all(fid in facts_text for fid in ["f1", "f2", "f3"]) diff --git a/tests/unit/aio/test_auto_trigger.py b/tests/unit/aio/test_auto_trigger.py index edcf382..adb9f88 100644 --- a/tests/unit/aio/test_auto_trigger.py +++ b/tests/unit/aio/test_auto_trigger.py @@ -8,14 +8,51 @@ from __future__ import annotations import asyncio -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from azure.cosmos.exceptions import CosmosResourceNotFoundError +from azure.cosmos.agent_memory import _counters from azure.cosmos.agent_memory.aio.cosmos_memory_client import AsyncCosmosMemoryClient from azure.cosmos.agent_memory.aio.processors import AsyncInProcessProcessor +class _AsyncFakeCounterContainer: + """Async in-memory counter container exercising the REAL increment / + watermark-read / watermark-advance helpers end-to-end (no constant mocks).""" + + def __init__(self) -> None: + self.store: dict[str, dict] = {} + self._etag = 0 + + async def read_item(self, *, item, partition_key): + if item not in self.store: + raise CosmosResourceNotFoundError(message="404") + return dict(self.store[item]) + + async def create_item(self, *, body): + self._etag += 1 + body = dict(body) + body["_etag"] = f"e{self._etag}" + self.store[body["id"]] = body + return dict(body) + + async def upsert_item(self, *, body, **_kwargs): + self._etag += 1 + body = dict(body) + body["_etag"] = f"e{self._etag}" + self.store[body["id"]] = body + return dict(body) + + async def patch_item(self, *, item, partition_key, patch_operations): + doc = self.store.setdefault(item, {"id": item}) + for op in patch_operations: + doc[op["path"].lstrip("/")] = op["value"] + return dict(doc) + + + class TestAsyncAutoTriggerNonBlocking: @pytest.mark.asyncio async def test_push_to_cosmos_does_not_await_auto_trigger(self, monkeypatch): @@ -25,7 +62,7 @@ async def test_push_to_cosmos_does_not_await_auto_trigger(self, monkeypatch): processor = AsyncInProcessProcessor(pipeline=MagicMock()) - async def slow_trigger(user_id, thread_id): + async def slow_trigger(user_id, thread_id, recent_k=None): # If push_to_cosmos awaited the trigger inline, the test would # block here for half a second before returning. await asyncio.sleep(0.5) @@ -63,6 +100,224 @@ async def fake_upsert(body): await asyncio.gather(*list(client._background_tasks), return_exceptions=True) +class TestAsyncExtractRecentK: + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("n_facts", "batch_count", "counter_result", "expected_recent_k"), + [ + (1, 1, (0, 1), 1), + (1, 3, (0, 3), 3), + (5, 1, (4, 5), 5), + ], + ) + async def test_extract_recent_k_uses_max_threshold_and_batch_count( + self, + monkeypatch, + n_facts, + batch_count, + counter_result, + expected_recent_k, + ): + monkeypatch.setenv("FACT_EXTRACTION_EVERY_N", str(n_facts)) + monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") + monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") + + processor = AsyncInProcessProcessor(pipeline=MagicMock()) + processor.process_extract_memories = MagicMock(return_value={}) + + client = AsyncCosmosMemoryClient(use_default_credential=False, processor=processor) + + async def fake_upsert(body): + return body + + client._memories_container_client = MagicMock() + client._memories_container_client.upsert_item = MagicMock(side_effect=fake_upsert) + client._turns_container_client = client._memories_container_client + client._summaries_container_client = client._memories_container_client + client._counter_container_client = MagicMock() + + with patch( + "azure.cosmos.agent_memory._counters.increment_counter_async", + return_value=counter_result, + ): + for i in range(batch_count): + client.add_local(user_id="u1", role="user", thread_id="t1", content=f"hi {i}") + await client.push_to_cosmos() + await asyncio.gather(*list(client._background_tasks), return_exceptions=True) + + processor.process_extract_memories.assert_called_once_with( + user_id="u1", + thread_id="t1", + recent_k=expected_recent_k, + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("counter_result", "watermark", "expected_recent_k"), + [ + ((5, 10), 5, 5), # backlog = new - watermark + ((0, 1), 1, 1), # new == watermark -> floored to 1 + ((98, 100), 0, 100), # large backlog is NOT capped + ((20, 30), None, 30), # BOOTSTRAP: no watermark -> base=0 -> recent_k = new_count + ], + ) + async def test_extract_recent_k_uses_watermark( + self, monkeypatch, counter_result, watermark, expected_recent_k + ): + monkeypatch.setenv("FACT_EXTRACTION_EVERY_N", "1") + monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") + monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") + + processor = AsyncInProcessProcessor(pipeline=MagicMock()) + processor.process_extract_memories = MagicMock(return_value={}) + client = AsyncCosmosMemoryClient(use_default_credential=False, processor=processor) + client._memories_container_client = MagicMock() + client._memories_container_client.upsert_item = AsyncMock(side_effect=lambda body: body) + client._turns_container_client = client._memories_container_client + client._summaries_container_client = client._memories_container_client + client._counter_container_client = MagicMock() + + with patch( + "azure.cosmos.agent_memory._counters.increment_counter_async", + return_value=counter_result, + ), patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_async", + new=AsyncMock(return_value=watermark), + ), patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_async", + new=AsyncMock(), + ) as advance: + client.add_local(user_id="u1", role="user", thread_id="t1", content="hi") + await client.push_to_cosmos() + await asyncio.gather(*list(client._background_tasks), return_exceptions=True) + + processor.process_extract_memories.assert_called_once_with( + user_id="u1", thread_id="t1", recent_k=expected_recent_k + ) + advance.assert_awaited_once() + + @pytest.mark.asyncio + async def test_watermark_not_advanced_when_extract_fails(self, monkeypatch): + monkeypatch.setenv("FACT_EXTRACTION_EVERY_N", "1") + monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") + monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") + + processor = AsyncInProcessProcessor(pipeline=MagicMock()) + processor.process_extract_memories = MagicMock(side_effect=RuntimeError("llm down")) + client = AsyncCosmosMemoryClient(use_default_credential=False, processor=processor) + client._memories_container_client = MagicMock() + client._memories_container_client.upsert_item = AsyncMock(side_effect=lambda body: body) + client._turns_container_client = client._memories_container_client + client._summaries_container_client = client._memories_container_client + client._counter_container_client = MagicMock() + + with patch( + "azure.cosmos.agent_memory._counters.increment_counter_async", + return_value=(0, 1), + ), patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_async", + new=AsyncMock(return_value=None), + ), patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_async", + new=AsyncMock(), + ) as advance, patch( + "azure.cosmos.agent_memory._counters.stamp_failure_async", + new=AsyncMock(), + ) as stamp: + client.add_local(user_id="u1", role="user", thread_id="t1", content="hi") + await client.push_to_cosmos() + await asyncio.gather(*list(client._background_tasks), return_exceptions=True) + + advance.assert_not_awaited() + stamp.assert_awaited_once() + + @pytest.mark.asyncio + async def test_watermark_round_trip_fail_then_succeed_no_strand(self, monkeypatch): + """Async stateful round-trip against a REAL in-memory counter: first + extract fails, second succeeds and must cover EVERY turn so far (20), + not just its own batch (10) — the bootstrap strand regression.""" + monkeypatch.setenv("FACT_EXTRACTION_EVERY_N", "1") + monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") + monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") + + counter = _AsyncFakeCounterContainer() + recorded: list[int] = [] + + def extract(*, user_id, thread_id, recent_k): + recorded.append(recent_k) + if len(recorded) == 1: + raise RuntimeError("transient LLM outage") + return {} + + processor = AsyncInProcessProcessor(pipeline=MagicMock()) + processor.process_extract_memories = MagicMock(side_effect=extract) + client = AsyncCosmosMemoryClient(use_default_credential=False, processor=processor) + client._memories_container_client = MagicMock() + client._memories_container_client.upsert_item = AsyncMock(side_effect=lambda body: body) + client._turns_container_client = client._memories_container_client + client._summaries_container_client = client._memories_container_client + client._counter_container_client = counter + + for i in range(10): + client.add_local(user_id="u1", role="user", thread_id="t1", content=f"a{i}") + await client.push_to_cosmos() + await asyncio.gather(*list(client._background_tasks), return_exceptions=True) + + for i in range(10): + client.add_local(user_id="u1", role="user", thread_id="t1", content=f"b{i}") + await client.push_to_cosmos() + await asyncio.gather(*list(client._background_tasks), return_exceptions=True) + + assert recorded == [10, 20] + cid = _counters.thread_counter_id("u1", "t1") + assert await _counters.read_extract_watermark_async(counter, cid, "u1", "t1") == 20 + + @pytest.mark.asyncio + async def test_reconcile_full_rebuild_on_persisted_counter_cadence(self, monkeypatch): + """Async symmetry: in-process auto-trigger requests full_rebuild on the + persisted-counter cadence (every 2 turns here), like the durable backend.""" + monkeypatch.setenv("FACT_EXTRACTION_EVERY_N", "1") + monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") + monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") + monkeypatch.setenv("DEDUP_EVERY_N", "1") + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_full_recluster_every_n", lambda: 2 + ) + + rebuilds: list[bool] = [] + processor = AsyncInProcessProcessor(pipeline=MagicMock()) + processor.process_extract_memories = MagicMock(return_value={}) + processor.synthesize_procedural = MagicMock(return_value=None) + processor.process_reconcile = MagicMock( + side_effect=lambda *, user_id, full_rebuild=False: rebuilds.append(full_rebuild) + ) + client = AsyncCosmosMemoryClient(use_default_credential=False, processor=processor) + client._memories_container_client = MagicMock() + client._memories_container_client.upsert_item = AsyncMock(side_effect=lambda body: body) + client._turns_container_client = client._memories_container_client + client._summaries_container_client = client._memories_container_client + client._counter_container_client = MagicMock() + + with patch( + "azure.cosmos.agent_memory._counters.increment_counter_async", + new=AsyncMock(side_effect=[(0, 1), (1, 2)]), + ), patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_async", + new=AsyncMock(return_value=None), + ), patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_async", + new=AsyncMock(), + ): + client.add_local(user_id="u1", role="user", thread_id="t1", content="a") + await client.push_to_cosmos() + await asyncio.gather(*list(client._background_tasks), return_exceptions=True) + client.add_local(user_id="u1", role="user", thread_id="t1", content="b") + await client.push_to_cosmos() + await asyncio.gather(*list(client._background_tasks), return_exceptions=True) + + assert rebuilds == [False, True] + + class TestPushToCosmosUnflushedDelta: """``push_to_cosmos`` must use the unflushed-add delta, not a recount of ``local_memory``, so callers that retain the buffer don't re-fire diff --git a/tests/unit/aio/test_cosmos_memory_client.py b/tests/unit/aio/test_cosmos_memory_client.py index 0010495..8eb9f2c 100644 --- a/tests/unit/aio/test_cosmos_memory_client.py +++ b/tests/unit/aio/test_cosmos_memory_client.py @@ -9,7 +9,9 @@ import pytest +from azure.cosmos.agent_memory._container_routing import ContainerKey from azure.cosmos.agent_memory.aio.cosmos_memory_client import AsyncCosmosMemoryClient +from azure.cosmos.agent_memory.aio.store import AsyncMemoryStore from azure.cosmos.agent_memory.exceptions import ( ConfigurationError, CosmosNotConnectedError, @@ -369,7 +371,8 @@ async def test_create_memory_store_turns_container_uses_30_day_ttl(self): assert mem._turns_container_client is mock_turns_container async def test_create_memory_store_defaults_to_serverless(self): - mem = _make_client(cosmos_throughput_mode="serverless") + # serverless mode ignores autoscale config entirely, even an invalid value. + mem = _make_client(cosmos_throughput_mode="serverless", cosmos_autoscale_max_ru="not-an-int") mock_cosmos_cls = MagicMock() mock_client = MagicMock() mock_db = AsyncMock() @@ -391,29 +394,28 @@ async def test_create_memory_store_defaults_to_serverless(self): ] ) - with patch.dict("os.environ", {"COSMOS_DB_AUTOSCALE_MAX_RU": "not-an-int"}, clear=False): - with patch.dict( - "sys.modules", - { - "azure.cosmos.aio": MagicMock(CosmosClient=mock_cosmos_cls), - "azure.cosmos": MagicMock( - PartitionKey=MagicMock(), - ThroughputProperties=MagicMock(), - ), - }, - ): - await mem.create_memory_store( - endpoint="https://fake.documents.azure.com:443/", - credential="fake-key", - throughput_mode="serverless", - ) + with patch.dict( + "sys.modules", + { + "azure.cosmos.aio": MagicMock(CosmosClient=mock_cosmos_cls), + "azure.cosmos": MagicMock( + PartitionKey=MagicMock(), + ThroughputProperties=MagicMock(), + ), + }, + ): + await mem.create_memory_store( + endpoint="https://fake.documents.azure.com:443/", + credential="fake-key", + throughput_mode="serverless", + ) for call in mock_db.create_container_if_not_exists.await_args_list: assert "offer_throughput" not in call.kwargs - def test_constructor_ignores_invalid_autoscale_env_in_serverless_mode(self): - with patch.dict("os.environ", {"COSMOS_DB_AUTOSCALE_MAX_RU": "not-an-int"}, clear=False): - mem = _make_client(cosmos_throughput_mode="serverless") + def test_constructor_ignores_autoscale_in_serverless_mode(self): + # Even an invalid autoscale value is ignored in serverless mode. + mem = _make_client(cosmos_throughput_mode="serverless", cosmos_autoscale_max_ru="not-an-int") assert mem._cosmos_autoscale_max_ru is None @@ -738,6 +740,37 @@ async def test_search_cosmos(self): query = container.query_items.call_args.kwargs["query"] assert "VectorDistance" in query + async def test_search_uses_keyword_params_for_hybrid_sql(self): + mem, container = _connected_client() + container.query_items = MagicMock(return_value=AsyncIterator([_make_doc()])) + + mem._embeddings_client = AsyncMock() + mem._embeddings_client.generate = AsyncMock(return_value=[0.1]) + + await mem.search_cosmos(search_terms="weather Seattle", top_k=3) + + query = container.query_items.call_args.kwargs["query"] + parameters = container.query_items.call_args.kwargs["parameters"] + assert "ORDER BY RANK RRF" in query + assert "FullTextScore(c.content, @kw0, @kw1)" in query + assert {"name": "@kw0", "value": "weather"} in parameters + assert {"name": "@kw1", "value": "seattle"} in parameters + + async def test_search_all_stopwords_uses_vector_sql(self): + mem, container = _connected_client() + container.query_items = MagicMock(return_value=AsyncIterator([_make_doc()])) + + mem._embeddings_client = AsyncMock() + mem._embeddings_client.generate = AsyncMock(return_value=[0.1]) + + await mem.search_cosmos(search_terms="what is the", top_k=3) + + query = container.query_items.call_args.kwargs["query"] + parameters = container.query_items.call_args.kwargs["parameters"] + assert "ORDER BY VectorDistance" in query + assert "RRF" not in query + assert not any(parameter["name"].startswith("@kw") for parameter in parameters) + async def test_search_hybrid(self): mem, container = _connected_client() docs = [_make_doc()] @@ -748,7 +781,6 @@ async def test_search_hybrid(self): results = await mem.search_cosmos( search_terms="weather Seattle", - hybrid_search=True, top_k=5, ) @@ -788,6 +820,38 @@ async def test_search_whitespace_only_terms(self): with pytest.raises(ValidationError, match="search_terms must be a non-empty string"): await mem.search_cosmos(search_terms=" ") + async def test_search_episodic_forwards_search_options(self): + containers = {key: MagicMock() for key in ContainerKey} + store = AsyncMemoryStore(containers=containers) + store.search = AsyncMock(return_value=[]) + + await store.search_episodic( + user_id="u1", + search_terms="weather", + top_k=2, + min_salience=0.4, + include_superseded=True, + ) + + store.search.assert_awaited_once_with( + search_terms="weather", + user_id="u1", + memory_types=["episodic"], + top_k=2, + min_salience=0.4, + include_superseded=True, + ) + + async def test_build_episodic_context_forwards_search_options(self): + containers = {key: MagicMock() for key in ContainerKey} + store = AsyncMemoryStore(containers=containers) + store.search_episodic = AsyncMock(return_value=[]) + + context = await store.build_episodic_context("u1", "weather", top_k=2) + + assert context == "" + store.search_episodic.assert_awaited_once_with("u1", "weather", top_k=2) + # =================================================================== # Processing delegation (async) diff --git a/tests/unit/aio/test_process_now.py b/tests/unit/aio/test_process_now.py index 25ae108..d0aa22e 100644 --- a/tests/unit/aio/test_process_now.py +++ b/tests/unit/aio/test_process_now.py @@ -2,7 +2,7 @@ from __future__ import annotations -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, call, patch import pytest @@ -50,7 +50,13 @@ async def test_process_now_with_inprocess_invokes_full_pipeline(): assert isinstance(client._processor, AsyncInProcessProcessor) pipeline.generate_thread_summary.assert_awaited_once_with("u", "t") pipeline.extract_memories.assert_awaited_once_with("u", "t") - pipeline.reconcile_memories.assert_awaited_once_with("u", 50) + assert pipeline.reconcile_memories.await_count == 2 + pipeline.reconcile_memories.assert_has_awaits( + [ + call("u", n=50, memory_type="fact", full_rebuild=False), + call("u", n=50, memory_type="episodic", full_rebuild=False), + ] + ) pipeline.synthesize_procedural.assert_awaited_once_with("u", force=False) pipeline.generate_user_summary.assert_awaited_once_with("u", None) assert result.procedural == {"id": "proc1", "type": "procedural"} diff --git a/tests/unit/aio/test_reconcile_telemetry.py b/tests/unit/aio/test_reconcile_telemetry.py index 858e328..cea2326 100644 --- a/tests/unit/aio/test_reconcile_telemetry.py +++ b/tests/unit/aio/test_reconcile_telemetry.py @@ -17,6 +17,14 @@ ASYNC_LOGGER_NAME = "azure.cosmos.agent_memory.pipeline.aio" +@pytest.fixture(autouse=True) +def _pin_async_legacy_reconcile(monkeypatch): + monkeypatch.setattr( + "azure.cosmos.agent_memory.aio.services.pipeline.get_dedup_reconcile_mode", + lambda: "full_pool", + ) + + def _make_async_pipeline() -> AsyncPipelineService: p = AsyncPipelineService.__new__(AsyncPipelineService) p._embeddings = MagicMock() diff --git a/tests/unit/function_app/test_change_feed.py b/tests/unit/function_app/test_change_feed.py index f5b7dd8..a70bcb3 100644 --- a/tests/unit/function_app/test_change_feed.py +++ b/tests/unit/function_app/test_change_feed.py @@ -73,9 +73,16 @@ async def upsert_item(*, body, **_kwargs): state[body["id"]] = dict(body) return body + async def patch_item(*, item, partition_key, patch_operations): + doc = state.setdefault(item, {"id": item}) + for op in patch_operations: + doc[op["path"].lstrip("/")] = op["value"] + return dict(doc) + container.read_item = AsyncMock(side_effect=read_item) container.create_item = AsyncMock(side_effect=create_item) container.upsert_item = AsyncMock(side_effect=upsert_item) + container.patch_item = AsyncMock(side_effect=patch_item) container._state = state # exposed for assertions return container @@ -618,10 +625,19 @@ def test_reconcile_flag_set_only_when_n_facts_times_n_dedup_threshold_crosses(): asyncio.run(process_changefeed_batch([_turn() for _ in range(4)], starter, counter_container=container)) extract_calls = [c for c in starter.start_new.await_args_list if c.args[0] == "ExtractMemoriesOrchestrator"] assert len(extract_calls) == 1 - assert _extract_payload(extract_calls[0]).get("reconcile") is False + first_payload = _extract_payload(extract_calls[0]) + assert first_payload.get("reconcile") is False + # Bootstrap: no watermark yet -> recent_k spans all 4 turns so far. + assert first_payload.get("recent_k") == 4 + + # Simulate the extract orchestrator's em_AdvanceExtractWatermark activity + # advancing the watermark to the processed count (4) after batch 1. + from shared.counters import advance_extract_watermark, thread_counter_id + + asyncio.run(advance_extract_watermark(container, thread_counter_id("u1", "t1"), "u1", "t1", 4)) # Next batch: counter 4 -> 5. Reconcile threshold crossed, so the same - # extract dispatch carries reconcile=True. + # extract dispatch carries reconcile=True, and recent_k = 5 - watermark(4) = 1. starter.start_new.reset_mock() asyncio.run(process_changefeed_batch([_turn()], starter, counter_container=container)) extract_calls = [c for c in starter.start_new.await_args_list if c.args[0] == "ExtractMemoriesOrchestrator"] @@ -629,6 +645,69 @@ def test_reconcile_flag_set_only_when_n_facts_times_n_dedup_threshold_crosses(): payload = _extract_payload(extract_calls[0]) assert payload.get("reconcile") is True assert payload.get("user_id") == "u1" + assert payload.get("recent_k") == 1 + + +@patch.dict( + os.environ, + { + "THREAD_SUMMARY_EVERY_N": "0", + "FACT_EXTRACTION_EVERY_N": "1", + "USER_SUMMARY_EVERY_N": "0", + "DEDUP_EVERY_N": "1", + }, + clear=False, +) +def test_full_rebuild_flag_set_on_persisted_counter_cadence(): + """The full-pool backstop is driven by the PERSISTED counter (durable-safe), + not the in-memory per-worker sweep counter: full_rebuild=True every + (n_facts * n_dedup * DEDUP_FULL_RECLUSTER_EVERY_N) turns. Here that's + 1 * 1 * 2 = every 2 turns.""" + with patch( + "azure.cosmos.agent_memory.thresholds.get_dedup_full_recluster_every_n", + return_value=2, + ): + starter = _make_starter() + container = _make_counter_container_starting_at() + + # Turn 1: counter 0->1. Reconcile crosses (n=1), full does NOT (n=2). + asyncio.run(process_changefeed_batch([_turn()], starter, counter_container=container)) + p1 = _extract_payload( + next(c for c in starter.start_new.await_args_list if c.args[0] == "ExtractMemoriesOrchestrator") + ) + assert p1.get("reconcile") is True + assert p1.get("full_rebuild") is False + + # Turn 2: counter 1->2. Full backstop threshold (2) crossed. + starter.start_new.reset_mock() + asyncio.run(process_changefeed_batch([_turn()], starter, counter_container=container)) + p2 = _extract_payload( + next(c for c in starter.start_new.await_args_list if c.args[0] == "ExtractMemoriesOrchestrator") + ) + assert p2.get("reconcile") is True + assert p2.get("full_rebuild") is True + + +@patch.dict( + os.environ, + { + "THREAD_SUMMARY_EVERY_N": "0", + "FACT_EXTRACTION_EVERY_N": "1", + "USER_SUMMARY_EVERY_N": "0", + "DEDUP_EVERY_N": "0", + }, + clear=False, +) +def test_extract_payload_recent_k_uses_max_threshold_and_batch_delta(): + starter = _make_starter() + container = _make_counter_container_starting_at() + + asyncio.run(process_changefeed_batch([_turn() for _ in range(3)], starter, counter_container=container)) + + extract_calls = [c for c in starter.start_new.await_args_list if c.args[0] == "ExtractMemoriesOrchestrator"] + assert len(extract_calls) == 1 + payload = _extract_payload(extract_calls[0]) + assert payload["recent_k"] == 3 @patch.dict( diff --git a/tests/unit/function_app/test_orchestrators.py b/tests/unit/function_app/test_orchestrators.py index 4f8f60f..9fc629f 100644 --- a/tests/unit/function_app/test_orchestrators.py +++ b/tests/unit/function_app/test_orchestrators.py @@ -9,7 +9,7 @@ from __future__ import annotations -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, call, patch import pytest from orchestrators import extract_memories as em_mod @@ -193,6 +193,7 @@ def test_extract_only_when_reconcile_flag_absent(self, _retry): gen, [ {"facts": [{"id": "f1"}], "episodic": [], "updates": []}, + {"facts": [{"id": "f1", "deduped": True}], "episodic": [], "updates": []}, { "fact_count": 2, "episodic_count": 0, @@ -201,7 +202,15 @@ def test_extract_only_when_reconcile_flag_absent(self, _retry): ], ) - assert [c[0] for c in ctx._yielded_calls] == ["em_Extract", "em_Persist"] + assert [c[0] for c in ctx._yielded_calls] == ["em_Extract", "em_Dedup", "em_Persist"] + assert ctx._yielded_calls[1][2] == { + "user_id": "u1", + "extracted": {"facts": [{"id": "f1"}], "episodic": [], "updates": []}, + } + assert ctx._yielded_calls[2][2] == { + "user_id": "u1", + "extracted": {"facts": [{"id": "f1", "deduped": True}], "episodic": [], "updates": []}, + } assert result["persisted"] is True assert result["extracted"]["fact_count"] == 2 assert result["reconciled"] is None @@ -213,21 +222,28 @@ def test_chains_reconcile_when_flag_true(self, _retry): result, _ = _drive( gen, [ + {"facts": [{"id": "f1"}], "episodic": [], "updates": []}, {"facts": [{"id": "f1"}], "episodic": [], "updates": []}, {"fact_count": 2, "episodic_count": 0, "updated_count": 0}, - {"kept": 0, "merged": 1, "contradicted": 0}, + { + "fact": {"kept": 0, "merged": 1, "contradicted": 0}, + "episodic": {"kept": 1, "merged": 0, "contradicted": 0}, + }, {"status": "synthesized", "version": 3}, ], ) names = [c[0] for c in ctx._yielded_calls] - assert names == ["em_Extract", "em_Persist", "em_ReconcileMemories"] - assert ctx._yielded_calls[2][2] == {"user_id": "u1"} + assert names == ["em_Extract", "em_Dedup", "em_Persist", "em_ReconcileMemories"] + assert ctx._yielded_calls[3][2] == {"user_id": "u1", "full_rebuild": False} assert [s[0] for s in ctx._yielded_sub_orchestrators] == [ "SynthesizeProceduralOrchestrator", ] assert ctx._yielded_sub_orchestrators[0][2] == {"user_id": "u1", "force": False} - assert result["reconciled"] == {"kept": 0, "merged": 1, "contradicted": 0} + assert result["reconciled"] == { + "fact": {"kept": 0, "merged": 1, "contradicted": 0}, + "episodic": {"kept": 1, "merged": 0, "contradicted": 0}, + } assert result["procedural"] == {"status": "synthesized", "version": 3} @patch.object(em_mod, "default_retry_options", return_value=MagicMock()) @@ -240,17 +256,24 @@ def boom_after_sub(name, retry, sub_payload, *args, **kwargs): return ("__sub_boom__", name, sub_payload) ctx.call_sub_orchestrator_with_retry.side_effect = boom_after_sub - # We'll send activity results normally for the first 3 yields, then throw - # an exception into the 4th yield (the sub-orchestrator call). + # We'll send activity results normally, then throw an exception into the + # sub-orchestrator yield. gen = self._orchestrator()(ctx) # Yield 1: em_Extract gen.send(None) - # Yield 2: em_Persist + # Yield 2: em_Dedup + gen.send({"facts": [{"id": "f1"}], "episodic": [], "updates": []}) + # Yield 3: em_Persist gen.send({"facts": [{"id": "f1"}], "episodic": [], "updates": []}) - # Yield 3: em_ReconcileMemories + # Yield 4: em_ReconcileMemories gen.send({"fact_count": 2, "episodic_count": 0, "updated_count": 0}) - # Yield 4: SynthesizeProceduralOrchestrator — throw an exception - gen.send({"kept": 0, "merged": 1, "contradicted": 0}) + # Yield 5: SynthesizeProceduralOrchestrator — throw an exception + gen.send( + { + "fact": {"kept": 0, "merged": 1, "contradicted": 0}, + "episodic": {"kept": 1, "merged": 0, "contradicted": 0}, + } + ) try: gen.throw(RuntimeError("procedural blew up")) except StopIteration as stop: @@ -259,7 +282,10 @@ def boom_after_sub(name, retry, sub_payload, *args, **kwargs): pytest.fail("orchestrator did not return after procedural exception") assert result["persisted"] is True - assert result["reconciled"] == {"kept": 0, "merged": 1, "contradicted": 0} + assert result["reconciled"] == { + "fact": {"kept": 0, "merged": 1, "contradicted": 0}, + "episodic": {"kept": 1, "merged": 0, "contradicted": 0}, + } assert result["procedural"] is None @patch.object(em_mod, "default_retry_options", return_value=MagicMock()) @@ -269,23 +295,44 @@ def test_procedural_not_called_when_reconcile_skipped(self, _retry): result, _ = _drive( gen, [ + {"facts": [], "episodic": [], "updates": []}, {"facts": [], "episodic": [], "updates": []}, {"fact_count": 0, "episodic_count": 0, "updated_count": 0}, ], ) - assert [c[0] for c in ctx._yielded_calls] == ["em_Extract", "em_Persist"] + assert [c[0] for c in ctx._yielded_calls] == ["em_Extract", "em_Dedup", "em_Persist"] assert ctx._yielded_sub_orchestrators == [] assert result["procedural"] is None @patch.object(em_mod, "default_retry_options", return_value=MagicMock()) - def test_extract_payload_carries_user_thread_and_limit(self, _retry): + def test_extract_payload_carries_user_thread_without_recent_k_when_absent(self, _retry): ctx = _make_context({"user_id": "u", "thread_id": "t"}) gen = self._orchestrator()(ctx) - _drive(gen, [{"facts": []}, {"fact_count": 0}]) + _drive(gen, [{"facts": []}, {"facts": []}, {"fact_count": 0}]) + + extract_payload = ctx._yielded_calls[0][2] + assert extract_payload == {"user_id": "u", "thread_id": "t"} + + @patch.object(em_mod, "default_retry_options", return_value=MagicMock()) + def test_extract_payload_carries_recent_k_when_provided(self, _retry): + ctx = _make_context({"user_id": "u", "thread_id": "t", "recent_k": 7}) + gen = self._orchestrator()(ctx) + _drive(gen, [{"facts": []}, {"facts": []}, {"fact_count": 0}]) extract_payload = ctx._yielded_calls[0][2] - assert extract_payload == {"user_id": "u", "thread_id": "t", "limit": 20} + assert extract_payload == {"user_id": "u", "thread_id": "t", "recent_k": 7} + + @patch.object(em_mod, "default_retry_options", return_value=MagicMock()) + def test_dedup_output_flows_to_persist(self, _retry): + extracted = {"facts": [{"id": "f1"}], "episodic": [], "updates": []} + deduped = {"facts": [{"id": "f1", "embedding": [0.1]}], "episodic": [], "updates": []} + ctx = _make_context({"user_id": "u", "thread_id": "t"}) + gen = self._orchestrator()(ctx) + _drive(gen, [extracted, deduped, {"fact_count": 1}]) + + assert ctx._yielded_calls[1][2] == {"user_id": "u", "extracted": extracted} + assert ctx._yielded_calls[2][2] == {"user_id": "u", "extracted": deduped} @patch.object(em_mod, "default_retry_options", return_value=MagicMock()) def test_activity_failure_propagates(self, _retry): @@ -295,6 +342,24 @@ def test_activity_failure_propagates(self, _retry): with pytest.raises(ValueError, match="kaboom"): gen.throw(ValueError("kaboom")) + @patch.object(em_mod, "default_retry_options", return_value=MagicMock()) + def test_advance_watermark_after_persist_when_count_present(self, _retry): + ctx = _make_context({"user_id": "u1", "thread_id": "t1", "count": 42}) + gen = self._orchestrator()(ctx) + _drive(gen, [{"facts": []}, {"facts": []}, {"fact_count": 0}, True]) + + names = [c[0] for c in ctx._yielded_calls] + assert names == ["em_Extract", "em_Dedup", "em_Persist", "em_AdvanceExtractWatermark"] + assert ctx._yielded_calls[3][2] == {"user_id": "u1", "thread_id": "t1", "count": 42} + + @patch.object(em_mod, "default_retry_options", return_value=MagicMock()) + def test_no_watermark_advance_when_count_absent(self, _retry): + ctx = _make_context({"user_id": "u1", "thread_id": "t1"}) + gen = self._orchestrator()(ctx) + _drive(gen, [{"facts": []}, {"facts": []}, {"fact_count": 0}]) + names = [c[0] for c in ctx._yielded_calls] + assert "em_AdvanceExtractWatermark" not in names + def test_missing_thread_id_raises(self): with patch.object(em_mod, "default_retry_options", return_value=MagicMock()): ctx = _make_context({"user_id": "u"}) @@ -304,10 +369,107 @@ def test_missing_thread_id_raises(self): # --------------------------------------------------------------------------- -# UserSummaryOrchestrator +# Extract memory activities # --------------------------------------------------------------------------- +class TestExtractMemoryActivities: + def test_em_extract_uses_payload_recent_k(self): + pipeline = MagicMock() + pipeline.extract_memories_dry.return_value = {"facts": [], "episodic": [], "updates": []} + + with patch.object(em_mod, "get_pipeline", return_value=pipeline): + result = em_mod.em_Extract({"user_id": "u1", "thread_id": "t1", "recent_k": 3}) + + pipeline.extract_memories_dry.assert_called_once_with(user_id="u1", thread_id="t1", recent_k=3) + assert result == {"facts": [], "episodic": [], "updates": []} + + def test_em_extract_falls_back_to_max_batch_size_when_recent_k_absent(self): + pipeline = MagicMock() + pipeline.extract_memories_dry.return_value = {"facts": [], "episodic": [], "updates": []} + + with patch.object(em_mod, "get_pipeline", return_value=pipeline): + em_mod.em_Extract({"user_id": "u1", "thread_id": "t1"}) + + pipeline.extract_memories_dry.assert_called_once_with(user_id="u1", thread_id="t1", recent_k=20) + + def test_em_dedup_delegates_to_pipeline_and_returns_deduped_dict(self): + extracted = {"facts": [{"id": "f1"}], "episodic": [], "updates": []} + deduped = {"facts": [{"id": "f1", "embedding": [0.1]}], "episodic": [], "updates": []} + pipeline = MagicMock() + pipeline.dedup_extracted_memories.return_value = deduped + + with patch.object(em_mod, "get_pipeline", return_value=pipeline): + result = em_mod.em_Dedup({"user_id": "u1", "extracted": extracted}) + + pipeline.dedup_extracted_memories.assert_called_once_with(user_id="u1", extracted=extracted) + assert result == deduped + + def test_em_dedup_falls_back_to_input_when_pipeline_returns_none(self): + extracted = {"facts": [], "episodic": [], "updates": []} + pipeline = MagicMock() + pipeline.dedup_extracted_memories.return_value = None + + with patch.object(em_mod, "get_pipeline", return_value=pipeline): + result = em_mod.em_Dedup({"user_id": "u1", "extracted": extracted}) + + assert result is extracted + + def test_em_reconcile_memories_reconciles_fact_and_episodic(self): + pipeline = MagicMock() + pipeline.reconcile_memories.side_effect = [ + {"kept": 2, "merged": 1, "contradicted": 0}, + {"kept": 1, "merged": 0, "contradicted": 0}, + ] + + with ( + patch.object(em_mod, "get_pipeline", return_value=pipeline), + patch("azure.cosmos.agent_memory.thresholds.get_dedup_pool_size", return_value=17), + ): + result = em_mod.em_ReconcileMemories({"user_id": "u1"}) + + # full_rebuild defaults False when the change-feed didn't request a backstop. + assert pipeline.reconcile_memories.call_args_list == [ + call(user_id="u1", n=17, memory_type="fact", full_rebuild=False), + call(user_id="u1", n=17, memory_type="episodic", full_rebuild=False), + ] + assert result == { + "fact": {"kept": 2, "merged": 1, "contradicted": 0}, + "episodic": {"kept": 1, "merged": 0, "contradicted": 0}, + } + + def test_em_reconcile_memories_forwards_full_rebuild(self): + pipeline = MagicMock() + pipeline.reconcile_memories.side_effect = [{"kept": 0}, {"kept": 0}] + with ( + patch.object(em_mod, "get_pipeline", return_value=pipeline), + patch("azure.cosmos.agent_memory.thresholds.get_dedup_pool_size", return_value=17), + ): + em_mod.em_ReconcileMemories({"user_id": "u1", "full_rebuild": True}) + + assert pipeline.reconcile_memories.call_args_list == [ + call(user_id="u1", n=17, memory_type="fact", full_rebuild=True), + call(user_id="u1", n=17, memory_type="episodic", full_rebuild=True), + ] + + def test_em_advance_extract_watermark_stamps_counter(self): + import asyncio + + from shared import cosmos_clients, counters + + container = MagicMock() + with ( + patch.object(cosmos_clients, "get_counter_container_async", new=AsyncMock(return_value=container)), + patch.object(counters, "advance_extract_watermark", new=AsyncMock()) as advance, + ): + result = asyncio.run( + em_mod.em_AdvanceExtractWatermark({"user_id": "u1", "thread_id": "t1", "count": 9}) + ) + + assert result is True + advance.assert_awaited_once_with(container, "thread:u1:t1", "u1", "t1", 9) + + class TestUserSummaryOrchestrator: def _orchestrator(self): return _user_function(us_mod.UserSummaryOrchestrator) diff --git a/tests/unit/processors/test_inprocess.py b/tests/unit/processors/test_inprocess.py index be8d841..28ddc2f 100644 --- a/tests/unit/processors/test_inprocess.py +++ b/tests/unit/processors/test_inprocess.py @@ -16,20 +16,23 @@ def test_process_thread_calls_summarize_extract_reconcile_in_order(): proc = InProcessProcessor(pipeline=pipeline) result = proc.process_thread(user_id="u1", thread_id="t1", turns=[]) - # Order of calls: summary -> extract -> reconcile + # Order of calls: summary -> extract -> reconcile (fact) -> reconcile (episodic) method_order = [c[0] for c in pipeline.method_calls] assert method_order == [ "generate_thread_summary", "extract_memories", "reconcile_memories", + "reconcile_memories", ] pipeline.generate_thread_summary.assert_called_once_with("u1", "t1") pipeline.extract_memories.assert_called_once_with("u1", "t1") - pipeline.reconcile_memories.assert_called_once_with("u1", 50) + assert pipeline.reconcile_memories.call_count == 2 + assert pipeline.reconcile_memories.call_args_list[0].kwargs["memory_type"] == "fact" + assert pipeline.reconcile_memories.call_args_list[1].kwargs["memory_type"] == "episodic" assert isinstance(result, ProcessThreadResult) assert result.thread_summary == {"id": "summary_u_t", "type": "thread_summary"} - assert result.reconciled_count == 3 + assert result.reconciled_count == 6 assert result.elapsed_ms >= 0 @@ -65,6 +68,20 @@ def test_generate_user_summary_no_summaries(): assert res.summary is None +def test_process_extract_memories_passes_recent_k_and_filters_to_ints(): + pipeline = MagicMock() + pipeline.extract_memories.return_value = { + "fact_count": 2, + "non_int_field": "skip me", + } + + proc = InProcessProcessor(pipeline=pipeline) + result = proc.process_extract_memories(user_id="u", thread_id="t", recent_k=3) + + pipeline.extract_memories.assert_called_once_with("u", "t", recent_k=3) + assert result == {"fact_count": 2} + + def test_close_is_noop(): proc = InProcessProcessor(pipeline=MagicMock()) assert proc.close() is None diff --git a/tests/unit/processors/test_protocol_satisfaction.py b/tests/unit/processors/test_protocol_satisfaction.py index dc63c98..6869f00 100644 --- a/tests/unit/processors/test_protocol_satisfaction.py +++ b/tests/unit/processors/test_protocol_satisfaction.py @@ -32,6 +32,7 @@ def process_extract_memories( *, user_id: str, thread_id: str, + recent_k: Optional[int] = None, ) -> dict[str, int]: return {} diff --git a/tests/unit/services/test_chaos_extract_persist.py b/tests/unit/services/test_chaos_extract_persist.py index d12e61b..82c61d0 100644 --- a/tests/unit/services/test_chaos_extract_persist.py +++ b/tests/unit/services/test_chaos_extract_persist.py @@ -11,6 +11,26 @@ from azure.cosmos.agent_memory.services.pipeline import PipelineService, _StoreContainerAdapter +@pytest.fixture(autouse=True) +def _pin_legacy_extract_dedup(monkeypatch): + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_vector_enabled", + lambda: False, + ) + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_context_vector_enabled", + lambda: False, + ) + monkeypatch.setattr( + "azure.cosmos.agent_memory.aio.services.pipeline.get_dedup_vector_enabled", + lambda: False, + ) + monkeypatch.setattr( + "azure.cosmos.agent_memory.aio.services.pipeline.get_dedup_context_vector_enabled", + lambda: False, + ) + + class _FlakyContainer: def __init__(self): self.docs: dict[str, dict[str, Any]] = {} diff --git a/tests/unit/services/test_dedup_vector.py b/tests/unit/services/test_dedup_vector.py new file mode 100644 index 0000000..a44f83a --- /dev/null +++ b/tests/unit/services/test_dedup_vector.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import MagicMock + +from azure.cosmos.agent_memory.services.pipeline import PipelineService + + +def _make_pipeline() -> PipelineService: + p = PipelineService.__new__(PipelineService) + p._memories_container = MagicMock() + p._container = p._memories_container + p._embeddings = MagicMock() + p._embed_batch = MagicMock() + p._embed_one = MagicMock(return_value=[1.0]) + p._run_prompty = MagicMock( + return_value=json.dumps({"duplicate_groups": [], "contradicted_pairs": [], "kept_ids": []}) + ) + p._upsert_memory = MagicMock(side_effect=lambda doc: doc) + p._mark_superseded = MagicMock(return_value=True) + return p + + +def _doc(mid: str, content: str, memory_type: str = "fact", **extra: Any) -> dict[str, Any]: + tags = extra.pop("tags", [f"sys:{memory_type}"]) + metadata = extra.pop("metadata", {"category": "preference"} if memory_type == "fact" else { + "scope_type": "project", + "scope_value": "demo", + "lesson": content, + "outcome_valence": "neutral", + }) + return { + "id": mid, + "user_id": "u1", + "thread_id": "t1", + "type": memory_type, + "role": "system", + "content": content, + "content_hash": mid, + "confidence": 0.8, + "salience": 0.7, + "tags": tags, + "metadata": metadata, + "prompt_id": "extract_memories.prompty", + "prompt_version": "v1", + "created_at": "2025-01-01T00:00:00+00:00", + "updated_at": "2025-01-01T00:00:00+00:00", + **extra, + } + + +def test_vector_distance_function_reads_container_policy() -> None: + # The distance function comes from the container's vector embedding policy + # (read once, cached), NOT an env var. + p = _make_pipeline() + p._memories_container.read.return_value = { + "vectorEmbeddingPolicy": {"vectorEmbeddings": [{"path": "/embedding", "distanceFunction": "euclidean"}]} + } + assert p._vector_distance_function() == "euclidean" + # Cached: a later policy change is not re-read within the instance's lifetime. + p._memories_container.read.return_value = { + "vectorEmbeddingPolicy": {"vectorEmbeddings": [{"path": "/embedding", "distanceFunction": "cosine"}]} + } + assert p._vector_distance_function() == "euclidean" + assert p._memories_container.read.call_count == 1 + + +def test_distance_function_not_cached_on_read_failure() -> None: + # A transient container.read() failure must NOT poison the cache: it returns an + # uncached cosine default so the next call self-heals to the real (euclidean) + # policy. Caching cosine here would silently mis-handle a euclidean container. + p = _make_pipeline() + euclid = {"vectorEmbeddingPolicy": {"vectorEmbeddings": [{"path": "/embedding", "distanceFunction": "euclidean"}]}} + p._memories_container.read = MagicMock(side_effect=[RuntimeError("429 throttled"), euclid]) + + # First call: transient failure -> cosine, but NOT cached. + assert p._vector_distance_function() == "cosine" + assert getattr(p, "_distance_function_cache", None) is None + + # Second call: read succeeds -> real euclidean policy, now cached. + assert p._vector_distance_function() == "euclidean" + assert p._distance_function_cache == "euclidean" + + +def test_vector_candidates_orders_nearest_first_by_distance_function() -> None: + # Parity with async: ORDER BY direction follows the container distanceFunction. + p = _make_pipeline() + captured: dict[str, str] = {} + + def query_items(*, query: str, parameters, **kwargs): + del parameters, kwargs + captured["query"] = query + return iter( + [ + {"id": "near", "content": "a", "type": "fact", "score": 0.95}, + {"id": "far", "content": "b", "type": "fact", "score": 0.10}, + ] + ) + + p._memories_container.query_items.side_effect = query_items + + p._distance_function_cache = "cosine" + out = p._vector_candidates( + user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set() + ) + # Cosmos rejects an explicit ASC/DESC on ORDER BY VectorDistance(); it orders + # most-similar-first server-side. Direction-awareness lives in the Python sort. + assert "ORDER BY VectorDistance(c.embedding, @vec)" in captured["query"] + assert "VectorDistance(c.embedding, @vec) DESC" not in captured["query"] + assert "VectorDistance(c.embedding, @vec) ASC" not in captured["query"] + assert [c["id"] for c in out] == ["near", "far"] + + p._distance_function_cache = "euclidean" + out = p._vector_candidates( + user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set() + ) + assert "VectorDistance(c.embedding, @vec) ASC" not in captured["query"] + # euclidean: lower distance = more similar, so 0.10 ("far" label) sorts first. + assert [c["id"] for c in out] == ["far", "near"] + + +def test_dedup_extracted_vector_ladder_and_intra_batch() -> None: + p = _make_pipeline() + p._embed_batch.return_value = [[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.99, 0.01]] + p._vector_candidates = MagicMock( + side_effect=[ + [{"id": "existing-high", "content": "same", "type": "fact", "score": 0.99}], + [{"id": "existing-mid", "content": "near", "type": "fact", "score": 0.85}], + [{"id": "existing-low", "content": "far", "type": "fact", "score": 0.20}], + [{"id": "existing-low-2", "content": "far", "type": "fact", "score": 0.10}], + ] + ) + extracted = { + "facts": [ + _doc("f-high", "drop against existing"), + _doc("f-mid", "tag against existing"), + _doc("f-clean", "keep clean"), + _doc("f-intra", "near-dup in batch is now kept (deferred to reconcile)"), + ], + "episodic": [], + "updates": [], + } + + out = p.dedup_extracted_memories("u1", extracted) + + # Intra-batch new-vs-new dedup was removed: f-intra is compared only against + # persisted memories (Cosmos) -> novel here -> kept; reconcile catches any + # same-batch near-dups later. + assert [doc["id"] for doc in out["facts"]] == ["f-mid", "f-clean", "f-intra"] + assert out["facts"][0]["tags"][-1] == "sys:dup-candidate" + assert out["facts"][0]["metadata"]["dup_of"] == "existing-mid" + assert out["facts"][0]["metadata"]["dup_score"] == 0.85 + assert "sys:dup-candidate" not in out["facts"][1]["tags"] + assert all("embedding" in doc for doc in out["facts"]) + assert out["updates"][-1]["vector_dedup_skipped"] == 1 + assert out["updates"][-1]["dup_candidates_tagged"] == 1 + + +def test_candidate_mode_clears_tags_only_for_survivors() -> None: + # Latent-bug regression: a source consumed (superseded) by a merge must NOT be + # re-upserted by tag-clearing, which would resurrect it without superseded_by. + p = _make_pipeline() + f1 = _doc("f1", "a", tags=["sys:fact", "sys:dup-candidate"]) + f2 = _doc("f2", "b", tags=["sys:fact", "sys:dup-candidate"]) + f3 = _doc("f3", "c", tags=["sys:fact", "sys:dup-candidate"]) + p._build_candidate_clusters = MagicMock(return_value=([[f1, f2, f3]], 3, [f1, f2, f3])) + p._reconcile_pool = MagicMock(return_value=({"kept": 1, "merged": 2, "contradicted": 0}, {"f1", "f2"})) + cleared: list[str] = [] + p._clear_dup_candidate_tags = MagicMock(side_effect=lambda docs: cleared.extend(d["id"] for d in docs)) + p._emit_reconcile_outcome = MagicMock() + + result = p._reconcile_candidate_mode("u1", n=50, memory_type="fact", started_at=0.0) + + assert cleared == ["f3"] + assert result == { + "kept": 1, + "merged": 2, + "contradicted": 0, + "reconcile_clusters_sent": 1, + "reconcile_llm_calls_saved": 2, + } + + +def test_candidate_mode_clears_tags_on_orphan_seeds() -> None: + # Seeds tagged sys:dup-candidate that never join a cluster (no near-duplicate) + # must have their stale tag cleared so future sweeps don't re-scan them forever. + p = _make_pipeline() + orphan = _doc("orphan", "lonely", tags=["sys:fact", "sys:dup-candidate"]) + c1 = _doc("c1", "a", tags=["sys:fact", "sys:dup-candidate"]) + c2 = _doc("c2", "b", tags=["sys:fact", "sys:dup-candidate"]) + p._build_candidate_clusters = MagicMock(return_value=([[c1, c2]], 3, [orphan, c1, c2])) + p._reconcile_pool = MagicMock(return_value=({"kept": 2, "merged": 0, "contradicted": 0}, set())) + cleared: list[str] = [] + p._clear_dup_candidate_tags = MagicMock(side_effect=lambda docs: cleared.extend(d["id"] for d in docs)) + p._emit_reconcile_outcome = MagicMock() + + p._reconcile_candidate_mode("u1", n=50, memory_type="fact", started_at=0.0) + + # Cluster survivors (c1, c2) plus the orphan all get cleared. + assert set(cleared) == {"c1", "c2", "orphan"} + + +def test_euclidean_disables_near_exact_autodrop() -> None: + # On euclidean distance the cosine-calibrated DEDUP_SIM_HIGH auto-drop is + # disabled: a near-identical existing memory must NOT silently drop the new + # one. It falls through to borderline tagging (LLM reconcile adjudicates). + p = _make_pipeline() + p._vector_distance_function = MagicMock(return_value="euclidean") + p._embed_batch.return_value = [[1.0, 0.0]] + # euclidean score = distance; 0.05 is "near-exact" (would drop under cosine rules). + p._vector_candidates = MagicMock( + return_value=[{"id": "existing", "content": "same", "type": "fact", "score": 0.05}] + ) + extracted = {"facts": [_doc("f-new", "near identical")], "episodic": [], "updates": []} + + out = p.dedup_extracted_memories("u1", extracted) + + # Not dropped — kept and tagged for LLM reconcile instead. + assert [doc["id"] for doc in out["facts"]] == ["f-new"] + assert out["facts"][0]["tags"][-1] == "sys:dup-candidate" + assert out["updates"][-1]["vector_dedup_skipped"] == 0 + assert out["updates"][-1]["dup_candidates_tagged"] == 1 + + +def test_candidate_mode_has_no_inmemory_backstop() -> None: + # The periodic full-pool backstop is no longer driven by an in-memory sweep + # counter (unreliable on the FA per-worker singleton). Candidate mode does + # ONLY clustering now; the full-pool pass is requested by the caller via + # full_rebuild on a persisted-counter cadence. + p = _make_pipeline() + assert not hasattr(p, "_next_reconcile_sweep") + p._build_candidate_clusters = MagicMock(return_value=([], 0, [])) + p._active_memories_for_reconcile = MagicMock() + p._emit_reconcile_outcome = MagicMock() + + # Many sweeps in a row never escalate to a full-pool pass on their own. + for _ in range(30): + p._reconcile_candidate_mode("u1", n=50, memory_type="fact", started_at=0.0) + p._build_candidate_clusters.assert_called() + p._active_memories_for_reconcile.assert_not_called() + + +def test_full_rebuild_bypasses_candidate_mode(monkeypatch) -> None: + # Public reconcile(full_rebuild=True) must take the full-pool single-LLM-pass + # path even under candidate mode, so it catches dissimilar contradictions. + monkeypatch.setattr("azure.cosmos.agent_memory.thresholds.get_dedup_reconcile_mode", lambda: "candidate") + p = _make_pipeline() + pool = [_doc("f1", "User is vegetarian"), _doc("f2", "User loves steak")] + p._active_memories_for_reconcile = MagicMock(return_value=pool) + p._reconcile_pool = MagicMock(return_value=({"kept": 1, "merged": 0, "contradicted": 1}, set())) + p._reconcile_candidate_mode = MagicMock() + p._emit_reconcile_outcome = MagicMock() + + result = p.reconcile_memories("u1", n=50, memory_type="fact", full_rebuild=True) + + p._reconcile_candidate_mode.assert_not_called() + p._reconcile_pool.assert_called_once_with("u1", "fact", pool) + assert result["contradicted"] == 1 + + +def test_full_rebuild_clears_survivor_tags(monkeypatch) -> None: + # full_rebuild full-pool path must clear sys:dup-candidate on survivors so it + # doesn't leave stale tags/metadata on user-visible memories. + monkeypatch.setattr("azure.cosmos.agent_memory.thresholds.get_dedup_reconcile_mode", lambda: "candidate") + p = _make_pipeline() + pool = [ + _doc("f1", "a", tags=["sys:fact", "sys:dup-candidate"]), + _doc("f2", "b", tags=["sys:fact", "sys:dup-candidate"]), + ] + p._active_memories_for_reconcile = MagicMock(return_value=pool) + # f1 consumed (superseded), f2 survives. + p._reconcile_pool = MagicMock(return_value=({"kept": 1, "merged": 1, "contradicted": 0}, {"f1"})) + cleared: list[str] = [] + p._clear_dup_candidate_tags = MagicMock(side_effect=lambda docs: cleared.extend(d["id"] for d in docs)) + p._emit_reconcile_outcome = MagicMock() + + p.reconcile_memories("u1", n=50, memory_type="fact", full_rebuild=True) + + # Survivor f2 cleared; consumed f1 not re-upserted (would resurrect it). + assert cleared == ["f2"] + + +def test_sweep_survives_one_cluster_failure() -> None: + # A truncated/malformed LLM response on one cluster must not abort the sweep: + # remaining clusters still reconcile, orphan clearing still runs, and the failed + # cluster's tags are RETAINED (not cleared) so it retries next sweep. + p = _make_pipeline() + c1 = [ + _doc("a1", "x", tags=["sys:fact", "sys:dup-candidate"]), + _doc("a2", "y", tags=["sys:fact", "sys:dup-candidate"]), + ] + c2 = [ + _doc("b1", "p", tags=["sys:fact", "sys:dup-candidate"]), + _doc("b2", "q", tags=["sys:fact", "sys:dup-candidate"]), + ] + orphan = _doc("o1", "lonely", tags=["sys:fact", "sys:dup-candidate"]) + seeds = [c1[0], c1[1], c2[0], c2[1], orphan] + p._build_candidate_clusters = MagicMock(return_value=([c1, c2], 5, seeds)) + p._reconcile_pool = MagicMock( + side_effect=[RuntimeError("truncated LLM response"), ({"kept": 2, "merged": 0, "contradicted": 0}, set())] + ) + cleared: list[str] = [] + p._clear_dup_candidate_tags = MagicMock(side_effect=lambda docs: cleared.extend(d["id"] for d in docs)) + p._emit_reconcile_outcome = MagicMock() + + result = p._reconcile_candidate_mode("u1", n=50, memory_type="fact", started_at=0.0) + + # c1 failed -> its tags retained (not cleared) for retry; c2 survivors + orphan cleared. + assert "a1" not in cleared and "a2" not in cleared + assert {"b1", "b2", "o1"} <= set(cleared) + assert result["reconcile_clusters_sent"] == 2 + assert result["kept"] == 2 + + +def test_dedup_extracted_flag_off_is_noop(monkeypatch) -> None: + monkeypatch.setattr("azure.cosmos.agent_memory.thresholds.get_dedup_vector_enabled", lambda: False) + p = _make_pipeline() + extracted = {"facts": [_doc("f1", "content")], "episodic": [], "updates": []} + + out = p.dedup_extracted_memories("u1", extracted) + + assert out is extracted + p._embed_batch.assert_not_called() + + +def test_reconcile_memory_type_routing_episodic_and_procedural(monkeypatch) -> None: + monkeypatch.setattr("azure.cosmos.agent_memory.thresholds.get_dedup_reconcile_mode", lambda: "full_pool") + p = _make_pipeline() + episodes = [_doc("e1", "episode one", "episodic"), _doc("e2", "episode two", "episodic")] + p._memories_container.query_items.return_value = iter(episodes) + p._run_prompty.return_value = json.dumps({"duplicate_groups": [], "kept_ids": ["e1", "e2"]}) + + result = p.reconcile_memories("u1", memory_type="episodic") + + assert result["contradicted"] == 0 + assert p._run_prompty.call_args.args[0] == "dedup_episodic.prompty" + assert "episodics_text" in p._run_prompty.call_args.kwargs["inputs"] + + p._run_prompty.reset_mock() + assert p.reconcile_memories("u1", memory_type="procedural")["reconcile_clusters_sent"] == 0 + p._run_prompty.assert_not_called() + + +def test_candidate_mode_connected_components() -> None: + p = _make_pipeline() + seed = _doc( + "f1", + "User likes coffee", + embedding=[1.0, 0.0], + tags=["sys:fact", "sys:dup-candidate"], + metadata={"category": "preference", "dup_of": "f2", "dup_score": 0.9}, + ) + neighbor = _doc("f2", "User loves coffee", embedding=[0.95, 0.05]) + + def query_items(*, query: str, parameters: list[dict[str, Any]], **kwargs: Any): + del kwargs + params = {p["name"]: p["value"] for p in parameters} + if "ARRAY_CONTAINS" in query: + return iter([seed]) + ids = {value for name, value in params.items() if name.startswith("@id")} + if ids: + return iter([doc for doc in (seed, neighbor) if doc["id"] in ids]) + return iter([seed, neighbor]) + + p._memories_container.query_items.side_effect = query_items + p._vector_candidates = MagicMock( + return_value=[{"id": "f2", "content": neighbor["content"], "type": "fact", "score": 0.9}] + ) + p._run_prompty.return_value = json.dumps( + { + "duplicate_groups": [{"merged_content": "User likes coffee.", "source_ids": ["f1", "f2"]}], + "contradicted_pairs": [], + "kept_ids": [], + } + ) + + result = p.reconcile_memories("u1", memory_type="fact") + + assert result["reconcile_clusters_sent"] == 1 + assert result["merged"] == 2 + p._run_prompty.assert_called_once() + assert p._run_prompty.call_args.args[0] == "dedup.prompty" diff --git a/tests/unit/services/test_extract_dry.py b/tests/unit/services/test_extract_dry.py index 32b4390..c5c9adc 100644 --- a/tests/unit/services/test_extract_dry.py +++ b/tests/unit/services/test_extract_dry.py @@ -2,6 +2,7 @@ import json from typing import Any +from unittest.mock import AsyncMock import pytest @@ -14,10 +15,12 @@ class _SyncChat: def __init__(self, responses: list[dict[str, Any]]): self.responses = list(responses) self.calls = 0 + self.messages: list[list[dict[str, Any]]] = [] def generate(self, messages: list[dict[str, Any]], **opts: Any) -> str: - del messages, opts + del opts self.calls += 1 + self.messages.append(messages) return json.dumps(self.responses.pop(0)) @@ -58,6 +61,8 @@ async def generate(self, text: str) -> list[float]: class _Store: def __init__(self, docs: list[dict[str, Any]]): self.docs = [dict(doc) for doc in docs] + self.search_calls: list[dict[str, Any]] = [] + self.search_results: list[dict[str, Any]] = [] def query(self, sql: str, parameters=None, partition_key=None, cross_partition: bool = False): del partition_key, cross_partition @@ -101,6 +106,26 @@ def mark_superseded(self, old_doc: dict[str, Any], superseder_id: str, *, reason del old_doc, superseder_id, reason return True + def search( + self, + *, + search_terms: str, + user_id: str, + memory_types: list[str], + top_k: int, + include_superseded: bool = False, + ) -> list[dict[str, Any]]: + self.search_calls.append( + { + "search_terms": search_terms, + "user_id": user_id, + "memory_types": memory_types, + "top_k": top_k, + "include_superseded": include_superseded, + } + ) + return [dict(doc) for doc in self.search_results] + class _AsyncStore(_Store): async def query(self, sql: str, parameters=None, partition_key=None, cross_partition: bool = False): @@ -221,53 +246,19 @@ def test_extract_memories_dry_is_byte_deterministic_for_same_llm_response() -> N ) -@pytest.mark.asyncio -async def test_async_extract_memories_dry_shape_is_small_and_has_no_embeddings() -> None: - chat = _AsyncChat([_response()]) - embeddings = _AsyncEmbeddings() - memories_store = _AsyncStore([]) - turns_store = _AsyncStore([_turn(i) for i in range(50)]) - service = AsyncPipelineService( - memories_store, - chat, - embeddings, - containers=_async_containers_for_store(memories_store, turns_store=turns_store), - ) - - output = await service.extract_memories_dry("u1", "t1") - - assert set(output) == {"facts", "episodic", "updates", "processed_turn_docs"} - assert len(json.dumps(output)) < 32 * 1024 - assert all("embedding" not in doc for docs in (output["facts"], output["episodic"]) for doc in docs) - assert embeddings.calls == [] - - -@pytest.mark.asyncio -async def test_async_extract_memories_dry_is_byte_deterministic_for_same_llm_response() -> None: - store = _AsyncStore([]) - turns_store = _AsyncStore([_turn(1)]) - service = AsyncPipelineService( - store, - _AsyncChat([_response(), _response()]), - _AsyncEmbeddings(), - containers=_async_containers_for_store(store, turns_store=turns_store), - ) - - first = await service.extract_memories_dry("u1", "t1") - second = await service.extract_memories_dry("u1", "t1") - - assert json.dumps(first, sort_keys=True, separators=(",", ":")) == json.dumps( - second, sort_keys=True, separators=(",", ":") - ) - - -def test_dry_returns_processed_turn_docs_for_watermarking() -> None: - """``extract_memories_dry`` must surface the turn docs it processed so - ``persist_extracted_memories`` can stamp ``extracted_at`` on them and - the next extraction call doesn't reprocess them.""" +def test_extract_memories_dry_stage1_searches_user_turn_text_by_default() -> None: chat = _SyncChat([_response()]) memories_store = _Store([]) - turns_store = _Store([_turn(i) for i in range(3)]) + memories_store.search_results = [ + { + "id": "memory-hybrid", + "content": "Existing hybrid memory from search.", + "type": "fact", + "salience": 0.7, + } + ] + turns = [_turn(1), _turn(2)] + turns_store = _Store(turns) service = PipelineService( memories_store, chat, @@ -275,410 +266,213 @@ def test_dry_returns_processed_turn_docs_for_watermarking() -> None: containers=_containers_for_store(memories_store, turns_store=turns_store), ) - output = service.extract_memories_dry("u1", "t1") + service.extract_memories_dry("u1", "t1") - assert "processed_turn_docs" in output - assert {d["id"] for d in output["processed_turn_docs"]} == {"turn-0", "turn-1", "turn-2"} + assert memories_store.search_calls == [ + { + "search_terms": "\n".join(turn["content"] for turn in turns), + "user_id": "u1", + "memory_types": ["fact"], + "top_k": 10, + "include_superseded": False, + } + ] + assert "Existing hybrid memory from search." in json.dumps(chat.messages) -def test_dry_alone_does_not_mark_turns_as_extracted() -> None: - """A dry run is read-only: it must NOT stamp ``extracted_at`` on any - turn (the wet ``extract_memories`` orchestrator handles marking only - after a successful persist).""" - chat = _SyncChat([_response()]) +def test_extract_memories_dry_stage1_falls_back_to_transcript_without_user_turns() -> None: memories_store = _Store([]) - turns_store = _Store([_turn(i) for i in range(3)]) + turns = [ + { + **_turn(1), + "id": "assistant-turn-1", + "role": "assistant", + "content": "Assistant response with no user-role content.", + } + ] service = PipelineService( memories_store, - chat, + _SyncChat([_response()]), _SyncEmbeddings(), - containers=_containers_for_store(memories_store, turns_store=turns_store), + containers=_containers_for_store(memories_store, turns_store=_Store(turns)), ) service.extract_memories_dry("u1", "t1") - for doc in turns_store.docs: - assert "extracted_at" not in doc or doc["extracted_at"] is None + assert memories_store.search_calls == [ + { + "search_terms": service._build_transcript(turns), + "user_id": "u1", + "memory_types": ["fact"], + "top_k": 10, + "include_superseded": False, + } + ] -def test_extract_memories_marks_turns_after_successful_persist() -> None: - """The wet ``extract_memories`` must stamp ``extracted_at`` on each - turn it processed. Without this, the next extraction call re-loads - the same turns and the LLM re-decides UPDATE/CONTRADICT — which is - the runaway-extraction bug this fix is designed to prevent.""" +def test_extract_memories_dry_stage1_legacy_context_does_not_call_search(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("azure.cosmos.agent_memory.thresholds.get_dedup_context_vector_enabled", lambda: False) chat = _SyncChat([_response()]) - memories_store = _Store([]) - turns_store = _Store([_turn(i) for i in range(3)]) - service = PipelineService( - memories_store, - chat, - _SyncEmbeddings(), - containers=_containers_for_store(memories_store, turns_store=turns_store), + memories_store = _Store( + [ + { + "id": "legacy-memory", + "user_id": "u1", + "type": "fact", + "content": "Existing legacy memory from load.", + "salience": 0.6, + } + ] ) - - service.extract_memories("u1", "t1") - - marked_turns = [doc for doc in turns_store.docs if doc.get("extracted_at")] - assert len(marked_turns) == 3 - marked_ids = {doc["id"] for doc in marked_turns} - assert marked_ids == {"turn-0", "turn-1", "turn-2"} - - -def test_second_extract_call_does_not_reprocess_already_extracted_turns() -> None: - """End-to-end watermarking proof: after a first ``extract_memories`` - marks the turns, a second call with no NEW turns must produce zero - work — no second LLM call, no second persist. This is the property - that prevents reversed-supersede / hallucinated-meta-fact bugs.""" - chat = _SyncChat([_response(), _response()]) # second response should never be consumed - memories_store = _Store([]) - turns_store = _Store([_turn(i) for i in range(3)]) service = PipelineService( memories_store, chat, _SyncEmbeddings(), - containers=_containers_for_store(memories_store, turns_store=turns_store), + containers=_containers_for_store(memories_store, turns_store=_Store([_turn(1)])), ) - service.extract_memories("u1", "t1") - calls_after_first = chat.calls + service.extract_memories_dry("u1", "t1") - # Second invocation with no new turns: watermarked turns are filtered - # out by the query and the dry early-returns with empty items. - service.extract_memories("u1", "t1") - assert chat.calls == calls_after_first + assert memories_store.search_calls == [] + assert "Existing legacy memory from load." in json.dumps(chat.messages) @pytest.mark.asyncio -async def test_async_extract_memories_marks_turns_after_successful_persist() -> None: +async def test_async_extract_memories_dry_shape_is_small_and_has_no_embeddings(monkeypatch) -> None: + monkeypatch.setattr( + "azure.cosmos.agent_memory.aio.services.pipeline.get_dedup_context_vector_enabled", lambda: False + ) chat = _AsyncChat([_response()]) + embeddings = _AsyncEmbeddings() memories_store = _AsyncStore([]) - turns_store = _AsyncStore([_turn(i) for i in range(3)]) + turns_store = _AsyncStore([_turn(i) for i in range(50)]) service = AsyncPipelineService( memories_store, chat, - _AsyncEmbeddings(), + embeddings, containers=_async_containers_for_store(memories_store, turns_store=turns_store), ) - await service.extract_memories("u1", "t1") - - marked_turns = [doc for doc in turns_store.docs if doc.get("extracted_at")] - assert len(marked_turns) == 3 - - -# --------------------------------------------------------------------------- -# Grounding-check regression tests -# -# These tests pin down the two known LLM extraction-time failure modes that -# previously corrupted the fact store and required a wheel hotfix: -# -# 1. The LLM synthesizes an ADD by paraphrase-merging 2+ existing facts -# (e.g. "user eats meat" + "user loves steak" → "user loves steak, -# indicating they eat meat") even though the new user turn said nothing -# on the topic. -# 2. The LLM emits a second invented CONTRADICT fact alongside the literal -# user statement (e.g. user says "I love steak"; LLM emits both -# "loves steak" AND "user eats meat" — the second is a phantom -# explicit-negation that polluted the store with claims the user -# didn't make). -# -# The fix is a prompt change that forbids both patterns. Because we can't -# directly test prompt-following at the unit-test level (no real LLM in -# the test loop), we test the structural safety net: -# ``check_extracted_fact_grounding`` logs a WARNING when these patterns -# slip through. The four scenarios below pair a "buggy" LLM response with -# a "clean" one for each pattern, asserting the WARNING fires (or not) -# accordingly. If a future change ever regresses the prompt and the LLM -# starts emitting these patterns, the WARNING in production telemetry -# becomes the visible signal. -# --------------------------------------------------------------------------- - - -def _existing_fact(fid: str, content: str) -> dict[str, Any]: - """Build a minimal existing-fact doc shaped like what - ``_load_existing_memories`` returns from Cosmos.""" - return { - "id": fid, - "user_id": "u1", - "type": "fact", - "content": content, - "content_hash": fid, - "salience": 0.8, - "confidence": 0.9, - "metadata": {"category": "preference"}, - "tags": ["sys:fact"], - } - - -def _moderate_hotels_turn() -> dict[str, Any]: - return { - "id": "turn-new", - "user_id": "u1", - "thread_id": "t1", - "role": "user", - "type": "turn", - "content": "Normally, I prefer moderate hotels.", - "created_at": "2026-06-02T19:00:00+00:00", - } - + output = await service.extract_memories_dry("u1", "t1") -def _steak_seafood_turn() -> dict[str, Any]: - return { - "id": "turn-new", - "user_id": "u1", - "thread_id": "t1", - "role": "user", - "type": "turn", - "content": "Actually, I love steak and seafood.", - "created_at": "2026-06-02T18:00:00+00:00", - } + assert set(output) == {"facts", "episodic", "updates", "processed_turn_docs"} + assert len(json.dumps(output)) < 32 * 1024 + assert all("embedding" not in doc for docs in (output["facts"], output["episodic"]) for doc in docs) + assert embeddings.calls == [] -def test_grounding_check_warns_when_add_synthesizes_from_multiple_existing_facts(caplog) -> None: - """Scenario 1 (buggy): the LLM emits a synthesized ADD whose tokens come - from 2+ existing facts but not from the new user turn. The grounding - check must emit a WARNING naming the offending fact.""" - existing = [ - _existing_fact("fact_meat", "The user eats meat."), - _existing_fact("fact_steak", "The user loves steak and seafood."), - ] - buggy_response = { - "facts": [ - { - "text": "The user normally prefers moderate hotels.", - "action": "ADD", - "category": "preference", - "confidence": 0.9, - "salience": 0.7, - }, - { - # synthesized — tokens come from existing fact_steak + fact_meat, - # not from the new "moderate hotels" turn - "text": "The user loves steak and seafood, indicating they eat meat.", - "action": "ADD", - "category": "preference", - "confidence": 0.9, - "salience": 0.7, - }, - ], - "episodic": [], - } - chat = _SyncChat([buggy_response]) - memories_store = _Store(existing) - turns_store = _Store([_moderate_hotels_turn()]) - service = PipelineService( - memories_store, - chat, - _SyncEmbeddings(), - containers=_containers_for_store(memories_store, turns_store=turns_store), +@pytest.mark.asyncio +async def test_async_extract_memories_dry_is_byte_deterministic_for_same_llm_response(monkeypatch) -> None: + monkeypatch.setattr( + "azure.cosmos.agent_memory.aio.services.pipeline.get_dedup_context_vector_enabled", lambda: False + ) + store = _AsyncStore([]) + turns_store = _AsyncStore([_turn(1)]) + service = AsyncPipelineService( + store, + _AsyncChat([_response(), _response()]), + _AsyncEmbeddings(), + containers=_async_containers_for_store(store, turns_store=turns_store), ) - with caplog.at_level("WARNING", logger="azure.cosmos.agent_memory.pipeline"): - service.extract_memories_dry("u1", "t1") + first = await service.extract_memories_dry("u1", "t1") + second = await service.extract_memories_dry("u1", "t1") - synthesis_warnings = [ - rec - for rec in caplog.records - if "synthesized from" in rec.getMessage() and "steak and seafood, indicating they eat meat" in rec.getMessage() - ] - assert synthesis_warnings, ( - f"expected a WARNING flagging the synthesized fact; got: {[rec.getMessage() for rec in caplog.records]}" - ) - # The grounded "moderate hotels" fact must NOT trigger a warning. - assert not any( - "moderate hotels" in rec.getMessage() and "synthesized from" in rec.getMessage() for rec in caplog.records + assert json.dumps(first, sort_keys=True, separators=(",", ":")) == json.dumps( + second, sort_keys=True, separators=(",", ":") ) -def test_grounding_check_silent_when_add_is_grounded_in_user_turn(caplog) -> None: - """Scenario 2 (clean): with the same existing-facts context, a single - grounded ADD (only the new "moderate hotels" claim) must NOT trigger - any synthesis WARNING. This is the post-fix expected behaviour.""" - existing = [ - _existing_fact("fact_meat", "The user eats meat."), - _existing_fact("fact_steak", "The user loves steak and seafood."), - ] - clean_response = { - "facts": [ +@pytest.mark.asyncio +async def test_async_extract_memories_dry_stage1_searches_user_turn_text_by_default() -> None: + store = _AsyncStore([]) + store.search = AsyncMock( + return_value=[ { - "text": "The user normally prefers moderate hotels.", - "action": "ADD", - "category": "preference", - "confidence": 0.9, + "id": "fact-1", + "content": "The user prefers dark mode.", + "type": "fact", "salience": 0.7, } - ], - "episodic": [], - } - chat = _SyncChat([clean_response]) - memories_store = _Store(existing) - turns_store = _Store([_moderate_hotels_turn()]) - service = PipelineService( - memories_store, - chat, - _SyncEmbeddings(), - containers=_containers_for_store(memories_store, turns_store=turns_store), + ] ) - - with caplog.at_level("WARNING", logger="azure.cosmos.agent_memory.pipeline"): - service.extract_memories_dry("u1", "t1") - - grounding_warnings = [ - rec - for rec in caplog.records - if "synthesized from" in rec.getMessage() or "phantom-negation/restatement" in rec.getMessage() - ] - assert grounding_warnings == [], ( - f"clean output must not emit grounding warnings; got: {[rec.getMessage() for rec in grounding_warnings]}" + turns_store = _AsyncStore([_turn(1)]) + service = AsyncPipelineService( + store, + _AsyncChat([_response()]), + _AsyncEmbeddings(), + containers=_async_containers_for_store(store, turns_store=turns_store), ) + service._vector_candidates = AsyncMock(return_value=[]) + turns = [_turn(1)] + await service.extract_memories_dry("u1", "t1") -def test_grounding_check_warns_on_phantom_explicit_negation_fact(caplog) -> None: - """Scenario 3 (buggy): user said "Actually, I love steak and seafood"; - LLM emits both a literal-paraphrase fact AND a phantom "user eats meat" - fact (an invented explicit-negation of the prior vegetarian fact). - The phantom fact's tokens are NOT in the user turn and they overlap - a single existing fact ("does not eat meat") — the single-contributor - branch of the grounding heuristic must fire a WARNING.""" - existing = [_existing_fact("fact_veg", "The user does not eat meat.")] - buggy_response = { - "facts": [ - { - # legitimate literal paraphrase of the user turn - "text": "The user loves steak and seafood.", - "action": "CONTRADICT", - "supersedes_id": "fact_veg", - "category": "preference", - "confidence": 0.95, - "salience": 0.8, - }, - { - # phantom — user never said this; tokens come from existing fact_veg - "text": "The user eats meat.", - "action": "CONTRADICT", - "supersedes_id": "fact_veg", - "category": "preference", - "confidence": 0.95, - "salience": 0.8, - }, - ], - "episodic": [], - } - chat = _SyncChat([buggy_response]) - memories_store = _Store(existing) - turns_store = _Store([_steak_seafood_turn()]) - service = PipelineService( - memories_store, - chat, - _SyncEmbeddings(), - containers=_containers_for_store(memories_store, turns_store=turns_store), + store.search.assert_awaited_once_with( + search_terms="\n".join(turn["content"] for turn in turns), + user_id="u1", + memory_types=["fact"], + top_k=10, ) + service._vector_candidates.assert_not_awaited() - with caplog.at_level("WARNING", logger="azure.cosmos.agent_memory.pipeline"): - service.extract_memories_dry("u1", "t1") - phantom_warnings = [ - rec for rec in caplog.records if "phantom-negation" in rec.getMessage() and "eats meat" in rec.getMessage() +@pytest.mark.asyncio +async def test_async_extract_memories_dry_stage1_falls_back_to_transcript_without_user_turns() -> None: + store = _AsyncStore([]) + store.search = AsyncMock(return_value=[]) + turns = [ + { + **_turn(1), + "id": "assistant-turn-1", + "role": "assistant", + "content": "Assistant response with no user-role content.", + } ] - assert phantom_warnings, ( - f"expected a WARNING flagging the phantom-negation fact; got: {[rec.getMessage() for rec in caplog.records]}" - ) - # The legitimate "loves steak and seafood" fact must NOT trigger a warning; - # its tokens are grounded in the user turn. - assert not any( - "loves steak and seafood" in rec.getMessage() - and ("phantom-negation" in rec.getMessage() or "synthesized from" in rec.getMessage()) - for rec in caplog.records - ) - - -def test_grounding_check_silent_on_clean_implicit_contradict(caplog) -> None: - """Scenario 4 (clean): the post-fix expected behaviour for an implicit - contradiction — ONE fact with literal user text and a CONTRADICT - supersedes_id. No phantom-negation fact, no WARNING.""" - existing = [_existing_fact("fact_veg", "The user does not eat meat.")] - clean_response = { - "facts": [ - { - "text": "The user loves steak and seafood.", - "action": "CONTRADICT", - "supersedes_id": "fact_veg", - "category": "preference", - "confidence": 0.95, - "salience": 0.8, - } - ], - "episodic": [], - } - chat = _SyncChat([clean_response]) - memories_store = _Store(existing) - turns_store = _Store([_steak_seafood_turn()]) - service = PipelineService( - memories_store, - chat, - _SyncEmbeddings(), - containers=_containers_for_store(memories_store, turns_store=turns_store), + turns_store = _AsyncStore(turns) + service = AsyncPipelineService( + store, + _AsyncChat([_response()]), + _AsyncEmbeddings(), + containers=_async_containers_for_store(store, turns_store=turns_store), ) + transcript = service._build_transcript(turns) - with caplog.at_level("WARNING", logger="azure.cosmos.agent_memory.pipeline"): - service.extract_memories_dry("u1", "t1") + await service.extract_memories_dry("u1", "t1") - grounding_warnings = [ - rec - for rec in caplog.records - if "synthesized from" in rec.getMessage() or "phantom-negation/restatement" in rec.getMessage() - ] - assert grounding_warnings == [], ( - "clean implicit-contradict must not emit grounding warnings; got: " - f"{[rec.getMessage() for rec in grounding_warnings]}" + store.search.assert_awaited_once_with( + search_terms=transcript, + user_id="u1", + memory_types=["fact"], + top_k=10, ) @pytest.mark.asyncio -async def test_async_grounding_check_warns_on_synthesis(caplog) -> None: - """Async-path mirror of scenario 1: confirms the grounding heuristic - is wired into both sync and async extract pipelines.""" - existing = [ - _existing_fact("fact_meat", "The user eats meat."), - _existing_fact("fact_steak", "The user loves steak and seafood."), - ] - buggy_response = { - "facts": [ - { - "text": "The user normally prefers moderate hotels.", - "action": "ADD", - "category": "preference", - "confidence": 0.9, - "salience": 0.7, - }, +async def test_async_extract_memories_dry_stage1_legacy_path_when_context_vector_unset(monkeypatch) -> None: + monkeypatch.setattr( + "azure.cosmos.agent_memory.aio.services.pipeline.get_dedup_context_vector_enabled", lambda: False + ) + store = _AsyncStore( + [ { - "text": "The user loves steak and seafood, indicating they eat meat.", - "action": "ADD", - "category": "preference", - "confidence": 0.9, - "salience": 0.7, - }, - ], - "episodic": [], - } - chat = _AsyncChat([buggy_response]) - memories_store = _AsyncStore(existing) - turns_store = _AsyncStore([_moderate_hotels_turn()]) + "id": "fact-1", + "user_id": "u1", + "type": "fact", + "content": "Existing fact.", + "content_hash": "hash-1", + } + ] + ) + store.search = AsyncMock(return_value=[]) + turns_store = _AsyncStore([_turn(1)]) service = AsyncPipelineService( - memories_store, - chat, + store, + _AsyncChat([_response()]), _AsyncEmbeddings(), - containers=_async_containers_for_store(memories_store, turns_store=turns_store), + containers=_async_containers_for_store(store, turns_store=turns_store), ) - with caplog.at_level("WARNING", logger="azure.cosmos.agent_memory.pipeline.aio"): - await service.extract_memories_dry("u1", "t1") + await service.extract_memories_dry("u1", "t1") - synthesis_warnings = [ - rec - for rec in caplog.records - if "synthesized from" in rec.getMessage() and "steak and seafood, indicating they eat meat" in rec.getMessage() - ] - assert synthesis_warnings, ( - f"expected an async WARNING flagging the synthesized fact; got: {[rec.getMessage() for rec in caplog.records]}" - ) + store.search.assert_not_awaited() diff --git a/tests/unit/services/test_pipeline_service.py b/tests/unit/services/test_pipeline_service.py index 587f680..27518b5 100644 --- a/tests/unit/services/test_pipeline_service.py +++ b/tests/unit/services/test_pipeline_service.py @@ -10,6 +10,22 @@ from azure.cosmos.agent_memory.services.pipeline import PipelineService, _StoreContainerAdapter +@pytest.fixture(autouse=True) +def _pin_legacy_dedup_paths(monkeypatch): + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_reconcile_mode", + lambda: "full_pool", + ) + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_vector_enabled", + lambda: False, + ) + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_context_vector_enabled", + lambda: False, + ) + + class FakeLLMService: """Test helper exposing chat_client + embeddings_client pair. @@ -203,7 +219,6 @@ def test_extract_memories_happy_path_writes_fact_and_episodic() -> None: { "scope_type": "project", "scope_value": "CI", - "text": "Stabilized flaky CI tests by adding retries.", "situation": "CI tests flaked intermittently", "action_taken": "Added retries", "outcome": "Tests stabilized", @@ -226,12 +241,14 @@ def test_extract_memories_happy_path_writes_fact_and_episodic() -> None: assert llm.embed_calls == [ [ "The user prefers dark mode.", - "Stabilized flaky CI tests by adding retries.", + "CI tests flaked intermittently → Added retries → Tests stabilized", ] ] -def test_extract_memories_contradict_supersedes_existing_fact() -> None: +def test_extract_memories_creates_new_fact_without_superseding() -> None: + # Extract no longer detects updates/contradictions — it just adds new facts. + # Contradiction resolution happens later in reconcile, not at extract time. old = _fact("old_fact", "The user prefers light mode.") store = FakeStore([old]) turns_store = FakeStore([_turn("Actually, I prefer dark mode now.")]) @@ -241,8 +258,6 @@ def test_extract_memories_contradict_supersedes_existing_fact() -> None: "facts": [ { "text": "The user prefers dark mode.", - "action": "CONTRADICT", - "supersedes_id": "old_fact", "category": "preference", } ] @@ -253,12 +268,10 @@ def test_extract_memories_contradict_supersedes_existing_fact() -> None: result = _pipeline(store, llm, turns_store=turns_store).extract_memories("u1", "t1") assert result["fact_count"] == 1 - assert result["contradicted_count"] == 1 - assert store.supersede_calls[0][2] == "contradict" + assert result["contradicted_count"] == 0 + assert store.supersede_calls == [] old_doc = next(doc for doc in store.docs if doc["id"] == "old_fact") - assert old_doc["superseded_by"] == store.upserts[0]["id"] - assert old_doc["supersede_reason"] == "contradict" - assert old_doc["superseded_at"] + assert "superseded_by" not in old_doc def test_synthesize_procedural_produces_procedural_memory() -> None: diff --git a/tests/unit/store/test_memory_store.py b/tests/unit/store/test_memory_store.py index 45a5117..e4371e1 100644 --- a/tests/unit/store/test_memory_store.py +++ b/tests/unit/store/test_memory_store.py @@ -273,6 +273,63 @@ def test_search_adds_created_time_range_filters(): assert params["@created_before"] == "2026-03-01T00:00:00+00:00" +def test_search_uses_keyword_params_for_hybrid_sql(): + memories = MagicMock() + memories.query_items.return_value = [] + embeddings = MagicMock() + embeddings.generate.return_value = [0.1, 0.2] + store = MemoryStore(containers=_containers(memories=memories), embeddings_client=embeddings) + + store.search("weather in Seattle", user_id="u1") + + call_kwargs = memories.query_items.call_args.kwargs + assert "ORDER BY RANK RRF" in call_kwargs["query"] + assert "FullTextScore(c.content, @kw0, @kw1)" in call_kwargs["query"] + params = _params_by_name(call_kwargs) + assert params["@kw0"] == "weather" + assert params["@kw1"] == "seattle" + + +def test_search_all_stopwords_falls_back_to_vector_only(): + memories = MagicMock() + memories.query_items.return_value = [] + embeddings = MagicMock() + embeddings.generate.return_value = [0.1, 0.2] + store = MemoryStore(containers=_containers(memories=memories), embeddings_client=embeddings) + + store.search("what is the", user_id="u1") + + call_kwargs = memories.query_items.call_args.kwargs + assert "ORDER BY VectorDistance" in call_kwargs["query"] + assert "RANK RRF" not in call_kwargs["query"] + assert not any(name.startswith("@kw") for name in _params_by_name(call_kwargs)) + + +def test_search_episodic_forwards_search_options(): + store = MemoryStore(containers=_containers()) + store.search = MagicMock(return_value=[]) + + store.search_episodic("u1", "weather") + + store.search.assert_called_once_with( + search_terms="weather", + user_id="u1", + memory_types=["episodic"], + top_k=5, + min_salience=None, + include_superseded=False, + ) + + +def test_build_episodic_context_forwards_search_options(): + store = MemoryStore(containers=_containers()) + store.search_episodic = MagicMock(return_value=[]) + + assert store.build_episodic_context("u1", "weather") == "" + + store.search_episodic.assert_called_once_with("u1", "weather", top_k=3) + + def test_add_cosmos_routes_by_type(): turns = MagicMock() memories = MagicMock() diff --git a/tests/unit/test_auto_trigger.py b/tests/unit/test_auto_trigger.py index 03c5420..4bb73e7 100644 --- a/tests/unit/test_auto_trigger.py +++ b/tests/unit/test_auto_trigger.py @@ -10,10 +10,52 @@ from unittest.mock import MagicMock, patch +import pytest +from azure.cosmos.exceptions import CosmosResourceNotFoundError + +from azure.cosmos.agent_memory import _counters from azure.cosmos.agent_memory.cosmos_memory_client import CosmosMemoryClient from azure.cosmos.agent_memory.processors import DurableFunctionProcessor, InProcessProcessor +class _FakeCounterContainer: + """Minimal in-memory stand-in for the Cosmos counter container. + + Exercises the REAL increment / watermark-read / watermark-advance helpers + end-to-end (no mocking of counter math), so a watermark/recent_k regression + can't slip through behind constant mocks. + """ + + def __init__(self) -> None: + self.store: dict[str, dict] = {} + self._etag = 0 + + def read_item(self, *, item, partition_key): + if item not in self.store: + raise CosmosResourceNotFoundError(message="404") + return dict(self.store[item]) + + def create_item(self, *, body): + self._etag += 1 + body = dict(body) + body["_etag"] = f"e{self._etag}" + self.store[body["id"]] = body + return dict(body) + + def upsert_item(self, *, body, **_kwargs): + self._etag += 1 + body = dict(body) + body["_etag"] = f"e{self._etag}" + self.store[body["id"]] = body + return dict(body) + + def patch_item(self, *, item, partition_key, patch_operations): + doc = self.store.setdefault(item, {"id": item}) + for op in patch_operations: + doc[op["path"].lstrip("/")] = op["value"] + return dict(doc) + + def _connected(processor=None) -> CosmosMemoryClient: client = CosmosMemoryClient(use_default_credential=False, processor=processor) client._memories_container_client = MagicMock() @@ -43,7 +85,7 @@ def test_push_to_cosmos_fires_inprocess_trigger_per_turn(monkeypatch): client.add_local(user_id="u1", role="user", thread_id="t1", content="hi") client.push_to_cosmos() - pipeline.extract_memories.assert_called_once_with("u1", "t1") + pipeline.extract_memories.assert_called_once_with("u1", "t1", recent_k=1) def test_push_to_cosmos_durable_does_not_fire_trigger(monkeypatch): @@ -143,10 +185,178 @@ def test_extract_fires_independently_of_summary(self, monkeypatch): client.add_local(user_id="u1", role="user", thread_id="t1", content="hi") client.push_to_cosmos() - processor.process_extract_memories.assert_called_once_with(user_id="u1", thread_id="t1") + processor.process_extract_memories.assert_called_once_with(user_id="u1", thread_id="t1", recent_k=1) processor.process_thread_summary.assert_not_called() processor.process_user_summary.assert_not_called() + @pytest.mark.parametrize( + ("n_facts", "batch_count", "counter_result", "watermark", "expected_recent_k"), + [ + (1, 1, (0, 1), None, 1), + (1, 3, (0, 3), None, 3), + (5, 1, (4, 5), None, 5), + (1, 1, (5, 10), 5, 5), + # Large backlog is NOT capped: recent_k spans every turn since the + # watermark (newest-recent_k slice covers exactly those), so the + # watermark can advance to new_count with no stranded turns. + (1, 1, (98, 100), 0, 100), + # BOOTSTRAP regression: no watermark yet but the counter is already + # ahead of this batch (earlier extracts failed). base=0 so recent_k = + # new_count (30) covers ALL turns — the old fallback max(n_facts, + # batch_count) would return 2 and strand turns 1-28 forever. + (1, 2, (20, 30), None, 30), + ], + ) + def test_extract_recent_k_uses_watermark_then_falls_back( + self, + monkeypatch, + n_facts, + batch_count, + counter_result, + watermark, + expected_recent_k, + ): + monkeypatch.setenv("FACT_EXTRACTION_EVERY_N", str(n_facts)) + monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") + monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") + + processor = InProcessProcessor(pipeline=MagicMock()) + processor.process_extract_memories = MagicMock(return_value={}) + + client = _connected(processor=processor) + client._counter_container_client = MagicMock() + + with patch( + "azure.cosmos.agent_memory._counters.increment_counter_sync", + return_value=counter_result, + ), patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_sync", + return_value=watermark, + ), patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_sync", + ) as advance: + for i in range(batch_count): + client.add_local(user_id="u1", role="user", thread_id="t1", content=f"hi {i}") + client.push_to_cosmos() + + processor.process_extract_memories.assert_called_once_with( + user_id="u1", + thread_id="t1", + recent_k=expected_recent_k, + ) + advance.assert_called_once() + + def test_watermark_round_trip_fail_then_succeed_no_strand(self, monkeypatch): + """End-to-end round-trip against a REAL in-memory counter (no constant + mocks): a thread's first extract fails, the second succeeds, and the + second must cover EVERY turn so far — not just its own batch — so turns + from the failed batch are never stranded. This is the bootstrap case the + constant-mock tests could not catch.""" + monkeypatch.setenv("FACT_EXTRACTION_EVERY_N", "1") + monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") + monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") + + counter = _FakeCounterContainer() + recorded_recent_k: list[int] = [] + + def extract(*, user_id, thread_id, recent_k): + recorded_recent_k.append(recent_k) + if len(recorded_recent_k) == 1: + raise RuntimeError("transient LLM outage") # first extract fails + return {} + + processor = InProcessProcessor(pipeline=MagicMock()) + processor.process_extract_memories = MagicMock(side_effect=extract) + client = _connected(processor=processor) + client._counter_container_client = counter + + # Batch 1: 10 turns -> counter 0->10, extract FAILS (watermark not advanced). + for i in range(10): + client.add_local(user_id="u1", role="user", thread_id="t1", content=f"a{i}") + client.push_to_cosmos() + + # Batch 2: 10 turns -> counter 10->20, extract SUCCEEDS. + for i in range(10): + client.add_local(user_id="u1", role="user", thread_id="t1", content=f"b{i}") + client.push_to_cosmos() + + # First fired with 10 (all turns so far); second with 20 (ALL turns, since + # the failed first extract left the watermark unset) — NOT 10. + assert recorded_recent_k == [10, 20] + # Watermark now seeded at the full count after the successful extract. + cid = _counters.thread_counter_id("u1", "t1") + assert _counters.read_extract_watermark_sync(counter, cid, "u1", "t1") == 20 + + def test_watermark_not_advanced_when_extract_fails(self, monkeypatch): + """advance-on-success: a failing extract must NOT move the watermark, so + the skipped turns are retried next sweep; failure is stamped instead.""" + monkeypatch.setenv("FACT_EXTRACTION_EVERY_N", "1") + monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") + monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") + + processor = InProcessProcessor(pipeline=MagicMock()) + processor.process_extract_memories = MagicMock(side_effect=RuntimeError("llm down")) + + client = _connected(processor=processor) + client._counter_container_client = MagicMock() + + with patch( + "azure.cosmos.agent_memory._counters.increment_counter_sync", + return_value=(0, 1), + ), patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_sync", + return_value=None, + ), patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_sync", + ) as advance, patch( + "azure.cosmos.agent_memory._counters.stamp_failure_sync", + ) as stamp: + client.add_local(user_id="u1", role="user", thread_id="t1", content="hi") + client.push_to_cosmos() + + advance.assert_not_called() + stamp.assert_called_once() + + def test_reconcile_full_rebuild_on_persisted_counter_cadence(self, monkeypatch): + """Symmetry with the durable backend: the in-process auto-trigger requests + a full-pool reconcile (full_rebuild=True) on a PERSISTED-counter cadence — + every DEDUP_FULL_RECLUSTER_EVERY_N-th reconcile — not via an in-memory + per-instance sweep counter. Here that's every 2 turns.""" + monkeypatch.setenv("FACT_EXTRACTION_EVERY_N", "1") + monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") + monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") + monkeypatch.setenv("DEDUP_EVERY_N", "1") + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_full_recluster_every_n", lambda: 2 + ) + + rebuilds: list[bool] = [] + processor = InProcessProcessor(pipeline=MagicMock()) + processor.process_extract_memories = MagicMock(return_value={}) + processor.synthesize_procedural = MagicMock(return_value=None) + processor.process_reconcile = MagicMock( + side_effect=lambda *, user_id, full_rebuild=False: rebuilds.append(full_rebuild) + ) + + client = _connected(processor=processor) + client._counter_container_client = MagicMock() + + with patch( + "azure.cosmos.agent_memory._counters.increment_counter_sync", + side_effect=[(0, 1), (1, 2)], + ), patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_sync", + return_value=None, + ), patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_sync", + ): + client.add_local(user_id="u1", role="user", thread_id="t1", content="a") + client.push_to_cosmos() # counter 0->1: reconcile, full crosses 2? no + client.add_local(user_id="u1", role="user", thread_id="t1", content="b") + client.push_to_cosmos() # counter 1->2: full backstop threshold (2) crossed + + assert rebuilds == [False, True] + def test_summary_fires_independently_when_threshold_crossed(self, monkeypatch): """N_summary=10 boundary fires summary; N_facts=0 prevents extract.""" monkeypatch.setenv("FACT_EXTRACTION_EVERY_N", "0") diff --git a/tests/unit/test_cosmos_memory_client.py b/tests/unit/test_cosmos_memory_client.py index c04b000..bae9f12 100644 --- a/tests/unit/test_cosmos_memory_client.py +++ b/tests/unit/test_cosmos_memory_client.py @@ -453,31 +453,31 @@ def test_create_memory_store_defaults_to_serverless(self): mock_lease_container, ] - mem = _make_client(cosmos_throughput_mode="serverless") - - with patch.dict("os.environ", {"COSMOS_DB_AUTOSCALE_MAX_RU": "not-an-int"}, clear=False): - with patch.dict( - "sys.modules", - { - "azure.cosmos": MagicMock( - CosmosClient=mock_cosmos_cls, - PartitionKey=MagicMock(), - ThroughputProperties=MagicMock(), - ), - }, - ): - mem.create_memory_store( - endpoint="https://fake.documents.azure.com:443/", - credential="fake-key", - throughput_mode="serverless", - ) + # serverless mode ignores autoscale config entirely, even an invalid value. + mem = _make_client(cosmos_throughput_mode="serverless", cosmos_autoscale_max_ru="not-an-int") + + with patch.dict( + "sys.modules", + { + "azure.cosmos": MagicMock( + CosmosClient=mock_cosmos_cls, + PartitionKey=MagicMock(), + ThroughputProperties=MagicMock(), + ), + }, + ): + mem.create_memory_store( + endpoint="https://fake.documents.azure.com:443/", + credential="fake-key", + throughput_mode="serverless", + ) for call in mock_db.create_container_if_not_exists.call_args_list: assert "offer_throughput" not in call.kwargs - def test_constructor_ignores_invalid_autoscale_env_in_serverless_mode(self): - with patch.dict("os.environ", {"COSMOS_DB_AUTOSCALE_MAX_RU": "not-an-int"}, clear=False): - mem = _make_client(cosmos_throughput_mode="serverless") + def test_constructor_ignores_autoscale_in_serverless_mode(self): + # Even an invalid autoscale value is ignored in serverless mode. + mem = _make_client(cosmos_throughput_mode="serverless", cosmos_autoscale_max_ru="not-an-int") assert mem._cosmos_autoscale_max_ru is None @@ -797,7 +797,7 @@ def test_search_cosmos(self): assert "VectorDistance" in call_kwargs["query"] assert len(result) == 1 - def test_search_hybrid(self): + def test_search_hybrid_uses_keyword_params(self): mem, container = _connected_client() container.query_items.return_value = [_make_doc()] @@ -806,14 +806,79 @@ def test_search_hybrid(self): mem.search_cosmos( search_terms="weather Seattle", - hybrid_search=True, top_k=5, ) call_kwargs = container.query_items.call_args.kwargs query = call_kwargs["query"] + params = {p["name"]: p["value"] for p in call_kwargs["parameters"]} assert "RANK RRF" in query - assert "FullTextScore" in query + assert "FullTextScore(c.content, @kw0, @kw1)" in query + assert params["@kw0"] == "weather" + assert params["@kw1"] == "seattle" + + def test_search_cosmos_forwards_search_options_to_store(self): + mem, _ = _connected_client() + store = MagicMock() + store._containers = mem._containers + store._embeddings_client = mem._embeddings_client + store.search.return_value = [] + mem._store = store + + mem.search_cosmos(search_terms="weather") + + store.search.assert_called_once_with( + search_terms="weather", + memory_id=None, + user_id=None, + role=None, + memory_types=None, + thread_id=None, + top_k=5, + tags_all=None, + tags_any=None, + exclude_tags=None, + include_superseded=False, + min_salience=None, + min_confidence=None, + created_after=None, + created_before=None, + ) + + def test_search_episodic_memories_forwards_search_options(self): + mem, _ = _connected_client() + store = MagicMock() + store._containers = mem._containers + store._embeddings_client = mem._embeddings_client + store.search_episodic.return_value = [] + mem._store = store + + mem.search_episodic_memories("u1", "weather") + + store.search_episodic.assert_called_once_with( + user_id="u1", + search_terms="weather", + top_k=5, + min_salience=None, + include_superseded=False, + ) + + def test_build_episodic_context_forwards_search_options(self): + mem, _ = _connected_client() + store = MagicMock() + store._containers = mem._containers + store._embeddings_client = mem._embeddings_client + store.build_episodic_context.return_value = "context" + mem._store = store + + result = mem.build_episodic_context("u1", "weather") + + assert result == "context" + store.build_episodic_context.assert_called_once_with( + user_id="u1", + query="weather", + top_k=3, + ) def test_search_turns(self): mem, container = _connected_client() diff --git a/tests/unit/test_pipeline_confidence.py b/tests/unit/test_pipeline_confidence.py index fce8308..3dcf20e 100644 --- a/tests/unit/test_pipeline_confidence.py +++ b/tests/unit/test_pipeline_confidence.py @@ -13,6 +13,18 @@ from azure.cosmos.agent_memory.store import MemoryStore +@pytest.fixture(autouse=True) +def _pin_legacy_extract_dedup(monkeypatch): + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_vector_enabled", + lambda: False, + ) + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_context_vector_enabled", + lambda: False, + ) + + def _make_pipeline(llm_response: dict): turns_container = MagicMock() memories_container = MagicMock() @@ -139,7 +151,6 @@ def test_extract_episodic_carries_confidence(): { "scope_type": "project", "scope_value": "CI revamp", - "text": "Set up CI by adding Ruff — faster lint times.", "situation": "Setup CI", "action_taken": "Added Ruff", "outcome": "Faster lint", @@ -286,17 +297,11 @@ def test_thread_ids_does_not_appear_in_query_or_parameters(self): # --------------------------------------------------------------------------- -def test_extract_drops_episodic_missing_text(caplog): - """An episodic with no ``text`` is dropped and surfaced via the return value. +def test_extract_scoped_intent_without_outcome_stores_correctly(caplog): + """An episodic with only scope fields (no situation/action/outcome) is kept. - Previously the pipeline synthesized boilerplate content like - ``"For the user's Paris trip, intent recorded."`` which was - semantically empty for embedding/recall. The fix is to require - the LLM to emit ``text`` (same field facts use) — if it doesn't, - drop the record so we don't poison the recall index. The drop is - logged at ERROR (it's data loss) and surfaced via the - ``dropped_episodic_count`` field on the return dict so callers - can monitor LLM-extraction compliance over time. + The doc must use the deterministic fallback content string, expose the + scope fields at the top level, and not emit a "dropping malformed" warning. """ pipeline, upserted = _make_pipeline( { @@ -306,34 +311,36 @@ def test_extract_drops_episodic_missing_text(caplog): "scope_value": "Paris", "confidence": 0.95, "salience": 0.8, - "tags": ["topic:travel", "topic:hotels"], } ] } ) - with caplog.at_level("ERROR", logger="azure.cosmos.agent_memory.pipeline"): - result = pipeline.extract_memories("u1", "t1") + with caplog.at_level("WARNING", logger="azure.cosmos.agent_memory.pipeline"): + pipeline.extract_memories("u1", "t1") eps = [d for d in upserted if d["type"] == "episodic"] - assert eps == [] - assert result["episodic_count"] == 0 - assert result["dropped_episodic_count"] == 1 - msgs = [rec.getMessage() for rec in caplog.records] - assert any("empty/missing text field" in m for m in msgs) - assert any("reason=missing_text" in m for m in msgs) - # Bumped from WARNING → ERROR because dropping == data loss. - assert any(rec.levelname == "ERROR" and "empty/missing text field" in rec.getMessage() for rec in caplog.records) + assert len(eps) == 1 + ep = eps[0] + assert ep["scope_type"] == "trip" + assert ep["scope_value"] == "Paris" + assert ep["metadata"]["scope_type"] == "trip" + assert ep["metadata"]["scope_value"] == "Paris" + assert ep["metadata"]["situation"] is None + assert ep["metadata"]["action_taken"] is None + assert ep["metadata"]["outcome"] is None + assert ep["content"] == "For the user's Paris trip, intent recorded." + assert ep["confidence"] == pytest.approx(0.95) + assert not any("dropping malformed episodic" in rec.getMessage() for rec in caplog.records) -def test_extract_past_event_episodic_uses_text_and_keeps_chain_in_metadata(): +def test_extract_past_event_episodic_uses_arrow_form_and_keeps_scope(): pipeline, upserted = _make_pipeline( { "episodic": [ { "scope_type": "project", "scope_value": "Acme revamp", - "text": "Migrated Acme DB by running the script — all rows migrated cleanly.", "situation": "Migrated DB", "action_taken": "Ran the script", "outcome": "All rows migrated", @@ -352,7 +359,7 @@ def test_extract_past_event_episodic_uses_text_and_keeps_chain_in_metadata(): pipeline.extract_memories("u1", "t1") [ep] = [d for d in upserted if d["type"] == "episodic"] - assert ep["content"] == "Migrated Acme DB by running the script — all rows migrated cleanly." + assert ep["content"] == "Migrated DB → Ran the script → All rows migrated" assert ep["scope_type"] == "project" assert ep["scope_value"] == "Acme revamp" md = ep["metadata"] @@ -366,13 +373,12 @@ def test_extract_past_event_episodic_uses_text_and_keeps_chain_in_metadata(): assert "topic:db" in ep["tags"] -def test_extract_episodic_uses_text_directly_no_synthesis(): - """The LLM-written ``text`` is the embedded ``content``, verbatim. +def test_extract_episodic_falls_back_to_arrow_form_when_summary_field_present(): + """The schema dropped ``summary``; pipeline now always uses arrow form. - The pipeline must NOT synthesize content from the s/a/o chain or - from scope fields — that's how we ended up with useless boilerplate - before the fix. Whatever the LLM emits in ``text`` is what gets - embedded. + Even if a non-strict LLM smuggles a ``summary`` field through, the + pipeline ignores it and builds content from + ``situation → action_taken → outcome``. """ pipeline, upserted = _make_pipeline( { @@ -380,7 +386,7 @@ def test_extract_episodic_uses_text_directly_no_synthesis(): { "scope_type": "trip", "scope_value": "Paris", - "text": "User wants luxury hotels for the Paris trip.", + "summary": "User wants luxury hotels for the Paris trip.", "situation": "Planning Paris trip", "action_taken": "Said luxury", "outcome": "Pending", @@ -392,106 +398,7 @@ def test_extract_episodic_uses_text_directly_no_synthesis(): pipeline.extract_memories("u1", "t1") [ep] = [d for d in upserted if d["type"] == "episodic"] - assert ep["content"] == "User wants luxury hotels for the Paris trip." - assert ep["metadata"]["situation"] == "Planning Paris trip" - assert ep["metadata"]["action_taken"] == "Said luxury" - assert ep["metadata"]["outcome"] == "Pending" - - -def test_extract_episodic_uses_text_alone_for_planned_intent(): - """Planned/in-flight episodics carry their meaning entirely in ``text``. - - This is the headline bug-1 scenario from the workshop: the LLM (correctly - following the prompt) emits only scope_type/scope_value/text for a - planned trip, and the pipeline must embed the text, not boilerplate. - """ - pipeline, upserted = _make_pipeline( - { - "episodic": [ - { - "scope_type": "trip", - "scope_value": "Tokyo", - "text": ("Planning a Tokyo trip with vegetarian and wheelchair-accessible-restaurant constraints."), - "confidence": 0.95, - "salience": 0.85, - "tags": ["topic:travel", "topic:accessibility"], - } - ] - } - ) - - pipeline.extract_memories("u1", "t1") - - [ep] = [d for d in upserted if d["type"] == "episodic"] - assert ep["content"] == ("Planning a Tokyo trip with vegetarian and wheelchair-accessible-restaurant constraints.") - assert ep["metadata"]["situation"] is None - assert ep["metadata"]["action_taken"] is None - assert ep["metadata"]["outcome"] is None - - -def test_extract_episodic_strips_whitespace_from_text(): - pipeline, upserted = _make_pipeline( - { - "episodic": [ - { - "scope_type": "trip", - "scope_value": "Paris", - "text": " Planning a Paris trip. ", - } - ] - } - ) - pipeline.extract_memories("u1", "t1") - [ep] = [d for d in upserted if d["type"] == "episodic"] - assert ep["content"] == "Planning a Paris trip." - - -def test_extract_compound_statement_yields_facts_across_categories(): - """Bug-2 scenario: a single user turn that combines preference + requirement - must produce two facts, not one merged "restaurant preferences" fact. - - Drives the prompt's tightened consolidation rule. We're mocking the LLM - response here so this is really a regression guard on the pipeline plumbing - (the prompt change is what makes a real LLM produce this shape). - """ - pipeline, upserted = _make_pipeline( - { - "facts": [ - { - "text": "The user does not eat meat.", - "category": "preference", - "subject": "user", - "predicate": "dietary_restriction", - "object": "no meat", - "confidence": 1.0, - "salience": 0.9, - "tags": ["topic:diet"], - "action": "ADD", - "supersedes_id": None, - }, - { - "text": "The user requires wheelchair-accessible restaurants.", - "category": "requirement", - "subject": "user", - "predicate": "accessibility_requirement", - "object": "wheelchair-accessible restaurants", - "confidence": 1.0, - "salience": 0.95, - "tags": ["topic:accessibility"], - "action": "ADD", - "supersedes_id": None, - }, - ] - } - ) - pipeline.extract_memories("u1", "t1") - - facts = [d for d in upserted if d["type"] == "fact"] - assert len(facts) == 2 - by_category = {f["metadata"]["category"]: f for f in facts} - assert set(by_category) == {"preference", "requirement"} - assert by_category["preference"]["content"] == "The user does not eat meat." - assert by_category["requirement"]["content"] == "The user requires wheelchair-accessible restaurants." + assert ep["content"] == "Planning Paris trip → Said luxury → Pending" def test_extract_drops_episodic_missing_scope_type(caplog): @@ -509,13 +416,10 @@ def test_extract_drops_episodic_missing_scope_type(caplog): ) with caplog.at_level("WARNING", logger="azure.cosmos.agent_memory.pipeline"): - result = pipeline.extract_memories("u1", "t1") + pipeline.extract_memories("u1", "t1") assert not any(d["type"] == "episodic" for d in upserted) assert any("dropping malformed episodic" in rec.getMessage() for rec in caplog.records) - assert any("reason=malformed_scope" in rec.getMessage() for rec in caplog.records) - # Malformed-scope drops also count toward the dropped_episodic_count signal. - assert result["dropped_episodic_count"] == 1 def test_extract_drops_episodic_missing_scope_value(caplog): @@ -578,7 +482,6 @@ def test_extract_strips_whitespace_from_scope_fields(): { "scope_type": " trip ", "scope_value": " Paris ", - "text": "Planning a Paris trip.", "confidence": 0.9, } ] @@ -590,4 +493,43 @@ def test_extract_strips_whitespace_from_scope_fields(): [ep] = [d for d in upserted if d["type"] == "episodic"] assert ep["scope_type"] == "trip" assert ep["scope_value"] == "Paris" - assert ep["content"] == "Planning a Paris trip." + assert ep["content"] == "For the user's Paris trip, intent recorded." + + +def test_extract_compound_statement_yields_facts_across_categories(): + """A single user turn that combines preference + requirement must produce two + facts, not one merged fact. Regression guard on the pipeline plumbing.""" + pipeline, upserted = _make_pipeline( + { + "facts": [ + { + "text": "The user does not eat meat.", + "category": "preference", + "subject": "user", + "predicate": "dietary_restriction", + "object": "no meat", + "confidence": 1.0, + "salience": 0.9, + "tags": ["topic:diet"], + }, + { + "text": "The user requires wheelchair-accessible restaurants.", + "category": "requirement", + "subject": "user", + "predicate": "accessibility_requirement", + "object": "wheelchair-accessible restaurants", + "confidence": 1.0, + "salience": 0.95, + "tags": ["topic:accessibility"], + }, + ] + } + ) + pipeline.extract_memories("u1", "t1") + + facts = [d for d in upserted if d["type"] == "fact"] + assert len(facts) == 2 + by_category = {f["metadata"]["category"]: f for f in facts} + assert set(by_category) == {"preference", "requirement"} + assert by_category["preference"]["content"] == "The user does not eat meat." + assert by_category["requirement"]["content"] == "The user requires wheelchair-accessible restaurants." diff --git a/tests/unit/test_procedural_synthesis.py b/tests/unit/test_procedural_synthesis.py index 889e3ca..5930104 100644 --- a/tests/unit/test_procedural_synthesis.py +++ b/tests/unit/test_procedural_synthesis.py @@ -15,6 +15,18 @@ from azure.cosmos.agent_memory.store import MemoryStore +@pytest.fixture(autouse=True) +def _pin_legacy_extract_dedup(monkeypatch): + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_vector_enabled", + lambda: False, + ) + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_context_vector_enabled", + lambda: False, + ) + + def _assert_iso8601(text: str) -> None: assert text datetime.fromisoformat(text) diff --git a/tests/unit/test_process_now.py b/tests/unit/test_process_now.py index d81a0f7..3d71c56 100644 --- a/tests/unit/test_process_now.py +++ b/tests/unit/test_process_now.py @@ -2,7 +2,7 @@ from __future__ import annotations -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import pytest @@ -50,7 +50,13 @@ def test_process_now_with_inprocess_invokes_full_pipeline(): assert isinstance(client._processor, InProcessProcessor) pipeline.generate_thread_summary.assert_called_once_with("u1", "t1") pipeline.extract_memories.assert_called_once_with("u1", "t1") - pipeline.reconcile_memories.assert_called_once_with("u1", 50) + assert pipeline.reconcile_memories.call_count == 2 + pipeline.reconcile_memories.assert_has_calls( + [ + call("u1", n=50, memory_type="fact", full_rebuild=False), + call("u1", n=50, memory_type="episodic", full_rebuild=False), + ] + ) pipeline.synthesize_procedural.assert_called_once_with(user_id="u1", force=False) pipeline.generate_user_summary.assert_called_once_with("u1", None) assert result.procedural == {"id": "proc1", "type": "procedural"} diff --git a/tests/unit/test_reconcile.py b/tests/unit/test_reconcile.py index be3f63b..f0a1c32 100644 --- a/tests/unit/test_reconcile.py +++ b/tests/unit/test_reconcile.py @@ -29,6 +29,23 @@ from azure.cosmos.agent_memory.services.pipeline import PipelineService +@pytest.fixture(autouse=True) +def _pin_legacy_dedup_paths(monkeypatch): + """These tests cover the pre-candidate reconcile/extract code paths.""" + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_reconcile_mode", + lambda: "full_pool", + ) + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_vector_enabled", + lambda: False, + ) + monkeypatch.setattr( + "azure.cosmos.agent_memory.thresholds.get_dedup_context_vector_enabled", + lambda: False, + ) + + def _make_pipeline() -> PipelineService: p = PipelineService.__new__(PipelineService) p._embeddings = MagicMock() @@ -649,269 +666,6 @@ def test_fact_not_dropped_when_only_procedural_has_same_hash(self): assert fact_docs[0]["content"] == text -class TestEpisodicReconciliation: - """Episodic memories use scope as identity: the deterministic ID is - seeded only on ``(user_id, scope_type, scope_value)``. Any re-emission - for the same scope (paraphrased intent, added detail, reversed intent) - collides on upsert and replaces the prior record. The LLM does NOT - make ADD/UPDATE/CONTRADICT decisions for episodics — the scope IS the - identity. Distinct events under the same umbrella belong under - distinct ``scope_value`` strings (e.g. "Tokyo trip" vs - "Tokyo lost-wallet incident"). - """ - - def _build(self) -> PipelineService: - p = PipelineService.__new__(PipelineService) - p._embeddings = MagicMock() - p._embeddings.generate.return_value = [0.1] * 8 - p._embeddings.generate_batch.return_value = [[0.1] * 8] - p._container = MagicMock() - p._memories_container = p._container - p._turns_container = p._container - p._summaries_container = p._container - p._chat = MagicMock() - p._upsert_memory = MagicMock(side_effect=lambda doc: doc) - p._create_memory = MagicMock(side_effect=lambda doc: doc) - p._mark_superseded = MagicMock(return_value=True) - return p - - def _turns(self) -> list[dict]: - return [ - { - "id": "turn-1", - "role": "user", - "content": "x", - "type": "turn", - "created_at": "2024-01-01T00:00:00+00:00", - } - ] - - def _episodic_payload(self, **overrides) -> dict: - payload = { - "scope_type": "trip", - "scope_value": "Tokyo", - "text": "Planning a Tokyo trip with a luxury hotel preference.", - "situation": None, - "action_taken": None, - "outcome": None, - "outcome_valence": None, - "reasoning": None, - "lesson": None, - "domain": "travel", - "confidence": 0.95, - "salience": 0.8, - "tags": ["topic:travel"], - } - payload.update(overrides) - return payload - - def test_existing_episodics_are_rendered_into_prompt_inputs(self): - """The extractor must pass ``existing_episodics`` to the LLM, grouped - by ``(scope_type, scope_value)``. Without this, the model has no - context for refining or reversing the existing intent for that - scope when it emits the merged text.""" - p = self._build() - existing_text = "Planning a Tokyo trip with a luxury hotel preference." - existing_ep = { - "id": "ep_existing", - "type": "episodic", - "content": existing_text, - "content_hash": compute_content_hash(existing_text), - "thread_id": "__episodic__", - "salience": 0.8, - "metadata": {"scope_type": "trip", "scope_value": "Tokyo"}, - } - p._container.query_items.return_value = iter(self._turns()) - # Two queries are issued (one for facts, one for episodics) so each - # type gets its own 100-row budget — return [] for facts, the - # existing episodic for the episodic call. - p._load_existing_memories = MagicMock( - side_effect=lambda user_id, memory_types, **kw: [existing_ep] if memory_types == ["episodic"] else [] - ) - p._run_prompty = MagicMock(return_value=json.dumps({"facts": [], "episodic": [], "unclassified": []})) - - p.extract_memories("u1", "t1") - - # Two separate calls — one per type, each with its own budget. - load_calls = [c.args for c in p._load_existing_memories.call_args_list] - assert ("u1", ["fact"]) in load_calls - assert ("u1", ["episodic"]) in load_calls - call_kwargs = p._run_prompty.call_args.kwargs - inputs = call_kwargs["inputs"] - assert "existing_episodics" in inputs - rendered = inputs["existing_episodics"] - assert "trip = Tokyo" in rendered - assert "ep_existing" in rendered - assert existing_text in rendered - - def test_same_scope_episodics_collide_on_deterministic_id(self): - """Two episodics with the same (scope_type, scope_value) but - different ``text`` MUST produce the same deterministic ID so that - the second write overwrites the first via upsert. This is the - core mechanism that prevents near-duplicate episodic storage when - a recent-turn re-extraction window paraphrases the same intent. - """ - p = self._build() - p._container.query_items.return_value = iter(self._turns()) - p._load_existing_memories = MagicMock(return_value=[]) - # LLM emits two episodics under the SAME scope but with paraphrased - # text — this is the exact failure mode the user reported. - p._run_prompty = MagicMock( - return_value=json.dumps( - { - "facts": [], - "episodic": [ - self._episodic_payload(text="Planning a Tokyo trip with a luxury hotel preference."), - self._episodic_payload(text="Planning a Tokyo trip with a preference for luxury hotels."), - ], - "unclassified": [], - } - ) - ) - - p.extract_memories("u1", "t1") - - upsert_calls = [c.args[0] for c in p._upsert_memory.call_args_list if c.args[0].get("type") == "episodic"] - # Both episodics flow through upsert (the persist path branches on - # type=episodic) and they MUST share the same det_id — that's what - # makes the second a Cosmos upsert that replaces the first. - assert len(upsert_calls) == 2 - assert upsert_calls[0]["id"] == upsert_calls[1]["id"] - # And neither went through create_item (which would 409 on the - # second write and silently lose the new richer text). - episodic_creates = [c for c in p._create_memory.call_args_list if c.args[0].get("type") == "episodic"] - assert episodic_creates == [] - - def test_different_scope_values_produce_different_ids(self): - """Two episodics with the same scope_type but different - scope_value (e.g. distinct trips, or distinct incidents within a - trip) MUST produce different deterministic IDs so they coexist. - """ - p = self._build() - p._container.query_items.return_value = iter(self._turns()) - p._load_existing_memories = MagicMock(return_value=[]) - p._run_prompty = MagicMock( - return_value=json.dumps( - { - "facts": [], - "episodic": [ - self._episodic_payload(scope_value="Tokyo", text="Trip A."), - self._episodic_payload(scope_value="Paris", text="Trip B."), - ], - "unclassified": [], - } - ) - ) - - p.extract_memories("u1", "t1") - - upsert_calls = [c.args[0] for c in p._upsert_memory.call_args_list if c.args[0].get("type") == "episodic"] - assert len(upsert_calls) == 2 - assert upsert_calls[0]["id"] != upsert_calls[1]["id"] - - def test_episodic_and_fact_with_same_content_do_not_collide(self): - """An episodic's deterministic ID is seeded on scope; a fact's is - seeded on content_hash. Even if their content text matches - verbatim, the IDs live in disjoint namespaces (ep_ vs fact_ prefix - plus different seeds) so both records persist.""" - p = self._build() - text = "Planning a Tokyo trip with a luxury hotel preference." - existing = [ - { - "id": "fact_existing", - "type": "fact", - "content": text, - "content_hash": compute_content_hash(text), - "thread_id": "t1", - "tags": ["sys:fact"], - } - ] - p._container.query_items.return_value = iter(self._turns()) - p._load_existing_memories = MagicMock(return_value=existing) - p._run_prompty = MagicMock( - return_value=json.dumps({"facts": [], "episodic": [self._episodic_payload(text=text)], "unclassified": []}) - ) - - out = p.extract_memories("u1", "t1") - - assert out["episodic_count"] == 1 - upsert_calls = [c.args[0] for c in p._upsert_memory.call_args_list if c.args[0].get("type") == "episodic"] - assert len(upsert_calls) == 1 - - def test_episodic_uses_sentinel_thread_id_for_partition_routing(self): - """Auto-extracted episodics MUST be persisted under the sentinel - ``thread_id="__episodic__"`` regardless of which thread emitted them. - - The memories container is partitioned hierarchically on - ``(user_id, thread_id)`` and Cosmos ``id`` uniqueness is per-partition - — so a deterministic ID seeded only on scope is only meaningful if - every episodic for that scope lands in the SAME partition. Writing - the live thread_id splits identical-scope episodics across two - partitions and breaks upsert dedup across threads. The originating - thread is preserved on ``metadata.originating_thread_id`` for audit. - """ - p = self._build() - p._container.query_items.return_value = iter(self._turns()) - p._load_existing_memories = MagicMock(return_value=[]) - p._run_prompty = MagicMock( - return_value=json.dumps( - { - "facts": [], - "episodic": [self._episodic_payload(text="Trip planning.")], - "unclassified": [], - } - ) - ) - - p.extract_memories("u1", "thread-alpha") - - upsert_calls = [c.args[0] for c in p._upsert_memory.call_args_list if c.args[0].get("type") == "episodic"] - assert len(upsert_calls) == 1 - doc = upsert_calls[0] - assert doc["thread_id"] == "__episodic__" - assert doc["metadata"]["originating_thread_id"] == "thread-alpha" - - def test_same_scope_episodics_collide_across_different_threads(self): - """The cross-thread regression test for the per-partition id bug. - - Same user, same scope ``(trip, Tokyo)``, but emitted from two - different threads (``thread-alpha`` and ``thread-beta``). Both - writes MUST produce the same det_id AND the same persisted - thread_id (the sentinel) so they land in one partition and the - second upsert replaces the first. Without the sentinel, the docs - live in two different partitions and you'd see duplicate - episodics for the same intent — exactly the bug the deterministic - ID was meant to prevent. - """ - p = self._build() - # Fresh iterator per call — both extract_memories calls need turns. - p._container.query_items.side_effect = lambda *a, **kw: iter(self._turns()) - p._load_existing_memories = MagicMock(return_value=[]) - p._run_prompty = MagicMock( - return_value=json.dumps( - { - "facts": [], - "episodic": [self._episodic_payload(text="Tokyo luxury hotel intent.")], - "unclassified": [], - } - ) - ) - - p.extract_memories("u1", "thread-alpha") - p.extract_memories("u1", "thread-beta") - - upsert_calls = [c.args[0] for c in p._upsert_memory.call_args_list if c.args[0].get("type") == "episodic"] - assert len(upsert_calls) == 2 - # Same det_id — scope is identity. - assert upsert_calls[0]["id"] == upsert_calls[1]["id"] - # Same partition (sentinel thread_id) — so the second upsert replaces - # the first instead of creating a duplicate in a sibling partition. - assert upsert_calls[0]["thread_id"] == upsert_calls[1]["thread_id"] == "__episodic__" - # Originating thread preserved on each metadata for audit. - assert upsert_calls[0]["metadata"]["originating_thread_id"] == "thread-alpha" - assert upsert_calls[1]["metadata"]["originating_thread_id"] == "thread-beta" - - class TestExtractEarlyReturnShape: """The no-memories early-return must include every key the success path returns; otherwise callers using ``result["exact_dedup_skipped"]`` @@ -931,7 +685,6 @@ def test_empty_thread_returns_full_dict_shape(self): "unclassified_count", "updated_count", "exact_dedup_skipped", - "dropped_episodic_count", ): assert key in out, f"missing key: {key}" assert out[key] == 0 @@ -1586,81 +1339,10 @@ def capture_prompty(name, inputs): assert "None" not in text -class TestExtractUpdateSupersedeReason: - """Extract-time UPDATE actions stamp ``supersede_reason="update"``, - distinct from reconcile-time ``"duplicate"`` (paraphrase merge) and - ``"contradict"`` (semantic conflict). The extract prompt defines - UPDATE as "contradicts or refines an existing memory" — labelling - these as ``"duplicate"`` makes audit trails ambiguous.""" - - def _build(self) -> PipelineService: - p = PipelineService.__new__(PipelineService) - p._embeddings = MagicMock() - p._embeddings.generate.return_value = [[0.1] * 8] - p._upsert_memory = MagicMock(side_effect=lambda doc: doc) - p._mark_superseded = MagicMock(return_value=True) - p._container = MagicMock() - p._memories_container = p._container - p._turns_container = p._container - p._summaries_container = p._container - p._chat = MagicMock() - p._load_existing_memories = MagicMock( - return_value=[ - { - "id": "fact_old", - "type": "fact", - "content": "User likes coffee", - "content_hash": "h_old", - } - ] - ) - p._container.read_item = MagicMock( - return_value={"id": "fact_old", "type": "fact", "content": "User likes coffee"} - ) - p._container.query_items.return_value = iter( - [ - { - "id": "turn-1", - "role": "user", - "content": "I love tea now", - "type": "turn", - "created_at": "2024-01-01T00:00:00+00:00", - } - ] - ) - return p - - def test_fact_update_uses_reason_update_not_duplicate(self): - p = self._build() - p._run_prompty = MagicMock( - return_value=json.dumps( - { - "facts": [ - { - "text": "User now prefers tea over coffee", - "confidence": 0.9, - "salience": 0.7, - "action": "UPDATE", - "supersedes_id": "fact_old", - "tags": ["sys:fact"], - } - ], - "procedural": [], - "episodic": [], - } - ) - ) - p.extract_memories("u1", "t1") - assert p._mark_superseded.called - call_kwargs = p._mark_superseded.call_args.kwargs - assert call_kwargs.get("reason") == "update" - - class TestExtractUpdateSelfCollapseGuard: - """When an LLM emits ``UPDATE`` whose new content hashes to the same - deterministic id as the target (paraphrase-equivalent text), the - upsert would overwrite the audit metadata that ``_mark_superseded`` - just stamped on the target. Treat as a no-op.""" + """Procedural synthesis self-collapse guard: when synthesis would emit a + proc id identical to the existing one, treat as a no-op. (Fact extract-time + UPDATE was removed — facts/contradictions are reconciled, not extract-tagged.)""" def _build(self) -> PipelineService: p = PipelineService.__new__(PipelineService) @@ -1676,52 +1358,6 @@ def _build(self) -> PipelineService: p._load_existing_memories = MagicMock(return_value=[]) return p - def test_fact_update_with_self_referential_id_is_skipped(self): - from azure.cosmos.agent_memory._utils import compute_content_hash - from azure.cosmos.agent_memory.services.pipeline import _ID_SEED_SEP - - p = self._build() - text = "User likes tea" - seed = _ID_SEED_SEP.join(("u1", "t1", compute_content_hash(text))) - det_id = f"fact_{hashlib.sha256(seed.encode()).hexdigest()[:32]}" - - p._container.query_items.return_value = iter( - [ - { - "id": "turn-1", - "role": "user", - "content": "tea", - "type": "turn", - "created_at": "2024-01-01T00:00:00+00:00", - } - ] - ) - p._run_prompty = MagicMock( - return_value=json.dumps( - { - "facts": [ - { - "text": text, - "confidence": 0.9, - "salience": 0.6, - "action": "UPDATE", - "supersedes_id": det_id, - "tags": ["sys:fact"], - } - ], - "procedural": [], - "episodic": [], - } - ) - ) - - out = p.extract_memories("u1", "t1") - - assert p._mark_superseded.call_count == 0 - fact_upserts = [c for c in p._upsert_memory.call_args_list if c.args[0].get("type") == "fact"] - assert fact_upserts == [] - assert out["fact_count"] == 0 - def test_procedural_update_with_self_referential_id_is_skipped(self): from azure.cosmos.agent_memory._utils import compute_content_hash from azure.cosmos.agent_memory.services.pipeline import _ID_SEED_SEP diff --git a/tests/unit/test_thresholds.py b/tests/unit/test_thresholds.py index 1add114..d7f9d3d 100644 --- a/tests/unit/test_thresholds.py +++ b/tests/unit/test_thresholds.py @@ -2,6 +2,7 @@ import pytest +from azure.cosmos.agent_memory import thresholds from azure.cosmos.agent_memory.thresholds import ( DEFAULT_ENABLE_TURN_EMBEDDINGS, DEFAULT_TTL_BY_TYPE, @@ -50,3 +51,149 @@ def test_enable_turn_embeddings_truthy_values(monkeypatch, raw) -> None: def test_enable_turn_embeddings_falsy_values(monkeypatch, raw) -> None: monkeypatch.setenv("ENABLE_TURN_EMBEDDINGS", raw) assert get_enable_turn_embeddings() is False + + +@pytest.mark.parametrize( + ("env_name", "getter_name", "expected"), + [ + ("FACT_EXTRACTION_EVERY_N", "get_fact_extraction_every_n", 1), + ("THREAD_SUMMARY_EVERY_N", "get_thread_summary_every_n", 10), + ("USER_SUMMARY_EVERY_N", "get_user_summary_every_n", 20), + ("DEDUP_EVERY_N", "get_dedup_every_n", 5), + ("DEDUP_POOL_SIZE", "get_dedup_pool_size", 50), + ("PROCEDURAL_SYNTHESIS_AUTO", "get_procedural_synthesis_auto", True), + ], +) +def test_env_config_getters_defaults( + monkeypatch: pytest.MonkeyPatch, + env_name: str, + getter_name: str, + expected: object, +) -> None: + monkeypatch.delenv(env_name, raising=False) + + assert getattr(thresholds, getter_name)() == expected + + +@pytest.mark.parametrize( + ("env_name", "getter_name", "raw", "expected"), + [ + ("FACT_EXTRACTION_EVERY_N", "get_fact_extraction_every_n", "2", 2), + ("THREAD_SUMMARY_EVERY_N", "get_thread_summary_every_n", "11", 11), + ("USER_SUMMARY_EVERY_N", "get_user_summary_every_n", "21", 21), + ("DEDUP_EVERY_N", "get_dedup_every_n", "3", 3), + ("DEDUP_POOL_SIZE", "get_dedup_pool_size", "75", 75), + ("PROCEDURAL_SYNTHESIS_AUTO", "get_procedural_synthesis_auto", "false", False), + ], +) +def test_env_config_getters_parse_env( + monkeypatch: pytest.MonkeyPatch, + env_name: str, + getter_name: str, + raw: str, + expected: object, +) -> None: + monkeypatch.setenv(env_name, raw) + + assert getattr(thresholds, getter_name)() == expected + + +@pytest.mark.parametrize( + ("env_name", "getter_name", "expected"), + [ + ("FACT_EXTRACTION_EVERY_N", "get_fact_extraction_every_n", 1), + ("THREAD_SUMMARY_EVERY_N", "get_thread_summary_every_n", 10), + ("USER_SUMMARY_EVERY_N", "get_user_summary_every_n", 20), + ("DEDUP_EVERY_N", "get_dedup_every_n", 5), + ("DEDUP_POOL_SIZE", "get_dedup_pool_size", 50), + ], +) +def test_int_getters_reject_negative( + monkeypatch: pytest.MonkeyPatch, + env_name: str, + getter_name: str, + expected: int, +) -> None: + monkeypatch.setenv(env_name, "-1") + + assert getattr(thresholds, getter_name)() == expected + + +@pytest.mark.parametrize( + ("env_name", "getter_name", "expected"), + [ + ("FACT_EXTRACTION_EVERY_N", "get_fact_extraction_every_n", 1), + ("THREAD_SUMMARY_EVERY_N", "get_thread_summary_every_n", 10), + ("USER_SUMMARY_EVERY_N", "get_user_summary_every_n", 20), + ("DEDUP_EVERY_N", "get_dedup_every_n", 5), + ("DEDUP_POOL_SIZE", "get_dedup_pool_size", 50), + ], +) +def test_int_getters_invalid_use_default( + monkeypatch: pytest.MonkeyPatch, + env_name: str, + getter_name: str, + expected: int, +) -> None: + monkeypatch.setenv(env_name, "bogus") + + assert getattr(thresholds, getter_name)() == expected + + +def test_dedup_pool_size_clamps_high(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("DEDUP_POOL_SIZE", "501") + + assert thresholds.get_dedup_pool_size() == 500 + + +def test_dedup_pool_size_rejects_zero(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("DEDUP_POOL_SIZE", "0") + + assert thresholds.get_dedup_pool_size() == 50 + + +def test_procedural_synthesis_auto_invalid_uses_default(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("PROCEDURAL_SYNTHESIS_AUTO", "bogus") + + assert thresholds.get_procedural_synthesis_auto() is True + + +def test_processor_owner_defaults_to_none(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("MEMORY_PROCESSOR_OWNER", raising=False) + + assert thresholds.get_processor_owner() is None + + +@pytest.mark.parametrize("raw", ["inprocess", "durable", "INPROCESS", "DURABLE"]) +def test_processor_owner_parse_env(monkeypatch: pytest.MonkeyPatch, raw: str) -> None: + monkeypatch.setenv("MEMORY_PROCESSOR_OWNER", raw) + + assert thresholds.get_processor_owner() == raw.lower() + + +def test_processor_owner_invalid_uses_default(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MEMORY_PROCESSOR_OWNER", "bogus") + + assert thresholds.get_processor_owner() is None + + +def test_internalized_getters_return_fixed_constants_and_ignore_env(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("DEDUP_CONTEXT_VECTOR_ENABLED", "false") + monkeypatch.setenv("DEDUP_CONTEXT_TOPK", "7") + monkeypatch.setenv("DEDUP_VECTOR_ENABLED", "false") + monkeypatch.setenv("DEDUP_SIM_HIGH", "0.50") + monkeypatch.setenv("DEDUP_SIM_LOW", "0.40") + monkeypatch.setenv("DEDUP_CANDIDATE_TOPK", "8") + monkeypatch.setenv("DEDUP_RECONCILE_MODE", "full_pool") + monkeypatch.setenv("DEDUP_CLUSTER_SIM", "0.10") + monkeypatch.setenv("DEDUP_FULL_RECLUSTER_EVERY_N", "4") + + assert thresholds.get_dedup_context_vector_enabled() is True + assert thresholds.get_dedup_context_topk() == 10 + assert thresholds.get_dedup_vector_enabled() is True + assert thresholds.get_dedup_sim_high() == 0.97 + assert thresholds.get_dedup_sim_low() == 0.80 + assert thresholds.get_dedup_candidate_topk() == 10 + assert thresholds.get_dedup_reconcile_mode() == "candidate" + assert thresholds.get_dedup_cluster_sim() == 0.60 + assert thresholds.get_dedup_full_recluster_every_n() == 12 diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 0749a0a..6226937 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -4,6 +4,7 @@ from azure.cosmos.agent_memory._utils import ( DEFAULT_TTL_BY_TYPE, + MAX_FULLTEXT_TERMS, _build_container_kwargs, _container_policies, _make_memory, @@ -12,7 +13,11 @@ _resolve_full_text_language, _resolve_vector_index_type, compute_content_hash, + distance_function_from_container_properties, + extract_keywords, normalize_ai_foundry_endpoint, + vector_order_direction, + vector_similarity_at_least, ) from azure.cosmos.agent_memory.exceptions import ConfigurationError, ValidationError @@ -192,65 +197,132 @@ def test_make_memory_invalid_type(): _make_memory(user_id="u1", role="user", content="test", memory_type="invalid") -def test_resolve_embedding_data_type_defaults(monkeypatch): - monkeypatch.delenv("AI_FOUNDRY_EMBEDDING_DATA_TYPE", raising=False) +def test_resolve_embedding_data_type_defaults(): assert _resolve_embedding_data_type(None) == "float32" -def test_resolve_embedding_data_type_from_env(monkeypatch): - monkeypatch.setenv("AI_FOUNDRY_EMBEDDING_DATA_TYPE", "int8") - assert _resolve_embedding_data_type(None) == "int8" - - -def test_resolve_embedding_data_type_explicit_overrides_env(monkeypatch): - monkeypatch.setenv("AI_FOUNDRY_EMBEDDING_DATA_TYPE", "int8") +def test_resolve_embedding_data_type_explicit(): + assert _resolve_embedding_data_type("int8") == "int8" assert _resolve_embedding_data_type("uint8") == "uint8" -def test_resolve_embedding_data_type_invalid_raises(monkeypatch): - monkeypatch.setenv("AI_FOUNDRY_EMBEDDING_DATA_TYPE", "bogus") +def test_resolve_embedding_data_type_invalid_raises(): with pytest.raises(ConfigurationError): - _resolve_embedding_data_type(None) + _resolve_embedding_data_type("bogus") -def test_resolve_distance_function_defaults(monkeypatch): - monkeypatch.delenv("AI_FOUNDRY_EMBEDDING_DISTANCE_FUNCTION", raising=False) +def test_resolve_distance_function_defaults(): assert _resolve_distance_function(None) == "cosine" -def test_resolve_distance_function_from_env(monkeypatch): - monkeypatch.setenv("AI_FOUNDRY_EMBEDDING_DISTANCE_FUNCTION", "dotproduct") - assert _resolve_distance_function(None) == "dotproduct" +def test_resolve_distance_function_explicit(): + assert _resolve_distance_function("dotproduct") == "dotproduct" + assert _resolve_distance_function("euclidean") == "euclidean" -def test_resolve_distance_function_invalid_raises(monkeypatch): - monkeypatch.setenv("AI_FOUNDRY_EMBEDDING_DISTANCE_FUNCTION", "manhattan") +def test_resolve_distance_function_invalid_raises(): with pytest.raises(ConfigurationError): - _resolve_distance_function(None) + _resolve_distance_function("manhattan") -def test_resolve_vector_index_type_defaults(monkeypatch): - monkeypatch.delenv("AI_FOUNDRY_EMBEDDING_VECTOR_INDEX_TYPE", raising=False) - assert _resolve_vector_index_type(None) == "quantizedFlat" +def test_vector_order_direction_per_function(): + # cosine/dotproduct: higher VectorDistance == more similar -> DESC for nearest-first. + assert vector_order_direction("cosine") == "DESC" + assert vector_order_direction("dotproduct") == "DESC" + # euclidean: lower distance == more similar -> ASC for nearest-first. + assert vector_order_direction("euclidean") == "ASC" + + +def test_vector_similarity_at_least_cosine_and_dotproduct(): + # Higher score is more similar; threshold is a floor. + for fn in ("cosine", "dotproduct"): + assert vector_similarity_at_least(0.97, 0.97, fn) is True + assert vector_similarity_at_least(0.99, 0.97, fn) is True + assert vector_similarity_at_least(0.80, 0.97, fn) is False -def test_resolve_vector_index_type_from_env(monkeypatch): - monkeypatch.setenv("AI_FOUNDRY_EMBEDDING_VECTOR_INDEX_TYPE", "quantizedFlat") +def test_vector_similarity_at_least_euclidean_inverts(): + # Lower distance is more similar; threshold is a ceiling. + assert vector_similarity_at_least(0.10, 0.20, "euclidean") is True + assert vector_similarity_at_least(0.20, 0.20, "euclidean") is True + assert vector_similarity_at_least(0.50, 0.20, "euclidean") is False + + +def test_distance_function_from_container_properties_reads_policy(): + props = { + "id": "memories", + "vectorEmbeddingPolicy": { + "vectorEmbeddings": [ + {"path": "/embedding", "dataType": "float32", "distanceFunction": "euclidean", "dimensions": 1536} + ] + }, + } + assert distance_function_from_container_properties(props) == "euclidean" + + +def test_distance_function_from_container_properties_reads_single_embedding(): + # This SDK provisions a single vector embedding; the resolver reads its + # distanceFunction directly (the path value is irrelevant here). + props = { + "vectorEmbeddingPolicy": { + "vectorEmbeddings": [ + {"path": "/embedding", "distanceFunction": "dotproduct"}, + ] + } + } + assert distance_function_from_container_properties(props) == "dotproduct" + + +@pytest.mark.parametrize( + "props", + [ + None, + {}, + {"vectorEmbeddingPolicy": {}}, + {"vectorEmbeddingPolicy": {"vectorEmbeddings": []}}, + {"vectorEmbeddingPolicy": {"vectorEmbeddings": [{"path": "/embedding", "distanceFunction": "manhattan"}]}}, + "not-a-dict", + ], +) +def test_distance_function_from_container_properties_falls_back_to_cosine(props): + assert distance_function_from_container_properties(props) == "cosine" + + +def test_extract_keywords_basic_and_stopwords(): + # Stopwords removed, lowercased, de-duplicated, first-seen order preserved. + assert extract_keywords("The user LOVES hiking and hiking trails") == [ + "user", + "loves", + "hiking", + "trails", + ] + assert extract_keywords("") == [] + assert extract_keywords("the and of a an") == [] # all stopwords + + +def test_extract_keywords_capped_at_cosmos_fulltext_limit(): + # Cosmos FullTextScore rejects >30 terms; extraction must cap at exactly 30 so + # the hybrid query is always valid even for long multi-turn context strings. + text = " ".join(f"term{i}" for i in range(100)) + kws = extract_keywords(text) + assert len(kws) == MAX_FULLTEXT_TERMS == 30 + assert kws == [f"term{i}" for i in range(30)] + + +def test_resolve_vector_index_type_defaults(): assert _resolve_vector_index_type(None) == "quantizedFlat" -def test_resolve_vector_index_type_explicit_overrides_env(monkeypatch): - monkeypatch.setenv("AI_FOUNDRY_EMBEDDING_VECTOR_INDEX_TYPE", "quantizedFlat") +def test_resolve_vector_index_type_explicit(): assert _resolve_vector_index_type("flat") == "flat" -def test_resolve_vector_index_type_invalid_raises(monkeypatch): - monkeypatch.setenv("AI_FOUNDRY_EMBEDDING_VECTOR_INDEX_TYPE", "hnsw") +def test_resolve_vector_index_type_invalid_raises(): with pytest.raises(ConfigurationError): - _resolve_vector_index_type(None) + _resolve_vector_index_type("hnsw") -def test_container_policies_defaults_to_diskann(): +def test_container_policies_defaults_vector_index_type(): _, indexing_policy, _ = _container_policies( embedding_dimensions=1536, embedding_data_type="float32", @@ -266,19 +338,17 @@ def test_container_policies_uses_supplied_vector_index_type(): embedding_data_type="float32", distance_function="cosine", full_text_language="en-US", - vector_index_type="quantizedFlat", + vector_index_type="diskANN", ) - assert indexing_policy["vectorIndexes"] == [{"path": "/embedding", "type": "quantizedFlat"}] + assert indexing_policy["vectorIndexes"] == [{"path": "/embedding", "type": "diskANN"}] -def test_resolve_full_text_language_defaults(monkeypatch): - monkeypatch.delenv("COSMOS_DB_FULL_TEXT_LANGUAGE", raising=False) +def test_resolve_full_text_language_defaults(): assert _resolve_full_text_language(None) == "en-US" -def test_resolve_full_text_language_from_env(monkeypatch): - monkeypatch.setenv("COSMOS_DB_FULL_TEXT_LANGUAGE", "fr-FR") - assert _resolve_full_text_language(None) == "fr-FR" +def test_resolve_full_text_language_explicit(): + assert _resolve_full_text_language("fr-FR") == "fr-FR" # --------------------------------------------------------------------------- From 2cf898833bec58e2302a6d9cbfe8d4da9788e9ab Mon Sep 17 00:00:00 2001 From: Aayush Kataria Date: Wed, 1 Jul 2026 09:40:21 -0700 Subject: [PATCH 2/2] fixing builds --- .../agent_memory/aio/services/pipeline.py | 8 +- .../cosmos/agent_memory/services/pipeline.py | 53 +++++------ azure/cosmos/agent_memory/thresholds.py | 14 +-- .../orchestrators/extract_memories.py | 11 ++- function_app/triggers/change_feed.py | 4 +- tests/integration/test_async_full_pipeline.py | 9 +- tests/integration/test_full_pipeline.py | 6 +- .../aio/services/test_dedup_vector_async.py | 16 +--- tests/unit/aio/test_auto_trigger.py | 92 ++++++++++--------- tests/unit/function_app/test_orchestrators.py | 4 +- tests/unit/services/test_dedup_vector.py | 25 ++--- tests/unit/test_auto_trigger.py | 73 ++++++++------- 12 files changed, 151 insertions(+), 164 deletions(-) diff --git a/azure/cosmos/agent_memory/aio/services/pipeline.py b/azure/cosmos/agent_memory/aio/services/pipeline.py index 4e5de1e..ba8798f 100644 --- a/azure/cosmos/agent_memory/aio/services/pipeline.py +++ b/azure/cosmos/agent_memory/aio/services/pipeline.py @@ -530,9 +530,7 @@ async def extract_memories_dry( m["content_hash"] for m in existing_for_hash if m.get("type") == "fact" and m.get("content_hash") } if get_dedup_context_vector_enabled(): - user_turns_text = "\n".join( - str(it.get("content", "")) for it in items if it.get("role") == "user" - ).strip() + user_turns_text = "\n".join(str(it.get("content", "")) for it in items if it.get("role") == "user").strip() context_query = user_turns_text or transcript existing = await self._store.search( search_terms=context_query, @@ -1778,9 +1776,7 @@ async def reconcile_memories( # "loves steak") — which candidate clustering, keyed on near-duplicate # similarity, would never group. Automatic sweeps use cheap candidate mode. if get_dedup_reconcile_mode() == "candidate" and not full_rebuild: - return await self._reconcile_candidate_mode( - user_id, n=n, memory_type=memory_type, started_at=started_at - ) + return await self._reconcile_candidate_mode(user_id, n=n, memory_type=memory_type, started_at=started_at) facts = await self._active_memories_for_reconcile(user_id, memory_type, n) result, consumed = await self._reconcile_pool(user_id, memory_type, facts) diff --git a/azure/cosmos/agent_memory/services/pipeline.py b/azure/cosmos/agent_memory/services/pipeline.py index b216816..a44f87f 100644 --- a/azure/cosmos/agent_memory/services/pipeline.py +++ b/azure/cosmos/agent_memory/services/pipeline.py @@ -414,9 +414,7 @@ def _query_active_memories( ) ) - def _load_memories_by_ids( - self, user_id: str, memory_type: str, ids: Iterable[str] - ) -> list[dict[str, Any]]: + def _load_memories_by_ids(self, user_id: str, memory_type: str, ids: Iterable[str]) -> list[dict[str, Any]]: ids = [mid for mid in dict.fromkeys(ids) if mid] if not ids: return [] @@ -600,9 +598,7 @@ def extract_memories_dry( transcript = self._build_transcript(items) existing = existing_for_hashes if threshold_config.get_dedup_context_vector_enabled(): - user_turns_text = "\n".join( - str(it.get("content", "")) for it in items if it.get("role") == "user" - ).strip() + user_turns_text = "\n".join(str(it.get("content", "")) for it in items if it.get("role") == "user").strip() context_query = user_turns_text or transcript existing = self._store.search( search_terms=context_query, @@ -1640,8 +1636,7 @@ def _build_candidate_clusters( candidate_ids = [ row["id"] for row in candidates - if row.get("id") - and vector_similarity_at_least(row.get("score", 0.0), cluster_sim, distance_function) + if row.get("id") and vector_similarity_at_least(row.get("score", 0.0), cluster_sim, distance_function) ] for cid in candidate_ids: edge_pairs.add(tuple(sorted((seed_id, cid)))) @@ -1698,9 +1693,7 @@ def _build_candidate_clusters( clusters.append([nodes_by_id[cid] for cid in component]) return clusters, len(nodes_by_id), seeds - def _reconcile_candidate_mode( - self, user_id: str, *, n: int, memory_type: str, started_at: float - ) -> dict[str, int]: + def _reconcile_candidate_mode(self, user_id: str, *, n: int, memory_type: str, started_at: float) -> dict[str, int]: # Candidate clustering only. The periodic full-pool backstop that catches # dissimilar-embedding contradictions ("vegetarian" vs "loves steak") is # driven by the caller via ``full_rebuild`` on a PERSISTED-counter cadence @@ -1810,9 +1803,7 @@ def reconcile_memories( # "loves steak") — which candidate clustering, keyed on near-duplicate # similarity, would never group. Automatic sweeps use cheap candidate mode. if threshold_config.get_dedup_reconcile_mode() == "candidate" and not full_rebuild: - return self._reconcile_candidate_mode( - user_id, n=n, memory_type=memory_type, started_at=started_at - ) + return self._reconcile_candidate_mode(user_id, n=n, memory_type=memory_type, started_at=started_at) facts = self._active_memories_for_reconcile(user_id, memory_type, n) result, consumed = self._reconcile_pool(user_id, memory_type, facts) @@ -1828,9 +1819,7 @@ def reconcile_memories( ) return result - def _active_memories_for_reconcile( - self, user_id: str, memory_type: str, n: int - ) -> list[dict[str, Any]]: + def _active_memories_for_reconcile(self, user_id: str, memory_type: str, n: int) -> list[dict[str, Any]]: # ---- Load up to N most recent active memories ---- # ORDER BY c.created_at DESC keeps the TOP cap deterministic across # physical partitions and matches the dedup prompt's tiebreaker @@ -2058,21 +2047,21 @@ def _reconcile_pool( try: merged_payload: dict[str, Any] = { - "id": merged_id, - "user_id": user_id, - "role": "system", - "type": memory_type, - "content": merged_content, - "thread_id": recent_thread_id or f"__reconciled__:{user_id}", - "confidence": confidence_val if confidence_val is not None else 0.5, - "salience": salience_val if salience_val is not None else 0.5, - "supersedes_ids": merged_supersedes, - "source_memory_ids": merged_source_memory_ids, - "tags": merged_tags, - "content_hash": merged_content_hash, - **self._prompt_lineage(prompt_filename), - "created_at": datetime.now(timezone.utc).isoformat(), - "updated_at": datetime.now(timezone.utc).isoformat(), + "id": merged_id, + "user_id": user_id, + "role": "system", + "type": memory_type, + "content": merged_content, + "thread_id": recent_thread_id or f"__reconciled__:{user_id}", + "confidence": confidence_val if confidence_val is not None else 0.5, + "salience": salience_val if salience_val is not None else 0.5, + "supersedes_ids": merged_supersedes, + "source_memory_ids": merged_source_memory_ids, + "tags": merged_tags, + "content_hash": merged_content_hash, + **self._prompt_lineage(prompt_filename), + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), } if memory_type == "fact": merged_payload["metadata"] = { diff --git a/azure/cosmos/agent_memory/thresholds.py b/azure/cosmos/agent_memory/thresholds.py index ca4c508..f25030c 100644 --- a/azure/cosmos/agent_memory/thresholds.py +++ b/azure/cosmos/agent_memory/thresholds.py @@ -35,15 +35,15 @@ # needs to become operator-facing we add the env plumbing back deliberately. # The dedup + hybrid-search features ship ON via these values. # --------------------------------------------------------------------------- -DEDUP_CONTEXT_VECTOR_ENABLED = True # Stage-1 relevance-ranked extraction context +DEDUP_CONTEXT_VECTOR_ENABLED = True # Stage-1 relevance-ranked extraction context DEDUP_CONTEXT_TOPK = 10 -DEDUP_VECTOR_ENABLED = True # Stage-3 vector near-dup ladder -DEDUP_SIM_HIGH = 0.97 # >= -> auto-skip near-exact -DEDUP_SIM_LOW = 0.80 # < -> novel; between -> tag candidate +DEDUP_VECTOR_ENABLED = True # Stage-3 vector near-dup ladder +DEDUP_SIM_HIGH = 0.97 # >= -> auto-skip near-exact +DEDUP_SIM_LOW = 0.80 # < -> novel; between -> tag candidate DEDUP_CANDIDATE_TOPK = 10 -DEDUP_RECONCILE_MODE = "candidate" # clustered candidate reconcile (vs legacy full_pool) -DEDUP_CLUSTER_SIM = 0.60 # Stage-5 clustering edge threshold -DEDUP_FULL_RECLUSTER_EVERY_N = 12 # full re-cluster safety net cadence +DEDUP_RECONCILE_MODE = "candidate" # clustered candidate reconcile (vs legacy full_pool) +DEDUP_CLUSTER_SIM = 0.60 # Stage-5 clustering edge threshold +DEDUP_FULL_RECLUSTER_EVERY_N = 12 # full re-cluster safety net cadence DEFAULT_TTL_BY_TYPE: dict[str, int] = { "turn": 2_592_000, diff --git a/function_app/orchestrators/extract_memories.py b/function_app/orchestrators/extract_memories.py index ec859bc..db56c24 100644 --- a/function_app/orchestrators/extract_memories.py +++ b/function_app/orchestrators/extract_memories.py @@ -127,10 +127,13 @@ def em_Extract(payload: dict) -> dict: @bp.activity_trigger(input_name="payload") def em_Dedup(payload: dict) -> dict: """vector-floor dedup ladder (gated; passthrough when disabled).""" - return get_pipeline().dedup_extracted_memories( - user_id=payload["user_id"], - extracted=payload["extracted"], - ) or payload["extracted"] + return ( + get_pipeline().dedup_extracted_memories( + user_id=payload["user_id"], + extracted=payload["extracted"], + ) + or payload["extracted"] + ) @bp.activity_trigger(input_name="payload") diff --git a/function_app/triggers/change_feed.py b/function_app/triggers/change_feed.py index 4630a46..6b7ad9d 100644 --- a/function_app/triggers/change_feed.py +++ b/function_app/triggers/change_feed.py @@ -220,9 +220,7 @@ async def process_changefeed_batch( # Persisted-counter backstop: every n_full_turns turns, force a # full-pool reconcile so dissimilar-embedding contradictions are # caught reliably on FA (not gated by the in-memory sweep counter). - should_full_reconcile = bool( - n_full_turns > 0 and crosses_threshold(old_count, new_count, n_full_turns) - ) + should_full_reconcile = bool(n_full_turns > 0 and crosses_threshold(old_count, new_count, n_full_turns)) watermark = await read_extract_watermark(counter_container, cid, user_id, thread_id) # Not capped: new_count - watermark is exactly the unextracted backlog # and the orchestrator advances the watermark to new_count, so capping diff --git a/tests/integration/test_async_full_pipeline.py b/tests/integration/test_async_full_pipeline.py index 88f66c4..a47804e 100644 --- a/tests/integration/test_async_full_pipeline.py +++ b/tests/integration/test_async_full_pipeline.py @@ -131,10 +131,7 @@ async def _async_seed_fact_with_embedding( sync helper). Retries through transient embedding-service blips so the extract-time vector floor always has a neighbour; skips honestly if the embedding service is genuinely unavailable.""" - check = ( - "SELECT c.id FROM c WHERE c.user_id = @uid AND c.content = @content " - "AND IS_DEFINED(c.embedding)" - ) + check = "SELECT c.id FROM c WHERE c.user_id = @uid AND c.content = @content AND IS_DEFINED(c.embedding)" params = [{"name": "@uid", "value": user_id}, {"name": "@content", "value": content}] for _ in range(retries): await mem.add_cosmos( @@ -145,9 +142,7 @@ async def _async_seed_fact_with_embedding( thread_id=thread_id, salience=0.7, ) - embedded = [ - doc async for doc in mem._memories_container_client.query_items(query=check, parameters=params) - ] + embedded = [doc async for doc in mem._memories_container_client.query_items(query=check, parameters=params)] if embedded: return await asyncio.sleep(1) diff --git a/tests/integration/test_full_pipeline.py b/tests/integration/test_full_pipeline.py index dcb217e..ed6be30 100644 --- a/tests/integration/test_full_pipeline.py +++ b/tests/integration/test_full_pipeline.py @@ -138,10 +138,7 @@ def _seed_fact_with_embedding( neighbour to match. Retry until an embedded copy exists (indexing is fast — the doc is vector-searchable within ~2s), and skip honestly if the embedding service is genuinely unavailable rather than reporting a false failure.""" - check = ( - "SELECT c.id FROM c WHERE c.user_id = @uid AND c.content = @content " - "AND IS_DEFINED(c.embedding)" - ) + check = "SELECT c.id FROM c WHERE c.user_id = @uid AND c.content = @content AND IS_DEFINED(c.embedding)" params = [{"name": "@uid", "value": user_id}, {"name": "@content", "value": content}] for _ in range(retries): mem.add_cosmos( @@ -624,4 +621,3 @@ def test_dedup_extracted_memories_flags_near_duplicate_of_stored_fact( ) finally: _cleanup(agent_memory, unique_user_id) - diff --git a/tests/unit/aio/services/test_dedup_vector_async.py b/tests/unit/aio/services/test_dedup_vector_async.py index be62553..e20a206 100644 --- a/tests/unit/aio/services/test_dedup_vector_async.py +++ b/tests/unit/aio/services/test_dedup_vector_async.py @@ -94,9 +94,7 @@ async def fake_query_items(_container, *, query, parameters): p._query_items = AsyncMock(side_effect=fake_query_items) p._distance_function_cache = "cosine" - out = await p._vector_candidates( - user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set() - ) + out = await p._vector_candidates(user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set()) # Cosmos rejects an explicit ASC/DESC on ORDER BY VectorDistance(); it orders # most-similar-first server-side. Direction-awareness lives in the Python sort. assert "ORDER BY VectorDistance(c.embedding, @vec)" in captured["query"] @@ -105,9 +103,7 @@ async def fake_query_items(_container, *, query, parameters): assert [c["id"] for c in out] == ["near", "far"] p._distance_function_cache = "euclidean" - out = await p._vector_candidates( - user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set() - ) + out = await p._vector_candidates(user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set()) assert "VectorDistance(c.embedding, @vec) ASC" not in captured["query"] # euclidean: lower distance = more similar, so 0.10 ("far" label) sorts first. # euclidean: lower distance = more similar, so 0.10 ("far" label) sorts first. @@ -202,9 +198,7 @@ async def clear(docs): @pytest.mark.asyncio async def test_full_rebuild_clears_survivor_tags(monkeypatch): # Async mirror: full_rebuild full-pool path clears survivor dup-candidate tags. - monkeypatch.setattr( - "azure.cosmos.agent_memory.aio.services.pipeline.get_dedup_reconcile_mode", lambda: "candidate" - ) + monkeypatch.setattr("azure.cosmos.agent_memory.aio.services.pipeline.get_dedup_reconcile_mode", lambda: "candidate") p = _service() pool = [ _fact("f1", "a", tags=["sys:fact", "sys:dup-candidate"]), @@ -333,9 +327,7 @@ async def test_dedup_skips_underspecified_doc_verbatim(): # Parity with sync: a doc with no/unknown type is passed through untouched # and never runs vector dedup (async previously defaulted type to the bucket). p = _service() - p._vector_candidates = AsyncMock( - return_value=[{"id": "x", "content": "c", "type": "fact", "score": 0.99}] - ) + p._vector_candidates = AsyncMock(return_value=[{"id": "x", "content": "c", "type": "fact", "score": 0.99}]) doc = _fact("f1", "content") doc.pop("type") doc.pop("embedding", None) diff --git a/tests/unit/aio/test_auto_trigger.py b/tests/unit/aio/test_auto_trigger.py index adb9f88..0c68e28 100644 --- a/tests/unit/aio/test_auto_trigger.py +++ b/tests/unit/aio/test_auto_trigger.py @@ -52,7 +52,6 @@ async def patch_item(self, *, item, partition_key, patch_operations): return dict(doc) - class TestAsyncAutoTriggerNonBlocking: @pytest.mark.asyncio async def test_push_to_cosmos_does_not_await_auto_trigger(self, monkeypatch): @@ -155,15 +154,13 @@ async def fake_upsert(body): @pytest.mark.parametrize( ("counter_result", "watermark", "expected_recent_k"), [ - ((5, 10), 5, 5), # backlog = new - watermark - ((0, 1), 1, 1), # new == watermark -> floored to 1 - ((98, 100), 0, 100), # large backlog is NOT capped + ((5, 10), 5, 5), # backlog = new - watermark + ((0, 1), 1, 1), # new == watermark -> floored to 1 + ((98, 100), 0, 100), # large backlog is NOT capped ((20, 30), None, 30), # BOOTSTRAP: no watermark -> base=0 -> recent_k = new_count ], ) - async def test_extract_recent_k_uses_watermark( - self, monkeypatch, counter_result, watermark, expected_recent_k - ): + async def test_extract_recent_k_uses_watermark(self, monkeypatch, counter_result, watermark, expected_recent_k): monkeypatch.setenv("FACT_EXTRACTION_EVERY_N", "1") monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") @@ -177,16 +174,20 @@ async def test_extract_recent_k_uses_watermark( client._summaries_container_client = client._memories_container_client client._counter_container_client = MagicMock() - with patch( - "azure.cosmos.agent_memory._counters.increment_counter_async", - return_value=counter_result, - ), patch( - "azure.cosmos.agent_memory._counters.read_extract_watermark_async", - new=AsyncMock(return_value=watermark), - ), patch( - "azure.cosmos.agent_memory._counters.advance_extract_watermark_async", - new=AsyncMock(), - ) as advance: + with ( + patch( + "azure.cosmos.agent_memory._counters.increment_counter_async", + return_value=counter_result, + ), + patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_async", + new=AsyncMock(return_value=watermark), + ), + patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_async", + new=AsyncMock(), + ) as advance, + ): client.add_local(user_id="u1", role="user", thread_id="t1", content="hi") await client.push_to_cosmos() await asyncio.gather(*list(client._background_tasks), return_exceptions=True) @@ -211,19 +212,24 @@ async def test_watermark_not_advanced_when_extract_fails(self, monkeypatch): client._summaries_container_client = client._memories_container_client client._counter_container_client = MagicMock() - with patch( - "azure.cosmos.agent_memory._counters.increment_counter_async", - return_value=(0, 1), - ), patch( - "azure.cosmos.agent_memory._counters.read_extract_watermark_async", - new=AsyncMock(return_value=None), - ), patch( - "azure.cosmos.agent_memory._counters.advance_extract_watermark_async", - new=AsyncMock(), - ) as advance, patch( - "azure.cosmos.agent_memory._counters.stamp_failure_async", - new=AsyncMock(), - ) as stamp: + with ( + patch( + "azure.cosmos.agent_memory._counters.increment_counter_async", + return_value=(0, 1), + ), + patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_async", + new=AsyncMock(return_value=None), + ), + patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_async", + new=AsyncMock(), + ) as advance, + patch( + "azure.cosmos.agent_memory._counters.stamp_failure_async", + new=AsyncMock(), + ) as stamp, + ): client.add_local(user_id="u1", role="user", thread_id="t1", content="hi") await client.push_to_cosmos() await asyncio.gather(*list(client._background_tasks), return_exceptions=True) @@ -280,9 +286,7 @@ async def test_reconcile_full_rebuild_on_persisted_counter_cadence(self, monkeyp monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") monkeypatch.setenv("DEDUP_EVERY_N", "1") - monkeypatch.setattr( - "azure.cosmos.agent_memory.thresholds.get_dedup_full_recluster_every_n", lambda: 2 - ) + monkeypatch.setattr("azure.cosmos.agent_memory.thresholds.get_dedup_full_recluster_every_n", lambda: 2) rebuilds: list[bool] = [] processor = AsyncInProcessProcessor(pipeline=MagicMock()) @@ -298,15 +302,19 @@ async def test_reconcile_full_rebuild_on_persisted_counter_cadence(self, monkeyp client._summaries_container_client = client._memories_container_client client._counter_container_client = MagicMock() - with patch( - "azure.cosmos.agent_memory._counters.increment_counter_async", - new=AsyncMock(side_effect=[(0, 1), (1, 2)]), - ), patch( - "azure.cosmos.agent_memory._counters.read_extract_watermark_async", - new=AsyncMock(return_value=None), - ), patch( - "azure.cosmos.agent_memory._counters.advance_extract_watermark_async", - new=AsyncMock(), + with ( + patch( + "azure.cosmos.agent_memory._counters.increment_counter_async", + new=AsyncMock(side_effect=[(0, 1), (1, 2)]), + ), + patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_async", + new=AsyncMock(return_value=None), + ), + patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_async", + new=AsyncMock(), + ), ): client.add_local(user_id="u1", role="user", thread_id="t1", content="a") await client.push_to_cosmos() diff --git a/tests/unit/function_app/test_orchestrators.py b/tests/unit/function_app/test_orchestrators.py index 9fc629f..808e9e1 100644 --- a/tests/unit/function_app/test_orchestrators.py +++ b/tests/unit/function_app/test_orchestrators.py @@ -462,9 +462,7 @@ def test_em_advance_extract_watermark_stamps_counter(self): patch.object(cosmos_clients, "get_counter_container_async", new=AsyncMock(return_value=container)), patch.object(counters, "advance_extract_watermark", new=AsyncMock()) as advance, ): - result = asyncio.run( - em_mod.em_AdvanceExtractWatermark({"user_id": "u1", "thread_id": "t1", "count": 9}) - ) + result = asyncio.run(em_mod.em_AdvanceExtractWatermark({"user_id": "u1", "thread_id": "t1", "count": 9})) assert result is True advance.assert_awaited_once_with(container, "thread:u1:t1", "u1", "t1", 9) diff --git a/tests/unit/services/test_dedup_vector.py b/tests/unit/services/test_dedup_vector.py index a44f83a..6b4c4bd 100644 --- a/tests/unit/services/test_dedup_vector.py +++ b/tests/unit/services/test_dedup_vector.py @@ -24,12 +24,17 @@ def _make_pipeline() -> PipelineService: def _doc(mid: str, content: str, memory_type: str = "fact", **extra: Any) -> dict[str, Any]: tags = extra.pop("tags", [f"sys:{memory_type}"]) - metadata = extra.pop("metadata", {"category": "preference"} if memory_type == "fact" else { - "scope_type": "project", - "scope_value": "demo", - "lesson": content, - "outcome_valence": "neutral", - }) + metadata = extra.pop( + "metadata", + {"category": "preference"} + if memory_type == "fact" + else { + "scope_type": "project", + "scope_value": "demo", + "lesson": content, + "outcome_valence": "neutral", + }, + ) return { "id": mid, "user_id": "u1", @@ -101,9 +106,7 @@ def query_items(*, query: str, parameters, **kwargs): p._memories_container.query_items.side_effect = query_items p._distance_function_cache = "cosine" - out = p._vector_candidates( - user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set() - ) + out = p._vector_candidates(user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set()) # Cosmos rejects an explicit ASC/DESC on ORDER BY VectorDistance(); it orders # most-similar-first server-side. Direction-awareness lives in the Python sort. assert "ORDER BY VectorDistance(c.embedding, @vec)" in captured["query"] @@ -112,9 +115,7 @@ def query_items(*, query: str, parameters, **kwargs): assert [c["id"] for c in out] == ["near", "far"] p._distance_function_cache = "euclidean" - out = p._vector_candidates( - user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set() - ) + out = p._vector_candidates(user_id="u1", embedding=[1.0, 0.0], memory_type="fact", top_k=2, exclude_ids=set()) assert "VectorDistance(c.embedding, @vec) ASC" not in captured["query"] # euclidean: lower distance = more similar, so 0.10 ("far" label) sorts first. assert [c["id"] for c in out] == ["far", "near"] diff --git a/tests/unit/test_auto_trigger.py b/tests/unit/test_auto_trigger.py index 4bb73e7..2f23420 100644 --- a/tests/unit/test_auto_trigger.py +++ b/tests/unit/test_auto_trigger.py @@ -226,15 +226,19 @@ def test_extract_recent_k_uses_watermark_then_falls_back( client = _connected(processor=processor) client._counter_container_client = MagicMock() - with patch( - "azure.cosmos.agent_memory._counters.increment_counter_sync", - return_value=counter_result, - ), patch( - "azure.cosmos.agent_memory._counters.read_extract_watermark_sync", - return_value=watermark, - ), patch( - "azure.cosmos.agent_memory._counters.advance_extract_watermark_sync", - ) as advance: + with ( + patch( + "azure.cosmos.agent_memory._counters.increment_counter_sync", + return_value=counter_result, + ), + patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_sync", + return_value=watermark, + ), + patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_sync", + ) as advance, + ): for i in range(batch_count): client.add_local(user_id="u1", role="user", thread_id="t1", content=f"hi {i}") client.push_to_cosmos() @@ -300,17 +304,22 @@ def test_watermark_not_advanced_when_extract_fails(self, monkeypatch): client = _connected(processor=processor) client._counter_container_client = MagicMock() - with patch( - "azure.cosmos.agent_memory._counters.increment_counter_sync", - return_value=(0, 1), - ), patch( - "azure.cosmos.agent_memory._counters.read_extract_watermark_sync", - return_value=None, - ), patch( - "azure.cosmos.agent_memory._counters.advance_extract_watermark_sync", - ) as advance, patch( - "azure.cosmos.agent_memory._counters.stamp_failure_sync", - ) as stamp: + with ( + patch( + "azure.cosmos.agent_memory._counters.increment_counter_sync", + return_value=(0, 1), + ), + patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_sync", + return_value=None, + ), + patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_sync", + ) as advance, + patch( + "azure.cosmos.agent_memory._counters.stamp_failure_sync", + ) as stamp, + ): client.add_local(user_id="u1", role="user", thread_id="t1", content="hi") client.push_to_cosmos() @@ -326,9 +335,7 @@ def test_reconcile_full_rebuild_on_persisted_counter_cadence(self, monkeypatch): monkeypatch.setenv("THREAD_SUMMARY_EVERY_N", "0") monkeypatch.setenv("USER_SUMMARY_EVERY_N", "0") monkeypatch.setenv("DEDUP_EVERY_N", "1") - monkeypatch.setattr( - "azure.cosmos.agent_memory.thresholds.get_dedup_full_recluster_every_n", lambda: 2 - ) + monkeypatch.setattr("azure.cosmos.agent_memory.thresholds.get_dedup_full_recluster_every_n", lambda: 2) rebuilds: list[bool] = [] processor = InProcessProcessor(pipeline=MagicMock()) @@ -341,14 +348,18 @@ def test_reconcile_full_rebuild_on_persisted_counter_cadence(self, monkeypatch): client = _connected(processor=processor) client._counter_container_client = MagicMock() - with patch( - "azure.cosmos.agent_memory._counters.increment_counter_sync", - side_effect=[(0, 1), (1, 2)], - ), patch( - "azure.cosmos.agent_memory._counters.read_extract_watermark_sync", - return_value=None, - ), patch( - "azure.cosmos.agent_memory._counters.advance_extract_watermark_sync", + with ( + patch( + "azure.cosmos.agent_memory._counters.increment_counter_sync", + side_effect=[(0, 1), (1, 2)], + ), + patch( + "azure.cosmos.agent_memory._counters.read_extract_watermark_sync", + return_value=None, + ), + patch( + "azure.cosmos.agent_memory._counters.advance_extract_watermark_sync", + ), ): client.add_local(user_id="u1", role="user", thread_id="t1", content="a") client.push_to_cosmos() # counter 0->1: reconcile, full crosses 2? no