1818from basic_memory .repository .search_index_row import SearchIndexRow
1919from basic_memory .repository .search_repository_base import (
2020 SearchRepositoryBase ,
21- VectorSyncBatchResult ,
22- _EntitySyncRuntime ,
23- _PendingEmbeddingJob ,
2421 _PreparedEntityVectorSync ,
2522)
2623from basic_memory .repository .metadata_filters import parse_metadata_filters
@@ -494,248 +491,22 @@ async def _run_vector_query(
494491 )
495492 return [dict (row ) for row in vector_result .mappings ().all ()]
496493
497- async def sync_entity_vectors_batch (
498- self ,
499- entity_ids : list [int ],
500- progress_callback = None ,
501- ) -> VectorSyncBatchResult :
502- """Sync semantic vectors with concurrent Postgres preparation windows.
503-
504- Trigger: cloud indexing uses Neon Postgres where network latency dominates
505- thousands of per-entity prepare queries.
506- Why: preparing a small config-driven window of entities concurrently hides
507- round-trip latency without exhausting the tenant connection pool.
508- Outcome: Postgres vector sync keeps the existing flush semantics while reducing
509- wall-clock time on large cloud projects.
510- """
511- self ._assert_semantic_available ()
512- await self ._ensure_vector_tables ()
513- assert self ._embedding_provider is not None
514-
515- total_entities = len (entity_ids )
516- result = VectorSyncBatchResult (
517- entities_total = total_entities ,
518- entities_synced = 0 ,
519- entities_failed = 0 ,
520- )
521- if total_entities == 0 :
522- return result
523-
524- logger .info (
525- "Vector batch sync start: project_id={project_id} entities_total={entities_total} "
526- "sync_batch_size={sync_batch_size} prepare_concurrency={prepare_concurrency}" ,
527- project_id = self .project_id ,
528- entities_total = total_entities ,
529- sync_batch_size = self ._semantic_embedding_sync_batch_size ,
530- prepare_concurrency = self ._semantic_postgres_prepare_concurrency ,
494+ def _vector_prepare_window_size (self ) -> int :
495+ """Use a bounded config-driven prepare window for Postgres vector sync."""
496+ return self ._semantic_postgres_prepare_concurrency
497+
498+ async def _prepare_entity_vector_jobs_window (
499+ self , entity_ids : list [int ]
500+ ) -> list [_PreparedEntityVectorSync | BaseException ]:
501+ """Prepare one Postgres window concurrently to hide DB round-trip latency."""
502+ prepared_window = await asyncio .gather (
503+ * (self ._prepare_entity_vector_jobs (entity_id ) for entity_id in entity_ids ),
504+ return_exceptions = True ,
531505 )
532-
533- pending_jobs : list [_PendingEmbeddingJob ] = []
534- entity_runtime : dict [int , _EntitySyncRuntime ] = {}
535- failed_entity_ids : set [int ] = set ()
536- deferred_entity_ids : set [int ] = set ()
537- synced_entity_ids : set [int ] = set ()
538-
539- for window_start in range (0 , total_entities , self ._semantic_postgres_prepare_concurrency ):
540- window_entity_ids = entity_ids [
541- window_start : window_start + self ._semantic_postgres_prepare_concurrency
542- ]
543-
544- if progress_callback is not None :
545- for offset , entity_id in enumerate (window_entity_ids , start = window_start ):
546- progress_callback (entity_id , offset , total_entities )
547-
548- prepared_window = await asyncio .gather (
549- * (self ._prepare_entity_vector_jobs (entity_id ) for entity_id in window_entity_ids ),
550- return_exceptions = True ,
551- )
552-
553- for entity_id , prepared in zip (window_entity_ids , prepared_window , strict = True ):
554- if isinstance (prepared , BaseException ):
555- failed_entity_ids .add (entity_id )
556- logger .warning (
557- "Vector batch sync entity prepare failed: project_id={project_id} "
558- "entity_id={entity_id} error={error}" ,
559- project_id = self .project_id ,
560- entity_id = entity_id ,
561- error = str (prepared ),
562- )
563- continue
564-
565- prepared_sync = cast (_PreparedEntityVectorSync , prepared )
566-
567- embedding_jobs_count = len (prepared_sync .embedding_jobs )
568- result .chunks_total += prepared_sync .chunks_total
569- result .chunks_skipped += prepared_sync .chunks_skipped
570- if prepared_sync .entity_skipped :
571- result .entities_skipped += 1
572- result .embedding_jobs_total += embedding_jobs_count
573- result .prepare_seconds_total += prepared_sync .prepare_seconds
574-
575- if embedding_jobs_count == 0 :
576- if prepared_sync .entity_complete :
577- synced_entity_ids .add (entity_id )
578- else :
579- deferred_entity_ids .add (entity_id )
580- total_seconds = time .perf_counter () - prepared_sync .sync_start
581- queue_wait_seconds = max (0.0 , total_seconds - prepared_sync .prepare_seconds )
582- result .queue_wait_seconds_total += queue_wait_seconds
583- self ._log_vector_sync_complete (
584- entity_id = entity_id ,
585- total_seconds = total_seconds ,
586- prepare_seconds = prepared_sync .prepare_seconds ,
587- queue_wait_seconds = queue_wait_seconds ,
588- embed_seconds = 0.0 ,
589- write_seconds = 0.0 ,
590- source_rows_count = prepared_sync .source_rows_count ,
591- chunks_total = prepared_sync .chunks_total ,
592- chunks_skipped = prepared_sync .chunks_skipped ,
593- embedding_jobs_count = 0 ,
594- entity_skipped = prepared_sync .entity_skipped ,
595- entity_complete = prepared_sync .entity_complete ,
596- oversized_entity = prepared_sync .oversized_entity ,
597- pending_jobs_total = prepared_sync .pending_jobs_total ,
598- shard_index = prepared_sync .shard_index ,
599- shard_count = prepared_sync .shard_count ,
600- remaining_jobs_after_shard = prepared_sync .remaining_jobs_after_shard ,
601- )
602- continue
603-
604- entity_runtime [entity_id ] = _EntitySyncRuntime (
605- sync_start = prepared_sync .sync_start ,
606- source_rows_count = prepared_sync .source_rows_count ,
607- embedding_jobs_count = embedding_jobs_count ,
608- remaining_jobs = embedding_jobs_count ,
609- chunks_total = prepared_sync .chunks_total ,
610- chunks_skipped = prepared_sync .chunks_skipped ,
611- entity_skipped = prepared_sync .entity_skipped ,
612- entity_complete = prepared_sync .entity_complete ,
613- oversized_entity = prepared_sync .oversized_entity ,
614- pending_jobs_total = prepared_sync .pending_jobs_total ,
615- shard_index = prepared_sync .shard_index ,
616- shard_count = prepared_sync .shard_count ,
617- remaining_jobs_after_shard = prepared_sync .remaining_jobs_after_shard ,
618- prepare_seconds = prepared_sync .prepare_seconds ,
619- )
620- pending_jobs .extend (
621- _PendingEmbeddingJob (
622- entity_id = entity_id ,
623- chunk_row_id = row_id ,
624- chunk_text = chunk_text ,
625- )
626- for row_id , chunk_text in prepared_sync .embedding_jobs
627- )
628-
629- while len (pending_jobs ) >= self ._semantic_embedding_sync_batch_size :
630- flush_jobs = pending_jobs [: self ._semantic_embedding_sync_batch_size ]
631- pending_jobs = pending_jobs [self ._semantic_embedding_sync_batch_size :]
632- try :
633- embed_seconds , write_seconds = await self ._flush_embedding_jobs (
634- flush_jobs = flush_jobs ,
635- entity_runtime = entity_runtime ,
636- synced_entity_ids = synced_entity_ids ,
637- )
638- result .embed_seconds_total += embed_seconds
639- result .write_seconds_total += write_seconds
640- (result .queue_wait_seconds_total ) += self ._finalize_completed_entity_syncs (
641- entity_runtime = entity_runtime ,
642- synced_entity_ids = synced_entity_ids ,
643- deferred_entity_ids = deferred_entity_ids ,
644- )
645- except Exception as exc :
646- affected_entity_ids = sorted ({job .entity_id for job in flush_jobs })
647- failed_entity_ids .update (affected_entity_ids )
648- synced_entity_ids .difference_update (affected_entity_ids )
649- deferred_entity_ids .difference_update (affected_entity_ids )
650- for failed_entity_id in affected_entity_ids :
651- entity_runtime .pop (failed_entity_id , None )
652- logger .warning (
653- "Vector batch sync flush failed: project_id={project_id} "
654- "affected_entities={affected_entities} chunk_count={chunk_count} "
655- "error={error}" ,
656- project_id = self .project_id ,
657- affected_entities = affected_entity_ids ,
658- chunk_count = len (flush_jobs ),
659- error = str (exc ),
660- )
661-
662- if pending_jobs :
663- flush_jobs = list (pending_jobs )
664- pending_jobs = []
665- try :
666- embed_seconds , write_seconds = await self ._flush_embedding_jobs (
667- flush_jobs = flush_jobs ,
668- entity_runtime = entity_runtime ,
669- synced_entity_ids = synced_entity_ids ,
670- )
671- result .embed_seconds_total += embed_seconds
672- result .write_seconds_total += write_seconds
673- (result .queue_wait_seconds_total ) += self ._finalize_completed_entity_syncs (
674- entity_runtime = entity_runtime ,
675- synced_entity_ids = synced_entity_ids ,
676- deferred_entity_ids = deferred_entity_ids ,
677- )
678- except Exception as exc :
679- affected_entity_ids = sorted ({job .entity_id for job in flush_jobs })
680- failed_entity_ids .update (affected_entity_ids )
681- synced_entity_ids .difference_update (affected_entity_ids )
682- deferred_entity_ids .difference_update (affected_entity_ids )
683- for failed_entity_id in affected_entity_ids :
684- entity_runtime .pop (failed_entity_id , None )
685- logger .warning (
686- "Vector batch sync final flush failed: project_id={project_id} "
687- "affected_entities={affected_entities} chunk_count={chunk_count} error={error}" ,
688- project_id = self .project_id ,
689- affected_entities = affected_entity_ids ,
690- chunk_count = len (flush_jobs ),
691- error = str (exc ),
692- )
693-
694- if entity_runtime :
695- orphan_runtime_entities = sorted (entity_runtime .keys ())
696- failed_entity_ids .update (orphan_runtime_entities )
697- synced_entity_ids .difference_update (orphan_runtime_entities )
698- deferred_entity_ids .difference_update (orphan_runtime_entities )
699- logger .warning (
700- "Vector batch sync left unfinished entities after flushes: "
701- "project_id={project_id} unfinished_entities={unfinished_entities}" ,
702- project_id = self .project_id ,
703- unfinished_entities = orphan_runtime_entities ,
704- )
705-
706- synced_entity_ids .difference_update (failed_entity_ids )
707- deferred_entity_ids .difference_update (failed_entity_ids )
708- deferred_entity_ids .difference_update (synced_entity_ids )
709- result .failed_entity_ids = sorted (failed_entity_ids )
710- result .entities_failed = len (result .failed_entity_ids )
711- result .entities_deferred = len (deferred_entity_ids )
712- result .entities_synced = len (synced_entity_ids )
713-
714- logger .info (
715- "Vector batch sync complete: project_id={project_id} entities_total={entities_total} "
716- "entities_synced={entities_synced} entities_failed={entities_failed} "
717- "entities_deferred={entities_deferred} "
718- "entities_skipped={entities_skipped} chunks_total={chunks_total} "
719- "chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} "
720- "prepare_seconds_total={prepare_seconds_total:.3f} "
721- "queue_wait_seconds_total={queue_wait_seconds_total:.3f} "
722- "embed_seconds_total={embed_seconds_total:.3f} write_seconds_total={write_seconds_total:.3f}" ,
723- project_id = self .project_id ,
724- entities_total = result .entities_total ,
725- entities_synced = result .entities_synced ,
726- entities_failed = result .entities_failed ,
727- entities_deferred = result .entities_deferred ,
728- entities_skipped = result .entities_skipped ,
729- chunks_total = result .chunks_total ,
730- chunks_skipped = result .chunks_skipped ,
731- embedding_jobs_total = result .embedding_jobs_total ,
732- prepare_seconds_total = result .prepare_seconds_total ,
733- queue_wait_seconds_total = result .queue_wait_seconds_total ,
734- embed_seconds_total = result .embed_seconds_total ,
735- write_seconds_total = result .write_seconds_total ,
736- )
737-
738- return result
506+ return [
507+ cast (_PreparedEntityVectorSync | BaseException , prepared )
508+ for prepared in prepared_window
509+ ]
739510
740511 async def _prepare_entity_vector_jobs (self , entity_id : int ) -> _PreparedEntityVectorSync :
741512 """Prepare chunk mutations with Postgres-specific bulk upserts."""
@@ -783,6 +554,7 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
783554 source_rows_count = source_rows_count ,
784555 )
785556 await self ._delete_entity_chunks (session , entity_id )
557+ await session .commit ()
786558 prepare_seconds = time .perf_counter () - sync_start
787559 return _PreparedEntityVectorSync (
788560 entity_id = entity_id ,
@@ -807,6 +579,7 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
807579 )
808580 if not chunk_records :
809581 await self ._delete_entity_chunks (session , entity_id )
582+ await session .commit ()
810583 prepare_seconds = time .perf_counter () - sync_start
811584 return _PreparedEntityVectorSync (
812585 entity_id = entity_id ,
@@ -830,12 +603,12 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
830603 existing_rows = existing_rows_result .mappings ().all ()
831604 existing_by_key = {str (row ["chunk_key" ]): row for row in existing_rows }
832605 existing_chunks_count = len (existing_by_key )
833- incoming_by_key = {record ["chunk_key" ]: record for record in chunk_records }
606+ incoming_chunk_keys = {record ["chunk_key" ] for record in chunk_records }
834607
835608 stale_ids = [
836609 int (row ["id" ])
837610 for chunk_key , row in existing_by_key .items ()
838- if chunk_key not in incoming_by_key
611+ if chunk_key not in incoming_chunk_keys
839612 ]
840613 stale_chunks_count = len (stale_ids )
841614 if stale_ids :
@@ -934,6 +707,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe
934707 "entity_id" : entity_id ,
935708 }
936709 upsert_values : list [str ] = []
710+ # The SQL template is built from integer enumerate() indices only.
711+ # No user-controlled text is interpolated into the statement.
937712 for index , record in enumerate (upsert_records ):
938713 upsert_params [f"chunk_key_{ index } " ] = record ["chunk_key" ]
939714 upsert_params [f"chunk_text_{ index } " ] = record ["chunk_text" ]
@@ -1028,6 +803,8 @@ async def _write_embeddings(
1028803 params : dict [str , object ] = {"project_id" : self .project_id }
1029804 value_rows : list [str ] = []
1030805
806+ # The SQL template is built from integer enumerate() indices only.
807+ # No user-controlled text is interpolated into the statement.
1031808 for index , ((row_id , _ ), vector ) in enumerate (zip (jobs , embeddings , strict = True )):
1032809 params [f"chunk_id_{ index } " ] = row_id
1033810 params [f"embedding_{ index } " ] = self ._format_pgvector_literal (vector )
0 commit comments