Skip to content

Commit 87bacb5

Browse files
committed
New llm first interactive script generator
1 parent 8af5077 commit 87bacb5

1 file changed

Lines changed: 375 additions & 0 deletions

File tree

Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Interactive LangChain agent for libEnsemble script generation and execution.
4+
5+
The agent has tools and the chat IS the agent loop:
6+
- Each user message is a real user message in the conversation
7+
- The agent responds with tool calls and text
8+
- No fake 'ask_user' workaround
9+
10+
Without --interactive, it runs autonomously in a single invocation.
11+
12+
Requirements: pip install langchain langchain-openai mcp openai
13+
For options: python libe_agent_interactive.py -h
14+
"""
15+
16+
import os
17+
import sys
18+
import asyncio
19+
import re
20+
import subprocess
21+
import argparse
22+
import shutil
23+
from pathlib import Path
24+
from pydantic import BaseModel, Field
25+
from langchain_openai import ChatOpenAI
26+
from langchain.agents import create_agent
27+
from langchain_core.tools import StructuredTool
28+
from mcp import ClientSession, StdioServerParameters
29+
from mcp.client.stdio import stdio_client
30+
31+
32+
DEFAULT_MODEL = "gpt-4o-mini"
33+
MODEL = os.environ.get("LLM_MODEL", DEFAULT_MODEL)
34+
SHOW_PROMPTS = False
35+
36+
# Marker so the web UI knows the script is waiting for input
37+
INPUT_MARKER = "[INPUT_REQUESTED]"
38+
39+
DEFAULT_PROMPT = """Create six_hump_camel APOSMM scripts:
40+
- Executable: /home/shudson/test_mcp/script-creator/six_hump_camel/six_hump_camel.x
41+
- Input: /home/shudson/test_mcp/script-creator/six_hump_camel/input.txt
42+
- Template vars: X0, X1
43+
- 4 workers, 100 sims.
44+
- The output file for each simulation is output.txt
45+
- The bounds should be 0,1 and -1,2 for X0 and X1 respectively"""
46+
47+
SYSTEM_PROMPT = """You are a libEnsemble script assistant. You have tools to generate, read, write, run, and list scripts.
48+
49+
IMPORTANT RULES:
50+
- Only use CreateLibEnsembleScripts ONCE to generate initial scripts. NEVER call it again.
51+
- For ANY modifications the user requests, use read_file to see the current file, then write_file to save the edited version.
52+
- If the user asks to see something, use read_file and show them the content.
53+
- Don't run scripts unless the user explicitly asks you to run them.
54+
- When reviewing scripts, highlight key configuration: generator bounds/parameters and the objective function.
55+
- After running, if scripts fail, explain the error and offer to fix using write_file."""
56+
57+
ARCHIVE_ITEMS = [
58+
"ensemble", "ensemble.log", "libE_stats.txt",
59+
"*.npy", "*.pickle",
60+
]
61+
62+
# Global state
63+
mcp_session = None
64+
WORK_DIR = None
65+
ARCHIVE_COUNTER = 1
66+
CURRENT_ARCHIVE = None
67+
68+
69+
# ── Archiving ────────────────────────────────────────────────
70+
71+
def start_new_archive(action):
72+
global ARCHIVE_COUNTER, CURRENT_ARCHIVE
73+
CURRENT_ARCHIVE = f"{ARCHIVE_COUNTER}_{action}"
74+
(WORK_DIR / "versions" / CURRENT_ARCHIVE).mkdir(parents=True, exist_ok=True)
75+
ARCHIVE_COUNTER += 1
76+
77+
78+
def archive_current_scripts():
79+
if not CURRENT_ARCHIVE:
80+
return
81+
dest = WORK_DIR / "versions" / CURRENT_ARCHIVE
82+
for f in WORK_DIR.glob("*.py"):
83+
shutil.copy(f, dest / f.name)
84+
85+
86+
def archive_run_output(error_msg=""):
87+
if not CURRENT_ARCHIVE:
88+
return
89+
output_dir = WORK_DIR / "versions" / CURRENT_ARCHIVE / "output"
90+
output_dir.mkdir(parents=True, exist_ok=True)
91+
if error_msg:
92+
(output_dir / "error.txt").write_text(error_msg)
93+
for item in ARCHIVE_ITEMS:
94+
item_path = WORK_DIR / item
95+
if item_path.exists() and item_path.is_dir():
96+
shutil.copytree(str(item_path), str(output_dir / item), dirs_exist_ok=True)
97+
shutil.rmtree(str(item_path))
98+
else:
99+
for fp in WORK_DIR.glob(item):
100+
if fp.is_file():
101+
shutil.copy(str(fp), str(output_dir / fp.name))
102+
fp.unlink()
103+
104+
105+
# ── Tool schemas ─────────────────────────────────────────────
106+
107+
class RunScriptInput(BaseModel):
108+
script_name: str = Field(description="Name of the Python script to run")
109+
110+
class ReadFileInput(BaseModel):
111+
filepath: str = Field(description="Path to file relative to work directory")
112+
113+
class WriteFileInput(BaseModel):
114+
filepath: str = Field(description="Path to file relative to work directory")
115+
content: str = Field(description="Full content to write")
116+
117+
class ListFilesInput(BaseModel):
118+
pass
119+
120+
121+
# ── Tool implementations ────────────────────────────────────
122+
123+
async def run_script_tool(script_name: str) -> str:
124+
script_path = WORK_DIR / script_name
125+
if not script_path.exists():
126+
return f"ERROR: Script '{script_name}' not found"
127+
128+
print(f"\nRunning {script_name}...", flush=True)
129+
try:
130+
result = subprocess.run(
131+
["python", script_name], cwd=WORK_DIR,
132+
capture_output=True, text=True, timeout=300
133+
)
134+
if result.returncode == 0:
135+
print("✓ Script ran successfully", flush=True)
136+
return f"SUCCESS\nOutput:\n{result.stdout[:500]}"
137+
else:
138+
error_msg = f"Return code {result.returncode}\nStderr: {result.stderr}\nStdout: {result.stdout}"
139+
print(f"✗ Failed (code {result.returncode})", flush=True)
140+
archive_run_output(error_msg)
141+
return f"FAILED (code {result.returncode})\nStderr:\n{result.stderr}\nStdout:\n{result.stdout[:500]}"
142+
except subprocess.TimeoutExpired:
143+
return "ERROR: Script timed out (300s)"
144+
except Exception as e:
145+
return f"ERROR: {e}"
146+
147+
148+
async def read_file_tool(filepath: str) -> str:
149+
file_path = WORK_DIR / filepath
150+
if not file_path.exists():
151+
return f"ERROR: File '{filepath}' not found"
152+
return file_path.read_text()
153+
154+
155+
async def write_file_tool(filepath: str, content: str) -> str:
156+
try:
157+
(WORK_DIR / filepath).write_text(content)
158+
start_new_archive("fix")
159+
archive_current_scripts()
160+
print(f"- Saved: {WORK_DIR / filepath}", flush=True)
161+
return f"SUCCESS: Wrote {filepath}"
162+
except Exception as e:
163+
return f"ERROR: {e}"
164+
165+
166+
async def list_files_tool() -> str:
167+
py_files = list(WORK_DIR.glob("*.py"))
168+
if not py_files:
169+
return "No Python files found"
170+
return "Files:\n" + "\n".join(f"- {f.name}" for f in py_files)
171+
172+
173+
async def generate_scripts_mcp(**kwargs):
174+
"""Call MCP tool to generate scripts, auto-save to work dir"""
175+
if 'custom_set_objective' in kwargs:
176+
del kwargs['custom_set_objective']
177+
if 'set_objective_code' in kwargs:
178+
del kwargs['set_objective_code']
179+
180+
result = await mcp_session.call_tool("CreateLibEnsembleScripts", kwargs)
181+
scripts_text = result.content[0].text if result.content else ""
182+
183+
if scripts_text and "===" in scripts_text:
184+
WORK_DIR.mkdir(exist_ok=True)
185+
pattern = r"=== (.+?) ===\n(.*?)(?=\n===|$)"
186+
for filename, content in re.findall(pattern, scripts_text, re.DOTALL):
187+
(WORK_DIR / filename.strip()).write_text(content.strip() + "\n")
188+
print(f"- Saved: {WORK_DIR / filename.strip()}", flush=True)
189+
start_new_archive("generated")
190+
archive_current_scripts()
191+
192+
return scripts_text
193+
194+
195+
# ── MCP server discovery ────────────────────────────────────
196+
197+
def find_mcp_server(user_path=None):
198+
locations = []
199+
if user_path:
200+
locations.append(Path(user_path))
201+
env_path = os.environ.get('GENERATOR_MCP_SERVER')
202+
if env_path:
203+
locations.append(Path(env_path))
204+
locations.extend([
205+
Path(__file__).parent.parent / "mcp_server.mjs",
206+
Path.cwd() / "mcp_server.mjs"
207+
])
208+
for loc in locations:
209+
if loc.exists():
210+
return loc
211+
print("Error: Cannot find mcp_server.mjs")
212+
sys.exit(1)
213+
214+
215+
# ── Main ─────────────────────────────────────────────────────
216+
217+
async def main():
218+
global mcp_session, WORK_DIR, SHOW_PROMPTS
219+
220+
parser = argparse.ArgumentParser(
221+
description="Interactive agent for libEnsemble scripts",
222+
formatter_class=argparse.RawDescriptionHelpFormatter,
223+
epilog="""
224+
Examples:
225+
python libe_agent_interactive.py --interactive
226+
python libe_agent_interactive.py --interactive --scripts my_scripts/
227+
python libe_agent_interactive.py --prompt "Create APOSMM scripts..."
228+
"""
229+
)
230+
parser.add_argument("--interactive", action="store_true", help="Enable interactive chat mode")
231+
parser.add_argument("--scripts", help="Use existing scripts from directory")
232+
parser.add_argument("--prompt", help="Prompt for script generation")
233+
parser.add_argument("--prompt-file", help="Read prompt from file")
234+
parser.add_argument("--show-prompts", action="store_true")
235+
parser.add_argument("--mcp-server", help="Path to mcp_server.mjs")
236+
parser.add_argument("--generate-only", action="store_true")
237+
parser.add_argument("--max-iterations", type=int, default=15)
238+
args = parser.parse_args()
239+
240+
SHOW_PROMPTS = args.show_prompts
241+
interactive = args.interactive
242+
WORK_DIR = Path("generated_scripts")
243+
WORK_DIR.mkdir(exist_ok=True)
244+
245+
# Connect to MCP server
246+
mcp_server = find_mcp_server(args.mcp_server)
247+
print(f"Generator MCP: {mcp_server}")
248+
server_params = StdioServerParameters(command="node", args=[str(mcp_server)])
249+
250+
async with stdio_client(server_params) as (read, write):
251+
async with ClientSession(read, write) as session:
252+
await session.initialize()
253+
mcp_session = session
254+
print("✓ Connected to MCP server")
255+
256+
# Get MCP tool schema
257+
mcp_tools = await session.list_tools()
258+
mcp_tool = mcp_tools.tools[0]
259+
260+
# Build tools
261+
tools = [
262+
StructuredTool(
263+
name=mcp_tool.name, description=mcp_tool.description,
264+
args_schema=mcp_tool.inputSchema, coroutine=generate_scripts_mcp
265+
),
266+
StructuredTool(name="run_script", description="Run a Python script. Returns SUCCESS or FAILED with error details.", args_schema=RunScriptInput, coroutine=run_script_tool),
267+
StructuredTool(name="read_file", description="Read a file to inspect its contents.", args_schema=ReadFileInput, coroutine=read_file_tool),
268+
StructuredTool(name="write_file", description="Write/overwrite a file to fix scripts.", args_schema=WriteFileInput, coroutine=write_file_tool),
269+
StructuredTool(name="list_files", description="List Python files in working directory.", args_schema=ListFilesInput, coroutine=list_files_tool),
270+
]
271+
272+
llm = ChatOpenAI(model=MODEL, temperature=0, base_url=os.environ.get("OPENAI_BASE_URL"))
273+
agent = create_agent(llm, tools)
274+
print("✓ Agent initialized\n")
275+
276+
# Build initial message
277+
messages = [("system", SYSTEM_PROMPT)]
278+
279+
if args.scripts:
280+
# Load existing scripts
281+
scripts_dir = Path(args.scripts)
282+
for f in sorted(scripts_dir.glob("*.py")):
283+
shutil.copy(f, WORK_DIR)
284+
print(f"Copied: {f.name}")
285+
start_new_archive("copied_scripts")
286+
archive_current_scripts()
287+
288+
run_scripts = list(WORK_DIR.glob("run_*.py"))
289+
run_name = run_scripts[0].name if run_scripts else "run_libe.py"
290+
initial_msg = f"I have libEnsemble scripts. The main script is '{run_name}'. Please review them and highlight the key configuration."
291+
292+
elif args.prompt:
293+
initial_msg = args.prompt
294+
elif args.prompt_file:
295+
initial_msg = Path(args.prompt_file).read_text()
296+
elif interactive:
297+
print("Describe the scripts you want to generate (or press Enter for default demo):", flush=True)
298+
print(INPUT_MARKER, flush=True)
299+
user_input = input().strip()
300+
initial_msg = user_input if user_input else DEFAULT_PROMPT
301+
if not user_input:
302+
print("Using default demo prompt")
303+
else:
304+
initial_msg = DEFAULT_PROMPT
305+
306+
if not interactive:
307+
# Autonomous mode: single invocation, agent does everything
308+
goal = f"""{initial_msg}
309+
310+
After generating/loading scripts: review them, run them, fix errors and retry (max 3 attempts). Report the result."""
311+
messages.append(("user", goal))
312+
313+
if SHOW_PROMPTS:
314+
print(f"Goal: {goal}\n")
315+
print("Starting agent...\n")
316+
317+
result = await agent.ainvoke({"messages": messages})
318+
print(f"\n{'='*60}")
319+
print("✓ Agent completed")
320+
print(f"{'='*60}")
321+
print(result["messages"][-1].content)
322+
323+
else:
324+
# Interactive mode: chat loop with automatic refine cycle
325+
goal = f"""User request: {initial_msg}
326+
327+
Instructions:
328+
1. Use CreateLibEnsembleScripts to generate the initial scripts.
329+
2. Read each generated script using read_file.
330+
3. Check the scripts match the user's request (bounds, sims, paths, parameters, etc).
331+
4. If anything doesn't match, fix it using write_file. Common issues: wrong bounds, wrong sim count, missing paths.
332+
5. Present a concise summary of the scripts and what you fixed (if anything).
333+
6. Then wait for the user's feedback."""
334+
messages.append(("user", goal))
335+
print("Starting agent...\n")
336+
337+
while True:
338+
try:
339+
# Agent turn
340+
result = await agent.ainvoke({"messages": messages})
341+
messages = result["messages"]
342+
343+
# Print agent's response
344+
response = messages[-1].content
345+
if response:
346+
print(f"\n{response}", flush=True)
347+
except Exception as e:
348+
print(f"\n⚠️ Agent error: {e}", flush=True)
349+
350+
# Wait for user input
351+
print(INPUT_MARKER, flush=True)
352+
user_input = input().strip()
353+
354+
if not user_input or user_input.lower() in ('quit', 'exit', 'done'):
355+
print("\n✓ Session ended")
356+
break
357+
358+
# Add as a proper HumanMessage to match LangGraph's message format
359+
from langchain_core.messages import HumanMessage, SystemMessage
360+
# Remind the model to respond to the user, not continue previous task
361+
messages.append(SystemMessage(content="STOP. Read the user's next message carefully and respond to exactly what they ask. Do not continue previous tasks."))
362+
messages.append(HumanMessage(content=user_input))
363+
364+
365+
if __name__ == "__main__":
366+
try:
367+
asyncio.run(main())
368+
except KeyboardInterrupt:
369+
print("\n\nInterrupted by user")
370+
sys.exit(0)
371+
except Exception as e:
372+
print(f"\n✗ Error: {e}")
373+
import traceback
374+
traceback.print_exc()
375+
sys.exit(1)

0 commit comments

Comments
 (0)