2323import com .google .adk .models .BaseLlmConnection ;
2424import com .google .adk .models .LlmRequest ;
2525import com .google .adk .models .LlmResponse ;
26+ import com .google .auto .value .AutoValue ;
2627import com .google .genai .types .Blob ;
2728import com .google .genai .types .Content ;
2829import com .google .genai .types .FunctionCall ;
2930import com .google .genai .types .FunctionCallingConfigMode ;
3031import com .google .genai .types .FunctionDeclaration ;
3132import com .google .genai .types .FunctionResponse ;
3233import com .google .genai .types .GenerateContentConfig ;
34+ import com .google .genai .types .GenerateContentResponseUsageMetadata ;
3335import com .google .genai .types .Part ;
3436import com .google .genai .types .Schema ;
3537import com .google .genai .types .ToolConfig ;
3638import com .google .genai .types .Type ;
37- import dev .langchain4j .Experimental ;
3839import dev .langchain4j .agent .tool .ToolExecutionRequest ;
3940import dev .langchain4j .agent .tool .ToolSpecification ;
4041import dev .langchain4j .data .audio .Audio ;
5253import dev .langchain4j .data .pdf .PdfFile ;
5354import dev .langchain4j .data .video .Video ;
5455import dev .langchain4j .exception .UnsupportedFeatureException ;
56+ import dev .langchain4j .model .TokenCountEstimator ;
5557import dev .langchain4j .model .chat .ChatModel ;
5658import dev .langchain4j .model .chat .StreamingChatModel ;
5759import dev .langchain4j .model .chat .request .ChatRequest ;
6567import dev .langchain4j .model .chat .request .json .JsonStringSchema ;
6668import dev .langchain4j .model .chat .response .ChatResponse ;
6769import dev .langchain4j .model .chat .response .StreamingChatResponseHandler ;
70+ import dev .langchain4j .model .output .TokenUsage ;
6871import io .reactivex .rxjava3 .core .BackpressureStrategy ;
6972import io .reactivex .rxjava3 .core .Flowable ;
7073import java .util .ArrayList ;
7174import java .util .Base64 ;
7275import java .util .HashMap ;
7376import java .util .List ;
7477import java .util .Map ;
75- import java .util .Objects ;
7678import java .util .UUID ;
79+ import org .jspecify .annotations .Nullable ;
7780
78- @ Experimental
79- public class LangChain4j extends BaseLlm {
81+ @ AutoValue
82+ public abstract class LangChain4j extends BaseLlm {
8083
8184 private static final TypeReference <Map <String , Object >> MAP_TYPE_REFERENCE =
8285 new TypeReference <>() {};
8386
84- private final ChatModel chatModel ;
85- private final StreamingChatModel streamingChatModel ;
86- private final ObjectMapper objectMapper ;
87+ LangChain4j () {
88+ super ("" );
89+ }
90+
91+ @ Nullable
92+ public abstract ChatModel chatModel ();
93+
94+ @ Nullable
95+ public abstract StreamingChatModel streamingChatModel ();
96+
97+ public abstract ObjectMapper objectMapper ();
98+
99+ public abstract String modelName ();
100+
101+ @ Nullable
102+ public abstract TokenCountEstimator tokenCountEstimator ();
103+
104+ @ Override
105+ public String model () {
106+ return modelName ();
107+ }
108+
109+ public static Builder builder () {
110+ return new AutoValue_LangChain4j .Builder ().objectMapper (new ObjectMapper ());
111+ }
112+
113+ @ AutoValue .Builder
114+ public abstract static class Builder {
115+ public abstract Builder chatModel (ChatModel chatModel );
116+
117+ public abstract Builder streamingChatModel (StreamingChatModel streamingChatModel );
118+
119+ public abstract Builder tokenCountEstimator (TokenCountEstimator tokenCountEstimator );
120+
121+ public abstract Builder objectMapper (ObjectMapper objectMapper );
122+
123+ public abstract Builder modelName (String modelName );
124+
125+ public abstract LangChain4j build ();
126+ }
87127
88128 public LangChain4j (ChatModel chatModel ) {
89- super (
90- Objects .requireNonNull (
91- chatModel .defaultRequestParameters ().modelName (), "chat model name cannot be null" ));
92- this .chatModel = Objects .requireNonNull (chatModel , "chatModel cannot be null" );
93- this .streamingChatModel = null ;
94- this .objectMapper = new ObjectMapper ();
129+ this (chatModel , null , null , chatModel .defaultRequestParameters ().modelName (), null );
95130 }
96131
97132 public LangChain4j (ChatModel chatModel , String modelName ) {
98- super (Objects .requireNonNull (modelName , "chat model name cannot be null" ));
99- this .chatModel = Objects .requireNonNull (chatModel , "chatModel cannot be null" );
100- this .streamingChatModel = null ;
101- this .objectMapper = new ObjectMapper ();
133+ this (chatModel , null , null , modelName , null );
102134 }
103135
104136 public LangChain4j (StreamingChatModel streamingChatModel ) {
105- super (
106- Objects .requireNonNull (
107- streamingChatModel .defaultRequestParameters ().modelName (),
108- "streaming chat model name cannot be null" ));
109- this .chatModel = null ;
110- this .streamingChatModel =
111- Objects .requireNonNull (streamingChatModel , "streamingChatModel cannot be null" );
112- this .objectMapper = new ObjectMapper ();
137+ this (
138+ null ,
139+ streamingChatModel ,
140+ null ,
141+ streamingChatModel .defaultRequestParameters ().modelName (),
142+ null );
113143 }
114144
115145 public LangChain4j (StreamingChatModel streamingChatModel , String modelName ) {
116- super (Objects .requireNonNull (modelName , "streaming chat model name cannot be null" ));
117- this .chatModel = null ;
118- this .streamingChatModel =
119- Objects .requireNonNull (streamingChatModel , "streamingChatModel cannot be null" );
120- this .objectMapper = new ObjectMapper ();
146+ this (null , streamingChatModel , null , modelName , null );
121147 }
122148
123149 public LangChain4j (ChatModel chatModel , StreamingChatModel streamingChatModel , String modelName ) {
124- super (Objects .requireNonNull (modelName , "model name cannot be null" ));
125- this .chatModel = Objects .requireNonNull (chatModel , "chatModel cannot be null" );
126- this .streamingChatModel =
127- Objects .requireNonNull (streamingChatModel , "streamingChatModel cannot be null" );
128- this .objectMapper = new ObjectMapper ();
150+ this (chatModel , streamingChatModel , null , modelName , null );
151+ }
152+
153+ private LangChain4j (
154+ ChatModel chatModel ,
155+ StreamingChatModel streamingChatModel ,
156+ ObjectMapper objectMapper ,
157+ String modelName ,
158+ TokenCountEstimator tokenCountEstimator ) {
159+ this ();
160+ LangChain4j .builder ()
161+ .chatModel (chatModel )
162+ .streamingChatModel (streamingChatModel )
163+ .objectMapper (objectMapper )
164+ .modelName (modelName )
165+ .tokenCountEstimator (tokenCountEstimator )
166+ .build ();
129167 }
130168
131169 @ Override
132170 public Flowable <LlmResponse > generateContent (LlmRequest llmRequest , boolean stream ) {
133171 if (stream ) {
134- if (this .streamingChatModel == null ) {
172+ if (this .streamingChatModel () == null ) {
135173 return Flowable .error (new IllegalStateException ("StreamingChatModel is not configured" ));
136174 }
137175
138176 ChatRequest chatRequest = toChatRequest (llmRequest );
139177
140178 return Flowable .create (
141179 emitter -> {
142- streamingChatModel .chat (
143- chatRequest ,
144- new StreamingChatResponseHandler () {
145- @ Override
146- public void onPartialResponse (String s ) {
147- emitter .onNext (
148- LlmResponse .builder ().content (Content .fromParts (Part .fromText (s ))).build ());
149- }
150-
151- @ Override
152- public void onCompleteResponse (ChatResponse chatResponse ) {
153- if (chatResponse .aiMessage ().hasToolExecutionRequests ()) {
154- AiMessage aiMessage = chatResponse .aiMessage ();
155- toParts (aiMessage ).stream ()
156- .map (Part ::functionCall )
157- .forEach (
158- functionCall -> {
159- functionCall .ifPresent (
160- function -> {
161- emitter .onNext (
162- LlmResponse .builder ()
163- .content (
164- Content .fromParts (
165- Part .fromFunctionCall (
166- function .name ().orElse ("" ),
167- function .args ().orElse (Map .of ()))))
168- .build ());
169- });
170- });
171- }
172- emitter .onComplete ();
173- }
174-
175- @ Override
176- public void onError (Throwable throwable ) {
177- emitter .onError (throwable );
178- }
179- });
180+ streamingChatModel ()
181+ .chat (
182+ chatRequest ,
183+ new StreamingChatResponseHandler () {
184+ @ Override
185+ public void onPartialResponse (String s ) {
186+ emitter .onNext (
187+ LlmResponse .builder ()
188+ .content (Content .fromParts (Part .fromText (s )))
189+ .build ());
190+ }
191+
192+ @ Override
193+ public void onCompleteResponse (ChatResponse chatResponse ) {
194+ if (chatResponse .aiMessage ().hasToolExecutionRequests ()) {
195+ AiMessage aiMessage = chatResponse .aiMessage ();
196+ toParts (aiMessage ).stream ()
197+ .map (Part ::functionCall )
198+ .forEach (
199+ functionCall -> {
200+ functionCall .ifPresent (
201+ function -> {
202+ emitter .onNext (
203+ LlmResponse .builder ()
204+ .content (
205+ Content .fromParts (
206+ Part .fromFunctionCall (
207+ function .name ().orElse ("" ),
208+ function .args ().orElse (Map .of ()))))
209+ .build ());
210+ });
211+ });
212+ }
213+ emitter .onComplete ();
214+ }
215+
216+ @ Override
217+ public void onError (Throwable throwable ) {
218+ emitter .onError (throwable );
219+ }
220+ });
180221 },
181222 BackpressureStrategy .BUFFER );
182223 } else {
183- if (this .chatModel == null ) {
224+ if (this .chatModel () == null ) {
184225 return Flowable .error (new IllegalStateException ("ChatModel is not configured" ));
185226 }
186227
187228 ChatRequest chatRequest = toChatRequest (llmRequest );
188- ChatResponse chatResponse = chatModel .chat (chatRequest );
189- LlmResponse llmResponse = toLlmResponse (chatResponse );
229+ ChatResponse chatResponse = chatModel () .chat (chatRequest );
230+ LlmResponse llmResponse = toLlmResponse (chatResponse , chatRequest );
190231
191232 return Flowable .just (llmResponse );
192233 }
@@ -413,7 +454,7 @@ private AiMessage toAiMessage(Content content) {
413454
414455 private String toJson (Object object ) {
415456 try {
416- return objectMapper .writeValueAsString (object );
457+ return objectMapper () .writeValueAsString (object );
417458 } catch (JsonProcessingException e ) {
418459 throw new RuntimeException (e );
419460 }
@@ -511,11 +552,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) {
511552 }
512553 }
513554
514- private LlmResponse toLlmResponse (ChatResponse chatResponse ) {
555+ private LlmResponse toLlmResponse (ChatResponse chatResponse , ChatRequest chatRequest ) {
515556 Content content =
516557 Content .builder ().role ("model" ).parts (toParts (chatResponse .aiMessage ())).build ();
517558
518- return LlmResponse .builder ().content (content ).build ();
559+ LlmResponse .Builder builder = LlmResponse .builder ().content (content );
560+ TokenUsage tokenUsage = chatResponse .tokenUsage ();
561+ if (tokenCountEstimator () != null ) {
562+ try {
563+ int estimatedInput =
564+ tokenCountEstimator ().estimateTokenCountInMessages (chatRequest .messages ());
565+ int estimatedOutput =
566+ tokenCountEstimator ().estimateTokenCountInText (chatResponse .aiMessage ().text ());
567+ int estimatedTotal = estimatedInput + estimatedOutput ;
568+ builder .usageMetadata (
569+ GenerateContentResponseUsageMetadata .builder ()
570+ .promptTokenCount (estimatedInput )
571+ .candidatesTokenCount (estimatedOutput )
572+ .totalTokenCount (estimatedTotal )
573+ .build ());
574+ } catch (Exception e ) {
575+ e .printStackTrace ();
576+ }
577+ } else if (tokenUsage != null ) {
578+ builder .usageMetadata (
579+ GenerateContentResponseUsageMetadata .builder ()
580+ .promptTokenCount (tokenUsage .inputTokenCount ())
581+ .candidatesTokenCount (tokenUsage .outputTokenCount ())
582+ .totalTokenCount (tokenUsage .totalTokenCount ())
583+ .build ());
584+ }
585+
586+ return builder .build ();
519587 }
520588
521589 private List <Part > toParts (AiMessage aiMessage ) {
@@ -539,14 +607,17 @@ private List<Part> toParts(AiMessage aiMessage) {
539607 });
540608 return parts ;
541609 } else {
542- Part part = Part .builder ().text (aiMessage .text ()).build ();
543- return List .of (part );
610+ String text = aiMessage .text ();
611+ if (text == null ) {
612+ return List .of ();
613+ }
614+ return List .of (Part .builder ().text (text ).build ());
544615 }
545616 }
546617
547618 private Map <String , Object > toArgs (ToolExecutionRequest toolExecutionRequest ) {
548619 try {
549- return objectMapper .readValue (toolExecutionRequest .arguments (), MAP_TYPE_REFERENCE );
620+ return objectMapper () .readValue (toolExecutionRequest .arguments (), MAP_TYPE_REFERENCE );
550621 } catch (JsonProcessingException e ) {
551622 throw new RuntimeException (e );
552623 }
0 commit comments