Skip to content

Commit 9e923e2

Browse files
committed
🎈 perf: improve embedding perform
1 parent 2e012f5 commit 9e923e2

1 file changed

Lines changed: 22 additions & 1 deletion

File tree

src/engine/embedding/sentence_transformer.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import concurrent.futures
13
from typing import List
24

35
from cachetools import LRUCache
@@ -22,6 +24,16 @@ def from_config(cls, config: EmbeddingConfig) -> "EmbeddingEngine":
2224

2325
logger.success(f"[EmbeddingEngine] load model from {config.path}")
2426
return cls(model=model, cache=cache, **config.model_dump())
27+
28+
@property
29+
def executor(self) -> concurrent.futures.ThreadPoolExecutor:
30+
"""延迟初始化线程池执行器"""
31+
if self._executor is None:
32+
self._executor = concurrent.futures.ThreadPoolExecutor(
33+
max_workers=4, # 单个模型使用一个worker避免GPU内存冲突
34+
thread_name_prefix=f"gpu_reranker_{self.alias}"
35+
)
36+
return self._executor
2537

2638
def invoke(self, sentences: List[str]) -> List[EmbeddingResult]:
2739
not_cached_ids = []
@@ -60,6 +72,15 @@ def invoke(self, sentences: List[str]) -> List[EmbeddingResult]:
6072
embeddings = no_cached_embeds
6173

6274
return [EmbeddingResult(embedding=embedding, index=i, object="embedding") for i, embedding in enumerate(embeddings)]
75+
76+
async def async_invoke(self, sentences: List[str]) -> List[EmbeddingResult]:
77+
loop = asyncio.get_event_loop()
78+
return await loop.run_in_executor(
79+
self.executor,
80+
self.invoke,
81+
sentences
82+
)
83+
6384

6485
def stream(self, sentences: List[str]) -> List[EmbeddingResult]:
65-
raise NotImplementedError(f"{self.__class__.__name__} does not implement `stream` method")
86+
raise NotImplementedError(f"{self.__class__.__name__} does not implement `stream` method")

0 commit comments

Comments
 (0)