Skip to content

Commit 5c8a4ec

Browse files
committed
fix: Address review feedback for embed_stream
1. V2 embed_stream mishandles duplicate texts (High): - Added used_batch_indices tracking like base_client - Now correctly assigns unique indices to duplicate texts 2. Unused variable total_embeddings_yielded (Low): - Removed from both base_client.py and v2/client.py
1 parent a1955f7 commit 5c8a4ec

2 files changed

Lines changed: 14 additions & 5 deletions

File tree

src/cohere/base_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,6 @@ def embed_stream(
12071207

12081208
# Process texts in batches
12091209
texts_list = list(texts)
1210-
total_embeddings_yielded = 0
12111210

12121211
for batch_start in range(0, len(texts_list), batch_size):
12131212
batch_end = min(batch_start + batch_size, len(texts_list))

src/cohere/v2/client.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,6 @@ def embed_stream(
583583

584584
# Process texts in batches
585585
texts_list = list(texts)
586-
total_embeddings_yielded = 0
587586

588587
for batch_start in range(0, len(texts_list), batch_size):
589588
batch_end = min(batch_start + batch_size, len(texts_list))
@@ -600,15 +599,26 @@ def embed_stream(
600599
truncate=truncate,
601600
request_options=request_options,
602601
)
603-
602+
604603
# Parse embeddings from response incrementally
605604
parser = StreamingEmbedParser(response._response, batch_texts)
605+
# Track used indices to handle duplicate texts correctly
606+
used_batch_indices: set[int] = set()
607+
606608
for embedding in parser.iter_embeddings():
607609
# The parser sets embedding.text correctly for multiple embedding types
608610
# Adjust the global index based on text position in batch
609611
if embedding.text and embedding.text in batch_texts:
610-
text_idx_in_batch = batch_texts.index(embedding.text)
611-
embedding.index = batch_start + text_idx_in_batch
612+
# Find the next unused occurrence of this text in the batch
613+
# This handles duplicate texts correctly
614+
text_idx_in_batch = None
615+
for idx, text in enumerate(batch_texts):
616+
if text == embedding.text and idx not in used_batch_indices:
617+
text_idx_in_batch = idx
618+
used_batch_indices.add(idx)
619+
break
620+
if text_idx_in_batch is not None:
621+
embedding.index = batch_start + text_idx_in_batch
612622
yield embedding
613623

614624
def rerank(

0 commit comments

Comments
 (0)