Skip to content

Commit 90e8df9

Browse files
authored
fix(_internals): use n_tokens0 offset when enabling last-token logits in add_sequence (abetlen#2205)
Fix batched embedding output flags for multi-sequence embed calls. Closes abetlen#2199.
1 parent 14d7846 commit 90e8df9

3 files changed

Lines changed: 18 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- fix: Correct batched embedding outputs for multi-sequence `embed()` calls by @Anai-Guo in #2205
11+
1012
## [0.3.22]
1113

1214
- feat: Update llama.cpp to ggerganov/llama.cpp@63d93d173

llama_cpp/_internals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
522522
self.batch.seq_id[j][0] = seq_id
523523
self.batch.n_seq_id[j] = 1
524524
self.batch.logits[j] = logits_all
525-
self.batch.logits[n_tokens - 1] = True
525+
self.batch.logits[n_tokens0 + n_tokens - 1] = True
526526

527527

528528
class LlamaTokenDataArray:

tests/test_llama.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,18 @@ def test_real_llama_embeddings(llama_cpp_embedding_model_path):
247247
)
248248
embedding = model.embed("Hello World")
249249
assert len(embedding) > 0
250+
251+
prompts = ["Hello World", "A different prompt"]
252+
individual_embeddings = [model.embed(prompt) for prompt in prompts]
253+
batched_embeddings = model.embed(prompts)
254+
255+
assert len(batched_embeddings) == len(prompts)
256+
for individual, batched in zip(individual_embeddings, batched_embeddings):
257+
np.testing.assert_allclose(batched, individual, rtol=1e-4, atol=1e-4)
258+
259+
repeated_embeddings = model.embed(list(reversed(prompts)))
260+
for individual, repeated in zip(
261+
reversed(individual_embeddings),
262+
repeated_embeddings,
263+
):
264+
np.testing.assert_allclose(repeated, individual, rtol=1e-4, atol=1e-4)

0 commit comments

Comments
 (0)