Skip to content

Commit 059f323

Browse files
committed
⚡ Refine Llama3 token approximation
1 parent 3ccfedc commit 059f323

3 files changed

Lines changed: 36 additions & 22 deletions

File tree

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/common/TextConversation.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,19 @@ public class TextConversation {
1313
public static TextConversation of(@NonNull String systemMessage, @NonNull String userMessage) {
1414
return new TextConversation(
1515
systemMessage,
16-
List.of(
17-
TextMessage.user(userMessage)
18-
)
16+
List.of(TextMessage.user(userMessage))
1917
);
2018
}
19+
20+
public List<TextMessage> getLastMessages(@NonNull Integer conversationSize) {
21+
return messages.stream()
22+
.skip(Math.max(0, messages.size() - conversationSize))
23+
.toList();
24+
}
25+
26+
public List<TextMessage> getFirstMessages(@NonNull Integer conversationSize) {
27+
return messages.stream()
28+
.limit(conversationSize)
29+
.toList();
30+
}
2131
}

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/llama3/Llama3Tokenizer.java

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import java.util.LinkedList;
1010
import java.util.Objects;
1111
import java.util.function.Consumer;
12-
import java.util.stream.Collectors;
1312

1413
@UtilityClass
1514
public class Llama3Tokenizer {
@@ -48,39 +47,38 @@ public static String systemHeader() {
4847
return Llama3Tokenizer.header(ROLE_SYSTEM);
4948
}
5049

51-
public static Llama3TextCompletion generateTextCompletion(@NonNull TextConversation conversation) {
52-
final var textCompletionBuilder = new StringBuilder();
53-
textCompletionBuilder
54-
.append(BEGIN_OF_TEXT)
55-
.append(systemHeader())
56-
.append(conversation.getSystemMessage());
50+
public static String tokenizeMessage(@NonNull TextMessage message) {
51+
final var messageBuilder = new StringBuilder();
52+
addMessageToTextCompletion(messageBuilder).accept(message);
53+
return messageBuilder.toString();
54+
}
5755

56+
public static Llama3TextCompletion generateTextCompletion(@NonNull TextConversation conversation) {
57+
final var textCompletionBuilder = createBeginOfText(conversation.getSystemMessage());
5858
conversation.getMessages().forEach(addMessageToTextCompletion(textCompletionBuilder));
5959
final var textCompletion = textCompletionBuilder.toString();
6060
return new Llama3TextCompletion(textCompletion);
6161
}
6262

6363
public static Integer approximateConversationContextSize(@NonNull TextConversation conversation, @Nullable Integer tokenSize) {
64-
final var conversationPlainText = conversation.getMessages().stream()
65-
.map(TextMessage::getText)
66-
.collect(Collectors.joining(System.lineSeparator()));
67-
68-
return getTokens(conversationPlainText, tokenSize);
64+
final var conversationTextCompletion = generateTextCompletion(conversation);
65+
return getTokens(conversationTextCompletion.getText(), tokenSize);
6966
}
7067

7168
public static TextConversation fitToContextWindow(@NonNull TextConversation conversation, @Nullable Integer contextWindowSize) {
7269
contextWindowSize = Objects.requireNonNullElse(contextWindowSize, DEFAULT_CONTEXT_WINDOW_SIZE);
7370

74-
final var systemMessage = conversation.getSystemMessage();
75-
final var systemMessageTokens = getTokens(systemMessage, null);
71+
final var systemMessage = createBeginOfText(conversation.getSystemMessage());
72+
final var systemMessageTokens = getTokens(systemMessage.toString(), null);
7673
int remainingTokens = contextWindowSize - systemMessageTokens;
7774

7875
final var messages = conversation.getMessages();
7976
final var fittedMessages = new LinkedList<TextMessage>();
8077

8178
for (int i = messages.size() - 1; i >= 0; i--) {
8279
final var message = messages.get(i);
83-
final var messageTokens = getTokens(message.getText(), null);
80+
final var tokenizedMessage = tokenizeMessage(message);
81+
final var messageTokens = getTokens(tokenizedMessage, null);
8482

8583
if (remainingTokens - messageTokens >= 0) {
8684
fittedMessages.addFirst(message);
@@ -90,7 +88,7 @@ public static TextConversation fitToContextWindow(@NonNull TextConversation conv
9088
}
9189
}
9290

93-
return new TextConversation(systemMessage, fittedMessages);
91+
return new TextConversation(conversation.getSystemMessage(), fittedMessages);
9492
}
9593

9694
Consumer<TextMessage> addMessageToTextCompletion(@NonNull StringBuilder textCompletionBuilder) {
@@ -100,8 +98,14 @@ Consumer<TextMessage> addMessageToTextCompletion(@NonNull StringBuilder textComp
10098
.append(END_OF_TEXT_ID);
10199
}
102100

103-
private Integer getTokens(@NonNull String string, @Nullable Integer tokenSize) {
101+
Integer getTokens(@NonNull String string, @Nullable Integer tokenSize) {
104102
tokenSize = Objects.requireNonNullElse(tokenSize, APPROXIMATE_CHARACTERS_PER_TOKEN);
105103
return string.length() / tokenSize;
106104
}
105+
106+
StringBuilder createBeginOfText(@NonNull String systemMessage) {
107+
return new StringBuilder().append(BEGIN_OF_TEXT)
108+
.append(systemHeader())
109+
.append(systemMessage);
110+
}
107111
}

spring-boot-starter-replicate/src/test/java/io/graversen/replicate/llama3/Llama3TokenizerTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ public void fitToContextWindow_defaultWindow() {
3737
@Test
3838
public void fitToContextWindow_windowTooSmall() {
3939
final var conversation = new TextConversation(systemMessage, messages);
40-
final var fittedConversation = Llama3Tokenizer.fitToContextWindow(conversation, 32);
40+
final var fittedConversation = Llama3Tokenizer.fitToContextWindow(conversation, 128);
4141

4242
assertNotNull(fittedConversation);
4343
assertTrue(fittedConversation.getMessages().size() < messages.size());
4444
assertNotEquals(messages.get(0), fittedConversation.getMessages().get(0));
45-
assertEquals(messages.get(5), fittedConversation.getMessages().get(1));
45+
assertEquals(messages.get(2), fittedConversation.getMessages().get(1));
4646
}
4747
}

0 commit comments

Comments
 (0)