Skip to content

Commit 51a8633

Browse files
committed
Add support for gpt-4-vision
support the content in chat completion with format https://platform.openai.com/docs/guides/vision
1 parent 99162e0 commit 51a8633

12 files changed

Lines changed: 333 additions & 18 deletions

File tree

api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionChoice.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ public class ChatCompletionChoice {
1515
Integer index;
1616

1717
/**
18-
* The {@link ChatMessageRole#assistant} message or delta (when streaming) which was generated
18+
* The {@link ChatMessageRole#ASSISTANT} message or delta (when streaming) which was generated
1919
*/
2020
@JsonAlias("delta")
21-
ChatMessage message;
21+
ChatMessage<String> message;
2222

2323
/**
2424
* The reason why GPT stopped generating, for example "length".

api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import com.fasterxml.jackson.annotation.JsonInclude;
44
import com.fasterxml.jackson.annotation.JsonProperty;
5-
import lombok.*;
5+
import lombok.AllArgsConstructor;
6+
import lombok.Data;
7+
import lombok.NoArgsConstructor;
68

79
/**
810
* <p>Each object has a role (either "system", "user", or "assistant") and content (the content of the message). Conversations can be as short as 1 message or fill many pages.</p>
@@ -16,32 +18,34 @@
1618
*/
1719
@Data
1820
@NoArgsConstructor(force = true)
19-
@RequiredArgsConstructor
2021
@AllArgsConstructor
21-
public class ChatMessage {
22+
public class ChatMessage<T> {
2223

2324
/**
2425
* Must be either 'system', 'user', 'assistant' or 'function'.<br>
2526
* You may use {@link ChatMessageRole} enum.
2627
*/
27-
@NonNull
2828
String role;
29+
/**
30+
* An array of content parts with a defined type, each can be of type text or image_url when passing in images. You
31+
* can pass multiple images by adding multiple image_url content parts. Image input is only supported when using the
32+
* gpt-4-visual-preview model.
33+
*/
2934
@JsonInclude() // content should always exist in the call, even if it is null
30-
String content;
35+
T content;
3136
//name is optional, The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
3237
String name;
3338
@JsonProperty("function_call")
3439
ChatFunctionCall functionCall;
3540

36-
public ChatMessage(String role, String content) {
41+
public ChatMessage(String role, T content) {
3742
this.role = role;
3843
this.content = content;
3944
}
4045

41-
public ChatMessage(String role, String content, String name) {
46+
public ChatMessage(String role, T content, String name) {
4247
this.role = role;
4348
this.content = content;
4449
this.name = name;
4550
}
46-
4751
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
import com.fasterxml.jackson.annotation.JsonProperty;
4+
import lombok.AllArgsConstructor;
5+
import lombok.Data;
6+
import lombok.NoArgsConstructor;
7+
8+
@Data
9+
@NoArgsConstructor
10+
public class ChatMessageContent {
11+
12+
/**
13+
* The type of the content part
14+
*
15+
* @see ChatMessageContentType
16+
*/
17+
private String type;
18+
19+
/**
20+
* The text content.
21+
*/
22+
private String text;
23+
24+
/**
25+
* Image input is only supported when using the gpt-4-visual-preview model.
26+
*/
27+
@JsonProperty("image_url")
28+
private ImageUrl imageUrl;
29+
30+
public ChatMessageContent(String text) {
31+
this.type = ChatMessageContentType.TEXT.value();
32+
this.text = text;
33+
}
34+
35+
public ChatMessageContent(ImageUrl imageUrl) {
36+
this.type = ChatMessageContentType.IMAGE_URL.value();
37+
this.imageUrl = imageUrl;
38+
}
39+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
/**
4+
* see {@link ChatMessage} documentation.
5+
*/
6+
public enum ChatMessageContentType {
7+
8+
TEXT("text"),
9+
IMAGE_URL("image_url");
10+
11+
private final String value;
12+
13+
ChatMessageContentType(final String value) {
14+
this.value = value;
15+
}
16+
17+
public String value() {
18+
return value;
19+
}
20+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
import lombok.*;
4+
5+
@Data
6+
@AllArgsConstructor
7+
@NoArgsConstructor
8+
@RequiredArgsConstructor
9+
public class ImageUrl {
10+
11+
/**
12+
* Either a URL of the image or the base64 encoded image data.
13+
*/
14+
@NonNull
15+
private String url;
16+
17+
/**
18+
* Specifies the detail level of the image. Learn more in the
19+
* <a href="https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding">
20+
* Vision guide</a>.
21+
*/
22+
private String detail;
23+
}

api/src/main/java/com/theokanning/openai/utils/TikTokensUtil.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ public static int tokens(String modelName, List<ChatMessage> messages) {
186186
int sum = 0;
187187
for (ChatMessage msg : messages) {
188188
sum += tokensPerMessage;
189-
sum += tokens(encoding, msg.getContent());
189+
if(msg.getContent() instanceof String){
190+
sum += tokens(encoding, msg.getContent().toString());
191+
}
190192
sum += tokens(encoding, msg.getRole());
191193
sum += tokens(encoding, msg.getName());
192194
if (isNotBlank(msg.getName())) {
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package com.theokanning.openai.utils;
2+
3+
import com.theokanning.openai.completion.chat.ChatMessage;
4+
import com.theokanning.openai.completion.chat.ChatMessageContent;
5+
import com.theokanning.openai.completion.chat.ImageUrl;
6+
7+
import java.util.ArrayList;
8+
import java.util.List;
9+
import java.util.regex.Matcher;
10+
import java.util.regex.Pattern;
11+
12+
/**
13+
* Vision tool class
14+
*
15+
* @author cong
16+
* @since 2023/11/17
17+
*/
18+
public class VisionUtil {
19+
20+
private static final Pattern pattern = Pattern.compile("(https?://\\S+)");
21+
22+
public static ChatMessage<List<ChatMessageContent>> convertForVision(ChatMessage<String> msg) {
23+
List<ChatMessageContent> content = new ArrayList<>();
24+
String sourceText = msg.getContent();
25+
// Regular expression to match image URLs
26+
Matcher matcher = pattern.matcher(sourceText);
27+
// Find image URLs and split the string
28+
int lastIndex = 0;
29+
while (matcher.find()) {
30+
String url = matcher.group();
31+
// Add the text before the image URL
32+
if (matcher.start() > lastIndex) {
33+
String text = sourceText.substring(lastIndex, matcher.start()).trim();
34+
content.add(new ChatMessageContent(text));
35+
}
36+
// Add the image URL
37+
ImageUrl imageUrl = new ImageUrl();
38+
imageUrl.setUrl(url);
39+
content.add(new ChatMessageContent(imageUrl));
40+
lastIndex = matcher.end();
41+
}
42+
// Add the remaining text
43+
if (lastIndex < sourceText.length()) {
44+
String text = sourceText.substring(lastIndex).trim();
45+
content.add(new ChatMessageContent(text));
46+
}
47+
return new ChatMessage<>(msg.getRole(), content, msg.getName());
48+
}
49+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package example;
2+
3+
import com.theokanning.openai.completion.chat.*;
4+
import com.theokanning.openai.service.OpenAiService;
5+
import com.theokanning.openai.utils.VisionUtil;
6+
7+
import java.time.Duration;
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
11+
class OpenAiApiVisionExample {
12+
public static void main(String... args) {
13+
String token = System.getenv("OPENAI_TOKEN");
14+
OpenAiService service = new OpenAiService(token, Duration.ofSeconds(30));
15+
16+
System.out.println("Streaming chat completion...");
17+
final List<ChatMessage> messages = new ArrayList<>();
18+
List<ChatMessageContent> content = new ArrayList<>();
19+
content.add(new ChatMessageContent("What’s in this image?"));
20+
content.add(new ChatMessageContent(new ImageUrl(
21+
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg")));
22+
messages.add(new ChatMessage<>(ChatMessageRole.USER.value(), content));
23+
24+
// use VisionUtil to convert image prompt to OpenAI format
25+
System.out.println("Converting image to OpenAI format...");
26+
ChatMessage<List<ChatMessageContent>> visionChatMessage = VisionUtil.convertForVision(
27+
new ChatMessage<>(ChatMessageRole.USER.value(),
28+
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg "
29+
+ "What are in these images? Is there any difference between them?"));
30+
messages.add(visionChatMessage);
31+
32+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
33+
.builder()
34+
.model("gpt-4-vision-preview")
35+
.messages(messages)
36+
.maxTokens(300)
37+
.build();
38+
39+
service.streamChatCompletion(chatCompletionRequest).blockingForEach(System.out::println);
40+
service.shutdownExecutor();
41+
}
42+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.theokanning.openai.service;
2+
3+
import com.fasterxml.jackson.annotation.JsonProperty;
4+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
5+
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
6+
7+
/**
8+
* @author cong
9+
* @since 2023/11/17
10+
*/
11+
public abstract class ChatMessageMixIn {
12+
@JsonProperty("content")
13+
@JsonSerialize(using = ChatMessageSerializerAndDeserializer.ChatMessageContentSerializer.class)
14+
@JsonDeserialize(using = ChatMessageSerializerAndDeserializer.ChatMessageContentDeserializer.class)
15+
abstract Object getContent();
16+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package com.theokanning.openai.service;
2+
3+
import com.fasterxml.jackson.core.JsonGenerator;
4+
import com.fasterxml.jackson.core.JsonParser;
5+
import com.fasterxml.jackson.databind.*;
6+
import com.theokanning.openai.completion.chat.ChatMessageContent;
7+
import com.theokanning.openai.completion.chat.ChatMessageContentType;
8+
import com.theokanning.openai.completion.chat.ImageUrl;
9+
10+
import java.io.IOException;
11+
import java.util.ArrayList;
12+
import java.util.List;
13+
import java.util.Optional;
14+
15+
public class ChatMessageSerializerAndDeserializer {
16+
17+
public static class ChatMessageContentSerializer extends JsonSerializer<Object> {
18+
@Override
19+
public void serialize(Object content, JsonGenerator gen, SerializerProvider serializers) throws IOException {
20+
if (content == null) {
21+
gen.writeNull();
22+
return;
23+
}
24+
if (content instanceof String) {
25+
gen.writeString((String)content);
26+
return;
27+
}
28+
if (content instanceof List) {
29+
gen.writeStartArray();
30+
List<?> contentList = (List<?>)content;
31+
for (Object item : contentList) {
32+
if (item instanceof ChatMessageContent) {
33+
ChatMessageContent contentItem = (ChatMessageContent)item;
34+
gen.writeStartObject();
35+
gen.writeStringField("type", contentItem.getType());
36+
if (ChatMessageContentType.TEXT.value().equals(contentItem.getType())) {
37+
gen.writeStringField("text", contentItem.getText());
38+
} else if (ChatMessageContentType.IMAGE_URL.value().equals(contentItem.getType())) {
39+
gen.writeObjectFieldStart("image_url");
40+
gen.writeStringField("url", contentItem.getImageUrl().getUrl());
41+
gen.writeStringField("detail", contentItem.getImageUrl().getDetail());
42+
gen.writeEndObject();
43+
}
44+
gen.writeEndObject();
45+
}
46+
}
47+
gen.writeEndArray();
48+
}
49+
}
50+
}
51+
52+
public static class ChatMessageContentDeserializer extends JsonDeserializer<Object> {
53+
@Override
54+
public Object deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
55+
JsonNode contentNode = p.readValueAsTree();
56+
if (contentNode.isTextual()) {
57+
return contentNode.asText();
58+
}
59+
if (contentNode.isArray()) {
60+
List<Object> contentList = new ArrayList<>();
61+
for (JsonNode itemNode : contentNode) {
62+
String type = itemNode.get("type").asText();
63+
if (ChatMessageContentType.TEXT.value().equals(type)) {
64+
contentList.add(new ChatMessageContent(itemNode.get("text").asText()));
65+
} else if (ChatMessageContentType.IMAGE_URL.value().equals(type)) {
66+
JsonNode imageUrlJsonNode = itemNode.get("image_url");
67+
ImageUrl imageUrl = new ImageUrl();
68+
imageUrl.setUrl(Optional.ofNullable(imageUrlJsonNode.get("url"))
69+
.map(JsonNode::asText).orElse(null));
70+
imageUrl.setDetail(Optional.ofNullable(imageUrlJsonNode.get("detail"))
71+
.map(JsonNode::asText).orElse(null));
72+
contentList.add(new ChatMessageContent(imageUrl));
73+
}
74+
}
75+
return contentList;
76+
}
77+
return null;
78+
}
79+
}
80+
81+
}

0 commit comments

Comments
 (0)