1+ import asyncio
2+ import concurrent .futures
13from typing import List
24
35from cachetools import LRUCache
@@ -22,6 +24,16 @@ def from_config(cls, config: EmbeddingConfig) -> "EmbeddingEngine":
2224
2325 logger .success (f"[EmbeddingEngine] load model from { config .path } " )
2426 return cls (model = model , cache = cache , ** config .model_dump ())
27+
28+ @property
29+ def executor (self ) -> concurrent .futures .ThreadPoolExecutor :
30+ """延迟初始化线程池执行器"""
31+ if self ._executor is None :
32+ self ._executor = concurrent .futures .ThreadPoolExecutor (
33+ max_workers = 4 , # 单个模型使用一个worker避免GPU内存冲突
34+ thread_name_prefix = f"gpu_reranker_{ self .alias } "
35+ )
36+ return self ._executor
2537
2638 def invoke (self , sentences : List [str ]) -> List [EmbeddingResult ]:
2739 not_cached_ids = []
@@ -60,6 +72,15 @@ def invoke(self, sentences: List[str]) -> List[EmbeddingResult]:
6072 embeddings = no_cached_embeds
6173
6274 return [EmbeddingResult (embedding = embedding , index = i , object = "embedding" ) for i , embedding in enumerate (embeddings )]
75+
76+ async def async_invoke (self , sentences : List [str ]) -> List [EmbeddingResult ]:
77+ loop = asyncio .get_event_loop ()
78+ return await loop .run_in_executor (
79+ self .executor ,
80+ self .invoke ,
81+ sentences
82+ )
83+
6384
6485 def stream (self , sentences : List [str ]) -> List [EmbeddingResult ]:
65- raise NotImplementedError (f"{ self .__class__ .__name__ } does not implement `stream` method" )
86+ raise NotImplementedError (f"{ self .__class__ .__name__ } does not implement `stream` method" )
0 commit comments