Skip to content

Commit 4529bfa

Browse files
committed
refactor(core): share vector sync prepare orchestration
Signed-off-by: phernandez <paul@basicmachines.co>
1 parent 2f34747 commit 4529bfa

2 files changed

Lines changed: 146 additions & 343 deletions

File tree

src/basic_memory/repository/postgres_search_repository.py

Lines changed: 23 additions & 246 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
from basic_memory.repository.search_index_row import SearchIndexRow
1919
from basic_memory.repository.search_repository_base import (
2020
SearchRepositoryBase,
21-
VectorSyncBatchResult,
22-
_EntitySyncRuntime,
23-
_PendingEmbeddingJob,
2421
_PreparedEntityVectorSync,
2522
)
2623
from basic_memory.repository.metadata_filters import parse_metadata_filters
@@ -494,248 +491,22 @@ async def _run_vector_query(
494491
)
495492
return [dict(row) for row in vector_result.mappings().all()]
496493

497-
async def sync_entity_vectors_batch(
498-
self,
499-
entity_ids: list[int],
500-
progress_callback=None,
501-
) -> VectorSyncBatchResult:
502-
"""Sync semantic vectors with concurrent Postgres preparation windows.
503-
504-
Trigger: cloud indexing uses Neon Postgres where network latency dominates
505-
thousands of per-entity prepare queries.
506-
Why: preparing a small config-driven window of entities concurrently hides
507-
round-trip latency without exhausting the tenant connection pool.
508-
Outcome: Postgres vector sync keeps the existing flush semantics while reducing
509-
wall-clock time on large cloud projects.
510-
"""
511-
self._assert_semantic_available()
512-
await self._ensure_vector_tables()
513-
assert self._embedding_provider is not None
514-
515-
total_entities = len(entity_ids)
516-
result = VectorSyncBatchResult(
517-
entities_total=total_entities,
518-
entities_synced=0,
519-
entities_failed=0,
520-
)
521-
if total_entities == 0:
522-
return result
523-
524-
logger.info(
525-
"Vector batch sync start: project_id={project_id} entities_total={entities_total} "
526-
"sync_batch_size={sync_batch_size} prepare_concurrency={prepare_concurrency}",
527-
project_id=self.project_id,
528-
entities_total=total_entities,
529-
sync_batch_size=self._semantic_embedding_sync_batch_size,
530-
prepare_concurrency=self._semantic_postgres_prepare_concurrency,
494+
def _vector_prepare_window_size(self) -> int:
495+
"""Use a bounded config-driven prepare window for Postgres vector sync."""
496+
return self._semantic_postgres_prepare_concurrency
497+
498+
async def _prepare_entity_vector_jobs_window(
499+
self, entity_ids: list[int]
500+
) -> list[_PreparedEntityVectorSync | BaseException]:
501+
"""Prepare one Postgres window concurrently to hide DB round-trip latency."""
502+
prepared_window = await asyncio.gather(
503+
*(self._prepare_entity_vector_jobs(entity_id) for entity_id in entity_ids),
504+
return_exceptions=True,
531505
)
532-
533-
pending_jobs: list[_PendingEmbeddingJob] = []
534-
entity_runtime: dict[int, _EntitySyncRuntime] = {}
535-
failed_entity_ids: set[int] = set()
536-
deferred_entity_ids: set[int] = set()
537-
synced_entity_ids: set[int] = set()
538-
539-
for window_start in range(0, total_entities, self._semantic_postgres_prepare_concurrency):
540-
window_entity_ids = entity_ids[
541-
window_start : window_start + self._semantic_postgres_prepare_concurrency
542-
]
543-
544-
if progress_callback is not None:
545-
for offset, entity_id in enumerate(window_entity_ids, start=window_start):
546-
progress_callback(entity_id, offset, total_entities)
547-
548-
prepared_window = await asyncio.gather(
549-
*(self._prepare_entity_vector_jobs(entity_id) for entity_id in window_entity_ids),
550-
return_exceptions=True,
551-
)
552-
553-
for entity_id, prepared in zip(window_entity_ids, prepared_window, strict=True):
554-
if isinstance(prepared, BaseException):
555-
failed_entity_ids.add(entity_id)
556-
logger.warning(
557-
"Vector batch sync entity prepare failed: project_id={project_id} "
558-
"entity_id={entity_id} error={error}",
559-
project_id=self.project_id,
560-
entity_id=entity_id,
561-
error=str(prepared),
562-
)
563-
continue
564-
565-
prepared_sync = cast(_PreparedEntityVectorSync, prepared)
566-
567-
embedding_jobs_count = len(prepared_sync.embedding_jobs)
568-
result.chunks_total += prepared_sync.chunks_total
569-
result.chunks_skipped += prepared_sync.chunks_skipped
570-
if prepared_sync.entity_skipped:
571-
result.entities_skipped += 1
572-
result.embedding_jobs_total += embedding_jobs_count
573-
result.prepare_seconds_total += prepared_sync.prepare_seconds
574-
575-
if embedding_jobs_count == 0:
576-
if prepared_sync.entity_complete:
577-
synced_entity_ids.add(entity_id)
578-
else:
579-
deferred_entity_ids.add(entity_id)
580-
total_seconds = time.perf_counter() - prepared_sync.sync_start
581-
queue_wait_seconds = max(0.0, total_seconds - prepared_sync.prepare_seconds)
582-
result.queue_wait_seconds_total += queue_wait_seconds
583-
self._log_vector_sync_complete(
584-
entity_id=entity_id,
585-
total_seconds=total_seconds,
586-
prepare_seconds=prepared_sync.prepare_seconds,
587-
queue_wait_seconds=queue_wait_seconds,
588-
embed_seconds=0.0,
589-
write_seconds=0.0,
590-
source_rows_count=prepared_sync.source_rows_count,
591-
chunks_total=prepared_sync.chunks_total,
592-
chunks_skipped=prepared_sync.chunks_skipped,
593-
embedding_jobs_count=0,
594-
entity_skipped=prepared_sync.entity_skipped,
595-
entity_complete=prepared_sync.entity_complete,
596-
oversized_entity=prepared_sync.oversized_entity,
597-
pending_jobs_total=prepared_sync.pending_jobs_total,
598-
shard_index=prepared_sync.shard_index,
599-
shard_count=prepared_sync.shard_count,
600-
remaining_jobs_after_shard=prepared_sync.remaining_jobs_after_shard,
601-
)
602-
continue
603-
604-
entity_runtime[entity_id] = _EntitySyncRuntime(
605-
sync_start=prepared_sync.sync_start,
606-
source_rows_count=prepared_sync.source_rows_count,
607-
embedding_jobs_count=embedding_jobs_count,
608-
remaining_jobs=embedding_jobs_count,
609-
chunks_total=prepared_sync.chunks_total,
610-
chunks_skipped=prepared_sync.chunks_skipped,
611-
entity_skipped=prepared_sync.entity_skipped,
612-
entity_complete=prepared_sync.entity_complete,
613-
oversized_entity=prepared_sync.oversized_entity,
614-
pending_jobs_total=prepared_sync.pending_jobs_total,
615-
shard_index=prepared_sync.shard_index,
616-
shard_count=prepared_sync.shard_count,
617-
remaining_jobs_after_shard=prepared_sync.remaining_jobs_after_shard,
618-
prepare_seconds=prepared_sync.prepare_seconds,
619-
)
620-
pending_jobs.extend(
621-
_PendingEmbeddingJob(
622-
entity_id=entity_id,
623-
chunk_row_id=row_id,
624-
chunk_text=chunk_text,
625-
)
626-
for row_id, chunk_text in prepared_sync.embedding_jobs
627-
)
628-
629-
while len(pending_jobs) >= self._semantic_embedding_sync_batch_size:
630-
flush_jobs = pending_jobs[: self._semantic_embedding_sync_batch_size]
631-
pending_jobs = pending_jobs[self._semantic_embedding_sync_batch_size :]
632-
try:
633-
embed_seconds, write_seconds = await self._flush_embedding_jobs(
634-
flush_jobs=flush_jobs,
635-
entity_runtime=entity_runtime,
636-
synced_entity_ids=synced_entity_ids,
637-
)
638-
result.embed_seconds_total += embed_seconds
639-
result.write_seconds_total += write_seconds
640-
(result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs(
641-
entity_runtime=entity_runtime,
642-
synced_entity_ids=synced_entity_ids,
643-
deferred_entity_ids=deferred_entity_ids,
644-
)
645-
except Exception as exc:
646-
affected_entity_ids = sorted({job.entity_id for job in flush_jobs})
647-
failed_entity_ids.update(affected_entity_ids)
648-
synced_entity_ids.difference_update(affected_entity_ids)
649-
deferred_entity_ids.difference_update(affected_entity_ids)
650-
for failed_entity_id in affected_entity_ids:
651-
entity_runtime.pop(failed_entity_id, None)
652-
logger.warning(
653-
"Vector batch sync flush failed: project_id={project_id} "
654-
"affected_entities={affected_entities} chunk_count={chunk_count} "
655-
"error={error}",
656-
project_id=self.project_id,
657-
affected_entities=affected_entity_ids,
658-
chunk_count=len(flush_jobs),
659-
error=str(exc),
660-
)
661-
662-
if pending_jobs:
663-
flush_jobs = list(pending_jobs)
664-
pending_jobs = []
665-
try:
666-
embed_seconds, write_seconds = await self._flush_embedding_jobs(
667-
flush_jobs=flush_jobs,
668-
entity_runtime=entity_runtime,
669-
synced_entity_ids=synced_entity_ids,
670-
)
671-
result.embed_seconds_total += embed_seconds
672-
result.write_seconds_total += write_seconds
673-
(result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs(
674-
entity_runtime=entity_runtime,
675-
synced_entity_ids=synced_entity_ids,
676-
deferred_entity_ids=deferred_entity_ids,
677-
)
678-
except Exception as exc:
679-
affected_entity_ids = sorted({job.entity_id for job in flush_jobs})
680-
failed_entity_ids.update(affected_entity_ids)
681-
synced_entity_ids.difference_update(affected_entity_ids)
682-
deferred_entity_ids.difference_update(affected_entity_ids)
683-
for failed_entity_id in affected_entity_ids:
684-
entity_runtime.pop(failed_entity_id, None)
685-
logger.warning(
686-
"Vector batch sync final flush failed: project_id={project_id} "
687-
"affected_entities={affected_entities} chunk_count={chunk_count} error={error}",
688-
project_id=self.project_id,
689-
affected_entities=affected_entity_ids,
690-
chunk_count=len(flush_jobs),
691-
error=str(exc),
692-
)
693-
694-
if entity_runtime:
695-
orphan_runtime_entities = sorted(entity_runtime.keys())
696-
failed_entity_ids.update(orphan_runtime_entities)
697-
synced_entity_ids.difference_update(orphan_runtime_entities)
698-
deferred_entity_ids.difference_update(orphan_runtime_entities)
699-
logger.warning(
700-
"Vector batch sync left unfinished entities after flushes: "
701-
"project_id={project_id} unfinished_entities={unfinished_entities}",
702-
project_id=self.project_id,
703-
unfinished_entities=orphan_runtime_entities,
704-
)
705-
706-
synced_entity_ids.difference_update(failed_entity_ids)
707-
deferred_entity_ids.difference_update(failed_entity_ids)
708-
deferred_entity_ids.difference_update(synced_entity_ids)
709-
result.failed_entity_ids = sorted(failed_entity_ids)
710-
result.entities_failed = len(result.failed_entity_ids)
711-
result.entities_deferred = len(deferred_entity_ids)
712-
result.entities_synced = len(synced_entity_ids)
713-
714-
logger.info(
715-
"Vector batch sync complete: project_id={project_id} entities_total={entities_total} "
716-
"entities_synced={entities_synced} entities_failed={entities_failed} "
717-
"entities_deferred={entities_deferred} "
718-
"entities_skipped={entities_skipped} chunks_total={chunks_total} "
719-
"chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} "
720-
"prepare_seconds_total={prepare_seconds_total:.3f} "
721-
"queue_wait_seconds_total={queue_wait_seconds_total:.3f} "
722-
"embed_seconds_total={embed_seconds_total:.3f} write_seconds_total={write_seconds_total:.3f}",
723-
project_id=self.project_id,
724-
entities_total=result.entities_total,
725-
entities_synced=result.entities_synced,
726-
entities_failed=result.entities_failed,
727-
entities_deferred=result.entities_deferred,
728-
entities_skipped=result.entities_skipped,
729-
chunks_total=result.chunks_total,
730-
chunks_skipped=result.chunks_skipped,
731-
embedding_jobs_total=result.embedding_jobs_total,
732-
prepare_seconds_total=result.prepare_seconds_total,
733-
queue_wait_seconds_total=result.queue_wait_seconds_total,
734-
embed_seconds_total=result.embed_seconds_total,
735-
write_seconds_total=result.write_seconds_total,
736-
)
737-
738-
return result
506+
return [
507+
cast(_PreparedEntityVectorSync | BaseException, prepared)
508+
for prepared in prepared_window
509+
]
739510

740511
async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVectorSync:
741512
"""Prepare chunk mutations with Postgres-specific bulk upserts."""
@@ -783,6 +554,7 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
783554
source_rows_count=source_rows_count,
784555
)
785556
await self._delete_entity_chunks(session, entity_id)
557+
await session.commit()
786558
prepare_seconds = time.perf_counter() - sync_start
787559
return _PreparedEntityVectorSync(
788560
entity_id=entity_id,
@@ -807,6 +579,7 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
807579
)
808580
if not chunk_records:
809581
await self._delete_entity_chunks(session, entity_id)
582+
await session.commit()
810583
prepare_seconds = time.perf_counter() - sync_start
811584
return _PreparedEntityVectorSync(
812585
entity_id=entity_id,
@@ -830,12 +603,12 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
830603
existing_rows = existing_rows_result.mappings().all()
831604
existing_by_key = {str(row["chunk_key"]): row for row in existing_rows}
832605
existing_chunks_count = len(existing_by_key)
833-
incoming_by_key = {record["chunk_key"]: record for record in chunk_records}
606+
incoming_chunk_keys = {record["chunk_key"] for record in chunk_records}
834607

835608
stale_ids = [
836609
int(row["id"])
837610
for chunk_key, row in existing_by_key.items()
838-
if chunk_key not in incoming_by_key
611+
if chunk_key not in incoming_chunk_keys
839612
]
840613
stale_chunks_count = len(stale_ids)
841614
if stale_ids:
@@ -934,6 +707,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
934707
"entity_id": entity_id,
935708
}
936709
upsert_values: list[str] = []
710+
# The SQL template is built from integer enumerate() indices only.
711+
# No user-controlled text is interpolated into the statement.
937712
for index, record in enumerate(upsert_records):
938713
upsert_params[f"chunk_key_{index}"] = record["chunk_key"]
939714
upsert_params[f"chunk_text_{index}"] = record["chunk_text"]
@@ -1028,6 +803,8 @@ async def _write_embeddings(
1028803
params: dict[str, object] = {"project_id": self.project_id}
1029804
value_rows: list[str] = []
1030805

806+
# The SQL template is built from integer enumerate() indices only.
807+
# No user-controlled text is interpolated into the statement.
1031808
for index, ((row_id, _), vector) in enumerate(zip(jobs, embeddings, strict=True)):
1032809
params[f"chunk_id_{index}"] = row_id
1033810
params[f"embedding_{index}"] = self._format_pgvector_literal(vector)

0 commit comments

Comments
 (0)