diff --git a/src/main/java/dev/braintrust/instrumentation/openai/otel/ChatCompletionEventsHelper.java b/src/main/java/dev/braintrust/instrumentation/openai/otel/ChatCompletionEventsHelper.java index 0814c16d..f2b86314 100644 --- a/src/main/java/dev/braintrust/instrumentation/openai/otel/ChatCompletionEventsHelper.java +++ b/src/main/java/dev/braintrust/instrumentation/openai/otel/ChatCompletionEventsHelper.java @@ -7,7 +7,6 @@ import static io.opentelemetry.api.common.AttributeKey.stringKey; -import com.fasterxml.jackson.databind.ObjectMapper; import com.openai.models.chat.completions.ChatCompletion; import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; import com.openai.models.chat.completions.ChatCompletionContentPartText; @@ -24,16 +23,11 @@ import io.opentelemetry.api.logs.Logger; import io.opentelemetry.api.trace.Span; import io.opentelemetry.context.Context; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.lang.invoke.MethodType; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.stream.Collectors; -import javax.annotation.Nullable; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; @@ -41,8 +35,6 @@ final class ChatCompletionEventsHelper { private static final AttributeKey EVENT_NAME = stringKey("event.name"); - private static final ObjectMapper JSON_MAPPER = - new com.fasterxml.jackson.databind.ObjectMapper(); @SneakyThrows public static void emitPromptLogEvents( @@ -50,10 +42,9 @@ public static void emitPromptLogEvents( Logger eventLogger, ChatCompletionCreateParams request, boolean captureMessageContent) { - Span.current() - .setAttribute( - "braintrust.input_json", - JSON_MAPPER.writeValueAsString(request.messages())); + String semconvJson = GenAiSemconvSerializer.serializeInputMessages(request.messages()); + Span span = Span.current(); + span.setAttribute("gen_ai.input.messages", semconvJson); } private static String contentToString(ChatCompletionToolMessageParam.Content content) { @@ -138,13 +129,12 @@ public static void emitCompletionLogEvents( } else if (completion.choices().size() > 1) { log.debug("multiple choices in OAI response: {}", completion.choices().size()); } else { - Span.current() - .setAttribute( - "braintrust.output_json", - JSON_MAPPER.writeValueAsString( - new ChatCompletionMessage[] { - completion.choices().get(0).message() - })); + // Set gen_ai.output.messages attribute for single choice (most common case) + ChatCompletion.Choice choice = completion.choices().get(0); + String outputJson = + GenAiSemconvSerializer.serializeOutputMessage( + choice.message(), choice.finishReason().toString()); + Span.current().setAttribute("gen_ai.output.messages", outputJson); } for (ChatCompletion.Choice choice : completion.choices()) { ChatCompletionMessage choiceMsg = choice.message(); @@ -211,7 +201,8 @@ private static LogRecordBuilder newEvent(Logger eventLogger, String name) { private static Value buildToolCallEventObject( ChatCompletionMessageToolCall call, boolean captureMessageContent) { Map> result = new HashMap<>(); - FunctionAccess functionAccess = getFunctionAccess(call); + GenAiSemconvSerializer.FunctionAccess functionAccess = + GenAiSemconvSerializer.getFunctionAccess(call); if (functionAccess != null) { result.put("id", Value.of(functionAccess.id())); result.put( @@ -223,7 +214,7 @@ private static Value buildToolCallEventObject( } private static Value buildFunctionEventObject( - FunctionAccess functionAccess, boolean captureMessageContent) { + GenAiSemconvSerializer.FunctionAccess functionAccess, boolean captureMessageContent) { Map> result = new HashMap<>(); result.put("name", Value.of(functionAccess.name())); if (captureMessageContent) { @@ -232,227 +223,5 @@ private static Value buildFunctionEventObject( return Value.of(result); } - @Nullable - private static FunctionAccess getFunctionAccess(ChatCompletionMessageToolCall call) { - if (V1FunctionAccess.isAvailable()) { - return V1FunctionAccess.create(call); - } - if (V3FunctionAccess.isAvailable()) { - return V3FunctionAccess.create(call); - } - - return null; - } - - private interface FunctionAccess { - String id(); - - String name(); - - String arguments(); - } - - private static String invokeStringHandle(@Nullable MethodHandle methodHandle, Object object) { - if (methodHandle == null) { - return ""; - } - - try { - return (String) methodHandle.invoke(object); - } catch (Throwable ignore) { - return ""; - } - } - - private static class V1FunctionAccess implements FunctionAccess { - @Nullable private static final MethodHandle idHandle; - @Nullable private static final MethodHandle functionHandle; - @Nullable private static final MethodHandle nameHandle; - @Nullable private static final MethodHandle argumentsHandle; - - static { - MethodHandle id; - MethodHandle function; - MethodHandle name; - MethodHandle arguments; - - try { - MethodHandles.Lookup lookup = MethodHandles.lookup(); - id = - lookup.findVirtual( - ChatCompletionMessageToolCall.class, - "id", - MethodType.methodType(String.class)); - Class functionClass = - Class.forName( - "com.openai.models.chat.completions.ChatCompletionMessageToolCall$Function"); - function = - lookup.findVirtual( - ChatCompletionMessageToolCall.class, - "function", - MethodType.methodType(functionClass)); - name = - lookup.findVirtual( - functionClass, "name", MethodType.methodType(String.class)); - arguments = - lookup.findVirtual( - functionClass, "arguments", MethodType.methodType(String.class)); - } catch (Exception exception) { - id = null; - function = null; - name = null; - arguments = null; - } - idHandle = id; - functionHandle = function; - nameHandle = name; - argumentsHandle = arguments; - } - - private final ChatCompletionMessageToolCall toolCall; - private final Object function; - - V1FunctionAccess(ChatCompletionMessageToolCall toolCall, Object function) { - this.toolCall = toolCall; - this.function = function; - } - - @Nullable - static FunctionAccess create(ChatCompletionMessageToolCall toolCall) { - if (functionHandle == null) { - return null; - } - - try { - return new V1FunctionAccess(toolCall, functionHandle.invoke(toolCall)); - } catch (Throwable ignore) { - return null; - } - } - - static boolean isAvailable() { - return idHandle != null; - } - - @Override - public String id() { - return invokeStringHandle(idHandle, toolCall); - } - - @Override - public String name() { - return invokeStringHandle(nameHandle, function); - } - - @Override - public String arguments() { - return invokeStringHandle(argumentsHandle, function); - } - } - - static class V3FunctionAccess implements FunctionAccess { - @Nullable private static final MethodHandle functionToolCallHandle; - @Nullable private static final MethodHandle idHandle; - @Nullable private static final MethodHandle functionHandle; - @Nullable private static final MethodHandle nameHandle; - @Nullable private static final MethodHandle argumentsHandle; - - static { - MethodHandle functionToolCall; - MethodHandle id; - MethodHandle function; - MethodHandle name; - MethodHandle arguments; - - try { - MethodHandles.Lookup lookup = MethodHandles.lookup(); - functionToolCall = - lookup.findVirtual( - ChatCompletionMessageToolCall.class, - "function", - MethodType.methodType(Optional.class)); - Class functionToolCallClass = - Class.forName( - "com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall"); - id = - lookup.findVirtual( - functionToolCallClass, "id", MethodType.methodType(String.class)); - Class functionClass = - Class.forName( - "com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall$Function"); - function = - lookup.findVirtual( - functionToolCallClass, - "function", - MethodType.methodType(functionClass)); - name = - lookup.findVirtual( - functionClass, "name", MethodType.methodType(String.class)); - arguments = - lookup.findVirtual( - functionClass, "arguments", MethodType.methodType(String.class)); - } catch (Exception exception) { - functionToolCall = null; - id = null; - function = null; - name = null; - arguments = null; - } - functionToolCallHandle = functionToolCall; - idHandle = id; - functionHandle = function; - nameHandle = name; - argumentsHandle = arguments; - } - - private final Object functionToolCall; - private final Object function; - - V3FunctionAccess(Object functionToolCall, Object function) { - this.functionToolCall = functionToolCall; - this.function = function; - } - - @Nullable - @SuppressWarnings("unchecked") - static FunctionAccess create(ChatCompletionMessageToolCall toolCall) { - if (functionToolCallHandle == null || functionHandle == null) { - return null; - } - - try { - Optional optional = - (Optional) functionToolCallHandle.invoke(toolCall); - if (!optional.isPresent()) { - return null; - } - Object functionToolCall = optional.get(); - return new V3FunctionAccess( - functionToolCall, functionHandle.invoke(functionToolCall)); - } catch (Throwable ignore) { - return null; - } - } - - static boolean isAvailable() { - return idHandle != null; - } - - @Override - public String id() { - return invokeStringHandle(idHandle, functionToolCall); - } - - @Override - public String name() { - return invokeStringHandle(nameHandle, function); - } - - @Override - public String arguments() { - return invokeStringHandle(argumentsHandle, function); - } - } - private ChatCompletionEventsHelper() {} } diff --git a/src/main/java/dev/braintrust/instrumentation/openai/otel/GenAiSemconvSerializer.java b/src/main/java/dev/braintrust/instrumentation/openai/otel/GenAiSemconvSerializer.java new file mode 100644 index 00000000..ce5aa84e --- /dev/null +++ b/src/main/java/dev/braintrust/instrumentation/openai/otel/GenAiSemconvSerializer.java @@ -0,0 +1,672 @@ +package dev.braintrust.instrumentation.openai.otel; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.module.SimpleModule; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionContentPartImage; +import com.openai.models.chat.completions.ChatCompletionContentPartText; +import com.openai.models.chat.completions.ChatCompletionDeveloperMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessage; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionSystemMessageParam; +import com.openai.models.chat.completions.ChatCompletionToolMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import dev.braintrust.trace.Base64Attachment; +import java.io.IOException; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +final class GenAiSemconvSerializer { + + private static final ObjectMapper JSON_MAPPER = createObjectMapper(); + + private static ObjectMapper createObjectMapper() { + final JsonSerializer attachmentSerializer = + Base64Attachment.createSerializer(); + ObjectMapper mapper = new ObjectMapper(); + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); + SimpleModule module = new SimpleModule(); + module.addSerializer( + ChatCompletionContentPartImage.class, + new JsonSerializer<>() { + @Override + public void serialize( + ChatCompletionContentPartImage value, + JsonGenerator gen, + SerializerProvider serializers) + throws IOException { + try { + var attachment = + Base64Attachment.of( + value.validate().imageUrl().validate().url()); + attachmentSerializer.serialize(attachment, gen, serializers); + } catch (Exception e) { + JsonSerializer defaultSerializer = + serializers.findValueSerializer( + ChatCompletionContentPartImage.class, null); + defaultSerializer.serialize(value, gen, serializers); + } + } + }); + mapper.registerModule(module); + return mapper; + } + + // OTel GenAI Semantic Convention structures + static class SemconvChatMessage { + @JsonProperty("role") + public final String role; + + @JsonProperty("parts") + public final List parts; + + public SemconvChatMessage(String role, List parts) { + this.role = role; + this.parts = parts; + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + static class SemconvOutputChatMessage { + @JsonProperty("role") + public final String role; + + @JsonProperty("parts") + public final List parts; + + @JsonProperty("finish_reason") + public final String finishReason; + + public SemconvOutputChatMessage(String role, List parts, String finishReason) { + this.role = role; + this.parts = parts; + this.finishReason = finishReason; + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + static class TextPart { + @JsonProperty("type") + public final String type = "text"; + + @JsonProperty("content") + public final String content; + + public TextPart(String content) { + this.content = content; + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + static class ToolCallRequestPart { + @JsonProperty("type") + public final String type = "tool_call"; + + @JsonProperty("name") + public final String name; + + @JsonProperty("id") + public final String id; + + @JsonProperty("arguments") + public final String arguments; + + public ToolCallRequestPart(String name, String id, String arguments) { + this.name = name; + this.id = id; + this.arguments = arguments; + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + static class ToolCallResponsePart { + @JsonProperty("type") + public final String type = "tool_call_response"; + + @JsonProperty("id") + public final String id; + + @JsonProperty("response") + public final String response; + + public ToolCallResponsePart(String id, String response) { + this.id = id; + this.response = response; + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + static class GenericPart { + @JsonProperty("type") + public final String type; + + @JsonProperty("data") + public final Object data; + + public GenericPart(String type, Object data) { + this.type = type; + this.data = data; + } + } + + // Transform OpenAI messages to OTel GenAI semconv format + static List transformToSemconvMessages( + List messages) { + return messages.stream() + .map(GenAiSemconvSerializer::transformMessage) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + } + + @Nullable + private static SemconvChatMessage transformMessage(ChatCompletionMessageParam message) { + // Try each union type accessor + var userOpt = tryGetUser(message); + if (userOpt.isPresent()) { + return transformUserMessage(userOpt.get()); + } + + var systemOpt = tryGetSystem(message); + if (systemOpt.isPresent()) { + return transformSystemMessage(systemOpt.get()); + } + + var assistantOpt = tryGetAssistant(message); + if (assistantOpt.isPresent()) { + return transformAssistantMessage(assistantOpt.get()); + } + + var toolOpt = tryGetTool(message); + if (toolOpt.isPresent()) { + return transformToolMessage(toolOpt.get()); + } + + var developerOpt = tryGetDeveloper(message); + if (developerOpt.isPresent()) { + return transformDeveloperMessage(developerOpt.get()); + } + + return null; + } + + private static Optional tryGetUser( + ChatCompletionMessageParam message) { + try { + var method = message.getClass().getMethod("user"); + @SuppressWarnings("unchecked") + var result = (Optional) method.invoke(message); + return result; + } catch (Exception e) { + return Optional.empty(); + } + } + + private static Optional tryGetSystem( + ChatCompletionMessageParam message) { + try { + var method = message.getClass().getMethod("system"); + @SuppressWarnings("unchecked") + var result = (Optional) method.invoke(message); + return result; + } catch (Exception e) { + return Optional.empty(); + } + } + + private static Optional tryGetAssistant( + ChatCompletionMessageParam message) { + try { + var method = message.getClass().getMethod("assistant"); + @SuppressWarnings("unchecked") + var result = (Optional) method.invoke(message); + return result; + } catch (Exception e) { + return Optional.empty(); + } + } + + private static Optional tryGetTool( + ChatCompletionMessageParam message) { + try { + var method = message.getClass().getMethod("tool"); + @SuppressWarnings("unchecked") + var result = (Optional) method.invoke(message); + return result; + } catch (Exception e) { + return Optional.empty(); + } + } + + private static Optional tryGetDeveloper( + ChatCompletionMessageParam message) { + try { + var method = message.getClass().getMethod("developer"); + @SuppressWarnings("unchecked") + var result = (Optional) method.invoke(message); + return result; + } catch (Exception e) { + return Optional.empty(); + } + } + + private static SemconvChatMessage transformUserMessage(ChatCompletionUserMessageParam message) { + List parts = new ArrayList<>(); + ChatCompletionUserMessageParam.Content content = message.content(); + + if (content.isText()) { + parts.add(new TextPart(content.asText())); + } else if (content.isArrayOfContentParts()) { + for (var part : content.asArrayOfContentParts()) { + if (part.isText()) { + parts.add(new TextPart(part.asText().text())); + } else { + // Try to get image part using union type accessor + var imageOpt = tryGetImagePart(part); + if (imageOpt.isPresent()) { + try { + var imageUrl = imageOpt.get().imageUrl().url(); + var attachment = Base64Attachment.of(imageUrl); + parts.add(attachment); + } catch (Exception e) { + log.debug("Failed to parse image URL", e); + } + } + } + } + } + + return new SemconvChatMessage("user", parts); + } + + private static Optional tryGetImagePart( + com.openai.models.chat.completions.ChatCompletionContentPart part) { + try { + var method = part.getClass().getMethod("imageUrl"); + @SuppressWarnings("unchecked") + var result = (Optional) method.invoke(part); + return result; + } catch (Exception e) { + return Optional.empty(); + } + } + + private static SemconvChatMessage transformSystemMessage( + ChatCompletionSystemMessageParam message) { + List parts = new ArrayList<>(); + ChatCompletionSystemMessageParam.Content content = message.content(); + + if (content.isText()) { + parts.add(new TextPart(content.asText())); + } else if (content.isArrayOfContentParts()) { + for (var part : content.asArrayOfContentParts()) { + parts.add(new TextPart(part.text())); + } + } + + return new SemconvChatMessage("system", parts); + } + + private static SemconvChatMessage transformDeveloperMessage( + ChatCompletionDeveloperMessageParam message) { + List parts = new ArrayList<>(); + ChatCompletionDeveloperMessageParam.Content content = message.content(); + + if (content.isText()) { + parts.add(new TextPart(content.asText())); + } else if (content.isArrayOfContentParts()) { + for (var part : content.asArrayOfContentParts()) { + parts.add(new TextPart(part.text())); + } + } + + return new SemconvChatMessage("system", parts); + } + + private static SemconvChatMessage transformAssistantMessage( + ChatCompletionAssistantMessageParam message) { + List parts = new ArrayList<>(); + + // Handle text content + message.content() + .ifPresent( + content -> { + if (content.isText()) { + parts.add(new TextPart(content.asText())); + } else if (content.isArrayOfContentParts()) { + for (var part : content.asArrayOfContentParts()) { + if (part.isText()) { + parts.add(new TextPart(part.asText().text())); + } else if (part.isRefusal()) { + parts.add(new TextPart(part.asRefusal().refusal())); + } + } + } + }); + + // Handle tool calls + message.toolCalls() + .ifPresent( + toolCalls -> { + for (var toolCall : toolCalls) { + FunctionAccess functionAccess = getFunctionAccess(toolCall); + if (functionAccess != null) { + parts.add( + new ToolCallRequestPart( + functionAccess.name(), + functionAccess.id(), + functionAccess.arguments())); + } + } + }); + + return new SemconvChatMessage("assistant", parts); + } + + private static SemconvChatMessage transformToolMessage(ChatCompletionToolMessageParam message) { + List parts = new ArrayList<>(); + String toolCallId = message.toolCallId(); + ChatCompletionToolMessageParam.Content content = message.content(); + + String responseContent = ""; + if (content.isText()) { + responseContent = content.asText(); + } else if (content.isArrayOfContentParts()) { + responseContent = joinContentParts(content.asArrayOfContentParts()); + } + + parts.add(new ToolCallResponsePart(toolCallId, responseContent)); + return new SemconvChatMessage("tool", parts); + } + + // Transform ChatCompletionMessage (output) to OTel GenAI semconv format + static SemconvOutputChatMessage transformOutputMessage( + ChatCompletionMessage message, String finishReason) { + List parts = new ArrayList<>(); + + // Handle text content + message.content() + .ifPresent( + content -> { + if (!content.isEmpty()) { + parts.add(new TextPart(content)); + } + }); + + // Handle tool calls + message.toolCalls() + .ifPresent( + toolCalls -> { + for (var toolCall : toolCalls) { + FunctionAccess functionAccess = getFunctionAccess(toolCall); + if (functionAccess != null) { + parts.add( + new ToolCallRequestPart( + functionAccess.name(), + functionAccess.id(), + functionAccess.arguments())); + } + } + }); + + // The role from ChatCompletionMessage is always "assistant" for output messages + return new SemconvOutputChatMessage("assistant", parts, finishReason); + } + + private static String joinContentParts(List contentParts) { + return contentParts.stream() + .map(ChatCompletionContentPartText::text) + .collect(Collectors.joining()); + } + + @SneakyThrows + static String serializeInputMessages(List messages) { + List semconvMessages = transformToSemconvMessages(messages); + return JSON_MAPPER.writeValueAsString(semconvMessages); + } + + @SneakyThrows + static String serializeOutputMessage(ChatCompletionMessage message, String finishReason) { + SemconvOutputChatMessage outputMessage = transformOutputMessage(message, finishReason); + return JSON_MAPPER.writeValueAsString(new SemconvOutputChatMessage[] {outputMessage}); + } + + @Nullable + static FunctionAccess getFunctionAccess(ChatCompletionMessageToolCall call) { + if (V1FunctionAccess.isAvailable()) { + return V1FunctionAccess.create(call); + } + if (V3FunctionAccess.isAvailable()) { + return V3FunctionAccess.create(call); + } + + return null; + } + + interface FunctionAccess { + String id(); + + String name(); + + String arguments(); + } + + private static String invokeStringHandle(@Nullable MethodHandle methodHandle, Object object) { + if (methodHandle == null) { + return ""; + } + + try { + return (String) methodHandle.invoke(object); + } catch (Throwable ignore) { + return ""; + } + } + + private static class V1FunctionAccess implements FunctionAccess { + @Nullable private static final MethodHandle idHandle; + @Nullable private static final MethodHandle functionHandle; + @Nullable private static final MethodHandle nameHandle; + @Nullable private static final MethodHandle argumentsHandle; + + static { + MethodHandle id; + MethodHandle function; + MethodHandle name; + MethodHandle arguments; + + try { + MethodHandles.Lookup lookup = MethodHandles.lookup(); + id = + lookup.findVirtual( + ChatCompletionMessageToolCall.class, + "id", + MethodType.methodType(String.class)); + Class functionClass = + Class.forName( + "com.openai.models.chat.completions.ChatCompletionMessageToolCall$Function"); + function = + lookup.findVirtual( + ChatCompletionMessageToolCall.class, + "function", + MethodType.methodType(functionClass)); + name = + lookup.findVirtual( + functionClass, "name", MethodType.methodType(String.class)); + arguments = + lookup.findVirtual( + functionClass, "arguments", MethodType.methodType(String.class)); + } catch (Exception exception) { + id = null; + function = null; + name = null; + arguments = null; + } + idHandle = id; + functionHandle = function; + nameHandle = name; + argumentsHandle = arguments; + } + + private final ChatCompletionMessageToolCall toolCall; + private final Object function; + + V1FunctionAccess(ChatCompletionMessageToolCall toolCall, Object function) { + this.toolCall = toolCall; + this.function = function; + } + + @Nullable + static FunctionAccess create(ChatCompletionMessageToolCall toolCall) { + if (functionHandle == null) { + return null; + } + + try { + return new V1FunctionAccess(toolCall, functionHandle.invoke(toolCall)); + } catch (Throwable ignore) { + return null; + } + } + + static boolean isAvailable() { + return idHandle != null; + } + + @Override + public String id() { + return invokeStringHandle(idHandle, toolCall); + } + + @Override + public String name() { + return invokeStringHandle(nameHandle, function); + } + + @Override + public String arguments() { + return invokeStringHandle(argumentsHandle, function); + } + } + + static class V3FunctionAccess implements FunctionAccess { + @Nullable private static final MethodHandle functionToolCallHandle; + @Nullable private static final MethodHandle idHandle; + @Nullable private static final MethodHandle functionHandle; + @Nullable private static final MethodHandle nameHandle; + @Nullable private static final MethodHandle argumentsHandle; + + static { + MethodHandle functionToolCall; + MethodHandle id; + MethodHandle function; + MethodHandle name; + MethodHandle arguments; + + try { + MethodHandles.Lookup lookup = MethodHandles.lookup(); + functionToolCall = + lookup.findVirtual( + ChatCompletionMessageToolCall.class, + "function", + MethodType.methodType(Optional.class)); + Class functionToolCallClass = + Class.forName( + "com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall"); + id = + lookup.findVirtual( + functionToolCallClass, "id", MethodType.methodType(String.class)); + Class functionClass = + Class.forName( + "com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall$Function"); + function = + lookup.findVirtual( + functionToolCallClass, + "function", + MethodType.methodType(functionClass)); + name = + lookup.findVirtual( + functionClass, "name", MethodType.methodType(String.class)); + arguments = + lookup.findVirtual( + functionClass, "arguments", MethodType.methodType(String.class)); + } catch (Exception exception) { + functionToolCall = null; + id = null; + function = null; + name = null; + arguments = null; + } + functionToolCallHandle = functionToolCall; + idHandle = id; + functionHandle = function; + nameHandle = name; + argumentsHandle = arguments; + } + + private final Object functionToolCall; + private final Object function; + + V3FunctionAccess(Object functionToolCall, Object function) { + this.functionToolCall = functionToolCall; + this.function = function; + } + + @Nullable + @SuppressWarnings("unchecked") + static FunctionAccess create(ChatCompletionMessageToolCall toolCall) { + if (functionToolCallHandle == null || functionHandle == null) { + return null; + } + + try { + Optional optional = + (Optional) functionToolCallHandle.invoke(toolCall); + if (!optional.isPresent()) { + return null; + } + Object functionToolCall = optional.get(); + return new V3FunctionAccess( + functionToolCall, functionHandle.invoke(functionToolCall)); + } catch (Throwable ignore) { + return null; + } + } + + static boolean isAvailable() { + return idHandle != null; + } + + @Override + public String id() { + return invokeStringHandle(idHandle, functionToolCall); + } + + @Override + public String name() { + return invokeStringHandle(nameHandle, function); + } + + @Override + public String arguments() { + return invokeStringHandle(argumentsHandle, function); + } + } + + private GenAiSemconvSerializer() {} +} diff --git a/src/main/java/dev/braintrust/trace/Base64Attachment.java b/src/main/java/dev/braintrust/trace/Base64Attachment.java new file mode 100644 index 00000000..aae9f37f --- /dev/null +++ b/src/main/java/dev/braintrust/trace/Base64Attachment.java @@ -0,0 +1,127 @@ +package dev.braintrust.trace; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Base64; +import java.util.Objects; +import javax.annotation.Nonnull; +import lombok.Getter; + +/** + * Utility to serialize LLM attachment data in a braintrust-friendly manner. + * + *

Users of the SDK likely don't need to use this utility directly because instrumentation will + * properly serialize messages out of the box. + * + *

The serialized json will conform to the otel input/output GenericPart schema. See + * https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-input-messages.json and + * https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-output-messages.json + */ +public class Base64Attachment { + @JsonProperty("type") + @Getter + private final String type = "base64_attachment"; + + @JsonProperty("content") + @Getter + private final String base64Data; + + private Base64Attachment(@Nonnull String base64Data) { + if (Objects.requireNonNull(base64Data).isEmpty()) { + throw new IllegalArgumentException("base64Data cannot be empty"); + } + // Check for data URL prefix (e.g., "data:image/png;base64,...") + if (!base64Data.startsWith("data:") || !base64Data.contains(";base64,")) { + throw new IllegalArgumentException( + "base64Data must be a data URL with format:" + + " data:;base64,"); + } + + this.base64Data = base64Data; + } + + /** + * Create a new attachment out of base64 data + * + * @param base64DataUri must conform to data:(content-type);base64,BYTES + */ + public static Base64Attachment of(String base64DataUri) { + return new Base64Attachment(base64DataUri); + } + + /** convenience utility to convert a file to a base64 attachment */ + public static Base64Attachment ofFile(ContentType contentType, String filePath) { + try { + Path path = Paths.get(filePath); + byte[] fileBytes = Files.readAllBytes(path); + String base64Encoded = Base64.getEncoder().encodeToString(fileBytes); + String dataUrl = "data:" + contentType.getMimeType() + ";base64," + base64Encoded; + return of(dataUrl); + } catch (IOException e) { + throw new RuntimeException("Failed to read file: " + filePath, e); + } + } + + /** create a jackson serializer for attachment data */ + public static JsonSerializer createSerializer() { + return new JsonSerializer<>() { + @Override + public void serialize( + Base64Attachment value, JsonGenerator gen, SerializerProvider serializers) + throws IOException { + gen.writeStartObject(); + try { + gen.writeStringField("type", value.type); + gen.writeStringField("content", value.base64Data); + } finally { + gen.writeEndObject(); + } + } + }; + } + + public static class ContentType { + // Common image formats + public static ContentType IMAGE_PNG = new ContentType("image/png"); + public static ContentType IMAGE_JPEG = new ContentType("image/jpeg"); + public static ContentType IMAGE_GIF = new ContentType("image/gif"); + public static ContentType IMAGE_WEBP = new ContentType("image/webp"); + public static ContentType IMAGE_SVG = new ContentType("image/svg+xml"); + + // Common document formats + public static ContentType APPLICATION_PDF = new ContentType("application/pdf"); + public static ContentType TEXT_PLAIN = new ContentType("text/plain"); + public static ContentType APPLICATION_JSON = new ContentType("application/json"); + + public static ContentType of(@Nonnull String mimeType) { + return new ContentType(mimeType); + } + + @Getter private final @Nonnull String mimeType; + + @Override + public int hashCode() { + return mimeType.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof ContentType) { + return mimeType.equals(((ContentType) obj).mimeType); + } else { + return super.equals(obj); + } + } + + private ContentType(@Nonnull String mimeType) { + Objects.requireNonNull(mimeType); + this.mimeType = mimeType.toLowerCase(); + } + } +} diff --git a/src/main/java/dev/braintrust/trace/BraintrustSpanExporter.java b/src/main/java/dev/braintrust/trace/BraintrustSpanExporter.java index a38c245f..00c4aac0 100644 --- a/src/main/java/dev/braintrust/trace/BraintrustSpanExporter.java +++ b/src/main/java/dev/braintrust/trace/BraintrustSpanExporter.java @@ -48,6 +48,7 @@ public CompletableResultCode export(Collection spans) { // Combine all results var combined = CompletableResultCode.ofAll(results); + log.debug("span export results: {}", combined.isSuccess()); return combined; } diff --git a/src/test/java/dev/braintrust/instrumentation/openai/BraintrustOpenAITest.java b/src/test/java/dev/braintrust/instrumentation/openai/BraintrustOpenAITest.java index 434adaba..79262de7 100644 --- a/src/test/java/dev/braintrust/instrumentation/openai/BraintrustOpenAITest.java +++ b/src/test/java/dev/braintrust/instrumentation/openai/BraintrustOpenAITest.java @@ -10,13 +10,19 @@ import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionContentPart; +import com.openai.models.chat.completions.ChatCompletionContentPartImage; +import com.openai.models.chat.completions.ChatCompletionContentPartText; import com.openai.models.chat.completions.ChatCompletionCreateParams; import com.openai.models.chat.completions.ChatCompletionStreamOptions; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; import dev.braintrust.config.BraintrustConfig; +import dev.braintrust.trace.Base64Attachment; import dev.braintrust.trace.BraintrustTracing; import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.sdk.OpenTelemetrySdk; +import java.util.Arrays; import java.util.concurrent.TimeUnit; import lombok.SneakyThrows; import org.junit.jupiter.api.BeforeEach; @@ -140,10 +146,11 @@ void testWrapOpenAi() { "chatcmpl-test123", span.getAttributes().get(AttributeKey.stringKey("gen_ai.response.id"))); assertEquals( - "[{\"content\":\"You are a helpful" - + " assistant\",\"role\":\"system\",\"valid\":true},{\"content\":\"What is the" - + " capital of France?\",\"role\":\"user\",\"valid\":true}]", - span.getAttributes().get(AttributeKey.stringKey("braintrust.input_json"))); + "[{\"role\":\"system\",\"parts\":[{\"type\":\"text\",\"content\":\"You are a" + + " helpful" + + " assistant\"}]},{\"role\":\"user\",\"parts\":[{\"type\":\"text\",\"content\":\"What" + + " is the capital of France?\"}]}]", + span.getAttributes().get(AttributeKey.stringKey("gen_ai.input.messages"))); assertEquals( "project_name:unit-test-project", span.getAttributes().get(AttributeKey.stringKey("braintrust.parent"))); @@ -156,12 +163,11 @@ void testWrapOpenAi() { span.getAttributes().get(AttributeKey.doubleKey("gen_ai.request.temperature"))); String outputJson = - span.getAttributes().get(AttributeKey.stringKey("braintrust.output_json")); - assertNotNull(outputJson); - var outputMessages = JSON_MAPPER.readTree(outputJson); - assertEquals(1, outputMessages.size()); - var messageZero = outputMessages.get(0); - assertEquals("The capital of France is Paris.", messageZero.get("content").asText()); + span.getAttributes().get(AttributeKey.stringKey("gen_ai.output.messages")); + assertEquals( + "[{\"role\":\"assistant\",\"parts\":[{\"type\":\"text\",\"content\":\"The capital" + + " of France is Paris.\"}],\"finish_reason\":\"stop\"}]", + outputJson); assertEquals( "chatcmpl-test123", @@ -295,4 +301,158 @@ void testWrapOpenAiStreaming() { var messageZero = outputMessages.get(0); assertEquals("The capital of France is Paris.", messageZero.get("content").asText()); } + + @Test + @SneakyThrows + void testWrapOpenAiWithImageAttachment() { + // Mock the OpenAI API response for vision request + wireMock.stubFor( + post(urlEqualTo("/chat/completions")) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "id": "chatcmpl-test456", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "This image shows the Eiffel Tower in Paris, France." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 150, + "completion_tokens": 15, + "total_tokens": 165 + } + } + """))); + + var openTelemetry = (OpenTelemetrySdk) BraintrustTracing.of(config, true); + + // Create OpenAI client pointing to WireMock server + OpenAIClient openAIClient = + OpenAIOkHttpClient.builder() + .baseUrl("http://localhost:" + wireMock.getPort()) + .apiKey("test-api-key") + .build(); + + // Wrap with Braintrust instrumentation + openAIClient = BraintrustOpenAI.wrapOpenAI(openTelemetry, openAIClient); + + String imageDataUrl = + Base64Attachment.ofFile( + Base64Attachment.ContentType.IMAGE_JPEG, + "src/test/java/dev/braintrust/instrumentation/openai/travel-paris-france-poster.jpg") + .getBase64Data(); + + // Create text content part + ChatCompletionContentPartText textPart = + ChatCompletionContentPartText.builder().text("What's in this image?").build(); + ChatCompletionContentPart textContentPart = ChatCompletionContentPart.ofText(textPart); + + // Create image content part with base64-encoded image + ChatCompletionContentPartImage imagePart = + ChatCompletionContentPartImage.builder() + .imageUrl( + ChatCompletionContentPartImage.ImageUrl.builder() + // .url("https://example.com/eiffel-tower.jpg") + .url(imageDataUrl) + .detail(ChatCompletionContentPartImage.ImageUrl.Detail.HIGH) + .build()) + .build(); + ChatCompletionContentPart imageContentPart = + ChatCompletionContentPart.ofImageUrl(imagePart); + + // Create user message with both text and image + ChatCompletionUserMessageParam userMessage = + ChatCompletionUserMessageParam.builder() + .contentOfArrayOfContentParts( + Arrays.asList(textContentPart, imageContentPart)) + .build(); + + var request = + ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_4O_MINI) + .addSystemMessage("You are a helpful assistant that can analyze images") + .addMessage(userMessage) + .temperature(0.0) + .build(); + + var response = openAIClient.chat().completions().create(request); + + // Verify the response + assertNotNull(response); + wireMock.verify(1, postRequestedFor(urlEqualTo("/chat/completions"))); + assertEquals("chatcmpl-test456", response.id()); + assertEquals( + "This image shows the Eiffel Tower in Paris, France.", + response.choices().get(0).message().content().get()); + + // Verify spans were exported + assertTrue( + openTelemetry + .getSdkTracerProvider() + .forceFlush() + .join(10, TimeUnit.SECONDS) + .isSuccess()); + var spanData = + getExportedBraintrustSpans().get(config.getBraintrustParentValue().orElseThrow()); + assertNotNull(spanData); + assertEquals(1, spanData.size()); + var span = spanData.get(0); + + // Verify span attributes + assertEquals("openai", span.getAttributes().get(AttributeKey.stringKey("gen_ai.system"))); + assertEquals( + "gpt-4o-mini", + span.getAttributes().get(AttributeKey.stringKey("gen_ai.request.model"))); + assertEquals( + "gpt-4o-mini", + span.getAttributes().get(AttributeKey.stringKey("gen_ai.response.model"))); + assertEquals( + "[stop]", + span.getAttributes() + .get(AttributeKey.stringArrayKey("gen_ai.response.finish_reasons")) + .toString()); + assertEquals( + "chat", span.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "chatcmpl-test456", + span.getAttributes().get(AttributeKey.stringKey("gen_ai.response.id"))); + + // Verify input JSON captures both text and image content + String inputJson = + span.getAttributes().get(AttributeKey.stringKey("gen_ai.input.messages")); + assertEquals( + "[{\"role\":\"system\",\"parts\":[{\"type\":\"text\",\"content\":\"You are a helpful assistant that can analyze images\"}]},{\"role\":\"user\",\"parts\":[{\"type\":\"text\",\"content\":\"What's in this image?\"},{\"type\":\"base64_attachment\",\"content\":\"%s\"}]}]" + .formatted(imageDataUrl), + inputJson); + assertNotNull(inputJson); + var inputMessages = JSON_MAPPER.readTree(inputJson); + assertEquals(2, inputMessages.size()); // system message + user message + + // Verify usage metrics + assertEquals( + 150L, span.getAttributes().get(AttributeKey.longKey("gen_ai.usage.input_tokens"))); + assertEquals( + 15L, span.getAttributes().get(AttributeKey.longKey("gen_ai.usage.output_tokens"))); + + // Verify output JSON + String outputJson = + span.getAttributes().get(AttributeKey.stringKey("gen_ai.output.messages")); + assertEquals( + "[{\"role\":\"assistant\",\"parts\":[{\"type\":\"text\",\"content\":\"This image" + + " shows the Eiffel Tower in Paris, France.\"}],\"finish_reason\":\"stop\"}]", + outputJson); + } } diff --git a/src/test/java/dev/braintrust/instrumentation/openai/travel-paris-france-poster.jpg b/src/test/java/dev/braintrust/instrumentation/openai/travel-paris-france-poster.jpg new file mode 100644 index 00000000..f0164667 Binary files /dev/null and b/src/test/java/dev/braintrust/instrumentation/openai/travel-paris-france-poster.jpg differ diff --git a/src/test/java/dev/braintrust/trace/Base64AttachmentTest.java b/src/test/java/dev/braintrust/trace/Base64AttachmentTest.java new file mode 100644 index 00000000..bf3a7ce6 --- /dev/null +++ b/src/test/java/dev/braintrust/trace/Base64AttachmentTest.java @@ -0,0 +1,128 @@ +package dev.braintrust.trace; + +import static org.junit.jupiter.api.Assertions.*; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.module.SimpleModule; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Base64; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class Base64AttachmentTest { + private static final ObjectMapper JSON_MAPPER = createObjectMapper(); + + private static ObjectMapper createObjectMapper() { + ObjectMapper mapper = new ObjectMapper(); + SimpleModule module = new SimpleModule(); + module.addSerializer(Base64Attachment.class, Base64Attachment.createSerializer()); + mapper.registerModule(module); + return mapper; + } + + @Test + void testOfWithValidDataUrl() { + String validDataUrl = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUA"; + Base64Attachment attachment = Base64Attachment.of(validDataUrl); + assertNotNull(attachment); + } + + @Test + void testBadBase64Data() { + assertThrows(Exception.class, () -> Base64Attachment.of(null)); + assertThrows(Exception.class, () -> Base64Attachment.of("")); + } + + @Test + void testOfWithoutDataPrefixThrowsException() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> Base64Attachment.of("image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUA")); + assertTrue(exception.getMessage().contains("data URL with format")); + } + + @Test + void testOfBase64WithoutMarkerThrowsException() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> Base64Attachment.of("data:image/png,iVBORw0KGgoAAAANSUhEUgAAAAUA")); + assertTrue(exception.getMessage().contains("data URL with format")); + } + + @Test + void testOfFileWithNonExistentFileThrowsException() { + RuntimeException exception = + assertThrows( + RuntimeException.class, + () -> + Base64Attachment.ofFile( + Base64Attachment.ContentType.IMAGE_PNG, + "/nonexistent/path/to/file.png")); + assertTrue(exception.getMessage().contains("Failed to read file")); + } + + @Test + @SneakyThrows + void testFileCreatesBase64Content(@TempDir Path tempDir) { + // Create a test file + Path testFile = tempDir.resolve("test.jpg"); + byte[] testData = "test jpeg data".getBytes(); + Files.write(testFile, testData); + + // Create attachment from file + Base64Attachment attachment = + Base64Attachment.ofFile( + Base64Attachment.ContentType.IMAGE_JPEG, testFile.toString()); + + // Serialize to JSON to verify the data URL format + String json = JSON_MAPPER.writeValueAsString(attachment); + + // Parse JSON and verify structure + var jsonNode = JSON_MAPPER.readTree(json); + assertEquals(2, jsonNode.size()); + + assertEquals("base64_attachment", jsonNode.get("type").asText()); + + String content = jsonNode.get("content").asText(); + assertTrue(content.startsWith("data:image/jpeg;base64,")); + + // Verify the base64 data is correct + String base64Part = content.substring("data:image/jpeg;base64,".length()); + byte[] decodedData = Base64.getDecoder().decode(base64Part); + assertArrayEquals(testData, decodedData); + } + + @Test + void testContentTypeOfNormalizesToLowercase() { + var customType = Base64Attachment.ContentType.of("IMAGE/PNG"); + assertEquals("image/png", customType.getMimeType()); + } + + @Test + void testContentTypeOfWithNullThrowsException() { + assertThrows(NullPointerException.class, () -> Base64Attachment.ContentType.of(null)); + } + + @Test + void testContentTypeEquality() { + var type1 = Base64Attachment.ContentType.of("image/png"); + var type2 = Base64Attachment.ContentType.of("image/png"); + var type3 = Base64Attachment.ContentType.of("image/jpeg"); + + assertEquals(type1, type2); + assertNotEquals(type1, type3); + assertEquals(type1, Base64Attachment.ContentType.IMAGE_PNG); + } + + @Test + void testContentTypeHashCode() { + var type1 = Base64Attachment.ContentType.of("image/png"); + var type2 = Base64Attachment.ContentType.of("image/png"); + + assertEquals(type1.hashCode(), type2.hashCode()); + } +}