Skip to content

Commit 742e167

Browse files
authored
Merge pull request #124 from jiaw3i/main
2 parents 20b3643 + f254c67 commit 742e167

5 files changed

Lines changed: 107 additions & 2 deletions

File tree

pom.xml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,11 @@
135135
<version>4.13.2</version>
136136
<scope>test</scope>
137137
</dependency>
138-
138+
<dependency>
139+
<groupId>com.knuddels</groupId>
140+
<artifactId>jtokkit</artifactId>
141+
<version>0.4.0</version>
142+
</dependency>
139143
</dependencies>
140144

141145
<!-- 下面这个标签里的不能改 -->

src/main/java/com/plexpt/chatgpt/entity/chat/ChatCompletion.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import lombok.NoArgsConstructor;
1515
import lombok.NonNull;
1616
import lombok.extern.slf4j.Slf4j;
17+
import com.plexpt.chatgpt.util.TokensUtil;
1718

1819
/**
1920
* chat
@@ -24,7 +25,7 @@
2425
@Builder
2526
@Slf4j
2627
@AllArgsConstructor
27-
@NoArgsConstructor
28+
@NoArgsConstructor(force = true)
2829
@JsonInclude(JsonInclude.Include.NON_NULL)
2930
public class ChatCompletion implements Serializable {
3031

@@ -125,6 +126,9 @@ public enum Model {
125126
private String name;
126127
}
127128

129+
public int countTokens() {
130+
return TokensUtil.tokens(this.model, this.messages);
131+
}
128132
}
129133

130134

src/main/java/com/plexpt/chatgpt/entity/chat/Message.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.plexpt.chatgpt.entity.chat;
22

3+
import com.fasterxml.jackson.annotation.JsonInclude;
34
import lombok.AllArgsConstructor;
45
import lombok.Builder;
56
import lombok.Data;
@@ -13,12 +14,19 @@
1314
@AllArgsConstructor
1415
@NoArgsConstructor
1516
@Builder
17+
@JsonInclude(JsonInclude.Include.NON_NULL)
1618
public class Message {
1719
/**
1820
* 目前支持三种角色参考官网,进行情景输入:https://platform.openai.com/docs/guides/chat/introduction
1921
*/
2022
private String role;
2123
private String content;
24+
private String name;
25+
26+
public Message(String role, String content) {
27+
this.role = role;
28+
this.content = content;
29+
}
2230

2331
public static Message of(String content) {
2432

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package com.plexpt.chatgpt.util;
2+
3+
import cn.hutool.core.util.StrUtil;
4+
import com.knuddels.jtokkit.Encodings;
5+
import com.knuddels.jtokkit.api.Encoding;
6+
import com.knuddels.jtokkit.api.EncodingRegistry;
7+
import com.plexpt.chatgpt.entity.chat.ChatCompletion;
8+
import com.plexpt.chatgpt.entity.chat.Message;
9+
import lombok.experimental.UtilityClass;
10+
11+
import java.util.HashMap;
12+
import java.util.List;
13+
import java.util.Map;
14+
import java.util.Optional;
15+
16+
@UtilityClass
17+
public class TokensUtil {
18+
19+
private static final Map<String, Encoding> modelEncodingMap = new HashMap<>();
20+
private static final EncodingRegistry encodingRegistry = Encodings.newDefaultEncodingRegistry();
21+
22+
static {
23+
for (ChatCompletion.Model model : ChatCompletion.Model.values()) {
24+
Optional<Encoding> encodingForModel = encodingRegistry.getEncodingForModel(model.getName());
25+
encodingForModel.ifPresent(encoding -> modelEncodingMap.put(model.getName(), encoding));
26+
}
27+
}
28+
29+
/**
30+
* 计算tokens
31+
* @param modelName 模型名称
32+
* @param messages 消息列表
33+
* @return 计算出的tokens数量
34+
*/
35+
36+
public static int tokens(String modelName, List<Message> messages) {
37+
Encoding encoding = modelEncodingMap.get(modelName);
38+
if (encoding == null) {
39+
throw new IllegalArgumentException("Unsupported model: " + modelName);
40+
}
41+
42+
int tokensPerMessage = 0;
43+
int tokensPerName = 0;
44+
if (modelName.startsWith("gpt-4")) {
45+
tokensPerMessage = 3;
46+
tokensPerName = 1;
47+
} else if (modelName.startsWith("gpt-3.5-turbo")) {
48+
tokensPerMessage = 4; // every message follows <|start|>{role/name}\n{content}<|end|>\n
49+
tokensPerName = -1; // if there's a name, the role is omitted
50+
}
51+
int sum = 0;
52+
for (Message message : messages) {
53+
sum += tokensPerMessage;
54+
sum += encoding.countTokens(message.getContent());
55+
sum += encoding.countTokens(message.getRole());
56+
if (StrUtil.isNotBlank(message.getName())) {
57+
sum += encoding.countTokens(message.getName());
58+
sum += tokensPerName;
59+
}
60+
}
61+
sum += 3;
62+
return sum;
63+
}
64+
}

src/test/java/com/plexpt/chatgpt/Test.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,29 @@ public void chatmsg() {
6161
System.out.println(res);
6262
}
6363

64+
/**
65+
* 测试tokens数量计算
66+
*/
67+
@org.junit.Test
68+
public void tokens() {
69+
Message system = Message.ofSystem("你现在是一个诗人,专门写七言绝句");
70+
Message message = Message.of("写一段七言绝句诗,题目是:火锅!");
71+
72+
ChatCompletion chatCompletion1 = ChatCompletion.builder()
73+
.model(ChatCompletion.Model.GPT_3_5_TURBO.getName())
74+
.messages(Arrays.asList(system, message))
75+
.maxTokens(3000)
76+
.temperature(0.9)
77+
.build();
78+
ChatCompletion chatCompletion2 = ChatCompletion.builder()
79+
.model(ChatCompletion.Model.GPT_4.getName())
80+
.messages(Arrays.asList(system, message))
81+
.maxTokens(3000)
82+
.temperature(0.9)
83+
.build();
84+
85+
log.info("{} tokens: {}", chatCompletion1.getModel(), chatCompletion1.countTokens());
86+
log.info("{} tokens: {}", chatCompletion2.getModel(), chatCompletion2.countTokens());
87+
}
88+
6489
}

0 commit comments

Comments
 (0)