22import hashlib
33import json
44import logging
5+ import os
56import traceback
67from dataclasses import dataclass
7- from typing import Dict , Tuple , Callable , Any
8+ from typing import Dict , Tuple , Callable , Any , Optional
89from multiprocessing import Queue
910from align_utils .models import ADMResult
1011from .executor import instantiate_adm
@@ -24,11 +25,57 @@ class CacheQuery:
2425@dataclass
2526class CacheQueryResult :
2627 is_cached : bool
28+ is_downloaded : Optional [bool ]
29+
30+
31+ def _extract_model_name (resolved_config : Dict [str , Any ]) -> Optional [str ]:
32+ if not isinstance (resolved_config , dict ):
33+ return None
34+
35+ if isinstance (resolved_config .get ("model_name" ), str ):
36+ return resolved_config ["model_name" ]
37+
38+ structured = resolved_config .get ("structured_inference_engine" )
39+ if isinstance (structured , dict ) and isinstance (structured .get ("model_name" ), str ):
40+ return structured ["model_name" ]
41+
42+ for value in resolved_config .values ():
43+ if isinstance (value , dict ):
44+ found = _extract_model_name (value )
45+ if found :
46+ return found
47+ elif isinstance (value , list ):
48+ for item in value :
49+ if isinstance (item , dict ):
50+ found = _extract_model_name (item )
51+ if found :
52+ return found
53+ return None
54+
55+
56+ def _is_model_downloaded (model_name : Optional [str ]) -> Optional [bool ]:
57+ if not model_name :
58+ return None
59+
60+ if os .path .exists (model_name ):
61+ return True
62+
63+ try :
64+ from huggingface_hub import snapshot_download
65+ except Exception :
66+ return None
67+
68+ try :
69+ snapshot_download (model_name , local_files_only = True )
70+ return True
71+ except Exception :
72+ return False
2773
2874
2975def decider_worker_func (task_queue : Queue , result_queue : Queue ):
3076 root_logger = logging .getLogger ()
3177 root_logger .setLevel ("WARNING" )
78+ logger = logging .getLogger (__name__ )
3279
3380 model_cache : Dict [str , Tuple [Callable , Callable ]] = {}
3481
@@ -37,8 +84,15 @@ def decider_worker_func(task_queue: Queue, result_queue: Queue):
3784 try :
3885 if isinstance (task , CacheQuery ):
3986 cache_key = extract_cache_key (task .resolved_config )
87+ is_cached = cache_key in model_cache
88+ is_downloaded = True if is_cached else _is_model_downloaded (
89+ _extract_model_name (task .resolved_config )
90+ )
4091 result_queue .put (
41- CacheQueryResult (is_cached = cache_key in model_cache )
92+ CacheQueryResult (
93+ is_cached = is_cached ,
94+ is_downloaded = is_downloaded ,
95+ )
4296 )
4397 continue
4498
@@ -72,11 +126,30 @@ def decider_worker_func(task_queue: Queue, result_queue: Queue):
72126 except (KeyboardInterrupt , SystemExit ):
73127 break
74128 except Exception as e :
75- error_msg = f"{ str (e )} \n { traceback .format_exc ()} "
129+ logger .error ("Worker error:\n %s" , traceback .format_exc ())
130+ error_msg = _format_worker_error (e )
76131 result_queue .put (Exception (error_msg ))
77132 finally :
78133 for _ , (_ , cleanup_func ) in model_cache .items ():
79134 try :
80135 cleanup_func ()
81136 except Exception :
82137 pass
138+
139+
140+ def _format_worker_error (error : Exception ) -> str :
141+ error_text = str (error )
142+ gated_tokens = (
143+ "GatedRepoError" ,
144+ "gated repo" ,
145+ "401 Client Error" ,
146+ "Access to model" ,
147+ "restricted" ,
148+ "Please log in" ,
149+ )
150+ if any (token in error_text for token in gated_tokens ):
151+ return (
152+ "Model access denied. Authenticate with Hugging Face or request access "
153+ "to the gated repo."
154+ )
155+ return f"{ error_text } \n { traceback .format_exc ()} "
0 commit comments