Skip to content

Commit 9a3721b

Browse files
authored
Merge pull request #13 from ddobrin/main
Merge main into planner
2 parents 9ca4384 + 34d8c3d commit 9a3721b

31 files changed

Lines changed: 1459 additions & 739 deletions

File tree

contrib/firestore-session-service/pom.xml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
-->
17-
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
17+
<project xmlns="http://maven.apache.org/POM/4.0.0"
18+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
19+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
1820
<modelVersion>4.0.0</modelVersion>
1921

2022
<parent>
@@ -49,7 +51,6 @@
4951
<dependency>
5052
<groupId>com.google.cloud</groupId>
5153
<artifactId>google-cloud-firestore</artifactId>
52-
<version>3.30.3</version>
5354
</dependency>
5455
<dependency>
5556
<groupId>com.google.truth</groupId>

contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java

Lines changed: 154 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,19 @@
2323
import com.google.adk.models.BaseLlmConnection;
2424
import com.google.adk.models.LlmRequest;
2525
import com.google.adk.models.LlmResponse;
26+
import com.google.auto.value.AutoValue;
2627
import com.google.genai.types.Blob;
2728
import com.google.genai.types.Content;
2829
import com.google.genai.types.FunctionCall;
2930
import com.google.genai.types.FunctionCallingConfigMode;
3031
import com.google.genai.types.FunctionDeclaration;
3132
import com.google.genai.types.FunctionResponse;
3233
import com.google.genai.types.GenerateContentConfig;
34+
import com.google.genai.types.GenerateContentResponseUsageMetadata;
3335
import com.google.genai.types.Part;
3436
import com.google.genai.types.Schema;
3537
import com.google.genai.types.ToolConfig;
3638
import com.google.genai.types.Type;
37-
import dev.langchain4j.Experimental;
3839
import dev.langchain4j.agent.tool.ToolExecutionRequest;
3940
import dev.langchain4j.agent.tool.ToolSpecification;
4041
import dev.langchain4j.data.audio.Audio;
@@ -52,6 +53,7 @@
5253
import dev.langchain4j.data.pdf.PdfFile;
5354
import dev.langchain4j.data.video.Video;
5455
import dev.langchain4j.exception.UnsupportedFeatureException;
56+
import dev.langchain4j.model.TokenCountEstimator;
5557
import dev.langchain4j.model.chat.ChatModel;
5658
import dev.langchain4j.model.chat.StreamingChatModel;
5759
import dev.langchain4j.model.chat.request.ChatRequest;
@@ -65,128 +67,167 @@
6567
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
6668
import dev.langchain4j.model.chat.response.ChatResponse;
6769
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
70+
import dev.langchain4j.model.output.TokenUsage;
6871
import io.reactivex.rxjava3.core.BackpressureStrategy;
6972
import io.reactivex.rxjava3.core.Flowable;
7073
import java.util.ArrayList;
7174
import java.util.Base64;
7275
import java.util.HashMap;
7376
import java.util.List;
7477
import java.util.Map;
75-
import java.util.Objects;
7678
import java.util.UUID;
79+
import org.jspecify.annotations.Nullable;
7780

78-
@Experimental
79-
public class LangChain4j extends BaseLlm {
81+
@AutoValue
82+
public abstract class LangChain4j extends BaseLlm {
8083

8184
private static final TypeReference<Map<String, Object>> MAP_TYPE_REFERENCE =
8285
new TypeReference<>() {};
8386

84-
private final ChatModel chatModel;
85-
private final StreamingChatModel streamingChatModel;
86-
private final ObjectMapper objectMapper;
87+
LangChain4j() {
88+
super("");
89+
}
90+
91+
@Nullable
92+
public abstract ChatModel chatModel();
93+
94+
@Nullable
95+
public abstract StreamingChatModel streamingChatModel();
96+
97+
public abstract ObjectMapper objectMapper();
98+
99+
public abstract String modelName();
100+
101+
@Nullable
102+
public abstract TokenCountEstimator tokenCountEstimator();
103+
104+
@Override
105+
public String model() {
106+
return modelName();
107+
}
108+
109+
public static Builder builder() {
110+
return new AutoValue_LangChain4j.Builder().objectMapper(new ObjectMapper());
111+
}
112+
113+
@AutoValue.Builder
114+
public abstract static class Builder {
115+
public abstract Builder chatModel(ChatModel chatModel);
116+
117+
public abstract Builder streamingChatModel(StreamingChatModel streamingChatModel);
118+
119+
public abstract Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator);
120+
121+
public abstract Builder objectMapper(ObjectMapper objectMapper);
122+
123+
public abstract Builder modelName(String modelName);
124+
125+
public abstract LangChain4j build();
126+
}
87127

88128
public LangChain4j(ChatModel chatModel) {
89-
super(
90-
Objects.requireNonNull(
91-
chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null"));
92-
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
93-
this.streamingChatModel = null;
94-
this.objectMapper = new ObjectMapper();
129+
this(chatModel, null, null, chatModel.defaultRequestParameters().modelName(), null);
95130
}
96131

97132
public LangChain4j(ChatModel chatModel, String modelName) {
98-
super(Objects.requireNonNull(modelName, "chat model name cannot be null"));
99-
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
100-
this.streamingChatModel = null;
101-
this.objectMapper = new ObjectMapper();
133+
this(chatModel, null, null, modelName, null);
102134
}
103135

104136
public LangChain4j(StreamingChatModel streamingChatModel) {
105-
super(
106-
Objects.requireNonNull(
107-
streamingChatModel.defaultRequestParameters().modelName(),
108-
"streaming chat model name cannot be null"));
109-
this.chatModel = null;
110-
this.streamingChatModel =
111-
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
112-
this.objectMapper = new ObjectMapper();
137+
this(
138+
null,
139+
streamingChatModel,
140+
null,
141+
streamingChatModel.defaultRequestParameters().modelName(),
142+
null);
113143
}
114144

115145
public LangChain4j(StreamingChatModel streamingChatModel, String modelName) {
116-
super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null"));
117-
this.chatModel = null;
118-
this.streamingChatModel =
119-
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
120-
this.objectMapper = new ObjectMapper();
146+
this(null, streamingChatModel, null, modelName, null);
121147
}
122148

123149
public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) {
124-
super(Objects.requireNonNull(modelName, "model name cannot be null"));
125-
this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null");
126-
this.streamingChatModel =
127-
Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null");
128-
this.objectMapper = new ObjectMapper();
150+
this(chatModel, streamingChatModel, null, modelName, null);
151+
}
152+
153+
private LangChain4j(
154+
ChatModel chatModel,
155+
StreamingChatModel streamingChatModel,
156+
ObjectMapper objectMapper,
157+
String modelName,
158+
TokenCountEstimator tokenCountEstimator) {
159+
this();
160+
LangChain4j.builder()
161+
.chatModel(chatModel)
162+
.streamingChatModel(streamingChatModel)
163+
.objectMapper(objectMapper)
164+
.modelName(modelName)
165+
.tokenCountEstimator(tokenCountEstimator)
166+
.build();
129167
}
130168

131169
@Override
132170
public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stream) {
133171
if (stream) {
134-
if (this.streamingChatModel == null) {
172+
if (this.streamingChatModel() == null) {
135173
return Flowable.error(new IllegalStateException("StreamingChatModel is not configured"));
136174
}
137175

138176
ChatRequest chatRequest = toChatRequest(llmRequest);
139177

140178
return Flowable.create(
141179
emitter -> {
142-
streamingChatModel.chat(
143-
chatRequest,
144-
new StreamingChatResponseHandler() {
145-
@Override
146-
public void onPartialResponse(String s) {
147-
emitter.onNext(
148-
LlmResponse.builder().content(Content.fromParts(Part.fromText(s))).build());
149-
}
150-
151-
@Override
152-
public void onCompleteResponse(ChatResponse chatResponse) {
153-
if (chatResponse.aiMessage().hasToolExecutionRequests()) {
154-
AiMessage aiMessage = chatResponse.aiMessage();
155-
toParts(aiMessage).stream()
156-
.map(Part::functionCall)
157-
.forEach(
158-
functionCall -> {
159-
functionCall.ifPresent(
160-
function -> {
161-
emitter.onNext(
162-
LlmResponse.builder()
163-
.content(
164-
Content.fromParts(
165-
Part.fromFunctionCall(
166-
function.name().orElse(""),
167-
function.args().orElse(Map.of()))))
168-
.build());
169-
});
170-
});
171-
}
172-
emitter.onComplete();
173-
}
174-
175-
@Override
176-
public void onError(Throwable throwable) {
177-
emitter.onError(throwable);
178-
}
179-
});
180+
streamingChatModel()
181+
.chat(
182+
chatRequest,
183+
new StreamingChatResponseHandler() {
184+
@Override
185+
public void onPartialResponse(String s) {
186+
emitter.onNext(
187+
LlmResponse.builder()
188+
.content(Content.fromParts(Part.fromText(s)))
189+
.build());
190+
}
191+
192+
@Override
193+
public void onCompleteResponse(ChatResponse chatResponse) {
194+
if (chatResponse.aiMessage().hasToolExecutionRequests()) {
195+
AiMessage aiMessage = chatResponse.aiMessage();
196+
toParts(aiMessage).stream()
197+
.map(Part::functionCall)
198+
.forEach(
199+
functionCall -> {
200+
functionCall.ifPresent(
201+
function -> {
202+
emitter.onNext(
203+
LlmResponse.builder()
204+
.content(
205+
Content.fromParts(
206+
Part.fromFunctionCall(
207+
function.name().orElse(""),
208+
function.args().orElse(Map.of()))))
209+
.build());
210+
});
211+
});
212+
}
213+
emitter.onComplete();
214+
}
215+
216+
@Override
217+
public void onError(Throwable throwable) {
218+
emitter.onError(throwable);
219+
}
220+
});
180221
},
181222
BackpressureStrategy.BUFFER);
182223
} else {
183-
if (this.chatModel == null) {
224+
if (this.chatModel() == null) {
184225
return Flowable.error(new IllegalStateException("ChatModel is not configured"));
185226
}
186227

187228
ChatRequest chatRequest = toChatRequest(llmRequest);
188-
ChatResponse chatResponse = chatModel.chat(chatRequest);
189-
LlmResponse llmResponse = toLlmResponse(chatResponse);
229+
ChatResponse chatResponse = chatModel().chat(chatRequest);
230+
LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest);
190231

191232
return Flowable.just(llmResponse);
192233
}
@@ -413,7 +454,7 @@ private AiMessage toAiMessage(Content content) {
413454

414455
private String toJson(Object object) {
415456
try {
416-
return objectMapper.writeValueAsString(object);
457+
return objectMapper().writeValueAsString(object);
417458
} catch (JsonProcessingException e) {
418459
throw new RuntimeException(e);
419460
}
@@ -511,11 +552,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) {
511552
}
512553
}
513554

514-
private LlmResponse toLlmResponse(ChatResponse chatResponse) {
555+
private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) {
515556
Content content =
516557
Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build();
517558

518-
return LlmResponse.builder().content(content).build();
559+
LlmResponse.Builder builder = LlmResponse.builder().content(content);
560+
TokenUsage tokenUsage = chatResponse.tokenUsage();
561+
if (tokenCountEstimator() != null) {
562+
try {
563+
int estimatedInput =
564+
tokenCountEstimator().estimateTokenCountInMessages(chatRequest.messages());
565+
int estimatedOutput =
566+
tokenCountEstimator().estimateTokenCountInText(chatResponse.aiMessage().text());
567+
int estimatedTotal = estimatedInput + estimatedOutput;
568+
builder.usageMetadata(
569+
GenerateContentResponseUsageMetadata.builder()
570+
.promptTokenCount(estimatedInput)
571+
.candidatesTokenCount(estimatedOutput)
572+
.totalTokenCount(estimatedTotal)
573+
.build());
574+
} catch (Exception e) {
575+
e.printStackTrace();
576+
}
577+
} else if (tokenUsage != null) {
578+
builder.usageMetadata(
579+
GenerateContentResponseUsageMetadata.builder()
580+
.promptTokenCount(tokenUsage.inputTokenCount())
581+
.candidatesTokenCount(tokenUsage.outputTokenCount())
582+
.totalTokenCount(tokenUsage.totalTokenCount())
583+
.build());
584+
}
585+
586+
return builder.build();
519587
}
520588

521589
private List<Part> toParts(AiMessage aiMessage) {
@@ -539,14 +607,17 @@ private List<Part> toParts(AiMessage aiMessage) {
539607
});
540608
return parts;
541609
} else {
542-
Part part = Part.builder().text(aiMessage.text()).build();
543-
return List.of(part);
610+
String text = aiMessage.text();
611+
if (text == null) {
612+
return List.of();
613+
}
614+
return List.of(Part.builder().text(text).build());
544615
}
545616
}
546617

547618
private Map<String, Object> toArgs(ToolExecutionRequest toolExecutionRequest) {
548619
try {
549-
return objectMapper.readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE);
620+
return objectMapper().readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE);
550621
} catch (JsonProcessingException e) {
551622
throw new RuntimeException(e);
552623
}

0 commit comments

Comments
 (0)