1+ package com .plexpt .chatgpt .util ;
2+
3+ import cn .hutool .core .util .StrUtil ;
4+ import com .knuddels .jtokkit .Encodings ;
5+ import com .knuddels .jtokkit .api .Encoding ;
6+ import com .knuddels .jtokkit .api .EncodingRegistry ;
7+ import com .plexpt .chatgpt .entity .chat .ChatCompletion ;
8+ import com .plexpt .chatgpt .entity .chat .Message ;
9+ import lombok .experimental .UtilityClass ;
10+
11+ import java .util .HashMap ;
12+ import java .util .List ;
13+ import java .util .Map ;
14+ import java .util .Optional ;
15+
16+ @ UtilityClass
17+ public class TokensUtil {
18+
19+ private static final Map <String , Encoding > modelEncodingMap = new HashMap <>();
20+ private static final EncodingRegistry encodingRegistry = Encodings .newDefaultEncodingRegistry ();
21+
22+ static {
23+ for (ChatCompletion .Model model : ChatCompletion .Model .values ()) {
24+ Optional <Encoding > encodingForModel = encodingRegistry .getEncodingForModel (model .getName ());
25+ encodingForModel .ifPresent (encoding -> modelEncodingMap .put (model .getName (), encoding ));
26+ }
27+ }
28+
29+ /**
30+ * 计算tokens
31+ * @param modelName 模型名称
32+ * @param messages 消息列表
33+ * @return 计算出的tokens数量
34+ */
35+
36+ public static int tokens (String modelName , List <Message > messages ) {
37+ Encoding encoding = modelEncodingMap .get (modelName );
38+ if (encoding == null ) {
39+ throw new IllegalArgumentException ("Unsupported model: " + modelName );
40+ }
41+
42+ int tokensPerMessage = 0 ;
43+ int tokensPerName = 0 ;
44+ if (modelName .startsWith ("gpt-4" )) {
45+ tokensPerMessage = 3 ;
46+ tokensPerName = 1 ;
47+ } else if (modelName .startsWith ("gpt-3.5-turbo" )) {
48+ tokensPerMessage = 4 ; // every message follows <|start|>{role/name}\n{content}<|end|>\n
49+ tokensPerName = -1 ; // if there's a name, the role is omitted
50+ }
51+ int sum = 0 ;
52+ for (Message message : messages ) {
53+ sum += tokensPerMessage ;
54+ sum += encoding .countTokens (message .getContent ());
55+ sum += encoding .countTokens (message .getRole ());
56+ if (StrUtil .isNotBlank (message .getName ())) {
57+ sum += encoding .countTokens (message .getName ());
58+ sum += tokensPerName ;
59+ }
60+ }
61+ sum += 3 ;
62+ return sum ;
63+ }
64+ }
0 commit comments