Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,37 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.JsonBaseModel;
import com.google.adk.models.LlmRequest;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.Part;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Data Transfer Objects for Chat Completion API requests.
*
* <p>Can be used to translate from a {@link LlmRequest} into a {@link ChatCompletionsRequest} using
* {@link #fromLlmRequest(LlmRequest, boolean)}.
*
* <p>See
* https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create
*/
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
final class ChatCompletionsRequest {
public final class ChatCompletionsRequest {

/**
* See
Expand Down Expand Up @@ -249,6 +268,319 @@ final class ChatCompletionsRequest {
@JsonProperty("extra_body")
public Map<String, Object> extraBody;

private static final Logger logger = LoggerFactory.getLogger(ChatCompletionsRequest.class);
private static final ObjectMapper objectMapper = JsonBaseModel.getMapper();

/**
* Converts a standard {@link LlmRequest} into a {@link ChatCompletionsRequest} for
* /chat/completions compatible endpoints.
*
* @param llmRequest The internal source request containing contents, configuration, and tool
* definitions.
* @param responseStreaming True if the request asks for a streaming response.
* @return A populated ChatCompletionsRequest ready for JSON serialization.
*/
public static ChatCompletionsRequest fromLlmRequest(
LlmRequest llmRequest, boolean responseStreaming) {
ChatCompletionsRequest request = new ChatCompletionsRequest();
request.model = llmRequest.model().orElse("");
request.stream = responseStreaming;
if (responseStreaming) {
StreamOptions options = new StreamOptions();
options.includeUsage = true;
request.streamOptions = options;
}

boolean isOSeries = request.model.matches("^o\\d+(?:-.*)?$");

List<Message> messages = new ArrayList<>();

llmRequest
.config()
.flatMap(config -> processSystemInstruction(config, isOSeries))
.ifPresent(messages::add);

for (Content content : llmRequest.contents()) {
messages.addAll(processContent(content));
}

request.messages = ImmutableList.copyOf(messages);

llmRequest
.config()
.ifPresent(
config -> {
handleConfigOptions(config, request);
handleTools(config, request);
});

return request;
}

/**
* Processes the system instruction configuration and returns a mapped Message if present.
*
* @param config The content generation configuration that may contain a system instruction.
* @param isOSeries True if the target model belongs to the OpenAI o-series (e.g., o1, o3), which
* requires the "developer" role instead of the standard "system" role.
* @return An Optional containing the mapped instruction, or empty if none exists.
*/
private static Optional<Message> processSystemInstruction(
GenerateContentConfig config, boolean isOSeries) {
if (config.systemInstruction().isPresent()) {
Message systemMsg = new Message();
systemMsg.role = isOSeries ? "developer" : "system";
systemMsg.content = new MessageContent(config.systemInstruction().get().text());
return Optional.of(systemMsg);
}
return Optional.empty();
}

/**
* Processes incoming content and returns a list of messages resulting from it.
*
* @param content The incoming content containing parts to map.
* @return A list of mapped messages.
*/
private static List<Message> processContent(Content content) {
Message msg = new Message();
String role = content.role().orElse("user");
msg.role = role.equals("model") ? "assistant" : role;

List<ContentPart> contentParts = new ArrayList<>();
List<ChatCompletionsCommon.ToolCall> toolCalls = new ArrayList<>();
List<Message> toolResponses = new ArrayList<>();

content
.parts()
.ifPresent(
parts -> {
for (Part part : parts) {
if (part.text().isPresent()) {
contentParts.add(processTextPart(part));
} else if (part.inlineData().isPresent()) {
contentParts.add(processInlineDataPart(part));
} else if (part.fileData().isPresent()) {
contentParts.add(processFileDataPart(part));
} else if (part.functionCall().isPresent()) {
toolCalls.add(processFunctionCallPart(part));
} else if (part.functionResponse().isPresent()) {
toolResponses.add(processFunctionResponsePart(part));
} else if (part.executableCode().isPresent()) {
logger.warn("Executable code is not supported in Chat Completion conversion");
} else if (part.codeExecutionResult().isPresent()) {
logger.warn(
"Code execution result is not supported in Chat Completion conversion");
}
}
});

if (!toolResponses.isEmpty()) {
return toolResponses;
} else {
if (!toolCalls.isEmpty()) {
msg.toolCalls = ImmutableList.copyOf(toolCalls);
}
if (!contentParts.isEmpty()) {
if (contentParts.size() == 1 && Objects.equals(contentParts.get(0).type, "text")) {
msg.content = new MessageContent(contentParts.get(0).text);
} else {
msg.content = new MessageContent(ImmutableList.copyOf(contentParts));
}
}
List<Message> messages = new ArrayList<>();
messages.add(msg);
return messages;
}
}

/**
* Processes a text part and returns a mapped ContentPart.
*
* @param part The input part containing simple text.
* @return The mapped text part.
*/
private static ContentPart processTextPart(Part part) {
ContentPart textPart = new ContentPart();
textPart.type = "text";
textPart.text = part.text().get();
return textPart;
}

/**
* Processes an inline data part and returns a mapped ContentPart.
*
* @param part The input part containing base64 inline data.
* @return The mapped inline data part.
*/
private static ContentPart processInlineDataPart(Part part) {
ContentPart imgPart = new ContentPart();
imgPart.type = "image_url";
ImageUrl imageUrl = new ImageUrl();
imageUrl.url =
"data:"
+ part.inlineData().get().mimeType().orElse("image/jpeg")
+ ";base64,"
+ Base64.getEncoder().encodeToString(part.inlineData().get().data().get());
imgPart.imageUrl = imageUrl;
return imgPart;
}

/**
* Processes a file data part and returns a mapped ContentPart.
*
* @param part The input part referencing a stored file via URI.
* @return The mapped file data part.
*/
private static ContentPart processFileDataPart(Part part) {
ContentPart imgPart = new ContentPart();
imgPart.type = "image_url";
ImageUrl imageUrl = new ImageUrl();
imageUrl.url = part.fileData().get().fileUri().orElse("");
imgPart.imageUrl = imageUrl;
return imgPart;
}

/**
* Processes a function call part and returns a mapped ToolCall.
*
* @param part The input part containing a requested function call or invocation.
* @return The mapped function call tool call.
*/
private static ChatCompletionsCommon.ToolCall processFunctionCallPart(Part part) {
com.google.genai.types.FunctionCall fc = part.functionCall().get();
ChatCompletionsCommon.ToolCall toolCall = new ChatCompletionsCommon.ToolCall();
toolCall.id = fc.id().orElse("call_" + fc.name().orElse("unknown"));
toolCall.type = "function";
ChatCompletionsCommon.Function function = new ChatCompletionsCommon.Function();
function.name = fc.name().orElse("");
if (fc.args().isPresent()) {
try {
function.arguments = objectMapper.writeValueAsString(fc.args().get());
} catch (Exception e) {
logger.warn("Failed to serialize function arguments", e);
}
}
toolCall.function = function;
return toolCall;
}

/**
* Processes a function response part and returns a mapped Message.
*
* @param part The input part containing the execution results of a function.
* @return The mapped tool response message.
*/
private static Message processFunctionResponsePart(Part part) {
FunctionResponse fr = part.functionResponse().get();
Message toolResp = new Message();
toolResp.role = "tool";
toolResp.toolCallId = fr.id().orElse("");
if (fr.response().isPresent()) {
try {
toolResp.content = new MessageContent(objectMapper.writeValueAsString(fr.response().get()));
} catch (Exception e) {
logger.warn("Failed to serialize tool response", e);
}
}
return toolResp;
}

/**
* Updates the request based on the provided configuration options.
*
* @param config The content generation configuration containing parameters such as temperature.
* @param request The chat completions request to populate with matching options.
*/
private static void handleConfigOptions(
GenerateContentConfig config, ChatCompletionsRequest request) {
config.temperature().ifPresent(v -> request.temperature = v.doubleValue());
config.topP().ifPresent(v -> request.topP = v.doubleValue());
config
.maxOutputTokens()
.ifPresent(
v -> {
request.maxCompletionTokens = Math.toIntExact(v);
});
config.stopSequences().ifPresent(v -> request.stop = new StopCondition(v));
config.candidateCount().ifPresent(v -> request.n = Math.toIntExact(v));
config.presencePenalty().ifPresent(v -> request.presencePenalty = v.doubleValue());
config.frequencyPenalty().ifPresent(v -> request.frequencyPenalty = v.doubleValue());
config.seed().ifPresent(v -> request.seed = v.longValue());

if (config.responseJsonSchema().isPresent()) {
ResponseFormatJsonSchema format = new ResponseFormatJsonSchema();
ResponseFormatJsonSchema.JsonSchema schema = new ResponseFormatJsonSchema.JsonSchema();
schema.name = "response_schema";
schema.schema =
objectMapper.convertValue(
config.responseJsonSchema().get(), new TypeReference<Map<String, Object>>() {});
schema.strict = true;
format.jsonSchema = schema;
request.responseFormat = format;
} else if (config.responseMimeType().isPresent()
&& config.responseMimeType().get().equals("application/json")) {
request.responseFormat = new ResponseFormatJsonObject();
}

if (config.responseLogprobs().isPresent() && config.responseLogprobs().get()) {
request.logprobs = true;
config.logprobs().ifPresent(v -> request.topLogprobs = Math.toIntExact(v));
}
}

/**
* Updates the request tools list based on the provided tools configuration.
*
* @param config The content generation configuration defining available tools.
* @param request The chat completions request to populate with mapped tool definitions.
*/
private static void handleTools(GenerateContentConfig config, ChatCompletionsRequest request) {
if (config.tools().isPresent()) {
List<Tool> tools = new ArrayList<>();
for (com.google.genai.types.Tool t : config.tools().get()) {
if (t.functionDeclarations().isPresent()) {
for (FunctionDeclaration fd : t.functionDeclarations().get()) {
Tool tool = new Tool();
tool.type = "function";
FunctionDefinition def = new FunctionDefinition();
def.name = fd.name().orElse("");
def.description = fd.description().orElse("");
fd.parameters()
.ifPresent(
params ->
def.parameters =
objectMapper.convertValue(
params, new TypeReference<Map<String, Object>>() {}));
tool.function = def;
tools.add(tool);
}
}
}
if (!tools.isEmpty()) {
request.tools = ImmutableList.copyOf(tools);
if (config.toolConfig().isPresent()
&& config.toolConfig().get().functionCallingConfig().isPresent()) {
config
.toolConfig()
.get()
.functionCallingConfig()
.get()
.mode()
.ifPresent(
mode -> {
switch (mode.knownEnum()) {
case ANY -> request.toolChoice = new ToolChoiceMode("required");
case NONE -> request.toolChoice = new ToolChoiceMode("none");
case AUTO -> request.toolChoice = new ToolChoiceMode("auto");
default -> {}
}
});
}
}
}
}

/**
* A catch-all class for message parameters. See
* https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20messages%20%3E%20(schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public final class ChatCompletionsResponse {

private ChatCompletionsResponse() {}

static @Nullable FinishReason mapFinishReason(String reason) {
static @Nullable FinishReason mapFinishReason(@Nullable String reason) {
if (reason == null) {
return null;
}
Expand All @@ -62,7 +62,7 @@ private ChatCompletionsResponse() {}
};
}

static @Nullable GenerateContentResponseUsageMetadata mapUsage(Usage usage) {
static @Nullable GenerateContentResponseUsageMetadata mapUsage(@Nullable Usage usage) {
if (usage == null) {
return null;
}
Expand Down Expand Up @@ -188,8 +188,15 @@ private ImmutableList<Part> mapMessageToParts(Message message) {
return parts.build();
}

/**
* Maps a list of tool calls to a list of {@link Part} objects.
*
* @param toolCalls the list of tool calls to map (non-null).
* @return a list of parts containing converted tool calls.
*/
private ImmutableList<Part> mapToolCallsToParts(
List<ChatCompletionsCommon.ToolCall> toolCalls) {

ImmutableList.Builder<Part> parts = ImmutableList.builder();
for (ChatCompletionsCommon.ToolCall toolCall : toolCalls) {
Part part = toolCall.toPart();
Expand Down
Loading
Loading