Skip to content

Commit b188af2

Browse files
committed
fix: Enable embedding models via base_url's rather than hard coded embedding models
1 parent 65cbbe8 commit b188af2

3 files changed

Lines changed: 16 additions & 26 deletions

File tree

openevolve/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ class DatabaseConfig:
350350

351351
novelty_llm: Optional["LLMInterface"] = None
352352
embedding_model: Optional[str] = None
353+
embedding_base_url: Optional[str] = None
353354
similarity_threshold: float = 0.99
354355

355356

openevolve/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(self, config: DatabaseConfig):
204204

205205
self.novelty_llm = config.novelty_llm
206206
self.embedding_client = (
207-
EmbeddingClient(config.embedding_model) if config.embedding_model else None
207+
EmbeddingClient(config.embedding_model, base_url=config.embedding_base_url) if config.embedding_model else None
208208
)
209209
self.similarity_threshold = config.similarity_threshold
210210

openevolve/embedding.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,51 +10,40 @@
1010

1111
logger = logging.getLogger(__name__)
1212

13-
M = 1_000_000
14-
15-
OPENAI_EMBEDDING_MODELS = [
16-
"text-embedding-3-small",
17-
"text-embedding-3-large",
18-
]
19-
2013
AZURE_EMBEDDING_MODELS = [
2114
"azure-text-embedding-3-small",
2215
"azure-text-embedding-3-large",
2316
]
2417

25-
OPENAI_EMBEDDING_COSTS = {
26-
"text-embedding-3-small": 0.02 / M,
27-
"text-embedding-3-large": 0.13 / M,
28-
}
29-
3018

3119
class EmbeddingClient:
32-
def __init__(self, model_name: str = "text-embedding-3-small"):
20+
def __init__(self, model_name: str = "text-embedding-3-small", base_url: str | None = None):
3321
"""
3422
Initialize the EmbeddingClient.
3523
3624
Args:
37-
model (str): The OpenAI embedding model name to use.
25+
model_name: The embedding model name to use.
26+
base_url: Optional base URL for the embedding API endpoint.
3827
"""
39-
self.client, self.model = self._get_client_model(model_name)
28+
self.client, self.model = self._get_client_model(model_name, base_url)
4029

41-
def _get_client_model(self, model_name: str) -> tuple[openai.OpenAI, str]:
42-
if model_name in OPENAI_EMBEDDING_MODELS:
43-
# Use OPENAI_EMBEDDING_API_KEY if set, otherwise fall back to OPENAI_API_KEY
44-
# This allows users to use OpenRouter for LLMs while using OpenAI for embeddings
45-
embedding_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY")
46-
client = openai.OpenAI(api_key=embedding_api_key)
47-
model_to_use = model_name
48-
elif model_name in AZURE_EMBEDDING_MODELS:
30+
def _get_client_model(
31+
self, model_name: str, base_url: str | None = None
32+
) -> tuple[openai.OpenAI, str]:
33+
if model_name in AZURE_EMBEDDING_MODELS:
4934
# get rid of the azure- prefix
5035
model_to_use = model_name.split("azure-")[-1]
5136
client = openai.AzureOpenAI(
5237
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
5338
api_version=os.getenv("AZURE_API_VERSION"),
54-
azure_endpoint=os.getenv("AZURE_API_ENDPOINT"),
39+
azure_endpoint=os.environ["AZURE_API_ENDPOINT"],
5540
)
5641
else:
57-
raise ValueError(f"Invalid embedding model: {model_name}")
42+
# Use OPENAI_EMBEDDING_API_KEY if set, otherwise fall back to OPENAI_API_KEY
43+
# This allows users to use OpenRouter for LLMs while using OpenAI for embeddings
44+
embedding_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY")
45+
client = openai.OpenAI(api_key=embedding_api_key, base_url=base_url)
46+
model_to_use = model_name
5847

5948
return client, model_to_use
6049

0 commit comments

Comments
 (0)