Skip to content

Commit efb5dc0

Browse files
committed
Add model selector to UI
1 parent 72df35a commit efb5dc0

2 files changed

Lines changed: 153 additions & 24 deletions

File tree

agentic/web_ui/app.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,11 @@ def send_input(self, text):
4545
except Exception as e:
4646
self.output_queue.put(("error", f"Failed to send input: {e}"))
4747

48-
def _subprocess_thread(self, cmd, cwd):
48+
def _subprocess_thread(self, cmd, cwd, env_overrides=None):
4949
try:
50+
env = {**os.environ, "PYTHONUNBUFFERED": "1"}
51+
if env_overrides:
52+
env.update(env_overrides)
5053
self.process = subprocess.Popen(
5154
cmd,
5255
cwd=cwd,
@@ -55,7 +58,7 @@ def _subprocess_thread(self, cmd, cwd):
5558
stderr=subprocess.STDOUT,
5659
text=True,
5760
bufsize=1,
58-
env={**os.environ, "PYTHONUNBUFFERED": "1"}
61+
env=env,
5962
)
6063
for line in self.process.stdout:
6164
self.output_queue.put(("line", line.rstrip()))
@@ -66,7 +69,8 @@ def _subprocess_thread(self, cmd, cwd):
6669
finally:
6770
self.process = None
6871

69-
async def run_agent(self, agent_script, scripts_dir, ws, agent_dir=None):
72+
async def run_agent(self, agent_script, scripts_dir, ws, agent_dir=None,
73+
llm_model=None, openai_base_url=None):
7074
run_dir = Path(agent_dir) if agent_dir else AGENT_DIR
7175
cmd = [sys.executable, agent_script]
7276

@@ -86,9 +90,17 @@ async def run_agent(self, agent_script, scripts_dir, ws, agent_dir=None):
8690
except Empty:
8791
break
8892

93+
# Pass selected model to subprocess
94+
env_overrides = {}
95+
if llm_model:
96+
env_overrides["LLM_MODEL"] = llm_model
97+
if openai_base_url:
98+
env_overrides["OPENAI_BASE_URL"] = openai_base_url
99+
89100
thread = threading.Thread(
90101
target=self._subprocess_thread,
91102
args=(cmd, str(run_dir)),
103+
kwargs={"env_overrides": env_overrides or None},
92104
daemon=True
93105
)
94106
thread.start()
@@ -138,7 +150,9 @@ async def ws_endpoint(ws: WebSocket, session_id: str):
138150
msg.get("agent_script", ""),
139151
msg.get("scripts_dir") or "",
140152
ws,
141-
agent_dir=msg.get("agent_dir")
153+
agent_dir=msg.get("agent_dir"),
154+
llm_model=msg.get("llm_model"),
155+
openai_base_url=msg.get("openai_base_url"),
142156
)
143157
)
144158
except WebSocketDisconnect:

agentic/web_ui/gradio_chat.py

Lines changed: 135 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
import os
34
import socket
45
import subprocess
56
import sys
@@ -9,47 +10,119 @@
910
from queue import Queue, Empty
1011

1112
import gradio as gr
13+
import requests
1214
import websockets
1315

1416
WS_URL = "ws://127.0.0.1:8000/ws/test"
1517
DEFAULT_AGENT_DIR = Path(__file__).parent.parent
18+
ALCF_API_BASE = "https://inference-api.alcf.anl.gov"
19+
ALCF_ENDPOINTS_URL = f"{ALCF_API_BASE}/resource_server/list-endpoints"
1620

17-
def _get_model_name():
18-
import os
21+
22+
def _fetch_models():
23+
"""Fetch available models. Uses ALCF list-endpoints for ALCF, or /v1/models for OpenAI.
24+
Returns (choices, model_map, error_or_none).
25+
model_map: label -> (model_name, base_url)
26+
"""
27+
api_key = os.environ.get("OPENAI_API_KEY", "")
28+
base_url = os.environ.get("OPENAI_BASE_URL", "")
29+
30+
if not api_key:
31+
return [], {}, "No OPENAI_API_KEY set"
32+
33+
# ALCF: use their list-endpoints API (returns models grouped by cluster)
34+
if "alcf" in base_url.lower():
35+
try:
36+
resp = requests.get(
37+
ALCF_ENDPOINTS_URL,
38+
headers={"Authorization": f"Bearer {api_key}"},
39+
timeout=10,
40+
)
41+
if resp.status_code in (401, 403):
42+
return [], {}, "Auth failed — token may be expired"
43+
resp.raise_for_status()
44+
data = resp.json()
45+
except requests.RequestException as e:
46+
return [], {}, f"Cannot reach ALCF API: {e}"
47+
48+
skip = {"embed", "genslm"}
49+
choices = []
50+
model_map = {}
51+
for cluster, info in data.get("clusters", {}).items():
52+
cluster_url = f"{ALCF_API_BASE}{info['base_url']}/api/v1"
53+
for fw in info.get("frameworks", {}).values():
54+
if "/v1/chat/completions" not in fw.get("endpoints", []):
55+
continue
56+
for model in fw.get("models", []):
57+
if any(s in model.lower() for s in skip):
58+
continue
59+
label = f"{model} ({cluster})"
60+
choices.append(label)
61+
model_map[label] = (model, cluster_url)
62+
return sorted(choices), model_map, None
63+
64+
# Any other OpenAI-compatible endpoint: query /v1/models
65+
try:
66+
from openai import OpenAI
67+
client = OpenAI(api_key=api_key, base_url=base_url or None)
68+
models = client.models.list()
69+
source = "OpenAI" if not base_url else base_url.split("//")[-1].split("/")[0]
70+
skip = {"embed", "tts", "whisper", "dall-e", "davinci", "babbage", "moderation"}
71+
choices = []
72+
model_map = {}
73+
for m in sorted(models.data, key=lambda x: x.id):
74+
if any(s in m.id.lower() for s in skip):
75+
continue
76+
label = f"{m.id} ({source})"
77+
choices.append(label)
78+
model_map[label] = (m.id, base_url)
79+
return choices, model_map, None
80+
except Exception as e:
81+
msg = str(e)
82+
if "401" in msg or "invalid" in msg.lower() or "api key" in msg.lower():
83+
return [], {}, "Invalid API key for this endpoint"
84+
return [], {}, f"Cannot fetch models: {type(e).__name__}"
85+
86+
87+
def _current_model_label():
88+
"""Label for the currently configured model."""
1989
model = os.environ.get("LLM_MODEL", "gpt-4o-mini")
2090
base = os.environ.get("OPENAI_BASE_URL", "")
21-
if "alcf" in base.lower():
22-
return f"{model} (ALCF)"
91+
if "metis" in base:
92+
return f"{model} (metis)"
93+
elif "sophia" in base:
94+
return f"{model} (sophia)"
95+
elif "alcf" in base.lower():
96+
return f"{model} (alcf)"
2397
elif base:
24-
return f"{model} ({base.split('//')[1].split('/')[0]})"
98+
return f"{model} (custom)"
2599
return f"{model} (OpenAI)"
26100

27101

28-
def _check_api():
102+
def _check_api(model=None, base_url=None):
29103
"""Quick API check. Returns None on success, or an error message string."""
30-
import os
31104
from openai import OpenAI
32-
model = os.environ.get("LLM_MODEL", "gpt-4o-mini")
105+
model = model or os.environ.get("LLM_MODEL", "gpt-4o-mini")
106+
base_url = base_url or os.environ.get("OPENAI_BASE_URL")
33107
try:
34-
client = OpenAI(
35-
api_key=os.environ.get("OPENAI_API_KEY"),
36-
base_url=os.environ.get("OPENAI_BASE_URL"),
37-
)
108+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url=base_url)
38109
client.chat.completions.create(
39110
model=model, messages=[{"role": "user", "content": "hi"}], max_tokens=1
40111
)
41112
return None
42113
except Exception as e:
43114
msg = str(e)
44115
if "403" in msg or "Permission" in msg.lower():
45-
return (f"⚠️ API auth failed ({model}). Token likely expired.\n\n"
116+
return ("⚠️ API auth failed. Token likely expired.\n\n"
46117
"```\npython3 inference_auth_token.py authenticate --force\n"
47118
"export OPENAI_API_KEY=$(python inference_auth_token.py get_access_token)\n```\n\n"
48119
"Then restart the UI.")
49120
elif "401" in msg or "invalid" in msg.lower():
50121
return f"⚠️ Invalid API key for {model}. Check OPENAI_API_KEY."
51122
else:
52123
return f"⚠️ API check failed ({model}): {e}"
124+
125+
53126
DEFAULT_TESTS_DIR = DEFAULT_AGENT_DIR / "tests"
54127
DEFAULT_AGENT_PATTERN = "libe_agent*.py"
55128
NONE_OPTION = "(none)"
@@ -142,15 +215,46 @@ async def _run():
142215
_init_tests = scan_script_dirs(str(DEFAULT_TESTS_DIR))
143216
_init_versions = scan_versions(str(DEFAULT_AGENT_DIR))
144217

218+
# Determine endpoint label for title
219+
_cur_base = os.environ.get("OPENAI_BASE_URL", "")
220+
if "metis" in _cur_base:
221+
_endpoint_label = "ALCF Metis"
222+
elif "sophia" in _cur_base:
223+
_endpoint_label = "ALCF Sophia"
224+
elif "alcf" in _cur_base.lower():
225+
_endpoint_label = "ALCF"
226+
elif _cur_base:
227+
_endpoint_label = _cur_base.split("//")[-1].split("/")[0]
228+
else:
229+
_endpoint_label = "OpenAI"
230+
231+
# Fetch available models (one quick call at startup)
232+
_init_model_label = _current_model_label()
233+
_cur_model = os.environ.get("LLM_MODEL", "gpt-4o-mini")
234+
_init_model_choices, _init_model_map, _init_model_err = _fetch_models()
235+
if _init_model_label not in _init_model_map:
236+
_init_model_choices = [_init_model_label] + _init_model_choices
237+
_init_model_map[_init_model_label] = (_cur_model, _cur_base)
238+
if not _init_model_choices:
239+
_init_model_choices = [_init_model_label]
240+
if _init_model_err:
241+
print(f"⚠ Model fetch: {_init_model_err}")
242+
print(f" Check OPENAI_API_KEY and OPENAI_BASE_URL match the same service.")
243+
145244
with gr.Blocks() as demo:
146245
with gr.Row():
147-
gr.Markdown(f"### libEnsemble Agent   ·   `{_get_model_name()}`")
246+
gr.Markdown(f"### libEnsemble Agent   ·   `{_endpoint_label}`")
247+
model_dropdown = gr.Dropdown(
248+
choices=_init_model_choices, value=_init_model_label,
249+
show_label=False, allow_custom_value=True, scale=2, min_width=300
250+
)
148251
with gr.Column(scale=0, min_width=60):
149-
settings_btn = gr.Button("⚙️", size="sm")
252+
settings_btn = gr.Button("⚙️")
150253

151254
agent_dir_state = gr.State(value=str(DEFAULT_AGENT_DIR))
152255
scripts_dir_state = gr.State(value=str(DEFAULT_TESTS_DIR))
153256
agent_pattern_state = gr.State(value=DEFAULT_AGENT_PATTERN)
257+
model_map_state = gr.State(value=_init_model_map)
154258
settings_visible = gr.State(value=False)
155259

156260
with gr.Column(visible=False) as settings_modal:
@@ -210,14 +314,22 @@ def start_websocket():
210314

211315
# --- Core event handlers ---
212316

213-
def start_run(agent_script, scripts_dir, history, agent_dir_val, scripts_dir_val):
317+
def start_run(agent_script, scripts_dir, history, agent_dir_val, scripts_dir_val,
318+
model_label, model_map):
214319
"""Send run command and add user message to chat"""
215320
if not agent_script:
216321
history = history + [{"role": "assistant", "content": "⚠️ No agent script selected"}]
217322
return history
218323

219-
# Preflight API check
220-
api_err = _check_api()
324+
# Resolve model from dropdown selection
325+
if model_label and model_label in model_map:
326+
sel_model, sel_base_url = model_map[model_label]
327+
else:
328+
sel_model = os.environ.get("LLM_MODEL", "gpt-4o-mini")
329+
sel_base_url = os.environ.get("OPENAI_BASE_URL", "")
330+
331+
# Preflight API check with selected model
332+
api_err = _check_api(model=sel_model, base_url=sel_base_url or None)
221333
if api_err:
222334
history = history + [{"role": "assistant", "content": api_err}]
223335
return history
@@ -245,7 +357,9 @@ def start_run(agent_script, scripts_dir, history, agent_dir_val, scripts_dir_val
245357
"type": "run",
246358
"agent_script": agent_script,
247359
"scripts_dir": resolved,
248-
"agent_dir": str(agent_dir)
360+
"agent_dir": str(agent_dir),
361+
"llm_model": sel_model,
362+
"openai_base_url": sel_base_url,
249363
}))
250364
return history
251365

@@ -370,7 +484,8 @@ def reset_ui():
370484
# Run button: start script → stream output → refresh versions → load scripts
371485
run_btn.click(
372486
start_run,
373-
inputs=[agent_dropdown, scripts_dropdown, chatbot, agent_dir_state, scripts_dir_state],
487+
inputs=[agent_dropdown, scripts_dropdown, chatbot, agent_dir_state, scripts_dir_state,
488+
model_dropdown, model_map_state],
374489
outputs=[chatbot]
375490
).then(
376491
stream_output, inputs=[chatbot], outputs=[chatbot]

0 commit comments

Comments
 (0)