22
33import os
44import logging
5- import backoff
5+ import random
6+ import time
67from typing import Dict , Any , Optional , List , Sequence
78
89from adalflow .core .model_client import ModelClient
@@ -54,16 +55,21 @@ def __init__(
5455 self ,
5556 api_key : Optional [str ] = None ,
5657 env_api_key_name : str = "GOOGLE_API_KEY" ,
58+ inter_batch_delay : float = 0.2 ,
5759 ):
5860 """Initialize Google AI Embeddings client.
59-
61+
6062 Args:
6163 api_key: Google AI API key. If not provided, uses environment variable.
6264 env_api_key_name: Name of environment variable containing API key.
65+ inter_batch_delay: Seconds to sleep after each successful embedding
66+ API call to avoid burst-hitting rate limits (default: 0.2s).
67+ Set to 0 to disable.
6368 """
6469 super ().__init__ ()
6570 self ._api_key = api_key
6671 self ._env_api_key_name = env_api_key_name
72+ self ._inter_batch_delay = inter_batch_delay
6773 self ._initialize_client ()
6874
6975 def _initialize_client (self ):
@@ -205,27 +211,57 @@ def convert_inputs_to_api_kwargs(
205211
206212 return final_model_kwargs
207213
208- @backoff .on_exception (
209- backoff .expo ,
210- (Exception ,), # Google AI may raise various exceptions
211- max_time = 5 ,
212- )
214+ # Retry configuration for rate-limit and transient errors
215+ _MAX_RETRIES = 5
216+ _BASE_DELAY = 1.0 # seconds
217+ _MAX_DELAY = 16.0 # seconds (cap for exponential backoff)
218+ _JITTER_MAX = 1.0 # max random jitter in seconds
219+
220+ @staticmethod
221+ def _is_retryable (exc : Exception ) -> bool :
222+ """Return True if the exception indicates a retryable error (429 / 503)."""
223+ exc_str = str (exc ).lower ()
224+ # google.api_core.exceptions.ResourceExhausted (429)
225+ if "resourceexhausted" in type (exc ).__name__ .lower ():
226+ return True
227+ # google.api_core.exceptions.ServiceUnavailable (503)
228+ if "serviceunavailable" in type (exc ).__name__ .lower ():
229+ return True
230+ # Catch by HTTP status code mentions in the message
231+ if "429" in exc_str or "resource exhausted" in exc_str :
232+ return True
233+ if "503" in exc_str or "service unavailable" in exc_str :
234+ return True
235+ # google.generativeai may raise a generic exception wrapping these
236+ if hasattr (exc , "code" ):
237+ code = getattr (exc , "code" , None )
238+ if code in (429 , 503 ):
239+ return True
240+ if hasattr (exc , "status_code" ):
241+ status = getattr (exc , "status_code" , None )
242+ if status in (429 , 503 ):
243+ return True
244+ return False
245+
213246 def call (self , api_kwargs : Dict = {}, model_type : ModelType = ModelType .UNDEFINED ):
214- """Call Google AI embedding API.
215-
247+ """Call Google AI embedding API with retry + exponential backoff.
248+
249+ Retries on 429 (ResourceExhausted) and 503 (ServiceUnavailable) errors
250+ with exponential backoff: 1s, 2s, 4s, 8s, 16s plus random jitter.
251+
216252 Args:
217253 api_kwargs: API parameters
218254 model_type: Should be ModelType.EMBEDDER
219-
255+
220256 Returns:
221257 Google AI embedding response
222258 """
223259 if model_type != ModelType .EMBEDDER :
224260 raise ValueError (f"GoogleEmbedderClient only supports EMBEDDER model type" )
225-
261+
226262 # DEBUG LOGGING (Simplified)
227263 log .info (f"DEBUG: GoogleEmbedderClient.call received api_kwargs keys: { list (api_kwargs .keys ())} " )
228-
264+
229265 safe_log_kwargs = {k : v for k , v in api_kwargs .items () if k not in {"content" , "contents" }}
230266 if "content" in api_kwargs :
231267 safe_log_kwargs ["content_chars" ] = len (str (api_kwargs .get ("content" , "" )))
@@ -236,28 +272,59 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
236272 except Exception :
237273 safe_log_kwargs ["contents_count" ] = None
238274 log .info ("Google AI Embeddings call kwargs (sanitized): %s" , safe_log_kwargs )
239-
240- try :
241- # Use embed_content for single text or batch embedding
242- # CRITICAL FIX: Do not modify api_kwargs in place as it breaks backoff retries!
243- call_kwargs = api_kwargs .copy ()
244-
245- if "content" in call_kwargs :
246- # Single embedding
247- response = genai .embed_content (** call_kwargs )
248- elif "contents" in call_kwargs :
249- # Batch embedding - Google AI supports batch natively
250- contents = call_kwargs .pop ("contents" )
251- # pass as 'content' argument which handles both single and batch in newer SDKs
252- response = genai .embed_content (content = contents , ** call_kwargs )
253- else :
254- raise ValueError (f"Either 'content' or 'contents' must be provided. Got kwargs: { list (api_kwargs .keys ())} " )
255-
256- return response
257-
258- except Exception as e :
259- log .error (f"Error calling Google AI Embeddings API: { e } " )
260- raise
275+
276+ last_exception : Optional [Exception ] = None
277+
278+ for attempt in range (self ._MAX_RETRIES + 1 ):
279+ try :
280+ # CRITICAL FIX: Do not modify api_kwargs in place as it breaks retries!
281+ call_kwargs = api_kwargs .copy ()
282+
283+ if "content" in call_kwargs :
284+ # Single embedding
285+ response = genai .embed_content (** call_kwargs )
286+ elif "contents" in call_kwargs :
287+ # Batch embedding - Google AI supports batch natively
288+ contents = call_kwargs .pop ("contents" )
289+ # pass as 'content' argument which handles both single and batch
290+ response = genai .embed_content (content = contents , ** call_kwargs )
291+ else :
292+ raise ValueError (
293+ f"Either 'content' or 'contents' must be provided. "
294+ f"Got kwargs: { list (api_kwargs .keys ())} "
295+ )
296+
297+ # Inter-batch cooldown to avoid burst-hitting rate limits
298+ if self ._inter_batch_delay > 0 :
299+ time .sleep (self ._inter_batch_delay )
300+
301+ return response
302+
303+ except Exception as e :
304+ last_exception = e
305+
306+ if not self ._is_retryable (e ) or attempt >= self ._MAX_RETRIES :
307+ log .error (
308+ "Google AI Embeddings API call failed (attempt %d/%d, non-retryable or max retries): %s" ,
309+ attempt + 1 , self ._MAX_RETRIES + 1 , e ,
310+ )
311+ raise
312+
313+ # Exponential backoff with jitter
314+ delay = min (self ._BASE_DELAY * (2 ** attempt ), self ._MAX_DELAY )
315+ jitter = random .uniform (0 , self ._JITTER_MAX )
316+ sleep_time = delay + jitter
317+
318+ log .warning (
319+ "Google AI Embeddings API returned retryable error (attempt %d/%d): %s. "
320+ "Retrying in %.1fs ..." ,
321+ attempt + 1 , self ._MAX_RETRIES + 1 , e , sleep_time ,
322+ )
323+ time .sleep (sleep_time )
324+
325+ # Should not be reached, but just in case
326+ if last_exception :
327+ raise last_exception
261328
262329 async def acall (self , api_kwargs : Dict = {}, model_type : ModelType = ModelType .UNDEFINED ):
263330 """Async call to Google AI embedding API.
0 commit comments