Skip to content

Commit 54490a2

Browse files
committed
⚡ Llama3Tokenizer, context window conversation fitting
1 parent bd52ff4 commit 54490a2

3 files changed

Lines changed: 154 additions & 63 deletions

File tree

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package io.graversen.replicate.llama3;
2+
3+
import io.graversen.replicate.common.TextConversation;
4+
import io.graversen.replicate.common.TextMessage;
5+
import jakarta.annotation.Nullable;
6+
import lombok.NonNull;
7+
import lombok.experimental.UtilityClass;
8+
9+
import java.util.LinkedList;
10+
import java.util.Objects;
11+
import java.util.function.Consumer;
12+
import java.util.stream.Collectors;
13+
14+
@UtilityClass
15+
public class Llama3Tokenizer {
16+
private static final String BEGIN_OF_TEXT = "<|begin_of_text|>";
17+
private static final String START_HEADER_ID = "<|start_header_id|>";
18+
private static final String END_HEADER_ID = "<|end_header_id|>";
19+
private static final String END_OF_TEXT_ID = "<|eot_id|>";
20+
21+
public static final Integer DEFAULT_CONTEXT_WINDOW_SIZE = 8000;
22+
public static final Integer APPROXIMATE_CHARACTERS_PER_TOKEN = 4;
23+
public static final String ROLE_USER = "user";
24+
public static final String ROLE_ASSISTANT = "assistant";
25+
public static final String ROLE_SYSTEM = "system";
26+
27+
public static String beginOfText(@NonNull String text) {
28+
return String.format("%s%s", BEGIN_OF_TEXT, text);
29+
}
30+
31+
public static String endOfText(@NonNull String text) {
32+
return String.format("%s%s", text, END_OF_TEXT_ID);
33+
}
34+
35+
public static String header(@NonNull String text) {
36+
return String.format("%s%s%s", START_HEADER_ID, text, END_HEADER_ID);
37+
}
38+
39+
public static String userHeader() {
40+
return Llama3Tokenizer.header(ROLE_USER);
41+
}
42+
43+
public static String assistantHeader() {
44+
return Llama3Tokenizer.header(ROLE_ASSISTANT);
45+
}
46+
47+
public static String systemHeader() {
48+
return Llama3Tokenizer.header(ROLE_SYSTEM);
49+
}
50+
51+
public static Llama3TextCompletion generateTextCompletion(@NonNull TextConversation conversation) {
52+
final var textCompletionBuilder = new StringBuilder();
53+
textCompletionBuilder
54+
.append(BEGIN_OF_TEXT)
55+
.append(systemHeader())
56+
.append(conversation.getSystemMessage());
57+
58+
conversation.getMessages().forEach(addMessageToTextCompletion(textCompletionBuilder));
59+
final var textCompletion = textCompletionBuilder.toString();
60+
return new Llama3TextCompletion(textCompletion);
61+
}
62+
63+
public static Integer approximateConversationContextSize(@NonNull TextConversation conversation, @Nullable Integer tokenSize) {
64+
final var conversationPlainText = conversation.getMessages().stream()
65+
.map(TextMessage::getText)
66+
.collect(Collectors.joining(System.lineSeparator()));
67+
68+
return getTokens(conversationPlainText, tokenSize);
69+
}
70+
71+
public static TextConversation fitToContextWindow(@NonNull TextConversation conversation, @Nullable Integer contextWindowSize) {
72+
contextWindowSize = Objects.requireNonNullElse(contextWindowSize, DEFAULT_CONTEXT_WINDOW_SIZE);
73+
74+
final var systemMessage = conversation.getSystemMessage();
75+
final var systemMessageTokens = getTokens(systemMessage, null);
76+
int remainingTokens = contextWindowSize - systemMessageTokens;
77+
78+
final var messages = conversation.getMessages();
79+
final var fittedMessages = new LinkedList<TextMessage>();
80+
81+
for (int i = messages.size() - 1; i >= 0; i--) {
82+
final var message = messages.get(i);
83+
final var messageTokens = getTokens(message.getText(), null);
84+
85+
if (remainingTokens - messageTokens >= 0) {
86+
fittedMessages.addFirst(message);
87+
remainingTokens -= messageTokens;
88+
} else {
89+
break;
90+
}
91+
}
92+
93+
return new TextConversation(systemMessage, fittedMessages);
94+
}
95+
96+
Consumer<TextMessage> addMessageToTextCompletion(@NonNull StringBuilder textCompletionBuilder) {
97+
return textMessage -> textCompletionBuilder
98+
.append(header(textMessage.getRole()))
99+
.append(textMessage.getText())
100+
.append(END_OF_TEXT_ID);
101+
}
102+
103+
private Integer getTokens(@NonNull String string, @Nullable Integer tokenSize) {
104+
tokenSize = Objects.requireNonNullElse(tokenSize, APPROXIMATE_CHARACTERS_PER_TOKEN);
105+
return string.length() / tokenSize;
106+
}
107+
}

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/llama3/Llama3Tokens.java

Lines changed: 0 additions & 63 deletions
This file was deleted.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package io.graversen.replicate.llama3;
2+
3+
import io.graversen.replicate.common.TextConversation;
4+
import io.graversen.replicate.common.TextMessage;
5+
import org.junit.jupiter.api.Test;
6+
7+
import java.util.List;
8+
9+
import static org.junit.jupiter.api.Assertions.*;
10+
11+
class Llama3TokenizerTest {
12+
private final String systemMessage = "You are a helpful assistant";
13+
private final List<TextMessage> messages = List.of(
14+
new TextMessage("user", "Hello"),
15+
new TextMessage("assistant", "Hi there!"),
16+
new TextMessage("user", "How are you?"),
17+
new TextMessage("assistant", "I'm fine, thank you! How can I assist you today?"),
18+
new TextMessage("user", "Tell me a joke."),
19+
new TextMessage("assistant", "Why don't scientists trust atoms? Because they make up everything!")
20+
);
21+
22+
@Test
23+
public void fitToContextWindow_defaultWindow() {
24+
final var conversation = new TextConversation(systemMessage, messages);
25+
final var fittedConversation = Llama3Tokenizer.fitToContextWindow(conversation, null);
26+
27+
assertNotNull(fittedConversation);
28+
assertTrue(fittedConversation.getMessages().size() <= messages.size());
29+
assertEquals(systemMessage, fittedConversation.getSystemMessage());
30+
assertEquals(messages.get(0), fittedConversation.getMessages().get(0));
31+
assertEquals(messages.get(5), fittedConversation.getMessages().get(5));
32+
fittedConversation.getMessages().forEach(message ->
33+
assertTrue(messages.contains(message), "Fitted conversation should only contain messages from the original conversation")
34+
);
35+
}
36+
37+
@Test
38+
public void fitToContextWindow_windowTooSmall() {
39+
final var conversation = new TextConversation(systemMessage, messages);
40+
final var fittedConversation = Llama3Tokenizer.fitToContextWindow(conversation, 32);
41+
42+
assertNotNull(fittedConversation);
43+
assertTrue(fittedConversation.getMessages().size() < messages.size());
44+
assertNotEquals(messages.get(0), fittedConversation.getMessages().get(0));
45+
assertEquals(messages.get(5), fittedConversation.getMessages().get(1));
46+
}
47+
}

0 commit comments

Comments
 (0)