Skip to content

Commit b8f7a4b

Browse files
authored
feat: add langchain basic stream and adkwebserver template (#42)
2 parents 49cc7f3 + 7c69110 commit b8f7a4b

7 files changed

Lines changed: 440 additions & 83 deletions

File tree

agentkit/toolkit/cli/cli_invoke.py

Lines changed: 254 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import json
2020
import typer
2121
from rich.console import Console
22+
import time
2223
import random
2324
import uuid
2425
from agentkit.toolkit.config import get_config
@@ -29,6 +30,136 @@
2930
console = Console()
3031

3132

33+
def _extract_text_chunks_from_langchain_event(event: dict) -> list[str]:
34+
"""Extract incremental text chunks from LangChain message_to_dict-style events.
35+
36+
Expected shape (example):
37+
{"type": "AIMessageChunk", "data": {"content": "今天", ...}}
38+
"""
39+
if not isinstance(event, dict):
40+
return []
41+
42+
event_type = event.get("type")
43+
data = event.get("data")
44+
if not isinstance(event_type, str) or not isinstance(data, dict):
45+
return []
46+
47+
# Most common streaming types: AIMessageChunk / HumanMessageChunk / ToolMessageChunk
48+
if not (
49+
event_type.endswith("MessageChunk")
50+
or event_type in {"AIMessage", "HumanMessage", "ToolMessage"}
51+
):
52+
return []
53+
54+
content = data.get("content")
55+
if content is None:
56+
return []
57+
58+
# content can be a string, or a multimodal list like:
59+
# [{"type":"text","text":"..."}, ...]
60+
if isinstance(content, str):
61+
return [content] if content else []
62+
if isinstance(content, list):
63+
chunks: list[str] = []
64+
for item in content:
65+
if isinstance(item, str) and item:
66+
chunks.append(item)
67+
elif isinstance(item, dict):
68+
text = item.get("text")
69+
if isinstance(text, str) and text:
70+
chunks.append(text)
71+
return chunks
72+
73+
return []
74+
75+
76+
def _extract_reasoning_chunks_from_langchain_event(event: dict) -> list[str]:
77+
"""Extract incremental reasoning chunks from LangChain events.
78+
79+
LangChain emit reasoning in:
80+
event['data']['additional_kwargs']['reasoning_content']
81+
while leaving event['data']['content'] empty.
82+
"""
83+
if not isinstance(event, dict):
84+
return []
85+
86+
event_type = event.get("type")
87+
data = event.get("data")
88+
if not isinstance(event_type, str) or not isinstance(data, dict):
89+
return []
90+
91+
if not (
92+
event_type.endswith("MessageChunk")
93+
or event_type in {"AIMessage", "HumanMessage", "ToolMessage"}
94+
):
95+
return []
96+
97+
additional_kwargs = data.get("additional_kwargs")
98+
if not isinstance(additional_kwargs, dict):
99+
return []
100+
101+
reasoning = additional_kwargs.get("reasoning_content")
102+
if isinstance(reasoning, str):
103+
return [reasoning] if reasoning else []
104+
return []
105+
106+
107+
def _extract_text_chunks_from_adk_event(event: dict) -> list[str]:
108+
"""Extract incremental text chunks from Google ADK/AgentKit streaming events."""
109+
if not isinstance(event, dict):
110+
return []
111+
112+
parts: list[Any] = []
113+
if isinstance(event.get("parts"), list):
114+
parts = event.get("parts", [])
115+
elif isinstance(event.get("message"), dict):
116+
parts = event["message"].get("parts", [])
117+
elif isinstance(event.get("content"), dict):
118+
parts = event["content"].get("parts", [])
119+
elif isinstance(event.get("status"), dict):
120+
role = event["status"].get("message", {}).get("role")
121+
if role == "agent":
122+
parts = event["status"].get("message", {}).get("parts", [])
123+
124+
if not isinstance(parts, list) or not parts:
125+
return []
126+
127+
chunks: list[str] = []
128+
for part in parts:
129+
text: Optional[str] = None
130+
if isinstance(part, dict) and "text" in part:
131+
val = part.get("text")
132+
text = val if isinstance(val, str) else None
133+
elif isinstance(part, str):
134+
text = part
135+
if text:
136+
chunks.append(text)
137+
return chunks
138+
139+
140+
def _normalize_stream_event(event: Any) -> Optional[dict]:
141+
"""Normalize an event yielded by InvokeResult.stream() to a dict.
142+
143+
- Runner normally yields dict (already JSON-decoded).
144+
- CLI keeps a fallback path for raw SSE strings ("data: {...}").
145+
"""
146+
if isinstance(event, dict):
147+
return event
148+
if isinstance(event, str):
149+
s = event.strip()
150+
if not s.startswith("data: "):
151+
return None
152+
json_str = s[6:].strip()
153+
if not json_str:
154+
return None
155+
try:
156+
parsed = json.loads(json_str)
157+
return parsed if isinstance(parsed, dict) else None
158+
except json.JSONDecodeError:
159+
return None
160+
return None
161+
162+
32163
def build_standard_payload(message: Optional[str], payload: Optional[str]) -> dict:
33164
if message:
34165
return {"prompt": message}
@@ -88,6 +219,16 @@ def invoke_command(
88219
headers: str = typer.Option(
89220
None, "--headers", "-h", help="JSON headers for request (advanced option)"
90221
),
222+
show_reasoning: bool = typer.Option(
223+
False,
224+
"--show-reasoning",
225+
help="Print LangChain reasoning_content (if present) during streaming",
226+
),
227+
raw: bool = typer.Option(
228+
False,
229+
"--raw",
230+
help="Print raw streaming events (and raw JSON response) for debugging",
231+
),
91232
apikey: str = typer.Option(
92233
None, "--apikey", "-ak", help="API key for authentication"
93234
),
@@ -115,29 +256,37 @@ def invoke_command(
115256
"[red]Error: Cannot specify both message and payload. Use either message or --payload.[/red]"
116257
)
117258
raise typer.Exit(1)
118-
119259
# Validate parameters: must provide either message or payload
120260
if not message and not payload:
121261
console.print(
122262
"[red]Error: Must provide either a message or --payload option.[/red]"
123263
)
124264
raise typer.Exit(1)
125-
126265
config = get_config(config_path=config_file)
127266
common_config = config.get_common_config()
128267

129268
# Process headers
130-
final_headers = {
269+
default_headers = {
131270
"user_id": "agentkit_user",
132271
"session_id": "agentkit_sample_session",
133272
}
273+
final_headers = default_headers.copy()
274+
134275
if headers:
135276
try:
136-
final_headers = json.loads(headers) if isinstance(headers, str) else headers
137-
console.print(f"[blue]Using custom headers: {final_headers}[/blue]")
277+
custom_headers = (
278+
json.loads(headers) if isinstance(headers, str) else headers
279+
)
138280
except json.JSONDecodeError as e:
139281
console.print(f"[red]Error: Invalid JSON headers: {e}[/red]")
140282
raise typer.Exit(1)
283+
if not isinstance(custom_headers, dict):
284+
console.print(
285+
'[red]Error: --headers must be a JSON object (e.g. \'{"user_id": "u1"}\').[/red]'
286+
)
287+
raise typer.Exit(1)
288+
final_headers.update(custom_headers)
289+
console.print(f"[blue]Using merged headers: {final_headers}[/blue]")
141290
else:
142291
console.print(f"[blue]Using default headers: {final_headers}[/blue]")
143292

@@ -154,7 +303,9 @@ def invoke_command(
154303
)
155304
final_payload = build_a2a_payload(message, payload, final_headers)
156305

157-
# Set execution context - CLI uses ConsoleReporter (with colored output and progress)
306+
if apikey:
307+
final_headers["Authorization"] = f"Bearer {apikey}"
308+
158309
from agentkit.toolkit.context import ExecutionContext
159310

160311
reporter = ConsoleReporter()
@@ -171,7 +322,6 @@ def invoke_command(
171322
if not result.success:
172323
console.print(f"[red]❌ Invocation failed: {result.error}[/red]")
173324
raise typer.Exit(1)
174-
175325
console.print("[green]✅ Invocation successful[/green]")
176326

177327
# Get response
@@ -180,69 +330,106 @@ def invoke_command(
180330
# Handle streaming response (generator)
181331
if result.is_streaming:
182332
console.print("[cyan]📡 Streaming response detected...[/cyan]\n")
333+
if raw:
334+
console.print(
335+
"[yellow]Raw mode enabled: printing raw stream events[/yellow]\n"
336+
)
183337
result_list = []
184338
complete_text = []
339+
printed_reasoning_header = False
340+
printed_answer_header = False
341+
printed_hidden_reasoning_hint = False
342+
printed_heartbeat = False
343+
last_heartbeat_ts = time.monotonic()
185344

186345
for event in result.stream():
187346
result_list.append(event)
188347

189-
# If it's a string starting with "data: ", try to parse (fallback handling)
190-
if isinstance(event, str):
191-
if event.strip().startswith("data: "):
192-
try:
193-
json_str = event.strip()[6:].strip() # Remove "data: " prefix
194-
event = json.loads(json_str)
195-
except json.JSONDecodeError:
196-
# Parsing failed, skip this event
197-
continue
348+
if raw:
349+
# Print the event as received (before normalization), to help debugging.
350+
if isinstance(event, dict):
351+
console.print(json.dumps(event, ensure_ascii=False))
352+
elif isinstance(event, str):
353+
console.print(event.rstrip("\n"))
198354
else:
199-
# Not SSE format string, skip
200-
continue
201-
202-
# Handle A2A JSON-RPC
203-
if isinstance(event, dict) and event.get("jsonrpc") and "result" in event:
204-
event = event["result"]
205-
206-
if isinstance(event, dict):
207-
parts = []
208-
if isinstance(event.get("parts"), list):
209-
parts = event.get("parts", [])
210-
elif isinstance(event.get("message"), dict):
211-
parts = event["message"].get("parts", [])
212-
elif isinstance(event.get("content"), dict):
213-
parts = event["content"].get("parts", [])
214-
elif isinstance(event.get("status"), dict):
215-
role = event["status"].get("message", {}).get("role")
216-
if role == "agent":
217-
parts = event["status"].get("message", {}).get("parts", [])
218-
if not event.get("partial", True):
219-
logger.info("Partial event: %s", event) # Log partial events
220-
continue
221-
222-
if parts:
223-
for p in parts:
224-
text = None
225-
if isinstance(p, dict) and "text" in p:
226-
text = p["text"]
227-
elif isinstance(p, str):
228-
text = p
229-
if text:
230-
complete_text.append(text)
231-
# Incremental print (keep no newline)
232-
console.print(text, end="", style="green")
233-
234-
# Display error information in event (if any)
235-
if "error" in event:
236-
console.print(f"\n[red]Error: {event['error']}[/red]")
237-
238-
# Handle status updates (e.g., final flag or completed status)
239-
if event.get("final") is True:
240-
break
241-
242-
status = event.get("status")
243-
if isinstance(status, dict) and status.get("state") == "completed":
244-
console.print("\n[cyan]Status indicates completed[/cyan]")
245-
break
355+
console.print(repr(event))
356+
357+
normalized = _normalize_stream_event(event)
358+
if normalized is None:
359+
continue
360+
361+
# Handle A2A JSON-RPC wrapper (unwrap to the underlying result payload)
362+
if normalized.get("jsonrpc") and "result" in normalized:
363+
result_payload = normalized.get("result")
364+
normalized = result_payload if isinstance(result_payload, dict) else {}
365+
366+
# Keep existing partial-event behavior for ADK style streams.
367+
# (LangChain message events typically don't carry this field.)
368+
if not normalized.get("partial", True):
369+
logger.info("Partial event: %s", normalized)
370+
continue
371+
372+
# In raw mode, we still keep termination/error handling, but skip
373+
# extracted text printing to avoid mixing structured debug output.
374+
if not raw:
375+
# LangChain: reasoning_content
376+
reasoning_chunks = _extract_reasoning_chunks_from_langchain_event(
377+
normalized
378+
)
379+
if reasoning_chunks:
380+
if show_reasoning:
381+
if not printed_reasoning_header:
382+
console.print("[cyan]🧠 Reasoning:[/cyan]")
383+
printed_reasoning_header = True
384+
for text in reasoning_chunks:
385+
console.print(text, end="", style="yellow")
386+
else:
387+
# Default behavior: do not print reasoning, but keep the CLI responsive
388+
# with a one-time hint and a periodic heartbeat.
389+
if not printed_hidden_reasoning_hint:
390+
console.print(
391+
"[cyan]🤔 Model is thinking... (use --show-reasoning to view)[/cyan]"
392+
)
393+
printed_hidden_reasoning_hint = True
394+
now = time.monotonic()
395+
if now - last_heartbeat_ts >= 1.5:
396+
console.print(".", end="", style="cyan")
397+
printed_heartbeat = True
398+
last_heartbeat_ts = now
399+
400+
# Extract and print incremental answer text chunks
401+
text_chunks: list[str] = []
402+
text_chunks.extend(
403+
_extract_text_chunks_from_langchain_event(normalized)
404+
)
405+
if not text_chunks:
406+
text_chunks.extend(_extract_text_chunks_from_adk_event(normalized))
407+
408+
if text_chunks:
409+
# If we printed a hidden reasoning hint / heartbeat dots, separate answer on a new line.
410+
if printed_hidden_reasoning_hint or printed_heartbeat:
411+
console.print("")
412+
printed_hidden_reasoning_hint = False
413+
printed_heartbeat = False
414+
if printed_reasoning_header and not printed_answer_header:
415+
console.print("\n[cyan]📝 Answer:[/cyan]")
416+
printed_answer_header = True
417+
for text in text_chunks:
418+
complete_text.append(text)
419+
console.print(text, end="", style="green")
420+
421+
# Display error information in event (if any)
422+
if "error" in normalized:
423+
console.print(f"\n[red]Error: {normalized['error']}[/red]")
424+
425+
# Handle status updates (e.g., final flag or completed status)
426+
if normalized.get("final") is True:
427+
break
428+
429+
status = normalized.get("status")
430+
if isinstance(status, dict) and status.get("state") == "completed":
431+
console.print("\n[cyan]Status indicates completed[/cyan]")
432+
break
246433

247434
# Display complete response (commented out for now)
248435
# if complete_text:
@@ -255,7 +442,10 @@ def invoke_command(
255442
# Handle non-streaming response
256443
console.print("[cyan]📝 Response:[/cyan]")
257444
if isinstance(response, dict):
258-
console.print(json.dumps(response, indent=2, ensure_ascii=False))
445+
if raw:
446+
console.print(json.dumps(response, ensure_ascii=False))
447+
else:
448+
console.print(json.dumps(response, indent=2, ensure_ascii=False))
259449
else:
260450
console.print(response)
261451

0 commit comments

Comments
 (0)