Skip to content

Commit 2d337a3

Browse files
committed
fix: Address remaining Copilot review comments
- Add batch_size validation (must be >= 1) - Handle OMIT sentinel properly in both v1 and v2 clients - Remove images parameter from v2 embed_stream (text-only support) - Document that embed_stream is for texts only, use embed() for images All tests passing (5/6, 1 skipped requires API key)
1 parent 8ef4bdc commit 2d337a3

2 files changed

Lines changed: 27 additions & 12 deletions

File tree

src/cohere/base_client.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,15 +1190,22 @@ def embed_stream(
11901190
print(f"Embedding {embedding.index}: {embedding.embedding[:5]}...")
11911191
# Process/save embedding immediately
11921192
"""
1193+
# Validate batch_size
1194+
if batch_size < 1:
1195+
raise ValueError("batch_size must be at least 1")
1196+
1197+
# Handle OMIT sentinel and empty texts
1198+
if texts is None or texts is OMIT:
1199+
return
11931200
if not texts:
11941201
return
1195-
1202+
11961203
from .streaming_utils import StreamingEmbedParser
1197-
1204+
11981205
# Process texts in batches
1199-
texts_list = list(texts) if texts else []
1206+
texts_list = list(texts)
12001207
total_embeddings_yielded = 0
1201-
1208+
12021209
for batch_start in range(0, len(texts_list), batch_size):
12031210
batch_end = min(batch_start + batch_size, len(texts_list))
12041211
batch_texts = texts_list[batch_start:batch_end]

src/cohere/v2/client.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,6 @@ def embed_stream(
498498
model: str,
499499
input_type: EmbedInputType,
500500
texts: typing.Optional[typing.Sequence[str]] = OMIT,
501-
images: typing.Optional[typing.Sequence[str]] = OMIT,
502501
max_tokens: typing.Optional[int] = OMIT,
503502
output_dimension: typing.Optional[int] = OMIT,
504503
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
@@ -508,11 +507,14 @@ def embed_stream(
508507
) -> typing.Iterator[typing.Any]: # Returns Iterator[StreamedEmbedding]
509508
"""
510509
Memory-efficient streaming version of embed that yields embeddings one at a time.
511-
510+
512511
This method processes texts in batches and yields individual embeddings as they are
513512
parsed from the response, without loading all embeddings into memory at once.
514513
Ideal for processing large datasets where memory usage is a concern.
515514
515+
Note: This method only supports text embeddings. For image embeddings, use the
516+
regular embed() method.
517+
516518
Parameters
517519
----------
518520
model : str
@@ -570,25 +572,31 @@ def embed_stream(
570572
print(f"Embedding {embedding.index}: {embedding.embedding[:5]}...")
571573
# Process/save embedding immediately
572574
"""
575+
# Validate batch_size
576+
if batch_size < 1:
577+
raise ValueError("batch_size must be at least 1")
578+
579+
# Handle OMIT sentinel and empty texts
580+
if texts is None or texts is OMIT:
581+
return
573582
if not texts:
574583
return
575-
584+
576585
from ..streaming_utils import StreamingEmbedParser
577-
586+
578587
# Process texts in batches
579-
texts_list = list(texts) if texts else []
588+
texts_list = list(texts)
580589
total_embeddings_yielded = 0
581-
590+
582591
for batch_start in range(0, len(texts_list), batch_size):
583592
batch_end = min(batch_start + batch_size, len(texts_list))
584593
batch_texts = texts_list[batch_start:batch_end]
585-
594+
586595
# Get response for this batch
587596
response = self._raw_client.embed(
588597
model=model,
589598
input_type=input_type,
590599
texts=batch_texts,
591-
images=images if batch_start == 0 else None, # Only include images in first batch
592600
max_tokens=max_tokens,
593601
output_dimension=output_dimension,
594602
embedding_types=embedding_types,

0 commit comments

Comments
 (0)