Skip to content

Commit c1f0f18

Browse files
committed
Improve model status alerts
1 parent 823e955 commit c1f0f18

6 files changed

Lines changed: 111 additions & 20 deletions

File tree

align_app/adm/decider/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from align_utils.models import ADMResult, Decision, ChoiceInfo
22
from .decider import MultiprocessDecider
3-
from .client import get_decision, is_model_cached
3+
from .client import get_decision, get_model_cache_status
44
from .types import DeciderParams
55

66
__all__ = [
77
"MultiprocessDecider",
88
"get_decision",
9-
"is_model_cached",
9+
"get_model_cache_status",
1010
"DeciderParams",
1111
"ADMResult",
1212
"Decision",

align_app/adm/decider/client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Dict, Any
88
from align_utils.models import ADMResult
99
from .decider import MultiprocessDecider
10+
from .worker import CacheQueryResult
1011
from .types import DeciderParams
1112

1213
_decider = None
@@ -33,10 +34,12 @@ async def get_decision(params: DeciderParams) -> ADMResult:
3334
return await process_manager.get_decision(params)
3435

3536

36-
async def is_model_cached(resolved_config: Dict[str, Any]) -> bool:
37-
"""Check if model for this config is already loaded in worker."""
37+
async def get_model_cache_status(
38+
resolved_config: Dict[str, Any]
39+
) -> CacheQueryResult | None:
40+
"""Get best-effort model cache status (memory + disk)."""
3841
process_manager = _get_process_manager()
39-
return await process_manager.is_model_cached(resolved_config)
42+
return await process_manager.get_model_cache_status(resolved_config)
4043

4144

4245
def cleanup():

align_app/adm/decider/decider.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ class MultiprocessDecider:
1414
def __init__(self):
1515
self.worker: WorkerHandle = create_worker(decider_worker_func)
1616

17-
async def is_model_cached(self, resolved_config: Dict[str, Any]) -> bool:
17+
async def get_model_cache_status(
18+
self, resolved_config: Dict[str, Any]
19+
) -> CacheQueryResult | None:
1820
self.worker, result = await send(self.worker, CacheQuery(resolved_config))
1921
if isinstance(result, CacheQueryResult):
20-
return result.is_cached
21-
return False
22+
return result
23+
return None
2224

2325
async def get_decision(self, params: DeciderParams) -> ADMResult:
2426
self.worker, result = await send(self.worker, params)

align_app/adm/decider/worker.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import hashlib
33
import json
44
import logging
5+
import os
56
import traceback
67
from dataclasses import dataclass
7-
from typing import Dict, Tuple, Callable, Any
8+
from typing import Dict, Tuple, Callable, Any, Optional
89
from multiprocessing import Queue
910
from align_utils.models import ADMResult
1011
from .executor import instantiate_adm
@@ -24,11 +25,57 @@ class CacheQuery:
2425
@dataclass
2526
class CacheQueryResult:
2627
is_cached: bool
28+
is_downloaded: Optional[bool]
29+
30+
31+
def _extract_model_name(resolved_config: Dict[str, Any]) -> Optional[str]:
32+
if not isinstance(resolved_config, dict):
33+
return None
34+
35+
if isinstance(resolved_config.get("model_name"), str):
36+
return resolved_config["model_name"]
37+
38+
structured = resolved_config.get("structured_inference_engine")
39+
if isinstance(structured, dict) and isinstance(structured.get("model_name"), str):
40+
return structured["model_name"]
41+
42+
for value in resolved_config.values():
43+
if isinstance(value, dict):
44+
found = _extract_model_name(value)
45+
if found:
46+
return found
47+
elif isinstance(value, list):
48+
for item in value:
49+
if isinstance(item, dict):
50+
found = _extract_model_name(item)
51+
if found:
52+
return found
53+
return None
54+
55+
56+
def _is_model_downloaded(model_name: Optional[str]) -> Optional[bool]:
57+
if not model_name:
58+
return None
59+
60+
if os.path.exists(model_name):
61+
return True
62+
63+
try:
64+
from huggingface_hub import snapshot_download
65+
except Exception:
66+
return None
67+
68+
try:
69+
snapshot_download(model_name, local_files_only=True)
70+
return True
71+
except Exception:
72+
return False
2773

2874

2975
def decider_worker_func(task_queue: Queue, result_queue: Queue):
3076
root_logger = logging.getLogger()
3177
root_logger.setLevel("WARNING")
78+
logger = logging.getLogger(__name__)
3279

3380
model_cache: Dict[str, Tuple[Callable, Callable]] = {}
3481

@@ -37,8 +84,15 @@ def decider_worker_func(task_queue: Queue, result_queue: Queue):
3784
try:
3885
if isinstance(task, CacheQuery):
3986
cache_key = extract_cache_key(task.resolved_config)
87+
is_cached = cache_key in model_cache
88+
is_downloaded = True if is_cached else _is_model_downloaded(
89+
_extract_model_name(task.resolved_config)
90+
)
4091
result_queue.put(
41-
CacheQueryResult(is_cached=cache_key in model_cache)
92+
CacheQueryResult(
93+
is_cached=is_cached,
94+
is_downloaded=is_downloaded,
95+
)
4296
)
4397
continue
4498

@@ -72,11 +126,30 @@ def decider_worker_func(task_queue: Queue, result_queue: Queue):
72126
except (KeyboardInterrupt, SystemExit):
73127
break
74128
except Exception as e:
75-
error_msg = f"{str(e)}\n{traceback.format_exc()}"
129+
logger.error("Worker error:\n%s", traceback.format_exc())
130+
error_msg = _format_worker_error(e)
76131
result_queue.put(Exception(error_msg))
77132
finally:
78133
for _, (_, cleanup_func) in model_cache.items():
79134
try:
80135
cleanup_func()
81136
except Exception:
82137
pass
138+
139+
140+
def _format_worker_error(error: Exception) -> str:
141+
error_text = str(error)
142+
gated_tokens = (
143+
"GatedRepoError",
144+
"gated repo",
145+
"401 Client Error",
146+
"Access to model",
147+
"restricted",
148+
"Please log in",
149+
)
150+
if any(token in error_text for token in gated_tokens):
151+
return (
152+
"Model access denied. Authenticate with Hugging Face or request access "
153+
"to the gated repo."
154+
)
155+
return f"{error_text}\n{traceback.format_exc()}"

align_app/app/runs_state_adapter.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .runs_registry import RunsRegistry
88
from .runs_table_filter import RunsTableFilter
99
from ..adm.decider.types import DeciderParams
10-
from ..adm.decider import is_model_cached
10+
from ..adm.decider import get_model_cache_status
1111
from ..adm.system_adm_discovery import discover_system_adms
1212
from ..utils.utils import get_id
1313
from .runs_presentation import extract_base_scenarios
@@ -614,16 +614,18 @@ async def _execute_run_decision(self, run_id: str):
614614

615615
run = self.runs_registry.get_run(run_id)
616616
is_cached_decision = self.runs_registry.has_cached_decision(run_id)
617-
is_model_loaded = False
617+
status = None
618618
if run:
619-
is_model_loaded = await is_model_cached(run.decider_params.resolved_config)
619+
status = await get_model_cache_status(run.decider_params.resolved_config)
620620

621-
if is_cached_decision or is_model_loaded:
622-
alert_id = self._alerts.create_info_alert(title="Deciding...", timeout=0)
621+
if is_cached_decision or (status and status.is_cached):
622+
alert_title = "Deciding..."
623+
elif status and status.is_downloaded is False:
624+
alert_title = "Downloading model and deciding..."
623625
else:
624-
alert_id = self._alerts.create_info_alert(
625-
title="Loading model and deciding...", timeout=0
626-
)
626+
alert_title = "Loading model and deciding..."
627+
628+
alert_id = self._alerts.create_info_alert(title=alert_title, timeout=0)
627629
await self.server.network_completion
628630

629631
try:
@@ -632,7 +634,15 @@ async def _execute_run_decision(self, run_id: str):
632634
self._alerts.create_info_alert(title="Decision complete", timeout=3000)
633635
except Exception as e:
634636
self._alerts.remove_alert(alert_id)
635-
self._alerts.create_info_alert(title=f"Decision failed: {e}", timeout=5000)
637+
error_text = str(e)
638+
if "Model access denied" in error_text:
639+
message = (
640+
"Decision failed: Model access denied. "
641+
"Authenticate with Hugging Face or request access to the model."
642+
)
643+
else:
644+
message = f"Decision failed: {e}"
645+
self._alerts.create_info_alert(title=message, timeout=8000)
636646

637647
with self.state:
638648
self._rebuild_comparison_runs()

align_app/app/ui.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,9 @@ def __init__(
15641564
".drop-zone-active { outline: 3px dashed #1976d2 !important; outline-offset: -3px; }"
15651565
".alert-popup-container { left: auto; right: 0; transform: none; width: fit-content; }"
15661566
".alert-popup-container .v-alert { --v-theme-info: 66, 66, 66; }"
1567+
".alert-popup-container .v-alert__icon { display: none; }"
1568+
".alert-popup-container .v-alert__prepend { display: none; }"
1569+
".alert-popup-container .v-alert__prepend .v-icon { display: none; }"
15671570
"</style>'"
15681571
)
15691572
)

0 commit comments

Comments
 (0)