Skip to content

Commit bae1314

Browse files
authored
Merge pull request #127 from Dooders/similarity-search-validation
Similarity search validation
2 parents 3e1e03a + 5d0f7ac commit bae1314

8 files changed

Lines changed: 318 additions & 166 deletions

File tree

memory/embeddings/vector_store.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def _ensure_index(self) -> None:
243243
indices = self.redis.execute_command("FT._LIST")
244244
# Convert both the index name and list items to strings for comparison
245245
index_name_str = str(self.index_name)
246-
indices_str = [str(idx) if isinstance(idx, bytes) else idx for idx in indices]
246+
indices_str = [
247+
str(idx) if isinstance(idx, bytes) else idx for idx in indices
248+
]
247249
if index_name_str in indices_str:
248250
self._index_exists = True
249251
return
@@ -534,11 +536,14 @@ def __init__(
534536
"Redis" if redis_client else "in-memory",
535537
)
536538

537-
def store_memory_vectors(self, memory_entry: Dict[str, Any]) -> bool:
539+
def store_memory_vectors(
540+
self, memory_entry: Dict[str, Any], tier: str = "stm"
541+
) -> bool:
538542
"""Store vectors for a memory entry in appropriate indices.
539543
540544
Args:
541545
memory_entry: Memory entry with embeddings
546+
tier: Tier to store the vectors in ("stm", "im", or "ltm")
542547
543548
Returns:
544549
True if storage was successful
@@ -553,22 +558,27 @@ def store_memory_vectors(self, memory_entry: Dict[str, Any]) -> bool:
553558

554559
success = True
555560

556-
# Store in STM index if full vector is available
557-
if "full_vector" in embeddings:
561+
# Store in STM index
562+
if tier == "stm":
563+
logger.debug("Storing STM vector for memory %s", memory_id)
558564
success = success and self.stm_index.add(
559565
memory_id, embeddings["full_vector"], metadata
560566
)
561567

562-
# Store in IM index if compressed vector is available
563-
if "compressed_vector" in embeddings:
568+
# Store in IM index
569+
if tier == "im":
570+
#! TODO: Use compressed vector
571+
logger.debug("Storing IM vector for memory %s", memory_id)
564572
success = success and self.im_index.add(
565-
memory_id, embeddings["compressed_vector"], metadata
573+
memory_id, embeddings["full_vector"], metadata
566574
)
567575

568-
# Store in LTM index if abstract vector is available
569-
if "abstract_vector" in embeddings:
576+
# Store in LTM index
577+
if tier == "ltm":
578+
#! TODO: Use abstract vector
579+
logger.debug("Storing LTM vector for memory %s", memory_id)
570580
success = success and self.ltm_index.add(
571-
memory_id, embeddings["abstract_vector"], metadata
581+
memory_id, embeddings["full_vector"], metadata
572582
)
573583

574584
return success
@@ -594,36 +604,55 @@ def find_similar_memories(
594604
# Create filter function if metadata filter is provided
595605
filter_fn = None
596606
if metadata_filter:
597-
logger.debug("Creating filter function for metadata filter: %s", metadata_filter)
607+
logger.debug(
608+
"Creating filter function for metadata filter: %s", metadata_filter
609+
)
598610

599611
def filter_fn(metadata):
600612
logger.debug("Checking metadata: %s", metadata)
613+
unmatched_keys = [] # Initialize the list before using it
601614
for key, value in metadata_filter.items():
602615
# Try direct match in top-level metadata
603616
if key in metadata and metadata[key] == value:
604617
logger.debug("Found direct match for %s: %s", key, value)
605618
continue
606-
619+
607620
# Special handling for 'type' field - also check 'memory_type'
608-
if key == 'type' and 'memory_type' in metadata and metadata['memory_type'] == value:
621+
if (
622+
key == "type"
623+
and "memory_type" in metadata
624+
and metadata["memory_type"] == value
625+
):
609626
logger.debug("Found match for type in memory_type: %s", value)
610627
continue
611-
628+
612629
# Try match in nested content.metadata
613-
if 'content' in metadata and isinstance(metadata['content'], dict):
614-
content = metadata['content']
615-
if 'metadata' in content and isinstance(content['metadata'], dict):
616-
content_metadata = content['metadata']
617-
if key in content_metadata and content_metadata[key] == value:
618-
logger.debug("Found nested match for %s: %s in content.metadata", key, value)
630+
if "content" in metadata and isinstance(metadata["content"], dict):
631+
content = metadata["content"]
632+
if "metadata" in content and isinstance(
633+
content["metadata"], dict
634+
):
635+
content_metadata = content["metadata"]
636+
if (
637+
key in content_metadata
638+
and content_metadata[key] == value
639+
):
640+
logger.debug(
641+
"Found nested match for %s: %s in content.metadata",
642+
key,
643+
value,
644+
)
619645
continue
620-
646+
621647
# No match found for this key
622648
unmatched_keys.append((key, value))
623649
return False
624-
650+
625651
if unmatched_keys:
626-
logger.debug("No matches found for the following keys and values: %s", unmatched_keys)
652+
logger.debug(
653+
"No matches found for the following keys and values: %s",
654+
unmatched_keys,
655+
)
627656
else:
628657
logger.debug("All filter criteria matched")
629658
# All keys matched

memory/search/strategies/similarity.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ def search(
111111

112112
# Generate query vector from input
113113
query_vector = self._generate_query_vector(query, current_tier)
114+
115+
# Add detailed logging for vector generation
116+
logger.debug(
117+
"Query vector generation for tier %s - Input: %s, Output: %s",
118+
current_tier,
119+
query,
120+
query_vector
121+
)
114122

115123
# Skip if vector generation failed
116124
if query_vector is None:
@@ -127,21 +135,30 @@ def search(
127135

128136
# Find similar vectors
129137
logger.debug(
130-
"Calling vector_store.find_similar_memories for tier %s", current_tier
131-
)
132-
similar_vectors = self.vector_store.find_similar_memories(
133-
query_vector,
134-
tier=current_tier,
135-
limit=limit * 2, # Get extra results to allow for score filtering
136-
metadata_filter=metadata_filter or {},
137-
)
138-
139-
logger.debug(
140-
"Vector store returned %d results for tier %s. Raw results: %s",
141-
len(similar_vectors),
138+
"About to call vector_store.find_similar_memories for tier %s with vector: %s",
142139
current_tier,
143-
similar_vectors,
140+
query_vector
144141
)
142+
try:
143+
similar_vectors = self.vector_store.find_similar_memories(
144+
query_vector,
145+
tier=current_tier,
146+
limit=limit * 2, # Get extra results to allow for score filtering
147+
metadata_filter=metadata_filter or {},
148+
)
149+
logger.debug(
150+
"Vector store returned %d results for tier %s. Raw results with scores: %s",
151+
len(similar_vectors),
152+
current_tier,
153+
[(v["id"], v["score"]) for v in similar_vectors],
154+
)
155+
except Exception as e:
156+
logger.error(
157+
"Error in vector_store.find_similar_memories for tier %s: %s",
158+
current_tier,
159+
str(e)
160+
)
161+
continue
145162

146163
# Filter by score
147164
filtered_vectors = [v for v in similar_vectors if v["score"] >= min_score]
@@ -150,7 +167,7 @@ def search(
150167
min_score,
151168
len(filtered_vectors),
152169
current_tier,
153-
filtered_vectors,
170+
[(v["id"], v["score"]) for v in filtered_vectors],
154171
)
155172

156173
# Limit results
@@ -274,22 +291,27 @@ def _generate_query_vector(
274291
logger.debug(
275292
"Encoding dictionary query for tier %s. Query dict: %s", tier, query
276293
)
277-
if tier == "stm":
278-
vector = self.embedding_engine.encode_stm(query)
279-
elif tier == "im":
280-
vector = self.embedding_engine.encode_im(query)
281-
elif tier == "ltm":
282-
vector = self.embedding_engine.encode_ltm(query)
283-
284-
if vector is not None:
285-
logger.debug(
286-
"Successfully generated vector of length %d: %s",
287-
len(vector),
288-
vector,
289-
)
290-
else:
291-
logger.warning("Failed to generate vector for tier %s", tier)
292-
return vector
294+
try:
295+
if tier == "stm":
296+
vector = self.embedding_engine.encode_stm(query)
297+
elif tier == "im":
298+
vector = self.embedding_engine.encode_im(query)
299+
elif tier == "ltm":
300+
vector = self.embedding_engine.encode_ltm(query)
301+
302+
if vector is not None:
303+
logger.debug(
304+
"Successfully generated vector of length %d for tier %s: %s",
305+
len(vector),
306+
tier,
307+
vector,
308+
)
309+
else:
310+
logger.warning("Failed to generate vector for tier %s - encoding returned None", tier)
311+
return vector
312+
except Exception as e:
313+
logger.error("Error generating vector for tier %s: %s", tier, str(e))
314+
return None
293315

294316
# If we get here, we couldn't generate a vector
295317
logger.warning("Could not generate vector for query type: %s", type(query))

memory/storage/mockredis/core.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -996,11 +996,47 @@ def execute_command(self, command, *args, **kwargs):
996996
if "_vector_idx" in index_name and has_vector_query:
997997
# Simulate a vector search by returning some random keys
998998
agent_prefix = index_name.split('_')[0]
999+
metadata_filter = {} # Initialize empty metadata filter
1000+
1001+
# Extract metadata filter from args if present
1002+
for i in range(len(args)):
1003+
if args[i].lower() == "filter":
1004+
try:
1005+
metadata_filter = json.loads(args[i + 2])
1006+
except:
1007+
pass
1008+
break
1009+
9991010
for key in self.store.keys():
10001011
if isinstance(key, str) and key.startswith(f"{agent_prefix}-"):
1001-
matching_keys.append(key)
1002-
if len(matching_keys) >= vector_k:
1003-
break
1012+
# Get the memory data to check metadata
1013+
memory_data = self.store.get(key)
1014+
if memory_data and isinstance(memory_data, dict):
1015+
# Check if memory has the required metadata
1016+
metadata = memory_data.get('metadata', {})
1017+
content = memory_data.get('content', {})
1018+
content_metadata = content.get('metadata', {}) if isinstance(content, dict) else {}
1019+
1020+
# Check if memory matches metadata filter
1021+
matches_filter = True
1022+
for filter_key, filter_value in metadata_filter.items():
1023+
# Check in top-level metadata
1024+
if filter_key in metadata and metadata[filter_key] == filter_value:
1025+
continue
1026+
# Check in memory_type
1027+
if filter_key == 'type' and 'memory_type' in metadata and metadata['memory_type'] == filter_value:
1028+
continue
1029+
# Check in content.metadata
1030+
if filter_key in content_metadata and content_metadata[filter_key] == filter_value:
1031+
continue
1032+
# No match found
1033+
matches_filter = False
1034+
break
1035+
1036+
if matches_filter:
1037+
matching_keys.append(key)
1038+
if len(matching_keys) >= vector_k:
1039+
break
10041040

10051041
# Simulate scores
10061042
scores = [random.uniform(0.5, 1.0) for _ in matching_keys]

memory/storage/redis_im.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,14 @@ def store(
288288
return False
289289

290290
# Verify compression level
291-
compression_level = memory_entry.get("metadata", {}).get("compression_level")
292-
if compression_level != 1:
293-
logger.error(
294-
"Invalid compression level for IM storage: %s. Expected level 1.",
295-
compression_level,
296-
)
297-
return False
291+
#! TODO: Use compression level validation
292+
# compression_level = memory_entry.get("metadata", {}).get("compression_level")
293+
# if compression_level != 1:
294+
# logger.error(
295+
# "Invalid compression level for IM storage: %s. Expected level 1.",
296+
# compression_level,
297+
# )
298+
# return False
298299

299300
# Use store_with_retry for resilient storage
300301
return self.redis.store_with_retry(
@@ -485,7 +486,9 @@ def _store_memory_entry(self, agent_id: str, memory_entry: MemoryEntry) -> bool:
485486
return False
486487
return True
487488
except (RedisUnavailableError, RedisTimeoutError) as e:
488-
logger.debug(f"Caught Redis error in pipeline block: {type(e).__name__}")
489+
logger.debug(
490+
f"Caught Redis error in pipeline block: {type(e).__name__}"
491+
)
489492
raise # propagate these errors
490493
except redis.RedisError as e:
491494
logger.exception(

0 commit comments

Comments
 (0)