|
1 | 1 | import asyncio |
2 | 2 | import json |
| 3 | +import os |
3 | 4 | import socket |
4 | 5 | import subprocess |
5 | 6 | import sys |
|
9 | 10 | from queue import Queue, Empty |
10 | 11 |
|
11 | 12 | import gradio as gr |
| 13 | +import requests |
12 | 14 | import websockets |
13 | 15 |
|
14 | 16 | WS_URL = "ws://127.0.0.1:8000/ws/test" |
15 | 17 | DEFAULT_AGENT_DIR = Path(__file__).parent.parent |
| 18 | +ALCF_API_BASE = "https://inference-api.alcf.anl.gov" |
| 19 | +ALCF_ENDPOINTS_URL = f"{ALCF_API_BASE}/resource_server/list-endpoints" |
16 | 20 |
|
17 | | -def _get_model_name(): |
18 | | - import os |
| 21 | + |
| 22 | +def _fetch_models(): |
| 23 | + """Fetch available models. Uses ALCF list-endpoints for ALCF, or /v1/models for OpenAI. |
| 24 | + Returns (choices, model_map, error_or_none). |
| 25 | + model_map: label -> (model_name, base_url) |
| 26 | + """ |
| 27 | + api_key = os.environ.get("OPENAI_API_KEY", "") |
| 28 | + base_url = os.environ.get("OPENAI_BASE_URL", "") |
| 29 | + |
| 30 | + if not api_key: |
| 31 | + return [], {}, "No OPENAI_API_KEY set" |
| 32 | + |
| 33 | + # ALCF: use their list-endpoints API (returns models grouped by cluster) |
| 34 | + if "alcf" in base_url.lower(): |
| 35 | + try: |
| 36 | + resp = requests.get( |
| 37 | + ALCF_ENDPOINTS_URL, |
| 38 | + headers={"Authorization": f"Bearer {api_key}"}, |
| 39 | + timeout=10, |
| 40 | + ) |
| 41 | + if resp.status_code in (401, 403): |
| 42 | + return [], {}, "Auth failed — token may be expired" |
| 43 | + resp.raise_for_status() |
| 44 | + data = resp.json() |
| 45 | + except requests.RequestException as e: |
| 46 | + return [], {}, f"Cannot reach ALCF API: {e}" |
| 47 | + |
| 48 | + skip = {"embed", "genslm"} |
| 49 | + choices = [] |
| 50 | + model_map = {} |
| 51 | + for cluster, info in data.get("clusters", {}).items(): |
| 52 | + cluster_url = f"{ALCF_API_BASE}{info['base_url']}/api/v1" |
| 53 | + for fw in info.get("frameworks", {}).values(): |
| 54 | + if "/v1/chat/completions" not in fw.get("endpoints", []): |
| 55 | + continue |
| 56 | + for model in fw.get("models", []): |
| 57 | + if any(s in model.lower() for s in skip): |
| 58 | + continue |
| 59 | + label = f"{model} ({cluster})" |
| 60 | + choices.append(label) |
| 61 | + model_map[label] = (model, cluster_url) |
| 62 | + return sorted(choices), model_map, None |
| 63 | + |
| 64 | + # Any other OpenAI-compatible endpoint: query /v1/models |
| 65 | + try: |
| 66 | + from openai import OpenAI |
| 67 | + client = OpenAI(api_key=api_key, base_url=base_url or None) |
| 68 | + models = client.models.list() |
| 69 | + source = "OpenAI" if not base_url else base_url.split("//")[-1].split("/")[0] |
| 70 | + skip = {"embed", "tts", "whisper", "dall-e", "davinci", "babbage", "moderation"} |
| 71 | + choices = [] |
| 72 | + model_map = {} |
| 73 | + for m in sorted(models.data, key=lambda x: x.id): |
| 74 | + if any(s in m.id.lower() for s in skip): |
| 75 | + continue |
| 76 | + label = f"{m.id} ({source})" |
| 77 | + choices.append(label) |
| 78 | + model_map[label] = (m.id, base_url) |
| 79 | + return choices, model_map, None |
| 80 | + except Exception as e: |
| 81 | + msg = str(e) |
| 82 | + if "401" in msg or "invalid" in msg.lower() or "api key" in msg.lower(): |
| 83 | + return [], {}, "Invalid API key for this endpoint" |
| 84 | + return [], {}, f"Cannot fetch models: {type(e).__name__}" |
| 85 | + |
| 86 | + |
| 87 | +def _current_model_label(): |
| 88 | + """Label for the currently configured model.""" |
19 | 89 | model = os.environ.get("LLM_MODEL", "gpt-4o-mini") |
20 | 90 | base = os.environ.get("OPENAI_BASE_URL", "") |
21 | | - if "alcf" in base.lower(): |
22 | | - return f"{model} (ALCF)" |
| 91 | + if "metis" in base: |
| 92 | + return f"{model} (metis)" |
| 93 | + elif "sophia" in base: |
| 94 | + return f"{model} (sophia)" |
| 95 | + elif "alcf" in base.lower(): |
| 96 | + return f"{model} (alcf)" |
23 | 97 | elif base: |
24 | | - return f"{model} ({base.split('//')[1].split('/')[0]})" |
| 98 | + return f"{model} (custom)" |
25 | 99 | return f"{model} (OpenAI)" |
26 | 100 |
|
27 | 101 |
|
28 | | -def _check_api(): |
| 102 | +def _check_api(model=None, base_url=None): |
29 | 103 | """Quick API check. Returns None on success, or an error message string.""" |
30 | | - import os |
31 | 104 | from openai import OpenAI |
32 | | - model = os.environ.get("LLM_MODEL", "gpt-4o-mini") |
| 105 | + model = model or os.environ.get("LLM_MODEL", "gpt-4o-mini") |
| 106 | + base_url = base_url or os.environ.get("OPENAI_BASE_URL") |
33 | 107 | try: |
34 | | - client = OpenAI( |
35 | | - api_key=os.environ.get("OPENAI_API_KEY"), |
36 | | - base_url=os.environ.get("OPENAI_BASE_URL"), |
37 | | - ) |
| 108 | + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url=base_url) |
38 | 109 | client.chat.completions.create( |
39 | 110 | model=model, messages=[{"role": "user", "content": "hi"}], max_tokens=1 |
40 | 111 | ) |
41 | 112 | return None |
42 | 113 | except Exception as e: |
43 | 114 | msg = str(e) |
44 | 115 | if "403" in msg or "Permission" in msg.lower(): |
45 | | - return (f"⚠️ API auth failed ({model}). Token likely expired.\n\n" |
| 116 | + return ("⚠️ API auth failed. Token likely expired.\n\n" |
46 | 117 | "```\npython3 inference_auth_token.py authenticate --force\n" |
47 | 118 | "export OPENAI_API_KEY=$(python inference_auth_token.py get_access_token)\n```\n\n" |
48 | 119 | "Then restart the UI.") |
49 | 120 | elif "401" in msg or "invalid" in msg.lower(): |
50 | 121 | return f"⚠️ Invalid API key for {model}. Check OPENAI_API_KEY." |
51 | 122 | else: |
52 | 123 | return f"⚠️ API check failed ({model}): {e}" |
| 124 | + |
| 125 | + |
53 | 126 | DEFAULT_TESTS_DIR = DEFAULT_AGENT_DIR / "tests" |
54 | 127 | DEFAULT_AGENT_PATTERN = "libe_agent*.py" |
55 | 128 | NONE_OPTION = "(none)" |
@@ -142,15 +215,46 @@ async def _run(): |
142 | 215 | _init_tests = scan_script_dirs(str(DEFAULT_TESTS_DIR)) |
143 | 216 | _init_versions = scan_versions(str(DEFAULT_AGENT_DIR)) |
144 | 217 |
|
| 218 | +# Determine endpoint label for title |
| 219 | +_cur_base = os.environ.get("OPENAI_BASE_URL", "") |
| 220 | +if "metis" in _cur_base: |
| 221 | + _endpoint_label = "ALCF Metis" |
| 222 | +elif "sophia" in _cur_base: |
| 223 | + _endpoint_label = "ALCF Sophia" |
| 224 | +elif "alcf" in _cur_base.lower(): |
| 225 | + _endpoint_label = "ALCF" |
| 226 | +elif _cur_base: |
| 227 | + _endpoint_label = _cur_base.split("//")[-1].split("/")[0] |
| 228 | +else: |
| 229 | + _endpoint_label = "OpenAI" |
| 230 | + |
| 231 | +# Fetch available models (one quick call at startup) |
| 232 | +_init_model_label = _current_model_label() |
| 233 | +_cur_model = os.environ.get("LLM_MODEL", "gpt-4o-mini") |
| 234 | +_init_model_choices, _init_model_map, _init_model_err = _fetch_models() |
| 235 | +if _init_model_label not in _init_model_map: |
| 236 | + _init_model_choices = [_init_model_label] + _init_model_choices |
| 237 | + _init_model_map[_init_model_label] = (_cur_model, _cur_base) |
| 238 | +if not _init_model_choices: |
| 239 | + _init_model_choices = [_init_model_label] |
| 240 | +if _init_model_err: |
| 241 | + print(f"⚠ Model fetch: {_init_model_err}") |
| 242 | + print(f" Check OPENAI_API_KEY and OPENAI_BASE_URL match the same service.") |
| 243 | + |
145 | 244 | with gr.Blocks() as demo: |
146 | 245 | with gr.Row(): |
147 | | - gr.Markdown(f"### libEnsemble Agent · `{_get_model_name()}`") |
| 246 | + gr.Markdown(f"### libEnsemble Agent · `{_endpoint_label}`") |
| 247 | + model_dropdown = gr.Dropdown( |
| 248 | + choices=_init_model_choices, value=_init_model_label, |
| 249 | + show_label=False, allow_custom_value=True, scale=2, min_width=300 |
| 250 | + ) |
148 | 251 | with gr.Column(scale=0, min_width=60): |
149 | | - settings_btn = gr.Button("⚙️", size="sm") |
| 252 | + settings_btn = gr.Button("⚙️") |
150 | 253 |
|
151 | 254 | agent_dir_state = gr.State(value=str(DEFAULT_AGENT_DIR)) |
152 | 255 | scripts_dir_state = gr.State(value=str(DEFAULT_TESTS_DIR)) |
153 | 256 | agent_pattern_state = gr.State(value=DEFAULT_AGENT_PATTERN) |
| 257 | + model_map_state = gr.State(value=_init_model_map) |
154 | 258 | settings_visible = gr.State(value=False) |
155 | 259 |
|
156 | 260 | with gr.Column(visible=False) as settings_modal: |
@@ -210,14 +314,22 @@ def start_websocket(): |
210 | 314 |
|
211 | 315 | # --- Core event handlers --- |
212 | 316 |
|
213 | | - def start_run(agent_script, scripts_dir, history, agent_dir_val, scripts_dir_val): |
| 317 | + def start_run(agent_script, scripts_dir, history, agent_dir_val, scripts_dir_val, |
| 318 | + model_label, model_map): |
214 | 319 | """Send run command and add user message to chat""" |
215 | 320 | if not agent_script: |
216 | 321 | history = history + [{"role": "assistant", "content": "⚠️ No agent script selected"}] |
217 | 322 | return history |
218 | 323 |
|
219 | | - # Preflight API check |
220 | | - api_err = _check_api() |
| 324 | + # Resolve model from dropdown selection |
| 325 | + if model_label and model_label in model_map: |
| 326 | + sel_model, sel_base_url = model_map[model_label] |
| 327 | + else: |
| 328 | + sel_model = os.environ.get("LLM_MODEL", "gpt-4o-mini") |
| 329 | + sel_base_url = os.environ.get("OPENAI_BASE_URL", "") |
| 330 | + |
| 331 | + # Preflight API check with selected model |
| 332 | + api_err = _check_api(model=sel_model, base_url=sel_base_url or None) |
221 | 333 | if api_err: |
222 | 334 | history = history + [{"role": "assistant", "content": api_err}] |
223 | 335 | return history |
@@ -245,7 +357,9 @@ def start_run(agent_script, scripts_dir, history, agent_dir_val, scripts_dir_val |
245 | 357 | "type": "run", |
246 | 358 | "agent_script": agent_script, |
247 | 359 | "scripts_dir": resolved, |
248 | | - "agent_dir": str(agent_dir) |
| 360 | + "agent_dir": str(agent_dir), |
| 361 | + "llm_model": sel_model, |
| 362 | + "openai_base_url": sel_base_url, |
249 | 363 | })) |
250 | 364 | return history |
251 | 365 |
|
@@ -370,7 +484,8 @@ def reset_ui(): |
370 | 484 | # Run button: start script → stream output → refresh versions → load scripts |
371 | 485 | run_btn.click( |
372 | 486 | start_run, |
373 | | - inputs=[agent_dropdown, scripts_dropdown, chatbot, agent_dir_state, scripts_dir_state], |
| 487 | + inputs=[agent_dropdown, scripts_dropdown, chatbot, agent_dir_state, scripts_dir_state, |
| 488 | + model_dropdown, model_map_state], |
374 | 489 | outputs=[chatbot] |
375 | 490 | ).then( |
376 | 491 | stream_output, inputs=[chatbot], outputs=[chatbot] |
|
0 commit comments