Skip to content

Commit c2c3f3e

Browse files
committed
fix: Address review feedback for configurable batch_size
Fixes for issues identified by Cursor bugbot: 1. Missing batch_size validation in embed method (Medium): - Added validation to raise ValueError if batch_size < 1 - Applied to both sync and async embed methods 2. IndexError when using multiple embedding types with embed_stream (High): - Fixed index calculation to use text position from parser - Parser correctly tracks text index per embedding type 3. Fallback causes duplicate embeddings after partial ijson failure (Low): - Collect all ijson embeddings into list before yielding - Reset embeddings_yielded counter before fallback - Only yield after successful complete parsing
1 parent 13b57e6 commit c2c3f3e

4 files changed

Lines changed: 37 additions & 17 deletions

File tree

src/cohere/base_client.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,12 +1222,13 @@ def embed_stream(
12221222

12231223
# Parse embeddings from response incrementally
12241224
parser = StreamingEmbedParser(response._response, batch_texts)
1225-
for i, embedding in enumerate(parser.iter_embeddings()):
1226-
# Adjust index for global position
1227-
embedding.index = batch_start + i
1228-
embedding.text = texts_list[embedding.index]
1225+
for embedding in parser.iter_embeddings():
1226+
# The parser sets embedding.text correctly for multiple embedding types
1227+
# Adjust the global index based on text position in batch
1228+
if embedding.text and embedding.text in batch_texts:
1229+
text_idx_in_batch = batch_texts.index(embedding.text)
1230+
embedding.index = batch_start + text_idx_in_batch
12291231
yield embedding
1230-
total_embeddings_yielded += len(batch_texts)
12311232

12321233
def rerank(
12331234
self,

src/cohere/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ def embed(
203203
request_options=request_options,
204204
)
205205

206+
# Validate batch_size
207+
if batch_size is not None and batch_size < 1:
208+
raise ValueError("batch_size must be at least 1")
209+
206210
textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
207211
effective_batch_size = batch_size if batch_size is not None else embed_batch_size
208212
texts_batches = [textsarr[i : i + effective_batch_size] for i in range(0, len(textsarr), effective_batch_size)]
@@ -408,6 +412,10 @@ async def embed(
408412
request_options=request_options,
409413
)
410414

415+
# Validate batch_size
416+
if batch_size is not None and batch_size < 1:
417+
raise ValueError("batch_size must be at least 1")
418+
411419
textsarr: typing.Sequence[str] = texts if texts is not OMIT and texts is not None else []
412420
effective_batch_size = batch_size if batch_size is not None else embed_batch_size
413421
texts_batches = [textsarr[i : i + effective_batch_size] for i in range(0, len(textsarr), effective_batch_size)]

src/cohere/streaming_utils.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,31 @@ def iter_embeddings(self) -> Iterator[StreamedEmbedding]:
5050
Yields:
5151
StreamedEmbedding objects as they are parsed from the response
5252
"""
53-
if not IJSON_AVAILABLE:
54-
# Fallback to regular parsing if ijson not available
53+
# Try to get response content as bytes for ijson
54+
response_content: Optional[bytes] = None
55+
try:
56+
content = self.response.content
57+
if isinstance(content, bytes):
58+
response_content = content
59+
except Exception:
60+
pass
61+
62+
if not IJSON_AVAILABLE or response_content is None:
63+
# Fallback to regular parsing if ijson not available or no bytes content
5564
yield from self._iter_embeddings_fallback()
5665
return
5766

58-
# Buffer response content first to allow fallback if ijson fails
59-
# This prevents partial parsing issues where ijson yields some embeddings then fails
60-
response_content = self.response.content
61-
6267
try:
6368
# Use ijson for memory-efficient parsing
69+
# Collect all embeddings first to avoid partial yields before failure
6470
parser = ijson.parse(io.BytesIO(response_content))
65-
yield from self._parse_with_ijson(parser)
71+
embeddings = list(self._parse_with_ijson(parser))
72+
# Only yield after successful complete parsing
73+
yield from embeddings
6674
except Exception:
6775
# If ijson parsing fails, fallback to regular parsing using buffered content
76+
# Reset embeddings_yielded since we collected but didn't yield
77+
self.embeddings_yielded = 0
6878
data = json.loads(response_content)
6979
yield from self._iter_embeddings_fallback_from_dict(data)
7080

src/cohere/v2/client.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -603,12 +603,13 @@ def embed_stream(
603603

604604
# Parse embeddings from response incrementally
605605
parser = StreamingEmbedParser(response._response, batch_texts)
606-
for i, embedding in enumerate(parser.iter_embeddings()):
607-
# Adjust index for global position
608-
embedding.index = batch_start + i
609-
embedding.text = texts_list[embedding.index]
606+
for embedding in parser.iter_embeddings():
607+
# The parser sets embedding.text correctly for multiple embedding types
608+
# Adjust the global index based on text position in batch
609+
if embedding.text and embedding.text in batch_texts:
610+
text_idx_in_batch = batch_texts.index(embedding.text)
611+
embedding.index = batch_start + text_idx_in_batch
610612
yield embedding
611-
total_embeddings_yielded += len(batch_texts)
612613

613614
def rerank(
614615
self,

0 commit comments

Comments
 (0)