|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Interactive Agent Tester (Docker **or** Singularity backend) |
| 4 | +========================================================== |
| 5 | +A unified interactive tester that can drive either the **Docker sandbox** (`benchmarking_sandbox_management.py`) |
| 6 | +or the **Apptainer/Singularity sandbox** (`benchmarking_sandbox_management_singularity.py`). |
| 7 | +
|
| 8 | +At launch you choose a backend: |
| 9 | + • *docker* – requires Docker daemon on this machine. |
| 10 | + • *singularity* – requires `apptainer`/`singularity`; no Docker needed. |
| 11 | +
|
| 12 | +The rest of the behaviour (multi‑turn GPT orchestration, FastAPI kernel execution, |
| 13 | +resource upload, unlimited chat loop) is unchanged. |
| 14 | +""" |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import argparse |
| 18 | +import base64 |
| 19 | +import json |
| 20 | +import os |
| 21 | +import re |
| 22 | +import shlex |
| 23 | +import subprocess |
| 24 | +import sys |
| 25 | +import textwrap |
| 26 | +import time |
| 27 | +from datetime import datetime |
| 28 | +from pathlib import Path |
| 29 | +from typing import List, Tuple |
| 30 | + |
| 31 | +# ── Third‑party deps ───────────────────────────────────────────────────────── |
| 32 | +try: |
| 33 | + from dotenv import load_dotenv |
| 34 | + from openai import OpenAI, APIError |
| 35 | + import requests |
| 36 | + from rich.console import Console |
| 37 | + from rich.panel import Panel |
| 38 | + from rich.prompt import Prompt |
| 39 | + from rich.syntax import Syntax |
| 40 | + from rich.table import Table |
| 41 | +except ImportError as e: |
| 42 | + print(f"Missing dependency: {e}. Install required packages.", file=sys.stderr) |
| 43 | + sys.exit(1) |
| 44 | + |
| 45 | +console = Console() |
| 46 | + |
| 47 | +# ── Runtime‑backend selection (ask the user **before** importing managers) ── |
| 48 | +backend = Prompt.ask("Choose sandbox backend", choices=["docker", "singularity"], default="docker") |
| 49 | + |
| 50 | +SCRIPT_DIR = Path(__file__).resolve().parent |
| 51 | + |
| 52 | +if backend == "docker": |
| 53 | + sandbox_dir = SCRIPT_DIR / "sandbox" |
| 54 | + sys.path.insert(0, str(sandbox_dir)) |
| 55 | + try: |
| 56 | + from benchmarking_sandbox_management import ( |
| 57 | + SandboxManager as _BackendManager, |
| 58 | + CONTAINER_NAME as _SANDBOX_HANDLE, |
| 59 | + API_PORT_HOST as _API_PORT, |
| 60 | + ) |
| 61 | + finally: |
| 62 | + sys.path.pop(0) |
| 63 | + COPY_CMD = lambda src, dst: subprocess.run(["docker", "cp", src, dst], check=True) |
| 64 | + |
| 65 | +elif backend == "singularity": |
| 66 | + sandbox_dir = SCRIPT_DIR / "sandbox" |
| 67 | + sys.path.insert(0, str(sandbox_dir)) |
| 68 | + try: |
| 69 | + import benchmarking_sandbox_management_singularity as sing |
| 70 | + except ImportError as e: |
| 71 | + console.print(f"[red]Failed to import Singularity manager: {e}[/red]") |
| 72 | + sys.exit(1) |
| 73 | + |
| 74 | + class _SingWrapper: # thin adapter to mimic Docker SandboxManager API |
| 75 | + def __init__(self): |
| 76 | + pass |
| 77 | + def start_container(self): |
| 78 | + return sing.start_instance() |
| 79 | + def stop_container(self, remove: bool = True, container_obj=None): |
| 80 | + return sing.stop_instance() |
| 81 | + _BackendManager = _SingWrapper |
| 82 | + _SANDBOX_HANDLE = sing.INSTANCE_NAME |
| 83 | + _API_PORT = sing.API_PORT_HOST |
| 84 | + |
| 85 | + # Apptainer/ Singularity lacks a simple cp, so we issue a warning and rely on bind‑mounts |
| 86 | + def COPY_CMD(src, dst): # noqa: N802 |
| 87 | + console.print(f"[yellow]File copy inside Singularity instance not automated.\n" |
| 88 | + f"Ensure the file {src} is reachable at {dst} via bind mount or in the definition file.[/yellow]") |
| 89 | + |
| 90 | +else: |
| 91 | + console.print("[red]Unknown backend choice.[/red]") |
| 92 | + sys.exit(1) |
| 93 | + |
| 94 | +# ── Constants (after backend choice) ───────────────────────────────────────── |
| 95 | +DATASETS_DIR = SCRIPT_DIR / "datasets" |
| 96 | +OUTPUTS_DIR = SCRIPT_DIR / "outputs" |
| 97 | +ENV_FILE = SCRIPT_DIR / ".env" |
| 98 | +SANDBOX_DATA_PATH = "/home/sandboxuser/data.h5ad" |
| 99 | +SANDBOX_RESOURCES_DIR = "/home/sandboxuser/resources" |
| 100 | +API_BASE_URL = f"http://localhost:{_API_PORT}" |
| 101 | +EXECUTE_ENDPOINT = f"{API_BASE_URL}/execute" |
| 102 | +STATUS_ENDPOINT = f"{API_BASE_URL}/status" |
| 103 | + |
| 104 | + |
| 105 | +# ── Helper utilities ──────────────────────────────────────────────────────── |
| 106 | + |
| 107 | +def extract_python_code(txt: str) -> str | None: |
| 108 | + m = re.search(r"```python\s*([\s\S]+?)\s*```", txt) |
| 109 | + return m.group(1).strip() if m else None |
| 110 | + |
| 111 | + |
| 112 | +def display(role: str, content: str) -> None: |
| 113 | + titles = {"system": "SYSTEM", "user": "USER", "assistant": "ASSISTANT"} |
| 114 | + styles = {"system": "dim blue", "user": "cyan", "assistant": "green"} |
| 115 | + title = titles.get(role, role.upper()) |
| 116 | + style = styles.get(role, "white") |
| 117 | + |
| 118 | + if role == "assistant": |
| 119 | + code = extract_python_code(content) |
| 120 | + txt = re.sub(r"```python[\s\S]+?```", "", content, count=1).strip() |
| 121 | + if txt: |
| 122 | + console.print(Panel(txt, title=f"{title} (text)", border_style=style)) |
| 123 | + if code: |
| 124 | + console.print(Panel(Syntax(code, "python", line_numbers=True), title=f"{title} (code)", border_style=style)) |
| 125 | + else: |
| 126 | + console.print(Panel(content, title=title, border_style=style)) |
| 127 | + |
| 128 | + |
| 129 | +# ── Dataset & prompt helpers ──────────────────────────────────────────────── |
| 130 | + |
| 131 | +def get_initial_prompt() -> str: |
| 132 | + console.print("[bold cyan]Enter the initial user prompt (Ctrl+D to finish):[/bold cyan]") |
| 133 | + try: |
| 134 | + txt = sys.stdin.read().strip() |
| 135 | + except EOFError: |
| 136 | + txt = "" |
| 137 | + if not txt: |
| 138 | + console.print("[red]Empty prompt. Aborting.[/red]") |
| 139 | + sys.exit(1) |
| 140 | + return txt |
| 141 | + |
| 142 | + |
| 143 | +def select_dataset() -> Tuple[Path, dict]: |
| 144 | + if not DATASETS_DIR.exists(): |
| 145 | + console.print(f"[red]Datasets dir not found: {DATASETS_DIR}[/red]") |
| 146 | + sys.exit(1) |
| 147 | + items = [(p, json.loads(p.with_suffix(".json").read_text())) for p in DATASETS_DIR.glob("*.h5ad") if p.with_suffix(".json").exists()] |
| 148 | + if not items: |
| 149 | + console.print("[red]No datasets found.[/red]") |
| 150 | + sys.exit(1) |
| 151 | + tbl = Table(title="Datasets") |
| 152 | + tbl.add_column("Idx", justify="right") |
| 153 | + tbl.add_column("Name") |
| 154 | + tbl.add_column("Cells", justify="right") |
| 155 | + for i, (p, meta) in enumerate(items, 1): |
| 156 | + tbl.add_row(str(i), meta.get("dataset_title", p.stem), str(meta.get("cell_count", "?"))) |
| 157 | + console.print(tbl) |
| 158 | + idx = int(Prompt.ask("Choose index", choices=[str(i) for i in range(1, len(items)+1)])) - 1 |
| 159 | + return items[idx] |
| 160 | + |
| 161 | + |
| 162 | +def collect_resources() -> List[Tuple[Path, str]]: |
| 163 | + console.print("\n[bold cyan]Optional: list files/folders to copy into sandbox[/bold cyan] (blank line to finish)") |
| 164 | + lst: List[Tuple[Path, str]] = [] |
| 165 | + while True: |
| 166 | + p = Prompt.ask("Path", default="").strip() |
| 167 | + if not p: |
| 168 | + break |
| 169 | + path = Path(p).expanduser().resolve() |
| 170 | + if not path.exists(): |
| 171 | + console.print(f"[yellow]Path does not exist: {path}[/yellow]") |
| 172 | + continue |
| 173 | + lst.append((path, f"{SANDBOX_RESOURCES_DIR}/{path.name}")) |
| 174 | + return lst |
| 175 | + |
| 176 | + |
| 177 | +# ── FastAPI kernel helpers ────────────────────────────────────────────────── |
| 178 | + |
| 179 | +def api_alive(max_retries: int = 10, delay: float = 1.5) -> bool: |
| 180 | + for _ in range(max_retries): |
| 181 | + try: |
| 182 | + if requests.get(STATUS_ENDPOINT, timeout=2).json().get("status") == "ok": |
| 183 | + return True |
| 184 | + except requests.RequestException: |
| 185 | + time.sleep(delay) |
| 186 | + return False |
| 187 | + |
| 188 | + |
| 189 | +def format_execute_response(resp: dict) -> str: |
| 190 | + lines = ["Code execution result:"] |
| 191 | + stdout, stderr, imgs = [], [], [] |
| 192 | + for itm in resp.get("outputs", []): |
| 193 | + if itm["type"] == "stream": |
| 194 | + (stdout if itm.get("name") == "stdout" else stderr).append(itm.get("text", "")) |
| 195 | + elif itm["type"] == "error": |
| 196 | + stderr.append("Error: " + itm.get("evalue", "")) |
| 197 | + stderr.extend(itm.get("traceback", [])) |
| 198 | + elif itm["type"] == "display_data": |
| 199 | + for mime, b64 in itm.get("data", {}).items(): |
| 200 | + if mime.startswith("image/"): |
| 201 | + fname = OUTPUTS_DIR / f"{datetime.now():%Y%m%d_%H%M%S_%f}.{mime.split('/')[1].split('+')[0]}" |
| 202 | + fname.parent.mkdir(exist_ok=True) |
| 203 | + with open(fname, "wb") as f: |
| 204 | + f.write(base64.b64decode(b64)) |
| 205 | + imgs.append(str(fname)) |
| 206 | + if stdout: |
| 207 | + lines += ["--- STDOUT ---", "".join(stdout)[:1500]] |
| 208 | + if stderr: |
| 209 | + lines += ["--- STDERR ---", "".join(stderr)[:1500]] |
| 210 | + if imgs: |
| 211 | + lines.append("Saved images: " + ", ".join(imgs)) |
| 212 | + lines.append(f"Final Status: {resp.get('final_status')}") |
| 213 | + return "\n".join(lines) |
| 214 | + |
| 215 | + |
| 216 | +# ── Chat‑runner ───────────────────────────────────────────────────────────── |
| 217 | + |
| 218 | +def run_interactive(prompt: str, dataset: Path, metadata: dict, resources: List[Tuple[Path, str]]) -> None: |
| 219 | + mgr = _BackendManager() |
| 220 | + console.print(f"Starting sandbox ({backend}) …") |
| 221 | + if not mgr.start_container(): |
| 222 | + console.print("[red]Failed to start sandbox.[/red]") |
| 223 | + return |
| 224 | + |
| 225 | + try: |
| 226 | + if not api_alive(): |
| 227 | + console.print("[red]Kernel API not responsive.[/red]") |
| 228 | + return |
| 229 | + # dataset copy (Docker only, Singularity warns via COPY_CMD) |
| 230 | + COPY_CMD(str(dataset), f"{_SANDBOX_HANDLE}:{SANDBOX_DATA_PATH}") |
| 231 | + for h, c in resources: |
| 232 | + COPY_CMD(str(h), f"{_SANDBOX_HANDLE}:{c}") |
| 233 | + |
| 234 | + resource_lines = [f"- {c} (from {h})" for h, c in resources] or ["- (none)"] |
| 235 | + sys_prompt = textwrap.dedent( |
| 236 | + f""" |
| 237 | + You are an AI assistant analysing a single‑cell dataset. The file lives inside the sandbox at **{SANDBOX_DATA_PATH}**. |
| 238 | + Additional resources:\n""" + "\n".join(resource_lines) + "\n\n" + textwrap.dedent( |
| 239 | + f"Dataset metadata:\n{json.dumps(metadata, indent=2)}\n\nWrap runnable Python in triple‑backtick ```python blocks. Imports & vars persist.""" |
| 240 | + ) |
| 241 | + ) |
| 242 | + |
| 243 | + history = [ |
| 244 | + {"role": "system", "content": sys_prompt}, |
| 245 | + {"role": "user", "content": prompt}, |
| 246 | + ] |
| 247 | + display("system", sys_prompt) |
| 248 | + display("user", prompt) |
| 249 | + |
| 250 | + openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
| 251 | + turn = 0 |
| 252 | + while True: |
| 253 | + turn += 1 |
| 254 | + console.print(f"\n[bold]OpenAI call (turn {turn})…[/bold]") |
| 255 | + try: |
| 256 | + rsp = openai.chat.completions.create(model="gpt-4o", messages=history, temperature=0.7) |
| 257 | + except APIError as e: |
| 258 | + console.print(f"[red]OpenAI error: {e}[/red]") |
| 259 | + break |
| 260 | + assistant_msg = rsp.choices[0].message.content |
| 261 | + history.append({"role": "assistant", "content": assistant_msg}) |
| 262 | + display("assistant", assistant_msg) |
| 263 | + |
| 264 | + code = extract_python_code(assistant_msg) |
| 265 | + if code: |
| 266 | + console.print("[cyan]Executing code…[/cyan]") |
| 267 | + try: |
| 268 | + api_r = requests.post(EXECUTE_ENDPOINT, json={"code": code, "timeout": 120}, timeout=130).json() |
| 269 | + feedback = format_execute_response(api_r) |
| 270 | + except Exception as exc: |
| 271 | + feedback = f"Code execution result:\n[Execution error: {exc}]" |
| 272 | + history.append({"role": "user", "content": feedback}) |
| 273 | + display("user", feedback) |
| 274 | + |
| 275 | + console.print("\n[bold]Next message (blank = continue, 'exit' to quit):[/bold]") |
| 276 | + try: |
| 277 | + user_in = input().strip() |
| 278 | + except (EOFError, KeyboardInterrupt): |
| 279 | + user_in = "exit" |
| 280 | + if user_in.lower() in {"exit", "quit"}: |
| 281 | + break |
| 282 | + if user_in: |
| 283 | + history.append({"role": "user", "content": user_in}) |
| 284 | + display("user", user_in) |
| 285 | + finally: |
| 286 | + console.print("Stopping sandbox…") |
| 287 | + mgr.stop_container(remove=True) |
| 288 | + |
| 289 | + |
| 290 | +# ── CLI entry ─────────────────────────────────────────────────────────────── |
| 291 | + |
| 292 | +def main(): |
| 293 | + load_dotenv(Path(ENV_FILE)) |
| 294 | + if not os.getenv("OPENAI_API_KEY"): |
| 295 | + console.print(f"[red]OPENAI_API_KEY not set in {ENV_FILE}.[/red]") |
| 296 | + sys.exit(1) |
| 297 | + |
| 298 | + prompt = get_initial_prompt() |
| 299 | + data_p, meta = select_dataset() |
| 300 | + res = collect_resources() |
| 301 | + run_interactive(prompt, data_p, meta, res) |
| 302 | + |
| 303 | + |
| 304 | +if __name__ == "__main__": |
| 305 | + try: |
| 306 | + main() |
| 307 | + except KeyboardInterrupt: |
| 308 | + console.print("\nInterrupted.") |
0 commit comments